import unittest import pytest import torch from ding.torch_utils.parameter import NonegativeParameter, TanhParameter @pytest.mark.unittest def test_nonegative_parameter(): nonegative_parameter = NonegativeParameter(torch.tensor([2.0, 3.0])) assert torch.sum(torch.abs(nonegative_parameter() - torch.tensor([2.0, 3.0]))) == 0 nonegative_parameter.set_data(torch.tensor(1)) assert nonegative_parameter() == 1 @pytest.mark.unittest def test_tanh_parameter(): tanh_parameter = TanhParameter(torch.tensor([0.5, -0.2])) assert torch.isclose(tanh_parameter() - torch.tensor([0.5, -0.2]), torch.zeros(2), atol=1e-6).all() tanh_parameter.set_data(torch.tensor(0.3)) assert tanh_parameter() == 0.3 if __name__ == "__main__": test_nonegative_parameter() test_tanh_parameter()