gomoku / DI-engine /ding /data /buffer /tests /test_middleware.py
zjowowen's picture
init space
079c32c
raw
history blame
6.71 kB
import pytest
import torch
from ding.data.buffer import DequeBuffer
from ding.data.buffer.middleware import clone_object, use_time_check, staleness_check, sample_range_view
from ding.data.buffer.middleware import PriorityExperienceReplay, group_sample
from ding.data.buffer.middleware.padding import padding
@pytest.mark.unittest
def test_clone_object():
buffer = DequeBuffer(size=10).use(clone_object())
# Store a dict, a list, a tensor
arr = [{"key": "v1"}, ["a"], torch.Tensor([1, 2, 3])]
for o in arr:
buffer.push(o)
# Modify it
for item in buffer.sample(len(arr)):
item = item.data
if isinstance(item, dict):
item["key"] = "v2"
elif isinstance(item, list):
item.append("b")
elif isinstance(item, torch.Tensor):
item[0] = 3
else:
raise Exception("Unexpected type")
# Resample it, and check their values
for item in buffer.sample(len(arr)):
item = item.data
if isinstance(item, dict):
assert item["key"] == "v1"
elif isinstance(item, list):
assert len(item) == 1
elif isinstance(item, torch.Tensor):
assert item[0] == 1
else:
raise Exception("Unexpected type")
def get_data():
return {'obs': torch.randn(4), 'reward': torch.randn(1), 'info': 'xxx'}
@pytest.mark.unittest
def test_use_time_check():
N = 6
buffer = DequeBuffer(size=10)
buffer.use(use_time_check(buffer, max_use=2))
for _ in range(N):
buffer.push(get_data())
for _ in range(2):
data = buffer.sample(size=N, replace=False)
assert len(data) == N
with pytest.raises(ValueError):
buffer.sample(size=1, replace=False)
@pytest.mark.unittest
def test_staleness_check():
N = 6
buffer = DequeBuffer(size=10)
buffer.use(staleness_check(buffer, max_staleness=10))
with pytest.raises(AssertionError):
buffer.push(get_data())
for _ in range(N):
buffer.push(get_data(), meta={'train_iter_data_collected': 0})
data = buffer.sample(size=N, replace=False, train_iter_sample_data=9)
assert len(data) == N
data = buffer.sample(size=N, replace=False, train_iter_sample_data=10) # edge case
assert len(data) == N
for _ in range(2):
buffer.push(get_data(), meta={'train_iter_data_collected': 5})
assert buffer.count() == 8
with pytest.raises(ValueError):
data = buffer.sample(size=N, replace=False, train_iter_sample_data=11)
assert buffer.count() == 2
@pytest.mark.unittest
def test_priority():
N = 5
buffer = DequeBuffer(size=10)
buffer.use(PriorityExperienceReplay(buffer, IS_weight=True))
for _ in range(N):
buffer.push(get_data(), meta={'priority': 2.0})
assert buffer.count() == N
for _ in range(N):
buffer.push(get_data(), meta={'priority': 2.0})
assert buffer.count() == N + N
data = buffer.sample(size=N + N, replace=False)
assert len(data) == N + N
for item in data:
meta = item.meta
assert set(meta.keys()).issuperset(set(['priority', 'priority_idx', 'priority_IS']))
meta['priority'] = 3.0
for item in data:
data, index, meta = item.data, item.index, item.meta
buffer.update(index, data, meta)
data = buffer.sample(size=1)
assert data[0].meta['priority'] == 3.0
buffer.delete(data[0].index)
assert buffer.count() == N + N - 1
buffer.clear()
assert buffer.count() == 0
@pytest.mark.unittest
def test_priority_from_collector():
N = 5
buffer = DequeBuffer(size=10)
buffer.use(PriorityExperienceReplay(buffer, IS_weight=True))
for _ in range(N):
tmp_data = get_data()
tmp_data['priority'] = 2.0
buffer.push(get_data())
assert buffer.count() == N
for _ in range(N):
tmp_data = get_data()
tmp_data['priority'] = 2.0
buffer.push(get_data())
assert buffer.count() == N + N
data = buffer.sample(size=N + N, replace=False)
assert len(data) == N + N
for item in data:
meta = item.meta
assert set(meta.keys()).issuperset(set(['priority', 'priority_idx', 'priority_IS']))
meta['priority'] = 3.0
for item in data:
data, index, meta = item.data, item.index, item.meta
buffer.update(index, data, meta)
data = buffer.sample(size=1)
assert data[0].meta['priority'] == 3.0
buffer.delete(data[0].index)
assert buffer.count() == N + N - 1
buffer.clear()
assert buffer.count() == 0
@pytest.mark.unittest
def test_padding():
buffer = DequeBuffer(size=10)
buffer.use(padding())
for i in range(10):
buffer.push(i, {"group": i & 5}) # [3,3,2,2]
sampled_data = buffer.sample(4, groupby="group")
assert len(sampled_data) == 4
for grouped_data in sampled_data:
assert len(grouped_data) == 3
@pytest.mark.unittest
def test_group_sample():
buffer = DequeBuffer(size=10)
buffer.use(padding(policy="none")).use(group_sample(size_in_group=5, ordered_in_group=True, max_use_in_group=True))
for i in range(4):
buffer.push(i, {"episode": 0})
for i in range(6):
buffer.push(i, {"episode": 1})
sampled_data = buffer.sample(2, groupby="episode")
assert len(sampled_data) == 2
def check_group0(grouped_data):
# In group0 should find only last record with data as None
n_none = 0
for item in grouped_data:
if item.data is None:
n_none += 1
assert n_none == 1
def check_group1(grouped_data):
# In group1 every record should have data and meta
for item in grouped_data:
assert item.data is not None
for grouped_data in sampled_data:
assert len(grouped_data) == 5
meta = grouped_data[0].meta
if meta and "episode" in meta and meta["episode"] == 1:
check_group1(grouped_data)
else:
check_group0(grouped_data)
@pytest.mark.unittest
def test_sample_range_view():
buffer_ = DequeBuffer(size=10)
for i in range(5):
buffer_.push({'data': 'x'})
for i in range(5, 5 + 3):
buffer_.push({'data': 'y'})
for i in range(8, 8 + 2):
buffer_.push({'data': 'z'})
buffer1 = buffer_.view()
buffer1.use(sample_range_view(buffer1, start=-5, end=-2))
for _ in range(10):
sampled_data = buffer1.sample(1)
assert sampled_data[0].data['data'] == 'y'
buffer2 = buffer_.view()
buffer2.use(sample_range_view(buffer1, start=-2))
for _ in range(10):
sampled_data = buffer2.sample(1)
assert sampled_data[0].data['data'] == 'z'