|
import inspect |
|
from typing import Callable, List, Optional, Union |
|
|
|
import torch |
|
from PIL import Image |
|
from retriever import Retriever, normalize_images, preprocess_images |
|
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTokenizer |
|
|
|
from diffusers import ( |
|
AutoencoderKL, |
|
DDIMScheduler, |
|
DiffusionPipeline, |
|
DPMSolverMultistepScheduler, |
|
EulerAncestralDiscreteScheduler, |
|
EulerDiscreteScheduler, |
|
ImagePipelineOutput, |
|
LMSDiscreteScheduler, |
|
PNDMScheduler, |
|
UNet2DConditionModel, |
|
) |
|
from diffusers.image_processor import VaeImageProcessor |
|
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin |
|
from diffusers.utils import logging |
|
from diffusers.utils.torch_utils import randn_tensor |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class RDMPipeline(DiffusionPipeline, StableDiffusionMixin): |
|
r""" |
|
Pipeline for text-to-image generation using Retrieval Augmented Diffusion. |
|
|
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the |
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) |
|
|
|
Args: |
|
vae ([`AutoencoderKL`]): |
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. |
|
clip ([`CLIPModel`]): |
|
Frozen CLIP model. Retrieval Augmented Diffusion uses the CLIP model, specifically the |
|
[clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. |
|
tokenizer (`CLIPTokenizer`): |
|
Tokenizer of class |
|
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). |
|
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. |
|
scheduler ([`SchedulerMixin`]): |
|
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of |
|
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. |
|
feature_extractor ([`CLIPFeatureExtractor`]): |
|
Model that extracts features from generated images to be used as inputs for the `safety_checker`. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
vae: AutoencoderKL, |
|
clip: CLIPModel, |
|
tokenizer: CLIPTokenizer, |
|
unet: UNet2DConditionModel, |
|
scheduler: Union[ |
|
DDIMScheduler, |
|
PNDMScheduler, |
|
LMSDiscreteScheduler, |
|
EulerDiscreteScheduler, |
|
EulerAncestralDiscreteScheduler, |
|
DPMSolverMultistepScheduler, |
|
], |
|
feature_extractor: CLIPFeatureExtractor, |
|
retriever: Optional[Retriever] = None, |
|
): |
|
super().__init__() |
|
self.register_modules( |
|
vae=vae, |
|
clip=clip, |
|
tokenizer=tokenizer, |
|
unet=unet, |
|
scheduler=scheduler, |
|
feature_extractor=feature_extractor, |
|
) |
|
|
|
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) |
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) |
|
self.retriever = retriever |
|
|
|
def _encode_prompt(self, prompt): |
|
|
|
text_inputs = self.tokenizer( |
|
prompt, |
|
padding="max_length", |
|
max_length=self.tokenizer.model_max_length, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
text_input_ids = text_inputs.input_ids |
|
|
|
if text_input_ids.shape[-1] > self.tokenizer.model_max_length: |
|
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) |
|
logger.warning( |
|
"The following part of your input was truncated because CLIP can only handle sequences up to" |
|
f" {self.tokenizer.model_max_length} tokens: {removed_text}" |
|
) |
|
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] |
|
prompt_embeds = self.clip.get_text_features(text_input_ids.to(self.device)) |
|
prompt_embeds = prompt_embeds / torch.linalg.norm(prompt_embeds, dim=-1, keepdim=True) |
|
prompt_embeds = prompt_embeds[:, None, :] |
|
return prompt_embeds |
|
|
|
def _encode_image(self, retrieved_images, batch_size): |
|
if len(retrieved_images[0]) == 0: |
|
return None |
|
for i in range(len(retrieved_images)): |
|
retrieved_images[i] = normalize_images(retrieved_images[i]) |
|
retrieved_images[i] = preprocess_images(retrieved_images[i], self.feature_extractor).to( |
|
self.clip.device, dtype=self.clip.dtype |
|
) |
|
_, c, h, w = retrieved_images[0].shape |
|
|
|
retrieved_images = torch.reshape(torch.cat(retrieved_images, dim=0), (-1, c, h, w)) |
|
image_embeddings = self.clip.get_image_features(retrieved_images) |
|
image_embeddings = image_embeddings / torch.linalg.norm(image_embeddings, dim=-1, keepdim=True) |
|
_, d = image_embeddings.shape |
|
image_embeddings = torch.reshape(image_embeddings, (batch_size, -1, d)) |
|
return image_embeddings |
|
|
|
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): |
|
shape = ( |
|
batch_size, |
|
num_channels_latents, |
|
int(height) // self.vae_scale_factor, |
|
int(width) // self.vae_scale_factor, |
|
) |
|
if isinstance(generator, list) and len(generator) != batch_size: |
|
raise ValueError( |
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators." |
|
) |
|
|
|
if latents is None: |
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
|
else: |
|
latents = latents.to(device) |
|
|
|
|
|
latents = latents * self.scheduler.init_noise_sigma |
|
return latents |
|
|
|
def retrieve_images(self, retrieved_images, prompt_embeds, knn=10): |
|
if self.retriever is not None: |
|
additional_images = self.retriever.retrieve_imgs_batch(prompt_embeds[:, 0].cpu(), knn).total_examples |
|
for i in range(len(retrieved_images)): |
|
retrieved_images[i] += additional_images[i][self.retriever.config.image_column] |
|
return retrieved_images |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
prompt: Union[str, List[str]], |
|
retrieved_images: Optional[List[Image.Image]] = None, |
|
height: int = 768, |
|
width: int = 768, |
|
num_inference_steps: int = 50, |
|
guidance_scale: float = 7.5, |
|
num_images_per_prompt: Optional[int] = 1, |
|
eta: float = 0.0, |
|
generator: Optional[torch.Generator] = None, |
|
latents: Optional[torch.Tensor] = None, |
|
prompt_embeds: Optional[torch.Tensor] = None, |
|
output_type: Optional[str] = "pil", |
|
return_dict: bool = True, |
|
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, |
|
callback_steps: Optional[int] = 1, |
|
knn: Optional[int] = 10, |
|
**kwargs, |
|
): |
|
r""" |
|
Function invoked when calling the pipeline for generation. |
|
|
|
Args: |
|
prompt (`str` or `List[str]`): |
|
The prompt or prompts to guide the image generation. |
|
height (`int`, *optional*, defaults to 512): |
|
The height in pixels of the generated image. |
|
width (`int`, *optional*, defaults to 512): |
|
The width in pixels of the generated image. |
|
num_inference_steps (`int`, *optional*, defaults to 50): |
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
|
expense of slower inference. |
|
guidance_scale (`float`, *optional*, defaults to 7.5): |
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). |
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen |
|
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > |
|
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, |
|
usually at the expense of lower image quality. |
|
num_images_per_prompt (`int`, *optional*, defaults to 1): |
|
The number of images to generate per prompt. |
|
eta (`float`, *optional*, defaults to 0.0): |
|
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to |
|
[`schedulers.DDIMScheduler`], will be ignored for others. |
|
generator (`torch.Generator`, *optional*): |
|
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation |
|
deterministic. |
|
latents (`torch.Tensor`, *optional*): |
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image |
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents |
|
tensor will ge generated by sampling using the supplied random `generator`. |
|
prompt_embeds (`torch.Tensor`, *optional*): |
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not |
|
provided, text embeddings will be generated from `prompt` input argument. |
|
output_type (`str`, *optional*, defaults to `"pil"`): |
|
The output format of the generate image. Choose between |
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. |
|
callback (`Callable`, *optional*): |
|
A function that will be called every `callback_steps` steps during inference. The function will be |
|
called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. |
|
callback_steps (`int`, *optional*, defaults to 1): |
|
The frequency at which the `callback` function will be called. If not specified, the callback will be |
|
called at every step. |
|
|
|
Returns: |
|
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if |
|
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the |
|
generated images. |
|
""" |
|
height = height or self.unet.config.sample_size * self.vae_scale_factor |
|
width = width or self.unet.config.sample_size * self.vae_scale_factor |
|
if isinstance(prompt, str): |
|
batch_size = 1 |
|
elif isinstance(prompt, list): |
|
batch_size = len(prompt) |
|
else: |
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") |
|
if retrieved_images is not None: |
|
retrieved_images = [retrieved_images for _ in range(batch_size)] |
|
else: |
|
retrieved_images = [[] for _ in range(batch_size)] |
|
device = self._execution_device |
|
|
|
if height % 8 != 0 or width % 8 != 0: |
|
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") |
|
|
|
if (callback_steps is None) or ( |
|
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) |
|
): |
|
raise ValueError( |
|
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" |
|
f" {type(callback_steps)}." |
|
) |
|
if prompt_embeds is None: |
|
prompt_embeds = self._encode_prompt(prompt) |
|
retrieved_images = self.retrieve_images(retrieved_images, prompt_embeds, knn=knn) |
|
image_embeddings = self._encode_image(retrieved_images, batch_size) |
|
if image_embeddings is not None: |
|
prompt_embeds = torch.cat([prompt_embeds, image_embeddings], dim=1) |
|
|
|
|
|
bs_embed, seq_len, _ = prompt_embeds.shape |
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) |
|
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) |
|
|
|
|
|
|
|
|
|
do_classifier_free_guidance = guidance_scale > 1.0 |
|
|
|
if do_classifier_free_guidance: |
|
uncond_embeddings = torch.zeros_like(prompt_embeds).to(prompt_embeds.device) |
|
|
|
|
|
|
|
|
|
prompt_embeds = torch.cat([uncond_embeddings, prompt_embeds]) |
|
|
|
num_channels_latents = self.unet.config.in_channels |
|
latents = self.prepare_latents( |
|
batch_size * num_images_per_prompt, |
|
num_channels_latents, |
|
height, |
|
width, |
|
prompt_embeds.dtype, |
|
device, |
|
generator, |
|
latents, |
|
) |
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps) |
|
|
|
|
|
|
|
timesteps_tensor = self.scheduler.timesteps.to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
|
extra_step_kwargs = {} |
|
if accepts_eta: |
|
extra_step_kwargs["eta"] = eta |
|
|
|
|
|
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
|
if accepts_generator: |
|
extra_step_kwargs["generator"] = generator |
|
|
|
for i, t in enumerate(self.progress_bar(timesteps_tensor)): |
|
|
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
|
|
|
|
|
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample |
|
|
|
|
|
if do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
|
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |
|
|
|
|
|
if callback is not None and i % callback_steps == 0: |
|
step_idx = i // getattr(self.scheduler, "order", 1) |
|
callback(step_idx, t, latents) |
|
if not output_type == "latent": |
|
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] |
|
else: |
|
image = latents |
|
|
|
image = self.image_processor.postprocess( |
|
image, output_type=output_type, do_denormalize=[True] * image.shape[0] |
|
) |
|
|
|
|
|
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: |
|
self.final_offload_hook.offload() |
|
|
|
if not return_dict: |
|
return (image,) |
|
|
|
return ImagePipelineOutput(images=image) |
|
|