File size: 7,189 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from typing import List, Optional, NamedTuple
import itertools
import numpy as np
from mlagents.torch_utils import torch

from mlagents.trainers.buffer import AgentBuffer, BufferKey
from mlagents.trainers.torch_entities.utils import ModelUtils
from mlagents_envs.base_env import ActionTuple


class AgentAction(NamedTuple):
    """
    A NamedTuple containing the tensor for continuous actions and list of tensors for
    discrete actions. Utility functions provide numpy <=> tensor conversions to be
    sent as actions to the environment manager as well as used by the optimizers.
    :param continuous_tensor: Torch tensor corresponding to continuous actions
    :param discrete_list: List of Torch tensors each corresponding to discrete actions
    """

    continuous_tensor: torch.Tensor
    discrete_list: Optional[List[torch.Tensor]]

    @property
    def discrete_tensor(self) -> torch.Tensor:
        """
        Returns the discrete action list as a stacked tensor
        """
        if self.discrete_list is not None and len(self.discrete_list) > 0:
            return torch.stack(self.discrete_list, dim=-1)
        else:
            return torch.empty(0)

    def slice(self, start: int, end: int) -> "AgentAction":
        """
        Returns an AgentAction with the continuous and discrete tensors slices
        from index start to index end.
        """
        _cont = None
        _disc_list = []
        if self.continuous_tensor is not None:
            _cont = self.continuous_tensor[start:end]
        if self.discrete_list is not None and len(self.discrete_list) > 0:
            for _disc in self.discrete_list:
                _disc_list.append(_disc[start:end])
        return AgentAction(_cont, _disc_list)

    def to_action_tuple(self, clip: bool = False) -> ActionTuple:
        """
        Returns an ActionTuple
        """
        action_tuple = ActionTuple()
        if self.continuous_tensor is not None:
            _continuous_tensor = self.continuous_tensor
            if clip:
                _continuous_tensor = torch.clamp(_continuous_tensor, -3, 3) / 3
            continuous = ModelUtils.to_numpy(_continuous_tensor)
            action_tuple.add_continuous(continuous)
        if self.discrete_list is not None:
            discrete = ModelUtils.to_numpy(self.discrete_tensor[:, 0, :])
            action_tuple.add_discrete(discrete)
        return action_tuple

    @staticmethod
    def from_buffer(buff: AgentBuffer) -> "AgentAction":
        """
        A static method that accesses continuous and discrete action fields in an AgentBuffer
        and constructs the corresponding AgentAction from the retrieved np arrays.
        """
        continuous: torch.Tensor = None
        discrete: List[torch.Tensor] = None  # type: ignore
        if BufferKey.CONTINUOUS_ACTION in buff:
            continuous = ModelUtils.list_to_tensor(buff[BufferKey.CONTINUOUS_ACTION])
        if BufferKey.DISCRETE_ACTION in buff:
            discrete_tensor = ModelUtils.list_to_tensor(
                buff[BufferKey.DISCRETE_ACTION], dtype=torch.long
            )
            discrete = [
                discrete_tensor[..., i] for i in range(discrete_tensor.shape[-1])
            ]
        return AgentAction(continuous, discrete)

    @staticmethod
    def _group_agent_action_from_buffer(
        buff: AgentBuffer, cont_action_key: BufferKey, disc_action_key: BufferKey
    ) -> List["AgentAction"]:
        """
        Extracts continuous and discrete groupmate actions, as specified by BufferKey, and
        returns a List of AgentActions that correspond to the groupmate's actions. List will
        be of length equal to the maximum number of groupmates in the buffer. Any spots where
        there are less agents than maximum, the actions will be padded with 0's.
        """
        continuous_tensors: List[torch.Tensor] = []
        discrete_tensors: List[torch.Tensor] = []
        if cont_action_key in buff:
            padded_batch = buff[cont_action_key].padded_to_batch()
            continuous_tensors = [
                ModelUtils.list_to_tensor(arr) for arr in padded_batch
            ]
        if disc_action_key in buff:
            padded_batch = buff[disc_action_key].padded_to_batch(dtype=np.long)
            discrete_tensors = [
                ModelUtils.list_to_tensor(arr, dtype=torch.long) for arr in padded_batch
            ]

        actions_list = []
        for _cont, _disc in itertools.zip_longest(
            continuous_tensors, discrete_tensors, fillvalue=None
        ):
            if _disc is not None:
                _disc = [_disc[..., i] for i in range(_disc.shape[-1])]
            actions_list.append(AgentAction(_cont, _disc))
        return actions_list

    @staticmethod
    def group_from_buffer(buff: AgentBuffer) -> List["AgentAction"]:
        """
        A static method that accesses next group continuous and discrete action fields in an AgentBuffer
        and constructs a padded List of AgentActions that represent the group agent actions.
        The List is of length equal to max number of groupmate agents in the buffer, and the AgentBuffer iss
        of the same length as the buffer. Empty spots (e.g. when agents die) are padded with 0.
        :param buff: AgentBuffer of a batch or trajectory
        :return: List of groupmate's AgentActions
        """
        return AgentAction._group_agent_action_from_buffer(
            buff, BufferKey.GROUP_CONTINUOUS_ACTION, BufferKey.GROUP_DISCRETE_ACTION
        )

    @staticmethod
    def group_from_buffer_next(buff: AgentBuffer) -> List["AgentAction"]:
        """
        A static method that accesses next group continuous and discrete action fields in an AgentBuffer
        and constructs a padded List of AgentActions that represent the next group agent actions.
        The List is of length equal to max number of groupmate agents in the buffer, and the AgentBuffer iss
        of the same length as the buffer. Empty spots (e.g. when agents die) are padded with 0.
        :param buff: AgentBuffer of a batch or trajectory
        :return: List of groupmate's AgentActions
        """
        return AgentAction._group_agent_action_from_buffer(
            buff, BufferKey.GROUP_NEXT_CONT_ACTION, BufferKey.GROUP_NEXT_DISC_ACTION
        )

    def to_flat(self, discrete_branches: List[int]) -> torch.Tensor:
        """
        Flatten this AgentAction into a single torch Tensor of dimension (batch, num_continuous + num_one_hot_discrete).
        Discrete actions are converted into one-hot and concatenated with continuous actions.
        :param discrete_branches: List of sizes for discrete actions.
        :return: Tensor of flattened actions.
        """
        # if there are any discrete actions, create one-hot
        if self.discrete_list is not None and len(self.discrete_list) > 0:
            discrete_oh = ModelUtils.actions_to_onehot(
                self.discrete_tensor, discrete_branches
            )
            discrete_oh = torch.cat(discrete_oh, dim=1)
        else:
            discrete_oh = torch.empty(0)
        return torch.cat([self.continuous_tensor, discrete_oh], dim=-1)