|
from typing import Dict, Union |
|
import torch |
|
import torch.nn as nn |
|
|
|
from functools import reduce |
|
from ding.torch_utils import one_hot, MLP |
|
from ding.utils import squeeze, list_split, MODEL_REGISTRY, SequenceType |
|
from .q_learning import DRQN |
|
|
|
|
|
class COMAActorNetwork(nn.Module): |
|
""" |
|
Overview: |
|
Decentralized actor network in COMA algorithm. |
|
Interface: |
|
``__init__``, ``forward`` |
|
""" |
|
|
|
def __init__( |
|
self, |
|
obs_shape: int, |
|
action_shape: int, |
|
hidden_size_list: SequenceType = [128, 128, 64], |
|
): |
|
""" |
|
Overview: |
|
Initialize COMA actor network |
|
Arguments: |
|
- obs_shape (:obj:`int`): the dimension of each agent's observation state |
|
- action_shape (:obj:`int`): the dimension of action shape |
|
- hidden_size_list (:obj:`list`): the list of hidden size, default to [128, 128, 64] |
|
""" |
|
super(COMAActorNetwork, self).__init__() |
|
self.main = DRQN(obs_shape, action_shape, hidden_size_list) |
|
|
|
def forward(self, inputs: Dict) -> Dict: |
|
""" |
|
Overview: |
|
The forward computation graph of COMA actor network |
|
Arguments: |
|
- inputs (:obj:`dict`): input data dict with keys ['obs', 'prev_state'] |
|
- agent_state (:obj:`torch.Tensor`): each agent local state(obs) |
|
- action_mask (:obj:`torch.Tensor`): the masked action |
|
- prev_state (:obj:`torch.Tensor`): the previous hidden state |
|
Returns: |
|
- output (:obj:`dict`): output data dict with keys ['logit', 'next_state', 'action_mask'] |
|
ArgumentsKeys: |
|
- necessary: ``obs`` { ``agent_state``, ``action_mask`` }, ``prev_state`` |
|
ReturnsKeys: |
|
- necessary: ``logit``, ``next_state``, ``action_mask`` |
|
Examples: |
|
>>> T, B, A, N = 4, 8, 3, 32 |
|
>>> embedding_dim = 64 |
|
>>> action_dim = 6 |
|
>>> data = torch.randn(T, B, A, N) |
|
>>> model = COMAActorNetwork((N, ), action_dim, [128, embedding_dim]) |
|
>>> prev_state = [[None for _ in range(A)] for _ in range(B)] |
|
>>> for t in range(T): |
|
>>> inputs = {'obs': {'agent_state': data[t], 'action_mask': None}, 'prev_state': prev_state} |
|
>>> outputs = model(inputs) |
|
>>> logit, prev_state = outputs['logit'], outputs['next_state'] |
|
""" |
|
agent_state = inputs['obs']['agent_state'] |
|
prev_state = inputs['prev_state'] |
|
if len(agent_state.shape) == 3: |
|
agent_state = agent_state.unsqueeze(0) |
|
unsqueeze_flag = True |
|
else: |
|
unsqueeze_flag = False |
|
T, B, A = agent_state.shape[:3] |
|
agent_state = agent_state.reshape(T, -1, *agent_state.shape[3:]) |
|
prev_state = reduce(lambda x, y: x + y, prev_state) |
|
output = self.main({'obs': agent_state, 'prev_state': prev_state, 'enable_fast_timestep': True}) |
|
logit, next_state = output['logit'], output['next_state'] |
|
next_state, _ = list_split(next_state, step=A) |
|
logit = logit.reshape(T, B, A, -1) |
|
if unsqueeze_flag: |
|
logit = logit.squeeze(0) |
|
return {'logit': logit, 'next_state': next_state, 'action_mask': inputs['obs']['action_mask']} |
|
|
|
|
|
class COMACriticNetwork(nn.Module): |
|
""" |
|
Overview: |
|
Centralized critic network in COMA algorithm. |
|
Interface: |
|
``__init__``, ``forward`` |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_size: int, |
|
action_shape: int, |
|
hidden_size: int = 128, |
|
): |
|
""" |
|
Overview: |
|
initialize COMA critic network |
|
Arguments: |
|
- input_size (:obj:`int`): the size of input global observation |
|
- action_shape (:obj:`int`): the dimension of action shape |
|
- hidden_size_list (:obj:`list`): the list of hidden size, default to 128 |
|
Returns: |
|
- output (:obj:`dict`): output data dict with keys ['q_value'] |
|
Shapes: |
|
- obs (:obj:`dict`): ``agent_state``: :math:`(T, B, A, N, D)`, ``action_mask``: :math:`(T, B, A, N, A)` |
|
- prev_state (:obj:`list`): :math:`[[[h, c] for _ in range(A)] for _ in range(B)]` |
|
- logit (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)` |
|
- next_state (:obj:`list`): :math:`[[[h, c] for _ in range(A)] for _ in range(B)]` |
|
- action_mask (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)` |
|
""" |
|
super(COMACriticNetwork, self).__init__() |
|
self.action_shape = action_shape |
|
self.act = nn.ReLU() |
|
self.mlp = nn.Sequential( |
|
MLP(input_size, hidden_size, hidden_size, 2, activation=self.act), nn.Linear(hidden_size, action_shape) |
|
) |
|
|
|
def forward(self, data: Dict) -> Dict: |
|
""" |
|
Overview: |
|
forward computation graph of qmix network |
|
Arguments: |
|
- data (:obj:`dict`): input data dict with keys ['obs', 'prev_state', 'action'] |
|
- agent_state (:obj:`torch.Tensor`): each agent local state(obs) |
|
- global_state (:obj:`torch.Tensor`): global state(obs) |
|
- action (:obj:`torch.Tensor`): the masked action |
|
ArgumentsKeys: |
|
- necessary: ``obs`` { ``agent_state``, ``global_state`` }, ``action``, ``prev_state`` |
|
ReturnsKeys: |
|
- necessary: ``q_value`` |
|
Examples: |
|
>>> agent_num, bs, T = 4, 3, 8 |
|
>>> obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9 |
|
>>> coma_model = COMACriticNetwork( |
|
>>> obs_dim - action_dim + global_obs_dim + 2 * action_dim * agent_num, action_dim) |
|
>>> data = { |
|
>>> 'obs': { |
|
>>> 'agent_state': torch.randn(T, bs, agent_num, obs_dim), |
|
>>> 'global_state': torch.randn(T, bs, global_obs_dim), |
|
>>> }, |
|
>>> 'action': torch.randint(0, action_dim, size=(T, bs, agent_num)), |
|
>>> } |
|
>>> output = coma_model(data) |
|
""" |
|
x = self._preprocess_data(data) |
|
q = self.mlp(x) |
|
return {'q_value': q} |
|
|
|
def _preprocess_data(self, data: Dict) -> torch.Tensor: |
|
""" |
|
Overview: |
|
preprocess data to make it can be used by MLP net |
|
Arguments: |
|
- data (:obj:`dict`): input data dict with keys ['obs', 'prev_state', 'action'] |
|
- agent_state (:obj:`torch.Tensor`): each agent local state(obs) |
|
- global_state (:obj:`torch.Tensor`): global state(obs) |
|
- action (:obj:`torch.Tensor`): the masked action |
|
ArgumentsKeys: |
|
- necessary: ``obs`` { ``agent_state``, ``global_state``} , ``action``, ``prev_state`` |
|
Return: |
|
- x (:obj:`torch.Tensor`): the data can be used by MLP net, including \ |
|
``global_state``, ``agent_state``, ``last_action``, ``action``, ``agent_id`` |
|
""" |
|
t_size, batch_size, agent_num = data['obs']['agent_state'].shape[:3] |
|
agent_state_ori, global_state = data['obs']['agent_state'], data['obs']['global_state'] |
|
|
|
|
|
agent_state = agent_state_ori[..., :-self.action_shape - agent_num] |
|
last_action = agent_state_ori[..., -self.action_shape - agent_num:-agent_num] |
|
last_action = last_action.reshape(t_size, batch_size, 1, -1).repeat(1, 1, agent_num, 1) |
|
agent_id = agent_state_ori[..., -agent_num:] |
|
|
|
action = one_hot(data['action'], self.action_shape) |
|
action = action.reshape(t_size, batch_size, -1, agent_num * self.action_shape).repeat(1, 1, agent_num, 1) |
|
action_mask = (1 - torch.eye(agent_num).to(action.device)) |
|
action_mask = action_mask.view(-1, 1).repeat(1, self.action_shape).view(agent_num, -1) |
|
action = (action_mask.unsqueeze(0).unsqueeze(0)) * action |
|
global_state = global_state.unsqueeze(2).repeat(1, 1, agent_num, 1) |
|
|
|
x = torch.cat([global_state, agent_state, last_action, action, agent_id], -1) |
|
return x |
|
|
|
|
|
@MODEL_REGISTRY.register('coma') |
|
class COMA(nn.Module): |
|
""" |
|
Overview: |
|
The network of COMA algorithm, which is QAC-type actor-critic. |
|
Interface: |
|
``__init__``, ``forward`` |
|
Properties: |
|
- mode (:obj:`list`): The list of forward mode, including ``compute_actor`` and ``compute_critic`` |
|
""" |
|
|
|
mode = ['compute_actor', 'compute_critic'] |
|
|
|
def __init__( |
|
self, agent_num: int, obs_shape: Dict, action_shape: Union[int, SequenceType], |
|
actor_hidden_size_list: SequenceType |
|
) -> None: |
|
""" |
|
Overview: |
|
initialize COMA network |
|
Arguments: |
|
- agent_num (:obj:`int`): the number of agent |
|
- obs_shape (:obj:`Dict`): the observation information, including agent_state and \ |
|
global_state |
|
- action_shape (:obj:`Union[int, SequenceType]`): the dimension of action shape |
|
- actor_hidden_size_list (:obj:`SequenceType`): the list of hidden size |
|
""" |
|
super(COMA, self).__init__() |
|
action_shape = squeeze(action_shape) |
|
actor_input_size = squeeze(obs_shape['agent_state']) |
|
critic_input_size = squeeze(obs_shape['agent_state']) + squeeze(obs_shape['global_state']) + \ |
|
agent_num * action_shape + (agent_num - 1) * action_shape |
|
critic_hidden_size = actor_hidden_size_list[-1] |
|
self.actor = COMAActorNetwork(actor_input_size, action_shape, actor_hidden_size_list) |
|
self.critic = COMACriticNetwork(critic_input_size, action_shape, critic_hidden_size) |
|
|
|
def forward(self, inputs: Dict, mode: str) -> Dict: |
|
""" |
|
Overview: |
|
forward computation graph of COMA network |
|
Arguments: |
|
- inputs (:obj:`dict`): input data dict with keys ['obs', 'prev_state', 'action'] |
|
- agent_state (:obj:`torch.Tensor`): each agent local state(obs) |
|
- global_state (:obj:`torch.Tensor`): global state(obs) |
|
- action (:obj:`torch.Tensor`): the masked action |
|
ArgumentsKeys: |
|
- necessary: ``obs`` { ``agent_state``, ``global_state``, ``action_mask`` }, ``action``, ``prev_state`` |
|
ReturnsKeys: |
|
- necessary: |
|
- compute_critic: ``q_value`` |
|
- compute_actor: ``logit``, ``next_state``, ``action_mask`` |
|
Shapes: |
|
- obs (:obj:`dict`): ``agent_state``: :math:`(T, B, A, N, D)`, ``action_mask``: :math:`(T, B, A, N, A)` |
|
- prev_state (:obj:`list`): :math:`[[[h, c] for _ in range(A)] for _ in range(B)]` |
|
- logit (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)` |
|
- next_state (:obj:`list`): :math:`[[[h, c] for _ in range(A)] for _ in range(B)]` |
|
- action_mask (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)` |
|
- q_value (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)` |
|
Examples: |
|
>>> agent_num, bs, T = 4, 3, 8 |
|
>>> agent_num, bs, T = 4, 3, 8 |
|
>>> obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9 |
|
>>> coma_model = COMA( |
|
>>> agent_num=agent_num, |
|
>>> obs_shape=dict(agent_state=(obs_dim, ), global_state=(global_obs_dim, )), |
|
>>> action_shape=action_dim, |
|
>>> actor_hidden_size_list=[128, 64], |
|
>>> ) |
|
>>> prev_state = [[None for _ in range(agent_num)] for _ in range(bs)] |
|
>>> data = { |
|
>>> 'obs': { |
|
>>> 'agent_state': torch.randn(T, bs, agent_num, obs_dim), |
|
>>> 'action_mask': None, |
|
>>> }, |
|
>>> 'prev_state': prev_state, |
|
>>> } |
|
>>> output = coma_model(data, mode='compute_actor') |
|
>>> data= { |
|
>>> 'obs': { |
|
>>> 'agent_state': torch.randn(T, bs, agent_num, obs_dim), |
|
>>> 'global_state': torch.randn(T, bs, global_obs_dim), |
|
>>> }, |
|
>>> 'action': torch.randint(0, action_dim, size=(T, bs, agent_num)), |
|
>>> } |
|
>>> output = coma_model(data, mode='compute_critic') |
|
""" |
|
assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) |
|
if mode == 'compute_actor': |
|
return self.actor(inputs) |
|
elif mode == 'compute_critic': |
|
return self.critic(inputs) |
|
|