from typing import Dict, Any, List from functools import partial import torch from torch import Tensor from torch import nn from torch.distributions import Normal, Independent from ding.torch_utils import to_device, fold_batch, unfold_batch, unsqueeze_repeat from ding.utils import POLICY_REGISTRY from ding.policy import SACPolicy from ding.rl_utils import generalized_lambda_returns from ding.policy.common_utils import default_preprocess_learn from .utils import q_evaluation @POLICY_REGISTRY.register('mbsac') class MBSACPolicy(SACPolicy): """ Overview: Model based SAC with value expansion (arXiv: 1803.00101) and value gradient (arXiv: 1510.09142) w.r.t lambda-return. https://arxiv.org/pdf/1803.00101.pdf https://arxiv.org/pdf/1510.09142.pdf Config: == ==================== ======== ============= ================================== ID Symbol Type Default Value Description == ==================== ======== ============= ================================== 1 ``learn._lambda`` float 0.8 | Lambda for TD-lambda return. 2 ``learn.grad_clip` float 100.0 | Max norm of gradients. 3 | ``learn.sample`` bool True | Whether to sample states or | ``_state`` | transitions from env buffer. == ==================== ======== ============= ================================== .. note:: For other configs, please refer to ding.policy.sac.SACPolicy. """ config = dict( learn=dict( # (float) Lambda for TD-lambda return. lambda_=0.8, # (float) Max norm of gradients. grad_clip=100, # (bool) Whether to sample states or transitions from environment buffer. sample_state=True, ) ) def _init_learn(self) -> None: super()._init_learn() self._target_model.requires_grad_(False) self._lambda = self._cfg.learn.lambda_ self._grad_clip = self._cfg.learn.grad_clip self._sample_state = self._cfg.learn.sample_state self._auto_alpha = self._cfg.learn.auto_alpha # TODO: auto alpha assert not self._auto_alpha, "NotImplemented" # TODO: TanhTransform leads to NaN def actor_fn(obs: Tensor): # (mu, sigma) = self._learn_model.forward( # obs, mode='compute_actor')['logit'] # # enforce action bounds # dist = TransformedDistribution( # Independent(Normal(mu, sigma), 1), [TanhTransform()]) # action = dist.rsample() # log_prob = dist.log_prob(action) # return action, -self._alpha.detach() * log_prob (mu, sigma) = self._learn_model.forward(obs, mode='compute_actor')['logit'] dist = Independent(Normal(mu, sigma), 1) pred = dist.rsample() action = torch.tanh(pred) log_prob = dist.log_prob( pred ) + 2 * (pred + torch.nn.functional.softplus(-2. * pred) - torch.log(torch.tensor(2.))).sum(-1) return action, -self._alpha.detach() * log_prob self._actor_fn = actor_fn def critic_fn(obss: Tensor, actions: Tensor, model: nn.Module): eval_data = {'obs': obss, 'action': actions} q_values = model.forward(eval_data, mode='compute_critic')['q_value'] return q_values self._critic_fn = critic_fn self._forward_learn_cnt = 0 def _forward_learn(self, data: dict, world_model, envstep) -> Dict[str, Any]: # preprocess data data = default_preprocess_learn( data, use_priority=self._priority, use_priority_IS_weight=self._cfg.priority_IS_weight, ignore_done=self._cfg.learn.ignore_done, use_nstep=False ) if self._cuda: data = to_device(data, self._device) if len(data['action'].shape) == 1: data['action'] = data['action'].unsqueeze(1) self._learn_model.train() self._target_model.train() # TODO: use treetensor # rollout length is determined by world_model.rollout_length_scheduler if self._sample_state: # data['reward'], ... are not used obss, actions, rewards, aug_rewards, dones = \ world_model.rollout(data['obs'], self._actor_fn, envstep) else: obss, actions, rewards, aug_rewards, dones = \ world_model.rollout(data['next_obs'], self._actor_fn, envstep) obss = torch.cat([data['obs'].unsqueeze(0), obss]) actions = torch.cat([data['action'].unsqueeze(0), actions]) rewards = torch.cat([data['reward'].unsqueeze(0), rewards]) aug_rewards = torch.cat([torch.zeros_like(data['reward']).unsqueeze(0), aug_rewards]) dones = torch.cat([data['done'].unsqueeze(0), dones]) dones = torch.cat([torch.zeros_like(data['done']).unsqueeze(0), dones]) # (T+1, B) target_q_values = q_evaluation(obss, actions, partial(self._critic_fn, model=self._target_model)) if self._twin_critic: target_q_values = torch.min(target_q_values[0], target_q_values[1]) + aug_rewards else: target_q_values = target_q_values + aug_rewards # (T, B) lambda_return = generalized_lambda_returns(target_q_values, rewards, self._gamma, self._lambda, dones[1:]) # (T, B) # If S_t terminates, we should not consider loss from t+1,... weight = (1 - dones[:-1].detach()).cumprod(dim=0) # (T+1, B) q_values = q_evaluation(obss.detach(), actions.detach(), partial(self._critic_fn, model=self._learn_model)) if self._twin_critic: critic_loss = 0.5 * torch.square(q_values[0][:-1] - lambda_return.detach()) \ + 0.5 * torch.square(q_values[1][:-1] - lambda_return.detach()) else: critic_loss = 0.5 * torch.square(q_values[:-1] - lambda_return.detach()) # value expansion loss critic_loss = (critic_loss * weight).mean() # value gradient loss policy_loss = -(lambda_return * weight).mean() # alpha_loss = None loss_dict = { 'critic_loss': critic_loss, 'policy_loss': policy_loss, # 'alpha_loss': alpha_loss.detach(), } norm_dict = self._update(loss_dict) # ============= # after update # ============= self._forward_learn_cnt += 1 # target update self._target_model.update(self._learn_model.state_dict()) return { 'cur_lr_q': self._optimizer_q.defaults['lr'], 'cur_lr_p': self._optimizer_policy.defaults['lr'], 'alpha': self._alpha.item(), 'target_q_value': target_q_values.detach().mean().item(), **norm_dict, **loss_dict, } def _update(self, loss_dict): # update critic self._optimizer_q.zero_grad() loss_dict['critic_loss'].backward() critic_norm = nn.utils.clip_grad_norm_(self._model.critic.parameters(), self._grad_clip) self._optimizer_q.step() # update policy self._optimizer_policy.zero_grad() loss_dict['policy_loss'].backward() policy_norm = nn.utils.clip_grad_norm_(self._model.actor.parameters(), self._grad_clip) self._optimizer_policy.step() # update temperature # self._alpha_optim.zero_grad() # loss_dict['alpha_loss'].backward() # self._alpha_optim.step() return {'policy_norm': policy_norm, 'critic_norm': critic_norm} def _monitor_vars_learn(self) -> List[str]: r""" Overview: Return variables' name if variables are to used in monitor. Returns: - vars (:obj:`List[str]`): Variables' name list. """ alpha_loss = ['alpha_loss'] if self._auto_alpha else [] return [ 'policy_loss', 'critic_loss', 'policy_norm', 'critic_norm', 'cur_lr_q', 'cur_lr_p', 'alpha', 'target_q_value', ] + alpha_loss @POLICY_REGISTRY.register('stevesac') class STEVESACPolicy(SACPolicy): r""" Overview: Model based SAC with stochastic value expansion (arXiv 1807.01675).\ This implementation also uses value gradient w.r.t the same STEVE target. https://arxiv.org/pdf/1807.01675.pdf Config: == ==================== ======== ============= ===================================== ID Symbol Type Default Value Description == ==================== ======== ============= ===================================== 1 ``learn.grad_clip` float 100.0 | Max norm of gradients. 2 ``learn.ensemble_size`` int 1 | The number of ensemble world models. == ==================== ======== ============= ===================================== .. note:: For other configs, please refer to ding.policy.sac.SACPolicy. """ config = dict( learn=dict( # (float) Max norm of gradients. grad_clip=100, # (int) The number of ensemble world models. ensemble_size=1, ) ) def _init_learn(self) -> None: super()._init_learn() self._target_model.requires_grad_(False) self._grad_clip = self._cfg.learn.grad_clip self._ensemble_size = self._cfg.learn.ensemble_size self._auto_alpha = self._cfg.learn.auto_alpha # TODO: auto alpha assert not self._auto_alpha, "NotImplemented" def actor_fn(obs: Tensor): obs, dim = fold_batch(obs, 1) (mu, sigma) = self._learn_model.forward(obs, mode='compute_actor')['logit'] dist = Independent(Normal(mu, sigma), 1) pred = dist.rsample() action = torch.tanh(pred) log_prob = dist.log_prob( pred ) + 2 * (pred + torch.nn.functional.softplus(-2. * pred) - torch.log(torch.tensor(2.))).sum(-1) aug_reward = -self._alpha.detach() * log_prob return unfold_batch(action, dim), unfold_batch(aug_reward, dim) self._actor_fn = actor_fn def critic_fn(obss: Tensor, actions: Tensor, model: nn.Module): eval_data = {'obs': obss, 'action': actions} q_values = model.forward(eval_data, mode='compute_critic')['q_value'] return q_values self._critic_fn = critic_fn self._forward_learn_cnt = 0 def _forward_learn(self, data: dict, world_model, envstep) -> Dict[str, Any]: # preprocess data data = default_preprocess_learn( data, use_priority=self._priority, use_priority_IS_weight=self._cfg.priority_IS_weight, ignore_done=self._cfg.learn.ignore_done, use_nstep=False ) if self._cuda: data = to_device(data, self._device) if len(data['action'].shape) == 1: data['action'] = data['action'].unsqueeze(1) # [B, D] -> [E, B, D] data['next_obs'] = unsqueeze_repeat(data['next_obs'], self._ensemble_size) data['reward'] = unsqueeze_repeat(data['reward'], self._ensemble_size) data['done'] = unsqueeze_repeat(data['done'], self._ensemble_size) self._learn_model.train() self._target_model.train() obss, actions, rewards, aug_rewards, dones = \ world_model.rollout(data['next_obs'], self._actor_fn, envstep, keep_ensemble=True) rewards = torch.cat([data['reward'].unsqueeze(0), rewards]) dones = torch.cat([data['done'].unsqueeze(0), dones]) # (T, E, B) target_q_values = q_evaluation(obss, actions, partial(self._critic_fn, model=self._target_model)) if self._twin_critic: target_q_values = torch.min(target_q_values[0], target_q_values[1]) + aug_rewards else: target_q_values = target_q_values + aug_rewards # (T+1, E, B) discounts = ((1 - dones) * self._gamma).cumprod(dim=0) discounts = torch.cat([torch.ones_like(discounts)[:1], discounts]) # (T, E, B) cum_rewards = (rewards * discounts[:-1]).cumsum(dim=0) discounted_q_values = target_q_values * discounts[1:] steve_return = cum_rewards + discounted_q_values # (T, B) steve_return_mean = steve_return.mean(1) with torch.no_grad(): steve_return_inv_var = 1 / (1e-8 + steve_return.var(1, unbiased=False)) steve_return_weight = steve_return_inv_var / (1e-8 + steve_return_inv_var.sum(dim=0)) # (B, ) steve_return = (steve_return_mean * steve_return_weight).sum(0) eval_data = {'obs': data['obs'], 'action': data['action']} q_values = self._learn_model.forward(eval_data, mode='compute_critic')['q_value'] if self._twin_critic: critic_loss = 0.5 * torch.square(q_values[0] - steve_return.detach()) \ + 0.5 * torch.square(q_values[1] - steve_return.detach()) else: critic_loss = 0.5 * torch.square(q_values - steve_return.detach()) critic_loss = critic_loss.mean() policy_loss = -steve_return.mean() # alpha_loss = None loss_dict = { 'critic_loss': critic_loss, 'policy_loss': policy_loss, # 'alpha_loss': alpha_loss.detach(), } norm_dict = self._update(loss_dict) # ============= # after update # ============= self._forward_learn_cnt += 1 # target update self._target_model.update(self._learn_model.state_dict()) return { 'cur_lr_q': self._optimizer_q.defaults['lr'], 'cur_lr_p': self._optimizer_policy.defaults['lr'], 'alpha': self._alpha.item(), 'target_q_value': target_q_values.detach().mean().item(), **norm_dict, **loss_dict, } def _update(self, loss_dict): # update critic self._optimizer_q.zero_grad() loss_dict['critic_loss'].backward() critic_norm = nn.utils.clip_grad_norm_(self._model.critic.parameters(), self._grad_clip) self._optimizer_q.step() # update policy self._optimizer_policy.zero_grad() loss_dict['policy_loss'].backward() policy_norm = nn.utils.clip_grad_norm_(self._model.actor.parameters(), self._grad_clip) self._optimizer_policy.step() # update temperature # self._alpha_optim.zero_grad() # loss_dict['alpha_loss'].backward() # self._alpha_optim.step() return {'policy_norm': policy_norm, 'critic_norm': critic_norm} def _monitor_vars_learn(self) -> List[str]: r""" Overview: Return variables' name if variables are to used in monitor. Returns: - vars (:obj:`List[str]`): Variables' name list. """ alpha_loss = ['alpha_loss'] if self._auto_alpha else [] return [ 'policy_loss', 'critic_loss', 'policy_norm', 'critic_norm', 'cur_lr_q', 'cur_lr_p', 'alpha', 'target_q_value', ] + alpha_loss