Compare commits

...

3 Commits

Author SHA1 Message Date
xxh
6ab356a947 improve train speed and add speed constrain 2026-03-13 08:51:49 -04:00
3a42120857 Revert "improve training speed and add speed constrain"
This reverts commit 648cf32e9c.
2026-03-13 08:43:28 -04:00
648cf32e9c improve training speed and add speed constrain 2026-03-13 08:40:50 -04:00
4 changed files with 317 additions and 281 deletions

View File

@@ -72,37 +72,37 @@ class Server:
self.commit(msg) self.commit(msg)
self.send() self.send()
def receive(self) -> None: def receive(self):
"""
Receive the next message from the TCP/IP socket and updates world while True:
"""
# Receive message length information
if ( if (
self.__socket.recv_into( self.__socket.recv_into(
self.__rcv_buffer, nbytes=4, flags=socket.MSG_WAITALL self.__rcv_buffer, nbytes=4, flags=socket.MSG_WAITALL
) ) != 4
!= 4
): ):
raise ConnectionResetError raise ConnectionResetError
msg_size = int.from_bytes(self.__rcv_buffer[:4], byteorder="big", signed=False) msg_size = int.from_bytes(self.__rcv_buffer[:4], byteorder="big", signed=False)
# Ensure receive buffer is large enough to hold the message
if msg_size > self.__rcv_buffer_size: if msg_size > self.__rcv_buffer_size:
self.__rcv_buffer_size = msg_size self.__rcv_buffer_size = msg_size
self.__rcv_buffer = bytearray(self.__rcv_buffer_size) self.__rcv_buffer = bytearray(self.__rcv_buffer_size)
# Receive message with the specified length
if ( if (
self.__socket.recv_into( self.__socket.recv_into(
self.__rcv_buffer, nbytes=msg_size, flags=socket.MSG_WAITALL self.__rcv_buffer, nbytes=msg_size, flags=socket.MSG_WAITALL
) ) != msg_size
!= msg_size
): ):
raise ConnectionResetError raise ConnectionResetError
self.world_parser.parse(message=self.__rcv_buffer[:msg_size].decode()) self.world_parser.parse(
message=self.__rcv_buffer[:msg_size].decode()
)
# 如果socket没有更多数据就退出
if len(select([self.__socket], [], [], 0.0)[0]) == 0:
break
def commit_beam(self, pos2d: list, rotation: float) -> None: def commit_beam(self, pos2d: list, rotation: float) -> None:
assert len(pos2d) == 2 assert len(pos2d) == 2

View File

@@ -18,9 +18,18 @@ class Server():
# makes it easier to kill test servers without affecting train servers # makes it easier to kill test servers without affecting train servers
cmd = "rcssservermj" cmd = "rcssservermj"
for i in range(n_servers): for i in range(n_servers):
port = first_server_p + i
mport = first_monitor_p + i
server_cmd = f"{cmd} --aport {port} --mport {mport} --no-render --no-realtime"
self.rcss_processes.append( self.rcss_processes.append(
subprocess.Popen((f"{cmd} --aport {first_server_p+i} --mport {first_monitor_p+i}").split(), subprocess.Popen(
stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT, start_new_session=True) server_cmd.split(),
stdout=subprocess.DEVNULL,
stderr=subprocess.STDOUT,
start_new_session=True
)
) )
def check_running_servers(self, psutil, first_server_p, first_monitor_p, n_servers): def check_running_servers(self, psutil, first_server_p, first_monitor_p, n_servers):
@@ -56,7 +65,6 @@ class Server():
p.kill() p.kill()
return return
def kill(self): def kill(self):
for p in self.rcss_processes: for p in self.rcss_processes:
p.kill() p.kill()

View File

