|
import pytest |
|
from collections import namedtuple |
|
import random |
|
import numpy as np |
|
import torch |
|
from ding.utils.data import timestep_collate, default_collate, default_decollate, diff_shape_collate |
|
|
|
B, T = 4, 3 |
|
|
|
|
|
@pytest.mark.unittest |
|
class TestTimestepCollate: |
|
|
|
def get_data(self): |
|
data = { |
|
'obs': [torch.randn(4) for _ in range(T)], |
|
'reward': [torch.FloatTensor([0]) for _ in range(T)], |
|
'done': [False for _ in range(T)], |
|
'prev_state': [(torch.randn(3), torch.randn(3)) for _ in range(T)], |
|
'action': [[torch.randn(3), torch.randn(5)] for _ in range(T)], |
|
} |
|
return data |
|
|
|
def get_multi_shape_state_data(self): |
|
data = { |
|
'obs': [torch.randn(4) for _ in range(T)], |
|
'reward': [torch.FloatTensor([0]) for _ in range(T)], |
|
'done': [False for _ in range(T)], |
|
'prev_state': [ |
|
[(torch.randn(3), torch.randn(5)), (torch.randn(4), ), (torch.randn(5), torch.randn(6))] |
|
for _ in range(T) |
|
], |
|
'action': [[torch.randn(3), torch.randn(5)] for _ in range(T)], |
|
} |
|
return data |
|
|
|
def test(self): |
|
batch = timestep_collate([self.get_data() for _ in range(B)]) |
|
assert isinstance(batch, dict) |
|
assert set(batch.keys()) == set(['obs', 'reward', 'done', 'prev_state', 'action']) |
|
assert batch['obs'].shape == (T, B, 4) |
|
assert batch['reward'].shape == (T, B) |
|
assert batch['done'].shape == (T, B) and batch['done'].dtype == torch.bool |
|
assert isinstance(batch['prev_state'], list) |
|
assert len(batch['prev_state']) == T and len(batch['prev_state'][0]) == B |
|
assert isinstance(batch['action'], list) and len(batch['action']) == T |
|
assert batch['action'][0][0].shape == (B, 3) |
|
assert batch['action'][0][1].shape == (B, 5) |
|
|
|
|
|
batch = timestep_collate([self.get_multi_shape_state_data() for _ in range(B)]) |
|
assert isinstance(batch, dict) |
|
assert set(batch.keys()) == set(['obs', 'reward', 'done', 'prev_state', 'action']) |
|
assert batch['obs'].shape == (T, B, 4) |
|
assert batch['reward'].shape == (T, B) |
|
assert batch['done'].shape == (T, B) and batch['done'].dtype == torch.bool |
|
assert isinstance(batch['prev_state'], list) |
|
print(batch['prev_state'][0][0]) |
|
assert len(batch['prev_state']) == T and len(batch['prev_state'][0] |
|
) == B and len(batch['prev_state'][0][0]) == 3 |
|
assert isinstance(batch['action'], list) and len(batch['action']) == T |
|
assert batch['action'][0][0].shape == (B, 3) |
|
assert batch['action'][0][1].shape == (B, 5) |
|
|
|
|
|
@pytest.mark.unittest |
|
class TestDefaultCollate: |
|
|
|
def test_numpy(self): |
|
data = [np.random.randn(4, 3).astype(np.float64) for _ in range(5)] |
|
data = default_collate(data) |
|
assert data.shape == (5, 4, 3) |
|
assert data.dtype == torch.float64 |
|
data = [float(np.random.randn(1)[0]) for _ in range(6)] |
|
data = default_collate(data) |
|
assert data.shape == (6, ) |
|
assert data.dtype == torch.float32 |
|
with pytest.raises(TypeError): |
|
default_collate([np.array(['str']) for _ in range(3)]) |
|
|
|
def test_basic(self): |
|
data = [random.random() for _ in range(3)] |
|
data = default_collate(data) |
|
assert data.shape == (3, ) |
|
assert data.dtype == torch.float32 |
|
data = [random.randint(0, 10) for _ in range(3)] |
|
data = default_collate(data) |
|
assert data.shape == (3, ) |
|
assert data.dtype == torch.int64 |
|
data = ['str' for _ in range(4)] |
|
data = default_collate(data) |
|
assert len(data) == 4 |
|
assert all([s == 'str' for s in data]) |
|
T = namedtuple('T', ['x', 'y']) |
|
data = [T(1, 2) for _ in range(4)] |
|
data = default_collate(data) |
|
assert isinstance(data, T) |
|
assert data.x.shape == (4, ) and data.x.eq(1).sum() == 4 |
|
assert data.y.shape == (4, ) and data.y.eq(2).sum() == 4 |
|
with pytest.raises(TypeError): |
|
default_collate([object() for _ in range(4)]) |
|
|
|
data = [{'collate_ignore_data': random.random()} for _ in range(4)] |
|
data = default_collate(data) |
|
assert isinstance(data, dict) |
|
assert len(data['collate_ignore_data']) == 4 |
|
|
|
|
|
@pytest.mark.unittest |
|
class TestDefaultDecollate: |
|
|
|
def test(self): |
|
with pytest.raises(TypeError): |
|
default_decollate([object() for _ in range(4)]) |
|
data = torch.randn(4, 3, 5) |
|
data = default_decollate(data) |
|
print([d.shape for d in data]) |
|
assert len(data) == 4 and all([d.shape == (3, 5) for d in data]) |
|
data = [torch.randn(8, 2, 4), torch.randn(8, 5)] |
|
data = default_decollate(data) |
|
assert len(data) == 8 and all([d[0].shape == (2, 4) and d[1].shape == (5, ) for d in data]) |
|
data = { |
|
'logit': torch.randn(4, 13), |
|
'action': torch.randint(0, 13, size=(4, )), |
|
'prev_state': [(torch.zeros(3, 1, 12), torch.zeros(3, 1, 12)) for _ in range(4)], |
|
} |
|
data = default_decollate(data) |
|
assert len(data) == 4 and isinstance(data, list) |
|
assert all([d['logit'].shape == (13, ) for d in data]) |
|
assert all([d['action'].shape == (1, ) for d in data]) |
|
assert all([len(d['prev_state']) == 2 and d['prev_state'][0].shape == (3, 1, 12) for d in data]) |
|
|
|
|
|
@pytest.mark.unittest |
|
class TestDiffShapeCollate: |
|
|
|
def test(self): |
|
with pytest.raises(TypeError): |
|
diff_shape_collate([object() for _ in range(4)]) |
|
data = [ |
|
{ |
|
'item1': torch.randn(4), |
|
'item2': None, |
|
'item3': torch.randn(3), |
|
'item4': np.random.randn(5, 6) |
|
}, |
|
{ |
|
'item1': torch.randn(5), |
|
'item2': torch.randn(6), |
|
'item3': torch.randn(3), |
|
'item4': np.random.randn(5, 6) |
|
}, |
|
] |
|
data = diff_shape_collate(data) |
|
assert isinstance(data['item1'], list) and len(data['item1']) == 2 |
|
assert isinstance(data['item2'], list) and len(data['item2']) == 2 and data['item2'][0] is None |
|
assert data['item3'].shape == (2, 3) |
|
assert data['item4'].shape == (2, 5, 6) |
|
data = [ |
|
{ |
|
'item1': 1, |
|
'item2': 3, |
|
'item3': 2.0 |
|
}, |
|
{ |
|
'item1': None, |
|
'item2': 4, |
|
'item3': 2.0 |
|
}, |
|
] |
|
data = diff_shape_collate(data) |
|
assert isinstance(data['item1'], list) and len(data['item1']) == 2 and data['item1'][1] is None |
|
assert data['item2'].shape == (2, ) and data['item2'].dtype == torch.int64 |
|
assert data['item3'].shape == (2, ) and data['item3'].dtype == torch.float32 |
|
|