diffusers-sdxl-controlnet
/
src
/diffusers
/pipelines
/controlnet_xs
/pipeline_controlnet_xs_sd_xl.py
# Copyright 2024 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import inspect | |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
import numpy as np | |
import PIL.Image | |
import torch | |
import torch.nn.functional as F | |
from transformers import ( | |
CLIPImageProcessor, | |
CLIPTextModel, | |
CLIPTextModelWithProjection, | |
CLIPTokenizer, | |
) | |
from diffusers.utils.import_utils import is_invisible_watermark_available | |
from ...callbacks import MultiPipelineCallbacks, PipelineCallback | |
from ...image_processor import PipelineImageInput, VaeImageProcessor | |
from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin | |
from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel | |
from ...models.attention_processor import ( | |
AttnProcessor2_0, | |
XFormersAttnProcessor, | |
) | |
from ...models.lora import adjust_lora_scale_text_encoder | |
from ...schedulers import KarrasDiffusionSchedulers | |
from ...utils import ( | |
USE_PEFT_BACKEND, | |
logging, | |
replace_example_docstring, | |
scale_lora_layers, | |
unscale_lora_layers, | |
) | |
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor | |
from ..pipeline_utils import DiffusionPipeline | |
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput | |
if is_invisible_watermark_available(): | |
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
EXAMPLE_DOC_STRING = """ | |
Examples: | |
```py | |
>>> # !pip install opencv-python transformers accelerate | |
>>> from diffusers import StableDiffusionXLControlNetXSPipeline, ControlNetXSAdapter, AutoencoderKL | |
>>> from diffusers.utils import load_image | |
>>> import numpy as np | |
>>> import torch | |
>>> import cv2 | |
>>> from PIL import Image | |
>>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" | |
>>> negative_prompt = "low quality, bad quality, sketches" | |
>>> # download an image | |
>>> image = load_image( | |
... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" | |
... ) | |
>>> # initialize the models and pipeline | |
>>> controlnet_conditioning_scale = 0.5 | |
>>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) | |
>>> controlnet = ControlNetXSAdapter.from_pretrained( | |
... "UmerHA/Testing-ConrolNetXS-SDXL-canny", torch_dtype=torch.float16 | |
... ) | |
>>> pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( | |
... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16 | |
... ) | |
>>> pipe.enable_model_cpu_offload() | |
>>> # get canny image | |
>>> image = np.array(image) | |
>>> image = cv2.Canny(image, 100, 200) | |
>>> image = image[:, :, None] | |
>>> image = np.concatenate([image, image, image], axis=2) | |
>>> canny_image = Image.fromarray(image) | |
>>> # generate image | |
>>> image = pipe( | |
... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image | |
... ).images[0] | |
``` | |
""" | |
class StableDiffusionXLControlNetXSPipeline( | |
DiffusionPipeline, | |
TextualInversionLoaderMixin, | |
StableDiffusionXLLoraLoaderMixin, | |
FromSingleFileMixin, | |
): | |
r""" | |
Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet-XS guidance. | |
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods | |
implemented for all pipelines (downloading, saving, running on a particular device, etc.). | |
The pipeline also inherits the following loading methods: | |
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings | |
- [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights | |
- [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files | |
Args: | |
vae ([`AutoencoderKL`]): | |
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. | |
text_encoder ([`~transformers.CLIPTextModel`]): | |
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). | |
text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]): | |
Second frozen text-encoder | |
([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)). | |
tokenizer ([`~transformers.CLIPTokenizer`]): | |
A `CLIPTokenizer` to tokenize text. | |
tokenizer_2 ([`~transformers.CLIPTokenizer`]): | |
A `CLIPTokenizer` to tokenize text. | |
unet ([`UNet2DConditionModel`]): | |
A [`UNet2DConditionModel`] used to create a UNetControlNetXSModel to denoise the encoded image latents. | |
controlnet ([`ControlNetXSAdapter`]): | |
A [`ControlNetXSAdapter`] to be used in combination with `unet` to denoise the encoded image latents. | |
scheduler ([`SchedulerMixin`]): | |
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of | |
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. | |
force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): | |
Whether the negative prompt embeddings should always be set to 0. Also see the config of | |
`stabilityai/stable-diffusion-xl-base-1-0`. | |
add_watermarker (`bool`, *optional*): | |
Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to | |
watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no | |
watermarker is used. | |
""" | |
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" | |
_optional_components = [ | |
"tokenizer", | |
"tokenizer_2", | |
"text_encoder", | |
"text_encoder_2", | |
"feature_extractor", | |
] | |
_callback_tensor_inputs = [ | |
"latents", | |
"prompt_embeds", | |
"negative_prompt_embeds", | |
"add_text_embeds", | |
"add_time_ids", | |
"negative_pooled_prompt_embeds", | |
"negative_add_time_ids", | |
] | |
def __init__( | |
self, | |
vae: AutoencoderKL, | |
text_encoder: CLIPTextModel, | |
text_encoder_2: CLIPTextModelWithProjection, | |
tokenizer: CLIPTokenizer, | |
tokenizer_2: CLIPTokenizer, | |
unet: Union[UNet2DConditionModel, UNetControlNetXSModel], | |
controlnet: ControlNetXSAdapter, | |
scheduler: KarrasDiffusionSchedulers, | |
force_zeros_for_empty_prompt: bool = True, | |
add_watermarker: Optional[bool] = None, | |
feature_extractor: CLIPImageProcessor = None, | |
): | |
super().__init__() | |
if isinstance(unet, UNet2DConditionModel): | |
unet = UNetControlNetXSModel.from_unet(unet, controlnet) | |
self.register_modules( | |
vae=vae, | |
text_encoder=text_encoder, | |
text_encoder_2=text_encoder_2, | |
tokenizer=tokenizer, | |
tokenizer_2=tokenizer_2, | |
unet=unet, | |
controlnet=controlnet, | |
scheduler=scheduler, | |
feature_extractor=feature_extractor, | |
) | |
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) | |
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) | |
self.control_image_processor = VaeImageProcessor( | |
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False | |
) | |
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() | |
if add_watermarker: | |
self.watermark = StableDiffusionXLWatermarker() | |
else: | |
self.watermark = None | |
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) | |
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt | |
def encode_prompt( | |
self, | |
prompt: str, | |
prompt_2: Optional[str] = None, | |
device: Optional[torch.device] = None, | |
num_images_per_prompt: int = 1, | |
do_classifier_free_guidance: bool = True, | |
negative_prompt: Optional[str] = None, | |
negative_prompt_2: Optional[str] = None, | |
prompt_embeds: Optional[torch.Tensor] = None, | |
negative_prompt_embeds: Optional[torch.Tensor] = None, | |
pooled_prompt_embeds: Optional[torch.Tensor] = None, | |
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, | |
lora_scale: Optional[float] = None, | |
clip_skip: Optional[int] = None, | |
): | |
r""" | |
Encodes the prompt into text encoder hidden states. | |
Args: | |
prompt (`str` or `List[str]`, *optional*): | |
prompt to be encoded | |
prompt_2 (`str` or `List[str]`, *optional*): | |
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is | |
used in both text-encoders | |
device: (`torch.device`): | |
torch device | |
num_images_per_prompt (`int`): | |
number of images that should be generated per prompt | |
do_classifier_free_guidance (`bool`): | |
whether to use classifier free guidance or not | |
negative_prompt (`str` or `List[str]`, *optional*): | |
The prompt or prompts not to guide the image generation. If not defined, one has to pass | |
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is | |
less than `1`). | |
negative_prompt_2 (`str` or `List[str]`, *optional*): | |
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and | |
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders | |
prompt_embeds (`torch.Tensor`, *optional*): | |
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not | |
provided, text embeddings will be generated from `prompt` input argument. | |
negative_prompt_embeds (`torch.Tensor`, *optional*): | |
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt | |
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input | |
argument. | |
pooled_prompt_embeds (`torch.Tensor`, *optional*): | |
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. | |
If not provided, pooled text embeddings will be generated from `prompt` input argument. | |
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): | |
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt | |
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` | |
input argument. | |
lora_scale (`float`, *optional*): | |
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. | |
clip_skip (`int`, *optional*): | |
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that | |
the output of the pre-final layer will be used for computing the prompt embeddings. | |
""" | |
device = device or self._execution_device | |
# set lora scale so that monkey patched LoRA | |
# function of text encoder can correctly access it | |
if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): | |
self._lora_scale = lora_scale | |
# dynamically adjust the LoRA scale | |
if self.text_encoder is not None: | |
if not USE_PEFT_BACKEND: | |
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) | |
else: | |
scale_lora_layers(self.text_encoder, lora_scale) | |
if self.text_encoder_2 is not None: | |
if not USE_PEFT_BACKEND: | |
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) | |
else: | |
scale_lora_layers(self.text_encoder_2, lora_scale) | |
prompt = [prompt] if isinstance(prompt, str) else prompt | |
if prompt is not None: | |
batch_size = len(prompt) | |
else: | |
batch_size = prompt_embeds.shape[0] | |
# Define tokenizers and text encoders | |
tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] | |
text_encoders = ( | |
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] | |
) | |
if prompt_embeds is None: | |
prompt_2 = prompt_2 or prompt | |
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 | |
# textual inversion: process multi-vector tokens if necessary | |
prompt_embeds_list = [] | |
prompts = [prompt, prompt_2] | |
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): | |
if isinstance(self, TextualInversionLoaderMixin): | |
prompt = self.maybe_convert_prompt(prompt, tokenizer) | |
text_inputs = tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids | |
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids | |
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( | |
text_input_ids, untruncated_ids | |
): | |
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) | |
logger.warning( | |
"The following part of your input was truncated because CLIP can only handle sequences up to" | |
f" {tokenizer.model_max_length} tokens: {removed_text}" | |
) | |
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) | |
# We are only ALWAYS interested in the pooled output of the final text encoder | |
pooled_prompt_embeds = prompt_embeds[0] | |
if clip_skip is None: | |
prompt_embeds = prompt_embeds.hidden_states[-2] | |
else: | |
# "2" because SDXL always indexes from the penultimate layer. | |
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] | |
prompt_embeds_list.append(prompt_embeds) | |
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) | |
# get unconditional embeddings for classifier free guidance | |
zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt | |
if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: | |
negative_prompt_embeds = torch.zeros_like(prompt_embeds) | |
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) | |
elif do_classifier_free_guidance and negative_prompt_embeds is None: | |
negative_prompt = negative_prompt or "" | |
negative_prompt_2 = negative_prompt_2 or negative_prompt | |
# normalize str to list | |
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt | |
negative_prompt_2 = ( | |
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 | |
) | |
uncond_tokens: List[str] | |
if prompt is not None and type(prompt) is not type(negative_prompt): | |
raise TypeError( | |
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" | |
f" {type(prompt)}." | |
) | |
elif batch_size != len(negative_prompt): | |
raise ValueError( | |
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" | |
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" | |
" the batch size of `prompt`." | |
) | |
else: | |
uncond_tokens = [negative_prompt, negative_prompt_2] | |
negative_prompt_embeds_list = [] | |
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): | |
if isinstance(self, TextualInversionLoaderMixin): | |
negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) | |
max_length = prompt_embeds.shape[1] | |
uncond_input = tokenizer( | |
negative_prompt, | |
padding="max_length", | |
max_length=max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
negative_prompt_embeds = text_encoder( | |
uncond_input.input_ids.to(device), | |
output_hidden_states=True, | |
) | |
# We are only ALWAYS interested in the pooled output of the final text encoder | |
negative_pooled_prompt_embeds = negative_prompt_embeds[0] | |
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] | |
negative_prompt_embeds_list.append(negative_prompt_embeds) | |
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) | |
if self.text_encoder_2 is not None: | |
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) | |
else: | |
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) | |
bs_embed, seq_len, _ = prompt_embeds.shape | |
# duplicate text embeddings for each generation per prompt, using mps friendly method | |
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) | |
if do_classifier_free_guidance: | |
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method | |
seq_len = negative_prompt_embeds.shape[1] | |
if self.text_encoder_2 is not None: | |
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) | |
else: | |
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) | |
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) | |
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( | |
bs_embed * num_images_per_prompt, -1 | |
) | |
if do_classifier_free_guidance: | |
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( | |
bs_embed * num_images_per_prompt, -1 | |
) | |
if self.text_encoder is not None: | |
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: | |
# Retrieve the original scale by scaling back the LoRA layers | |
unscale_lora_layers(self.text_encoder, lora_scale) | |
if self.text_encoder_2 is not None: | |
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: | |
# Retrieve the original scale by scaling back the LoRA layers | |
unscale_lora_layers(self.text_encoder_2, lora_scale) | |
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds | |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs | |
def prepare_extra_step_kwargs(self, generator, eta): | |
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | |
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | |
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | |
# and should be between [0, 1] | |
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) | |
extra_step_kwargs = {} | |
if accepts_eta: | |
extra_step_kwargs["eta"] = eta | |
# check if the scheduler accepts generator | |
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) | |
if accepts_generator: | |
extra_step_kwargs["generator"] = generator | |
return extra_step_kwargs | |
def check_inputs( | |
self, | |
prompt, | |
prompt_2, | |
image, | |
negative_prompt=None, | |
negative_prompt_2=None, | |
prompt_embeds=None, | |
negative_prompt_embeds=None, | |
pooled_prompt_embeds=None, | |
negative_pooled_prompt_embeds=None, | |
controlnet_conditioning_scale=1.0, | |
control_guidance_start=0.0, | |
control_guidance_end=1.0, | |
callback_on_step_end_tensor_inputs=None, | |
): | |
if callback_on_step_end_tensor_inputs is not None and not all( | |
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs | |
): | |
raise ValueError( | |
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" | |
) | |
if prompt is not None and prompt_embeds is not None: | |
raise ValueError( | |
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" | |
" only forward one of the two." | |
) | |
elif prompt_2 is not None and prompt_embeds is not None: | |
raise ValueError( | |
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" | |
" only forward one of the two." | |
) | |
elif prompt is None and prompt_embeds is None: | |
raise ValueError( | |
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." | |
) | |
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): | |
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | |
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): | |
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") | |
if negative_prompt is not None and negative_prompt_embeds is not None: | |
raise ValueError( | |
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" | |
f" {negative_prompt_embeds}. Please make sure to only forward one of the two." | |
) | |
elif negative_prompt_2 is not None and negative_prompt_embeds is not None: | |
raise ValueError( | |
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" | |
f" {negative_prompt_embeds}. Please make sure to only forward one of the two." | |
) | |
if prompt_embeds is not None and negative_prompt_embeds is not None: | |
if prompt_embeds.shape != negative_prompt_embeds.shape: | |
raise ValueError( | |
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" | |
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" | |
f" {negative_prompt_embeds.shape}." | |
) | |
if prompt_embeds is not None and pooled_prompt_embeds is None: | |
raise ValueError( | |
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." | |
) | |
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: | |
raise ValueError( | |
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." | |
) | |
# Check `image` and ``controlnet_conditioning_scale`` | |
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( | |
self.unet, torch._dynamo.eval_frame.OptimizedModule | |
) | |
if ( | |
isinstance(self.unet, UNetControlNetXSModel) | |
or is_compiled | |
and isinstance(self.unet._orig_mod, UNetControlNetXSModel) | |
): | |
self.check_image(image, prompt, prompt_embeds) | |
if not isinstance(controlnet_conditioning_scale, float): | |
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") | |
else: | |
assert False | |
start, end = control_guidance_start, control_guidance_end | |
if start >= end: | |
raise ValueError( | |
f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." | |
) | |
if start < 0.0: | |
raise ValueError(f"control guidance start: {start} can't be smaller than 0.") | |
if end > 1.0: | |
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") | |
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image | |
def check_image(self, image, prompt, prompt_embeds): | |
image_is_pil = isinstance(image, PIL.Image.Image) | |
image_is_tensor = isinstance(image, torch.Tensor) | |
image_is_np = isinstance(image, np.ndarray) | |
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) | |
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) | |
image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) | |
if ( | |
not image_is_pil | |
and not image_is_tensor | |
and not image_is_np | |
and not image_is_pil_list | |
and not image_is_tensor_list | |
and not image_is_np_list | |
): | |
raise TypeError( | |
f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" | |
) | |
if image_is_pil: | |
image_batch_size = 1 | |
else: | |
image_batch_size = len(image) | |
if prompt is not None and isinstance(prompt, str): | |
prompt_batch_size = 1 | |
elif prompt is not None and isinstance(prompt, list): | |
prompt_batch_size = len(prompt) | |
elif prompt_embeds is not None: | |
prompt_batch_size = prompt_embeds.shape[0] | |
if image_batch_size != 1 and image_batch_size != prompt_batch_size: | |
raise ValueError( | |
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" | |
) | |
def prepare_image( | |
self, | |
image, | |
width, | |
height, | |
batch_size, | |
num_images_per_prompt, | |
device, | |
dtype, | |
do_classifier_free_guidance=False, | |
): | |
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) | |
image_batch_size = image.shape[0] | |
if image_batch_size == 1: | |
repeat_by = batch_size | |
else: | |
# image batch size is the same as prompt batch size | |
repeat_by = num_images_per_prompt | |
image = image.repeat_interleave(repeat_by, dim=0) | |
image = image.to(device=device, dtype=dtype) | |
if do_classifier_free_guidance: | |
image = torch.cat([image] * 2) | |
return image | |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents | |
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): | |
shape = ( | |
batch_size, | |
num_channels_latents, | |
int(height) // self.vae_scale_factor, | |
int(width) // self.vae_scale_factor, | |
) | |
if isinstance(generator, list) and len(generator) != batch_size: | |
raise ValueError( | |
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | |
f" size of {batch_size}. Make sure the batch size matches the length of the generators." | |
) | |
if latents is None: | |
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | |
else: | |
latents = latents.to(device) | |
# scale the initial noise by the standard deviation required by the scheduler | |
latents = latents * self.scheduler.init_noise_sigma | |
return latents | |
def _get_add_time_ids( | |
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None | |
): | |
add_time_ids = list(original_size + crops_coords_top_left + target_size) | |
passed_add_embed_dim = ( | |
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim | |
) | |
expected_add_embed_dim = self.unet.base_add_embedding.linear_1.in_features | |
if expected_add_embed_dim != passed_add_embed_dim: | |
raise ValueError( | |
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." | |
) | |
add_time_ids = torch.tensor([add_time_ids], dtype=dtype) | |
return add_time_ids | |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae | |
def upcast_vae(self): | |
dtype = self.vae.dtype | |
self.vae.to(dtype=torch.float32) | |
use_torch_2_0_or_xformers = isinstance( | |
self.vae.decoder.mid_block.attentions[0].processor, | |
( | |
AttnProcessor2_0, | |
XFormersAttnProcessor, | |
), | |
) | |
# if xformers or torch_2_0 is used attention block does not need | |
# to be in float32 which can save lots of memory | |
if use_torch_2_0_or_xformers: | |
self.vae.post_quant_conv.to(dtype) | |
self.vae.decoder.conv_in.to(dtype) | |
self.vae.decoder.mid_block.to(dtype) | |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale | |
def guidance_scale(self): | |
return self._guidance_scale | |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip | |
def clip_skip(self): | |
return self._clip_skip | |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance | |
def do_classifier_free_guidance(self): | |
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None | |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs | |
def cross_attention_kwargs(self): | |
return self._cross_attention_kwargs | |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps | |
def num_timesteps(self): | |
return self._num_timesteps | |
def __call__( | |
self, | |
prompt: Union[str, List[str]] = None, | |
prompt_2: Optional[Union[str, List[str]]] = None, | |
image: PipelineImageInput = None, | |
height: Optional[int] = None, | |
width: Optional[int] = None, | |
num_inference_steps: int = 50, | |
guidance_scale: float = 5.0, | |
negative_prompt: Optional[Union[str, List[str]]] = None, | |
negative_prompt_2: Optional[Union[str, List[str]]] = None, | |
num_images_per_prompt: Optional[int] = 1, | |
eta: float = 0.0, | |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
latents: Optional[torch.Tensor] = None, | |
prompt_embeds: Optional[torch.Tensor] = None, | |
negative_prompt_embeds: Optional[torch.Tensor] = None, | |
pooled_prompt_embeds: Optional[torch.Tensor] = None, | |
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, | |
output_type: Optional[str] = "pil", | |
return_dict: bool = True, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
controlnet_conditioning_scale: Union[float, List[float]] = 1.0, | |
control_guidance_start: float = 0.0, | |
control_guidance_end: float = 1.0, | |
original_size: Tuple[int, int] = None, | |
crops_coords_top_left: Tuple[int, int] = (0, 0), | |
target_size: Tuple[int, int] = None, | |
negative_original_size: Optional[Tuple[int, int]] = None, | |
negative_crops_coords_top_left: Tuple[int, int] = (0, 0), | |
negative_target_size: Optional[Tuple[int, int]] = None, | |
clip_skip: Optional[int] = None, | |
callback_on_step_end: Optional[ | |
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] | |
] = None, | |
callback_on_step_end_tensor_inputs: List[str] = ["latents"], | |
): | |
r""" | |
The call function to the pipeline for generation. | |
Args: | |
prompt (`str` or `List[str]`, *optional*): | |
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. | |
prompt_2 (`str` or `List[str]`, *optional*): | |
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is | |
used in both text-encoders. | |
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: | |
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): | |
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is | |
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted | |
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or | |
width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, | |
images must be passed as a list such that each element of the list can be correctly batched for input | |
to a single ControlNet. | |
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): | |
The height in pixels of the generated image. Anything below 512 pixels won't work well for | |
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) | |
and checkpoints that are not specifically fine-tuned on low resolutions. | |
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): | |
The width in pixels of the generated image. Anything below 512 pixels won't work well for | |
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) | |
and checkpoints that are not specifically fine-tuned on low resolutions. | |
num_inference_steps (`int`, *optional*, defaults to 50): | |
The number of denoising steps. More denoising steps usually lead to a higher quality image at the | |
expense of slower inference. | |
guidance_scale (`float`, *optional*, defaults to 5.0): | |
A higher guidance scale value encourages the model to generate images closely linked to the text | |
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. | |
negative_prompt (`str` or `List[str]`, *optional*): | |
The prompt or prompts to guide what to not include in image generation. If not defined, you need to | |
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). | |
negative_prompt_2 (`str` or `List[str]`, *optional*): | |
The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` | |
and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. | |
num_images_per_prompt (`int`, *optional*, defaults to 1): | |
The number of images to generate per prompt. | |
eta (`float`, *optional*, defaults to 0.0): | |
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies | |
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. | |
generator (`torch.Generator` or `List[torch.Generator]`, *optional*): | |
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make | |
generation deterministic. | |
latents (`torch.Tensor`, *optional*): | |
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image | |
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents | |
tensor is generated by sampling using the supplied random `generator`. | |
prompt_embeds (`torch.Tensor`, *optional*): | |
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not | |
provided, text embeddings are generated from the `prompt` input argument. | |
negative_prompt_embeds (`torch.Tensor`, *optional*): | |
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If | |
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. | |
pooled_prompt_embeds (`torch.Tensor`, *optional*): | |
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If | |
not provided, pooled text embeddings are generated from `prompt` input argument. | |
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): | |
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt | |
weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input | |
argument. | |
output_type (`str`, *optional*, defaults to `"pil"`): | |
The output format of the generated image. Choose between `PIL.Image` or `np.array`. | |
return_dict (`bool`, *optional*, defaults to `True`): | |
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a | |
plain tuple. | |
cross_attention_kwargs (`dict`, *optional*): | |
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in | |
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). | |
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): | |
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added | |
to the residual in the original `unet`. | |
control_guidance_start (`float`, *optional*, defaults to 0.0): | |
The percentage of total steps at which the ControlNet starts applying. | |
control_guidance_end (`float`, *optional*, defaults to 1.0): | |
The percentage of total steps at which the ControlNet stops applying. | |
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): | |
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. | |
`original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as | |
explained in section 2.2 of | |
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). | |
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): | |
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position | |
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting | |
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of | |
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). | |
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): | |
For most cases, `target_size` should be set to the desired height and width of the generated image. If | |
not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in | |
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). | |
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): | |
To negatively condition the generation process based on a specific image resolution. Part of SDXL's | |
micro-conditioning as explained in section 2.2 of | |
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more | |
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. | |
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): | |
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's | |
micro-conditioning as explained in section 2.2 of | |
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more | |
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. | |
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): | |
To negatively condition the generation process based on a target image resolution. It should be as same | |
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of | |
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more | |
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. | |
clip_skip (`int`, *optional*): | |
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that | |
the output of the pre-final layer will be used for computing the prompt embeddings. | |
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): | |
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of | |
each denoising step during the inference. with the following arguments: `callback_on_step_end(self: | |
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a | |
list of all tensors as specified by `callback_on_step_end_tensor_inputs`. | |
callback_on_step_end_tensor_inputs (`List`, *optional*): | |
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list | |
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the | |
`._callback_tensor_inputs` attribute of your pipeine class. | |
Examples: | |
Returns: | |
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: | |
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] is | |
returned, otherwise a `tuple` is returned containing the output images. | |
""" | |
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): | |
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs | |
unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet | |
# 1. Check inputs. Raise error if not correct | |
self.check_inputs( | |
prompt, | |
prompt_2, | |
image, | |
negative_prompt, | |
negative_prompt_2, | |
prompt_embeds, | |
negative_prompt_embeds, | |
pooled_prompt_embeds, | |
negative_pooled_prompt_embeds, | |
controlnet_conditioning_scale, | |
control_guidance_start, | |
control_guidance_end, | |
callback_on_step_end_tensor_inputs, | |
) | |
self._guidance_scale = guidance_scale | |
self._clip_skip = clip_skip | |
self._cross_attention_kwargs = cross_attention_kwargs | |
self._interrupt = False | |
# 2. Define call parameters | |
if prompt is not None and isinstance(prompt, str): | |
batch_size = 1 | |
elif prompt is not None and isinstance(prompt, list): | |
batch_size = len(prompt) | |
else: | |
batch_size = prompt_embeds.shape[0] | |
device = self._execution_device | |
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | |
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | |
# corresponds to doing no classifier free guidance. | |
do_classifier_free_guidance = guidance_scale > 1.0 | |
# 3. Encode input prompt | |
text_encoder_lora_scale = ( | |
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None | |
) | |
( | |
prompt_embeds, | |
negative_prompt_embeds, | |
pooled_prompt_embeds, | |
negative_pooled_prompt_embeds, | |
) = self.encode_prompt( | |
prompt, | |
prompt_2, | |
device, | |
num_images_per_prompt, | |
do_classifier_free_guidance, | |
negative_prompt, | |
negative_prompt_2, | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
pooled_prompt_embeds=pooled_prompt_embeds, | |
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, | |
lora_scale=text_encoder_lora_scale, | |
clip_skip=clip_skip, | |
) | |
# 4. Prepare image | |
if isinstance(unet, UNetControlNetXSModel): | |
image = self.prepare_image( | |
image=image, | |
width=width, | |
height=height, | |
batch_size=batch_size * num_images_per_prompt, | |
num_images_per_prompt=num_images_per_prompt, | |
device=device, | |
dtype=unet.dtype, | |
do_classifier_free_guidance=do_classifier_free_guidance, | |
) | |
height, width = image.shape[-2:] | |
else: | |
assert False | |
# 5. Prepare timesteps | |
self.scheduler.set_timesteps(num_inference_steps, device=device) | |
timesteps = self.scheduler.timesteps | |
# 6. Prepare latent variables | |
num_channels_latents = self.unet.in_channels | |
latents = self.prepare_latents( | |
batch_size * num_images_per_prompt, | |
num_channels_latents, | |
height, | |
width, | |
prompt_embeds.dtype, | |
device, | |
generator, | |
latents, | |
) | |
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | |
# 7.1 Prepare added time ids & embeddings | |
if isinstance(image, list): | |
original_size = original_size or image[0].shape[-2:] | |
else: | |
original_size = original_size or image.shape[-2:] | |
target_size = target_size or (height, width) | |
add_text_embeds = pooled_prompt_embeds | |
if self.text_encoder_2 is None: | |
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) | |
else: | |
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim | |
add_time_ids = self._get_add_time_ids( | |
original_size, | |
crops_coords_top_left, | |
target_size, | |
dtype=prompt_embeds.dtype, | |
text_encoder_projection_dim=text_encoder_projection_dim, | |
) | |
if negative_original_size is not None and negative_target_size is not None: | |
negative_add_time_ids = self._get_add_time_ids( | |
negative_original_size, | |
negative_crops_coords_top_left, | |
negative_target_size, | |
dtype=prompt_embeds.dtype, | |
text_encoder_projection_dim=text_encoder_projection_dim, | |
) | |
else: | |
negative_add_time_ids = add_time_ids | |
if do_classifier_free_guidance: | |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) | |
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) | |
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) | |
prompt_embeds = prompt_embeds.to(device) | |
add_text_embeds = add_text_embeds.to(device) | |
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) | |
# 8. Denoising loop | |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | |
self._num_timesteps = len(timesteps) | |
is_controlnet_compiled = is_compiled_module(self.unet) | |
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") | |
with self.progress_bar(total=num_inference_steps) as progress_bar: | |
for i, t in enumerate(timesteps): | |
# Relevant thread: | |
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 | |
if is_controlnet_compiled and is_torch_higher_equal_2_1: | |
torch._inductor.cudagraph_mark_step_begin() | |
# expand the latents if we are doing classifier free guidance | |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} | |
# predict the noise residual | |
apply_control = ( | |
i / len(timesteps) >= control_guidance_start and (i + 1) / len(timesteps) <= control_guidance_end | |
) | |
noise_pred = self.unet( | |
sample=latent_model_input, | |
timestep=t, | |
encoder_hidden_states=prompt_embeds, | |
controlnet_cond=image, | |
conditioning_scale=controlnet_conditioning_scale, | |
cross_attention_kwargs=cross_attention_kwargs, | |
added_cond_kwargs=added_cond_kwargs, | |
return_dict=True, | |
apply_control=apply_control, | |
).sample | |
# perform guidance | |
if do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
# compute the previous noisy sample x_t -> x_t-1 | |
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] | |
if callback_on_step_end is not None: | |
callback_kwargs = {} | |
for k in callback_on_step_end_tensor_inputs: | |
callback_kwargs[k] = locals()[k] | |
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) | |
latents = callback_outputs.pop("latents", latents) | |
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) | |
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) | |
# call the callback, if provided | |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | |
progress_bar.update() | |
# manually for max memory savings | |
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: | |
self.upcast_vae() | |
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) | |
if not output_type == "latent": | |
# make sure the VAE is in float32 mode, as it overflows in float16 | |
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast | |
if needs_upcasting: | |
self.upcast_vae() | |
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) | |
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] | |
# cast back to fp16 if needed | |
if needs_upcasting: | |
self.vae.to(dtype=torch.float16) | |
else: | |
image = latents | |
if not output_type == "latent": | |
# apply watermark if available | |
if self.watermark is not None: | |
image = self.watermark.apply_watermark(image) | |
image = self.image_processor.postprocess(image, output_type=output_type) | |
# Offload all models | |
self.maybe_free_model_hooks() | |
if not return_dict: | |
return (image,) | |
return StableDiffusionXLPipelineOutput(images=image) | |