当前位置: 首页 > 科技观察

GraphSAGE图神经网络算法详解

时间:2023-03-12 19:34:16 科技观察

GraphSAGE是一篇17年的文章,却一直被业界重视。最重要的是其论文名称中的两个关键词:inductive和largegraph。今天我们就来梳理一下这篇文章的核心思想和一些容易被忽略的细节。为什么要使用GraphSAGE我们先来思考一下为什么图这么受欢迎。有几个原因。图的数据来源丰富,图包含的信息量非常大。所以现在我们在思考如何更好的利用图中的信息。那么我们需要用图表做什么呢?核心是利用图的结构信息为每个节点学习一个合适的嵌入向量。只要我们有一个合适的embedding结果,无论我们接下来做什么工作,我们都可以直接用它来设置模型。在GraphSAGE之前,主要的方法有DeepWalk、GCN等,但缺点是需要学习整张图。并且主要是基于transductivelearning,也就是说在训练的时候,图中已经包含了要预测的节点。考虑到在实际应用中,图的结构会经常发生变化,在最后的预测阶段可能会向图中添加一些新的节点。那该怎么办呢?GraphSAGE就是为此而提出的,其核心思想其实就是它的名字GraphSAGE=GraphSampleAggregate。也就是说,对图进行采样和聚合。我们在GraphSAGE的思想中提到了sample和aggregate,它们到底指的是什么?这一步如何运作?为什么可以应用于大规模图?下面我就用通俗易懂的语言给大家描述清楚。顾名思义,sample就是选取一些点,aggregate就是聚合它们的信息。那么整个过程是如何进行的呢?看下图:我们先在第一张图上学习sample的过程。如果我有这样一个图,我需要更新最中心节点的mebedding,首先从它的邻居中选择S1(本例中选择3)个节点,如果K=2,那么我们将更新第二个采样进行再次在下一层,即选择刚刚选择的S1邻居的邻居。在第二张图中,我们可以看到聚合操作,也就是说先用邻居的邻居更新邻居的信息,然后用更新后的邻居的信息更新目标节点(也就是中间的红点)信息。听起来可能有点啰嗦,但是思路并不曲折,大家仔细梳理一下就明白了。第三张图,如果我们要预测一个未知节点的信息,只需要利用它的邻居来做预测即可。我们再梳理一下这个思路:如果我想知道小明是个什么样的性格,我会先去他的几个好朋友那里去观察,然后再选择他朋友的其他朋友做进一步的确认。朋友们,再看看。也就是说,通过小明的朋友的好友,我可以判断出小明的朋友是一个什么样的人,然后根据他的朋友,我可以大致知道小明是一个什么样性格的人。GraphSAGE补充既然知道了GraphSAGE的基本思路,大家可能还有一些困惑:单个节点的思路是这样的,那么整体的训练过程应该如何进行呢?到现在为止,还没有告诉我们为什么GraphSAGE可以应用于大规模图,为什么是归纳的?接下来,我们将补充GraphSAGE的训练过程,以及它在这个过程中有哪些优势。首先,考虑到我们需要从初始特征开始,逐层更新embedding,那么我们如何知道需要聚合哪些点呢?套用上面说的sample的思想,我们具体来看一下算法:先看算法的第2-7行,其实就是一个sample的过程,将sample的结果保存到B。下9-15行是一个聚合过程。根据前一个样本的结果,将相应的邻居信息聚合到目标节点。细心的朋友一定发现了,sample的过程是从K到1(见第2行),aggregate的过程是从1到K(第9行)。原因很明显。采样时,我们先从整个图中选择我们要embed的节点,然后对这些节点的邻居进行采样,逐渐采样到距离较远的邻居。但是聚合时必须先从最远的邻居开始聚合,聚合只能聚合到第K层的目标节点。这就是GraphSAGE的整体思路。那么需要思考的是,这么简单的想法有什么玄机呢?GraphSAGE的精妙之处在于,为什么一开始就提出了GraphSAGE?其实最重要的是归纳学习。这两天看到有同学同时在几个讨论组讨论transductivelearning和inductivelearning。一般来说,归纳学习无疑可以对测试时新增的内容进行推理。因此,GraphSAGE的优势之一就是经过训练后,可以对图网络中新增的节点进行推理,这在实际场景应用中非常重要。另一方面,在图网络的应用中,数据集往往非常庞大,因此minibatch的能力非常重要。但是因为GraphSAGE的思想,我们只需要聚合我们采样的数据,而不需要考虑其他节点。每批可以是一批样品结果的组合。再考虑聚合函数的部分。在这里的训练结果中,聚合函数起到了非常重要的作用。聚合函数的选取有两个条件:一是必须可导,因为训练目标的聚合函数参数需要反向传递;第二,它是对称的,这里的对称是指对输入不敏感,因为我们在聚合的时候,图中的节点关系是没有顺序特征的。因此,在作者的原文中,选择了Mean和maxpooling等聚合器。虽然作者也使用了LSTM,但是在输入之前节点会被打乱。也就是说,LSTM无法从序列顺序中学习。什么知识。此外,论文中还有一个小细节。我第一次看论文的时候没有仔细看。经朋友指点后才知道是真的。先贴一下原文:这里是算法1中的第4行和第5行,也就是我们前面说的。给定算法的第11和12行。也就是说,文中作者提到的GraphSAGE-GCN其实就是用上面的聚合函数来代替其他方法中先聚合再concat的操作,作者指出这种方法是局部谱的线性逼近卷积,所以称之为GCNAggregator。让我们做一些善后工作。最后再简单补充一些通俗的、比较简单的东西。GraphSAGE通常用于什么?首先,作者提出它既可以用于无监督学习,也可以用于监督学习。通过监督学习,我们可以直接将最终预测的损失函数作为目标,反向传播进行训练。无监督学习呢?其实无论是哪种用途,我们都需要关注图本身。我们还是主要用它来完成嵌入操作。即得到一个节点的embedding后,得到一个更有效的特征向量。那么在做无监督的时候,怎么知道它的embedding结果是对还是错呢?作者选择了一个通俗易懂的思路,就是邻里关系。默认情况下,当两个节点之间的距离很近时,它们的嵌入结果会比较相似。如果距离远了,embedding的结果自然应该大不相同。这样,下面的损失函数就很容易理解了:z_v表示是目标节点u的邻居,而v_n表示不是,P_n(v)是负样本的分布,Q是个数负样本。那么现在剩下的唯一问题就是如何定义邻居?作者选择了一个很简单的思路:直接用DeepWalk进行随机游走,步长为5,测试50次,所有游走都是neighbors。我们不会显示实验结果的摘要。其实我们可以看到作者在很多地方都使用了一些baseline的思路。您可以更换和调整相应的地方以满足您的业务需求。稍后,我们将继续分享一些关于GNN和嵌入的经典和鼓舞人心的论文。欢迎继续关注~~~