improve training speed and add speed constrain

This commit is contained in:
2026-03-13 08:40:50 -04:00
parent 092fb521e1
commit 648cf32e9c
4 changed files with 78 additions and 40 deletions

View File

@@ -72,37 +72,37 @@ class Server:
self.commit(msg) self.commit(msg)
self.send() self.send()
def receive(self) -> None: def receive(self):
"""
Receive the next message from the TCP/IP socket and updates world while True:
"""
# Receive message length information
if ( if (
self.__socket.recv_into( self.__socket.recv_into(
self.__rcv_buffer, nbytes=4, flags=socket.MSG_WAITALL self.__rcv_buffer, nbytes=4, flags=socket.MSG_WAITALL
) ) != 4
!= 4
): ):
raise ConnectionResetError raise ConnectionResetError
msg_size = int.from_bytes(self.__rcv_buffer[:4], byteorder="big", signed=False) msg_size = int.from_bytes(self.__rcv_buffer[:4], byteorder="big", signed=False)
# Ensure receive buffer is large enough to hold the message
if msg_size > self.__rcv_buffer_size: if msg_size > self.__rcv_buffer_size:
self.__rcv_buffer_size = msg_size self.__rcv_buffer_size = msg_size
self.__rcv_buffer = bytearray(self.__rcv_buffer_size) self.__rcv_buffer = bytearray(self.__rcv_buffer_size)
# Receive message with the specified length
if ( if (
self.__socket.recv_into( self.__socket.recv_into(
self.__rcv_buffer, nbytes=msg_size, flags=socket.MSG_WAITALL self.__rcv_buffer, nbytes=msg_size, flags=socket.MSG_WAITALL
) ) != msg_size
!= msg_size
): ):
raise ConnectionResetError raise ConnectionResetError
self.world_parser.parse(message=self.__rcv_buffer[:msg_size].decode()) self.world_parser.parse(
message=self.__rcv_buffer[:msg_size].decode()
)
# 如果socket没有更多数据就退出
if len(select([self.__socket], [], [], 0.0)[0]) == 0:
break
def commit_beam(self, pos2d: list, rotation: float) -> None: def commit_beam(self, pos2d: list, rotation: float) -> None:
assert len(pos2d) == 2 assert len(pos2d) == 2

View File

@@ -18,9 +18,18 @@ class Server():
# makes it easier to kill test servers without affecting train servers # makes it easier to kill test servers without affecting train servers
cmd = "rcssservermj" cmd = "rcssservermj"
for i in range(n_servers): for i in range(n_servers):
port = first_server_p + i
mport = first_monitor_p + i
server_cmd = f"{cmd} --aport {port} --mport {mport} --no-render --no-realtime"
self.rcss_processes.append( self.rcss_processes.append(
subprocess.Popen((f"{cmd} --aport {first_server_p+i} --mport {first_monitor_p+i}").split(), subprocess.Popen(
stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT, start_new_session=True) server_cmd.split(),
stdout=subprocess.DEVNULL,
stderr=subprocess.STDOUT,
start_new_session=True
)
) )
def check_running_servers(self, psutil, first_server_p, first_monitor_p, n_servers): def check_running_servers(self, psutil, first_server_p, first_monitor_p, n_servers):

View File

@@ -46,7 +46,7 @@ class Train_Base():
@staticmethod @staticmethod
def prompt_user_for_model(self): def prompt_user_for_model(self):
gyms_logs_path = "./mujococodebase/scripts/gyms/logs/" gyms_logs_path = "./scripts/gyms/logs/"
folders = [f for f in listdir(gyms_logs_path) if isdir(join(gyms_logs_path, f))] folders = [f for f in listdir(gyms_logs_path) if isdir(join(gyms_logs_path, f))]
folders.sort(key=lambda f: os.path.getmtime(join(gyms_logs_path, f)), reverse=True) # sort by modification date folders.sort(key=lambda f: os.path.getmtime(join(gyms_logs_path, f)), reverse=True) # sort by modification date

View File

