gomoku / DI-engine /ding /utils /tests /test_default_helper.py
zjowowen's picture
init space
079c32c
raw
history blame
9.49 kB
from collections import namedtuple
import numpy as np
import pytest
import torch
import treetensor.torch as ttorch
from ding.utils.default_helper import lists_to_dicts, dicts_to_lists, squeeze, default_get, override, error_wrapper, \
list_split, LimitedSpaceContainer, set_pkg_seed, deep_merge_dicts, deep_update, flatten_dict, RunningMeanStd, \
one_time_warning, split_data_generator, get_shape0
@pytest.mark.unittest
class TestDefaultHelper():
def test_get_shape0(self):
a = {
'a': {
'b': torch.randn(4, 3)
},
'c': {
'd': torch.randn(4)
},
}
b = [a, a]
c = (a, a)
d = {
'a': {
'b': ["a", "b", "c", "d"]
},
'c': {
'd': torch.randn(4)
},
}
a = ttorch.as_tensor(a)
assert get_shape0(a) == 4
assert get_shape0(b) == 4
assert get_shape0(c) == 4
with pytest.raises(Exception) as e_info:
assert get_shape0(d) == 4
def test_lists_to_dicts(self):
set_pkg_seed(12)
with pytest.raises(ValueError):
lists_to_dicts([])
with pytest.raises(TypeError):
lists_to_dicts([1])
assert lists_to_dicts([{1: 1, 10: 3}, {1: 2, 10: 4}]) == {1: [1, 2], 10: [3, 4]}
T = namedtuple('T', ['location', 'race'])
data = [T({'x': 1, 'y': 2}, 'zerg') for _ in range(3)]
output = lists_to_dicts(data)
assert isinstance(output, T) and output.__class__ == T
assert len(output.location) == 3
data = [{'value': torch.randn(1), 'obs': {'scalar': torch.randn(4)}} for _ in range(3)]
output = lists_to_dicts(data, recursive=True)
assert isinstance(output, dict)
assert len(output['value']) == 3
assert len(output['obs']['scalar']) == 3
def test_dicts_to_lists(self):
assert dicts_to_lists({1: [1, 2], 10: [3, 4]}) == [{1: 1, 10: 3}, {1: 2, 10: 4}]
def test_squeeze(self):
assert squeeze((4, )) == 4
assert squeeze({'a': 4}) == 4
assert squeeze([1, 3]) == (1, 3)
data = np.random.randn(3)
output = squeeze(data)
assert (output == data).all()
def test_default_get(self):
assert default_get({}, 'a', default_value=1, judge_fn=lambda x: x < 2) == 1
assert default_get({}, 'a', default_fn=lambda: 1, judge_fn=lambda x: x < 2) == 1
with pytest.raises(AssertionError):
default_get({}, 'a', default_fn=lambda: 1, judge_fn=lambda x: x < 0)
assert default_get({'val': 1}, 'val', default_value=2) == 1
def test_override(self):
class foo(object):
def fun(self):
raise NotImplementedError
class foo1(foo):
@override(foo)
def fun(self):
return "a"
with pytest.raises(NameError):
class foo2(foo):
@override(foo)
def func(self):
pass
with pytest.raises(NotImplementedError):
foo().fun()
foo1().fun()
def test_error_wrapper(self):
def good_ret(a, b=1):
return a + b
wrap_good_ret = error_wrapper(good_ret, 0)
assert good_ret(1) == wrap_good_ret(1)
def bad_ret(a, b=0):
return a / b
wrap_bad_ret = error_wrapper(bad_ret, 0)
assert wrap_bad_ret(1) == 0
wrap_bad_ret_with_customized_log = error_wrapper(bad_ret, 0, 'customized_information')
def test_list_split(self):
data = [i for i in range(10)]
output, residual = list_split(data, step=4)
assert len(output) == 2
assert output[1] == [4, 5, 6, 7]
assert residual == [8, 9]
output, residual = list_split(data, step=5)
assert len(output) == 2
assert output[1] == [5, 6, 7, 8, 9]
assert residual is None
@pytest.mark.unittest
class TestLimitedSpaceContainer():
def test_container(self):
container = LimitedSpaceContainer(0, 5)
first = container.acquire_space()
assert first
assert container.cur == 1
left = container.get_residual_space()
assert left == 4
assert container.cur == container.max_val == 5
no_space = container.acquire_space()
assert not no_space
container.increase_space()
six = container.acquire_space()
assert six
for i in range(6):
container.release_space()
assert container.cur == 5 - i
container.decrease_space()
assert container.max_val == 5
@pytest.mark.unittest
class TestDict:
def test_deep_merge_dicts(self):
dict1 = {
'a': 3,
'b': {
'c': 3,
'd': {
'e': 6,
'f': 5,
}
}
}
dict2 = {
'b': {
'c': 5,
'd': 6,
'g': 4,
}
}
new_dict = deep_merge_dicts(dict1, dict2)
assert new_dict['a'] == 3
assert isinstance(new_dict['b'], dict)
assert new_dict['b']['c'] == 5
assert new_dict['b']['c'] == 5
assert new_dict['b']['g'] == 4
def test_deep_update(self):
dict1 = {
'a': 3,
'b': {
'c': 3,
'd': {
'e': 6,
'f': 5,
},
'z': 4,
}
}
dict2 = {
'b': {
'c': 5,
'd': 6,
'g': 4,
}
}
with pytest.raises(RuntimeError):
new1 = deep_update(dict1, dict2, new_keys_allowed=False)
new2 = deep_update(dict1, dict2, new_keys_allowed=False, whitelist=['b'])
assert new2['a'] == 3
assert new2['b']['c'] == 5
assert new2['b']['d'] == 6
assert new2['b']['g'] == 4
assert new2['b']['z'] == 4
dict1 = {
'a': 3,
'b': {
'type': 'old',
'z': 4,
}
}
dict2 = {
'b': {
'type': 'new',
'c': 5,
}
}
new3 = deep_update(dict1, dict2, new_keys_allowed=True, whitelist=[], override_all_if_type_changes=['b'])
assert new3['a'] == 3
assert new3['b']['type'] == 'new'
assert new3['b']['c'] == 5
assert 'z' not in new3['b']
def test_flatten_dict(self):
dict = {
'a': 3,
'b': {
'c': 3,
'd': {
'e': 6,
'f': 5,
},
'z': 4,
}
}
flat = flatten_dict(dict)
assert flat['a'] == 3
assert flat['b/c'] == 3
assert flat['b/d/e'] == 6
assert flat['b/d/f'] == 5
assert flat['b/z'] == 4
def test_one_time_warning(self):
one_time_warning('test_one_time_warning')
def test_running_mean_std(self):
running = RunningMeanStd()
running.reset()
running.update(np.arange(1, 10))
assert running.mean == pytest.approx(5, abs=1e-4)
assert running.std == pytest.approx(2.582030, abs=1e-6)
running.update(np.arange(2, 11))
assert running.mean == pytest.approx(5.5, abs=1e-4)
assert running.std == pytest.approx(2.629981, abs=1e-6)
running.reset()
running.update(np.arange(1, 10))
assert pytest.approx(running.mean, abs=1e-4) == 5
assert running.mean == pytest.approx(5, abs=1e-4)
assert running.std == pytest.approx(2.582030, abs=1e-6)
new_shape = running.new_shape((2, 4), (3, ), (1, ))
assert isinstance(new_shape, tuple) and len(new_shape) == 3
running = RunningMeanStd(shape=(4, ))
running.reset()
running.update(np.random.random((10, 4)))
assert isinstance(running.mean, torch.Tensor) and running.mean.shape == (4, )
assert isinstance(running.std, torch.Tensor) and running.std.shape == (4, )
def test_split_data_generator(self):
def get_data():
return {
'obs': torch.randn(5),
'action': torch.randint(0, 10, size=(1, )),
'prev_state': [None, None],
'info': {
'other_obs': torch.randn(5)
},
}
data = [get_data() for _ in range(4)]
data = lists_to_dicts(data)
data['obs'] = torch.stack(data['obs'])
data['action'] = torch.stack(data['action'])
data['info'] = {'other_obs': torch.stack([t['other_obs'] for t in data['info']])}
assert len(data['obs']) == 4
data['NoneKey'] = None
generator = split_data_generator(data, 3)
generator_result = list(generator)
assert len(generator_result) == 2
assert generator_result[0]['NoneKey'] is None
assert len(generator_result[0]['obs']) == 3
assert generator_result[0]['info']['other_obs'].shape == (3, 5)
assert generator_result[1]['NoneKey'] is None
assert len(generator_result[1]['obs']) == 3
assert generator_result[1]['info']['other_obs'].shape == (3, 5)
generator = split_data_generator(data, 3, shuffle=False)