pytorch实现LSTM

本文不讲解LSTM的理论基础,提供了一个简单的代码实现供参考
为了更实际的理解LSTM,使用pytorch代码实现手撕LSTM,并实现了训练循环,数据是简单的随机生成一个序列X,并把X的累计和作为Y:

import torch
import torch.nn as nn
import torch.optim as optim


class CustomLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(CustomLSTMCell, self).__init__()
        self.hidden_size = hidden_size

        # 初始化LSTM的权重和偏置
        self.W_f = nn.Linear(input_size + hidden_size, hidden_size)  # 遗忘门权重
        self.W_i = nn.Linear(input_size + hidden_size, hidden_size)  # 输入门权重
        self.W_c = nn.Linear(input_size + hidden_size, hidden_size)  # 候选记忆单元权重
        self.W_o = nn.Linear(input_size + hidden_size, hidden_size)  # 输出门权重

    def forward(self, x, hidden):
        # 获取上一个时间步的隐状态和细胞状态
        h_prev, c_prev = hidden

        # 拼接当前输入和上一个时间步的隐状态
        combined = torch.cat((x, h_prev), dim=1)  # [batch_size, input_size + hidden_size]

        # 1. 计算遗忘门
        f_t = torch.sigmoid(self.W_f(combined))  # [batch_size, hidden_size]

        # 2. 计算输入门
        i_t = torch.sigmoid(self.W_i(combined))  # [batch_size, hidden_size]

        # 3. 计算候选细胞状态
        c_tilde_t = torch.tanh(self.W_c(combined))  # [batch_size, hidden_size]

        # 4. 更新细胞状态
        c_t = f_t * c_prev + i_t * c_tilde_t  # [batch_size, hidden_size]

        # 5. 计算输出门
        o_t = torch.sigmoid(self.W_o(combined))  # [batch_size, hidden_size]

        # 6. 更新隐状态
        h_t = o_t * torch.tanh(c_t)  # [batch_size, hidden_size]

        # 返回新的隐状态和细胞状态
        return h_t, c_t

    def init_hidden(self, batch_size):
        # 初始化隐状态和细胞状态为零
        return (torch.zeros(batch_size, self.hidden_size),
                torch.zeros(batch_size, self.hidden_size))


