|
from typing import Union, Optional |
|
from easydict import EasyDict |
|
import torch |
|
import torch.nn as nn |
|
import treetensor.torch as ttorch |
|
from copy import deepcopy |
|
from ding.utils import SequenceType, squeeze |
|
from ding.model.common import ReparameterizationHead, RegressionHead, MultiHead, \ |
|
FCEncoder, ConvEncoder, IMPALAConvEncoder, PopArtVHead |
|
from ding.torch_utils import MLP, fc_block |
|
|
|
|
|
class DiscretePolicyHead(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
hidden_size: int, |
|
output_size: int, |
|
layer_num: int = 1, |
|
activation: Optional[nn.Module] = nn.ReLU(), |
|
norm_type: Optional[str] = None, |
|
) -> None: |
|
super(DiscretePolicyHead, self).__init__() |
|
self.main = nn.Sequential( |
|
MLP( |
|
hidden_size, |
|
hidden_size, |
|
hidden_size, |
|
layer_num, |
|
layer_fn=nn.Linear, |
|
activation=activation, |
|
norm_type=norm_type |
|
), fc_block(hidden_size, output_size) |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return self.main(x) |
|
|
|
|
|
class PPOFModel(nn.Module): |
|
mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] |
|
|
|
def __init__( |
|
self, |
|
obs_shape: Union[int, SequenceType], |
|
action_shape: Union[int, SequenceType, EasyDict], |
|
action_space: str = 'discrete', |
|
share_encoder: bool = True, |
|
encoder_hidden_size_list: SequenceType = [128, 128, 64], |
|
actor_head_hidden_size: int = 64, |
|
actor_head_layer_num: int = 1, |
|
critic_head_hidden_size: int = 64, |
|
critic_head_layer_num: int = 1, |
|
activation: Optional[nn.Module] = nn.ReLU(), |
|
norm_type: Optional[str] = None, |
|
sigma_type: Optional[str] = 'independent', |
|
fixed_sigma_value: Optional[int] = 0.3, |
|
bound_type: Optional[str] = None, |
|
encoder: Optional[torch.nn.Module] = None, |
|
popart_head=False, |
|
) -> None: |
|
super(PPOFModel, self).__init__() |
|
obs_shape = squeeze(obs_shape) |
|
action_shape = squeeze(action_shape) |
|
self.obs_shape, self.action_shape = obs_shape, action_shape |
|
self.share_encoder = share_encoder |
|
|
|
|
|
def new_encoder(outsize): |
|
if isinstance(obs_shape, int) or len(obs_shape) == 1: |
|
return FCEncoder( |
|
obs_shape=obs_shape, |
|
hidden_size_list=encoder_hidden_size_list, |
|
activation=activation, |
|
norm_type=norm_type |
|
) |
|
elif len(obs_shape) == 3: |
|
return ConvEncoder( |
|
obs_shape=obs_shape, |
|
hidden_size_list=encoder_hidden_size_list, |
|
activation=activation, |
|
norm_type=norm_type |
|
) |
|
else: |
|
raise RuntimeError( |
|
"not support obs_shape for pre-defined encoder: {}, please customize your own encoder". |
|
format(obs_shape) |
|
) |
|
|
|
if self.share_encoder: |
|
assert actor_head_hidden_size == critic_head_hidden_size, \ |
|
"actor and critic network head should have same size." |
|
if encoder: |
|
if isinstance(encoder, torch.nn.Module): |
|
self.encoder = encoder |
|
else: |
|
raise ValueError("illegal encoder instance.") |
|
else: |
|
self.encoder = new_encoder(actor_head_hidden_size) |
|
else: |
|
if encoder: |
|
if isinstance(encoder, torch.nn.Module): |
|
self.actor_encoder = encoder |
|
self.critic_encoder = deepcopy(encoder) |
|
else: |
|
raise ValueError("illegal encoder instance.") |
|
else: |
|
self.actor_encoder = new_encoder(actor_head_hidden_size) |
|
self.critic_encoder = new_encoder(critic_head_hidden_size) |
|
|
|
|
|
if not popart_head: |
|
self.critic_head = RegressionHead( |
|
critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type |
|
) |
|
else: |
|
self.critic_head = PopArtVHead( |
|
critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type |
|
) |
|
|
|
self.action_space = action_space |
|
assert self.action_space in ['discrete', 'continuous', 'hybrid'], self.action_space |
|
if self.action_space == 'continuous': |
|
self.multi_head = False |
|
self.actor_head = ReparameterizationHead( |
|
actor_head_hidden_size, |
|
action_shape, |
|
actor_head_layer_num, |
|
sigma_type=sigma_type, |
|
activation=activation, |
|
norm_type=norm_type, |
|
bound_type=bound_type |
|
) |
|
elif self.action_space == 'discrete': |
|
actor_head_cls = DiscretePolicyHead |
|
multi_head = not isinstance(action_shape, int) |
|
self.multi_head = multi_head |
|
if multi_head: |
|
self.actor_head = MultiHead( |
|
actor_head_cls, |
|
actor_head_hidden_size, |
|
action_shape, |
|
layer_num=actor_head_layer_num, |
|
activation=activation, |
|
norm_type=norm_type |
|
) |
|
else: |
|
self.actor_head = actor_head_cls( |
|
actor_head_hidden_size, |
|
action_shape, |
|
actor_head_layer_num, |
|
activation=activation, |
|
norm_type=norm_type |
|
) |
|
elif self.action_space == 'hybrid': |
|
|
|
|
|
action_shape.action_args_shape = squeeze(action_shape.action_args_shape) |
|
action_shape.action_type_shape = squeeze(action_shape.action_type_shape) |
|
actor_action_args = ReparameterizationHead( |
|
actor_head_hidden_size, |
|
action_shape.action_args_shape, |
|
actor_head_layer_num, |
|
sigma_type=sigma_type, |
|
fixed_sigma_value=fixed_sigma_value, |
|
activation=activation, |
|
norm_type=norm_type, |
|
bound_type=bound_type, |
|
) |
|
actor_action_type = DiscretePolicyHead( |
|
actor_head_hidden_size, |
|
action_shape.action_type_shape, |
|
actor_head_layer_num, |
|
activation=activation, |
|
norm_type=norm_type, |
|
) |
|
self.actor_head = nn.ModuleList([actor_action_type, actor_action_args]) |
|
|
|
|
|
if self.share_encoder: |
|
self.actor = [self.encoder, self.actor_head] |
|
self.critic = [self.encoder, self.critic_head] |
|
else: |
|
self.actor = [self.actor_encoder, self.actor_head] |
|
self.critic = [self.critic_encoder, self.critic_head] |
|
|
|
|
|
self.actor = nn.ModuleList(self.actor) |
|
self.critic = nn.ModuleList(self.critic) |
|
|
|
def forward(self, inputs: ttorch.Tensor, mode: str) -> ttorch.Tensor: |
|
assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) |
|
return getattr(self, mode)(inputs) |
|
|
|
def compute_actor(self, x: ttorch.Tensor) -> ttorch.Tensor: |
|
if self.share_encoder: |
|
x = self.encoder(x) |
|
else: |
|
x = self.actor_encoder(x) |
|
|
|
if self.action_space == 'discrete': |
|
return self.actor_head(x) |
|
elif self.action_space == 'continuous': |
|
x = self.actor_head(x) |
|
return ttorch.as_tensor(x) |
|
elif self.action_space == 'hybrid': |
|
action_type = self.actor_head[0](x) |
|
action_args = self.actor_head[1](x) |
|
return ttorch.as_tensor({'action_type': action_type, 'action_args': action_args}) |
|
|
|
def compute_critic(self, x: ttorch.Tensor) -> ttorch.Tensor: |
|
if self.share_encoder: |
|
x = self.encoder(x) |
|
else: |
|
x = self.critic_encoder(x) |
|
x = self.critic_head(x) |
|
return x |
|
|
|
def compute_actor_critic(self, x: ttorch.Tensor) -> ttorch.Tensor: |
|
if self.share_encoder: |
|
actor_embedding = critic_embedding = self.encoder(x) |
|
else: |
|
actor_embedding = self.actor_encoder(x) |
|
critic_embedding = self.critic_encoder(x) |
|
|
|
value = self.critic_head(critic_embedding) |
|
|
|
if self.action_space == 'discrete': |
|
logit = self.actor_head(actor_embedding) |
|
return ttorch.as_tensor({'logit': logit, 'value': value['pred']}) |
|
elif self.action_space == 'continuous': |
|
x = self.actor_head(actor_embedding) |
|
return ttorch.as_tensor({'logit': x, 'value': value['pred']}) |
|
elif self.action_space == 'hybrid': |
|
action_type = self.actor_head[0](actor_embedding) |
|
action_args = self.actor_head[1](actor_embedding) |
|
return ttorch.as_tensor( |
|
{ |
|
'logit': { |
|
'action_type': action_type, |
|
'action_args': action_args |
|
}, |
|
'value': value['pred'] |
|
} |
|
) |
|
|