File size: 3,730 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import pytest
import numpy as np
import torch
from itertools import product

from ding.model import VAC, DREAMERVAC
from ding.torch_utils import is_differentiable

from ding.model import ConvEncoder
from easydict import EasyDict

ezD = EasyDict({'action_args_shape': (3, ), 'action_type_shape': 4})
B, C, H, W = 4, 3, 128, 128
obs_shape = [4, (8, ), (4, 64, 64)]
act_args = [[6, 'discrete'], [(3, ), 'continuous'], [[2, 3, 6], 'discrete'], [ezD, 'hybrid']]
# act_args = [[(3, ), True]]
args = list(product(*[obs_shape, act_args, [False, True]]))


def output_check(model, outputs, action_shape):
    if isinstance(action_shape, tuple):
        loss = sum([t.sum() for t in outputs])
    elif np.isscalar(action_shape):
        loss = outputs.sum()
    elif isinstance(action_shape, dict):
        loss = outputs.sum()
    is_differentiable(loss, model)


def model_check(model, inputs):
    outputs = model(inputs, mode='compute_actor_critic')
    value, logit = outputs['value'], outputs['logit']
    if model.action_space == 'continuous':
        outputs = value.sum() + logit['mu'].sum() + logit['sigma'].sum()
    elif model.action_space == 'hybrid':
        outputs = value.sum() + logit['action_type'].sum() + logit['action_args']['mu'].sum(
        ) + logit['action_args']['sigma'].sum()
    else:
        if model.multi_head:
            outputs = value.sum() + sum([t.sum() for t in logit])
        else:
            outputs = value.sum() + logit.sum()
    output_check(model, outputs, 1)

    for p in model.parameters():
        p.grad = None
    logit = model(inputs, mode='compute_actor')['logit']
    if model.action_space == 'continuous':
        logit = logit['mu'].sum() + logit['sigma'].sum()
    elif model.action_space == 'hybrid':
        logit = logit['action_type'].sum() + logit['action_args']['mu'].sum() + logit['action_args']['sigma'].sum()
    output_check(model.actor, logit, model.action_shape)

    for p in model.parameters():
        p.grad = None
    value = model(inputs, mode='compute_critic')['value']
    assert value.shape == (B, )
    output_check(model.critic, value, 1)


@pytest.mark.unittest
class TestDREAMERVAC:

    def test_DREAMERVAC(self):
        obs_shape = 8
        act_shape = 6
        model = DREAMERVAC(obs_shape, act_shape)


@pytest.mark.unittest
@pytest.mark.parametrize('obs_shape, act_args, share_encoder', args)
class TestVACGeneral:

    def test_vac(self, obs_shape, act_args, share_encoder):
        if isinstance(obs_shape, int):
            inputs = torch.randn(B, obs_shape)
        else:
            inputs = torch.randn(B, *obs_shape)
        model = VAC(obs_shape, action_shape=act_args[0], action_space=act_args[1], share_encoder=share_encoder)
        model_check(model, inputs)


@pytest.mark.unittest
@pytest.mark.parametrize('share_encoder', [(False, ), (True, )])
class TestVACEncoder:

    def test_vac_with_impala_encoder(self, share_encoder):
        inputs = torch.randn(B, 4, 64, 64)
        model = VAC(
            obs_shape=(4, 64, 64),
            action_shape=6,
            action_space='discrete',
            share_encoder=share_encoder,
            impala_cnn_encoder=True
        )
        model_check(model, inputs)

    def test_encoder_assignment(self, share_encoder):
        inputs = torch.randn(B, 4, 64, 64)

        special_encoder = ConvEncoder(obs_shape=(4, 64, 64), hidden_size_list=[16, 32, 32, 64])

        model = VAC(
            obs_shape=(4, 64, 64),
            action_shape=6,
            action_space='discrete',
            share_encoder=share_encoder,
            actor_head_hidden_size=64,
            critic_head_hidden_size=64,
            encoder=special_encoder
        )
        model_check(model, inputs)