Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 02c06c23ad |
@@ -1,83 +1,101 @@
|
|||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
# 关键:确保当前目录在 sys.path 中,这样才能直接 from config 导入
|
import argparse
|
||||||
|
|
||||||
|
# 确保能找到项目根目录下的模块
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
import argparse
|
|
||||||
from isaaclab.app import AppLauncher
|
from isaaclab.app import AppLauncher
|
||||||
|
|
||||||
# 添加启动参数
|
# 1. 配置启动参数
|
||||||
parser = argparse.ArgumentParser(description="Train T1 robot with rl_games.")
|
parser = argparse.ArgumentParser(description="Train T1 robot to Get-Up with RL-Games.")
|
||||||
parser.add_argument("--num_envs", type=int, default=16384, help="Number of envs to run.")
|
parser.add_argument("--num_envs", type=int, default=16384, help="起身任务建议并行 4096 即可")
|
||||||
|
parser.add_argument("--task", type=str, default="Isaac-T1-GetUp-v0", help="任务 ID")
|
||||||
|
parser.add_argument("--seed", type=int, default=42, help="随机种子")
|
||||||
AppLauncher.add_app_launcher_args(parser)
|
AppLauncher.add_app_launcher_args(parser)
|
||||||
args_cli = parser.parse_args()
|
args_cli = parser.parse_args()
|
||||||
|
|
||||||
# 启动仿真器
|
# 2. 启动仿真器(必须在导入其他 isaaclab 模块前)
|
||||||
app_launcher = AppLauncher(args_cli)
|
app_launcher = AppLauncher(args_cli)
|
||||||
simulation_app = app_launcher.app
|
simulation_app = app_launcher.app
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
import yaml
|
||||||
from isaaclab_rl.rl_games import RlGamesVecEnvWrapper
|
from isaaclab_rl.rl_games import RlGamesVecEnvWrapper
|
||||||
from rl_games.torch_runner import Runner
|
from rl_games.torch_runner import Runner
|
||||||
import yaml
|
|
||||||
from config.t1_env_cfg import T1EnvCfg
|
|
||||||
from rl_games.common import env_configurations, vecenv
|
from rl_games.common import env_configurations, vecenv
|
||||||
|
|
||||||
|
# 导入你刚刚修改好的配置类
|
||||||
|
# 假设你的文件名是 t1_getup_cfg.py,类名是 T1EnvCfg
|
||||||
|
from config.t1_env_cfg import T1EnvCfg
|
||||||
|
|
||||||
|
# 3. 注册环境
|
||||||
gym.register(
|
gym.register(
|
||||||
id="Isaac-T1-Walking-v0",
|
id="Isaac-T1-Walk-v0",
|
||||||
entry_point="isaaclab.envs:ManagerBasedRLEnv", # Isaac Lab 统一的强化学习环境入口
|
entry_point="isaaclab.envs:ManagerBasedRLEnv",
|
||||||
kwargs={
|
kwargs={
|
||||||
"cfg": T1EnvCfg(),
|
"cfg": T1EnvCfg(), # 这里会加载你设置的随机旋转、时间惩罚等
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def main():
|
|
||||||
# 1. 创建环境 (保持不变)
|
|
||||||
env = gym.make("Isaac-T1-Walking-v0", num_envs=args_cli.num_envs)
|
|
||||||
|
|
||||||
# 2. 包装环境 (保持不变)
|
def main():
|
||||||
|
# --- 新增:处理 Retrain 参数 ---
|
||||||
|
# 你可以手动指定路径,或者在 argparse 里增加一个 --checkpoint 参数
|
||||||
|
checkpoint_path = os.path.join(os.path.dirname(__file__), "logs/T1_GetUp/nn/**.pth")
|
||||||
|
# 检查模型文件是否存在
|
||||||
|
should_retrain = os.path.exists(checkpoint_path)
|
||||||
|
|
||||||
|
env = gym.make("Isaac-T1-Walk-v0", num_envs=args_cli.num_envs)
|
||||||
|
|
||||||
|
# 注意:rl_device 必须设置为 args_cli.device (通常是 'cuda:0')
|
||||||
wrapped_env = RlGamesVecEnvWrapper(
|
wrapped_env = RlGamesVecEnvWrapper(
|
||||||
env,
|
env,
|
||||||
rl_device=args_cli.device,
|
rl_device=args_cli.device,
|
||||||
clip_obs=5.0,
|
clip_obs=5.0,
|
||||||
clip_actions=100.0
|
clip_actions=1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
vecenv.register('as_is', lambda config_name, num_actors, **kwargs: wrapped_env)
|
vecenv.register('as_is', lambda config_name, num_actors, **kwargs: wrapped_env)
|
||||||
|
|
||||||
# 注册环境配置
|
|
||||||
env_configurations.register('rlgym', {
|
env_configurations.register('rlgym', {
|
||||||
'vecenv_type': 'as_is',
|
'vecenv_type': 'as_is',
|
||||||
'env_creator': lambda **kwargs: wrapped_env
|
'env_creator': lambda **kwargs: wrapped_env
|
||||||
})
|
})
|
||||||
|
|
||||||
# 3. 加载 PPO 配置 (保持不变)
|
|
||||||
config_path = os.path.join(os.path.dirname(__file__), "config", "ppo_cfg.yaml")
|
config_path = os.path.join(os.path.dirname(__file__), "config", "ppo_cfg.yaml")
|
||||||
with open(config_path, "r") as f:
|
with open(config_path, "r") as f:
|
||||||
rl_config = yaml.safe_load(f)
|
rl_config = yaml.safe_load(f)
|
||||||
|
|
||||||
# 设置日志路径
|
# 设置日志和实验名称
|
||||||
rl_game_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "."))
|
rl_game_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "."))
|
||||||
log_dir = os.path.join(rl_game_dir, "logs")
|
log_dir = os.path.join(rl_game_dir, "logs")
|
||||||
rl_config['params']['config']['train_dir'] = log_dir
|
rl_config['params']['config']['train_dir'] = log_dir
|
||||||
|
rl_config['params']['config']['name'] = "T1_GetUp"
|
||||||
|
|
||||||
# 4. 启动训练
|
# --- 关键修改:注入模型路径 ---
|
||||||
|
if should_retrain:
|
||||||
|
print(f"[INFO]: 检测到预训练模型,正在从 {checkpoint_path} 恢复训练...")
|
||||||
|
# rl_games 会读取 config 中的 load_path 进行续训
|
||||||
|
rl_config['params']['config']['load_path'] = checkpoint_path
|
||||||
|
else:
|
||||||
|
print("[INFO]: 未找到预训练模型,将从零开始训练。")
|
||||||
|
|
||||||
|
# 7. 运行训练
|
||||||
runner = Runner()
|
runner = Runner()
|
||||||
|
|
||||||
# 此时 rl_config 只有文本和数字,没有复杂对象,deepcopy 会成功
|
|
||||||
runner.load(rl_config)
|
runner.load(rl_config)
|
||||||
|
|
||||||
# 在 run 时传入对象是安全的
|
|
||||||
runner.run({
|
runner.run({
|
||||||
"train": True,
|
"train": True,
|
||||||
"play": False,
|
"play": False,
|
||||||
|
# 如果你想强制从某个 checkpoint 开始,也可以在这里传参
|
||||||
|
"checkpoint": checkpoint_path if should_retrain else None,
|
||||||
"vec_env": wrapped_env
|
"vec_env": wrapped_env
|
||||||
})
|
})
|
||||||
|
|
||||||
simulation_app.close()
|
simulation_app.close()
|
||||||
|
|
||||||
# PYTHONPATH=. python rl_game/your_file_name/train.py
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
Reference in New Issue
Block a user