网易首页 > 网易号 > 正文 申请入驻

深入解析图神经网络注意力机制:数学原理与可视化实现

0
分享至

在图神经网络(Graph Neural Networks, GNNs)的发展历程中,注意力机制扮演着至关重要的角色。通过赋予模型关注图中最相关节点和连接的能力,注意力机制显著提升了GNN在节点分类、链接预测和图分类等任务上的性能。尽管这一机制的重要性不言而喻,但其内部工作原理对许多研究者和工程师而言仍是一个"黑盒"。

本文旨在通过可视化方法和数学推导,揭示图神经网络自注意力层的内部运作机制。我们将采用"位置-转移图"的概念框架,结合NumPy编程实现,一步步拆解自注意力层的计算过程,使读者能够直观理解注意力权重是如何生成并应用于图结构数据的。

通过将复杂的数学表达式转化为易于理解的代码块和可视化图形,本文不仅适合已经熟悉图神经网络的研究人员,也为刚开始接触这一领域的学习者提供了一个清晰的学习路径。

本文详细解析了图神经网络自注意力层的可视化方法及其数学原理,通过代码实现展示其内部工作机制。

图神经网络自注意力层的数学表示

在采用自注意力机制的图神经网络中,一个典型层的计算可以通过以下张量乘法表示:

其中各元素定义如下:

包含自循环的邻接矩阵的转置

注意力张量

节点特征矩阵

常规(非注意力)权重张量的转置

"自注意力"机制的核心在于注意力张量实际上是由方程中其他元素通过线性函数与非线性函数组合生成的。这一概念可能较为抽象,但我们可以通过编程实现来展示这种组合关系,并从代码中推导出直观的图形表示。

选择NumPy实现而非解析PyTorch Geometric

我们选择使用NumPy的原因在于:

PyG的实际代码包含大量计算细节,且设计目标是扩展基础MessagePassing模块,这使得理解张量元素间的关系变得复杂。例如,GATv2Conv模块处理了以下复杂性:

  • 参数重置
  • forward()方法的多种变体
  • SparseTensors的特殊处理

而基本的MessagePassing模块则考虑了更多复杂因素,包括钩子、Jinja文本渲染、可解释性、推理分解、张量大小不匹配异常、"提升"和"收集"的子任务以及分解层等。

因此使用NumPy构建一个简洁明了的例子能够更有效地帮助我们理解注意力张量是如何从方程的其他元素构建而来的。

图注意力层的NumPy实现

为了绘制方程的位置-转移图,我们将Labonne的代码重构为四个类,这四个类对应于本文顶部图中的四个胶囊(GAL1GAL4)。

采用面向对象的方法使得我们可以通过构造函数(init方法)区分中间结果和在整个位置-转移图中四个类/胶囊间共享的结果。共享结果通过self.x = y赋值保存为实例数据成员。

为便于理解,下面是一个四节点图的示例:

我们假设每个节点都与自身连接。图中展示了入站和出站弧而非无向边,因为入站-出站关系在代码中被显式表示。

为简化起见,我们假设特征和权重初始化均在(-1, 1)范围内。

以下是GAL1的代码实现:

import numpy as np
np.random.seed(0)
class GAL1:
num_nodes = 4
num_features = 4
num_hidden_dimensions = 2 # We just choose this arbitrarily // 我们任意选择这个值
X = np.random.uniform(-1, 1, (num_nodes, num_features))
print('X\n', X, '\n')
def __init__(self):
W = np.random.uniform(-1, 1, (GAL1.num_hidden_dimensions, GAL1.num_nodes))
print('W\n', W, '\n')
self.XatWT = GAL1.X @ W.T
print('XatWT\n', self.XatWT, '\n')

执行该代码会产生以下输出:

X
[[ 0.09762701 0.43037873 0.20552675 0.08976637]
[-0.1526904 0.29178823 -0.12482558 0.783546 ]
[ 0.92732552 -0.23311696 0.58345008 0.05778984]
[ 0.13608912 0.85119328 -0.85792788 -0.8257414 ]]
W
[[-0.95956321 0.66523969 0.5563135 0.7400243 ]
[ 0.95723668 0.59831713 -0.07704128 0.56105835]]
XatWT
[[ 0.37339233 0.38548525]
[ 0.85102612 0.47765279]
[-0.67755906 0.73566587]
[-0.65268413 0.24235977]]

