|
import math |
|
from abc import ABC, abstractmethod |
|
from dataclasses import dataclass |
|
from typing import Callable, Optional, Tuple, Union |
|
|
|
import torch |
|
from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin |
|
from diffusers.utils import BaseOutput |
|
from torch import Tensor |
|
|
|
from xora.utils.torch_utils import append_dims |
|
|
|
|
|
def simple_diffusion_resolution_dependent_timestep_shift( |
|
samples: Tensor, |
|
timesteps: Tensor, |
|
n: int = 32 * 32, |
|
) -> Tensor: |
|
if len(samples.shape) == 3: |
|
_, m, _ = samples.shape |
|
elif len(samples.shape) in [4, 5]: |
|
m = math.prod(samples.shape[2:]) |
|
else: |
|
raise ValueError( |
|
"Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)" |
|
) |
|
snr = (timesteps / (1 - timesteps)) ** 2 |
|
shift_snr = torch.log(snr) + 2 * math.log(m / n) |
|
shifted_timesteps = torch.sigmoid(0.5 * shift_snr) |
|
|
|
return shifted_timesteps |
|
|
|
|
|
def time_shift(mu: float, sigma: float, t: Tensor): |
|
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) |
|
|
|
|
|
def get_normal_shift( |
|
n_tokens: int, |
|
min_tokens: int = 1024, |
|
max_tokens: int = 4096, |
|
min_shift: float = 0.95, |
|
max_shift: float = 2.05, |
|
) -> Callable[[float], float]: |
|
m = (max_shift - min_shift) / (max_tokens - min_tokens) |
|
b = min_shift - m * min_tokens |
|
return m * n_tokens + b |
|
|
|
|
|
def sd3_resolution_dependent_timestep_shift( |
|
samples: Tensor, timesteps: Tensor |
|
) -> Tensor: |
|
""" |
|
Shifts the timestep schedule as a function of the generated resolution. |
|
|
|
In the SD3 paper, the authors empirically how to shift the timesteps based on the resolution of the target images. |
|
For more details: https://arxiv.org/pdf/2403.03206 |
|
|
|
In Flux they later propose a more dynamic resolution dependent timestep shift, see: |
|
https://github.com/black-forest-labs/flux/blob/87f6fff727a377ea1c378af692afb41ae84cbe04/src/flux/sampling.py#L66 |
|
|
|
|
|
Args: |
|
samples (Tensor): A batch of samples with shape (batch_size, channels, height, width) or |
|
(batch_size, channels, frame, height, width). |
|
timesteps (Tensor): A batch of timesteps with shape (batch_size,). |
|
|
|
Returns: |
|
Tensor: The shifted timesteps. |
|
""" |
|
if len(samples.shape) == 3: |
|
_, m, _ = samples.shape |
|
elif len(samples.shape) in [4, 5]: |
|
m = math.prod(samples.shape[2:]) |
|
else: |
|
raise ValueError( |
|
"Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)" |
|
) |
|
|
|
shift = get_normal_shift(m) |
|
return time_shift(shift, 1, timesteps) |
|
|
|
|
|
class TimestepShifter(ABC): |
|
@abstractmethod |
|
def shift_timesteps(self, samples: Tensor, timesteps: Tensor) -> Tensor: |
|
pass |
|
|
|
|
|
@dataclass |
|
class RectifiedFlowSchedulerOutput(BaseOutput): |
|
""" |
|
Output class for the scheduler's step function output. |
|
|
|
Args: |
|
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): |
|
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the |
|
denoising loop. |
|
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): |
|
The predicted denoised sample (x_{0}) based on the model output from the current timestep. |
|
`pred_original_sample` can be used to preview progress or for guidance. |
|
""" |
|
|
|
prev_sample: torch.FloatTensor |
|
pred_original_sample: Optional[torch.FloatTensor] = None |
|
|
|
|
|
class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter): |
|
order = 1 |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
num_train_timesteps=1000, |
|
shifting: Optional[str] = None, |
|
base_resolution: int = 32**2, |
|
): |
|
super().__init__() |
|
self.init_noise_sigma = 1.0 |
|
self.num_inference_steps = None |
|
self.timesteps = self.sigmas = torch.linspace( |
|
1, 1 / num_train_timesteps, num_train_timesteps |
|
) |
|
self.delta_timesteps = self.timesteps - torch.cat( |
|
[self.timesteps[1:], torch.zeros_like(self.timesteps[-1:])] |
|
) |
|
self.shifting = shifting |
|
self.base_resolution = base_resolution |
|
|
|
def shift_timesteps(self, samples: Tensor, timesteps: Tensor) -> Tensor: |
|
if self.shifting == "SD3": |
|
return sd3_resolution_dependent_timestep_shift(samples, timesteps) |
|
elif self.shifting == "SimpleDiffusion": |
|
return simple_diffusion_resolution_dependent_timestep_shift( |
|
samples, timesteps, self.base_resolution |
|
) |
|
return timesteps |
|
|
|
def set_timesteps( |
|
self, |
|
num_inference_steps: int, |
|
samples: Tensor, |
|
device: Union[str, torch.device] = None, |
|
): |
|
""" |
|
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. |
|
|
|
Args: |
|
num_inference_steps (`int`): The number of diffusion steps used when generating samples. |
|
samples (`Tensor`): A batch of samples with shape. |
|
device (`Union[str, torch.device]`, *optional*): The device to which the timesteps tensor will be moved. |
|
""" |
|
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps) |
|
timesteps = torch.linspace(1, 1 / num_inference_steps, num_inference_steps).to( |
|
device |
|
) |
|
self.timesteps = self.shift_timesteps(samples, timesteps) |
|
self.delta_timesteps = self.timesteps - torch.cat( |
|
[self.timesteps[1:], torch.zeros_like(self.timesteps[-1:])] |
|
) |
|
self.num_inference_steps = num_inference_steps |
|
self.sigmas = self.timesteps |
|
|
|
def scale_model_input( |
|
self, sample: torch.FloatTensor, timestep: Optional[int] = None |
|
) -> torch.FloatTensor: |
|
|
|
""" |
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the |
|
current timestep. |
|
|
|
Args: |
|
sample (`torch.FloatTensor`): input sample |
|
timestep (`int`, optional): current timestep |
|
|
|
Returns: |
|
`torch.FloatTensor`: scaled input sample |
|
""" |
|
return sample |
|
|
|
def step( |
|
self, |
|
model_output: torch.FloatTensor, |
|
timestep: torch.FloatTensor, |
|
sample: torch.FloatTensor, |
|
eta: float = 0.0, |
|
use_clipped_model_output: bool = False, |
|
generator=None, |
|
variance_noise: Optional[torch.FloatTensor] = None, |
|
return_dict: bool = True, |
|
) -> Union[RectifiedFlowSchedulerOutput, Tuple]: |
|
|
|
""" |
|
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion |
|
process from the learned model outputs (most often the predicted noise). |
|
|
|
Args: |
|
model_output (`torch.FloatTensor`): |
|
The direct output from learned diffusion model. |
|
timestep (`float`): |
|
The current discrete timestep in the diffusion chain. |
|
sample (`torch.FloatTensor`): |
|
A current instance of a sample created by the diffusion process. |
|
eta (`float`): |
|
The weight of noise for added noise in diffusion step. |
|
use_clipped_model_output (`bool`, defaults to `False`): |
|
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary |
|
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no |
|
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and |
|
`use_clipped_model_output` has no effect. |
|
generator (`torch.Generator`, *optional*): |
|
A random number generator. |
|
variance_noise (`torch.FloatTensor`): |
|
Alternative to generating noise with `generator` by directly providing the noise for the variance |
|
itself. Useful for methods such as [`CycleDiffusion`]. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`. |
|
|
|
Returns: |
|
[`~schedulers.scheduling_utils.RectifiedFlowSchedulerOutput`] or `tuple`: |
|
If return_dict is `True`, [`~schedulers.rf_scheduler.RectifiedFlowSchedulerOutput`] is returned, |
|
otherwise a tuple is returned where the first element is the sample tensor. |
|
""" |
|
if self.num_inference_steps is None: |
|
raise ValueError( |
|
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" |
|
) |
|
|
|
if timestep.ndim == 0: |
|
|
|
current_index = (self.timesteps - timestep).abs().argmin() |
|
dt = self.delta_timesteps.gather(0, current_index.unsqueeze(0)) |
|
else: |
|
|
|
assert timestep.ndim == 2 |
|
current_index = ( |
|
(self.timesteps[:, None, None] - timestep[None]).abs().argmin(dim=0) |
|
) |
|
dt = self.delta_timesteps[current_index] |
|
|
|
dt = torch.where(timestep == 0.0, torch.zeros_like(dt), dt)[..., None] |
|
|
|
prev_sample = sample - dt * model_output |
|
|
|
if not return_dict: |
|
return (prev_sample,) |
|
|
|
return RectifiedFlowSchedulerOutput(prev_sample=prev_sample) |
|
|
|
def add_noise( |
|
self, |
|
original_samples: torch.FloatTensor, |
|
noise: torch.FloatTensor, |
|
timesteps: torch.FloatTensor, |
|
) -> torch.FloatTensor: |
|
sigmas = timesteps |
|
sigmas = append_dims(sigmas, original_samples.ndim) |
|
alphas = 1 - sigmas |
|
noisy_samples = alphas * original_samples + sigmas * noise |
|
return noisy_samples |
|
|