File size: 2,921 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from typing import Tuple
from collections import namedtuple
import torch
import torch.nn.functional as F

ppg_data = namedtuple('ppg_data', ['logit_new', 'logit_old', 'action', 'value_new', 'value_old', 'return_', 'weight'])
ppg_joint_loss = namedtuple('ppg_joint_loss', ['auxiliary_loss', 'behavioral_cloning_loss'])


def ppg_joint_error(
        data: namedtuple,
        clip_ratio: float = 0.2,
        use_value_clip: bool = True,
) -> Tuple[namedtuple, namedtuple]:
    '''
    Overview:
        Get PPG joint loss
    Arguments:
        - data (:obj:`namedtuple`): ppg input data with fieids shown in ``ppg_data``
        - clip_ratio (:obj:`float`): clip value for ratio
        - use_value_clip (:obj:`bool`): whether use value clip
    Returns:
        - ppg_joint_loss (:obj:`namedtuple`): the ppg loss item, all of them are the differentiable 0-dim tensor
    Shapes:
        - logit_new (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim
        - logit_old (:obj:`torch.FloatTensor`): :math:`(B, N)`
        - action (:obj:`torch.LongTensor`): :math:`(B,)`
        - value_new (:obj:`torch.FloatTensor`): :math:`(B, 1)`
        - value_old (:obj:`torch.FloatTensor`): :math:`(B, 1)`
        - return (:obj:`torch.FloatTensor`): :math:`(B, 1)`
        - weight (:obj:`torch.FloatTensor`): :math:`(B,)`
        - auxiliary_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor
        - behavioral_cloning_loss (:obj:`torch.FloatTensor`): :math:`()`
    Examples:
        >>> action_dim = 4
        >>> data = ppg_data(
        >>>     logit_new=torch.randn(3, action_dim),
        >>>     logit_old=torch.randn(3, action_dim),
        >>>     action=torch.randint(0, action_dim, (3,)),
        >>>     value_new=torch.randn(3, 1),
        >>>     value_old=torch.randn(3, 1),
        >>>     return_=torch.randn(3, 1),
        >>>     weight=torch.ones(3),
        >>> )
        >>> loss = ppg_joint_error(data, 0.99, 0.99)
    '''
    logit_new, logit_old, action, value_new, value_old, return_, weight = data

    if weight is None:
        weight = torch.ones_like(return_)

    # auxiliary_loss
    if use_value_clip:
        value_clip = value_old + (value_new - value_old).clamp(-clip_ratio, clip_ratio)
        v1 = (return_ - value_new).pow(2)
        v2 = (return_ - value_clip).pow(2)
        auxiliary_loss = 0.5 * (torch.max(v1, v2) * weight).mean()
    else:
        auxiliary_loss = 0.5 * ((return_ - value_new).pow(2) * weight).mean()

    dist_new = torch.distributions.categorical.Categorical(logits=logit_new)
    dist_old = torch.distributions.categorical.Categorical(logits=logit_old)
    logp_new = dist_new.log_prob(action)
    logp_old = dist_old.log_prob(action)

    # behavioral cloning loss
    behavioral_cloning_loss = F.kl_div(logp_new, logp_old, reduction='batchmean')

    return ppg_joint_loss(auxiliary_loss, behavioral_cloning_loss)