# Copyright (c) Alibaba, Inc. and its affiliates. import math import torch def betas_to_sigmas(betas): return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0)) def sigmas_to_betas(sigmas): square_alphas = 1 - sigmas**2 betas = 1 - torch.cat( [square_alphas[:1], square_alphas[1:] / square_alphas[:-1]]) return betas def logsnrs_to_sigmas(logsnrs): return torch.sqrt(torch.sigmoid(-logsnrs)) def sigmas_to_logsnrs(sigmas): square_sigmas = sigmas**2 return torch.log(square_sigmas / (1 - square_sigmas)) def _logsnr_cosine(n, logsnr_min=-15, logsnr_max=15): t_min = math.atan(math.exp(-0.5 * logsnr_min)) t_max = math.atan(math.exp(-0.5 * logsnr_max)) t = torch.linspace(1, 0, n) logsnrs = -2 * torch.log(torch.tan(t_min + t * (t_max - t_min))) return logsnrs def _logsnr_cosine_shifted(n, logsnr_min=-15, logsnr_max=15, scale=2): logsnrs = _logsnr_cosine(n, logsnr_min, logsnr_max) logsnrs += 2 * math.log(1 / scale) return logsnrs def _logsnr_cosine_interp(n, logsnr_min=-15, logsnr_max=15, scale_min=2, scale_max=4): t = torch.linspace(1, 0, n) logsnrs_min = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_min) logsnrs_max = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_max) logsnrs = t * logsnrs_min + (1 - t) * logsnrs_max return logsnrs def karras_schedule(n, sigma_min=0.002, sigma_max=80.0, rho=7.0): ramp = torch.linspace(1, 0, n) min_inv_rho = sigma_min**(1 / rho) max_inv_rho = sigma_max**(1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**rho sigmas = torch.sqrt(sigmas**2 / (1 + sigmas**2)) return sigmas def logsnr_cosine_interp_schedule(n, logsnr_min=-15, logsnr_max=15, scale_min=2, scale_max=4): return logsnrs_to_sigmas( _logsnr_cosine_interp(n, logsnr_min, logsnr_max, scale_min, scale_max)) def noise_schedule(schedule='logsnr_cosine_interp', n=1000, zero_terminal_snr=False, **kwargs): # compute sigmas sigmas = { 'logsnr_cosine_interp': logsnr_cosine_interp_schedule }[schedule](n, **kwargs) # post-processing if zero_terminal_snr and sigmas.max() != 1.0: scale = (1.0 - sigmas.min()) / (sigmas.max() - sigmas.min()) sigmas = sigmas.min() + scale * (sigmas - sigmas.min()) return sigmas