|
from typing import Union |
|
import torch |
|
from torch.distributions import Categorical, Independent, Normal |
|
|
|
|
|
def compute_importance_weights( |
|
target_output: Union[torch.Tensor, dict], |
|
behaviour_output: Union[torch.Tensor, dict], |
|
action: torch.Tensor, |
|
action_space_type: str = 'discrete', |
|
requires_grad: bool = False |
|
): |
|
""" |
|
Overview: |
|
Computing importance sampling weight with given output and action |
|
Arguments: |
|
- target_output (:obj:`Union[torch.Tensor,dict]`): the output taking the action \ |
|
by the current policy network, \ |
|
usually this output is network output logit if action space is discrete, \ |
|
or is a dict containing parameters of action distribution if action space is continuous. |
|
- behaviour_output (:obj:`Union[torch.Tensor,dict]`): the output taking the action \ |
|
by the behaviour policy network,\ |
|
usually this output is network output logit, if action space is discrete, \ |
|
or is a dict containing parameters of action distribution if action space is continuous. |
|
- action (:obj:`torch.Tensor`): the chosen action(index for the discrete action space) in trajectory,\ |
|
i.e.: behaviour_action |
|
- action_space_type (:obj:`str`): action space types in ['discrete', 'continuous'] |
|
- requires_grad (:obj:`bool`): whether requires grad computation |
|
Returns: |
|
- rhos (:obj:`torch.Tensor`): Importance sampling weight |
|
Shapes: |
|
- target_output (:obj:`Union[torch.FloatTensor,dict]`): :math:`(T, B, N)`, \ |
|
where T is timestep, B is batch size and N is action dim |
|
- behaviour_output (:obj:`Union[torch.FloatTensor,dict]`): :math:`(T, B, N)` |
|
- action (:obj:`torch.LongTensor`): :math:`(T, B)` |
|
- rhos (:obj:`torch.FloatTensor`): :math:`(T, B)` |
|
Examples: |
|
>>> target_output = torch.randn(2, 3, 4) |
|
>>> behaviour_output = torch.randn(2, 3, 4) |
|
>>> action = torch.randint(0, 4, (2, 3)) |
|
>>> rhos = compute_importance_weights(target_output, behaviour_output, action) |
|
""" |
|
grad_context = torch.enable_grad() if requires_grad else torch.no_grad() |
|
assert isinstance(action, torch.Tensor) |
|
assert action_space_type in ['discrete', 'continuous'] |
|
|
|
with grad_context: |
|
if action_space_type == 'continuous': |
|
dist_target = Independent(Normal(loc=target_output['mu'], scale=target_output['sigma']), 1) |
|
dist_behaviour = Independent(Normal(loc=behaviour_output['mu'], scale=behaviour_output['sigma']), 1) |
|
rhos = dist_target.log_prob(action) - dist_behaviour.log_prob(action) |
|
rhos = torch.exp(rhos) |
|
return rhos |
|
elif action_space_type == 'discrete': |
|
dist_target = Categorical(logits=target_output) |
|
dist_behaviour = Categorical(logits=behaviour_output) |
|
rhos = dist_target.log_prob(action) - dist_behaviour.log_prob(action) |
|
rhos = torch.exp(rhos) |
|
return rhos |
|
|