|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
|
|
|
import torch |
|
from transformers import CLIPTextModelWithProjection, CLIPTokenizer |
|
|
|
from ...image_processor import VaeImageProcessor |
|
from ...models import UVit2DModel, VQModel |
|
from ...schedulers import AmusedScheduler |
|
from ...utils import replace_example_docstring |
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput |
|
|
|
|
|
EXAMPLE_DOC_STRING = """ |
|
Examples: |
|
```py |
|
>>> import torch |
|
>>> from diffusers import AmusedPipeline |
|
|
|
>>> pipe = AmusedPipeline.from_pretrained("amused/amused-512", variant="fp16", torch_dtype=torch.float16) |
|
>>> pipe = pipe.to("cuda") |
|
|
|
>>> prompt = "a photo of an astronaut riding a horse on mars" |
|
>>> image = pipe(prompt).images[0] |
|
``` |
|
""" |
|
|
|
|
|
class AmusedPipeline(DiffusionPipeline): |
|
image_processor: VaeImageProcessor |
|
vqvae: VQModel |
|
tokenizer: CLIPTokenizer |
|
text_encoder: CLIPTextModelWithProjection |
|
transformer: UVit2DModel |
|
scheduler: AmusedScheduler |
|
|
|
model_cpu_offload_seq = "text_encoder->transformer->vqvae" |
|
|
|
def __init__( |
|
self, |
|
vqvae: VQModel, |
|
tokenizer: CLIPTokenizer, |
|
text_encoder: CLIPTextModelWithProjection, |
|
transformer: UVit2DModel, |
|
scheduler: AmusedScheduler, |
|
): |
|
super().__init__() |
|
|
|
self.register_modules( |
|
vqvae=vqvae, |
|
tokenizer=tokenizer, |
|
text_encoder=text_encoder, |
|
transformer=transformer, |
|
scheduler=scheduler, |
|
) |
|
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) |
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) |
|
|
|
@torch.no_grad() |
|
@replace_example_docstring(EXAMPLE_DOC_STRING) |
|
def __call__( |
|
self, |
|
prompt: Optional[Union[List[str], str]] = None, |
|
height: Optional[int] = None, |
|
width: Optional[int] = None, |
|
num_inference_steps: int = 12, |
|
guidance_scale: float = 10.0, |
|
negative_prompt: Optional[Union[str, List[str]]] = None, |
|
num_images_per_prompt: Optional[int] = 1, |
|
generator: Optional[torch.Generator] = None, |
|
latents: Optional[torch.IntTensor] = None, |
|
prompt_embeds: Optional[torch.Tensor] = None, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
negative_prompt_embeds: Optional[torch.Tensor] = None, |
|
negative_encoder_hidden_states: Optional[torch.Tensor] = None, |
|
output_type="pil", |
|
return_dict: bool = True, |
|
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, |
|
callback_steps: int = 1, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
micro_conditioning_aesthetic_score: int = 6, |
|
micro_conditioning_crop_coord: Tuple[int, int] = (0, 0), |
|
temperature: Union[int, Tuple[int, int], List[int]] = (2, 0), |
|
): |
|
""" |
|
The call function to the pipeline for generation. |
|
|
|
Args: |
|
prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. |
|
height (`int`, *optional*, defaults to `self.transformer.config.sample_size * self.vae_scale_factor`): |
|
The height in pixels of the generated image. |
|
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): |
|
The width in pixels of the generated image. |
|
num_inference_steps (`int`, *optional*, defaults to 16): |
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
|
expense of slower inference. |
|
guidance_scale (`float`, *optional*, defaults to 10.0): |
|
A higher guidance scale value encourages the model to generate images closely linked to the text |
|
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. |
|
negative_prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts to guide what to not include in image generation. If not defined, you need to |
|
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). |
|
num_images_per_prompt (`int`, *optional*, defaults to 1): |
|
The number of images to generate per prompt. |
|
generator (`torch.Generator`, *optional*): |
|
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make |
|
generation deterministic. |
|
latents (`torch.IntTensor`, *optional*): |
|
Pre-generated tokens representing latent vectors in `self.vqvae`, to be used as inputs for image |
|
gneration. If not provided, the starting latents will be completely masked. |
|
prompt_embeds (`torch.Tensor`, *optional*): |
|
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not |
|
provided, text embeddings are generated from the `prompt` input argument. A single vector from the |
|
pooled and projected final hidden states. |
|
encoder_hidden_states (`torch.Tensor`, *optional*): |
|
Pre-generated penultimate hidden states from the text encoder providing additional text conditioning. |
|
negative_prompt_embeds (`torch.Tensor`, *optional*): |
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If |
|
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. |
|
negative_encoder_hidden_states (`torch.Tensor`, *optional*): |
|
Analogous to `encoder_hidden_states` for the positive prompt. |
|
output_type (`str`, *optional*, defaults to `"pil"`): |
|
The output format of the generated image. Choose between `PIL.Image` or `np.array`. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a |
|
plain tuple. |
|
callback (`Callable`, *optional*): |
|
A function that calls every `callback_steps` steps during inference. The function is called with the |
|
following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. |
|
callback_steps (`int`, *optional*, defaults to 1): |
|
The frequency at which the `callback` function is called. If not specified, the callback is called at |
|
every step. |
|
cross_attention_kwargs (`dict`, *optional*): |
|
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in |
|
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). |
|
micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6): |
|
The targeted aesthetic score according to the laion aesthetic classifier. See |
|
https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of |
|
https://arxiv.org/abs/2307.01952. |
|
micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)): |
|
The targeted height, width crop coordinates. See the micro-conditioning section of |
|
https://arxiv.org/abs/2307.01952. |
|
temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)): |
|
Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`. |
|
|
|
Examples: |
|
|
|
Returns: |
|
[`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`: |
|
If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a |
|
`tuple` is returned where the first element is a list with the generated images. |
|
""" |
|
if (prompt_embeds is not None and encoder_hidden_states is None) or ( |
|
prompt_embeds is None and encoder_hidden_states is not None |
|
): |
|
raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither") |
|
|
|
if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or ( |
|
negative_prompt_embeds is None and negative_encoder_hidden_states is not None |
|
): |
|
raise ValueError( |
|
"pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither" |
|
) |
|
|
|
if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None): |
|
raise ValueError("pass only one of `prompt` or `prompt_embeds`") |
|
|
|
if isinstance(prompt, str): |
|
prompt = [prompt] |
|
|
|
if prompt is not None: |
|
batch_size = len(prompt) |
|
else: |
|
batch_size = prompt_embeds.shape[0] |
|
|
|
batch_size = batch_size * num_images_per_prompt |
|
|
|
if height is None: |
|
height = self.transformer.config.sample_size * self.vae_scale_factor |
|
|
|
if width is None: |
|
width = self.transformer.config.sample_size * self.vae_scale_factor |
|
|
|
if prompt_embeds is None: |
|
input_ids = self.tokenizer( |
|
prompt, |
|
return_tensors="pt", |
|
padding="max_length", |
|
truncation=True, |
|
max_length=self.tokenizer.model_max_length, |
|
).input_ids.to(self._execution_device) |
|
|
|
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) |
|
prompt_embeds = outputs.text_embeds |
|
encoder_hidden_states = outputs.hidden_states[-2] |
|
|
|
prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1) |
|
encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) |
|
|
|
if guidance_scale > 1.0: |
|
if negative_prompt_embeds is None: |
|
if negative_prompt is None: |
|
negative_prompt = [""] * len(prompt) |
|
|
|
if isinstance(negative_prompt, str): |
|
negative_prompt = [negative_prompt] |
|
|
|
input_ids = self.tokenizer( |
|
negative_prompt, |
|
return_tensors="pt", |
|
padding="max_length", |
|
truncation=True, |
|
max_length=self.tokenizer.model_max_length, |
|
).input_ids.to(self._execution_device) |
|
|
|
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) |
|
negative_prompt_embeds = outputs.text_embeds |
|
negative_encoder_hidden_states = outputs.hidden_states[-2] |
|
|
|
negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1) |
|
negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) |
|
|
|
prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds]) |
|
encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states]) |
|
|
|
|
|
|
|
micro_conds = torch.tensor( |
|
[ |
|
width, |
|
height, |
|
micro_conditioning_crop_coord[0], |
|
micro_conditioning_crop_coord[1], |
|
micro_conditioning_aesthetic_score, |
|
], |
|
device=self._execution_device, |
|
dtype=encoder_hidden_states.dtype, |
|
) |
|
micro_conds = micro_conds.unsqueeze(0) |
|
micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1) |
|
|
|
shape = (batch_size, height // self.vae_scale_factor, width // self.vae_scale_factor) |
|
|
|
if latents is None: |
|
latents = torch.full( |
|
shape, self.scheduler.config.mask_token_id, dtype=torch.long, device=self._execution_device |
|
) |
|
|
|
self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device) |
|
|
|
num_warmup_steps = len(self.scheduler.timesteps) - num_inference_steps * self.scheduler.order |
|
with self.progress_bar(total=num_inference_steps) as progress_bar: |
|
for i, timestep in enumerate(self.scheduler.timesteps): |
|
if guidance_scale > 1.0: |
|
model_input = torch.cat([latents] * 2) |
|
else: |
|
model_input = latents |
|
|
|
model_output = self.transformer( |
|
model_input, |
|
micro_conds=micro_conds, |
|
pooled_text_emb=prompt_embeds, |
|
encoder_hidden_states=encoder_hidden_states, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
) |
|
|
|
if guidance_scale > 1.0: |
|
uncond_logits, cond_logits = model_output.chunk(2) |
|
model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits) |
|
|
|
latents = self.scheduler.step( |
|
model_output=model_output, |
|
timestep=timestep, |
|
sample=latents, |
|
generator=generator, |
|
).prev_sample |
|
|
|
if i == len(self.scheduler.timesteps) - 1 or ( |
|
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 |
|
): |
|
progress_bar.update() |
|
if callback is not None and i % callback_steps == 0: |
|
step_idx = i // getattr(self.scheduler, "order", 1) |
|
callback(step_idx, timestep, latents) |
|
|
|
if output_type == "latent": |
|
output = latents |
|
else: |
|
needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast |
|
|
|
if needs_upcasting: |
|
self.vqvae.float() |
|
|
|
output = self.vqvae.decode( |
|
latents, |
|
force_not_quantize=True, |
|
shape=( |
|
batch_size, |
|
height // self.vae_scale_factor, |
|
width // self.vae_scale_factor, |
|
self.vqvae.config.latent_channels, |
|
), |
|
).sample.clip(0, 1) |
|
output = self.image_processor.postprocess(output, output_type) |
|
|
|
if needs_upcasting: |
|
self.vqvae.half() |
|
|
|
self.maybe_free_model_hooks() |
|
|
|
if not return_dict: |
|
return (output,) |
|
|
|
return ImagePipelineOutput(output) |
|
|