|
from typing import Tuple, Callable, Optional |
|
from collections import namedtuple |
|
from abc import ABC, abstractmethod |
|
|
|
import torch |
|
from torch import Tensor, nn |
|
from easydict import EasyDict |
|
|
|
from ding.worker import IBuffer |
|
from ding.envs import BaseEnv |
|
from ding.utils import deep_merge_dicts |
|
from ding.world_model.utils import get_rollout_length_scheduler |
|
|
|
from ding.utils import import_module, WORLD_MODEL_REGISTRY |
|
|
|
|
|
def get_world_model_cls(cfg): |
|
import_module(cfg.get('import_names', [])) |
|
return WORLD_MODEL_REGISTRY.get(cfg.type) |
|
|
|
|
|
def create_world_model(cfg, *args, **kwargs): |
|
import_module(cfg.get('import_names', [])) |
|
return WORLD_MODEL_REGISTRY.build(cfg.type, cfg, *args, **kwargs) |
|
|
|
|
|
class WorldModel(ABC): |
|
r""" |
|
Overview: |
|
Abstract baseclass for world model. |
|
|
|
Interfaces: |
|
should_train, should_eval, train, eval, step |
|
""" |
|
|
|
config = dict( |
|
train_freq=250, |
|
eval_freq=250, |
|
cuda=True, |
|
rollout_length_scheduler=dict( |
|
type='linear', |
|
rollout_start_step=20000, |
|
rollout_end_step=150000, |
|
rollout_length_min=1, |
|
rollout_length_max=25, |
|
) |
|
) |
|
|
|
def __init__(self, cfg: dict, env: BaseEnv, tb_logger: 'SummaryWriter'): |
|
self.cfg = cfg |
|
self.env = env |
|
self.tb_logger = tb_logger |
|
|
|
self._cuda = cfg.cuda |
|
self.train_freq = cfg.train_freq |
|
self.eval_freq = cfg.eval_freq |
|
self.rollout_length_scheduler = get_rollout_length_scheduler(cfg.rollout_length_scheduler) |
|
|
|
self.last_train_step = 0 |
|
self.last_eval_step = 0 |
|
|
|
@classmethod |
|
def default_config(cls: type) -> EasyDict: |
|
|
|
|
|
merge_cfg = EasyDict(cfg_type=cls.__name__ + 'Dict') |
|
while cls != ABC: |
|
merge_cfg = deep_merge_dicts(merge_cfg, cls.config) |
|
cls = cls.__base__ |
|
return merge_cfg |
|
|
|
def should_train(self, envstep: int): |
|
r""" |
|
Overview: |
|
Check whether need to train world model. |
|
""" |
|
return (envstep - self.last_train_step) >= self.train_freq |
|
|
|
def should_eval(self, envstep: int): |
|
r""" |
|
Overview: |
|
Check whether need to evaluate world model. |
|
""" |
|
return (envstep - self.last_eval_step) >= self.eval_freq and self.last_train_step != 0 |
|
|
|
@abstractmethod |
|
def train(self, env_buffer: IBuffer, envstep: int, train_iter: int): |
|
r""" |
|
Overview: |
|
Train world model using data from env_buffer. |
|
|
|
Arguments: |
|
- env_buffer (:obj:`IBuffer`): the buffer which collects real environment steps |
|
- envstep (:obj:`int`): the current number of environment steps in real environment |
|
- train_iter (:obj:`int`): the current number of policy training iterations |
|
""" |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def eval(self, env_buffer: IBuffer, envstep: int, train_iter: int): |
|
r""" |
|
Overview: |
|
Evaluate world model using data from env_buffer. |
|
|
|
Arguments: |
|
- env_buffer (:obj:`IBuffer`): the buffer that collects real environment steps |
|
- envstep (:obj:`int`): the current number of environment steps in real environment |
|
- train_iter (:obj:`int`): the current number of policy training iterations |
|
""" |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def step(self, obs: Tensor, action: Tensor) -> Tuple[Tensor, Tensor, Tensor]: |
|
r""" |
|
Overview: |
|
Take one step in world model. |
|
|
|
Arguments: |
|
- obs (:obj:`torch.Tensor`): current observations :math:`S_t` |
|
- action (:obj:`torch.Tensor`): current actions :math:`A_t` |
|
|
|
Returns: |
|
- reward (:obj:`torch.Tensor`): rewards :math:`R_t` |
|
- next_obs (:obj:`torch.Tensor`): next observations :math:`S_t+1` |
|
- done (:obj:`torch.Tensor`): whether the episodes ends |
|
|
|
Shapes: |
|
:math:`B`: batch size |
|
:math:`O`: observation dimension |
|
:math:`A`: action dimension |
|
|
|
- obs: [B, O] |
|
- action: [B, A] |
|
- reward: [B, ] |
|
- next_obs: [B, O] |
|
- done: [B, ] |
|
""" |
|
raise NotImplementedError |
|
|
|
|
|
class DynaWorldModel(WorldModel, ABC): |
|
r""" |
|
Overview: |
|
Dyna-style world model (summarized in arXiv: 1907.02057) which stores and\ |
|
reuses imagination rollout in the imagination buffer. |
|
|
|
Interfaces: |
|
sample, fill_img_buffer, should_train, should_eval, train, eval, step |
|
""" |
|
|
|
config = dict( |
|
other=dict( |
|
real_ratio=0.05, |
|
rollout_retain=4, |
|
rollout_batch_size=100000, |
|
imagination_buffer=dict( |
|
type='elastic', |
|
replay_buffer_size=6000000, |
|
deepcopy=False, |
|
enable_track_used_data=False, |
|
|
|
periodic_thruput_seconds=60, |
|
), |
|
) |
|
) |
|
|
|
def __init__(self, cfg: dict, env: BaseEnv, tb_logger: 'SummaryWriter'): |
|
super().__init__(cfg, env, tb_logger) |
|
self.real_ratio = cfg.other.real_ratio |
|
self.rollout_batch_size = cfg.other.rollout_batch_size |
|
self.rollout_retain = cfg.other.rollout_retain |
|
self.buffer_size_scheduler = \ |
|
lambda x: self.rollout_length_scheduler(x) * self.rollout_batch_size * self.rollout_retain |
|
|
|
def sample(self, env_buffer: IBuffer, img_buffer: IBuffer, batch_size: int, train_iter: int) -> dict: |
|
r""" |
|
Overview: |
|
Sample from the combination of environment buffer and imagination buffer with\ |
|
certain ratio to generate batched data for policy training. |
|
|
|
Arguments: |
|
- policy (:obj:`namedtuple`): policy in collect mode |
|
- env_buffer (:obj:`IBuffer`): the buffer that collects real environment steps |
|
- img_buffer (:obj:`IBuffer`): the buffer that collects imagination steps |
|
- batch_size (:obj:`int`): the batch size for policy training |
|
- train_iter (:obj:`int`): the current number of policy training iterations |
|
|
|
Returns: |
|
- data (:obj:`int`): the training data for policy training |
|
""" |
|
env_batch_size = int(batch_size * self.real_ratio) |
|
img_batch_size = batch_size - env_batch_size |
|
env_data = env_buffer.sample(env_batch_size, train_iter) |
|
img_data = img_buffer.sample(img_batch_size, train_iter) |
|
train_data = env_data + img_data |
|
return train_data |
|
|
|
def fill_img_buffer( |
|
self, policy: namedtuple, env_buffer: IBuffer, img_buffer: IBuffer, envstep: int, train_iter: int |
|
): |
|
r""" |
|
Overview: |
|
Sample from the env_buffer, rollouts to generate new data, and push them into the img_buffer. |
|
|
|
Arguments: |
|
- policy (:obj:`namedtuple`): policy in collect mode |
|
- env_buffer (:obj:`IBuffer`): the buffer that collects real environment steps |
|
- img_buffer (:obj:`IBuffer`): the buffer that collects imagination steps |
|
- envstep (:obj:`int`): the current number of environment steps in real environment |
|
- train_iter (:obj:`int`): the current number of policy training iterations |
|
""" |
|
from ding.torch_utils import to_tensor |
|
from ding.envs import BaseEnvTimestep |
|
from ding.worker.collector.base_serial_collector import to_tensor_transitions |
|
|
|
def step(obs, act): |
|
|
|
data_id = list(obs.keys()) |
|
obs = torch.stack([obs[id] for id in data_id], dim=0) |
|
act = torch.stack([act[id] for id in data_id], dim=0) |
|
with torch.no_grad(): |
|
rewards, next_obs, terminals = self.step(obs, act) |
|
|
|
timesteps = { |
|
id: BaseEnvTimestep(n, r, d, {}) |
|
for id, n, r, d in zip( |
|
data_id, |
|
next_obs.cpu().numpy(), |
|
rewards.unsqueeze(-1).cpu().numpy(), |
|
terminals.cpu().numpy() |
|
) |
|
} |
|
return timesteps |
|
|
|
|
|
rollout_length = self.rollout_length_scheduler(envstep) |
|
|
|
data = env_buffer.sample(self.rollout_batch_size, train_iter, replace=True) |
|
obs = {id: data[id]['obs'] for id in range(len(data))} |
|
|
|
buffer = [[] for id in range(len(obs))] |
|
new_data = [] |
|
for i in range(rollout_length): |
|
|
|
obs = to_tensor(obs, dtype=torch.float32) |
|
policy_output = policy.forward(obs) |
|
actions = {id: output['action'] for id, output in policy_output.items()} |
|
|
|
|
|
timesteps = step(obs, actions) |
|
obs_new = {} |
|
for id, timestep in timesteps.items(): |
|
transition = policy.process_transition(obs[id], policy_output[id], timestep) |
|
transition['collect_iter'] = train_iter |
|
buffer[id].append(transition) |
|
if not timestep.done: |
|
obs_new[id] = timestep.obs |
|
if timestep.done or i + 1 == rollout_length: |
|
transitions = to_tensor_transitions(buffer[id]) |
|
train_sample = policy.get_train_sample(transitions) |
|
new_data.extend(train_sample) |
|
if len(obs_new) == 0: |
|
break |
|
obs = obs_new |
|
|
|
img_buffer.push(new_data, cur_collector_envstep=envstep) |
|
|
|
|
|
class DreamWorldModel(WorldModel, ABC): |
|
r""" |
|
Overview: |
|
Dreamer-style world model which uses each imagination rollout only once\ |
|
and backpropagate through time(rollout) to optimize policy. |
|
|
|
Interfaces: |
|
rollout, should_train, should_eval, train, eval, step |
|
""" |
|
|
|
def rollout(self, obs: Tensor, actor_fn: Callable[[Tensor], Tuple[Tensor, Tensor]], envstep: int, |
|
**kwargs) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Optional[bool]]: |
|
r""" |
|
Overview: |
|
Generate batched imagination rollouts starting from the current observations.\ |
|
This function is useful for value gradients where the policy is optimized by BPTT. |
|
|
|
Arguments: |
|
- obs (:obj:`Tensor`): the current observations :math:`S_t` |
|
- actor_fn (:obj:`Callable`): the unified API :math:`(A_t, H_t) = pi(S_t)` |
|
- envstep (:obj:`int`): the current number of environment steps in real environment |
|
|
|
Returns: |
|
- obss (:obj:`Tensor`): :math:`S_t, ..., S_t+n` |
|
- actions (:obj:`Tensor`): :math:`A_t, ..., A_t+n` |
|
- rewards (:obj:`Tensor`): :math:`R_t, ..., R_t+n-1` |
|
- aug_rewards (:obj:`Tensor`): :math:`H_t, ..., H_t+n`, this can be entropy bonus as in SAC, |
|
otherwise it should be a zero tensor |
|
- dones (:obj:`Tensor`): :math:`\text{done}_t, ..., \text{done}_t+n` |
|
|
|
Shapes: |
|
:math:`N`: time step |
|
:math:`B`: batch size |
|
:math:`O`: observation dimension |
|
:math:`A`: action dimension |
|
|
|
- obss: :math:`[N+1, B, O]`, where obss[0] are the real observations |
|
- actions: :math:`[N+1, B, A]` |
|
- rewards: :math:`[N, B]` |
|
- aug_rewards: :math:`[N+1, B]` |
|
- dones: :math:`[N, B]` |
|
|
|
.. note:: |
|
- The rollout length is determined by rollout length scheduler. |
|
|
|
- actor_fn's inputs and outputs shape are similar to WorldModel.step() |
|
""" |
|
horizon = self.rollout_length_scheduler(envstep) |
|
if isinstance(self, nn.Module): |
|
|
|
|
|
self.requires_grad_(False) |
|
obss = [obs] |
|
actions = [] |
|
rewards = [] |
|
aug_rewards = [] |
|
dones = [] |
|
for _ in range(horizon): |
|
action, aug_reward = actor_fn(obs) |
|
|
|
reward, obs, done = self.step(obs, action, **kwargs) |
|
reward = reward + aug_reward |
|
obss.append(obs) |
|
actions.append(action) |
|
rewards.append(reward) |
|
aug_rewards.append(aug_reward) |
|
dones.append(done) |
|
action, aug_reward = actor_fn(obs) |
|
actions.append(action) |
|
aug_rewards.append(aug_reward) |
|
if isinstance(self, nn.Module): |
|
self.requires_grad_(True) |
|
return ( |
|
torch.stack(obss), |
|
torch.stack(actions), |
|
|
|
torch.stack(rewards) if rewards else torch.tensor(rewards, device=obs.device), |
|
torch.stack(aug_rewards), |
|
torch.stack(dones) if dones else torch.tensor(dones, device=obs.device) |
|
) |
|
|
|
|
|
class HybridWorldModel(DynaWorldModel, DreamWorldModel, ABC): |
|
r""" |
|
Overview: |
|
The hybrid model that combines reused and on-the-fly rollouts. |
|
|
|
Interfaces: |
|
rollout, sample, fill_img_buffer, should_train, should_eval, train, eval, step |
|
""" |
|
|
|
def __init__(self, cfg: dict, env: BaseEnv, tb_logger: 'SummaryWriter'): |
|
DynaWorldModel.__init__(self, cfg, env, tb_logger) |
|
DreamWorldModel.__init__(self, cfg, env, tb_logger) |
|
|