import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
import gym
import numpy as np
from collections import deque
import matplotlib.pyplot as plt

class PPONetwork(nn.Module):
    """
    PPO网络结构,包含策略网络(Actor)和价值网络(Critic)
    输入: 状态 (state_dim,)
    输出: 动作概率 (action_dim,), 状态价值 (1,)
    """
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super(PPONetwork, self).__init__()
        # 共享的特征提取层
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        
        # Actor层 - 输出动作概率
        self.fc_actor = nn.Linear(hidden_dim, action_dim)
        
        # Critic层 - 输出状态价值
        self.fc_critic = nn.Linear(hidden_dim, 1)
        
    def forward(self, x):
        # x形状: (batch_size, state_dim) 或 (state_dim,)
        if len(x.shape) == 1:
            x = x.unsqueeze(0)  # 确保是二维
        
        x = F.relu(self.fc1(x))  # (batch_size, hidden_dim)
        
        # 动作概率 (batch_size, action_dim)
        action_probs = F.softmax(self.fc_actor(x), dim=-1)
        
        # 状态价值 (batch_size, 1)
        state_values = self.fc_critic(x)
        
        return action_probs, state_values

def compute_gae(rewards, values, dones, next_values, gamma=0.99, lambda_=0.95):
    """
    计算广义优势估计(Generalized Advantage Estimation, GAE)
    
    参数:
        rewards: 奖励序列,形状 (T,)
        values: 状态价值序列,形状 (T,)
        dones: 终止标志序列,形状 (T,)
        next_values: 下一个状态价值序列,形状 (T,)
        gamma: 折扣因子
        lambda_: GAE参数
        
    返回:
        advantages: 优势函数序列,形状 (T,)
        returns: 回报序列,形状 (T,)
    """
    advantages = torch.zeros_like(rewards)  # (T,)
    gae = 0
    for t in reversed(range(len(rewards))):
        # 计算TD误差
        delta = rewards[t] + gamma * next_values[t] * (1 - dones[t]) - values[t]
        # 递归计算GAE
        gae = delta + gamma * lambda_ * (1 - dones[t]) * gae
        advantages[t] = gae
    # 计算回报
    returns = advantages + values  # (T,)
    return advantages, returns

def ppo_update(ppo_net, optimizer, states, actions, old_log_probs, 
               advantages, returns, clip_epsilon=0.2, c1=0.5, c2=0.01):
    """
    PPO算法更新步骤
    
    参数:
        ppo_net: PPO网络
        optimizer: 优化器
        states: 状态序列,形状 (B, state_dim)
        actions: 动作序列,形状 (B,)
        old_log_probs: 旧策略的对数概率,形状 (B,)
        advantages: 优势函数序列,形状 (B,)
        returns: 回报序列,形状 (B,)
        clip_epsilon: 裁剪参数
        c1: 价值损失系数
        c2: 熵系数
        
    返回:
        policy_loss: 策略损失值
        value_loss: 价值损失值
        entropy: 熵值
    """
    # 转换为tensor
    states = torch.FloatTensor(states)  # (B, state_dim)
    actions = torch.LongTensor(actions)  # (B,)
    old_log_probs = torch.FloatTensor(old_log_probs)  # (B,)
    advantages = torch.FloatTensor(advantages)  # (B,)
    returns = torch.FloatTensor(returns)  # (B,)
    
    # 获取新策略的概率和状态价值
    new_action_probs, state_values = ppo_net(states)  # (B, action_dim), (B, 1)
    dist = Categorical(new_action_probs)
    new_log_probs = dist.log_prob(actions)  # (B,)
    
    # 计算重要性采样比率
    ratio = (new_log_probs - old_log_probs).exp()  # (B,)
    
    # 裁剪策略损失
    surr1 = ratio * advantages  # (B,)
    surr2 = torch.clamp(ratio, 1.0 - clip_epsilon, 1.0 + clip_epsilon) * advantages  # (B,)
    policy_loss = -torch.min(surr1, surr2).mean()  # scalar
    
    # 价值函数损失
    value_loss = F.mse_loss(state_values.squeeze(), returns)  # scalar
    
    # 熵奖励
    entropy = dist.entropy().mean()  # scalar
    
    # 总损失
    total_loss = policy_loss + c1 * value_loss - c2 * entropy
    
    # 反向传播
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    return policy_loss.item(), value_loss.item(), entropy.item()

