From 02c06c23ad9cbb3a88bba4ad57eee08b2b04ac0f Mon Sep 17 00:00:00 2001 From: ChenXi Date: Tue, 17 Mar 2026 05:56:26 -0400 Subject: [PATCH] add some codes to make retain come true --- rl_game/demo/train.py | 66 +++++++++++++++++++++++++++---------------- 1 file changed, 42 insertions(+), 24 deletions(-) diff --git a/rl_game/demo/train.py b/rl_game/demo/train.py index 58bece0..df1394e 100644 --- a/rl_game/demo/train.py +++ b/rl_game/demo/train.py @@ -1,83 +1,101 @@ import sys import os -# 关键:确保当前目录在 sys.path 中,这样才能直接 from config 导入 +import argparse + +# 确保能找到项目根目录下的模块 sys.path.append(os.path.dirname(os.path.abspath(__file__))) -import argparse from isaaclab.app import AppLauncher -# 添加启动参数 -parser = argparse.ArgumentParser(description="Train T1 robot with rl_games.") -parser.add_argument("--num_envs", type=int, default=16384, help="Number of envs to run.") +# 1. 配置启动参数 +parser = argparse.ArgumentParser(description="Train T1 robot to Get-Up with RL-Games.") +parser.add_argument("--num_envs", type=int, default=16384, help="起身任务建议并行 4096 即可") +parser.add_argument("--task", type=str, default="Isaac-T1-GetUp-v0", help="任务 ID") +parser.add_argument("--seed", type=int, default=42, help="随机种子") AppLauncher.add_app_launcher_args(parser) args_cli = parser.parse_args() -# 启动仿真器 +# 2. 启动仿真器(必须在导入其他 isaaclab 模块前) app_launcher = AppLauncher(args_cli) simulation_app = app_launcher.app import torch import gymnasium as gym +import yaml from isaaclab_rl.rl_games import RlGamesVecEnvWrapper from rl_games.torch_runner import Runner -import yaml -from config.t1_env_cfg import T1EnvCfg from rl_games.common import env_configurations, vecenv +# 导入你刚刚修改好的配置类 +# 假设你的文件名是 t1_getup_cfg.py,类名是 T1EnvCfg +from config.t1_env_cfg import T1EnvCfg + +# 3. 注册环境 gym.register( - id="Isaac-T1-Walking-v0", - entry_point="isaaclab.envs:ManagerBasedRLEnv", # Isaac Lab 统一的强化学习环境入口 + id="Isaac-T1-Walk-v0", + entry_point="isaaclab.envs:ManagerBasedRLEnv", kwargs={ - "cfg": T1EnvCfg(), + "cfg": T1EnvCfg(), # 这里会加载你设置的随机旋转、时间惩罚等 }, ) -def main(): - # 1. 创建环境 (保持不变) - env = gym.make("Isaac-T1-Walking-v0", num_envs=args_cli.num_envs) - # 2. 包装环境 (保持不变) +def main(): + # --- 新增:处理 Retrain 参数 --- + # 你可以手动指定路径,或者在 argparse 里增加一个 --checkpoint 参数 + checkpoint_path = os.path.join(os.path.dirname(__file__), "logs/T1_GetUp/nn/**.pth") + # 检查模型文件是否存在 + should_retrain = os.path.exists(checkpoint_path) + + env = gym.make("Isaac-T1-Walk-v0", num_envs=args_cli.num_envs) + + # 注意:rl_device 必须设置为 args_cli.device (通常是 'cuda:0') wrapped_env = RlGamesVecEnvWrapper( env, rl_device=args_cli.device, clip_obs=5.0, - clip_actions=100.0 + clip_actions=1.0 ) vecenv.register('as_is', lambda config_name, num_actors, **kwargs: wrapped_env) - # 注册环境配置 env_configurations.register('rlgym', { 'vecenv_type': 'as_is', 'env_creator': lambda **kwargs: wrapped_env }) - # 3. 加载 PPO 配置 (保持不变) config_path = os.path.join(os.path.dirname(__file__), "config", "ppo_cfg.yaml") with open(config_path, "r") as f: rl_config = yaml.safe_load(f) - # 设置日志路径 + # 设置日志和实验名称 rl_game_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".")) log_dir = os.path.join(rl_game_dir, "logs") rl_config['params']['config']['train_dir'] = log_dir + rl_config['params']['config']['name'] = "T1_GetUp" - # 4. 启动训练 + # --- 关键修改:注入模型路径 --- + if should_retrain: + print(f"[INFO]: 检测到预训练模型,正在从 {checkpoint_path} 恢复训练...") + # rl_games 会读取 config 中的 load_path 进行续训 + rl_config['params']['config']['load_path'] = checkpoint_path + else: + print("[INFO]: 未找到预训练模型,将从零开始训练。") + + # 7. 运行训练 runner = Runner() - - # 此时 rl_config 只有文本和数字,没有复杂对象,deepcopy 会成功 runner.load(rl_config) - # 在 run 时传入对象是安全的 runner.run({ "train": True, "play": False, + # 如果你想强制从某个 checkpoint 开始,也可以在这里传参 + "checkpoint": checkpoint_path if should_retrain else None, "vec_env": wrapped_env }) simulation_app.close() -# PYTHONPATH=. python rl_game/your_file_name/train.py if __name__ == "__main__": main() \ No newline at end of file