|
from typing import List |
|
from mlagents.torch_utils import torch |
|
|
|
from mlagents_envs.base_env import ActionSpec |
|
from mlagents.trainers.torch_entities.agent_action import AgentAction |
|
from mlagents.trainers.torch_entities.utils import ModelUtils |
|
|
|
|
|
class ActionFlattener: |
|
def __init__(self, action_spec: ActionSpec): |
|
""" |
|
A torch module that creates the flattened form of an AgentAction object. |
|
The flattened form is the continuous action concatenated with the |
|
concatenated one hot encodings of the discrete actions. |
|
:param action_spec: An ActionSpec that describes the action space dimensions |
|
""" |
|
self._specs = action_spec |
|
|
|
@property |
|
def flattened_size(self) -> int: |
|
""" |
|
The flattened size is the continuous size plus the sum of the branch sizes |
|
since discrete actions are encoded as one hots. |
|
""" |
|
return self._specs.continuous_size + sum(self._specs.discrete_branches) |
|
|
|
def forward(self, action: AgentAction) -> torch.Tensor: |
|
""" |
|
Returns a tensor corresponding the flattened action |
|
:param action: An AgentAction object |
|
""" |
|
action_list: List[torch.Tensor] = [] |
|
if self._specs.continuous_size > 0: |
|
action_list.append(action.continuous_tensor) |
|
if self._specs.discrete_size > 0: |
|
flat_discrete = torch.cat( |
|
ModelUtils.actions_to_onehot( |
|
torch.as_tensor(action.discrete_tensor, dtype=torch.long), |
|
self._specs.discrete_branches, |
|
), |
|
dim=1, |
|
) |
|
action_list.append(flat_discrete) |
|
return torch.cat(action_list, dim=1) |
|
|