File size: 10,749 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
from typing import List, Tuple, NamedTuple, Optional
from mlagents.torch_utils import torch, nn
from mlagents.trainers.torch_entities.distributions import (
    DistInstance,
    DiscreteDistInstance,
    GaussianDistribution,
    MultiCategoricalDistribution,
)
from mlagents.trainers.torch_entities.agent_action import AgentAction
from mlagents.trainers.torch_entities.action_log_probs import ActionLogProbs
from mlagents_envs.base_env import ActionSpec


EPSILON = 1e-7  # Small value to avoid divide by zero


class DistInstances(NamedTuple):
    """
    A NamedTuple with fields corresponding the the DistInstance objects
    output by continuous and discrete distributions, respectively. Discrete distributions
    output a list of DistInstance objects whereas continuous distributions output a single
    DistInstance object.
    """

    continuous: Optional[DistInstance]
    discrete: Optional[List[DiscreteDistInstance]]


class ActionModel(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        action_spec: ActionSpec,
        conditional_sigma: bool = False,
        tanh_squash: bool = False,
        deterministic: bool = False,
    ):
        """
        A torch module that represents the action space of a policy. The ActionModel may contain
        a continuous distribution, a discrete distribution or both where construction depends on
        the action_spec.  The ActionModel uses the encoded input of the network body to parameterize
        these distributions. The forward method of this module outputs the action, log probs,
        and entropies given the encoding from the network body.
        :params hidden_size: Size of the input to the ActionModel.
        :params action_spec: The ActionSpec defining the action space dimensions and distributions.
        :params conditional_sigma: Whether or not the std of a Gaussian is conditioned on state.
        :params tanh_squash: Whether to squash the output of a Gaussian with the tanh function.
        :params deterministic: Whether to select actions deterministically in policy.
        """
        super().__init__()
        self.encoding_size = hidden_size
        self.action_spec = action_spec
        self._continuous_distribution = None
        self._discrete_distribution = None

        if self.action_spec.continuous_size > 0:
            self._continuous_distribution = GaussianDistribution(
                self.encoding_size,
                self.action_spec.continuous_size,
                conditional_sigma=conditional_sigma,
                tanh_squash=tanh_squash,
            )

        if self.action_spec.discrete_size > 0:
            self._discrete_distribution = MultiCategoricalDistribution(
                self.encoding_size, self.action_spec.discrete_branches
            )

        # During training, clipping is done in TorchPolicy, but we need to clip before ONNX
        # export as well.
        self.clip_action = not tanh_squash
        self._deterministic = deterministic

    def _sample_action(self, dists: DistInstances) -> AgentAction:
        """
        Samples actions from a DistInstances tuple
        :params dists: The DistInstances tuple
        :return: An AgentAction corresponding to the actions sampled from the DistInstances
        """

        continuous_action: Optional[torch.Tensor] = None
        discrete_action: Optional[List[torch.Tensor]] = None
        # This checks None because mypy complains otherwise
        if dists.continuous is not None:
            if self._deterministic:
                continuous_action = dists.continuous.deterministic_sample()
            else:
                continuous_action = dists.continuous.sample()
        if dists.discrete is not None:
            discrete_action = []
            if self._deterministic:
                for discrete_dist in dists.discrete:
                    discrete_action.append(discrete_dist.deterministic_sample())
            else:
                for discrete_dist in dists.discrete:
                    discrete_action.append(discrete_dist.sample())
        return AgentAction(continuous_action, discrete_action)

    def _get_dists(self, inputs: torch.Tensor, masks: torch.Tensor) -> DistInstances:
        """
        Creates a DistInstances tuple using the continuous and discrete distributions
        :params inputs: The encoding from the network body
        :params masks: Action masks for discrete actions
        :return: A DistInstances tuple
        """
        continuous_dist: Optional[DistInstance] = None
        discrete_dist: Optional[List[DiscreteDistInstance]] = None
        # This checks None because mypy complains otherwise
        if self._continuous_distribution is not None:
            continuous_dist = self._continuous_distribution(inputs)
        if self._discrete_distribution is not None:
            discrete_dist = self._discrete_distribution(inputs, masks)
        return DistInstances(continuous_dist, discrete_dist)

    def _get_probs_and_entropy(
        self, actions: AgentAction, dists: DistInstances
    ) -> Tuple[ActionLogProbs, torch.Tensor]:
        """
        Computes the log probabilites of the actions given distributions and entropies of
        the given distributions.
        :params actions: The AgentAction
        :params dists: The DistInstances tuple
        :return: An ActionLogProbs tuple and a torch tensor of the distribution entropies.
        """
        entropies_list: List[torch.Tensor] = []
        continuous_log_prob: Optional[torch.Tensor] = None
        discrete_log_probs: Optional[List[torch.Tensor]] = None
        all_discrete_log_probs: Optional[List[torch.Tensor]] = None
        # This checks None because mypy complains otherwise
        if dists.continuous is not None:
            continuous_log_prob = dists.continuous.log_prob(actions.continuous_tensor)
            entropies_list.append(dists.continuous.entropy())
        if dists.discrete is not None:
            discrete_log_probs = []
            all_discrete_log_probs = []
            for discrete_action, discrete_dist in zip(
                actions.discrete_list, dists.discrete  # type: ignore
            ):
                discrete_log_prob = discrete_dist.log_prob(discrete_action)
                entropies_list.append(discrete_dist.entropy())
                discrete_log_probs.append(discrete_log_prob)
                all_discrete_log_probs.append(discrete_dist.all_log_prob())
        action_log_probs = ActionLogProbs(
            continuous_log_prob, discrete_log_probs, all_discrete_log_probs
        )
        entropies = torch.cat(entropies_list, dim=1)
        return action_log_probs, entropies

    def evaluate(
        self, inputs: torch.Tensor, masks: torch.Tensor, actions: AgentAction
    ) -> Tuple[ActionLogProbs, torch.Tensor]:
        """
        Given actions and encoding from the network body, gets the distributions and
        computes the log probabilites and entropies.
        :params inputs: The encoding from the network body
        :params masks: Action masks for discrete actions
        :params actions: The AgentAction
        :return: An ActionLogProbs tuple and a torch tensor of the distribution entropies.
        """
        dists = self._get_dists(inputs, masks)
        log_probs, entropies = self._get_probs_and_entropy(actions, dists)
        # Use the sum of entropy across actions, not the mean
        entropy_sum = torch.sum(entropies, dim=1)
        return log_probs, entropy_sum

    def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:
        """
        Gets the tensors corresponding to the output of the policy network to be used for
        inference. Called by the Actor's forward call.
        :params inputs: The encoding from the network body
        :params masks: Action masks for discrete actions
        :return: A tuple of torch tensors corresponding to the inference output
        """
        dists = self._get_dists(inputs, masks)
        continuous_out, discrete_out, action_out_deprecated = None, None, None
        deterministic_continuous_out, deterministic_discrete_out = (
            None,
            None,
        )  # deterministic actions
        if self.action_spec.continuous_size > 0 and dists.continuous is not None:
            continuous_out = dists.continuous.exported_model_output()
            action_out_deprecated = continuous_out
            deterministic_continuous_out = dists.continuous.deterministic_sample()
            if self.clip_action:
                continuous_out = torch.clamp(continuous_out, -3, 3) / 3
                action_out_deprecated = continuous_out
                deterministic_continuous_out = (
                    torch.clamp(deterministic_continuous_out, -3, 3) / 3
                )
        if self.action_spec.discrete_size > 0 and dists.discrete is not None:
            discrete_out_list = [
                discrete_dist.exported_model_output()
                for discrete_dist in dists.discrete
            ]
            discrete_out = torch.cat(discrete_out_list, dim=1)
            action_out_deprecated = torch.cat(discrete_out_list, dim=1)
            deterministic_discrete_out_list = [
                discrete_dist.deterministic_sample() for discrete_dist in dists.discrete
            ]
            deterministic_discrete_out = torch.cat(
                deterministic_discrete_out_list, dim=1
            )

        # deprecated action field does not support hybrid action
        if self.action_spec.continuous_size > 0 and self.action_spec.discrete_size > 0:
            action_out_deprecated = None
        return (
            continuous_out,
            discrete_out,
            action_out_deprecated,
            deterministic_continuous_out,
            deterministic_discrete_out,
        )

    def forward(
        self, inputs: torch.Tensor, masks: torch.Tensor
    ) -> Tuple[AgentAction, ActionLogProbs, torch.Tensor]:
        """
        The forward method of this module. Outputs the action, log probs,
        and entropies given the encoding from the network body.
        :params inputs: The encoding from the network body
        :params masks: Action masks for discrete actions
        :return: Given the input, an AgentAction of the actions generated by the policy and the corresponding
        ActionLogProbs and entropies.
        """
        dists = self._get_dists(inputs, masks)
        actions = self._sample_action(dists)
        log_probs, entropies = self._get_probs_and_entropy(actions, dists)
        # Use the sum of entropy across actions, not the mean
        entropy_sum = torch.sum(entropies, dim=1)
        return (actions, log_probs, entropy_sum)