|
import os |
|
import pytest |
|
import time |
|
import random |
|
import functools |
|
import tempfile |
|
from typing import Callable |
|
from ding.data.buffer import DequeBuffer |
|
from ding.data.buffer.buffer import BufferedData |
|
from torch.utils.data import DataLoader |
|
|
|
|
|
class RateLimit: |
|
r""" |
|
Add rate limit threshold to push function |
|
""" |
|
|
|
def __init__(self, max_rate: int = float("inf"), window_seconds: int = 30) -> None: |
|
self.max_rate = max_rate |
|
self.window_seconds = window_seconds |
|
self.buffered = [] |
|
|
|
def __call__(self, action: str, chain: Callable, *args, **kwargs): |
|
if action == "push": |
|
return self.push(chain, *args, **kwargs) |
|
return chain(*args, **kwargs) |
|
|
|
def push(self, chain, data, *args, **kwargs) -> None: |
|
current = time.time() |
|
|
|
self.buffered = [t for t in self.buffered if t > current - self.window_seconds] |
|
if len(self.buffered) < self.max_rate: |
|
self.buffered.append(current) |
|
return chain(data, *args, **kwargs) |
|
else: |
|
return None |
|
|
|
|
|
def add_10() -> Callable: |
|
""" |
|
Transform data on sampling |
|
""" |
|
|
|
def sample(chain: Callable, size: int, replace: bool = False, *args, **kwargs): |
|
sampled_data = chain(size, replace, *args, **kwargs) |
|
return [BufferedData(data=item.data + 10, index=item.index, meta=item.meta) for item in sampled_data] |
|
|
|
def _subview(action: str, chain: Callable, *args, **kwargs): |
|
if action == "sample": |
|
return sample(chain, *args, **kwargs) |
|
return chain(*args, **kwargs) |
|
|
|
return _subview |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_naive_push_sample(): |
|
|
|
buffer = DequeBuffer(size=10) |
|
for i in range(20): |
|
buffer.push(i) |
|
assert buffer.count() == 10 |
|
assert 0 not in [item.data for item in buffer.sample(10)] |
|
|
|
|
|
buffer.clear() |
|
assert buffer.count() == 0 |
|
|
|
|
|
for i in range(5): |
|
buffer.push(i) |
|
assert buffer.count() == 5 |
|
assert len(buffer.sample(10, replace=True)) == 10 |
|
|
|
|
|
buffer.clear() |
|
for i in range(10): |
|
buffer.push(i) |
|
assert len(buffer.sample(5, sample_range=slice(5, 10))) == 5 |
|
assert 0 not in [item.data for item in buffer.sample(5, sample_range=slice(5, 10))] |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_rate_limit_push_sample(): |
|
buffer = DequeBuffer(size=10).use(RateLimit(max_rate=5)) |
|
for i in range(10): |
|
buffer.push(i) |
|
assert buffer.count() == 5 |
|
assert 5 not in buffer.sample(5) |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_load_and_save(): |
|
buffer = DequeBuffer(size=10).use(RateLimit(max_rate=5)) |
|
buffer.meta_index = {"label": []} |
|
for i in range(10): |
|
buffer.push(i, meta={"label": i}) |
|
assert buffer.count() == 5 |
|
assert 5 not in buffer.sample(5) |
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
test_file = os.path.join(tmpdirname, "data.hkl") |
|
buffer.save_data(test_file) |
|
buffer_new = DequeBuffer(size=10).use(RateLimit(max_rate=5)) |
|
buffer_new.load_data(test_file) |
|
assert buffer_new.count() == 5 |
|
assert 5 not in buffer_new.sample(5) |
|
assert len(buffer.meta_index["label"]) == 5 |
|
assert all([index < 5 for index in buffer.meta_index["label"]]) |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_buffer_view(): |
|
buf1 = DequeBuffer(size=10) |
|
for i in range(1): |
|
buf1.push(i) |
|
assert buf1.count() == 1 |
|
|
|
buf2 = buf1.view().use(RateLimit(max_rate=5)).use(add_10()) |
|
|
|
for i in range(10): |
|
buf2.push(i) |
|
|
|
assert len(buf1._middleware) == 0 |
|
assert buf1.count() == 6 |
|
|
|
assert all(d.data >= 10 for d in buf2.sample(5)) |
|
|
|
assert all(d.data < 10 for d in buf1.sample(5)) |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_sample_with_index(): |
|
buf = DequeBuffer(size=10) |
|
for i in range(10): |
|
buf.push({"data": i}, {"meta": i}) |
|
|
|
indices = [item.index for item in buf.sample(10)] |
|
assert len(indices) == 10 |
|
random.shuffle(indices) |
|
indices = indices[:5] |
|
|
|
|
|
new_indices = [item.index for item in buf.sample(indices=indices)] |
|
assert len(new_indices) == len(indices) |
|
for index in new_indices: |
|
assert index in indices |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_update(): |
|
buf = DequeBuffer(size=10) |
|
for i in range(1): |
|
buf.push({"data": i}, {"meta": i}) |
|
|
|
|
|
[item] = buf.sample(1) |
|
item.data["new_prop"] = "any" |
|
meta = None |
|
success = buf.update(item.index, item.data, item.meta) |
|
assert success |
|
|
|
[item] = buf.sample(1) |
|
assert "new_prop" in item.data |
|
assert meta is None |
|
|
|
success = buf.update("invalidindex", {}, None) |
|
assert not success |
|
|
|
|
|
for i in range(20): |
|
buf.push({"data": i}) |
|
assert len(buf.indices) == 10 |
|
assert len(buf.storage) == 10 |
|
for i in range(10): |
|
index = buf.storage[i].index |
|
assert buf.indices.get(index) == i |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_delete(): |
|
maxlen = 100 |
|
cumlen = 40 |
|
dellen = 20 |
|
buf = DequeBuffer(size=maxlen) |
|
for i in range(cumlen): |
|
buf.push(i) |
|
|
|
del_indices = [item.index for item in buf.sample(dellen)] |
|
buf.delete(del_indices) |
|
|
|
for i in range(10): |
|
buf.push(i) |
|
remlen = min(cumlen, maxlen) - dellen + 10 |
|
assert len(buf.indices) == remlen |
|
assert len(buf.storage) == remlen |
|
for i in range(remlen): |
|
index = buf.storage[i].index |
|
assert buf.indices.get(index) == i |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_ignore_insufficient(): |
|
buffer = DequeBuffer(size=10) |
|
for i in range(2): |
|
buffer.push(i) |
|
|
|
with pytest.raises(ValueError): |
|
buffer.sample(3, ignore_insufficient=False) |
|
data = buffer.sample(3, ignore_insufficient=True) |
|
assert len(data) == 0 |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_independence(): |
|
|
|
buffer = DequeBuffer(size=1) |
|
data = {"key": "origin"} |
|
buffer.push(data) |
|
sampled_data = buffer.sample(2, replace=True) |
|
assert len(sampled_data) == 2 |
|
sampled_data[0].data["key"] = "new" |
|
assert sampled_data[1].data["key"] == "origin" |
|
|
|
|
|
buffer = DequeBuffer(size=1) |
|
data = {"key": "origin"} |
|
buffered = buffer.push(data) |
|
indices = [buffered.index, buffered.index] |
|
sampled_data = buffer.sample(indices=indices) |
|
assert len(sampled_data) == 2 |
|
sampled_data[0].data["key"] = "new" |
|
assert sampled_data[1].data["key"] == "origin" |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_groupby(): |
|
buffer = DequeBuffer(size=3) |
|
buffer.push("a", {"group": 1}) |
|
buffer.push("b", {"group": 2}) |
|
buffer.push("c", {"group": 2}) |
|
|
|
sampled_data = buffer.sample(2, groupby="group") |
|
assert len(sampled_data) == 2 |
|
group1 = sampled_data[0] if len(sampled_data[0]) == 1 else sampled_data[1] |
|
group2 = sampled_data[0] if len(sampled_data[0]) == 2 else sampled_data[1] |
|
|
|
assert "a" == group1[0].data |
|
|
|
data = [buffered.data for buffered in group2] |
|
assert "b" in data |
|
assert "c" in data |
|
|
|
|
|
buffer.push("d", {"group": 2}) |
|
sampled_data = buffer.sample(1, groupby="group") |
|
assert len(sampled_data) == 1 |
|
assert len(sampled_data[0]) == 3 |
|
data = [buffered.data for buffered in sampled_data[0]] |
|
assert "d" in data |
|
|
|
|
|
first: BufferedData = buffer.storage[0] |
|
buffer.update(first.index, first.data, {"group": 1}) |
|
sampled_data = buffer.sample(2, groupby="group") |
|
assert len(sampled_data) == 2 |
|
|
|
|
|
last: BufferedData = buffer.storage[-1] |
|
buffer.delete(last.index) |
|
sampled_data = buffer.sample(2, groupby="group") |
|
assert len(sampled_data) == 2 |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_dataset(): |
|
buffer = DequeBuffer(size=10) |
|
for i in range(10): |
|
buffer.push(i) |
|
dataloader = DataLoader(buffer, batch_size=6, shuffle=True, collate_fn=lambda batch: batch) |
|
for batch in dataloader: |
|
assert len(batch) in [4, 6] |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_unroll_len_in_group(): |
|
buffer = DequeBuffer(size=100) |
|
for i in range(10): |
|
for env_id in list("ABC"): |
|
buffer.push(i, {"env": env_id}) |
|
|
|
sampled_data = buffer.sample(3, groupby="env", unroll_len=4) |
|
assert len(sampled_data) == 3 |
|
for grouped_data in sampled_data: |
|
assert len(grouped_data) == 4 |
|
|
|
env_ids = set(map(lambda sample: sample.meta["env"], grouped_data)) |
|
assert len(env_ids) == 1 |
|
|
|
result = functools.reduce(lambda a, b: a and a.data + 1 == b.data and b, grouped_data) |
|
assert isinstance(result, BufferedData), "Not continuous" |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_insufficient_unroll_len_in_group(): |
|
buffer = DequeBuffer(size=100) |
|
|
|
num = 3 |
|
for env_id in list("ABC"): |
|
for i in range(num): |
|
buffer.push(i, {"env": env_id}) |
|
num += 1 |
|
|
|
with pytest.raises(ValueError) as exc_info: |
|
buffer.sample(3, groupby="env", unroll_len=4) |
|
e = exc_info._excinfo[1] |
|
assert "There are less than" in str(e) |
|
|
|
|
|
sampled_data = buffer.sample(3, groupby="env", unroll_len=4, replace=True) |
|
assert len(sampled_data) == 3 |
|
for grouped_data in sampled_data: |
|
assert len(grouped_data) == 4 |
|
|
|
env_ids = set(map(lambda sample: sample.meta["env"], grouped_data)) |
|
assert len(env_ids) == 1 |
|
|
|
result = functools.reduce(lambda a, b: a and a.data + 1 == b.data and b, grouped_data) |
|
assert isinstance(result, BufferedData), "Not continuous" |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_slice_unroll_len_in_group(): |
|
buffer = DequeBuffer(size=100, sliced=True) |
|
data_len = 10 |
|
unroll_len = 4 |
|
start_index = list(range(0, data_len, unroll_len)) + [data_len - unroll_len] |
|
for i in range(data_len): |
|
for env_id in list("ABC"): |
|
buffer.push(i, {"env": env_id}) |
|
|
|
sampled_data = buffer.sample(3, groupby="env", unroll_len=unroll_len) |
|
assert len(sampled_data) == 3 |
|
for grouped_data in sampled_data: |
|
assert len(grouped_data) == 4 |
|
|
|
env_ids = set(map(lambda sample: sample.meta["env"], grouped_data)) |
|
assert len(env_ids) == 1 |
|
|
|
result = functools.reduce(lambda a, b: a and a.data + 1 == b.data and b, grouped_data) |
|
assert isinstance(result, BufferedData), "Not continuous" |
|
|
|
assert grouped_data[0].data in start_index |
|
|