File size: 11,266 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
import os
import pytest
import time
import random
import functools
import tempfile
from typing import Callable
from ding.data.buffer import DequeBuffer
from ding.data.buffer.buffer import BufferedData
from torch.utils.data import DataLoader


class RateLimit:
    r"""
    Add rate limit threshold to push function
    """

    def __init__(self, max_rate: int = float("inf"), window_seconds: int = 30) -> None:
        self.max_rate = max_rate
        self.window_seconds = window_seconds
        self.buffered = []

    def __call__(self, action: str, chain: Callable, *args, **kwargs):
        if action == "push":
            return self.push(chain, *args, **kwargs)
        return chain(*args, **kwargs)

    def push(self, chain, data, *args, **kwargs) -> None:
        current = time.time()
        # Cut off stale records
        self.buffered = [t for t in self.buffered if t > current - self.window_seconds]
        if len(self.buffered) < self.max_rate:
            self.buffered.append(current)
            return chain(data, *args, **kwargs)
        else:
            return None


def add_10() -> Callable:
    """
    Transform data on sampling
    """

    def sample(chain: Callable, size: int, replace: bool = False, *args, **kwargs):
        sampled_data = chain(size, replace, *args, **kwargs)
        return [BufferedData(data=item.data + 10, index=item.index, meta=item.meta) for item in sampled_data]

    def _subview(action: str, chain: Callable, *args, **kwargs):
        if action == "sample":
            return sample(chain, *args, **kwargs)
        return chain(*args, **kwargs)

    return _subview


@pytest.mark.unittest
def test_naive_push_sample():
    # Push and sample
    buffer = DequeBuffer(size=10)
    for i in range(20):
        buffer.push(i)
    assert buffer.count() == 10
    assert 0 not in [item.data for item in buffer.sample(10)]

    # Clear
    buffer.clear()
    assert buffer.count() == 0

    # Test replace sample
    for i in range(5):
        buffer.push(i)
    assert buffer.count() == 5
    assert len(buffer.sample(10, replace=True)) == 10

    # Test slicing
    buffer.clear()
    for i in range(10):
        buffer.push(i)
    assert len(buffer.sample(5, sample_range=slice(5, 10))) == 5
    assert 0 not in [item.data for item in buffer.sample(5, sample_range=slice(5, 10))]


@pytest.mark.unittest
def test_rate_limit_push_sample():
    buffer = DequeBuffer(size=10).use(RateLimit(max_rate=5))
    for i in range(10):
        buffer.push(i)
    assert buffer.count() == 5
    assert 5 not in buffer.sample(5)


@pytest.mark.unittest
def test_load_and_save():
    buffer = DequeBuffer(size=10).use(RateLimit(max_rate=5))
    buffer.meta_index = {"label": []}
    for i in range(10):
        buffer.push(i, meta={"label": i})
    assert buffer.count() == 5
    assert 5 not in buffer.sample(5)
    with tempfile.TemporaryDirectory() as tmpdirname:
        test_file = os.path.join(tmpdirname, "data.hkl")
        buffer.save_data(test_file)
        buffer_new = DequeBuffer(size=10).use(RateLimit(max_rate=5))
        buffer_new.load_data(test_file)
        assert buffer_new.count() == 5
        assert 5 not in buffer_new.sample(5)
        assert len(buffer.meta_index["label"]) == 5
        assert all([index < 5 for index in buffer.meta_index["label"]])


@pytest.mark.unittest
def test_buffer_view():
    buf1 = DequeBuffer(size=10)
    for i in range(1):
        buf1.push(i)
    assert buf1.count() == 1

    buf2 = buf1.view().use(RateLimit(max_rate=5)).use(add_10())

    for i in range(10):
        buf2.push(i)
    # With 1 record written by buf1 and 5 records written by buf2
    assert len(buf1._middleware) == 0
    assert buf1.count() == 6
    # All data in buffer should bigger than 10 because of `add_10`
    assert all(d.data >= 10 for d in buf2.sample(5))
    # But data in storage is still less than 10
    assert all(d.data < 10 for d in buf1.sample(5))


