Pytorch迁移学习——修改训练好的网络结构及不同层设置不同的学习速率
目录1. 修改最后一层的网络结构2. 不同网络层设置不同的学习速率3. Demo在使用深度学习的时候有时候需要用到迁移学习,但是由于不同的任务最终的输出可能不同因此需要修改最后的网络结构,并且由于最后一层前面的权重是已经训练好的,因此不用再花大量的精力集中在它们上面只需要对最后一层的权重进行重点训练即可。1. 修改最后一层的网络结构以为分类为题为例,如果在Imagenet上...
·
目录
在使用深度学习的时候有时候需要用到迁移学习,但是由于不同的任务最终的输出可能不同因此需要修改最后的网络结构,并且由于最后一层前面的权重是已经训练好的,因此不用再花大量的精力集中在它们上面只需要对最后一层的权重进行重点训练即可。
1. 修改最后一层的网络结构
以为分类为题为例,如果在Imagenet上训练,那么其最后一层有1000个结点。但是我们的数据集可能没有那么多种类,因此需要将其修改为合适的数目。可以构造如下的网络来将训练好的网络结构的最后一层替换为我们所需要的。值得注意的是,如果不知道倒数第二层的输出,可以先随便写个数,反正会报错,然后根据错误信息进行相应的修改。
class TransferNet(nn.Module):
def __init__(self, model, input_dim, output_dim):
super(TransferNet, self).__init__()
self.pre_layers = nn.Sequential(*list(model.children()))[:-1]
self.last_layer = nn.Linear(input_dim, output_dim)
def forward(self, x):
x = self.pre_layers(x)
x = x.view(x.size(0), -1)
x = self.last_layer(x)
return x
2. 不同网络层设置不同的学习速率
由于在迁移学习中,前面的网络层已经得到充分的训练,因此在fine-tune中需要对最后一层进行学习速率进行调整,代码如下:
def set_optimizer(model, lr_base, momentum, w_decay):
last_params = map(id, model.last_layer.parameters())
pre_params = filter(lambda addr: id(addr) not in last_params, model.parameters())
optimizer = torch.optim.SGD([
{'params': pre_params},
{'params': model.last_layer.parameters(), 'lr': 0.1}], lr=lr_base, momentum = momentum, weight_decay=w_decay)
return optimizer
3. Demo
# -*- coding: UTF-8 -*-
"""
@FileName: Demo.py
@Description: Implement Transfer learning
@Author: Lj
@CreateDate: 2019/11/28 14:01
@LastEditTime: 2019/11/28 14:47
@LastEditors: Please set LastEditors
@Version: v1.0
"""
import torch
import torch.nn as nn
from torchvision.models import resnet50
class TransferNet(nn.Module):
def __init__(self, model, input_dim, output_dim):
super(TransferNet, self).__init__()
self.pre_layers = nn.Sequential(*list(model.children()))[:-1]
self.last_layer = nn.Linear(input_dim, output_dim)
def forward(self, x):
x = self.pre_layers(x)
x = x.view(x.size(0), -1)
x = self.last_layer(x)
return x
def set_optimizer(model, lr_base, momentum, w_decay):
last_params = map(id, model.last_layer.parameters())
pre_params = filter(lambda addr: id(addr) not in last_params, model.parameters())
optimizer = torch.optim.SGD([
{'params': pre_params},
{'params': model.last_layer.parameters(), 'lr': 0.1}], lr=lr_base, momentum = momentum, weight_decay=w_decay)
return optimizer
if __name__ == '__main__':
model = resnet50()
model = TransferNet(model, 2048, 100)
x = torch.randn(1,3,224,224)
print(model(x))

更多推荐
所有评论(0)