# Copyright 2024 Katherine Crowson, AniMemory Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple, Union import numpy as np import torch from diffusers.utils import logging from diffusers.utils.torch_utils import randn_tensor from diffusers.schedulers.scheduling_euler_ancestral_discrete import ( EulerAncestralDiscreteScheduler, EulerAncestralDiscreteSchedulerOutput, rescale_zero_terminal_snr, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name class EulerAncestralDiscreteXPredScheduler(EulerAncestralDiscreteScheduler): """ Ancestral sampling with Euler method steps. This model inherits from [`EulerAncestralDiscreteScheduler`]. Check the superclass documentation for the args and returns. For more details, see the original paper: https://arxiv.org/abs/2403.08381 """ def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", timestep_spacing: str = "linspace", steps_offset: int = 0, ): super(EulerAncestralDiscreteXPredScheduler, self).__init__( num_train_timesteps, beta_start, beta_end, beta_schedule, trained_betas, prediction_type, timestep_spacing, steps_offset, ) sigmas = np.array((1 - self.alphas_cumprod) ** 0.5, dtype=np.float32) self.sigmas = torch.from_numpy(sigmas) def rescale_betas_zero_snr(self): self.betas = rescale_zero_terminal_snr(self.betas) self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) sigmas = np.array((1 - self.alphas_cumprod) ** 0.5) self.sigmas = torch.from_numpy(sigmas) @property def init_noise_sigma(self): return 1.0 def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] ) -> torch.FloatTensor: self.is_scale_input_called = True # standard deviation of the initial noise distribution return sample def set_timesteps( self, num_inference_steps: int, device: Union[str, torch.device] = None ): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = np.linspace( 0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float )[::-1].copy() elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = ( (np.arange(0, num_inference_steps) * step_ratio) .round()[::-1] .copy() .astype(float) ) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = ( (np.arange(self.config.num_train_timesteps, 0, -step_ratio)) .round() .copy() .astype(float) ) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array((1 - self.alphas_cumprod) ** 0.5) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) if str(device).startswith("mps"): # mps does not support float64 self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) else: self.timesteps = torch.from_numpy(timesteps).to(device=device) def step( self, model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`float`): current timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. generator (`torch.Generator`, optional): Random number generator. return_dict (`bool`): option for returning tuple rather than EulerAncestralDiscreteSchedulerOutput class Returns: [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if ( isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor) ): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" " one of the `scheduler.timesteps` as a timestep." ), ) if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) step_index = (self.timesteps == timestep).nonzero().item() if self.config.prediction_type == "sample": pred_original_sample = model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) sigma_t = self.sigmas[step_index] sigma_s = self.sigmas[step_index + 1] alpha_t = (1 - sigma_t**2) ** 0.5 alpha_s = (1 - sigma_s**2) ** 0.5 coef_sample = (sigma_s / sigma_t) ** 2 * alpha_t / alpha_s coef_noise = (sigma_s / sigma_t) * (1 - (alpha_t / alpha_s) ** 2) ** 0.5 coef_x = alpha_s * (1 - alpha_t**2 / alpha_s**2) / sigma_t**2 device = model_output.device noise = randn_tensor( model_output.shape, dtype=model_output.dtype, device=device, generator=generator, ) prev_sample = ( coef_sample * sample + coef_x * pred_original_sample + coef_noise * noise ) if not return_dict: return (prev_sample,) return EulerAncestralDiscreteSchedulerOutput( prev_sample=prev_sample, pred_original_sample=pred_original_sample ) def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to( device=original_samples.device, dtype=original_samples.dtype ) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to( original_samples.device, dtype=torch.float32 ) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma return noisy_samples