zjowowen's picture
init space
079c32c
raw
history blame
1.66 kB
import pytest
import torch
from ding.torch_utils import is_differentiable
from ding.model.template.coma import COMACriticNetwork, COMAActorNetwork
@pytest.mark.unittest
def test_coma_critic():
agent_num, bs, T = 4, 3, 8
obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9
coma_model = COMACriticNetwork(obs_dim - action_dim + global_obs_dim + 2 * action_dim * agent_num, action_dim)
data = {
'obs': {
'agent_state': torch.randn(T, bs, agent_num, obs_dim),
'global_state': torch.randn(T, bs, global_obs_dim),
},
'action': torch.randint(0, action_dim, size=(T, bs, agent_num)),
}
output = coma_model(data)
assert set(output.keys()) == set(['q_value'])
assert output['q_value'].shape == (T, bs, agent_num, action_dim)
loss = output['q_value'].sum()
is_differentiable(loss, coma_model)
@pytest.mark.unittest
def test_rnn_actor_net():
T, B, A, N = 4, 8, 3, 32
embedding_dim = 64
action_dim = 6
data = torch.randn(T, B, A, N)
model = COMAActorNetwork((N, ), action_dim, [128, embedding_dim])
prev_state = [[None for _ in range(A)] for _ in range(B)]
for t in range(T):
inputs = {'obs': {'agent_state': data[t], 'action_mask': None}, 'prev_state': prev_state}
outputs = model(inputs)
logit, prev_state = outputs['logit'], outputs['next_state']
assert len(prev_state) == B
assert all([len(o) == A and all([len(o1) == 2 for o1 in o]) for o in prev_state])
assert logit.shape == (B, A, action_dim)
# test the last step can backward correctly
loss = logit.sum()
is_differentiable(loss, model)