zjowowen's picture
init space
079c32c
from typing import Union, Optional, Dict
import torch
import torch.nn as nn
from easydict import EasyDict
from ding.utils import MODEL_REGISTRY, SequenceType, squeeze
from ..common import FCEncoder, ConvEncoder, DiscreteHead, DuelingHead, \
MultiHead, RegressionHead, ReparameterizationHead
@MODEL_REGISTRY.register('discrete_bc')
class DiscreteBC(nn.Module):
"""
Overview:
The DiscreteBC network.
Interfaces:
``__init__``, ``forward``
"""
def __init__(
self,
obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType],
encoder_hidden_size_list: SequenceType = [128, 128, 64],
dueling: bool = True,
head_hidden_size: Optional[int] = None,
head_layer_num: int = 1,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None,
strides: Optional[list] = None,
) -> None:
"""
Overview:
Init the DiscreteBC (encoder + head) Model according to input arguments.
Arguments:
- obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84].
- action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3].
- encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \
the last element must match ``head_hidden_size``.
- dueling (:obj:`dueling`): Whether choose ``DuelingHead`` or ``DiscreteHead(default)``.
- head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of head network.
- head_layer_num (:obj:`int`): The number of layers used in the head network to compute Q value output
- activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \
if ``None`` then default set it to ``nn.ReLU()``.
- norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \
``ding.torch_utils.fc_block`` for more details.
- strides (:obj:`Optional[list]`): The strides for each convolution layers, such as [2, 2, 2]. The length \
of this argument should be the same as ``encoder_hidden_size_list``.
"""
super(DiscreteBC, self).__init__()
# For compatibility: 1, (1, ), [4, 32, 32]
obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape)
if head_hidden_size is None:
head_hidden_size = encoder_hidden_size_list[-1]
# FC Encoder
if isinstance(obs_shape, int) or len(obs_shape) == 1:
self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
# Conv Encoder
elif len(obs_shape) == 3:
if not strides:
self.encoder = ConvEncoder(
obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type
)
else:
self.encoder = ConvEncoder(
obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type, stride=strides
)
else:
raise RuntimeError(
"not support obs_shape for pre-defined encoder: {}, please customize your own BC".format(obs_shape)
)
# Head Type
if dueling:
head_cls = DuelingHead
else:
head_cls = DiscreteHead
multi_head = not isinstance(action_shape, int)
if multi_head:
self.head = MultiHead(
head_cls,
head_hidden_size,
action_shape,
layer_num=head_layer_num,
activation=activation,
norm_type=norm_type
)
else:
self.head = head_cls(
head_hidden_size, action_shape, head_layer_num, activation=activation, norm_type=norm_type
)
def forward(self, x: torch.Tensor) -> Dict:
"""
Overview:
DiscreteBC forward computation graph, input observation tensor to predict q_value.
Arguments:
- x (:obj:`torch.Tensor`): Observation inputs
Returns:
- outputs (:obj:`Dict`): DiscreteBC forward outputs, such as q_value.
ReturnsKeys:
- logit (:obj:`torch.Tensor`): Discrete Q-value output of each action dimension.
Shapes:
- x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape``
- logit (:obj:`torch.FloatTensor`): :math:`(B, M)`, where B is batch size and M is ``action_shape``
Examples:
>>> model = DiscreteBC(32, 6) # arguments: 'obs_shape' and 'action_shape'
>>> inputs = torch.randn(4, 32)
>>> outputs = model(inputs)
>>> assert isinstance(outputs, dict) and outputs['logit'].shape == torch.Size([4, 6])
"""
x = self.encoder(x)
x = self.head(x)
return x
@MODEL_REGISTRY.register('continuous_bc')
class ContinuousBC(nn.Module):
"""
Overview:
The ContinuousBC network.
Interfaces:
``__init__``, ``forward``
"""
def __init__(
self,
obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType, EasyDict],
action_space: str,
actor_head_hidden_size: int = 64,
actor_head_layer_num: int = 1,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None,
) -> None:
"""
Overview:
Initialize the ContinuousBC Model according to input arguments.
Arguments:
- obs_shape (:obj:`Union[int, SequenceType]`): Observation's shape, such as 128, (156, ).
- action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's shape, such as 4, (3, ), \
EasyDict({'action_type_shape': 3, 'action_args_shape': 4}).
- action_space (:obj:`str`): The type of action space, \
including [``regression``, ``reparameterization``].
- actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor head.
- actor_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \
for actor head.
- activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` \
after each FC layer, if ``None`` then default set to ``nn.ReLU()``.
- norm_type (:obj:`Optional[str]`): The type of normalization to after network layer (FC, Conv), \
see ``ding.torch_utils.network`` for more details.
"""
super(ContinuousBC, self).__init__()
obs_shape: int = squeeze(obs_shape)
action_shape = squeeze(action_shape)
self.action_shape = action_shape
self.action_space = action_space
assert self.action_space in ['regression', 'reparameterization']
if self.action_space == 'regression':
self.actor = nn.Sequential(
nn.Linear(obs_shape, actor_head_hidden_size), activation,
RegressionHead(
actor_head_hidden_size,
action_shape,
actor_head_layer_num,
final_tanh=True,
activation=activation,
norm_type=norm_type
)
)
elif self.action_space == 'reparameterization':
self.actor = nn.Sequential(
nn.Linear(obs_shape, actor_head_hidden_size), activation,
ReparameterizationHead(
actor_head_hidden_size,
action_shape,
actor_head_layer_num,
sigma_type='conditioned',
activation=activation,
norm_type=norm_type
)
)
def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Dict:
"""
Overview:
The unique execution (forward) method of ContinuousBC.
Arguments:
- inputs (:obj:`torch.Tensor`): Observation data, defaults to tensor.
Returns:
- output (:obj:`Dict`): Output dict data, including different key-values among distinct action_space.
ReturnsKeys:
- action (:obj:`torch.Tensor`): action output of actor network, \
with shape :math:`(B, action_shape)`.
- logit (:obj:`List[torch.Tensor]`): reparameterized action output of actor network, \
with shape :math:`(B, action_shape)`.
Shapes:
- inputs (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape``
- action (:obj:`torch.FloatTensor`): :math:`(B, M)`, where B is batch size and M is ``action_shape``
- logit (:obj:`List[torch.FloatTensor]`): :math:`(B, M)`, where B is batch size and M is ``action_shape``
Examples (Regression):
>>> model = ContinuousBC(32, 6, action_space='regression')
>>> inputs = torch.randn(4, 32)
>>> outputs = model(inputs)
>>> assert isinstance(outputs, dict) and outputs['action'].shape == torch.Size([4, 6])
Examples (Reparameterization):
>>> model = ContinuousBC(32, 6, action_space='reparameterization')
>>> inputs = torch.randn(4, 32)
>>> outputs = model(inputs)
>>> assert isinstance(outputs, dict) and outputs['logit'][0].shape == torch.Size([4, 6])
>>> assert outputs['logit'][1].shape == torch.Size([4, 6])
"""
if self.action_space == 'regression':
x = self.actor(inputs)
return {'action': x['pred']}
elif self.action_space == 'reparameterization':
x = self.actor(inputs)
return {'logit': [x['mu'], x['sigma']]}