92 lines
2.9 KiB
Python
92 lines
2.9 KiB
Python
import sys
|
||
import os
|
||
import argparse
|
||
|
||
# 确保能找到项目根目录下的模块
|
||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||
|
||
from isaaclab.app import AppLauncher
|
||
|
||
# 1. 配置启动参数
|
||
parser = argparse.ArgumentParser(description="Train T1 robot to Get-Up with RL-Games.")
|
||
parser.add_argument("--num_envs", type=int, default=16384, help="起身任务建议并行 4096 即可")
|
||
parser.add_argument("--task", type=str, default="Isaac-T1-GetUp-v0", help="任务 ID")
|
||
parser.add_argument("--seed", type=int, default=42, help="随机种子")
|
||
AppLauncher.add_app_launcher_args(parser)
|
||
args_cli = parser.parse_args()
|
||
|
||
# 2. 启动仿真器(必须在导入其他 isaaclab 模块前)
|
||
app_launcher = AppLauncher(args_cli)
|
||
simulation_app = app_launcher.app
|
||
|
||
import torch
|
||
import gymnasium as gym
|
||
import yaml
|
||
from isaaclab_rl.rl_games import RlGamesVecEnvWrapper
|
||
from rl_games.torch_runner import Runner
|
||
from rl_games.common import env_configurations, vecenv
|
||
|
||
# 导入你刚刚修改好的配置类
|
||
# 假设你的文件名是 t1_getup_cfg.py,类名是 T1EnvCfg
|
||
from config.t1_env_cfg import T1EnvCfg
|
||
|
||
# 3. 注册环境
|
||
gym.register(
|
||
id="Isaac-T1-GetUp-v0",
|
||
entry_point="isaaclab.envs:ManagerBasedRLEnv",
|
||
kwargs={
|
||
"cfg": T1EnvCfg(), # 这里会加载你设置的随机旋转、时间惩罚等
|
||
},
|
||
)
|
||
|
||
|
||
def main():
|
||
# 4. 创建环境,显式传入命令行指定的 num_envs
|
||
env = gym.make("Isaac-T1-GetUp-v0", num_envs=args_cli.num_envs)
|
||
|
||
# 5. 包装环境
|
||
# 注意:rl_device 必须设置为 args_cli.device (通常是 'cuda:0')
|
||
wrapped_env = RlGamesVecEnvWrapper(
|
||
env,
|
||
rl_device=args_cli.device,
|
||
clip_obs=5.0,
|
||
clip_actions=1.0 # 动作裁剪建议设小一点,防止电机输出瞬间爆表
|
||
)
|
||
|
||
# 注册给 rl_games 使用
|
||
vecenv.register('as_is', lambda config_name, num_actors, **kwargs: wrapped_env)
|
||
|
||
env_configurations.register('rlgym', {
|
||
'vecenv_type': 'as_is',
|
||
'env_creator': lambda **kwargs: wrapped_env
|
||
})
|
||
|
||
# 6. 加载 PPO 配置文件
|
||
# 提示:由于是起身任务,建议在 ppo_cfg.yaml 中调大 mini_batch 大数或提高学习率
|
||
config_path = os.path.join(os.path.dirname(__file__), "config", "ppo_cfg.yaml")
|
||
with open(config_path, "r") as f:
|
||
rl_config = yaml.safe_load(f)
|
||
|
||
# 设置日志和实验名称
|
||
rl_game_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "."))
|
||
log_dir = os.path.join(rl_game_dir, "logs")
|
||
rl_config['params']['config']['train_dir'] = log_dir
|
||
rl_config['params']['config']['name'] = "T1_GetUp"
|
||
|
||
# 7. 运行训练
|
||
runner = Runner()
|
||
runner.load(rl_config)
|
||
|
||
print(f"[INFO]: 开始训练任务 {args_cli.task},环境数量: {args_cli.num_envs}")
|
||
|
||
runner.run({
|
||
"train": True,
|
||
"play": False,
|
||
"vec_env": wrapped_env
|
||
})
|
||
|
||
simulation_app.close()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |