|
import pytest |
|
import torch |
|
from ding.data.buffer import DequeBuffer |
|
from ding.data.buffer.middleware import clone_object, use_time_check, staleness_check, sample_range_view |
|
from ding.data.buffer.middleware import PriorityExperienceReplay, group_sample |
|
from ding.data.buffer.middleware.padding import padding |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_clone_object(): |
|
buffer = DequeBuffer(size=10).use(clone_object()) |
|
|
|
|
|
arr = [{"key": "v1"}, ["a"], torch.Tensor([1, 2, 3])] |
|
for o in arr: |
|
buffer.push(o) |
|
|
|
|
|
for item in buffer.sample(len(arr)): |
|
item = item.data |
|
if isinstance(item, dict): |
|
item["key"] = "v2" |
|
elif isinstance(item, list): |
|
item.append("b") |
|
elif isinstance(item, torch.Tensor): |
|
item[0] = 3 |
|
else: |
|
raise Exception("Unexpected type") |
|
|
|
|
|
for item in buffer.sample(len(arr)): |
|
item = item.data |
|
if isinstance(item, dict): |
|
assert item["key"] == "v1" |
|
elif isinstance(item, list): |
|
assert len(item) == 1 |
|
elif isinstance(item, torch.Tensor): |
|
assert item[0] == 1 |
|
else: |
|
raise Exception("Unexpected type") |
|
|
|
|
|
def get_data(): |
|
return {'obs': torch.randn(4), 'reward': torch.randn(1), 'info': 'xxx'} |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_use_time_check(): |
|
N = 6 |
|
buffer = DequeBuffer(size=10) |
|
buffer.use(use_time_check(buffer, max_use=2)) |
|
|
|
for _ in range(N): |
|
buffer.push(get_data()) |
|
|
|
for _ in range(2): |
|
data = buffer.sample(size=N, replace=False) |
|
assert len(data) == N |
|
with pytest.raises(ValueError): |
|
buffer.sample(size=1, replace=False) |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_staleness_check(): |
|
N = 6 |
|
buffer = DequeBuffer(size=10) |
|
buffer.use(staleness_check(buffer, max_staleness=10)) |
|
|
|
with pytest.raises(AssertionError): |
|
buffer.push(get_data()) |
|
for _ in range(N): |
|
buffer.push(get_data(), meta={'train_iter_data_collected': 0}) |
|
data = buffer.sample(size=N, replace=False, train_iter_sample_data=9) |
|
assert len(data) == N |
|
data = buffer.sample(size=N, replace=False, train_iter_sample_data=10) |
|
assert len(data) == N |
|
for _ in range(2): |
|
buffer.push(get_data(), meta={'train_iter_data_collected': 5}) |
|
assert buffer.count() == 8 |
|
with pytest.raises(ValueError): |
|
data = buffer.sample(size=N, replace=False, train_iter_sample_data=11) |
|
assert buffer.count() == 2 |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_priority(): |
|
N = 5 |
|
buffer = DequeBuffer(size=10) |
|
buffer.use(PriorityExperienceReplay(buffer, IS_weight=True)) |
|
for _ in range(N): |
|
buffer.push(get_data(), meta={'priority': 2.0}) |
|
assert buffer.count() == N |
|
for _ in range(N): |
|
buffer.push(get_data(), meta={'priority': 2.0}) |
|
assert buffer.count() == N + N |
|
data = buffer.sample(size=N + N, replace=False) |
|
assert len(data) == N + N |
|
for item in data: |
|
meta = item.meta |
|
assert set(meta.keys()).issuperset(set(['priority', 'priority_idx', 'priority_IS'])) |
|
meta['priority'] = 3.0 |
|
for item in data: |
|
data, index, meta = item.data, item.index, item.meta |
|
buffer.update(index, data, meta) |
|
data = buffer.sample(size=1) |
|
assert data[0].meta['priority'] == 3.0 |
|
buffer.delete(data[0].index) |
|
assert buffer.count() == N + N - 1 |
|
buffer.clear() |
|
assert buffer.count() == 0 |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_priority_from_collector(): |
|
N = 5 |
|
buffer = DequeBuffer(size=10) |
|
buffer.use(PriorityExperienceReplay(buffer, IS_weight=True)) |
|
for _ in range(N): |
|
tmp_data = get_data() |
|
tmp_data['priority'] = 2.0 |
|
buffer.push(get_data()) |
|
assert buffer.count() == N |
|
for _ in range(N): |
|
tmp_data = get_data() |
|
tmp_data['priority'] = 2.0 |
|
buffer.push(get_data()) |
|
assert buffer.count() == N + N |
|
data = buffer.sample(size=N + N, replace=False) |
|
assert len(data) == N + N |
|
for item in data: |
|
meta = item.meta |
|
assert set(meta.keys()).issuperset(set(['priority', 'priority_idx', 'priority_IS'])) |
|
meta['priority'] = 3.0 |
|
for item in data: |
|
data, index, meta = item.data, item.index, item.meta |
|
buffer.update(index, data, meta) |
|
data = buffer.sample(size=1) |
|
assert data[0].meta['priority'] == 3.0 |
|
buffer.delete(data[0].index) |
|
assert buffer.count() == N + N - 1 |
|
buffer.clear() |
|
assert buffer.count() == 0 |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_padding(): |
|
buffer = DequeBuffer(size=10) |
|
buffer.use(padding()) |
|
for i in range(10): |
|
buffer.push(i, {"group": i & 5}) |
|
sampled_data = buffer.sample(4, groupby="group") |
|
assert len(sampled_data) == 4 |
|
for grouped_data in sampled_data: |
|
assert len(grouped_data) == 3 |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_group_sample(): |
|
buffer = DequeBuffer(size=10) |
|
buffer.use(padding(policy="none")).use(group_sample(size_in_group=5, ordered_in_group=True, max_use_in_group=True)) |
|
for i in range(4): |
|
buffer.push(i, {"episode": 0}) |
|
for i in range(6): |
|
buffer.push(i, {"episode": 1}) |
|
sampled_data = buffer.sample(2, groupby="episode") |
|
assert len(sampled_data) == 2 |
|
|
|
def check_group0(grouped_data): |
|
|
|
n_none = 0 |
|
for item in grouped_data: |
|
if item.data is None: |
|
n_none += 1 |
|
assert n_none == 1 |
|
|
|
def check_group1(grouped_data): |
|
|
|
for item in grouped_data: |
|
assert item.data is not None |
|
|
|
for grouped_data in sampled_data: |
|
assert len(grouped_data) == 5 |
|
meta = grouped_data[0].meta |
|
if meta and "episode" in meta and meta["episode"] == 1: |
|
check_group1(grouped_data) |
|
else: |
|
check_group0(grouped_data) |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_sample_range_view(): |
|
buffer_ = DequeBuffer(size=10) |
|
for i in range(5): |
|
buffer_.push({'data': 'x'}) |
|
for i in range(5, 5 + 3): |
|
buffer_.push({'data': 'y'}) |
|
for i in range(8, 8 + 2): |
|
buffer_.push({'data': 'z'}) |
|
|
|
buffer1 = buffer_.view() |
|
buffer1.use(sample_range_view(buffer1, start=-5, end=-2)) |
|
for _ in range(10): |
|
sampled_data = buffer1.sample(1) |
|
assert sampled_data[0].data['data'] == 'y' |
|
|
|
buffer2 = buffer_.view() |
|
buffer2.use(sample_range_view(buffer1, start=-2)) |
|
for _ in range(10): |
|
sampled_data = buffer2.sample(1) |
|
assert sampled_data[0].data['data'] == 'z' |
|
|