よくわかる Gated Graph Neural Network (GGNN)

はじめに

こんなタイトルだけど、もしかしたらよくわからないかもしれない。
冗談です。世の中には、Graph Neural Network (GNN) [1] や、それをちょっと拡張させた Gated Graph Neural Network (GGNN) [2] というものがあります。 これらは名前から分かるように Neurarl Network の一種で、グラフ構造を直接取り扱うことができるというものです。

本記事では、主に機械学習などをやっている人を対象に、GNN や GGNN がどういったものなのかを解説したいと思います。

前提知識

  • Neural Network の基本的な知識。
  • RNN や LSTM に関する基本的な知識。
  • その他、機械学習についての基本的な知識。

GNN って何?

一言で言ってしまえば、「グラフ構造を直接取り扱うことのできるネットワークモデル」です。 グラフ構造とは、葉や節とも呼ばれるノードと、それらノードを繋ぐエッジからなる構造のことです。

f:id:hilinker:20190309095246p:plain
図1. [1] より、グラフ構造の例。

図 1 にとってもシンプルなグラフの例を挙げました。 ①~④がそれぞれノードであり、その間を繋いでいる矢印がエッジです。 図中のエッジに黒色と灰色がありますが、このようにノード間を繋ぐエッジの種類が複数ある場合もあります。

ここで、ノードの間を方向を持つ矢印で繋いでいますが、このように方向のあるエッジを有向エッジと呼び、有効エッジを持つようなグラフを有向グラフと呼びます。 一方で、ただノード間を方向のないエッジ(線)で繋いだものについても同様に、無向エッジとか無向グラフとか呼びます。

GNN では、基本的には有向グラフを扱うのですが、無向グラフも取り扱えます。 無向グラフは、単純に無向エッジを双方向の有向エッジに変換することで有向グラフにして取り扱います。

さて、ではグラフについての軽い解説が終わったところで、実際の GNN の解説に移りたいと思います。

GNN の解説としては、 Microsoft Research の上げている下の動画がはちゃめちゃに分かりやすいです。 なので、英語が分かる人はこっちを見ることをオススメします(僕のブログが読まれないのは悲しいけど)。

www.youtube.com

それはそれとして、日本語の解説も書きます。
GNN は大雑把にいってしまうと、「各ノードがメッセージを作成し、隣に伝言ゲームで伝えていくことで学習を行うモデル」となっています。 順を追って見て行きましょう。

まず、グラフはノード v ∈ V、エッジ e ∈ E からなります。ここで、V や E はグラフに含まれるノードやエッジの集合です。 それぞれのエッジはあるタイプ k ∈ K に分類されます。つまり、「このエッジは黒色!」とか「このエッジは灰色!」みたいな分類があるということです。より具体的に言えば、例えば化学の「単結合」「二重結合」「三重結合」のような分類が考えられます。このように、エッジは 1 種類ではなく、複数種類存在することがありえます。

そして、それぞれのノード v は、現在の自身の状態 h(v) (ベクトル)を持ちます。RNN の隠れ層のようなものですね。 各時刻 t において(急に時刻が出てきて混乱するかもしれませんが、あとで分かるのでそういうものだと思ってください)、各ノード v はタイプ k のエッジで繋がった隣のノードに、メッセージ m_k^v を送ります。 ここで、メッセージ m_k^v は、次のような式で計算されます。

\begin{align} m_k^v = f_k(h(v)) \end{align}

ここで、  f_k はエッジのタイプ k ごとに設定した、何かしらの関数です。これは何でもよく、一例としては、 Linear Layer などが挙げられます。  f_k は現在のノードの状態 h(v) から、隣のノードに伝達すべきメッセージを作成します。 このメッセージは、式からも分かるように現在の自分の状態を少し加工して隣に伝えるようなものになります。

メッセージを作成した後、各ノードはタイプ k のエッジで繋がった自分の隣のノードにメッセージ m_k^v を一斉に渡します。 メッセージを受け取った各ノードは、それを元に自分の状態 h(v) を更新します。 ここで、ノードによっては複数のメッセージを受け取る場合もあるでしょう。その場合には受け取ったメッセージを全て足し合わせた  \tilde{m}^v を受け取ったメッセージとします。 実際に次の式で自分の状態を更新します。

\begin{align} h'(v) = RNN(\tilde{m}^v, h(v)) \end{align}

ここで、RNN はそのまま RNN の状態更新を行う関数だとしてください。厳密には RNN でなくとも良いのですが、実質やっていることは RNN なのでこう記述しました。 RNN として simple RNN ではなく、 LSTM や GRU を用いたものを、 Gated Graph Neural Network (GGNN) と言います。違いはそれだけです。

以上のように、「自分の状態からメッセージを作成する」「隣にメッセージを伝達する」「受け取ったメッセージを元に自分の状態を更新する」という一連の流れを、各ノードで一斉に繰り返すことで学習を行うのが、 GNN および GGNN のアルゴリズムになります。

さて、各タイムステップごとに一斉に値を伝播するのですが、そのタイムステップはあらかじめ決めた値になります。例えば、5ステップで学習を終了させると予め決めていれば、上記の操作を5回繰り返して終わりになります。そのため、場合によっては遠くのノードに自分の情報が伝わらないこともあります。

GNN のいいところ

GNN の利点としては、何と言ってもグラフ構造をそのままニューラルネットワークのモデルに落とし込めることです。 化合物の構造や、構文木など、グラフ構造を持っているものはたくさんあります。 そうしたものを、そのままニューラルネットワークを通じて利用できるというのが、GNN の強みであると言えるでしょう。 実際に、ソースコードの解析などに用いている先行研究 [3] もあったりして、そこそこにいい性能が出ています。

また、 GNN は複数種類のエッジにも対応しているという強みもあります。 これにより、ノードとノードの様々な関係性を強くモデルに反映させることができます。

まとめ

以上が、 GNN や GGNN の解説になります。 今後、グラフを対象にした研究は(たぶん)増えていくので、理解しておいて損はないかなと思います。 グラフ構造をニューラルネットで扱いたい場合の選択肢の一つとしてどうぞ。

何か間違いなどありましたら、ご指摘頂けると幸いです。

参考

[1] Franco Scarselli, Marco Gori, Ah Chung Tsoi, Gabriele Monfardini. The Graph Neural Network Model. In IEEE, 2007.
[2] Yujia Li, Richard Zemel, Marc Brockschmidt, Daniel Tarlow. GATED GRAPH SEQUENCE NEURAL NETWORKS. In ICLR 2016.
[3] Miltiadis Allamanis, Marc Brockschmidt, and Mahmoud Khademi. Learning to represent programs with graphs. In ICLR, 2018.