>- **🍨 本文为[🔗365天深度学习训练营]中的学习记录博客**
>- **🍖 原作者:[K同学啊]**

 本人往期文章可查阅: 深度学习总结

🍺要求:

  1. 了解如何设置动态学习率(重点)
  2. 调整代码使测试集accuracy到达84%。

🍻拔高(可选):

  1. 保存训练过程中的最佳模型权重
  2. 调整代码使测试集accuracy到达86%。

🏡 我的环境:

  • 语言环境:Python3.11.7
  • 编译器:Jupyter Lab
  • 深度学习环境:Pytorch

 一、 前期准备

1. 设置GPU

 如果设备上支持GPU就使用GPU,否则使用CPU

import torch

device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

 运行结果:

device(type='cpu')

2. 导入数据

import pathlib

data_dir='D:\THE MNIST DATABASE\P5-data'
data_dir=pathlib.Path(data_dir)

data_paths=list(data_dir.glob('*'))
classeNames=[str(path).split("\\")[3] for path in data_paths]
classeNames

运行结果:

['test', 'train']

3. 测试获取到的图片

import matplotlib.pyplot as plt
from PIL import Image
import os

#指定图像文件夹路径
image_folder=r'D:\THE MNIST DATABASE\P5-data\train\adidas'

#获取文件夹中的所有图像文件
image_files=[f for f in os.listdir(image_folder) if f.endswith((".jpg",".png",".jpeg"))]
#创建Matplolib图像
fig,axes=plt.subplots(3,8,figsize=(16,6))

#使用列表推导式加载和显示图像
for ax,img_file in zip(axes.flat,image_files):
    img_path=os.path.join(image_folder,img_file)
    img=Image.open(img_path)
    ax.imshow(img)
    ax.axis('off')

#显示图像
plt.tight_layout()
plt.show()

运行结果:

4. 图像预处理

import torchvision
from torchvision import transforms,datasets
train_transforms=transforms.Compose([
    transforms.Resize([224,224]),     #将输入图片resize成统一尺寸
    transforms.RandomHorizontalFlip(),    #随机水平翻转
    transforms.ToTensor(),     #将PIL Image或numpy.ndarry转换为tensor,并归一化到[0,1]
    transforms.Normalize(     #标准化处理-->转换为标准正态分布(高斯分布),使模型更容易收敛
        mean=[0.485,0.456,0.406],
        std=[0.229,0.224,0.225])  
])
test_transforms=transforms.Compose([
    transforms.Resize([224,224]),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485,0.456,0.406],
        std=[0.229,0.224,0.225])
])
train_dataset=datasets.ImageFolder(r"D:\THE MNIST DATABASE\P5-data\train",
                                  transform=train_transforms)
test_dataset=datasets.ImageFolder(r"D:\THE MNIST DATABASE\P5-data\test",
                                  transform=test_transforms)
train_dataset,test_dataset

运行结果:

(Dataset ImageFolder
     Number of datapoints: 502
     Root location: D:\THE MNIST DATABASE\P5-data\train
     StandardTransform
 Transform: Compose(
                Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=True)
                RandomHorizontalFlip(p=0.5)
                ToTensor()
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ),
 Dataset ImageFolder
     Number of datapoints: 76
     Root location: D:\THE MNIST DATABASE\P5-data\test
     StandardTransform
 Transform: Compose(
                Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=True)
                ToTensor()
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ))

 映射数据集类别

train_dataset.class_to_idx

 运行结果

{'adidas': 0, 'nike': 1}

5. 加载数据集

train_dl=torch.utils.data.DataLoader(
    train_dataset,batch_size=32,shuffle=True)
test_dl=torch.utils.data.DataLoader(
    test_dataset,batch_size=32,shuffle=True)

显示测试集的情况

for x,y in test_dl:
    print("Shape of x [N,C,H,W]:",x.shape)
    print("Shape of y:",y.shape,y.dtype)
    break

 运行结果:

Shape of x [N,C,H,W]: torch.Size([32, 3, 224, 224])
Shape of y: torch.Size([32]) torch.int64

二、构建简单的CNN网络

