GNN详解
-
Introduction
既然你点进了这篇博客,我就默认你了解图的最基本概念,包括但不限于有向图、无向图的定义,这里我不再多做赘述,下面我只阐述一些比较重要的部分
图神经网络是一种直接对图结构进行操作的神经网络。GNN的一个典型应用是节点分类。本质上,图中的每一个节点都与一个标签相关联。如下图所示,a节点的向量表示为
[0.3, 0.02, 7, 4, ...]
,将该向量送入下游继续做分类即可得到a节点的类别a节点的向量表示究竟是如何得到的,等会儿我会详细阐述,不过在这我们可以简单的思考一下,a节点的向量表示一定应该与其相邻节点和相邻边有关系,假设每个节点vvv的one-hot向量表示为XvX_vXv,则
$$
h_v=f(X_v, X_{co[v]}, h_{ne[v]},X_{ne[v]})
$$
其中Xco[v]X_{co[v]}Xco[v]表示与vvv相连边的one-hot表示,hne[v]h_{ne[v]}hne[v]表示与vvv相邻节点的embedding,Xne[v]X_{ne[v]}Xne[v]表示与vvv相邻节点的one-hot表示。最终再将hvh_vhv向量和真实标签计算loss,不过也有的地方是将hvh_vhv向量和XvX_vXv经过一轮融合之后再计算loss,例如
$$
\begin{aligned}
O_v&=g(h_v,X_v)\
loss &= \sum_{i=1}^p (t_i-o_i)
\end{aligned}
$$上述公式以及符号可能不好理解,这里其实可以类比Word2Vec。在Word2Vec中,输入的是每个词对应的one-hot,即XvX_vXv。输出的是这个词对应的embedding,即hvh_vhv
DeepWalk:第一个无监督学习节点嵌入的算法
DeepWalk用一句话概述就是:随机生成图节点序列,然后对该序列进行Word2Vec
具体来说,给定一个图,我们随机选择一个节点作为起始,然后随机"步行"到邻居节点,直到节点序列的长度达到给定的最大值。例如下图,分别选择d,e,f作为起点进行游走,得到三条节点序列
在这种情况下,我们可以将节点和节点序列分别看作是"单词"和"句子",然后利用skip-gram或者cbow算法训练得到每个节点的embedding
node2vec:bias random walk
一般的随机游走公式
$$
{P}\left(c_{i}=x \mid c_{i-1}=v\right)=\left{\begin{array}{cc}
\frac{1}{|N(v)|}, & \text { if }(v, x) \in E \
0, & \text { otherwise }
\end{array}\right.
$$
其中,∣N(v)∣|N(v)|∣N(v)∣表示vvv节点的邻居节点数量,ci−1c_{i-1}ci−1表示当前节点,cic_ici表示下一个选择的节点一般的随机游走存在以下几个问题:
- 如果是带权图,没考虑到边权值的影响
- 太过于随机,不能由模型自行学习以何种方式游走更好
实际上图的游走分两大类,即DFS和BFS
为了引入边权,以及依概率选择DFS或BFS,我们首先要将一般的随机游走公式进行修改
$$
{P}\left(c_{i}=x \mid c_{i-1}=v\right)=\left{\begin{array}{cc}
\frac{\pi_{vx}}{Z}, & \text { if }(v, x) \in E \
0, & \text { otherwise }
\end{array}\right.
$$
其中,πvxZ\frac{\pi_{vx}}{Z}Zπvx在值上与1∣N(v)∣\frac{1}{|N(v)|}∣N(v)∣1是相等的,ZZZ可以理解为归一化的缩放因子设节点vvv和xxx之间的边权为wvxw_{vx}wvx,则可以将πvx\pi_{vx}πvx改写为αpq(t,x)⋅wvx\alpha_{pq}(t,x)\cdot w_{vx}αpq(t,x)⋅wvx
$$
\alpha_{{pq}}(t, x)=\left{\begin{array}{l}
\frac{1}{p}, &\quad{ \text { if } d_{t x}=0} \
1, &\quad{\text { if } d_{t x}=1} \
\frac{1}{q}, &\quad{ \text { if } d_{t x}=2}
\end{array}\right.
$$
其中,dtxd_{tx}dtx表示当前节点vvv的一阶邻居节点到节点ttt的距离- ppp:控制随机游走以多大的概率"回头"
- qqq:控制随机游走偏向DFS还是BFS
- qqq较大时$(q>1)$,倾向于BFS
- qqq较小时$(q<1)$,倾向于DFS
- p=q=1p=q=1p=q=1时,πvx=wvx\pi_{vx}=w_{vx}πvx=wvx
到此为止,GNN中节点Embedding算法我就不再多叙述,其实还有很多更好的算法,例如LINE,Struct2vec等,不过我个人感觉这些embedding算法并不重要,或者说它们并不是GNN的核心部分,只要当作工具来用就行了,类比Transformer中Word Embedding的地位
References
- An introduction to Graph Neural Networks
- Random Walk in Node Embeddings (DeepWalk, node2vec, LINE, and GraphSAGE)
- How to do Deep Learning on Graphs with Graph Convolutional Networks
- A Gentle Introduction to Graph Neural Networks (Basics, DeepWalk, and GraphSage)
- Hands-on Graph Neural Networks with PyTorch & PyTorch Geometric
- PGL全球冠军团队带你攻破图神经网络
- 台大李宏毅助教讲解GNN图神经网络
- 图神经网络详解