File size: 4,000 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from collections import namedtuple
import torch
import torch.nn.functional as F
from torch.distributions import Independent, Normal

a2c_data = namedtuple('a2c_data', ['logit', 'action', 'value', 'adv', 'return_', 'weight'])
a2c_loss = namedtuple('a2c_loss', ['policy_loss', 'value_loss', 'entropy_loss'])


def a2c_error(data: namedtuple) -> namedtuple:
    """
    Overview:
        Implementation of A2C(Advantage Actor-Critic) (arXiv:1602.01783) for discrete action space
    Arguments:
        - data (:obj:`namedtuple`): a2c input data with fieids shown in ``a2c_data``
    Returns:
        - a2c_loss (:obj:`namedtuple`): the a2c loss item, all of them are the differentiable 0-dim tensor
    Shapes:
        - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim
        - action (:obj:`torch.LongTensor`): :math:`(B, )`
        - value (:obj:`torch.FloatTensor`): :math:`(B, )`
        - adv (:obj:`torch.FloatTensor`): :math:`(B, )`
        - return (:obj:`torch.FloatTensor`): :math:`(B, )`
        - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )`
        - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor
        - value_loss (:obj:`torch.FloatTensor`): :math:`()`
        - entropy_loss (:obj:`torch.FloatTensor`): :math:`()`
    Examples:
        >>> data = a2c_data(
        >>>     logit=torch.randn(2, 3),
        >>>     action=torch.randint(0, 3, (2, )),
        >>>     value=torch.randn(2, ),
        >>>     adv=torch.randn(2, ),
        >>>     return_=torch.randn(2, ),
        >>>     weight=torch.ones(2, ),
        >>> )
        >>> loss = a2c_error(data)
    """
    logit, action, value, adv, return_, weight = data
    if weight is None:
        weight = torch.ones_like(value)
    dist = torch.distributions.categorical.Categorical(logits=logit)
    logp = dist.log_prob(action)
    entropy_loss = (dist.entropy() * weight).mean()
    policy_loss = -(logp * adv * weight).mean()
    value_loss = (F.mse_loss(return_, value, reduction='none') * weight).mean()
    return a2c_loss(policy_loss, value_loss, entropy_loss)


def a2c_error_continuous(data: namedtuple) -> namedtuple:
    """
    Overview:
        Implementation of A2C(Advantage Actor-Critic) (arXiv:1602.01783) for continuous action space
    Arguments:
        - data (:obj:`namedtuple`): a2c input data with fieids shown in ``a2c_data``
    Returns:
        - a2c_loss (:obj:`namedtuple`): the a2c loss item, all of them are the differentiable 0-dim tensor
    Shapes:
        - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim
        - action (:obj:`torch.LongTensor`): :math:`(B, N)`
        - value (:obj:`torch.FloatTensor`): :math:`(B, )`
        - adv (:obj:`torch.FloatTensor`): :math:`(B, )`
        - return (:obj:`torch.FloatTensor`): :math:`(B, )`
        - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )`
        - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor
        - value_loss (:obj:`torch.FloatTensor`): :math:`()`
        - entropy_loss (:obj:`torch.FloatTensor`): :math:`()`
    Examples:
        >>> data = a2c_data(
        >>>     logit={'mu': torch.randn(2, 3), 'sigma': torch.sqrt(torch.randn(2, 3)**2)},
        >>>     action=torch.randn(2, 3),
        >>>     value=torch.randn(2, ),
        >>>     adv=torch.randn(2, ),
        >>>     return_=torch.randn(2, ),
        >>>     weight=torch.ones(2, ),
        >>> )
        >>> loss = a2c_error_continuous(data)
    """
    logit, action, value, adv, return_, weight = data
    if weight is None:
        weight = torch.ones_like(value)

    dist = Independent(Normal(logit['mu'], logit['sigma']), 1)
    logp = dist.log_prob(action)
    entropy_loss = (dist.entropy() * weight).mean()
    policy_loss = -(logp * adv * weight).mean()
    value_loss = (F.mse_loss(return_, value, reduction='none') * weight).mean()
    return a2c_loss(policy_loss, value_loss, entropy_loss)