Pytorch中gpu的并行运算

常用的最多的就是,多块GPU训练同一个网络模型。Pytorch中的并行运算。

1. 多GPU输入数据并行运算

一般使用torch.nn.DataParallel,例如:

device_ids = [0, 1]
net = torch.nn.DataParallel(net, device_ids=device_ids)
2. 推荐GPU设置方式:
  • 单卡
    使用CUDA_VISIBLE_DEVICES指定GPU,然后.cuda()不传入参数
    import os 
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu_ids  # gpu_ids参数为int类型,如 0
    model.cuda()
    
  • 多GPU输入数据并行处理
    import os 
    gpu_list = '0,1,2,3'
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu_list
    
    device_ids = [0, 1, 2, 3]
    net = torch.nn.DataParallel(net, device_ids=device_ids)
    
3. 保存加载多GPU网络
net = torch.nn.Linear(10,1)  # 构造网络
net = torch.nn.DataParallel(net, device_ids=[0,1, 2, 3]) 
torch.save(net.module.state_dict(), './model/multiGPU.pth') #保存网络

# 加载网络
new_net = torch.nn.Linear(10,1)
new_net.load_state_dict(torch.load("./model/multiGPU.pth"))

更多推荐