在这一阶段,我们初始化了节点特征矩阵X和标准权重矩阵W。在实际训练场景中,X来自图结构,而W则源自初始化或前一轮训练。这在位置-转移图上表示为标记为"Graph"和"PyTorch Geo"的"云"位置。

GAL1的主要保留数据成员是self.XatWT,即我们方程的右侧部分("at"表示矩阵乘法的"@"中缀符号)。在后续代码中,这个中间结果将与邻接矩阵结合,形成注意力张量。

GAL2的代码实现如下:

class GAL2:
A = np.array([
[1, 1, 1, 1],
[1, 1, 0, 0],
[1, 0, 1, 1],
[1, 0, 1, 1]
])
def __init__(self, gal1: GAL1):
print('A\n', GAL2.A, '\n')
u = np.asarray(GAL2.A > 0)
print('u\n', u, '\n')
self.connections = u.nonzero()
print('connections\n', self.connections, '\n')
XatWTc0 = gal1.XatWT[self.connections[0]]
print('XatWTc0\n', XatWTc0, '\n')
XatWTc1 = gal1.XatWT[self.connections[1]]
print('XatWTc1\n', XatWTc1, '\n')
self.XatWT_concat = np.concatenate([XatWTc0, XatWTc1], axis=1)
print('XatWT_concat\n', self.XatWT_concat, '\n')
def reshape(self, e: np.ndarray) -> np.ndarray:
E = np.zeros(GAL2.A.shape)
E[self.connections[0], self.connections[1]] = e[0]
return E

邻接矩阵A由图的结构固定。connections计算的结果如下:

A
[[1 1 1 1]
[1 1 0 0]
[1 0 1 1]
[1 0 1 1]]
u
[[ True True True True]
[ True True False False]
[ True False True True]
[ True False True True]]
connections
(array([0, 0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3]), array([0, 1, 2, 3, 0, 1, 0, 2, 3, 0, 2, 3]))

我们选择的节点标签与邻接矩阵中的索引对应。第一个connections数组表示具有到节点j的出站连接的节点索引i

例如:

  • 节点0出现四次(出站连接到所有节点包括自身)
  • 节点1仅出现两次(出站连接到节点0和自身)
  • 节点2和节点3各出现三次(出站连接到节点0、彼此和自身)

第二个connections数组包含相同的值,但按入站顺序排列,这是因为该图实际上是非定向的。

使用connections数组作为gal1.XatWT的索引,产生以下输出:

XatWTc0
[[ 0.37339233 0.38548525]
[ 0.37339233 0.38548525]
[ 0.37339233 0.38548525]
[ 0.37339233 0.38548525]
[ 0.85102612 0.47765279]
[ 0.85102612 0.47765279]
[-0.67755906 0.73566587]
[-0.67755906 0.73566587]
[-0.67755906 0.73566587]
[-0.65268413 0.24235977]
[-0.65268413 0.24235977]
[-0.65268413 0.24235977]]
XatWTc1
[[ 0.37339233 0.38548525]
[ 0.85102612 0.47765279]
[-0.67755906 0.73566587]
[-0.65268413 0.24235977]
[ 0.37339233 0.38548525]
[ 0.85102612 0.47765279]
[ 0.37339233 0.38548525]
[-0.67755906 0.73566587]
[-0.65268413 0.24235977]
[ 0.37339233 0.38548525]
[-0.67755906 0.73566587]
[-0.65268413 0.24235977]]

此处,我们的十二元素入站和出站connections索引数组分别被转换为gal1.XatWT元素的十二元素数组。

将入站和出站数组连接,得到结果:

XatWT_concat
[[ 0.37339233 0.38548525 0.37339233 0.38548525]
[ 0.37339233 0.38548525 0.85102612 0.47765279]
[ 0.37339233 0.38548525 -0.67755906 0.73566587]
[ 0.37339233 0.38548525 -0.65268413 0.24235977]
[ 0.85102612 0.47765279 0.37339233 0.38548525]
[ 0.85102612 0.47765279 0.85102612 0.47765279]
[-0.67755906 0.73566587 0.37339233 0.38548525]
[-0.67755906 0.73566587 -0.67755906 0.73566587]
[-0.67755906 0.73566587 -0.65268413 0.24235977]
[-0.65268413 0.24235977 0.37339233 0.38548525]
[-0.65268413 0.24235977 -0.67755906 0.73566587]
[-0.65268413 0.24235977 -0.65268413 0.24235977]]

