1 安装

(1)安装mujoco210
Ubuntu 20.04上安装MuJoCo 210
(2)安装gymnasium-robotics

pip install gymnasium-robotics[mujoco-py]

(3)安装stable-baseline3
根据显卡驱动安装对应版本的torch

pip install --index-url https://download.pytorch.org/whl/cu126 torch torchvision torchaudio --upgrade
pip install stable-baselines3 torch==2.9.0+cu126
pip install sb3-contrib

(4)其他包

pip install tensorboard
pip intall h5py

2 代码

2.1 SAC 算法

import gymnasium as gym
import os
import signal
import sys

from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.her import HerReplayBuffer
from stable_baselines3.common.callbacks import BaseCallback

# 自定义CheckpointCallback以保存replay_buffer
class CustomCheckpointCallback(CheckpointCallback):
    def _on_step(self) -> bool:
        if self.n_calls % self.save_freq == 0:
            model_path = os.path.join(self.save_path, f"{self.name_prefix}_{self.num_timesteps}_steps")
            self.model.save(model_path)
            self.model.save_replay_buffer(model_path + "_replay_buffer")
            if self.verbose > 1:
                print(f"Saving checkpoint to {model_path}")
        return super()._on_step()

# 日志目录
log_dir = "./fetch_logs/"
os.makedirs(log_dir, exist_ok=True)

# 环境ID(使用v3,根据Gymnasium-Robotics文档;若需v4,替换)
env_id = "FetchReach-v3"

# 训练和评估环境(无render)
train_env = gym.make(env_id)
eval_env = gym.make(env_id)

# HER参数
goal_selection_strategy = "future"

# 中断处理:保存当前模型和buffer
def signal_handler(sig, frame):
    print("Interrupt received! Saving current model and replay buffer...")
    if 'model' in globals():
        model.save(os.path.join(log_dir, "latest_model"))
        model.save_replay_buffer(os.path.join(log_dir, "latest_replay_buffer"))
    sys.exit(0)

signal.signal(signal.SIGINT, signal_handler)

# 创建模型函数
def make_model(env):
    return SAC(
        "MultiInputPolicy",
        env,
        replay_buffer_class=HerReplayBuffer,
        replay_buffer_kwargs=dict(
            n_sampled_goal=4,
            goal_selection_strategy=goal_selection_strategy,
        ),
        verbose=1,
        tensorboard_log=log_dir,
        buffer_size=1000000,  # 可调整
    )

# 用户选择
choice = input("Enter 1 to train from scratch, 2 to continue training, 3 to infer: ").strip()

if choice == '1':
    # 从头训练
    model = make_model(train_env)
    
    # 回调:评估并保存最佳,自定义checkpoint保存buffer
    eval_callback = EvalCallback(eval_env, best_model_save_path=log_dir, log_path=log_dir, eval_freq=10000, n_eval_episodes=10, deterministic=True)
    checkpoint_callback = CustomCheckpointCallback(save_freq=10000, save_path=log_dir, name_prefix="sac_checkpoint")
    
    # 训练(总步数可调整)
    model.learn(total_timesteps=100000, callback=[eval_callback, checkpoint_callback], log_interval=4)
    
    # 最终评估
    mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=10, deterministic=True)
    print(f"Final mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")
    
    # 保存最终模型和buffer
    model.save(os.path.join(log_dir, "final_model"))
    model.save_replay_buffer(os.path.join(log_dir, "final_replay_buffer"))

