audioEditingFULL / inversion_utils.py
hilamanor's picture
Stable Audio Open + progbars + mp3 + batched forward + cleanup
7c56def
import torch
from tqdm import tqdm
from typing import List, Optional, Tuple
from models import PipelineWrapper
import gradio as gr
def inversion_forward_process(model: PipelineWrapper,
x0: torch.Tensor,
etas: Optional[float] = None,
prompts: List[str] = [""],
cfg_scales: List[float] = [3.5],
num_inference_steps: int = 50,
numerical_fix: bool = False,
duration: Optional[float] = None,
first_order: bool = False,
save_compute: bool = True,
progress=gr.Progress()) -> Tuple:
if len(prompts) > 1 or prompts[0] != "":
text_embeddings_hidden_states, text_embeddings_class_labels, \
text_embeddings_boolean_prompt_mask = model.encode_text(prompts)
# In the forward negative prompts are not supported currently (TODO)
uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text(
[""], negative=True, save_compute=save_compute, cond_length=text_embeddings_class_labels.shape[1]
if text_embeddings_class_labels is not None else None)
else:
uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text(
[""], negative=True, save_compute=False)
timesteps = model.model.scheduler.timesteps.to(model.device)
variance_noise_shape = model.get_noise_shape(x0, num_inference_steps)
if type(etas) in [int, float]:
etas = [etas]*model.model.scheduler.num_inference_steps
xts = model.sample_xts_from_x0(x0, num_inference_steps=num_inference_steps)
zs = torch.zeros(size=variance_noise_shape, device=model.device)
extra_info = [None] * len(zs)
if timesteps[0].dtype == torch.int64:
t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
elif timesteps[0].dtype == torch.float32:
t_to_idx = {float(v): k for k, v in enumerate(timesteps)}
xt = x0
op = tqdm(timesteps, desc="Inverting")
model.setup_extra_inputs(xt, init_timestep=timesteps[0], audio_end_in_s=duration,
save_compute=save_compute and prompts[0] != "")
app_op = progress.tqdm(timesteps, desc="Inverting")
for t, _ in zip(op, app_op):
idx = num_inference_steps - t_to_idx[int(t) if timesteps[0].dtype == torch.int64 else float(t)] - 1
# 1. predict noise residual
xt = xts[idx+1][None]
xt_inp = model.model.scheduler.scale_model_input(xt, t)
with torch.no_grad():
if save_compute and prompts[0] != "":
comb_out, _, _ = model.unet_forward(
xt_inp.expand(2, -1, -1, -1) if hasattr(model.model, 'unet') else xt_inp.expand(2, -1, -1),
timestep=t,
encoder_hidden_states=torch.cat([uncond_embeddings_hidden_states, text_embeddings_hidden_states
], dim=0)
if uncond_embeddings_hidden_states is not None else None,
class_labels=torch.cat([uncond_embeddings_class_lables, text_embeddings_class_labels], dim=0)
if uncond_embeddings_class_lables is not None else None,
encoder_attention_mask=torch.cat([uncond_boolean_prompt_mask, text_embeddings_boolean_prompt_mask
], dim=0)
if uncond_boolean_prompt_mask is not None else None,
)
out, cond_out = comb_out.sample.chunk(2, dim=0)
else:
out = model.unet_forward(xt_inp, timestep=t,
encoder_hidden_states=uncond_embeddings_hidden_states,
class_labels=uncond_embeddings_class_lables,
encoder_attention_mask=uncond_boolean_prompt_mask)[0].sample
if len(prompts) > 1 or prompts[0] != "":
cond_out = model.unet_forward(
xt_inp,
timestep=t,
encoder_hidden_states=text_embeddings_hidden_states,
class_labels=text_embeddings_class_labels,
encoder_attention_mask=text_embeddings_boolean_prompt_mask)[0].sample
if len(prompts) > 1 or prompts[0] != "":
# # classifier free guidance
noise_pred = out + (cfg_scales[0] * (cond_out - out)).sum(axis=0).unsqueeze(0)
else:
noise_pred = out
# xtm1 = xts[idx+1][None]
xtm1 = xts[idx][None]
z, xtm1, extra = model.get_zs_from_xts(xt, xtm1, noise_pred, t,
eta=etas[idx], numerical_fix=numerical_fix,
first_order=first_order)
zs[idx] = z
# print(f"Fix Xt-1 distance - NORM:{torch.norm(xts[idx] - xtm1):.4g}, MSE:{((xts[idx] - xtm1)**2).mean():.4g}")
xts[idx] = xtm1
extra_info[idx] = extra
if zs is not None:
# zs[-1] = torch.zeros_like(zs[-1])
zs[0] = torch.zeros_like(zs[0])
# zs_cycle[0] = torch.zeros_like(zs[0])
del app_op.iterables[0]
return xt, zs, xts, extra_info
def inversion_reverse_process(model: PipelineWrapper,
xT: torch.Tensor,
tstart: torch.Tensor,
etas: float = 0,
prompts: List[str] = [""],
neg_prompts: List[str] = [""],
cfg_scales: Optional[List[float]] = None,
zs: Optional[List[torch.Tensor]] = None,
duration: Optional[float] = None,
first_order: bool = False,
extra_info: Optional[List] = None,
save_compute: bool = True,
progress=gr.Progress()) -> Tuple[torch.Tensor, torch.Tensor]:
text_embeddings_hidden_states, text_embeddings_class_labels, \
text_embeddings_boolean_prompt_mask = model.encode_text(prompts)
uncond_embeddings_hidden_states, uncond_embeddings_class_lables, \
uncond_boolean_prompt_mask = model.encode_text(neg_prompts,
negative=True,
save_compute=save_compute,
cond_length=text_embeddings_class_labels.shape[1]
if text_embeddings_class_labels is not None else None)
xt = xT[tstart.max()].unsqueeze(0)
if etas is None:
etas = 0
if type(etas) in [int, float]:
etas = [etas]*model.model.scheduler.num_inference_steps
assert len(etas) == model.model.scheduler.num_inference_steps
timesteps = model.model.scheduler.timesteps.to(model.device)
op = tqdm(timesteps[-zs.shape[0]:], desc="Editing")
if timesteps[0].dtype == torch.int64:
t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])}
elif timesteps[0].dtype == torch.float32:
t_to_idx = {float(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])}
model.setup_extra_inputs(xt, extra_info=extra_info, init_timestep=timesteps[-zs.shape[0]],
audio_end_in_s=duration, save_compute=save_compute)
app_op = progress.tqdm(timesteps[-zs.shape[0]:], desc="Editing")
for it, (t, _) in enumerate(zip(op, app_op)):
idx = model.model.scheduler.num_inference_steps - t_to_idx[
int(t) if timesteps[0].dtype == torch.int64 else float(t)] - \
(model.model.scheduler.num_inference_steps - zs.shape[0] + 1)
xt_inp = model.model.scheduler.scale_model_input(xt, t)
# # Unconditional embedding
with torch.no_grad():
# print(f'xt_inp.shape: {xt_inp.shape}')
# print(f't.shape: {t.shape}')
# print(f'uncond_embeddings_hidden_states.shape: {uncond_embeddings_hidden_states.shape}')
# print(f'uncond_embeddings_class_lables.shape: {uncond_embeddings_class_lables.shape}')
# print(f'uncond_boolean_prompt_mask.shape: {uncond_boolean_prompt_mask.shape}')
# print(f'text_embeddings_hidden_states.shape: {text_embeddings_hidden_states.shape}')
# print(f'text_embeddings_class_labels.shape: {text_embeddings_class_labels.shape}')
# print(f'text_embeddings_boolean_prompt_mask.shape: {text_embeddings_boolean_prompt_mask.shape}')
if save_compute:
comb_out, _, _ = model.unet_forward(
xt_inp.expand(2, -1, -1, -1) if hasattr(model.model, 'unet') else xt_inp.expand(2, -1, -1),
timestep=t,
encoder_hidden_states=torch.cat([uncond_embeddings_hidden_states, text_embeddings_hidden_states
], dim=0)
if uncond_embeddings_hidden_states is not None else None,
class_labels=torch.cat([uncond_embeddings_class_lables, text_embeddings_class_labels], dim=0)
if uncond_embeddings_class_lables is not None else None,
encoder_attention_mask=torch.cat([uncond_boolean_prompt_mask, text_embeddings_boolean_prompt_mask
], dim=0)
if uncond_boolean_prompt_mask is not None else None,
)
uncond_out, cond_out = comb_out.sample.chunk(2, dim=0)
else:
uncond_out = model.unet_forward(
xt_inp, timestep=t,
encoder_hidden_states=uncond_embeddings_hidden_states,
class_labels=uncond_embeddings_class_lables,
encoder_attention_mask=uncond_boolean_prompt_mask,
)[0].sample
# Conditional embedding
cond_out = model.unet_forward(
xt_inp,
timestep=t,
encoder_hidden_states=text_embeddings_hidden_states,
class_labels=text_embeddings_class_labels,
encoder_attention_mask=text_embeddings_boolean_prompt_mask,
)[0].sample
z = zs[idx] if zs is not None else None
z = z.unsqueeze(0)
# classifier free guidance
noise_pred = uncond_out + (cfg_scales[0] * (cond_out - uncond_out)).sum(axis=0).unsqueeze(0)
# 2. compute less noisy image and set x_t -> x_t-1
xt = model.reverse_step_with_custom_noise(noise_pred, t, xt, variance_noise=z,
eta=etas[idx], first_order=first_order)
del app_op.iterables[0]
return xt, zs