Files
Gym_GPU/rl_game/get_up/train.py

101 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import sys
import os
import argparse
# 确保能找到项目根目录下的模块
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from isaaclab.app import AppLauncher
# 1. 配置启动参数
parser = argparse.ArgumentParser(description="Train T1 robot to Get-Up with RL-Games.")
parser.add_argument("--num_envs", type=int, default=8192, help="起身任务建议并行 4096 即可")
parser.add_argument("--task", type=str, default="Isaac-T1-GetUp-v0", help="任务 ID")
parser.add_argument("--seed", type=int, default=42, help="随机种子")
AppLauncher.add_app_launcher_args(parser)
args_cli = parser.parse_args()
# 2. 启动仿真器(必须在导入其他 isaaclab 模块前)
app_launcher = AppLauncher(args_cli)
simulation_app = app_launcher.app
import torch
import gymnasium as gym
import yaml
from isaaclab_rl.rl_games import RlGamesVecEnvWrapper
from rl_games.torch_runner import Runner
from rl_games.common import env_configurations, vecenv
# 导入你刚刚修改好的配置类
# 假设你的文件名是 t1_getup_cfg.py类名是 T1EnvCfg
from config.t1_env_cfg import T1EnvCfg
# 3. 注册环境
gym.register(
id="Isaac-T1-GetUp-v0",
entry_point="isaaclab.envs:ManagerBasedRLEnv",
kwargs={
"cfg": T1EnvCfg(), # 这里会加载你设置的随机旋转、时间惩罚等
},
)
def main():
# --- 新增:处理 Retrain 参数 ---
# 你可以手动指定路径,或者在 argparse 里增加一个 --checkpoint 参数
checkpoint_path = os.path.join(os.path.dirname(__file__), "logs/T1_GetUp/nn/T1_GetUp.pth")
# 检查模型文件是否存在
should_retrain = os.path.exists(checkpoint_path)
env = gym.make("Isaac-T1-GetUp-v0", num_envs=args_cli.num_envs)
# 注意rl_device 必须设置为 args_cli.device (通常是 'cuda:0')
wrapped_env = RlGamesVecEnvWrapper(
env,
rl_device=args_cli.device,
clip_obs=5.0,
clip_actions=1.0
)
vecenv.register('as_is', lambda config_name, num_actors, **kwargs: wrapped_env)
env_configurations.register('rlgym', {
'vecenv_type': 'as_is',
'env_creator': lambda **kwargs: wrapped_env
})
config_path = os.path.join(os.path.dirname(__file__), "config", "ppo_cfg.yaml")
with open(config_path, "r") as f:
rl_config = yaml.safe_load(f)
# 设置日志和实验名称
rl_game_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "."))
log_dir = os.path.join(rl_game_dir, "logs")
rl_config['params']['config']['train_dir'] = log_dir
rl_config['params']['config']['name'] = "T1_GetUp"
# --- 关键修改:注入模型路径 ---
if should_retrain:
print(f"[INFO]: 检测到预训练模型,正在从 {checkpoint_path} 恢复训练...")
# rl_games 会读取 config 中的 load_path 进行续训
rl_config['params']['config']['load_path'] = checkpoint_path
else:
print("[INFO]: 未找到预训练模型,将从零开始训练。")
# 7. 运行训练
runner = Runner()
runner.load(rl_config)
runner.run({
"train": True,
"play": False,
# 如果你想强制从某个 checkpoint 开始,也可以在这里传参
"checkpoint": checkpoint_path if should_retrain else None,
"vec_env": wrapped_env
})
simulation_app.close()
if __name__ == "__main__":
main()