import os from typing import Dict, Optional, Union import safetensors import torch from diffusers.utils import _get_model_file, logging from safetensors import safe_open logger = logging.get_logger(__name__) # pylint: disable=invalid-name class CustomAdapterMixin: def init_custom_adapter(self, *args, **kwargs): self._init_custom_adapter(*args, **kwargs) def _init_custom_adapter(self, *args, **kwargs): raise NotImplementedError def load_custom_adapter( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], weight_name: str, subfolder: Optional[str] = None, **kwargs, ): # Load the main state dict first. cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) revision = kwargs.pop("revision", None) user_agent = { "file_type": "attn_procs_weights", "framework": "pytorch", } if not isinstance(pretrained_model_name_or_path_or_dict, dict): model_file = _get_model_file( pretrained_model_name_or_path_or_dict, weights_name=weight_name, subfolder=subfolder, cache_dir=cache_dir, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, user_agent=user_agent, ) if weight_name.endswith(".safetensors"): state_dict = {} with safe_open(model_file, framework="pt", device="cpu") as f: for key in f.keys(): state_dict[key] = f.get_tensor(key) else: state_dict = torch.load(model_file, map_location="cpu") else: state_dict = pretrained_model_name_or_path_or_dict self._load_custom_adapter(state_dict) def _load_custom_adapter(self, state_dict): raise NotImplementedError def save_custom_adapter( self, save_directory: Union[str, os.PathLike], weight_name: str, safe_serialization: bool = False, **kwargs, ): if os.path.isfile(save_directory): logger.error( f"Provided path ({save_directory}) should be a directory, not a file" ) return if safe_serialization: def save_function(weights, filename): return safetensors.torch.save_file( weights, filename, metadata={"format": "pt"} ) else: save_function = torch.save # Save the model state_dict = self._save_custom_adapter(**kwargs) save_function(state_dict, os.path.join(save_directory, weight_name)) logger.info( f"Custom adapter weights saved in {os.path.join(save_directory, weight_name)}" ) def _save_custom_adapter(self): raise NotImplementedError