class CustomLSTM(nn.Module):
    def __init__(self, input_size, hidden_size,output_size):
        super(CustomLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.lstm_cell = CustomLSTMCell(input_size, hidden_size)
        self.fc=nn.Linear(hidden_size,output_size)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()

        # 初始化隐藏状态和细胞状态
        hidden = self.lstm_cell.init_hidden(batch_size)

        # 存储每个时间步的输出
        outputs = []
        for t in range(seq_len):
            hidden = self.lstm_cell(x[:, t, :], hidden)  # 更新每个时间步的隐状态和细胞状态
            outputs.append(hidden[0])  # 仅存储隐状态

        outputs=torch.stack(outputs, dim=1)
        outputs=self.fc(outputs)
        # 返回所有时间步的隐状态
        return outputs


# 超参数设置
input_size = 10  # 输入特征的维度
# hidden_size = 20  # 隐状态的维度
hidden_size = 50  # 隐状态的维度
seq_length = 5  # 序列长度
# batch_size = 3  # 批量大小
batch_size = 512  # 批量大小
num_epochs = 1000*3  # 训练周期
learning_rate = 0.01  # 学习率
output_size=input_size
# 创建模型、损失函数和优化器
model = CustomLSTM(input_size, hidden_size,output_size)
criterion = nn.MSELoss()  # 使用均方误差作为损失函数
optimizer = optim.Adam(model.parameters(), lr=learning_rate)  # 使用Adam优化器

# 生成训练数据(假设是线性序列)
def generate_data(batch_size, seq_length, input_size):
    X = torch.randn(batch_size, seq_length, input_size)  # 随机输入
    Y = torch.sum(X, dim=1)  # 目标是输入序列的和
    return X, Y
import time
time1=time.time()
# 训练循环
for epoch in range(num_epochs):
    model.train()  # 设置模型为训练模式
    optimizer.zero_grad()  # 清零梯度

    # 生成训练数据
    inputs, targets = generate_data(batch_size, seq_length, input_size)

    # 前向传播
    outputs = model(inputs)

    # 计算损失,注意取最后一个时间步的输出
    loss = criterion(outputs[:, -1, :], targets)  # 使用最后一个时间步的输出
    loss.backward()  # 反向传播
    optimizer.step()  # 更新参数

    if (epoch + 1) % 100 == 0:  # 每10个周期输出一次损失
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
print(time.time()-time1)

Output:

Epoch [100/3000], Loss: 0.1212
Epoch [200/3000], Loss: 0.0443
Epoch [300/3000], Loss: 0.0263
Epoch [400/3000], Loss: 0.0147
Epoch [500/3000], Loss: 0.0134
Epoch [600/3000], Loss: 0.0156
Epoch [700/3000], Loss: 0.0090
Epoch [800/3000], Loss: 0.0111
Epoch [900/3000], Loss: 0.0067
Epoch [1000/3000], Loss: 0.0072
Epoch [1100/3000], Loss: 0.0057
Epoch [1200/3000], Loss: 0.0060
Epoch [1300/3000], Loss: 0.0057
Epoch [1400/3000], Loss: 0.0048
Epoch [1500/3000], Loss: 0.0044
Epoch [1600/3000], Loss: 0.0059
Epoch [1700/3000], Loss: 0.0044
Epoch [1800/3000], Loss: 0.0044
Epoch [1900/3000], Loss: 0.0049
Epoch [2000/3000], Loss: 0.0046
Epoch [2100/3000], Loss: 0.0040
Epoch [2200/3000], Loss: 0.0044
Epoch [2300/3000], Loss: 0.0039
Epoch [2400/3000], Loss: 0.0043
Epoch [2500/3000], Loss: 0.0043
Epoch [2600/3000], Loss: 0.0046
Epoch [2700/3000], Loss: 0.0040
Epoch [2800/3000], Loss: 0.0037
Epoch [2900/3000], Loss: 0.0027
Epoch [3000/3000], Loss: 0.0035
54.482550621032715

GPU版本:

import torch
import torch.nn as nn
import torch.optim as optim


class CustomLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(CustomLSTMCell, self).__init__()
        self.hidden_size = hidden_size

        # 初始化LSTM的权重和偏置
        self.W_f = nn.Linear(input_size + hidden_size, hidden_size)  # 遗忘门权重
        self.W_i = nn.Linear(input_size + hidden_size, hidden_size)  # 输入门权重
        self.W_c = nn.Linear(input_size + hidden_size, hidden_size)  # 候选记忆单元权重
        self.W_o = nn.Linear(input_size + hidden_size, hidden_size)  # 输出门权重

    def forward(self, x, hidden):
        # 获取上一个时间步的隐状态和细胞状态
        h_prev, c_prev = hidden

        # 拼接当前输入和上一个时间步的隐状态
        combined = torch.cat((x, h_prev), dim=1)  # [batch_size, input_size + hidden_size]

        # 1. 计算遗忘门
        f_t = torch.sigmoid(self.W_f(combined))  # [batch_size, hidden_size]

        # 2. 计算输入门
        i_t = torch.sigmoid(self.W_i(combined))  # [batch_size, hidden_size]

        # 3. 计算候选细胞状态
        c_tilde_t = torch.tanh(self.W_c(combined))  # [batch_size, hidden_size]

        # 4. 更新细胞状态
        c_t = f_t * c_prev + i_t * c_tilde_t  # [batch_size, hidden_size]

        # 5. 计算输出门
        o_t = torch.sigmoid(self.W_o(combined))  # [batch_size, hidden_size]

        # 6. 更新隐状态
        h_t = o_t * torch.tanh(c_t)  # [batch_size, hidden_size]

        # 返回新的隐状态和细胞状态
        return h_t, c_t

    def init_hidden(self, batch_size, device):
        return (torch.zeros(batch_size, self.hidden_size, device=device),
                torch.zeros(batch_size, self.hidden_size, device=device))


class CustomLSTM(nn.Module):
    def __init__(self, input_size, hidden_size,output_size,device):
        super(CustomLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.lstm_cell = CustomLSTMCell(input_size, hidden_size)
        self.fc=nn.Linear(hidden_size,output_size)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()

        # 初始化隐藏状态和细胞状态
        hidden = self.lstm_cell.init_hidden(batch_size,device)

        # 存储每个时间步的输出
        outputs = []
        for t in range(seq_len):
            hidden = self.lstm_cell(x[:, t, :], hidden)  # 更新每个时间步的隐状态和细胞状态
            outputs.append(hidden[0])  # 仅存储隐状态

        outputs=torch.stack(outputs, dim=1)
        outputs=self.fc(outputs)
        # 返回所有时间步的隐状态
        return outputs


# 超参数设置
input_size = 10  # 输入特征的维度
# hidden_size = 20  # 隐状态的维度
hidden_size = 50  # 隐状态的维度
seq_length = 5  # 序列长度
# batch_size = 3  # 批量大小
batch_size = 512  # 批量大小
num_epochs = 1000*3  # 训练周期
learning_rate = 0.01  # 学习率
output_size=input_size

# 生成训练数据(假设是线性序列)
def generate_data(batch_size, seq_length, input_size):
    X = torch.randn(batch_size, seq_length, input_size)  # 随机输入
    Y = torch.sum(X, dim=1)  # 目标是输入序列的和
    return X, Y

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = CustomLSTM(input_size, hidden_size, output_size,device).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
import time
time1=time.time()
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()

    # 生成训练数据并移动到 GPU
    inputs, targets = generate_data(batch_size, seq_length, input_size)
    inputs, targets = inputs.to(device), targets.to(device)

    # 前向传播
    outputs = model(inputs)
    loss = criterion(outputs[:, -1, :], targets)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
print(time.time()-time1)

时间:

27.915155172348022 s

一开始比较疑惑为什么cpu版本比gpu版本还快,发现是batch_size设的太小了的原因导致gpu并行计算的能力没有完全体现

测试:

# 测试模型
model.eval()  # 设置模型为评估模式
loss=0
test_num=10
for _ in range(test_num):
    with torch.no_grad():
        test_inputs, test_targets = generate_data(batch_size, seq_length, input_size)
        print("test input:{}",test_inputs)
        print("test target:",test_targets)
        
        test_outputs = model(test_inputs)
        print("output:",test_outputs[:, -1, :])
        test_loss = criterion(test_outputs[:, -1, :], test_targets)
        print(f'Test Loss: {test_loss.item():.4f}')
        loss+=test_loss.item()
print(loss/test_num)
test input: tensor([[[-0.1689,  1.2587, -0.2325,  ...,  1.1661, -0.3952,  1.3948],
         [ 0.5925, -0.2150, -0.3206,  ..., -0.1345,  0.3898, -0.1269],
         [-1.0967,  2.3059, -1.8854,  ..., -0.5927, -0.1536, -0.8220],
         [ 0.2587,  0.5456, -0.0286,  ..., -0.8626, -1.2658,  0.1615],
         [ 0.3446,  0.1702, -0.2319,  ..., -0.7598, -0.6328,  0.2581]],

        [[-1.7417,  1.0593,  1.0788,  ..., -1.6702,  1.0734,  2.0313],
         [ 1.6153, -0.2084, -0.4093,  ..., -0.5111, -0.5246, -0.4017],
         [ 2.0313, -1.3608, -0.9690,  ...,  0.0230, -1.2582, -1.1674],
         [ 1.4228, -0.9935, -1.3398,  ..., -0.1919,  0.4006, -0.9374],
         [ 0.7134, -0.9957, -0.0641,  ...,  0.0968,  1.0851,  0.0461]],

        [[-0.2515,  0.8441, -0.1456,  ..., -0.5415,  0.3201, -0.1809],
         [ 0.3923,  0.7804, -0.4870,  ...,  0.5387,  1.1762, -0.5460],
         [-0.0847, -0.5531,  0.8429,  ...,  1.0130,  0.7171,  1.1560],
         [ 0.3533, -0.4585,  1.4239,  ..., -0.0759, -0.9672, -0.3518],
         [ 0.5108, -2.0705,  1.1390,  ..., -1.5381,  0.1113, -0.2444]],

        ...,

        [[ 0.8122, -0.2227,  0.5449,  ...,  0.6472, -0.9025,  0.0963],
         [ 1.3878, -1.7086, -0.9966,  ...,  0.3764, -1.4767, -0.6054],
         [-0.1911,  1.2958,  1.5702,  ...,  0.4829, -1.1859,  0.8539],
         [-1.1988, -0.5388,  1.5019,  ..., -0.8905,  0.3147, -0.4322],
         [-0.0756,  1.1308, -1.6030,  ..., -0.6542,  0.1091, -0.5346]],

        [[ 1.9590, -0.6846,  1.4058,  ..., -1.3380,  0.7594, -0.8432],
         [-1.2274, -0.3200,  0.8222,  ...,  0.4213, -0.0495, -0.3112],
         [-0.1699, -0.3631, -0.1957,  ..., -0.4369, -1.1797, -0.9325],
         [-0.2505, -0.7296,  0.4274,  ..., -0.3892,  0.7507,  0.7931],
         [ 0.5962,  0.1331, -1.4419,  ...,  0.2037,  0.0616,  0.3307]],

        [[-2.1913, -0.0430, -1.0888,  ..., -0.5875, -0.3072, -1.0826],
         [-0.1817,  0.2678, -1.0504,  ...,  2.8191, -1.4597,  0.8648],
         [-0.0483, -0.4606, -0.2773,  ...,  1.2193,  0.4859, -1.0354],
         [ 1.3227,  0.1481,  0.8500,  ...,  0.1011, -0.4418, -0.5827],
         [-1.4760, -1.0485,  1.9185,  ..., -0.4341,  0.4932, -0.3094]]])
test target: tensor([[-0.0698,  4.0654, -2.6990,  ..., -1.1835, -2.0576,  0.8656],
        [ 4.0411, -2.4991, -1.7034,  ..., -2.2535,  0.7763, -0.4292],
        [ 0.9201, -1.4576,  2.7733,  ..., -0.6038,  1.3575, -0.1672],
        ...,
        [ 0.7344, -0.0435,  1.0175,  ..., -0.0382, -3.1412, -0.6220],
        [ 0.9074, -1.9642,  1.0178,  ..., -1.5390,  0.3424, -0.9630],
        [-2.5746, -1.1362,  0.3520,  ...,  3.1179, -1.2296, -2.1453]])
output: tensor([[-0.0845,  4.0455, -2.7100,  ..., -1.1685, -1.9360,  0.9110],
        [ 4.1641, -2.5004, -1.6384,  ..., -2.0929,  0.7764, -0.4944],
        [ 0.9523, -1.4426,  2.8670,  ..., -0.6102,  1.3349, -0.1365],
        ...,
        [ 0.8122, -0.0629,  1.0134,  ..., -0.0479, -3.1589, -0.5850],
        [ 0.9086, -2.0162,  0.9303,  ..., -1.5335,  0.3185, -0.9241],
        [-2.5714, -1.1214,  0.3246,  ...,  3.1120, -1.2731, -2.1153]])
Test Loss: 0.0034
test input: tensor([[[ 2.5929,  0.1978,  1.5500,  ..., -0.3944,  0.3159,  0.8025],
         [-0.8422, -0.1749,  1.3274,  ...,  0.2725,  2.3107, -0.5987],
         [ 0.0493,  0.4596, -0.2925,  ...,  0.3553, -0.0602, -0.3874],
         [-0.1100, -0.8498, -1.5901,  ...,  1.1603,  0.3895, -1.2063],
         [ 2.1411,  1.2079, -1.2222,  ...,  0.9933,  0.9691,  1.1165]],

        [[ 0.4969,  0.2073,  0.6287,  ..., -2.1063, -1.4328, -0.8577],
         [-0.0828, -1.4124,  0.4884,  ...,  1.9454, -0.1868, -0.7624],
         [-0.4180, -0.4282, -0.9868,  ...,  0.5704,  1.1960,  1.5377],
         [-0.2925,  1.4883,  1.0165,  ..., -1.2878, -0.8021,  0.4670],
         [ 0.5026,  1.7186, -0.3610,  ...,  0.5500,  0.8758, -0.4046]],

        [[-1.1402,  0.6832,  1.1964,  ..., -0.3867, -0.5885,  0.2191],
         [-1.2404,  1.7585, -3.0369,  ...,  0.1347,  0.9355, -0.1765],
         [-0.1398, -0.0176,  0.2003,  ..., -1.2592, -0.3214, -0.6245],
         [ 0.2821,  0.6852, -2.8414,  ...,  0.4883,  0.1187, -1.1225],
         [-0.0315, -0.9387, -0.8465,  ..., -0.1643, -0.8266,  1.2745]],

        ...,

        [[ 0.4313, -1.2559, -1.7052,  ...,  0.0113, -0.3625,  0.2509],
         [-0.1438,  0.2636,  1.3472,  ...,  0.1038, -0.8130, -1.1147],
         [ 0.6887,  1.0949, -0.1862,  ..., -0.3261,  1.8613,  0.6097],
         [ 0.0900,  0.6062,  0.2841,  ..., -1.4473,  0.3375, -1.1272],
         [-1.7171, -0.3127,  1.2086,  ...,  0.4012,  0.5875,  1.5758]],

        [[-0.7768, -0.1033,  1.7666,  ..., -0.4122,  0.0196,  1.1098],
         [ 1.0152, -0.0262,  0.4504,  ..., -0.3686, -1.5060,  0.5859],
         [ 0.2382,  0.8173,  1.5905,  ...,  0.8005,  0.4086, -1.1080],
         [ 0.7557,  0.5851, -0.2964,  ..., -0.0260,  0.0689, -0.8201],
         [ 0.2370, -2.1839,  0.5807,  ...,  1.0797, -0.5843, -0.0906]],

        [[ 0.4415, -1.0362,  0.8147,  ...,  0.2035, -0.6759,  0.8420],
         [ 0.1552,  0.0619, -0.6936,  ...,  0.2109, -1.1705, -1.1882],
         [ 1.2606,  0.0583, -0.7055,  ..., -0.6487, -1.2045, -0.3974],
         [ 0.3006, -0.3409, -0.2717,  ..., -0.0874, -1.0602, -0.4099],
         [-2.6576,  0.9069,  0.0059,  ...,  1.8999, -1.9768,  2.5610]]])
test target: tensor([[ 3.8311,  0.8406, -0.2274,  ...,  2.3870,  3.9249, -0.2734],
        [ 0.2063,  1.5737,  0.7858,  ..., -0.3283, -0.3498, -0.0200],
        [-2.2698,  2.1706, -5.3281,  ..., -1.1872, -0.6822, -0.4299],
        ...,
        [-0.6510,  0.3960,  0.9485,  ..., -1.2571,  1.6108,  0.1945],
        [ 1.4693, -0.9111,  4.0918,  ...,  1.0733, -1.5932, -0.3230],
        [-0.4997, -0.3500, -0.8503,  ...,  1.5781, -6.0879,  1.4075]])
output: tensor([[ 3.8252,  0.8959, -0.2285,  ...,  2.4706,  3.9936, -0.2670],
        [ 0.2188,  1.6810,  0.8445,  ..., -0.2575, -0.3024,  0.0307],
        [-2.2365,  2.0328, -5.1632,  ..., -1.2063, -0.6941, -0.3935],
        ...,
        [-0.6253,  0.4530,  0.9328,  ..., -1.2882,  1.6239,  0.2207],
        [ 1.5137, -0.8888,  4.0510,  ...,  1.1423, -1.6135, -0.2541],
        [-0.3872, -0.4619, -0.8793,  ...,  1.6329, -6.0644,  1.3579]])
Test Loss: 0.0033
test input: tensor([[[-0.2673,  1.0215,  0.7980,  ..., -0.2239, -1.5850, -0.7097],
         [ 0.3957,  0.7162, -1.0691,  ...,  0.2742, -0.2141, -1.1319],
         [-1.8952,  0.8279, -1.6393,  ...,  0.5658, -0.2553,  1.6003],
         [-0.1973, -0.0216, -0.0057,  ..., -0.1565, -1.3231, -0.1084],
         [-0.5937, -0.6538, -0.6966,  ..., -0.4609, -0.3213, -0.8327]],

        [[ 0.3793, -0.4352, -0.1368,  ...,  1.8472,  0.0512, -0.3820],
         [-0.2910,  1.4615, -0.9674,  ..., -0.4545, -2.4213, -0.0293],
         [ 0.0919,  0.0434, -1.3971,  ..., -1.2369,  0.3955,  0.0068],
         [-0.7757,  0.2856, -0.1693,  ..., -0.6219, -1.1484, -0.2100],
         [-0.5417,  1.2803, -0.5744,  ..., -0.1622, -0.2365, -0.4435]],

        [[-0.4802, -0.0777, -1.2657,  ..., -0.4536,  0.8761, -0.8684],
         [ 1.1426, -1.2437,  0.7745,  ..., -1.7280,  0.5387,  0.7279],
         [-0.6429, -0.5482, -0.1431,  ...,  0.5178, -0.1188, -1.3464],
         [-0.1955, -1.3221, -0.3209,  ..., -0.4542, -1.6140, -0.1808],
         [-1.0118,  0.5614,  0.0052,  ..., -0.5933, -1.5626, -0.5758]],

        ...,

        [[-0.2545,  1.0486,  1.3051,  ...,  0.6358,  0.3123,  0.0643],
         [-0.6198,  1.0159,  0.7555,  ..., -1.2070,  0.3430,  0.2573],
         [ 0.4471,  0.0880,  1.0887,  ..., -0.1341,  1.2339, -0.8415],
         [ 0.1029,  0.2686,  0.7183,  ..., -1.1201,  0.7749, -0.3663],
         [ 0.8686,  1.1721,  0.5444,  ..., -0.1482, -0.3150, -0.0523]],

        [[ 0.9771,  0.0061, -1.9452,  ..., -0.9623, -0.8364, -2.0293],
         [-0.4265,  0.5097, -0.2101,  ..., -1.8497,  1.5712,  1.6646],
         [ 1.0189, -0.7612,  1.9286,  ...,  0.6257, -1.1704,  1.2700],
         [-0.9839, -0.5303, -0.1938,  ..., -0.1502,  0.3275,  0.5457],
         [ 0.9025, -0.4543,  1.2847,  ...,  1.0576, -0.8189, -0.9613]],

        [[-0.8464, -0.9731,  0.1420,  ..., -0.4450,  0.9471,  1.2521],
         [-1.2145, -0.4421, -1.0015,  ...,  1.1457, -0.1939, -1.5541],
         [-0.9773, -0.3849, -1.9078,  ...,  0.7324,  0.1202, -1.7172],
         [ 1.6711, -1.2261,  1.1563,  ...,  1.4700, -0.3114,  1.5038],
         [-0.1961,  1.1472, -0.5389,  ...,  0.7941,  0.8100, -0.0857]]])
test target: tensor([[-2.5578e+00,  1.8902e+00, -2.6127e+00,  ..., -1.2444e-03,
         -3.6988e+00, -1.1824e+00],
        [-1.1373e+00,  2.6356e+00, -3.2449e+00,  ..., -6.2832e-01,
         -3.3594e+00, -1.0580e+00],
        [-1.1879e+00, -2.6303e+00, -9.5005e-01,  ..., -2.7113e+00,
         -1.8805e+00, -2.2436e+00],
        ...,
        [ 5.4419e-01,  3.5932e+00,  4.4119e+00,  ..., -1.9736e+00,
          2.3491e+00, -9.3851e-01],
        [ 1.4880e+00, -1.2301e+00,  8.6410e-01,  ..., -1.2789e+00,
         -9.2703e-01,  4.8969e-01],
        [-1.5633e+00, -1.8789e+00, -2.1499e+00,  ...,  3.6971e+00,
          1.3721e+00, -6.0111e-01]])
output: tensor([[-2.5923,  1.8687, -2.6293,  ...,  0.0315, -3.7254, -1.1840],
        [-1.1362,  2.5613, -3.1926,  ..., -0.5891, -3.3555, -1.0394],
        [-1.1493, -2.6889, -0.9138,  ..., -2.7508, -1.9224, -2.2159],
        ...,
        [ 0.5671,  3.5559,  4.3077,  ..., -1.8404,  2.3559, -0.9059],
        [ 1.5430, -1.1827,  0.8804,  ..., -1.2427, -0.9392,  0.4725],
        [-1.5820, -1.8845, -2.2152,  ...,  3.7788,  1.3970, -0.5757]])
Test Loss: 0.0031
平均loss:0.003249840810894966

代码改进:

加入validation

validation:

  • 加入validation可以反映训练模型的泛化能力,同时可以根据validation设计早停机制

  • 加入验证集本身不会直接提高模型的泛化能力。验证集的作用是评估模型在未见数据上的表现,帮助你判断模型是否出现了过拟合或欠拟合,但不会直接影响模型的训练或泛化。

  • 本实验中的数据生成一个固定的验证集即可:

    # 生成验证集(一次生成,用于所有epoch)
    validation_batch_size = 256
    val_inputs, val_targets = generate_data(validation_batch_size, seq_length, input_size)
    val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)
    

早停

  • 如果验证损失不再下降,可以提前停止训练。
class EarlyStopping:
    def __init__(self, patience=10, min_delta=0):
        """
        初始化早停机制
        :param patience: 等待验证损失不再下降的最大次数
        :param min_delta: 判断损失改善的最小变化量
        """
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        # 初始化最佳损失
        if self.best_loss is None:
            self.best_loss = val_loss
        # 判断是否满足早停条件
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

训练循环:

for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()

    # 生成训练数据并移动到 GPU
    inputs, targets = generate_data(batch_size, seq_length, input_size)
    inputs, targets = inputs.to(device), targets.to(device)

    # 前向传播
    outputs = model(inputs)
    loss = criterion(outputs[:, -1, :], targets)
    loss.backward()
    optimizer.step()
    # 每 100 个 epoch 进行一次验证评估
    if (epoch + 1) % 100 == 0:
        model.eval()
        with torch.no_grad():
            val_outputs = model(val_inputs)
            val_loss = criterion(val_outputs[:, -1, :], val_targets)
            # 早停判断
        early_stopping(val_loss.item())
        print(early_stopping.best_loss)
        if early_stopping.early_stop:
            print("早停机制触发,停止训练。")
            break
        
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}, Val Loss: {val_loss.item():.4f}')

