File size: 7,140 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
from typing import Union, List
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from functools import reduce
from ding.utils import list_split, squeeze, MODEL_REGISTRY
from ding.torch_utils.network.nn_module import fc_block, MLP
from ding.torch_utils.network.transformer import ScaledDotProductAttention
from ding.torch_utils import to_tensor, tensor_to_list
from .q_learning import DRQN
@MODEL_REGISTRY.register('qtran')
class QTran(nn.Module):
"""
Overview:
QTRAN network
Interface:
__init__, forward
"""
def __init__(
self,
agent_num: int,
obs_shape: int,
global_obs_shape: int,
action_shape: int,
hidden_size_list: list,
embedding_size: int,
lstm_type: str = 'gru',
dueling: bool = False
) -> None:
"""
Overview:
initialize QTRAN network
Arguments:
- agent_num (:obj:`int`): the number of agent
- obs_shape (:obj:`int`): the dimension of each agent's observation state
- global_obs_shape (:obj:`int`): the dimension of global observation state
- action_shape (:obj:`int`): the dimension of action shape
- hidden_size_list (:obj:`list`): the list of hidden size
- embedding_size (:obj:`int`): the dimension of embedding
- lstm_type (:obj:`str`): use lstm or gru, default to gru
- dueling (:obj:`bool`): use dueling head or not, default to False.
"""
super(QTran, self).__init__()
self._act = nn.ReLU()
self._q_network = DRQN(obs_shape, action_shape, hidden_size_list, lstm_type=lstm_type, dueling=dueling)
q_input_size = global_obs_shape + hidden_size_list[-1] + action_shape
self.Q = nn.Sequential(
nn.Linear(q_input_size, embedding_size), nn.ReLU(), nn.Linear(embedding_size, embedding_size), nn.ReLU(),
nn.Linear(embedding_size, 1)
)
# V(s)
self.V = nn.Sequential(
nn.Linear(global_obs_shape, embedding_size), nn.ReLU(), nn.Linear(embedding_size, embedding_size),
nn.ReLU(), nn.Linear(embedding_size, 1)
)
ae_input = hidden_size_list[-1] + action_shape
self.action_encoding = nn.Sequential(nn.Linear(ae_input, ae_input), nn.ReLU(), nn.Linear(ae_input, ae_input))
def forward(self, data: dict, single_step: bool = True) -> dict:
"""
Overview:
forward computation graph of qtran network
Arguments:
- data (:obj:`dict`): input data dict with keys ['obs', 'prev_state', 'action']
- agent_state (:obj:`torch.Tensor`): each agent local state(obs)
- global_state (:obj:`torch.Tensor`): global state(obs)
- prev_state (:obj:`list`): previous rnn state
- action (:obj:`torch.Tensor` or None): if action is None, use argmax q_value index as action to\
calculate ``agent_q_act``
- single_step (:obj:`bool`): whether single_step forward, if so, add timestep dim before forward and\
remove it after forward
Return:
- ret (:obj:`dict`): output data dict with keys ['total_q', 'logit', 'next_state']
- total_q (:obj:`torch.Tensor`): total q_value, which is the result of mixer network
- agent_q (:obj:`torch.Tensor`): each agent q_value
- next_state (:obj:`list`): next rnn state
Shapes:
- agent_state (:obj:`torch.Tensor`): :math:`(T, B, A, N)`, where T is timestep, B is batch_size\
A is agent_num, N is obs_shape
- global_state (:obj:`torch.Tensor`): :math:`(T, B, M)`, where M is global_obs_shape
- prev_state (:obj:`list`): math:`(B, A)`, a list of length B, and each element is a list of length A
- action (:obj:`torch.Tensor`): :math:`(T, B, A)`
- total_q (:obj:`torch.Tensor`): :math:`(T, B)`
- agent_q (:obj:`torch.Tensor`): :math:`(T, B, A, P)`, where P is action_shape
- next_state (:obj:`list`): math:`(B, A)`, a list of length B, and each element is a list of length A
"""
agent_state, global_state, prev_state = data['obs']['agent_state'], data['obs']['global_state'], data[
'prev_state']
action = data.get('action', None)
if single_step:
agent_state, global_state = agent_state.unsqueeze(0), global_state.unsqueeze(0)
T, B, A = agent_state.shape[:3]
assert len(prev_state) == B and all(
[len(p) == A for p in prev_state]
), '{}-{}-{}-{}'.format([type(p) for p in prev_state], B, A, len(prev_state[0]))
prev_state = reduce(lambda x, y: x + y, prev_state)
agent_state = agent_state.reshape(T, -1, *agent_state.shape[3:])
output = self._q_network({'obs': agent_state, 'prev_state': prev_state, 'enable_fast_timestep': True})
agent_q, next_state = output['logit'], output['next_state']
next_state, _ = list_split(next_state, step=A)
agent_q = agent_q.reshape(T, B, A, -1)
if action is None:
# For target forward process
if len(data['obs']['action_mask'].shape) == 3:
action_mask = data['obs']['action_mask'].unsqueeze(0)
else:
action_mask = data['obs']['action_mask']
agent_q[action_mask == 0.0] = -9999999
action = agent_q.argmax(dim=-1)
agent_q_act = torch.gather(agent_q, dim=-1, index=action.unsqueeze(-1))
agent_q_act = agent_q_act.squeeze(-1) # T, B, A
hidden_states = output['hidden_state'].reshape(T * B, A, -1)
action = action.reshape(T * B, A).unsqueeze(-1)
action_onehot = torch.zeros(size=(T * B, A, agent_q.shape[-1]), device=action.device)
action_onehot = action_onehot.scatter(2, action, 1)
agent_state_action_input = torch.cat([hidden_states, action_onehot], dim=2)
agent_state_action_encoding = self.action_encoding(agent_state_action_input.reshape(T * B * A,
-1)).reshape(T * B, A, -1)
agent_state_action_encoding = agent_state_action_encoding.sum(dim=1) # Sum across agents
inputs = torch.cat([global_state.reshape(T * B, -1), agent_state_action_encoding], dim=1)
q_outputs = self.Q(inputs)
q_outputs = q_outputs.reshape(T, B)
v_outputs = self.V(global_state.reshape(T * B, -1))
v_outputs = v_outputs.reshape(T, B)
if single_step:
q_outputs, agent_q, agent_q_act, v_outputs = q_outputs.squeeze(0), agent_q.squeeze(0), agent_q_act.squeeze(
0
), v_outputs.squeeze(0)
return {
'total_q': q_outputs,
'logit': agent_q,
'agent_q_act': agent_q_act,
'vs': v_outputs,
'next_state': next_state,
'action_mask': data['obs']['action_mask']
}
|