|
import os |
|
import random |
|
import shutil |
|
|
|
import numpy as np |
|
import pytest |
|
import torch |
|
from ding.envs.common.common_function import sqrt_one_hot, div_one_hot, div_func, clip_one_hot, \ |
|
reorder_one_hot, reorder_one_hot_array, reorder_boolean_vector, \ |
|
batch_binary_encode, get_postion_vector, \ |
|
affine_transform, save_frames_as_gif |
|
|
|
VALUES = [2, 3, 5, 7, 11] |
|
|
|
|
|
@pytest.fixture(scope="function") |
|
def setup_reorder_array(): |
|
ret = np.full((12), -1) |
|
for i, v in enumerate(VALUES): |
|
ret[v] = i |
|
return ret |
|
|
|
|
|
@pytest.fixture(scope="function") |
|
def setup_reorder_dict(): |
|
return {v: i for i, v in enumerate(VALUES)} |
|
|
|
|
|
def generate_data(): |
|
ret = { |
|
'obs': np.random.randn(4), |
|
} |
|
p_weight = np.random.uniform() |
|
if p_weight < 1. / 3: |
|
pass |
|
elif p_weight < 2. / 3: |
|
ret['priority'] = None |
|
else: |
|
ret['priority'] = np.random.uniform() |
|
|
|
return ret |
|
|
|
|
|
@pytest.mark.unittest |
|
class TestEnvCommonFunc: |
|
|
|
def test_one_hot(self): |
|
a = torch.Tensor([[3, 4, 5], [1, 2, 6]]) |
|
|
|
a_sqrt = sqrt_one_hot(a, 6) |
|
assert a_sqrt.max().item() == 1 |
|
assert [j.sum().item() for i in a_sqrt for j in i] == [1 for _ in range(6)] |
|
sqrt_dim = 3 |
|
assert a_sqrt.shape == (2, 3, sqrt_dim) |
|
|
|
a_div = div_one_hot(a, 6, 2) |
|
assert a_div.max().item() == 1 |
|
assert [j.sum().item() for i in a_div for j in i] == [1 for _ in range(6)] |
|
div_dim = 4 |
|
assert a_div.shape == (2, 3, div_dim) |
|
|
|
a_di = div_func(a, 2) |
|
assert a_di.shape == (2, 1, 3) |
|
assert torch.eq(a_di.squeeze() * 2, a).all() |
|
|
|
a_clip = clip_one_hot(a.long(), 4) |
|
assert a_clip.max().item() == 1 |
|
assert [j.sum().item() for i in a_clip for j in i] == [1 for _ in range(6)] |
|
clip_dim = 4 |
|
assert a_clip.shape == (2, 3, clip_dim) |
|
|
|
def test_reorder(self, setup_reorder_array, setup_reorder_dict): |
|
a = torch.LongTensor([2, 7]) |
|
|
|
a_array = reorder_one_hot_array(a, setup_reorder_array, 5) |
|
a_dict = reorder_one_hot(a, setup_reorder_dict, 5) |
|
assert torch.eq(a_array, a_dict).all() |
|
assert a_array.max().item() == 1 |
|
assert [j.sum().item() for j in a_array] == [1 for _ in range(2)] |
|
reorder_dim = 5 |
|
assert a_array.shape == (2, reorder_dim) |
|
|
|
a_bool = reorder_boolean_vector(a, setup_reorder_dict, 5) |
|
assert a_array.max().item() == 1 |
|
assert torch.eq(a_bool, sum([_ for _ in a_array])).all() |
|
|
|
def test_binary(self): |
|
a = torch.LongTensor([445, 1023]) |
|
a_binary = batch_binary_encode(a, 10) |
|
ans = [] |
|
for number in a: |
|
one = [int(_) for _ in list(bin(number))[2:]] |
|
for _ in range(10 - len(one)): |
|
one.insert(0, 0) |
|
ans.append(one) |
|
ans = torch.Tensor(ans) |
|
assert torch.eq(a_binary, ans).all() |
|
|
|
def test_position(self): |
|
a = [random.randint(0, 5000) for _ in range(32)] |
|
a_position = get_postion_vector(a) |
|
assert a_position.shape == (64, ) |
|
|
|
def test_affine_transform(self): |
|
a = torch.rand(4, 3) |
|
a = (a - a.min()) / (a.max() - a.min()) |
|
a = a * 2 - 1 |
|
ans = affine_transform(a, min_val=-2, max_val=2) |
|
assert ans.shape == (4, 3) |
|
assert ans.min() == -2 and ans.max() == 2 |
|
a = np.random.rand(3, 5) |
|
a = (a - a.min()) / (a.max() - a.min()) |
|
a = a * 2 - 1 |
|
ans = affine_transform(a, alpha=4, beta=1) |
|
assert ans.shape == (3, 5) |
|
assert ans.min() == -3 and ans.max() == 5 |
|
|
|
|
|
@pytest.mark.other |
|
def test_save_frames_as_gif(): |
|
frames = [np.random.randint(0, 255, [84, 84, 3]) for _ in range(100)] |
|
replay_path_gif = './replay_path_gif' |
|
env_id = 'test' |
|
save_replay_count = 1 |
|
if not os.path.exists(replay_path_gif): |
|
os.makedirs(replay_path_gif) |
|
path = os.path.join(replay_path_gif, '{}_episode_{}.gif'.format(env_id, save_replay_count)) |
|
save_frames_as_gif(frames, path) |
|
shutil.rmtree(replay_path_gif) |
|
|