@@ -1,6 +1,7 @@
import os import os
import numpy as np import numpy as np
import math import math
import time
from time import sleep from time import sleep
from random import random from random import random
from random import uniform from random import uniform
@@ -55,6 +56,17 @@ class WalkEnv(gym.Env):
self.auto_calibrate_train_sim_flip = True self.auto_calibrate_train_sim_flip = True
self.nominal_calibrated_once = False self.nominal_calibrated_once = False
self.flip_calibrated_once = False self.flip_calibrated_once = False
self._target_hz = 0.0
self._target_dt = 0.0
self._last_sync_time = None
target_hz_env = 24
if target_hz_env:
try:
self._target_hz = float(target_hz_env)
except ValueError:
self._target_hz = 0.0
if self._target_hz > 0.0:
self._target_dt = 1.0 / self._target_hz
# State space # State space
@@ -141,7 +153,7 @@ class WalkEnv(gym.Env):
self.previous_action = np.zeros(len(self.Player.robot.ROBOT_MOTORS)) self.previous_action = np.zeros(len(self.Player.robot.ROBOT_MOTORS))
self.previous_pos = np.array([0.0, 0.0]) # Track previous position self.previous_pos = np.array([0.0, 0.0]) # Track previous position
self.Player.server.connect() self.Player.server.connect()
sleep(2.0) # Longer wait for connection to establish completely # sleep(2.0) # Longer wait for connection to establish completely
self.Player.server.send_immediate( self.Player.server.send_immediate(
f"(init {self.Player.robot.name} {self.Player.world.team_name} {self.Player.world.number})" f"(init {self.Player.robot.name} {self.Player.world.team_name} {self.Player.world.number})"
) )
@@ -260,6 +272,17 @@ class WalkEnv(gym.Env):
self.Player.world.update() self.Player.world.update()
self.Player.robot.commit_motor_targets_pd() self.Player.robot.commit_motor_targets_pd()
self.Player.server.send() self.Player.server.send()
if self._target_dt > 0.0:
now = time.time()
if self._last_sync_time is None:
self._last_sync_time = now
return
elapsed = now - self._last_sync_time
remaining = self._target_dt - elapsed
if remaining > 0.0:
time.sleep(remaining)
now = time.time()
self._last_sync_time = now
def debug_joint_status(self): def debug_joint_status(self):
robot = self.Player.robot robot = self.Player.robot
@@ -322,12 +345,12 @@ class WalkEnv(gym.Env):
# 执行 Neutral 技能直到完成,给机器人足够时间在 beam 位置稳定站立 # 执行 Neutral 技能直到完成,给机器人足够时间在 beam 位置稳定站立
finished_count = 0 finished_count = 0
for _ in range(20): for _ in range(10):
finished = self.Player.skills_manager.execute("Neutral") finished = self.Player.skills_manager.execute("Neutral")
self.sync() self.sync()
if finished: if finished:
finished_count += 1 finished_count += 1
if finished_count >= 2: # 假设需要连续2次完成才算成功 if finished_count >= 3: # 假设需要连续3次完成才算成功
break break
# neutral_joint_positions = np.deg2rad( # neutral_joint_positions = np.deg2rad(
@@ -493,8 +516,10 @@ class Train(Train_Base):
#--------------------------------------- Learning parameters #--------------------------------------- Learning parameters
n_envs = 8 # Reduced from 8 to decrease CPU/network pressure during init n_envs = 8 # Reduced from 8 to decrease CPU/network pressure during init
if n_envs < 1:
raise ValueError("GYM_CPU_N_ENVS must be >= 1")
n_steps_per_env = 512 # RolloutBuffer is of size (n_steps_per_env * n_envs) n_steps_per_env = 512 # RolloutBuffer is of size (n_steps_per_env * n_envs)
minibatch_size = 64 # should be a factor of (n_steps_per_env * n_envs) minibatch_size = 128 # should be a factor of (n_steps_per_env * n_envs)
total_steps = 30000000 total_steps = 30000000
learning_rate = 3e-4 learning_rate = 3e-4
folder_name = f'Walk_R{self.robot_type}' folder_name = f'Walk_R{self.robot_type}'
@@ -509,6 +534,8 @@ class Train(Train_Base):
return WalkEnv( self.ip , self.server_p + i_env) return WalkEnv( self.ip , self.server_p + i_env)
return thunk return thunk
server_log_dir = os.path.join(model_path, "server_logs")
os.makedirs(server_log_dir, exist_ok=True)
servers = Train_Server( self.server_p, self.monitor_p_1000, n_envs+1 ) #include 1 extra server for testing servers = Train_Server( self.server_p, self.monitor_p_1000, n_envs+1 ) #include 1 extra server for testing
# Wait for servers to start # Wait for servers to start
@@ -547,7 +574,7 @@ class Train(Train_Base):
gamma=0.99 # Discount factor gamma=0.99 # Discount factor
) )
model_path = self.learn_model( model, total_steps, model_path, eval_env=eval_env, eval_freq=n_steps_per_env*20, save_freq=n_steps_per_env*20, backup_env_file=__file__ ) model_path = self.learn_model( model, total_steps, model_path, eval_env=eval_env, eval_freq=n_steps_per_env*10, save_freq=n_steps_per_env*10, backup_env_file=__file__ )
except KeyboardInterrupt: except KeyboardInterrupt:
sleep(1) # wait for child processes sleep(1) # wait for child processes
print("\nctrl+c pressed, aborting...\n") print("\nctrl+c pressed, aborting...\n")
@@ -562,7 +589,9 @@ class Train(Train_Base):
def test(self, args): def test(self, args):
# Uses different server and monitor ports # Uses different server and monitor ports
server = Train_Server( self.server_p-1, self.monitor_p, 1 ) server_log_dir = os.path.join(args["folder_dir"], "server_logs")
os.makedirs(server_log_dir, exist_ok=True)
server = Train_Server( self.server_p-1, self.monitor_p, 1, log_dir=server_log_dir )
env = WalkEnv( self.ip, self.server_p-1 ) env = WalkEnv( self.ip, self.server_p-1 )
model = PPO.load( args["model_file"], env=env ) model = PPO.load( args["model_file"], env=env )
@@ -592,4 +621,4 @@ if __name__ == "__main__":
) )
trainer = Train(script_args) trainer = Train(script_args)
trainer.train({}) trainer.train({"model_file": "scripts/gyms/logs/Walk_R0_000/model_245760_steps.zip"})