elif choice == '2':
    # 继续训练
    model_path = os.path.join(log_dir, "latest_model")
    replay_path = os.path.join(log_dir, "latest_replay_buffer")
    
    if not os.path.exists(model_path + ".zip"):
        print("No latest model found, starting from scratch.")
        model = make_model(train_env)
    else:
        model = SAC.load(model_path, env=train_env)
        if os.path.exists(replay_path):
            model.load_replay_buffer(replay_path)
        else:
            print("No replay buffer found, continuing with empty buffer.")
    
    # 回调同上
    eval_callback = EvalCallback(eval_env, best_model_save_path=log_dir, log_path=log_dir, eval_freq=10000, n_eval_episodes=10, deterministic=True)
    checkpoint_callback = CustomCheckpointCallback(save_freq=10000, save_path=log_dir, name_prefix="sac_checkpoint")
    
    # 继续训练(不重置时序)
    model.learn(total_timesteps=100000, callback=[eval_callback, checkpoint_callback], log_interval=4, reset_num_timesteps=False)
    
    # 最终评估
    mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=10, deterministic=True)
    print(f"Final mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")
    
    # 保存最终
    model.save(os.path.join(log_dir, "final_model"))
    model.save_replay_buffer(os.path.join(log_dir, "final_replay_buffer"))

elif choice == '3':
    # 加载最佳模型推理并渲染
    best_path = os.path.join(log_dir, "best_model")
    if not os.path.exists(best_path + ".zip"):
        print("No best model found.")
        sys.exit(1)
    
    # 推理环境(带渲染)
    infer_env = gym.make(env_id, render_mode="human")
    model = SAC.load(best_path, env=infer_env)  # 无需buffer for infer
    
    # 运行10个episode渲染
    for episode in range(10):
        obs, _ = infer_env.reset()
        done = False
        total_reward = 0
        while not done:
            action, _ = model.predict(obs, deterministic=True)
            obs, reward, terminated, truncated, _ = infer_env.step(action)
            infer_env.render()
            total_reward += reward
            done = terminated or truncated
        print(f"Episode {episode + 1} reward: {total_reward}")
    
    infer_env.close()

else:
    print("Invalid choice.")

2.2 TQC算法

以下是最佳训练参数,将四个文件放在config文件夹下

FetchReach-v4.txt
OrderedDict([('batch_size', 256),
             ('buffer_size', 1000000),
             ('ent_coef', 'auto'),
             ('env_wrapper', 'sb3_contrib.common.wrappers.TimeFeatureWrapper'),
             ('gamma', 0.95),
             ('learning_rate', 0.001),
             ('learning_starts', 1000),
             ('n_timesteps', 20000.0),
             ('normalize', True),
             ('policy', 'MultiInputPolicy'),
             ('policy_kwargs', 'dict(net_arch=[64, 64], n_critics=1)'),
             ('replay_buffer_class', 'HerReplayBuffer'),
             ('replay_buffer_kwargs',
              "dict( online_sampling=True, goal_selection_strategy='future', "
              'n_sampled_goal=4 )'),
             ('normalize_kwargs', {'norm_obs': True, 'norm_reward': False})])
FetchPush-v4.txt
OrderedDict([('batch_size', 512),
             ('buffer_size', 1000000),
             ('env_wrapper', 'sb3_contrib.common.wrappers.TimeFeatureWrapper'),
             ('gamma', 0.98),
             ('learning_rate', 0.001),
             ('n_timesteps', 1000000.0),
             ('policy', 'MultiInputPolicy'),
             ('policy_kwargs', 'dict(net_arch=[512, 512, 512], n_critics=2)'),
             ('replay_buffer_class', 'HerReplayBuffer'),
             ('replay_buffer_kwargs',
              "dict( online_sampling=True, goal_selection_strategy='future', "
              'n_sampled_goal=4, max_episode_length=100 )'),
             ('tau', 0.005),
             ('normalize', False)])

FetchSlide-v4.txt
OrderedDict([('batch_size', 512),
             ('buffer_size', 1000000),
             ('env_wrapper', 'sb3_contrib.common.wrappers.TimeFeatureWrapper'),
             ('gamma', 0.98),
             ('learning_rate', 0.001),
             ('n_timesteps', 3000000.0),
             ('policy', 'MultiInputPolicy'),
             ('policy_kwargs', 'dict(net_arch=[512, 512, 512], n_critics=2)'),
             ('replay_buffer_class', 'HerReplayBuffer'),
             ('replay_buffer_kwargs',
              "dict( online_sampling=True, goal_selection_strategy='future', "
              'n_sampled_goal=4, max_episode_length=100 )'),
             ('tau', 0.005),
             ('normalize', False)])
