zjowowen's picture
init space
079c32c
raw
history blame
13.8 kB
from typing import List, Dict, Any, Tuple, Union
from collections import namedtuple
import torch
from torch import nn
from copy import deepcopy
from ding.torch_utils import Adam, to_device
from ding.rl_utils import get_train_sample
from ding.utils import POLICY_REGISTRY, deep_merge_dicts
from ding.utils.data import default_collate, default_decollate
from ding.policy import Policy
from ding.model import model_wrap
from ding.policy.common_utils import default_preprocess_learn
from .utils import imagine, compute_target, compute_actor_loss, RewardEMA, tensorstats
@POLICY_REGISTRY.register('dreamer')
class DREAMERPolicy(Policy):
config = dict(
# (str) RL policy register name (refer to function "POLICY_REGISTRY").
type='dreamer',
# (bool) Whether to use cuda for network and loss computation.
cuda=False,
# (int) Number of training samples (randomly collected) in replay buffer when training starts.
random_collect_size=5000,
# (bool) Whether to need policy-specific data in preprocess transition.
transition_with_policy_data=False,
# (int)
imag_horizon=15,
learn=dict(
# (float) Lambda for TD-lambda return.
lambda_=0.95,
# (float) Max norm of gradients.
grad_clip=100,
learning_rate=3e-5,
batch_size=16,
batch_length=64,
imag_sample=True,
slow_value_target=True,
slow_target_update=1,
slow_target_fraction=0.02,
discount=0.997,
reward_EMA=True,
actor_entropy=3e-4,
actor_state_entropy=0.0,
value_decay=0.0,
),
)
def default_model(self) -> Tuple[str, List[str]]:
return 'dreamervac', ['ding.model.template.vac']
def _init_learn(self) -> None:
r"""
Overview:
Learn mode init method. Called by ``self.__init__``.
Init the optimizer, algorithm config, main and target models.
"""
# Algorithm config
self._lambda = self._cfg.learn.lambda_
self._grad_clip = self._cfg.learn.grad_clip
self._critic = self._model.critic
self._actor = self._model.actor
if self._cfg.learn.slow_value_target:
self._slow_value = deepcopy(self._critic)
self._updates = 0
# Optimizer
self._optimizer_value = Adam(
self._critic.parameters(),
lr=self._cfg.learn.learning_rate,
)
self._optimizer_actor = Adam(
self._actor.parameters(),
lr=self._cfg.learn.learning_rate,
)
self._learn_model = model_wrap(self._model, wrapper_name='base')
self._learn_model.reset()
self._forward_learn_cnt = 0
if self._cfg.learn.reward_EMA:
self.reward_ema = RewardEMA(device=self._device)
def _forward_learn(self, start: dict, world_model, envstep) -> Dict[str, Any]:
# log dict
log_vars = {}
self._learn_model.train()
self._update_slow_target()
self._actor.requires_grad_(requires_grad=True)
# start is dict of {stoch, deter, logit}
if self._cuda:
start = to_device(start, self._device)
# train self._actor
imag_feat, imag_state, imag_action = imagine(
self._cfg.learn, world_model, start, self._actor, self._cfg.imag_horizon
)
reward = world_model.heads["reward"](world_model.dynamics.get_feat(imag_state)).mode()
actor_ent = self._actor(imag_feat).entropy()
state_ent = world_model.dynamics.get_dist(imag_state).entropy()
# this target is not scaled
# slow is flag to indicate whether slow_target is used for lambda-return
target, weights, base = compute_target(
self._cfg.learn, world_model, self._critic, imag_feat, imag_state, reward, actor_ent, state_ent
)
actor_loss, mets = compute_actor_loss(
self._cfg.learn,
self._actor,
self.reward_ema,
imag_feat,
imag_state,
imag_action,
target,
actor_ent,
state_ent,
weights,
base,
)
log_vars.update(mets)
value_input = imag_feat
self._actor.requires_grad_(requires_grad=False)
self._critic.requires_grad_(requires_grad=True)
value = self._critic(value_input[:-1].detach())
# to do
# target = torch.stack(target, dim=1)
# (time, batch, 1), (time, batch, 1) -> (time, batch)
value_loss = -value.log_prob(target.detach())
slow_target = self._slow_value(value_input[:-1].detach())
if self._cfg.learn.slow_value_target:
value_loss = value_loss - value.log_prob(slow_target.mode().detach())
if self._cfg.learn.value_decay:
value_loss += self._cfg.learn.value_decay * value.mode()
# (time, batch, 1), (time, batch, 1) -> (1,)
value_loss = torch.mean(weights[:-1] * value_loss[:, :, None])
self._critic.requires_grad_(requires_grad=False)
log_vars.update(tensorstats(value.mode(), "value"))
log_vars.update(tensorstats(target, "target"))
log_vars.update(tensorstats(reward, "imag_reward"))
log_vars.update(tensorstats(imag_action, "imag_action"))
log_vars["actor_ent"] = torch.mean(actor_ent).detach().cpu().numpy().item()
# ====================
# actor-critic update
# ====================
self._model.requires_grad_(requires_grad=True)
world_model.requires_grad_(requires_grad=True)
loss_dict = {
'critic_loss': value_loss,
'actor_loss': actor_loss,
}
norm_dict = self._update(loss_dict)
self._model.requires_grad_(requires_grad=False)
world_model.requires_grad_(requires_grad=False)
# =============
# after update
# =============
self._forward_learn_cnt += 1
return {
**log_vars,
**norm_dict,
**loss_dict,
}
def _update(self, loss_dict):
# update actor
self._optimizer_actor.zero_grad()
loss_dict['actor_loss'].backward()
actor_norm = nn.utils.clip_grad_norm_(self._model.actor.parameters(), self._grad_clip)
self._optimizer_actor.step()
# update critic
self._optimizer_value.zero_grad()
loss_dict['critic_loss'].backward()
critic_norm = nn.utils.clip_grad_norm_(self._model.critic.parameters(), self._grad_clip)
self._optimizer_value.step()
return {'actor_grad_norm': actor_norm, 'critic_grad_norm': critic_norm}
def _update_slow_target(self):
if self._cfg.learn.slow_value_target:
if self._updates % self._cfg.learn.slow_target_update == 0:
mix = self._cfg.learn.slow_target_fraction
for s, d in zip(self._critic.parameters(), self._slow_value.parameters()):
d.data = mix * s.data + (1 - mix) * d.data
self._updates += 1
def _state_dict_learn(self) -> Dict[str, Any]:
ret = {
'model': self._learn_model.state_dict(),
'optimizer_value': self._optimizer_value.state_dict(),
'optimizer_actor': self._optimizer_actor.state_dict(),
}
return ret
def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
self._learn_model.load_state_dict(state_dict['model'])
self._optimizer_value.load_state_dict(state_dict['optimizer_value'])
self._optimizer_actor.load_state_dict(state_dict['optimizer_actor'])
def _init_collect(self) -> None:
self._unroll_len = self._cfg.collect.unroll_len
self._collect_model = model_wrap(self._model, wrapper_name='base')
self._collect_model.reset()
def _forward_collect(self, data: dict, world_model, envstep, reset=None, state=None) -> dict:
data_id = list(data.keys())
data = default_collate(list(data.values()))
if self._cuda:
data = to_device(data, self._device)
self._collect_model.eval()
if state is None:
batch_size = len(data_id)
latent = world_model.dynamics.initial(batch_size) # {logit, stoch, deter}
action = torch.zeros((batch_size, self._cfg.collect.action_size)).to(self._device)
else:
#state = default_collate(list(state.values()))
latent = to_device(default_collate(list(zip(*state))[0]), self._device)
action = to_device(default_collate(list(zip(*state))[1]), self._device)
if len(action.shape) == 1:
action = action.unsqueeze(-1)
if reset.any():
mask = 1 - reset
for key in latent.keys():
for i in range(latent[key].shape[0]):
latent[key][i] *= mask[i]
for i in range(len(action)):
action[i] *= mask[i]
data = data - 0.5
embed = world_model.encoder(data)
latent, _ = world_model.dynamics.obs_step(latent, action, embed, self._cfg.collect.collect_dyn_sample)
feat = world_model.dynamics.get_feat(latent)
actor = self._actor(feat)
action = actor.sample()
logprob = actor.log_prob(action)
latent = {k: v.detach() for k, v in latent.items()}
action = action.detach()
state = (latent, action)
output = {"action": action, "logprob": logprob, "state": state}
if self._cuda:
output = to_device(output, 'cpu')
output = default_decollate(output)
return {i: d for i, d in zip(data_id, output)}
def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
r"""
Overview:
Generate dict type transition data from inputs.
Arguments:
- obs (:obj:`Any`): Env observation
- model_output (:obj:`dict`): Output of collect model, including at least ['action']
- timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \
(here 'obs' indicates obs after env step).
Returns:
- transition (:obj:`dict`): Dict type transition data.
"""
transition = {
'obs': obs,
'action': model_output['action'],
# TODO(zp) random_collect just have action
#'logprob': model_output['logprob'],
'reward': timestep.reward,
'discount': timestep.info['discount'],
'done': timestep.done,
}
return transition
def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
return get_train_sample(data, self._unroll_len)
def _init_eval(self) -> None:
self._eval_model = model_wrap(self._model, wrapper_name='base')
self._eval_model.reset()
def _forward_eval(self, data: dict, world_model, reset=None, state=None) -> dict:
data_id = list(data.keys())
data = default_collate(list(data.values()))
if self._cuda:
data = to_device(data, self._device)
self._eval_model.eval()
if state is None:
batch_size = len(data_id)
latent = world_model.dynamics.initial(batch_size) # {logit, stoch, deter}
action = torch.zeros((batch_size, self._cfg.collect.action_size)).to(self._device)
else:
#state = default_collate(list(state.values()))
latent = to_device(default_collate(list(zip(*state))[0]), self._device)
action = to_device(default_collate(list(zip(*state))[1]), self._device)
if len(action.shape) == 1:
action = action.unsqueeze(-1)
if reset.any():
mask = 1 - reset
for key in latent.keys():
for i in range(latent[key].shape[0]):
latent[key][i] *= mask[i]
for i in range(len(action)):
action[i] *= mask[i]
data = data - 0.5
embed = world_model.encoder(data)
latent, _ = world_model.dynamics.obs_step(latent, action, embed, self._cfg.collect.collect_dyn_sample)
feat = world_model.dynamics.get_feat(latent)
actor = self._actor(feat)
action = actor.mode()
logprob = actor.log_prob(action)
latent = {k: v.detach() for k, v in latent.items()}
action = action.detach()
state = (latent, action)
output = {"action": action, "logprob": logprob, "state": state}
if self._cuda:
output = to_device(output, 'cpu')
output = default_decollate(output)
return {i: d for i, d in zip(data_id, output)}
def _monitor_vars_learn(self) -> List[str]:
r"""
Overview:
Return variables' name if variables are to used in monitor.
Returns:
- vars (:obj:`List[str]`): Variables' name list.
"""
return [
'normed_target_mean', 'normed_target_std', 'normed_target_min', 'normed_target_max', 'EMA_005', 'EMA_095',
'actor_entropy', 'actor_state_entropy', 'value_mean', 'value_std', 'value_min', 'value_max', 'target_mean',
'target_std', 'target_min', 'target_max', 'imag_reward_mean', 'imag_reward_std', 'imag_reward_min',
'imag_reward_max', 'imag_action_mean', 'imag_action_std', 'imag_action_min', 'imag_action_max', 'actor_ent',
'actor_loss', 'critic_loss', 'actor_grad_norm', 'critic_grad_norm'
]