Init
This commit is contained in:
69
utils/neural_network.py
Normal file
69
utils/neural_network.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
def export_model(model_class, weights_path, output_file):
|
||||
"""
|
||||
Export a PyTorch model to ONNX format automatically detecting input shape.
|
||||
"""
|
||||
import torch # imported only here
|
||||
|
||||
model = model_class()
|
||||
model.load_state_dict(torch.load(weights_path, map_location="cpu"))
|
||||
model.eval()
|
||||
|
||||
# Infer input size from first Linear layer
|
||||
first_linear = next(m for m in model.modules() if isinstance(m, torch.nn.Linear))
|
||||
input_size = first_linear.in_features
|
||||
dummy_input = torch.randn(1, input_size)
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
dummy_input,
|
||||
output_file,
|
||||
input_names=["obs"],
|
||||
output_names=["action"],
|
||||
dynamic_axes={"obs": {0: "batch"}},
|
||||
opset_version=17,
|
||||
)
|
||||
print(f"Model exported to {output_file} (input size: {input_size})")
|
||||
|
||||
|
||||
def load_network(model_path):
|
||||
"""
|
||||
Load an ONNX model into memory for fast reuse.
|
||||
"""
|
||||
session = ort.InferenceSession(model_path)
|
||||
input_name = session.get_inputs()[0].name
|
||||
output_name = session.get_outputs()[0].name
|
||||
return {"session": session, "input_name": input_name, "output_name": output_name}
|
||||
|
||||
|
||||
def run_network(obs, model):
|
||||
"""
|
||||
Run a preloaded ONNX model and return a flat float array suitable for motor targets.
|
||||
|
||||
Args:
|
||||
obs (np.ndarray): Input observation array.
|
||||
model (dict): The loaded model from load_network().
|
||||
|
||||
Returns:
|
||||
np.ndarray: 1D array of floats.
|
||||
"""
|
||||
if not isinstance(obs, np.ndarray):
|
||||
obs = np.array(obs, dtype=np.float32)
|
||||
else:
|
||||
obs = obs.astype(np.float32) # ensure float32
|
||||
|
||||
if obs.ndim == 1:
|
||||
obs = obs[np.newaxis, :] # make batch dimension
|
||||
|
||||
session = model["session"]
|
||||
input_name = model["input_name"]
|
||||
output_name = model["output_name"]
|
||||
|
||||
result = session.run([output_name], {input_name: obs})[0]
|
||||
|
||||
# flatten to 1D and convert to float
|
||||
result = result.flatten().astype(np.float32)
|
||||
|
||||
return result
|
||||
Reference in New Issue
Block a user