|
from typing import Union, Optional, Dict |
|
import torch |
|
import torch.nn as nn |
|
from easydict import EasyDict |
|
|
|
from ding.utils import MODEL_REGISTRY, SequenceType, squeeze |
|
from ..common import FCEncoder, ConvEncoder, DiscreteHead, DuelingHead, \ |
|
MultiHead, RegressionHead, ReparameterizationHead |
|
|
|
|
|
@MODEL_REGISTRY.register('discrete_bc') |
|
class DiscreteBC(nn.Module): |
|
""" |
|
Overview: |
|
The DiscreteBC network. |
|
Interfaces: |
|
``__init__``, ``forward`` |
|
""" |
|
|
|
def __init__( |
|
self, |
|
obs_shape: Union[int, SequenceType], |
|
action_shape: Union[int, SequenceType], |
|
encoder_hidden_size_list: SequenceType = [128, 128, 64], |
|
dueling: bool = True, |
|
head_hidden_size: Optional[int] = None, |
|
head_layer_num: int = 1, |
|
activation: Optional[nn.Module] = nn.ReLU(), |
|
norm_type: Optional[str] = None, |
|
strides: Optional[list] = None, |
|
) -> None: |
|
""" |
|
Overview: |
|
Init the DiscreteBC (encoder + head) Model according to input arguments. |
|
Arguments: |
|
- obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84]. |
|
- action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3]. |
|
- encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \ |
|
the last element must match ``head_hidden_size``. |
|
- dueling (:obj:`dueling`): Whether choose ``DuelingHead`` or ``DiscreteHead(default)``. |
|
- head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of head network. |
|
- head_layer_num (:obj:`int`): The number of layers used in the head network to compute Q value output |
|
- activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \ |
|
if ``None`` then default set it to ``nn.ReLU()``. |
|
- norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \ |
|
``ding.torch_utils.fc_block`` for more details. |
|
- strides (:obj:`Optional[list]`): The strides for each convolution layers, such as [2, 2, 2]. The length \ |
|
of this argument should be the same as ``encoder_hidden_size_list``. |
|
""" |
|
super(DiscreteBC, self).__init__() |
|
|
|
obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape) |
|
if head_hidden_size is None: |
|
head_hidden_size = encoder_hidden_size_list[-1] |
|
|
|
if isinstance(obs_shape, int) or len(obs_shape) == 1: |
|
self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type) |
|
|
|
elif len(obs_shape) == 3: |
|
if not strides: |
|
self.encoder = ConvEncoder( |
|
obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type |
|
) |
|
else: |
|
self.encoder = ConvEncoder( |
|
obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type, stride=strides |
|
) |
|
else: |
|
raise RuntimeError( |
|
"not support obs_shape for pre-defined encoder: {}, please customize your own BC".format(obs_shape) |
|
) |
|
|
|
if dueling: |
|
head_cls = DuelingHead |
|
else: |
|
head_cls = DiscreteHead |
|
multi_head = not isinstance(action_shape, int) |
|
if multi_head: |
|
self.head = MultiHead( |
|
head_cls, |
|
head_hidden_size, |
|
action_shape, |
|
layer_num=head_layer_num, |
|
activation=activation, |
|
norm_type=norm_type |
|
) |
|
else: |
|
self.head = head_cls( |
|
head_hidden_size, action_shape, head_layer_num, activation=activation, norm_type=norm_type |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> Dict: |
|
""" |
|
Overview: |
|
DiscreteBC forward computation graph, input observation tensor to predict q_value. |
|
Arguments: |
|
- x (:obj:`torch.Tensor`): Observation inputs |
|
Returns: |
|
- outputs (:obj:`Dict`): DiscreteBC forward outputs, such as q_value. |
|
ReturnsKeys: |
|
- logit (:obj:`torch.Tensor`): Discrete Q-value output of each action dimension. |
|
Shapes: |
|
- x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape`` |
|
- logit (:obj:`torch.FloatTensor`): :math:`(B, M)`, where B is batch size and M is ``action_shape`` |
|
Examples: |
|
>>> model = DiscreteBC(32, 6) # arguments: 'obs_shape' and 'action_shape' |
|
>>> inputs = torch.randn(4, 32) |
|
>>> outputs = model(inputs) |
|
>>> assert isinstance(outputs, dict) and outputs['logit'].shape == torch.Size([4, 6]) |
|
""" |
|
x = self.encoder(x) |
|
x = self.head(x) |
|
return x |
|
|
|
|
|
@MODEL_REGISTRY.register('continuous_bc') |
|
class ContinuousBC(nn.Module): |
|
""" |
|
Overview: |
|
The ContinuousBC network. |
|
Interfaces: |
|
``__init__``, ``forward`` |
|
""" |
|
|
|
def __init__( |
|
self, |
|
obs_shape: Union[int, SequenceType], |
|
action_shape: Union[int, SequenceType, EasyDict], |
|
action_space: str, |
|
actor_head_hidden_size: int = 64, |
|
actor_head_layer_num: int = 1, |
|
activation: Optional[nn.Module] = nn.ReLU(), |
|
norm_type: Optional[str] = None, |
|
) -> None: |
|
""" |
|
Overview: |
|
Initialize the ContinuousBC Model according to input arguments. |
|
Arguments: |
|
- obs_shape (:obj:`Union[int, SequenceType]`): Observation's shape, such as 128, (156, ). |
|
- action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's shape, such as 4, (3, ), \ |
|
EasyDict({'action_type_shape': 3, 'action_args_shape': 4}). |
|
- action_space (:obj:`str`): The type of action space, \ |
|
including [``regression``, ``reparameterization``]. |
|
- actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor head. |
|
- actor_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \ |
|
for actor head. |
|
- activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` \ |
|
after each FC layer, if ``None`` then default set to ``nn.ReLU()``. |
|
- norm_type (:obj:`Optional[str]`): The type of normalization to after network layer (FC, Conv), \ |
|
see ``ding.torch_utils.network`` for more details. |
|
""" |
|
super(ContinuousBC, self).__init__() |
|
obs_shape: int = squeeze(obs_shape) |
|
action_shape = squeeze(action_shape) |
|
self.action_shape = action_shape |
|
self.action_space = action_space |
|
assert self.action_space in ['regression', 'reparameterization'] |
|
if self.action_space == 'regression': |
|
self.actor = nn.Sequential( |
|
nn.Linear(obs_shape, actor_head_hidden_size), activation, |
|
RegressionHead( |
|
actor_head_hidden_size, |
|
action_shape, |
|
actor_head_layer_num, |
|
final_tanh=True, |
|
activation=activation, |
|
norm_type=norm_type |
|
) |
|
) |
|
elif self.action_space == 'reparameterization': |
|
self.actor = nn.Sequential( |
|
nn.Linear(obs_shape, actor_head_hidden_size), activation, |
|
ReparameterizationHead( |
|
actor_head_hidden_size, |
|
action_shape, |
|
actor_head_layer_num, |
|
sigma_type='conditioned', |
|
activation=activation, |
|
norm_type=norm_type |
|
) |
|
) |
|
|
|
def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Dict: |
|
""" |
|
Overview: |
|
The unique execution (forward) method of ContinuousBC. |
|
Arguments: |
|
- inputs (:obj:`torch.Tensor`): Observation data, defaults to tensor. |
|
Returns: |
|
- output (:obj:`Dict`): Output dict data, including different key-values among distinct action_space. |
|
ReturnsKeys: |
|
- action (:obj:`torch.Tensor`): action output of actor network, \ |
|
with shape :math:`(B, action_shape)`. |
|
- logit (:obj:`List[torch.Tensor]`): reparameterized action output of actor network, \ |
|
with shape :math:`(B, action_shape)`. |
|
Shapes: |
|
- inputs (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape`` |
|
- action (:obj:`torch.FloatTensor`): :math:`(B, M)`, where B is batch size and M is ``action_shape`` |
|
- logit (:obj:`List[torch.FloatTensor]`): :math:`(B, M)`, where B is batch size and M is ``action_shape`` |
|
Examples (Regression): |
|
>>> model = ContinuousBC(32, 6, action_space='regression') |
|
>>> inputs = torch.randn(4, 32) |
|
>>> outputs = model(inputs) |
|
>>> assert isinstance(outputs, dict) and outputs['action'].shape == torch.Size([4, 6]) |
|
Examples (Reparameterization): |
|
>>> model = ContinuousBC(32, 6, action_space='reparameterization') |
|
>>> inputs = torch.randn(4, 32) |
|
>>> outputs = model(inputs) |
|
>>> assert isinstance(outputs, dict) and outputs['logit'][0].shape == torch.Size([4, 6]) |
|
>>> assert outputs['logit'][1].shape == torch.Size([4, 6]) |
|
""" |
|
if self.action_space == 'regression': |
|
x = self.actor(inputs) |
|
return {'action': x['pred']} |
|
elif self.action_space == 'reparameterization': |
|
x = self.actor(inputs) |
|
return {'logit': [x['mu'], x['sigma']]} |
|
|