|
import pytest |
|
from mlagents.torch_utils import torch |
|
|
|
from mlagents.trainers.buffer import BufferKey, RewardSignalUtil |
|
from mlagents.trainers.sac.optimizer_torch import TorchSACOptimizer |
|
from mlagents.trainers.policy.torch_policy import TorchPolicy |
|
from mlagents.trainers.torch_entities.networks import SimpleActor |
|
from mlagents.trainers.tests import mock_brain as mb |
|
from mlagents.trainers.settings import NetworkSettings |
|
from mlagents.trainers.tests.dummy_config import ( |
|
sac_dummy_config, |
|
curiosity_dummy_config, |
|
) |
|
|
|
|
|
@pytest.fixture |
|
def dummy_config(): |
|
return sac_dummy_config() |
|
|
|
|
|
VECTOR_ACTION_SPACE = 2 |
|
VECTOR_OBS_SPACE = 8 |
|
DISCRETE_ACTION_SPACE = [3, 3, 3, 2] |
|
BUFFER_INIT_SAMPLES = 64 |
|
NUM_AGENTS = 12 |
|
|
|
|
|
def create_sac_optimizer_mock(dummy_config, use_rnn, use_discrete, use_visual): |
|
mock_brain = mb.setup_test_behavior_specs( |
|
use_discrete, |
|
use_visual, |
|
vector_action_space=DISCRETE_ACTION_SPACE |
|
if use_discrete |
|
else VECTOR_ACTION_SPACE, |
|
vector_obs_space=VECTOR_OBS_SPACE if not use_visual else 0, |
|
) |
|
trainer_settings = dummy_config |
|
trainer_settings.network_settings.memory = ( |
|
NetworkSettings.MemorySettings(sequence_length=16, memory_size=12) |
|
if use_rnn |
|
else None |
|
) |
|
actor_kwargs = { |
|
"conditional_sigma": False, |
|
"tanh_squash": False, |
|
} |
|
policy = TorchPolicy( |
|
0, mock_brain, trainer_settings.network_settings, SimpleActor, actor_kwargs |
|
) |
|
optimizer = TorchSACOptimizer(policy, trainer_settings) |
|
return optimizer |
|
|
|
|
|
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"]) |
|
@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"]) |
|
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"]) |
|
def test_sac_optimizer_update(dummy_config, rnn, visual, discrete): |
|
torch.manual_seed(0) |
|
|
|
optimizer = create_sac_optimizer_mock( |
|
dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual |
|
) |
|
|
|
update_buffer = mb.simulate_rollout( |
|
BUFFER_INIT_SAMPLES, optimizer.policy.behavior_spec, memory_size=12 |
|
) |
|
|
|
update_buffer[RewardSignalUtil.rewards_key("extrinsic")] = update_buffer[ |
|
BufferKey.ENVIRONMENT_REWARDS |
|
] |
|
|
|
update_buffer[BufferKey.CRITIC_MEMORY] = update_buffer[BufferKey.MEMORY] |
|
return_stats = optimizer.update( |
|
update_buffer, |
|
num_sequences=update_buffer.num_experiences // optimizer.policy.sequence_length, |
|
) |
|
|
|
required_stats = [ |
|
"Losses/Policy Loss", |
|
"Losses/Value Loss", |
|
"Losses/Q1 Loss", |
|
"Losses/Q2 Loss", |
|
"Policy/Continuous Entropy Coeff", |
|
"Policy/Discrete Entropy Coeff", |
|
"Policy/Learning Rate", |
|
] |
|
for stat in required_stats: |
|
assert stat in return_stats.keys() |
|
|
|
|
|
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"]) |
|
def test_sac_update_reward_signals( |
|
dummy_config, curiosity_dummy_config, discrete |
|
): |
|
|
|
dummy_config.reward_signals = curiosity_dummy_config |
|
optimizer = create_sac_optimizer_mock( |
|
dummy_config, use_rnn=False, use_discrete=discrete, use_visual=False |
|
) |
|
|
|
|
|
update_buffer = mb.simulate_rollout( |
|
BUFFER_INIT_SAMPLES, optimizer.policy.behavior_spec |
|
) |
|
|
|
|
|
update_buffer[RewardSignalUtil.rewards_key("extrinsic")] = update_buffer[ |
|
BufferKey.ENVIRONMENT_REWARDS |
|
] |
|
update_buffer[RewardSignalUtil.rewards_key("curiosity")] = update_buffer[ |
|
BufferKey.ENVIRONMENT_REWARDS |
|
] |
|
return_stats = optimizer.update_reward_signals(update_buffer) |
|
required_stats = ["Losses/Curiosity Forward Loss", "Losses/Curiosity Inverse Loss"] |
|
for stat in required_stats: |
|
assert stat in return_stats.keys() |
|
|
|
|
|
if __name__ == "__main__": |
|
pytest.main() |
|
|