|
from unittest.mock import MagicMock, patch |
|
import pytest |
|
from mlagents.torch_utils import torch |
|
|
|
from mlagents.trainers.trainer_controller import TrainerController |
|
from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager |
|
from mlagents.trainers.ghost.controller import GhostController |
|
|
|
|
|
@pytest.fixture |
|
def basic_trainer_controller(): |
|
trainer_factory_mock = MagicMock() |
|
trainer_factory_mock.ghost_controller = GhostController() |
|
return TrainerController( |
|
trainer_factory=trainer_factory_mock, |
|
output_path="test_model_path", |
|
run_id="test_run_id", |
|
param_manager=EnvironmentParameterManager(), |
|
train=True, |
|
training_seed=99, |
|
) |
|
|
|
|
|
@patch("numpy.random.seed") |
|
@patch.object(torch, "manual_seed") |
|
def test_initialization_seed(numpy_random_seed, torch_set_seed): |
|
seed = 27 |
|
trainer_factory_mock = MagicMock() |
|
trainer_factory_mock.ghost_controller = GhostController() |
|
TrainerController( |
|
trainer_factory=trainer_factory_mock, |
|
output_path="", |
|
run_id="1", |
|
param_manager=None, |
|
train=True, |
|
training_seed=seed, |
|
) |
|
numpy_random_seed.assert_called_with(seed) |
|
torch_set_seed.assert_called_with(seed) |
|
|
|
|
|
@pytest.fixture |
|
def trainer_controller_with_start_learning_mocks(basic_trainer_controller): |
|
trainer_mock = MagicMock() |
|
trainer_mock.get_step = 0 |
|
trainer_mock.get_max_steps = 5 |
|
trainer_mock.should_still_train = True |
|
trainer_mock.parameters = {"some": "parameter"} |
|
trainer_mock.write_tensorboard_text = MagicMock() |
|
|
|
tc = basic_trainer_controller |
|
tc.trainers = {"testbrain": trainer_mock} |
|
tc.advance = MagicMock() |
|
tc.trainers["testbrain"].get_step = 0 |
|
|
|
def take_step_sideeffect(env): |
|
tc.trainers["testbrain"].get_step += 1 |
|
if ( |
|
not tc.trainers["testbrain"].get_step |
|
<= tc.trainers["testbrain"].get_max_steps |
|
): |
|
tc.trainers["testbrain"].should_still_train = False |
|
if tc.trainers["testbrain"].get_step > 10: |
|
raise KeyboardInterrupt |
|
return 1 |
|
|
|
tc.advance.side_effect = take_step_sideeffect |
|
|
|
tc._save_models = MagicMock() |
|
return tc, trainer_mock |
|
|
|
|
|
def test_start_learning_trains_forever_if_no_train_model( |
|
trainer_controller_with_start_learning_mocks, |
|
): |
|
tc, trainer_mock = trainer_controller_with_start_learning_mocks |
|
tc.train_model = False |
|
|
|
env_mock = MagicMock() |
|
env_mock.close = MagicMock() |
|
env_mock.reset = MagicMock() |
|
env_mock.training_behaviors = MagicMock() |
|
|
|
tc.start_learning(env_mock) |
|
env_mock.reset.assert_called_once() |
|
assert tc.advance.call_count == 11 |
|
tc._save_models.assert_not_called() |
|
|
|
|
|
def test_start_learning_trains_until_max_steps_then_saves( |
|
trainer_controller_with_start_learning_mocks, |
|
): |
|
tc, trainer_mock = trainer_controller_with_start_learning_mocks |
|
|
|
brain_info_mock = MagicMock() |
|
env_mock = MagicMock() |
|
env_mock.close = MagicMock() |
|
env_mock.reset = MagicMock(return_value=brain_info_mock) |
|
env_mock.training_behaviors = MagicMock() |
|
|
|
tc.start_learning(env_mock) |
|
env_mock.reset.assert_called_once() |
|
assert tc.advance.call_count == trainer_mock.get_max_steps + 1 |
|
tc._save_models.assert_called_once() |
|
|
|
|
|
@pytest.fixture |
|
def trainer_controller_with_take_step_mocks(basic_trainer_controller): |
|
trainer_mock = MagicMock() |
|
trainer_mock.get_step = 0 |
|
trainer_mock.get_max_steps = 5 |
|
trainer_mock.parameters = {"some": "parameter"} |
|
trainer_mock.write_tensorboard_text = MagicMock() |
|
|
|
tc = basic_trainer_controller |
|
tc.trainers = {"testbrain": trainer_mock} |
|
tc.managers = {"testbrain": MagicMock()} |
|
|
|
return tc, trainer_mock |
|
|
|
|
|
def test_advance_adds_experiences_to_trainer_and_trains( |
|
trainer_controller_with_take_step_mocks, |
|
): |
|
tc, trainer_mock = trainer_controller_with_take_step_mocks |
|
|
|
brain_name = "testbrain" |
|
|
|
env_mock = MagicMock() |
|
|
|
tc.brain_name_to_identifier[brain_name].add(brain_name) |
|
|
|
tc.advance(env_mock) |
|
|
|
env_mock.reset.assert_not_called() |
|
env_mock.get_steps.assert_called_once() |
|
env_mock.process_steps.assert_called_once() |
|
|
|
|
|
|