|
import pytest |
|
import yaml |
|
from unittest.mock import MagicMock, patch, mock_open |
|
from mlagents.trainers import learn |
|
from mlagents.trainers.trainer_controller import TrainerController |
|
from mlagents.trainers.learn import parse_command_line |
|
from mlagents.trainers.cli_utils import DetectDefault |
|
from mlagents_envs.exception import UnityEnvironmentException |
|
from mlagents.trainers.stats import StatsReporter |
|
from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager |
|
import os.path |
|
|
|
|
|
def basic_options(extra_args=None): |
|
extra_args = extra_args or {} |
|
args = ["basic_path"] |
|
if extra_args: |
|
args += [f"{k}={v}" for k, v in extra_args.items()] |
|
return parse_command_line(args) |
|
|
|
|
|
MOCK_YAML = """ |
|
behaviors: |
|
{} |
|
""" |
|
|
|
MOCK_INITIALIZE_YAML = """ |
|
behaviors: |
|
{} |
|
checkpoint_settings: |
|
initialize_from: notuselessrun |
|
""" |
|
|
|
MOCK_PARAMETER_YAML = """ |
|
behaviors: |
|
{} |
|
env_settings: |
|
env_path: "./oldenvfile" |
|
num_envs: 4 |
|
num_areas: 4 |
|
base_port: 4001 |
|
seed: 9870 |
|
checkpoint_settings: |
|
run_id: uselessrun |
|
initialize_from: notuselessrun |
|
debug: false |
|
""" |
|
|
|
|
|
@patch("mlagents.trainers.learn.write_timing_tree") |
|
@patch("mlagents.trainers.learn.write_run_options") |
|
@patch("mlagents.trainers.learn.validate_existing_directories") |
|
@patch("mlagents.trainers.learn.TrainerFactory") |
|
@patch("mlagents.trainers.learn.SubprocessEnvManager") |
|
@patch("mlagents.trainers.learn.create_environment_factory") |
|
@patch("mlagents.trainers.settings.load_config") |
|
def test_run_training( |
|
load_config, |
|
create_environment_factory, |
|
subproc_env_mock, |
|
trainer_factory_mock, |
|
handle_dir_mock, |
|
write_run_options_mock, |
|
write_timing_tree_mock, |
|
): |
|
mock_env = MagicMock() |
|
mock_env.external_brain_names = [] |
|
mock_env.academy_name = "TestAcademyName" |
|
create_environment_factory.return_value = mock_env |
|
load_config.return_value = yaml.safe_load(MOCK_INITIALIZE_YAML) |
|
mock_param_manager = MagicMock(return_value="mock_param_manager") |
|
mock_init = MagicMock(return_value=None) |
|
with patch.object(EnvironmentParameterManager, "__new__", mock_param_manager): |
|
with patch.object(TrainerController, "__init__", mock_init): |
|
with patch.object(TrainerController, "start_learning", MagicMock()): |
|
options = basic_options() |
|
learn.run_training(0, options, 1) |
|
mock_init.assert_called_once_with( |
|
trainer_factory_mock.return_value, |
|
os.path.join("results", "ppo"), |
|
"ppo", |
|
"mock_param_manager", |
|
True, |
|
0, |
|
) |
|
handle_dir_mock.assert_called_once_with( |
|
os.path.join("results", "ppo"), |
|
False, |
|
False, |
|
os.path.join("results", "notuselessrun"), |
|
) |
|
write_timing_tree_mock.assert_called_once_with( |
|
os.path.join("results", "ppo", "run_logs") |
|
) |
|
write_run_options_mock.assert_called_once_with( |
|
os.path.join("results", "ppo"), options |
|
) |
|
StatsReporter.writers.clear() |
|
|
|
|
|
def test_bad_env_path(): |
|
with pytest.raises(UnityEnvironmentException): |
|
factory = learn.create_environment_factory( |
|
env_path="/foo/bar", |
|
no_graphics=True, |
|
seed=-1, |
|
num_areas=1, |
|
start_port=8000, |
|
env_args=None, |
|
log_folder="results/log_folder", |
|
) |
|
factory(worker_id=-1, side_channels=[]) |
|
|
|
|
|
@patch("builtins.open", new_callable=mock_open, read_data=MOCK_YAML) |
|
def test_commandline_args(mock_file): |
|
|
|
|
|
|
|
|
|
opt = parse_command_line(["mytrainerpath"]) |
|
assert opt.behaviors == {} |
|
assert opt.env_settings.env_path is None |
|
assert opt.checkpoint_settings.resume is False |
|
assert opt.checkpoint_settings.inference is False |
|
assert opt.checkpoint_settings.run_id == "ppo" |
|
assert opt.checkpoint_settings.initialize_from is None |
|
assert opt.env_settings.seed == -1 |
|
assert opt.env_settings.base_port == 5005 |
|
assert opt.env_settings.num_envs == 1 |
|
assert opt.env_settings.num_areas == 1 |
|
assert opt.engine_settings.no_graphics is False |
|
assert opt.debug is False |
|
assert opt.env_settings.env_args is None |
|
|
|
full_args = [ |
|
"mytrainerpath", |
|
"--env=./myenvfile", |
|
"--inference", |
|
"--run-id=myawesomerun", |
|
"--seed=7890", |
|
"--train", |
|
"--base-port=4004", |
|
"--initialize-from=testdir", |
|
"--num-envs=2", |
|
"--num-areas=2", |
|
"--no-graphics", |
|
"--debug", |
|
] |
|
|
|
opt = parse_command_line(full_args) |
|
assert opt.behaviors == {} |
|
assert opt.env_settings.env_path == "./myenvfile" |
|
assert opt.checkpoint_settings.run_id == "myawesomerun" |
|
assert opt.checkpoint_settings.initialize_from == "testdir" |
|
assert opt.env_settings.seed == 7890 |
|
assert opt.env_settings.base_port == 4004 |
|
assert opt.env_settings.num_envs == 2 |
|
assert opt.env_settings.num_areas == 2 |
|
assert opt.engine_settings.no_graphics is True |
|
assert opt.debug is True |
|
assert opt.checkpoint_settings.inference is True |
|
assert opt.checkpoint_settings.resume is False |
|
|
|
|
|
full_args.append("--resume") |
|
opt = parse_command_line(full_args) |
|
assert opt.checkpoint_settings.initialize_from is None |
|
assert opt.checkpoint_settings.resume is True |
|
|
|
|
|
@patch("builtins.open", new_callable=mock_open, read_data=MOCK_PARAMETER_YAML) |
|
def test_yaml_args(mock_file): |
|
|
|
DetectDefault.non_default_args.clear() |
|
opt = parse_command_line(["mytrainerpath"]) |
|
assert opt.behaviors == {} |
|
assert opt.env_settings.env_path == "./oldenvfile" |
|
assert opt.checkpoint_settings.run_id == "uselessrun" |
|
assert opt.checkpoint_settings.initialize_from == "notuselessrun" |
|
assert opt.env_settings.seed == 9870 |
|
assert opt.env_settings.base_port == 4001 |
|
assert opt.env_settings.num_envs == 4 |
|
assert opt.env_settings.num_areas == 4 |
|
assert opt.engine_settings.no_graphics is False |
|
assert opt.debug is False |
|
assert opt.env_settings.env_args is None |
|
|
|
full_args = [ |
|
"mytrainerpath", |
|
"--env=./myenvfile", |
|
"--resume", |
|
"--inference", |
|
"--run-id=myawesomerun", |
|
"--seed=7890", |
|
"--train", |
|
"--base-port=4004", |
|
"--num-envs=2", |
|
"--num-areas=2", |
|
"--no-graphics", |
|
"--debug", |
|
"--results-dir=myresults", |
|
] |
|
|
|
opt = parse_command_line(full_args) |
|
assert opt.behaviors == {} |
|
assert opt.env_settings.env_path == "./myenvfile" |
|
assert opt.checkpoint_settings.run_id == "myawesomerun" |
|
assert opt.env_settings.seed == 7890 |
|
assert opt.env_settings.base_port == 4004 |
|
assert opt.env_settings.num_envs == 2 |
|
assert opt.env_settings.num_areas == 2 |
|
assert opt.engine_settings.no_graphics is True |
|
assert opt.debug is True |
|
assert opt.checkpoint_settings.inference is True |
|
assert opt.checkpoint_settings.resume is True |
|
assert opt.checkpoint_settings.results_dir == "myresults" |
|
|
|
|
|
@patch("builtins.open", new_callable=mock_open, read_data=MOCK_YAML) |
|
def test_env_args(mock_file): |
|
full_args = [ |
|
"mytrainerpath", |
|
"--env=./myenvfile", |
|
"--env-args", |
|
"--foo=bar", |
|
"--blah", |
|
"baz", |
|
"100", |
|
] |
|
|
|
opt = parse_command_line(full_args) |
|
assert opt.env_settings.env_args == ["--foo=bar", "--blah", "baz", "100"] |
|
|