|
from typing import Dict, Any |
|
from enum import Enum |
|
from collections import defaultdict |
|
import json |
|
import attr |
|
import cattr |
|
|
|
from mlagents.torch_utils import torch |
|
from mlagents_envs.logging_util import get_logger |
|
from mlagents.trainers import __version__ |
|
from mlagents.trainers.exception import TrainerError |
|
|
|
logger = get_logger(__name__) |
|
|
|
STATUS_FORMAT_VERSION = "0.3.0" |
|
|
|
|
|
class StatusType(Enum): |
|
LESSON_NUM = "lesson_num" |
|
STATS_METADATA = "metadata" |
|
CHECKPOINTS = "checkpoints" |
|
FINAL_CHECKPOINT = "final_checkpoint" |
|
ELO = "elo" |
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
class StatusMetaData: |
|
stats_format_version: str = STATUS_FORMAT_VERSION |
|
mlagents_version: str = __version__ |
|
torch_version: str = torch.__version__ |
|
|
|
def to_dict(self) -> Dict[str, str]: |
|
return cattr.unstructure(self) |
|
|
|
@staticmethod |
|
def from_dict(import_dict: Dict[str, str]) -> "StatusMetaData": |
|
return cattr.structure(import_dict, StatusMetaData) |
|
|
|
def check_compatibility(self, other: "StatusMetaData") -> None: |
|
""" |
|
Check compatibility with a loaded StatsMetaData and warn the user |
|
if versions mismatch. This is used for resuming from old checkpoints. |
|
""" |
|
|
|
if self.mlagents_version != other.mlagents_version: |
|
logger.warning( |
|
"Checkpoint was loaded from a different version of ML-Agents. Some things may not resume properly." |
|
) |
|
if self.torch_version != other.torch_version: |
|
logger.warning( |
|
"PyTorch checkpoint was saved with a different version of PyTorch. Model may not resume properly." |
|
) |
|
|
|
|
|
class GlobalTrainingStatus: |
|
""" |
|
GlobalTrainingStatus class that contains static methods to save global training status and |
|
load it on a resume. These are values that might be needed for the training resume that |
|
cannot/should not be captured in a model checkpoint, such as curriclum lesson. |
|
""" |
|
|
|
saved_state: Dict[str, Dict[str, Any]] = defaultdict(lambda: {}) |
|
|
|
@staticmethod |
|
def load_state(path: str) -> None: |
|
""" |
|
Load a JSON file that contains saved state. |
|
:param path: Path to the JSON file containing the state. |
|
""" |
|
try: |
|
with open(path) as f: |
|
loaded_dict = json.load(f) |
|
|
|
_metadata = loaded_dict[StatusType.STATS_METADATA.value] |
|
StatusMetaData.from_dict(_metadata).check_compatibility(StatusMetaData()) |
|
|
|
GlobalTrainingStatus.saved_state.update(loaded_dict) |
|
except FileNotFoundError: |
|
logger.warning( |
|
"Training status file not found. Not all functions will resume properly." |
|
) |
|
except KeyError: |
|
raise TrainerError( |
|
"Metadata not found, resuming from an incompatible version of ML-Agents." |
|
) |
|
|
|
@staticmethod |
|
def save_state(path: str) -> None: |
|
""" |
|
Save a JSON file that contains saved state. |
|
:param path: Path to the JSON file containing the state. |
|
""" |
|
GlobalTrainingStatus.saved_state[ |
|
StatusType.STATS_METADATA.value |
|
] = StatusMetaData().to_dict() |
|
with open(path, "w") as f: |
|
json.dump(GlobalTrainingStatus.saved_state, f, indent=4) |
|
|
|
@staticmethod |
|
def set_parameter_state(category: str, key: StatusType, value: Any) -> None: |
|
""" |
|
Stores an arbitrary-named parameter in the global saved state. |
|
:param category: The category (usually behavior name) of the parameter. |
|
:param key: The parameter, e.g. lesson number. |
|
:param value: The value. |
|
""" |
|
GlobalTrainingStatus.saved_state[category][key.value] = value |
|
|
|
@staticmethod |
|
def get_parameter_state(category: str, key: StatusType) -> Any: |
|
""" |
|
Loads an arbitrary-named parameter from training_status.json. |
|
If not found, returns None. |
|
:param category: The category (usually behavior name) of the parameter. |
|
:param key: The statistic, e.g. lesson number. |
|
:param value: The value. |
|
""" |
|
return GlobalTrainingStatus.saved_state[category].get(key.value, None) |
|
|