|
from collections import namedtuple |
|
|
|
import numpy as np |
|
import pytest |
|
import torch |
|
import treetensor.torch as ttorch |
|
|
|
from ding.utils.default_helper import lists_to_dicts, dicts_to_lists, squeeze, default_get, override, error_wrapper, \ |
|
list_split, LimitedSpaceContainer, set_pkg_seed, deep_merge_dicts, deep_update, flatten_dict, RunningMeanStd, \ |
|
one_time_warning, split_data_generator, get_shape0 |
|
|
|
|
|
@pytest.mark.unittest |
|
class TestDefaultHelper(): |
|
|
|
def test_get_shape0(self): |
|
a = { |
|
'a': { |
|
'b': torch.randn(4, 3) |
|
}, |
|
'c': { |
|
'd': torch.randn(4) |
|
}, |
|
} |
|
b = [a, a] |
|
c = (a, a) |
|
d = { |
|
'a': { |
|
'b': ["a", "b", "c", "d"] |
|
}, |
|
'c': { |
|
'd': torch.randn(4) |
|
}, |
|
} |
|
a = ttorch.as_tensor(a) |
|
assert get_shape0(a) == 4 |
|
assert get_shape0(b) == 4 |
|
assert get_shape0(c) == 4 |
|
with pytest.raises(Exception) as e_info: |
|
assert get_shape0(d) == 4 |
|
|
|
def test_lists_to_dicts(self): |
|
set_pkg_seed(12) |
|
with pytest.raises(ValueError): |
|
lists_to_dicts([]) |
|
with pytest.raises(TypeError): |
|
lists_to_dicts([1]) |
|
assert lists_to_dicts([{1: 1, 10: 3}, {1: 2, 10: 4}]) == {1: [1, 2], 10: [3, 4]} |
|
T = namedtuple('T', ['location', 'race']) |
|
data = [T({'x': 1, 'y': 2}, 'zerg') for _ in range(3)] |
|
output = lists_to_dicts(data) |
|
assert isinstance(output, T) and output.__class__ == T |
|
assert len(output.location) == 3 |
|
data = [{'value': torch.randn(1), 'obs': {'scalar': torch.randn(4)}} for _ in range(3)] |
|
output = lists_to_dicts(data, recursive=True) |
|
assert isinstance(output, dict) |
|
assert len(output['value']) == 3 |
|
assert len(output['obs']['scalar']) == 3 |
|
|
|
def test_dicts_to_lists(self): |
|
assert dicts_to_lists({1: [1, 2], 10: [3, 4]}) == [{1: 1, 10: 3}, {1: 2, 10: 4}] |
|
|
|
def test_squeeze(self): |
|
assert squeeze((4, )) == 4 |
|
assert squeeze({'a': 4}) == 4 |
|
assert squeeze([1, 3]) == (1, 3) |
|
data = np.random.randn(3) |
|
output = squeeze(data) |
|
assert (output == data).all() |
|
|
|
def test_default_get(self): |
|
assert default_get({}, 'a', default_value=1, judge_fn=lambda x: x < 2) == 1 |
|
assert default_get({}, 'a', default_fn=lambda: 1, judge_fn=lambda x: x < 2) == 1 |
|
with pytest.raises(AssertionError): |
|
default_get({}, 'a', default_fn=lambda: 1, judge_fn=lambda x: x < 0) |
|
assert default_get({'val': 1}, 'val', default_value=2) == 1 |
|
|
|
def test_override(self): |
|
|
|
class foo(object): |
|
|
|
def fun(self): |
|
raise NotImplementedError |
|
|
|
class foo1(foo): |
|
|
|
@override(foo) |
|
def fun(self): |
|
return "a" |
|
|
|
with pytest.raises(NameError): |
|
|
|
class foo2(foo): |
|
|
|
@override(foo) |
|
def func(self): |
|
pass |
|
|
|
with pytest.raises(NotImplementedError): |
|
foo().fun() |
|
foo1().fun() |
|
|
|
def test_error_wrapper(self): |
|
|
|
def good_ret(a, b=1): |
|
return a + b |
|
|
|
wrap_good_ret = error_wrapper(good_ret, 0) |
|
assert good_ret(1) == wrap_good_ret(1) |
|
|
|
def bad_ret(a, b=0): |
|
return a / b |
|
|
|
wrap_bad_ret = error_wrapper(bad_ret, 0) |
|
assert wrap_bad_ret(1) == 0 |
|
wrap_bad_ret_with_customized_log = error_wrapper(bad_ret, 0, 'customized_information') |
|
|
|
def test_list_split(self): |
|
data = [i for i in range(10)] |
|
output, residual = list_split(data, step=4) |
|
assert len(output) == 2 |
|
assert output[1] == [4, 5, 6, 7] |
|
assert residual == [8, 9] |
|
output, residual = list_split(data, step=5) |
|
assert len(output) == 2 |
|
assert output[1] == [5, 6, 7, 8, 9] |
|
assert residual is None |
|
|
|
|
|
@pytest.mark.unittest |
|
class TestLimitedSpaceContainer(): |
|
|
|
def test_container(self): |
|
container = LimitedSpaceContainer(0, 5) |
|
first = container.acquire_space() |
|
assert first |
|
assert container.cur == 1 |
|
left = container.get_residual_space() |
|
assert left == 4 |
|
assert container.cur == container.max_val == 5 |
|
no_space = container.acquire_space() |
|
assert not no_space |
|
container.increase_space() |
|
six = container.acquire_space() |
|
assert six |
|
for i in range(6): |
|
container.release_space() |
|
assert container.cur == 5 - i |
|
container.decrease_space() |
|
assert container.max_val == 5 |
|
|
|
|
|
@pytest.mark.unittest |
|
class TestDict: |
|
|
|
def test_deep_merge_dicts(self): |
|
dict1 = { |
|
'a': 3, |
|
'b': { |
|
'c': 3, |
|
'd': { |
|
'e': 6, |
|
'f': 5, |
|
} |
|
} |
|
} |
|
dict2 = { |
|
'b': { |
|
'c': 5, |
|
'd': 6, |
|
'g': 4, |
|
} |
|
} |
|
new_dict = deep_merge_dicts(dict1, dict2) |
|
assert new_dict['a'] == 3 |
|
assert isinstance(new_dict['b'], dict) |
|
assert new_dict['b']['c'] == 5 |
|
assert new_dict['b']['c'] == 5 |
|
assert new_dict['b']['g'] == 4 |
|
|
|
def test_deep_update(self): |
|
dict1 = { |
|
'a': 3, |
|
'b': { |
|
'c': 3, |
|
'd': { |
|
'e': 6, |
|
'f': 5, |
|
}, |
|
'z': 4, |
|
} |
|
} |
|
dict2 = { |
|
'b': { |
|
'c': 5, |
|
'd': 6, |
|
'g': 4, |
|
} |
|
} |
|
with pytest.raises(RuntimeError): |
|
new1 = deep_update(dict1, dict2, new_keys_allowed=False) |
|
new2 = deep_update(dict1, dict2, new_keys_allowed=False, whitelist=['b']) |
|
assert new2['a'] == 3 |
|
assert new2['b']['c'] == 5 |
|
assert new2['b']['d'] == 6 |
|
assert new2['b']['g'] == 4 |
|
assert new2['b']['z'] == 4 |
|
|
|
dict1 = { |
|
'a': 3, |
|
'b': { |
|
'type': 'old', |
|
'z': 4, |
|
} |
|
} |
|
dict2 = { |
|
'b': { |
|
'type': 'new', |
|
'c': 5, |
|
} |
|
} |
|
new3 = deep_update(dict1, dict2, new_keys_allowed=True, whitelist=[], override_all_if_type_changes=['b']) |
|
assert new3['a'] == 3 |
|
assert new3['b']['type'] == 'new' |
|
assert new3['b']['c'] == 5 |
|
assert 'z' not in new3['b'] |
|
|
|
def test_flatten_dict(self): |
|
dict = { |
|
'a': 3, |
|
'b': { |
|
'c': 3, |
|
'd': { |
|
'e': 6, |
|
'f': 5, |
|
}, |
|
'z': 4, |
|
} |
|
} |
|
flat = flatten_dict(dict) |
|
assert flat['a'] == 3 |
|
assert flat['b/c'] == 3 |
|
assert flat['b/d/e'] == 6 |
|
assert flat['b/d/f'] == 5 |
|
assert flat['b/z'] == 4 |
|
|
|
def test_one_time_warning(self): |
|
one_time_warning('test_one_time_warning') |
|
|
|
def test_running_mean_std(self): |
|
running = RunningMeanStd() |
|
running.reset() |
|
running.update(np.arange(1, 10)) |
|
assert running.mean == pytest.approx(5, abs=1e-4) |
|
assert running.std == pytest.approx(2.582030, abs=1e-6) |
|
running.update(np.arange(2, 11)) |
|
assert running.mean == pytest.approx(5.5, abs=1e-4) |
|
assert running.std == pytest.approx(2.629981, abs=1e-6) |
|
running.reset() |
|
running.update(np.arange(1, 10)) |
|
assert pytest.approx(running.mean, abs=1e-4) == 5 |
|
assert running.mean == pytest.approx(5, abs=1e-4) |
|
assert running.std == pytest.approx(2.582030, abs=1e-6) |
|
new_shape = running.new_shape((2, 4), (3, ), (1, )) |
|
assert isinstance(new_shape, tuple) and len(new_shape) == 3 |
|
|
|
running = RunningMeanStd(shape=(4, )) |
|
running.reset() |
|
running.update(np.random.random((10, 4))) |
|
assert isinstance(running.mean, torch.Tensor) and running.mean.shape == (4, ) |
|
assert isinstance(running.std, torch.Tensor) and running.std.shape == (4, ) |
|
|
|
def test_split_data_generator(self): |
|
|
|
def get_data(): |
|
return { |
|
'obs': torch.randn(5), |
|
'action': torch.randint(0, 10, size=(1, )), |
|
'prev_state': [None, None], |
|
'info': { |
|
'other_obs': torch.randn(5) |
|
}, |
|
} |
|
|
|
data = [get_data() for _ in range(4)] |
|
data = lists_to_dicts(data) |
|
data['obs'] = torch.stack(data['obs']) |
|
data['action'] = torch.stack(data['action']) |
|
data['info'] = {'other_obs': torch.stack([t['other_obs'] for t in data['info']])} |
|
assert len(data['obs']) == 4 |
|
data['NoneKey'] = None |
|
generator = split_data_generator(data, 3) |
|
generator_result = list(generator) |
|
assert len(generator_result) == 2 |
|
assert generator_result[0]['NoneKey'] is None |
|
assert len(generator_result[0]['obs']) == 3 |
|
assert generator_result[0]['info']['other_obs'].shape == (3, 5) |
|
assert generator_result[1]['NoneKey'] is None |
|
assert len(generator_result[1]['obs']) == 3 |
|
assert generator_result[1]['info']['other_obs'].shape == (3, 5) |
|
|
|
generator = split_data_generator(data, 3, shuffle=False) |
|
|