输出:

cuda
0.11180789768695831
Epoch [100/3000], Loss: 0.1160, Val Loss: 0.1118
0.038831677287817
Epoch [200/3000], Loss: 0.0419, Val Loss: 0.0388
0.02261391095817089
Epoch [300/3000], Loss: 0.0228, Val Loss: 0.0226
0.015386641025543213
Epoch [400/3000], Loss: 0.0162, Val Loss: 0.0154
0.011308404617011547
Epoch [500/3000], Loss: 0.0139, Val Loss: 0.0113
0.009909652173519135
Epoch [600/3000], Loss: 0.0101, Val Loss: 0.0099
0.00933060236275196
Epoch [700/3000], Loss: 0.0090, Val Loss: 0.0093
0.007593141403049231
Epoch [800/3000], Loss: 0.0076, Val Loss: 0.0076
0.007593141403049231
Epoch [900/3000], Loss: 0.0083, Val Loss: 0.0076
0.006765467580407858
Epoch [1000/3000], Loss: 0.0073, Val Loss: 0.0068
0.006765467580407858
Epoch [1100/3000], Loss: 0.0069, Val Loss: 0.0078
0.005835698451846838
Epoch [1200/3000], Loss: 0.0062, Val Loss: 0.0058
0.00552908843383193
Epoch [1300/3000], Loss: 0.0071, Val Loss: 0.0055
0.005006270948797464
Epoch [1400/3000], Loss: 0.0059, Val Loss: 0.0050
0.005006270948797464
Epoch [1500/3000], Loss: 0.0054, Val Loss: 0.0053
0.004516806453466415
Epoch [1600/3000], Loss: 0.0057, Val Loss: 0.0045
0.004516806453466415
Epoch [1700/3000], Loss: 0.0046, Val Loss: 0.0048
0.004516806453466415
Epoch [1800/3000], Loss: 0.0055, Val Loss: 0.0051
0.003946827724575996
Epoch [1900/3000], Loss: 0.0043, Val Loss: 0.0039
0.003946827724575996
Epoch [2000/3000], Loss: 0.0060, Val Loss: 0.0050
0.003946827724575996
Epoch [2100/3000], Loss: 0.0061, Val Loss: 0.0060
0.003946827724575996
早停机制触发,停止训练。
17.569655656814575

