from typing import List, Dict, Any, Tuple, Union from collections import namedtuple import torch from torch import nn from copy import deepcopy from ding.torch_utils import Adam, to_device from ding.rl_utils import get_train_sample from ding.utils import POLICY_REGISTRY, deep_merge_dicts from ding.utils.data import default_collate, default_decollate from ding.policy import Policy from ding.model import model_wrap from ding.policy.common_utils import default_preprocess_learn from .utils import imagine, compute_target, compute_actor_loss, RewardEMA, tensorstats @POLICY_REGISTRY.register('dreamer') class DREAMERPolicy(Policy): config = dict( # (str) RL policy register name (refer to function "POLICY_REGISTRY"). type='dreamer', # (bool) Whether to use cuda for network and loss computation. cuda=False, # (int) Number of training samples (randomly collected) in replay buffer when training starts. random_collect_size=5000, # (bool) Whether to need policy-specific data in preprocess transition. transition_with_policy_data=False, # (int) imag_horizon=15, learn=dict( # (float) Lambda for TD-lambda return. lambda_=0.95, # (float) Max norm of gradients. grad_clip=100, learning_rate=3e-5, batch_size=16, batch_length=64, imag_sample=True, slow_value_target=True, slow_target_update=1, slow_target_fraction=0.02, discount=0.997, reward_EMA=True, actor_entropy=3e-4, actor_state_entropy=0.0, value_decay=0.0, ), ) def default_model(self) -> Tuple[str, List[str]]: return 'dreamervac', ['ding.model.template.vac'] def _init_learn(self) -> None: r""" Overview: Learn mode init method. Called by ``self.__init__``. Init the optimizer, algorithm config, main and target models. """ # Algorithm config self._lambda = self._cfg.learn.lambda_ self._grad_clip = self._cfg.learn.grad_clip self._critic = self._model.critic self._actor = self._model.actor if self._cfg.learn.slow_value_target: self._slow_value = deepcopy(self._critic) self._updates = 0 # Optimizer self._optimizer_value = Adam( self._critic.parameters(), lr=self._cfg.learn.learning_rate, ) self._optimizer_actor = Adam( self._actor.parameters(), lr=self._cfg.learn.learning_rate, ) self._learn_model = model_wrap(self._model, wrapper_name='base') self._learn_model.reset() self._forward_learn_cnt = 0 if self._cfg.learn.reward_EMA: self.reward_ema = RewardEMA(device=self._device) def _forward_learn(self, start: dict, world_model, envstep) -> Dict[str, Any]: # log dict log_vars = {} self._learn_model.train() self._update_slow_target() self._actor.requires_grad_(requires_grad=True) # start is dict of {stoch, deter, logit} if self._cuda: start = to_device(start, self._device) # train self._actor imag_feat, imag_state, imag_action = imagine( self._cfg.learn, world_model, start, self._actor, self._cfg.imag_horizon ) reward = world_model.heads["reward"](world_model.dynamics.get_feat(imag_state)).mode() actor_ent = self._actor(imag_feat).entropy() state_ent = world_model.dynamics.get_dist(imag_state).entropy() # this target is not scaled # slow is flag to indicate whether slow_target is used for lambda-return target, weights, base = compute_target( self._cfg.learn, world_model, self._critic, imag_feat, imag_state, reward, actor_ent, state_ent ) actor_loss, mets = compute_actor_loss( self._cfg.learn, self._actor, self.reward_ema, imag_feat, imag_state, imag_action, target, actor_ent, state_ent, weights, base, ) log_vars.update(mets) value_input = imag_feat self._actor.requires_grad_(requires_grad=False) self._critic.requires_grad_(requires_grad=True) value = self._critic(value_input[:-1].detach()) # to do # target = torch.stack(target, dim=1) # (time, batch, 1), (time, batch, 1) -> (time, batch) value_loss = -value.log_prob(target.detach()) slow_target = self._slow_value(value_input[:-1].detach()) if self._cfg.learn.slow_value_target: value_loss = value_loss - value.log_prob(slow_target.mode().detach()) if self._cfg.learn.value_decay: value_loss += self._cfg.learn.value_decay * value.mode() # (time, batch, 1), (time, batch, 1) -> (1,) value_loss = torch.mean(weights[:-1] * value_loss[:, :, None]) self._critic.requires_grad_(requires_grad=False) log_vars.update(tensorstats(value.mode(), "value")) log_vars.update(tensorstats(target, "target")) log_vars.update(tensorstats(reward, "imag_reward")) log_vars.update(tensorstats(imag_action, "imag_action")) log_vars["actor_ent"] = torch.mean(actor_ent).detach().cpu().numpy().item() # ==================== # actor-critic update # ==================== self._model.requires_grad_(requires_grad=True) world_model.requires_grad_(requires_grad=True) loss_dict = { 'critic_loss': value_loss, 'actor_loss': actor_loss, } norm_dict = self._update(loss_dict) self._model.requires_grad_(requires_grad=False) world_model.requires_grad_(requires_grad=False) # ============= # after update # ============= self._forward_learn_cnt += 1 return { **log_vars, **norm_dict, **loss_dict, } def _update(self, loss_dict): # update actor self._optimizer_actor.zero_grad() loss_dict['actor_loss'].backward() actor_norm = nn.utils.clip_grad_norm_(self._model.actor.parameters(), self._grad_clip) self._optimizer_actor.step() # update critic self._optimizer_value.zero_grad() loss_dict['critic_loss'].backward() critic_norm = nn.utils.clip_grad_norm_(self._model.critic.parameters(), self._grad_clip) self._optimizer_value.step() return {'actor_grad_norm': actor_norm, 'critic_grad_norm': critic_norm} def _update_slow_target(self): if self._cfg.learn.slow_value_target: if self._updates % self._cfg.learn.slow_target_update == 0: mix = self._cfg.learn.slow_target_fraction for s, d in zip(self._critic.parameters(), self._slow_value.parameters()): d.data = mix * s.data + (1 - mix) * d.data self._updates += 1 def _state_dict_learn(self) -> Dict[str, Any]: ret = { 'model': self._learn_model.state_dict(), 'optimizer_value': self._optimizer_value.state_dict(), 'optimizer_actor': self._optimizer_actor.state_dict(), } return ret def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: self._learn_model.load_state_dict(state_dict['model']) self._optimizer_value.load_state_dict(state_dict['optimizer_value']) self._optimizer_actor.load_state_dict(state_dict['optimizer_actor']) def _init_collect(self) -> None: self._unroll_len = self._cfg.collect.unroll_len self._collect_model = model_wrap(self._model, wrapper_name='base') self._collect_model.reset() def _forward_collect(self, data: dict, world_model, envstep, reset=None, state=None) -> dict: data_id = list(data.keys()) data = default_collate(list(data.values())) if self._cuda: data = to_device(data, self._device) self._collect_model.eval() if state is None: batch_size = len(data_id) latent = world_model.dynamics.initial(batch_size) # {logit, stoch, deter} action = torch.zeros((batch_size, self._cfg.collect.action_size)).to(self._device) else: #state = default_collate(list(state.values())) latent = to_device(default_collate(list(zip(*state))[0]), self._device) action = to_device(default_collate(list(zip(*state))[1]), self._device) if len(action.shape) == 1: action = action.unsqueeze(-1) if reset.any(): mask = 1 - reset for key in latent.keys(): for i in range(latent[key].shape[0]): latent[key][i] *= mask[i] for i in range(len(action)): action[i] *= mask[i] data = data - 0.5 embed = world_model.encoder(data) latent, _ = world_model.dynamics.obs_step(latent, action, embed, self._cfg.collect.collect_dyn_sample) feat = world_model.dynamics.get_feat(latent) actor = self._actor(feat) action = actor.sample() logprob = actor.log_prob(action) latent = {k: v.detach() for k, v in latent.items()} action = action.detach() state = (latent, action) output = {"action": action, "logprob": logprob, "state": state} if self._cuda: output = to_device(output, 'cpu') output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: r""" Overview: Generate dict type transition data from inputs. Arguments: - obs (:obj:`Any`): Env observation - model_output (:obj:`dict`): Output of collect model, including at least ['action'] - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \ (here 'obs' indicates obs after env step). Returns: - transition (:obj:`dict`): Dict type transition data. """ transition = { 'obs': obs, 'action': model_output['action'], # TODO(zp) random_collect just have action #'logprob': model_output['logprob'], 'reward': timestep.reward, 'discount': timestep.info['discount'], 'done': timestep.done, } return transition def _get_train_sample(self, data: list) -> Union[None, List[Any]]: return get_train_sample(data, self._unroll_len) def _init_eval(self) -> None: self._eval_model = model_wrap(self._model, wrapper_name='base') self._eval_model.reset() def _forward_eval(self, data: dict, world_model, reset=None, state=None) -> dict: data_id = list(data.keys()) data = default_collate(list(data.values())) if self._cuda: data = to_device(data, self._device) self._eval_model.eval() if state is None: batch_size = len(data_id) latent = world_model.dynamics.initial(batch_size) # {logit, stoch, deter} action = torch.zeros((batch_size, self._cfg.collect.action_size)).to(self._device) else: #state = default_collate(list(state.values())) latent = to_device(default_collate(list(zip(*state))[0]), self._device) action = to_device(default_collate(list(zip(*state))[1]), self._device) if len(action.shape) == 1: action = action.unsqueeze(-1) if reset.any(): mask = 1 - reset for key in latent.keys(): for i in range(latent[key].shape[0]): latent[key][i] *= mask[i] for i in range(len(action)): action[i] *= mask[i] data = data - 0.5 embed = world_model.encoder(data) latent, _ = world_model.dynamics.obs_step(latent, action, embed, self._cfg.collect.collect_dyn_sample) feat = world_model.dynamics.get_feat(latent) actor = self._actor(feat) action = actor.mode() logprob = actor.log_prob(action) latent = {k: v.detach() for k, v in latent.items()} action = action.detach() state = (latent, action) output = {"action": action, "logprob": logprob, "state": state} if self._cuda: output = to_device(output, 'cpu') output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} def _monitor_vars_learn(self) -> List[str]: r""" Overview: Return variables' name if variables are to used in monitor. Returns: - vars (:obj:`List[str]`): Variables' name list. """ return [ 'normed_target_mean', 'normed_target_std', 'normed_target_min', 'normed_target_max', 'EMA_005', 'EMA_095', 'actor_entropy', 'actor_state_entropy', 'value_mean', 'value_std', 'value_min', 'value_max', 'target_mean', 'target_std', 'target_min', 'target_max', 'imag_reward_mean', 'imag_reward_std', 'imag_reward_min', 'imag_reward_max', 'imag_action_mean', 'imag_action_std', 'imag_action_min', 'imag_action_max', 'actor_ent', 'actor_loss', 'critic_loss', 'actor_grad_norm', 'critic_grad_norm' ]