zjowowen's picture
init space
079c32c
raw
history blame
1.12 kB
import pytest
import torch
from ding.torch_utils import is_differentiable
from ding.model.template import MADQN
@pytest.mark.unittest
def test_madqn():
agent_num, bs, T = 4, 3, 8
obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9
embedding_dim = 64
madqn_model = MADQN(
agent_num=agent_num,
obs_shape=obs_dim,
action_shape=action_dim,
hidden_size_list=[embedding_dim, embedding_dim],
global_obs_shape=global_obs_dim
)
data = {
'obs': {
'agent_state': torch.randn(T, bs, agent_num, obs_dim),
'global_state': torch.randn(T, bs, agent_num, global_obs_dim),
'action_mask': torch.randint(0, 2, size=(T, bs, agent_num, action_dim))
},
'prev_state': [[None for _ in range(agent_num)] for _ in range(bs)],
'action': torch.randint(0, action_dim, size=(T, bs, agent_num))
}
output = madqn_model(data, cooperation=True, single_step=False)
assert output['total_q'].shape == (T, bs)
assert len(output['next_state']) == bs and all([len(n) == agent_num for n in output['next_state']])