@@ -41,12 +41,10 @@ class Train_Base():
self.cf_delay = 0 self.cf_delay = 0
# self.cf_target_period = World.STEPTIME # target simulation speed while testing (default: real-time) # self.cf_target_period = World.STEPTIME # target simulation speed while testing (default: real-time)
@staticmethod @staticmethod
def prompt_user_for_model(self): def prompt_user_for_model(self):
gyms_logs_path = "./mujococodebase/scripts/gyms/logs/" gyms_logs_path = "./scripts/gyms/logs/"
folders = [f for f in listdir(gyms_logs_path) if isdir(join(gyms_logs_path, f))] folders = [f for f in listdir(gyms_logs_path) if isdir(join(gyms_logs_path, f))]
folders.sort(key=lambda f: os.path.getmtime(join(gyms_logs_path, f)), reverse=True) # sort by modification date folders.sort(key=lambda f: os.path.getmtime(join(gyms_logs_path, f)), reverse=True) # sort by modification date
@@ -64,7 +62,8 @@ class Train_Base():
print("The chosen folder does not contain any .zip file!") print("The chosen folder does not contain any .zip file!")
continue continue
models.sort(key=lambda m: os.path.getmtime(join(folder_dir, m+".zip")), reverse=True) # sort by modification date models.sort(key=lambda m: os.path.getmtime(join(folder_dir, m + ".zip")),
reverse=True) # sort by modification date
try: try:
model_name = UI.print_list(models, prompt="Choose model (ctrl+c to return): ")[1] model_name = UI.print_list(models, prompt="Choose model (ctrl+c to return): ")[1]
@@ -72,8 +71,8 @@ class Train_Base():
except KeyboardInterrupt: except KeyboardInterrupt:
print() print()
return {"folder_dir":folder_dir, "folder_name":folder_name, "model_file":os.path.join(folder_dir, model_name+".zip")} return {"folder_dir": folder_dir, "folder_name": folder_name,
"model_file": os.path.join(folder_dir, model_name + ".zip")}
# def control_fps(self, read_input = False): # def control_fps(self, read_input = False):
# ''' Add delay to control simulation speed ''' # ''' Add delay to control simulation speed '''
@@ -108,8 +107,8 @@ class Train_Base():
# else: # else:
# self.cf_delay = 0 # self.cf_delay = 0
def test_model(self, model: BaseAlgorithm, env, log_path: str = None, model_path: str = None, max_episodes=0,
def test_model(self, model:BaseAlgorithm, env, log_path:str=None, model_path:str=None, max_episodes=0, enable_FPS_control=True, verbose=1): enable_FPS_control=True, verbose=1):
''' '''
Test model and log results Test model and log results
@@ -186,8 +185,10 @@ class Train_Base():
avg_rewards = rewards_sum / ep_no avg_rewards = rewards_sum / ep_no
if verbose > 0: if verbose > 0:
print( f"\rEpisode: {ep_no:<3} Ep.Length: {ep_length:<4.0f} Reward: {ep_reward:<6.2f} \n", print(
end=f"--AVERAGE-- Ep.Length: {avg_ep_lengths:<4.0f} Reward: {avg_rewards:<6.2f} (Min: {reward_min:<6.2f} Max: {reward_max:<6.2f})", flush=True) f"\rEpisode: {ep_no:<3} Ep.Length: {ep_length:<4.0f} Reward: {ep_reward:<6.2f} \n",
end=f"--AVERAGE-- Ep.Length: {avg_ep_lengths:<4.0f} Reward: {avg_rewards:<6.2f} (Min: {reward_min:<6.2f} Max: {reward_max:<6.2f})",
flush=True)
if log_path is not None: if log_path is not None:
with open(log_path, 'a') as f: with open(log_path, 'a') as f:
@@ -200,7 +201,8 @@ class Train_Base():
ep_reward = 0 ep_reward = 0
ep_length = 0 ep_length = 0
def learn_model(self, model:BaseAlgorithm, total_steps:int, path:str, eval_env=None, eval_freq=None, eval_eps=5, save_freq=None, backup_env_file=None, export_name=None): def learn_model(self, model: BaseAlgorithm, total_steps: int, path: str, eval_env=None, eval_freq=None, eval_eps=5,
save_freq=None, backup_env_file=None, export_name=None):
''' '''
Learn Model for a specific number of time steps Learn Model for a specific number of time steps
@@ -265,19 +267,25 @@ class Train_Base():
evaluate = bool(eval_env is not None and eval_freq is not None) evaluate = bool(eval_env is not None and eval_freq is not None)
# Create evaluation callback # Create evaluation callback
eval_callback = None if not evaluate else EvalCallback(eval_env, n_eval_episodes=eval_eps, eval_freq=eval_freq, log_path=path, eval_callback = None if not evaluate else EvalCallback(eval_env, n_eval_episodes=eval_eps, eval_freq=eval_freq,
best_model_save_path=path, deterministic=True, render=False) log_path=path,
best_model_save_path=path, deterministic=True,
render=False)
# Create custom callback to display evaluations # Create custom callback to display evaluations
custom_callback = None if not evaluate else Cyclic_Callback(eval_freq, lambda:self.display_evaluations(path,True)) custom_callback = None if not evaluate else Cyclic_Callback(eval_freq,
lambda: self.display_evaluations(path, True))
# Create checkpoint callback # Create checkpoint callback
checkpoint_callback = None if save_freq is None else CheckpointCallback(save_freq=save_freq, save_path=path, name_prefix="model", verbose=1) checkpoint_callback = None if save_freq is None else CheckpointCallback(save_freq=save_freq, save_path=path,
name_prefix="model", verbose=1)
# Create custom callback to export checkpoint models # Create custom callback to export checkpoint models
export_callback = None if save_freq is None or export_name is None else Export_Callback(save_freq, path, export_name) export_callback = None if save_freq is None or export_name is None else Export_Callback(save_freq, path,
export_name)
callbacks = CallbackList([c for c in [eval_callback, custom_callback, checkpoint_callback, export_callback] if c is not None]) callbacks = CallbackList(
[c for c in [eval_callback, custom_callback, checkpoint_callback, export_callback] if c is not None])
model.learn(total_timesteps=total_steps, callback=callbacks) model.learn(total_timesteps=total_steps, callback=callbacks)
model.save(os.path.join(path, "last_model")) model.save(os.path.join(path, "last_model"))
@@ -329,8 +337,10 @@ class Train_Base():
results_limits = np.min(results), np.max(results) results_limits = np.min(results), np.max(results)
ep_lengths_limits = np.min(ep_lengths), np.max(ep_lengths) ep_lengths_limits = np.min(ep_lengths), np.max(ep_lengths)
results_discrete = np.digitize(results, np.linspace(results_limits[0]-1e-5, results_limits[1]+1e-5, console_height+1))-1 results_discrete = np.digitize(results, np.linspace(results_limits[0] - 1e-5, results_limits[1] + 1e-5,
ep_lengths_discrete = np.digitize(ep_lengths, np.linspace(0, ep_lengths_limits[1]+1e-5, console_height+1))-1 console_height + 1)) - 1
ep_lengths_discrete = np.digitize(ep_lengths,
np.linspace(0, ep_lengths_limits[1] + 1e-5, console_height + 1)) - 1
matrix = np.zeros((console_height, console_width, 2), int) matrix = np.zeros((console_height, console_width, 2), int)
matrix[results_discrete[0]][0][0] = 1 # draw 1st column matrix[results_discrete[0]][0][0] = 1 # draw 1st column
@@ -353,14 +363,19 @@ class Train_Base():
print(f'{"-" * console_width}') print(f'{"-" * console_width}')
for l in reversed(range(console_height)): for l in reversed(range(console_height)):
for c in range(console_width): for c in range(console_width):
if np.all(matrix[l][c] == 0): print(end=" ") if np.all(matrix[l][c] == 0):
elif np.all(matrix[l][c] == 1): print(end=symb_xo) print(end=" ")
elif matrix[l][c][0] == 1: print(end=symb_x) elif np.all(matrix[l][c] == 1):
else: print(end=symb_o) print(end=symb_xo)
elif matrix[l][c][0] == 1:
print(end=symb_x)
else:
print(end=symb_o)
print() print()
print(f'{"-" * console_width}') print(f'{"-" * console_width}')
print(f"({symb_x})-reward min:{results_limits[0]:11.2f} max:{results_limits[1]:11.2f}") print(f"({symb_x})-reward min:{results_limits[0]:11.2f} max:{results_limits[1]:11.2f}")
print(f"({symb_o})-ep. length min:{ep_lengths_limits[0]:11.0f} max:{ep_lengths_limits[1]:11.0f} {time_steps[-1]/1000:15.0f}k steps") print(
f"({symb_o})-ep. length min:{ep_lengths_limits[0]:11.0f} max:{ep_lengths_limits[1]:11.0f} {time_steps[-1] / 1000:15.0f}k steps")
print(f'{"-" * console_width}') print(f'{"-" * console_width}')
# save CSV # save CSV
@@ -372,7 +387,6 @@ class Train_Base():
writer.writerow(["time_steps", "reward ep.", "length"]) writer.writerow(["time_steps", "reward ep.", "length"])
writer.writerow([time_steps[-1], results_raw[-1], ep_lengths_raw[-1]]) writer.writerow([time_steps[-1], results_raw[-1], ep_lengths_raw[-1]])
# def generate_slot_behavior(self, path, slots, auto_head:bool, XML_name): # def generate_slot_behavior(self, path, slots, auto_head:bool, XML_name):
# ''' # '''
# Function that generates the XML file for the optimized slot behavior, overwriting previous files # Function that generates the XML file for the optimized slot behavior, overwriting previous files
@@ -462,14 +476,14 @@ class Train_Base():
for i in count(0, 2): # add hidden layers (step=2 because that's how SB3 works) for i in count(0, 2): # add hidden layers (step=2 because that's how SB3 works)
if f"mlp_extractor.policy_net.{i}.bias" not in weights: if f"mlp_extractor.policy_net.{i}.bias" not in weights:
break break
var_list.append([w(f"mlp_extractor.policy_net.{i}.bias"), w(f"mlp_extractor.policy_net.{i}.weight"), "tanh"]) var_list.append(
[w(f"mlp_extractor.policy_net.{i}.bias"), w(f"mlp_extractor.policy_net.{i}.weight"), "tanh"])
var_list.append([w("action_net.bias"), w("action_net.weight"), "none"]) # add final layer var_list.append([w("action_net.bias"), w("action_net.weight"), "none"]) # add final layer
with open(output_file, "wb") as f: with open(output_file, "wb") as f:
pickle.dump(var_list, f, protocol=4) # protocol 4 is backward compatible with Python 3.4 pickle.dump(var_list, f, protocol=4) # protocol 4 is backward compatible with Python 3.4
def print_list(data, numbering=True, prompt=None, divider=" | ", alignment="<", min_per_col=6): def print_list(data, numbering=True, prompt=None, divider=" | ", alignment="<", min_per_col=6):
''' '''
Print list - prints list, using as many columns as possible Print list - prints list, using as many columns as possible
@@ -509,7 +523,8 @@ class Train_Base():
items.append(f"{divider}{number}{data[i]}") items.append(f"{divider}{number}{data[i]}")
items_len.append(len(items[-1])) items_len.append(len(items[-1]))
max_cols = np.clip((WIDTH+len(divider)) // min(items_len),1,math.ceil(data_size/max(min_per_col,1))) # width + len(divider) because it is not needed in last col max_cols = np.clip((WIDTH + len(divider)) // min(items_len), 1, math.ceil(
data_size / max(min_per_col, 1))) # width + len(divider) because it is not needed in last col
# --------------------------------------------- Check maximum number of columns, considering content width (min:1) # --------------------------------------------- Check maximum number of columns, considering content width (min:1)
for i in range(max_cols, 0, -1): for i in range(max_cols, 0, -1):
@@ -532,7 +547,8 @@ class Train_Base():
print("=" * table_width) print("=" * table_width)
for row in range(math.ceil(data_size / i)): for row in range(math.ceil(data_size / i)):
for col in range(i): for col in range(i):
content = cols_items[col][row] if len(cols_items[col]) > row else divider # print divider when there are no items content = cols_items[col][row] if len(
cols_items[col]) > row else divider # print divider when there are no items
if col == 0: if col == 0:
l = len(divider) l = len(divider)
print(end=f"{content[l:]:{alignment}{cols_width[col] - l}}") # remove divider from 1st col print(end=f"{content[l:]:{alignment}{cols_width[col] - l}}") # remove divider from 1st col
@@ -552,9 +568,9 @@ class Train_Base():
return idx, data[idx] return idx, data[idx]
class Cyclic_Callback(BaseCallback): class Cyclic_Callback(BaseCallback):
''' Stable baselines custom callback ''' ''' Stable baselines custom callback '''
def __init__(self, freq, function): def __init__(self, freq, function):
super(Cyclic_Callback, self).__init__(1) super(Cyclic_Callback, self).__init__(1)
self.freq = freq self.freq = freq
@@ -565,8 +581,10 @@ class Cyclic_Callback(BaseCallback):
self.function() self.function()
return True # If the callback returns False, training is aborted early return True # If the callback returns False, training is aborted early
class Export_Callback(BaseCallback): class Export_Callback(BaseCallback):
''' Stable baselines custom callback ''' ''' Stable baselines custom callback '''
def __init__(self, freq, load_path, export_name): def __init__(self, freq, load_path, export_name):
super(Export_Callback, self).__init__(1) super(Export_Callback, self).__init__(1)
self.freq = freq self.freq = freq
@@ -581,4 +599,3 @@ class Export_Callback(BaseCallback):

