STAR / video_to_video /diffusion /schedules_sdedit.py
xierui.0097
Add application file
f0e9666
raw
history blame
2.64 kB
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import torch
def betas_to_sigmas(betas):
return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0))
def sigmas_to_betas(sigmas):
square_alphas = 1 - sigmas**2
betas = 1 - torch.cat(
[square_alphas[:1], square_alphas[1:] / square_alphas[:-1]])
return betas
def logsnrs_to_sigmas(logsnrs):
return torch.sqrt(torch.sigmoid(-logsnrs))
def sigmas_to_logsnrs(sigmas):
square_sigmas = sigmas**2
return torch.log(square_sigmas / (1 - square_sigmas))
def _logsnr_cosine(n, logsnr_min=-15, logsnr_max=15):
t_min = math.atan(math.exp(-0.5 * logsnr_min))
t_max = math.atan(math.exp(-0.5 * logsnr_max))
t = torch.linspace(1, 0, n)
logsnrs = -2 * torch.log(torch.tan(t_min + t * (t_max - t_min)))
return logsnrs
def _logsnr_cosine_shifted(n, logsnr_min=-15, logsnr_max=15, scale=2):
logsnrs = _logsnr_cosine(n, logsnr_min, logsnr_max)
logsnrs += 2 * math.log(1 / scale)
return logsnrs
def _logsnr_cosine_interp(n,
logsnr_min=-15,
logsnr_max=15,
scale_min=2,
scale_max=4):
t = torch.linspace(1, 0, n)
logsnrs_min = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_min)
logsnrs_max = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_max)
logsnrs = t * logsnrs_min + (1 - t) * logsnrs_max
return logsnrs
def karras_schedule(n, sigma_min=0.002, sigma_max=80.0, rho=7.0):
ramp = torch.linspace(1, 0, n)
min_inv_rho = sigma_min**(1 / rho)
max_inv_rho = sigma_max**(1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**rho
sigmas = torch.sqrt(sigmas**2 / (1 + sigmas**2))
return sigmas
def logsnr_cosine_interp_schedule(n,
logsnr_min=-15,
logsnr_max=15,
scale_min=2,
scale_max=4):
return logsnrs_to_sigmas(
_logsnr_cosine_interp(n, logsnr_min, logsnr_max, scale_min, scale_max))
def noise_schedule(schedule='logsnr_cosine_interp',
n=1000,
zero_terminal_snr=False,
**kwargs):
# compute sigmas
sigmas = {
'logsnr_cosine_interp': logsnr_cosine_interp_schedule
}[schedule](n, **kwargs)
# post-processing
if zero_terminal_snr and sigmas.max() != 1.0:
scale = (1.0 - sigmas.min()) / (sigmas.max() - sigmas.min())
sigmas = sigmas.min() + scale * (sigmas - sigmas.min())
return sigmas