11.6 GNN
11.6.1 原理
图的结构
图由边与节点构成,即\(G=\{V,E\}\),有时还会附带全局信息\(\text{Master Node(U)}\)。无论是节点、边还是全局信息,都可通过向量来存储数据。而对于图的连接信息,则可通过邻接列表(二维列表,每个子列表代表谁与谁连接)来存储。
信息聚合
边、节点亦或是全局信息都可通过信息聚合的方式从邻居(三者都可以)处获取信息。可通过求和、求平均、取最大值的形式完成信息聚合。
sum、mean、max的操作没有显著差异
输入与输出
图神经网络的输入与输出都是图,每个GNN层都是一次信息聚合,从而完成节点、边或全局信息的更新。堆叠层数越多就是让元素能够逐步整合更大范围的图结构信息。
11.6.2 图神经网络的类型
GCN(图卷积神经网络)
GCN以节点为研究单位,根据连接关系从邻居节点处聚合特征。
GAT(图注意力网络)
引入注意力机制,使得GCN能够根据注意力得分对邻居特征进行加权。
ST-GNN(时空图神经网络)
利用图神经网络从空间角度建模,也就是说可以用图神经网络对具有网络结构的数据进行特征提取。之后可将提取后的特征代入到时序模型,例如LSTM、GRU等。通过同时捕捉节点间拓扑依赖和时间动态变化,实现对时空关联数据的精准预测。
11.6.3 示例
图神经网络建模可通过torch_geometric实现。基本建模过程就是定义图数据结构(确定每个节点的特征、构建边关系),之后再定义图神经网络模型即可。
注意图神经网络的视角是空间视角,抓住数据中的结构关系即可
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
# ===============================
# 1️⃣ 定义图结构
# ===============================
# 图的边 (source, target),采用 COO 格式
# 例如:0↔1, 1↔2, 2↔3, 3↔4, 4↔0
edge_index = torch.tensor([
[0, 1, 2, 3, 4, 1, 2, 3, 4, 0], # source
[1, 2, 3, 4, 0, 0, 1, 2, 3, 4] # target
], dtype=torch.long)
# 每个节点的特征(这里每个节点3维)
x = torch.randn((5, 3))
# 如果是节点分类任务,可加上节点标签
y = torch.tensor([0, 1, 0, 1, 0], dtype=torch.long)
# 构建图数据对象
data = Data(x=x, edge_index=edge_index, y=y)
print("图信息:")
print(data)
print("节点特征形状:", data.x.shape)
print("边数量:", data.edge_index.shape[1])
# ===============================
# 2️⃣ 定义 GCN 模型
# ===============================
class GCN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super(GCN, self).__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
# 第一层:图卷积 + ReLU
x = self.conv1(x, edge_index)
x = F.relu(x)
# 第二层:图卷积 + Softmax 输出(用于分类)
x = self.conv2(x, edge_index)
return x
# ===============================
# 3️⃣ 实例化模型并前向传播
# ===============================
model = GCN(in_channels=3, hidden_channels=4, out_channels=2)
out = model(data.x, data.edge_index)
print("\n输出特征形状:", out.shape)
print("输出节点嵌入:\n", out)
# 若是分类任务:
pred = out.argmax(dim=1)
print("\n节点类别预测:", pred)