File size: 2,979 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from collections import namedtuple
import torch
import torch.nn.functional as F
from ding.rl_utils.td import generalized_lambda_returns

coma_data = namedtuple('coma_data', ['logit', 'action', 'q_value', 'target_q_value', 'reward', 'weight'])
coma_loss = namedtuple('coma_loss', ['policy_loss', 'q_value_loss', 'entropy_loss'])


def coma_error(data: namedtuple, gamma: float, lambda_: float) -> namedtuple:
    """
    Overview:
        Implementation of COMA
    Arguments:
        - data (:obj:`namedtuple`): coma input data with fieids shown in ``coma_data``
    Returns:
        - coma_loss (:obj:`namedtuple`): the coma loss item, all of them are the differentiable 0-dim tensor
    Shapes:
        - logit (:obj:`torch.FloatTensor`): :math:`(T, B, A, N)`, where B is batch size A is the agent num, and N is \
            action dim
        - action (:obj:`torch.LongTensor`): :math:`(T, B, A)`
        - q_value (:obj:`torch.FloatTensor`): :math:`(T, B, A, N)`
        - target_q_value (:obj:`torch.FloatTensor`): :math:`(T, B, A, N)`
        - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`
        - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(T ,B, A)`
        - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor
        - value_loss (:obj:`torch.FloatTensor`): :math:`()`
        - entropy_loss (:obj:`torch.FloatTensor`): :math:`()`
    Examples:
        >>> action_dim = 4
        >>> agent_num = 3
        >>> data = coma_data(
        >>>     logit=torch.randn(2, 3, agent_num, action_dim),
        >>>     action=torch.randint(0, action_dim, (2, 3, agent_num)),
        >>>     q_value=torch.randn(2, 3, agent_num, action_dim),
        >>>     target_q_value=torch.randn(2, 3, agent_num, action_dim),
        >>>     reward=torch.randn(2, 3),
        >>>     weight=torch.ones(2, 3, agent_num),
        >>> )
        >>> loss = coma_error(data, 0.99, 0.99)
    """
    logit, action, q_value, target_q_value, reward, weight = data
    if weight is None:
        weight = torch.ones_like(action)
    q_taken = torch.gather(q_value, -1, index=action.unsqueeze(-1)).squeeze(-1)
    target_q_taken = torch.gather(target_q_value, -1, index=action.unsqueeze(-1)).squeeze(-1)
    T, B, A = target_q_taken.shape
    reward = reward.unsqueeze(-1).expand_as(target_q_taken).reshape(T, -1)
    target_q_taken = target_q_taken.reshape(T, -1)
    return_ = generalized_lambda_returns(target_q_taken, reward[:-1], gamma, lambda_)
    return_ = return_.reshape(T - 1, B, A)
    q_value_loss = (F.mse_loss(return_, q_taken[:-1], reduction='none') * weight[:-1]).mean()

    dist = torch.distributions.categorical.Categorical(logits=logit)
    logp = dist.log_prob(action)
    baseline = (torch.softmax(logit, dim=-1) * q_value).sum(-1).detach()
    adv = (q_taken - baseline).detach()
    entropy_loss = (dist.entropy() * weight).mean()
    policy_loss = -(logp * adv * weight).mean()
    return coma_loss(policy_loss, q_value_loss, entropy_loss)