FetchPickAndPlace-v4.txt
OrderedDict([('batch_size', 512),
             ('buffer_size', 1000000),
             ('env_wrapper', 'sb3_contrib.common.wrappers.TimeFeatureWrapper'),
             ('gamma', 0.98),
             ('learning_rate', 0.001),
             ('n_timesteps', 1000000.0),
             ('policy', 'MultiInputPolicy'),
             ('policy_kwargs', 'dict(net_arch=[512, 512, 512], n_critics=2)'),
             ('replay_buffer_class', 'HerReplayBuffer'),
             ('replay_buffer_kwargs',
              "dict( online_sampling=True, goal_selection_strategy='future', "
              'n_sampled_goal=4, max_episode_length=100 )'),
             ('tau', 0.005),
             ('normalize', False)])


import gymnasium as gym
import gymnasium_robotics
import os
import sys
import argparse
import h5py
import numpy as np
import collections
import importlib
import warnings
from pathlib import Path

from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.her import HerReplayBuffer
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv, VecNormalize
from stable_baselines3.common.env_util import make_vec_env
from sb3_contrib import TQC  # 从SB3-Contrib导入TQC

gym.register_envs(gymnasium_robotics)

# 加载参数文件
def load_params(env_id):
    file_map = {
        "FetchReach-v4": "config/FetchReach-v4.txt",
        "FetchPush-v4": "config/FetchPush-v4.txt",
        "FetchSlide-v4": "config/FetchSlide-v4.txt",  # 注意文件名为v1,但用于v4环境
        "FetchPickAndPlace-v4": "config/FetchPickAndPlace-v4.txt"
    }
    file_path = file_map.get(env_id)
    if not file_path:
        raise ValueError(f"No parameter file for environment {env_id}")
    with open(file_path, 'r') as f:
        content = f.read()
    try:
        # 使用eval解析OrderedDict
        params = eval(content, {"OrderedDict": collections.OrderedDict})
    except Exception as e:
        raise ValueError(f"Error parsing params file {file_path}: {e}")
    return params

# 自定义CheckpointCallback以保存replay_buffer和vec_normalize(如果存在)
class CustomCheckpointCallback(CheckpointCallback):
    def _on_step(self) -> bool:
        if self.n_calls % self.save_freq == 0:
            model_path = os.path.join(self.save_path, f"{self.name_prefix}_{self.num_timesteps}_steps")
            self.model.save(model_path)
            self.model.save_replay_buffer(model_path + "_replay_buffer")
            if self.model._vec_normalize_env is not None:
                self.model._vec_normalize_env.save(model_path + "_vec_normalize.pkl")
            if self.verbose > 1:
                print(f"Saving checkpoint to {model_path}")
        return super()._on_step()