ndarrayconnections被赋值给self,但并非为了在GAL2外部使用(因此在图中用虚线椭圆表示)。相反,我们在reshape方法中使用connectionsreshape方法通过创建一个与A形状相同的零矩阵来生成ndarrayE,然后使用connections[0]作为E的行索引,connections[1]作为E的列索引,从输入ndarraye[0]分配值。此方法将被GAL3调用。

显然,Econnectionse的排序应具有相同数量的元素。E的某些元素将保持未分配状态(零值),即那些对应于图中缺少入站或出站弧的节点对的元素。

除了GAL2.A之外,数组XatWT_concat也将在后续计算中使用,因此被赋值给self

GAL3的代码实现如下:

class GAL3:
@staticmethod
def leaky_relu(x, alpha=0.2) -> np.ndarray:
return np.maximum(alpha * x, x)
@staticmethod
def softmax2D(x, axis) -> np.ndarray:
e = np.exp(x - np.expand_dims(np.max(x, axis=axis), axis))
sum_ = np.expand_dims(np.sum(e, axis=axis), axis)
return e / sum_
def __init__(self, gal2: GAL2):
W_att = np.random.uniform(-1, 1, (1, GAL1.num_nodes))
print('W_att\n', W_att, '\n')
a = W_att @ gal2.XatWT_concat.T
print('a\n', a, '\n')
e = GAL3.leaky_relu(a)
print('e\n', e, '\n')
E = gal2.reshape(e)
print('E\n', E, '\n')
W_alpha = GAL3.softmax2D(E, 1)
print('W_alpha\n', W_alpha, '\n')
self.left = gal2.A.T @ W_alpha
print('left\n', self.left, '\n')

GAL3是我们引入非线性(leaky_relu)和归一化(softmax2D)操作的类。GAL3最终将生成原始方程的整个左侧,仅剩右侧gal1.XatWT未处理。GAL3的唯一"输出"是self.left

以下是GAL3中的前四个计算步骤:

  • W_att:初始化或来自前一轮训练
  • aW_attgal2.XatWT_concat的矩阵乘法
  • e:对a应用leaky_relu函数
  • E:调用gal2.reshape方法,传入e作为输入

这四个计算的结果如下:

W_att
[[-0.76345115 0.27984204 -0.71329343 0.88933783]]
a
[[-0.1007035 -0.35942847 0.96036209 0.50390318 -0.43956122 -0.69828618
0.79964181 1.8607074 1.40424849 0.64260322 1.70366881 1.2472099 ]]
e
[[-0.0201407 -0.07188569 0.96036209 0.50390318 -0.08791224 -0.13965724
0.79964181 1.8607074 1.40424849 0.64260322 1.70366881 1.2472099 ]]
E
[[-0.0201407 -0.07188569 0.96036209 0.50390318]
[-0.08791224 -0.13965724 0. 0. ]
[ 0.79964181 0. 1.8607074 1.40424849]
[ 0.64260322 0. 1.70366881 1.2472099 ]]

GAL3中的最后两个计算步骤:

  • W_alpha:对E应用softmax函数
  • self.leftgal2.A.TW_alpha的矩阵乘法

结果如下:

W_alpha
[[0.15862414 0.15062488 0.42285965 0.26789133]
[0.24193418 0.22973368 0.26416607 0.26416607]
[0.16208847 0.07285714 0.46834625 0.29670814]
[0.16010498 0.08420266 0.46261506 0.2930773 ]]
left
[[0.72275177 0.53741836 1.61798703 1.12184284]
[0.40055832 0.38035856 0.68702572 0.5320574 ]
[0.48081759 0.30768468 1.35382096 0.85767677]
[0.48081759 0.30768468 1.35382096 0.85767677]]

GAL3的唯一"输出"是left,因此它被赋值给self

至此,我们已经计算出原始方程的左侧和右侧(gal1.XatWT)。

GAL4的代码实现及主函数如下:

class GAL4:
def __init__(self, gal1: GAL1, gal3: GAL3):
self.H = gal3.left @ gal1.XatWT
print('H\n', self.H, '\n')
if __name__ == '__main__':
gal_1 = GAL1()
gal_2 = GAL2(gal_1)
gal_3 = GAL3(gal_2)
gal_4 = GAL4(gal_1, gal_3)

最终结果H为:

H
[[-1.10126376 1.99749693]
[-0.33950544 0.97045933]
[-1.03570438 1.53614075]
[-1.03570438 1.53614075]]

在这里,我们将原始方程的左侧和右侧进行矩阵乘法运算,得到最终结果。