View File

@@ -1,11 +1,11 @@
import os import os
import numpy as np import numpy as np
import math import math
import time
from time import sleep from time import sleep
from random import random from random import random
from random import uniform from random import uniform
from stable_baselines3 import PPO from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv from stable_baselines3.common.vec_env import SubprocVecEnv
@@ -28,11 +28,10 @@ Learn how to run forward using step primitive
- class Train: implements algorithms to train a new model or test an existing model - class Train: implements algorithms to train a new model or test an existing model
''' '''
class WalkEnv(gym.Env): class WalkEnv(gym.Env):
def __init__(self, ip, server_p) -> None: def __init__(self, ip, server_p) -> None:
# Args: Server IP, Agent Port, Monitor Port, Uniform No., Robot Type, Team Name, Enable Log, Enable Draw # Args: Server IP, Agent Port, Monitor Port, Uniform No., Robot Type, Team Name, Enable Log, Enable Draw
self.Player = player = Base_Agent( self.Player = player = Base_Agent(
team_name="Gym", team_name="Gym",
@@ -55,7 +54,17 @@ class WalkEnv(gym.Env):
self.auto_calibrate_train_sim_flip = True self.auto_calibrate_train_sim_flip = True
self.nominal_calibrated_once = False self.nominal_calibrated_once = False
self.flip_calibrated_once = False self.flip_calibrated_once = False
self._target_hz = 0.0
self._target_dt = 0.0
self._last_sync_time = None
target_hz_env = 24
if target_hz_env:
try:
self._target_hz = float(target_hz_env)
except ValueError:
self._target_hz = 0.0
if self._target_hz > 0.0:
self._target_dt = 1.0 / self._target_hz
# State space # State space
# 原始观测大小: 78 # 原始观测大小: 78
@@ -77,7 +86,6 @@ class WalkEnv(gym.Env):
dtype=np.float32 dtype=np.float32
) )
# 中立姿态 # 中立姿态
self.joint_nominal_position = np.array( self.joint_nominal_position = np.array(
[ [
@@ -141,7 +149,7 @@ class WalkEnv(gym.Env):
self.previous_action = np.zeros(len(self.Player.robot.ROBOT_MOTORS)) self.previous_action = np.zeros(len(self.Player.robot.ROBOT_MOTORS))
self.previous_pos = np.array([0.0, 0.0]) # Track previous position self.previous_pos = np.array([0.0, 0.0]) # Track previous position
self.Player.server.connect() self.Player.server.connect()
sleep(2.0) # Longer wait for connection to establish completely # sleep(2.0) # Longer wait for connection to establish completely
self.Player.server.send_immediate( self.Player.server.send_immediate(
f"(init {self.Player.robot.name} {self.Player.world.team_name} {self.Player.world.number})" f"(init {self.Player.robot.name} {self.Player.world.team_name} {self.Player.world.number})"
) )
@@ -194,7 +202,6 @@ class WalkEnv(gym.Env):
return reliable, leg_norm, leg_max, height return reliable, leg_norm, leg_max, height
def observe(self, init=False): def observe(self, init=False):
"""获取当前观测值""" """获取当前观测值"""
@@ -240,7 +247,6 @@ class WalkEnv(gym.Env):
orientation_quat_inv = R.from_quat(robot._global_cheat_orientation).inv() orientation_quat_inv = R.from_quat(robot._global_cheat_orientation).inv()
projected_gravity = orientation_quat_inv.apply(np.array([0.0, 0.0, -1.0])) projected_gravity = orientation_quat_inv.apply(np.array([0.0, 0.0, -1.0]))
# 组合观测 # 组合观测
observation = np.concatenate([ observation = np.concatenate([
qpos_qvel_previous_action, qpos_qvel_previous_action,
@@ -249,8 +255,6 @@ class WalkEnv(gym.Env):
projected_gravity, projected_gravity,
]) ])
observation = np.clip(observation, -10.0, 10.0) observation = np.clip(observation, -10.0, 10.0)
return observation.astype(np.float32) return observation.astype(np.float32)
@@ -260,6 +264,17 @@ class WalkEnv(gym.Env):
self.Player.world.update() self.Player.world.update()
self.Player.robot.commit_motor_targets_pd() self.Player.robot.commit_motor_targets_pd()
self.Player.server.send() self.Player.server.send()
if self._target_dt > 0.0:
now = time.time()
if self._last_sync_time is None:
self._last_sync_time = now
return
elapsed = now - self._last_sync_time
remaining = self._target_dt - elapsed
if remaining > 0.0:
time.sleep(remaining)
now = time.time()
self._last_sync_time = now
def debug_joint_status(self): def debug_joint_status(self):
robot = self.Player.robot robot = self.Player.robot
@@ -301,7 +316,6 @@ class WalkEnv(gym.Env):
angle2 = np.random.uniform(-30, 30) # randomize initial orientation angle2 = np.random.uniform(-30, 30) # randomize initial orientation
angle3 = np.random.uniform(-30, 30) # randomize target direction angle3 = np.random.uniform(-30, 30) # randomize target direction
self.step_counter = 0 self.step_counter = 0
self.waypoint_index = 0 self.waypoint_index = 0
self.route_completed = False self.route_completed = False
@@ -322,12 +336,12 @@ class WalkEnv(gym.Env):
# 执行 Neutral 技能直到完成,给机器人足够时间在 beam 位置稳定站立 # 执行 Neutral 技能直到完成,给机器人足够时间在 beam 位置稳定站立
finished_count = 0 finished_count = 0
for _ in range(20): for _ in range(10):
finished = self.Player.skills_manager.execute("Neutral") finished = self.Player.skills_manager.execute("Neutral")
self.sync() self.sync()
if finished: if finished:
finished_count += 1 finished_count += 1
if finished_count >= 2: # 假设需要连续2次完成才算成功 if finished_count >= 3: # 假设需要连续3次完成才算成功
break break
# neutral_joint_positions = np.deg2rad( # neutral_joint_positions = np.deg2rad(
@@ -356,13 +370,11 @@ class WalkEnv(gym.Env):
# reset_action_noise = np.random.uniform(-0.015, 0.015, size=(len(self.Player.robot.ROBOT_MOTORS),)) # reset_action_noise = np.random.uniform(-0.015, 0.015, size=(len(self.Player.robot.ROBOT_MOTORS),))
# self.target_joint_positions = (self.joint_nominal_position + reset_action_noise) * self.train_sim_flip # self.target_joint_positions = (self.joint_nominal_position + reset_action_noise) * self.train_sim_flip
# for idx, target in enumerate(self.target_joint_positions): # for idx, target in enumerate(self.target_joint_positions):
# r.set_motor_target_position( # r.set_motor_target_position(
# r.ROBOT_MOTORS[idx], target*180/math.pi, kp=25, kd=0.6 # r.ROBOT_MOTORS[idx], target*180/math.pi, kp=25, kd=0.6
# ) # )
# memory variables # memory variables
self.initial_position = np.array(self.Player.world.global_position[:2]) self.initial_position = np.array(self.Player.world.global_position[:2])
self.previous_pos = self.initial_position.copy() # Critical: set to actual position self.previous_pos = self.initial_position.copy() # Critical: set to actual position
@@ -438,7 +450,6 @@ class WalkEnv(gym.Env):
r = self.Player.robot r = self.Player.robot
self.previous_action = action self.previous_action = action
self.target_joint_positions = ( self.target_joint_positions = (
@@ -447,15 +458,11 @@ class WalkEnv(gym.Env):
) )
self.target_joint_positions *= self.train_sim_flip self.target_joint_positions *= self.train_sim_flip
for idx, target in enumerate(self.target_joint_positions): for idx, target in enumerate(self.target_joint_positions):
r.set_motor_target_position( r.set_motor_target_position(
r.ROBOT_MOTORS[idx], target * 180 / math.pi, kp=25, kd=0.6 r.ROBOT_MOTORS[idx], target * 180 / math.pi, kp=25, kd=0.6
) )
self.sync() # run simulation step self.sync() # run simulation step
self.step_counter += 1 self.step_counter += 1
@@ -467,7 +474,6 @@ class WalkEnv(gym.Env):
# Compute reward based on movement from previous step # Compute reward based on movement from previous step
reward = self.compute_reward(self.previous_pos, current_pos, action) reward = self.compute_reward(self.previous_pos, current_pos, action)
# Update previous position # Update previous position
self.previous_pos = current_pos.copy() self.previous_pos = current_pos.copy()
@@ -481,20 +487,18 @@ class WalkEnv(gym.Env):
return self.observe(), reward, terminated, truncated, {} return self.observe(), reward, terminated, truncated, {}
class Train(Train_Base): class Train(Train_Base):
def __init__(self, script) -> None: def __init__(self, script) -> None:
super().__init__(script) super().__init__(script)
def train(self, args): def train(self, args):
# --------------------------------------- Learning parameters # --------------------------------------- Learning parameters
n_envs = 8 # Reduced from 8 to decrease CPU/network pressure during init n_envs = 8 # Reduced from 8 to decrease CPU/network pressure during init
if n_envs < 1:
raise ValueError("GYM_CPU_N_ENVS must be >= 1")
n_steps_per_env = 512 # RolloutBuffer is of size (n_steps_per_env * n_envs) n_steps_per_env = 512 # RolloutBuffer is of size (n_steps_per_env * n_envs)
minibatch_size = 64 # should be a factor of (n_steps_per_env * n_envs) minibatch_size = 128 # should be a factor of (n_steps_per_env * n_envs)
total_steps = 30000000 total_steps = 30000000
learning_rate = 3e-4 learning_rate = 3e-4
folder_name = f'Walk_R{self.robot_type}' folder_name = f'Walk_R{self.robot_type}'
@@ -507,8 +511,11 @@ class Train(Train_Base):
def init_env(i_env): def init_env(i_env):
def thunk(): def thunk():
return WalkEnv(self.ip, self.server_p + i_env) return WalkEnv(self.ip, self.server_p + i_env)
return thunk return thunk
server_log_dir = os.path.join(model_path, "server_logs")
os.makedirs(server_log_dir, exist_ok=True)
servers = Train_Server(self.server_p, self.monitor_p_1000, n_envs + 1) # include 1 extra server for testing servers = Train_Server(self.server_p, self.monitor_p_1000, n_envs + 1) # include 1 extra server for testing
# Wait for servers to start # Wait for servers to start
@@ -518,7 +525,6 @@ class Train(Train_Base):
env = SubprocVecEnv([init_env(i) for i in range(n_envs)]) env = SubprocVecEnv([init_env(i) for i in range(n_envs)])
eval_env = SubprocVecEnv([init_env(n_envs)]) eval_env = SubprocVecEnv([init_env(n_envs)])
try: try:
# Custom policy network architecture # Custom policy network architecture
policy_kwargs = dict( policy_kwargs = dict(
@@ -530,7 +536,8 @@ class Train(Train_Base):
) )
if "model_file" in args: # retrain if "model_file" in args: # retrain
model = PPO.load( args["model_file"], env=env, device="cpu", n_envs=n_envs, n_steps=n_steps_per_env, batch_size=minibatch_size, learning_rate=learning_rate ) model = PPO.load(args["model_file"], env=env, device="cpu", n_envs=n_envs, n_steps=n_steps_per_env,
batch_size=minibatch_size, learning_rate=learning_rate)
else: # train new model else: # train new model
model = PPO( model = PPO(
"MlpPolicy", "MlpPolicy",
@@ -547,7 +554,9 @@ class Train(Train_Base):
gamma=0.99 # Discount factor gamma=0.99 # Discount factor
) )
model_path = self.learn_model( model, total_steps, model_path, eval_env=eval_env, eval_freq=n_steps_per_env*20, save_freq=n_steps_per_env*20, backup_env_file=__file__ ) model_path = self.learn_model(model, total_steps, model_path, eval_env=eval_env,
eval_freq=n_steps_per_env * 10, save_freq=n_steps_per_env * 10,
backup_env_file=__file__)
except KeyboardInterrupt: except KeyboardInterrupt:
sleep(1) # wait for child processes sleep(1) # wait for child processes
print("\nctrl+c pressed, aborting...\n") print("\nctrl+c pressed, aborting...\n")
@@ -558,16 +567,18 @@ class Train(Train_Base):
eval_env.close() eval_env.close()
servers.kill() servers.kill()
def test(self, args): def test(self, args):
# Uses different server and monitor ports # Uses different server and monitor ports
server = Train_Server( self.server_p-1, self.monitor_p, 1 ) server_log_dir = os.path.join(args["folder_dir"], "server_logs")
os.makedirs(server_log_dir, exist_ok=True)
server = Train_Server(self.server_p - 1, self.monitor_p, 1, log_dir=server_log_dir)
env = WalkEnv(self.ip, self.server_p - 1) env = WalkEnv(self.ip, self.server_p - 1)
model = PPO.load(args["model_file"], env=env) model = PPO.load(args["model_file"], env=env)
try: try:
self.export_model( args["model_file"], args["model_file"]+".pkl", False ) # Export to pkl to create custom behavior self.export_model(args["model_file"], args["model_file"] + ".pkl",
False) # Export to pkl to create custom behavior
self.test_model(model, env, log_path=args["folder_dir"], model_path=args["folder_dir"]) self.test_model(model, env, log_path=args["folder_dir"], model_path=args["folder_dir"])
except KeyboardInterrupt: except KeyboardInterrupt:
print() print()
@@ -592,4 +603,4 @@ if __name__ == "__main__":
) )
trainer = Train(script_args) trainer = Train(script_args)
trainer.train({}) trainer.train({"model_file": "scripts/gyms/logs/Walk_R0_000/model_245760_steps.zip"})