import torch import numpy as np import pytest from ding.model.common.head import DuelingHead, ReparameterizationHead, MultiHead, StochasticDuelingHead, EnsembleHead from ding.torch_utils import is_differentiable B = 4 T = 6 embedding_dim = 64 action_shape = 12 @pytest.mark.unittest class TestHead: def output_check(self, model, outputs): if isinstance(outputs, torch.Tensor): loss = outputs.sum() elif isinstance(outputs, list): loss = sum([t.sum() for t in outputs]) elif isinstance(outputs, dict): loss = sum([v.sum() for v in outputs.values()]) is_differentiable(loss, model) def test_dueling(self): inputs = torch.randn(B, embedding_dim) model = DuelingHead(embedding_dim, action_shape, 3, 3) outputs = model(inputs)['logit'] self.output_check(model, outputs) assert outputs.shape == (B, action_shape) @pytest.mark.parametrize('action_shape', [1, 8]) def test_reparameterization(self, action_shape): inputs = torch.randn(B, embedding_dim) for sigma_type in ['fixed', 'independent', 'conditioned']: if sigma_type == 'fixed': model = ReparameterizationHead( embedding_dim, action_shape, sigma_type=sigma_type, fixed_sigma_value=0.5 ) outputs = model(inputs) mu, sigma = outputs['mu'], outputs['sigma'] assert mu.shape == (B, action_shape) and sigma.shape == (B, action_shape) assert sigma.eq(torch.full((B, action_shape), 0.5)).all() self.output_check(model, outputs) elif sigma_type == 'independent': model = ReparameterizationHead(embedding_dim, action_shape, sigma_type=sigma_type) outputs = model(inputs) mu, sigma = outputs['mu'], outputs['sigma'] assert mu.shape == (B, action_shape) and sigma.shape == (B, action_shape) self.output_check(model, outputs) assert model.log_sigma_param.grad is not None elif sigma_type == 'conditioned': model = ReparameterizationHead(embedding_dim, action_shape, sigma_type=sigma_type) outputs = model(inputs) mu, sigma = outputs['mu'], outputs['sigma'] assert mu.shape == (B, action_shape) and sigma.shape == (B, action_shape) self.output_check(model, outputs) def test_multi_head(self): output_size_list = [2, 3, 7] head = MultiHead(DuelingHead, embedding_dim, output_size_list, activation=torch.nn.Tanh()) print(head) inputs = torch.randn(B, embedding_dim) outputs = head(inputs) assert isinstance(outputs, dict) self.output_check(head, outputs['logit']) for i, d in enumerate(output_size_list): assert outputs['logit'][i].shape == (B, d) @pytest.mark.tmp def test_stochastic_dueling(self): obs = torch.randn(B, embedding_dim) behaviour_action = torch.randn(B, action_shape).clamp(-1, 1) mu = torch.randn(B, action_shape).requires_grad_(True) sigma = torch.rand(B, action_shape).requires_grad_(True) model = StochasticDuelingHead(embedding_dim, action_shape, 3, 3) assert mu.grad is None and sigma.grad is None outputs = model(obs, behaviour_action, mu, sigma) self.output_check(model, outputs['q_value']) assert isinstance(mu.grad, torch.Tensor) print(mu.grad) assert isinstance(sigma.grad, torch.Tensor) assert outputs['q_value'].shape == (B, 1) assert outputs['v_value'].shape == (B, 1) def test_ensemble(self): inputs = torch.randn(B, embedding_dim * 3, 1) model = EnsembleHead(embedding_dim, action_shape, 3, 3, 3) outputs = model(inputs)['pred'] self.output_check(model, outputs) assert outputs.shape == (B, action_shape * 3)