gpu测试数据记得加to(device):

test_inputs=test_inputs.to(device)
test_targets=test_targets.to(device)

加入学习率调度器并保存模型的完整代码:

import torch
import torch.nn as nn
import torch.optim as optim


class CustomLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(CustomLSTMCell, self).__init__()
        self.hidden_size = hidden_size

        # 初始化LSTM的权重和偏置
        self.W_f = nn.Linear(input_size + hidden_size, hidden_size)  # 遗忘门权重
        self.W_i = nn.Linear(input_size + hidden_size, hidden_size)  # 输入门权重
        self.W_c = nn.Linear(input_size + hidden_size, hidden_size)  # 候选记忆单元权重
        self.W_o = nn.Linear(input_size + hidden_size, hidden_size)  # 输出门权重

    def forward(self, x, hidden):
        # 获取上一个时间步的隐状态和细胞状态
        h_prev, c_prev = hidden

        # 拼接当前输入和上一个时间步的隐状态
        combined = torch.cat((x, h_prev), dim=1)  # [batch_size, input_size + hidden_size]

        # 1. 计算遗忘门
        f_t = torch.sigmoid(self.W_f(combined))  # [batch_size, hidden_size]

        # 2. 计算输入门
        i_t = torch.sigmoid(self.W_i(combined))  # [batch_size, hidden_size]

        # 3. 计算候选细胞状态
        c_tilde_t = torch.tanh(self.W_c(combined))  # [batch_size, hidden_size]

        # 4. 更新细胞状态
        c_t = f_t * c_prev + i_t * c_tilde_t  # [batch_size, hidden_size]

        # 5. 计算输出门
        o_t = torch.sigmoid(self.W_o(combined))  # [batch_size, hidden_size]

        # 6. 更新隐状态
        h_t = o_t * torch.tanh(c_t)  # [batch_size, hidden_size]

        # 返回新的隐状态和细胞状态
        return h_t, c_t

    def init_hidden(self, batch_size, device):
        return (torch.zeros(batch_size, self.hidden_size, device=device),
                torch.zeros(batch_size, self.hidden_size, device=device))


