|
|
|
from mlagents import torch_utils |
|
import yaml |
|
|
|
import os |
|
import numpy as np |
|
import json |
|
|
|
from typing import Callable, Optional, List |
|
|
|
import mlagents.trainers |
|
import mlagents_envs |
|
from mlagents.trainers.trainer_controller import TrainerController |
|
from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager |
|
from mlagents.trainers.trainer import TrainerFactory |
|
from mlagents.trainers.directory_utils import ( |
|
validate_existing_directories, |
|
setup_init_path, |
|
) |
|
from mlagents.trainers.stats import StatsReporter |
|
from mlagents.trainers.cli_utils import parser |
|
from mlagents_envs.environment import UnityEnvironment |
|
from mlagents.trainers.settings import RunOptions |
|
|
|
from mlagents.trainers.training_status import GlobalTrainingStatus |
|
from mlagents_envs.base_env import BaseEnv |
|
from mlagents.trainers.subprocess_env_manager import SubprocessEnvManager |
|
from mlagents_envs.side_channel.side_channel import SideChannel |
|
from mlagents_envs.timers import ( |
|
hierarchical_timer, |
|
get_timer_tree, |
|
add_metadata as add_timer_metadata, |
|
) |
|
from mlagents_envs import logging_util |
|
from mlagents.plugins.stats_writer import register_stats_writer_plugins |
|
from mlagents.plugins.trainer_type import register_trainer_plugins |
|
|
|
logger = logging_util.get_logger(__name__) |
|
|
|
TRAINING_STATUS_FILE_NAME = "training_status.json" |
|
|
|
|
|
def get_version_string() -> str: |
|
return f""" Version information: |
|
ml-agents: {mlagents.trainers.__version__}, |
|
ml-agents-envs: {mlagents_envs.__version__}, |
|
Communicator API: {UnityEnvironment.API_VERSION}, |
|
PyTorch: {torch_utils.torch.__version__}""" |
|
|
|
|
|
def parse_command_line( |
|
argv: Optional[List[str]] = None, |
|
) -> RunOptions: |
|
_, _ = register_trainer_plugins() |
|
args = parser.parse_args(argv) |
|
return RunOptions.from_argparse(args) |
|
|
|
|
|
def run_training(run_seed: int, options: RunOptions, num_areas: int) -> None: |
|
""" |
|
Launches training session. |
|
:param run_seed: Random seed used for training. |
|
:param num_areas: Number of training areas to instantiate |
|
:param options: parsed command line arguments |
|
""" |
|
with hierarchical_timer("run_training.setup"): |
|
torch_utils.set_torch_config(options.torch_settings) |
|
checkpoint_settings = options.checkpoint_settings |
|
env_settings = options.env_settings |
|
engine_settings = options.engine_settings |
|
|
|
run_logs_dir = checkpoint_settings.run_logs_dir |
|
port: Optional[int] = env_settings.base_port |
|
|
|
validate_existing_directories( |
|
checkpoint_settings.write_path, |
|
checkpoint_settings.resume, |
|
checkpoint_settings.force, |
|
checkpoint_settings.maybe_init_path, |
|
) |
|
|
|
os.makedirs(run_logs_dir, exist_ok=True) |
|
|
|
if checkpoint_settings.resume: |
|
GlobalTrainingStatus.load_state( |
|
os.path.join(run_logs_dir, "training_status.json") |
|
) |
|
|
|
elif checkpoint_settings.maybe_init_path is not None: |
|
setup_init_path(options.behaviors, checkpoint_settings.maybe_init_path) |
|
|
|
|
|
stats_writers = register_stats_writer_plugins(options) |
|
for sw in stats_writers: |
|
StatsReporter.add_writer(sw) |
|
|
|
if env_settings.env_path is None: |
|
port = None |
|
env_factory = create_environment_factory( |
|
env_settings.env_path, |
|
engine_settings.no_graphics, |
|
run_seed, |
|
num_areas, |
|
port, |
|
env_settings.env_args, |
|
os.path.abspath(run_logs_dir), |
|
) |
|
|
|
env_manager = SubprocessEnvManager(env_factory, options, env_settings.num_envs) |
|
env_parameter_manager = EnvironmentParameterManager( |
|
options.environment_parameters, run_seed, restore=checkpoint_settings.resume |
|
) |
|
|
|
trainer_factory = TrainerFactory( |
|
trainer_config=options.behaviors, |
|
output_path=checkpoint_settings.write_path, |
|
train_model=not checkpoint_settings.inference, |
|
load_model=checkpoint_settings.resume, |
|
seed=run_seed, |
|
param_manager=env_parameter_manager, |
|
init_path=checkpoint_settings.maybe_init_path, |
|
multi_gpu=False, |
|
) |
|
|
|
tc = TrainerController( |
|
trainer_factory, |
|
checkpoint_settings.write_path, |
|
checkpoint_settings.run_id, |
|
env_parameter_manager, |
|
not checkpoint_settings.inference, |
|
run_seed, |
|
) |
|
|
|
|
|
try: |
|
tc.start_learning(env_manager) |
|
finally: |
|
env_manager.close() |
|
write_run_options(checkpoint_settings.write_path, options) |
|
write_timing_tree(run_logs_dir) |
|
write_training_status(run_logs_dir) |
|
|
|
|
|
def write_run_options(output_dir: str, run_options: RunOptions) -> None: |
|
run_options_path = os.path.join(output_dir, "configuration.yaml") |
|
try: |
|
with open(run_options_path, "w") as f: |
|
try: |
|
yaml.dump(run_options.as_dict(), f, sort_keys=False) |
|
except TypeError: |
|
yaml.dump(run_options.as_dict(), f) |
|
except FileNotFoundError: |
|
logger.warning( |
|
f"Unable to save configuration to {run_options_path}. Make sure the directory exists" |
|
) |
|
|
|
|
|
def write_training_status(output_dir: str) -> None: |
|
GlobalTrainingStatus.save_state(os.path.join(output_dir, TRAINING_STATUS_FILE_NAME)) |
|
|
|
|
|
def write_timing_tree(output_dir: str) -> None: |
|
timing_path = os.path.join(output_dir, "timers.json") |
|
try: |
|
with open(timing_path, "w") as f: |
|
json.dump(get_timer_tree(), f, indent=4) |
|
except FileNotFoundError: |
|
logger.warning( |
|
f"Unable to save to {timing_path}. Make sure the directory exists" |
|
) |
|
|
|
|
|
def create_environment_factory( |
|
env_path: Optional[str], |
|
no_graphics: bool, |
|
seed: int, |
|
num_areas: int, |
|
start_port: Optional[int], |
|
env_args: Optional[List[str]], |
|
log_folder: str, |
|
) -> Callable[[int, List[SideChannel]], BaseEnv]: |
|
def create_unity_environment( |
|
worker_id: int, side_channels: List[SideChannel] |
|
) -> UnityEnvironment: |
|
|
|
env_seed = seed + worker_id |
|
return UnityEnvironment( |
|
file_name=env_path, |
|
worker_id=worker_id, |
|
seed=env_seed, |
|
num_areas=num_areas, |
|
no_graphics=no_graphics, |
|
base_port=start_port, |
|
additional_args=env_args, |
|
side_channels=side_channels, |
|
log_folder=log_folder, |
|
) |
|
|
|
return create_unity_environment |
|
|
|
|
|
def run_cli(options: RunOptions) -> None: |
|
try: |
|
print( |
|
""" |
|
β β |
|
βββ¬ββ‘ βββ¬ββ |
|
βββ¬ββββββ β¬ββββββ¬β |
|
ββ¬ββββββ¬β ββ¬βββββββ βββ |
|
β¬β¬β¬β¬ββββ¦β ββ¬ββββ£β£β£β¬ ββ£β£β¬ ββ£β£β£ βββ ββ£β£ |
|
β¬β¬β¬β¬β¬β¬β¬β¬βββ¬ββββ¬βͺβββ£β£β£β£β£β£β£β¬ ββ£β£β¬ ββ£β£β£ ββ£β£βββ£β£β£β β£β£β£ β£β£β£β£β£β£ ββ£β£β β£β£β£ |
|
β¬β¬β¬β¬β ββ¬β¬β¬β¬βββ£β£β£ββ β«β£β£β£β¬ ββ£β£β¬ ββ£β£β£ ββ£β£β£β ββ£β£β£ β£β£β£ βββ£β£ββ β«β£β£ ββ£β£ |
|
β¬β¬β¬β¬β ββ¬β¬β£β£ β«β£β£β£β¬ ββ£β£β¬ ββ£β£β£ ββ£β£β¬ β£β£β£ β£β£β£ ββ£β£ β£β£β£ββ£β£β |
|
β¬β¬β¬β β¬β¬β£β£ βββ£β£β¬ ββ£β£β£βββββ£β£β£β ββ£β£β¬ β£β£β£ β£β£β£ ββ£β£β¦β β£β£β£β£β£ |
|
β ββ¦β β¬β¬β£β£ ββββ βββ£β£β£β£ββ ββββ βββ βββ ββ£β£β£ ββ£β£β£ |
|
β©β¬β¬β¬β¬β¬β¬β¦β¦β¬β¬β£β£ββ£β£β£β£β£β£β£β β«β£β£β£β£ |
|
ββ¬β¬β¬β¬β¬β¬β¬β£β£β£β£β£β£ββ |
|
ββ¬β¬β¬β£β£β£β |
|
β |
|
""" |
|
) |
|
except Exception: |
|
print("\n\n\tUnity Technologies\n") |
|
print(get_version_string()) |
|
|
|
if options.debug: |
|
log_level = logging_util.DEBUG |
|
else: |
|
log_level = logging_util.INFO |
|
|
|
logging_util.set_log_level(log_level) |
|
|
|
logger.debug("Configuration for this run:") |
|
logger.debug(json.dumps(options.as_dict(), indent=4)) |
|
|
|
|
|
if options.checkpoint_settings.load_model: |
|
logger.warning( |
|
"The --load option has been deprecated. Please use the --resume option instead." |
|
) |
|
if options.checkpoint_settings.train_model: |
|
logger.warning( |
|
"The --train option has been deprecated. Train mode is now the default. Use " |
|
"--inference to run in inference mode." |
|
) |
|
|
|
run_seed = options.env_settings.seed |
|
num_areas = options.env_settings.num_areas |
|
|
|
|
|
add_timer_metadata("mlagents_version", mlagents.trainers.__version__) |
|
add_timer_metadata("mlagents_envs_version", mlagents_envs.__version__) |
|
add_timer_metadata("communication_protocol_version", UnityEnvironment.API_VERSION) |
|
add_timer_metadata("pytorch_version", torch_utils.torch.__version__) |
|
add_timer_metadata("numpy_version", np.__version__) |
|
|
|
if options.env_settings.seed == -1: |
|
run_seed = np.random.randint(0, 10000) |
|
logger.debug(f"run_seed set to {run_seed}") |
|
run_training(run_seed, options, num_areas) |
|
|
|
|
|
def main(): |
|
run_cli(parse_command_line()) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|