图注意力层的结构分析

从文章开头的图和上面"main"中的代码可以看出,每个GALx仅依赖于前一个GAL(x-1),除了GAL4,它同时依赖于GAL1GAL3。通过对代码进行分类和封装,我们使其结构更加清晰,从而更易于理解。

该图由位置(椭圆)和转移(矩形)组成,因此被称为位置-转移图。在本文中,我们仅针对GAL特定实现的位置-转移图进行直观分析。有关位置-转移图的更详细信息,请参考我之前的文章(参考文献[PT-GNN-TD])中的"位置-转移图基础"部分。

下面我们将详细分析GAL位置-转移图的各个组成部分。

GAL1结构相对简单,仅执行一次矩阵乘法运算。但其结果是原始方程的整个右侧,也是GAL2GAL4的主要非邻接相关输入。

将这两个组件合并分析是因为它们之间的连接较为紧密。GAL3利用了GAL2的值AXatWT_concat,以及GAL2的方法reshape。我们通过标记来自输入引用gal2的弧线来突出每个值或方法的使用位置。

同样,GAL2connections使用虚线表示,因为它仅在公开方法reshape中使用。

GAL2专注于矩阵操作,是邻接矩阵A"注入"到原始方程的关键点。因此,GAL2是以图结构为中心的组件。

GAL3同样执行矩阵操作,但其核心功能是应用非线性函数(leaky_relu)和归一化操作(softmax)。注意力权重矩阵W_att的引入对GAL3的功能也至关重要。GAL3是以注意力机制为中心的组件。

GAL1类似,GAL4的结构也相对简单,仅执行一次矩阵乘法。它将方程的左侧gal3.left与右侧gal1.XatWT结合。GAL4是唯一一个接收来自多个组件输入的类,因此它扮演着"混合器"的角色,在"串联"和"并联"模式下连接节点特征、邻接关系和注意力机制。

核心代码

以下是实际PyG库中GATv2Conv的核心代码,涵盖了我们使用NumPy模拟的大部分功能:

def edge_update(self, x_j: Tensor, x_i: Tensor, edge_attr: OptTensor,
index: Tensor, ptr: OptTensor,
dim_size: Optional[int]) -> Tensor:
x = x_i + x_j
# some conditional edge code removed... // 删除了一些条件边缘代码...
x = F.leaky_relu(x, self.negative_slope)
alpha = (x * self.att).sum(dim=-1)
alpha = softmax(alpha, index, ptr, dim_size)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
return alpha
def message(self, x_j: Tensor, alpha: Tensor) -> Tensor:
return x_j * alpha.unsqueeze(-1)

忽略MessagePassing的部分复杂性,我们可以看到实际的PyG代码与我们的NumPy实现在核心逻辑上非常相似。

总结

通过本文的分析,我们已经深入剖析了图神经网络自注意力机制的内部工作原理。从数学表达式到代码实现再到可视化图形,我们提供了一个全方位的视角来理解注意力权重如何在图结构数据中生成和应用。

通过位置-转移图的概念框架,我们不仅展示了计算流程,还揭示了各组件之间的依赖关系,为图神经网络的可解释性研究提供了新的思路。

https://avoid.overfit.cn/post/1b68891a54a543da8d4f72fb2491d7c8

作者:John Baumgarten

特别声明:以上内容(如有图片或视频亦包括在内)为自媒体平台“网易号”用户上传并发布,本平台仅提供信息存储服务。

Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.

相关推荐
热点推荐
毛主席后代现状:无心进入政坛,大多从商,从不言是伟人后代

毛主席后代现状:无心进入政坛,大多从商,从不言是伟人后代

瞻史
2026-05-04 14:56:22
杨尚昆晚年回忆道:党内那么多人,山头也多,只有毛主席能拢得住

杨尚昆晚年回忆道:党内那么多人,山头也多,只有毛主席能拢得住

史韵流转
2026-04-08 09:42:46
湖南浏阳一烟花厂发生爆炸事故

湖南浏阳一烟花厂发生爆炸事故

澎湃新闻
2026-05-04 19:18:09
醪糟再次被关注!医生发现:高血脂患者喝醪糟,不用多久4大变化

醪糟再次被关注!医生发现:高血脂患者喝醪糟,不用多久4大变化

芹姐说生活
2026-04-19 15:52:53
俄罗斯打破承诺,一艘满载油轮送到日本港口,高市终于松了一口气