class CustomLSTM(nn.Module):
    def __init__(self, input_size, hidden_size,output_size,device):
        super(CustomLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.lstm_cell = CustomLSTMCell(input_size, hidden_size)
        self.fc=nn.Linear(hidden_size,output_size)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()

        # 初始化隐藏状态和细胞状态
        hidden = self.lstm_cell.init_hidden(batch_size,device)

        # 存储每个时间步的输出
        outputs = []
        for t in range(seq_len):
            hidden = self.lstm_cell(x[:, t, :], hidden)  # 更新每个时间步的隐状态和细胞状态
            outputs.append(hidden[0])  # 仅存储隐状态

        outputs=torch.stack(outputs, dim=1)
        outputs=self.fc(outputs)
        # 返回所有时间步的隐状态
        return outputs
class EarlyStopping:
    def __init__(self, patience=10, min_delta=0,save_path='best_model.pth'):
        """
        初始化早停机制
        :param patience: 等待验证损失不再下降的最大次数
        :param min_delta: 判断损失改善的最小变化量
        """
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.save_path = save_path

    def __call__(self, val_loss,model):
        # 初始化最佳损失
        if self.best_loss is None:
            self.best_loss = val_loss
            torch.save(model.state_dict(), self.save_path)
        # 判断是否满足早停条件
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0
        torch.save(model.state_dict(), self.save_path)

# 超参数设置
input_size = 10  # 输入特征的维度
# hidden_size = 20  # 隐状态的维度
hidden_size = 50  # 隐状态的维度
seq_length = 5  # 序列长度
# batch_size = 3  # 批量大小
batch_size = 512  # 批量大小
num_epochs = 1000*3  # 训练周期
learning_rate = 0.01  # 学习率
output_size=input_size
patience=30
early_stopping = EarlyStopping(patience=patience, min_delta=0,save_path='best_model.pth')
# 生成训练数据(假设是线性序列)
def generate_data(batch_size, seq_length, input_size):
    X = torch.randn(batch_size, seq_length, input_size)  # 随机输入
    Y = torch.sum(X, dim=1)  # 目标是输入序列的和
    return X, Y

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = CustomLSTM(input_size, hidden_size, output_size,device).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

#调度器
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

# 生成验证集(一次生成,用于所有epoch)
validation_batch_size = 256
val_inputs, val_targets = generate_data(validation_batch_size, seq_length, input_size)
val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)

