Files
Gym_CPU/utils/neural_network.py
2026-03-10 09:31:39 -04:00

69 lines
2.0 KiB
Python

import numpy as np
import onnxruntime as ort
def export_model(model_class, weights_path, output_file):
"""
Export a PyTorch model to ONNX format automatically detecting input shape.
"""
import torch # imported only here
model = model_class()
model.load_state_dict(torch.load(weights_path, map_location="cpu"))
model.eval()
# Infer input size from first Linear layer
first_linear = next(m for m in model.modules() if isinstance(m, torch.nn.Linear))
input_size = first_linear.in_features
dummy_input = torch.randn(1, input_size)
torch.onnx.export(
model,
dummy_input,
output_file,
input_names=["obs"],
output_names=["action"],
dynamic_axes={"obs": {0: "batch"}},
opset_version=17,
)
print(f"Model exported to {output_file} (input size: {input_size})")
def load_network(model_path):
"""
Load an ONNX model into memory for fast reuse.
"""
session = ort.InferenceSession(model_path)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
return {"session": session, "input_name": input_name, "output_name": output_name}
def run_network(obs, model):
"""
Run a preloaded ONNX model and return a flat float array suitable for motor targets.
Args:
obs (np.ndarray): Input observation array.
model (dict): The loaded model from load_network().
Returns:
np.ndarray: 1D array of floats.
"""
if not isinstance(obs, np.ndarray):
obs = np.array(obs, dtype=np.float32)
else:
obs = obs.astype(np.float32) # ensure float32
if obs.ndim == 1:
obs = obs[np.newaxis, :] # make batch dimension
session = model["session"]
input_name = model["input_name"]
output_name = model["output_name"]
result = session.run([output_name], {input_name: obs})[0]
# flatten to 1D and convert to float
result = result.flatten().astype(np.float32)
return result