AniMemory-alpha / scheduler /scheduling_euler_ancestral_discrete_x_pred.py
hlky's picture
hlky HF staff
Remote code for Diffusers
06e58fe
raw
history blame
10 kB
# Copyright 2024 Katherine Crowson, AniMemory Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor
from diffusers.schedulers.scheduling_euler_ancestral_discrete import (
EulerAncestralDiscreteScheduler,
EulerAncestralDiscreteSchedulerOutput,
rescale_zero_terminal_snr,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class EulerAncestralDiscreteXPredScheduler(EulerAncestralDiscreteScheduler):
"""
Ancestral sampling with Euler method steps. This model inherits from [`EulerAncestralDiscreteScheduler`]. Check the
superclass documentation for the args and returns.
For more details, see the original paper: https://arxiv.org/abs/2403.08381
"""
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
timestep_spacing: str = "linspace",
steps_offset: int = 0,
):
super(EulerAncestralDiscreteXPredScheduler, self).__init__(
num_train_timesteps,
beta_start,
beta_end,
beta_schedule,
trained_betas,
prediction_type,
timestep_spacing,
steps_offset,
)
sigmas = np.array((1 - self.alphas_cumprod) ** 0.5, dtype=np.float32)
self.sigmas = torch.from_numpy(sigmas)
def rescale_betas_zero_snr(self):
self.betas = rescale_zero_terminal_snr(self.betas)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
sigmas = np.array((1 - self.alphas_cumprod) ** 0.5)
self.sigmas = torch.from_numpy(sigmas)
@property
def init_noise_sigma(self):
return 1.0
def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor:
self.is_scale_input_called = True
# standard deviation of the initial noise distribution
return sample
def set_timesteps(
self, num_inference_steps: int, device: Union[str, torch.device] = None
):
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(
0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float
)[::-1].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (
(np.arange(0, num_inference_steps) * step_ratio)
.round()[::-1]
.copy()
.astype(float)
)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (
(np.arange(self.config.num_train_timesteps, 0, -step_ratio))
.round()
.copy()
.astype(float)
)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
sigmas = np.array((1 - self.alphas_cumprod) ** 0.5)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device)
if str(device).startswith("mps"):
# mps does not support float64
self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
else:
self.timesteps = torch.from_numpy(timesteps).to(device=device)
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`float`): current timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
generator (`torch.Generator`, optional): Random number generator.
return_dict (`bool`): option for returning tuple rather than EulerAncestralDiscreteSchedulerOutput class
Returns:
[`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] if `return_dict` is True, otherwise
a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero().item()
if self.config.prediction_type == "sample":
pred_original_sample = model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
)
sigma_t = self.sigmas[step_index]
sigma_s = self.sigmas[step_index + 1]
alpha_t = (1 - sigma_t**2) ** 0.5
alpha_s = (1 - sigma_s**2) ** 0.5
coef_sample = (sigma_s / sigma_t) ** 2 * alpha_t / alpha_s
coef_noise = (sigma_s / sigma_t) * (1 - (alpha_t / alpha_s) ** 2) ** 0.5
coef_x = alpha_s * (1 - alpha_t**2 / alpha_s**2) / sigma_t**2
device = model_output.device
noise = randn_tensor(
model_output.shape,
dtype=model_output.dtype,
device=device,
generator=generator,
)
prev_sample = (
coef_sample * sample + coef_x * pred_original_sample + coef_noise * noise
)
if not return_dict:
return (prev_sample,)
return EulerAncestralDiscreteSchedulerOutput(
prev_sample=prev_sample, pred_original_sample=pred_original_sample
)
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.FloatTensor,
) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(
device=original_samples.device, dtype=original_samples.dtype
)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64
schedule_timesteps = self.timesteps.to(
original_samples.device, dtype=torch.float32
)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
noisy_samples = original_samples + noise * sigma
return noisy_samples