zjowowen's picture
init space
079c32c
raw
history blame
4.01 kB
import torch
import numpy as np
import pytest
from ding.model.common.head import DuelingHead, ReparameterizationHead, MultiHead, StochasticDuelingHead, EnsembleHead
from ding.torch_utils import is_differentiable
B = 4
T = 6
embedding_dim = 64
action_shape = 12
@pytest.mark.unittest
class TestHead:
def output_check(self, model, outputs):
if isinstance(outputs, torch.Tensor):
loss = outputs.sum()
elif isinstance(outputs, list):
loss = sum([t.sum() for t in outputs])
elif isinstance(outputs, dict):
loss = sum([v.sum() for v in outputs.values()])
is_differentiable(loss, model)
def test_dueling(self):
inputs = torch.randn(B, embedding_dim)
model = DuelingHead(embedding_dim, action_shape, 3, 3)
outputs = model(inputs)['logit']
self.output_check(model, outputs)
assert outputs.shape == (B, action_shape)
@pytest.mark.parametrize('action_shape', [1, 8])
def test_reparameterization(self, action_shape):
inputs = torch.randn(B, embedding_dim)
for sigma_type in ['fixed', 'independent', 'conditioned']:
if sigma_type == 'fixed':
model = ReparameterizationHead(
embedding_dim, action_shape, sigma_type=sigma_type, fixed_sigma_value=0.5
)
outputs = model(inputs)
mu, sigma = outputs['mu'], outputs['sigma']
assert mu.shape == (B, action_shape) and sigma.shape == (B, action_shape)
assert sigma.eq(torch.full((B, action_shape), 0.5)).all()
self.output_check(model, outputs)
elif sigma_type == 'independent':
model = ReparameterizationHead(embedding_dim, action_shape, sigma_type=sigma_type)
outputs = model(inputs)
mu, sigma = outputs['mu'], outputs['sigma']
assert mu.shape == (B, action_shape) and sigma.shape == (B, action_shape)
self.output_check(model, outputs)
assert model.log_sigma_param.grad is not None
elif sigma_type == 'conditioned':
model = ReparameterizationHead(embedding_dim, action_shape, sigma_type=sigma_type)
outputs = model(inputs)
mu, sigma = outputs['mu'], outputs['sigma']
assert mu.shape == (B, action_shape) and sigma.shape == (B, action_shape)
self.output_check(model, outputs)
def test_multi_head(self):
output_size_list = [2, 3, 7]
head = MultiHead(DuelingHead, embedding_dim, output_size_list, activation=torch.nn.Tanh())
print(head)
inputs = torch.randn(B, embedding_dim)
outputs = head(inputs)
assert isinstance(outputs, dict)
self.output_check(head, outputs['logit'])
for i, d in enumerate(output_size_list):
assert outputs['logit'][i].shape == (B, d)
@pytest.mark.tmp
def test_stochastic_dueling(self):
obs = torch.randn(B, embedding_dim)
behaviour_action = torch.randn(B, action_shape).clamp(-1, 1)
mu = torch.randn(B, action_shape).requires_grad_(True)
sigma = torch.rand(B, action_shape).requires_grad_(True)
model = StochasticDuelingHead(embedding_dim, action_shape, 3, 3)
assert mu.grad is None and sigma.grad is None
outputs = model(obs, behaviour_action, mu, sigma)
self.output_check(model, outputs['q_value'])
assert isinstance(mu.grad, torch.Tensor)
print(mu.grad)
assert isinstance(sigma.grad, torch.Tensor)
assert outputs['q_value'].shape == (B, 1)
assert outputs['v_value'].shape == (B, 1)
def test_ensemble(self):
inputs = torch.randn(B, embedding_dim * 3, 1)
model = EnsembleHead(embedding_dim, action_shape, 3, 3, 3)
outputs = model(inputs)['pred']
self.output_check(model, outputs)
assert outputs.shape == (B, action_shape * 3)