File size: 2,290 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import numpy as np
import pytest
from itertools import product

from ding.model.template import ContinuousQAC
from ding.torch_utils import is_differentiable
from ding.utils import squeeze
from easydict import EasyDict
B = 4
T = 6
embedding_size = 32
hybrid_args = {
    'action_shape': EasyDict({
        'action_type_shape': (4, ),
        'action_args_shape': (6, )
    }),
    'twin': True,
    'action_space': 'hybrid'
}


@pytest.mark.unittest
class TestHybridContinuousQAC:

    def test_hybrid_qac(
        self,
        action_shape=hybrid_args['action_shape'],
        twin=hybrid_args['twin'],
        action_space=hybrid_args['action_space']
    ):
        N = 32
        assert action_space == 'hybrid'
        inputs = {
            'obs': torch.randn(B, N),
            'action': {
                'action_type': torch.randint(0, squeeze(action_shape.action_type_shape), (B, )),
                'action_args': torch.rand(B, squeeze(action_shape.action_args_shape))
            },
            'logit': torch.randn(B, squeeze(action_shape.action_type_shape))
        }
        model = ContinuousQAC(
            obs_shape=(N, ),
            action_shape=action_shape,
            action_space=action_space,
            critic_head_hidden_size=embedding_size,
            actor_head_hidden_size=embedding_size,
            twin_critic=twin,
        )
        # compute_q
        q = model(inputs, mode='compute_critic')['q_value']
        if twin:
            is_differentiable(q[0].sum(), model.critic[1][0])
            is_differentiable(q[1].sum(), model.critic[1][1])
        else:
            is_differentiable(q.sum(), model.critic)

        # compute_action
        print(model)

        output = model(inputs['obs'], mode='compute_actor')
        discrete_logit = output['logit']
        continuous_args = output['action_args']
        # test discrete action_type + continuous action_args
        if squeeze(action_shape.action_type_shape) == 1:
            assert discrete_logit.shape == (B, )
        else:
            assert discrete_logit.shape == (B, squeeze(action_shape.action_type_shape))
        assert continuous_args.shape == (B, action_shape.action_args_shape)
        is_differentiable(discrete_logit.sum() + continuous_args.sum(), model.actor)