from typing import Tuple, List from collections import namedtuple import torch import torch.nn.functional as F EPS = 1e-8 def acer_policy_error( q_values: torch.Tensor, q_retraces: torch.Tensor, v_pred: torch.Tensor, target_logit: torch.Tensor, actions: torch.Tensor, ratio: torch.Tensor, c_clip_ratio: float = 10.0 ) -> Tuple[torch.Tensor, torch.Tensor]: """ Overview: Get ACER policy loss. Arguments: - q_values (:obj:`torch.Tensor`): Q values - q_retraces (:obj:`torch.Tensor`): Q values (be calculated by retrace method) - v_pred (:obj:`torch.Tensor`): V values - target_pi (:obj:`torch.Tensor`): The new policy's probability - actions (:obj:`torch.Tensor`): The actions in replay buffer - ratio (:obj:`torch.Tensor`): ratio of new polcy with behavior policy - c_clip_ratio (:obj:`float`): clip value for ratio Returns: - actor_loss (:obj:`torch.Tensor`): policy loss from q_retrace - bc_loss (:obj:`torch.Tensor`): correct policy loss Shapes: - q_values (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where B is batch size and N is action dim - q_retraces (:obj:`torch.FloatTensor`): :math:`(T, B, 1)` - v_pred (:obj:`torch.FloatTensor`): :math:`(T, B, 1)` - target_pi (:obj:`torch.FloatTensor`): :math:`(T, B, N)` - actions (:obj:`torch.LongTensor`): :math:`(T, B)` - ratio (:obj:`torch.FloatTensor`): :math:`(T, B, N)` - actor_loss (:obj:`torch.FloatTensor`): :math:`(T, B, 1)` - bc_loss (:obj:`torch.FloatTensor`): :math:`(T, B, 1)` Examples: >>> q_values=torch.randn(2, 3, 4), >>> q_retraces=torch.randn(2, 3, 1), >>> v_pred=torch.randn(2, 3, 1), >>> target_pi=torch.randn(2, 3, 4), >>> actions=torch.randint(0, 4, (2, 3)), >>> ratio=torch.randn(2, 3, 4), >>> loss = acer_policy_error(q_values, q_retraces, v_pred, target_pi, actions, ratio) """ actions = actions.unsqueeze(-1) with torch.no_grad(): advantage_retraces = q_retraces - v_pred # shape T,B,1 advantage_native = q_values - v_pred # shape T,B,env_action_shape actor_loss = ratio.gather(-1, actions).clamp(max=c_clip_ratio) * advantage_retraces * target_logit.gather( -1, actions ) # shape T,B,1 # bias correction term, the first target_pi will not calculate gradient flow bias_correction_loss = (1.0-c_clip_ratio/(ratio+EPS)).clamp(min=0.0)*torch.exp(target_logit).detach() * \ advantage_native*target_logit # shape T,B,env_action_shape bias_correction_loss = bias_correction_loss.sum(-1, keepdim=True) return actor_loss, bias_correction_loss def acer_value_error(q_values, q_retraces, actions): """ Overview: Get ACER critic loss. Arguments: - q_values (:obj:`torch.Tensor`): Q values - q_retraces (:obj:`torch.Tensor`): Q values (be calculated by retrace method) - actions (:obj:`torch.Tensor`): The actions in replay buffer - ratio (:obj:`torch.Tensor`): ratio of new polcy with behavior policy Returns: - critic_loss (:obj:`torch.Tensor`): critic loss Shapes: - q_values (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where B is batch size and N is action dim - q_retraces (:obj:`torch.FloatTensor`): :math:`(T, B, 1)` - actions (:obj:`torch.LongTensor`): :math:`(T, B)` - critic_loss (:obj:`torch.FloatTensor`): :math:`(T, B, 1)` Examples: >>> q_values=torch.randn(2, 3, 4) >>> q_retraces=torch.randn(2, 3, 1) >>> actions=torch.randint(0, 4, (2, 3)) >>> loss = acer_value_error(q_values, q_retraces, actions) """ actions = actions.unsqueeze(-1) critic_loss = 0.5 * (q_retraces - q_values.gather(-1, actions)).pow(2) return critic_loss def acer_trust_region_update( actor_gradients: List[torch.Tensor], target_logit: torch.Tensor, avg_logit: torch.Tensor, trust_region_value: float ) -> List[torch.Tensor]: """ Overview: calcuate gradient with trust region constrain Arguments: - actor_gradients (:obj:`list(torch.Tensor)`): gradients value's for different part - target_pi (:obj:`torch.Tensor`): The new policy's probability - avg_pi (:obj:`torch.Tensor`): The average policy's probability - trust_region_value (:obj:`float`): the range of trust region Returns: - update_gradients (:obj:`list(torch.Tensor)`): gradients with trust region constraint Shapes: - target_pi (:obj:`torch.FloatTensor`): :math:`(T, B, N)` - avg_pi (:obj:`torch.FloatTensor`): :math:`(T, B, N)` - update_gradients (:obj:`list(torch.FloatTensor)`): :math:`(T, B, N)` Examples: >>> actor_gradients=[torch.randn(2, 3, 4)] >>> target_pi=torch.randn(2, 3, 4) >>> avg_pi=torch.randn(2, 3, 4) >>> loss = acer_trust_region_update(actor_gradients, target_pi, avg_pi, 0.1) """ with torch.no_grad(): KL_gradients = [torch.exp(avg_logit)] update_gradients = [] # TODO: here is only one elements in this list.Maybe will use to more elements in the future actor_gradient = actor_gradients[0] KL_gradient = KL_gradients[0] scale = actor_gradient.mul(KL_gradient).sum(-1, keepdim=True) - trust_region_value scale = torch.div(scale, KL_gradient.mul(KL_gradient).sum(-1, keepdim=True)).clamp(min=0.0) update_gradients.append(actor_gradient - scale * KL_gradient) return update_gradients