Spaces:
Running
Running
import torch | |
from tqdm import tqdm | |
from typing import List, Optional, Tuple | |
from models import PipelineWrapper | |
import gradio as gr | |
def inversion_forward_process(model: PipelineWrapper, | |
x0: torch.Tensor, | |
etas: Optional[float] = None, | |
prompts: List[str] = [""], | |
cfg_scales: List[float] = [3.5], | |
num_inference_steps: int = 50, | |
numerical_fix: bool = False, | |
duration: Optional[float] = None, | |
first_order: bool = False, | |
save_compute: bool = True, | |
progress=gr.Progress()) -> Tuple: | |
if len(prompts) > 1 or prompts[0] != "": | |
text_embeddings_hidden_states, text_embeddings_class_labels, \ | |
text_embeddings_boolean_prompt_mask = model.encode_text(prompts) | |
# In the forward negative prompts are not supported currently (TODO) | |
uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text( | |
[""], negative=True, save_compute=save_compute, cond_length=text_embeddings_class_labels.shape[1] | |
if text_embeddings_class_labels is not None else None) | |
else: | |
uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text( | |
[""], negative=True, save_compute=False) | |
timesteps = model.model.scheduler.timesteps.to(model.device) | |
variance_noise_shape = model.get_noise_shape(x0, num_inference_steps) | |
if type(etas) in [int, float]: | |
etas = [etas]*model.model.scheduler.num_inference_steps | |
xts = model.sample_xts_from_x0(x0, num_inference_steps=num_inference_steps) | |
zs = torch.zeros(size=variance_noise_shape, device=model.device) | |
extra_info = [None] * len(zs) | |
if timesteps[0].dtype == torch.int64: | |
t_to_idx = {int(v): k for k, v in enumerate(timesteps)} | |
elif timesteps[0].dtype == torch.float32: | |
t_to_idx = {float(v): k for k, v in enumerate(timesteps)} | |
xt = x0 | |
op = tqdm(timesteps, desc="Inverting") | |
model.setup_extra_inputs(xt, init_timestep=timesteps[0], audio_end_in_s=duration, | |
save_compute=save_compute and prompts[0] != "") | |
app_op = progress.tqdm(timesteps, desc="Inverting") | |
for t, _ in zip(op, app_op): | |
idx = num_inference_steps - t_to_idx[int(t) if timesteps[0].dtype == torch.int64 else float(t)] - 1 | |
# 1. predict noise residual | |
xt = xts[idx+1][None] | |
xt_inp = model.model.scheduler.scale_model_input(xt, t) | |
with torch.no_grad(): | |
if save_compute and prompts[0] != "": | |
comb_out, _, _ = model.unet_forward( | |
xt_inp.expand(2, -1, -1, -1) if hasattr(model.model, 'unet') else xt_inp.expand(2, -1, -1), | |
timestep=t, | |
encoder_hidden_states=torch.cat([uncond_embeddings_hidden_states, text_embeddings_hidden_states | |
], dim=0) | |
if uncond_embeddings_hidden_states is not None else None, | |
class_labels=torch.cat([uncond_embeddings_class_lables, text_embeddings_class_labels], dim=0) | |
if uncond_embeddings_class_lables is not None else None, | |
encoder_attention_mask=torch.cat([uncond_boolean_prompt_mask, text_embeddings_boolean_prompt_mask | |
], dim=0) | |
if uncond_boolean_prompt_mask is not None else None, | |
) | |
out, cond_out = comb_out.sample.chunk(2, dim=0) | |
else: | |
out = model.unet_forward(xt_inp, timestep=t, | |
encoder_hidden_states=uncond_embeddings_hidden_states, | |
class_labels=uncond_embeddings_class_lables, | |
encoder_attention_mask=uncond_boolean_prompt_mask)[0].sample | |
if len(prompts) > 1 or prompts[0] != "": | |
cond_out = model.unet_forward( | |
xt_inp, | |
timestep=t, | |
encoder_hidden_states=text_embeddings_hidden_states, | |
class_labels=text_embeddings_class_labels, | |
encoder_attention_mask=text_embeddings_boolean_prompt_mask)[0].sample | |
if len(prompts) > 1 or prompts[0] != "": | |
# # classifier free guidance | |
noise_pred = out + (cfg_scales[0] * (cond_out - out)).sum(axis=0).unsqueeze(0) | |
else: | |
noise_pred = out | |
# xtm1 = xts[idx+1][None] | |
xtm1 = xts[idx][None] | |
z, xtm1, extra = model.get_zs_from_xts(xt, xtm1, noise_pred, t, | |
eta=etas[idx], numerical_fix=numerical_fix, | |
first_order=first_order) | |
zs[idx] = z | |
# print(f"Fix Xt-1 distance - NORM:{torch.norm(xts[idx] - xtm1):.4g}, MSE:{((xts[idx] - xtm1)**2).mean():.4g}") | |
xts[idx] = xtm1 | |
extra_info[idx] = extra | |
if zs is not None: | |
# zs[-1] = torch.zeros_like(zs[-1]) | |
zs[0] = torch.zeros_like(zs[0]) | |
# zs_cycle[0] = torch.zeros_like(zs[0]) | |
del app_op.iterables[0] | |
return xt, zs, xts, extra_info | |
def inversion_reverse_process(model: PipelineWrapper, | |
xT: torch.Tensor, | |
tstart: torch.Tensor, | |
etas: float = 0, | |
prompts: List[str] = [""], | |
neg_prompts: List[str] = [""], | |
cfg_scales: Optional[List[float]] = None, | |
zs: Optional[List[torch.Tensor]] = None, | |
duration: Optional[float] = None, | |
first_order: bool = False, | |
extra_info: Optional[List] = None, | |
save_compute: bool = True, | |
progress=gr.Progress()) -> Tuple[torch.Tensor, torch.Tensor]: | |
text_embeddings_hidden_states, text_embeddings_class_labels, \ | |
text_embeddings_boolean_prompt_mask = model.encode_text(prompts) | |
uncond_embeddings_hidden_states, uncond_embeddings_class_lables, \ | |
uncond_boolean_prompt_mask = model.encode_text(neg_prompts, | |
negative=True, | |
save_compute=save_compute, | |
cond_length=text_embeddings_class_labels.shape[1] | |
if text_embeddings_class_labels is not None else None) | |
xt = xT[tstart.max()].unsqueeze(0) | |
if etas is None: | |
etas = 0 | |
if type(etas) in [int, float]: | |
etas = [etas]*model.model.scheduler.num_inference_steps | |
assert len(etas) == model.model.scheduler.num_inference_steps | |
timesteps = model.model.scheduler.timesteps.to(model.device) | |
op = tqdm(timesteps[-zs.shape[0]:], desc="Editing") | |
if timesteps[0].dtype == torch.int64: | |
t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])} | |
elif timesteps[0].dtype == torch.float32: | |
t_to_idx = {float(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])} | |
model.setup_extra_inputs(xt, extra_info=extra_info, init_timestep=timesteps[-zs.shape[0]], | |
audio_end_in_s=duration, save_compute=save_compute) | |
app_op = progress.tqdm(timesteps[-zs.shape[0]:], desc="Editing") | |
for it, (t, _) in enumerate(zip(op, app_op)): | |
idx = model.model.scheduler.num_inference_steps - t_to_idx[ | |
int(t) if timesteps[0].dtype == torch.int64 else float(t)] - \ | |
(model.model.scheduler.num_inference_steps - zs.shape[0] + 1) | |
xt_inp = model.model.scheduler.scale_model_input(xt, t) | |
# # Unconditional embedding | |
with torch.no_grad(): | |
# print(f'xt_inp.shape: {xt_inp.shape}') | |
# print(f't.shape: {t.shape}') | |
# print(f'uncond_embeddings_hidden_states.shape: {uncond_embeddings_hidden_states.shape}') | |
# print(f'uncond_embeddings_class_lables.shape: {uncond_embeddings_class_lables.shape}') | |
# print(f'uncond_boolean_prompt_mask.shape: {uncond_boolean_prompt_mask.shape}') | |
# print(f'text_embeddings_hidden_states.shape: {text_embeddings_hidden_states.shape}') | |
# print(f'text_embeddings_class_labels.shape: {text_embeddings_class_labels.shape}') | |
# print(f'text_embeddings_boolean_prompt_mask.shape: {text_embeddings_boolean_prompt_mask.shape}') | |
if save_compute: | |
comb_out, _, _ = model.unet_forward( | |
xt_inp.expand(2, -1, -1, -1) if hasattr(model.model, 'unet') else xt_inp.expand(2, -1, -1), | |
timestep=t, | |
encoder_hidden_states=torch.cat([uncond_embeddings_hidden_states, text_embeddings_hidden_states | |
], dim=0) | |
if uncond_embeddings_hidden_states is not None else None, | |
class_labels=torch.cat([uncond_embeddings_class_lables, text_embeddings_class_labels], dim=0) | |
if uncond_embeddings_class_lables is not None else None, | |
encoder_attention_mask=torch.cat([uncond_boolean_prompt_mask, text_embeddings_boolean_prompt_mask | |
], dim=0) | |
if uncond_boolean_prompt_mask is not None else None, | |
) | |
uncond_out, cond_out = comb_out.sample.chunk(2, dim=0) | |
else: | |
uncond_out = model.unet_forward( | |
xt_inp, timestep=t, | |
encoder_hidden_states=uncond_embeddings_hidden_states, | |
class_labels=uncond_embeddings_class_lables, | |
encoder_attention_mask=uncond_boolean_prompt_mask, | |
)[0].sample | |
# Conditional embedding | |
cond_out = model.unet_forward( | |
xt_inp, | |
timestep=t, | |
encoder_hidden_states=text_embeddings_hidden_states, | |
class_labels=text_embeddings_class_labels, | |
encoder_attention_mask=text_embeddings_boolean_prompt_mask, | |
)[0].sample | |
z = zs[idx] if zs is not None else None | |
z = z.unsqueeze(0) | |
# classifier free guidance | |
noise_pred = uncond_out + (cfg_scales[0] * (cond_out - uncond_out)).sum(axis=0).unsqueeze(0) | |
# 2. compute less noisy image and set x_t -> x_t-1 | |
xt = model.reverse_step_with_custom_noise(noise_pred, t, xt, variance_noise=z, | |
eta=etas[idx], first_order=first_order) | |
del app_op.iterables[0] | |
return xt, zs | |