File size: 3,033 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from typing import Union
import torch
from torch.distributions import Categorical, Independent, Normal


def compute_importance_weights(
    target_output: Union[torch.Tensor, dict],
    behaviour_output: Union[torch.Tensor, dict],
    action: torch.Tensor,
    action_space_type: str = 'discrete',
    requires_grad: bool = False
):
    """
    Overview:
        Computing importance sampling weight with given output and action
    Arguments:
        - target_output (:obj:`Union[torch.Tensor,dict]`): the output taking the action \
            by the current policy network, \
            usually this output is network output logit if action space is discrete, \
            or is a dict containing parameters of action distribution if action space is continuous.
        - behaviour_output (:obj:`Union[torch.Tensor,dict]`): the output taking the action \
            by the behaviour policy network,\
            usually this output is network output logit,  if action space is discrete, \
            or is a dict containing parameters of action distribution if action space is continuous.
        - action (:obj:`torch.Tensor`): the chosen action(index for the discrete action space) in trajectory,\
            i.e.: behaviour_action
        - action_space_type (:obj:`str`): action space types in ['discrete', 'continuous']
        - requires_grad (:obj:`bool`): whether requires grad computation
    Returns:
        - rhos (:obj:`torch.Tensor`): Importance sampling weight
    Shapes:
        - target_output (:obj:`Union[torch.FloatTensor,dict]`): :math:`(T, B, N)`, \
            where T is timestep, B is batch size and N is action dim
        - behaviour_output (:obj:`Union[torch.FloatTensor,dict]`): :math:`(T, B, N)`
        - action (:obj:`torch.LongTensor`): :math:`(T, B)`
        - rhos (:obj:`torch.FloatTensor`): :math:`(T, B)`
    Examples:
        >>> target_output = torch.randn(2, 3, 4)
        >>> behaviour_output = torch.randn(2, 3, 4)
        >>> action = torch.randint(0, 4, (2, 3))
        >>> rhos = compute_importance_weights(target_output, behaviour_output, action)
    """
    grad_context = torch.enable_grad() if requires_grad else torch.no_grad()
    assert isinstance(action, torch.Tensor)
    assert action_space_type in ['discrete', 'continuous']

    with grad_context:
        if action_space_type == 'continuous':
            dist_target = Independent(Normal(loc=target_output['mu'], scale=target_output['sigma']), 1)
            dist_behaviour = Independent(Normal(loc=behaviour_output['mu'], scale=behaviour_output['sigma']), 1)
            rhos = dist_target.log_prob(action) - dist_behaviour.log_prob(action)
            rhos = torch.exp(rhos)
            return rhos
        elif action_space_type == 'discrete':
            dist_target = Categorical(logits=target_output)
            dist_behaviour = Categorical(logits=behaviour_output)
            rhos = dist_target.log_prob(action) - dist_behaviour.log_prob(action)
            rhos = torch.exp(rhos)
            return rhos