File size: 4,291 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from unittest.mock import MagicMock, patch
import pytest
from mlagents.torch_utils import torch

from mlagents.trainers.trainer_controller import TrainerController
from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager
from mlagents.trainers.ghost.controller import GhostController


@pytest.fixture
def basic_trainer_controller():
    trainer_factory_mock = MagicMock()
    trainer_factory_mock.ghost_controller = GhostController()
    return TrainerController(
        trainer_factory=trainer_factory_mock,
        output_path="test_model_path",
        run_id="test_run_id",
        param_manager=EnvironmentParameterManager(),
        train=True,
        training_seed=99,
    )


@patch("numpy.random.seed")
@patch.object(torch, "manual_seed")
def test_initialization_seed(numpy_random_seed, torch_set_seed):
    seed = 27
    trainer_factory_mock = MagicMock()
    trainer_factory_mock.ghost_controller = GhostController()
    TrainerController(
        trainer_factory=trainer_factory_mock,
        output_path="",
        run_id="1",
        param_manager=None,
        train=True,
        training_seed=seed,
    )
    numpy_random_seed.assert_called_with(seed)
    torch_set_seed.assert_called_with(seed)


@pytest.fixture
def trainer_controller_with_start_learning_mocks(basic_trainer_controller):
    trainer_mock = MagicMock()
    trainer_mock.get_step = 0
    trainer_mock.get_max_steps = 5
    trainer_mock.should_still_train = True
    trainer_mock.parameters = {"some": "parameter"}
    trainer_mock.write_tensorboard_text = MagicMock()

    tc = basic_trainer_controller
    tc.trainers = {"testbrain": trainer_mock}
    tc.advance = MagicMock()
    tc.trainers["testbrain"].get_step = 0

    def take_step_sideeffect(env):
        tc.trainers["testbrain"].get_step += 1
        if (
            not tc.trainers["testbrain"].get_step
            <= tc.trainers["testbrain"].get_max_steps
        ):
            tc.trainers["testbrain"].should_still_train = False
        if tc.trainers["testbrain"].get_step > 10:
            raise KeyboardInterrupt
        return 1

    tc.advance.side_effect = take_step_sideeffect

    tc._save_models = MagicMock()
    return tc, trainer_mock


def test_start_learning_trains_forever_if_no_train_model(
    trainer_controller_with_start_learning_mocks,
):
    tc, trainer_mock = trainer_controller_with_start_learning_mocks
    tc.train_model = False

    env_mock = MagicMock()
    env_mock.close = MagicMock()
    env_mock.reset = MagicMock()
    env_mock.training_behaviors = MagicMock()

    tc.start_learning(env_mock)
    env_mock.reset.assert_called_once()
    assert tc.advance.call_count == 11
    tc._save_models.assert_not_called()


def test_start_learning_trains_until_max_steps_then_saves(
    trainer_controller_with_start_learning_mocks,
):
    tc, trainer_mock = trainer_controller_with_start_learning_mocks

    brain_info_mock = MagicMock()
    env_mock = MagicMock()
    env_mock.close = MagicMock()
    env_mock.reset = MagicMock(return_value=brain_info_mock)
    env_mock.training_behaviors = MagicMock()

    tc.start_learning(env_mock)
    env_mock.reset.assert_called_once()
    assert tc.advance.call_count == trainer_mock.get_max_steps + 1
    tc._save_models.assert_called_once()


@pytest.fixture
def trainer_controller_with_take_step_mocks(basic_trainer_controller):
    trainer_mock = MagicMock()
    trainer_mock.get_step = 0
    trainer_mock.get_max_steps = 5
    trainer_mock.parameters = {"some": "parameter"}
    trainer_mock.write_tensorboard_text = MagicMock()

    tc = basic_trainer_controller
    tc.trainers = {"testbrain": trainer_mock}
    tc.managers = {"testbrain": MagicMock()}

    return tc, trainer_mock


def test_advance_adds_experiences_to_trainer_and_trains(
    trainer_controller_with_take_step_mocks,
):
    tc, trainer_mock = trainer_controller_with_take_step_mocks

    brain_name = "testbrain"

    env_mock = MagicMock()

    tc.brain_name_to_identifier[brain_name].add(brain_name)

    tc.advance(env_mock)

    env_mock.reset.assert_not_called()
    env_mock.get_steps.assert_called_once()
    env_mock.process_steps.assert_called_once()
    # May have been called many times due to thread
    # assert trainer_mock.advance.call_count > 0