38 Commits

Author SHA1 Message Date
4bc205399c reward modification and change the get_up logic 2026-03-23 09:06:36 -04:00
af42087bd8 Amend tiny bug 2026-03-22 21:21:17 -04:00
7f7ec781c5 Add weighting function, change the reward logic 2026-03-22 21:11:46 -04:00
a642274fa6 Amend symbol to save video memory 2026-03-22 03:05:24 -04:00
20c961936d Amend symbol 2026-03-22 02:57:04 -04:00
0315b4cb99 prevent gradient explosion 2026-03-22 02:55:07 -04:00
616dd06e78 Amend success rewards 2026-03-22 02:32:58 -04:00
2e2d68a933 change the reward remove arm disturbance 2026-03-22 02:26:16 -04:00
f7c8e6e325 Amend bugs 2026-03-22 02:20:17 -04:00
a8199fd056 Amend arm push reward 2026-03-22 02:19:29 -04:00
0e70d34e81 Amend bugs 2026-03-22 00:01:21 -04:00
905e998596 change model 2026-03-21 23:46:59 -04:00
4833ba33c8 change parameter 2026-03-21 10:16:01 -04:00
fd8238dc41 Amend arm reward to get reward difficultly 2026-03-21 09:30:43 -04:00
72a22bd78a change arm to push the ground reward function 2026-03-21 08:38:17 -04:00
d78fdeda0d change reward function 2026-03-21 07:00:49 -04:00
6d2ad9846a change parameter 2026-03-20 10:51:07 -04:00
1fbc9dccac change parameter 2026-03-20 09:53:34 -04:00
49da77db51 change parameter 2026-03-20 08:55:29 -04:00
c0088ebac3 Amend tiny bug 2026-03-20 08:12:08 -04:00
00d3be8e7a Amend tiny bug 2026-03-20 08:00:51 -04:00
ad2255bc18 change parameter 2026-03-20 07:06:42 -04:00
14f2151014 Amend bugs 2026-03-20 07:03:41 -04:00
31a9fa9965 change T1EventCfg to add more initial state 2026-03-20 05:20:17 -04:00
2ae7210062 Amend for standing 2026-03-20 03:37:56 -04:00
9cfc127694 Amend bug 2026-03-19 09:36:32 -04:00
af3ba4704f Add feet_airtime loss 2026-03-19 09:25:20 -04:00
5df147b0b1 Add arm link rewards 2026-03-19 09:08:57 -04:00
6ca671dce5 change rewards 2026-03-19 06:29:30 -04:00
d4089b103e change init nums 2026-03-18 06:36:40 -04:00
118d39f4bc change env num 2026-03-18 06:32:06 -04:00
fdfd962fbc Amend a tiny bug 2026-03-18 06:18:29 -04:00
08d1bb539b Amend a tiny bug 2026-03-18 06:11:30 -04:00
9f3ec9d67a Amend some codes to init training for get up better 2026-03-18 06:05:30 -04:00
4933567ef8 change reward add punishment of joint_vel and root_vel_z_penalty 2026-03-17 05:54:20 -04:00
c1e3d9382f Add reward to maintain an upright and stable position 2026-03-16 09:23:22 -04:00
6510cb0bfc Amend some bugs and make it training 2026-03-16 05:46:49 -04:00
4b0b1fac8d The demo of get up 2026-03-16 05:00:20 -04:00
7 changed files with 562 additions and 0 deletions

View File

@@ -0,0 +1,13 @@
import gymnasium as gym
# 导入你的配置
from rl_game.demo.config.t1_env_cfg import T1EnvCfg
# 注册环境到 Gymnasium
gym.register(
id="Isaac-T1-GetUp-v0",
entry_point="isaaclab.envs:ManagerBasedRLEnv", # Isaac Lab 统一的强化学习环境入口
kwargs={
"cfg": T1EnvCfg(),
},
)

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,60 @@
params:
seed: 42
algo:
name: a2c_continuous
model:
name: continuous_a2c_logstd
network:
name: actor_critic
separate: False
space:
continuous:
mu_activation: None
sigma_activation: None
mu_init:
name: default
sigma_init:
name: const_initializer
val: 0.5
fixed_sigma: False
mlp:
units: [512, 256, 128]
activation: relu
d2rl: False
initializer:
name: default
config:
name: T1_Walking
env_name: rlgym # Isaac Lab 包装器
multi_gpu: False
ppo: True
mixed_precision: True
normalize_input: True
normalize_value: True
value_bootstrap: True
num_actors: 8192 # 同时训练的机器人数量
reward_shaper:
scale_value: 1.0
normalize_advantage: True
gamma: 0.98
tau: 0.95
learning_rate: 3e-4
lr_schedule: adaptive
kl_threshold: 0.015
score_to_win: 20000
max_epochs: 500
save_best_after: 50
save_frequency: 100
grad_norm: 1.0
entropy_coef: 0.005
truncate_grads: True
bounds_loss_coef: 0.001
e_clip: 0.2
horizon_length: 256
minibatch_size: 65536
mini_epochs: 4
critic_coef: 1
clip_value: True

