69 lines
2.0 KiB
Python
69 lines
2.0 KiB
Python
import numpy as np
|
|
import onnxruntime as ort
|
|
|
|
def export_model(model_class, weights_path, output_file):
|
|
"""
|
|
Export a PyTorch model to ONNX format automatically detecting input shape.
|
|
"""
|
|
import torch # imported only here
|
|
|
|
model = model_class()
|
|
model.load_state_dict(torch.load(weights_path, map_location="cpu"))
|
|
model.eval()
|
|
|
|
# Infer input size from first Linear layer
|
|
first_linear = next(m for m in model.modules() if isinstance(m, torch.nn.Linear))
|
|
input_size = first_linear.in_features
|
|
dummy_input = torch.randn(1, input_size)
|
|
|
|
torch.onnx.export(
|
|
model,
|
|
dummy_input,
|
|
output_file,
|
|
input_names=["obs"],
|
|
output_names=["action"],
|
|
dynamic_axes={"obs": {0: "batch"}},
|
|
opset_version=17,
|
|
)
|
|
print(f"Model exported to {output_file} (input size: {input_size})")
|
|
|
|
|
|
def load_network(model_path):
|
|
"""
|
|
Load an ONNX model into memory for fast reuse.
|
|
"""
|
|
session = ort.InferenceSession(model_path)
|
|
input_name = session.get_inputs()[0].name
|
|
output_name = session.get_outputs()[0].name
|
|
return {"session": session, "input_name": input_name, "output_name": output_name}
|
|
|
|
|
|
def run_network(obs, model):
|
|
"""
|
|
Run a preloaded ONNX model and return a flat float array suitable for motor targets.
|
|
|
|
Args:
|
|
obs (np.ndarray): Input observation array.
|
|
model (dict): The loaded model from load_network().
|
|
|
|
Returns:
|
|
np.ndarray: 1D array of floats.
|
|
"""
|
|
if not isinstance(obs, np.ndarray):
|
|
obs = np.array(obs, dtype=np.float32)
|
|
else:
|
|
obs = obs.astype(np.float32) # ensure float32
|
|
|
|
if obs.ndim == 1:
|
|
obs = obs[np.newaxis, :] # make batch dimension
|
|
|
|
session = model["session"]
|
|
input_name = model["input_name"]
|
|
output_name = model["output_name"]
|
|
|
|
result = session.run([output_name], {input_name: obs})[0]
|
|
|
|
# flatten to 1D and convert to float
|
|
result = result.flatten().astype(np.float32)
|
|
|
|
return result |