File size: 4,703 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import numpy as np
import pytest
from itertools import product

from ding.model.template import DiscreteMAQAC, ContinuousMAQAC
from ding.torch_utils import is_differentiable
from ding.utils.default_helper import squeeze

B = 32
agent_obs_shape = [216, 265]
global_obs_shape = [264, 324]
agent_num = 8
action_shape = 14
args = list(product(*[agent_obs_shape, global_obs_shape, [False, True]]))


@pytest.mark.unittest
@pytest.mark.parametrize('agent_obs_shape, global_obs_shape, twin_critic', args)
class TestDiscreteMAQAC:

    def output_check(self, model, outputs, action_shape):
        if isinstance(action_shape, tuple):
            loss = sum([t.sum() for t in outputs])
        elif np.isscalar(action_shape):
            loss = outputs.sum()
        is_differentiable(loss, model)

    def test_maqac(self, agent_obs_shape, global_obs_shape, twin_critic):
        data = {
            'obs': {
                'agent_state': torch.randn(B, agent_num, agent_obs_shape),
                'global_state': torch.randn(B, agent_num, global_obs_shape),
                'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape))
            }
        }
        model = DiscreteMAQAC(agent_obs_shape, global_obs_shape, action_shape, twin_critic=twin_critic)

        logit = model(data, mode='compute_actor')['logit']
        value = model(data, mode='compute_critic')['q_value']

        value_sum = sum(t.sum() for t in value) if twin_critic else value.sum()
        outputs = value_sum + logit.sum()
        self.output_check(model, outputs, action_shape)

        for p in model.parameters():
            p.grad = None
        logit = model(data, mode='compute_actor')['logit']
        self.output_check(model.actor, logit, action_shape)

        for p in model.parameters():
            p.grad = None
        value = model(data, mode='compute_critic')['q_value']
        if twin_critic:
            for v in value:
                assert v.shape == (B, agent_num, action_shape)
        else:
            assert value.shape == (B, agent_num, action_shape)
        self.output_check(model.critic, sum(t.sum() for t in value) if twin_critic else value.sum(), action_shape)


B = 32
agent_obs_shape = [216, 265]
global_obs_shape = [264, 324]
agent_num = 8
action_shape = 14
action_space = ['regression', 'reparameterization']
args = list(product(*[agent_obs_shape, global_obs_shape, action_space, [False, True]]))


@pytest.mark.unittest
@pytest.mark.parametrize('agent_obs_shape, global_obs_shape, action_space, twin_critic', args)
class TestContinuousMAQAC:

    def output_check(self, model, outputs, action_shape):
        if isinstance(action_shape, tuple):
            loss = sum([t.sum() for t in outputs])
        elif np.isscalar(action_shape):
            loss = outputs.sum()
        is_differentiable(loss, model)

    def test_continuousmaqac(self, agent_obs_shape, global_obs_shape, action_space, twin_critic):
        data = {
            'obs': {
                'agent_state': torch.randn(B, agent_num, agent_obs_shape),
                'global_state': torch.randn(B, agent_num, global_obs_shape),
                'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape))
            },
            'action': torch.randn(B, agent_num, squeeze(action_shape))
        }
        model = ContinuousMAQAC(agent_obs_shape, global_obs_shape, action_shape, action_space, twin_critic=twin_critic)

        for p in model.parameters():
            p.grad = None

        if action_space == 'regression':
            action = model(data['obs'], mode='compute_actor')['action']
            if squeeze(action_shape) == 1:
                assert action.shape == (B, )
            else:
                assert action.shape == (B, agent_num, squeeze(action_shape))
            assert action.eq(action.clamp(-1, 1)).all()
            self.output_check(model.actor, action, action_shape)
            #is_differentiable(action.sum(), model.actor)
        elif action_space == 'reparameterization':
            (mu, sigma) = model(data['obs'], mode='compute_actor')['logit']
            assert mu.shape == (B, agent_num, action_shape)
            assert sigma.shape == (B, agent_num, action_shape)
            is_differentiable(mu.sum() + sigma.sum(), model.actor)

        for p in model.parameters():
            p.grad = None
        value = model(data, mode='compute_critic')['q_value']
        if twin_critic:
            for v in value:
                assert v.shape == (B, agent_num)
        else:
            assert value.shape == (B, agent_num)
        self.output_check(model.critic, sum(t.sum() for t in value) if twin_critic else value.sum(), action_shape)