File size: 6,620 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import os
import shutil
from mlagents.torch_utils import torch
from typing import Dict, Union, Optional, cast, Tuple, List
from mlagents_envs.exception import UnityPolicyException
from mlagents_envs.logging_util import get_logger
from mlagents.trainers.model_saver.model_saver import BaseModelSaver
from mlagents.trainers.settings import TrainerSettings, SerializationSettings
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer
from mlagents.trainers.torch_entities.model_serialization import ModelSerializer
logger = get_logger(__name__)
DEFAULT_CHECKPOINT_NAME = "checkpoint.pt"
class TorchModelSaver(BaseModelSaver):
"""
ModelSaver class for PyTorch
"""
def __init__(
self, trainer_settings: TrainerSettings, model_path: str, load: bool = False
):
super().__init__()
self.model_path = model_path
self.initialize_path = trainer_settings.init_path
self._keep_checkpoints = trainer_settings.keep_checkpoints
self.load = load
self.policy: Optional[TorchPolicy] = None
self.exporter: Optional[ModelSerializer] = None
self.modules: Dict[str, torch.nn.Modules] = {}
def register(self, module: Union[TorchPolicy, TorchOptimizer]) -> None:
if isinstance(module, TorchPolicy) or isinstance(module, TorchOptimizer):
self.modules.update(module.get_modules()) # type: ignore
else:
raise UnityPolicyException(
"Registering Object of unsupported type {} to ModelSaver ".format(
type(module)
)
)
if self.policy is None and isinstance(module, TorchPolicy):
self.policy = module
self.exporter = ModelSerializer(self.policy)
def save_checkpoint(self, behavior_name: str, step: int) -> Tuple[str, List[str]]:
if not os.path.exists(self.model_path):
os.makedirs(self.model_path)
checkpoint_path = os.path.join(self.model_path, f"{behavior_name}-{step}")
state_dict = {
name: module.state_dict() for name, module in self.modules.items()
}
pytorch_ckpt_path = f"{checkpoint_path}.pt"
export_ckpt_path = f"{checkpoint_path}.onnx"
torch.save(state_dict, f"{checkpoint_path}.pt")
torch.save(state_dict, os.path.join(self.model_path, DEFAULT_CHECKPOINT_NAME))
self.export(checkpoint_path, behavior_name)
return export_ckpt_path, [pytorch_ckpt_path]
def export(self, output_filepath: str, behavior_name: str) -> None:
if self.exporter is not None:
self.exporter.export_policy_model(output_filepath)
def initialize_or_load(self, policy: Optional[TorchPolicy] = None) -> None:
# Initialize/Load registered self.policy by default.
# If given input argument policy, use the input policy instead.
# This argument is mainly for initialization of the ghost trainer's fixed policy.
reset_steps = not self.load
if self.initialize_path is not None:
logger.info(f"Initializing from {self.initialize_path}.")
self._load_model(
self.initialize_path, policy, reset_global_steps=reset_steps
)
elif self.load:
logger.info(f"Resuming from {self.model_path}.")
self._load_model(
os.path.join(self.model_path, DEFAULT_CHECKPOINT_NAME),
policy,
reset_global_steps=reset_steps,
)
def _load_model(
self,
load_path: str,
policy: Optional[TorchPolicy] = None,
reset_global_steps: bool = False,
) -> None:
saved_state_dict = torch.load(load_path)
if policy is None:
modules = self.modules
policy = self.policy
else:
modules = policy.get_modules()
policy = cast(TorchPolicy, policy)
for name, mod in modules.items():
try:
if isinstance(mod, torch.nn.Module):
missing_keys, unexpected_keys = mod.load_state_dict(
saved_state_dict[name], strict=False
)
if missing_keys:
logger.warning(
f"Did not find these keys {missing_keys} in checkpoint. Initializing."
)
if unexpected_keys:
logger.warning(
f"Did not expect these keys {unexpected_keys} in checkpoint. Ignoring."
)
else:
# If module is not an nn.Module, try to load as one piece
mod.load_state_dict(saved_state_dict[name])
# KeyError is raised if the module was not present in the last run but is being
# accessed in the saved_state_dict.
# ValueError is raised by the optimizer's load_state_dict if the parameters have
# have changed. Note, the optimizer uses a completely different load_state_dict
# function because it is not an nn.Module.
# RuntimeError is raised by PyTorch if there is a size mismatch between modules
# of the same name. This will still partially assign values to those layers that
# have not changed shape.
except (KeyError, ValueError, RuntimeError) as err:
logger.warning(f"Failed to load for module {name}. Initializing")
logger.debug(f"Module loading error : {err}")
if reset_global_steps:
policy.set_step(0)
logger.info(
"Starting training from step 0 and saving to {}.".format(
self.model_path
)
)
else:
logger.info(f"Resuming training from step {policy.get_current_step()}.")
def copy_final_model(self, source_nn_path: str) -> None:
"""
Copy the .nn file at the given source to the destination.
Also copies the corresponding .onnx file if it exists.
"""
final_model_name = os.path.splitext(source_nn_path)[0]
if SerializationSettings.convert_to_onnx:
try:
source_path = f"{final_model_name}.onnx"
destination_path = f"{self.model_path}.onnx"
shutil.copyfile(source_path, destination_path)
logger.info(f"Copied {source_path} to {destination_path}.")
except OSError:
pass
|