import os import shutil from mlagents.torch_utils import torch from typing import Dict, Union, Optional, cast, Tuple, List from mlagents_envs.exception import UnityPolicyException from mlagents_envs.logging_util import get_logger from mlagents.trainers.model_saver.model_saver import BaseModelSaver from mlagents.trainers.settings import TrainerSettings, SerializationSettings from mlagents.trainers.policy.torch_policy import TorchPolicy from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer from mlagents.trainers.torch_entities.model_serialization import ModelSerializer logger = get_logger(__name__) DEFAULT_CHECKPOINT_NAME = "checkpoint.pt" class TorchModelSaver(BaseModelSaver): """ ModelSaver class for PyTorch """ def __init__( self, trainer_settings: TrainerSettings, model_path: str, load: bool = False ): super().__init__() self.model_path = model_path self.initialize_path = trainer_settings.init_path self._keep_checkpoints = trainer_settings.keep_checkpoints self.load = load self.policy: Optional[TorchPolicy] = None self.exporter: Optional[ModelSerializer] = None self.modules: Dict[str, torch.nn.Modules] = {} def register(self, module: Union[TorchPolicy, TorchOptimizer]) -> None: if isinstance(module, TorchPolicy) or isinstance(module, TorchOptimizer): self.modules.update(module.get_modules()) # type: ignore else: raise UnityPolicyException( "Registering Object of unsupported type {} to ModelSaver ".format( type(module) ) ) if self.policy is None and isinstance(module, TorchPolicy): self.policy = module self.exporter = ModelSerializer(self.policy) def save_checkpoint(self, behavior_name: str, step: int) -> Tuple[str, List[str]]: if not os.path.exists(self.model_path): os.makedirs(self.model_path) checkpoint_path = os.path.join(self.model_path, f"{behavior_name}-{step}") state_dict = { name: module.state_dict() for name, module in self.modules.items() } pytorch_ckpt_path = f"{checkpoint_path}.pt" export_ckpt_path = f"{checkpoint_path}.onnx" torch.save(state_dict, f"{checkpoint_path}.pt") torch.save(state_dict, os.path.join(self.model_path, DEFAULT_CHECKPOINT_NAME)) self.export(checkpoint_path, behavior_name) return export_ckpt_path, [pytorch_ckpt_path] def export(self, output_filepath: str, behavior_name: str) -> None: if self.exporter is not None: self.exporter.export_policy_model(output_filepath) def initialize_or_load(self, policy: Optional[TorchPolicy] = None) -> None: # Initialize/Load registered self.policy by default. # If given input argument policy, use the input policy instead. # This argument is mainly for initialization of the ghost trainer's fixed policy. reset_steps = not self.load if self.initialize_path is not None: logger.info(f"Initializing from {self.initialize_path}.") self._load_model( self.initialize_path, policy, reset_global_steps=reset_steps ) elif self.load: logger.info(f"Resuming from {self.model_path}.") self._load_model( os.path.join(self.model_path, DEFAULT_CHECKPOINT_NAME), policy, reset_global_steps=reset_steps, ) def _load_model( self, load_path: str, policy: Optional[TorchPolicy] = None, reset_global_steps: bool = False, ) -> None: saved_state_dict = torch.load(load_path) if policy is None: modules = self.modules policy = self.policy else: modules = policy.get_modules() policy = cast(TorchPolicy, policy) for name, mod in modules.items(): try: if isinstance(mod, torch.nn.Module): missing_keys, unexpected_keys = mod.load_state_dict( saved_state_dict[name], strict=False ) if missing_keys: logger.warning( f"Did not find these keys {missing_keys} in checkpoint. Initializing." ) if unexpected_keys: logger.warning( f"Did not expect these keys {unexpected_keys} in checkpoint. Ignoring." ) else: # If module is not an nn.Module, try to load as one piece mod.load_state_dict(saved_state_dict[name]) # KeyError is raised if the module was not present in the last run but is being # accessed in the saved_state_dict. # ValueError is raised by the optimizer's load_state_dict if the parameters have # have changed. Note, the optimizer uses a completely different load_state_dict # function because it is not an nn.Module. # RuntimeError is raised by PyTorch if there is a size mismatch between modules # of the same name. This will still partially assign values to those layers that # have not changed shape. except (KeyError, ValueError, RuntimeError) as err: logger.warning(f"Failed to load for module {name}. Initializing") logger.debug(f"Module loading error : {err}") if reset_global_steps: policy.set_step(0) logger.info( "Starting training from step 0 and saving to {}.".format( self.model_path ) ) else: logger.info(f"Resuming training from step {policy.get_current_step()}.") def copy_final_model(self, source_nn_path: str) -> None: """ Copy the .nn file at the given source to the destination. Also copies the corresponding .onnx file if it exists. """ final_model_name = os.path.splitext(source_nn_path)[0] if SerializationSettings.convert_to_onnx: try: source_path = f"{final_model_name}.onnx" destination_path = f"{self.model_path}.onnx" shutil.copyfile(source_path, destination_path) logger.info(f"Copied {source_path} to {destination_path}.") except OSError: pass