View File

@@ -0,0 +1,314 @@
import torch
import random
import numpy as np
import isaaclab.envs.mdp as mdp
from isaaclab.assets import ArticulationCfg
from isaaclab.envs import ManagerBasedRLEnvCfg, ManagerBasedRLEnv
from isaaclab.managers import ObservationGroupCfg as ObsGroup
from isaaclab.managers import ObservationTermCfg as ObsTerm
from isaaclab.managers import RewardTermCfg as RewTerm
from isaaclab.managers import TerminationTermCfg as DoneTerm
from isaaclab.managers import EventTermCfg as EventTerm
from isaaclab.envs.mdp import JointPositionActionCfg
from isaaclab.managers import SceneEntityCfg
from isaaclab.utils import configclass
from rl_game.get_up.env.t1_env import T1SceneCfg
# --- 1. 自定义 MDP 逻辑函数 ---
def standing_with_feet_reward(
env: ManagerBasedRLEnv,
min_head_height: float,
min_pelvis_height: float,
sensor_cfg: SceneEntityCfg,
force_threshold: float = 20.0,
max_v_z: float = 0.5
) -> torch.Tensor:
"""终极高度目标:头高、盆骨高、足部受力稳定"""
head_idx, _ = env.scene["robot"].find_bodies("H2")
pelvis_idx, _ = env.scene["robot"].find_bodies("Trunk")
curr_head_h = env.scene["robot"].data.body_state_w[:, head_idx[0], 2]
curr_pelvis_h = env.scene["robot"].data.body_state_w[:, pelvis_idx[0], 2]
# 归一化高度评分
head_score = torch.clamp(curr_head_h / min_head_height, 0.0, 1.2)
pelvis_score = torch.clamp(curr_pelvis_h / min_pelvis_height, 0.0, 1.2)
height_reward = (head_score + pelvis_score) / 2.0
# 足部受力判定
contact_sensor = env.scene.sensors.get(sensor_cfg.name)
if contact_sensor is None: return torch.zeros(env.num_envs, device=env.device)
foot_forces_z = torch.sum(contact_sensor.data.net_forces_w[:, :, 2], dim=-1)
force_weight = torch.sigmoid((foot_forces_z - force_threshold) / 5.0)
# 垂直速度惩罚(防止跳跃不稳)
root_vel_z = env.scene["robot"].data.root_lin_vel_w[:, 2]
vel_penalty = torch.exp(-torch.abs(root_vel_z) / max_v_z)
return height_reward * (0.5 + 0.5 * force_weight * vel_penalty)
def dynamic_getup_strategy_reward(env: ManagerBasedRLEnv) -> torch.Tensor:
"""
全姿态对称起立策略:
1. 核心蜷缩 (Spring Loading):无论仰卧还是俯卧,只要高度低,就必须强制收腿。
2. 仰卧支撑 (Back-Pushing):在仰卧状态下,鼓励手臂向后发力并抬高盆骨。
3. 协同爆发 (Explosive Jump):蜷缩状态下产生的向上动量获得最高倍率奖励。
"""
# --- 1. 获取物理状态 ---
gravity_z = env.scene["robot"].data.projected_gravity_b[:, 2] # 1:仰卧, -1:俯卧
pelvis_idx, _ = env.scene["robot"].find_bodies("Trunk")
curr_pelvis_h = env.scene["robot"].data.body_state_w[:, pelvis_idx[0], 2]
root_vel_z = env.scene["robot"].data.root_lin_vel_w[:, 2]
# 关节索引11,12髋, 17,18膝 (确保与T1模型一致)
knee_joints = [17, 18]
hip_pitch_joints = [11, 12]
joint_pos = env.scene["robot"].data.joint_pos
# --- 2. 核心蜷缩评分 (Crouch Score) ---
# 无论仰俯,蜷缩是起立的绝对前提。目标是让脚尽可能靠近质心。
# 提高膝盖弯曲目标 (1.5 rad),引导更深度的折叠
knee_flex_err = torch.abs(joint_pos[:, knee_joints] - 1.5).sum(dim=-1)
hip_flex_err = torch.abs(joint_pos[:, hip_pitch_joints] - 1.2).sum(dim=-1)
crouch_score = torch.exp(-(knee_flex_err + hip_flex_err) * 0.6)
# 基础蜷缩奖励 (Spring Base) - 权重加大
crouch_trigger = torch.clamp(0.6 - curr_pelvis_h, min=0.0)
base_crouch_reward = crouch_trigger * crouch_score * 40.0
# --- 3. 支撑力奖励 (Support Force) ---
push_reward = torch.zeros_like(curr_pelvis_h)
contact_sensor = env.scene.sensors.get("contact_sensor")
if contact_sensor is not None:
# 监测非足部Link手、臂的受力
# 无论正反,只要手能提供垂直向上的推力,就是好手
arm_forces_z = contact_sensor.data.net_forces_w[:, :, 2]
push_reward = torch.tanh(torch.max(arm_forces_z, dim=-1)[0] / 30.0)
# --- 4. 姿态特定引导 (Orientation-Neutral) ---
is_back = torch.clamp(gravity_z, min=0.0) # 仰卧程度
is_belly = torch.clamp(-gravity_z, min=0.0) # 俯卧程度
# A. 仰卧直接起立逻辑:
# 在仰卧时,如果能把盆骨撑起来 (curr_pelvis_h 增加),给予重奖
# 配合crouch_score鼓励“收腿-撑地-挺髋”的动作链
back_lift_reward = is_back * torch.clamp(curr_pelvis_h - 0.15, min=0.0) * crouch_score * 50.0
# B. 俯卧/翻身辅助逻辑 (保留一定的翻身倾向,但不再是唯一路径)
flip_reward = is_back * (1.0 - gravity_z) * 5.0 # 权重降低,仅作为备选
# --- 5. 最终爆发项 (The Jump) ---
# 核心公式:蜷缩程度 * 向上速度 * 支撑力感应
# 这是一个通用的“起跳”奖励,无论正反面,只要满足“缩得紧、跳得快、手有撑”,奖励就爆炸
explosion_reward = crouch_score * torch.clamp(root_vel_z, min=0.0) * (0.5 + 0.5 * push_reward) * 80.0
# --- 6. 汇总 ---
total_reward = (
base_crouch_reward + # 必须缩腿
back_lift_reward + # 仰卧挺髋
flip_reward + # 翻身尝试
explosion_reward # 终极爆发
)
return total_reward
def is_standing_still(
env: ManagerBasedRLEnv,
min_head_height: float,
min_pelvis_height: float,
max_angle_error: float,
standing_time: float,
velocity_threshold: float = 0.15
) -> torch.Tensor:
head_idx, _ = env.scene["robot"].find_bodies("H2")
pelvis_idx, _ = env.scene["robot"].find_bodies("Trunk")
current_head_h = env.scene["robot"].data.body_state_w[:, head_idx[0], 2]
current_pelvis_h = env.scene["robot"].data.body_state_w[:, pelvis_idx[0], 2]
gravity_error = torch.norm(env.scene["robot"].data.projected_gravity_b[:, :2], dim=-1)
root_vel_norm = torch.norm(env.scene["robot"].data.root_lin_vel_w, dim=-1)
is_stable_now = (
(current_head_h > min_head_height) &
(current_pelvis_h > min_pelvis_height) &
(gravity_error < max_angle_error) &
(root_vel_norm < velocity_threshold)
)
if "stable_timer" not in env.extras:
env.extras["stable_timer"] = torch.zeros(env.num_envs, device=env.device)
dt = env.physics_dt * env.cfg.decimation
env.extras["stable_timer"] = torch.where(is_stable_now, env.extras["stable_timer"] + dt,
torch.zeros_like(env.extras["stable_timer"]))
return env.extras["stable_timer"] > standing_time
# --- 2. 配置类 ---
T1_JOINT_NAMES = [
'AAHead_yaw', 'Head_pitch',
'Left_Shoulder_Pitch', 'Left_Shoulder_Roll', 'Left_Elbow_Pitch', 'Left_Elbow_Yaw',
'Right_Shoulder_Pitch', 'Right_Shoulder_Roll', 'Right_Elbow_Pitch', 'Right_Elbow_Yaw',
'Waist',
'Left_Hip_Pitch', 'Right_Hip_Pitch', 'Left_Hip_Roll', 'Right_Hip_Roll',
'Left_Hip_Yaw', 'Right_Hip_Yaw', 'Left_Knee_Pitch', 'Right_Knee_Pitch',
'Left_Ankle_Pitch', 'Right_Ankle_Pitch', 'Left_Ankle_Roll', 'Right_Ankle_Roll'
]
@configclass
class T1ObservationCfg:
@configclass
class PolicyCfg(ObsGroup):
concatenate_terms = True
base_lin_vel = ObsTerm(func=mdp.base_lin_vel)
base_ang_vel = ObsTerm(func=mdp.base_ang_vel)
projected_gravity = ObsTerm(func=mdp.projected_gravity)
root_pos = ObsTerm(func=mdp.root_pos_w)
joint_pos = ObsTerm(func=mdp.joint_pos_rel,
params={"asset_cfg": SceneEntityCfg("robot", joint_names=T1_JOINT_NAMES)})
joint_vel = ObsTerm(func=mdp.joint_vel_rel,
params={"asset_cfg": SceneEntityCfg("robot", joint_names=T1_JOINT_NAMES)})
actions = ObsTerm(func=mdp.last_action)
policy = PolicyCfg()
@configclass
class T1EventCfg:
reset_robot_rotation = EventTerm(
func=mdp.reset_root_state_uniform,
params={
"asset_cfg": SceneEntityCfg("robot"),
"pose_range": {
"roll": (-1.57, 1.57),
"pitch": tuple(np.array([1.4, 1.6], dtype=np.float32) * random.choice([-1 , 1])),
"yaw": (-3.14, 3.14),
"x": (0.0, 0.0),
"y": (0.0, 0.0),
"z": (0.3, 0.4),
},
"velocity_range": {},
},
mode="reset",
)
@configclass
class T1ActionCfg:
# 拆分动作组以防止抽搐。由于不强制规定动作,我们可以给各个部位较为均衡的探索范围。
arm_action = JointPositionActionCfg(
asset_name="robot",
joint_names=[
'Left_Shoulder_Pitch', 'Left_Shoulder_Roll', 'Left_Elbow_Pitch', 'Left_Elbow_Yaw',
'Right_Shoulder_Pitch', 'Right_Shoulder_Roll', 'Right_Elbow_Pitch', 'Right_Elbow_Yaw'
],
scale=1.2, # 给了手臂相对充裕的自由度去摸索
use_default_offset=True
)
torso_action = JointPositionActionCfg(
asset_name="robot",
joint_names=['Waist', 'AAHead_yaw', 'Head_pitch'],
scale=0.8,
use_default_offset=True
)
leg_action = JointPositionActionCfg(
asset_name="robot",
joint_names=[
'Left_Hip_Pitch', 'Right_Hip_Pitch', 'Left_Hip_Roll', 'Right_Hip_Roll',
'Left_Hip_Yaw', 'Right_Hip_Yaw', 'Left_Knee_Pitch', 'Right_Knee_Pitch',
'Left_Ankle_Pitch', 'Right_Ankle_Pitch', 'Left_Ankle_Roll', 'Right_Ankle_Roll'
],
scale=0.6,
use_default_offset=True
)
@configclass
class T1GetUpRewardCfg:
# 1. 核心阶段性引导 (翻身 -> 蜷缩 -> 支撑)
dynamic_strategy = RewTerm(
func=dynamic_getup_strategy_reward,
weight=1.5
)
# 2. 站立质量奖励 (强化双脚受力)
height_with_feet = RewTerm(
func=standing_with_feet_reward,
weight=40.0, # 大权重
params={
"min_head_height": 1.1,
"min_pelvis_height": 0.7,
"sensor_cfg": SceneEntityCfg("contact_sensor", body_names=[".*_foot_link"]),
"force_threshold": 40.0, # 必须达到一定压力,防止脚尖点地作弊
"max_v_z": 0.2
}
)
# 3. 惩罚项:防止钻空子
# 严厉惩罚:如果躯干(Trunk)或头(H2)直接接触地面,扣大分
body_contact_penalty = RewTerm(
func=mdp.contact_forces,
weight=-20.0,
params={
"sensor_cfg": SceneEntityCfg("contact_sensor", body_names=["Trunk", "H2"]),
"threshold": 1.0
}
)
# 4. 关节功耗惩罚 (防止高频抽搐)
action_rate = RewTerm(
func=mdp.action_rate_l2,
weight=-0.01
)
# 5. 成功维持奖励
is_success_maintain = RewTerm(
func=is_standing_still,
weight=1000.0, # 巨大的成功奖励
params={
"min_head_height": 1.08,
"min_pelvis_height": 0.72,
"max_angle_error": 0.2,
"standing_time": 0.4, # 必须站稳 0.4s
"velocity_threshold": 0.3
}
)
@configclass
class T1GetUpTerminationsCfg:
time_out = DoneTerm(func=mdp.time_out)
standing_success = DoneTerm(
func=is_standing_still,
params={
"min_head_height": 1.05,
"min_pelvis_height": 0.75,
"max_angle_error": 0.3,
"standing_time": 0.2,
"velocity_threshold": 0.5
}
)
@configclass
class T1EnvCfg(ManagerBasedRLEnvCfg):
scene = T1SceneCfg(num_envs=8192, env_spacing=2.5)
observations = T1ObservationCfg()
rewards = T1GetUpRewardCfg()
terminations = T1GetUpTerminationsCfg()
events = T1EventCfg()
actions = T1ActionCfg()
episode_length_s = 10.0
decimation = 4

