STAR / video_to_video /diffusion /diffusion_sdedit.py
xierui.0097
Add application file
f0e9666
raw
history blame
18.2 kB
import random
import torch
from .schedules_sdedit import karras_schedule
from .solvers_sdedit import sample_dpmpp_2m_sde, sample_heun
from video_to_video.utils.logger import get_logger
logger = get_logger()
__all__ = ['GaussianDiffusion']
def _i(tensor, t, x):
shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
return tensor[t.to(tensor.device)].view(shape).to(x.device)
class GaussianDiffusion(object):
def __init__(self, sigmas):
self.sigmas = sigmas
self.alphas = torch.sqrt(1 - sigmas**2)
self.num_timesteps = len(sigmas)
def diffuse(self, x0, t, noise=None):
noise = torch.randn_like(x0) if noise is None else noise
xt = _i(self.alphas, t, x0) * x0 + _i(self.sigmas, t, x0) * noise
return xt
def get_velocity(self, x0, xt, t):
sigmas = _i(self.sigmas, t, xt)
alphas = _i(self.alphas, t, xt)
velocity = (alphas * xt - x0) / sigmas
return velocity
def get_x0(self, v, xt, t):
sigmas = _i(self.sigmas, t, xt)
alphas = _i(self.alphas, t, xt)
x0 = alphas * xt - sigmas * v
return x0
def denoise(self,
xt,
t,
s,
model,
model_kwargs={},
guide_scale=None,
guide_rescale=None,
clamp=None,
percentile=None,
variant_info=None,):
s = t - 1 if s is None else s
# hyperparams
sigmas = _i(self.sigmas, t, xt)
alphas = _i(self.alphas, t, xt)
alphas_s = _i(self.alphas, s.clamp(0), xt)
alphas_s[s < 0] = 1.
sigmas_s = torch.sqrt(1 - alphas_s**2)
# precompute variables
betas = 1 - (alphas / alphas_s)**2
coef1 = betas * alphas_s / sigmas**2
coef2 = (alphas * sigmas_s**2) / (alphas_s * sigmas**2)
var = betas * (sigmas_s / sigmas)**2
log_var = torch.log(var).clamp_(-20, 20)
# prediction
if guide_scale is None:
assert isinstance(model_kwargs, dict)
out = model(xt, t=t, **model_kwargs)
else:
# classifier-free guidance
assert isinstance(model_kwargs, list)
if len(model_kwargs) > 3:
y_out = model(xt, t=t, **model_kwargs[0], **model_kwargs[2], **model_kwargs[3], **model_kwargs[4], **model_kwargs[5])
else:
y_out = model(xt, t=t, **model_kwargs[0], **model_kwargs[2], variant_info=variant_info)
if guide_scale == 1.:
out = y_out
else:
if len(model_kwargs) > 3:
u_out = model(xt, t=t, **model_kwargs[1], **model_kwargs[2], **model_kwargs[3], **model_kwargs[4], **model_kwargs[5])
else:
u_out = model(xt, t=t, **model_kwargs[1], **model_kwargs[2], variant_info=variant_info)
out = u_out + guide_scale * (y_out - u_out)
if guide_rescale is not None:
assert guide_rescale >= 0 and guide_rescale <= 1
ratio = (
y_out.flatten(1).std(dim=1) / # noqa
(out.flatten(1).std(dim=1) + 1e-12)
).view((-1, ) + (1, ) * (y_out.ndim - 1))
out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0
x0 = alphas * xt - sigmas * out
# restrict the range of x0
if percentile is not None:
assert percentile > 0 and percentile <= 1
s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1)
s = s.clamp_(1.0).view((-1, ) + (1, ) * (xt.ndim - 1))
x0 = torch.min(s, torch.max(-s, x0)) / s
elif clamp is not None:
x0 = x0.clamp(-clamp, clamp)
# recompute eps using the restricted x0
eps = (xt - alphas * x0) / sigmas
# compute mu (mean of posterior distribution) using the restricted x0
mu = coef1 * x0 + coef2 * xt
return mu, var, log_var, x0, eps
@torch.no_grad()
def sample(self,
noise,
model,
model_kwargs={},
condition_fn=None,
guide_scale=None,
guide_rescale=None,
clamp=None,
percentile=None,
solver='euler_a',
solver_mode='fast',
steps=20,
t_max=None,
t_min=None,
discretization=None,
discard_penultimate_step=None,
return_intermediate=None,
show_progress=False,
seed=-1,
chunk_inds=None,
**kwargs):
# sanity check
assert isinstance(steps, (int, torch.LongTensor))
assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1)
assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1)
assert discretization in (None, 'leading', 'linspace', 'trailing')
assert discard_penultimate_step in (None, True, False)
assert return_intermediate in (None, 'x0', 'xt')
# function of diffusion solver
solver_fn = {
'heun': sample_heun,
'dpmpp_2m_sde': sample_dpmpp_2m_sde
}[solver]
# options
schedule = 'karras' if 'karras' in solver else None
discretization = discretization or 'linspace'
seed = seed if seed >= 0 else random.randint(0, 2**31)
if isinstance(steps, torch.LongTensor):
discard_penultimate_step = False
if discard_penultimate_step is None:
discard_penultimate_step = True if solver in (
'dpm2', 'dpm2_ancestral', 'dpmpp_2m_sde', 'dpm2_karras',
'dpm2_ancestral_karras', 'dpmpp_2m_sde_karras') else False
# function for denoising xt to get x0
intermediates = []
def model_fn(xt, sigma):
# denoising
t = self._sigma_to_t(sigma).repeat(len(xt)).round().long()
x0 = self.denoise(xt, t, None, model, model_kwargs, guide_scale,
guide_rescale, clamp, percentile)[-2]
# collect intermediate outputs
if return_intermediate == 'xt':
intermediates.append(xt)
elif return_intermediate == 'x0':
intermediates.append(x0)
return x0
mask_cond = model_kwargs[3]['mask_cond']
def model_chunk_fn(xt, sigma):
# denoising
t = self._sigma_to_t(sigma).repeat(len(xt)).round().long()
O_LEN = chunk_inds[0][-1]-chunk_inds[1][0]
cut_f_ind = O_LEN//2
results_list = []
for i in range(len(chunk_inds)):
ind_start, ind_end = chunk_inds[i]
xt_chunk = xt[:,:,ind_start:ind_end].clone()
cur_f = xt_chunk.size(2)
model_kwargs[3]['mask_cond'] = mask_cond[:,ind_start:ind_end].clone()
x0_chunk = self.denoise(xt_chunk, t, None, model, model_kwargs, guide_scale,
guide_rescale, clamp, percentile)[-2]
if i == 0:
results_list.append(x0_chunk[:,:,:cur_f+cut_f_ind-O_LEN])
elif i == len(chunk_inds)-1:
results_list.append(x0_chunk[:,:,cut_f_ind:])
else:
results_list.append(x0_chunk[:,:,cut_f_ind:cur_f+cut_f_ind-O_LEN])
x0 = torch.concat(results_list, dim=2)
torch.cuda.empty_cache()
return x0
# get timesteps
if isinstance(steps, int):
steps += 1 if discard_penultimate_step else 0
t_max = self.num_timesteps - 1 if t_max is None else t_max
t_min = 0 if t_min is None else t_min
# discretize timesteps
if discretization == 'leading':
steps = torch.arange(t_min, t_max + 1,
(t_max - t_min + 1) / steps).flip(0)
elif discretization == 'linspace':
steps = torch.linspace(t_max, t_min, steps)
elif discretization == 'trailing':
steps = torch.arange(t_max, t_min - 1,
-((t_max - t_min + 1) / steps))
if solver_mode == 'fast':
t_mid = 500
steps1 = torch.arange(t_max, t_mid - 1,
-((t_max - t_mid + 1) / 4))
steps2 = torch.arange(t_mid, t_min - 1,
-((t_mid - t_min + 1) / 11))
steps = torch.concat([steps1, steps2])
else:
raise NotImplementedError(
f'{discretization} discretization not implemented')
steps = steps.clamp_(t_min, t_max)
steps = torch.as_tensor(
steps, dtype=torch.float32, device=noise.device)
# get sigmas
sigmas = self._t_to_sigma(steps)
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
if schedule == 'karras':
if sigmas[0] == float('inf'):
sigmas = karras_schedule(
n=len(steps) - 1,
sigma_min=sigmas[sigmas > 0].min().item(),
sigma_max=sigmas[sigmas < float('inf')].max().item(),
rho=7.).to(sigmas)
sigmas = torch.cat([
sigmas.new_tensor([float('inf')]), sigmas,
sigmas.new_zeros([1])
])
else:
sigmas = karras_schedule(
n=len(steps),
sigma_min=sigmas[sigmas > 0].min().item(),
sigma_max=sigmas.max().item(),
rho=7.).to(sigmas)
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
if discard_penultimate_step:
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
fn = model_chunk_fn if chunk_inds is not None else model_fn
x0 = solver_fn(
noise, fn, sigmas, show_progress=show_progress, **kwargs)
return (x0, intermediates) if return_intermediate is not None else x0
@torch.no_grad()
def sample_sr(self,
noise,
model,
model_kwargs={},
condition_fn=None,
guide_scale=None,
guide_rescale=None,
clamp=None,
percentile=None,
solver='euler_a',
solver_mode='fast',
steps=20,
t_max=None,
t_min=None,
discretization=None,
discard_penultimate_step=None,
return_intermediate=None,
show_progress=False,
seed=-1,
chunk_inds=None,
variant_info=None,
**kwargs):
# sanity check
assert isinstance(steps, (int, torch.LongTensor))
assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1)
assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1)
assert discretization in (None, 'leading', 'linspace', 'trailing')
assert discard_penultimate_step in (None, True, False)
assert return_intermediate in (None, 'x0', 'xt')
# function of diffusion solver
solver_fn = {
'heun': sample_heun,
'dpmpp_2m_sde': sample_dpmpp_2m_sde
}[solver]
# options
schedule = 'karras' if 'karras' in solver else None
discretization = discretization or 'linspace'
seed = seed if seed >= 0 else random.randint(0, 2**31)
if isinstance(steps, torch.LongTensor):
discard_penultimate_step = False
if discard_penultimate_step is None:
discard_penultimate_step = True if solver in (
'dpm2', 'dpm2_ancestral', 'dpmpp_2m_sde', 'dpm2_karras',
'dpm2_ancestral_karras', 'dpmpp_2m_sde_karras') else False
# function for denoising xt to get x0
intermediates = []
def model_fn(xt, sigma, variant_info=None):
# denoising
t = self._sigma_to_t(sigma).repeat(len(xt)).round().long()
x0 = self.denoise(xt, t, None, model, model_kwargs, guide_scale,
guide_rescale, clamp, percentile, variant_info=variant_info)[-2]
# collect intermediate outputs
if return_intermediate == 'xt':
intermediates.append(xt)
elif return_intermediate == 'x0':
print('add intermediate outputs x0')
intermediates.append(x0)
return x0
# mask_cond = model_kwargs[3]['mask_cond']
def model_chunk_fn(xt, sigma, variant_info=None):
# denoising
t = self._sigma_to_t(sigma).repeat(len(xt)).round().long()
O_LEN = chunk_inds[0][-1]-chunk_inds[1][0]
cut_f_ind = O_LEN//2
results_list = []
for i in range(len(chunk_inds)):
ind_start, ind_end = chunk_inds[i]
xt_chunk = xt[:,:,ind_start:ind_end].clone()
model_kwargs[2]['hint_chunk'] = model_kwargs[2]['hint'][:,:,ind_start:ind_end].clone() # new added
cur_f = xt_chunk.size(2)
# model_kwargs[3]['mask_cond'] = mask_cond[:,ind_start:ind_end].clone()
x0_chunk = self.denoise(xt_chunk, t, None, model, model_kwargs, guide_scale,
guide_rescale, clamp, percentile, variant_info=variant_info)[-2]
if i == 0:
results_list.append(x0_chunk[:,:,:cur_f+cut_f_ind-O_LEN])
elif i == len(chunk_inds)-1:
results_list.append(x0_chunk[:,:,cut_f_ind:])
else:
results_list.append(x0_chunk[:,:,cut_f_ind:cur_f+cut_f_ind-O_LEN])
x0 = torch.concat(results_list, dim=2)
torch.cuda.empty_cache()
return x0
# get timesteps
if isinstance(steps, int):
steps += 1 if discard_penultimate_step else 0
t_max = self.num_timesteps - 1 if t_max is None else t_max
t_min = 0 if t_min is None else t_min
# discretize timesteps
if discretization == 'leading':
steps = torch.arange(t_min, t_max + 1,
(t_max - t_min + 1) / steps).flip(0)
elif discretization == 'linspace':
steps = torch.linspace(t_max, t_min, steps)
elif discretization == 'trailing':
steps = torch.arange(t_max, t_min - 1,
-((t_max - t_min + 1) / steps))
if solver_mode == 'fast':
t_mid = 500
steps1 = torch.arange(t_max, t_mid - 1,
-((t_max - t_mid + 1) / 4))
steps2 = torch.arange(t_mid, t_min - 1,
-((t_mid - t_min + 1) / 11))
steps = torch.concat([steps1, steps2])
else:
raise NotImplementedError(
f'{discretization} discretization not implemented')
steps = steps.clamp_(t_min, t_max)
steps = torch.as_tensor(
steps, dtype=torch.float32, device=noise.device)
# get sigmas
sigmas = self._t_to_sigma(steps)
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
if schedule == 'karras':
if sigmas[0] == float('inf'):
sigmas = karras_schedule(
n=len(steps) - 1,
sigma_min=sigmas[sigmas > 0].min().item(),
sigma_max=sigmas[sigmas < float('inf')].max().item(),
rho=7.).to(sigmas)
sigmas = torch.cat([
sigmas.new_tensor([float('inf')]), sigmas,
sigmas.new_zeros([1])
])
else:
sigmas = karras_schedule(
n=len(steps),
sigma_min=sigmas[sigmas > 0].min().item(),
sigma_max=sigmas.max().item(),
rho=7.).to(sigmas)
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
if discard_penultimate_step:
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
fn = model_chunk_fn if chunk_inds is not None else model_fn
x0 = solver_fn(
noise, fn, sigmas, variant_info=variant_info, show_progress=show_progress, **kwargs)
return (x0, intermediates) if return_intermediate is not None else x0
def _sigma_to_t(self, sigma):
if sigma == float('inf'):
t = torch.full_like(sigma, len(self.sigmas) - 1)
else:
log_sigmas = torch.sqrt(self.sigmas**2 / # noqa
(1 - self.sigmas**2)).log().to(sigma)
log_sigma = sigma.log()
dists = log_sigma - log_sigmas[:, None]
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(
max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low, high = log_sigmas[low_idx], log_sigmas[high_idx]
w = (low - log_sigma) / (low - high)
w = w.clamp(0, 1)
t = (1 - w) * low_idx + w * high_idx
t = t.view(sigma.shape)
if t.ndim == 0:
t = t.unsqueeze(0)
return t
def _t_to_sigma(self, t):
t = t.float()
low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
log_sigmas = torch.sqrt(self.sigmas**2 / # noqa
(1 - self.sigmas**2)).log().to(t)
log_sigma = (1 - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx]
log_sigma[torch.isnan(log_sigma)
| torch.isinf(log_sigma)] = float('inf')
return log_sigma.exp()