@pytest.mark.unittest
def test_sample_with_index():
    buf = DequeBuffer(size=10)
    for i in range(10):
        buf.push({"data": i}, {"meta": i})
    # Random sample and get indices
    indices = [item.index for item in buf.sample(10)]
    assert len(indices) == 10
    random.shuffle(indices)
    indices = indices[:5]

    # Resample by indices
    new_indices = [item.index for item in buf.sample(indices=indices)]
    assert len(new_indices) == len(indices)
    for index in new_indices:
        assert index in indices


@pytest.mark.unittest
def test_update():
    buf = DequeBuffer(size=10)
    for i in range(1):
        buf.push({"data": i}, {"meta": i})

    # Update one data
    [item] = buf.sample(1)
    item.data["new_prop"] = "any"
    meta = None
    success = buf.update(item.index, item.data, item.meta)
    assert success
    # Resample
    [item] = buf.sample(1)
    assert "new_prop" in item.data
    assert meta is None
    # Update object that not exists in buffer
    success = buf.update("invalidindex", {}, None)
    assert not success

    # When exceed buffer size
    for i in range(20):
        buf.push({"data": i})
    assert len(buf.indices) == 10
    assert len(buf.storage) == 10
    for i in range(10):
        index = buf.storage[i].index
        assert buf.indices.get(index) == i


@pytest.mark.unittest
def test_delete():
    maxlen = 100
    cumlen = 40
    dellen = 20
    buf = DequeBuffer(size=maxlen)
    for i in range(cumlen):
        buf.push(i)
    # Delete data
    del_indices = [item.index for item in buf.sample(dellen)]
    buf.delete(del_indices)
    # Reappend
    for i in range(10):
        buf.push(i)
    remlen = min(cumlen, maxlen) - dellen + 10
    assert len(buf.indices) == remlen
    assert len(buf.storage) == remlen
    for i in range(remlen):
        index = buf.storage[i].index
        assert buf.indices.get(index) == i


@pytest.mark.unittest
def test_ignore_insufficient():
    buffer = DequeBuffer(size=10)
    for i in range(2):
        buffer.push(i)

    with pytest.raises(ValueError):
        buffer.sample(3, ignore_insufficient=False)
    data = buffer.sample(3, ignore_insufficient=True)
    assert len(data) == 0


@pytest.mark.unittest
def test_independence():
    # By replace
    buffer = DequeBuffer(size=1)
    data = {"key": "origin"}
    buffer.push(data)
    sampled_data = buffer.sample(2, replace=True)
    assert len(sampled_data) == 2
    sampled_data[0].data["key"] = "new"
    assert sampled_data[1].data["key"] == "origin"

    # By indices
    buffer = DequeBuffer(size=1)
    data = {"key": "origin"}
    buffered = buffer.push(data)
    indices = [buffered.index, buffered.index]
    sampled_data = buffer.sample(indices=indices)
    assert len(sampled_data) == 2
    sampled_data[0].data["key"] = "new"
    assert sampled_data[1].data["key"] == "origin"


