|
import pytest |
|
import torch |
|
import random |
|
from ding.torch_utils import is_differentiable |
|
from ding.model.template import HAVAC |
|
|
|
|
|
@pytest.mark.unittest |
|
class TestHAVAC: |
|
|
|
def test_havac_rnn_actor(self): |
|
|
|
bs, T = 3, 8 |
|
obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9 |
|
agent_num = 5 |
|
data = { |
|
'obs': { |
|
'agent_state': torch.randn(T, bs, obs_dim), |
|
'global_state': torch.randn(T, bs, global_obs_dim), |
|
'action_mask': torch.randint(0, 2, size=(T, bs, action_dim)) |
|
}, |
|
'actor_prev_state': [None for _ in range(bs)], |
|
} |
|
model = HAVAC( |
|
agent_obs_shape=obs_dim, |
|
global_obs_shape=global_obs_dim, |
|
action_shape=action_dim, |
|
agent_num=agent_num, |
|
use_lstm=True, |
|
) |
|
agent_idx = random.randint(0, agent_num - 1) |
|
output = model(agent_idx, data, mode='compute_actor') |
|
assert set(output.keys()) == set(['logit', 'actor_next_state', 'actor_hidden_state']) |
|
assert output['logit'].shape == (T, bs, action_dim) |
|
assert len(output['actor_next_state']) == bs |
|
print(output['actor_next_state'][0]['h'].shape) |
|
loss = output['logit'].sum() |
|
is_differentiable(loss, model.agent_models[agent_idx].actor) |
|
|
|
def test_havac_rnn_critic(self): |
|
|
|
bs, T = 3, 8 |
|
obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9 |
|
agent_num = 5 |
|
data = { |
|
'obs': { |
|
'agent_state': torch.randn(T, bs, obs_dim), |
|
'global_state': torch.randn(T, bs, global_obs_dim), |
|
'action_mask': torch.randint(0, 2, size=(T, bs, action_dim)) |
|
}, |
|
'critic_prev_state': [None for _ in range(bs)], |
|
} |
|
model = HAVAC( |
|
agent_obs_shape=obs_dim, |
|
global_obs_shape=global_obs_dim, |
|
action_shape=action_dim, |
|
agent_num=agent_num, |
|
use_lstm=True, |
|
) |
|
agent_idx = random.randint(0, agent_num - 1) |
|
output = model(agent_idx, data, mode='compute_critic') |
|
assert set(output.keys()) == set(['value', 'critic_next_state', 'critic_hidden_state']) |
|
assert output['value'].shape == (T, bs) |
|
assert len(output['critic_next_state']) == bs |
|
print(output['critic_next_state'][0]['h'].shape) |
|
loss = output['value'].sum() |
|
is_differentiable(loss, model.agent_models[agent_idx].critic) |
|
|
|
def test_havac_rnn_actor_critic(self): |
|
|
|
bs, T = 3, 8 |
|
obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9 |
|
agent_num = 5 |
|
data = { |
|
'obs': { |
|
'agent_state': torch.randn(T, bs, obs_dim), |
|
'global_state': torch.randn(T, bs, global_obs_dim), |
|
'action_mask': torch.randint(0, 2, size=(T, bs, action_dim)) |
|
}, |
|
'actor_prev_state': [None for _ in range(bs)], |
|
'critic_prev_state': [None for _ in range(bs)], |
|
} |
|
model = HAVAC( |
|
agent_obs_shape=obs_dim, |
|
global_obs_shape=global_obs_dim, |
|
action_shape=action_dim, |
|
agent_num=agent_num, |
|
use_lstm=True, |
|
) |
|
agent_idx = random.randint(0, agent_num - 1) |
|
output = model(agent_idx, data, mode='compute_actor_critic') |
|
assert set(output.keys()) == set( |
|
['logit', 'actor_next_state', 'actor_hidden_state', 'value', 'critic_next_state', 'critic_hidden_state'] |
|
) |
|
assert output['logit'].shape == (T, bs, action_dim) |
|
assert output['value'].shape == (T, bs) |
|
loss = output['logit'].sum() + output['value'].sum() |
|
is_differentiable(loss, model.agent_models[agent_idx]) |
|
|
|
|
|
|
|
|
|
|
|
|