import pytest import torch from ding.torch_utils.distribution import Pd, CategoricalPd, CategoricalPdPytorch @pytest.mark.unittest class TestProbDistribution: def test_Pd(self): pd = Pd() with pytest.raises(NotImplementedError): pd.neglogp(torch.randn(5, )) with pytest.raises(NotImplementedError): pd.noise_mode() with pytest.raises(NotImplementedError): pd.mode() with pytest.raises(NotImplementedError): pd.sample() def test_CatePD(self): pd = CategoricalPd() logit1 = torch.randn(3, 5, requires_grad=True) logit2 = torch.randint(5, (3, ), dtype=torch.int64) pd.update_logits(logit1) entropy = pd.neglogp(logit2) assert entropy.requires_grad assert entropy.shape == torch.Size([]) entropy = pd.entropy() assert entropy.requires_grad assert entropy.shape == torch.Size([]) entropy = pd.entropy(reduction=None) assert entropy.requires_grad assert entropy.shape == torch.Size([3]) ret = pd.sample() assert ret.shape == torch.Size([3]) ret = pd.sample(viz=True) assert ret[0].shape == torch.Size([3]) ret = pd.mode() assert ret.shape == torch.Size([3]) ret = pd.mode(viz=True) assert ret[0].shape == torch.Size([3]) ret = pd.noise_mode() assert ret.shape == torch.Size([3]) ret = pd.noise_mode(viz=True) assert ret[0].shape == torch.Size([3]) pd = CategoricalPdPytorch() pd.update_logits(logit1) ret = pd.sample() assert ret.shape == torch.Size([3]) ret = pd.mode() assert ret.shape == torch.Size([3]) entropy = pd.entropy(reduction='mean') assert entropy.requires_grad assert entropy.shape == torch.Size([]) entropy = pd.entropy(reduction=None) assert entropy.requires_grad assert entropy.shape == torch.Size([3])