# 自定义EvalCallback以保存best_vec_normalize(如果存在)
class CustomEvalCallback(EvalCallback):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        if self.log_path is not None:
            self.log_path = Path(self.log_path)
            self.log_path.mkdir(parents=True, exist_ok=True)

    def _sync_env_norms(self, training_env, eval_env):
        if hasattr(training_env, 'obs_rms'):
            eval_env.obs_rms = training_env.obs_rms.copy()
            eval_env.clip_obs = training_env.clip_obs
            # 注意:不同步ret_rms,因为它不向后兼容

    def _do_rollout(self, is_deterministic):
        obs = self.eval_env.reset()
        episode_reward = 0.0
        episode_length = 0
        done = False
        while not done:
            action, _ = self.model.predict(obs, deterministic=is_deterministic)
            new_obs, rewards, dones, infos = self.eval_env.step(action)
            episode_reward += rewards[0]
            episode_length += 1
            done = dones[0]
            info = infos[0]
            obs = new_obs
        if "is_success" in info:
            self._is_success_buffer.append(info["is_success"])
        return episode_reward, episode_length

    def _on_step(self) -> bool:
        continue_training = True
        if self.n_calls % self.eval_freq == 0:
            # Sync training and eval env if there is VecNormalize
            if self.model.get_vec_normalize_env() is not None:
                self._sync_env_norms(self.training_env, self.eval_env)

            # Reset success rate buffer
            self._is_success_buffer = []

            episode_rewards, episode_lengths = [], []

            is_deterministic = self.deterministic
            if not self.deterministic and self.warn:
                warnings.warn(
                    "evaluate_policy will use 'deterministic=True' instead of 'deterministic=False' for evaluation "
                    "in a future version of Stable Baselines3. Please pass `deterministic` to `evaluate_policy` instead.",
                    FutureWarning,
                )
                is_deterministic = True

            for _ in range(self.n_eval_episodes):
                # Do rollouts
                episode_reward, episode_length = self._do_rollout(is_deterministic)
                episode_rewards.append(episode_reward)
                episode_lengths.append(episode_length)

            if self.log_path is not None:
                np.save(self.log_path / "episode_rewards.npy", episode_rewards)
                np.save(self.log_path / "episode_lengths.npy", episode_lengths)

            mean_reward = np.mean(episode_rewards)
            std_reward = np.std(episode_rewards)
            if self.verbose >= 1:
                print(f"Eval mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")
            # Add to current Logger
            self.logger.record("eval/mean_reward", float(mean_reward))
            self.logger.record("eval/mean_ep_length", np.mean(episode_lengths))
            self.logger.record("eval/std_reward", float(std_reward))
            if len(self._is_success_buffer) > 0:
                success_rate = np.mean(self._is_success_buffer)
                if self.verbose >= 1:
                    print(f"Success rate: {100 * success_rate:.2f}%")
                self.logger.record("eval/success_rate", success_rate)

            # Dump log so the evaluation results are printed with the correct timestep
            self.logger.record("time/total_timesteps", self.model.num_timesteps, exclude="tensorboard")
            self.logger.dump(self.model.num_timesteps)

            if mean_reward > self.best_mean_reward:
                if self.verbose >= 1:
                    print("New best mean reward!")
                if self.best_model_save_path is not None:
                    self.model.save(os.path.join(self.best_model_save_path, "best_model"))
                    if self.model._vec_normalize_env is not None:
                        self.model._vec_normalize_env.save(os.path.join(self.best_model_save_path, "best_vec_normalize.pkl"))
                self.best_mean_reward = mean_reward
                # Trigger callback on new best model, if needed
                if self.callback_on_new_best is not None:
                    continue_training = self.callback_on_new_best.on_step()

            # Trigger callback if needed
            if self.callback is not None:
                continue_training = continue_training and self._on_event()

        return continue_training

# 创建环境(支持VecEnv,支持wrapper但不包括normalize)
def create_env(env_id, render_mode=None, n_envs=1):
    params = load_params(env_id)

    def make_single_env():
        env = gym.make(env_id, render_mode=render_mode)
        if 'env_wrapper' in params:
            wrapper_str = params['env_wrapper']
            module_name, class_name = wrapper_str.rsplit('.', 1)
            module = importlib.import_module(module_name)
            wrapper_class = getattr(module, class_name)
            env = wrapper_class(env)
        return env

    vec_env_cls = DummyVecEnv if n_envs == 1 else SubprocVecEnv
    vec_env = vec_env_cls([make_single_env for _ in range(n_envs)])
    return vec_env

# 创建模型(使用加载的参数配置TQC)
def create_model(env, log_dir, env_id):
    params = load_params(env_id)
    policy = params['policy']
    replay_buffer_class = HerReplayBuffer  # 假设总是HerReplayBuffer
    replay_buffer_kwargs = eval(params['replay_buffer_kwargs'])
    policy_kwargs = eval(params['policy_kwargs'])

    tqc_kwargs = {
        "policy": policy,
        "env": env,
        "replay_buffer_class": replay_buffer_class,
        "replay_buffer_kwargs": replay_buffer_kwargs,
        "policy_kwargs": policy_kwargs,
        "verbose": 1,
        "tensorboard_log": log_dir,
    }
    for key in ['batch_size', 'buffer_size', 'ent_coef', 'gamma', 'learning_rate', 'learning_starts', 'tau']:
        if key in params:
            tqc_kwargs[key] = params[key]

    return TQC(**tqc_kwargs)

