|
from typing import Tuple |
|
import threading |
|
from mlagents.torch_utils import torch |
|
|
|
from mlagents_envs.logging_util import get_logger |
|
from mlagents.trainers.settings import SerializationSettings |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
class exporting_to_onnx: |
|
""" |
|
Set this context by calling |
|
``` |
|
with exporting_to_onnx(): |
|
``` |
|
Within this context, the variable exporting_to_onnx.is_exporting() will be true. |
|
This implementation is thread safe. |
|
""" |
|
|
|
|
|
_local_data = threading.local() |
|
_local_data._is_exporting = False |
|
|
|
|
|
_lock = threading.Lock() |
|
|
|
def __enter__(self): |
|
self._lock.acquire() |
|
self._local_data._is_exporting = True |
|
|
|
def __exit__(self, *args): |
|
self._local_data._is_exporting = False |
|
self._lock.release() |
|
|
|
@staticmethod |
|
def is_exporting(): |
|
if not hasattr(exporting_to_onnx._local_data, "_is_exporting"): |
|
return False |
|
return exporting_to_onnx._local_data._is_exporting |
|
|
|
|
|
class TensorNames: |
|
batch_size_placeholder = "batch_size" |
|
sequence_length_placeholder = "sequence_length" |
|
vector_observation_placeholder = "vector_observation" |
|
recurrent_in_placeholder = "recurrent_in" |
|
visual_observation_placeholder_prefix = "visual_observation_" |
|
observation_placeholder_prefix = "obs_" |
|
previous_action_placeholder = "prev_action" |
|
action_mask_placeholder = "action_masks" |
|
random_normal_epsilon_placeholder = "epsilon" |
|
|
|
value_estimate_output = "value_estimate" |
|
recurrent_output = "recurrent_out" |
|
memory_size = "memory_size" |
|
version_number = "version_number" |
|
|
|
continuous_action_output_shape = "continuous_action_output_shape" |
|
discrete_action_output_shape = "discrete_action_output_shape" |
|
continuous_action_output = "continuous_actions" |
|
discrete_action_output = "discrete_actions" |
|
deterministic_continuous_action_output = "deterministic_continuous_actions" |
|
deterministic_discrete_action_output = "deterministic_discrete_actions" |
|
|
|
|
|
is_continuous_control_deprecated = "is_continuous_control" |
|
action_output_deprecated = "action" |
|
action_output_shape_deprecated = "action_output_shape" |
|
|
|
@staticmethod |
|
def get_visual_observation_name(index: int) -> str: |
|
""" |
|
Returns the name of the visual observation with a given index |
|
""" |
|
return TensorNames.visual_observation_placeholder_prefix + str(index) |
|
|
|
@staticmethod |
|
def get_observation_name(index: int) -> str: |
|
""" |
|
Returns the name of the observation with a given index |
|
""" |
|
return TensorNames.observation_placeholder_prefix + str(index) |
|
|
|
|
|
class ModelSerializer: |
|
def __init__(self, policy): |
|
|
|
|
|
|
|
|
|
self.policy = policy |
|
observation_specs = self.policy.behavior_spec.observation_specs |
|
batch_dim = [1] |
|
seq_len_dim = [1] |
|
num_obs = len(observation_specs) |
|
|
|
dummy_obs = [ |
|
torch.zeros( |
|
batch_dim + list(ModelSerializer._get_onnx_shape(obs_spec.shape)) |
|
) |
|
for obs_spec in observation_specs |
|
] |
|
|
|
dummy_masks = torch.ones( |
|
batch_dim + [sum(self.policy.behavior_spec.action_spec.discrete_branches)] |
|
) |
|
dummy_memories = torch.zeros( |
|
batch_dim + seq_len_dim + [self.policy.export_memory_size] |
|
) |
|
|
|
self.dummy_input = (dummy_obs, dummy_masks, dummy_memories) |
|
|
|
self.input_names = [TensorNames.get_observation_name(i) for i in range(num_obs)] |
|
self.input_names += [ |
|
TensorNames.action_mask_placeholder, |
|
TensorNames.recurrent_in_placeholder, |
|
] |
|
|
|
self.dynamic_axes = {name: {0: "batch"} for name in self.input_names} |
|
|
|
self.output_names = [TensorNames.version_number, TensorNames.memory_size] |
|
if self.policy.behavior_spec.action_spec.continuous_size > 0: |
|
self.output_names += [ |
|
TensorNames.continuous_action_output, |
|
TensorNames.continuous_action_output_shape, |
|
TensorNames.deterministic_continuous_action_output, |
|
] |
|
self.dynamic_axes.update( |
|
{TensorNames.continuous_action_output: {0: "batch"}} |
|
) |
|
if self.policy.behavior_spec.action_spec.discrete_size > 0: |
|
self.output_names += [ |
|
TensorNames.discrete_action_output, |
|
TensorNames.discrete_action_output_shape, |
|
TensorNames.deterministic_discrete_action_output, |
|
] |
|
self.dynamic_axes.update({TensorNames.discrete_action_output: {0: "batch"}}) |
|
|
|
if self.policy.export_memory_size > 0: |
|
self.output_names += [TensorNames.recurrent_output] |
|
|
|
@staticmethod |
|
def _get_onnx_shape(shape: Tuple[int, ...]) -> Tuple[int, ...]: |
|
""" |
|
Converts the shape of an observation to be compatible with the NCHW format |
|
of ONNX |
|
""" |
|
if len(shape) == 3: |
|
return shape[2], shape[0], shape[1] |
|
return shape |
|
|
|
def export_policy_model(self, output_filepath: str) -> None: |
|
""" |
|
Exports a Torch model for a Policy to .onnx format for Unity embedding. |
|
|
|
:param output_filepath: file path to output the model (without file suffix) |
|
""" |
|
onnx_output_path = f"{output_filepath}.onnx" |
|
logger.debug(f"Converting to {onnx_output_path}") |
|
|
|
with exporting_to_onnx(): |
|
torch.onnx.export( |
|
self.policy.actor, |
|
self.dummy_input, |
|
onnx_output_path, |
|
opset_version=SerializationSettings.onnx_opset, |
|
input_names=self.input_names, |
|
output_names=self.output_names, |
|
dynamic_axes=self.dynamic_axes, |
|
) |
|
logger.info(f"Exported {onnx_output_path}") |
|
|