|
from typing import List, Tuple, NamedTuple, Optional |
|
from mlagents.torch_utils import torch, nn |
|
from mlagents.trainers.torch_entities.distributions import ( |
|
DistInstance, |
|
DiscreteDistInstance, |
|
GaussianDistribution, |
|
MultiCategoricalDistribution, |
|
) |
|
from mlagents.trainers.torch_entities.agent_action import AgentAction |
|
from mlagents.trainers.torch_entities.action_log_probs import ActionLogProbs |
|
from mlagents_envs.base_env import ActionSpec |
|
|
|
|
|
EPSILON = 1e-7 |
|
|
|
|
|
class DistInstances(NamedTuple): |
|
""" |
|
A NamedTuple with fields corresponding the the DistInstance objects |
|
output by continuous and discrete distributions, respectively. Discrete distributions |
|
output a list of DistInstance objects whereas continuous distributions output a single |
|
DistInstance object. |
|
""" |
|
|
|
continuous: Optional[DistInstance] |
|
discrete: Optional[List[DiscreteDistInstance]] |
|
|
|
|
|
class ActionModel(nn.Module): |
|
def __init__( |
|
self, |
|
hidden_size: int, |
|
action_spec: ActionSpec, |
|
conditional_sigma: bool = False, |
|
tanh_squash: bool = False, |
|
deterministic: bool = False, |
|
): |
|
""" |
|
A torch module that represents the action space of a policy. The ActionModel may contain |
|
a continuous distribution, a discrete distribution or both where construction depends on |
|
the action_spec. The ActionModel uses the encoded input of the network body to parameterize |
|
these distributions. The forward method of this module outputs the action, log probs, |
|
and entropies given the encoding from the network body. |
|
:params hidden_size: Size of the input to the ActionModel. |
|
:params action_spec: The ActionSpec defining the action space dimensions and distributions. |
|
:params conditional_sigma: Whether or not the std of a Gaussian is conditioned on state. |
|
:params tanh_squash: Whether to squash the output of a Gaussian with the tanh function. |
|
:params deterministic: Whether to select actions deterministically in policy. |
|
""" |
|
super().__init__() |
|
self.encoding_size = hidden_size |
|
self.action_spec = action_spec |
|
self._continuous_distribution = None |
|
self._discrete_distribution = None |
|
|
|
if self.action_spec.continuous_size > 0: |
|
self._continuous_distribution = GaussianDistribution( |
|
self.encoding_size, |
|
self.action_spec.continuous_size, |
|
conditional_sigma=conditional_sigma, |
|
tanh_squash=tanh_squash, |
|
) |
|
|
|
if self.action_spec.discrete_size > 0: |
|
self._discrete_distribution = MultiCategoricalDistribution( |
|
self.encoding_size, self.action_spec.discrete_branches |
|
) |
|
|
|
|
|
|
|
self.clip_action = not tanh_squash |
|
self._deterministic = deterministic |
|
|
|
def _sample_action(self, dists: DistInstances) -> AgentAction: |
|
""" |
|
Samples actions from a DistInstances tuple |
|
:params dists: The DistInstances tuple |
|
:return: An AgentAction corresponding to the actions sampled from the DistInstances |
|
""" |
|
|
|
continuous_action: Optional[torch.Tensor] = None |
|
discrete_action: Optional[List[torch.Tensor]] = None |
|
|
|
if dists.continuous is not None: |
|
if self._deterministic: |
|
continuous_action = dists.continuous.deterministic_sample() |
|
else: |
|
continuous_action = dists.continuous.sample() |
|
if dists.discrete is not None: |
|
discrete_action = [] |
|
if self._deterministic: |
|
for discrete_dist in dists.discrete: |
|
discrete_action.append(discrete_dist.deterministic_sample()) |
|
else: |
|
for discrete_dist in dists.discrete: |
|
discrete_action.append(discrete_dist.sample()) |
|
return AgentAction(continuous_action, discrete_action) |
|
|
|
def _get_dists(self, inputs: torch.Tensor, masks: torch.Tensor) -> DistInstances: |
|
""" |
|
Creates a DistInstances tuple using the continuous and discrete distributions |
|
:params inputs: The encoding from the network body |
|
:params masks: Action masks for discrete actions |
|
:return: A DistInstances tuple |
|
""" |
|
continuous_dist: Optional[DistInstance] = None |
|
discrete_dist: Optional[List[DiscreteDistInstance]] = None |
|
|
|
if self._continuous_distribution is not None: |
|
continuous_dist = self._continuous_distribution(inputs) |
|
if self._discrete_distribution is not None: |
|
discrete_dist = self._discrete_distribution(inputs, masks) |
|
return DistInstances(continuous_dist, discrete_dist) |
|
|
|
def _get_probs_and_entropy( |
|
self, actions: AgentAction, dists: DistInstances |
|
) -> Tuple[ActionLogProbs, torch.Tensor]: |
|
""" |
|
Computes the log probabilites of the actions given distributions and entropies of |
|
the given distributions. |
|
:params actions: The AgentAction |
|
:params dists: The DistInstances tuple |
|
:return: An ActionLogProbs tuple and a torch tensor of the distribution entropies. |
|
""" |
|
entropies_list: List[torch.Tensor] = [] |
|
continuous_log_prob: Optional[torch.Tensor] = None |
|
discrete_log_probs: Optional[List[torch.Tensor]] = None |
|
all_discrete_log_probs: Optional[List[torch.Tensor]] = None |
|
|
|
if dists.continuous is not None: |
|
continuous_log_prob = dists.continuous.log_prob(actions.continuous_tensor) |
|
entropies_list.append(dists.continuous.entropy()) |
|
if dists.discrete is not None: |
|
discrete_log_probs = [] |
|
all_discrete_log_probs = [] |
|
for discrete_action, discrete_dist in zip( |
|
actions.discrete_list, dists.discrete |
|
): |
|
discrete_log_prob = discrete_dist.log_prob(discrete_action) |
|
entropies_list.append(discrete_dist.entropy()) |
|
discrete_log_probs.append(discrete_log_prob) |
|
all_discrete_log_probs.append(discrete_dist.all_log_prob()) |
|
action_log_probs = ActionLogProbs( |
|
continuous_log_prob, discrete_log_probs, all_discrete_log_probs |
|
) |
|
entropies = torch.cat(entropies_list, dim=1) |
|
return action_log_probs, entropies |
|
|
|
def evaluate( |
|
self, inputs: torch.Tensor, masks: torch.Tensor, actions: AgentAction |
|
) -> Tuple[ActionLogProbs, torch.Tensor]: |
|
""" |
|
Given actions and encoding from the network body, gets the distributions and |
|
computes the log probabilites and entropies. |
|
:params inputs: The encoding from the network body |
|
:params masks: Action masks for discrete actions |
|
:params actions: The AgentAction |
|
:return: An ActionLogProbs tuple and a torch tensor of the distribution entropies. |
|
""" |
|
dists = self._get_dists(inputs, masks) |
|
log_probs, entropies = self._get_probs_and_entropy(actions, dists) |
|
|
|
entropy_sum = torch.sum(entropies, dim=1) |
|
return log_probs, entropy_sum |
|
|
|
def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Gets the tensors corresponding to the output of the policy network to be used for |
|
inference. Called by the Actor's forward call. |
|
:params inputs: The encoding from the network body |
|
:params masks: Action masks for discrete actions |
|
:return: A tuple of torch tensors corresponding to the inference output |
|
""" |
|
dists = self._get_dists(inputs, masks) |
|
continuous_out, discrete_out, action_out_deprecated = None, None, None |
|
deterministic_continuous_out, deterministic_discrete_out = ( |
|
None, |
|
None, |
|
) |
|
if self.action_spec.continuous_size > 0 and dists.continuous is not None: |
|
continuous_out = dists.continuous.exported_model_output() |
|
action_out_deprecated = continuous_out |
|
deterministic_continuous_out = dists.continuous.deterministic_sample() |
|
if self.clip_action: |
|
continuous_out = torch.clamp(continuous_out, -3, 3) / 3 |
|
action_out_deprecated = continuous_out |
|
deterministic_continuous_out = ( |
|
torch.clamp(deterministic_continuous_out, -3, 3) / 3 |
|
) |
|
if self.action_spec.discrete_size > 0 and dists.discrete is not None: |
|
discrete_out_list = [ |
|
discrete_dist.exported_model_output() |
|
for discrete_dist in dists.discrete |
|
] |
|
discrete_out = torch.cat(discrete_out_list, dim=1) |
|
action_out_deprecated = torch.cat(discrete_out_list, dim=1) |
|
deterministic_discrete_out_list = [ |
|
discrete_dist.deterministic_sample() for discrete_dist in dists.discrete |
|
] |
|
deterministic_discrete_out = torch.cat( |
|
deterministic_discrete_out_list, dim=1 |
|
) |
|
|
|
|
|
if self.action_spec.continuous_size > 0 and self.action_spec.discrete_size > 0: |
|
action_out_deprecated = None |
|
return ( |
|
continuous_out, |
|
discrete_out, |
|
action_out_deprecated, |
|
deterministic_continuous_out, |
|
deterministic_discrete_out, |
|
) |
|
|
|
def forward( |
|
self, inputs: torch.Tensor, masks: torch.Tensor |
|
) -> Tuple[AgentAction, ActionLogProbs, torch.Tensor]: |
|
""" |
|
The forward method of this module. Outputs the action, log probs, |
|
and entropies given the encoding from the network body. |
|
:params inputs: The encoding from the network body |
|
:params masks: Action masks for discrete actions |
|
:return: Given the input, an AgentAction of the actions generated by the policy and the corresponding |
|
ActionLogProbs and entropies. |
|
""" |
|
dists = self._get_dists(inputs, masks) |
|
actions = self._sample_action(dists) |
|
log_probs, entropies = self._get_probs_and_entropy(actions, dists) |
|
|
|
entropy_sum = torch.sum(entropies, dim=1) |
|
return (actions, log_probs, entropy_sum) |
|
|