# 保存数据集到HDF5(从replay buffer采样)
def save_dataset(buffer, file_path):
    # 采样整个buffer(或子集以节省空间)
    batch_size = min(buffer.pos if buffer.full else buffer.buffer_size, 10000)  # Limit to avoid OOM
    data = buffer.sample(batch_size)
    with h5py.File(file_path, 'w') as f:
        # 保存非字典字段
        for key in ['actions', 'rewards', 'dones']:
            f.create_dataset(key, data=getattr(data, key).cpu().numpy())
        
        # 保存observations和next_observations作为组
        obs_group = f.create_group('observations')
        for subkey, value in data.observations.items():
            obs_group.create_dataset(subkey, data=value.cpu().numpy())
        
        next_obs_group = f.create_group('next_observations')
        for subkey, value in data.next_observations.items():
            next_obs_group.create_dataset(subkey, data=value.cpu().numpy())
        
        # 额外保存HER相关(从observations)
        f.create_dataset('achieved_goals', data=data.observations['achieved_goal'].cpu().numpy())
        f.create_dataset('desired_goals', data=data.observations['desired_goal'].cpu().numpy())

# 训练函数(通用,用于scratch或continue)
def train_model(model, eval_env, log_dir, total_timesteps=100000, reset_num_timesteps=True):
    # 回调(增强评估,使用自定义EvalCallback)
    eval_callback = CustomEvalCallback(eval_env, best_model_save_path=log_dir, log_path=log_dir, eval_freq=50000, n_eval_episodes=50, deterministic=True, warn=True)
    checkpoint_callback = CustomCheckpointCallback(save_freq=10000, save_path=log_dir, name_prefix="tqc_checkpoint")
    
    try:
        # 训练
        model.learn(total_timesteps=total_timesteps, callback=[eval_callback, checkpoint_callback], log_interval=4, reset_num_timesteps=reset_num_timesteps)
    except KeyboardInterrupt:
        print("Interrupt received! Saving current model and replay buffer...")
        model.save(os.path.join(log_dir, "latest_model"))
        model.save_replay_buffer(os.path.join(log_dir, "latest_replay_buffer"))
        if model._vec_normalize_env is not None:
            model._vec_normalize_env.save(os.path.join(log_dir, "latest_vec_normalize.pkl"))
        sys.exit(0)
    
    # 最终评估
    mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=50, deterministic=True)
    print(f"Final mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")
    
    # 保存最终模型和buffer
    model.save(os.path.join(log_dir, "final_model"))
    model.save_replay_buffer(os.path.join(log_dir, "final_replay_buffer"))
    if model._vec_normalize_env is not None:
        model._vec_normalize_env.save(os.path.join(log_dir, "final_vec_normalize.pkl"))
    
    # 保存数据集为HDF5
    dataset_path = os.path.join(log_dir, "dataset.h5")
    save_dataset(model.replay_buffer, dataset_path)
    print(f"Dataset saved to {dataset_path}")

# 从头训练 (choice 1)
def train_from_scratch(env_id, log_dir, steps=None):
    params = load_params(env_id)
    train_steps = steps if steps is not None else params['n_timesteps']
    
    # 创建环境并根据参数添加normalize
    train_env = create_env(env_id, n_envs=1)
    if params.get('normalize', False):
        normalize_kwargs = params.get('normalize_kwargs', {})
        train_env = VecNormalize(train_env, **normalize_kwargs)
    
    eval_env = create_env(env_id, n_envs=1)
    if params.get('normalize', False):
        normalize_kwargs = params.get('normalize_kwargs', {})
        eval_env = VecNormalize(eval_env, training=False, **normalize_kwargs)
    
    model = create_model(train_env, log_dir, env_id)
    train_model(model, eval_env, log_dir, total_timesteps=train_steps)
    train_env.close()
    eval_env.close()

