深度学习第P5周:Pytorch实现运动鞋识别
本次项目耗费很长时间在模型调整上,由于初始结果不理想,想到可能是由于卷积核5*5是否过大导致没能捕捉到敏感信息的缘故,将卷积核调整为3*3,但是结果更差,猜测可能是感受野过小不利于特征信息的捕捉,于是又将卷积核重新调整为5*5。同时,优化器修改为Adam。在此过程中,也曾将学习率修改为1e-3,但效果不理想,改回1e-4后结果达到满意状态。
·
>- **🍨 本文为[🔗365天深度学习训练营]中的学习记录博客**
>- **🍖 原作者:[K同学啊]**
本人往期文章可查阅: 深度学习总结
🍺要求:
- 了解如何设置动态学习率(重点)
- 调整代码使测试集accuracy到达84%。
🍻拔高(可选):
- 保存训练过程中的最佳模型权重
- 调整代码使测试集accuracy到达86%。
🏡 我的环境:
- 语言环境:Python3.11.7
- 编译器:Jupyter Lab
- 深度学习环境:Pytorch
一、 前期准备
1. 设置GPU
如果设备上支持GPU就使用GPU,否则使用CPU
import torch
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
运行结果:
device(type='cpu')
2. 导入数据
import pathlib
data_dir='D:\THE MNIST DATABASE\P5-data'
data_dir=pathlib.Path(data_dir)
data_paths=list(data_dir.glob('*'))
classeNames=[str(path).split("\\")[3] for path in data_paths]
classeNames
运行结果:
['test', 'train']
3. 测试获取到的图片
import matplotlib.pyplot as plt
from PIL import Image
import os
#指定图像文件夹路径
image_folder=r'D:\THE MNIST DATABASE\P5-data\train\adidas'
#获取文件夹中的所有图像文件
image_files=[f for f in os.listdir(image_folder) if f.endswith((".jpg",".png",".jpeg"))]
#创建Matplolib图像
fig,axes=plt.subplots(3,8,figsize=(16,6))
#使用列表推导式加载和显示图像
for ax,img_file in zip(axes.flat,image_files):
img_path=os.path.join(image_folder,img_file)
img=Image.open(img_path)
ax.imshow(img)
ax.axis('off')
#显示图像
plt.tight_layout()
plt.show()
运行结果:

4. 图像预处理
import torchvision
from torchvision import transforms,datasets
train_transforms=transforms.Compose([
transforms.Resize([224,224]), #将输入图片resize成统一尺寸
transforms.RandomHorizontalFlip(), #随机水平翻转
transforms.ToTensor(), #将PIL Image或numpy.ndarry转换为tensor,并归一化到[0,1]
transforms.Normalize( #标准化处理-->转换为标准正态分布(高斯分布),使模型更容易收敛
mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225])
])
test_transforms=transforms.Compose([
transforms.Resize([224,224]),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225])
])
train_dataset=datasets.ImageFolder(r"D:\THE MNIST DATABASE\P5-data\train",
transform=train_transforms)
test_dataset=datasets.ImageFolder(r"D:\THE MNIST DATABASE\P5-data\test",
transform=test_transforms)
train_dataset,test_dataset
运行结果:
(Dataset ImageFolder
Number of datapoints: 502
Root location: D:\THE MNIST DATABASE\P5-data\train
StandardTransform
Transform: Compose(
Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=True)
RandomHorizontalFlip(p=0.5)
ToTensor()
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
),
Dataset ImageFolder
Number of datapoints: 76
Root location: D:\THE MNIST DATABASE\P5-data\test
StandardTransform
Transform: Compose(
Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=True)
ToTensor()
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
))
映射数据集类别
train_dataset.class_to_idx
运行结果
{'adidas': 0, 'nike': 1}
5. 加载数据集
train_dl=torch.utils.data.DataLoader(
train_dataset,batch_size=32,shuffle=True)
test_dl=torch.utils.data.DataLoader(
test_dataset,batch_size=32,shuffle=True)
显示测试集的情况
for x,y in test_dl:
print("Shape of x [N,C,H,W]:",x.shape)
print("Shape of y:",y.shape,y.dtype)
break
运行结果:
Shape of x [N,C,H,W]: torch.Size([32, 3, 224, 224])
Shape of y: torch.Size([32]) torch.int64
二、构建简单的CNN网络
网络结构图(可单击放大查看):

