pytorch 迁移学习多分类(resnet18)
丢完代码就跑
·
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
from torchvision import transforms, utils
from torchvision import models
import glob
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
label = pd.read_csv('train.csv')
label = label.set_index('filename')
labels = [int(label.loc[int(i.split('\\')[1].split('.')[0])]) for i in images]
images = glob.glob('train/*.jpg')
num_train = int(len(labels)*0.8)
class FoodDataset(Dataset):
def __init__(self, images, labels, transform):
self.images = images
self.labels = labels
self.transform = transform
def __getitem__(self, index):
img = Image.open(self.images[index]).convert('RGB')
img = self.transform(img)
return img, self.labels[index]
def __len__(self):
return len(self.labels)
transform_train=transforms.Compose([
transforms.Resize([256,256]),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])
])
transform_val=transforms.Compose([
transforms.Resize([256,256]),
transforms.ToTensor(),
transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])
])
train_dataset = FoodDataset(images[:num_train], labels[:num_train], transform_train)
train_loader = DataLoader(dataset = train_dataset, batch_size=128, shuffle=True)
val_dataset = FoodDataset(images[num_train:], labels[num_train:], transform_val)
val_loader = DataLoader(dataset = val_dataset, batch_size=128, shuffle=False)
def show_batch(images_batch):
batch_size = len(images_batch)
im_size = images_batch.size(2)
grid = utils.make_grid(images_batch)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
plt.show()

def build_model(num_classes):
transfer_model = models.resnet18(pretrained=True)
for param in transfer_model.parameters():
param.requires_grad = False
# 修改最后一层维数,即 把原来的全连接层 替换成 输出维数为2的全连接层
dim = transfer_model.fc.in_features
transfer_model.fc = nn.Linear(dim, num_classes)
return transfer_model
net = build_model(4).to(device)
criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
optimizer = torch.optim.SGD(net.fc.parameters(), lr=1e-3)
def train():
net.train()
batch_num = len(train_loader)
running_loss = 0.0
for i, data in enumerate(train_loader,start=1):
# 将输入传入GPU
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 计算误差并显示
running_loss += loss.item()
if i % 20 == 0:
print(
'batch:{}/{} loss:{:.3f}'.format(i, batch_num, running_loss / 20))
running_loss = 0.0
#测试函数
def validate():
net.eval() # !!!!!!!
correct = 0
total = 0
with torch.no_grad():
for data in val_loader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the test images: %d %%' %
(100 * correct / total))
n_epoch = 10
for epoch in range(n_epoch):
print('epoch {}'.format(epoch+1))
train()
validate()
save_path = 'params/param_{}.pkl'.format(epoch)
torch.save(net.state_dict(), save_path)
'''
epoch 1
batch:20/39 loss:1.348
Accuracy of the network on the test images: 48 %
epoch 2
batch:20/39 loss:1.183
Accuracy of the network on the test images: 55 %
epoch 3
batch:20/39 loss:1.088
Accuracy of the network on the test images: 64 %
epoch 4
batch:20/39 loss:1.005
Accuracy of the network on the test images: 68 %
epoch 5
batch:20/39 loss:0.953
Accuracy of the network on the test images: 71 %
epoch 6
batch:20/39 loss:0.896
Accuracy of the network on the test images: 73 %
epoch 7
batch:20/39 loss:0.840
Accuracy of the network on the test images: 75 %
epoch 8
batch:20/39 loss:0.797
Accuracy of the network on the test images: 77 %
epoch 9
batch:20/39 loss:0.770
Accuracy of the network on the test images: 78 %
epoch 10
batch:20/39 loss:0.729
Accuracy of the network on the test images: 78 %
'''
# net.load_state_dict(torch.load(save_path))
for epoch in range(10,20):
print('epoch {}'.format(epoch+1))
train()
validate()
save_path = 'params/param_{}.pkl'.format(epoch)
torch.save(net.state_dict(), save_path)
'''
epoch 11
batch:20/39 loss:0.704
Accuracy of the network on the test images: 80 %
epoch 12
batch:20/39 loss:0.675
Accuracy of the network on the test images: 81 %
epoch 13
batch:20/39 loss:0.666
Accuracy of the network on the test images: 81 %
epoch 14
batch:20/39 loss:0.655
Accuracy of the network on the test images: 82 %
epoch 15
batch:20/39 loss:0.633
Accuracy of the network on the test images: 83 %
epoch 16
batch:20/39 loss:0.608
Accuracy of the network on the test images: 84 %
epoch 17
batch:20/39 loss:0.588
Accuracy of the network on the test images: 84 %
epoch 18
batch:20/39 loss:0.586
Accuracy of the network on the test images: 84 %
epoch 19
batch:20/39 loss:0.575
Accuracy of the network on the test images: 84 %
epoch 20
batch:20/39 loss:0.561
Accuracy of the network on the test images: 85 %
'''
optimizer = torch.optim.SGD(net.parameters(), lr=1e-3) # 注意这里把 net.fc 改成了 net
for param in net.parameters():
param.requires_grad = True
for epoch in range(20,30):
print('epoch {}'.format(epoch+1))
train()
validate()
save_path = 'params/param_{}.pkl'.format(epoch)
torch.save(net.state_dict(), save_path)
'''
epoch 21
batch:20/39 loss:0.509
Accuracy of the network on the test images: 87 %
epoch 22
batch:20/39 loss:0.467
Accuracy of the network on the test images: 88 %
epoch 23
batch:20/39 loss:0.395
Accuracy of the network on the test images: 88 %
epoch 24
batch:20/39 loss:0.395
Accuracy of the network on the test images: 89 %
epoch 25
batch:20/39 loss:0.366
Accuracy of the network on the test images: 89 %
epoch 26
batch:20/39 loss:0.337
Accuracy of the network on the test images: 90 %
epoch 27
batch:20/39 loss:0.329
Accuracy of the network on the test images: 91 %
epoch 28
batch:20/39 loss:0.293
Accuracy of the network on the test images: 91 %
epoch 29
batch:20/39 loss:0.282
Accuracy of the network on the test images: 91 %
epoch 30
batch:20/39 loss:0.267
Accuracy of the network on the test images: 92 %
'''
class TestDataset(Dataset):
def __init__(self, images, transform):
self.images = images
self.transform = transform
def __getitem__(self, index):
img = Image.open(self.images[index]).convert('RGB')
img = self.transform(img)
return img, self.images[index]
def __len__(self):
return len(self.images)
transform_test=transforms.Compose([
transforms.Resize([256,256]),
transforms.ToTensor(),
transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])
])
test_images = glob.glob('test/*.jpg')
test_dataset = TestDataset(test_images, transform_train)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)
def test():
result = {}
net.eval() # !!!!!!!
with torch.no_grad():
for images, names in test_loader:
images = images.to(device)
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
for name, pred in zip(names,predicted.to('cpu')):
result[name] = pred.item()
return result
result = test()
keys, values = [], []
for key, value in result.items():
keys.append(int(key.split('\\')[1].split('.')[0]))
values.append(value)
df = pd.DataFrame({'filename':keys,'label':values})
df = df.sort_values(by='filename')
df = df.set_index('filename')
df.to_csv('test.csv',header=False,encoding = "UTF8")
更多推荐



所有评论(0)