From 3a42120857dc17afe9ef6cc8848aaffc350db3c7 Mon Sep 17 00:00:00 2001 From: ChenXi Date: Fri, 13 Mar 2026 08:43:28 -0400 Subject: [PATCH] Revert "improve training speed and add speed constrain" This reverts commit 648cf32e9c4d8c019aba35373ac393fb7fb6498a. --- communication/server.py | 56 +++++++++++++++++------------------ scripts/commons/Server.py | 15 ++-------- scripts/commons/Train_Base.py | 2 +- scripts/gyms/Walk.py | 45 +++++----------------------- 4 files changed, 40 insertions(+), 78 deletions(-) diff --git a/communication/server.py b/communication/server.py index c956d0a..03eab86 100644 --- a/communication/server.py +++ b/communication/server.py @@ -72,37 +72,37 @@ class Server: self.commit(msg) self.send() - def receive(self): + def receive(self) -> None: + """ + Receive the next message from the TCP/IP socket and updates world + """ - while True: - - if ( - self.__socket.recv_into( - self.__rcv_buffer, nbytes=4, flags=socket.MSG_WAITALL - ) != 4 - ): - raise ConnectionResetError - - msg_size = int.from_bytes(self.__rcv_buffer[:4], byteorder="big", signed=False) - - if msg_size > self.__rcv_buffer_size: - self.__rcv_buffer_size = msg_size - self.__rcv_buffer = bytearray(self.__rcv_buffer_size) - - if ( - self.__socket.recv_into( - self.__rcv_buffer, nbytes=msg_size, flags=socket.MSG_WAITALL - ) != msg_size - ): - raise ConnectionResetError - - self.world_parser.parse( - message=self.__rcv_buffer[:msg_size].decode() + # Receive message length information + if ( + self.__socket.recv_into( + self.__rcv_buffer, nbytes=4, flags=socket.MSG_WAITALL ) + != 4 + ): + raise ConnectionResetError - # 如果socket没有更多数据就退出 - if len(select([self.__socket], [], [], 0.0)[0]) == 0: - break + msg_size = int.from_bytes(self.__rcv_buffer[:4], byteorder="big", signed=False) + + # Ensure receive buffer is large enough to hold the message + if msg_size > self.__rcv_buffer_size: + self.__rcv_buffer_size = msg_size + self.__rcv_buffer = bytearray(self.__rcv_buffer_size) + + # Receive message with the specified length + if ( + self.__socket.recv_into( + self.__rcv_buffer, nbytes=msg_size, flags=socket.MSG_WAITALL + ) + != msg_size + ): + raise ConnectionResetError + + self.world_parser.parse(message=self.__rcv_buffer[:msg_size].decode()) def commit_beam(self, pos2d: list, rotation: float) -> None: assert len(pos2d) == 2 diff --git a/scripts/commons/Server.py b/scripts/commons/Server.py index 8a7763a..17070e6 100644 --- a/scripts/commons/Server.py +++ b/scripts/commons/Server.py @@ -18,18 +18,9 @@ class Server(): # makes it easier to kill test servers without affecting train servers cmd = "rcssservermj" for i in range(n_servers): - port = first_server_p + i - mport = first_monitor_p + i - - server_cmd = f"{cmd} --aport {port} --mport {mport} --no-render --no-realtime" - self.rcss_processes.append( - subprocess.Popen( - server_cmd.split(), - stdout=subprocess.DEVNULL, - stderr=subprocess.STDOUT, - start_new_session=True - ) + subprocess.Popen((f"{cmd} --aport {first_server_p+i} --mport {first_monitor_p+i}").split(), + stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT, start_new_session=True) ) def check_running_servers(self, psutil, first_server_p, first_monitor_p, n_servers): @@ -69,4 +60,4 @@ class Server(): def kill(self): for p in self.rcss_processes: p.kill() - print(f"Killed {self.n_servers} rcssservermj processes starting at {self.first_server_p}") + print(f"Killed {self.n_servers} rcssservermj processes starting at {self.first_server_p}") \ No newline at end of file diff --git a/scripts/commons/Train_Base.py b/scripts/commons/Train_Base.py index 94109c2..8f4203f 100644 --- a/scripts/commons/Train_Base.py +++ b/scripts/commons/Train_Base.py @@ -46,7 +46,7 @@ class Train_Base(): @staticmethod def prompt_user_for_model(self): - gyms_logs_path = "./scripts/gyms/logs/" + gyms_logs_path = "./mujococodebase/scripts/gyms/logs/" folders = [f for f in listdir(gyms_logs_path) if isdir(join(gyms_logs_path, f))] folders.sort(key=lambda f: os.path.getmtime(join(gyms_logs_path, f)), reverse=True) # sort by modification date diff --git a/scripts/gyms/Walk.py b/scripts/gyms/Walk.py index 4d96479..4447dde 100644 --- a/scripts/gyms/Walk.py +++ b/scripts/gyms/Walk.py @@ -1,7 +1,6 @@ import os import numpy as np import math -import time from time import sleep from random import random from random import uniform @@ -56,17 +55,6 @@ class WalkEnv(gym.Env): self.auto_calibrate_train_sim_flip = True self.nominal_calibrated_once = False self.flip_calibrated_once = False - self._target_hz = 0.0 - self._target_dt = 0.0 - self._last_sync_time = None - target_hz_env = 24 - if target_hz_env: - try: - self._target_hz = float(target_hz_env) - except ValueError: - self._target_hz = 0.0 - if self._target_hz > 0.0: - self._target_dt = 1.0 / self._target_hz # State space @@ -153,7 +141,7 @@ class WalkEnv(gym.Env): self.previous_action = np.zeros(len(self.Player.robot.ROBOT_MOTORS)) self.previous_pos = np.array([0.0, 0.0]) # Track previous position self.Player.server.connect() - # sleep(2.0) # Longer wait for connection to establish completely + sleep(2.0) # Longer wait for connection to establish completely self.Player.server.send_immediate( f"(init {self.Player.robot.name} {self.Player.world.team_name} {self.Player.world.number})" ) @@ -272,17 +260,6 @@ class WalkEnv(gym.Env): self.Player.world.update() self.Player.robot.commit_motor_targets_pd() self.Player.server.send() - if self._target_dt > 0.0: - now = time.time() - if self._last_sync_time is None: - self._last_sync_time = now - return - elapsed = now - self._last_sync_time - remaining = self._target_dt - elapsed - if remaining > 0.0: - time.sleep(remaining) - now = time.time() - self._last_sync_time = now def debug_joint_status(self): robot = self.Player.robot @@ -345,12 +322,12 @@ class WalkEnv(gym.Env): # 执行 Neutral 技能直到完成,给机器人足够时间在 beam 位置稳定站立 finished_count = 0 - for _ in range(10): + for _ in range(20): finished = self.Player.skills_manager.execute("Neutral") self.sync() if finished: finished_count += 1 - if finished_count >= 3: # 假设需要连续3次完成才算成功 + if finished_count >= 2: # 假设需要连续2次完成才算成功 break # neutral_joint_positions = np.deg2rad( @@ -515,11 +492,9 @@ class Train(Train_Base): def train(self, args): #--------------------------------------- Learning parameters - n_envs = 8 # Reduced from 8 to decrease CPU/network pressure during init - if n_envs < 1: - raise ValueError("GYM_CPU_N_ENVS must be >= 1") + n_envs = 8 # Reduced from 8 to decrease CPU/network pressure during init n_steps_per_env = 512 # RolloutBuffer is of size (n_steps_per_env * n_envs) - minibatch_size = 128 # should be a factor of (n_steps_per_env * n_envs) + minibatch_size = 64 # should be a factor of (n_steps_per_env * n_envs) total_steps = 30000000 learning_rate = 3e-4 folder_name = f'Walk_R{self.robot_type}' @@ -534,8 +509,6 @@ class Train(Train_Base): return WalkEnv( self.ip , self.server_p + i_env) return thunk - server_log_dir = os.path.join(model_path, "server_logs") - os.makedirs(server_log_dir, exist_ok=True) servers = Train_Server( self.server_p, self.monitor_p_1000, n_envs+1 ) #include 1 extra server for testing # Wait for servers to start @@ -574,7 +547,7 @@ class Train(Train_Base): gamma=0.99 # Discount factor ) - model_path = self.learn_model( model, total_steps, model_path, eval_env=eval_env, eval_freq=n_steps_per_env*10, save_freq=n_steps_per_env*10, backup_env_file=__file__ ) + model_path = self.learn_model( model, total_steps, model_path, eval_env=eval_env, eval_freq=n_steps_per_env*20, save_freq=n_steps_per_env*20, backup_env_file=__file__ ) except KeyboardInterrupt: sleep(1) # wait for child processes print("\nctrl+c pressed, aborting...\n") @@ -589,9 +562,7 @@ class Train(Train_Base): def test(self, args): # Uses different server and monitor ports - server_log_dir = os.path.join(args["folder_dir"], "server_logs") - os.makedirs(server_log_dir, exist_ok=True) - server = Train_Server( self.server_p-1, self.monitor_p, 1, log_dir=server_log_dir ) + server = Train_Server( self.server_p-1, self.monitor_p, 1 ) env = WalkEnv( self.ip, self.server_p-1 ) model = PPO.load( args["model_file"], env=env ) @@ -621,4 +592,4 @@ if __name__ == "__main__": ) trainer = Train(script_args) - trainer.train({"model_file": "scripts/gyms/logs/Walk_R0_000/model_245760_steps.zip"}) + trainer.train({}) \ No newline at end of file