74
rl_game/get_up/env/t1_env.py vendored Normal file
View File

@@ -0,0 +1,74 @@
from isaaclab.assets import ArticulationCfg, AssetBaseCfg
from isaaclab.scene import InteractiveSceneCfg
from isaaclab.sensors import ContactSensorCfg
from isaaclab.utils import configclass
from isaaclab.actuators import ImplicitActuatorCfg
from isaaclab import sim as sim_utils
import os
_DEMO_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
T1_USD_PATH = os.path.join(_DEMO_DIR, "asset", "t1", "T1_locomotion_physics_lab.usd")
@configclass
class T1SceneCfg(InteractiveSceneCfg):
"""最终修正版:彻底解决 Unknown asset config type 报错"""
# 1. 地面配置:直接在 spawn 内部定义材质
ground = AssetBaseCfg(
prim_path="/World/ground",
spawn=sim_utils.GroundPlaneCfg(
physics_material=sim_utils.RigidBodyMaterialCfg(
static_friction=1.0,
dynamic_friction=1.0,
restitution=0.3,
friction_combine_mode="average",
restitution_combine_mode="average",
)
),
)
# 2. 机器人配置
robot = ArticulationCfg(
prim_path="{ENV_REGEX_NS}/Robot",
spawn=sim_utils.UsdFileCfg(
usd_path=T1_USD_PATH,
activate_contact_sensors=True,
rigid_props=sim_utils.RigidBodyPropertiesCfg(
disable_gravity=False,
max_depenetration_velocity=10.0,
),
articulation_props=sim_utils.ArticulationRootPropertiesCfg(
enabled_self_collisions=True,
solver_position_iteration_count=8,
solver_velocity_iteration_count=4,
),
),
init_state=ArticulationCfg.InitialStateCfg(
pos=(0.0, 0.0, 0.4), # 掉落高度
joint_pos={".*": 0.0},
),
actuators={
"t1_joints": ImplicitActuatorCfg(
joint_names_expr=[".*"],
effort_limit=800.0, # 翻倍,确保电机有力气
velocity_limit=20.0,
stiffness=500.0, # 【关键】从 150 提到 500-800 之间
damping=40.0, # 【关键】从 5 提到 30-50 之间,抑制乱抖
),
},
)
contact_sensor = ContactSensorCfg(
prim_path="{ENV_REGEX_NS}/Robot/.*",
update_period=0.0,
history_length=3,
)
# 3. 光照配置
light = AssetBaseCfg(
prim_path="/World/light",
spawn=sim_utils.DistantLightCfg(color=(0.75, 0.75, 0.75), intensity=3000.0),
)
# ['Trunk', 'H1', 'H2', 'AL1', 'AL2', 'AL3', 'left_hand_link', 'AR1', 'AR2', 'AR3', 'right_hand_link', 'Waist', 'Hip_Pitch_Left', 'Hip_Roll_Left', 'Hip_Yaw_Left', 'Shank_Left', 'Ankle_Cross_Left', 'left_foot_link', 'Hip_Pitch_Right', 'Hip_Roll_Right', 'Hip_Yaw_Right', 'Shank_Right', 'Ankle_Cross_Right', 'right_foot_link']

