import random import torch from .schedules_sdedit import karras_schedule from .solvers_sdedit import sample_dpmpp_2m_sde, sample_heun from video_to_video.utils.logger import get_logger logger = get_logger() __all__ = ['GaussianDiffusion'] def _i(tensor, t, x): shape = (x.size(0), ) + (1, ) * (x.ndim - 1) return tensor[t.to(tensor.device)].view(shape).to(x.device) class GaussianDiffusion(object): def __init__(self, sigmas): self.sigmas = sigmas self.alphas = torch.sqrt(1 - sigmas**2) self.num_timesteps = len(sigmas) def diffuse(self, x0, t, noise=None): noise = torch.randn_like(x0) if noise is None else noise xt = _i(self.alphas, t, x0) * x0 + _i(self.sigmas, t, x0) * noise return xt def get_velocity(self, x0, xt, t): sigmas = _i(self.sigmas, t, xt) alphas = _i(self.alphas, t, xt) velocity = (alphas * xt - x0) / sigmas return velocity def get_x0(self, v, xt, t): sigmas = _i(self.sigmas, t, xt) alphas = _i(self.alphas, t, xt) x0 = alphas * xt - sigmas * v return x0 def denoise(self, xt, t, s, model, model_kwargs={}, guide_scale=None, guide_rescale=None, clamp=None, percentile=None, variant_info=None,): s = t - 1 if s is None else s # hyperparams sigmas = _i(self.sigmas, t, xt) alphas = _i(self.alphas, t, xt) alphas_s = _i(self.alphas, s.clamp(0), xt) alphas_s[s < 0] = 1. sigmas_s = torch.sqrt(1 - alphas_s**2) # precompute variables betas = 1 - (alphas / alphas_s)**2 coef1 = betas * alphas_s / sigmas**2 coef2 = (alphas * sigmas_s**2) / (alphas_s * sigmas**2) var = betas * (sigmas_s / sigmas)**2 log_var = torch.log(var).clamp_(-20, 20) # prediction if guide_scale is None: assert isinstance(model_kwargs, dict) out = model(xt, t=t, **model_kwargs) else: # classifier-free guidance assert isinstance(model_kwargs, list) if len(model_kwargs) > 3: y_out = model(xt, t=t, **model_kwargs[0], **model_kwargs[2], **model_kwargs[3], **model_kwargs[4], **model_kwargs[5]) else: y_out = model(xt, t=t, **model_kwargs[0], **model_kwargs[2], variant_info=variant_info) if guide_scale == 1.: out = y_out else: if len(model_kwargs) > 3: u_out = model(xt, t=t, **model_kwargs[1], **model_kwargs[2], **model_kwargs[3], **model_kwargs[4], **model_kwargs[5]) else: u_out = model(xt, t=t, **model_kwargs[1], **model_kwargs[2], variant_info=variant_info) out = u_out + guide_scale * (y_out - u_out) if guide_rescale is not None: assert guide_rescale >= 0 and guide_rescale <= 1 ratio = ( y_out.flatten(1).std(dim=1) / # noqa (out.flatten(1).std(dim=1) + 1e-12) ).view((-1, ) + (1, ) * (y_out.ndim - 1)) out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0 x0 = alphas * xt - sigmas * out # restrict the range of x0 if percentile is not None: assert percentile > 0 and percentile <= 1 s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1) s = s.clamp_(1.0).view((-1, ) + (1, ) * (xt.ndim - 1)) x0 = torch.min(s, torch.max(-s, x0)) / s elif clamp is not None: x0 = x0.clamp(-clamp, clamp) # recompute eps using the restricted x0 eps = (xt - alphas * x0) / sigmas # compute mu (mean of posterior distribution) using the restricted x0 mu = coef1 * x0 + coef2 * xt return mu, var, log_var, x0, eps @torch.no_grad() def sample(self, noise, model, model_kwargs={}, condition_fn=None, guide_scale=None, guide_rescale=None, clamp=None, percentile=None, solver='euler_a', solver_mode='fast', steps=20, t_max=None, t_min=None, discretization=None, discard_penultimate_step=None, return_intermediate=None, show_progress=False, seed=-1, chunk_inds=None, **kwargs): # sanity check assert isinstance(steps, (int, torch.LongTensor)) assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1) assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1) assert discretization in (None, 'leading', 'linspace', 'trailing') assert discard_penultimate_step in (None, True, False) assert return_intermediate in (None, 'x0', 'xt') # function of diffusion solver solver_fn = { 'heun': sample_heun, 'dpmpp_2m_sde': sample_dpmpp_2m_sde }[solver] # options schedule = 'karras' if 'karras' in solver else None discretization = discretization or 'linspace' seed = seed if seed >= 0 else random.randint(0, 2**31) if isinstance(steps, torch.LongTensor): discard_penultimate_step = False if discard_penultimate_step is None: discard_penultimate_step = True if solver in ( 'dpm2', 'dpm2_ancestral', 'dpmpp_2m_sde', 'dpm2_karras', 'dpm2_ancestral_karras', 'dpmpp_2m_sde_karras') else False # function for denoising xt to get x0 intermediates = [] def model_fn(xt, sigma): # denoising t = self._sigma_to_t(sigma).repeat(len(xt)).round().long() x0 = self.denoise(xt, t, None, model, model_kwargs, guide_scale, guide_rescale, clamp, percentile)[-2] # collect intermediate outputs if return_intermediate == 'xt': intermediates.append(xt) elif return_intermediate == 'x0': intermediates.append(x0) return x0 mask_cond = model_kwargs[3]['mask_cond'] def model_chunk_fn(xt, sigma): # denoising t = self._sigma_to_t(sigma).repeat(len(xt)).round().long() O_LEN = chunk_inds[0][-1]-chunk_inds[1][0] cut_f_ind = O_LEN//2 results_list = [] for i in range(len(chunk_inds)): ind_start, ind_end = chunk_inds[i] xt_chunk = xt[:,:,ind_start:ind_end].clone() cur_f = xt_chunk.size(2) model_kwargs[3]['mask_cond'] = mask_cond[:,ind_start:ind_end].clone() x0_chunk = self.denoise(xt_chunk, t, None, model, model_kwargs, guide_scale, guide_rescale, clamp, percentile)[-2] if i == 0: results_list.append(x0_chunk[:,:,:cur_f+cut_f_ind-O_LEN]) elif i == len(chunk_inds)-1: results_list.append(x0_chunk[:,:,cut_f_ind:]) else: results_list.append(x0_chunk[:,:,cut_f_ind:cur_f+cut_f_ind-O_LEN]) x0 = torch.concat(results_list, dim=2) torch.cuda.empty_cache() return x0 # get timesteps if isinstance(steps, int): steps += 1 if discard_penultimate_step else 0 t_max = self.num_timesteps - 1 if t_max is None else t_max t_min = 0 if t_min is None else t_min # discretize timesteps if discretization == 'leading': steps = torch.arange(t_min, t_max + 1, (t_max - t_min + 1) / steps).flip(0) elif discretization == 'linspace': steps = torch.linspace(t_max, t_min, steps) elif discretization == 'trailing': steps = torch.arange(t_max, t_min - 1, -((t_max - t_min + 1) / steps)) if solver_mode == 'fast': t_mid = 500 steps1 = torch.arange(t_max, t_mid - 1, -((t_max - t_mid + 1) / 4)) steps2 = torch.arange(t_mid, t_min - 1, -((t_mid - t_min + 1) / 11)) steps = torch.concat([steps1, steps2]) else: raise NotImplementedError( f'{discretization} discretization not implemented') steps = steps.clamp_(t_min, t_max) steps = torch.as_tensor( steps, dtype=torch.float32, device=noise.device) # get sigmas sigmas = self._t_to_sigma(steps) sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) if schedule == 'karras': if sigmas[0] == float('inf'): sigmas = karras_schedule( n=len(steps) - 1, sigma_min=sigmas[sigmas > 0].min().item(), sigma_max=sigmas[sigmas < float('inf')].max().item(), rho=7.).to(sigmas) sigmas = torch.cat([ sigmas.new_tensor([float('inf')]), sigmas, sigmas.new_zeros([1]) ]) else: sigmas = karras_schedule( n=len(steps), sigma_min=sigmas[sigmas > 0].min().item(), sigma_max=sigmas.max().item(), rho=7.).to(sigmas) sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) if discard_penultimate_step: sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) fn = model_chunk_fn if chunk_inds is not None else model_fn x0 = solver_fn( noise, fn, sigmas, show_progress=show_progress, **kwargs) return (x0, intermediates) if return_intermediate is not None else x0 @torch.no_grad() def sample_sr(self, noise, model, model_kwargs={}, condition_fn=None, guide_scale=None, guide_rescale=None, clamp=None, percentile=None, solver='euler_a', solver_mode='fast', steps=20, t_max=None, t_min=None, discretization=None, discard_penultimate_step=None, return_intermediate=None, show_progress=False, seed=-1, chunk_inds=None, variant_info=None, **kwargs): # sanity check assert isinstance(steps, (int, torch.LongTensor)) assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1) assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1) assert discretization in (None, 'leading', 'linspace', 'trailing') assert discard_penultimate_step in (None, True, False) assert return_intermediate in (None, 'x0', 'xt') # function of diffusion solver solver_fn = { 'heun': sample_heun, 'dpmpp_2m_sde': sample_dpmpp_2m_sde }[solver] # options schedule = 'karras' if 'karras' in solver else None discretization = discretization or 'linspace' seed = seed if seed >= 0 else random.randint(0, 2**31) if isinstance(steps, torch.LongTensor): discard_penultimate_step = False if discard_penultimate_step is None: discard_penultimate_step = True if solver in ( 'dpm2', 'dpm2_ancestral', 'dpmpp_2m_sde', 'dpm2_karras', 'dpm2_ancestral_karras', 'dpmpp_2m_sde_karras') else False # function for denoising xt to get x0 intermediates = [] def model_fn(xt, sigma, variant_info=None): # denoising t = self._sigma_to_t(sigma).repeat(len(xt)).round().long() x0 = self.denoise(xt, t, None, model, model_kwargs, guide_scale, guide_rescale, clamp, percentile, variant_info=variant_info)[-2] # collect intermediate outputs if return_intermediate == 'xt': intermediates.append(xt) elif return_intermediate == 'x0': print('add intermediate outputs x0') intermediates.append(x0) return x0 # mask_cond = model_kwargs[3]['mask_cond'] def model_chunk_fn(xt, sigma, variant_info=None): # denoising t = self._sigma_to_t(sigma).repeat(len(xt)).round().long() O_LEN = chunk_inds[0][-1]-chunk_inds[1][0] cut_f_ind = O_LEN//2 results_list = [] for i in range(len(chunk_inds)): ind_start, ind_end = chunk_inds[i] xt_chunk = xt[:,:,ind_start:ind_end].clone() model_kwargs[2]['hint_chunk'] = model_kwargs[2]['hint'][:,:,ind_start:ind_end].clone() # new added cur_f = xt_chunk.size(2) # model_kwargs[3]['mask_cond'] = mask_cond[:,ind_start:ind_end].clone() x0_chunk = self.denoise(xt_chunk, t, None, model, model_kwargs, guide_scale, guide_rescale, clamp, percentile, variant_info=variant_info)[-2] if i == 0: results_list.append(x0_chunk[:,:,:cur_f+cut_f_ind-O_LEN]) elif i == len(chunk_inds)-1: results_list.append(x0_chunk[:,:,cut_f_ind:]) else: results_list.append(x0_chunk[:,:,cut_f_ind:cur_f+cut_f_ind-O_LEN]) x0 = torch.concat(results_list, dim=2) torch.cuda.empty_cache() return x0 # get timesteps if isinstance(steps, int): steps += 1 if discard_penultimate_step else 0 t_max = self.num_timesteps - 1 if t_max is None else t_max t_min = 0 if t_min is None else t_min # discretize timesteps if discretization == 'leading': steps = torch.arange(t_min, t_max + 1, (t_max - t_min + 1) / steps).flip(0) elif discretization == 'linspace': steps = torch.linspace(t_max, t_min, steps) elif discretization == 'trailing': steps = torch.arange(t_max, t_min - 1, -((t_max - t_min + 1) / steps)) if solver_mode == 'fast': t_mid = 500 steps1 = torch.arange(t_max, t_mid - 1, -((t_max - t_mid + 1) / 4)) steps2 = torch.arange(t_mid, t_min - 1, -((t_mid - t_min + 1) / 11)) steps = torch.concat([steps1, steps2]) else: raise NotImplementedError( f'{discretization} discretization not implemented') steps = steps.clamp_(t_min, t_max) steps = torch.as_tensor( steps, dtype=torch.float32, device=noise.device) # get sigmas sigmas = self._t_to_sigma(steps) sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) if schedule == 'karras': if sigmas[0] == float('inf'): sigmas = karras_schedule( n=len(steps) - 1, sigma_min=sigmas[sigmas > 0].min().item(), sigma_max=sigmas[sigmas < float('inf')].max().item(), rho=7.).to(sigmas) sigmas = torch.cat([ sigmas.new_tensor([float('inf')]), sigmas, sigmas.new_zeros([1]) ]) else: sigmas = karras_schedule( n=len(steps), sigma_min=sigmas[sigmas > 0].min().item(), sigma_max=sigmas.max().item(), rho=7.).to(sigmas) sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) if discard_penultimate_step: sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) fn = model_chunk_fn if chunk_inds is not None else model_fn x0 = solver_fn( noise, fn, sigmas, variant_info=variant_info, show_progress=show_progress, **kwargs) return (x0, intermediates) if return_intermediate is not None else x0 def _sigma_to_t(self, sigma): if sigma == float('inf'): t = torch.full_like(sigma, len(self.sigmas) - 1) else: log_sigmas = torch.sqrt(self.sigmas**2 / # noqa (1 - self.sigmas**2)).log().to(sigma) log_sigma = sigma.log() dists = log_sigma - log_sigmas[:, None] low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp( max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low, high = log_sigmas[low_idx], log_sigmas[high_idx] w = (low - log_sigma) / (low - high) w = w.clamp(0, 1) t = (1 - w) * low_idx + w * high_idx t = t.view(sigma.shape) if t.ndim == 0: t = t.unsqueeze(0) return t def _t_to_sigma(self, t): t = t.float() low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() log_sigmas = torch.sqrt(self.sigmas**2 / # noqa (1 - self.sigmas**2)).log().to(t) log_sigma = (1 - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx] log_sigma[torch.isnan(log_sigma) | torch.isinf(log_sigma)] = float('inf') return log_sigma.exp()