|
import numpy as np |
|
import copy |
|
import torch |
|
from torch import nn |
|
|
|
from ding.utils import WORLD_MODEL_REGISTRY, lists_to_dicts |
|
from ding.utils.data import default_collate |
|
from ding.model import ConvEncoder |
|
from ding.world_model.base_world_model import WorldModel |
|
from ding.world_model.model.networks import RSSM, ConvDecoder |
|
from ding.torch_utils import to_device |
|
from ding.torch_utils.network.dreamer import DenseHead |
|
|
|
|
|
@WORLD_MODEL_REGISTRY.register('dreamer') |
|
class DREAMERWorldModel(WorldModel, nn.Module): |
|
config = dict( |
|
pretrain=100, |
|
train_freq=2, |
|
model=dict( |
|
state_size=None, |
|
action_size=None, |
|
model_lr=1e-4, |
|
reward_size=1, |
|
hidden_size=200, |
|
batch_size=256, |
|
max_epochs_since_update=5, |
|
dyn_stoch=32, |
|
dyn_deter=512, |
|
dyn_hidden=512, |
|
dyn_input_layers=1, |
|
dyn_output_layers=1, |
|
dyn_rec_depth=1, |
|
dyn_shared=False, |
|
dyn_discrete=32, |
|
act='SiLU', |
|
norm='LayerNorm', |
|
grad_heads=['image', 'reward', 'discount'], |
|
units=512, |
|
reward_layers=2, |
|
discount_layers=2, |
|
value_layers=2, |
|
actor_layers=2, |
|
cnn_depth=32, |
|
encoder_kernels=[4, 4, 4, 4], |
|
decoder_kernels=[4, 4, 4, 4], |
|
reward_head='twohot_symlog', |
|
kl_lscale=0.1, |
|
kl_rscale=0.5, |
|
kl_free=1.0, |
|
kl_forward=False, |
|
pred_discount=True, |
|
dyn_mean_act='none', |
|
dyn_std_act='sigmoid2', |
|
dyn_temp_post=True, |
|
dyn_min_std=0.1, |
|
dyn_cell='gru_layer_norm', |
|
unimix_ratio=0.01, |
|
device='cuda' if torch.cuda.is_available() else 'cpu', |
|
), |
|
) |
|
|
|
def __init__(self, cfg, env, tb_logger): |
|
WorldModel.__init__(self, cfg, env, tb_logger) |
|
nn.Module.__init__(self) |
|
|
|
self.pretrain_flag = True |
|
self._cfg = cfg.model |
|
|
|
|
|
self._cfg.act = nn.modules.activation.SiLU |
|
self._cfg.norm = nn.modules.normalization.LayerNorm |
|
self.state_size = self._cfg.state_size |
|
self.action_size = self._cfg.action_size |
|
self.reward_size = self._cfg.reward_size |
|
self.hidden_size = self._cfg.hidden_size |
|
self.batch_size = self._cfg.batch_size |
|
|
|
self.encoder = ConvEncoder( |
|
self.state_size, |
|
hidden_size_list=[32, 64, 128, 256, 4096], |
|
activation=torch.nn.SiLU(), |
|
kernel_size=self._cfg.encoder_kernels, |
|
layer_norm=True |
|
) |
|
self.embed_size = ( |
|
(self.state_size[1] // 2 ** (len(self._cfg.encoder_kernels))) ** 2 * self._cfg.cnn_depth * |
|
2 ** (len(self._cfg.encoder_kernels) - 1) |
|
) |
|
self.dynamics = RSSM( |
|
self._cfg.dyn_stoch, |
|
self._cfg.dyn_deter, |
|
self._cfg.dyn_hidden, |
|
self._cfg.dyn_input_layers, |
|
self._cfg.dyn_output_layers, |
|
self._cfg.dyn_rec_depth, |
|
self._cfg.dyn_shared, |
|
self._cfg.dyn_discrete, |
|
self._cfg.act, |
|
self._cfg.norm, |
|
self._cfg.dyn_mean_act, |
|
self._cfg.dyn_std_act, |
|
self._cfg.dyn_temp_post, |
|
self._cfg.dyn_min_std, |
|
self._cfg.dyn_cell, |
|
self._cfg.unimix_ratio, |
|
self._cfg.action_size, |
|
self.embed_size, |
|
self._cfg.device, |
|
) |
|
self.heads = nn.ModuleDict() |
|
if self._cfg.dyn_discrete: |
|
feat_size = self._cfg.dyn_stoch * self._cfg.dyn_discrete + self._cfg.dyn_deter |
|
else: |
|
feat_size = self._cfg.dyn_stoch + self._cfg.dyn_deter |
|
self.heads["image"] = ConvDecoder( |
|
feat_size, |
|
self._cfg.cnn_depth, |
|
self._cfg.act, |
|
self._cfg.norm, |
|
self.state_size, |
|
self._cfg.decoder_kernels, |
|
) |
|
self.heads["reward"] = DenseHead( |
|
feat_size, |
|
(255, ), |
|
self._cfg.reward_layers, |
|
self._cfg.units, |
|
'SiLU', |
|
'LN', |
|
dist=self._cfg.reward_head, |
|
outscale=0.0, |
|
device=self._cfg.device, |
|
) |
|
if self._cfg.pred_discount: |
|
self.heads["discount"] = DenseHead( |
|
feat_size, |
|
[], |
|
self._cfg.discount_layers, |
|
self._cfg.units, |
|
'SiLU', |
|
'LN', |
|
dist="binary", |
|
device=self._cfg.device, |
|
) |
|
|
|
if self._cuda: |
|
self.cuda() |
|
|
|
|
|
self.optimizer = torch.optim.Adam(self.parameters(), lr=self._cfg.model_lr) |
|
|
|
def step(self, obs, act): |
|
pass |
|
|
|
def eval(self, env_buffer, envstep, train_iter): |
|
pass |
|
|
|
def should_pretrain(self): |
|
if self.pretrain_flag: |
|
self.pretrain_flag = False |
|
return True |
|
return False |
|
|
|
def train(self, env_buffer, envstep, train_iter, batch_size, batch_length): |
|
self.last_train_step = envstep |
|
data = env_buffer.sample( |
|
batch_size, batch_length, train_iter |
|
) |
|
data = default_collate(data) |
|
data = lists_to_dicts(data, recursive=True) |
|
data = {k: torch.stack(data[k], dim=1) for k in data} |
|
|
|
data['discount'] = data.get('discount', 1.0 - data['done'].float()) |
|
data['discount'] *= 0.997 |
|
data['weight'] = data.get('weight', None) |
|
data['image'] = data['obs'] - 0.5 |
|
data = to_device(data, self._cfg.device) |
|
if len(data['reward'].shape) == 2: |
|
data['reward'] = data['reward'].unsqueeze(-1) |
|
if len(data['action'].shape) == 2: |
|
data['action'] = data['action'].unsqueeze(-1) |
|
if len(data['discount'].shape) == 2: |
|
data['discount'] = data['discount'].unsqueeze(-1) |
|
|
|
self.requires_grad_(requires_grad=True) |
|
|
|
image = data['image'].reshape([-1] + list(data['image'].shape[-3:])) |
|
embed = self.encoder(image) |
|
embed = embed.reshape(list(data['image'].shape[:-3]) + [embed.shape[-1]]) |
|
|
|
post, prior = self.dynamics.observe(embed, data["action"]) |
|
kl_loss, kl_value, loss_lhs, loss_rhs = self.dynamics.kl_loss( |
|
post, prior, self._cfg.kl_forward, self._cfg.kl_free, self._cfg.kl_lscale, self._cfg.kl_rscale |
|
) |
|
losses = {} |
|
likes = {} |
|
for name, head in self.heads.items(): |
|
grad_head = name in self._cfg.grad_heads |
|
feat = self.dynamics.get_feat(post) |
|
feat = feat if grad_head else feat.detach() |
|
pred = head(feat) |
|
like = pred.log_prob(data[name]) |
|
likes[name] = like |
|
losses[name] = -torch.mean(like) |
|
model_loss = sum(losses.values()) + kl_loss |
|
|
|
|
|
|
|
|
|
self.optimizer.zero_grad() |
|
model_loss.backward() |
|
self.optimizer.step() |
|
|
|
self.requires_grad_(requires_grad=False) |
|
|
|
if self.tb_logger is not None: |
|
for name, loss in losses.items(): |
|
self.tb_logger.add_scalar(name + '_loss', loss.detach().cpu().numpy().item(), envstep) |
|
self.tb_logger.add_scalar('kl_free', self._cfg.kl_free, envstep) |
|
self.tb_logger.add_scalar('kl_lscale', self._cfg.kl_lscale, envstep) |
|
self.tb_logger.add_scalar('kl_rscale', self._cfg.kl_rscale, envstep) |
|
self.tb_logger.add_scalar('loss_lhs', loss_lhs.detach().cpu().numpy().item(), envstep) |
|
self.tb_logger.add_scalar('loss_rhs', loss_rhs.detach().cpu().numpy().item(), envstep) |
|
self.tb_logger.add_scalar('kl', torch.mean(kl_value).detach().cpu().numpy().item(), envstep) |
|
|
|
prior_ent = torch.mean(self.dynamics.get_dist(prior).entropy()).detach().cpu().numpy() |
|
post_ent = torch.mean(self.dynamics.get_dist(post).entropy()).detach().cpu().numpy() |
|
|
|
self.tb_logger.add_scalar('prior_ent', prior_ent.item(), envstep) |
|
self.tb_logger.add_scalar('post_ent', post_ent.item(), envstep) |
|
|
|
context = dict( |
|
embed=embed, |
|
feat=self.dynamics.get_feat(post), |
|
kl=kl_value, |
|
postent=self.dynamics.get_dist(post).entropy(), |
|
) |
|
post = {k: v.detach() for k, v in post.items()} |
|
return post, context |
|
|
|
def _save_states(self, ): |
|
self._states = copy.deepcopy(self.state_dict()) |
|
|
|
def _save_state(self, id): |
|
state_dict = self.state_dict() |
|
for k, v in state_dict.items(): |
|
if 'weight' in k or 'bias' in k: |
|
self._states[k].data[id] = copy.deepcopy(v.data[id]) |
|
|
|
def _load_states(self): |
|
self.load_state_dict(self._states) |
|
|
|
def _save_best(self, epoch, holdout_losses): |
|
updated = False |
|
for i in range(len(holdout_losses)): |
|
current = holdout_losses[i] |
|
_, best = self._snapshots[i] |
|
improvement = (best - current) / best |
|
if improvement > 0.01: |
|
self._snapshots[i] = (epoch, current) |
|
self._save_state(i) |
|
|
|
updated = True |
|
|
|
|
|
if updated: |
|
self._epochs_since_update = 0 |
|
else: |
|
self._epochs_since_update += 1 |
|
return self._epochs_since_update > self.max_epochs_since_update |
|
|