File size: 15,670 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
import os
import itertools
import random
import uuid
from ditk import logging
import hickle
from typing import Any, Iterable, List, Optional, Tuple, Union
from collections import Counter
from collections import defaultdict, deque, OrderedDict
from ding.data.buffer import Buffer, apply_middleware, BufferedData
from ding.utils import fastcopy
from ding.torch_utils import get_null_data


class BufferIndex():
    """
    Overview:
        Save index string and offset in key value pair.
    """

    def __init__(self, maxlen: int, *args, **kwargs):
        self.maxlen = maxlen
        self.__map = OrderedDict(*args, **kwargs)
        self._last_key = next(reversed(self.__map)) if len(self) > 0 else None
        self._cumlen = len(self.__map)

    def get(self, key: str) -> int:
        value = self.__map[key]
        value = value % self._cumlen + min(0, (self.maxlen - self._cumlen))
        return value

    def __len__(self) -> int:
        return len(self.__map)

    def has(self, key: str) -> bool:
        return key in self.__map

    def append(self, key: str):
        self.__map[key] = self.__map[self._last_key] + 1 if self._last_key else 0
        self._last_key = key
        self._cumlen += 1
        if len(self) > self.maxlen:
            self.__map.popitem(last=False)

    def clear(self):
        self.__map = OrderedDict()
        self._last_key = None
        self._cumlen = 0


