11.6 GNN

零基础多图详解图神经网络(GNN/GCN)【论文精读】

11.6.1 原理

  1. 图的结构

    图由边与节点构成,即\(G=\{V,E\}\),有时还会附带全局信息\(\text{Master Node(U)}\)。无论是节点、边还是全局信息,都可通过向量来存储数据。而对于图的连接信息,则可通过邻接列表(二维列表,每个子列表代表谁与谁连接)来存储。

  2. 信息聚合

    边、节点亦或是全局信息都可通过信息聚合的方式从邻居(三者都可以)处获取信息。可通过求和、求平均、取最大值的形式完成信息聚合。

    sum、mean、max的操作没有显著差异

  3. 输入与输出

    图神经网络的输入与输出都是图,每个GNN层都是一次信息聚合,从而完成节点、边或全局信息的更新。堆叠层数越多就是让元素能够逐步整合更大范围的图结构信息。

11.6.2 图神经网络的类型

  1. GCN(图卷积神经网络)

    GCN以节点为研究单位,根据连接关系从邻居节点处聚合特征。

  2. GAT(图注意力网络)

    引入注意力机制,使得GCN能够根据注意力得分对邻居特征进行加权。

  3. ST-GNN(时空图神经网络)

    利用图神经网络从空间角度建模,也就是说可以用图神经网络对具有网络结构的数据进行特征提取。之后可将提取后的特征代入到时序模型,例如LSTM、GRU等。通过同时捕捉节点间拓扑依赖和时间动态变化,实现对时空关联数据的精准预测。

11.6.3 示例

图神经网络建模可通过torch_geometric实现。基本建模过程就是定义图数据结构(确定每个节点的特征、构建边关系),之后再定义图神经网络模型即可。

注意图神经网络的视角是空间视角,抓住数据中的结构关系即可

import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv

# ===============================
# 1️⃣ 定义图结构
# ===============================

# 图的边 (source, target),采用 COO 格式
# 例如:0↔1, 1↔2, 2↔3, 3↔4, 4↔0
edge_index = torch.tensor([
    [0, 1, 2, 3, 4, 1, 2, 3, 4, 0],  # source
    [1, 2, 3, 4, 0, 0, 1, 2, 3, 4]   # target
], dtype=torch.long)

# 每个节点的特征(这里每个节点3维)
x = torch.randn((5, 3))

# 如果是节点分类任务,可加上节点标签
y = torch.tensor([0, 1, 0, 1, 0], dtype=torch.long)

# 构建图数据对象
data = Data(x=x, edge_index=edge_index, y=y)

print("图信息:")
print(data)
print("节点特征形状:", data.x.shape)
print("边数量:", data.edge_index.shape[1])

# ===============================
# 2️⃣ 定义 GCN 模型
# ===============================
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        # 第一层:图卷积 + ReLU
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        # 第二层:图卷积 + Softmax 输出(用于分类)
        x = self.conv2(x, edge_index)
        return x

# ===============================
# 3️⃣ 实例化模型并前向传播
# ===============================
model = GCN(in_channels=3, hidden_channels=4, out_channels=2)
out = model(data.x, data.edge_index)

print("\n输出特征形状:", out.shape)
print("输出节点嵌入:\n", out)

# 若是分类任务:
pred = out.argmax(dim=1)
print("\n节点类别预测:", pred)