jiku log

データサイエンスの核心を掴む : 学びと発見の記録

「データのつながりを活かす技術」を読む ~第6章 グラフニューラルネットワーク ④GNNの実装~

6.7 GNNの実装

本節では,GNN(Graph Neural Network)を使ったノード分類タスク(論文の技術領域の分類)とリンク予測タスク(論文の引用関係の予測)の実装方法について説明している。


本節で使う代表的なライブラリは以下の通りである。

  • PyTorch Geometric : 所与のノード特徴を用いたGNNによる技術領域の分類・引用関係の予測を行なう。
  • scikit-learn : 所与のノード特徴を用いたロジスティック回帰による技術領域の分類を行なう。


また本節では,以下の2つのパターンを比較する。

  1. 所与のノード特徴を用いたGNNによる技術領域の分類
  2. 所与のノード特徴を用いたロジスティック回帰による技術領域の分類


サンプルコードをGoogle Colab上で動かしながら挙動を確認した。
github.com

データはCoraデータセットであり,学習データとテストデータを6 : 4に分けて利用している。

GNNによる論文の技術領域の分類

本項では,GCN(Graph Convolutional Network)を用いたノード分類を説明している。

GCNは隣接関係を利用し,ノードの特徴を伝播させながら新しい特徴ベクトルを学習するモデルである。PyTorch Geometricでは,GCNConvクラスを用いる。

主要なパラメータは以下の通り。

  • projection_dim : 中間層の次元数を示す。

なお,num_node_features(入力する特徴量の次元数)や,num_class(分類するクラス数)は,学習データに応じて決定される。

損失関数は,予測したラベル間の項さエントロピーを用いる。

計算結果は以下の通りである。

  • 出力
              precision    recall  f1-score   support

           0       0.79      0.75      0.77       154
           1       0.77      0.78      0.78        83
           2       0.94      0.95      0.94       158
           3       0.89      0.88      0.88       332
           4       0.87      0.89      0.88       169
           5       0.80      0.85      0.82       120
           6       0.83      0.82      0.83        67

    accuracy                           0.86      1083
   macro avg       0.84      0.85      0.84      1083
weighted avg       0.86      0.86      0.86      1083

「データのつながりを活かす技術」を読む ~第5章 ノード埋め込み ⑥ノード埋め込みの実装~ - jiku log で紹介した,「ノード埋め込み+ロジスティック回帰」を用いたときのaccuracyは0.82,f1-scoreの重み付き平均も0.82だったのに比べると,accuracyおよびf1-scoreの重み付き平均は0.86なので,GCNの方が高い精度を実現していることが確認できた。
これはGCNの方が,論文同士の引用関係を上手く活用して,よりよい特徴を得ているためだと考えられる。

フィルタの差し替え

本項では,GraphSAGEモデルによる分類を説明している。これにはSAGEConvクラスを用いる。

計算結果は以下の通りである。

  • 出力
              precision    recall  f1-score   support

           0       0.82      0.75      0.78       154
           1       0.83      0.82      0.82        83
           2       0.95      0.98      0.96       158
           3       0.90      0.90      0.90       332
           4       0.90      0.89      0.90       169
           5       0.88      0.88      0.88       120
           6       0.79      0.90      0.84        67

    accuracy                           0.88      1083
   macro avg       0.87      0.87      0.87      1083
weighted avg       0.88      0.88      0.88      1083

accuracyおよびf1-scoreの重み付き平均は0.88なので,GCNよりも良くなっていることが確認できた。

GraphSAGEは,「データのつながりを活かす技術」を読む ~第6章 グラフニューラルネットワーク ②グラフ畳み込みネットワーク~ - jiku logで紹介したとおり,ネットワーク構造が完全には分かっていない場合でも扱える手法である。今回精度が高くなった理由は,この手法の特徴が出たためであるとは考えにくいが,様々な手法を試してみることは精度を向上させるためには重要な取り組みであると言える。

GNNによる論文の引用関係の予測

本項では,GNNによる引用関係の予測,すなわち「どの論文同士にリンク(引用関係)が生じるか」ということを予測するタスクについて紹介している。

このタスクにおいても,学習データとテストデータは,6 : 4に分割する。

リンク予測を行なうためのモデル

リンク予測を行なうためのモデルは,次のようにして構築する。

  1. GCNでノードの特徴を変換する。
  2. その後ノード同士のベクトル(特徴)の内積を取ることでリンクの有無を推定する。
モデルの評価

今回は,AUCにより評価する。

  • 出力
AUC : 0.8461

AUCは0.8461であるため,ランダムに予測するよりもリンクの有無を予測できている。

またROC曲線のプロット結果は以下の通りである。

AUC曲線

まとめと感想

今回は,「第6章 グラフニューラルネットワーク」における,GNNの実装についてまとめた。

技術領域の分類問題はGCN・GraphSAGEの実装例が紹介されていた。今回の問題設定では,GraphSAGEの性能の方が良かったが,今回のように複数の手法を手軽に試せることがPyTorch Geometricの利点であると考えられる。

本記事を最後まで読んでくださり,どうもありがとうございました。