Amend arm push reward
This commit is contained in:
@@ -49,39 +49,56 @@ def standing_with_feet_reward(
|
|||||||
return combined_reward
|
return combined_reward
|
||||||
|
|
||||||
|
|
||||||
def arm_push_up_reward(
|
def universal_arm_support_reward(
|
||||||
env: ManagerBasedRLEnv,
|
env: ManagerBasedRLEnv,
|
||||||
sensor_cfg: SceneEntityCfg,
|
sensor_cfg: SceneEntityCfg,
|
||||||
height_threshold: float = 0.65,
|
height_threshold: float = 0.60,
|
||||||
min_force: float = 3.0
|
min_force: float = 2.0
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
通用手臂支撑奖励:同时支持仰卧起坐支撑和俯卧撑起。
|
||||||
|
逻辑:只要手臂有向上的推力,且身体正在向上移动,就给奖。
|
||||||
|
"""
|
||||||
|
# 1. 获取传感器数据
|
||||||
contact_sensor = env.scene.sensors.get(sensor_cfg.name)
|
contact_sensor = env.scene.sensors.get(sensor_cfg.name)
|
||||||
if contact_sensor is None:
|
if contact_sensor is None:
|
||||||
return torch.zeros(env.num_envs, device=env.device)
|
return torch.zeros(env.num_envs, device=env.device)
|
||||||
|
|
||||||
# 1. 获取受力数据
|
# 获取所有定义的手臂/手部 link 的垂直总受力 (World Z)
|
||||||
|
# net_forces_w 形状: (num_envs, num_bodies, 3)
|
||||||
arm_forces_z = contact_sensor.data.net_forces_w[:, :, 2]
|
arm_forces_z = contact_sensor.data.net_forces_w[:, :, 2]
|
||||||
avg_arm_force = torch.mean(arm_forces_z, dim=-1)
|
# 取所有受力点的最大值或平均值,代表支撑强度
|
||||||
|
max_arm_force = torch.max(arm_forces_z, dim=-1)[0]
|
||||||
|
|
||||||
# 2. 几何限制:手臂必须在躯干下方 (修复了之前的 AttributeError)
|
# 2. 获取状态数据
|
||||||
arm_body_indices, _ = env.scene["robot"].find_bodies(sensor_cfg.body_names)
|
|
||||||
pelvis_idx, _ = env.scene["robot"].find_bodies("Trunk")
|
pelvis_idx, _ = env.scene["robot"].find_bodies("Trunk")
|
||||||
pelvis_pos_z = env.scene["robot"].data.body_state_w[:, pelvis_idx[0], 2]
|
pelvis_pos_z = env.scene["robot"].data.body_state_w[:, pelvis_idx[0], 2]
|
||||||
arm_pos_z = env.scene["robot"].data.body_state_w[:, arm_body_indices, 2]
|
|
||||||
|
|
||||||
# 手臂是否全部低于盆骨
|
|
||||||
is_below_pelvis = torch.all(arm_pos_z < pelvis_pos_z.unsqueeze(1), dim=-1).float()
|
|
||||||
|
|
||||||
# 3. 计算奖励
|
|
||||||
force_reward = torch.clamp((avg_arm_force - min_force) / 45.0, min=0.0, max=1.0)
|
|
||||||
root_vel_z = env.scene["robot"].data.root_lin_vel_w[:, 2]
|
root_vel_z = env.scene["robot"].data.root_lin_vel_w[:, 2]
|
||||||
velocity_factor = torch.clamp(root_vel_z * 3.0, min=0.0, max=1.5)
|
|
||||||
|
|
||||||
total_reward = force_reward * is_below_pelvis * (1.0 + velocity_factor)
|
# 3. 计算奖励项
|
||||||
|
# A. 受力奖励:鼓励手部与地面产生大于 min_force 的推力
|
||||||
|
# 使用 tanh 归一化,防止力矩过大导致奖励爆炸 (NaN 风险)
|
||||||
|
force_reward = torch.tanh(torch.clamp(max_arm_force - min_force, min=0.0) / 50.0)
|
||||||
|
|
||||||
# 高度越高,手臂奖励越低 (强迫切换到腿)
|
# B. 速度引导:只有当机器人正在“向上起”时,支撑奖励才翻倍
|
||||||
height_fade = torch.clamp((height_threshold - pelvis_pos_z) / 0.1, min=0.0, max=1.0)
|
# 这样可以防止它趴在地上乱按手骗分
|
||||||
return total_reward * height_fade
|
velocity_factor = torch.clamp(root_vel_z, min=0.0, max=2.0)
|
||||||
|
|
||||||
|
# C. 姿态惩罚回避:
|
||||||
|
# 不再检查手是否在盆骨下方,而是检查手是否“在干活”
|
||||||
|
# 只要受力足够大,就认为是在支撑
|
||||||
|
is_supporting = (max_arm_force > min_force).float()
|
||||||
|
|
||||||
|
# 4. 阶段性退出机制 (Curriculum)
|
||||||
|
# 当盆骨高度超过 height_threshold (0.6m) 时,奖励线性消失
|
||||||
|
# 强迫机器人最终依靠腿部力量平衡,而不是一直扶着地
|
||||||
|
height_fade = torch.clamp((height_threshold - pelvis_pos_z) / 0.15, min=0.0, max=1.0)
|
||||||
|
|
||||||
|
# 最终组合
|
||||||
|
# 逻辑:受力 * (1 + 垂直速度) * 高度衰减
|
||||||
|
total_reward = force_reward * (1.0 + 2.0 * velocity_factor) * is_supporting * height_fade
|
||||||
|
|
||||||
|
return total_reward
|
||||||
|
|
||||||
def is_standing_still(
|
def is_standing_still(
|
||||||
env: ManagerBasedRLEnv,
|
env: ManagerBasedRLEnv,
|
||||||
@@ -210,7 +227,7 @@ class T1GetUpRewardCfg:
|
|||||||
# 3. 手臂撑地奖:辅助脱离地面阶段
|
# 3. 手臂撑地奖:辅助脱离地面阶段
|
||||||
arm_push_support = RewTerm(
|
arm_push_support = RewTerm(
|
||||||
func=arm_push_up_reward,
|
func=arm_push_up_reward,
|
||||||
weight=15.0, # 显著增加权重(从 3.0 提到 15.0),让它成为起步的关键
|
weight=20.0, # 显著增加权重(从 3.0 提到 15.0),让它成为起步的关键
|
||||||
params={
|
params={
|
||||||
"sensor_cfg": SceneEntityCfg("contact_sensor", body_names=[".*_hand_link", "AL3", "AR3"]),
|
"sensor_cfg": SceneEntityCfg("contact_sensor", body_names=[".*_hand_link", "AL3", "AR3"]),
|
||||||
"height_threshold": 0.65, # 躯干升到 0.6m 前都鼓励手臂用力
|
"height_threshold": 0.65, # 躯干升到 0.6m 前都鼓励手臂用力
|
||||||
|
|||||||
Reference in New Issue
Block a user