Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from importlib import import_module | |
from diffusers import ( | |
StableDiffusionControlNetImg2ImgPipeline, | |
StableDiffusionControlNetPipeline, | |
StableDiffusionImg2ImgPipeline, | |
StableDiffusionPipeline, | |
) | |
from diffusers.loaders.single_file import ( | |
SINGLE_FILE_OPTIONAL_COMPONENTS, | |
load_single_file_sub_model, | |
) | |
from diffusers.loaders.single_file_utils import fetch_diffusers_config, load_single_file_checkpoint | |
from diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT | |
from diffusers.pipelines.pipeline_loading_utils import ( | |
ALL_IMPORTABLE_CLASSES, | |
_get_pipeline_class, | |
load_sub_model, | |
) | |
from diffusers.utils import logging | |
from huggingface_hub import snapshot_download | |
from huggingface_hub.utils import validate_hf_hub_args | |
class CustomDiffusionMixin: | |
r""" | |
Overrides DiffusionPipeline methods. | |
""" | |
# Copied from https://github.com/huggingface/diffusers/blob/v0.30.3/src/diffusers/pipelines/pipeline_utils.py#L480 | |
def from_pretrained(cls, pretrained_model_name_or_path, progress=None, **kwargs): | |
torch_dtype = kwargs.pop("torch_dtype", None) | |
variant = kwargs.pop("variant", None) | |
token = kwargs.pop("token", None) | |
# download the checkpoints and configs | |
cached_folder = cls.download( | |
pretrained_model_name_or_path, | |
variant=variant, | |
token=token, | |
**kwargs, | |
) | |
# pop out "_ignore_files" as it is only needed for download | |
config_dict = cls.load_config(cached_folder) | |
config_dict.pop("_ignore_files", None) | |
# Define which model components should load variants. | |
# We retrieve the information by matching whether variant model checkpoints exist in the subfolders. | |
# Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors` with variant being `"fp16"`. | |
model_variants = {} | |
if variant is not None: | |
for folder in os.listdir(cached_folder): | |
folder_path = os.path.join(cached_folder, folder) | |
is_folder = os.path.isdir(folder_path) and folder in config_dict | |
variant_exists = is_folder and any( | |
p.split(".")[1].startswith(variant) for p in os.listdir(folder_path) | |
) | |
if variant_exists: | |
model_variants[folder] = variant | |
# load the pipeline class | |
pipeline_class = _get_pipeline_class(cls, config=config_dict) | |
# define expected modules given pipeline signature and define non-None initialized modules (=`init_kwargs`) | |
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) | |
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} | |
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} | |
def load_module(name, value): | |
if value[0] is None: | |
return False | |
if name in passed_class_obj and passed_class_obj[name] is None: | |
return False | |
return True | |
init_dict, _, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) | |
init_kwargs = { | |
k: init_dict.pop(k) | |
for k in optional_kwargs | |
if k in init_dict and k not in pipeline_class._optional_components | |
} | |
init_kwargs = {**init_kwargs, **passed_pipe_kwargs} | |
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} | |
# load each module in the pipeline | |
pipelines = import_module("diffusers.pipelines") | |
tqdm = logging.tqdm if progress is None else progress.tqdm | |
for name, (library_name, class_name) in tqdm( | |
sorted(init_dict.items()), | |
desc="Loading pipeline components", | |
): | |
# use passed sub model or load class_name from library_name | |
loaded_sub_model = None | |
if name in passed_class_obj: | |
# passed as an argument like "scheduler" | |
loaded_sub_model = passed_class_obj[name] | |
else: | |
loaded_sub_model = load_sub_model( | |
library_name=library_name, | |
class_name=class_name, | |
importable_classes=ALL_IMPORTABLE_CLASSES, | |
pipelines=pipelines, | |
is_pipeline_module=hasattr(pipelines, library_name), | |
pipeline_class=pipeline_class, | |
torch_dtype=torch_dtype, | |
provider=None, | |
sess_options=None, | |
device_map=None, | |
max_memory=None, | |
offload_folder=None, | |
offload_state_dict=False, | |
model_variants=model_variants, | |
name=name, | |
from_flax=False, | |
variant=variant, | |
low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT, | |
cached_folder=cached_folder, | |
) | |
init_kwargs[name] = loaded_sub_model | |
# potentially add passed objects if expected | |
missing_modules = set(expected_modules) - set(init_kwargs.keys()) | |
if len(missing_modules) > 0: | |
for module in missing_modules: | |
init_kwargs[module] = passed_class_obj.get(module, None) | |
# instantiate the pipeline | |
model = pipeline_class(**init_kwargs) | |
# save where the model was instantiated from | |
model.register_to_config(_name_or_path=pretrained_model_name_or_path) | |
return model | |
# Copied from https://github.com/huggingface/diffusers/blob/v0.30.3/src/diffusers/loaders/single_file.py#L270 | |
def from_single_file(cls, pretrained_model_link_or_path, progress=None, **kwargs): | |
token = kwargs.pop("token", None) | |
torch_dtype = kwargs.pop("torch_dtype", None) | |
# load the pipeline class | |
pipeline_class = _get_pipeline_class(cls, config=None) | |
checkpoint = load_single_file_checkpoint(pretrained_model_link_or_path, token=token) | |
config = fetch_diffusers_config(checkpoint) | |
default_pretrained_model_config_name = config["pretrained_model_name_or_path"] | |
# attempt to download the config files for the pipeline | |
cached_model_config_path = snapshot_download( | |
default_pretrained_model_config_name, | |
token=token, | |
allow_patterns=["**/*.json", "*.json", "*.txt", "**/*.txt", "**/*.model"], | |
) | |
# pop out "_ignore_files" as it is only needed for download | |
config_dict = pipeline_class.load_config(cached_model_config_path) | |
config_dict.pop("_ignore_files", None) | |
# define expected modules given pipeline signature and define non-None initialized modules (=`init_kwargs`) | |
expected_modules, optional_kwargs = pipeline_class._get_signature_keys(cls) | |
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} | |
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} | |
def load_module(name, value): | |
if value[0] is None: | |
return False | |
if name in passed_class_obj and passed_class_obj[name] is None: | |
return False | |
if name in SINGLE_FILE_OPTIONAL_COMPONENTS: | |
return False | |
return True | |
init_dict, _, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) | |
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict} | |
init_kwargs = {**init_kwargs, **passed_pipe_kwargs} | |
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} | |
# load each module in the pipeline | |
pipelines = import_module("diffusers.pipelines") | |
tqdm = logging.tqdm if progress is None else progress.tqdm | |
for name, (library_name, class_name) in tqdm( | |
sorted(init_dict.items()), | |
desc="Loading pipeline components", | |
): | |
# use passed sub model or load class_name from library_name | |
loaded_sub_model = None | |
if name in passed_class_obj: | |
# passed as an argument like "scheduler" | |
loaded_sub_model = passed_class_obj[name] | |
else: | |
loaded_sub_model = load_single_file_sub_model( | |
library_name=library_name, | |
class_name=class_name, | |
name=name, | |
checkpoint=checkpoint, | |
is_pipeline_module=hasattr(pipelines, library_name), | |
cached_model_config_path=cached_model_config_path, | |
pipelines=pipelines, | |
torch_dtype=torch_dtype, | |
**kwargs, | |
) | |
init_kwargs[name] = loaded_sub_model | |
# potentially add passed objects if expected | |
missing_modules = set(expected_modules) - set(init_kwargs.keys()) | |
if len(missing_modules) > 0: | |
for module in missing_modules: | |
init_kwargs[module] = passed_class_obj.get(module, None) | |
# instantiate the pipeline | |
pipe = pipeline_class(**init_kwargs) | |
# save where the model was instantiated from | |
pipe.register_to_config(_name_or_path=pretrained_model_link_or_path) | |
return pipe | |
class CustomStableDiffusionPipeline(CustomDiffusionMixin, StableDiffusionPipeline): | |
pass | |
class CustomStableDiffusionImg2ImgPipeline(CustomDiffusionMixin, StableDiffusionImg2ImgPipeline): | |
pass | |
class CustomStableDiffusionControlNetPipeline( | |
CustomDiffusionMixin, | |
StableDiffusionControlNetPipeline, | |
): | |
pass | |
class CustomStableDiffusionControlNetImg2ImgPipeline( | |
CustomDiffusionMixin, | |
StableDiffusionControlNetImg2ImgPipeline, | |
): | |
pass | |