网络结构图(可单击放大查看)

import torch.nn.functional as F
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        #设置卷积层
        self.conv1=nn.Sequential(
            nn.Conv2d(3,12,kernel_size=5),
            nn.BatchNorm2d(12),
            nn.ReLU())
        self.conv2=nn.Sequential(
            nn.Conv2d(12,12,kernel_size=5),
            nn.BatchNorm2d(12),
            nn.ReLU())
        self.conv3=nn.Sequential(
            nn.Conv2d(12,24,kernel_size=5),
            nn.BatchNorm2d(24),
            nn.ReLU())
        self.conv4=nn.Sequential(
            nn.Conv2d(24,24,kernel_size=5),
            nn.BatchNorm2d(24),
            nn.ReLU())
        
        self.maxpool=nn.MaxPool2d(2)
        self.avgpool=nn.AvgPool2d(2)
        self.dropout=nn.Dropout(0.2)
        self.fc1=nn.Linear(24*50*50,len(classeNames))
    
    def forward(self,x):
        x=self.conv1(x)  #卷积-BN-激活   12*220*220
        x=self.conv2(x)  #卷积-BN-激活   12*216*216
        x=self.maxpool(x)  #池化   12*108*108
        x=self.conv3(x)  #卷积-BN-激活   24*104*104
        x=self.conv4(x)  #卷积-BN-激活   24*100*100
        x=self.maxpool(x)  #池化   24*50*50
        x=self.dropout(x)
        x=x.view(x.size(0),-1) #flatten 变成全连接网络需要的输入
        x=self.fc(x)
        
        return x

device="cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

model=Model().to(device)
model        

运行结果:

