File size: 13,769 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
import os
from typing import TYPE_CHECKING, Callable, List, Union, Tuple, Dict, Optional
from easydict import EasyDict
from ditk import logging
import torch
from ding.data import Buffer, Dataset, DataLoader, offline_data_save_type
from ding.data.buffer.middleware import PriorityExperienceReplay
from ding.framework import task
from ding.utils import get_rank

if TYPE_CHECKING:
    from ding.framework import OnlineRLContext, OfflineRLContext


def data_pusher(cfg: EasyDict, buffer_: Buffer, group_by_env: Optional[bool] = None):
    """
    Overview:
        Push episodes or trajectories into the buffer.
    Arguments:
        - cfg (:obj:`EasyDict`): Config.
        - buffer (:obj:`Buffer`): Buffer to push the data in.
    """
    if task.router.is_active and not task.has_role(task.role.LEARNER):
        return task.void()

    def _push(ctx: "OnlineRLContext"):
        """
        Overview:
            In ctx, either `ctx.trajectories` or `ctx.episodes` should not be None.
        Input of ctx:
            - trajectories (:obj:`List[Dict]`): Trajectories.
            - episodes (:obj:`List[Dict]`): Episodes.
        """

        if ctx.trajectories is not None:  # each data in buffer is a transition
            if group_by_env:
                for i, t in enumerate(ctx.trajectories):
                    buffer_.push(t, {'env': t.env_data_id.item()})
            else:
                for t in ctx.trajectories:
                    buffer_.push(t)
            ctx.trajectories = None
        elif ctx.episodes is not None:  # each data in buffer is a episode
            for t in ctx.episodes:
                buffer_.push(t)
            ctx.episodes = None
        else:
            raise RuntimeError("Either ctx.trajectories or ctx.episodes should be not None.")

    return _push


def buffer_saver(cfg: EasyDict, buffer_: Buffer, every_envstep: int = 1000, replace: bool = False):
    """
    Overview:
        Save current buffer data.
    Arguments:
        - cfg (:obj:`EasyDict`): Config.
        - buffer (:obj:`Buffer`): Buffer to push the data in.
        - every_envstep (:obj:`int`): save at every env step.
        - replace (:obj:`bool`): Whether replace the last file.
    """

    buffer_saver_env_counter = -every_envstep

    def _save(ctx: "OnlineRLContext"):
        """
        Overview:
            In ctx, `ctx.env_step` should not be None.
        Input of ctx:
            - env_step (:obj:`int`): env step.
        """
        nonlocal buffer_saver_env_counter
        if ctx.env_step is not None:
            if ctx.env_step >= every_envstep + buffer_saver_env_counter:
                buffer_saver_env_counter = ctx.env_step
                if replace:
                    buffer_.save_data(os.path.join(cfg.exp_name, "replaybuffer", "data_latest.hkl"))
                else:
                    buffer_.save_data(
                        os.path.join(cfg.exp_name, "replaybuffer", "data_envstep_{}.hkl".format(ctx.env_step))
                    )
        else:
            raise RuntimeError("buffer_saver only supports collecting data by step rather than episode.")

    return _save


