zjowowen's picture
init space
079c32c
raw
history blame
10.1 kB
import numpy as np
import copy
import torch
from torch import nn
from ding.utils import WORLD_MODEL_REGISTRY, lists_to_dicts
from ding.utils.data import default_collate
from ding.model import ConvEncoder
from ding.world_model.base_world_model import WorldModel
from ding.world_model.model.networks import RSSM, ConvDecoder
from ding.torch_utils import to_device
from ding.torch_utils.network.dreamer import DenseHead
@WORLD_MODEL_REGISTRY.register('dreamer')
class DREAMERWorldModel(WorldModel, nn.Module):
config = dict(
pretrain=100,
train_freq=2,
model=dict(
state_size=None,
action_size=None,
model_lr=1e-4,
reward_size=1,
hidden_size=200,
batch_size=256,
max_epochs_since_update=5,
dyn_stoch=32,
dyn_deter=512,
dyn_hidden=512,
dyn_input_layers=1,
dyn_output_layers=1,
dyn_rec_depth=1,
dyn_shared=False,
dyn_discrete=32,
act='SiLU',
norm='LayerNorm',
grad_heads=['image', 'reward', 'discount'],
units=512,
reward_layers=2,
discount_layers=2,
value_layers=2,
actor_layers=2,
cnn_depth=32,
encoder_kernels=[4, 4, 4, 4],
decoder_kernels=[4, 4, 4, 4],
reward_head='twohot_symlog',
kl_lscale=0.1,
kl_rscale=0.5,
kl_free=1.0,
kl_forward=False,
pred_discount=True,
dyn_mean_act='none',
dyn_std_act='sigmoid2',
dyn_temp_post=True,
dyn_min_std=0.1,
dyn_cell='gru_layer_norm',
unimix_ratio=0.01,
device='cuda' if torch.cuda.is_available() else 'cpu',
),
)
def __init__(self, cfg, env, tb_logger):
WorldModel.__init__(self, cfg, env, tb_logger)
nn.Module.__init__(self)
self.pretrain_flag = True
self._cfg = cfg.model
#self._cfg.act = getattr(torch.nn, self._cfg.act),
#self._cfg.norm = getattr(torch.nn, self._cfg.norm),
self._cfg.act = nn.modules.activation.SiLU # nn.SiLU
self._cfg.norm = nn.modules.normalization.LayerNorm # nn.LayerNorm
self.state_size = self._cfg.state_size
self.action_size = self._cfg.action_size
self.reward_size = self._cfg.reward_size
self.hidden_size = self._cfg.hidden_size
self.batch_size = self._cfg.batch_size
self.encoder = ConvEncoder(
self.state_size,
hidden_size_list=[32, 64, 128, 256, 4096], # to last layer 128?
activation=torch.nn.SiLU(),
kernel_size=self._cfg.encoder_kernels,
layer_norm=True
)
self.embed_size = (
(self.state_size[1] // 2 ** (len(self._cfg.encoder_kernels))) ** 2 * self._cfg.cnn_depth *
2 ** (len(self._cfg.encoder_kernels) - 1)
)
self.dynamics = RSSM(
self._cfg.dyn_stoch,
self._cfg.dyn_deter,
self._cfg.dyn_hidden,
self._cfg.dyn_input_layers,
self._cfg.dyn_output_layers,
self._cfg.dyn_rec_depth,
self._cfg.dyn_shared,
self._cfg.dyn_discrete,
self._cfg.act,
self._cfg.norm,
self._cfg.dyn_mean_act,
self._cfg.dyn_std_act,
self._cfg.dyn_temp_post,
self._cfg.dyn_min_std,
self._cfg.dyn_cell,
self._cfg.unimix_ratio,
self._cfg.action_size,
self.embed_size,
self._cfg.device,
)
self.heads = nn.ModuleDict()
if self._cfg.dyn_discrete:
feat_size = self._cfg.dyn_stoch * self._cfg.dyn_discrete + self._cfg.dyn_deter
else:
feat_size = self._cfg.dyn_stoch + self._cfg.dyn_deter
self.heads["image"] = ConvDecoder(
feat_size, # pytorch version
self._cfg.cnn_depth,
self._cfg.act,
self._cfg.norm,
self.state_size,
self._cfg.decoder_kernels,
)
self.heads["reward"] = DenseHead(
feat_size, # dyn_stoch * dyn_discrete + dyn_deter
(255, ),
self._cfg.reward_layers,
self._cfg.units,
'SiLU', # self._cfg.act
'LN', # self._cfg.norm
dist=self._cfg.reward_head,
outscale=0.0,
device=self._cfg.device,
)
if self._cfg.pred_discount:
self.heads["discount"] = DenseHead(
feat_size, # pytorch version
[],
self._cfg.discount_layers,
self._cfg.units,
'SiLU', # self._cfg.act
'LN', # self._cfg.norm
dist="binary",
device=self._cfg.device,
)
if self._cuda:
self.cuda()
# to do
# grad_clip, weight_decay
self.optimizer = torch.optim.Adam(self.parameters(), lr=self._cfg.model_lr)
def step(self, obs, act):
pass
def eval(self, env_buffer, envstep, train_iter):
pass
def should_pretrain(self):
if self.pretrain_flag:
self.pretrain_flag = False
return True
return False
def train(self, env_buffer, envstep, train_iter, batch_size, batch_length):
self.last_train_step = envstep
data = env_buffer.sample(
batch_size, batch_length, train_iter
) # [len=B, ele=[len=T, ele={dict_key: Tensor(any_dims)}]]
data = default_collate(data) # -> [len=T, ele={dict_key: Tensor(B, any_dims)}]
data = lists_to_dicts(data, recursive=True) # -> {some_key: T lists}, each list is [B, some_dim]
data = {k: torch.stack(data[k], dim=1) for k in data} # -> {dict_key: Tensor([B, T, any_dims])}
data['discount'] = data.get('discount', 1.0 - data['done'].float())
data['discount'] *= 0.997
data['weight'] = data.get('weight', None)
data['image'] = data['obs'] - 0.5
data = to_device(data, self._cfg.device)
if len(data['reward'].shape) == 2:
data['reward'] = data['reward'].unsqueeze(-1)
if len(data['action'].shape) == 2:
data['action'] = data['action'].unsqueeze(-1)
if len(data['discount'].shape) == 2:
data['discount'] = data['discount'].unsqueeze(-1)
self.requires_grad_(requires_grad=True)
image = data['image'].reshape([-1] + list(data['image'].shape[-3:]))
embed = self.encoder(image)
embed = embed.reshape(list(data['image'].shape[:-3]) + [embed.shape[-1]])
post, prior = self.dynamics.observe(embed, data["action"])
kl_loss, kl_value, loss_lhs, loss_rhs = self.dynamics.kl_loss(
post, prior, self._cfg.kl_forward, self._cfg.kl_free, self._cfg.kl_lscale, self._cfg.kl_rscale
)
losses = {}
likes = {}
for name, head in self.heads.items():
grad_head = name in self._cfg.grad_heads
feat = self.dynamics.get_feat(post)
feat = feat if grad_head else feat.detach()
pred = head(feat)
like = pred.log_prob(data[name])
likes[name] = like
losses[name] = -torch.mean(like)
model_loss = sum(losses.values()) + kl_loss
# ====================
# world model update
# ====================
self.optimizer.zero_grad()
model_loss.backward()
self.optimizer.step()
self.requires_grad_(requires_grad=False)
# log
if self.tb_logger is not None:
for name, loss in losses.items():
self.tb_logger.add_scalar(name + '_loss', loss.detach().cpu().numpy().item(), envstep)
self.tb_logger.add_scalar('kl_free', self._cfg.kl_free, envstep)
self.tb_logger.add_scalar('kl_lscale', self._cfg.kl_lscale, envstep)
self.tb_logger.add_scalar('kl_rscale', self._cfg.kl_rscale, envstep)
self.tb_logger.add_scalar('loss_lhs', loss_lhs.detach().cpu().numpy().item(), envstep)
self.tb_logger.add_scalar('loss_rhs', loss_rhs.detach().cpu().numpy().item(), envstep)
self.tb_logger.add_scalar('kl', torch.mean(kl_value).detach().cpu().numpy().item(), envstep)
prior_ent = torch.mean(self.dynamics.get_dist(prior).entropy()).detach().cpu().numpy()
post_ent = torch.mean(self.dynamics.get_dist(post).entropy()).detach().cpu().numpy()
self.tb_logger.add_scalar('prior_ent', prior_ent.item(), envstep)
self.tb_logger.add_scalar('post_ent', post_ent.item(), envstep)
context = dict(
embed=embed,
feat=self.dynamics.get_feat(post),
kl=kl_value,
postent=self.dynamics.get_dist(post).entropy(),
)
post = {k: v.detach() for k, v in post.items()}
return post, context
def _save_states(self, ):
self._states = copy.deepcopy(self.state_dict())
def _save_state(self, id):
state_dict = self.state_dict()
for k, v in state_dict.items():
if 'weight' in k or 'bias' in k:
self._states[k].data[id] = copy.deepcopy(v.data[id])
def _load_states(self):
self.load_state_dict(self._states)
def _save_best(self, epoch, holdout_losses):
updated = False
for i in range(len(holdout_losses)):
current = holdout_losses[i]
_, best = self._snapshots[i]
improvement = (best - current) / best
if improvement > 0.01:
self._snapshots[i] = (epoch, current)
self._save_state(i)
# self._save_state(i)
updated = True
# improvement = (best - current) / best
if updated:
self._epochs_since_update = 0
else:
self._epochs_since_update += 1
return self._epochs_since_update > self.max_epochs_since_update