File size: 4,007 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import numpy as np
import pytest

from ding.model.common.head import DuelingHead, ReparameterizationHead, MultiHead, StochasticDuelingHead, EnsembleHead
from ding.torch_utils import is_differentiable

B = 4
T = 6
embedding_dim = 64
action_shape = 12


@pytest.mark.unittest
class TestHead:

    def output_check(self, model, outputs):
        if isinstance(outputs, torch.Tensor):
            loss = outputs.sum()
        elif isinstance(outputs, list):
            loss = sum([t.sum() for t in outputs])
        elif isinstance(outputs, dict):
            loss = sum([v.sum() for v in outputs.values()])
        is_differentiable(loss, model)

    def test_dueling(self):
        inputs = torch.randn(B, embedding_dim)
        model = DuelingHead(embedding_dim, action_shape, 3, 3)
        outputs = model(inputs)['logit']
        self.output_check(model, outputs)
        assert outputs.shape == (B, action_shape)

    @pytest.mark.parametrize('action_shape', [1, 8])
    def test_reparameterization(self, action_shape):
        inputs = torch.randn(B, embedding_dim)
        for sigma_type in ['fixed', 'independent', 'conditioned']:
            if sigma_type == 'fixed':
                model = ReparameterizationHead(
                    embedding_dim, action_shape, sigma_type=sigma_type, fixed_sigma_value=0.5
                )
                outputs = model(inputs)
                mu, sigma = outputs['mu'], outputs['sigma']
                assert mu.shape == (B, action_shape) and sigma.shape == (B, action_shape)
                assert sigma.eq(torch.full((B, action_shape), 0.5)).all()
                self.output_check(model, outputs)
            elif sigma_type == 'independent':
                model = ReparameterizationHead(embedding_dim, action_shape, sigma_type=sigma_type)
                outputs = model(inputs)
                mu, sigma = outputs['mu'], outputs['sigma']
                assert mu.shape == (B, action_shape) and sigma.shape == (B, action_shape)
                self.output_check(model, outputs)
                assert model.log_sigma_param.grad is not None
            elif sigma_type == 'conditioned':
                model = ReparameterizationHead(embedding_dim, action_shape, sigma_type=sigma_type)
                outputs = model(inputs)
                mu, sigma = outputs['mu'], outputs['sigma']
                assert mu.shape == (B, action_shape) and sigma.shape == (B, action_shape)
                self.output_check(model, outputs)

    def test_multi_head(self):
        output_size_list = [2, 3, 7]
        head = MultiHead(DuelingHead, embedding_dim, output_size_list, activation=torch.nn.Tanh())
        print(head)
        inputs = torch.randn(B, embedding_dim)
        outputs = head(inputs)
        assert isinstance(outputs, dict)
        self.output_check(head, outputs['logit'])
        for i, d in enumerate(output_size_list):
            assert outputs['logit'][i].shape == (B, d)

    @pytest.mark.tmp
    def test_stochastic_dueling(self):
        obs = torch.randn(B, embedding_dim)
        behaviour_action = torch.randn(B, action_shape).clamp(-1, 1)
        mu = torch.randn(B, action_shape).requires_grad_(True)
        sigma = torch.rand(B, action_shape).requires_grad_(True)
        model = StochasticDuelingHead(embedding_dim, action_shape, 3, 3)

        assert mu.grad is None and sigma.grad is None
        outputs = model(obs, behaviour_action, mu, sigma)
        self.output_check(model, outputs['q_value'])
        assert isinstance(mu.grad, torch.Tensor)
        print(mu.grad)
        assert isinstance(sigma.grad, torch.Tensor)
        assert outputs['q_value'].shape == (B, 1)
        assert outputs['v_value'].shape == (B, 1)

    def test_ensemble(self):
        inputs = torch.randn(B, embedding_dim * 3, 1)
        model = EnsembleHead(embedding_dim, action_shape, 3, 3, 3)
        outputs = model(inputs)['pred']
        self.output_check(model, outputs)
        assert outputs.shape == (B, action_shape * 3)