图神经网络(Graph Neural Network, GNN)是一种建立在神经网络理论之上对图数据进行分析、学习的模型。其核心是图信号处理,包括对图信号的卷积滤波等等。

从任务的角度,对图数据的学习可以实现:

  • 节点层面:节点标签的分类、引文网络分类、恶意账户检测;
  • 边层面:社交网络中边预测、推荐系统;
  • 图层面:对图的整体结构进行分类、表示、生成,药物分子的分类、酶的分类;

本文希望对图神经网络需求的计算模式和目前实际使用的GNN框架内的计算方式进行总结。

一些历史

图神经网络算是一种图分析和深度学习的杂交衍生物吧,理论的完善算是比较晚的。

2005年首次由Marco Gori等人提出图神经网络的概念;

2009年两篇论文介绍如何使用监督的方法对图神经网络进行训练;

2013年Bruna等人首次将卷积引入图神经网络中,并基于频域卷积的概念开发出了图卷积神经网络;

2016年,Kipf等人简化了频域上的图卷积操作,使得能够在空域上进行,极大地提升了图卷积模型的计算效率。而这篇文章所提出的简化版本的实现,现在也被广泛使用,称为GCN层。GCN层进行堆叠组成的神经网络模型称为图卷积模型GCN。

这里频域上的卷积需要对图的Laplace矩阵进行特征分解,计算开销很大,而转到空域上就等价于进行矩阵向量乘法。

之后,各种GNN的变体被提出,考虑的角度有计算复杂度、任务需求、学习效率、节点异构等等。

GCN层

GCN层来源于 “Semi-supervised Classification with Graph Convolutional Networks” 。其计算公式为:

$$ X' = \hat{D}^{-1/2}\hat{A}\hat{D}^{-1/2}X\Theta $$

其中$X$是图信号矩阵,每个分量代表某个节点处的特征向量;A为邻接矩阵,$\hat{A}=A+I$(加上节点自己),邻接矩阵的值可以为权重$e_{ij}$;$\hat{D}_{ii}=\sum_{j}\hat{A}_{ij}$节点的度对角线矩阵;$\Theta$表示一个线性变换操作。

在单个节点层面,公式可以写成:

$$ x_i' = \Theta^T\sum_{j\in N(i)\cup\{i\}}\frac{e_{ji}}{\sqrt{\hat{d}_j{\hat{d}_i}}}x_{j} $$

其中$\hat{d}_i=1+\sum_{j\in N(i)}e_{ji}$,$e_{ji}$表示从源节点$j$到目标节点$i$的边权重。

PyG框架

PyG(Pytorch Geometric)是基于Pytorch框架的一个图神经网络框架。PyG对GNN的建模采用消息传播机制(Message Passing Neural Network,MPNN)。框架内的两个重要的扩展组件库为torch_scattertorch_sparse,新增SparseTensor来实现Memory-Efficient的Aggregation操作。

PyG中实现模型对图的操作主要通过消息传播函数propagate()的调用,可以分为三个步骤:message()aggregate()update()。用公式表示如下:

$$ \mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in N(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right), $$

其中$\phi$对应message(),表示当前中心节点与邻居节点、邻接关系共同生成消息的方法;$\square$对应aggregate(),表示将所有邻居节点生成的消息进行聚合的方法;$\gamma$对应update(),表示聚合后生成的总消息是如何与当前中心节点之前的特征向量交互,从而生成新的特征向量。这里的三个函数,均要求可差分(differentiable)以实现梯度的计算和反向传导,同时aggregate()也要具有排列不变性(permutation invariant)。

值得一提的是,这里计算message()的时候,需要将x在节点数量维度展开成边数量的维度。这种方式会产生很大的存储开销,具体解释可以看如下代码:
from torch_geometric.utils import scatter

x = ...           # Node features of shape [num_nodes, num_features]
edge_index = ...  # Edge indices of shape [2, num_edges]

x_j = x[edge_index[0]]  # Source node features [num_edges, num_features]
x_i = x[edge_index[1]]  # Target node features [num_edges, num_features]

msg = MLP(x_j - x_i)  # Compute message for each edge

# Aggregate messages based on target node indices
out = scatter(msg, edge_index[1], dim=0, dim_size=x.size(0), reduce='sum')

例子:PyG中实现的GCN层

    def forward(self, x: Tensor, edge_index: Adj,
                edge_weight: OptTensor = None) -> Tensor:
        """"""

        if self.normalize:
            if isinstance(edge_index, Tensor):
                cache = self._cached_edge_index
                if cache is None:
                    edge_index, edge_weight = gcn_norm(  # yapf: disable
                        edge_index, edge_weight, x.size(self.node_dim),
                        self.improved, self.add_self_loops, self.flow, x.dtype)
                    if self.cached:
                        self._cached_edge_index = (edge_index, edge_weight)
                else:
                    edge_index, edge_weight = cache[0], cache[1]

            elif isinstance(edge_index, SparseTensor):
                cache = self._cached_adj_t
                if cache is None:
                    edge_index = gcn_norm(  # yapf: disable
                        edge_index, edge_weight, x.size(self.node_dim),
                        self.improved, self.add_self_loops, self.flow, x.dtype)
                    if self.cached:
                        self._cached_adj_t = edge_index
                else:
                    edge_index = cache

        x = self.lin(x)

        # propagate_type: (x: Tensor, edge_weight: OptTensor)
        out = self.propagate(edge_index, x=x, edge_weight=edge_weight,
                             size=None)

        if self.bias is not None:
            out = out + self.bias

        return out

    def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

    def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
        return spmm(adj_t, x, reduce=self.aggr)

