gomoku / DI-engine /ding /utils /data /tests /test_collate_fn.py
zjowowen's picture
init space
079c32c
import pytest
from collections import namedtuple
import random
import numpy as np
import torch
from ding.utils.data import timestep_collate, default_collate, default_decollate, diff_shape_collate
B, T = 4, 3
@pytest.mark.unittest
class TestTimestepCollate:
def get_data(self):
data = {
'obs': [torch.randn(4) for _ in range(T)],
'reward': [torch.FloatTensor([0]) for _ in range(T)],
'done': [False for _ in range(T)],
'prev_state': [(torch.randn(3), torch.randn(3)) for _ in range(T)],
'action': [[torch.randn(3), torch.randn(5)] for _ in range(T)],
}
return data
def get_multi_shape_state_data(self):
data = {
'obs': [torch.randn(4) for _ in range(T)],
'reward': [torch.FloatTensor([0]) for _ in range(T)],
'done': [False for _ in range(T)],
'prev_state': [
[(torch.randn(3), torch.randn(5)), (torch.randn(4), ), (torch.randn(5), torch.randn(6))]
for _ in range(T)
],
'action': [[torch.randn(3), torch.randn(5)] for _ in range(T)],
}
return data
def test(self):
batch = timestep_collate([self.get_data() for _ in range(B)])
assert isinstance(batch, dict)
assert set(batch.keys()) == set(['obs', 'reward', 'done', 'prev_state', 'action'])
assert batch['obs'].shape == (T, B, 4)
assert batch['reward'].shape == (T, B)
assert batch['done'].shape == (T, B) and batch['done'].dtype == torch.bool
assert isinstance(batch['prev_state'], list)
assert len(batch['prev_state']) == T and len(batch['prev_state'][0]) == B
assert isinstance(batch['action'], list) and len(batch['action']) == T
assert batch['action'][0][0].shape == (B, 3)
assert batch['action'][0][1].shape == (B, 5)
# hidden_state might contain multi prev_states with different shapes
batch = timestep_collate([self.get_multi_shape_state_data() for _ in range(B)])
assert isinstance(batch, dict)
assert set(batch.keys()) == set(['obs', 'reward', 'done', 'prev_state', 'action'])
assert batch['obs'].shape == (T, B, 4)
assert batch['reward'].shape == (T, B)
assert batch['done'].shape == (T, B) and batch['done'].dtype == torch.bool
assert isinstance(batch['prev_state'], list)
print(batch['prev_state'][0][0])
assert len(batch['prev_state']) == T and len(batch['prev_state'][0]
) == B and len(batch['prev_state'][0][0]) == 3
assert isinstance(batch['action'], list) and len(batch['action']) == T
assert batch['action'][0][0].shape == (B, 3)
assert batch['action'][0][1].shape == (B, 5)
@pytest.mark.unittest
class TestDefaultCollate:
def test_numpy(self):
data = [np.random.randn(4, 3).astype(np.float64) for _ in range(5)]
data = default_collate(data)
assert data.shape == (5, 4, 3)
assert data.dtype == torch.float64
data = [float(np.random.randn(1)[0]) for _ in range(6)]
data = default_collate(data)
assert data.shape == (6, )
assert data.dtype == torch.float32
with pytest.raises(TypeError):
default_collate([np.array(['str']) for _ in range(3)])
def test_basic(self):
data = [random.random() for _ in range(3)]
data = default_collate(data)
assert data.shape == (3, )
assert data.dtype == torch.float32
data = [random.randint(0, 10) for _ in range(3)]
data = default_collate(data)
assert data.shape == (3, )
assert data.dtype == torch.int64
data = ['str' for _ in range(4)]
data = default_collate(data)
assert len(data) == 4
assert all([s == 'str' for s in data])
T = namedtuple('T', ['x', 'y'])
data = [T(1, 2) for _ in range(4)]
data = default_collate(data)
assert isinstance(data, T)
assert data.x.shape == (4, ) and data.x.eq(1).sum() == 4
assert data.y.shape == (4, ) and data.y.eq(2).sum() == 4
with pytest.raises(TypeError):
default_collate([object() for _ in range(4)])
data = [{'collate_ignore_data': random.random()} for _ in range(4)]
data = default_collate(data)
assert isinstance(data, dict)
assert len(data['collate_ignore_data']) == 4
@pytest.mark.unittest
class TestDefaultDecollate:
def test(self):
with pytest.raises(TypeError):
default_decollate([object() for _ in range(4)])
data = torch.randn(4, 3, 5)
data = default_decollate(data)
print([d.shape for d in data])
assert len(data) == 4 and all([d.shape == (3, 5) for d in data])
data = [torch.randn(8, 2, 4), torch.randn(8, 5)]
data = default_decollate(data)
assert len(data) == 8 and all([d[0].shape == (2, 4) and d[1].shape == (5, ) for d in data])
data = {
'logit': torch.randn(4, 13),
'action': torch.randint(0, 13, size=(4, )),
'prev_state': [(torch.zeros(3, 1, 12), torch.zeros(3, 1, 12)) for _ in range(4)],
}
data = default_decollate(data)
assert len(data) == 4 and isinstance(data, list)
assert all([d['logit'].shape == (13, ) for d in data])
assert all([d['action'].shape == (1, ) for d in data])
assert all([len(d['prev_state']) == 2 and d['prev_state'][0].shape == (3, 1, 12) for d in data])
@pytest.mark.unittest
class TestDiffShapeCollate:
def test(self):
with pytest.raises(TypeError):
diff_shape_collate([object() for _ in range(4)])
data = [
{
'item1': torch.randn(4),
'item2': None,
'item3': torch.randn(3),
'item4': np.random.randn(5, 6)
},
{
'item1': torch.randn(5),
'item2': torch.randn(6),
'item3': torch.randn(3),
'item4': np.random.randn(5, 6)
},
]
data = diff_shape_collate(data)
assert isinstance(data['item1'], list) and len(data['item1']) == 2
assert isinstance(data['item2'], list) and len(data['item2']) == 2 and data['item2'][0] is None
assert data['item3'].shape == (2, 3)
assert data['item4'].shape == (2, 5, 6)
data = [
{
'item1': 1,
'item2': 3,
'item3': 2.0
},
{
'item1': None,
'item2': 4,
'item3': 2.0
},
]
data = diff_shape_collate(data)
assert isinstance(data['item1'], list) and len(data['item1']) == 2 and data['item1'][1] is None
assert data['item2'].shape == (2, ) and data['item2'].dtype == torch.int64
assert data['item3'].shape == (2, ) and data['item3'].dtype == torch.float32