File size: 7,050 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import pytest
from collections import namedtuple
import random
import numpy as np
import torch
from ding.utils.data import timestep_collate, default_collate, default_decollate, diff_shape_collate

B, T = 4, 3


@pytest.mark.unittest
class TestTimestepCollate:

    def get_data(self):
        data = {
            'obs': [torch.randn(4) for _ in range(T)],
            'reward': [torch.FloatTensor([0]) for _ in range(T)],
            'done': [False for _ in range(T)],
            'prev_state': [(torch.randn(3), torch.randn(3)) for _ in range(T)],
            'action': [[torch.randn(3), torch.randn(5)] for _ in range(T)],
        }
        return data

    def get_multi_shape_state_data(self):
        data = {
            'obs': [torch.randn(4) for _ in range(T)],
            'reward': [torch.FloatTensor([0]) for _ in range(T)],
            'done': [False for _ in range(T)],
            'prev_state': [
                [(torch.randn(3), torch.randn(5)), (torch.randn(4), ), (torch.randn(5), torch.randn(6))]
                for _ in range(T)
            ],
            'action': [[torch.randn(3), torch.randn(5)] for _ in range(T)],
        }
        return data

    def test(self):
        batch = timestep_collate([self.get_data() for _ in range(B)])
        assert isinstance(batch, dict)
        assert set(batch.keys()) == set(['obs', 'reward', 'done', 'prev_state', 'action'])
        assert batch['obs'].shape == (T, B, 4)
        assert batch['reward'].shape == (T, B)
        assert batch['done'].shape == (T, B) and batch['done'].dtype == torch.bool
        assert isinstance(batch['prev_state'], list)
        assert len(batch['prev_state']) == T and len(batch['prev_state'][0]) == B
        assert isinstance(batch['action'], list) and len(batch['action']) == T
        assert batch['action'][0][0].shape == (B, 3)
        assert batch['action'][0][1].shape == (B, 5)

        # hidden_state might contain multi prev_states with different shapes
        batch = timestep_collate([self.get_multi_shape_state_data() for _ in range(B)])
        assert isinstance(batch, dict)
        assert set(batch.keys()) == set(['obs', 'reward', 'done', 'prev_state', 'action'])
        assert batch['obs'].shape == (T, B, 4)
        assert batch['reward'].shape == (T, B)
        assert batch['done'].shape == (T, B) and batch['done'].dtype == torch.bool
        assert isinstance(batch['prev_state'], list)
        print(batch['prev_state'][0][0])
        assert len(batch['prev_state']) == T and len(batch['prev_state'][0]
                                                     ) == B and len(batch['prev_state'][0][0]) == 3
        assert isinstance(batch['action'], list) and len(batch['action']) == T
        assert batch['action'][0][0].shape == (B, 3)
        assert batch['action'][0][1].shape == (B, 5)


@pytest.mark.unittest
class TestDefaultCollate:

    def test_numpy(self):
        data = [np.random.randn(4, 3).astype(np.float64) for _ in range(5)]
        data = default_collate(data)
        assert data.shape == (5, 4, 3)
        assert data.dtype == torch.float64
        data = [float(np.random.randn(1)[0]) for _ in range(6)]
        data = default_collate(data)
        assert data.shape == (6, )
        assert data.dtype == torch.float32
        with pytest.raises(TypeError):
            default_collate([np.array(['str']) for _ in range(3)])

    def test_basic(self):
        data = [random.random() for _ in range(3)]
        data = default_collate(data)
        assert data.shape == (3, )
        assert data.dtype == torch.float32
        data = [random.randint(0, 10) for _ in range(3)]
        data = default_collate(data)
        assert data.shape == (3, )
        assert data.dtype == torch.int64
        data = ['str' for _ in range(4)]
        data = default_collate(data)
        assert len(data) == 4
        assert all([s == 'str' for s in data])
        T = namedtuple('T', ['x', 'y'])
        data = [T(1, 2) for _ in range(4)]
        data = default_collate(data)
        assert isinstance(data, T)
        assert data.x.shape == (4, ) and data.x.eq(1).sum() == 4
        assert data.y.shape == (4, ) and data.y.eq(2).sum() == 4
        with pytest.raises(TypeError):
            default_collate([object() for _ in range(4)])

        data = [{'collate_ignore_data': random.random()} for _ in range(4)]
        data = default_collate(data)
        assert isinstance(data, dict)
        assert len(data['collate_ignore_data']) == 4


@pytest.mark.unittest
class TestDefaultDecollate:

    def test(self):
        with pytest.raises(TypeError):
            default_decollate([object() for _ in range(4)])
        data = torch.randn(4, 3, 5)
        data = default_decollate(data)
        print([d.shape for d in data])
        assert len(data) == 4 and all([d.shape == (3, 5) for d in data])
        data = [torch.randn(8, 2, 4), torch.randn(8, 5)]
        data = default_decollate(data)
        assert len(data) == 8 and all([d[0].shape == (2, 4) and d[1].shape == (5, ) for d in data])
        data = {
            'logit': torch.randn(4, 13),
            'action': torch.randint(0, 13, size=(4, )),
            'prev_state': [(torch.zeros(3, 1, 12), torch.zeros(3, 1, 12)) for _ in range(4)],
        }
        data = default_decollate(data)
        assert len(data) == 4 and isinstance(data, list)
        assert all([d['logit'].shape == (13, ) for d in data])
        assert all([d['action'].shape == (1, ) for d in data])
        assert all([len(d['prev_state']) == 2 and d['prev_state'][0].shape == (3, 1, 12) for d in data])


@pytest.mark.unittest
class TestDiffShapeCollate:

    def test(self):
        with pytest.raises(TypeError):
            diff_shape_collate([object() for _ in range(4)])
        data = [
            {
                'item1': torch.randn(4),
                'item2': None,
                'item3': torch.randn(3),
                'item4': np.random.randn(5, 6)
            },
            {
                'item1': torch.randn(5),
                'item2': torch.randn(6),
                'item3': torch.randn(3),
                'item4': np.random.randn(5, 6)
            },
        ]
        data = diff_shape_collate(data)
        assert isinstance(data['item1'], list) and len(data['item1']) == 2
        assert isinstance(data['item2'], list) and len(data['item2']) == 2 and data['item2'][0] is None
        assert data['item3'].shape == (2, 3)
        assert data['item4'].shape == (2, 5, 6)
        data = [
            {
                'item1': 1,
                'item2': 3,
                'item3': 2.0
            },
            {
                'item1': None,
                'item2': 4,
                'item3': 2.0
            },
        ]
        data = diff_shape_collate(data)
        assert isinstance(data['item1'], list) and len(data['item1']) == 2 and data['item1'][1] is None
        assert data['item2'].shape == (2, ) and data['item2'].dtype == torch.int64
        assert data['item3'].shape == (2, ) and data['item3'].dtype == torch.float32