俄罗斯打破承诺,一艘满载油轮送到日本港口,高市终于松了一口气

花小猫的美食日常
2026-05-04 19:38:45
火箭拒新控卫原因曝光:自认夺冠无望,再忍1年,下赛季卷土重来

火箭拒新控卫原因曝光:自认夺冠无望,再忍1年,下赛季卷土重来

熊哥爱篮球
2026-05-04 10:55:27
允许一切,自在随心

允许一切,自在随心

青苹果sht
2026-05-04 05:03:52
退休后才发现,大多数老人不旅游,去旅游的是这几类人

退休后才发现,大多数老人不旅游,去旅游的是这几类人

十点读书
2026-05-03 18:38:35
男人搞定50岁女人最好方法,喂饱了她两个需求,她就会主动依你

男人搞定50岁女人最好方法,喂饱了她两个需求,她就会主动依你

心理观察局
2026-05-04 08:20:08
感谢特朗普!我国投入巨资建设的瓜达尔港,终于等来了大订单

感谢特朗普!我国投入巨资建设的瓜达尔港,终于等来了大订单

南生今世说
2026-05-02 17:56:11
中央明确!高校薪酬制度改革,来了

中央明确!高校薪酬制度改革,来了

麦可思研究
2026-05-04 17:04:28
牛!《消失的人》逆袭冠军,预测暴涨到5.6亿,陈思诚有对手了

牛!《消失的人》逆袭冠军,预测暴涨到5.6亿,陈思诚有对手了

得得电影
2026-05-03 13:28:45
张柏芝大儿子终于“长开”了!穿西装比谢霆锋还帅,网友:像爷爷

张柏芝大儿子终于“长开”了!穿西装比谢霆锋还帅,网友:像爷爷

木子爱娱乐大号
2026-01-07 21:47:13
莫雷加德:中国队不必太过担忧,这周还有很多硬仗要打

莫雷加德:中国队不必太过担忧,这周还有很多硬仗要打

懂球帝
2026-05-04 09:44:16
中原大地明星谱——30名河南籍电影演员名录

中原大地明星谱——30名河南籍电影演员名录

陈意小可爱
2026-05-03 16:02:33
“高净值家庭”标准出炉,全中国共有512.8万户,你家达标了吗?

“高净值家庭”标准出炉,全中国共有512.8万户,你家达标了吗?

毒sir财经
2026-04-26 21:11:44
女子给男主播刷4万礼物,私下见面想亲热被拒绝,气得要求退钱

女子给男主播刷4万礼物,私下见面想亲热被拒绝,气得要求退钱

新游戏大妹子
2026-04-27 10:57:55
18亿+凌晨档+无国足!FIFA想割韭菜?亚洲多国说不,中国不惯着

18亿+凌晨档+无国足!FIFA想割韭菜?亚洲多国说不,中国不惯着

曹老师评球
2026-05-04 18:36:46
孙杨反击!瓜越吃越大,他和张豆豆怎么有勇气上真人秀的

孙杨反击!瓜越吃越大,他和张豆豆怎么有勇气上真人秀的

东方不败然多多
2026-05-04 16:52:08
3-1击败法国队!国羽男队卫冕汤姆斯杯,队史第12次夺冠

3-1击败法国队!国羽男队卫冕汤姆斯杯,队史第12次夺冠

全景体育V
2026-05-04 05:24:07
2026-05-04 21:07:00
deephub incentive-icons
deephub
CV NLP和数据挖掘知识
1986文章数 1461关注度
往期回顾 全部

科技要闻

OpenAI“复活”了QQ宠物,网友直接玩疯

头条要闻

英媒:伊朗革命卫队要求特朗普“二选一”

头条要闻

英媒:伊朗革命卫队要求特朗普“二选一”

体育要闻

骑士破猛龙:加雷特·阿伦的活力

娱乐要闻

张敬轩还是站上了英皇25周年舞台

财经要闻

魔幻的韩国股市,父母给婴儿开户买股票

汽车要闻

同比大涨190% 方程豹4月销量29138台

态度原创

房产
健康
游戏
公开课
军事航空

房产要闻

五一楼市彻底明牌!塔尖人群都在重仓凯旋新世界

干细胞治烧烫伤面临这些“瓶颈”

PS5破解最新工具即将发布!开发者喊话

公开课

李玫瑾:为什么性格比能力更重要?

军事要闻

特朗普回绝伊朗新方案

无障碍浏览 进入关怀版