File size: 13,985 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
from typing import Tuple, Callable, Optional
from collections import namedtuple
from abc import ABC, abstractmethod

import torch
from torch import Tensor, nn
from easydict import EasyDict

from ding.worker import IBuffer
from ding.envs import BaseEnv
from ding.utils import deep_merge_dicts
from ding.world_model.utils import get_rollout_length_scheduler

from ding.utils import import_module, WORLD_MODEL_REGISTRY


def get_world_model_cls(cfg):
    import_module(cfg.get('import_names', []))
    return WORLD_MODEL_REGISTRY.get(cfg.type)


def create_world_model(cfg, *args, **kwargs):
    import_module(cfg.get('import_names', []))
    return WORLD_MODEL_REGISTRY.build(cfg.type, cfg, *args, **kwargs)


class WorldModel(ABC):
    r"""
    Overview:
        Abstract baseclass for world model.

    Interfaces:
        should_train, should_eval, train, eval, step
    """

    config = dict(
        train_freq=250,  # w.r.t environment step
        eval_freq=250,  # w.r.t environment step
        cuda=True,
        rollout_length_scheduler=dict(
            type='linear',
            rollout_start_step=20000,
            rollout_end_step=150000,
            rollout_length_min=1,
            rollout_length_max=25,
        )
    )

    def __init__(self, cfg: dict, env: BaseEnv, tb_logger: 'SummaryWriter'):  # noqa
        self.cfg = cfg
        self.env = env
        self.tb_logger = tb_logger

        self._cuda = cfg.cuda
        self.train_freq = cfg.train_freq
        self.eval_freq = cfg.eval_freq
        self.rollout_length_scheduler = get_rollout_length_scheduler(cfg.rollout_length_scheduler)

        self.last_train_step = 0
        self.last_eval_step = 0

    @classmethod
    def default_config(cls: type) -> EasyDict:
        # can not call default_config() recursively
        # because config will be overwritten by subclasses
        merge_cfg = EasyDict(cfg_type=cls.__name__ + 'Dict')
        while cls != ABC:
            merge_cfg = deep_merge_dicts(merge_cfg, cls.config)
            cls = cls.__base__
        return merge_cfg

    def should_train(self, envstep: int):
        r"""
        Overview:
            Check whether need to train world model.
        """
        return (envstep - self.last_train_step) >= self.train_freq

    def should_eval(self, envstep: int):
        r"""
        Overview:
            Check whether need to evaluate world model.
        """
        return (envstep - self.last_eval_step) >= self.eval_freq and self.last_train_step != 0

    @abstractmethod
    def train(self, env_buffer: IBuffer, envstep: int, train_iter: int):
        r"""
        Overview:
            Train world model using data from env_buffer.

        Arguments:
            - env_buffer (:obj:`IBuffer`): the buffer which collects real environment steps
            - envstep (:obj:`int`): the current number of environment steps in real environment
            - train_iter (:obj:`int`): the current number of policy training iterations
        """
        raise NotImplementedError

    @abstractmethod
    def eval(self, env_buffer: IBuffer, envstep: int, train_iter: int):
        r"""
        Overview:
            Evaluate world model using data from env_buffer.

        Arguments:
            - env_buffer (:obj:`IBuffer`): the buffer that collects real environment steps
            - envstep (:obj:`int`): the current number of environment steps in real environment
            - train_iter (:obj:`int`): the current number of policy training iterations
        """
        raise NotImplementedError

    @abstractmethod
    def step(self, obs: Tensor, action: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        r"""
        Overview:
            Take one step in world model.

        Arguments:
            - obs (:obj:`torch.Tensor`): current observations :math:`S_t`
            - action (:obj:`torch.Tensor`): current actions :math:`A_t`

        Returns:
            - reward (:obj:`torch.Tensor`): rewards :math:`R_t`
            - next_obs (:obj:`torch.Tensor`): next observations :math:`S_t+1`
            - done (:obj:`torch.Tensor`): whether the episodes ends

        Shapes:
            :math:`B`: batch size
            :math:`O`: observation dimension
            :math:`A`: action dimension

            - obs:      [B, O]
            - action:   [B, A]
            - reward:   [B, ]
            - next_obs: [B, O]
            - done:     [B, ]
        """
        raise NotImplementedError


class DynaWorldModel(WorldModel, ABC):
    r"""
    Overview:
        Dyna-style world model (summarized in arXiv: 1907.02057) which stores and\
        reuses imagination rollout in the imagination buffer.

    Interfaces:
        sample, fill_img_buffer, should_train, should_eval, train, eval, step
    """

    config = dict(
        other=dict(
            real_ratio=0.05,
            rollout_retain=4,
            rollout_batch_size=100000,
            imagination_buffer=dict(
                type='elastic',
                replay_buffer_size=6000000,
                deepcopy=False,
                enable_track_used_data=False,
                # set_buffer_size=set_buffer_size,
                periodic_thruput_seconds=60,
            ),
        )
    )

    def __init__(self, cfg: dict, env: BaseEnv, tb_logger: 'SummaryWriter'):  # noqa
        super().__init__(cfg, env, tb_logger)
        self.real_ratio = cfg.other.real_ratio
        self.rollout_batch_size = cfg.other.rollout_batch_size
        self.rollout_retain = cfg.other.rollout_retain
        self.buffer_size_scheduler = \
            lambda x: self.rollout_length_scheduler(x) * self.rollout_batch_size * self.rollout_retain

    def sample(self, env_buffer: IBuffer, img_buffer: IBuffer, batch_size: int, train_iter: int) -> dict:
        r"""
        Overview:
            Sample from the combination of environment buffer and imagination buffer with\
            certain ratio to generate batched data for policy training.

        Arguments:
            - policy (:obj:`namedtuple`): policy in collect mode
            - env_buffer (:obj:`IBuffer`): the buffer that collects real environment steps
            - img_buffer (:obj:`IBuffer`): the buffer that collects imagination steps
            - batch_size (:obj:`int`): the batch size for policy training
            - train_iter (:obj:`int`): the current number of policy training iterations

        Returns:
            - data (:obj:`int`): the training data for policy training
        """
        env_batch_size = int(batch_size * self.real_ratio)
        img_batch_size = batch_size - env_batch_size
        env_data = env_buffer.sample(env_batch_size, train_iter)
        img_data = img_buffer.sample(img_batch_size, train_iter)
        train_data = env_data + img_data
        return train_data

    def fill_img_buffer(
        self, policy: namedtuple, env_buffer: IBuffer, img_buffer: IBuffer, envstep: int, train_iter: int
    ):
        r"""
        Overview:
            Sample from the env_buffer, rollouts to generate new data, and push them into the img_buffer.

        Arguments:
            - policy (:obj:`namedtuple`): policy in collect mode
            - env_buffer (:obj:`IBuffer`): the buffer that collects real environment steps
            - img_buffer (:obj:`IBuffer`): the buffer that collects imagination steps
            - envstep (:obj:`int`): the current number of environment steps in real environment
            - train_iter (:obj:`int`): the current number of policy training iterations
        """
        from ding.torch_utils import to_tensor
        from ding.envs import BaseEnvTimestep
        from ding.worker.collector.base_serial_collector import to_tensor_transitions

        def step(obs, act):
            # This function has the same input and output format as env manager's step
            data_id = list(obs.keys())
            obs = torch.stack([obs[id] for id in data_id], dim=0)
            act = torch.stack([act[id] for id in data_id], dim=0)
            with torch.no_grad():
                rewards, next_obs, terminals = self.step(obs, act)
            # terminals = self.termination_fn(next_obs)
            timesteps = {
                id: BaseEnvTimestep(n, r, d, {})
                for id, n, r, d in zip(
                    data_id,
                    next_obs.cpu().numpy(),
                    rewards.unsqueeze(-1).cpu().numpy(),  # ding api
                    terminals.cpu().numpy()
                )
            }
            return timesteps

        # set rollout length
        rollout_length = self.rollout_length_scheduler(envstep)
        # load data
        data = env_buffer.sample(self.rollout_batch_size, train_iter, replace=True)
        obs = {id: data[id]['obs'] for id in range(len(data))}
        # rollout
        buffer = [[] for id in range(len(obs))]
        new_data = []
        for i in range(rollout_length):
            # get action
            obs = to_tensor(obs, dtype=torch.float32)
            policy_output = policy.forward(obs)
            actions = {id: output['action'] for id, output in policy_output.items()}
            # predict next obs and reward
            # timesteps = self.step(obs, actions, env_model)
            timesteps = step(obs, actions)
            obs_new = {}
            for id, timestep in timesteps.items():
                transition = policy.process_transition(obs[id], policy_output[id], timestep)
                transition['collect_iter'] = train_iter
                buffer[id].append(transition)
                if not timestep.done:
                    obs_new[id] = timestep.obs
                if timestep.done or i + 1 == rollout_length:
                    transitions = to_tensor_transitions(buffer[id])
                    train_sample = policy.get_train_sample(transitions)
                    new_data.extend(train_sample)
            if len(obs_new) == 0:
                break
            obs = obs_new

        img_buffer.push(new_data, cur_collector_envstep=envstep)


class DreamWorldModel(WorldModel, ABC):
    r"""
    Overview:
        Dreamer-style world model which uses each imagination rollout only once\
        and backpropagate through time(rollout) to optimize policy.

    Interfaces:
        rollout, should_train, should_eval, train, eval, step
    """

    def rollout(self, obs: Tensor, actor_fn: Callable[[Tensor], Tuple[Tensor, Tensor]], envstep: int,
                **kwargs) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Optional[bool]]:
        r"""
        Overview:
            Generate batched imagination rollouts starting from the current observations.\
            This function is useful for value gradients where the policy is optimized by BPTT.

        Arguments:
            - obs (:obj:`Tensor`): the current observations :math:`S_t`
            - actor_fn (:obj:`Callable`): the unified API :math:`(A_t, H_t) = pi(S_t)`
            - envstep (:obj:`int`): the current number of environment steps in real environment

        Returns:
            - obss (:obj:`Tensor`):        :math:`S_t,  ..., S_t+n`
            - actions (:obj:`Tensor`):     :math:`A_t,  ..., A_t+n`
            - rewards (:obj:`Tensor`):     :math:`R_t,  ..., R_t+n-1`
            - aug_rewards (:obj:`Tensor`): :math:`H_t,  ..., H_t+n`, this can be entropy bonus as in SAC,
                                                otherwise it should be a zero tensor
            - dones (:obj:`Tensor`):       :math:`\text{done}_t, ..., \text{done}_t+n`

        Shapes:
            :math:`N`: time step
            :math:`B`: batch size
            :math:`O`: observation dimension
            :math:`A`: action dimension

            - obss:        :math:`[N+1, B, O]`, where obss[0] are the real observations
            - actions:     :math:`[N+1, B, A]`
            - rewards:     :math:`[N,   B]`
            - aug_rewards: :math:`[N+1, B]`
            - dones:       :math:`[N,   B]`

        .. note::
            - The rollout length is determined by rollout length scheduler.

            - actor_fn's inputs and outputs shape are similar to WorldModel.step()
        """
        horizon = self.rollout_length_scheduler(envstep)
        if isinstance(self, nn.Module):
            # Rollouts should propagate gradients only to policy,
            # so make sure that the world model is not updated by rollout.
            self.requires_grad_(False)
        obss = [obs]
        actions = []
        rewards = []
        aug_rewards = []  # -temperature*logprob
        dones = []
        for _ in range(horizon):
            action, aug_reward = actor_fn(obs)
            # done: probability of termination
            reward, obs, done = self.step(obs, action, **kwargs)
            reward = reward + aug_reward
            obss.append(obs)
            actions.append(action)
            rewards.append(reward)
            aug_rewards.append(aug_reward)
            dones.append(done)
        action, aug_reward = actor_fn(obs)
        actions.append(action)
        aug_rewards.append(aug_reward)
        if isinstance(self, nn.Module):
            self.requires_grad_(True)
        return (
            torch.stack(obss),
            torch.stack(actions),
            # rewards is an empty list when horizon=0
            torch.stack(rewards) if rewards else torch.tensor(rewards, device=obs.device),
            torch.stack(aug_rewards),
            torch.stack(dones) if dones else torch.tensor(dones, device=obs.device)
        )


class HybridWorldModel(DynaWorldModel, DreamWorldModel, ABC):
    r"""
    Overview:
        The hybrid model that combines reused and on-the-fly rollouts.

    Interfaces:
        rollout, sample, fill_img_buffer, should_train, should_eval, train, eval, step
    """

    def __init__(self, cfg: dict, env: BaseEnv, tb_logger: 'SummaryWriter'):  # noqa
        DynaWorldModel.__init__(self, cfg, env, tb_logger)
        DreamWorldModel.__init__(self, cfg, env, tb_logger)