AnnaMats's picture
Second Push
05c9ac2
# # Unity ML-Agents Toolkit
from typing import Dict, Any, Optional, List
import os
import attr
from mlagents.trainers.training_status import GlobalTrainingStatus, StatusType
from mlagents_envs.logging_util import get_logger
logger = get_logger(__name__)
@attr.s(auto_attribs=True)
class ModelCheckpoint:
steps: int
file_path: str
reward: Optional[float]
creation_time: float
auxillary_file_paths: List[str] = attr.ib(factory=list)
class ModelCheckpointManager:
@staticmethod
def get_checkpoints(behavior_name: str) -> List[Dict[str, Any]]:
checkpoint_list = GlobalTrainingStatus.get_parameter_state(
behavior_name, StatusType.CHECKPOINTS
)
if not checkpoint_list:
checkpoint_list = []
GlobalTrainingStatus.set_parameter_state(
behavior_name, StatusType.CHECKPOINTS, checkpoint_list
)
return checkpoint_list
@staticmethod
def remove_checkpoint(checkpoint: Dict[str, Any]) -> None:
"""
Removes a checkpoint stored in checkpoint_list.
If checkpoint cannot be found, no action is done.
:param checkpoint: A checkpoint stored in checkpoint_list
"""
file_paths: List[str] = [checkpoint["file_path"]]
file_paths.extend(checkpoint["auxillary_file_paths"])
for file_path in file_paths:
if os.path.exists(file_path):
os.remove(file_path)
logger.debug(f"Removed checkpoint model {file_path}.")
else:
logger.debug(f"Checkpoint at {file_path} could not be found.")
return
@classmethod
def _cleanup_extra_checkpoints(
cls, checkpoints: List[Dict], keep_checkpoints: int
) -> List[Dict]:
"""
Ensures that the number of checkpoints stored are within the number
of checkpoints the user defines. If the limit is hit, checkpoints are
removed to create room for the next checkpoint to be inserted.
:param behavior_name: The behavior name whose checkpoints we will mange.
:param keep_checkpoints: Number of checkpoints to record (user-defined).
"""
while len(checkpoints) > keep_checkpoints:
if keep_checkpoints <= 0 or len(checkpoints) == 0:
break
ModelCheckpointManager.remove_checkpoint(checkpoints.pop(0))
return checkpoints
@classmethod
def add_checkpoint(
cls, behavior_name: str, new_checkpoint: ModelCheckpoint, keep_checkpoints: int
) -> None:
"""
Make room for new checkpoint if needed and insert new checkpoint information.
:param behavior_name: Behavior name for the checkpoint.
:param new_checkpoint: The new checkpoint to be recorded.
:param keep_checkpoints: Number of checkpoints to record (user-defined).
"""
new_checkpoint_dict = attr.asdict(new_checkpoint)
checkpoints = cls.get_checkpoints(behavior_name)
checkpoints.append(new_checkpoint_dict)
cls._cleanup_extra_checkpoints(checkpoints, keep_checkpoints)
GlobalTrainingStatus.set_parameter_state(
behavior_name, StatusType.CHECKPOINTS, checkpoints
)
@classmethod
def track_final_checkpoint(
cls, behavior_name: str, final_checkpoint: ModelCheckpoint
) -> None:
"""
Ensures number of checkpoints stored is within the max number of checkpoints
defined by the user and finally stores the information about the final
model (or intermediate model if training is interrupted).
:param behavior_name: Behavior name of the model.
:param final_checkpoint: Checkpoint information for the final model.
"""
final_model_dict = attr.asdict(final_checkpoint)
GlobalTrainingStatus.set_parameter_state(
behavior_name, StatusType.FINAL_CHECKPOINT, final_model_dict
)