Using cpu device
Model(
  (conv1): Sequential(
    (0): Conv2d(3, 12, kernel_size=(5, 5), stride=(1, 1))
    (1): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (conv2): Sequential(
    (0): Conv2d(12, 12, kernel_size=(5, 5), stride=(1, 1))
    (1): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (conv3): Sequential(
    (0): Conv2d(12, 24, kernel_size=(5, 5), stride=(1, 1))
    (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (conv4): Sequential(
    (0): Conv2d(24, 24, kernel_size=(5, 5), stride=(1, 1))
    (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (avgpool): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (dropout): Dropout(p=0.2, inplace=False)
  (fc1): Linear(in_features=60000, out_features=2, bias=True)
)

三、 训练模型

1. 编写训练函数

#训练循环
def train(dataloader,model,loss_fn,optimizer):
    size=len(dataloader.dataset)    #训练集的大小
    num_batches=len(dataloader)    #批次数目,(size/batch_size,向上取整)
    
    train_loss,train_acc=0,0    #初始化损失和正确率
    
    for x,y in dataloader:   #获取图片及其标签
        x,y=x.to(device),y.to(device)
        
        #计算预测误差
        pred=model(x)   #网络输出
        loss=loss_fn(pred,y)  #计算网络输出和真实值之间的差距
        
        #反向传播
        optimizer.zero_grad()   #grad属性归零
        loss.backward()    #反向传播
        optimizer.step()   #每一步自动更新
        
        #记录acc与loss
        train_acc+=(predd.argmax(1)==y).type(torch.float).sum().item()
        train_loss+=loss.item()
        
    train_acc/=size
    train_loss/=num_batches
    
    return train_acc,train_loss

 2. 编写测试函数

测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器 

def test(dataloader,model,loss_fn):
    size=len(dataloader.dataset)   #测试集的大小
    num_batches=len(dataloader)    #批次数目,(size/batch_size,向上取整)
    test_loss,test_acc=0,0
    
    #当不进行训练时,停止梯度更新,节省内存消耗
    with torch.no_grad():
        for imgs,target in dataloader:
            imgs,target=imgs.to(device),target.to(device)
            
            #计算loss
            target_pred=model(imgs)
            loss=loss_fn(target_pred,target)
            
            test_loss+=loss.item()
            test_acc+=(target_pred.argmax(1)==target).type(torch.float).sum().item()
            
    test_acc/=size
    test_loss/=num_batches
    
    return test_acc,test_loss

 3. 设置动态学习率

def adjust_learning_rate(optimizer,epoch,start_lr):
    #每2个epoch衰减到原来的0.92
    lr=start_lr*(0.92**(epoch//2))
    for param_group in optimizer.param_groups:
        param_group['lr']=lr
        
learn_rate=1e-4   #初始学习率
optimizer=torch.optim.SGD(model.parameters(),lr=learn_rate)

调用官方动态学习率接口

与上面方法是等价的

'''#调用官方动态学习率接口时使用
lambda1=lambda epoch:(0.92**(epoch//2))
optimizer=torch.optim.SGD(model.parameters(),lr=learn_rate)
scheduler=torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=lambda1)#选定调整方法'''

 4. 正式训练

loss_fn=nn.CrossEntropyLoss()  #创建损失函数
epochs=20

train_loss,train_acc,test_loss,test_acc=[],[],[],[]

for epoch in range(epochs):
    #更新学习率(使用自定义学习率时使用)
    adjust_learning_rate(optimizer,epoch,learn_rate)
    
    model.train()
    epoch_train_acc,epoch_train_loss=train(train_dl,model,loss_fn,optimizer)
    #scheduler.step()  #更新学习率(调用官方动态学习率接口时使用)
    
    model.eval()
    epoch_test_acc,epoch_test_loss=test(test_dl,model,loss_fn)
    
    train_acc.append(epoch_test_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    
    #获取当前学习率
    lr=optimizer.state_dict()['param_groups'][0]['lr']
    
    template=('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f},Lr:{:.2E}')
    print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,
                          epoch_test_acc*100,epoch_test_loss,lr))
print('Done')

 运行结果:

Epoch: 1,Train_acc:64.7%,Train_loss:0.652,Test_acc:64.5%,Test_loss:0.682,Lr:1.00E-04
Epoch: 2,Train_acc:67.3%,Train_loss:0.621,Test_acc:61.8%,Test_loss:0.634,Lr:1.00E-04
Epoch: 3,Train_acc:69.9%,Train_loss:0.572,Test_acc:63.2%,Test_loss:0.586,Lr:9.20E-05
Epoch: 4,Train_acc:69.9%,Train_loss:0.580,Test_acc:69.7%,Test_loss:0.543,Lr:9.20E-05
Epoch: 5,Train_acc:74.5%,Train_loss:0.516,Test_acc:65.8%,Test_loss:0.543,Lr:8.46E-05
Epoch: 6,Train_acc:73.9%,Train_loss:0.525,Test_acc:73.7%,Test_loss:0.584,Lr:8.46E-05
Epoch: 7,Train_acc:74.3%,Train_loss:0.516,Test_acc:71.1%,Test_loss:0.549,Lr:7.79E-05
Epoch: 8,Train_acc:78.5%,Train_loss:0.494,Test_acc:72.4%,Test_loss:0.536,Lr:7.79E-05
Epoch: 9,Train_acc:78.3%,Train_loss:0.490,Test_acc:68.4%,Test_loss:0.569,Lr:7.16E-05
Epoch:10,Train_acc:79.1%,Train_loss:0.492,Test_acc:61.8%,Test_loss:0.574,Lr:7.16E-05
Epoch:11,Train_acc:80.9%,Train_loss:0.468,Test_acc:73.7%,Test_loss:0.548,Lr:6.59E-05
Epoch:12,Train_acc:80.7%,Train_loss:0.468,Test_acc:76.3%,Test_loss:0.555,Lr:6.59E-05
Epoch:13,Train_acc:81.1%,Train_loss:0.450,Test_acc:68.4%,Test_loss:0.563,Lr:6.06E-05
Epoch:14,Train_acc:82.9%,Train_loss:0.450,Test_acc:75.0%,Test_loss:0.562,Lr:6.06E-05
Epoch:15,Train_acc:84.1%,Train_loss:0.426,Test_acc:77.6%,Test_loss:0.563,Lr:5.58E-05
Epoch:16,Train_acc:83.7%,Train_loss:0.434,Test_acc:77.6%,Test_loss:0.538,Lr:5.58E-05
Epoch:17,Train_acc:82.3%,Train_loss:0.420,Test_acc:68.4%,Test_loss:0.531,Lr:5.13E-05
Epoch:18,Train_acc:83.7%,Train_loss:0.425,Test_acc:76.3%,Test_loss:0.496,Lr:5.13E-05
Epoch:19,Train_acc:82.7%,Train_loss:0.424,Test_acc:76.3%,Test_loss:0.524,Lr:4.72E-05
Epoch:20,Train_acc:84.1%,Train_loss:0.410,Test_acc:77.6%,Test_loss:0.501,Lr:4.72E-05
Done

由于结果不尽人意,重新调整模型,将卷积核修改为3*3,但结果依旧不理想。然后将卷积核重新改为5*5,优化器修改为Adam,跑了60轮后实现达到86%以上的目标。

Epoch: 1,Train_acc:50.8%,Train_loss:1.307,Test_acc:51.3%,Test_loss:0.911,Lr:1.00E-04
Epoch: 2,Train_acc:66.1%,Train_loss:0.648,Test_acc:77.6%,Test_loss:0.565,Lr:1.00E-04
Epoch: 3,Train_acc:70.5%,Train_loss:0.581,Test_acc:67.1%,Test_loss:0.477,Lr:9.20E-05
Epoch: 4,Train_acc:79.7%,Train_loss:0.446,Test_acc:80.3%,Test_loss:0.460,Lr:9.20E-05
Epoch: 5,Train_acc:84.7%,Train_loss:0.394,Test_acc:75.0%,Test_loss:0.515,Lr:8.46E-05
Epoch: 6,Train_acc:87.8%,Train_loss:0.338,Test_acc:73.7%,Test_loss:0.491,Lr:8.46E-05
Epoch: 7,Train_acc:87.1%,Train_loss:0.324,Test_acc:82.9%,Test_loss:0.410,Lr:7.79E-05
Epoch: 8,Train_acc:89.8%,Train_loss:0.287,Test_acc:81.6%,Test_loss:0.430,Lr:7.79E-05
Epoch: 9,Train_acc:92.8%,Train_loss:0.247,Test_acc:80.3%,Test_loss:0.447,Lr:7.16E-05
Epoch:10,Train_acc:93.4%,Train_loss:0.242,Test_acc:76.3%,Test_loss:0.576,Lr:7.16E-05
Epoch:11,Train_acc:96.4%,Train_loss:0.190,Test_acc:82.9%,Test_loss:0.345,Lr:6.59E-05
Epoch:12,Train_acc:94.0%,Train_loss:0.186,Test_acc:81.6%,Test_loss:0.350,Lr:6.59E-05
Epoch:13,Train_acc:96.4%,Train_loss:0.170,Test_acc:77.6%,Test_loss:0.445,Lr:6.06E-05
Epoch:14,Train_acc:95.4%,Train_loss:0.167,Test_acc:86.8%,Test_loss:0.328,Lr:6.06E-05
Epoch:15,Train_acc:97.0%,Train_loss:0.136,Test_acc:85.5%,Test_loss:0.372,Lr:5.58E-05
Epoch:16,Train_acc:98.6%,Train_loss:0.125,Test_acc:85.5%,Test_loss:0.333,Lr:5.58E-05
Epoch:17,Train_acc:98.8%,Train_loss:0.121,Test_acc:85.5%,Test_loss:0.398,Lr:5.13E-05
Epoch:18,Train_acc:99.2%,Train_loss:0.113,Test_acc:85.5%,Test_loss:0.359,Lr:5.13E-05
Epoch:19,Train_acc:99.4%,Train_loss:0.100,Test_acc:82.9%,Test_loss:0.357,Lr:4.72E-05
Epoch:20,Train_acc:99.0%,Train_loss:0.095,Test_acc:85.5%,Test_loss:0.353,Lr:4.72E-05
Epoch:21,Train_acc:99.2%,Train_loss:0.091,Test_acc:82.9%,Test_loss:0.367,Lr:4.34E-05
Epoch:22,Train_acc:99.4%,Train_loss:0.082,Test_acc:85.5%,Test_loss:0.407,Lr:4.34E-05
Epoch:23,Train_acc:99.4%,Train_loss:0.086,Test_acc:86.8%,Test_loss:0.306,Lr:4.00E-05
Epoch:24,Train_acc:99.8%,Train_loss:0.082,Test_acc:86.8%,Test_loss:0.294,Lr:4.00E-05
Epoch:25,Train_acc:99.6%,Train_loss:0.077,Test_acc:86.8%,Test_loss:0.337,Lr:3.68E-05
Epoch:26,Train_acc:99.2%,Train_loss:0.075,Test_acc:85.5%,Test_loss:0.386,Lr:3.68E-05
Epoch:27,Train_acc:99.8%,Train_loss:0.068,Test_acc:85.5%,Test_loss:0.312,Lr:3.38E-05
Epoch:28,Train_acc:99.8%,Train_loss:0.066,Test_acc:85.5%,Test_loss:0.419,Lr:3.38E-05
Epoch:29,Train_acc:100.0%,Train_loss:0.059,Test_acc:82.9%,Test_loss:0.338,Lr:3.11E-05
Epoch:30,Train_acc:100.0%,Train_loss:0.063,Test_acc:85.5%,Test_loss:0.311,Lr:3.11E-05
Epoch:31,Train_acc:99.6%,Train_loss:0.059,Test_acc:86.8%,Test_loss:0.324,Lr:2.86E-05
Epoch:32,Train_acc:99.8%,Train_loss:0.060,Test_acc:84.2%,Test_loss:0.336,Lr:2.86E-05
Epoch:33,Train_acc:99.8%,Train_loss:0.054,Test_acc:85.5%,Test_loss:0.302,Lr:2.63E-05
Epoch:34,Train_acc:100.0%,Train_loss:0.053,Test_acc:82.9%,Test_loss:0.282,Lr:2.63E-05
Epoch:35,Train_acc:100.0%,Train_loss:0.052,Test_acc:86.8%,Test_loss:0.303,Lr:2.42E-05
Epoch:36,Train_acc:100.0%,Train_loss:0.052,Test_acc:86.8%,Test_loss:0.329,Lr:2.42E-05
Epoch:37,Train_acc:99.8%,Train_loss:0.051,Test_acc:85.5%,Test_loss:0.372,Lr:2.23E-05
Epoch:38,Train_acc:99.8%,Train_loss:0.048,Test_acc:85.5%,Test_loss:0.376,Lr:2.23E-05
Epoch:39,Train_acc:100.0%,Train_loss:0.049,Test_acc:86.8%,Test_loss:0.321,Lr:2.05E-05
Epoch:40,Train_acc:100.0%,Train_loss:0.045,Test_acc:85.5%,Test_loss:0.395,Lr:2.05E-05
Epoch:41,Train_acc:99.8%,Train_loss:0.046,Test_acc:85.5%,Test_loss:0.288,Lr:1.89E-05
Epoch:42,Train_acc:100.0%,Train_loss:0.044,Test_acc:86.8%,Test_loss:0.349,Lr:1.89E-05
Epoch:43,Train_acc:99.8%,Train_loss:0.045,Test_acc:85.5%,Test_loss:0.400,Lr:1.74E-05
Epoch:44,Train_acc:100.0%,Train_loss:0.041,Test_acc:85.5%,Test_loss:0.310,Lr:1.74E-05
Epoch:45,Train_acc:100.0%,Train_loss:0.042,Test_acc:85.5%,Test_loss:0.306,Lr:1.60E-05
Epoch:46,Train_acc:99.8%,Train_loss:0.041,Test_acc:85.5%,Test_loss:0.341,Lr:1.60E-05
Epoch:47,Train_acc:100.0%,Train_loss:0.038,Test_acc:85.5%,Test_loss:0.338,Lr:1.47E-05
Epoch:48,Train_acc:100.0%,Train_loss:0.040,Test_acc:84.2%,Test_loss:0.329,Lr:1.47E-05
Epoch:49,Train_acc:100.0%,Train_loss:0.037,Test_acc:85.5%,Test_loss:0.328,Lr:1.35E-05
Epoch:50,Train_acc:100.0%,Train_loss:0.038,Test_acc:82.9%,Test_loss:0.314,Lr:1.35E-05
Epoch:51,Train_acc:100.0%,Train_loss:0.039,Test_acc:86.8%,Test_loss:0.321,Lr:1.24E-05
Epoch:52,Train_acc:100.0%,Train_loss:0.036,Test_acc:86.8%,Test_loss:0.319,Lr:1.24E-05
Epoch:53,Train_acc:100.0%,Train_loss:0.037,Test_acc:85.5%,Test_loss:0.344,Lr:1.14E-05
Epoch:54,Train_acc:100.0%,Train_loss:0.035,Test_acc:86.8%,Test_loss:0.322,Lr:1.14E-05
Epoch:55,Train_acc:100.0%,Train_loss:0.037,Test_acc:86.8%,Test_loss:0.294,Lr:1.05E-05
Epoch:56,Train_acc:100.0%,Train_loss:0.036,Test_acc:86.8%,Test_loss:0.302,Lr:1.05E-05
Epoch:57,Train_acc:100.0%,Train_loss:0.037,Test_acc:85.5%,Test_loss:0.284,Lr:9.68E-06
Epoch:58,Train_acc:100.0%,Train_loss:0.037,Test_acc:85.5%,Test_loss:0.325,Lr:9.68E-06
Epoch:59,Train_acc:100.0%,Train_loss:0.033,Test_acc:85.5%,Test_loss:0.383,Lr:8.91E-06
Epoch:60,Train_acc:100.0%,Train_loss:0.034,Test_acc:85.5%,Test_loss:0.332,Lr:8.91E-06
Done

四、 结果可视化

1. Loss与Accuracy图

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")   #忽略警告信息
plt.rcParams['font.sans-serif']=['SimHei']  #用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False  #用来正常显示负号
plt.rcParams['figure.dpi']=300  #分辨率

epochs_range=range(epochs)

plt.figure(figsize=(12,3))

plt.subplot(1,2,1)
plt.plot(epochs_range,train_acc,label='Training Accuracy')
plt.plot(epochs_range,test_acc,label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1,2,2)
plt.plot(epochs_range,train_loss,label='Training Loss')
plt.plot(epochs_range,test_loss,label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')

plt.show()

2. 指定图片进行预测

from PIL import Image
calsses=list(train_dataset.class_to_idx)

def predict_one_image(image_path,model,transform,classes):
    test_img=Image.open(image_path).convert('RGB')
    plt.imshow(test_img)   # 展示预测的图片
    
    test_img=transform(test_img)
    img=test_img.to(device).unsqueeze(0)
    
    model.eval()
    output=model(img)
    
    _,pred=torch.max(output,1)
    pred_class=classes[pred]
    print(f'预测结果是:{pred_class}')

#预测训练集中的某张照片
predict_one_image(image_path=r'D:\THE MNIST DATABASE\P5-data\test\adidas\1.jpg',
                  model=model,
                  transform=train_transforms,classes=calsses)

 运行结果:

五、保存并加载模型 

#模型保存
PATH=r'C:\Users\Administrator\PycharmProjects\pytorchProject1\P5周:Pytorch实现运动鞋识别\model-p5.pth'
torch.save(model.state_dict(),PATH)

#将参数加载到model当中
model.load_state_dict(torch.load(PATH,map_location=device))

运行结果:

<All keys matched successfully>

六、个人总结

本次项目耗费很长时间在模型调整上,由于初始结果不理想,想到可能是由于卷积核5*5是否过大导致没能捕捉到敏感信息的缘故,将卷积核调整为3*3,但是结果更差,猜测可能是感受野过小不利于特征信息的捕捉,于是又将卷积核重新调整为5*5。同时,优化器修改为Adam。在此过程中,也曾将学习率修改为1e-3,但效果不理想,改回1e-4后结果达到满意状态。 

更多推荐