import torch import torch.nn.functional as F from torch.distributions import Categorical, Independent, Normal from collections import namedtuple from .isw import compute_importance_weights from ding.hpc_rl import hpc_wrapper def vtrace_nstep_return(clipped_rhos, clipped_cs, reward, bootstrap_values, gamma=0.99, lambda_=0.95): """ Overview: Computation of vtrace return. Returns: - vtrace_return (:obj:`torch.FloatTensor`): the vtrace loss item, all of them are differentiable 0-dim tensor Shapes: - clipped_rhos (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep, B is batch size - clipped_cs (:obj:`torch.FloatTensor`): :math:`(T, B)` - reward (:obj:`torch.FloatTensor`): :math:`(T, B)` - bootstrap_values (:obj:`torch.FloatTensor`): :math:`(T+1, B)` - vtrace_return (:obj:`torch.FloatTensor`): :math:`(T, B)` """ deltas = clipped_rhos * (reward + gamma * bootstrap_values[1:] - bootstrap_values[:-1]) factor = gamma * lambda_ result = bootstrap_values[:-1].clone() vtrace_item = 0. for t in reversed(range(reward.size()[0])): vtrace_item = deltas[t] + factor * clipped_cs[t] * vtrace_item result[t] += vtrace_item return result def vtrace_advantage(clipped_pg_rhos, reward, return_, bootstrap_values, gamma): """ Overview: Computation of vtrace advantage. Returns: - vtrace_advantage (:obj:`namedtuple`): the vtrace loss item, all of them are the differentiable 0-dim tensor Shapes: - clipped_pg_rhos (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep, B is batch size - reward (:obj:`torch.FloatTensor`): :math:`(T, B)` - return (:obj:`torch.FloatTensor`): :math:`(T, B)` - bootstrap_values (:obj:`torch.FloatTensor`): :math:`(T, B)` - vtrace_advantage (:obj:`torch.FloatTensor`): :math:`(T, B)` """ return clipped_pg_rhos * (reward + gamma * return_ - bootstrap_values) vtrace_data = namedtuple('vtrace_data', ['target_output', 'behaviour_output', 'action', 'value', 'reward', 'weight']) vtrace_loss = namedtuple('vtrace_loss', ['policy_loss', 'value_loss', 'entropy_loss']) def shape_fn_vtrace_discrete_action(args, kwargs): r""" Overview: Return shape of vtrace for hpc Returns: shape: [T, B, N] """ if len(args) <= 0: tmp = kwargs['data'].target_output.shape else: tmp = args[0].target_output.shape return tmp @hpc_wrapper( shape_fn=shape_fn_vtrace_discrete_action, namedtuple_data=True, include_args=[0, 1, 2, 3, 4, 5], include_kwargs=['data', 'gamma', 'lambda_', 'rho_clip_ratio', 'c_clip_ratio', 'rho_pg_clip_ratio'] ) def vtrace_error_discrete_action( data: namedtuple, gamma: float = 0.99, lambda_: float = 0.95, rho_clip_ratio: float = 1.0, c_clip_ratio: float = 1.0, rho_pg_clip_ratio: float = 1.0 ): """ Overview: Implementation of vtrace(IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner\ Architectures), (arXiv:1802.01561) Arguments: - data (:obj:`namedtuple`): input data with fields shown in ``vtrace_data`` - target_output (:obj:`torch.Tensor`): the output taking the action by the current policy network,\ usually this output is network output logit - behaviour_output (:obj:`torch.Tensor`): the output taking the action by the behaviour policy network,\ usually this output is network output logit, which is used to produce the trajectory(collector) - action (:obj:`torch.Tensor`): the chosen action(index for the discrete action space) in trajectory,\ i.e.: behaviour_action - gamma: (:obj:`float`): the future discount factor, defaults to 0.95 - lambda: (:obj:`float`): mix factor between 1-step (lambda_=0) and n-step, defaults to 1.0 - rho_clip_ratio (:obj:`float`): the clipping threshold for importance weights (rho) when calculating\ the baseline targets (vs) - c_clip_ratio (:obj:`float`): the clipping threshold for importance weights (c) when calculating\ the baseline targets (vs) - rho_pg_clip_ratio (:obj:`float`): the clipping threshold for importance weights (rho) when calculating\ the policy gradient advantage Returns: - trace_loss (:obj:`namedtuple`): the vtrace loss item, all of them are the differentiable 0-dim tensor Shapes: - target_output (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where T is timestep, B is batch size and\ N is action dim - behaviour_output (:obj:`torch.FloatTensor`): :math:`(T, B, N)` - action (:obj:`torch.LongTensor`): :math:`(T, B)` - value (:obj:`torch.FloatTensor`): :math:`(T+1, B)` - reward (:obj:`torch.LongTensor`): :math:`(T, B)` - weight (:obj:`torch.LongTensor`): :math:`(T, B)` Examples: >>> T, B, N = 4, 8, 16 >>> value = torch.randn(T + 1, B).requires_grad_(True) >>> reward = torch.rand(T, B) >>> target_output = torch.randn(T, B, N).requires_grad_(True) >>> behaviour_output = torch.randn(T, B, N) >>> action = torch.randint(0, N, size=(T, B)) >>> data = vtrace_data(target_output, behaviour_output, action, value, reward, None) >>> loss = vtrace_error_discrete_action(data, rho_clip_ratio=1.1) """ target_output, behaviour_output, action, value, reward, weight = data with torch.no_grad(): IS = compute_importance_weights(target_output, behaviour_output, action, 'discrete') rhos = torch.clamp(IS, max=rho_clip_ratio) cs = torch.clamp(IS, max=c_clip_ratio) return_ = vtrace_nstep_return(rhos, cs, reward, value, gamma, lambda_) pg_rhos = torch.clamp(IS, max=rho_pg_clip_ratio) return_t_plus_1 = torch.cat([return_[1:], value[-1:]], 0) adv = vtrace_advantage(pg_rhos, reward, return_t_plus_1, value[:-1], gamma) if weight is None: weight = torch.ones_like(reward) dist_target = Categorical(logits=target_output) pg_loss = -(dist_target.log_prob(action) * adv * weight).mean() value_loss = (F.mse_loss(value[:-1], return_, reduction='none') * weight).mean() entropy_loss = (dist_target.entropy() * weight).mean() return vtrace_loss(pg_loss, value_loss, entropy_loss) def vtrace_error_continuous_action( data: namedtuple, gamma: float = 0.99, lambda_: float = 0.95, rho_clip_ratio: float = 1.0, c_clip_ratio: float = 1.0, rho_pg_clip_ratio: float = 1.0 ): """ Overview: Implementation of vtrace(IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner\ Architectures), (arXiv:1802.01561) Arguments: - data (:obj:`namedtuple`): input data with fields shown in ``vtrace_data`` - target_output (:obj:`dict{key:torch.Tensor}`): the output taking the action \ by the current policy network, usually this output is network output, \ which represents the distribution by reparameterization trick. - behaviour_output (:obj:`dict{key:torch.Tensor}`): the output taking the action \ by the behaviour policy network, usually this output is network output logit, \ which represents the distribution by reparameterization trick. - action (:obj:`torch.Tensor`): the chosen action(index for the discrete action space) in trajectory, \ i.e.: behaviour_action - gamma: (:obj:`float`): the future discount factor, defaults to 0.95 - lambda: (:obj:`float`): mix factor between 1-step (lambda_=0) and n-step, defaults to 1.0 - rho_clip_ratio (:obj:`float`): the clipping threshold for importance weights (rho) when calculating\ the baseline targets (vs) - c_clip_ratio (:obj:`float`): the clipping threshold for importance weights (c) when calculating\ the baseline targets (vs) - rho_pg_clip_ratio (:obj:`float`): the clipping threshold for importance weights (rho) when calculating\ the policy gradient advantage Returns: - trace_loss (:obj:`namedtuple`): the vtrace loss item, all of them are the differentiable 0-dim tensor Shapes: - target_output (:obj:`dict{key:torch.FloatTensor}`): :math:`(T, B, N)`, \ where T is timestep, B is batch size and \ N is action dim. The keys are usually parameters of reparameterization trick. - behaviour_output (:obj:`dict{key:torch.FloatTensor}`): :math:`(T, B, N)` - action (:obj:`torch.LongTensor`): :math:`(T, B)` - value (:obj:`torch.FloatTensor`): :math:`(T+1, B)` - reward (:obj:`torch.LongTensor`): :math:`(T, B)` - weight (:obj:`torch.LongTensor`): :math:`(T, B)` Examples: >>> T, B, N = 4, 8, 16 >>> value = torch.randn(T + 1, B).requires_grad_(True) >>> reward = torch.rand(T, B) >>> target_output = dict( >>> 'mu': torch.randn(T, B, N).requires_grad_(True), >>> 'sigma': torch.exp(torch.randn(T, B, N).requires_grad_(True)), >>> ) >>> behaviour_output = dict( >>> 'mu': torch.randn(T, B, N), >>> 'sigma': torch.exp(torch.randn(T, B, N)), >>> ) >>> action = torch.randn((T, B, N)) >>> data = vtrace_data(target_output, behaviour_output, action, value, reward, None) >>> loss = vtrace_error_continuous_action(data, rho_clip_ratio=1.1) """ target_output, behaviour_output, action, value, reward, weight = data with torch.no_grad(): IS = compute_importance_weights(target_output, behaviour_output, action, 'continuous') rhos = torch.clamp(IS, max=rho_clip_ratio) cs = torch.clamp(IS, max=c_clip_ratio) return_ = vtrace_nstep_return(rhos, cs, reward, value, gamma, lambda_) pg_rhos = torch.clamp(IS, max=rho_pg_clip_ratio) return_t_plus_1 = torch.cat([return_[1:], value[-1:]], 0) adv = vtrace_advantage(pg_rhos, reward, return_t_plus_1, value[:-1], gamma) if weight is None: weight = torch.ones_like(reward) dist_target = Independent(Normal(loc=target_output['mu'], scale=target_output['sigma']), 1) pg_loss = -(dist_target.log_prob(action) * adv * weight).mean() value_loss = (F.mse_loss(value[:-1], return_, reduction='none') * weight).mean() entropy_loss = (dist_target.entropy() * weight).mean() return vtrace_loss(pg_loss, value_loss, entropy_loss)