import numpy as np import copy import torch from torch import nn from ding.utils import WORLD_MODEL_REGISTRY, lists_to_dicts from ding.utils.data import default_collate from ding.model import ConvEncoder from ding.world_model.base_world_model import WorldModel from ding.world_model.model.networks import RSSM, ConvDecoder from ding.torch_utils import to_device from ding.torch_utils.network.dreamer import DenseHead @WORLD_MODEL_REGISTRY.register('dreamer') class DREAMERWorldModel(WorldModel, nn.Module): config = dict( pretrain=100, train_freq=2, model=dict( state_size=None, action_size=None, model_lr=1e-4, reward_size=1, hidden_size=200, batch_size=256, max_epochs_since_update=5, dyn_stoch=32, dyn_deter=512, dyn_hidden=512, dyn_input_layers=1, dyn_output_layers=1, dyn_rec_depth=1, dyn_shared=False, dyn_discrete=32, act='SiLU', norm='LayerNorm', grad_heads=['image', 'reward', 'discount'], units=512, reward_layers=2, discount_layers=2, value_layers=2, actor_layers=2, cnn_depth=32, encoder_kernels=[4, 4, 4, 4], decoder_kernels=[4, 4, 4, 4], reward_head='twohot_symlog', kl_lscale=0.1, kl_rscale=0.5, kl_free=1.0, kl_forward=False, pred_discount=True, dyn_mean_act='none', dyn_std_act='sigmoid2', dyn_temp_post=True, dyn_min_std=0.1, dyn_cell='gru_layer_norm', unimix_ratio=0.01, device='cuda' if torch.cuda.is_available() else 'cpu', ), ) def __init__(self, cfg, env, tb_logger): WorldModel.__init__(self, cfg, env, tb_logger) nn.Module.__init__(self) self.pretrain_flag = True self._cfg = cfg.model #self._cfg.act = getattr(torch.nn, self._cfg.act), #self._cfg.norm = getattr(torch.nn, self._cfg.norm), self._cfg.act = nn.modules.activation.SiLU # nn.SiLU self._cfg.norm = nn.modules.normalization.LayerNorm # nn.LayerNorm self.state_size = self._cfg.state_size self.action_size = self._cfg.action_size self.reward_size = self._cfg.reward_size self.hidden_size = self._cfg.hidden_size self.batch_size = self._cfg.batch_size self.encoder = ConvEncoder( self.state_size, hidden_size_list=[32, 64, 128, 256, 4096], # to last layer 128? activation=torch.nn.SiLU(), kernel_size=self._cfg.encoder_kernels, layer_norm=True ) self.embed_size = ( (self.state_size[1] // 2 ** (len(self._cfg.encoder_kernels))) ** 2 * self._cfg.cnn_depth * 2 ** (len(self._cfg.encoder_kernels) - 1) ) self.dynamics = RSSM( self._cfg.dyn_stoch, self._cfg.dyn_deter, self._cfg.dyn_hidden, self._cfg.dyn_input_layers, self._cfg.dyn_output_layers, self._cfg.dyn_rec_depth, self._cfg.dyn_shared, self._cfg.dyn_discrete, self._cfg.act, self._cfg.norm, self._cfg.dyn_mean_act, self._cfg.dyn_std_act, self._cfg.dyn_temp_post, self._cfg.dyn_min_std, self._cfg.dyn_cell, self._cfg.unimix_ratio, self._cfg.action_size, self.embed_size, self._cfg.device, ) self.heads = nn.ModuleDict() if self._cfg.dyn_discrete: feat_size = self._cfg.dyn_stoch * self._cfg.dyn_discrete + self._cfg.dyn_deter else: feat_size = self._cfg.dyn_stoch + self._cfg.dyn_deter self.heads["image"] = ConvDecoder( feat_size, # pytorch version self._cfg.cnn_depth, self._cfg.act, self._cfg.norm, self.state_size, self._cfg.decoder_kernels, ) self.heads["reward"] = DenseHead( feat_size, # dyn_stoch * dyn_discrete + dyn_deter (255, ), self._cfg.reward_layers, self._cfg.units, 'SiLU', # self._cfg.act 'LN', # self._cfg.norm dist=self._cfg.reward_head, outscale=0.0, device=self._cfg.device, ) if self._cfg.pred_discount: self.heads["discount"] = DenseHead( feat_size, # pytorch version [], self._cfg.discount_layers, self._cfg.units, 'SiLU', # self._cfg.act 'LN', # self._cfg.norm dist="binary", device=self._cfg.device, ) if self._cuda: self.cuda() # to do # grad_clip, weight_decay self.optimizer = torch.optim.Adam(self.parameters(), lr=self._cfg.model_lr) def step(self, obs, act): pass def eval(self, env_buffer, envstep, train_iter): pass def should_pretrain(self): if self.pretrain_flag: self.pretrain_flag = False return True return False def train(self, env_buffer, envstep, train_iter, batch_size, batch_length): self.last_train_step = envstep data = env_buffer.sample( batch_size, batch_length, train_iter ) # [len=B, ele=[len=T, ele={dict_key: Tensor(any_dims)}]] data = default_collate(data) # -> [len=T, ele={dict_key: Tensor(B, any_dims)}] data = lists_to_dicts(data, recursive=True) # -> {some_key: T lists}, each list is [B, some_dim] data = {k: torch.stack(data[k], dim=1) for k in data} # -> {dict_key: Tensor([B, T, any_dims])} data['discount'] = data.get('discount', 1.0 - data['done'].float()) data['discount'] *= 0.997 data['weight'] = data.get('weight', None) data['image'] = data['obs'] - 0.5 data = to_device(data, self._cfg.device) if len(data['reward'].shape) == 2: data['reward'] = data['reward'].unsqueeze(-1) if len(data['action'].shape) == 2: data['action'] = data['action'].unsqueeze(-1) if len(data['discount'].shape) == 2: data['discount'] = data['discount'].unsqueeze(-1) self.requires_grad_(requires_grad=True) image = data['image'].reshape([-1] + list(data['image'].shape[-3:])) embed = self.encoder(image) embed = embed.reshape(list(data['image'].shape[:-3]) + [embed.shape[-1]]) post, prior = self.dynamics.observe(embed, data["action"]) kl_loss, kl_value, loss_lhs, loss_rhs = self.dynamics.kl_loss( post, prior, self._cfg.kl_forward, self._cfg.kl_free, self._cfg.kl_lscale, self._cfg.kl_rscale ) losses = {} likes = {} for name, head in self.heads.items(): grad_head = name in self._cfg.grad_heads feat = self.dynamics.get_feat(post) feat = feat if grad_head else feat.detach() pred = head(feat) like = pred.log_prob(data[name]) likes[name] = like losses[name] = -torch.mean(like) model_loss = sum(losses.values()) + kl_loss # ==================== # world model update # ==================== self.optimizer.zero_grad() model_loss.backward() self.optimizer.step() self.requires_grad_(requires_grad=False) # log if self.tb_logger is not None: for name, loss in losses.items(): self.tb_logger.add_scalar(name + '_loss', loss.detach().cpu().numpy().item(), envstep) self.tb_logger.add_scalar('kl_free', self._cfg.kl_free, envstep) self.tb_logger.add_scalar('kl_lscale', self._cfg.kl_lscale, envstep) self.tb_logger.add_scalar('kl_rscale', self._cfg.kl_rscale, envstep) self.tb_logger.add_scalar('loss_lhs', loss_lhs.detach().cpu().numpy().item(), envstep) self.tb_logger.add_scalar('loss_rhs', loss_rhs.detach().cpu().numpy().item(), envstep) self.tb_logger.add_scalar('kl', torch.mean(kl_value).detach().cpu().numpy().item(), envstep) prior_ent = torch.mean(self.dynamics.get_dist(prior).entropy()).detach().cpu().numpy() post_ent = torch.mean(self.dynamics.get_dist(post).entropy()).detach().cpu().numpy() self.tb_logger.add_scalar('prior_ent', prior_ent.item(), envstep) self.tb_logger.add_scalar('post_ent', post_ent.item(), envstep) context = dict( embed=embed, feat=self.dynamics.get_feat(post), kl=kl_value, postent=self.dynamics.get_dist(post).entropy(), ) post = {k: v.detach() for k, v in post.items()} return post, context def _save_states(self, ): self._states = copy.deepcopy(self.state_dict()) def _save_state(self, id): state_dict = self.state_dict() for k, v in state_dict.items(): if 'weight' in k or 'bias' in k: self._states[k].data[id] = copy.deepcopy(v.data[id]) def _load_states(self): self.load_state_dict(self._states) def _save_best(self, epoch, holdout_losses): updated = False for i in range(len(holdout_losses)): current = holdout_losses[i] _, best = self._snapshots[i] improvement = (best - current) / best if improvement > 0.01: self._snapshots[i] = (epoch, current) self._save_state(i) # self._save_state(i) updated = True # improvement = (best - current) / best if updated: self._epochs_since_update = 0 else: self._epochs_since_update += 1 return self._epochs_since_update > self.max_epochs_since_update