|
""" |
|
Overview: |
|
In this Python file, we provide a collection of reusable model templates designed to streamline the development |
|
process for various custom algorithms. By utilizing these pre-built model templates, users can quickly adapt and |
|
customize their custom algorithms, ensuring efficient and effective development. |
|
BTW, users can refer to the unittest of these model templates to learn how to use them. |
|
""" |
|
import math |
|
from typing import Optional, Tuple |
|
from dataclasses import dataclass |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from ding.torch_utils import MLP, ResBlock |
|
from ding.utils import SequenceType |
|
|
|
|
|
|
|
@dataclass |
|
class EZNetworkOutput: |
|
|
|
value: torch.Tensor |
|
value_prefix: torch.Tensor |
|
policy_logits: torch.Tensor |
|
latent_state: torch.Tensor |
|
reward_hidden_state: Tuple[torch.Tensor] |
|
|
|
|
|
@dataclass |
|
class MZNetworkOutput: |
|
|
|
value: torch.Tensor |
|
reward: torch.Tensor |
|
policy_logits: torch.Tensor |
|
latent_state: torch.Tensor |
|
|
|
|
|
class DownSample(nn.Module): |
|
|
|
def __init__(self, observation_shape: SequenceType, out_channels: int, activation: nn.Module = nn.ReLU(inplace=True), |
|
norm_type: Optional[str] = 'BN', |
|
) -> None: |
|
""" |
|
Overview: |
|
Define downSample convolution network. Encode the observation into hidden state. |
|
This network is often used in video games like Atari. In board games like go and chess, |
|
we don't need this module. |
|
Arguments: |
|
- observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[12, 96, 96] |
|
for video games like atari, RGB 3 channel times stack 4 frames. |
|
- out_channels (:obj:`int`): The output channels of output hidden state. |
|
- activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(). \ |
|
Use the inplace operation to speed up. |
|
- norm_type (:obj:`Optional[str]`): The normalization type used in network, defaults to 'BN'. |
|
""" |
|
super().__init__() |
|
assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" |
|
|
|
self.conv1 = nn.Conv2d( |
|
observation_shape[0], |
|
out_channels // 2, |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
bias=False, |
|
) |
|
if norm_type == 'BN': |
|
self.norm1 = nn.BatchNorm2d(out_channels // 2) |
|
elif norm_type == 'LN': |
|
self.norm1 = nn.LayerNorm([out_channels // 2, observation_shape[-2] // 2, observation_shape[-1] // 2]) |
|
|
|
self.resblocks1 = nn.ModuleList( |
|
[ |
|
ResBlock( |
|
in_channels=out_channels // 2, |
|
activation=activation, |
|
norm_type='BN', |
|
res_type='basic', |
|
bias=False |
|
) for _ in range(1) |
|
] |
|
) |
|
self.conv2 = nn.Conv2d( |
|
out_channels // 2, |
|
out_channels, |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
bias=False, |
|
) |
|
self.downsample_block = ResBlock( |
|
in_channels=out_channels // 2, |
|
out_channels=out_channels, |
|
activation=activation, |
|
norm_type='BN', |
|
res_type='downsample', |
|
bias=False |
|
) |
|
self.resblocks2 = nn.ModuleList( |
|
[ |
|
ResBlock( |
|
in_channels=out_channels, activation=activation, norm_type='BN', res_type='basic', bias=False |
|
) for _ in range(1) |
|
] |
|
) |
|
self.pooling1 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) |
|
self.resblocks3 = nn.ModuleList( |
|
[ |
|
ResBlock( |
|
in_channels=out_channels, activation=activation, norm_type='BN', res_type='basic', bias=False |
|
) for _ in range(1) |
|
] |
|
) |
|
self.pooling2 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) |
|
self.activation = activation |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Shapes: |
|
- x (:obj:`torch.Tensor`): :math:`(B, C_in, W, H)`, where B is batch size, C_in is channel, W is width, \ |
|
H is height. |
|
- output (:obj:`torch.Tensor`): :math:`(B, C_out, W_, H_)`, where B is batch size, C_out is channel, W_ is \ |
|
output width, H_ is output height. |
|
""" |
|
x = self.conv1(x) |
|
x = self.norm1(x) |
|
x = self.activation(x) |
|
|
|
for block in self.resblocks1: |
|
x = block(x) |
|
x = self.downsample_block(x) |
|
for block in self.resblocks2: |
|
x = block(x) |
|
x = self.pooling1(x) |
|
for block in self.resblocks3: |
|
x = block(x) |
|
output = self.pooling2(x) |
|
return output |
|
|
|
|
|
class RepresentationNetwork(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
observation_shape: SequenceType = (12, 96, 96), |
|
num_res_blocks: int = 1, |
|
num_channels: int = 64, |
|
downsample: bool = True, |
|
activation: nn.Module = nn.ReLU(inplace=True), |
|
norm_type: str = 'BN', |
|
) -> None: |
|
""" |
|
Overview: |
|
Representation network used in MuZero and derived algorithms. Encode the 2D image obs into hidden state. |
|
Arguments: |
|
- observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[12, 96, 96] |
|
for video games like atari, RGB 3 channel times stack 4 frames. |
|
- num_res_blocks (:obj:`int`): The number of residual blocks. |
|
- num_channels (:obj:`int`): The channel of output hidden state. |
|
- downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ |
|
defaults to True. This option is often used in video games like Atari. In board games like go, \ |
|
we don't need this module. |
|
- activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(). \ |
|
Use the inplace operation to speed up. |
|
- norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. |
|
""" |
|
super().__init__() |
|
assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" |
|
|
|
self.downsample = downsample |
|
if self.downsample: |
|
self.downsample_net = DownSample( |
|
observation_shape, |
|
num_channels, |
|
activation=activation, |
|
norm_type=norm_type, |
|
) |
|
else: |
|
self.conv = nn.Conv2d(observation_shape[0], num_channels, kernel_size=3, stride=1, padding=1, bias=False) |
|
|
|
if norm_type == 'BN': |
|
self.norm = nn.BatchNorm2d(num_channels) |
|
elif norm_type == 'LN': |
|
if downsample: |
|
self.norm = nn.LayerNorm([num_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)]) |
|
else: |
|
self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]]) |
|
|
|
self.resblocks = nn.ModuleList( |
|
[ |
|
ResBlock( |
|
in_channels=num_channels, activation=activation, norm_type='BN', res_type='basic', bias=False |
|
) for _ in range(num_res_blocks) |
|
] |
|
) |
|
self.activation = activation |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Shapes: |
|
- x (:obj:`torch.Tensor`): :math:`(B, C_in, W, H)`, where B is batch size, C_in is channel, W is width, \ |
|
H is height. |
|
- output (:obj:`torch.Tensor`): :math:`(B, C_out, W_, H_)`, where B is batch size, C_out is channel, W_ is \ |
|
output width, H_ is output height. |
|
""" |
|
if self.downsample: |
|
x = self.downsample_net(x) |
|
else: |
|
x = self.conv(x) |
|
x = self.norm(x) |
|
x = self.activation(x) |
|
|
|
for block in self.resblocks: |
|
x = block(x) |
|
return x |
|
|
|
def get_param_mean(self) -> float: |
|
""" |
|
Overview: |
|
Get the mean of parameters in the network for debug and visualization. |
|
Returns: |
|
- mean (:obj:`float`): The mean of parameters in the network. |
|
""" |
|
mean = [] |
|
for name, param in self.named_parameters(): |
|
mean += np.abs(param.detach().cpu().numpy().reshape(-1)).tolist() |
|
mean = sum(mean) / len(mean) |
|
return mean |
|
|
|
|
|
class RepresentationNetworkMLP(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
observation_shape: int, |
|
hidden_channels: int = 64, |
|
layer_num: int = 2, |
|
activation: Optional[nn.Module] = nn.ReLU(inplace=True), |
|
last_linear_layer_init_zero: bool = True, |
|
norm_type: Optional[str] = 'BN', |
|
) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Representation network used in MuZero and derived algorithms. Encode the vector obs into latent state \ |
|
with Multi-Layer Perceptron (MLP). |
|
Arguments: |
|
- observation_shape (:obj:`int`): The shape of vector observation space, e.g. N = 10. |
|
- num_res_blocks (:obj:`int`): The number of residual blocks. |
|
- hidden_channels (:obj:`int`): The channel of output hidden state. |
|
- downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ |
|
defaults to True. This option is often used in video games like Atari. In board games like go, \ |
|
we don't need this module. |
|
- activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(). \ |
|
Use the inplace operation to speed up. |
|
- last_linear_layer_init_zero (:obj:`bool`): Whether to initialize the last linear layer with zeros, \ |
|
which can provide stable zero outputs in the beginning, defaults to True. |
|
- norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. |
|
""" |
|
super().__init__() |
|
self.fc_representation = MLP( |
|
in_channels=observation_shape, |
|
hidden_channels=hidden_channels, |
|
out_channels=hidden_channels, |
|
layer_num=layer_num, |
|
activation=activation, |
|
norm_type=norm_type, |
|
|
|
output_activation=False, |
|
output_norm=False, |
|
|
|
last_linear_layer_init_zero=True, |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Shapes: |
|
- x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size, N is the length of vector observation. |
|
- output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)`, where B is batch size. |
|
""" |
|
return self.fc_representation(x) |
|
|
|
|
|
class PredictionNetwork(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
observation_shape: SequenceType, |
|
action_space_size: int, |
|
num_res_blocks: int, |
|
num_channels: int, |
|
value_head_channels: int, |
|
policy_head_channels: int, |
|
fc_value_layers: int, |
|
fc_policy_layers: int, |
|
output_support_size: int, |
|
flatten_output_size_for_value_head: int, |
|
flatten_output_size_for_policy_head: int, |
|
downsample: bool = False, |
|
last_linear_layer_init_zero: bool = True, |
|
activation: nn.Module = nn.ReLU(inplace=True), |
|
norm_type: Optional[str] = 'BN', |
|
) -> None: |
|
""" |
|
Overview: |
|
The definition of policy and value prediction network, which is used to predict value and policy by the |
|
given latent state. |
|
Arguments: |
|
- observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. (C, H, W) for image. |
|
- action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. |
|
- num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. |
|
- num_channels (:obj:`int`): The channels of hidden states. |
|
- value_head_channels (:obj:`int`): The channels of value head. |
|
- policy_head_channels (:obj:`int`): The channels of policy head. |
|
- fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). |
|
- fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). |
|
- output_support_size (:obj:`int`): The size of categorical value output. |
|
- self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \ |
|
- flatten_output_size_for_value_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ |
|
of the value head. |
|
- flatten_output_size_for_policy_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ |
|
of the policy head. |
|
- downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``. |
|
- last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \ |
|
dynamics/prediction mlp, default sets it to True. |
|
- activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ |
|
operation to speedup, e.g. ReLU(inplace=True). |
|
- norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. |
|
""" |
|
super(PredictionNetwork, self).__init__() |
|
assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" |
|
|
|
self.resblocks = nn.ModuleList( |
|
[ |
|
ResBlock( |
|
in_channels=num_channels, activation=activation, norm_type='BN', res_type='basic', bias=False |
|
) for _ in range(num_res_blocks) |
|
] |
|
) |
|
|
|
self.conv1x1_value = nn.Conv2d(num_channels, value_head_channels, 1) |
|
self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1) |
|
|
|
if norm_type == 'BN': |
|
self.norm_value = nn.BatchNorm2d(value_head_channels) |
|
self.norm_policy = nn.BatchNorm2d(policy_head_channels) |
|
elif norm_type == 'LN': |
|
if downsample: |
|
self.norm_value = nn.LayerNorm([value_head_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)]) |
|
self.norm_policy = nn.LayerNorm([policy_head_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)]) |
|
else: |
|
self.norm_value = nn.LayerNorm([value_head_channels, observation_shape[-2], observation_shape[-1]]) |
|
self.norm_policy = nn.LayerNorm([policy_head_channels, observation_shape[-2], observation_shape[-1]]) |
|
|
|
self.flatten_output_size_for_value_head = flatten_output_size_for_value_head |
|
self.flatten_output_size_for_policy_head = flatten_output_size_for_policy_head |
|
self.activation = activation |
|
|
|
self.fc_value = MLP( |
|
in_channels=self.flatten_output_size_for_value_head, |
|
hidden_channels=fc_value_layers[0], |
|
out_channels=output_support_size, |
|
layer_num=len(fc_value_layers) + 1, |
|
activation=self.activation, |
|
norm_type=norm_type, |
|
output_activation=False, |
|
output_norm=False, |
|
|
|
last_linear_layer_init_zero=last_linear_layer_init_zero |
|
) |
|
self.fc_policy = MLP( |
|
in_channels=self.flatten_output_size_for_policy_head, |
|
hidden_channels=fc_policy_layers[0], |
|
out_channels=action_space_size, |
|
layer_num=len(fc_policy_layers) + 1, |
|
activation=self.activation, |
|
norm_type=norm_type, |
|
output_activation=False, |
|
output_norm=False, |
|
|
|
last_linear_layer_init_zero=last_linear_layer_init_zero |
|
) |
|
|
|
def forward(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Overview: |
|
Forward computation of the prediction network. |
|
Arguments: |
|
- latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim). |
|
Returns: |
|
- policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). |
|
- value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). |
|
""" |
|
for res_block in self.resblocks: |
|
latent_state = res_block(latent_state) |
|
|
|
value = self.conv1x1_value(latent_state) |
|
value = self.norm_value(value) |
|
value = self.activation(value) |
|
|
|
policy = self.conv1x1_policy(latent_state) |
|
policy = self.norm_policy(policy) |
|
policy = self.activation(policy) |
|
|
|
value = value.reshape(-1, self.flatten_output_size_for_value_head) |
|
policy = policy.reshape(-1, self.flatten_output_size_for_policy_head) |
|
|
|
value = self.fc_value(value) |
|
policy = self.fc_policy(policy) |
|
return policy, value |
|
|
|
|
|
class PredictionNetworkMLP(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
action_space_size, |
|
num_channels, |
|
common_layer_num: int = 2, |
|
fc_value_layers: SequenceType = [32], |
|
fc_policy_layers: SequenceType = [32], |
|
output_support_size: int = 601, |
|
last_linear_layer_init_zero: bool = True, |
|
activation: Optional[nn.Module] = nn.ReLU(inplace=True), |
|
norm_type: Optional[str] = 'BN', |
|
): |
|
""" |
|
Overview: |
|
The definition of policy and value prediction network with Multi-Layer Perceptron (MLP), |
|
which is used to predict value and policy by the given latent state. |
|
Arguments: |
|
- action_space_size: (:obj:`int`): Action space size, usually an integer number. For discrete action \ |
|
space, it is the number of discrete actions. |
|
- num_channels (:obj:`int`): The channels of latent states. |
|
- fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). |
|
- fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). |
|
- output_support_size (:obj:`int`): The size of categorical value output. |
|
- last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \ |
|
dynamics/prediction mlp, default sets it to True. |
|
- activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ |
|
operation to speedup, e.g. ReLU(inplace=True). |
|
- norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. |
|
""" |
|
super().__init__() |
|
self.num_channels = num_channels |
|
|
|
|
|
self.fc_prediction_common = MLP( |
|
in_channels=self.num_channels, |
|
hidden_channels=self.num_channels, |
|
out_channels=self.num_channels, |
|
layer_num=common_layer_num, |
|
activation=activation, |
|
norm_type=norm_type, |
|
output_activation=True, |
|
output_norm=True, |
|
|
|
last_linear_layer_init_zero=False, |
|
) |
|
|
|
|
|
self.fc_value_head = MLP( |
|
in_channels=self.num_channels, |
|
hidden_channels=fc_value_layers[0], |
|
out_channels=output_support_size, |
|
layer_num=len(fc_value_layers) + 1, |
|
activation=activation, |
|
norm_type=norm_type, |
|
output_activation=False, |
|
output_norm=False, |
|
|
|
last_linear_layer_init_zero=last_linear_layer_init_zero |
|
) |
|
self.fc_policy_head = MLP( |
|
in_channels=self.num_channels, |
|
hidden_channels=fc_policy_layers[0], |
|
out_channels=action_space_size, |
|
layer_num=len(fc_policy_layers) + 1, |
|
activation=activation, |
|
norm_type=norm_type, |
|
output_activation=False, |
|
output_norm=False, |
|
|
|
last_linear_layer_init_zero=last_linear_layer_init_zero |
|
) |
|
|
|
def forward(self, latent_state: torch.Tensor): |
|
""" |
|
Overview: |
|
Forward computation of the prediction network. |
|
Arguments: |
|
- latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim). |
|
Returns: |
|
- policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). |
|
- value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). |
|
""" |
|
x_prediction_common = self.fc_prediction_common(latent_state) |
|
|
|
value = self.fc_value_head(x_prediction_common) |
|
policy = self.fc_policy_head(x_prediction_common) |
|
return policy, value |
|
|