深度学习入门Day9:图神经网络原理与实战全解析
本文介绍了图神经网络(GNN)的核心原理与应用实践。主要内容包括:1)图数据表示与常见图数据集;2)GCN和GAT的原理及PyTorch实现;3)三大实战任务:节点分类(Cora数据集)、图分类(使用图池化)和链接预测(图自编码器);4)工业级应用如社交网络推荐系统和大图训练技巧;5)学习总结与资源推荐。文章通过代码示例详细展示了GNN模型的构建与训练过程,并讨论了实际应用中的关键技术与优化方法,
一、开篇:从规则数据到关系数据
前八天我们处理的都是规则网格数据(如图像)或序列数据(如文本),但现实世界中更多数据以图的形式存在:社交网络、分子结构、交通系统等。今天我们将学习图神经网络(GNN),这种专门处理关系数据的强大工具,它让深度学习能够理解复杂的关联和交互。

二、上午学习:GNN核心原理
2.1 图数据结构表示
图的数学定义:
import torch
from torch_geometric.data import Data
# 构建一个简单图数据
edge_index = torch.tensor([[0, 1, 1, 2], # 源节点
[1, 0, 2, 1]], dtype=torch.long) # 目标节点
x = torch.tensor([[-1], [0], [1]], dtype=torch.float) # 节点特征
data = Data(x=x, edge_index=edge_index)
print(f"节点数: {data.num_nodes}")
print(f"边数: {data.num_edges}")
print(f"平均度数: {data.num_edges / data.num_nodes}")
常见图数据集:
数据集信息表
| 数据集 | 类型 | 节点数 | 边数 | 任务 |
|---|---|---|---|---|
| Cora | 引文网络 | 2,708 | 5,429 | 节点分类 |
| MUTAG | 分子图 | 188 | - | 图分类 |
| 社交网络 | 4,039 | 88,234 | 社区发现 |
2.2 图卷积网络(GCN)实现
GCN层数学表达:
=σ(Â
)
其中 Â = D̃^{-1/2}ÃD̃^{-1/2}, Ã = A + I
PyTorch实现:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class GCN(nn.Module):
def __init__(self, num_features, hidden_size, num_classes):
super().__init__()
self.conv1 = GCNConv(num_features, hidden_size)
self.conv2 = GCNConv(hidden_size, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
2.3 图注意力网络(GAT)
注意力机制计算:
from torch_geometric.nn import GATConv
class GAT(nn.Module):
def __init__(self, num_features, hidden_size, num_classes, heads=8):
super().__init__()
self.conv1 = GATConv(num_features, hidden_size, heads=heads)
self.conv2 = GATConv(hidden_size*heads, num_classes, heads=1)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv1(x, edge_index)
x = F.elu(x)
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
GCN vs GAT对比:
| 特性 | GCN | GAT |
|---|---|---|
| 邻居权重 | 固定 | 动态学习 |
| 计算复杂度 | O( | E |
| 可解释性 | 弱 | 注意力权重 |
| 适合场景 | 同质图 | 异质图 |
三、下午实战:图数据应用
3.1 Cora节点分类实战
数据加载与训练:
from torch_geometric.datasets import Planetoid
# 加载Cora数据集
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]
# 创建模型
model = GCN(dataset.num_features, 16, dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
# 训练函数
def train():
model.train()
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return loss.item()
# 测试函数
def test():
model.eval()
logits = model(data)
pred = logits.argmax(dim=1)
accs = []
for mask in [data.train_mask, data.val_mask, data.test_mask]:
accs.append(pred[mask].eq(data.y[mask]).sum().item() / mask.sum().item())
return accs
# 训练循环
for epoch in range(200):
loss = train()
train_acc, val_acc, test_acc = test()
if epoch % 50 == 0:
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')
3.2 图分类任务实践
图池化实现:
from torch_geometric.nn import global_mean_pool
class GraphClassifier(nn.Module):
def __init__(self, num_features, hidden_size, num_classes):
super().__init__()
self.conv1 = GCNConv(num_features, hidden_size)
self.conv2 = GCNConv(hidden_size, hidden_size)
self.lin = nn.Linear(hidden_size, num_classes)
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
# 图级池化
x = global_mean_pool(x, batch)
x = self.lin(x)
return F.log_softmax(x, dim=1)
3.3 链接预测任务
from torch_geometric.nn import GAE
# 使用图自编码器:
encoder = GCN(dataset.num_features, 32, 16) # 输出16维嵌入
model = GAE(encoder)
# 负采样训练
def train():
model.train()
optimizer.zero_grad()
z = model.encode(data.x, data.edge_index)
loss = model.recon_loss(z, data.edge_index)
loss.backward()
optimizer.step()
return loss.item()
# 评估AUC
from sklearn.metrics import roc_auc_score
def test():
model.eval()
z = model.encode(data.x, data.edge_index)
pos_score = model.decoder(z, data.edge_index)
neg_edge_index = negative_sampling(data.edge_index)
neg_score = model.decoder(z, neg_edge_index)
scores = torch.cat([pos_score, neg_score]).detach().numpy()
labels = torch.cat([torch.ones(pos_score.size(0)),
torch.zeros(neg_score.size(0))]).numpy()
return roc_auc_score(labels, scores)
四、GNN高级应用与优化
4.1 工业级应用实现
社交网络推荐系统:
class SocialGNN(nn.Module):
def __init__(self, user_features, item_features, hidden_size):
super().__init__()
self.user_encoder = GATConv(user_features, hidden_size)
self.item_encoder = GATConv(item_features, hidden_size)
self.predictor = nn.Linear(2*hidden_size, 1)
def forward(self, user_data, item_data, edges):
user_emb = self.user_encoder(user_data.x, user_data.edge_index)
item_emb = self.item_encoder(item_data.x, item_data.edge_index)
src, dst = edges
pred = self.predictor(torch.cat([user_emb[src], item_emb[dst]], dim=1))
return torch.sigmoid(pred)
4.2 大图训练技巧
邻居采样:
from torch_geometric.loader import NeighborLoader
# 创建采样加载器
train_loader = NeighborLoader(
data,
num_neighbors=[10, 5], # 两层采样
batch_size=32,
input_nodes=data.train_mask
)
# 训练循环调整
for batch in train_loader:
optimizer.zero_grad()
out = model(batch.x, batch.edge_index)
loss = F.nll_loss(out[batch.train_mask], batch.y[batch.train_mask])
loss.backward()
optimizer.step()
五、学习总结与明日计划
5.1 今日核心成果
✅ 掌握GCN和GAT的核心原理与实现
✅ 完成节点分类、图分类、链接预测三大任务
✅ 学习图池化与邻居采样等实用技术
✅ 理解GNN在推荐系统等场景的应用
5.2 待解决问题
❓ 超大规模图的分布式训练
❓ 动态图的时间演化建模
❓ 图与文本/图像的多模态融合
5.3 明日学习重点
- 强化学习基础概念(马尔可夫决策过程)
- Q-Learning算法与Deep Q Network
- 策略梯度方法实现
- Gym环境实战
六、资源推荐与延伸阅读
1. PyG官方教程:最全面的GNN实践指南
2. 图表示学习书籍:系统性学习图神经网络
3. OGB排行榜:了解最前沿的GNN性能
4. GraphGym:图神经网络的实验管理工具
七、关键经验总结
1. 图数据预处理要点:
# 标准化节点特征
data.x = (data.x - data.x.mean(dim=0)) / data.x.std(dim=0)
# 添加自循环
edge_index = add_self_loops(data.edge_index)[0]
2. 模型调试技巧:
- 先在小图上过拟合(验证模型容量)
- 监控训练/验证损失差距
- 可视化第一层权重分布
3. 工业部署建议:
- 使用`torch_geometric.compile`加速模型
- 对静态图进行预处理计算
- 考虑使用`GraphSAGE`处理动态图
下篇预告:《Day10:深度强化学习入门—从Q学习到策略梯度》
将探索AI如何通过试错学习,实现游戏控制、机器人决策等复杂任务!
更多推荐
所有评论(0)