File size: 1,776 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import pytest
import numpy as np
import torch
from itertools import product

from ding.model import mavac
from ding.model.template.mavac import MAVAC
from ding.torch_utils import is_differentiable

B = 32
agent_obs_shape = [216, 265]
global_obs_shape = [264, 324]
agent_num = 8
action_shape = 14
args = list(product(*[agent_obs_shape, global_obs_shape]))


@pytest.mark.unittest
@pytest.mark.parametrize('agent_obs_shape, global_obs_shape', args)
class TestVAC:

    def output_check(self, model, outputs, action_shape):
        if isinstance(action_shape, tuple):
            loss = sum([t.sum() for t in outputs])
        elif np.isscalar(action_shape):
            loss = outputs.sum()
        is_differentiable(loss, model)

    def test_vac(self, agent_obs_shape, global_obs_shape):
        data = {
            'agent_state': torch.randn(B, agent_num, agent_obs_shape),
            'global_state': torch.randn(B, agent_num, global_obs_shape),
            'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape))
        }
        model = MAVAC(agent_obs_shape, global_obs_shape, action_shape, agent_num)

        logit = model(data, mode='compute_actor_critic')['logit']
        value = model(data, mode='compute_actor_critic')['value']

        outputs = value.sum() + logit.sum()
        self.output_check(model, outputs, action_shape)

        for p in model.parameters():
            p.grad = None
        logit = model(data, mode='compute_actor')['logit']
        self.output_check(model.actor, logit, model.action_shape)

        for p in model.parameters():
            p.grad = None
        value = model(data, mode='compute_critic')['value']
        assert value.shape == (B, agent_num)
        self.output_check(model.critic, value, action_shape)