def offpolicy_data_fetcher(
        cfg: EasyDict,
        buffer_: Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]],
        data_shortage_warning: bool = False,
) -> Callable:
    """
    Overview:
        The return function is a generator which meanly fetch a batch of data from a buffer, \
        a list of buffers, or a dict of buffers.
    Arguments:
        - cfg (:obj:`EasyDict`): Config which should contain the following keys: `cfg.policy.learn.batch_size`.
        - buffer (:obj:`Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]]`): \
            The buffer where the data is fetched from. \
            ``Buffer`` type means a buffer.\
            ``List[Tuple[Buffer, float]]`` type means a list of tuple. In each tuple there is a buffer and a float. \
            The float defines, how many batch_size is the size of the data \
            which is sampled from the corresponding buffer.\
            ``Dict[str, Buffer]`` type means a dict in which the value of each element is a buffer. \
            For each key-value pair of dict, batch_size of data will be sampled from the corresponding buffer \
            and assigned to the same key of `ctx.train_data`.
        - data_shortage_warning (:obj:`bool`): Whether to output warning when data shortage occurs in fetching.
    """

    def _fetch(ctx: "OnlineRLContext"):
        """
        Input of ctx:
            - train_output (:obj:`Union[Dict, Deque[Dict]]`): This attribute should exist \
                if `buffer_` is of type Buffer and if `buffer_` use the middleware `PriorityExperienceReplay`. \
                The meta data `priority` of the sampled data in the `buffer_` will be updated \
                to the `priority` attribute of `ctx.train_output` if `ctx.train_output` is a dict, \
                or the `priority` attribute of `ctx.train_output`'s popped element \
                if `ctx.train_output` is a deque of dicts.
        Output of ctx:
            - train_data (:obj:`Union[List[Dict], Dict[str, List[Dict]]]`): The fetched data. \
                ``List[Dict]`` type means a list of data.
                    `train_data` is of this type if the type of `buffer_` is Buffer or List.
                ``Dict[str, List[Dict]]]`` type means a dict, in which the value of each key-value pair
                    is a list of data. `train_data` is of this type if the type of `buffer_` is Dict.
        """
        try:
            unroll_len = cfg.policy.collect.unroll_len
            if isinstance(buffer_, Buffer):
                if unroll_len > 1:
                    buffered_data = buffer_.sample(
                        cfg.policy.learn.batch_size, groupby="env", unroll_len=unroll_len, replace=True
                    )
                    ctx.train_data = [[t.data for t in d] for d in buffered_data]  # B, unroll_len
                else:
                    buffered_data = buffer_.sample(cfg.policy.learn.batch_size)
                    ctx.train_data = [d.data for d in buffered_data]
            elif isinstance(buffer_, List):  # like sqil, r2d3
                assert unroll_len == 1, "not support"
                buffered_data = []
                for buffer_elem, p in buffer_:
                    data_elem = buffer_elem.sample(int(cfg.policy.learn.batch_size * p))
                    assert data_elem is not None
                    buffered_data.append(data_elem)
                buffered_data = sum(buffered_data, [])
                ctx.train_data = [d.data for d in buffered_data]
            elif isinstance(buffer_, Dict):  # like ppg_offpolicy
                assert unroll_len == 1, "not support"
                buffered_data = {k: v.sample(cfg.policy.learn.batch_size) for k, v in buffer_.items()}
                ctx.train_data = {k: [d.data for d in v] for k, v in buffered_data.items()}
            else:
                raise TypeError("not support buffer argument type: {}".format(type(buffer_)))

            assert buffered_data is not None
        except (ValueError, AssertionError):
            if data_shortage_warning:
                # You can modify data collect config to avoid this warning, e.g. increasing n_sample, n_episode.
                # Fetcher will skip this this attempt.
                logging.warning(
                    "Replay buffer's data is not enough to support training, so skip this training to wait more data."
                )
            ctx.train_data = None
            return

        yield

        if isinstance(buffer_, Buffer):
            if any([isinstance(m, PriorityExperienceReplay) for m in buffer_._middleware]):
                index = [d.index for d in buffered_data]
                meta = [d.meta for d in buffered_data]
                # such as priority
                if isinstance(ctx.train_output, List):
                    priority = ctx.train_output.pop()['priority']
                else:
                    priority = ctx.train_output['priority']
                for idx, m, p in zip(index, meta, priority):
                    m['priority'] = p
                    buffer_.update(index=idx, data=None, meta=m)

    return _fetch


def offline_data_fetcher_from_mem(cfg: EasyDict, dataset: Dataset) -> Callable:

    from threading import Thread
    from queue import Queue
    import time
    stream = torch.cuda.Stream()

    def producer(queue, dataset, batch_size, device):
        torch.set_num_threads(4)
        nonlocal stream
        idx_iter = iter(range(len(dataset) - batch_size))

        if len(dataset) < batch_size:
            logging.warning('batch_size is too large!!!!')
        with torch.cuda.stream(stream):
            while True:
                if queue.full():
                    time.sleep(0.1)
                else:
                    try:
                        start_idx = next(idx_iter)
                    except StopIteration:
                        del idx_iter
                        idx_iter = iter(range(len(dataset) - batch_size))
                        start_idx = next(idx_iter)
                    data = [dataset.__getitem__(idx) for idx in range(start_idx, start_idx + batch_size)]
                    data = [[i[j] for i in data] for j in range(len(data[0]))]
                    data = [torch.stack(x).to(device) for x in data]
                    queue.put(data)

    queue = Queue(maxsize=50)
    device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu'
    producer_thread = Thread(
        target=producer, args=(queue, dataset, cfg.policy.learn.batch_size, device), name='cuda_fetcher_producer'
    )

    def _fetch(ctx: "OfflineRLContext"):
        nonlocal queue, producer_thread
        if not producer_thread.is_alive():
            time.sleep(5)
            producer_thread.start()
        while queue.empty():
            time.sleep(0.001)
        ctx.train_data = queue.get()

    return _fetch