import time
time1=time.time()
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()

    # 生成训练数据并移动到 GPU
    inputs, targets = generate_data(batch_size, seq_length, input_size)
    inputs, targets = inputs.to(device), targets.to(device)

    # 前向传播
    outputs = model(inputs)
    loss = criterion(outputs[:, -1, :], targets)
    loss.backward()
    optimizer.step()
    # 每 100 个 epoch 进行一次验证评估
    if (epoch + 1) % 10 == 0:
        model.eval()
        with torch.no_grad():
            val_outputs = model(val_inputs)
            val_loss = criterion(val_outputs[:, -1, :], val_targets)
            # 早停判断
        early_stopping(val_loss.item(),model)
        # print(early_stopping.best_loss)
        if early_stopping.early_stop:
            print("早停机制触发,停止训练。")
            break
        
        # print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}, Val Loss: {val_loss.item():.4f}')
        scheduler.step(val_loss)
    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}, Val Loss: {val_loss.item():.4f}')
    
print(time.time()-time1)

model.load_state_dict(torch.load('best_model.pth'))
# 测试模型
model.eval()  # 设置模型为评估模式
loss=0
test_num=3
for _ in range(test_num):
    with torch.no_grad():
        test_inputs, test_targets = generate_data(batch_size, seq_length, input_size)
        test_inputs=test_inputs.to(device)
        test_targets=test_targets.to(device)
        print("test input:",test_inputs)
        print("test target:",test_targets)
        test_outputs = model(test_inputs)
        print("output:",test_outputs[:, -1, :])
        test_loss = criterion(test_outputs[:, -1, :], test_targets)
        print(f'Test Loss: {test_loss.item():.4f}')
        loss+=test_loss.item()