为了实现更好的存储利用率,更高的计算效率,PyG引入了SparseTensor优化稀疏的情况,同时可以使用融合形式的message_and_aggregate()(必须邻接矩阵为SparseTensor类型),而在这种情况下,使用加法聚合,计算其实就是稀疏矩阵乘(SpMM)。

All code remains the same as before, except for the data transform via T.ToSparseTensor(). As an additional advantage, MessagePassing implementations that utilize the SparseTensor class are deterministic on the GPU since aggregations no longer rely on atomic operations. 原子操作为什么会导致结果是不确定的呢?

这里PyG调用的spmm()来自于torch_sparse;在非稀疏的条件下,aggregate()最终调用的是torch_scatter中由原子操作组成的scatter()实现reduce操作。

torch_sparse中的spmm()给出了两种实现,CPUCUDA的版本,输入均为CSR的压缩格式。其中CPU版本的计算采用三层循环的形式,直接进行计算;CUDA版本的计算采用 Merge-based SpMM形式。

DGL框架

DGL全程是Deep Graph Library,建立之初的定位是一个在现有深度学习框架之上的Python库,用以方便得实现不同类型的图神经网络。DGL同样使用消息传递机制对图操作进行建模,并且提供性能调优后的稀疏矩阵计算内核,后端支持包括CPU、GPU和对应的集群。

It offers a versatile control of message passing, speed optimization via auto-batching and highly tuned sparse matrix kernels, and multi-GPU/CPU training to scale to graphs of hundreds of millions of nodes and edges.

DGL使用和PyG类似的消息传递计算范式:

$$ \text{边上计算: } m_{e}^{(t+1)} = \phi \left( x_v^{(t)}, x_u^{(t)}, w_{e}^{(t)} \right) , ({u}, {v},{e}) \in E. $$

$$ \text{点上计算: } x_v^{(t+1)} = \psi \left(x_v^{(t)}, \rho\left(\left\lbrace m_{e}^{(t+1)} : ({u}, {v},{e}) \in E \right\rbrace \right) \right). $$

$\phi$是定义在每条边上的消息函数,它通过将边上特征与其两端节点的特征相结合来生成消息。聚合函数 $\rho$会聚合节点接受到的消息。更新函数 $\psi$会结合聚合后的消息和节点本身的特征来更新节点的特征。

这三个函数可以使用内置函数或者用户自定义函数,作为参数传给update_all()进行调用,这样做可以在整体上进行系统优化。但是一般DGL不建议在update_all()中指定更新函数,可以直接通过纯张量实现接在后面调用。此外,DGL还提供了apply_edges()计算在边上保存的消息,这种情况下,边上保存的高维消息对于内存的消耗很大。

def update_all_example(graph):
    # 在graph.ndata['ft']中存储结果
    graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
                     fn.sum('m', 'ft'))
    # 在update_all外调用更新函数
    final_ft = graph.ndata['ft'] * 2
    return final_ft

DGL对图神经网络模型的构建,主要是基于后端框架的方式。以Pytorch后端为例,设置选项,注册可学习的参数或者子模块,初始化参数,即完成模块的构造函数。

import torch.nn as nn

from dgl.utils import expand_as_pair

class SAGEConv(nn.Module):
    def __init__(self,
                 in_feats,
                 out_feats,
                 aggregator_type,
                 bias=True,
                 norm=None,
                 activation=None):
        super(SAGEConv, self).__init__()

        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
        self._out_feats = out_feats
        self._aggre_type = aggregator_type
        self.norm = norm
        self.activation = activation

之后,再定义foward()函数前向传播,这里面的主要操作是消息传递和聚合之后,更新特征作为输出。对消息进行计算之前,还需要对图的类型和格式进行处理。

DGL同样提供了优化的稀疏代数计算(如GSpMM)对消息的计算和聚合进行融合。聚合的类型支持加法型(sum、add、mean)和比较型(min、max)。

在使用Pytorch作为后端时,具体的Kernel实现在Pytorch::ATen命名空间下,DGL实现了CPU和CUDA两个版本的SpMM。CPU下有自己实现的Nativa版本和LIBXSMM库(AVX)版本,分别支持单双精度的计算。CUDA下有自己实现的朴素版本和cuSparse计算库的版本(聚合为sum),同时整合了GE-SpMM(dgsparse)的cuda版本作为可选。

值得一提的是,PyG和DGL都有整合cuGraph作为图计算支持库的接口。看起来cuGraph(NV)还是很有面子的

一些其他的碎碎念

图神经网络可以视为需要反复进行图数据分析操作的一种任务,在这种任务之下,对高性能图计算的需求被放大。

一方面对图神经网络之中图计算的建模也大大影响着最终可以进行性能调优的空间。

新型图算子的设计需要依赖图的频谱分析,执行的效率则依赖图结构的表达和图数据的存储方式。另一方面,图算子在图神经网络中不应该是孤立的,可以适当和其他神经网络计算层进行融合,通过消除或者优化中间表达,提升整个模型的计算效率。

图是一种非结构的数据,而图像可以被认为是一种特殊的图。在CV领域效果显著的卷积神经网络,在图这里可以被扩展成对局部子图的某种信息聚合,在这个含义下,节点的连接关系的多样性既实现了图信息表达的丰富也增加了适合规则计算的计算机处理的难度。

这可能也是为什么图分析一直在计算机领域是老大难的问题。