# 继续训练 (choice 2)
def continue_training(env_id, log_dir, steps=None):
    params = load_params(env_id)
    total_timesteps = steps if steps is not None else params['n_timesteps']
    
    # 创建环境并根据参数添加normalize
    train_env = create_env(env_id, n_envs=1)
    eval_env = create_env(env_id, n_envs=1)
    if params.get('normalize', False):
        normalize_kwargs = params.get('normalize_kwargs', {})
        vec_path = os.path.join(log_dir, "latest_vec_normalize.pkl")
        if os.path.exists(vec_path):
            train_env = VecNormalize.load(vec_path, train_env)
        else:
            train_env = VecNormalize(train_env, **normalize_kwargs)
        eval_env = VecNormalize(eval_env, training=False, **normalize_kwargs)
    
    model_path = os.path.join(log_dir, "latest_model.zip")
    replay_path = os.path.join(log_dir, "latest_replay_buffer")
    
    if not os.path.exists(model_path):
        print("No latest model found, starting from scratch.")
        model = create_model(train_env, log_dir, env_id)
    else:
        model = TQC.load(model_path, env=train_env)
        if os.path.exists(replay_path):
            model.load_replay_buffer(replay_path)
        else:
            print("No replay buffer found, continuing with empty buffer.")
    
    train_model(model, eval_env, log_dir, total_timesteps=total_timesteps, reset_num_timesteps=False)
    train_env.close()
    eval_env.close()

# 推理 (choice 3)
def perform_inference(env_id, log_dir):
    best_path = os.path.join(log_dir, "best_model.zip")
    if not os.path.exists(best_path):
        print("No best model found.")
        sys.exit(1)
    
    params = load_params(env_id)
    infer_env = create_env(env_id, render_mode="human", n_envs=1)
    vec_path = os.path.join(log_dir, "best_vec_normalize.pkl")
    if params.get('normalize', False):
        normalize_kwargs = params.get('normalize_kwargs', {})
        if os.path.exists(vec_path):
            infer_env = VecNormalize.load(vec_path, infer_env)
        else:
            infer_env = VecNormalize(infer_env, training=False, **normalize_kwargs)
        infer_env.training = False
    
    model = TQC.load(best_path, env=infer_env)  # 无需buffer for infer
    
    # 运行10个episode渲染
    for episode in range(10):
        obs = infer_env.reset()
        done = False
        total_reward = 0
        while not done:
            action, _ = model.predict(obs, deterministic=True)
            obs, rewards, dones, _ = infer_env.step(action)
            infer_env.render()
            total_reward += rewards[0]
            done = dones[0]
        print(f"Episode {episode + 1} reward: {total_reward}")
    
    infer_env.close()

# 主函数
def main():
    parser = argparse.ArgumentParser(description="Train or infer TQC on Fetch environment.")
    parser.add_argument('--choice', type=int, choices=[1, 2, 3], required=True,
                        help='1: train from scratch, 2: continue training, 3: infer')
    parser.add_argument('--env_id', type=str, default="FetchReach-v4",
                        choices=["FetchReach-v4", "FetchPush-v4", "FetchSlide-v4", "FetchPickAndPlace-v4"],
                        help='The environment ID to use.')
    parser.add_argument('--steps', type=float, default=None, required=False,
                        help='Number of training steps (only for training options). If not provided, use from params file.')
    args = parser.parse_args()
    
    log_dir = f"./tqc_fetch_logs/{args.env_id}/"
    os.makedirs(log_dir, exist_ok=True)
    
    if args.choice == 1:
        train_from_scratch(args.env_id, log_dir, args.steps)
    elif args.choice == 2:
        continue_training(args.env_id, log_dir, args.steps)
    elif args.choice == 3:
        perform_inference(args.env_id, log_dir)

if __name__ == "__main__":
    main()
python tqc_fetch.py --choice 1 --env_id FetchReach-v4

python tqc_fetch.py --choice 1 --env_id FetchPush-v4

python tqc_fetch.py --choice 1 --env_id FetchSlide-v4

python tqc_fetch.py --choice 1 --env_id FetchPickAndPlace-v4

FetchReach-v4训练的时间大约6分钟
FetchPush-v4 训练的时间大约6.8小时
FetchSlide-v4 训练的时间大约20个小时

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

更多推荐