def train_ppo(env_name="CartPole-v1", num_episodes=500, max_steps=1000, 
              gamma=0.99, lambda_=0.95, clip_epsilon=0.2, 
              update_epochs=4, batch_size=64, hidden_dim=64):
    """
    PPO训练函数
    
    参数:
        env_name: 环境名称
        num_episodes: 训练的总回合数
        max_steps: 每个回合的最大步数
        gamma: 折扣因子
        lambda_: GAE参数
        clip_epsilon: 裁剪参数
        update_epochs: 每次数据收集后的更新轮数
        batch_size: 小批量大小
        hidden_dim: 网络隐藏层维度
    """
    # 创建环境
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]  # 状态维度
    action_dim = env.action_space.n  # 动作维度
    
    # 初始化PPO智能体
    ppo_net = PPONetwork(state_dim, action_dim, hidden_dim)
    optimizer = optim.Adam(ppo_net.parameters(), lr=0.001)
    
    # 训练统计
    episode_rewards = []
    moving_avg_rewards = []
    
    for episode in range(num_episodes):
        state = env.reset()
        episode_reward = 0
        states = []
        actions = []
        rewards = []
        dones = []
        values = []
        log_probs = []
        
        # 收集数据
        for step in range(max_steps):
            # 选择动作
            state_tensor = torch.FloatTensor(state)  # (state_dim,)
            with torch.no_grad():
                action_probs, value = ppo_net(state_tensor)  # (action_dim,), (1,)
                dist = Categorical(action_probs)
                action = dist.sample().item()  # scalar
                log_prob = dist.log_prob(torch.tensor(action))  # scalar
            
            # 执行动作
            next_state, reward, done, _ = env.step(action)
            
            # 存储经验
            states.append(state)
            actions.append(action)
            rewards.append(reward)
            dones.append(done)
            values.append(value.item())
            log_probs.append(log_prob.item())
            
            state = next_state
            episode_reward += reward
            
            if done:
                break
        
        # 计算最后一个状态的价值
        with torch.no_grad():
            _, next_value = ppo_net(torch.FloatTensor(state))  # (1,)
            next_values = values[1:] + [next_value.item()]
        
        # 计算GAE和returns
        advantages, returns = compute_gae(
            torch.FloatTensor(rewards),  # (T,)
            torch.FloatTensor(values),  # (T,)
            torch.FloatTensor(dones),  # (T,)
            torch.FloatTensor(next_values),  # (T,)
            gamma, 
            lambda_
        )
        
        # 转换为numpy数组
        states = np.array(states)  # (T, state_dim)
        actions = np.array(actions)  # (T,)
        old_log_probs = np.array(log_probs)  # (T,)
        advantages = advantages.numpy()  # (T,)
        returns = returns.numpy()  # (T,)
        
        # 标准化优势函数
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        # 多次更新网络
        for _ in range(update_epochs):
            # 随机打乱数据
            indices = np.arange(len(states))
            np.random.shuffle(indices)
            
            # 小批量更新
            for start in range(0, len(states), batch_size):
                end = start + batch_size
                batch_indices = indices[start:end]
                
                policy_loss, value_loss, entropy = ppo_update(
                    ppo_net, optimizer,
                    states[batch_indices],
                    actions[batch_indices],
                    old_log_probs[batch_indices],
                    advantages[batch_indices],
                    returns[batch_indices],
                    clip_epsilon
                )
        
        # 记录奖励
        episode_rewards.append(episode_reward)
        moving_avg = np.mean(episode_rewards[-100:])  # 最近100轮平均奖励
        moving_avg_rewards.append(moving_avg)
        
        # 打印训练信息
        if (episode + 1) % 10 == 0:
            print(f"Episode {episode+1}, Reward: {episode_reward}, Moving Avg: {moving_avg:.1f}, "
                  f"Policy Loss: {policy_loss:.3f}, Value Loss: {value_loss:.3f}, Entropy: {entropy:.3f}")
    
    # 绘制训练曲线
    plt.figure(figsize=(10, 5))
    plt.plot(episode_rewards, label='Episode Reward')
    plt.plot(moving_avg_rewards, label='Moving Avg (100)')
    plt.xlabel('Episode')
    plt.ylabel('Reward')
    plt.title('PPO Training Performance')
    plt.legend()
    plt.show()
    
    return ppo_net

def test_agent(agent, env_name="CartPole-v1", num_episodes=10, render=False):
    """
    测试训练好的智能体
    
    参数:
        agent: 训练好的PPO网络
        env_name: 环境名称
        num_episodes: 测试回合数
        render: 是否渲染环境
    """
    env = gym.make(env_name)
    total_rewards = []
    
    for episode in range(num_episodes):
        state = env.reset()
        total_reward = 0
        done = False
        
        while not done:
            if render:
                env.render()
            
            state_tensor = torch.FloatTensor(state)
            with torch.no_grad():
                action_probs, _ = agent(state_tensor)
                action = torch.argmax(action_probs).item()
            
            next_state, reward, done, _ = env.step(action)
            total_reward += reward
            state = next_state
        
        total_rewards.append(total_reward)
        print(f"Test Episode {episode+1}: Total Reward = {total_reward}")
    
    env.close()
    print(f"Average Reward over {num_episodes} episodes: {np.mean(total_rewards):.1f}")

# 训练PPO智能体
ppo_net = train_ppo()

# 测试训练好的智能体
test_agent(ppo_net, render=True)

更多推荐