zjowowen's picture
init space
079c32c
raw
history blame
3.03 kB
from typing import Union
import torch
from torch.distributions import Categorical, Independent, Normal
def compute_importance_weights(
target_output: Union[torch.Tensor, dict],
behaviour_output: Union[torch.Tensor, dict],
action: torch.Tensor,
action_space_type: str = 'discrete',
requires_grad: bool = False
):
"""
Overview:
Computing importance sampling weight with given output and action
Arguments:
- target_output (:obj:`Union[torch.Tensor,dict]`): the output taking the action \
by the current policy network, \
usually this output is network output logit if action space is discrete, \
or is a dict containing parameters of action distribution if action space is continuous.
- behaviour_output (:obj:`Union[torch.Tensor,dict]`): the output taking the action \
by the behaviour policy network,\
usually this output is network output logit, if action space is discrete, \
or is a dict containing parameters of action distribution if action space is continuous.
- action (:obj:`torch.Tensor`): the chosen action(index for the discrete action space) in trajectory,\
i.e.: behaviour_action
- action_space_type (:obj:`str`): action space types in ['discrete', 'continuous']
- requires_grad (:obj:`bool`): whether requires grad computation
Returns:
- rhos (:obj:`torch.Tensor`): Importance sampling weight
Shapes:
- target_output (:obj:`Union[torch.FloatTensor,dict]`): :math:`(T, B, N)`, \
where T is timestep, B is batch size and N is action dim
- behaviour_output (:obj:`Union[torch.FloatTensor,dict]`): :math:`(T, B, N)`
- action (:obj:`torch.LongTensor`): :math:`(T, B)`
- rhos (:obj:`torch.FloatTensor`): :math:`(T, B)`
Examples:
>>> target_output = torch.randn(2, 3, 4)
>>> behaviour_output = torch.randn(2, 3, 4)
>>> action = torch.randint(0, 4, (2, 3))
>>> rhos = compute_importance_weights(target_output, behaviour_output, action)
"""
grad_context = torch.enable_grad() if requires_grad else torch.no_grad()
assert isinstance(action, torch.Tensor)
assert action_space_type in ['discrete', 'continuous']
with grad_context:
if action_space_type == 'continuous':
dist_target = Independent(Normal(loc=target_output['mu'], scale=target_output['sigma']), 1)
dist_behaviour = Independent(Normal(loc=behaviour_output['mu'], scale=behaviour_output['sigma']), 1)
rhos = dist_target.log_prob(action) - dist_behaviour.log_prob(action)
rhos = torch.exp(rhos)
return rhos
elif action_space_type == 'discrete':
dist_target = Categorical(logits=target_output)
dist_behaviour = Categorical(logits=behaviour_output)
rhos = dist_target.log_prob(action) - dist_behaviour.log_prob(action)
rhos = torch.exp(rhos)
return rhos