作者 | Boris Knyazev
编译 | 栗峰
来源 | 深度学习这件小事
最近,Graph Neural Network(GNN)在很多领域日益普及,包括社交网络、知识图谱、推荐系统甚至于生命科学。GNN在对节点关系建模方面表现十分突出,使得相关的研究领域取得了一定突破。本文将就“为什么图有用”、“为什么很难在图上定义卷积”、“是什么使神经网络成为了图神经网络”这些问题进行讨论。
首先,让我们简单回顾一下什么是图?图 G 是由有向或无向边连接的一组节点(顶点)。节点和边通常是由专家依据知识经验或是对问题的直觉进行设置的。因此,它可以是分子中的原子,社交网络中的用户,交通系统中的城市,团队运动中的运动员,大脑中的神经元,动态物理系统中的交互对象,图像中的像素、图像边界框或是图像分割掩模。
换句话说,在很多情况下,实际上是由你来决定图的节点和边。
这是一种很灵活的数据结构,它囊括了很多其他的数据结构。例如,如果没有边,那么它就会变成一个集合;如果只有“垂直”边,其中任意两个节点都相连,那么我们就有了一个数据树。当然正如我们接下来将要讨论的,这种灵活性有利也有弊。
两个分别有5和6个节点的无向图,节点的顺序是任意的。
一.为什么图有用
在计算机视觉(CV)和机器学习(ML)的背景下,研究图以及学习当中的模型至少可以给我们带来以下四个好处:
1. 我们可以有机会解决以前解决不了的难题,例如:癌症药物发现(Veselkov等人,Nature,2019年);更好地理解人脑结构(Diez&Sepulre,Nature,2019);能源和环境友好材料的发现(Xie等人,Nature Communications,2019年)。
2. 在大多数CV/ML应用程序中,你可能曾经把它们看成是另一种数据结构,但数据实际上可以被看作是图。将数据表示成图可以提供很大的灵活性,并能在你处理问题的时候为你提供截然不同的视角。例如,你可以直接从“超像素”中学习,而不必从图像像素中学习,在Liang等人2016年在ECCV发表的论文,以及我们即将发表的BMVC论文都可以找到依据。图还允许你对数据施加关系归纳偏差,能使你在处理问题时具备一些先验知识。例如,如果你想对人体的姿势进行推理,你的关系偏差就可以是人体骨架关节的图 (Yen等人,AAAI,2018);或者如果你想对视频进行推理,你的关系偏差可以是移动边框的图 (Wang&Gupta,ECCV,2018)。另一个例子是可以将面部标志表示为图 (Antonakos等人,CVPR,2015),以便对面部特征和身份进行识别。
3. 神经网络本身可以看作是一个图,其中节点是神经元,边是权重,或者节点是层,边表示向前/向后传递的流程(在这种情况下,我们讨论的是在TensorFlow中使用计算图、PyTorch和其他DL框架)。应用程序可以是计算图的优化、神经结构搜索和训练行为分析等。
4. 最后一点,你可以更高效的解决很多问题,在这些问题中数据可以更自然地表示成图。 这包括但又不限于分子和社会网络分类(Knyazev等人,NeurIPS-W,2018),3D Mesh的分类及对应(Fey等人,CVPR 2018),动态交互对象的建模行为(Kipf等人,ICML,2018),视景图建模(详见即将到来的ICCV研讨会)和问答(Narasimhan, NeurIPS,2018),程序综合(Allamanis等人,ICLR,2018),不同的强化学习任务(Bapst等人,ICML,2019)和许多其他问题。
我之前的研究是关于人脸识别和分析面部情绪,所以我很欣赏下面这个图。
来自(Antonakos等人,CVPR,2015)的图,将脸部标志提取出来。这是一种有趣的方法,但在很多情况下它并不能全面的表示出一个人的面部特征,因此可以通过卷积网络从面部纹理中出捕捉到更多信息。相反,与2D标志相比,基于人脸的3D网格的推理看起来更合理(Ranjan等人,ECCV,2018)。
二.为什么很难在图上定义卷积
要回答这个问题,首先要理清一般使用卷积的动机,然后用图术语描述“图像上的卷积”,这将使“图卷积”的过渡更加流畅。
1. 为什么卷积有用
我们应该理解为什么我们要注意到卷积,以及为什么我们要用它来处理图?与完全连接的神经网络(NNS或MLP)相比,卷积网络(CNN或Convnet)具有一定的优势。
首先,Convnet利用图像中的一种自然先验,在Bronstein等人在2016年发布的论文中有了更正式的描述,例如:
(1)平移不变性,如果我们将上面图像上的汽车平移到左/右/上/下,我们仍然能够认识到它是一辆汽车。这是通过在所有位置共享滤波器来实现的,也就是应用卷积。
(2)局域性,附近的像素是密切相关的,通常表示一些语义概念,如车轮或车窗。这是通过使用相对较大的滤波器来实现的,它可以捕捉到局部空间邻域中的图像特征。
(3)组合性(或层次结构),图像中较大的区域通常都包含了较小区域的语义父级。例如,汽车是车门、车窗、车轮、驾驶员等的母体,而司机则是头部、手臂等的母体。这是通过叠加卷积层和应用池进行的隐含表达。
其次,卷积层中可训练参数(即滤波器)的数目并不取决于输入维数,因此在技术上我们可以在28×28和512×512图像上训练完全相同的模型。换句话说,模型是参数化的。
理想情况下,我们的目标是开发一个像图神经网络一样灵活的模型,它可以消化和学习任何数据,但同时我们希望通过打开或关闭某些先验来控制(正则)这种灵活性的元素。
所有这些良好的特性使得ConvNet不太容易过度拟合(训练集的高精度和验证/测试集的低精度),在不同的视觉任务中更精确,并且易于扩展到大型图像和数据集。因此,当我们想要解决输入数据是图结构的重要任务时,将这些属性全部转移到图神经网络(GNN)上,以规范它们的灵活性并使它们具有可扩展性。理想情况下,我们的目标是开发一个像GNN一样灵活的模型,可以消化和学习任何数据,但同时我们希望通过打开或关闭某些先验来控制(正则化)这种灵活性的元素。这可以在很多创新的方向上进行研究。然而,想要控制它并且达到一种平衡状态还是很有挑战性的。
2. 根据图进行图像卷积
我们先来考虑一下具有N个节点的无向图G,边E表示节点之间的无向连接。节点和边通常是由你自己设定的。关于图像,我们认为节点应该是像素或超像素(一组形状怪异的像素),边是它们之间的空间距离。例如,左下方的MNIST图像通常表示为28×28维矩阵。我们也可以用一组N=28*28=784像素来表示它。因此,我们的图G应该有N=784个节点,而对于位置较近的像素,边会有一个较大的值(下图中较厚的边),对于较远的像素,则相应的有较小的值(较薄的边)。
左侧是MNIST数据集的图像,右侧是图的示范。右侧较暗和较大的节点对应较高的像素强度。右图的灵感来自图6(Fey等人,CVPR,2018)
当我们在图像上训练神经网络或Convnet时,潜意识里我们在图上就已经将图像定义成了一个规则的2D网格,如下图所示。这个网格对于所有的训练和测试图像是相同且规则的,也就是说,网格的所有像素都以完全相同的方式在所有图像之间连接(即具有相同的连接数、边缘长度等),所以这个规则的网格图没办法帮我们从一幅图像中分辨出另一幅图像。下面我可视化了一些2D和3D规则网格,其中节点的顺序是彩色编码的。顺便说一句,我是在Python代码中使用了NetworkX来实现的,例如G=networkx.Grid_Graph([4,4])。
规则的2D和3D网格的例子。图像在2D网格上的表现,视频在3D网格上的表现。
考虑到这是个4×4的规则网格,我们可以简单地看看2D卷积是如何工作的,就可以理解为什么很难将算子转换成图。规则网格上的滤波器具有相同的节点级,但现代卷积网络通常有小滤波器,例如下面的例子中的3×3。这个滤波器有9个值:W₁,W₂,…,W₉,这是由于我们在训练过程中使用了backprop工具进行更新以尽量减少损失和解决下游任务的问题。在下面的例子中,我们只是受到启发将滤波器初始化成了边缘检测器(请参阅这里的其他可能的滤波器):
在规则2D网格上的3×3滤波器的例子,左侧是任意权值w,右侧是边缘检测器。
当我们进行卷积的时候,要从两个方向滑动这个滤波器:向右和向下,可以从底角开始,重要的是要滑过所有可能的位置。在每个位置,计算网格上值之间的点积(表示为X)和滤波器的值W:X₁W₁+X₂W₂+…+X₉W₉,并将结果存储在输出图像中。在我们的可视化过程中,改变节点在滑动过程中的颜色,以匹配网格中节点的颜色。在常规网格中,我们始终将滤波器的节点与网格的节点相匹配。但这并不适用于图,我将在下面进行解释。
规则网格上2D卷积的2个步骤。如果我们不应用填充的话,一共会有4个步骤,因此结果是2×2图像。为了使得到的图像更大,我们需要应用填充。在这里,请参阅关于深度学习中卷积的全面指南。
上面使用的点积就是所谓的“聚合算子”之一。广义上来讲,聚合算子的目标是将数据归纳成简单的形式。在上面的例子中,点积将一个3×3矩阵概括为单个值。另一个例子是在Convnet中进行数据汇总。请记住,诸如最大值或和总计值的位置是不变的,也就是说,即使随机地移动该区域内的所有像素,它们还是会在空间区域内汇总成相同的值。为了说明这一点,点积不是置换不变的,因为在一般情况下:X₁W₁+X₂W₂≠X₂W₁+X₁W₂。
现在,让我们使用MNIST图像,来定义规则网格、滤波器和卷积。考虑到我们的图术语,这个规则的28×28网格将是我们的图G,因此这个网格中的每个单元都是一个节点,节点特征是一个实际的图像X,也就是说每个节点只有一个特征,像素强度从0(黑色)到1(白色)。
规则28×28网格(左)和该网格上的图像(右)。
接下来,我们要定义滤波器,并让它成为具有 (几乎)任意参数的著名Gabor滤波器。一旦我们有了图像和滤波器,我们就可以通过在图像上滑动滤波器 (在我们的例子中是数字7),并在每一步之后将点积的结果放到输出矩阵中来执行卷积。
一个28×28滤波器(左)和该滤波器与数字7图像的2D卷积结果。(右)
正如我前面提到的,当你尝试将卷积应用到图时,就会遇到很多问题。
节点是一个集合,该集合的任何排列都不会改变它。因此,人们应用的聚合算子应该是置换不变的.
正如我前面提到的,用于计算每一步卷积的点积对顺序是敏感的。这种灵敏度使我们能够学习与Gabor滤波器相似的边缘检测器,这对于捕获图像特征非常重要。问题在于,在图中没有明确定义的节点顺序,除非你学会给它们排序,或者想出其他一些启发式的方法,能在图与图之间形成规范的顺序。简而言之,节点是一个集合,该集合的任何排列都不会改变它。因此,人们应用的聚合算子应该是置换不变的.最受欢迎的选择是平均值(GCN、Kipf&Wling、ICLR,2017)和对所有相邻数值求和(GIN、XU等人、ICLR,2019),也就是求和或均值池,然后由可训练向量W进行推测,其他聚合器参见Hamilton等人,NIPS, 2017。
说明节点特征X的“图卷积”,滤波器W以节点1(深蓝色)为中心。
例如,左侧的图,节点1的求和聚合器的输出为X₁=(X₁+X₂+X₃+X₄)W₁,节点2:X₂=(X₁+X₂+X₃+X₅)W₁等,即我们需要对所有节点应用此聚合器。因此,我们将得到具有相同结构的图,节点现在包含了所有邻值的功能。我们可以用同样的方法处理右边的图。
通俗地说,人们称这种平均或求和为“卷积”,因为我们也是从一个节点“滑动”到另一个节点,并在每一步中应用聚合算子。但是,重要的一点,这是一种非常特殊的卷积形式,在这里,滤波器没有方向感。下面我将展示这些滤波器的外观,并给出如何使它们更好的建议。
你应该知道典型的神经网络是怎么工作的,我们将C维特征X作为网络的输入。用我们正在运行的MNIST举例,X将是我们的C=784维像素特征(即“扁平”图像)。这些特征乘以我们在训练过程中更新的C×F维度权值W,使输出能更接近我们预期的结果。这个结果可以直接用于解决任务(例如,在回归的情况下),也可以进一步反馈到一些非线性(激活),如relu,或其他可微分(或更准确地说,是次微分)函数,从而形成多层网络。一般来说,l 层的输出是:
全连通层具有可学习权值W。“完全连接”是指X⁽ˡ⁺1 1⁾中的每个输出值取决于或“连接到”所有输入X⁽ˡ⁾。通常情况下,虽然也不总是这样,但我们在输出中添加了一个偏差项。
MNIST中的信号非常强,只要使用上面的公式和交叉熵损失,精准度就可以达到91%以上,而且不需要任何非线性和其他技巧(我是使用了一个略微修改过的PyTorch代码做到了这一点)。这种模型称为多项式(或多类,因为我们有10类数字)Logistic回归。
现在,如何将我们的神经网络转换成图神经网络?正如你已经知道的,GNN背后的核心思想是聚合“邻值”。在这里,重点是要理解,在很多情况下,实际上是你指定了“邻值”。
让我们先来考虑一个简单的情况,当你得到一些图。例如,这可以是5人的社交网络的一个片段(子图),节点之间的边缘表示两个人是否是朋友(或者他们中至少有一个人这样认为)。右边图中的邻接矩阵(通常表示为A)是一种以矩阵形式表示这些边色代表边缘的缺失。
图及其邻接矩阵的例子。我们在这两种情况下定义的节点顺序都是随机的,而图仍然是相同的。
现在,让我们根据像素的坐标为我们的MNIST示例创建一个邻接矩阵A(文章末尾提供了完整的代码):
这是定义视觉任务中邻接矩阵的典型方法但并非是唯一的方法(Defferrard等人,2016年;Bronstein等人,2016年)。这个邻接矩阵是我们的先验,或者说是我们的归纳偏差,我们根据经验施加到模型上,应该连接附近的像素,远程像素不应该有边缘,即使有也应该是非常薄的边缘(小值的边缘)。这是因为我们观察到,在自然图像中的邻近像素通常对应于同一个或多个经常交互的对象(我们前面提到的局部性原则),因此连接这些像素很有意义。
邻接矩阵(NxN)的所有节点对之间的距离(左)和相邻矩阵(中间) (右) 具有16个相邻像素的子图,其对应于中间的邻接矩阵。既然它是一个完整的子图,它也被称为“集团”。
所以,现在不是只有特征X,还有一些值在[0,1]范围内的奇特的矩阵A。需要注意的是,一旦我们知道输入是一个图,我们就假设在数据集中的所有其他图节点的顺序都是一致的。就图像而言,这意味着假定像素被随机调整。在实践中,想要找到节点的规范顺序是根本无法解决的。尽管对于MNIST来说,我们可以通过假定这个顺序来进行操作(因为数据最初来自一个常规网格),但它不适用于实际的图数据集。
记住,我们的特征矩阵X有行和C列。因此,就图而言,每一行对应于一个节点,C是节点特征的维度。但现在的问题是,我们不知道节点的顺序,所以我们不知道应该在哪一行中放置特定节点的特征。
如果我们直接忽略这个问题,并像以前一样直接将X提供给MLP,效果与将每个图像随机打乱像素进行重新组合形成的图像相同,令人惊讶的是,神经网络在原则上是可以拟合这样的随机数据的(Zhang等人,ICLR,2017),但是测试性能将接近随机预测。其中一个解决方案是简单地使用前面创建的邻接矩阵A,方法如下:
图神经层具有邻接矩阵A,输入或输出特征X,可学习权值W。
我们只需要确保A中的第一行对应于X的第一行中节点的特征。这里,我使用的是而不是普通的A,因为你想将A规范化,如果=A,矩阵乘法X⁽ˡ⁾将等价于邻值的求和特征,这在许多任务中都是有用的(Xu等人,ICLR,2019)。最常见的情况是,你将其规范化,使X⁽ˡ⁾具有平均邻值的特性,即=A/ΣᵢAᵢ。规范矩阵A的更好方法可在(Kipf&Wling,ICLR,2017)中找到。
以下是NN和GNN在PyTorch代码方面的比较:
这里有完整的PyTorch代码训练上面的两个模型:Pythonmnist_fc.py-model fc训练NN模型;python mnist_fc.py-模型图训练GNN模型。作为一个练习,可以尝试在模型图中随机打乱代码中的像素(不要忘记以同样的方式对A进行调整),并确保它不会影响结果。对-FC模式的模型来说会是可行的吗?
运行代码后,你可能会注意到在分类的准确性上实际上是相同的。那还有什么问题吗?图形网络不应该运行得更好吗?其实在大多数情况下,它们都是可以正常运行,但是在这个例子中出现了特殊情况,因为我们添加的X⁽ˡ⁾运算符实际上就是一个高斯滤波器:
图神经网络中滤波器的2D可视化及其对图像的影响。
我们的图神经网络被证明是等同于具有单个高斯滤波器的卷积神经网络,在训练过程中我们从不更新,然后是完全连接的层。这个滤波器基本上显示模糊或是清晰的图像,这并不是一件特别有用的事情(见上图右边)。然而,这是图神经网络的最简单的变体,尽管如此,它在图结构的数据上仍然运行得很好。为了使GNN更好地在规则图上工作,(比如图像),我们需要应用一些技巧。例如,我们可以通过使用如下可微函数来学习预测任意一对像素之间的边,而不是使用预定义的高斯滤波器:
为了使GNN更好地在规则图上工作,(比如图像),我们需要应用一些技巧。例如,我们可以通过使用如下可微函数来学习预测任意一对像素之间的边,而不是使用预定义的高斯滤波器。
这一想法类似于动态滤波器网络(Brabander等人,NIP,2016年)、边缘条件图网络(ECC、Simonovsky&Komodakis、CVPR,2017)和(Knyazev等人,NeurIPS-W,2018)。如果想用我们的代码进行尝试,只需要添加-pred_Edge标志,所以完整的指令就是python mnist_fc.py --model graph --pred_edge。下面我展示了预定义的高斯滤波器和学习滤波器的动画。你可能会注意到,我们刚刚学到的滤波器(在中间)看起来很奇怪。这是因为任务相当复杂,我们同时优化了两个模型:预测边缘的模型和预测数字类的模型。为了更好的学习滤波器(如右图所示),我们需要从BMVC论文中应用一些其他技巧,这已经超出了这个教程范畴。
以红点为中心的2D神经网络滤波器。平均(左92.2 4%),坐标学习(中91.05%),坐标学习(右92.39%)。
生成这些GIF的代码非常简单:
我还分享了一个IPython代码笔记,它用Gabor滤波器显示了图像的2D卷积(使用邻接矩阵),而不是使用循环矩阵,循环矩阵通常用于信号处理。
在本教程的下一部分中,我将详解更高级的图层,这些图层可以对图进行更好的筛选。
四、总结
图神经网络是一个非常灵活且有趣的神经网络家族,可以应用于非常复杂的数据。当然,这种灵活性也要付出一定的代价。在GNN的情况下,难以通过将这样的运算符定义为卷积来使模型正规化。但这方面的研究进展很快,相信不久会得到完善的解决,GNN将会在机器学习和计算机视觉领域得到越来越广泛的应用。
数据分析咨询请扫描二维码
在当今以数据为导向的商业环境中,数据分析师的角色变得越来越重要。无论是揭示消费者行为的趋势,还是优化企业运营的效率,数据 ...
2024-11-17在当今以数据为导向的商业环境中,数据分析师的角色变得越来越重要。无论是揭示消费者行为的趋势,还是优化企业运营的效率,数据 ...
2024-11-17金融数学是一门充满挑战和机遇的专业,它将数学、统计学和金融学的知识有机结合,旨在培养能够运用数学和统计方法解决复杂金融市 ...
2024-11-16在信息时代的浪潮中,大数据已成为推动创新的重要力量。无论是在商业、医疗、金融,还是在日常生活中,大数据扮演的角色都愈发举 ...
2024-11-16随着大数据技术的迅猛发展,数据已经成为现代商业、科技乃至生活各个方面的重要资产。大数据专业的毕业生在这一变革背景下,拥有 ...
2024-11-15随着大数据技术的迅猛发展,数据已经成为现代商业、科技乃至生活各个方面的重要资产。大数据专业的毕业生在这一变革背景下,拥有 ...
2024-11-15在快速演变的数字时代,数据分析已成为多个行业的核心驱动力。无论你是刚刚踏入数据分析领域,还是寻求进一步发展的专业人士,理 ...
2024-11-15Python作为一种通用编程语言,以其简单易学、功能强大等特点,成为众多领域的核心技术驱动者。无论是初学者还是有经验的编程人员 ...
2024-11-15在当今数据驱动的世界中,数据分析已成为许多行业的基础。无论是商业决策,产品开发,还是市场策略优化,数据分析都扮演着至关重 ...
2024-11-15数据分析作为现代商业和研究领域不可或缺的一部分,吸引了越来越多的初学者。然而,自学数据分析的过程中,初学者常常会遇到许多 ...
2024-11-15在当今的数据驱动世界中,机器学习方法在数据挖掘与分析中扮演着核心角色。这些方法通过从数据中学习模式和规律来构建模型,实现 ...
2024-11-15随着数据在各个行业的重要性日益增加,数据分析师在商业和技术领域的角色变得至关重要。其核心职责之一便是通过数据可视化,将复 ...
2024-11-15数据分析师的职责不仅仅局限于解析数据和得出结论,更在于将这些复杂的信息转换为清晰、易懂且具有影响力的沟通。良好的沟通能力 ...
2024-11-15数字化转型是企业提升竞争力和实现可持续发展的关键路径。面对快速变化的市场环境,以及技术的飞速发展,企业在数字化转型过程中 ...
2024-11-15CDA数据分析师认证:CDA认证分为三个等级:Level Ⅰ、Level Ⅱ和Level Ⅲ,每个等级的报考条件如下: Le ...
2024-11-14自学数据分析可能是一条充满挑战却又令人兴奋的道路。随着数据在现代社会中的重要性日益增长,掌握数据分析技能不仅能提升你的就 ...
2024-11-14数据分析相关职业选择 数据分析领域正在蓬勃发展,为各种专业背景的人才提供了丰富的职业机会。从初学者到有经验的专家,每个人 ...
2024-11-14数据挖掘与分析在金融行业的使用 在当今快速发展的金融行业中,数据挖掘与分析的应用愈发重要,成为驱动行业变革和提升竞争力的 ...
2024-11-14学习数据挖掘需要掌握哪些技能 数据挖掘是一个不断发展的领域,它结合了统计学、计算机科学和领域专业知识,旨在从数据中提取有 ...
2024-11-14统计学作为一门基于数据的学科,其广泛的应用领域和多样的职业选择,使得毕业生拥有丰厚的就业前景。无论是在政府还是企业,统计 ...
2024-11-14