|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
|
|
|
import torch |
|
from transformers import CLIPTextModelWithProjection, CLIPTokenizer |
|
|
|
from ...image_processor import PipelineImageInput, VaeImageProcessor |
|
from ...models import UVit2DModel, VQModel |
|
from ...schedulers import AmusedScheduler |
|
from ...utils import replace_example_docstring |
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput |
|
|
|
|
|
EXAMPLE_DOC_STRING = """ |
|
Examples: |
|
```py |
|
>>> import torch |
|
>>> from diffusers import AmusedInpaintPipeline |
|
>>> from diffusers.utils import load_image |
|
|
|
>>> pipe = AmusedInpaintPipeline.from_pretrained( |
|
... "amused/amused-512", variant="fp16", torch_dtype=torch.float16 |
|
... ) |
|
>>> pipe = pipe.to("cuda") |
|
|
|
>>> prompt = "fall mountains" |
|
>>> input_image = ( |
|
... load_image( |
|
... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg" |
|
... ) |
|
... .resize((512, 512)) |
|
... .convert("RGB") |
|
... ) |
|
>>> mask = ( |
|
... load_image( |
|
... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png" |
|
... ) |
|
... .resize((512, 512)) |
|
... .convert("L") |
|
... ) |
|
>>> pipe(prompt, input_image, mask).images[0].save("out.png") |
|
``` |
|
""" |
|
|
|
|
|
class AmusedInpaintPipeline(DiffusionPipeline): |
|
image_processor: VaeImageProcessor |
|
vqvae: VQModel |
|
tokenizer: CLIPTokenizer |
|
text_encoder: CLIPTextModelWithProjection |
|
transformer: UVit2DModel |
|
scheduler: AmusedScheduler |
|
|
|
model_cpu_offload_seq = "text_encoder->transformer->vqvae" |
|
|
|
|
|
|
|
|
|
_exclude_from_cpu_offload = ["vqvae"] |
|
|
|
def __init__( |
|
self, |
|
vqvae: VQModel, |
|
tokenizer: CLIPTokenizer, |
|
text_encoder: CLIPTextModelWithProjection, |
|
transformer: UVit2DModel, |
|
scheduler: AmusedScheduler, |
|
): |
|
super().__init__() |
|
|
|
self.register_modules( |
|
vqvae=vqvae, |
|
tokenizer=tokenizer, |
|
text_encoder=text_encoder, |
|
transformer=transformer, |
|
scheduler=scheduler, |
|
) |
|
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) |
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) |
|
self.mask_processor = VaeImageProcessor( |
|
vae_scale_factor=self.vae_scale_factor, |
|
do_normalize=False, |
|
do_binarize=True, |
|
do_convert_grayscale=True, |
|
do_resize=True, |
|
) |
|
self.scheduler.register_to_config(masking_schedule="linear") |
|
|
|
@torch.no_grad() |
|
@replace_example_docstring(EXAMPLE_DOC_STRING) |
|
def __call__( |
|
self, |
|
prompt: Optional[Union[List[str], str]] = None, |
|
image: PipelineImageInput = None, |
|
mask_image: PipelineImageInput = None, |
|
strength: float = 1.0, |
|
num_inference_steps: int = 12, |
|
guidance_scale: float = 10.0, |
|
negative_prompt: Optional[Union[str, List[str]]] = None, |
|
num_images_per_prompt: Optional[int] = 1, |
|
generator: Optional[torch.Generator] = None, |
|
prompt_embeds: Optional[torch.Tensor] = None, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
negative_prompt_embeds: Optional[torch.Tensor] = None, |
|
negative_encoder_hidden_states: Optional[torch.Tensor] = None, |
|
output_type="pil", |
|
return_dict: bool = True, |
|
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, |
|
callback_steps: int = 1, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
micro_conditioning_aesthetic_score: int = 6, |
|
micro_conditioning_crop_coord: Tuple[int, int] = (0, 0), |
|
temperature: Union[int, Tuple[int, int], List[int]] = (2, 0), |
|
): |
|
""" |
|
The call function to the pipeline for generation. |
|
|
|
Args: |
|
prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. |
|
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): |
|
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both |
|
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list |
|
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a |
|
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image |
|
latents as `image`, but if passing latents directly it is not encoded again. |
|
mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): |
|
`Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask |
|
are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a |
|
single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one |
|
color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, |
|
H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, |
|
1)`, or `(H, W)`. |
|
strength (`float`, *optional*, defaults to 1.0): |
|
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a |
|
starting point and more noise is added the higher the `strength`. The number of denoising steps depends |
|
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising |
|
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 |
|
essentially ignores `image`. |
|
num_inference_steps (`int`, *optional*, defaults to 16): |
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
|
expense of slower inference. |
|
guidance_scale (`float`, *optional*, defaults to 10.0): |
|
A higher guidance scale value encourages the model to generate images closely linked to the text |
|
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. |
|
negative_prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts to guide what to not include in image generation. If not defined, you need to |
|
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). |
|
num_images_per_prompt (`int`, *optional*, defaults to 1): |
|
The number of images to generate per prompt. |
|
generator (`torch.Generator`, *optional*): |
|
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make |
|
generation deterministic. |
|
prompt_embeds (`torch.Tensor`, *optional*): |
|
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not |
|
provided, text embeddings are generated from the `prompt` input argument. A single vector from the |
|
pooled and projected final hidden states. |
|
encoder_hidden_states (`torch.Tensor`, *optional*): |
|
Pre-generated penultimate hidden states from the text encoder providing additional text conditioning. |
|
negative_prompt_embeds (`torch.Tensor`, *optional*): |
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If |
|
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. |
|
negative_encoder_hidden_states (`torch.Tensor`, *optional*): |
|
Analogous to `encoder_hidden_states` for the positive prompt. |
|
output_type (`str`, *optional*, defaults to `"pil"`): |
|
The output format of the generated image. Choose between `PIL.Image` or `np.array`. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a |
|
plain tuple. |
|
callback (`Callable`, *optional*): |
|
A function that calls every `callback_steps` steps during inference. The function is called with the |
|
following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. |
|
callback_steps (`int`, *optional*, defaults to 1): |
|
The frequency at which the `callback` function is called. If not specified, the callback is called at |
|
every step. |
|
cross_attention_kwargs (`dict`, *optional*): |
|
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in |
|
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). |
|
micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6): |
|
The targeted aesthetic score according to the laion aesthetic classifier. See |
|
https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of |
|
https://arxiv.org/abs/2307.01952. |
|
micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)): |
|
The targeted height, width crop coordinates. See the micro-conditioning section of |
|
https://arxiv.org/abs/2307.01952. |
|
temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)): |
|
Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`. |
|
|
|
Examples: |
|
|
|
Returns: |
|
[`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`: |
|
If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a |
|
`tuple` is returned where the first element is a list with the generated images. |
|
""" |
|
|
|
if (prompt_embeds is not None and encoder_hidden_states is None) or ( |
|
prompt_embeds is None and encoder_hidden_states is not None |
|
): |
|
raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither") |
|
|
|
if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or ( |
|
negative_prompt_embeds is None and negative_encoder_hidden_states is not None |
|
): |
|
raise ValueError( |
|
"pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither" |
|
) |
|
|
|
if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None): |
|
raise ValueError("pass only one of `prompt` or `prompt_embeds`") |
|
|
|
if isinstance(prompt, str): |
|
prompt = [prompt] |
|
|
|
if prompt is not None: |
|
batch_size = len(prompt) |
|
else: |
|
batch_size = prompt_embeds.shape[0] |
|
|
|
batch_size = batch_size * num_images_per_prompt |
|
|
|
if prompt_embeds is None: |
|
input_ids = self.tokenizer( |
|
prompt, |
|
return_tensors="pt", |
|
padding="max_length", |
|
truncation=True, |
|
max_length=self.tokenizer.model_max_length, |
|
).input_ids.to(self._execution_device) |
|
|
|
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) |
|
prompt_embeds = outputs.text_embeds |
|
encoder_hidden_states = outputs.hidden_states[-2] |
|
|
|
prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1) |
|
encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) |
|
|
|
if guidance_scale > 1.0: |
|
if negative_prompt_embeds is None: |
|
if negative_prompt is None: |
|
negative_prompt = [""] * len(prompt) |
|
|
|
if isinstance(negative_prompt, str): |
|
negative_prompt = [negative_prompt] |
|
|
|
input_ids = self.tokenizer( |
|
negative_prompt, |
|
return_tensors="pt", |
|
padding="max_length", |
|
truncation=True, |
|
max_length=self.tokenizer.model_max_length, |
|
).input_ids.to(self._execution_device) |
|
|
|
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) |
|
negative_prompt_embeds = outputs.text_embeds |
|
negative_encoder_hidden_states = outputs.hidden_states[-2] |
|
|
|
negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1) |
|
negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) |
|
|
|
prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds]) |
|
encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states]) |
|
|
|
image = self.image_processor.preprocess(image) |
|
|
|
height, width = image.shape[-2:] |
|
|
|
|
|
|
|
micro_conds = torch.tensor( |
|
[ |
|
width, |
|
height, |
|
micro_conditioning_crop_coord[0], |
|
micro_conditioning_crop_coord[1], |
|
micro_conditioning_aesthetic_score, |
|
], |
|
device=self._execution_device, |
|
dtype=encoder_hidden_states.dtype, |
|
) |
|
|
|
micro_conds = micro_conds.unsqueeze(0) |
|
micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1) |
|
|
|
self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device) |
|
num_inference_steps = int(len(self.scheduler.timesteps) * strength) |
|
start_timestep_idx = len(self.scheduler.timesteps) - num_inference_steps |
|
|
|
needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast |
|
|
|
if needs_upcasting: |
|
self.vqvae.float() |
|
|
|
latents = self.vqvae.encode(image.to(dtype=self.vqvae.dtype, device=self._execution_device)).latents |
|
latents_bsz, channels, latents_height, latents_width = latents.shape |
|
latents = self.vqvae.quantize(latents)[2][2].reshape(latents_bsz, latents_height, latents_width) |
|
|
|
mask = self.mask_processor.preprocess( |
|
mask_image, height // self.vae_scale_factor, width // self.vae_scale_factor |
|
) |
|
mask = mask.reshape(mask.shape[0], latents_height, latents_width).bool().to(latents.device) |
|
latents[mask] = self.scheduler.config.mask_token_id |
|
|
|
starting_mask_ratio = mask.sum() / latents.numel() |
|
|
|
latents = latents.repeat(num_images_per_prompt, 1, 1) |
|
|
|
with self.progress_bar(total=num_inference_steps) as progress_bar: |
|
for i in range(start_timestep_idx, len(self.scheduler.timesteps)): |
|
timestep = self.scheduler.timesteps[i] |
|
|
|
if guidance_scale > 1.0: |
|
model_input = torch.cat([latents] * 2) |
|
else: |
|
model_input = latents |
|
|
|
model_output = self.transformer( |
|
model_input, |
|
micro_conds=micro_conds, |
|
pooled_text_emb=prompt_embeds, |
|
encoder_hidden_states=encoder_hidden_states, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
) |
|
|
|
if guidance_scale > 1.0: |
|
uncond_logits, cond_logits = model_output.chunk(2) |
|
model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits) |
|
|
|
latents = self.scheduler.step( |
|
model_output=model_output, |
|
timestep=timestep, |
|
sample=latents, |
|
generator=generator, |
|
starting_mask_ratio=starting_mask_ratio, |
|
).prev_sample |
|
|
|
if i == len(self.scheduler.timesteps) - 1 or ((i + 1) % self.scheduler.order == 0): |
|
progress_bar.update() |
|
if callback is not None and i % callback_steps == 0: |
|
step_idx = i // getattr(self.scheduler, "order", 1) |
|
callback(step_idx, timestep, latents) |
|
|
|
if output_type == "latent": |
|
output = latents |
|
else: |
|
output = self.vqvae.decode( |
|
latents, |
|
force_not_quantize=True, |
|
shape=( |
|
batch_size, |
|
height // self.vae_scale_factor, |
|
width // self.vae_scale_factor, |
|
self.vqvae.config.latent_channels, |
|
), |
|
).sample.clip(0, 1) |
|
output = self.image_processor.postprocess(output, output_type) |
|
|
|
if needs_upcasting: |
|
self.vqvae.half() |
|
|
|
self.maybe_free_model_hooks() |
|
|
|
if not return_dict: |
|
return (output,) |
|
|
|
return ImagePipelineOutput(output) |
|
|