File size: 4,291 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
from unittest.mock import MagicMock, patch
import pytest
from mlagents.torch_utils import torch
from mlagents.trainers.trainer_controller import TrainerController
from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager
from mlagents.trainers.ghost.controller import GhostController
@pytest.fixture
def basic_trainer_controller():
trainer_factory_mock = MagicMock()
trainer_factory_mock.ghost_controller = GhostController()
return TrainerController(
trainer_factory=trainer_factory_mock,
output_path="test_model_path",
run_id="test_run_id",
param_manager=EnvironmentParameterManager(),
train=True,
training_seed=99,
)
@patch("numpy.random.seed")
@patch.object(torch, "manual_seed")
def test_initialization_seed(numpy_random_seed, torch_set_seed):
seed = 27
trainer_factory_mock = MagicMock()
trainer_factory_mock.ghost_controller = GhostController()
TrainerController(
trainer_factory=trainer_factory_mock,
output_path="",
run_id="1",
param_manager=None,
train=True,
training_seed=seed,
)
numpy_random_seed.assert_called_with(seed)
torch_set_seed.assert_called_with(seed)
@pytest.fixture
def trainer_controller_with_start_learning_mocks(basic_trainer_controller):
trainer_mock = MagicMock()
trainer_mock.get_step = 0
trainer_mock.get_max_steps = 5
trainer_mock.should_still_train = True
trainer_mock.parameters = {"some": "parameter"}
trainer_mock.write_tensorboard_text = MagicMock()
tc = basic_trainer_controller
tc.trainers = {"testbrain": trainer_mock}
tc.advance = MagicMock()
tc.trainers["testbrain"].get_step = 0
def take_step_sideeffect(env):
tc.trainers["testbrain"].get_step += 1
if (
not tc.trainers["testbrain"].get_step
<= tc.trainers["testbrain"].get_max_steps
):
tc.trainers["testbrain"].should_still_train = False
if tc.trainers["testbrain"].get_step > 10:
raise KeyboardInterrupt
return 1
tc.advance.side_effect = take_step_sideeffect
tc._save_models = MagicMock()
return tc, trainer_mock
def test_start_learning_trains_forever_if_no_train_model(
trainer_controller_with_start_learning_mocks,
):
tc, trainer_mock = trainer_controller_with_start_learning_mocks
tc.train_model = False
env_mock = MagicMock()
env_mock.close = MagicMock()
env_mock.reset = MagicMock()
env_mock.training_behaviors = MagicMock()
tc.start_learning(env_mock)
env_mock.reset.assert_called_once()
assert tc.advance.call_count == 11
tc._save_models.assert_not_called()
def test_start_learning_trains_until_max_steps_then_saves(
trainer_controller_with_start_learning_mocks,
):
tc, trainer_mock = trainer_controller_with_start_learning_mocks
brain_info_mock = MagicMock()
env_mock = MagicMock()
env_mock.close = MagicMock()
env_mock.reset = MagicMock(return_value=brain_info_mock)
env_mock.training_behaviors = MagicMock()
tc.start_learning(env_mock)
env_mock.reset.assert_called_once()
assert tc.advance.call_count == trainer_mock.get_max_steps + 1
tc._save_models.assert_called_once()
@pytest.fixture
def trainer_controller_with_take_step_mocks(basic_trainer_controller):
trainer_mock = MagicMock()
trainer_mock.get_step = 0
trainer_mock.get_max_steps = 5
trainer_mock.parameters = {"some": "parameter"}
trainer_mock.write_tensorboard_text = MagicMock()
tc = basic_trainer_controller
tc.trainers = {"testbrain": trainer_mock}
tc.managers = {"testbrain": MagicMock()}
return tc, trainer_mock
def test_advance_adds_experiences_to_trainer_and_trains(
trainer_controller_with_take_step_mocks,
):
tc, trainer_mock = trainer_controller_with_take_step_mocks
brain_name = "testbrain"
env_mock = MagicMock()
tc.brain_name_to_identifier[brain_name].add(brain_name)
tc.advance(env_mock)
env_mock.reset.assert_not_called()
env_mock.get_steps.assert_called_once()
env_mock.process_steps.assert_called_once()
# May have been called many times due to thread
# assert trainer_mock.advance.call_count > 0
|