zjowowen's picture
init space
079c32c
raw
history blame
1.78 kB
import pytest
import numpy as np
import torch
from itertools import product
from ding.model import mavac
from ding.model.template.mavac import MAVAC
from ding.torch_utils import is_differentiable
B = 32
agent_obs_shape = [216, 265]
global_obs_shape = [264, 324]
agent_num = 8
action_shape = 14
args = list(product(*[agent_obs_shape, global_obs_shape]))
@pytest.mark.unittest
@pytest.mark.parametrize('agent_obs_shape, global_obs_shape', args)
class TestVAC:
def output_check(self, model, outputs, action_shape):
if isinstance(action_shape, tuple):
loss = sum([t.sum() for t in outputs])
elif np.isscalar(action_shape):
loss = outputs.sum()
is_differentiable(loss, model)
def test_vac(self, agent_obs_shape, global_obs_shape):
data = {
'agent_state': torch.randn(B, agent_num, agent_obs_shape),
'global_state': torch.randn(B, agent_num, global_obs_shape),
'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape))
}
model = MAVAC(agent_obs_shape, global_obs_shape, action_shape, agent_num)
logit = model(data, mode='compute_actor_critic')['logit']
value = model(data, mode='compute_actor_critic')['value']
outputs = value.sum() + logit.sum()
self.output_check(model, outputs, action_shape)
for p in model.parameters():
p.grad = None
logit = model(data, mode='compute_actor')['logit']
self.output_check(model.actor, logit, model.action_shape)
for p in model.parameters():
p.grad = None
value = model(data, mode='compute_critic')['value']
assert value.shape == (B, agent_num)
self.output_check(model.critic, value, action_shape)