import torch.nn.functional as F
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model,self).__init__()
#设置卷积层
self.conv1=nn.Sequential(
nn.Conv2d(3,12,kernel_size=5),
nn.BatchNorm2d(12),
nn.ReLU())
self.conv2=nn.Sequential(
nn.Conv2d(12,12,kernel_size=5),
nn.BatchNorm2d(12),
nn.ReLU())
self.conv3=nn.Sequential(
nn.Conv2d(12,24,kernel_size=5),
nn.BatchNorm2d(24),
nn.ReLU())
self.conv4=nn.Sequential(
nn.Conv2d(24,24,kernel_size=5),
nn.BatchNorm2d(24),
nn.ReLU())
self.maxpool=nn.MaxPool2d(2)
self.avgpool=nn.AvgPool2d(2)
self.dropout=nn.Dropout(0.2)
self.fc1=nn.Linear(24*50*50,len(classeNames))
def forward(self,x):
x=self.conv1(x) #卷积-BN-激活 12*220*220
x=self.conv2(x) #卷积-BN-激活 12*216*216
x=self.maxpool(x) #池化 12*108*108
x=self.conv3(x) #卷积-BN-激活 24*104*104
x=self.conv4(x) #卷积-BN-激活 24*100*100
x=self.maxpool(x) #池化 24*50*50
x=self.dropout(x)
x=x.view(x.size(0),-1) #flatten 变成全连接网络需要的输入
x=self.fc(x)
return x
device="cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))
model=Model().to(device)
model
运行结果:
Using cpu device
Model(
(conv1): Sequential(
(0): Conv2d(3, 12, kernel_size=(5, 5), stride=(1, 1))
(1): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv2): Sequential(
(0): Conv2d(12, 12, kernel_size=(5, 5), stride=(1, 1))
(1): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv3): Sequential(
(0): Conv2d(12, 24, kernel_size=(5, 5), stride=(1, 1))
(1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv4): Sequential(
(0): Conv2d(24, 24, kernel_size=(5, 5), stride=(1, 1))
(1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(avgpool): AvgPool2d(kernel_size=2, stride=2, padding=0)
(dropout): Dropout(p=0.2, inplace=False)
(fc1): Linear(in_features=60000, out_features=2, bias=True)
)
三、 训练模型
1. 编写训练函数
#训练循环
def train(dataloader,model,loss_fn,optimizer):
size=len(dataloader.dataset) #训练集的大小
num_batches=len(dataloader) #批次数目,(size/batch_size,向上取整)
train_loss,train_acc=0,0 #初始化损失和正确率
for x,y in dataloader: #获取图片及其标签
x,y=x.to(device),y.to(device)
#计算预测误差
pred=model(x) #网络输出
loss=loss_fn(pred,y) #计算网络输出和真实值之间的差距
#反向传播
optimizer.zero_grad() #grad属性归零
loss.backward() #反向传播
optimizer.step() #每一步自动更新
#记录acc与loss
train_acc+=(predd.argmax(1)==y).type(torch.float).sum().item()
train_loss+=loss.item()
train_acc/=size
train_loss/=num_batches
return train_acc,train_loss
2. 编写测试函数
测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器
def test(dataloader,model,loss_fn):
size=len(dataloader.dataset) #测试集的大小
num_batches=len(dataloader) #批次数目,(size/batch_size,向上取整)
test_loss,test_acc=0,0
#当不进行训练时,停止梯度更新,节省内存消耗
with torch.no_grad():
for imgs,target in dataloader:
imgs,target=imgs.to(device),target.to(device)
#计算loss
target_pred=model(imgs)
loss=loss_fn(target_pred,target)
test_loss+=loss.item()
test_acc+=(target_pred.argmax(1)==target).type(torch.float).sum().item()
test_acc/=size
test_loss/=num_batches
return test_acc,test_loss
3. 设置动态学习率
def adjust_learning_rate(optimizer,epoch,start_lr):
#每2个epoch衰减到原来的0.92
lr=start_lr*(0.92**(epoch//2))
for param_group in optimizer.param_groups:
param_group['lr']=lr
learn_rate=1e-4 #初始学习率
optimizer=torch.optim.SGD(model.parameters(),lr=learn_rate)
✨调用官方动态学习率接口
与上面方法是等价的
'''#调用官方动态学习率接口时使用
lambda1=lambda epoch:(0.92**(epoch//2))
optimizer=torch.optim.SGD(model.parameters(),lr=learn_rate)
scheduler=torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=lambda1)#选定调整方法'''
4. 正式训练
loss_fn=nn.CrossEntropyLoss() #创建损失函数
epochs=20
train_loss,train_acc,test_loss,test_acc=[],[],[],[]
for epoch in range(epochs):
#更新学习率(使用自定义学习率时使用)
adjust_learning_rate(optimizer,epoch,learn_rate)
model.train()
epoch_train_acc,epoch_train_loss=train(train_dl,model,loss_fn,optimizer)
#scheduler.step() #更新学习率(调用官方动态学习率接口时使用)
model.eval()
epoch_test_acc,epoch_test_loss=test(test_dl,model,loss_fn)
train_acc.append(epoch_test_acc)
train_loss.append(epoch_train_loss)
test_acc.append(epoch_test_acc)
test_loss.append(epoch_test_loss)
#获取当前学习率
lr=optimizer.state_dict()['param_groups'][0]['lr']
template=('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f},Lr:{:.2E}')
print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,
epoch_test_acc*100,epoch_test_loss,lr))
print('Done')
运行结果:
Epoch: 1,Train_acc:64.7%,Train_loss:0.652,Test_acc:64.5%,Test_loss:0.682,Lr:1.00E-04
Epoch: 2,Train_acc:67.3%,Train_loss:0.621,Test_acc:61.8%,Test_loss:0.634,Lr:1.00E-04
Epoch: 3,Train_acc:69.9%,Train_loss:0.572,Test_acc:63.2%,Test_loss:0.586,Lr:9.20E-05
Epoch: 4,Train_acc:69.9%,Train_loss:0.580,Test_acc:69.7%,Test_loss:0.543,Lr:9.20E-05
Epoch: 5,Train_acc:74.5%,Train_loss:0.516,Test_acc:65.8%,Test_loss:0.543,Lr:8.46E-05
Epoch: 6,Train_acc:73.9%,Train_loss:0.525,Test_acc:73.7%,Test_loss:0.584,Lr:8.46E-05
Epoch: 7,Train_acc:74.3%,Train_loss:0.516,Test_acc:71.1%,Test_loss:0.549,Lr:7.79E-05
Epoch: 8,Train_acc:78.5%,Train_loss:0.494,Test_acc:72.4%,Test_loss:0.536,Lr:7.79E-05
Epoch: 9,Train_acc:78.3%,Train_loss:0.490,Test_acc:68.4%,Test_loss:0.569,Lr:7.16E-05
Epoch:10,Train_acc:79.1%,Train_loss:0.492,Test_acc:61.8%,Test_loss:0.574,Lr:7.16E-05
Epoch:11,Train_acc:80.9%,Train_loss:0.468,Test_acc:73.7%,Test_loss:0.548,Lr:6.59E-05
Epoch:12,Train_acc:80.7%,Train_loss:0.468,Test_acc:76.3%,Test_loss:0.555,Lr:6.59E-05
Epoch:13,Train_acc:81.1%,Train_loss:0.450,Test_acc:68.4%,Test_loss:0.563,Lr:6.06E-05
Epoch:14,Train_acc:82.9%,Train_loss:0.450,Test_acc:75.0%,Test_loss:0.562,Lr:6.06E-05
Epoch:15,Train_acc:84.1%,Train_loss:0.426,Test_acc:77.6%,Test_loss:0.563,Lr:5.58E-05
Epoch:16,Train_acc:83.7%,Train_loss:0.434,Test_acc:77.6%,Test_loss:0.538,Lr:5.58E-05
Epoch:17,Train_acc:82.3%,Train_loss:0.420,Test_acc:68.4%,Test_loss:0.531,Lr:5.13E-05
Epoch:18,Train_acc:83.7%,Train_loss:0.425,Test_acc:76.3%,Test_loss:0.496,Lr:5.13E-05
Epoch:19,Train_acc:82.7%,Train_loss:0.424,Test_acc:76.3%,Test_loss:0.524,Lr:4.72E-05
Epoch:20,Train_acc:84.1%,Train_loss:0.410,Test_acc:77.6%,Test_loss:0.501,Lr:4.72E-05
Done
由于结果不尽人意,重新调整模型,将卷积核修改为3*3,但结果依旧不理想。然后将卷积核重新改为5*5,优化器修改为Adam,跑了60轮后实现达到86%以上的目标。
Epoch: 1,Train_acc:50.8%,Train_loss:1.307,Test_acc:51.3%,Test_loss:0.911,Lr:1.00E-04
Epoch: 2,Train_acc:66.1%,Train_loss:0.648,Test_acc:77.6%,Test_loss:0.565,Lr:1.00E-04
Epoch: 3,Train_acc:70.5%,Train_loss:0.581,Test_acc:67.1%,Test_loss:0.477,Lr:9.20E-05
Epoch: 4,Train_acc:79.7%,Train_loss:0.446,Test_acc:80.3%,Test_loss:0.460,Lr:9.20E-05
Epoch: 5,Train_acc:84.7%,Train_loss:0.394,Test_acc:75.0%,Test_loss:0.515,Lr:8.46E-05
Epoch: 6,Train_acc:87.8%,Train_loss:0.338,Test_acc:73.7%,Test_loss:0.491,Lr:8.46E-05
Epoch: 7,Train_acc:87.1%,Train_loss:0.324,Test_acc:82.9%,Test_loss:0.410,Lr:7.79E-05
Epoch: 8,Train_acc:89.8%,Train_loss:0.287,Test_acc:81.6%,Test_loss:0.430,Lr:7.79E-05
Epoch: 9,Train_acc:92.8%,Train_loss:0.247,Test_acc:80.3%,Test_loss:0.447,Lr:7.16E-05
Epoch:10,Train_acc:93.4%,Train_loss:0.242,Test_acc:76.3%,Test_loss:0.576,Lr:7.16E-05
Epoch:11,Train_acc:96.4%,Train_loss:0.190,Test_acc:82.9%,Test_loss:0.345,Lr:6.59E-05
Epoch:12,Train_acc:94.0%,Train_loss:0.186,Test_acc:81.6%,Test_loss:0.350,Lr:6.59E-05
Epoch:13,Train_acc:96.4%,Train_loss:0.170,Test_acc:77.6%,Test_loss:0.445,Lr:6.06E-05
Epoch:14,Train_acc:95.4%,Train_loss:0.167,Test_acc:86.8%,Test_loss:0.328,Lr:6.06E-05
Epoch:15,Train_acc:97.0%,Train_loss:0.136,Test_acc:85.5%,Test_loss:0.372,Lr:5.58E-05
Epoch:16,Train_acc:98.6%,Train_loss:0.125,Test_acc:85.5%,Test_loss:0.333,Lr:5.58E-05
Epoch:17,Train_acc:98.8%,Train_loss:0.121,Test_acc:85.5%,Test_loss:0.398,Lr:5.13E-05
Epoch:18,Train_acc:99.2%,Train_loss:0.113,Test_acc:85.5%,Test_loss:0.359,Lr:5.13E-05
Epoch:19,Train_acc:99.4%,Train_loss:0.100,Test_acc:82.9%,Test_loss:0.357,Lr:4.72E-05
Epoch:20,Train_acc:99.0%,Train_loss:0.095,Test_acc:85.5%,Test_loss:0.353,Lr:4.72E-05
Epoch:21,Train_acc:99.2%,Train_loss:0.091,Test_acc:82.9%,Test_loss:0.367,Lr:4.34E-05
Epoch:22,Train_acc:99.4%,Train_loss:0.082,Test_acc:85.5%,Test_loss:0.407,Lr:4.34E-05
Epoch:23,Train_acc:99.4%,Train_loss:0.086,Test_acc:86.8%,Test_loss:0.306,Lr:4.00E-05
Epoch:24,Train_acc:99.8%,Train_loss:0.082,Test_acc:86.8%,Test_loss:0.294,Lr:4.00E-05
Epoch:25,Train_acc:99.6%,Train_loss:0.077,Test_acc:86.8%,Test_loss:0.337,Lr:3.68E-05
Epoch:26,Train_acc:99.2%,Train_loss:0.075,Test_acc:85.5%,Test_loss:0.386,Lr:3.68E-05
Epoch:27,Train_acc:99.8%,Train_loss:0.068,Test_acc:85.5%,Test_loss:0.312,Lr:3.38E-05
Epoch:28,Train_acc:99.8%,Train_loss:0.066,Test_acc:85.5%,Test_loss:0.419,Lr:3.38E-05
Epoch:29,Train_acc:100.0%,Train_loss:0.059,Test_acc:82.9%,Test_loss:0.338,Lr:3.11E-05
Epoch:30,Train_acc:100.0%,Train_loss:0.063,Test_acc:85.5%,Test_loss:0.311,Lr:3.11E-05
Epoch:31,Train_acc:99.6%,Train_loss:0.059,Test_acc:86.8%,Test_loss:0.324,Lr:2.86E-05
Epoch:32,Train_acc:99.8%,Train_loss:0.060,Test_acc:84.2%,Test_loss:0.336,Lr:2.86E-05
Epoch:33,Train_acc:99.8%,Train_loss:0.054,Test_acc:85.5%,Test_loss:0.302,Lr:2.63E-05
Epoch:34,Train_acc:100.0%,Train_loss:0.053,Test_acc:82.9%,Test_loss:0.282,Lr:2.63E-05
Epoch:35,Train_acc:100.0%,Train_loss:0.052,Test_acc:86.8%,Test_loss:0.303,Lr:2.42E-05
Epoch:36,Train_acc:100.0%,Train_loss:0.052,Test_acc:86.8%,Test_loss:0.329,Lr:2.42E-05
Epoch:37,Train_acc:99.8%,Train_loss:0.051,Test_acc:85.5%,Test_loss:0.372,Lr:2.23E-05
Epoch:38,Train_acc:99.8%,Train_loss:0.048,Test_acc:85.5%,Test_loss:0.376,Lr:2.23E-05
Epoch:39,Train_acc:100.0%,Train_loss:0.049,Test_acc:86.8%,Test_loss:0.321,Lr:2.05E-05
Epoch:40,Train_acc:100.0%,Train_loss:0.045,Test_acc:85.5%,Test_loss:0.395,Lr:2.05E-05
Epoch:41,Train_acc:99.8%,Train_loss:0.046,Test_acc:85.5%,Test_loss:0.288,Lr:1.89E-05
Epoch:42,Train_acc:100.0%,Train_loss:0.044,Test_acc:86.8%,Test_loss:0.349,Lr:1.89E-05
Epoch:43,Train_acc:99.8%,Train_loss:0.045,Test_acc:85.5%,Test_loss:0.400,Lr:1.74E-05
Epoch:44,Train_acc:100.0%,Train_loss:0.041,Test_acc:85.5%,Test_loss:0.310,Lr:1.74E-05
Epoch:45,Train_acc:100.0%,Train_loss:0.042,Test_acc:85.5%,Test_loss:0.306,Lr:1.60E-05
Epoch:46,Train_acc:99.8%,Train_loss:0.041,Test_acc:85.5%,Test_loss:0.341,Lr:1.60E-05
Epoch:47,Train_acc:100.0%,Train_loss:0.038,Test_acc:85.5%,Test_loss:0.338,Lr:1.47E-05
Epoch:48,Train_acc:100.0%,Train_loss:0.040,Test_acc:84.2%,Test_loss:0.329,Lr:1.47E-05
Epoch:49,Train_acc:100.0%,Train_loss:0.037,Test_acc:85.5%,Test_loss:0.328,Lr:1.35E-05
Epoch:50,Train_acc:100.0%,Train_loss:0.038,Test_acc:82.9%,Test_loss:0.314,Lr:1.35E-05
Epoch:51,Train_acc:100.0%,Train_loss:0.039,Test_acc:86.8%,Test_loss:0.321,Lr:1.24E-05
Epoch:52,Train_acc:100.0%,Train_loss:0.036,Test_acc:86.8%,Test_loss:0.319,Lr:1.24E-05
Epoch:53,Train_acc:100.0%,Train_loss:0.037,Test_acc:85.5%,Test_loss:0.344,Lr:1.14E-05
Epoch:54,Train_acc:100.0%,Train_loss:0.035,Test_acc:86.8%,Test_loss:0.322,Lr:1.14E-05
Epoch:55,Train_acc:100.0%,Train_loss:0.037,Test_acc:86.8%,Test_loss:0.294,Lr:1.05E-05
Epoch:56,Train_acc:100.0%,Train_loss:0.036,Test_acc:86.8%,Test_loss:0.302,Lr:1.05E-05
Epoch:57,Train_acc:100.0%,Train_loss:0.037,Test_acc:85.5%,Test_loss:0.284,Lr:9.68E-06
Epoch:58,Train_acc:100.0%,Train_loss:0.037,Test_acc:85.5%,Test_loss:0.325,Lr:9.68E-06
Epoch:59,Train_acc:100.0%,Train_loss:0.033,Test_acc:85.5%,Test_loss:0.383,Lr:8.91E-06
Epoch:60,Train_acc:100.0%,Train_loss:0.034,Test_acc:85.5%,Test_loss:0.332,Lr:8.91E-06
Done
四、 结果可视化
1. Loss与Accuracy图
import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore") #忽略警告信息
plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False #用来正常显示负号
plt.rcParams['figure.dpi']=300 #分辨率
epochs_range=range(epochs)
plt.figure(figsize=(12,3))
plt.subplot(1,2,1)
plt.plot(epochs_range,train_acc,label='Training Accuracy')
plt.plot(epochs_range,test_acc,label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1,2,2)
plt.plot(epochs_range,train_loss,label='Training Loss')
plt.plot(epochs_range,test_loss,label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
2. 指定图片进行预测
from PIL import Image
calsses=list(train_dataset.class_to_idx)
def predict_one_image(image_path,model,transform,classes):
test_img=Image.open(image_path).convert('RGB')
plt.imshow(test_img) # 展示预测的图片
test_img=transform(test_img)
img=test_img.to(device).unsqueeze(0)
model.eval()
output=model(img)
_,pred=torch.max(output,1)
pred_class=classes[pred]
print(f'预测结果是:{pred_class}')
#预测训练集中的某张照片
predict_one_image(image_path=r'D:\THE MNIST DATABASE\P5-data\test\adidas\1.jpg',
model=model,
transform=train_transforms,classes=calsses)
运行结果:

五、保存并加载模型
#模型保存
PATH=r'C:\Users\Administrator\PycharmProjects\pytorchProject1\P5周:Pytorch实现运动鞋识别\model-p5.pth'
torch.save(model.state_dict(),PATH)
#将参数加载到model当中
model.load_state_dict(torch.load(PATH,map_location=device))
运行结果:
<All keys matched successfully>
六、个人总结
本次项目耗费很长时间在模型调整上,由于初始结果不理想,想到可能是由于卷积核5*5是否过大导致没能捕捉到敏感信息的缘故,将卷积核调整为3*3,但是结果更差,猜测可能是感受野过小不利于特征信息的捕捉,于是又将卷积核重新调整为5*5。同时,优化器修改为Adam。在此过程中,也曾将学习率修改为1e-3,但效果不理想,改回1e-4后结果达到满意状态。
更多推荐




所有评论(0)