AnnaMats's picture
Second Push
from typing import Dict, Any
from enum import Enum
from collections import defaultdict
import json
import attr
import cattr
from mlagents.torch_utils import torch
from mlagents_envs.logging_util import get_logger
from mlagents.trainers import __version__
from mlagents.trainers.exception import TrainerError
logger = get_logger(__name__)
class StatusType(Enum):
LESSON_NUM = "lesson_num"
STATS_METADATA = "metadata"
CHECKPOINTS = "checkpoints"
FINAL_CHECKPOINT = "final_checkpoint"
ELO = "elo"
class StatusMetaData:
stats_format_version: str = STATUS_FORMAT_VERSION
mlagents_version: str = __version__
torch_version: str = torch.__version__
def to_dict(self) -> Dict[str, str]:
return cattr.unstructure(self)
def from_dict(import_dict: Dict[str, str]) -> "StatusMetaData":
return cattr.structure(import_dict, StatusMetaData)
def check_compatibility(self, other: "StatusMetaData") -> None:
Check compatibility with a loaded StatsMetaData and warn the user
if versions mismatch. This is used for resuming from old checkpoints.
# This should cover all stats version mismatches as well.
if self.mlagents_version != other.mlagents_version:
"Checkpoint was loaded from a different version of ML-Agents. Some things may not resume properly."
if self.torch_version != other.torch_version:
"PyTorch checkpoint was saved with a different version of PyTorch. Model may not resume properly."
class GlobalTrainingStatus:
GlobalTrainingStatus class that contains static methods to save global training status and
load it on a resume. These are values that might be needed for the training resume that
cannot/should not be captured in a model checkpoint, such as curriclum lesson.
saved_state: Dict[str, Dict[str, Any]] = defaultdict(lambda: {})
def load_state(path: str) -> None:
Load a JSON file that contains saved state.
:param path: Path to the JSON file containing the state.
with open(path) as f:
loaded_dict = json.load(f)
# Compare the metadata
_metadata = loaded_dict[StatusType.STATS_METADATA.value]
# Update saved state.
except FileNotFoundError:
"Training status file not found. Not all functions will resume properly."
except KeyError:
raise TrainerError(
"Metadata not found, resuming from an incompatible version of ML-Agents."
def save_state(path: str) -> None:
Save a JSON file that contains saved state.
:param path: Path to the JSON file containing the state.
] = StatusMetaData().to_dict()
with open(path, "w") as f:
json.dump(GlobalTrainingStatus.saved_state, f, indent=4)
def set_parameter_state(category: str, key: StatusType, value: Any) -> None:
Stores an arbitrary-named parameter in the global saved state.
:param category: The category (usually behavior name) of the parameter.
:param key: The parameter, e.g. lesson number.
:param value: The value.
GlobalTrainingStatus.saved_state[category][key.value] = value
def get_parameter_state(category: str, key: StatusType) -> Any:
Loads an arbitrary-named parameter from training_status.json.
If not found, returns None.
:param category: The category (usually behavior name) of the parameter.
:param key: The statistic, e.g. lesson number.
:param value: The value.
return GlobalTrainingStatus.saved_state[category].get(key.value, None)