Spaces:
Paused
Paused
import pdb, sys | |
import numpy as np | |
import torch | |
from typing import Any, Callable, Dict, List, Optional, Union | |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput | |
sys.path.insert(0, "src/utils") | |
from base_pipeline import BasePipeline | |
from cross_attention import prep_unet | |
class EditingPipeline(BasePipeline): | |
def __call__( | |
self, | |
prompt: Union[str, List[str]] = None, | |
height: Optional[int] = None, | |
width: Optional[int] = None, | |
num_inference_steps: int = 50, | |
guidance_scale: float = 7.5, | |
negative_prompt: Optional[Union[str, List[str]]] = None, | |
num_images_per_prompt: Optional[int] = 1, | |
eta: float = 0.0, | |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
latents: Optional[torch.FloatTensor] = None, | |
prompt_embeds: Optional[torch.FloatTensor] = None, | |
negative_prompt_embeds: Optional[torch.FloatTensor] = None, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
# pix2pix parameters | |
guidance_amount=0.1, | |
edit_dir=None, | |
x_in=None, | |
): | |
x_in.to(dtype=self.unet.dtype, device=self._execution_device) | |
# 0. modify the unet to be useful :D | |
self.unet = prep_unet(self.unet) | |
# 1. setup all caching objects | |
d_ref_t2attn = {} # reference cross attention maps | |
# 2. Default height and width to unet | |
height = height or self.unet.config.sample_size * self.vae_scale_factor | |
width = width or self.unet.config.sample_size * self.vae_scale_factor | |
# TODO: add the input checker function | |
# self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) | |
# 2. Define call parameters | |
if prompt is not None and isinstance(prompt, str): | |
batch_size = 1 | |
elif prompt is not None and isinstance(prompt, list): | |
batch_size = len(prompt) | |
else: | |
batch_size = prompt_embeds.shape[0] | |
device = self._execution_device | |
do_classifier_free_guidance = guidance_scale > 1.0 | |
x_in = x_in.to(dtype=self.unet.dtype, device=self._execution_device) | |
# 3. Encode input prompt = 2x77x1024 | |
prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,) | |
# 4. Prepare timesteps | |
self.scheduler.set_timesteps(num_inference_steps, device=device) | |
timesteps = self.scheduler.timesteps | |
# 5. Prepare latent variables | |
num_channels_latents = self.unet.in_channels | |
# randomly sample a latent code if not provided | |
latents = self.prepare_latents(batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, x_in,) | |
latents_init = latents.clone() | |
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | |
# 7. First Denoising loop for getting the reference cross attention maps | |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | |
with torch.no_grad(): | |
with self.progress_bar(total=num_inference_steps) as progress_bar: | |
for i, t in enumerate(timesteps): | |
# expand the latents if we are doing classifier free guidance | |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
# predict the noise residual | |
noise_pred = self.unet(latent_model_input,t,encoder_hidden_states=prompt_embeds,cross_attention_kwargs=cross_attention_kwargs,).sample | |
# add the cross attention map to the dictionary | |
d_ref_t2attn[t.item()] = {} | |
for name, module in self.unet.named_modules(): | |
module_name = type(module).__name__ | |
if module_name == "CrossAttention" and 'attn2' in name: | |
attn_mask = module.attn_probs # size is num_channel,s*s,77 | |
d_ref_t2attn[t.item()][name] = attn_mask.detach().cpu() | |
# perform guidance | |
if do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
# compute the previous noisy sample x_t -> x_t-1 | |
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | |
# call the callback, if provided | |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | |
progress_bar.update() | |
# make the reference image (reconstruction) | |
image_rec = self.numpy_to_pil(self.decode_latents(latents.detach())) | |
prompt_embeds_edit = prompt_embeds.clone() | |
#add the edit only to the second prompt, idx 0 is the negative prompt | |
prompt_embeds_edit[1:2] += edit_dir | |
latents = latents_init | |
# Second denoising loop for editing the text prompt | |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | |
with self.progress_bar(total=num_inference_steps) as progress_bar: | |
for i, t in enumerate(timesteps): | |
# expand the latents if we are doing classifier free guidance | |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
x_in = latent_model_input.detach().clone() | |
x_in.requires_grad = True | |
opt = torch.optim.SGD([x_in], lr=guidance_amount) | |
# predict the noise residual | |
noise_pred = self.unet(x_in,t,encoder_hidden_states=prompt_embeds_edit.detach(),cross_attention_kwargs=cross_attention_kwargs,).sample | |
loss = 0.0 | |
for name, module in self.unet.named_modules(): | |
module_name = type(module).__name__ | |
if module_name == "CrossAttention" and 'attn2' in name: | |
curr = module.attn_probs # size is num_channel,s*s,77 | |
ref = d_ref_t2attn[t.item()][name].detach().cuda() | |
loss += ((curr-ref)**2).sum((1,2)).mean(0) | |
loss.backward(retain_graph=False) | |
opt.step() | |
# recompute the noise | |
with torch.no_grad(): | |
noise_pred = self.unet(x_in.detach(),t,encoder_hidden_states=prompt_embeds_edit,cross_attention_kwargs=cross_attention_kwargs,).sample | |
latents = x_in.detach().chunk(2)[0] | |
# perform guidance | |
if do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
# compute the previous noisy sample x_t -> x_t-1 | |
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | |
# call the callback, if provided | |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | |
progress_bar.update() | |
# 8. Post-processing | |
image = self.decode_latents(latents.detach()) | |
# 9. Run safety checker | |
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) | |
# 10. Convert to PIL | |
image_edit = self.numpy_to_pil(image) | |
return image_rec, image_edit | |