101
rl_game/get_up/train.py Normal file
View File

@@ -0,0 +1,101 @@
import sys
import os
import argparse
# 确保能找到项目根目录下的模块
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from isaaclab.app import AppLauncher
# 1. 配置启动参数
parser = argparse.ArgumentParser(description="Train T1 robot to Get-Up with RL-Games.")
parser.add_argument("--num_envs", type=int, default=8192, help="起身任务建议并行 4096 即可")
parser.add_argument("--task", type=str, default="Isaac-T1-GetUp-v0", help="任务 ID")
parser.add_argument("--seed", type=int, default=42, help="随机种子")
AppLauncher.add_app_launcher_args(parser)
args_cli = parser.parse_args()
# 2. 启动仿真器(必须在导入其他 isaaclab 模块前)
app_launcher = AppLauncher(args_cli)
simulation_app = app_launcher.app
import torch
import gymnasium as gym
import yaml
from isaaclab_rl.rl_games import RlGamesVecEnvWrapper
from rl_games.torch_runner import Runner
from rl_games.common import env_configurations, vecenv
# 导入你刚刚修改好的配置类
# 假设你的文件名是 t1_getup_cfg.py类名是 T1EnvCfg
from config.t1_env_cfg import T1EnvCfg
# 3. 注册环境
gym.register(
id="Isaac-T1-GetUp-v0",
entry_point="isaaclab.envs:ManagerBasedRLEnv",
kwargs={
"cfg": T1EnvCfg(), # 这里会加载你设置的随机旋转、时间惩罚等
},
)
def main():
# --- 新增:处理 Retrain 参数 ---
# 你可以手动指定路径,或者在 argparse 里增加一个 --checkpoint 参数
checkpoint_path = os.path.join(os.path.dirname(__file__), "logs/T1_GetUp/nn/T1_GetUp.pth")
# 检查模型文件是否存在
should_retrain = os.path.exists(checkpoint_path)
env = gym.make("Isaac-T1-GetUp-v0", num_envs=args_cli.num_envs)
# 注意rl_device 必须设置为 args_cli.device (通常是 'cuda:0')
wrapped_env = RlGamesVecEnvWrapper(
env,
rl_device=args_cli.device,
clip_obs=5.0,
clip_actions=1.0
)
vecenv.register('as_is', lambda config_name, num_actors, **kwargs: wrapped_env)
env_configurations.register('rlgym', {
'vecenv_type': 'as_is',
'env_creator': lambda **kwargs: wrapped_env
})
config_path = os.path.join(os.path.dirname(__file__), "config", "ppo_cfg.yaml")
with open(config_path, "r") as f:
rl_config = yaml.safe_load(f)
# 设置日志和实验名称
rl_game_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "."))
log_dir = os.path.join(rl_game_dir, "logs")
rl_config['params']['config']['train_dir'] = log_dir
rl_config['params']['config']['name'] = "T1_GetUp"
# --- 关键修改:注入模型路径 ---
if should_retrain:
print(f"[INFO]: 检测到预训练模型,正在从 {checkpoint_path} 恢复训练...")
# rl_games 会读取 config 中的 load_path 进行续训
rl_config['params']['config']['load_path'] = checkpoint_path
else:
print("[INFO]: 未找到预训练模型,将从零开始训练。")
# 7. 运行训练
runner = Runner()
runner.load(rl_config)
runner.run({
"train": True,
"play": False,
# 如果你想强制从某个 checkpoint 开始,也可以在这里传参
"checkpoint": checkpoint_path if should_retrain else None,
"vec_env": wrapped_env
})
simulation_app.close()
if __name__ == "__main__":
main()