class DequeBuffer(Buffer):
    """
    Overview:
        A buffer implementation based on the deque structure.
    """

    def __init__(self, size: int, sliced: bool = False) -> None:
        """
        Overview:
            The initialization method of DequeBuffer.
        Arguments:
            - size (:obj:`int`): The maximum number of objects that the buffer can hold.
            - sliced (:obj:`bool`): The flag whether slice data by unroll_len when sample by group
        """
        super().__init__(size=size)
        self.storage = deque(maxlen=size)
        self.indices = BufferIndex(maxlen=size)
        self.sliced = sliced
        # Meta index is a dict which uses deque as values
        self.meta_index = {}

    @apply_middleware("push")
    def push(self, data: Any, meta: Optional[dict] = None) -> BufferedData:
        """
        Overview:
            The method that input the objects and the related meta information into the buffer.
        Arguments:
            - data (:obj:`Any`): The input object which can be in any format.
            - meta (:obj:`Optional[dict]`): A dict that helps describe data, such as\
                category, label, priority, etc. Default to ``None``.
        """
        return self._push(data, meta)

    @apply_middleware("sample")
    def sample(
            self,
            size: Optional[int] = None,
            indices: Optional[List[str]] = None,
            replace: bool = False,
            sample_range: Optional[slice] = None,
            ignore_insufficient: bool = False,
            groupby: Optional[str] = None,
            unroll_len: Optional[int] = None
    ) -> Union[List[BufferedData], List[List[BufferedData]]]:
        """
        Overview:
            The method that randomly sample data from the buffer or retrieve certain data by indices.
        Arguments:
            - size (:obj:`Optional[int]`): The number of objects to be obtained from the buffer.
                If ``indices`` is not specified, the ``size`` is required to randomly sample the\
                corresponding number of objects from the buffer.
            - indices (:obj:`Optional[List[str]]`): Only used when you want to retrieve data by indices.
                Default to ``None``.
            - replace (:obj:`bool`): As the sampling process is carried out one by one, this parameter\
                determines whether the previous samples will be put back into the buffer for subsequent\
                sampling. Default to ``False``, it means that duplicate samples will not appear in one\
                ``sample`` call.
            - sample_range (:obj:`Optional[slice]`): The indices range to sample data. Default to ``None``,\
                it means no restrictions on the range of indices for the sampling process.
            - ignore_insufficient (:obj:`bool`): whether throw `` ValueError`` if the sampled size is smaller\
                than the required size. Default to ``False``.
            - groupby (:obj:`Optional[str]`): If this parameter is activated, the method will return a\
                target size of object groups.
            - unroll_len (:obj:`Optional[int]`): The unroll length of a trajectory, used only when the\
                ``groupby`` is activated.
        Returns:
            - sampled_data (Union[List[BufferedData], List[List[BufferedData]]]): The sampling result.
        """
        storage = self.storage
        if sample_range:
            storage = list(itertools.islice(self.storage, sample_range.start, sample_range.stop, sample_range.step))

        # Size and indices
        assert size or indices, "One of size and indices must not be empty."
        if (size and indices) and (size != len(indices)):
            raise AssertionError("Size and indices length must be equal.")
        if not size:
            size = len(indices)
        # Indices and groupby
        assert not (indices and groupby), "Cannot use groupby and indicex at the same time."
        # Groupby and unroll_len
        assert not unroll_len or (
            unroll_len and groupby
        ), "Parameter unroll_len needs to be used in conjunction with groupby."

        value_error = None
        sampled_data = []
        if indices:
            indices_set = set(indices)
            hashed_data = filter(lambda item: item.index in indices_set, storage)
            hashed_data = map(lambda item: (item.index, item), hashed_data)
            hashed_data = dict(hashed_data)
            # Re-sample and return in indices order
            sampled_data = [hashed_data[index] for index in indices]
        elif groupby:
            sampled_data = self._sample_by_group(
                size=size, groupby=groupby, replace=replace, unroll_len=unroll_len, storage=storage, sliced=self.sliced
            )
        else:
            if replace:
                sampled_data = random.choices(storage, k=size)
            else:
                try:
                    sampled_data = random.sample(storage, k=size)
                except ValueError as e:
                    value_error = e

        if value_error or len(sampled_data) != size:
            if ignore_insufficient:
                logging.warning(
                    "Sample operation is ignored due to data insufficient, current buffer is {} while sample is {}".
                    format(self.count(), size)
                )
            else:
                raise ValueError("There are less than {} records/groups in buffer({})".format(size, self.count()))

        sampled_data = self._independence(sampled_data)

        return sampled_data

    @apply_middleware("update")
    def update(self, index: str, data: Optional[Any] = None, meta: Optional[dict] = None) -> bool:
        """
        Overview:
            the method that update data and the related meta information with a certain index.
        Arguments:
            - data (:obj:`Any`): The data which is supposed to replace the old one. If you set it\
                to ``None``, nothing will happen to the old record.
            - meta (:obj:`Optional[dict]`): The new dict which is supposed to merge with the old one.
        """
        if not self.indices.has(index):
            return False
        i = self.indices.get(index)
        item = self.storage[i]
        if data is not None:
            item.data = data
        if meta is not None:
            item.meta = meta
            for key in self.meta_index:
                self.meta_index[key][i] = meta[key] if key in meta else None
        return True

    @apply_middleware("delete")
    def delete(self, indices: Union[str, Iterable[str]]) -> None:
        """
        Overview:
            The method that delete the data and related meta information by specific indices.
        Arguments:
            - indices (Union[str, Iterable[str]]): Where the data to be cleared in the buffer.
        """
        if isinstance(indices, str):
            indices = [indices]
        del_idx = []
        for index in indices:
            if self.indices.has(index):
                del_idx.append(self.indices.get(index))
        if len(del_idx) == 0:
            return
        del_idx = sorted(del_idx, reverse=True)
        for idx in del_idx:
            del self.storage[idx]
        remain_indices = [item.index for item in self.storage]
        key_value_pairs = zip(remain_indices, range(len(indices)))
        self.indices = BufferIndex(self.storage.maxlen, key_value_pairs)

    def save_data(self, file_name: str):
        if not os.path.exists(os.path.dirname(file_name)):
            # If the folder for the specified file does not exist, it will be created.
            if os.path.dirname(file_name) != "":
                os.makedirs(os.path.dirname(file_name))
        hickle.dump(
            py_obj=(
                self.storage,
                self.indices,
                self.meta_index,
            ), file_obj=file_name
        )

    def load_data(self, file_name: str):
        self.storage, self.indices, self.meta_index = hickle.load(file_name)

    def count(self) -> int:
        """
        Overview:
            The method that returns the current length of the buffer.
        """
        return len(self.storage)

    def get(self, idx: int) -> BufferedData:
        """
        Overview:
            The method that returns the BufferedData object given a specific index.
        """
        return self.storage[idx]

    @apply_middleware("clear")
    def clear(self) -> None:
        """
        Overview:
            The method that clear all data, indices, and the meta information in the buffer.
        """
        self.storage.clear()
        self.indices.clear()
        self.meta_index = {}

    def _push(self, data: Any, meta: Optional[dict] = None) -> BufferedData:
        index = uuid.uuid1().hex
        if meta is None:
            meta = {}
        buffered = BufferedData(data=data, index=index, meta=meta)
        self.storage.append(buffered)
        self.indices.append(index)
        # Add meta index
        for key in self.meta_index:
            self.meta_index[key].append(meta[key] if key in meta else None)

        return buffered

    def _independence(
        self, buffered_samples: Union[List[BufferedData], List[List[BufferedData]]]
    ) -> Union[List[BufferedData], List[List[BufferedData]]]:
        """
        Overview:
            Make sure that each record is different from each other, but remember that this function
            is different from clone_object. You may change the data in the buffer by modifying a record.
        Arguments:
            - buffered_samples (:obj:`Union[List[BufferedData], List[List[BufferedData]]]`) Sampled data,
                can be nested if groupby has been set.
        """
        if len(buffered_samples) == 0:
            return buffered_samples
        occurred = defaultdict(int)

        for i, buffered in enumerate(buffered_samples):
            if isinstance(buffered, list):
                sampled_list = buffered
                # Loop over nested samples
                for j, buffered in enumerate(sampled_list):
                    occurred[buffered.index] += 1
                    if occurred[buffered.index] > 1:
                        sampled_list[j] = fastcopy.copy(buffered)
            elif isinstance(buffered, BufferedData):
                occurred[buffered.index] += 1
                if occurred[buffered.index] > 1:
                    buffered_samples[i] = fastcopy.copy(buffered)
            else:
                raise Exception("Get unexpected buffered type {}".format(type(buffered)))
        return buffered_samples

    def _sample_by_group(
            self,
            size: int,
            groupby: str,
            replace: bool = False,
            unroll_len: Optional[int] = None,
            storage: deque = None,
            sliced: bool = False
    ) -> List[List[BufferedData]]:
        """
        Overview:
            Sampling by `group` instead of records, the result will be a collection
            of lists with a length of `size`, but the length of each list may be different from other lists.
        """
        if storage is None:
            storage = self.storage
        if groupby not in self.meta_index:
            self._create_index(groupby)

        def filter_by_unroll_len():
            "Filter groups by unroll len, ensure count of items in each group is greater than unroll_len."
            group_count = Counter(self.meta_index[groupby])
            group_names = []
            for key, count in group_count.items():
                if count >= unroll_len:
                    group_names.append(key)
            return group_names

        if unroll_len and unroll_len > 1:
            group_names = filter_by_unroll_len()
            if len(group_names) == 0:
                return []
        else:
            group_names = list(set(self.meta_index[groupby]))

        sampled_groups = []
        if replace:
            sampled_groups = random.choices(group_names, k=size)
        else:
            try:
                sampled_groups = random.sample(group_names, k=size)
            except ValueError:
                raise ValueError("There are less than {} groups in buffer({} groups)".format(size, len(group_names)))

        # Build dict like {"group name": [records]}
        sampled_data = defaultdict(list)
        for buffered in storage:
            meta_value = buffered.meta[groupby] if groupby in buffered.meta else None
            if meta_value in sampled_groups:
                sampled_data[buffered.meta[groupby]].append(buffered)

        final_sampled_data = []
        for group in sampled_groups:
            seq_data = sampled_data[group]
            # Filter records by unroll_len
            if unroll_len:
                # slice b unroll_len. If don’t do this, more likely obtain duplicate data, \
                #  and the training will easily crash.
                if sliced:
                    start_indice = random.choice(range(max(1, len(seq_data))))
                    start_indice = start_indice // unroll_len
                    if start_indice == (len(seq_data) - 1) // unroll_len:
                        seq_data = seq_data[-unroll_len:]
                    else:
                        seq_data = seq_data[start_indice * unroll_len:start_indice * unroll_len + unroll_len]
                else:
                    start_indice = random.choice(range(max(1, len(seq_data) - unroll_len)))
                    seq_data = seq_data[start_indice:start_indice + unroll_len]

            final_sampled_data.append(seq_data)

        return final_sampled_data

    def _create_index(self, meta_key: str):
        self.meta_index[meta_key] = deque(maxlen=self.storage.maxlen)
        for data in self.storage:
            self.meta_index[meta_key].append(data.meta[meta_key] if meta_key in data.meta else None)

    def __iter__(self) -> deque:
        return iter(self.storage)

    def __copy__(self) -> "DequeBuffer":
        buffer = type(self)(size=self.storage.maxlen)
        buffer.storage = self.storage
        buffer.meta_index = self.meta_index
        buffer.indices = self.indices
        return buffer