File size: 12,221 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 |
from typing import Union, Dict, Optional
import torch
import torch.nn as nn
from ding.utils import SequenceType, squeeze, MODEL_REGISTRY
from ..common import RegressionHead, ReparameterizationHead, DistributionHead
@MODEL_REGISTRY.register('qac_dist')
class QACDIST(nn.Module):
"""
Overview:
The QAC model with distributional Q-value.
Interfaces:
``__init__``, ``forward``, ``compute_actor``, ``compute_critic``
"""
mode = ['compute_actor', 'compute_critic']
def __init__(
self,
obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType],
action_space: str = "regression",
critic_head_type: str = "categorical",
actor_head_hidden_size: int = 64,
actor_head_layer_num: int = 1,
critic_head_hidden_size: int = 64,
critic_head_layer_num: int = 1,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None,
v_min: Optional[float] = -10,
v_max: Optional[float] = 10,
n_atom: Optional[int] = 51,
) -> None:
"""
Overview:
Init the QAC Distributional Model according to arguments.
Arguments:
- obs_shape (:obj:`Union[int, SequenceType]`): Observation's space.
- action_shape (:obj:`Union[int, SequenceType]`): Action's space.
- action_space (:obj:`str`): Whether choose ``regression`` or ``reparameterization``.
- critic_head_type (:obj:`str`): Only ``categorical``.
- actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``.
- actor_head_layer_num (:obj:`int`):
The num of layers used in the network to compute Q value output for actor's nn.
- critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``.
- critic_head_layer_num (:obj:`int`):
The num of layers used in the network to compute Q value output for critic's nn.
- activation (:obj:`Optional[nn.Module]`):
The type of activation function to use in ``MLP`` the after ``layer_fn``,
if ``None`` then default set to ``nn.ReLU()``
- norm_type (:obj:`Optional[str]`):
The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details.
- v_min (:obj:`int`): Value of the smallest atom
- v_max (:obj:`int`): Value of the largest atom
- n_atom (:obj:`int`): Number of atoms in the support
"""
super(QACDIST, self).__init__()
obs_shape: int = squeeze(obs_shape)
action_shape: int = squeeze(action_shape)
self.action_space = action_space
assert self.action_space in ['regression', 'reparameterization']
if self.action_space == 'regression':
self.actor = nn.Sequential(
nn.Linear(obs_shape, actor_head_hidden_size), activation,
RegressionHead(
actor_head_hidden_size,
action_shape,
actor_head_layer_num,
final_tanh=True,
activation=activation,
norm_type=norm_type
)
)
elif self.action_space == 'reparameterization':
self.actor = nn.Sequential(
nn.Linear(obs_shape, actor_head_hidden_size), activation,
ReparameterizationHead(
actor_head_hidden_size,
action_shape,
actor_head_layer_num,
sigma_type='conditioned',
activation=activation,
norm_type=norm_type
)
)
self.critic_head_type = critic_head_type
assert self.critic_head_type in ['categorical'], self.critic_head_type
if self.critic_head_type == 'categorical':
self.critic = nn.Sequential(
nn.Linear(obs_shape + action_shape, critic_head_hidden_size), activation,
DistributionHead(
critic_head_hidden_size,
1,
critic_head_layer_num,
n_atom=n_atom,
v_min=v_min,
v_max=v_max,
activation=activation,
norm_type=norm_type
)
)
def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict:
"""
Overview:
Use observation and action tensor to predict output.
Parameter updates with QACDIST's MLPs forward setup.
Arguments:
Forward with ``'compute_actor'``:
- inputs (:obj:`torch.Tensor`):
The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``.
Whether ``actor_head_hidden_size`` or ``critic_head_hidden_size`` depend on ``mode``.
Forward with ``'compute_critic'``, inputs (`Dict`) Necessary Keys:
- ``obs``, ``action`` encoded tensors.
- mode (:obj:`str`): Name of the forward mode.
Returns:
- outputs (:obj:`Dict`): Outputs of network forward.
Forward with ``'compute_actor'``, Necessary Keys (either):
- action (:obj:`torch.Tensor`): Action tensor with same size as input ``x``.
- logit (:obj:`torch.Tensor`):
Logit tensor encoding ``mu`` and ``sigma``, both with same size as input ``x``.
Forward with ``'compute_critic'``, Necessary Keys:
- q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
- distribution (:obj:`torch.Tensor`): Q value distribution tensor.
Actor Shapes:
- inputs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``hidden_size``
- action (:obj:`torch.Tensor`): :math:`(B, N0)`
- q_value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size.
Critic Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape``
- action (:obj:`torch.Tensor`): :math:`(B, N2)`, where B is batch size and N2 is``action_shape``
- q_value (:obj:`torch.FloatTensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape``
- distribution (:obj:`torch.FloatTensor`): :math:`(B, 1, N3)`, where B is batch size and N3 is ``num_atom``
Actor Examples:
>>> # Regression mode
>>> model = QACDIST(64, 64, 'regression')
>>> inputs = torch.randn(4, 64)
>>> actor_outputs = model(inputs,'compute_actor')
>>> assert actor_outputs['action'].shape == torch.Size([4, 64])
>>> # Reparameterization Mode
>>> model = QACDIST(64, 64, 'reparameterization')
>>> inputs = torch.randn(4, 64)
>>> actor_outputs = model(inputs,'compute_actor')
>>> actor_outputs['logit'][0].shape # mu
>>> torch.Size([4, 64])
>>> actor_outputs['logit'][1].shape # sigma
>>> torch.Size([4, 64])
Critic Examples:
>>> # Categorical mode
>>> inputs = {'obs': torch.randn(4,N), 'action': torch.randn(4,1)}
>>> model = QACDIST(obs_shape=(N, ),action_shape=1,action_space='regression', \
... critic_head_type='categorical', n_atoms=51)
>>> q_value = model(inputs, mode='compute_critic') # q value
>>> assert q_value['q_value'].shape == torch.Size([4, 1])
>>> assert q_value['distribution'].shape == torch.Size([4, 1, 51])
"""
assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
return getattr(self, mode)(inputs)
def compute_actor(self, inputs: torch.Tensor) -> Dict:
"""
Overview:
Use encoded embedding tensor to predict output.
Execute parameter updates with ``'compute_actor'`` mode
Use encoded embedding tensor to predict output.
Arguments:
- inputs (:obj:`torch.Tensor`):
The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``.
``hidden_size = actor_head_hidden_size``
- mode (:obj:`str`): Name of the forward mode.
Returns:
- outputs (:obj:`Dict`): Outputs of forward pass encoder and head.
ReturnsKeys (either):
- action (:obj:`torch.Tensor`): Continuous action tensor with same size as ``action_shape``.
- logit (:obj:`torch.Tensor`):
Logit tensor encoding ``mu`` and ``sigma``, both with same size as input ``x``.
Shapes:
- inputs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``hidden_size``
- action (:obj:`torch.Tensor`): :math:`(B, N0)`
- logit (:obj:`list`): 2 elements, mu and sigma, each is the shape of :math:`(B, N0)`.
- q_value (:obj:`torch.FloatTensor`): :math:`(B, )`, B is batch size.
Examples:
>>> # Regression mode
>>> model = QACDIST(64, 64, 'regression')
>>> inputs = torch.randn(4, 64)
>>> actor_outputs = model(inputs,'compute_actor')
>>> assert actor_outputs['action'].shape == torch.Size([4, 64])
>>> # Reparameterization Mode
>>> model = QACDIST(64, 64, 'reparameterization')
>>> inputs = torch.randn(4, 64)
>>> actor_outputs = model(inputs,'compute_actor')
>>> actor_outputs['logit'][0].shape # mu
>>> torch.Size([4, 64])
>>> actor_outputs['logit'][1].shape # sigma
>>> torch.Size([4, 64])
"""
x = self.actor(inputs)
if self.action_space == 'regression':
return {'action': x['pred']}
elif self.action_space == 'reparameterization':
return {'logit': [x['mu'], x['sigma']]}
def compute_critic(self, inputs: Dict) -> Dict:
"""
Overview:
Execute parameter updates with ``'compute_critic'`` mode
Use encoded embedding tensor to predict output.
Arguments:
- ``obs``, ``action`` encoded tensors.
- mode (:obj:`str`): Name of the forward mode.
Returns:
- outputs (:obj:`Dict`): Q-value output and distribution.
ReturnKeys:
- q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
- distribution (:obj:`torch.Tensor`): Q value distribution tensor.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape``
- action (:obj:`torch.Tensor`): :math:`(B, N2)`, where B is batch size and N2 is``action_shape``
- q_value (:obj:`torch.FloatTensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape``
- distribution (:obj:`torch.FloatTensor`): :math:`(B, 1, N3)`, where B is batch size and N3 is ``num_atom``
Examples:
>>> # Categorical mode
>>> inputs = {'obs': torch.randn(4,N), 'action': torch.randn(4,1)}
>>> model = QACDIST(obs_shape=(N, ),action_shape=1,action_space='regression', \
... critic_head_type='categorical', n_atoms=51)
>>> q_value = model(inputs, mode='compute_critic') # q value
>>> assert q_value['q_value'].shape == torch.Size([4, 1])
>>> assert q_value['distribution'].shape == torch.Size([4, 1, 51])
"""
obs, action = inputs['obs'], inputs['action']
assert len(obs.shape) == 2
if len(action.shape) == 1: # (B, ) -> (B, 1)
action = action.unsqueeze(1)
x = torch.cat([obs, action], dim=1)
x = self.critic(x)
return {'q_value': x['logit'], 'distribution': x['distribution']}
|