Files
Gym_GPU/rl_game/demo/train.py
2026-03-15 20:14:06 -04:00

83 lines
2.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import sys
import os
# 关键:确保当前目录在 sys.path 中,这样才能直接 from config 导入
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import argparse
from isaaclab.app import AppLauncher
# 添加启动参数
parser = argparse.ArgumentParser(description="Train T1 robot with rl_games.")
parser.add_argument("--num_envs", type=int, default=16384, help="Number of envs to run.")
AppLauncher.add_app_launcher_args(parser)
args_cli = parser.parse_args()
# 启动仿真器
app_launcher = AppLauncher(args_cli)
simulation_app = app_launcher.app
import torch
import gymnasium as gym
from isaaclab_rl.rl_games import RlGamesVecEnvWrapper
from rl_games.torch_runner import Runner
import yaml
from config.t1_env_cfg import T1EnvCfg
from rl_games.common import env_configurations, vecenv
gym.register(
id="Isaac-T1-Walking-v0",
entry_point="isaaclab.envs:ManagerBasedRLEnv", # Isaac Lab 统一的强化学习环境入口
kwargs={
"cfg": T1EnvCfg(),
},
)
def main():
# 1. 创建环境 (保持不变)
env = gym.make("Isaac-T1-Walking-v0", num_envs=args_cli.num_envs)
# 2. 包装环境 (保持不变)
wrapped_env = RlGamesVecEnvWrapper(
env,
rl_device=args_cli.device,
clip_obs=5.0,
clip_actions=100.0
)
vecenv.register('as_is', lambda config_name, num_actors, **kwargs: wrapped_env)
# 注册环境配置
env_configurations.register('rlgym', {
'vecenv_type': 'as_is',
'env_creator': lambda **kwargs: wrapped_env
})
# 3. 加载 PPO 配置 (保持不变)
config_path = os.path.join(os.path.dirname(__file__), "config", "ppo_cfg.yaml")
with open(config_path, "r") as f:
rl_config = yaml.safe_load(f)
# 设置日志路径
rl_game_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "."))
log_dir = os.path.join(rl_game_dir, "logs")
rl_config['params']['config']['train_dir'] = log_dir
# 4. 启动训练
runner = Runner()
# 此时 rl_config 只有文本和数字没有复杂对象deepcopy 会成功
runner.load(rl_config)
# 在 run 时传入对象是安全的
runner.run({
"train": True,
"play": False,
"vec_env": wrapped_env
})
simulation_app.close()
# PYTHONPATH=. python rl_game/your_file_name/train.py
if __name__ == "__main__":
main()