change reward add punishment of joint_vel and root_vel_z_penalty
This commit is contained in:
@@ -53,35 +53,71 @@ def get_success_reward(env: ManagerBasedRLEnv, term_keys: str) -> torch.Tensor:
|
||||
"""检查是否触发了特定的成功终止条件"""
|
||||
return env.termination_manager.get_term(term_keys)
|
||||
|
||||
def root_lin_vel_norm(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor:
|
||||
"""手动计算 root 线性速度的 L2 范数"""
|
||||
# 获取速度 (16384, 3)
|
||||
def root_vel_z_l2_local(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor:
|
||||
"""专门惩罚 Z 轴方向的速度"""
|
||||
# 获取根部速度 (num_envs, 3) -> [vx, vy, vz]
|
||||
vel = env.scene[asset_cfg.name].data.root_lin_vel_w
|
||||
# 返回模长 (16384,)
|
||||
return torch.norm(vel, dim=-1)
|
||||
# 只取 Z 轴:vel[:, 2]
|
||||
return torch.square(vel[:, 2])
|
||||
|
||||
def root_ang_vel_norm(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor:
|
||||
"""手动计算 root 角速度的 L2 范数"""
|
||||
vel = env.scene[asset_cfg.name].data.root_ang_vel_w
|
||||
return torch.norm(vel, dim=-1)
|
||||
def joint_torques_l2_local(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor:
|
||||
"""计算机器人所有关节施加扭矩的平方和"""
|
||||
# 从 data.applied_torques 获取数据,通常形状为 (num_envs, num_joints)
|
||||
torques = env.scene[asset_cfg.name].data.applied_torque
|
||||
return torch.sum(torch.square(torques), dim=-1)
|
||||
|
||||
def joint_vel_l2_local(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor:
|
||||
"""计算机器人所有关节速度的平方和"""
|
||||
# 从 data.joint_vel 获取数据
|
||||
vel = env.scene[asset_cfg.name].data.joint_vel
|
||||
return torch.sum(torch.square(vel), dim=-1)
|
||||
|
||||
# --- 2. 配置类定义 ---
|
||||
|
||||
## 1. 定义与你的类一致的关节列表 (按照 ROBOT_MOTORS 的顺序)
|
||||
T1_JOINT_NAMES = [
|
||||
'Left_Hip_Pitch', 'Right_Hip_Pitch',
|
||||
'Left_Hip_Roll', 'Right_Hip_Roll',
|
||||
'Left_Hip_Yaw', 'Right_Hip_Yaw',
|
||||
'Left_Knee_Pitch', 'Right_Knee_Pitch',
|
||||
'Left_Ankle_Pitch', 'Right_Ankle_Pitch',
|
||||
'Left_Ankle_Roll', 'Right_Ankle_Roll'
|
||||
]
|
||||
|
||||
@configclass
|
||||
class T1ObservationCfg:
|
||||
"""观察值配置"""
|
||||
"""观察值空间配置:严格对应你的 Robot 基类数据结构"""
|
||||
|
||||
@configclass
|
||||
class PolicyCfg(ObsGroup):
|
||||
concatenate_terms = True
|
||||
enable_corruption = False
|
||||
|
||||
# --- 状态量 (对应你的 Robot 类属性) ---
|
||||
|
||||
# 1. 基体线速度 (accelerometer 相关的速度项)
|
||||
base_lin_vel = ObsTerm(func=mdp.base_lin_vel)
|
||||
|
||||
# 2. 角速度 (对应你的 gyroscope 属性: degrees/s -> IsaacLab 默认为 rad/s)
|
||||
base_ang_vel = ObsTerm(func=mdp.base_ang_vel)
|
||||
|
||||
# 3. 重力投影 (对应 global_orientation_euler/quat 相关的姿态感知)
|
||||
projected_gravity = ObsTerm(func=mdp.projected_gravity)
|
||||
root_height = ObsTerm(func=mdp.root_pos_w) # 高度信息对起身至关重要
|
||||
joint_pos = ObsTerm(func=mdp.joint_pos_rel)
|
||||
joint_vel = ObsTerm(func=mdp.joint_vel_rel)
|
||||
|
||||
# 4. 关节位置 (对应 motor_positions)
|
||||
# 使用 joint_pos_rel 获取相对于默认姿态的偏差,显式指定关节顺序
|
||||
joint_pos = ObsTerm(
|
||||
func=mdp.joint_pos_rel,
|
||||
params={"asset_cfg": SceneEntityCfg("robot", joint_names=T1_JOINT_NAMES)}
|
||||
)
|
||||
|
||||
# 5. 关节速度 (对应 motor_speeds)
|
||||
joint_vel = ObsTerm(
|
||||
func=mdp.joint_vel_rel,
|
||||
params={"asset_cfg": SceneEntityCfg("robot", joint_names=T1_JOINT_NAMES)}
|
||||
)
|
||||
|
||||
# 6. 上一次的动作 (对应 motor_targets)
|
||||
actions = ObsTerm(func=mdp.last_action)
|
||||
|
||||
policy = PolicyCfg()
|
||||
@@ -113,7 +149,7 @@ class T1ActionCfg:
|
||||
"""动作空间"""
|
||||
joint_pos = JointPositionActionCfg(
|
||||
asset_name="robot",
|
||||
joint_names=[".*"],
|
||||
joint_names=T1_JOINT_NAMES,
|
||||
scale=0.5,
|
||||
use_default_offset=True
|
||||
)
|
||||
@@ -121,45 +157,55 @@ class T1ActionCfg:
|
||||
|
||||
@configclass
|
||||
class T1GetUpRewardCfg:
|
||||
"""奖励函数:引导、稳定、终点奖"""
|
||||
# 1. 进度引导:越高分越高
|
||||
height_progress = RewTerm(
|
||||
func=mdp.root_height_below_minimum,
|
||||
weight=15.0,
|
||||
"""优化后的奖励函数:抑制跳跃,引导稳健起身"""
|
||||
|
||||
# 1. 高度引导 (改为平滑的指数奖励)
|
||||
# 相比 root_height_below_minimum,这个函数会让机器人越接近目标高度得分越高,且曲线平稳
|
||||
height_tracking = RewTerm(
|
||||
func=mdp.root_height_below_minimum, # 如果没有自定义函数,保留这个但调低权重
|
||||
weight=5.0, # 降低权重,防止“弹射”
|
||||
params={"minimum_height": 0.65}
|
||||
)
|
||||
|
||||
# 2. 时间惩罚:鼓励尽快起身
|
||||
time_penalty = RewTerm(func=mdp.is_alive, weight=-0.5)
|
||||
|
||||
# 3. 姿态奖:保持躯干垂直
|
||||
# 2. 姿态奖 (保持不变,这是核心)
|
||||
upright = RewTerm(func=mdp.flat_orientation_l2, weight=2.0)
|
||||
|
||||
# 4. 稳定性奖:站起来后不要乱晃
|
||||
root_static = RewTerm(
|
||||
func=root_lin_vel_norm, # 使用上面定义的本地函数
|
||||
weight=-1.5,
|
||||
# 3. 稳定性引导 (增加对速度的惩罚,抑制跳跃)
|
||||
# 惩罚过大的垂直速度,防止“跳起”
|
||||
root_vel_z_penalty = RewTerm(
|
||||
func=root_vel_z_l2_local, # 使用本地函数
|
||||
weight=-2.0,
|
||||
params={"asset_cfg": SceneEntityCfg("robot")} # 传入资产配置
|
||||
)
|
||||
|
||||
# 4. 关节与能量约束 (防止 NaN 和乱跳的关键)
|
||||
joint_vel = RewTerm(
|
||||
func=joint_vel_l2_local,
|
||||
weight=-0.005,
|
||||
params={"asset_cfg": SceneEntityCfg("robot")}
|
||||
)
|
||||
|
||||
# 角速度稳定性
|
||||
root_ang_static = RewTerm(
|
||||
func=root_ang_vel_norm, # 使用上面定义的本地函数
|
||||
weight=-0.5,
|
||||
applied_torque = RewTerm(
|
||||
func=joint_torques_l2_local,
|
||||
weight=-1.0e-5,
|
||||
params={"asset_cfg": SceneEntityCfg("robot")}
|
||||
)
|
||||
|
||||
# 5. 核心终点奖励:当满足 standing_success 终止条件时,给 500 分
|
||||
# 5. 动作平滑 (非常重要)
|
||||
action_rate = RewTerm(
|
||||
func=mdp.action_rate_l2,
|
||||
weight=-0.05 # 增大权重,强制动作连贯
|
||||
)
|
||||
|
||||
# 6. 核心终点奖励
|
||||
is_success = RewTerm(
|
||||
func=get_success_reward, # 使用我们刚刚定义的本地函数
|
||||
weight=500.0,
|
||||
params={"term_keys": "standing_success"} # 确保名字对应 TerminationsCfg 里的变量名
|
||||
func=get_success_reward,
|
||||
weight=1000.0, # 成功奖励可以给高点,但前提是动作要平稳
|
||||
params={"term_keys": "standing_success"}
|
||||
)
|
||||
|
||||
# 6. 平滑惩罚
|
||||
action_rate = RewTerm(func=mdp.action_rate_l2, weight=-0.01)
|
||||
joint_vel = RewTerm(func=mdp.joint_vel_l2, weight=-0.001)
|
||||
|
||||
# 7. 生存奖励 (保持微小正值即可)
|
||||
is_alive = RewTerm(func=mdp.is_alive, weight=0.1)
|
||||
|
||||
@configclass
|
||||
class T1GetUpTerminationsCfg:
|
||||
|
||||
@@ -41,19 +41,22 @@ gym.register(
|
||||
|
||||
|
||||
def main():
|
||||
# 4. 创建环境,显式传入命令行指定的 num_envs
|
||||
# --- 新增:处理 Retrain 参数 ---
|
||||
# 你可以手动指定路径,或者在 argparse 里增加一个 --checkpoint 参数
|
||||
checkpoint_path = os.path.join(os.path.dirname(__file__), "logs/T1_GetUp/nn/T1_GetUp.pth")
|
||||
# 检查模型文件是否存在
|
||||
should_retrain = os.path.exists(checkpoint_path)
|
||||
|
||||
env = gym.make("Isaac-T1-GetUp-v0", num_envs=args_cli.num_envs)
|
||||
|
||||
# 5. 包装环境
|
||||
# 注意:rl_device 必须设置为 args_cli.device (通常是 'cuda:0')
|
||||
wrapped_env = RlGamesVecEnvWrapper(
|
||||
env,
|
||||
rl_device=args_cli.device,
|
||||
clip_obs=5.0,
|
||||
clip_actions=1.0 # 动作裁剪建议设小一点,防止电机输出瞬间爆表
|
||||
clip_actions=1.0
|
||||
)
|
||||
|
||||
# 注册给 rl_games 使用
|
||||
vecenv.register('as_is', lambda config_name, num_actors, **kwargs: wrapped_env)
|
||||
|
||||
env_configurations.register('rlgym', {
|
||||
@@ -61,8 +64,6 @@ def main():
|
||||
'env_creator': lambda **kwargs: wrapped_env
|
||||
})
|
||||
|
||||
# 6. 加载 PPO 配置文件
|
||||
# 提示:由于是起身任务,建议在 ppo_cfg.yaml 中调大 mini_batch 大数或提高学习率
|
||||
config_path = os.path.join(os.path.dirname(__file__), "config", "ppo_cfg.yaml")
|
||||
with open(config_path, "r") as f:
|
||||
rl_config = yaml.safe_load(f)
|
||||
@@ -73,15 +74,23 @@ def main():
|
||||
rl_config['params']['config']['train_dir'] = log_dir
|
||||
rl_config['params']['config']['name'] = "T1_GetUp"
|
||||
|
||||
# --- 关键修改:注入模型路径 ---
|
||||
if should_retrain:
|
||||
print(f"[INFO]: 检测到预训练模型,正在从 {checkpoint_path} 恢复训练...")
|
||||
# rl_games 会读取 config 中的 load_path 进行续训
|
||||
rl_config['params']['config']['load_path'] = checkpoint_path
|
||||
else:
|
||||
print("[INFO]: 未找到预训练模型,将从零开始训练。")
|
||||
|
||||
# 7. 运行训练
|
||||
runner = Runner()
|
||||
runner.load(rl_config)
|
||||
|
||||
print(f"[INFO]: 开始训练任务 {args_cli.task},环境数量: {args_cli.num_envs}")
|
||||
|
||||
runner.run({
|
||||
"train": True,
|
||||
"play": False,
|
||||
# 如果你想强制从某个 checkpoint 开始,也可以在这里传参
|
||||
"checkpoint": checkpoint_path if should_retrain else None,
|
||||
"vec_env": wrapped_env
|
||||
})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user