gomoku / DI-engine /ding /rl_utils /tests /test_value_rescale.py
zjowowen's picture
init space
079c32c
raw
history blame
1.66 kB
import pytest
import torch
from ding.rl_utils.value_rescale import value_inv_transform, value_transform, symlog, inv_symlog
@pytest.mark.unittest
class TestValueRescale:
def test_value_transform(self):
for _ in range(10):
t = torch.rand((2, 3))
assert isinstance(value_transform(t), torch.Tensor)
assert value_transform(t).shape == t.shape
def test_value_inv_transform(self):
for _ in range(10):
t = torch.rand((2, 3))
assert isinstance(value_inv_transform(t), torch.Tensor)
assert value_inv_transform(t).shape == t.shape
def test_trans_inverse(self):
for _ in range(10):
t = torch.rand((4, 16))
diff = value_inv_transform(value_transform(t)) - t
assert pytest.approx(diff.abs().max().item(), abs=2e-5) == 0
assert pytest.approx(diff.abs().max().item(), abs=2e-5) == 0
@pytest.mark.unittest
class TestSymlog:
def test_symlog(self):
for _ in range(10):
t = torch.rand((3, 4))
assert isinstance(symlog(t), torch.Tensor)
assert symlog(t).shape == t.shape
def test_inv_symlog(self):
for _ in range(10):
t = torch.rand((3, 4))
assert isinstance(inv_symlog(t), torch.Tensor)
assert inv_symlog(t).shape == t.shape
def test_trans_inverse(self):
for _ in range(10):
t = torch.rand((4, 16))
diff = inv_symlog(symlog(t)) - t
assert pytest.approx(diff.abs().max().item(), abs=2e-5) == 0
assert pytest.approx(diff.abs().max().item(), abs=2e-5) == 0