File size: 36,999 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 |
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
from ding.model.common import ReparameterizationHead
from ding.torch_utils import MLP, ResBlock
from ding.utils import MODEL_REGISTRY, SequenceType
from .common import EZNetworkOutput, RepresentationNetwork
from .efficientzero_model import DynamicsNetwork
from .utils import renormalize, get_params_mean
# use ModelRegistry to register the model, for more details about ModelRegistry, please refer to DI-engine's document.
@MODEL_REGISTRY.register('SampledEfficientZeroModel')
class SampledEfficientZeroModel(nn.Module):
def __init__(
self,
observation_shape: SequenceType = (12, 96, 96),
action_space_size: int = 6,
num_res_blocks: int = 1,
num_channels: int = 64,
lstm_hidden_size: int = 512,
reward_head_channels: int = 16,
value_head_channels: int = 16,
policy_head_channels: int = 16,
fc_reward_layers: SequenceType = [32],
fc_value_layers: SequenceType = [32],
fc_policy_layers: SequenceType = [32],
reward_support_size: int = 601,
value_support_size: int = 601,
proj_hid: int = 1024,
proj_out: int = 1024,
pred_hid: int = 512,
pred_out: int = 1024,
self_supervised_learning_loss: bool = True,
categorical_distribution: bool = True,
activation: Optional[nn.Module] = nn.ReLU(inplace=True),
last_linear_layer_init_zero: bool = True,
state_norm: bool = False,
downsample: bool = False,
# ==============================================================
# specific sampled related config
# ==============================================================
continuous_action_space: bool = False,
num_of_sampled_actions: int = 6,
sigma_type='conditioned',
fixed_sigma_value: float = 0.3,
bound_type: str = None,
norm_type: str = 'BN',
discrete_action_encoding_type: str = 'one_hot',
*args,
**kwargs,
):
"""
Overview:
The definition of the network model of Sampled EfficientZero, which is a generalization version for 2D image obs.
The networks are mainly built on convolution residual blocks and fully connected layers.
Sampled EfficientZero model consists of a representation network, a dynamics network and a prediction network.
The representation network is an MLP network which maps the raw observation to a latent state.
The dynamics network is an MLP+LSTM network which predicts the next latent state, reward_hidden_state and value_prefix given the current latent state and action.
The prediction network is an MLP network which predicts the value and policy given the current latent state.
Arguments:
- observation_shape (:obj:`SequenceType`): Observation space shape, e.g. [C, W, H]=[12, 96, 96] for Atari.
- action_space_size: (:obj:`int`): Action space size, which is an integer number. For discrete action space, it is the num of discrete actions, \
e.g. 4 for Lunarlander. For continuous action space, it is the dimension of the continuous action, e.g. 4 for bipedalwalker.
- num_res_blocks (:obj:`int`): The number of res blocks in Sampled EfficientZero model.
- num_channels (:obj:`int`): The channels of hidden states.
- lstm_hidden_size (:obj:`int`): dim of lstm hidden state in dynamics network.
- reward_head_channels (:obj:`int`): The channels of reward head.
- value_head_channels (:obj:`int`): The channels of value head.
- policy_head_channels (:obj:`int`): The channels of policy head.
- fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head).
- fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head).
- fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head).
- reward_support_size (:obj:`int`): The size of categorical reward output
- value_support_size (:obj:`int`): The size of categorical value output.
- proj_hid (:obj:`int`): The size of projection hidden layer.
- proj_out (:obj:`int`): The size of projection output layer.
- pred_hid (:obj:`int`): The size of prediction hidden layer.
- pred_out (:obj:`int`): The size of prediction output layer.
- self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks in model, default set it to False.
- categorical_distribution (:obj:`bool`): Whether to use discrete support to represent categorical distribution \
for value, reward/value_prefix.
- activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \
operation to speedup, e.g. ReLU(inplace=True).
- last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \
value/policy mlp, default sets it to True.
- state_norm (:obj:`bool`): Whether to use normalization for hidden states, default sets it to True.
- downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \
defaults to True. This option is often used in video games like Atari. In board games like go, \
we don't need this module.
# ==============================================================
# specific sampled related config
# ==============================================================
- continuous_action_space (:obj:`bool`): The type of action space. default set it to False.
- num_of_sampled_actions (:obj:`int`): the number of sampled actions, i.e. the K in original Sampled MuZero paper.
# Please see ``ReparameterizationHead`` in ``ding.model.common.head`` for more details about the following arguments.
- sigma_type (:obj:`str`): the type of sigma in policy head of prediction network, options={'conditioned', 'fixed'}.
- fixed_sigma_value (:obj:`float`): the fixed sigma value in policy head of prediction network,
- bound_type (:obj:`str`): The type of bound in networks, default set it to None.
- norm_type (:obj:`str`): The type of normalization in networks, default sets it to 'BN'.
- discrete_action_encoding_type (:obj:`str`): The type of encoding for discrete action. default sets it to 'one_hot'. options = {'one_hot', 'not_one_hot'}
"""
super(SampledEfficientZeroModel, self).__init__()
if isinstance(observation_shape, int) or len(observation_shape) == 1:
# for vector obs input, e.g. classical control and box2d environments
# to be compatible with LightZero model/policy, transform to shape: [C, W, H]
observation_shape = [1, observation_shape, 1]
if not categorical_distribution:
self.reward_support_size = 1
self.value_support_size = 1
else:
self.reward_support_size = reward_support_size
self.value_support_size = value_support_size
self.continuous_action_space = continuous_action_space
self.action_space_size = action_space_size
# The dim of action space. For discrete action space, it's 1.
# For continuous action space, it is the dim of action.
self.action_space_dim = action_space_size if self.continuous_action_space else 1
assert discrete_action_encoding_type in ['one_hot', 'not_one_hot'], discrete_action_encoding_type
self.discrete_action_encoding_type = discrete_action_encoding_type
if self.continuous_action_space:
self.action_encoding_dim = action_space_size
else:
if self.discrete_action_encoding_type == 'one_hot':
self.action_encoding_dim = action_space_size
elif self.discrete_action_encoding_type == 'not_one_hot':
self.action_encoding_dim = 1
self.lstm_hidden_size = lstm_hidden_size
self.proj_hid = proj_hid
self.proj_out = proj_out
self.pred_hid = pred_hid
self.pred_out = pred_out
self.last_linear_layer_init_zero = last_linear_layer_init_zero
self.state_norm = state_norm
self.downsample = downsample
self.self_supervised_learning_loss = self_supervised_learning_loss
self.sigma_type = sigma_type
self.fixed_sigma_value = fixed_sigma_value
self.bound_type = bound_type
self.norm_type = norm_type
self.num_of_sampled_actions = num_of_sampled_actions
flatten_output_size_for_reward_head = (
(reward_head_channels * math.ceil(observation_shape[1] / 16) *
math.ceil(observation_shape[2] / 16)) if downsample else
(reward_head_channels * observation_shape[1] * observation_shape[2])
)
flatten_output_size_for_value_head = (
(value_head_channels * math.ceil(observation_shape[1] / 16) *
math.ceil(observation_shape[2] / 16)) if downsample else
(value_head_channels * observation_shape[1] * observation_shape[2])
)
flatten_output_size_for_policy_head = (
(policy_head_channels * math.ceil(observation_shape[1] / 16) *
math.ceil(observation_shape[2] / 16)) if downsample else
(policy_head_channels * observation_shape[1] * observation_shape[2])
)
self.representation_network = RepresentationNetwork(
observation_shape,
num_res_blocks,
num_channels,
downsample,
norm_type=self.norm_type,
)
self.dynamics_network = DynamicsNetwork(
observation_shape,
self.action_encoding_dim,
num_res_blocks,
num_channels + self.action_encoding_dim,
reward_head_channels,
fc_reward_layers,
self.reward_support_size,
flatten_output_size_for_reward_head,
downsample,
lstm_hidden_size=self.lstm_hidden_size,
last_linear_layer_init_zero=self.last_linear_layer_init_zero,
activation=activation,
norm_type=norm_type
)
self.prediction_network = PredictionNetwork(
observation_shape,
self.continuous_action_space,
action_space_size,
num_res_blocks,
num_channels,
value_head_channels,
policy_head_channels,
fc_value_layers,
fc_policy_layers,
self.value_support_size,
flatten_output_size_for_value_head,
flatten_output_size_for_policy_head,
downsample,
last_linear_layer_init_zero=self.last_linear_layer_init_zero,
sigma_type=self.sigma_type,
fixed_sigma_value=self.fixed_sigma_value,
bound_type=self.bound_type,
norm_type=self.norm_type,
)
if self.self_supervised_learning_loss:
# self_supervised_learning_loss related network proposed in EfficientZero
if self.downsample:
# In Atari, if the observation_shape is set to (12, 96, 96), which indicates the original shape of
# (3,96,96), and frame_stack_num is 4. Due to downsample, the encoding of observation (latent_state) is
# (64, 96/16, 96/16), where 64 is the number of channels, 96/16 is the size of the latent state. Thus,
# self.projection_input_dim = 64 * 96/16 * 96/16 = 64*6*6 = 2304
self.projection_input_dim = num_channels * math.ceil(observation_shape[1] / 16
) * math.ceil(observation_shape[2] / 16)
else:
self.projection_input_dim = num_channels * observation_shape[1] * observation_shape[2]
self.projection = nn.Sequential(
nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation,
nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation,
nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out)
)
self.prediction_head = nn.Sequential(
nn.Linear(self.proj_out, self.pred_hid),
nn.BatchNorm1d(self.pred_hid),
activation,
nn.Linear(self.pred_hid, self.pred_out),
)
def initial_inference(self, obs: torch.Tensor) -> EZNetworkOutput:
"""
Overview:
Initial inference of SampledEfficientZero model, which is the first step of the SampledEfficientZero model.
To perform the initial inference, we first use the representation network to obtain the "latent_state" of the observation.
Then we use the prediction network to predict the "value" and "policy_logits" of the "latent_state", and
also prepare the zeros-like ``reward_hidden_state`` for the next step of the SampledEfficientZero model.
Arguments:
- obs (:obj:`torch.Tensor`): The 2D image observation data.
Returns (EZNetworkOutput):
- value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation.
- value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. \
In initial inference, we set it to zero vector.
- policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action.
- latent_state (:obj:`torch.Tensor`): The encoding latent state of input state.
- reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The hidden state of LSTM about reward. In initial inference, \
we set it to the zeros-like hidden state (H and C).
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size.
- value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size.
- value_prefix (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size.
- policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size.
- latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \
latent state, W_ is the width of latent state.
- reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The shape of each element is :math:`(1, B, lstm_hidden_size)`, where B is batch_size.
"""
batch_size = obs.size(0)
latent_state = self._representation(obs)
policy_logits, value = self._prediction(latent_state)
# zero initialization for reward hidden states
# (hn, cn), each element shape is (layer_num=1, batch_size, lstm_hidden_size)
reward_hidden_state = (
torch.zeros(1, batch_size,
self.lstm_hidden_size).to(obs.device), torch.zeros(1, batch_size,
self.lstm_hidden_size).to(obs.device)
)
return EZNetworkOutput(value, [0. for _ in range(batch_size)], policy_logits, latent_state, reward_hidden_state)
def recurrent_inference(
self, latent_state: torch.Tensor, reward_hidden_state: torch.Tensor, action: torch.Tensor
) -> EZNetworkOutput:
"""
Overview:
Recurrent inference of Sampled EfficientZero model, which is the rollout step of the Sampled EfficientZero model.
To perform the recurrent inference, we first use the dynamics network to predict ``next_latent_state``,
``reward_hidden_state``, ``value_prefix`` by the given current ``latent_state`` and ``action``.
We then use the prediction network to predict the ``value`` and ``policy_logits``.
Arguments:
- latent_state (:obj:`torch.Tensor`): The encoding latent state of input state.
- reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The input hidden state of LSTM about reward.
- action (:obj:`torch.Tensor`): The predicted action to rollout.
Returns (EZNetworkOutput):
- value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation.
- value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state.
- policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action.
- latent_state (:obj:`torch.Tensor`): The encoding latent state of input state.
- next_latent_state (:obj:`torch.Tensor`): The predicted next latent state.
- reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The output hidden state of LSTM about reward.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size.
- action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size.
- value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size.
- value_prefix (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size.
- policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size.
- latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \
latent state, W_ is the width of latent state.
- next_latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \
latent state, W_ is the width of latent state.
- reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The shape of each element is :math:`(1, B, lstm_hidden_size)`, where B is batch_size.
"""
next_latent_state, reward_hidden_state, value_prefix = self._dynamics(latent_state, reward_hidden_state, action)
policy_logits, value = self._prediction(next_latent_state)
return EZNetworkOutput(value, value_prefix, policy_logits, next_latent_state, reward_hidden_state)
def _representation(self, observation: torch.Tensor) -> Tuple[torch.Tensor]:
"""
Overview:
Use the representation network to encode the observations into latent state.
Arguments:
- obs (:obj:`torch.Tensor`): The 2D image observation data.
Returns:
- latent_state (:obj:`torch.Tensor`): The encoding latent state of input state.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size.
- latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \
latent state, W_ is the width of latent state.
"""
latent_state = self.representation_network(observation)
if self.state_norm:
latent_state = renormalize(latent_state)
return latent_state
def _prediction(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Overview:
use the prediction network to predict the "value" and "policy_logits" of the "latent_state".
Arguments:
- latent_state (:obj:`torch.Tensor`): The encoding latent state of input obs.
Returns:
- policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action.
- value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation.
Shapes:
- latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \
latent state, W_ is the width of latent state.
- policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size.
- value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size.
"""
return self.prediction_network(latent_state)
def _dynamics(self, latent_state: torch.Tensor, reward_hidden_state: Tuple[torch.Tensor],
action: torch.Tensor) -> Tuple[torch.Tensor, Tuple[torch.Tensor], torch.Tensor]:
"""
Overview:
Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state``
``value_prefix`` and ``next_reward_hidden_state``.
Arguments:
- latent_state (:obj:`torch.Tensor`): The encoding latent state of input state.
- reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The input hidden state of LSTM about reward.
- action (:obj:`torch.Tensor`): The predicted action to rollout.
Returns:
- next_latent_state (:obj:`torch.Tensor`): The predicted latent state of the next timestep.
- next_reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The output hidden state of LSTM about reward.
- value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state.
Shapes:
- latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \
latent state, W_ is the width of latent state.
- action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size.
- next_latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \
latent state, W_ is the width of latent state.
- value_prefix (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size.
"""
# NOTE: the discrete action encoding type is important for some environments
if not self.continuous_action_space:
# discrete action space
if self.discrete_action_encoding_type == 'one_hot':
# Stack latent_state with the one hot encoded action.
# The final action_encoding shape is (batch_size, action_space_size, latent_state[2], latent_state[3]), e.g. (8, 2, 4, 1).
if len(action.shape) == 1:
# (batch_size, ) -> (batch_size, 1)
# e.g., torch.Size([8]) -> torch.Size([8, 1])
action = action.unsqueeze(-1)
# transform action to one-hot encoding.
# action_one_hot shape: (batch_size, action_space_size), e.g., (8, 4)
action_one_hot = torch.zeros(action.shape[0], self.action_space_size, device=action.device)
# transform action to torch.int64
action = action.long()
action_one_hot.scatter_(1, action, 1)
action_encoding_tmp = action_one_hot.unsqueeze(-1).unsqueeze(-1)
action_encoding = action_encoding_tmp.expand(
latent_state.shape[0], self.action_space_size, latent_state.shape[2], latent_state.shape[3]
)
elif self.discrete_action_encoding_type == 'not_one_hot':
# Stack latent_state with the normalized encoded action.
# The final action_encoding shape is (batch_size, 1, latent_state[2], latent_state[3]), e.g. (8, 1, 4, 1).
if len(action.shape) == 2:
# (batch_size, action_dim=1) -> (batch_size, 1, 1, 1)
# e.g., torch.Size([8, 1]) -> torch.Size([8, 1, 1, 1])
action = action.unsqueeze(-1).unsqueeze(-1)
elif len(action.shape) == 1:
# (batch_size,) -> (batch_size, 1, 1, 1)
# e.g., torch.Size([8]) -> torch.Size([8, 1, 1, 1])
action = action.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
action_encoding = action.expand(
latent_state.shape[0], 1, latent_state.shape[2], latent_state.shape[3]
) / self.action_space_size
else:
# continuous action space
if len(action.shape) == 1:
# (batch_size,) -> (batch_size, action_dim=1, 1, 1)
# e.g., torch.Size([8]) -> torch.Size([8, 1, 1, 1])
action = action.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
elif len(action.shape) == 2:
# (batch_size, action_dim) -> (batch_size, action_dim, 1, 1)
# e.g., torch.Size([8, 2]) -> torch.Size([8, 2, 1, 1])
action = action.unsqueeze(-1).unsqueeze(-1)
elif len(action.shape) == 3:
# (batch_size, action_dim, 1) -> (batch_size, action_dim)
# e.g., torch.Size([8, 2, 1]) -> torch.Size([8, 2, 1, 1])
action = action.unsqueeze(-1)
action_encoding_tmp = action
action_encoding = action_encoding_tmp.expand(
latent_state.shape[0], self.action_space_size, latent_state.shape[2], latent_state.shape[3]
)
action_encoding = action_encoding.to(latent_state.device).float()
# state_action_encoding shape: (batch_size, latent_state[1] + action_dim, latent_state[2], latent_state[3]) or
# (batch_size, latent_state[1] + action_space_size, latent_state[2], latent_state[3]) depending on the discrete_action_encoding_type.
state_action_encoding = torch.cat((latent_state, action_encoding), dim=1)
next_latent_state, next_reward_hidden_state, value_prefix = self.dynamics_network(
state_action_encoding, reward_hidden_state
)
if not self.state_norm:
return next_latent_state, next_reward_hidden_state, value_prefix
else:
next_latent_state_normalized = renormalize(next_latent_state)
return next_latent_state_normalized, next_reward_hidden_state, value_prefix
def project(self, latent_state: torch.Tensor, with_grad=True) -> torch.Tensor:
"""
Overview:
Project the latent state to a lower dimension to calculate the self-supervised loss, which is proposed in EfficientZero.
For more details, please refer to paper ``Exploring Simple Siamese Representation Learning``.
Arguments:
- latent_state (:obj:`torch.Tensor`): The encoding latent state of input state.
- with_grad (:obj:`bool`): Whether to calculate gradient for the projection result.
Returns:
- proj (:obj:`torch.Tensor`): The result embedding vector of projection operation.
Shapes:
- latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \
latent state, W_ is the width of latent state.
- proj (:obj:`torch.Tensor`): :math:`(B, projection_output_dim)`, where B is batch_size.
Examples:
>>> latent_state = torch.randn(256, 64, 6, 6)
>>> output = self.project(latent_state)
>>> output.shape # (256, 1024)
.. note::
for Atari:
observation_shape = (12, 96, 96), # original shape is (3,96,96), frame_stack_num=4
if downsample is True, latent_state.shape: (batch_size, num_channel, obs_shape[1] / 16, obs_shape[2] / 16)
i.e., (256, 64, 96 / 16, 96 / 16) = (256, 64, 6, 6)
latent_state reshape: (256, 64, 6, 6) -> (256,64*6*6) = (256, 2304)
# self.projection_input_dim = 64*6*6 = 2304
# self.projection_output_dim = 1024
"""
latent_state = latent_state.reshape(latent_state.shape[0], -1)
proj = self.projection(latent_state)
if with_grad:
# with grad, use prediction_head
return self.prediction_head(proj)
else:
return proj.detach()
def get_params_mean(self):
return get_params_mean(self)
class PredictionNetwork(nn.Module):
def __init__(
self,
observation_shape: SequenceType,
continuous_action_space,
action_space_size,
num_res_blocks,
num_channels,
value_head_channels,
policy_head_channels,
fc_value_layers,
fc_policy_layers,
output_support_size,
flatten_output_size_for_value_head,
flatten_output_size_for_policy_head,
downsample: bool = False,
last_linear_layer_init_zero: bool = True,
activation: Optional[nn.Module] = nn.ReLU(inplace=True),
# ==============================================================
# specific sampled related config
# ==============================================================
sigma_type='conditioned',
fixed_sigma_value: float = 0.3,
bound_type: str = None,
norm_type: str = 'BN',
):
"""
Overview:
The definition of policy and value prediction network, which is used to predict value and policy by the
given latent state.
The networks are mainly build on res_conv_blocks and fully connected layers.
Arguments:
- observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. (C, H, W) for image.
- continuous_action_space (:obj:`bool`): The type of action space. Default sets it to False.
- action_space_size: (:obj:`int`): Action space size, usually an integer number. For discrete action \
space, it is the number of discrete actions. For continuous action space, it is the dimension of \
continuous action.
- num_res_blocks (:obj:`int`): number of res blocks in model.
- num_channels (:obj:`int`): channels of hidden states.
- value_head_channels (:obj:`int`): channels of value head.
- policy_head_channels (:obj:`int`): channels of policy head.
- fc_value_layers (:obj:`SequenceType`): hidden layers of the value prediction head (MLP head).
- fc_policy_layers (:obj:`SequenceType`): hidden layers of the policy prediction head (MLP head).
- output_support_size (:obj:`int`): dim of value output.
- flatten_output_size_for_value_head (:obj:`int`): dim of flatten hidden states.
- flatten_output_size_for_policy_head (:obj:`int`): dim of flatten hidden states.
- downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``.
- last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializationss for the last layer of value/policy mlp, default sets it to True.
# ==============================================================
# specific sampled related config
# ==============================================================
# see ``ReparameterizationHead`` in ``ding.model.common.head`` for more details about the following arguments.
- sigma_type (:obj:`str`): the type of sigma in policy head of prediction network, options={'conditioned', 'fixed'}.
- fixed_sigma_value (:obj:`float`): the fixed sigma value in policy head of prediction network,
- bound_type (:obj:`str`): The type of bound in networks. Default sets it to None.
- norm_type (:obj:`str`): The type of normalization in networks. Default sets it to 'BN'.
"""
super().__init__()
self.continuous_action_space = continuous_action_space
self.flatten_output_size_for_value_head = flatten_output_size_for_value_head
self.flatten_output_size_for_policy_head = flatten_output_size_for_policy_head
self.norm_type = norm_type
self.sigma_type = sigma_type
self.fixed_sigma_value = fixed_sigma_value
self.bound_type = bound_type
self.activation = activation
self.resblocks = nn.ModuleList(
[
ResBlock(
in_channels=num_channels,
activation=activation,
norm_type='BN',
res_type='basic',
bias=False
) for _ in range(num_res_blocks)
]
)
self.conv1x1_value = nn.Conv2d(num_channels, value_head_channels, 1)
self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1)
if norm_type == 'BN':
self.norm_value = nn.BatchNorm2d(value_head_channels)
self.norm_policy = nn.BatchNorm2d(policy_head_channels)
elif norm_type == 'LN':
if downsample:
self.norm_value = nn.LayerNorm(
[value_head_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)])
self.norm_policy = nn.LayerNorm([policy_head_channels, math.ceil(observation_shape[-2] / 16),
math.ceil(observation_shape[-1] / 16)])
else:
self.norm_value = nn.LayerNorm([value_head_channels, observation_shape[-2], observation_shape[-1]])
self.norm_policy = nn.LayerNorm([policy_head_channels, observation_shape[-2], observation_shape[-1]])
self.fc_value_head = MLP(
in_channels=self.flatten_output_size_for_value_head,
hidden_channels=fc_value_layers[0],
out_channels=output_support_size,
layer_num=len(fc_value_layers) + 1,
activation=activation,
norm_type=self.norm_type,
output_activation=False,
output_norm=False,
# last_linear_layer_init_zero=True is beneficial for convergence speed.
last_linear_layer_init_zero=last_linear_layer_init_zero
)
# sampled related core code
if self.continuous_action_space:
self.fc_policy_head = ReparameterizationHead(
input_size=self.flatten_output_size_for_policy_head,
output_size=action_space_size,
layer_num=len(fc_policy_layers) + 1,
sigma_type=self.sigma_type,
fixed_sigma_value=self.fixed_sigma_value,
activation=nn.ReLU(),
norm_type=None,
bound_type=self.bound_type
)
else:
self.fc_policy_head = MLP(
in_channels=self.flatten_output_size_for_policy_head,
hidden_channels=fc_policy_layers[0],
out_channels=action_space_size,
layer_num=len(fc_policy_layers) + 1,
activation=activation,
norm_type=self.norm_type,
output_activation=False,
output_norm=False,
# last_linear_layer_init_zero=True is beneficial for convergence speed.
last_linear_layer_init_zero=last_linear_layer_init_zero
)
def forward(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Overview:
Forward computation of the prediction network.
Arguments:
- latent_state (:obj:`torch.Tensor`): input tensor with shape (B, in_channels).
Returns:
- policy (:obj:`torch.Tensor`): policy tensor. If action space is discrete, shape is (B, action_space_size).
If action space is continuous, shape is (B, action_space_size * 2).
- value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size).
"""
for res_block in self.resblocks:
latent_state = res_block(latent_state)
value = self.conv1x1_value(latent_state)
value = self.norm_value(value)
value = self.activation(value)
policy = self.conv1x1_policy(latent_state)
policy = self.norm_policy(policy)
policy = self.activation(policy)
value = value.reshape(-1, self.flatten_output_size_for_value_head)
policy = policy.reshape(-1, self.flatten_output_size_for_policy_head)
value = self.fc_value_head(value)
# sampled related core code
policy = self.fc_policy_head(policy)
if self.continuous_action_space:
policy = torch.cat([policy['mu'], policy['sigma']], dim=-1)
return policy, value
|