|
from typing import Union, Dict, Optional |
|
import torch |
|
import torch.nn as nn |
|
|
|
from ding.torch_utils import get_lstm |
|
from ding.utils import SequenceType, squeeze, MODEL_REGISTRY |
|
from ding.model.template.q_learning import parallel_wrapper |
|
from ..common import ReparameterizationHead, RegressionHead, DiscreteHead, \ |
|
FCEncoder, ConvEncoder |
|
|
|
|
|
class RNNLayer(nn.Module): |
|
|
|
def __init__(self, lstm_type, input_size, hidden_size, res_link: bool = False): |
|
super(RNNLayer, self).__init__() |
|
self.rnn = get_lstm(lstm_type, input_size=input_size, hidden_size=hidden_size) |
|
self.res_link = res_link |
|
|
|
def forward(self, x, prev_state, inference: bool = False): |
|
""" |
|
Forward pass of the RNN layer. |
|
If inference is True, sequence length of input is set to 1. |
|
If res_link is True, a residual link is added to the output. |
|
""" |
|
|
|
if self.res_link: |
|
a = x |
|
if inference: |
|
x = x.unsqueeze(0) |
|
|
|
x, next_state = self.rnn(x, prev_state) |
|
x = x.squeeze(0) |
|
if self.res_link: |
|
x = x + a |
|
return {'output': x, 'next_state': next_state} |
|
else: |
|
|
|
lstm_embedding = [] |
|
hidden_state_list = [] |
|
for t in range(x.shape[0]): |
|
|
|
output, prev_state = self.rnn(x[t:t + 1], prev_state) |
|
lstm_embedding.append(output) |
|
hidden_state = [p['h'] for p in prev_state] |
|
|
|
hidden_state_list.append(torch.cat(hidden_state, dim=1)) |
|
x = torch.cat(lstm_embedding, 0) |
|
if self.res_link: |
|
x = x + a |
|
all_hidden_state = torch.cat(hidden_state_list, dim=0) |
|
return {'output': x, 'next_state': prev_state, 'hidden_state': all_hidden_state} |
|
|
|
|
|
@MODEL_REGISTRY.register('havac') |
|
class HAVAC(nn.Module): |
|
""" |
|
Overview: |
|
The HAVAC model of each agent for HAPPO. |
|
Interfaces: |
|
``__init__``, ``forward`` |
|
""" |
|
mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] |
|
|
|
def __init__( |
|
self, |
|
agent_obs_shape: Union[int, SequenceType], |
|
global_obs_shape: Union[int, SequenceType], |
|
action_shape: Union[int, SequenceType], |
|
agent_num: int, |
|
use_lstm: bool = False, |
|
lstm_type: str = 'gru', |
|
encoder_hidden_size_list: SequenceType = [128, 128, 64], |
|
actor_head_hidden_size: int = 64, |
|
actor_head_layer_num: int = 2, |
|
critic_head_hidden_size: int = 64, |
|
critic_head_layer_num: int = 1, |
|
action_space: str = 'discrete', |
|
activation: Optional[nn.Module] = nn.ReLU(), |
|
norm_type: Optional[str] = None, |
|
sigma_type: Optional[str] = 'independent', |
|
bound_type: Optional[str] = None, |
|
res_link: bool = False, |
|
) -> None: |
|
r""" |
|
Overview: |
|
Init the VAC Model for HAPPO according to arguments. |
|
Arguments: |
|
- agent_obs_shape (:obj:`Union[int, SequenceType]`): Observation's space for single agent. |
|
- global_obs_shape (:obj:`Union[int, SequenceType]`): Observation's space for global agent |
|
- action_shape (:obj:`Union[int, SequenceType]`): Action's space. |
|
- agent_num (:obj:`int`): Number of agents. |
|
- lstm_type (:obj:`str`): use lstm or gru, default to gru |
|
- encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder`` |
|
- actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``. |
|
- actor_head_layer_num (:obj:`int`): |
|
The num of layers used in the network to compute Q value output for actor's nn. |
|
- critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``. |
|
- critic_head_layer_num (:obj:`int`): |
|
The num of layers used in the network to compute Q value output for critic's nn. |
|
- activation (:obj:`Optional[nn.Module]`): |
|
The type of activation function to use in ``MLP`` the after ``layer_fn``, |
|
if ``None`` then default set to ``nn.ReLU()`` |
|
- norm_type (:obj:`Optional[str]`): |
|
The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details` |
|
- res_link (:obj:`bool`): use the residual link or not, default to False |
|
""" |
|
super(HAVAC, self).__init__() |
|
self.agent_num = agent_num |
|
self.agent_models = nn.ModuleList( |
|
[ |
|
HAVACAgent( |
|
agent_obs_shape=agent_obs_shape, |
|
global_obs_shape=global_obs_shape, |
|
action_shape=action_shape, |
|
use_lstm=use_lstm, |
|
action_space=action_space, |
|
) for _ in range(agent_num) |
|
] |
|
) |
|
|
|
def forward(self, agent_idx, input_data, mode): |
|
selected_agent_model = self.agent_models[agent_idx] |
|
output = selected_agent_model(input_data, mode) |
|
return output |
|
|
|
|
|
class HAVACAgent(nn.Module): |
|
""" |
|
Overview: |
|
The HAVAC model of each agent for HAPPO. |
|
Interfaces: |
|
``__init__``, ``forward``, ``compute_actor``, ``compute_critic``, ``compute_actor_critic`` |
|
""" |
|
mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] |
|
|
|
def __init__( |
|
self, |
|
agent_obs_shape: Union[int, SequenceType], |
|
global_obs_shape: Union[int, SequenceType], |
|
action_shape: Union[int, SequenceType], |
|
use_lstm: bool = False, |
|
lstm_type: str = 'gru', |
|
encoder_hidden_size_list: SequenceType = [128, 128, 64], |
|
actor_head_hidden_size: int = 64, |
|
actor_head_layer_num: int = 2, |
|
critic_head_hidden_size: int = 64, |
|
critic_head_layer_num: int = 1, |
|
action_space: str = 'discrete', |
|
activation: Optional[nn.Module] = nn.ReLU(), |
|
norm_type: Optional[str] = None, |
|
sigma_type: Optional[str] = 'happo', |
|
bound_type: Optional[str] = None, |
|
res_link: bool = False, |
|
) -> None: |
|
r""" |
|
Overview: |
|
Init the VAC Model for HAPPO according to arguments. |
|
Arguments: |
|
- agent_obs_shape (:obj:`Union[int, SequenceType]`): Observation's space for single agent. |
|
- global_obs_shape (:obj:`Union[int, SequenceType]`): Observation's space for global agent |
|
- action_shape (:obj:`Union[int, SequenceType]`): Action's space. |
|
- lstm_type (:obj:`str`): use lstm or gru, default to gru |
|
- encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder`` |
|
- actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``. |
|
- actor_head_layer_num (:obj:`int`): |
|
The num of layers used in the network to compute Q value output for actor's nn. |
|
- critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``. |
|
- critic_head_layer_num (:obj:`int`): |
|
The num of layers used in the network to compute Q value output for critic's nn. |
|
- activation (:obj:`Optional[nn.Module]`): |
|
The type of activation function to use in ``MLP`` the after ``layer_fn``, |
|
if ``None`` then default set to ``nn.ReLU()`` |
|
- norm_type (:obj:`Optional[str]`): |
|
The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details` |
|
- res_link (:obj:`bool`): use the residual link or not, default to False |
|
""" |
|
super(HAVACAgent, self).__init__() |
|
agent_obs_shape: int = squeeze(agent_obs_shape) |
|
global_obs_shape: int = squeeze(global_obs_shape) |
|
action_shape: int = squeeze(action_shape) |
|
self.global_obs_shape, self.agent_obs_shape, self.action_shape = global_obs_shape, agent_obs_shape, action_shape |
|
self.action_space = action_space |
|
|
|
if isinstance(agent_obs_shape, int) or len(agent_obs_shape) == 1: |
|
actor_encoder_cls = FCEncoder |
|
elif len(agent_obs_shape) == 3: |
|
actor_encoder_cls = ConvEncoder |
|
else: |
|
raise RuntimeError( |
|
"not support obs_shape for pre-defined encoder: {}, please customize your own VAC". |
|
format(agent_obs_shape) |
|
) |
|
if isinstance(global_obs_shape, int) or len(global_obs_shape) == 1: |
|
critic_encoder_cls = FCEncoder |
|
elif len(global_obs_shape) == 3: |
|
critic_encoder_cls = ConvEncoder |
|
else: |
|
raise RuntimeError( |
|
"not support obs_shape for pre-defined encoder: {}, please customize your own VAC". |
|
format(global_obs_shape) |
|
) |
|
|
|
|
|
|
|
|
|
self.actor_encoder = actor_encoder_cls( |
|
obs_shape=agent_obs_shape, |
|
hidden_size_list=encoder_hidden_size_list, |
|
activation=activation, |
|
norm_type=norm_type |
|
) |
|
self.critic_encoder = critic_encoder_cls( |
|
obs_shape=global_obs_shape, |
|
hidden_size_list=encoder_hidden_size_list, |
|
activation=activation, |
|
norm_type=norm_type |
|
) |
|
|
|
self.use_lstm = use_lstm |
|
if self.use_lstm: |
|
self.actor_rnn = RNNLayer( |
|
lstm_type, |
|
input_size=encoder_hidden_size_list[-1], |
|
hidden_size=actor_head_hidden_size, |
|
res_link=res_link |
|
) |
|
self.critic_rnn = RNNLayer( |
|
lstm_type, |
|
input_size=encoder_hidden_size_list[-1], |
|
hidden_size=critic_head_hidden_size, |
|
res_link=res_link |
|
) |
|
|
|
self.critic_head = RegressionHead( |
|
critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type |
|
) |
|
assert self.action_space in ['discrete', 'continuous'], self.action_space |
|
if self.action_space == 'discrete': |
|
self.actor_head = DiscreteHead( |
|
actor_head_hidden_size, action_shape, actor_head_layer_num, activation=activation, norm_type=norm_type |
|
) |
|
elif self.action_space == 'continuous': |
|
self.actor_head = ReparameterizationHead( |
|
actor_head_hidden_size, |
|
action_shape, |
|
actor_head_layer_num, |
|
sigma_type=sigma_type, |
|
activation=activation, |
|
norm_type=norm_type, |
|
bound_type=bound_type |
|
) |
|
|
|
self.actor = [self.actor_encoder, self.actor_rnn, self.actor_head] if self.use_lstm \ |
|
else [self.actor_encoder, self.actor_head] |
|
self.critic = [self.critic_encoder, self.critic_rnn, self.critic_head] if self.use_lstm \ |
|
else [self.critic_encoder, self.critic_head] |
|
|
|
|
|
self.actor = nn.ModuleList(self.actor) |
|
self.critic = nn.ModuleList(self.critic) |
|
|
|
def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict: |
|
r""" |
|
Overview: |
|
Use encoded embedding tensor to predict output. |
|
Parameter updates with VAC's MLPs forward setup. |
|
Arguments: |
|
Forward with ``'compute_actor'`` or ``'compute_critic'``: |
|
- inputs (:obj:`torch.Tensor`): |
|
The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``. |
|
Whether ``actor_head_hidden_size`` or ``critic_head_hidden_size`` depend on ``mode``. |
|
Returns: |
|
- outputs (:obj:`Dict`): |
|
Run with encoder and head. |
|
|
|
Forward with ``'compute_actor'``, Necessary Keys: |
|
- logit (:obj:`torch.Tensor`): Logit encoding tensor, with same size as input ``x``. |
|
|
|
Forward with ``'compute_critic'``, Necessary Keys: |
|
- value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. |
|
Shapes: |
|
- inputs (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N corresponding ``hidden_size`` |
|
- logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape`` |
|
- value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size. |
|
|
|
Actor Examples: |
|
>>> model = VAC(64,128) |
|
>>> inputs = torch.randn(4, 64) |
|
>>> actor_outputs = model(inputs,'compute_actor') |
|
>>> assert actor_outputs['logit'].shape == torch.Size([4, 128]) |
|
|
|
Critic Examples: |
|
>>> model = VAC(64,64) |
|
>>> inputs = torch.randn(4, 64) |
|
>>> critic_outputs = model(inputs,'compute_critic') |
|
>>> critic_outputs['value'] |
|
tensor([0.0252, 0.0235, 0.0201, 0.0072], grad_fn=<SqueezeBackward1>) |
|
|
|
Actor-Critic Examples: |
|
>>> model = VAC(64,64) |
|
>>> inputs = torch.randn(4, 64) |
|
>>> outputs = model(inputs,'compute_actor_critic') |
|
>>> outputs['value'] |
|
tensor([0.0252, 0.0235, 0.0201, 0.0072], grad_fn=<SqueezeBackward1>) |
|
>>> assert outputs['logit'].shape == torch.Size([4, 64]) |
|
|
|
""" |
|
assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) |
|
return getattr(self, mode)(inputs) |
|
|
|
def compute_actor(self, inputs: Dict, inference: bool = False) -> Dict: |
|
r""" |
|
Overview: |
|
Execute parameter updates with ``'compute_actor'`` mode |
|
Use encoded embedding tensor to predict output. |
|
Arguments: |
|
- inputs (:obj:`torch.Tensor`): |
|
input data dict with keys ['obs'(with keys ['agent_state', 'global_state', 'action_mask']), |
|
'actor_prev_state'] |
|
Returns: |
|
- outputs (:obj:`Dict`): |
|
Run with encoder RNN(optional) and head. |
|
|
|
ReturnsKeys: |
|
- logit (:obj:`torch.Tensor`): Logit encoding tensor. |
|
- actor_next_state: |
|
- hidden_state |
|
Shapes: |
|
- logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape`` |
|
- actor_next_state: (B,) |
|
- hidden_state: |
|
|
|
Examples: |
|
>>> model = HAVAC( |
|
agent_obs_shape=obs_dim, |
|
global_obs_shape=global_obs_dim, |
|
action_shape=action_dim, |
|
use_lstm = True, |
|
) |
|
>>> inputs = { |
|
'obs': { |
|
'agent_state': torch.randn(T, bs, obs_dim), |
|
'global_state': torch.randn(T, bs, global_obs_dim), |
|
'action_mask': torch.randint(0, 2, size=(T, bs, action_dim)) |
|
}, |
|
'actor_prev_state': [None for _ in range(bs)], |
|
} |
|
>>> actor_outputs = model(inputs,'compute_actor') |
|
>>> assert actor_outputs['logit'].shape == (T, bs, action_dim) |
|
""" |
|
x = inputs['obs']['agent_state'] |
|
output = {} |
|
if self.use_lstm: |
|
rnn_actor_prev_state = inputs['actor_prev_state'] |
|
if inference: |
|
x = self.actor_encoder(x) |
|
rnn_output = self.actor_rnn(x, rnn_actor_prev_state, inference) |
|
x = rnn_output['output'] |
|
x = self.actor_head(x) |
|
output['next_state'] = rnn_output['next_state'] |
|
|
|
else: |
|
assert len(x.shape) in [3, 5], x.shape |
|
x = parallel_wrapper(self.actor_encoder)(x) |
|
rnn_output = self.actor_rnn(x, rnn_actor_prev_state, inference) |
|
x = rnn_output['output'] |
|
x = parallel_wrapper(self.actor_head)(x) |
|
output['actor_next_state'] = rnn_output['next_state'] |
|
output['actor_hidden_state'] = rnn_output['hidden_state'] |
|
|
|
else: |
|
x = self.actor_encoder(x) |
|
x = self.actor_head(x) |
|
|
|
|
|
if self.action_space == 'discrete': |
|
action_mask = inputs['obs']['action_mask'] |
|
logit = x['logit'] |
|
logit[action_mask == 0.0] = -99999999 |
|
elif self.action_space == 'continuous': |
|
logit = x |
|
output['logit'] = logit |
|
return output |
|
|
|
def compute_critic(self, inputs: Dict, inference: bool = False) -> Dict: |
|
r""" |
|
Overview: |
|
Execute parameter updates with ``'compute_critic'`` mode |
|
Use encoded embedding tensor to predict output. |
|
Arguments: |
|
- inputs (:obj:`Dict`): |
|
input data dict with keys ['obs'(with keys ['agent_state', 'global_state', 'action_mask']), |
|
'critic_prev_state'(when you are using rnn)] |
|
Returns: |
|
- outputs (:obj:`Dict`): |
|
Run with encoder [rnn] and head. |
|
|
|
Necessary Keys: |
|
- value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. |
|
- logits |
|
Shapes: |
|
- value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size. |
|
- logits |
|
|
|
Examples: |
|
>>> model = HAVAC( |
|
agent_obs_shape=obs_dim, |
|
global_obs_shape=global_obs_dim, |
|
action_shape=action_dim, |
|
use_lstm = True, |
|
) |
|
>>> inputs = { |
|
'obs': { |
|
'agent_state': torch.randn(T, bs, obs_dim), |
|
'global_state': torch.randn(T, bs, global_obs_dim), |
|
'action_mask': torch.randint(0, 2, size=(T, bs, action_dim)) |
|
}, |
|
'critic_prev_state': [None for _ in range(bs)], |
|
} |
|
>>> critic_outputs = model(inputs,'compute_critic') |
|
>>> assert critic_outputs['value'].shape == (T, bs)) |
|
""" |
|
global_obs = inputs['obs']['global_state'] |
|
output = {} |
|
if self.use_lstm: |
|
rnn_critic_prev_state = inputs['critic_prev_state'] |
|
if inference: |
|
x = self.critic_encoder(global_obs) |
|
rnn_output = self.critic_rnn(x, rnn_critic_prev_state, inference) |
|
x = rnn_output['output'] |
|
x = self.critic_head(x) |
|
output['next_state'] = rnn_output['next_state'] |
|
|
|
else: |
|
assert len(global_obs.shape) in [3, 5], global_obs.shape |
|
x = parallel_wrapper(self.critic_encoder)(global_obs) |
|
rnn_output = self.critic_rnn(x, rnn_critic_prev_state, inference) |
|
x = rnn_output['output'] |
|
x = parallel_wrapper(self.critic_head)(x) |
|
output['critic_next_state'] = rnn_output['next_state'] |
|
output['critic_hidden_state'] = rnn_output['hidden_state'] |
|
|
|
else: |
|
x = self.critic_encoder(global_obs) |
|
x = self.critic_head(x) |
|
|
|
output['value'] = x['pred'] |
|
return output |
|
|
|
def compute_actor_critic(self, inputs: Dict, inference: bool = False) -> Dict: |
|
r""" |
|
Overview: |
|
Execute parameter updates with ``'compute_actor_critic'`` mode |
|
Use encoded embedding tensor to predict output. |
|
Arguments: |
|
- inputs (:dict): input data dict with keys |
|
['obs'(with keys ['agent_state', 'global_state', 'action_mask']), |
|
'actor_prev_state', 'critic_prev_state'(when you are using rnn)] |
|
|
|
Returns: |
|
- outputs (:obj:`Dict`): |
|
Run with encoder and head. |
|
|
|
ReturnsKeys: |
|
- logit (:obj:`torch.Tensor`): Logit encoding tensor, with same size as input ``x``. |
|
- value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. |
|
Shapes: |
|
- logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape`` |
|
- value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size. |
|
|
|
Examples: |
|
>>> model = VAC(64,64) |
|
>>> inputs = torch.randn(4, 64) |
|
>>> outputs = model(inputs,'compute_actor_critic') |
|
>>> outputs['value'] |
|
tensor([0.0252, 0.0235, 0.0201, 0.0072], grad_fn=<SqueezeBackward1>) |
|
>>> assert outputs['logit'].shape == torch.Size([4, 64]) |
|
|
|
|
|
.. note:: |
|
``compute_actor_critic`` interface aims to save computation when shares encoder. |
|
Returning the combination dictionry. |
|
|
|
""" |
|
actor_output = self.compute_actor(inputs, inference) |
|
critic_output = self.compute_critic(inputs, inference) |
|
if self.use_lstm: |
|
return { |
|
'logit': actor_output['logit'], |
|
'value': critic_output['value'], |
|
'actor_next_state': actor_output['actor_next_state'], |
|
'actor_hidden_state': actor_output['actor_hidden_state'], |
|
'critic_next_state': critic_output['critic_next_state'], |
|
'critic_hidden_state': critic_output['critic_hidden_state'], |
|
} |
|
else: |
|
return { |
|
'logit': actor_output['logit'], |
|
'value': critic_output['value'], |
|
} |
|
|