File size: 4,104 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import os
import random
import shutil
import numpy as np
import pytest
import torch
from ding.envs.common.common_function import sqrt_one_hot, div_one_hot, div_func, clip_one_hot, \
reorder_one_hot, reorder_one_hot_array, reorder_boolean_vector, \
batch_binary_encode, get_postion_vector, \
affine_transform, save_frames_as_gif
VALUES = [2, 3, 5, 7, 11]
def setup_reorder_array():
ret = np.full((12), -1)
for i, v in enumerate(VALUES):
ret[v] = i
return ret
def setup_reorder_dict():
return {v: i for i, v in enumerate(VALUES)}
def generate_data():
ret = {
'obs': np.random.randn(4),
p_weight = np.random.uniform()
if p_weight < 1. / 3:
pass # no key 'priority'
elif p_weight < 2. / 3:
ret['priority'] = None
ret['priority'] = np.random.uniform()
return ret
class TestEnvCommonFunc:
def test_one_hot(self):
a = torch.Tensor([[3, 4, 5], [1, 2, 6]])
a_sqrt = sqrt_one_hot(a, 6)
assert a_sqrt.max().item() == 1
assert [j.sum().item() for i in a_sqrt for j in i] == [1 for _ in range(6)]
sqrt_dim = 3
assert a_sqrt.shape == (2, 3, sqrt_dim)
a_div = div_one_hot(a, 6, 2)
assert a_div.max().item() == 1
assert [j.sum().item() for i in a_div for j in i] == [1 for _ in range(6)]
div_dim = 4
assert a_div.shape == (2, 3, div_dim)
a_di = div_func(a, 2)
assert a_di.shape == (2, 1, 3)
assert torch.eq(a_di.squeeze() * 2, a).all()
a_clip = clip_one_hot(a.long(), 4)
assert a_clip.max().item() == 1
assert [j.sum().item() for i in a_clip for j in i] == [1 for _ in range(6)]
clip_dim = 4
assert a_clip.shape == (2, 3, clip_dim)
def test_reorder(self, setup_reorder_array, setup_reorder_dict):
a = torch.LongTensor([2, 7]) # VALUES = [2, 3, 5, 7, 11]
a_array = reorder_one_hot_array(a, setup_reorder_array, 5)
a_dict = reorder_one_hot(a, setup_reorder_dict, 5)
assert torch.eq(a_array, a_dict).all()
assert a_array.max().item() == 1
assert [j.sum().item() for j in a_array] == [1 for _ in range(2)]
reorder_dim = 5
assert a_array.shape == (2, reorder_dim)
a_bool = reorder_boolean_vector(a, setup_reorder_dict, 5)
assert a_array.max().item() == 1
assert torch.eq(a_bool, sum([_ for _ in a_array])).all()
def test_binary(self):
a = torch.LongTensor([445, 1023])
a_binary = batch_binary_encode(a, 10)
ans = []
for number in a:
one = [int(_) for _ in list(bin(number))[2:]]
for _ in range(10 - len(one)):
one.insert(0, 0)
ans = torch.Tensor(ans)
assert torch.eq(a_binary, ans).all()
def test_position(self):
a = [random.randint(0, 5000) for _ in range(32)]
a_position = get_postion_vector(a)
assert a_position.shape == (64, )
def test_affine_transform(self):
a = torch.rand(4, 3)
a = (a - a.min()) / (a.max() - a.min())
a = a * 2 - 1
ans = affine_transform(a, min_val=-2, max_val=2)
assert ans.shape == (4, 3)
assert ans.min() == -2 and ans.max() == 2
a = np.random.rand(3, 5)
a = (a - a.min()) / (a.max() - a.min())
a = a * 2 - 1
ans = affine_transform(a, alpha=4, beta=1)
assert ans.shape == (3, 5)
assert ans.min() == -3 and ans.max() == 5
def test_save_frames_as_gif():
frames = [np.random.randint(0, 255, [84, 84, 3]) for _ in range(100)]
replay_path_gif = './replay_path_gif'
env_id = 'test'
save_replay_count = 1
if not os.path.exists(replay_path_gif):
path = os.path.join(replay_path_gif, '{}_episode_{}.gif'.format(env_id, save_replay_count))
save_frames_as_gif(frames, path)