def offline_data_fetcher(cfg: EasyDict, dataset: Dataset) -> Callable:
    """
    Overview:
        The outer function transforms a Pytorch `Dataset` to `DataLoader`. \
        The return function is a generator which each time fetches a batch of data from the previous `DataLoader`.\
        Please refer to the link https://pytorch.org/tutorials/beginner/basics/data_tutorial.html \
        and https://pytorch.org/docs/stable/data.html for more details.
    Arguments:
        - cfg (:obj:`EasyDict`): Config which should contain the following keys: `cfg.policy.learn.batch_size`.
        - dataset (:obj:`Dataset`): The dataset of type `torch.utils.data.Dataset` which stores the data.
    """
    # collate_fn is executed in policy now
    dataloader = DataLoader(dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x)
    dataloader = iter(dataloader)

    def _fetch(ctx: "OfflineRLContext"):
        """
        Overview:
            Every time this generator is iterated, the fetched data will be assigned to ctx.train_data. \
            After the dataloader is empty, the attribute `ctx.train_epoch` will be incremented by 1.
        Input of ctx:
            - train_epoch (:obj:`int`): Number of `train_epoch`.
        Output of ctx:
            - train_data (:obj:`List[Tensor]`): The fetched data batch.
        """
        nonlocal dataloader
        try:
            ctx.train_data = next(dataloader)  # noqa
        except StopIteration:
            ctx.train_epoch += 1
            del dataloader
            dataloader = DataLoader(
                dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x
            )
            dataloader = iter(dataloader)
            ctx.train_data = next(dataloader)
        # TODO apply data update (e.g. priority) in offline setting when necessary
        ctx.trained_env_step += len(ctx.train_data)

    return _fetch


def offline_data_saver(data_path: str, data_type: str = 'hdf5') -> Callable:
    """
    Overview:
        Save the expert data of offline RL in a directory.
    Arguments:
        - data_path (:obj:`str`): File path where the expert data will be written into, which is usually ./expert.pkl'.
        - data_type (:obj:`str`): Define the type of the saved data. \
            The type of saved data is pkl if `data_type == 'naive'`. \
            The type of saved data is hdf5 if `data_type == 'hdf5'`.
    """

    def _save(ctx: "OnlineRLContext"):
        """
        Input of ctx:
            - trajectories (:obj:`List[Tensor]`): The expert data to be saved.
        """
        data = ctx.trajectories
        offline_data_save_type(data, data_path, data_type)
        ctx.trajectories = None

    return _save


def sqil_data_pusher(cfg: EasyDict, buffer_: Buffer, expert: bool) -> Callable:
    """
    Overview:
        Push trajectories into the buffer in sqil learning pipeline.
    Arguments:
        - cfg (:obj:`EasyDict`): Config.
        - buffer (:obj:`Buffer`): Buffer to push the data in.
        - expert (:obj:`bool`): Whether the pushed data is expert data or not. \
            In each element of the pushed data, the reward will be set to 1 if this attribute is `True`, otherwise 0.
    """

    def _pusher(ctx: "OnlineRLContext"):
        """
        Input of ctx:
            - trajectories (:obj:`List[Dict]`): The trajectories to be pushed.
        """
        for t in ctx.trajectories:
            if expert:
                t.reward = torch.ones_like(t.reward)
            else:
                t.reward = torch.zeros_like(t.reward)
            buffer_.push(t)
        ctx.trajectories = None

    return _pusher