import pytest import torch from lzero.policy.scaling_transform import inverse_scalar_transform, InverseScalarTransform @pytest.mark.unittest def test_scaling_transform(): import time logit = torch.randn(16, 601) start = time.time() output_1 = inverse_scalar_transform(logit, 300) print('t1', time.time() - start) handle = InverseScalarTransform(300) start = time.time() output_2 = handle(logit) print('t2', time.time() - start) assert output_1.shape == output_2.shape == (16, 1) assert (output_1 == output_2).all()