import torch import torch.nn.functional as F from collections import namedtuple from ding.rl_utils.isw import compute_importance_weights def compute_q_retraces( q_values: torch.Tensor, v_pred: torch.Tensor, rewards: torch.Tensor, actions: torch.Tensor, weights: torch.Tensor, ratio: torch.Tensor, gamma: float = 0.9 ) -> torch.Tensor: """ Shapes: - q_values (:obj:`torch.Tensor`): :math:`(T + 1, B, N)`, where T is unroll_len, B is batch size, N is discrete \ action dim. - v_pred (:obj:`torch.Tensor`): :math:`(T + 1, B, 1)` - rewards (:obj:`torch.Tensor`): :math:`(T, B)` - actions (:obj:`torch.Tensor`): :math:`(T, B)` - weights (:obj:`torch.Tensor`): :math:`(T, B)` - ratio (:obj:`torch.Tensor`): :math:`(T, B, N)` - q_retraces (:obj:`torch.Tensor`): :math:`(T + 1, B, 1)` Examples: >>> T=2 >>> B=3 >>> N=4 >>> q_values=torch.randn(T+1, B, N) >>> v_pred=torch.randn(T+1, B, 1) >>> rewards=torch.randn(T, B) >>> actions=torch.randint(0, N, (T, B)) >>> weights=torch.ones(T, B) >>> ratio=torch.randn(T, B, N) >>> q_retraces = compute_q_retraces(q_values, v_pred, rewards, actions, weights, ratio) .. note:: q_retrace operation doesn't need to compute gradient, just executes forward computation. """ T = q_values.size()[0] - 1 rewards = rewards.unsqueeze(-1) actions = actions.unsqueeze(-1) weights = weights.unsqueeze(-1) q_retraces = torch.zeros_like(v_pred) # shape (T+1),B,1 tmp_retraces = v_pred[-1] # shape B,1 q_retraces[-1] = v_pred[-1] q_gather = torch.zeros_like(v_pred) q_gather[0:-1] = q_values[0:-1].gather(-1, actions) # shape (T+1),B,1 ratio_gather = ratio.gather(-1, actions) # shape T,B,1 for idx in reversed(range(T)): q_retraces[idx] = rewards[idx] + gamma * weights[idx] * tmp_retraces tmp_retraces = ratio_gather[idx].clamp(max=1.0) * (q_retraces[idx] - q_gather[idx]) + v_pred[idx] return q_retraces # shape (T+1),B,1