|
import sys |
|
from typing import List |
|
|
|
|
|
|
|
if sys.version_info < (3, 8): |
|
import importlib_metadata |
|
else: |
|
import importlib.metadata as importlib_metadata |
|
|
|
from mlagents.trainers.stats import StatsWriter |
|
|
|
from mlagents_envs import logging_util |
|
from mlagents.plugins import ML_AGENTS_STATS_WRITER |
|
from mlagents.trainers.settings import RunOptions |
|
from mlagents.trainers.stats import TensorboardWriter, GaugeWriter, ConsoleWriter |
|
|
|
|
|
logger = logging_util.get_logger(__name__) |
|
|
|
|
|
def get_default_stats_writers(run_options: RunOptions) -> List[StatsWriter]: |
|
""" |
|
The StatsWriters that mlagents-learn always uses: |
|
* A TensorboardWriter to write information to TensorBoard |
|
* A GaugeWriter to record our internal stats |
|
* A ConsoleWriter to output to stdout. |
|
""" |
|
checkpoint_settings = run_options.checkpoint_settings |
|
return [ |
|
TensorboardWriter( |
|
checkpoint_settings.write_path, |
|
clear_past_data=not checkpoint_settings.resume, |
|
hidden_keys=["Is Training", "Step"], |
|
), |
|
GaugeWriter(), |
|
ConsoleWriter(), |
|
] |
|
|
|
|
|
def register_stats_writer_plugins(run_options: RunOptions) -> List[StatsWriter]: |
|
""" |
|
Registers all StatsWriter plugins (including the default one), |
|
and evaluates them, and returns the list of all the StatsWriter implementations. |
|
""" |
|
all_stats_writers: List[StatsWriter] = [] |
|
if ML_AGENTS_STATS_WRITER not in importlib_metadata.entry_points(): |
|
logger.warning( |
|
f"Unable to find any entry points for {ML_AGENTS_STATS_WRITER}, even the default ones. " |
|
"Uninstalling and reinstalling ml-agents via pip should resolve. " |
|
"Using default plugins for now." |
|
) |
|
return get_default_stats_writers(run_options) |
|
|
|
entry_points = importlib_metadata.entry_points()[ML_AGENTS_STATS_WRITER] |
|
|
|
for entry_point in entry_points: |
|
|
|
try: |
|
logger.debug(f"Initializing StatsWriter plugins: {entry_point.name}") |
|
plugin_func = entry_point.load() |
|
plugin_stats_writers = plugin_func(run_options) |
|
logger.debug( |
|
f"Found {len(plugin_stats_writers)} StatsWriters for plugin {entry_point.name}" |
|
) |
|
all_stats_writers += plugin_stats_writers |
|
except BaseException: |
|
|
|
logger.exception( |
|
f"Error initializing StatsWriter plugins for {entry_point.name}. This plugin will not be used." |
|
) |
|
return all_stats_writers |
|
|