Spaces:
Running
on
Zero
Running
on
Zero
import random | |
import torch | |
from .schedules_sdedit import karras_schedule | |
from .solvers_sdedit import sample_dpmpp_2m_sde, sample_heun | |
from video_to_video.utils.logger import get_logger | |
logger = get_logger() | |
__all__ = ['GaussianDiffusion'] | |
def _i(tensor, t, x): | |
shape = (x.size(0), ) + (1, ) * (x.ndim - 1) | |
return tensor[t.to(tensor.device)].view(shape).to(x.device) | |
class GaussianDiffusion(object): | |
def __init__(self, sigmas): | |
self.sigmas = sigmas | |
self.alphas = torch.sqrt(1 - sigmas**2) | |
self.num_timesteps = len(sigmas) | |
def diffuse(self, x0, t, noise=None): | |
noise = torch.randn_like(x0) if noise is None else noise | |
xt = _i(self.alphas, t, x0) * x0 + _i(self.sigmas, t, x0) * noise | |
return xt | |
def get_velocity(self, x0, xt, t): | |
sigmas = _i(self.sigmas, t, xt) | |
alphas = _i(self.alphas, t, xt) | |
velocity = (alphas * xt - x0) / sigmas | |
return velocity | |
def get_x0(self, v, xt, t): | |
sigmas = _i(self.sigmas, t, xt) | |
alphas = _i(self.alphas, t, xt) | |
x0 = alphas * xt - sigmas * v | |
return x0 | |
def denoise(self, | |
xt, | |
t, | |
s, | |
model, | |
model_kwargs={}, | |
guide_scale=None, | |
guide_rescale=None, | |
clamp=None, | |
percentile=None, | |
variant_info=None,): | |
s = t - 1 if s is None else s | |
# hyperparams | |
sigmas = _i(self.sigmas, t, xt) | |
alphas = _i(self.alphas, t, xt) | |
alphas_s = _i(self.alphas, s.clamp(0), xt) | |
alphas_s[s < 0] = 1. | |
sigmas_s = torch.sqrt(1 - alphas_s**2) | |
# precompute variables | |
betas = 1 - (alphas / alphas_s)**2 | |
coef1 = betas * alphas_s / sigmas**2 | |
coef2 = (alphas * sigmas_s**2) / (alphas_s * sigmas**2) | |
var = betas * (sigmas_s / sigmas)**2 | |
log_var = torch.log(var).clamp_(-20, 20) | |
# prediction | |
if guide_scale is None: | |
assert isinstance(model_kwargs, dict) | |
out = model(xt, t=t, **model_kwargs) | |
else: | |
# classifier-free guidance | |
assert isinstance(model_kwargs, list) | |
if len(model_kwargs) > 3: | |
y_out = model(xt, t=t, **model_kwargs[0], **model_kwargs[2], **model_kwargs[3], **model_kwargs[4], **model_kwargs[5]) | |
else: | |
y_out = model(xt, t=t, **model_kwargs[0], **model_kwargs[2], variant_info=variant_info) | |
if guide_scale == 1.: | |
out = y_out | |
else: | |
if len(model_kwargs) > 3: | |
u_out = model(xt, t=t, **model_kwargs[1], **model_kwargs[2], **model_kwargs[3], **model_kwargs[4], **model_kwargs[5]) | |
else: | |
u_out = model(xt, t=t, **model_kwargs[1], **model_kwargs[2], variant_info=variant_info) | |
out = u_out + guide_scale * (y_out - u_out) | |
if guide_rescale is not None: | |
assert guide_rescale >= 0 and guide_rescale <= 1 | |
ratio = ( | |
y_out.flatten(1).std(dim=1) / # noqa | |
(out.flatten(1).std(dim=1) + 1e-12) | |
).view((-1, ) + (1, ) * (y_out.ndim - 1)) | |
out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0 | |
x0 = alphas * xt - sigmas * out | |
# restrict the range of x0 | |
if percentile is not None: | |
assert percentile > 0 and percentile <= 1 | |
s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1) | |
s = s.clamp_(1.0).view((-1, ) + (1, ) * (xt.ndim - 1)) | |
x0 = torch.min(s, torch.max(-s, x0)) / s | |
elif clamp is not None: | |
x0 = x0.clamp(-clamp, clamp) | |
# recompute eps using the restricted x0 | |
eps = (xt - alphas * x0) / sigmas | |
# compute mu (mean of posterior distribution) using the restricted x0 | |
mu = coef1 * x0 + coef2 * xt | |
return mu, var, log_var, x0, eps | |
def sample(self, | |
noise, | |
model, | |
model_kwargs={}, | |
condition_fn=None, | |
guide_scale=None, | |
guide_rescale=None, | |
clamp=None, | |
percentile=None, | |
solver='euler_a', | |
solver_mode='fast', | |
steps=20, | |
t_max=None, | |
t_min=None, | |
discretization=None, | |
discard_penultimate_step=None, | |
return_intermediate=None, | |
show_progress=False, | |
seed=-1, | |
chunk_inds=None, | |
**kwargs): | |
# sanity check | |
assert isinstance(steps, (int, torch.LongTensor)) | |
assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1) | |
assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1) | |
assert discretization in (None, 'leading', 'linspace', 'trailing') | |
assert discard_penultimate_step in (None, True, False) | |
assert return_intermediate in (None, 'x0', 'xt') | |
# function of diffusion solver | |
solver_fn = { | |
'heun': sample_heun, | |
'dpmpp_2m_sde': sample_dpmpp_2m_sde | |
}[solver] | |
# options | |
schedule = 'karras' if 'karras' in solver else None | |
discretization = discretization or 'linspace' | |
seed = seed if seed >= 0 else random.randint(0, 2**31) | |
if isinstance(steps, torch.LongTensor): | |
discard_penultimate_step = False | |
if discard_penultimate_step is None: | |
discard_penultimate_step = True if solver in ( | |
'dpm2', 'dpm2_ancestral', 'dpmpp_2m_sde', 'dpm2_karras', | |
'dpm2_ancestral_karras', 'dpmpp_2m_sde_karras') else False | |
# function for denoising xt to get x0 | |
intermediates = [] | |
def model_fn(xt, sigma): | |
# denoising | |
t = self._sigma_to_t(sigma).repeat(len(xt)).round().long() | |
x0 = self.denoise(xt, t, None, model, model_kwargs, guide_scale, | |
guide_rescale, clamp, percentile)[-2] | |
# collect intermediate outputs | |
if return_intermediate == 'xt': | |
intermediates.append(xt) | |
elif return_intermediate == 'x0': | |
intermediates.append(x0) | |
return x0 | |
mask_cond = model_kwargs[3]['mask_cond'] | |
def model_chunk_fn(xt, sigma): | |
# denoising | |
t = self._sigma_to_t(sigma).repeat(len(xt)).round().long() | |
O_LEN = chunk_inds[0][-1]-chunk_inds[1][0] | |
cut_f_ind = O_LEN//2 | |
results_list = [] | |
for i in range(len(chunk_inds)): | |
ind_start, ind_end = chunk_inds[i] | |
xt_chunk = xt[:,:,ind_start:ind_end].clone() | |
cur_f = xt_chunk.size(2) | |
model_kwargs[3]['mask_cond'] = mask_cond[:,ind_start:ind_end].clone() | |
x0_chunk = self.denoise(xt_chunk, t, None, model, model_kwargs, guide_scale, | |
guide_rescale, clamp, percentile)[-2] | |
if i == 0: | |
results_list.append(x0_chunk[:,:,:cur_f+cut_f_ind-O_LEN]) | |
elif i == len(chunk_inds)-1: | |
results_list.append(x0_chunk[:,:,cut_f_ind:]) | |
else: | |
results_list.append(x0_chunk[:,:,cut_f_ind:cur_f+cut_f_ind-O_LEN]) | |
x0 = torch.concat(results_list, dim=2) | |
torch.cuda.empty_cache() | |
return x0 | |
# get timesteps | |
if isinstance(steps, int): | |
steps += 1 if discard_penultimate_step else 0 | |
t_max = self.num_timesteps - 1 if t_max is None else t_max | |
t_min = 0 if t_min is None else t_min | |
# discretize timesteps | |
if discretization == 'leading': | |
steps = torch.arange(t_min, t_max + 1, | |
(t_max - t_min + 1) / steps).flip(0) | |
elif discretization == 'linspace': | |
steps = torch.linspace(t_max, t_min, steps) | |
elif discretization == 'trailing': | |
steps = torch.arange(t_max, t_min - 1, | |
-((t_max - t_min + 1) / steps)) | |
if solver_mode == 'fast': | |
t_mid = 500 | |
steps1 = torch.arange(t_max, t_mid - 1, | |
-((t_max - t_mid + 1) / 4)) | |
steps2 = torch.arange(t_mid, t_min - 1, | |
-((t_mid - t_min + 1) / 11)) | |
steps = torch.concat([steps1, steps2]) | |
else: | |
raise NotImplementedError( | |
f'{discretization} discretization not implemented') | |
steps = steps.clamp_(t_min, t_max) | |
steps = torch.as_tensor( | |
steps, dtype=torch.float32, device=noise.device) | |
# get sigmas | |
sigmas = self._t_to_sigma(steps) | |
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) | |
if schedule == 'karras': | |
if sigmas[0] == float('inf'): | |
sigmas = karras_schedule( | |
n=len(steps) - 1, | |
sigma_min=sigmas[sigmas > 0].min().item(), | |
sigma_max=sigmas[sigmas < float('inf')].max().item(), | |
rho=7.).to(sigmas) | |
sigmas = torch.cat([ | |
sigmas.new_tensor([float('inf')]), sigmas, | |
sigmas.new_zeros([1]) | |
]) | |
else: | |
sigmas = karras_schedule( | |
n=len(steps), | |
sigma_min=sigmas[sigmas > 0].min().item(), | |
sigma_max=sigmas.max().item(), | |
rho=7.).to(sigmas) | |
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) | |
if discard_penultimate_step: | |
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) | |
fn = model_chunk_fn if chunk_inds is not None else model_fn | |
x0 = solver_fn( | |
noise, fn, sigmas, show_progress=show_progress, **kwargs) | |
return (x0, intermediates) if return_intermediate is not None else x0 | |
def sample_sr(self, | |
noise, | |
model, | |
model_kwargs={}, | |
condition_fn=None, | |
guide_scale=None, | |
guide_rescale=None, | |
clamp=None, | |
percentile=None, | |
solver='euler_a', | |
solver_mode='fast', | |
steps=20, | |
t_max=None, | |
t_min=None, | |
discretization=None, | |
discard_penultimate_step=None, | |
return_intermediate=None, | |
show_progress=False, | |
seed=-1, | |
chunk_inds=None, | |
variant_info=None, | |
**kwargs): | |
# sanity check | |
assert isinstance(steps, (int, torch.LongTensor)) | |
assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1) | |
assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1) | |
assert discretization in (None, 'leading', 'linspace', 'trailing') | |
assert discard_penultimate_step in (None, True, False) | |
assert return_intermediate in (None, 'x0', 'xt') | |
# function of diffusion solver | |
solver_fn = { | |
'heun': sample_heun, | |
'dpmpp_2m_sde': sample_dpmpp_2m_sde | |
}[solver] | |
# options | |
schedule = 'karras' if 'karras' in solver else None | |
discretization = discretization or 'linspace' | |
seed = seed if seed >= 0 else random.randint(0, 2**31) | |
if isinstance(steps, torch.LongTensor): | |
discard_penultimate_step = False | |
if discard_penultimate_step is None: | |
discard_penultimate_step = True if solver in ( | |
'dpm2', 'dpm2_ancestral', 'dpmpp_2m_sde', 'dpm2_karras', | |
'dpm2_ancestral_karras', 'dpmpp_2m_sde_karras') else False | |
# function for denoising xt to get x0 | |
intermediates = [] | |
def model_fn(xt, sigma, variant_info=None): | |
# denoising | |
t = self._sigma_to_t(sigma).repeat(len(xt)).round().long() | |
x0 = self.denoise(xt, t, None, model, model_kwargs, guide_scale, | |
guide_rescale, clamp, percentile, variant_info=variant_info)[-2] | |
# collect intermediate outputs | |
if return_intermediate == 'xt': | |
intermediates.append(xt) | |
elif return_intermediate == 'x0': | |
print('add intermediate outputs x0') | |
intermediates.append(x0) | |
return x0 | |
# mask_cond = model_kwargs[3]['mask_cond'] | |
def model_chunk_fn(xt, sigma, variant_info=None): | |
# denoising | |
t = self._sigma_to_t(sigma).repeat(len(xt)).round().long() | |
O_LEN = chunk_inds[0][-1]-chunk_inds[1][0] | |
cut_f_ind = O_LEN//2 | |
results_list = [] | |
for i in range(len(chunk_inds)): | |
ind_start, ind_end = chunk_inds[i] | |
xt_chunk = xt[:,:,ind_start:ind_end].clone() | |
model_kwargs[2]['hint_chunk'] = model_kwargs[2]['hint'][:,:,ind_start:ind_end].clone() # new added | |
cur_f = xt_chunk.size(2) | |
# model_kwargs[3]['mask_cond'] = mask_cond[:,ind_start:ind_end].clone() | |
x0_chunk = self.denoise(xt_chunk, t, None, model, model_kwargs, guide_scale, | |
guide_rescale, clamp, percentile, variant_info=variant_info)[-2] | |
if i == 0: | |
results_list.append(x0_chunk[:,:,:cur_f+cut_f_ind-O_LEN]) | |
elif i == len(chunk_inds)-1: | |
results_list.append(x0_chunk[:,:,cut_f_ind:]) | |
else: | |
results_list.append(x0_chunk[:,:,cut_f_ind:cur_f+cut_f_ind-O_LEN]) | |
x0 = torch.concat(results_list, dim=2) | |
torch.cuda.empty_cache() | |
return x0 | |
# get timesteps | |
if isinstance(steps, int): | |
steps += 1 if discard_penultimate_step else 0 | |
t_max = self.num_timesteps - 1 if t_max is None else t_max | |
t_min = 0 if t_min is None else t_min | |
# discretize timesteps | |
if discretization == 'leading': | |
steps = torch.arange(t_min, t_max + 1, | |
(t_max - t_min + 1) / steps).flip(0) | |
elif discretization == 'linspace': | |
steps = torch.linspace(t_max, t_min, steps) | |
elif discretization == 'trailing': | |
steps = torch.arange(t_max, t_min - 1, | |
-((t_max - t_min + 1) / steps)) | |
if solver_mode == 'fast': | |
t_mid = 500 | |
steps1 = torch.arange(t_max, t_mid - 1, | |
-((t_max - t_mid + 1) / 4)) | |
steps2 = torch.arange(t_mid, t_min - 1, | |
-((t_mid - t_min + 1) / 11)) | |
steps = torch.concat([steps1, steps2]) | |
else: | |
raise NotImplementedError( | |
f'{discretization} discretization not implemented') | |
steps = steps.clamp_(t_min, t_max) | |
steps = torch.as_tensor( | |
steps, dtype=torch.float32, device=noise.device) | |
# get sigmas | |
sigmas = self._t_to_sigma(steps) | |
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) | |
if schedule == 'karras': | |
if sigmas[0] == float('inf'): | |
sigmas = karras_schedule( | |
n=len(steps) - 1, | |
sigma_min=sigmas[sigmas > 0].min().item(), | |
sigma_max=sigmas[sigmas < float('inf')].max().item(), | |
rho=7.).to(sigmas) | |
sigmas = torch.cat([ | |
sigmas.new_tensor([float('inf')]), sigmas, | |
sigmas.new_zeros([1]) | |
]) | |
else: | |
sigmas = karras_schedule( | |
n=len(steps), | |
sigma_min=sigmas[sigmas > 0].min().item(), | |
sigma_max=sigmas.max().item(), | |
rho=7.).to(sigmas) | |
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) | |
if discard_penultimate_step: | |
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) | |
fn = model_chunk_fn if chunk_inds is not None else model_fn | |
x0 = solver_fn( | |
noise, fn, sigmas, variant_info=variant_info, show_progress=show_progress, **kwargs) | |
return (x0, intermediates) if return_intermediate is not None else x0 | |
def _sigma_to_t(self, sigma): | |
if sigma == float('inf'): | |
t = torch.full_like(sigma, len(self.sigmas) - 1) | |
else: | |
log_sigmas = torch.sqrt(self.sigmas**2 / # noqa | |
(1 - self.sigmas**2)).log().to(sigma) | |
log_sigma = sigma.log() | |
dists = log_sigma - log_sigmas[:, None] | |
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp( | |
max=log_sigmas.shape[0] - 2) | |
high_idx = low_idx + 1 | |
low, high = log_sigmas[low_idx], log_sigmas[high_idx] | |
w = (low - log_sigma) / (low - high) | |
w = w.clamp(0, 1) | |
t = (1 - w) * low_idx + w * high_idx | |
t = t.view(sigma.shape) | |
if t.ndim == 0: | |
t = t.unsqueeze(0) | |
return t | |
def _t_to_sigma(self, t): | |
t = t.float() | |
low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() | |
log_sigmas = torch.sqrt(self.sigmas**2 / # noqa | |
(1 - self.sigmas**2)).log().to(t) | |
log_sigma = (1 - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx] | |
log_sigma[torch.isnan(log_sigma) | |
| torch.isinf(log_sigma)] = float('inf') | |
return log_sigma.exp() |