import pytest import torch import random from ding.torch_utils import is_differentiable from ding.model.template import HAVAC @pytest.mark.unittest class TestHAVAC: def test_havac_rnn_actor(self): # discrete+rnn bs, T = 3, 8 obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9 agent_num = 5 data = { 'obs': { 'agent_state': torch.randn(T, bs, obs_dim), 'global_state': torch.randn(T, bs, global_obs_dim), 'action_mask': torch.randint(0, 2, size=(T, bs, action_dim)) }, 'actor_prev_state': [None for _ in range(bs)], } model = HAVAC( agent_obs_shape=obs_dim, global_obs_shape=global_obs_dim, action_shape=action_dim, agent_num=agent_num, use_lstm=True, ) agent_idx = random.randint(0, agent_num - 1) output = model(agent_idx, data, mode='compute_actor') assert set(output.keys()) == set(['logit', 'actor_next_state', 'actor_hidden_state']) assert output['logit'].shape == (T, bs, action_dim) assert len(output['actor_next_state']) == bs print(output['actor_next_state'][0]['h'].shape) loss = output['logit'].sum() is_differentiable(loss, model.agent_models[agent_idx].actor) def test_havac_rnn_critic(self): # discrete+rnn bs, T = 3, 8 obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9 agent_num = 5 data = { 'obs': { 'agent_state': torch.randn(T, bs, obs_dim), 'global_state': torch.randn(T, bs, global_obs_dim), 'action_mask': torch.randint(0, 2, size=(T, bs, action_dim)) }, 'critic_prev_state': [None for _ in range(bs)], } model = HAVAC( agent_obs_shape=obs_dim, global_obs_shape=global_obs_dim, action_shape=action_dim, agent_num=agent_num, use_lstm=True, ) agent_idx = random.randint(0, agent_num - 1) output = model(agent_idx, data, mode='compute_critic') assert set(output.keys()) == set(['value', 'critic_next_state', 'critic_hidden_state']) assert output['value'].shape == (T, bs) assert len(output['critic_next_state']) == bs print(output['critic_next_state'][0]['h'].shape) loss = output['value'].sum() is_differentiable(loss, model.agent_models[agent_idx].critic) def test_havac_rnn_actor_critic(self): # discrete+rnn bs, T = 3, 8 obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9 agent_num = 5 data = { 'obs': { 'agent_state': torch.randn(T, bs, obs_dim), 'global_state': torch.randn(T, bs, global_obs_dim), 'action_mask': torch.randint(0, 2, size=(T, bs, action_dim)) }, 'actor_prev_state': [None for _ in range(bs)], 'critic_prev_state': [None for _ in range(bs)], } model = HAVAC( agent_obs_shape=obs_dim, global_obs_shape=global_obs_dim, action_shape=action_dim, agent_num=agent_num, use_lstm=True, ) agent_idx = random.randint(0, agent_num - 1) output = model(agent_idx, data, mode='compute_actor_critic') assert set(output.keys()) == set( ['logit', 'actor_next_state', 'actor_hidden_state', 'value', 'critic_next_state', 'critic_hidden_state'] ) assert output['logit'].shape == (T, bs, action_dim) assert output['value'].shape == (T, bs) loss = output['logit'].sum() + output['value'].sum() is_differentiable(loss, model.agent_models[agent_idx]) # test_havac_rnn_actor() # test_havac_rnn_critic() # test_havac_rnn_actor_critic()