File size: 7,140 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from typing import Union, List
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from functools import reduce
from ding.utils import list_split, squeeze, MODEL_REGISTRY
from ding.torch_utils.network.nn_module import fc_block, MLP
from ding.torch_utils.network.transformer import ScaledDotProductAttention
from ding.torch_utils import to_tensor, tensor_to_list
from .q_learning import DRQN


@MODEL_REGISTRY.register('qtran')
class QTran(nn.Module):
    """
    Overview:
        QTRAN network
    Interface:
        __init__, forward
    """

    def __init__(
            self,
            agent_num: int,
            obs_shape: int,
            global_obs_shape: int,
            action_shape: int,
            hidden_size_list: list,
            embedding_size: int,
            lstm_type: str = 'gru',
            dueling: bool = False
    ) -> None:
        """
        Overview:
            initialize QTRAN network
        Arguments:
            - agent_num (:obj:`int`): the number of agent
            - obs_shape (:obj:`int`): the dimension of each agent's observation state
            - global_obs_shape (:obj:`int`): the dimension of global observation state
            - action_shape (:obj:`int`): the dimension of action shape
            - hidden_size_list (:obj:`list`): the list of hidden size
            - embedding_size (:obj:`int`): the dimension of embedding
            - lstm_type (:obj:`str`): use lstm or gru, default to gru
            - dueling (:obj:`bool`): use dueling head or not, default to False.
        """
        super(QTran, self).__init__()
        self._act = nn.ReLU()
        self._q_network = DRQN(obs_shape, action_shape, hidden_size_list, lstm_type=lstm_type, dueling=dueling)
        q_input_size = global_obs_shape + hidden_size_list[-1] + action_shape
        self.Q = nn.Sequential(
            nn.Linear(q_input_size, embedding_size), nn.ReLU(), nn.Linear(embedding_size, embedding_size), nn.ReLU(),
            nn.Linear(embedding_size, 1)
        )

        # V(s)
        self.V = nn.Sequential(
            nn.Linear(global_obs_shape, embedding_size), nn.ReLU(), nn.Linear(embedding_size, embedding_size),
            nn.ReLU(), nn.Linear(embedding_size, 1)
        )
        ae_input = hidden_size_list[-1] + action_shape
        self.action_encoding = nn.Sequential(nn.Linear(ae_input, ae_input), nn.ReLU(), nn.Linear(ae_input, ae_input))

    def forward(self, data: dict, single_step: bool = True) -> dict:
        """
        Overview:
            forward computation graph of qtran network
        Arguments:
            - data (:obj:`dict`): input data dict with keys ['obs', 'prev_state', 'action']
                - agent_state (:obj:`torch.Tensor`): each agent local state(obs)
                - global_state (:obj:`torch.Tensor`): global state(obs)
                - prev_state (:obj:`list`): previous rnn state
                - action (:obj:`torch.Tensor` or None): if action is None, use argmax q_value index as action to\
                    calculate ``agent_q_act``
            - single_step (:obj:`bool`): whether single_step forward, if so, add timestep dim before forward and\
                remove it after forward
        Return:
            - ret (:obj:`dict`): output data dict with keys ['total_q', 'logit', 'next_state']
                - total_q (:obj:`torch.Tensor`): total q_value, which is the result of mixer network
                - agent_q (:obj:`torch.Tensor`): each agent q_value
                - next_state (:obj:`list`): next rnn state
        Shapes:
            - agent_state (:obj:`torch.Tensor`): :math:`(T, B, A, N)`, where T is timestep, B is batch_size\
                A is agent_num, N is obs_shape
            - global_state (:obj:`torch.Tensor`): :math:`(T, B, M)`, where M is global_obs_shape
            - prev_state (:obj:`list`): math:`(B, A)`, a list of length B, and each element is a list of length A
            - action (:obj:`torch.Tensor`): :math:`(T, B, A)`
            - total_q (:obj:`torch.Tensor`): :math:`(T, B)`
            - agent_q (:obj:`torch.Tensor`): :math:`(T, B, A, P)`, where P is action_shape
            - next_state (:obj:`list`): math:`(B, A)`, a list of length B, and each element is a list of length A
        """
        agent_state, global_state, prev_state = data['obs']['agent_state'], data['obs']['global_state'], data[
            'prev_state']
        action = data.get('action', None)
        if single_step:
            agent_state, global_state = agent_state.unsqueeze(0), global_state.unsqueeze(0)
        T, B, A = agent_state.shape[:3]
        assert len(prev_state) == B and all(
            [len(p) == A for p in prev_state]
        ), '{}-{}-{}-{}'.format([type(p) for p in prev_state], B, A, len(prev_state[0]))
        prev_state = reduce(lambda x, y: x + y, prev_state)
        agent_state = agent_state.reshape(T, -1, *agent_state.shape[3:])
        output = self._q_network({'obs': agent_state, 'prev_state': prev_state, 'enable_fast_timestep': True})
        agent_q, next_state = output['logit'], output['next_state']
        next_state, _ = list_split(next_state, step=A)
        agent_q = agent_q.reshape(T, B, A, -1)
        if action is None:
            # For target forward process
            if len(data['obs']['action_mask'].shape) == 3:
                action_mask = data['obs']['action_mask'].unsqueeze(0)
            else:
                action_mask = data['obs']['action_mask']
            agent_q[action_mask == 0.0] = -9999999
            action = agent_q.argmax(dim=-1)
        agent_q_act = torch.gather(agent_q, dim=-1, index=action.unsqueeze(-1))
        agent_q_act = agent_q_act.squeeze(-1)  # T, B, A

        hidden_states = output['hidden_state'].reshape(T * B, A, -1)
        action = action.reshape(T * B, A).unsqueeze(-1)
        action_onehot = torch.zeros(size=(T * B, A, agent_q.shape[-1]), device=action.device)
        action_onehot = action_onehot.scatter(2, action, 1)
        agent_state_action_input = torch.cat([hidden_states, action_onehot], dim=2)
        agent_state_action_encoding = self.action_encoding(agent_state_action_input.reshape(T * B * A,
                                                                                            -1)).reshape(T * B, A, -1)
        agent_state_action_encoding = agent_state_action_encoding.sum(dim=1)  # Sum across agents

        inputs = torch.cat([global_state.reshape(T * B, -1), agent_state_action_encoding], dim=1)
        q_outputs = self.Q(inputs)
        q_outputs = q_outputs.reshape(T, B)
        v_outputs = self.V(global_state.reshape(T * B, -1))
        v_outputs = v_outputs.reshape(T, B)
        if single_step:
            q_outputs, agent_q, agent_q_act, v_outputs = q_outputs.squeeze(0), agent_q.squeeze(0), agent_q_act.squeeze(
                0
            ), v_outputs.squeeze(0)
        return {
            'total_q': q_outputs,
            'logit': agent_q,
            'agent_q_act': agent_q_act,
            'vs': v_outputs,
            'next_state': next_state,
            'action_mask': data['obs']['action_mask']
        }