Spaces:
Paused
Paused
File size: 6,025 Bytes
d950775 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import sys
import numpy as np
import torch
import torch.nn.functional as F
from random import randrange
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
from diffusers import DDIMScheduler
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
sys.path.insert(0, "src/utils")
from base_pipeline import BasePipeline
from cross_attention import prep_unet
class DDIMInversion(BasePipeline):
def auto_corr_loss(self, x, random_shift=True):
B,C,H,W = x.shape
assert B==1
x = x.squeeze(0)
# x must be shape [C,H,W] now
reg_loss = 0.0
for ch_idx in range(x.shape[0]):
noise = x[ch_idx][None, None,:,:]
while True:
if random_shift: roll_amount = randrange(noise.shape[2]//2)
else: roll_amount = 1
reg_loss += (noise*torch.roll(noise, shifts=roll_amount, dims=2)).mean()**2
reg_loss += (noise*torch.roll(noise, shifts=roll_amount, dims=3)).mean()**2
if noise.shape[2] <= 8:
break
noise = F.avg_pool2d(noise, kernel_size=2)
return reg_loss
def kl_divergence(self, x):
_mu = x.mean()
_var = x.var()
return _var + _mu**2 - 1 - torch.log(_var+1e-7)
def __call__(
self,
prompt: Union[str, List[str]] = None,
num_inversion_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
img=None, # the input image as a PIL image
torch_dtype=torch.float32,
# inversion regularization parameters
lambda_ac: float = 20.0,
lambda_kl: float = 20.0,
num_reg_steps: int = 5,
num_ac_rolls: int = 5,
):
# 0. modify the unet to be useful :D
self.unet = prep_unet(self.unet)
# set the scheduler to be the Inverse DDIM scheduler
# self.scheduler = MyDDIMScheduler.from_config(self.scheduler.config)
device = self._execution_device
do_classifier_free_guidance = guidance_scale > 1.0
self.scheduler.set_timesteps(num_inversion_steps, device=device)
timesteps = self.scheduler.timesteps
# Encode the input image with the first stage model
x0 = np.array(img)/255
x0 = torch.from_numpy(x0).type(torch_dtype).permute(2, 0, 1).unsqueeze(dim=0).repeat(1, 1, 1, 1).cuda()
x0 = (x0 - 0.5) * 2.
with torch.no_grad():
x0_enc = self.vae.encode(x0).latent_dist.sample().to(device, torch_dtype)
latents = x0_enc = 0.18215 * x0_enc
# Decode and return the image
with torch.no_grad():
x0_dec = self.decode_latents(x0_enc.detach())
image_x0_dec = self.numpy_to_pil(x0_dec)
with torch.no_grad():
prompt_embeds = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt).to(device)
extra_step_kwargs = self.prepare_extra_step_kwargs(None, eta)
# Do the inversion
num_warmup_steps = len(timesteps) - num_inversion_steps * self.scheduler.order # should be 0?
with self.progress_bar(total=num_inversion_steps) as progress_bar:
for i, t in enumerate(timesteps.flip(0)[1:-1]):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
with torch.no_grad():
noise_pred = self.unet(latent_model_input,t,encoder_hidden_states=prompt_embeds,cross_attention_kwargs=cross_attention_kwargs,).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# regularization of the noise prediction
e_t = noise_pred
for _outer in range(num_reg_steps):
if lambda_ac>0:
for _inner in range(num_ac_rolls):
_var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
l_ac = self.auto_corr_loss(_var)
l_ac.backward()
_grad = _var.grad.detach()/num_ac_rolls
e_t = e_t - lambda_ac*_grad
if lambda_kl>0:
_var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
l_kld = self.kl_divergence(_var)
l_kld.backward()
_grad = _var.grad.detach()
e_t = e_t - lambda_kl*_grad
e_t = e_t.detach()
noise_pred = e_t
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, reverse=True, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
x_inv = latents.detach().clone()
# reconstruct the image
# 8. Post-processing
image = self.decode_latents(latents.detach())
image = self.numpy_to_pil(image)
return x_inv, image, image_x0_dec |