|
from typing import Union, Optional, Dict |
|
from easydict import EasyDict |
|
|
|
import torch |
|
import torch.nn as nn |
|
from ding.model.common import ReparameterizationHead, EnsembleHead |
|
from ding.utils import SequenceType, squeeze |
|
|
|
from ding.utils import MODEL_REGISTRY |
|
|
|
|
|
@MODEL_REGISTRY.register('edac') |
|
class EDAC(nn.Module): |
|
""" |
|
Overview: |
|
The Q-value Actor-Critic network with the ensemble mechanism, which is used in EDAC. |
|
Interfaces: |
|
``__init__``, ``forward``, ``compute_actor``, ``compute_critic`` |
|
""" |
|
mode = ['compute_actor', 'compute_critic'] |
|
|
|
def __init__( |
|
self, |
|
obs_shape: Union[int, SequenceType], |
|
action_shape: Union[int, SequenceType, EasyDict], |
|
ensemble_num: int = 2, |
|
actor_head_hidden_size: int = 64, |
|
actor_head_layer_num: int = 1, |
|
critic_head_hidden_size: int = 64, |
|
critic_head_layer_num: int = 1, |
|
activation: Optional[nn.Module] = nn.ReLU(), |
|
norm_type: Optional[str] = None, |
|
**kwargs |
|
) -> None: |
|
""" |
|
Overview: |
|
Initailize the EDAC Model according to input arguments. |
|
Arguments: |
|
- obs_shape (:obj:`Union[int, SequenceType]`): Observation's shape, such as 128, (156, ). |
|
- action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's shape, such as 4, (3, ), \ |
|
EasyDict({'action_type_shape': 3, 'action_args_shape': 4}). |
|
- ensemble_num (:obj:`int`): Q-net number. |
|
- actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor head. |
|
- actor_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \ |
|
for actor head. |
|
- critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic head. |
|
- critic_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \ |
|
for critic head. |
|
- activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` \ |
|
after each FC layer, if ``None`` then default set to ``nn.ReLU()``. |
|
- norm_type (:obj:`Optional[str]`): The type of normalization to after network layer (FC, Conv), \ |
|
see ``ding.torch_utils.network`` for more details. |
|
""" |
|
super(EDAC, self).__init__() |
|
obs_shape: int = squeeze(obs_shape) |
|
action_shape = squeeze(action_shape) |
|
self.action_shape = action_shape |
|
self.ensemble_num = ensemble_num |
|
self.actor = nn.Sequential( |
|
nn.Linear(obs_shape, actor_head_hidden_size), activation, |
|
ReparameterizationHead( |
|
actor_head_hidden_size, |
|
action_shape, |
|
actor_head_layer_num, |
|
sigma_type='conditioned', |
|
activation=activation, |
|
norm_type=norm_type |
|
) |
|
) |
|
|
|
critic_input_size = obs_shape + action_shape |
|
self.critic = EnsembleHead( |
|
critic_input_size, |
|
1, |
|
critic_head_hidden_size, |
|
critic_head_layer_num, |
|
self.ensemble_num, |
|
activation=activation, |
|
norm_type=norm_type |
|
) |
|
|
|
def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], mode: str) -> Dict[str, torch.Tensor]: |
|
""" |
|
Overview: |
|
The unique execution (forward) method of EDAC method, and one can indicate different modes to implement \ |
|
different computation graph, including ``compute_actor`` and ``compute_critic`` in EDAC. |
|
Mode compute_actor: |
|
Arguments: |
|
- inputs (:obj:`torch.Tensor`): Observation data, defaults to tensor. |
|
Returns: |
|
- output (:obj:`Dict`): Output dict data, including differnet key-values among distinct action_space. |
|
Mode compute_critic: |
|
Arguments: |
|
- inputs (:obj:`Dict`): Input dict data, including obs and action tensor. |
|
Returns: |
|
- output (:obj:`Dict`): Output dict data, including q_value tensor. |
|
|
|
.. note:: |
|
For specific examples, one can refer to API doc of ``compute_actor`` and ``compute_critic`` respectively. |
|
""" |
|
assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) |
|
return getattr(self, mode)(inputs) |
|
|
|
def compute_actor(self, obs: torch.Tensor) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]: |
|
""" |
|
Overview: |
|
The forward computation graph of compute_actor mode, uses observation tensor to produce actor output, |
|
such as ``action``, ``logit`` and so on. |
|
Arguments: |
|
- obs (:obj:`torch.Tensor`): Observation tensor data, now supports a batch of 1-dim vector data, \ |
|
i.e. ``(B, obs_shape)``. |
|
Returns: |
|
- outputs (:obj:`Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]`): Actor output varying \ |
|
from action_space: ``reparameterization``. |
|
ReturnsKeys (either): |
|
- logit (:obj:`Dict[str, torch.Tensor]`): Reparameterization logit, usually in SAC. |
|
- mu (:obj:`torch.Tensor`): Mean of parameterization gaussion distribution. |
|
- sigma (:obj:`torch.Tensor`): Standard variation of parameterization gaussion distribution. |
|
Shapes: |
|
- obs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``obs_shape``. |
|
- action (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size and N1 corresponds to ``action_shape``. |
|
- logit.mu (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size and N1 corresponds to ``action_shape``. |
|
- logit.sigma (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size. |
|
- logit (:obj:`torch.Tensor`): :math:`(B, N2)`, B is batch size and N2 corresponds to \ |
|
``action_shape.action_type_shape``. |
|
- action_args (:obj:`torch.Tensor`): :math:`(B, N3)`, B is batch size and N3 corresponds to \ |
|
``action_shape.action_args_shape``. |
|
Examples: |
|
>>> model = EDAC(64, 64,) |
|
>>> obs = torch.randn(4, 64) |
|
>>> actor_outputs = model(obs,'compute_actor') |
|
>>> assert actor_outputs['logit'][0].shape == torch.Size([4, 64]) # mu |
|
>>> actor_outputs['logit'][1].shape == torch.Size([4, 64]) # sigma |
|
""" |
|
x = self.actor(obs) |
|
return {'logit': [x['mu'], x['sigma']]} |
|
|
|
def compute_critic(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
""" |
|
Overview: |
|
The forward computation graph of compute_critic mode, uses observation and action tensor to produce critic |
|
output, such as ``q_value``. |
|
Arguments: |
|
- inputs (:obj:`Dict[str, torch.Tensor]`): Dict strcture of input data, including ``obs`` and \ |
|
``action`` tensor |
|
Returns: |
|
- outputs (:obj:`Dict[str, torch.Tensor]`): Critic output, such as ``q_value``. |
|
ArgumentsKeys: |
|
- obs: (:obj:`torch.Tensor`): Observation tensor data, now supports a batch of 1-dim vector data. |
|
- action (:obj:`Union[torch.Tensor, Dict]`): Continuous action with same size as ``action_shape``. |
|
ReturnKeys: |
|
- q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. |
|
Shapes: |
|
- obs (:obj:`torch.Tensor`): :math:`(B, N1)` or '(Ensemble_num, B, N1)', where B is batch size and N1 is \ |
|
``obs_shape``. |
|
- action (:obj:`torch.Tensor`): :math:`(B, N2)` or '(Ensemble_num, B, N2)', where B is batch size and N4 \ |
|
is ``action_shape``. |
|
- q_value (:obj:`torch.Tensor`): :math:`(Ensemble_num, B)`, where B is batch size. |
|
Examples: |
|
>>> inputs = {'obs': torch.randn(4, 8), 'action': torch.randn(4, 1)} |
|
>>> model = EDAC(obs_shape=(8, ),action_shape=1) |
|
>>> model(inputs, mode='compute_critic')['q_value'] # q value |
|
... tensor([0.0773, 0.1639, 0.0917, 0.0370], grad_fn=<SqueezeBackward1>) |
|
""" |
|
|
|
obs, action = inputs['obs'], inputs['action'] |
|
if len(action.shape) == 1: |
|
action = action.unsqueeze(1) |
|
x = torch.cat([obs, action], dim=-1) |
|
if len(obs.shape) < 3: |
|
|
|
x = x.repeat(1, self.ensemble_num).unsqueeze(-1) |
|
else: |
|
|
|
x = x.transpose(0, 1) |
|
batch_size = obs.shape[1] |
|
x = x.reshape(batch_size, -1, 1) |
|
|
|
x = self.critic(x)['pred'] |
|
|
|
x = x.permute(1, 0) |
|
return {'q_value': x} |
|
|