@pytest.mark.unittest
def test_groupby():
    buffer = DequeBuffer(size=3)
    buffer.push("a", {"group": 1})
    buffer.push("b", {"group": 2})
    buffer.push("c", {"group": 2})

    sampled_data = buffer.sample(2, groupby="group")
    assert len(sampled_data) == 2
    group1 = sampled_data[0] if len(sampled_data[0]) == 1 else sampled_data[1]
    group2 = sampled_data[0] if len(sampled_data[0]) == 2 else sampled_data[1]
    # Group1 should contain a
    assert "a" == group1[0].data
    # Group2 should contain b and c
    data = [buffered.data for buffered in group2]  # ["b", "c"]
    assert "b" in data
    assert "c" in data

    # Push new data and swap out a, the result will all in group 2
    buffer.push("d", {"group": 2})
    sampled_data = buffer.sample(1, groupby="group")
    assert len(sampled_data) == 1
    assert len(sampled_data[0]) == 3
    data = [buffered.data for buffered in sampled_data[0]]
    assert "d" in data

    # Update meta, set first data's group to 1
    first: BufferedData = buffer.storage[0]
    buffer.update(first.index, first.data, {"group": 1})
    sampled_data = buffer.sample(2, groupby="group")
    assert len(sampled_data) == 2

    # Delete last record, each group will only have one record
    last: BufferedData = buffer.storage[-1]
    buffer.delete(last.index)
    sampled_data = buffer.sample(2, groupby="group")
    assert len(sampled_data) == 2


@pytest.mark.unittest
def test_dataset():
    buffer = DequeBuffer(size=10)
    for i in range(10):
        buffer.push(i)
    dataloader = DataLoader(buffer, batch_size=6, shuffle=True, collate_fn=lambda batch: batch)
    for batch in dataloader:
        assert len(batch) in [4, 6]


@pytest.mark.unittest
def test_unroll_len_in_group():
    buffer = DequeBuffer(size=100)
    for i in range(10):
        for env_id in list("ABC"):
            buffer.push(i, {"env": env_id})

    sampled_data = buffer.sample(3, groupby="env", unroll_len=4)
    assert len(sampled_data) == 3
    for grouped_data in sampled_data:
        assert len(grouped_data) == 4
        # Ensure each group has the same env
        env_ids = set(map(lambda sample: sample.meta["env"], grouped_data))
        assert len(env_ids) == 1
        # Ensure samples in each group is continuous
        result = functools.reduce(lambda a, b: a and a.data + 1 == b.data and b, grouped_data)
        assert isinstance(result, BufferedData), "Not continuous"


@pytest.mark.unittest
def test_insufficient_unroll_len_in_group():
    buffer = DequeBuffer(size=100)

    num = 3  # Items in group A,B,C is 3,4,5
    for env_id in list("ABC"):
        for i in range(num):
            buffer.push(i, {"env": env_id})
        num += 1

    with pytest.raises(ValueError) as exc_info:
        buffer.sample(3, groupby="env", unroll_len=4)
    e = exc_info._excinfo[1]
    assert "There are less than" in str(e)

    # Sample with replace
    sampled_data = buffer.sample(3, groupby="env", unroll_len=4, replace=True)
    assert len(sampled_data) == 3
    for grouped_data in sampled_data:
        assert len(grouped_data) == 4
        # Ensure each group has the same env
        env_ids = set(map(lambda sample: sample.meta["env"], grouped_data))
        assert len(env_ids) == 1
        # Ensure samples in each group is continuous
        result = functools.reduce(lambda a, b: a and a.data + 1 == b.data and b, grouped_data)
        assert isinstance(result, BufferedData), "Not continuous"


@pytest.mark.unittest
def test_slice_unroll_len_in_group():
    buffer = DequeBuffer(size=100, sliced=True)
    data_len = 10
    unroll_len = 4
    start_index = list(range(0, data_len, unroll_len)) + [data_len - unroll_len]
    for i in range(data_len):
        for env_id in list("ABC"):
            buffer.push(i, {"env": env_id})

    sampled_data = buffer.sample(3, groupby="env", unroll_len=unroll_len)
    assert len(sampled_data) == 3
    for grouped_data in sampled_data:
        assert len(grouped_data) == 4
        # Ensure each group has the same env
        env_ids = set(map(lambda sample: sample.meta["env"], grouped_data))
        assert len(env_ids) == 1
        # Ensure samples in each group is continuous
        result = functools.reduce(lambda a, b: a and a.data + 1 == b.data and b, grouped_data)
        assert isinstance(result, BufferedData), "Not continuous"
        # Ensure data after sliced start from correct index
        assert grouped_data[0].data in start_index