The example of train-T1
This commit is contained in:
83
rl_game/demo/train.py
Normal file
83
rl_game/demo/train.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import sys
|
||||
import os
|
||||
# 关键:确保当前目录在 sys.path 中,这样才能直接 from config 导入
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
import argparse
|
||||
from isaaclab.app import AppLauncher
|
||||
|
||||
# 添加启动参数
|
||||
parser = argparse.ArgumentParser(description="Train T1 robot with rl_games.")
|
||||
parser.add_argument("--num_envs", type=int, default=16384, help="Number of envs to run.")
|
||||
AppLauncher.add_app_launcher_args(parser)
|
||||
args_cli = parser.parse_args()
|
||||
|
||||
# 启动仿真器
|
||||
app_launcher = AppLauncher(args_cli)
|
||||
simulation_app = app_launcher.app
|
||||
|
||||
import torch
|
||||
import gymnasium as gym
|
||||
from isaaclab_rl.rl_games import RlGamesVecEnvWrapper
|
||||
from rl_games.torch_runner import Runner
|
||||
import yaml
|
||||
from config.t1_env_cfg import T1EnvCfg
|
||||
from rl_games.common import env_configurations, vecenv
|
||||
|
||||
gym.register(
|
||||
id="Isaac-T1-Walking-v0",
|
||||
entry_point="isaaclab.envs:ManagerBasedRLEnv", # Isaac Lab 统一的强化学习环境入口
|
||||
kwargs={
|
||||
"cfg": T1EnvCfg(),
|
||||
},
|
||||
)
|
||||
|
||||
def main():
|
||||
# 1. 创建环境 (保持不变)
|
||||
env = gym.make("Isaac-T1-Walking-v0", num_envs=args_cli.num_envs)
|
||||
|
||||
# 2. 包装环境 (保持不变)
|
||||
wrapped_env = RlGamesVecEnvWrapper(
|
||||
env,
|
||||
rl_device=args_cli.device,
|
||||
clip_obs=5.0,
|
||||
clip_actions=100.0
|
||||
)
|
||||
|
||||
vecenv.register('as_is', lambda config_name, num_actors, **kwargs: wrapped_env)
|
||||
|
||||
# 注册环境配置
|
||||
env_configurations.register('rlgym', {
|
||||
'vecenv_type': 'as_is',
|
||||
'env_creator': lambda **kwargs: wrapped_env
|
||||
})
|
||||
|
||||
# 3. 加载 PPO 配置 (保持不变)
|
||||
config_path = os.path.join(os.path.dirname(__file__), "config", "ppo_cfg.yaml")
|
||||
with open(config_path, "r") as f:
|
||||
rl_config = yaml.safe_load(f)
|
||||
|
||||
# 设置日志路径
|
||||
rl_game_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "."))
|
||||
log_dir = os.path.join(rl_game_dir, "logs")
|
||||
rl_config['params']['config']['train_dir'] = log_dir
|
||||
|
||||
# 4. 启动训练
|
||||
runner = Runner()
|
||||
|
||||
# 此时 rl_config 只有文本和数字,没有复杂对象,deepcopy 会成功
|
||||
runner.load(rl_config)
|
||||
|
||||
# 在 run 时传入对象是安全的
|
||||
runner.run({
|
||||
"train": True,
|
||||
"play": False,
|
||||
"vec_env": wrapped_env
|
||||
})
|
||||
|
||||
simulation_app.close()
|
||||
|
||||
# PYTHONPATH=. python rl_game/your_file_name/train.py
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user