print(loss/test_num)

Output:

cuda
Epoch [100/3000], Loss: 0.1178, Val Loss: 0.1067
Epoch [200/3000], Loss: 0.0410, Val Loss: 0.0354
Epoch [300/3000], Loss: 0.0242, Val Loss: 0.0206
Epoch [400/3000], Loss: 0.0171, Val Loss: 0.0158
Epoch 00043: reducing learning rate of group 0 to 5.0000e-03.
Epoch [500/3000], Loss: 0.0116, Val Loss: 0.0095
Epoch [600/3000], Loss: 0.0097, Val Loss: 0.0078
Epoch [700/3000], Loss: 0.0083, Val Loss: 0.0068
Epoch [800/3000], Loss: 0.0078, Val Loss: 0.0059
Epoch 00085: reducing learning rate of group 0 to 2.5000e-03.
Epoch [900/3000], Loss: 0.0062, Val Loss: 0.0049
Epoch [1000/3000], Loss: 0.0054, Val Loss: 0.0047
Epoch [1100/3000], Loss: 0.0051, Val Loss: 0.0044
Epoch [1200/3000], Loss: 0.0054, Val Loss: 0.0042
Epoch [1300/3000], Loss: 0.0050, Val Loss: 0.0041
Epoch [1400/3000], Loss: 0.0053, Val Loss: 0.0038
Epoch [1500/3000], Loss: 0.0044, Val Loss: 0.0041
Epoch [1600/3000], Loss: 0.0040, Val Loss: 0.0035
Epoch 00163: reducing learning rate of group 0 to 1.2500e-03.
Epoch [1700/3000], Loss: 0.0041, Val Loss: 0.0030
Epoch [1800/3000], Loss: 0.0038, Val Loss: 0.0030
Epoch 00181: reducing learning rate of group 0 to 6.2500e-04.
Epoch [1900/3000], Loss: 0.0035, Val Loss: 0.0028
Epoch [2000/3000], Loss: 0.0036, Val Loss: 0.0027
Epoch [2100/3000], Loss: 0.0031, Val Loss: 0.0027
Epoch [2200/3000], Loss: 0.0035, Val Loss: 0.0026
Epoch [2300/3000], Loss: 0.0031, Val Loss: 0.0026
Epoch 00239: reducing learning rate of group 0 to 3.1250e-04.
Epoch [2400/3000], Loss: 0.0032, Val Loss: 0.0025
Epoch 00249: reducing learning rate of group 0 to 1.5625e-04.
Epoch [2500/3000], Loss: 0.0031, Val Loss: 0.0024
Epoch 00259: reducing learning rate of group 0 to 7.8125e-05.
Epoch [2600/3000], Loss: 0.0039, Val Loss: 0.0024
Epoch [2700/3000], Loss: 0.0030, Val Loss: 0.0024
Epoch 00271: reducing learning rate of group 0 to 3.9063e-05.
Epoch 00277: reducing learning rate of group 0 to 1.9531e-05.
Epoch [2800/3000], Loss: 0.0030, Val Loss: 0.0024
Epoch 00283: reducing learning rate of group 0 to 9.7656e-06.
Epoch 00289: reducing learning rate of group 0 to 4.8828e-06.
Epoch [2900/3000], Loss: 0.0034, Val Loss: 0.0024
早停机制触发,停止训练。
24.15766954421997

更多推荐