|
from typing import List, Optional, NamedTuple |
|
import itertools |
|
import numpy as np |
|
from mlagents.torch_utils import torch |
|
|
|
from mlagents.trainers.buffer import AgentBuffer, BufferKey |
|
from mlagents.trainers.torch_entities.utils import ModelUtils |
|
from mlagents_envs.base_env import ActionTuple |
|
|
|
|
|
class AgentAction(NamedTuple): |
|
""" |
|
A NamedTuple containing the tensor for continuous actions and list of tensors for |
|
discrete actions. Utility functions provide numpy <=> tensor conversions to be |
|
sent as actions to the environment manager as well as used by the optimizers. |
|
:param continuous_tensor: Torch tensor corresponding to continuous actions |
|
:param discrete_list: List of Torch tensors each corresponding to discrete actions |
|
""" |
|
|
|
continuous_tensor: torch.Tensor |
|
discrete_list: Optional[List[torch.Tensor]] |
|
|
|
@property |
|
def discrete_tensor(self) -> torch.Tensor: |
|
""" |
|
Returns the discrete action list as a stacked tensor |
|
""" |
|
if self.discrete_list is not None and len(self.discrete_list) > 0: |
|
return torch.stack(self.discrete_list, dim=-1) |
|
else: |
|
return torch.empty(0) |
|
|
|
def slice(self, start: int, end: int) -> "AgentAction": |
|
""" |
|
Returns an AgentAction with the continuous and discrete tensors slices |
|
from index start to index end. |
|
""" |
|
_cont = None |
|
_disc_list = [] |
|
if self.continuous_tensor is not None: |
|
_cont = self.continuous_tensor[start:end] |
|
if self.discrete_list is not None and len(self.discrete_list) > 0: |
|
for _disc in self.discrete_list: |
|
_disc_list.append(_disc[start:end]) |
|
return AgentAction(_cont, _disc_list) |
|
|
|
def to_action_tuple(self, clip: bool = False) -> ActionTuple: |
|
""" |
|
Returns an ActionTuple |
|
""" |
|
action_tuple = ActionTuple() |
|
if self.continuous_tensor is not None: |
|
_continuous_tensor = self.continuous_tensor |
|
if clip: |
|
_continuous_tensor = torch.clamp(_continuous_tensor, -3, 3) / 3 |
|
continuous = ModelUtils.to_numpy(_continuous_tensor) |
|
action_tuple.add_continuous(continuous) |
|
if self.discrete_list is not None: |
|
discrete = ModelUtils.to_numpy(self.discrete_tensor[:, 0, :]) |
|
action_tuple.add_discrete(discrete) |
|
return action_tuple |
|
|
|
@staticmethod |
|
def from_buffer(buff: AgentBuffer) -> "AgentAction": |
|
""" |
|
A static method that accesses continuous and discrete action fields in an AgentBuffer |
|
and constructs the corresponding AgentAction from the retrieved np arrays. |
|
""" |
|
continuous: torch.Tensor = None |
|
discrete: List[torch.Tensor] = None |
|
if BufferKey.CONTINUOUS_ACTION in buff: |
|
continuous = ModelUtils.list_to_tensor(buff[BufferKey.CONTINUOUS_ACTION]) |
|
if BufferKey.DISCRETE_ACTION in buff: |
|
discrete_tensor = ModelUtils.list_to_tensor( |
|
buff[BufferKey.DISCRETE_ACTION], dtype=torch.long |
|
) |
|
discrete = [ |
|
discrete_tensor[..., i] for i in range(discrete_tensor.shape[-1]) |
|
] |
|
return AgentAction(continuous, discrete) |
|
|
|
@staticmethod |
|
def _group_agent_action_from_buffer( |
|
buff: AgentBuffer, cont_action_key: BufferKey, disc_action_key: BufferKey |
|
) -> List["AgentAction"]: |
|
""" |
|
Extracts continuous and discrete groupmate actions, as specified by BufferKey, and |
|
returns a List of AgentActions that correspond to the groupmate's actions. List will |
|
be of length equal to the maximum number of groupmates in the buffer. Any spots where |
|
there are less agents than maximum, the actions will be padded with 0's. |
|
""" |
|
continuous_tensors: List[torch.Tensor] = [] |
|
discrete_tensors: List[torch.Tensor] = [] |
|
if cont_action_key in buff: |
|
padded_batch = buff[cont_action_key].padded_to_batch() |
|
continuous_tensors = [ |
|
ModelUtils.list_to_tensor(arr) for arr in padded_batch |
|
] |
|
if disc_action_key in buff: |
|
padded_batch = buff[disc_action_key].padded_to_batch(dtype=np.long) |
|
discrete_tensors = [ |
|
ModelUtils.list_to_tensor(arr, dtype=torch.long) for arr in padded_batch |
|
] |
|
|
|
actions_list = [] |
|
for _cont, _disc in itertools.zip_longest( |
|
continuous_tensors, discrete_tensors, fillvalue=None |
|
): |
|
if _disc is not None: |
|
_disc = [_disc[..., i] for i in range(_disc.shape[-1])] |
|
actions_list.append(AgentAction(_cont, _disc)) |
|
return actions_list |
|
|
|
@staticmethod |
|
def group_from_buffer(buff: AgentBuffer) -> List["AgentAction"]: |
|
""" |
|
A static method that accesses next group continuous and discrete action fields in an AgentBuffer |
|
and constructs a padded List of AgentActions that represent the group agent actions. |
|
The List is of length equal to max number of groupmate agents in the buffer, and the AgentBuffer iss |
|
of the same length as the buffer. Empty spots (e.g. when agents die) are padded with 0. |
|
:param buff: AgentBuffer of a batch or trajectory |
|
:return: List of groupmate's AgentActions |
|
""" |
|
return AgentAction._group_agent_action_from_buffer( |
|
buff, BufferKey.GROUP_CONTINUOUS_ACTION, BufferKey.GROUP_DISCRETE_ACTION |
|
) |
|
|
|
@staticmethod |
|
def group_from_buffer_next(buff: AgentBuffer) -> List["AgentAction"]: |
|
""" |
|
A static method that accesses next group continuous and discrete action fields in an AgentBuffer |
|
and constructs a padded List of AgentActions that represent the next group agent actions. |
|
The List is of length equal to max number of groupmate agents in the buffer, and the AgentBuffer iss |
|
of the same length as the buffer. Empty spots (e.g. when agents die) are padded with 0. |
|
:param buff: AgentBuffer of a batch or trajectory |
|
:return: List of groupmate's AgentActions |
|
""" |
|
return AgentAction._group_agent_action_from_buffer( |
|
buff, BufferKey.GROUP_NEXT_CONT_ACTION, BufferKey.GROUP_NEXT_DISC_ACTION |
|
) |
|
|
|
def to_flat(self, discrete_branches: List[int]) -> torch.Tensor: |
|
""" |
|
Flatten this AgentAction into a single torch Tensor of dimension (batch, num_continuous + num_one_hot_discrete). |
|
Discrete actions are converted into one-hot and concatenated with continuous actions. |
|
:param discrete_branches: List of sizes for discrete actions. |
|
:return: Tensor of flattened actions. |
|
""" |
|
|
|
if self.discrete_list is not None and len(self.discrete_list) > 0: |
|
discrete_oh = ModelUtils.actions_to_onehot( |
|
self.discrete_tensor, discrete_branches |
|
) |
|
discrete_oh = torch.cat(discrete_oh, dim=1) |
|
else: |
|
discrete_oh = torch.empty(0) |
|
return torch.cat([self.continuous_tensor, discrete_oh], dim=-1) |
|
|