zjowowen's picture
init space
079c32c
import itertools
import numpy as np
import copy
import torch
from torch import nn
from ding.utils import WORLD_MODEL_REGISTRY
from ding.utils.data import default_collate
from ding.world_model.base_world_model import HybridWorldModel
from ding.world_model.model.ensemble import EnsembleModel, StandardScaler
from ding.torch_utils import fold_batch, unfold_batch, unsqueeze_repeat
@WORLD_MODEL_REGISTRY.register('mbpo')
class MBPOWorldModel(HybridWorldModel, nn.Module):
config = dict(
model=dict(
ensemble_size=7,
elite_size=5,
state_size=None,
action_size=None,
reward_size=1,
hidden_size=200,
use_decay=False,
batch_size=256,
holdout_ratio=0.2,
max_epochs_since_update=5,
deterministic_rollout=True,
),
)
def __init__(self, cfg, env, tb_logger):
HybridWorldModel.__init__(self, cfg, env, tb_logger)
nn.Module.__init__(self)
cfg = cfg.model
self.ensemble_size = cfg.ensemble_size
self.elite_size = cfg.elite_size
self.state_size = cfg.state_size
self.action_size = cfg.action_size
self.reward_size = cfg.reward_size
self.hidden_size = cfg.hidden_size
self.use_decay = cfg.use_decay
self.batch_size = cfg.batch_size
self.holdout_ratio = cfg.holdout_ratio
self.max_epochs_since_update = cfg.max_epochs_since_update
self.deterministic_rollout = cfg.deterministic_rollout
self.ensemble_model = EnsembleModel(
self.state_size,
self.action_size,
self.reward_size,
self.ensemble_size,
self.hidden_size,
use_decay=self.use_decay
)
self.scaler = StandardScaler(self.state_size + self.action_size)
if self._cuda:
self.cuda()
self.ensemble_mse_losses = []
self.model_variances = []
self.elite_model_idxes = []
def step(self, obs, act, batch_size=8192, keep_ensemble=False):
if len(act.shape) == 1:
act = act.unsqueeze(1)
if self._cuda:
obs = obs.cuda()
act = act.cuda()
inputs = torch.cat([obs, act], dim=-1)
if keep_ensemble:
inputs, dim = fold_batch(inputs, 1)
inputs = self.scaler.transform(inputs)
inputs = unfold_batch(inputs, dim)
else:
inputs = self.scaler.transform(inputs)
# predict
ensemble_mean, ensemble_var = [], []
batch_dim = 0 if len(inputs.shape) == 2 else 1
for i in range(0, inputs.shape[batch_dim], batch_size):
if keep_ensemble:
# inputs: [E, B, D]
input = inputs[:, i:i + batch_size]
else:
# input: [B, D]
input = unsqueeze_repeat(inputs[i:i + batch_size], self.ensemble_size)
b_mean, b_var = self.ensemble_model(input, ret_log_var=False)
ensemble_mean.append(b_mean)
ensemble_var.append(b_var)
ensemble_mean = torch.cat(ensemble_mean, 1)
ensemble_var = torch.cat(ensemble_var, 1)
if keep_ensemble:
ensemble_mean[:, :, 1:] += obs
else:
ensemble_mean[:, :, 1:] += obs.unsqueeze(0)
ensemble_std = ensemble_var.sqrt()
# sample from the predicted distribution
if self.deterministic_rollout:
ensemble_sample = ensemble_mean
else:
ensemble_sample = ensemble_mean + torch.randn_like(ensemble_mean).to(ensemble_mean) * ensemble_std
if keep_ensemble:
# [E, B, D]
rewards, next_obs = ensemble_sample[:, :, 0], ensemble_sample[:, :, 1:]
next_obs_flatten, dim = fold_batch(next_obs)
done = unfold_batch(self.env.termination_fn(next_obs_flatten), dim)
return rewards, next_obs, done
# sample from ensemble
model_idxes = torch.from_numpy(np.random.choice(self.elite_model_idxes, size=len(obs))).to(inputs.device)
batch_idxes = torch.arange(len(obs)).to(inputs.device)
sample = ensemble_sample[model_idxes, batch_idxes]
rewards, next_obs = sample[:, 0], sample[:, 1:]
return rewards, next_obs, self.env.termination_fn(next_obs)
def eval(self, env_buffer, envstep, train_iter):
data = env_buffer.sample(self.eval_freq, train_iter)
data = default_collate(data)
data['done'] = data['done'].float()
data['weight'] = data.get('weight', None)
obs = data['obs']
action = data['action']
reward = data['reward']
next_obs = data['next_obs']
if len(reward.shape) == 1:
reward = reward.unsqueeze(1)
if len(action.shape) == 1:
action = action.unsqueeze(1)
# build eval samples
inputs = torch.cat([obs, action], dim=1)
labels = torch.cat([reward, next_obs - obs], dim=1)
if self._cuda:
inputs = inputs.cuda()
labels = labels.cuda()
# normalize
inputs = self.scaler.transform(inputs)
# repeat for ensemble
inputs = unsqueeze_repeat(inputs, self.ensemble_size)
labels = unsqueeze_repeat(labels, self.ensemble_size)
# eval
with torch.no_grad():
mean, logvar = self.ensemble_model(inputs, ret_log_var=True)
loss, mse_loss = self.ensemble_model.loss(mean, logvar, labels)
ensemble_mse_loss = torch.pow(mean.mean(0) - labels[0], 2)
model_variance = mean.var(0)
self.tb_logger.add_scalar('env_model_step/eval_mse_loss', mse_loss.mean().item(), envstep)
self.tb_logger.add_scalar('env_model_step/eval_ensemble_mse_loss', ensemble_mse_loss.mean().item(), envstep)
self.tb_logger.add_scalar('env_model_step/eval_model_variances', model_variance.mean().item(), envstep)
self.last_eval_step = envstep
def train(self, env_buffer, envstep, train_iter):
data = env_buffer.sample(env_buffer.count(), train_iter)
data = default_collate(data)
data['done'] = data['done'].float()
data['weight'] = data.get('weight', None)
obs = data['obs']
action = data['action']
reward = data['reward']
next_obs = data['next_obs']
if len(reward.shape) == 1:
reward = reward.unsqueeze(1)
if len(action.shape) == 1:
action = action.unsqueeze(1)
# build train samples
inputs = torch.cat([obs, action], dim=1)
labels = torch.cat([reward, next_obs - obs], dim=1)
if self._cuda:
inputs = inputs.cuda()
labels = labels.cuda()
# train
logvar = self._train(inputs, labels)
self.last_train_step = envstep
# log
if self.tb_logger is not None:
for k, v in logvar.items():
self.tb_logger.add_scalar('env_model_step/' + k, v, envstep)
def _train(self, inputs, labels):
#split
num_holdout = int(inputs.shape[0] * self.holdout_ratio)
train_inputs, train_labels = inputs[num_holdout:], labels[num_holdout:]
holdout_inputs, holdout_labels = inputs[:num_holdout], labels[:num_holdout]
#normalize
self.scaler.fit(train_inputs)
train_inputs = self.scaler.transform(train_inputs)
holdout_inputs = self.scaler.transform(holdout_inputs)
#repeat for ensemble
holdout_inputs = unsqueeze_repeat(holdout_inputs, self.ensemble_size)
holdout_labels = unsqueeze_repeat(holdout_labels, self.ensemble_size)
self._epochs_since_update = 0
self._snapshots = {i: (-1, 1e10) for i in range(self.ensemble_size)}
self._save_states()
for epoch in itertools.count():
train_idx = torch.stack([torch.randperm(train_inputs.shape[0])
for _ in range(self.ensemble_size)]).to(train_inputs.device)
self.mse_loss = []
for start_pos in range(0, train_inputs.shape[0], self.batch_size):
idx = train_idx[:, start_pos:start_pos + self.batch_size]
train_input = train_inputs[idx]
train_label = train_labels[idx]
mean, logvar = self.ensemble_model(train_input, ret_log_var=True)
loss, mse_loss = self.ensemble_model.loss(mean, logvar, train_label)
self.ensemble_model.train(loss)
self.mse_loss.append(mse_loss.mean().item())
self.mse_loss = sum(self.mse_loss) / len(self.mse_loss)
with torch.no_grad():
holdout_mean, holdout_logvar = self.ensemble_model(holdout_inputs, ret_log_var=True)
_, holdout_mse_loss = self.ensemble_model.loss(holdout_mean, holdout_logvar, holdout_labels)
self.curr_holdout_mse_loss = holdout_mse_loss.mean().item()
break_train = self._save_best(epoch, holdout_mse_loss)
if break_train:
break
self._load_states()
with torch.no_grad():
holdout_mean, holdout_logvar = self.ensemble_model(holdout_inputs, ret_log_var=True)
_, holdout_mse_loss = self.ensemble_model.loss(holdout_mean, holdout_logvar, holdout_labels)
sorted_loss, sorted_loss_idx = holdout_mse_loss.sort()
sorted_loss = sorted_loss.detach().cpu().numpy().tolist()
sorted_loss_idx = sorted_loss_idx.detach().cpu().numpy().tolist()
self.elite_model_idxes = sorted_loss_idx[:self.elite_size]
self.top_holdout_mse_loss = sorted_loss[0]
self.middle_holdout_mse_loss = sorted_loss[self.ensemble_size // 2]
self.bottom_holdout_mse_loss = sorted_loss[-1]
self.best_holdout_mse_loss = holdout_mse_loss.mean().item()
return {
'mse_loss': self.mse_loss,
'curr_holdout_mse_loss': self.curr_holdout_mse_loss,
'best_holdout_mse_loss': self.best_holdout_mse_loss,
'top_holdout_mse_loss': self.top_holdout_mse_loss,
'middle_holdout_mse_loss': self.middle_holdout_mse_loss,
'bottom_holdout_mse_loss': self.bottom_holdout_mse_loss,
}
def _save_states(self, ):
self._states = copy.deepcopy(self.state_dict())
def _save_state(self, id):
state_dict = self.state_dict()
for k, v in state_dict.items():
if 'weight' in k or 'bias' in k:
self._states[k].data[id] = copy.deepcopy(v.data[id])
def _load_states(self):
self.load_state_dict(self._states)
def _save_best(self, epoch, holdout_losses):
updated = False
for i in range(len(holdout_losses)):
current = holdout_losses[i]
_, best = self._snapshots[i]
improvement = (best - current) / best
if improvement > 0.01:
self._snapshots[i] = (epoch, current)
self._save_state(i)
# self._save_state(i)
updated = True
# improvement = (best - current) / best
if updated:
self._epochs_since_update = 0
else:
self._epochs_since_update += 1
return self._epochs_since_update > self.max_epochs_since_update