diffusers-sdxl-controlnet
/
src
/diffusers
/pipelines
/controlnet
/pipeline_controlnet_blip_diffusion.py
# Copyright 2024 Salesforce.com, inc. | |
# Copyright 2024 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import List, Optional, Union | |
import PIL.Image | |
import torch | |
from transformers import CLIPTokenizer | |
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel | |
from ...schedulers import PNDMScheduler | |
from ...utils import ( | |
logging, | |
replace_example_docstring, | |
) | |
from ...utils.torch_utils import randn_tensor | |
from ..blip_diffusion.blip_image_processing import BlipImageProcessor | |
from ..blip_diffusion.modeling_blip2 import Blip2QFormerModel | |
from ..blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel | |
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
EXAMPLE_DOC_STRING = """ | |
Examples: | |
```py | |
>>> from diffusers.pipelines import BlipDiffusionControlNetPipeline | |
>>> from diffusers.utils import load_image | |
>>> from controlnet_aux import CannyDetector | |
>>> import torch | |
>>> blip_diffusion_pipe = BlipDiffusionControlNetPipeline.from_pretrained( | |
... "Salesforce/blipdiffusion-controlnet", torch_dtype=torch.float16 | |
... ).to("cuda") | |
>>> style_subject = "flower" | |
>>> tgt_subject = "teapot" | |
>>> text_prompt = "on a marble table" | |
>>> cldm_cond_image = load_image( | |
... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/kettle.jpg" | |
... ).resize((512, 512)) | |
>>> canny = CannyDetector() | |
>>> cldm_cond_image = canny(cldm_cond_image, 30, 70, output_type="pil") | |
>>> style_image = load_image( | |
... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/flower.jpg" | |
... ) | |
>>> guidance_scale = 7.5 | |
>>> num_inference_steps = 50 | |
>>> negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate" | |
>>> output = blip_diffusion_pipe( | |
... text_prompt, | |
... style_image, | |
... cldm_cond_image, | |
... style_subject, | |
... tgt_subject, | |
... guidance_scale=guidance_scale, | |
... num_inference_steps=num_inference_steps, | |
... neg_prompt=negative_prompt, | |
... height=512, | |
... width=512, | |
... ).images | |
>>> output[0].save("image.png") | |
``` | |
""" | |
class BlipDiffusionControlNetPipeline(DiffusionPipeline): | |
""" | |
Pipeline for Canny Edge based Controlled subject-driven generation using Blip Diffusion. | |
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the | |
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) | |
Args: | |
tokenizer ([`CLIPTokenizer`]): | |
Tokenizer for the text encoder | |
text_encoder ([`ContextCLIPTextModel`]): | |
Text encoder to encode the text prompt | |
vae ([`AutoencoderKL`]): | |
VAE model to map the latents to the image | |
unet ([`UNet2DConditionModel`]): | |
Conditional U-Net architecture to denoise the image embedding. | |
scheduler ([`PNDMScheduler`]): | |
A scheduler to be used in combination with `unet` to generate image latents. | |
qformer ([`Blip2QFormerModel`]): | |
QFormer model to get multi-modal embeddings from the text and image. | |
controlnet ([`ControlNetModel`]): | |
ControlNet model to get the conditioning image embedding. | |
image_processor ([`BlipImageProcessor`]): | |
Image Processor to preprocess and postprocess the image. | |
ctx_begin_pos (int, `optional`, defaults to 2): | |
Position of the context token in the text encoder. | |
""" | |
model_cpu_offload_seq = "qformer->text_encoder->unet->vae" | |
def __init__( | |
self, | |
tokenizer: CLIPTokenizer, | |
text_encoder: ContextCLIPTextModel, | |
vae: AutoencoderKL, | |
unet: UNet2DConditionModel, | |
scheduler: PNDMScheduler, | |
qformer: Blip2QFormerModel, | |
controlnet: ControlNetModel, | |
image_processor: BlipImageProcessor, | |
ctx_begin_pos: int = 2, | |
mean: List[float] = None, | |
std: List[float] = None, | |
): | |
super().__init__() | |
self.register_modules( | |
tokenizer=tokenizer, | |
text_encoder=text_encoder, | |
vae=vae, | |
unet=unet, | |
scheduler=scheduler, | |
qformer=qformer, | |
controlnet=controlnet, | |
image_processor=image_processor, | |
) | |
self.register_to_config(ctx_begin_pos=ctx_begin_pos, mean=mean, std=std) | |
def get_query_embeddings(self, input_image, src_subject): | |
return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False) | |
# from the original Blip Diffusion code, speciefies the target subject and augments the prompt by repeating it | |
def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20): | |
rv = [] | |
for prompt, tgt_subject in zip(prompts, tgt_subjects): | |
prompt = f"a {tgt_subject} {prompt.strip()}" | |
# a trick to amplify the prompt | |
rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps))) | |
return rv | |
# Copied from diffusers.pipelines.consistency_models.pipeline_consistency_models.ConsistencyModelPipeline.prepare_latents | |
def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None): | |
shape = (batch_size, num_channels, height, width) | |
if isinstance(generator, list) and len(generator) != batch_size: | |
raise ValueError( | |
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | |
f" size of {batch_size}. Make sure the batch size matches the length of the generators." | |
) | |
if latents is None: | |
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | |
else: | |
latents = latents.to(device=device, dtype=dtype) | |
# scale the initial noise by the standard deviation required by the scheduler | |
latents = latents * self.scheduler.init_noise_sigma | |
return latents | |
def encode_prompt(self, query_embeds, prompt, device=None): | |
device = device or self._execution_device | |
# embeddings for prompt, with query_embeds as context | |
max_len = self.text_encoder.text_model.config.max_position_embeddings | |
max_len -= self.qformer.config.num_query_tokens | |
tokenized_prompt = self.tokenizer( | |
prompt, | |
padding="max_length", | |
truncation=True, | |
max_length=max_len, | |
return_tensors="pt", | |
).to(device) | |
batch_size = query_embeds.shape[0] | |
ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size | |
text_embeddings = self.text_encoder( | |
input_ids=tokenized_prompt.input_ids, | |
ctx_embeddings=query_embeds, | |
ctx_begin_pos=ctx_begin_pos, | |
)[0] | |
return text_embeddings | |
# Adapted from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image | |
def prepare_control_image( | |
self, | |
image, | |
width, | |
height, | |
batch_size, | |
num_images_per_prompt, | |
device, | |
dtype, | |
do_classifier_free_guidance=False, | |
): | |
image = self.image_processor.preprocess( | |
image, | |
size={"width": width, "height": height}, | |
do_rescale=True, | |
do_center_crop=False, | |
do_normalize=False, | |
return_tensors="pt", | |
)["pixel_values"].to(device) | |
image_batch_size = image.shape[0] | |
if image_batch_size == 1: | |
repeat_by = batch_size | |
else: | |
# image batch size is the same as prompt batch size | |
repeat_by = num_images_per_prompt | |
image = image.repeat_interleave(repeat_by, dim=0) | |
image = image.to(device=device, dtype=dtype) | |
if do_classifier_free_guidance: | |
image = torch.cat([image] * 2) | |
return image | |
def __call__( | |
self, | |
prompt: List[str], | |
reference_image: PIL.Image.Image, | |
condtioning_image: PIL.Image.Image, | |
source_subject_category: List[str], | |
target_subject_category: List[str], | |
latents: Optional[torch.Tensor] = None, | |
guidance_scale: float = 7.5, | |
height: int = 512, | |
width: int = 512, | |
num_inference_steps: int = 50, | |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
neg_prompt: Optional[str] = "", | |
prompt_strength: float = 1.0, | |
prompt_reps: int = 20, | |
output_type: Optional[str] = "pil", | |
return_dict: bool = True, | |
): | |
""" | |
Function invoked when calling the pipeline for generation. | |
Args: | |
prompt (`List[str]`): | |
The prompt or prompts to guide the image generation. | |
reference_image (`PIL.Image.Image`): | |
The reference image to condition the generation on. | |
condtioning_image (`PIL.Image.Image`): | |
The conditioning canny edge image to condition the generation on. | |
source_subject_category (`List[str]`): | |
The source subject category. | |
target_subject_category (`List[str]`): | |
The target subject category. | |
latents (`torch.Tensor`, *optional*): | |
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image | |
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents | |
tensor will ge generated by random sampling. | |
guidance_scale (`float`, *optional*, defaults to 7.5): | |
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). | |
`guidance_scale` is defined as `w` of equation 2. of [Imagen | |
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > | |
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, | |
usually at the expense of lower image quality. | |
height (`int`, *optional*, defaults to 512): | |
The height of the generated image. | |
width (`int`, *optional*, defaults to 512): | |
The width of the generated image. | |
seed (`int`, *optional*, defaults to 42): | |
The seed to use for random generation. | |
num_inference_steps (`int`, *optional*, defaults to 50): | |
The number of denoising steps. More denoising steps usually lead to a higher quality image at the | |
expense of slower inference. | |
generator (`torch.Generator` or `List[torch.Generator]`, *optional*): | |
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) | |
to make generation deterministic. | |
neg_prompt (`str`, *optional*, defaults to ""): | |
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored | |
if `guidance_scale` is less than `1`). | |
prompt_strength (`float`, *optional*, defaults to 1.0): | |
The strength of the prompt. Specifies the number of times the prompt is repeated along with prompt_reps | |
to amplify the prompt. | |
prompt_reps (`int`, *optional*, defaults to 20): | |
The number of times the prompt is repeated along with prompt_strength to amplify the prompt. | |
Examples: | |
Returns: | |
[`~pipelines.ImagePipelineOutput`] or `tuple` | |
""" | |
device = self._execution_device | |
reference_image = self.image_processor.preprocess( | |
reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt" | |
)["pixel_values"] | |
reference_image = reference_image.to(device) | |
if isinstance(prompt, str): | |
prompt = [prompt] | |
if isinstance(source_subject_category, str): | |
source_subject_category = [source_subject_category] | |
if isinstance(target_subject_category, str): | |
target_subject_category = [target_subject_category] | |
batch_size = len(prompt) | |
prompt = self._build_prompt( | |
prompts=prompt, | |
tgt_subjects=target_subject_category, | |
prompt_strength=prompt_strength, | |
prompt_reps=prompt_reps, | |
) | |
query_embeds = self.get_query_embeddings(reference_image, source_subject_category) | |
text_embeddings = self.encode_prompt(query_embeds, prompt, device) | |
# 3. unconditional embedding | |
do_classifier_free_guidance = guidance_scale > 1.0 | |
if do_classifier_free_guidance: | |
max_length = self.text_encoder.text_model.config.max_position_embeddings | |
uncond_input = self.tokenizer( | |
[neg_prompt] * batch_size, | |
padding="max_length", | |
max_length=max_length, | |
return_tensors="pt", | |
) | |
uncond_embeddings = self.text_encoder( | |
input_ids=uncond_input.input_ids.to(device), | |
ctx_embeddings=None, | |
)[0] | |
# For classifier free guidance, we need to do two forward passes. | |
# Here we concatenate the unconditional and text embeddings into a single batch | |
# to avoid doing two forward passes | |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
scale_down_factor = 2 ** (len(self.unet.config.block_out_channels) - 1) | |
latents = self.prepare_latents( | |
batch_size=batch_size, | |
num_channels=self.unet.config.in_channels, | |
height=height // scale_down_factor, | |
width=width // scale_down_factor, | |
generator=generator, | |
latents=latents, | |
dtype=self.unet.dtype, | |
device=device, | |
) | |
# set timesteps | |
extra_set_kwargs = {} | |
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) | |
cond_image = self.prepare_control_image( | |
image=condtioning_image, | |
width=width, | |
height=height, | |
batch_size=batch_size, | |
num_images_per_prompt=1, | |
device=device, | |
dtype=self.controlnet.dtype, | |
do_classifier_free_guidance=do_classifier_free_guidance, | |
) | |
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): | |
# expand the latents if we are doing classifier free guidance | |
do_classifier_free_guidance = guidance_scale > 1.0 | |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | |
down_block_res_samples, mid_block_res_sample = self.controlnet( | |
latent_model_input, | |
t, | |
encoder_hidden_states=text_embeddings, | |
controlnet_cond=cond_image, | |
return_dict=False, | |
) | |
noise_pred = self.unet( | |
latent_model_input, | |
timestep=t, | |
encoder_hidden_states=text_embeddings, | |
down_block_additional_residuals=down_block_res_samples, | |
mid_block_additional_residual=mid_block_res_sample, | |
)["sample"] | |
# perform guidance | |
if do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
latents = self.scheduler.step( | |
noise_pred, | |
t, | |
latents, | |
)["prev_sample"] | |
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] | |
image = self.image_processor.postprocess(image, output_type=output_type) | |
# Offload all models | |
self.maybe_free_model_hooks() | |
if not return_dict: | |
return (image,) | |
return ImagePipelineOutput(images=image) | |