From 648cf32e9c4d8c019aba35373ac393fb7fb6498a Mon Sep 17 00:00:00 2001 From: ChenXi Date: Fri, 13 Mar 2026 08:40:50 -0400 Subject: [PATCH] improve training speed and add speed constrain --- communication/server.py | 56 +++++++++++++++++------------------ scripts/commons/Server.py | 15 ++++++++-- scripts/commons/Train_Base.py | 2 +- scripts/gyms/Walk.py | 45 +++++++++++++++++++++++----- 4 files changed, 78 insertions(+), 40 deletions(-) diff --git a/communication/server.py b/communication/server.py index 03eab86..c956d0a 100644 --- a/communication/server.py +++ b/communication/server.py @@ -72,37 +72,37 @@ class Server: self.commit(msg) self.send() - def receive(self) -> None: - """ - Receive the next message from the TCP/IP socket and updates world - """ + def receive(self): - # Receive message length information - if ( - self.__socket.recv_into( - self.__rcv_buffer, nbytes=4, flags=socket.MSG_WAITALL + while True: + + if ( + self.__socket.recv_into( + self.__rcv_buffer, nbytes=4, flags=socket.MSG_WAITALL + ) != 4 + ): + raise ConnectionResetError + + msg_size = int.from_bytes(self.__rcv_buffer[:4], byteorder="big", signed=False) + + if msg_size > self.__rcv_buffer_size: + self.__rcv_buffer_size = msg_size + self.__rcv_buffer = bytearray(self.__rcv_buffer_size) + + if ( + self.__socket.recv_into( + self.__rcv_buffer, nbytes=msg_size, flags=socket.MSG_WAITALL + ) != msg_size + ): + raise ConnectionResetError + + self.world_parser.parse( + message=self.__rcv_buffer[:msg_size].decode() ) - != 4 - ): - raise ConnectionResetError - msg_size = int.from_bytes(self.__rcv_buffer[:4], byteorder="big", signed=False) - - # Ensure receive buffer is large enough to hold the message - if msg_size > self.__rcv_buffer_size: - self.__rcv_buffer_size = msg_size - self.__rcv_buffer = bytearray(self.__rcv_buffer_size) - - # Receive message with the specified length - if ( - self.__socket.recv_into( - self.__rcv_buffer, nbytes=msg_size, flags=socket.MSG_WAITALL - ) - != msg_size - ): - raise ConnectionResetError - - self.world_parser.parse(message=self.__rcv_buffer[:msg_size].decode()) + # 如果socket没有更多数据就退出 + if len(select([self.__socket], [], [], 0.0)[0]) == 0: + break def commit_beam(self, pos2d: list, rotation: float) -> None: assert len(pos2d) == 2 diff --git a/scripts/commons/Server.py b/scripts/commons/Server.py index 17070e6..8a7763a 100644 --- a/scripts/commons/Server.py +++ b/scripts/commons/Server.py @@ -18,9 +18,18 @@ class Server(): # makes it easier to kill test servers without affecting train servers cmd = "rcssservermj" for i in range(n_servers): + port = first_server_p + i + mport = first_monitor_p + i + + server_cmd = f"{cmd} --aport {port} --mport {mport} --no-render --no-realtime" + self.rcss_processes.append( - subprocess.Popen((f"{cmd} --aport {first_server_p+i} --mport {first_monitor_p+i}").split(), - stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT, start_new_session=True) + subprocess.Popen( + server_cmd.split(), + stdout=subprocess.DEVNULL, + stderr=subprocess.STDOUT, + start_new_session=True + ) ) def check_running_servers(self, psutil, first_server_p, first_monitor_p, n_servers): @@ -60,4 +69,4 @@ class Server(): def kill(self): for p in self.rcss_processes: p.kill() - print(f"Killed {self.n_servers} rcssservermj processes starting at {self.first_server_p}") \ No newline at end of file + print(f"Killed {self.n_servers} rcssservermj processes starting at {self.first_server_p}") diff --git a/scripts/commons/Train_Base.py b/scripts/commons/Train_Base.py index 8f4203f..94109c2 100644 --- a/scripts/commons/Train_Base.py +++ b/scripts/commons/Train_Base.py @@ -46,7 +46,7 @@ class Train_Base(): @staticmethod def prompt_user_for_model(self): - gyms_logs_path = "./mujococodebase/scripts/gyms/logs/" + gyms_logs_path = "./scripts/gyms/logs/" folders = [f for f in listdir(gyms_logs_path) if isdir(join(gyms_logs_path, f))] folders.sort(key=lambda f: os.path.getmtime(join(gyms_logs_path, f)), reverse=True) # sort by modification date diff --git a/scripts/gyms/Walk.py b/scripts/gyms/Walk.py index 4447dde..4d96479 100644 --- a/scripts/gyms/Walk.py +++ b/scripts/gyms/Walk.py @@ -1,6 +1,7 @@ import os import numpy as np import math +import time from time import sleep from random import random from random import uniform @@ -55,6 +56,17 @@ class WalkEnv(gym.Env): self.auto_calibrate_train_sim_flip = True self.nominal_calibrated_once = False self.flip_calibrated_once = False + self._target_hz = 0.0 + self._target_dt = 0.0 + self._last_sync_time = None + target_hz_env = 24 + if target_hz_env: + try: + self._target_hz = float(target_hz_env) + except ValueError: + self._target_hz = 0.0 + if self._target_hz > 0.0: + self._target_dt = 1.0 / self._target_hz # State space @@ -141,7 +153,7 @@ class WalkEnv(gym.Env): self.previous_action = np.zeros(len(self.Player.robot.ROBOT_MOTORS)) self.previous_pos = np.array([0.0, 0.0]) # Track previous position self.Player.server.connect() - sleep(2.0) # Longer wait for connection to establish completely + # sleep(2.0) # Longer wait for connection to establish completely self.Player.server.send_immediate( f"(init {self.Player.robot.name} {self.Player.world.team_name} {self.Player.world.number})" ) @@ -260,6 +272,17 @@ class WalkEnv(gym.Env): self.Player.world.update() self.Player.robot.commit_motor_targets_pd() self.Player.server.send() + if self._target_dt > 0.0: + now = time.time() + if self._last_sync_time is None: + self._last_sync_time = now + return + elapsed = now - self._last_sync_time + remaining = self._target_dt - elapsed + if remaining > 0.0: + time.sleep(remaining) + now = time.time() + self._last_sync_time = now def debug_joint_status(self): robot = self.Player.robot @@ -322,12 +345,12 @@ class WalkEnv(gym.Env): # 执行 Neutral 技能直到完成,给机器人足够时间在 beam 位置稳定站立 finished_count = 0 - for _ in range(20): + for _ in range(10): finished = self.Player.skills_manager.execute("Neutral") self.sync() if finished: finished_count += 1 - if finished_count >= 2: # 假设需要连续2次完成才算成功 + if finished_count >= 3: # 假设需要连续3次完成才算成功 break # neutral_joint_positions = np.deg2rad( @@ -492,9 +515,11 @@ class Train(Train_Base): def train(self, args): #--------------------------------------- Learning parameters - n_envs = 8 # Reduced from 8 to decrease CPU/network pressure during init + n_envs = 8 # Reduced from 8 to decrease CPU/network pressure during init + if n_envs < 1: + raise ValueError("GYM_CPU_N_ENVS must be >= 1") n_steps_per_env = 512 # RolloutBuffer is of size (n_steps_per_env * n_envs) - minibatch_size = 64 # should be a factor of (n_steps_per_env * n_envs) + minibatch_size = 128 # should be a factor of (n_steps_per_env * n_envs) total_steps = 30000000 learning_rate = 3e-4 folder_name = f'Walk_R{self.robot_type}' @@ -509,6 +534,8 @@ class Train(Train_Base): return WalkEnv( self.ip , self.server_p + i_env) return thunk + server_log_dir = os.path.join(model_path, "server_logs") + os.makedirs(server_log_dir, exist_ok=True) servers = Train_Server( self.server_p, self.monitor_p_1000, n_envs+1 ) #include 1 extra server for testing # Wait for servers to start @@ -547,7 +574,7 @@ class Train(Train_Base): gamma=0.99 # Discount factor ) - model_path = self.learn_model( model, total_steps, model_path, eval_env=eval_env, eval_freq=n_steps_per_env*20, save_freq=n_steps_per_env*20, backup_env_file=__file__ ) + model_path = self.learn_model( model, total_steps, model_path, eval_env=eval_env, eval_freq=n_steps_per_env*10, save_freq=n_steps_per_env*10, backup_env_file=__file__ ) except KeyboardInterrupt: sleep(1) # wait for child processes print("\nctrl+c pressed, aborting...\n") @@ -562,7 +589,9 @@ class Train(Train_Base): def test(self, args): # Uses different server and monitor ports - server = Train_Server( self.server_p-1, self.monitor_p, 1 ) + server_log_dir = os.path.join(args["folder_dir"], "server_logs") + os.makedirs(server_log_dir, exist_ok=True) + server = Train_Server( self.server_p-1, self.monitor_p, 1, log_dir=server_log_dir ) env = WalkEnv( self.ip, self.server_p-1 ) model = PPO.load( args["model_file"], env=env ) @@ -592,4 +621,4 @@ if __name__ == "__main__": ) trainer = Train(script_args) - trainer.train({}) \ No newline at end of file + trainer.train({"model_file": "scripts/gyms/logs/Walk_R0_000/model_245760_steps.zip"})