Files
Gym_GPU/rl_game/get_up/train.py

92 lines
2.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import sys
import os
import argparse
# 确保能找到项目根目录下的模块
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from isaaclab.app import AppLauncher
# 1. 配置启动参数
parser = argparse.ArgumentParser(description="Train T1 robot to Get-Up with RL-Games.")
parser.add_argument("--num_envs", type=int, default=16384, help="起身任务建议并行 4096 即可")
parser.add_argument("--task", type=str, default="Isaac-T1-GetUp-v0", help="任务 ID")
parser.add_argument("--seed", type=int, default=42, help="随机种子")
AppLauncher.add_app_launcher_args(parser)
args_cli = parser.parse_args()
# 2. 启动仿真器(必须在导入其他 isaaclab 模块前)
app_launcher = AppLauncher(args_cli)
simulation_app = app_launcher.app
import torch
import gymnasium as gym
import yaml
from isaaclab_rl.rl_games import RlGamesVecEnvWrapper
from rl_games.torch_runner import Runner
from rl_games.common import env_configurations, vecenv
# 导入你刚刚修改好的配置类
# 假设你的文件名是 t1_getup_cfg.py类名是 T1EnvCfg
from config.t1_env_cfg import T1EnvCfg
# 3. 注册环境
gym.register(
id="Isaac-T1-GetUp-v0",
entry_point="isaaclab.envs:ManagerBasedRLEnv",
kwargs={
"cfg": T1EnvCfg(), # 这里会加载你设置的随机旋转、时间惩罚等
},
)
def main():
# 4. 创建环境,显式传入命令行指定的 num_envs
env = gym.make("Isaac-T1-GetUp-v0", num_envs=args_cli.num_envs)
# 5. 包装环境
# 注意rl_device 必须设置为 args_cli.device (通常是 'cuda:0')
wrapped_env = RlGamesVecEnvWrapper(
env,
rl_device=args_cli.device,
clip_obs=5.0,
clip_actions=1.0 # 动作裁剪建议设小一点,防止电机输出瞬间爆表
)
# 注册给 rl_games 使用
vecenv.register('as_is', lambda config_name, num_actors, **kwargs: wrapped_env)
env_configurations.register('rlgym', {
'vecenv_type': 'as_is',
'env_creator': lambda **kwargs: wrapped_env
})
# 6. 加载 PPO 配置文件
# 提示:由于是起身任务,建议在 ppo_cfg.yaml 中调大 mini_batch 大数或提高学习率
config_path = os.path.join(os.path.dirname(__file__), "config", "ppo_cfg.yaml")
with open(config_path, "r") as f:
rl_config = yaml.safe_load(f)
# 设置日志和实验名称
rl_game_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "."))
log_dir = os.path.join(rl_game_dir, "logs")
rl_config['params']['config']['train_dir'] = log_dir
rl_config['params']['config']['name'] = "T1_GetUp"
# 7. 运行训练
runner = Runner()
runner.load(rl_config)
print(f"[INFO]: 开始训练任务 {args_cli.task},环境数量: {args_cli.num_envs}")
runner.run({
"train": True,
"play": False,
"vec_env": wrapped_env
})
simulation_app.close()
if __name__ == "__main__":
main()