|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
from typing import Optional, Tuple, Union |
|
|
|
import flax |
|
import jax |
|
import jax.numpy as jnp |
|
|
|
from ..configuration_utils import ConfigMixin, register_to_config |
|
from .scheduling_utils_flax import ( |
|
CommonSchedulerState, |
|
FlaxKarrasDiffusionSchedulers, |
|
FlaxSchedulerMixin, |
|
FlaxSchedulerOutput, |
|
add_noise_common, |
|
) |
|
|
|
|
|
@flax.struct.dataclass |
|
class PNDMSchedulerState: |
|
common: CommonSchedulerState |
|
final_alpha_cumprod: jnp.ndarray |
|
|
|
|
|
init_noise_sigma: jnp.ndarray |
|
timesteps: jnp.ndarray |
|
num_inference_steps: Optional[int] = None |
|
prk_timesteps: Optional[jnp.ndarray] = None |
|
plms_timesteps: Optional[jnp.ndarray] = None |
|
|
|
|
|
cur_model_output: Optional[jnp.ndarray] = None |
|
counter: Optional[jnp.int32] = None |
|
cur_sample: Optional[jnp.ndarray] = None |
|
ets: Optional[jnp.ndarray] = None |
|
|
|
@classmethod |
|
def create( |
|
cls, |
|
common: CommonSchedulerState, |
|
final_alpha_cumprod: jnp.ndarray, |
|
init_noise_sigma: jnp.ndarray, |
|
timesteps: jnp.ndarray, |
|
): |
|
return cls( |
|
common=common, |
|
final_alpha_cumprod=final_alpha_cumprod, |
|
init_noise_sigma=init_noise_sigma, |
|
timesteps=timesteps, |
|
) |
|
|
|
|
|
@dataclass |
|
class FlaxPNDMSchedulerOutput(FlaxSchedulerOutput): |
|
state: PNDMSchedulerState |
|
|
|
|
|
class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): |
|
""" |
|
Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, |
|
namely Runge-Kutta method and a linear multi-step method. |
|
|
|
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` |
|
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. |
|
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and |
|
[`~SchedulerMixin.from_pretrained`] functions. |
|
|
|
For more details, see the original paper: https://arxiv.org/abs/2202.09778 |
|
|
|
Args: |
|
num_train_timesteps (`int`): number of diffusion steps used to train the model. |
|
beta_start (`float`): the starting `beta` value of inference. |
|
beta_end (`float`): the final `beta` value. |
|
beta_schedule (`str`): |
|
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from |
|
`linear`, `scaled_linear`, or `squaredcos_cap_v2`. |
|
trained_betas (`jnp.ndarray`, optional): |
|
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. |
|
skip_prk_steps (`bool`): |
|
allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required |
|
before plms steps; defaults to `False`. |
|
set_alpha_to_one (`bool`, default `False`): |
|
each diffusion step uses the value of alphas product at that step and at the previous one. For the final |
|
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, |
|
otherwise it uses the value of alpha at step 0. |
|
steps_offset (`int`, default `0`): |
|
An offset added to the inference steps, as required by some model families. |
|
prediction_type (`str`, default `epsilon`, optional): |
|
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion |
|
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 |
|
https://imagen.research.google/video/paper.pdf) |
|
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): |
|
the `dtype` used for params and computation. |
|
""" |
|
|
|
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] |
|
|
|
dtype: jnp.dtype |
|
pndm_order: int |
|
|
|
@property |
|
def has_state(self): |
|
return True |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
num_train_timesteps: int = 1000, |
|
beta_start: float = 0.0001, |
|
beta_end: float = 0.02, |
|
beta_schedule: str = "linear", |
|
trained_betas: Optional[jnp.ndarray] = None, |
|
skip_prk_steps: bool = False, |
|
set_alpha_to_one: bool = False, |
|
steps_offset: int = 0, |
|
prediction_type: str = "epsilon", |
|
dtype: jnp.dtype = jnp.float32, |
|
): |
|
self.dtype = dtype |
|
|
|
|
|
|
|
|
|
self.pndm_order = 4 |
|
|
|
def create_state(self, common: Optional[CommonSchedulerState] = None) -> PNDMSchedulerState: |
|
if common is None: |
|
common = CommonSchedulerState.create(self) |
|
|
|
|
|
|
|
|
|
|
|
final_alpha_cumprod = ( |
|
jnp.array(1.0, dtype=self.dtype) if self.config.set_alpha_to_one else common.alphas_cumprod[0] |
|
) |
|
|
|
|
|
init_noise_sigma = jnp.array(1.0, dtype=self.dtype) |
|
|
|
timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] |
|
|
|
return PNDMSchedulerState.create( |
|
common=common, |
|
final_alpha_cumprod=final_alpha_cumprod, |
|
init_noise_sigma=init_noise_sigma, |
|
timesteps=timesteps, |
|
) |
|
|
|
def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, shape: Tuple) -> PNDMSchedulerState: |
|
""" |
|
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. |
|
|
|
Args: |
|
state (`PNDMSchedulerState`): |
|
the `FlaxPNDMScheduler` state data class instance. |
|
num_inference_steps (`int`): |
|
the number of diffusion steps used when generating samples with a pre-trained model. |
|
shape (`Tuple`): |
|
the shape of the samples to be generated. |
|
""" |
|
|
|
step_ratio = self.config.num_train_timesteps // num_inference_steps |
|
|
|
|
|
_timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round() + self.config.steps_offset |
|
|
|
if self.config.skip_prk_steps: |
|
|
|
|
|
|
|
|
|
prk_timesteps = jnp.array([], dtype=jnp.int32) |
|
plms_timesteps = jnp.concatenate([_timesteps[:-1], _timesteps[-2:-1], _timesteps[-1:]])[::-1] |
|
|
|
else: |
|
prk_timesteps = _timesteps[-self.pndm_order :].repeat(2) + jnp.tile( |
|
jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2], dtype=jnp.int32), |
|
self.pndm_order, |
|
) |
|
|
|
prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1] |
|
plms_timesteps = _timesteps[:-3][::-1] |
|
|
|
timesteps = jnp.concatenate([prk_timesteps, plms_timesteps]) |
|
|
|
|
|
|
|
cur_model_output = jnp.zeros(shape, dtype=self.dtype) |
|
counter = jnp.int32(0) |
|
cur_sample = jnp.zeros(shape, dtype=self.dtype) |
|
ets = jnp.zeros((4,) + shape, dtype=self.dtype) |
|
|
|
return state.replace( |
|
timesteps=timesteps, |
|
num_inference_steps=num_inference_steps, |
|
prk_timesteps=prk_timesteps, |
|
plms_timesteps=plms_timesteps, |
|
cur_model_output=cur_model_output, |
|
counter=counter, |
|
cur_sample=cur_sample, |
|
ets=ets, |
|
) |
|
|
|
def scale_model_input( |
|
self, state: PNDMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None |
|
) -> jnp.ndarray: |
|
""" |
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the |
|
current timestep. |
|
|
|
Args: |
|
state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. |
|
sample (`jnp.ndarray`): input sample |
|
timestep (`int`, optional): current timestep |
|
|
|
Returns: |
|
`jnp.ndarray`: scaled input sample |
|
""" |
|
return sample |
|
|
|
def step( |
|
self, |
|
state: PNDMSchedulerState, |
|
model_output: jnp.ndarray, |
|
timestep: int, |
|
sample: jnp.ndarray, |
|
return_dict: bool = True, |
|
) -> Union[FlaxPNDMSchedulerOutput, Tuple]: |
|
""" |
|
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion |
|
process from the learned model outputs (most often the predicted noise). |
|
|
|
This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`. |
|
|
|
Args: |
|
state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. |
|
model_output (`jnp.ndarray`): direct output from learned diffusion model. |
|
timestep (`int`): current discrete timestep in the diffusion chain. |
|
sample (`jnp.ndarray`): |
|
current instance of sample being created by diffusion process. |
|
return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class |
|
|
|
Returns: |
|
[`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a |
|
`tuple`. When returning a tuple, the first element is the sample tensor. |
|
|
|
""" |
|
|
|
if state.num_inference_steps is None: |
|
raise ValueError( |
|
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" |
|
) |
|
|
|
if self.config.skip_prk_steps: |
|
prev_sample, state = self.step_plms(state, model_output, timestep, sample) |
|
else: |
|
prk_prev_sample, prk_state = self.step_prk(state, model_output, timestep, sample) |
|
plms_prev_sample, plms_state = self.step_plms(state, model_output, timestep, sample) |
|
|
|
cond = state.counter < len(state.prk_timesteps) |
|
|
|
prev_sample = jax.lax.select(cond, prk_prev_sample, plms_prev_sample) |
|
|
|
state = state.replace( |
|
cur_model_output=jax.lax.select(cond, prk_state.cur_model_output, plms_state.cur_model_output), |
|
ets=jax.lax.select(cond, prk_state.ets, plms_state.ets), |
|
cur_sample=jax.lax.select(cond, prk_state.cur_sample, plms_state.cur_sample), |
|
counter=jax.lax.select(cond, prk_state.counter, plms_state.counter), |
|
) |
|
|
|
if not return_dict: |
|
return (prev_sample, state) |
|
|
|
return FlaxPNDMSchedulerOutput(prev_sample=prev_sample, state=state) |
|
|
|
def step_prk( |
|
self, |
|
state: PNDMSchedulerState, |
|
model_output: jnp.ndarray, |
|
timestep: int, |
|
sample: jnp.ndarray, |
|
) -> Union[FlaxPNDMSchedulerOutput, Tuple]: |
|
""" |
|
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the |
|
solution to the differential equation. |
|
|
|
Args: |
|
state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. |
|
model_output (`jnp.ndarray`): direct output from learned diffusion model. |
|
timestep (`int`): current discrete timestep in the diffusion chain. |
|
sample (`jnp.ndarray`): |
|
current instance of sample being created by diffusion process. |
|
return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class |
|
|
|
Returns: |
|
[`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a |
|
`tuple`. When returning a tuple, the first element is the sample tensor. |
|
|
|
""" |
|
|
|
if state.num_inference_steps is None: |
|
raise ValueError( |
|
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" |
|
) |
|
|
|
diff_to_prev = jnp.where( |
|
state.counter % 2, 0, self.config.num_train_timesteps // state.num_inference_steps // 2 |
|
) |
|
prev_timestep = timestep - diff_to_prev |
|
timestep = state.prk_timesteps[state.counter // 4 * 4] |
|
|
|
model_output = jax.lax.select( |
|
(state.counter % 4) != 3, |
|
model_output, |
|
state.cur_model_output + 1 / 6 * model_output, |
|
) |
|
|
|
state = state.replace( |
|
cur_model_output=jax.lax.select_n( |
|
state.counter % 4, |
|
state.cur_model_output + 1 / 6 * model_output, |
|
state.cur_model_output + 1 / 3 * model_output, |
|
state.cur_model_output + 1 / 3 * model_output, |
|
jnp.zeros_like(state.cur_model_output), |
|
), |
|
ets=jax.lax.select( |
|
(state.counter % 4) == 0, |
|
state.ets.at[0:3].set(state.ets[1:4]).at[3].set(model_output), |
|
state.ets, |
|
), |
|
cur_sample=jax.lax.select( |
|
(state.counter % 4) == 0, |
|
sample, |
|
state.cur_sample, |
|
), |
|
) |
|
|
|
cur_sample = state.cur_sample |
|
prev_sample = self._get_prev_sample(state, cur_sample, timestep, prev_timestep, model_output) |
|
state = state.replace(counter=state.counter + 1) |
|
|
|
return (prev_sample, state) |
|
|
|
def step_plms( |
|
self, |
|
state: PNDMSchedulerState, |
|
model_output: jnp.ndarray, |
|
timestep: int, |
|
sample: jnp.ndarray, |
|
) -> Union[FlaxPNDMSchedulerOutput, Tuple]: |
|
""" |
|
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple |
|
times to approximate the solution. |
|
|
|
Args: |
|
state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. |
|
model_output (`jnp.ndarray`): direct output from learned diffusion model. |
|
timestep (`int`): current discrete timestep in the diffusion chain. |
|
sample (`jnp.ndarray`): |
|
current instance of sample being created by diffusion process. |
|
return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class |
|
|
|
Returns: |
|
[`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a |
|
`tuple`. When returning a tuple, the first element is the sample tensor. |
|
|
|
""" |
|
|
|
if state.num_inference_steps is None: |
|
raise ValueError( |
|
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" |
|
) |
|
|
|
|
|
|
|
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps |
|
prev_timestep = jnp.where(prev_timestep > 0, prev_timestep, 0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prev_timestep = jnp.where(state.counter == 1, timestep, prev_timestep) |
|
timestep = jnp.where( |
|
state.counter == 1, timestep + self.config.num_train_timesteps // state.num_inference_steps, timestep |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
state = state.replace( |
|
ets=jax.lax.select( |
|
state.counter != 1, |
|
state.ets.at[0:3].set(state.ets[1:4]).at[3].set(model_output), |
|
state.ets, |
|
), |
|
cur_sample=jax.lax.select( |
|
state.counter != 1, |
|
sample, |
|
state.cur_sample, |
|
), |
|
) |
|
|
|
state = state.replace( |
|
cur_model_output=jax.lax.select_n( |
|
jnp.clip(state.counter, 0, 4), |
|
model_output, |
|
(model_output + state.ets[-1]) / 2, |
|
(3 * state.ets[-1] - state.ets[-2]) / 2, |
|
(23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12, |
|
(1 / 24) |
|
* (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]), |
|
), |
|
) |
|
|
|
sample = state.cur_sample |
|
model_output = state.cur_model_output |
|
prev_sample = self._get_prev_sample(state, sample, timestep, prev_timestep, model_output) |
|
state = state.replace(counter=state.counter + 1) |
|
|
|
return (prev_sample, state) |
|
|
|
def _get_prev_sample(self, state: PNDMSchedulerState, sample, timestep, prev_timestep, model_output): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
alpha_prod_t = state.common.alphas_cumprod[timestep] |
|
alpha_prod_t_prev = jnp.where( |
|
prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod |
|
) |
|
beta_prod_t = 1 - alpha_prod_t |
|
beta_prod_t_prev = 1 - alpha_prod_t_prev |
|
|
|
if self.config.prediction_type == "v_prediction": |
|
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample |
|
elif self.config.prediction_type != "epsilon": |
|
raise ValueError( |
|
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `v_prediction`" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) |
|
|
|
|
|
model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + ( |
|
alpha_prod_t * beta_prod_t * alpha_prod_t_prev |
|
) ** (0.5) |
|
|
|
|
|
prev_sample = ( |
|
sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff |
|
) |
|
|
|
return prev_sample |
|
|
|
def add_noise( |
|
self, |
|
state: PNDMSchedulerState, |
|
original_samples: jnp.ndarray, |
|
noise: jnp.ndarray, |
|
timesteps: jnp.ndarray, |
|
) -> jnp.ndarray: |
|
return add_noise_common(state.common, original_samples, noise, timesteps) |
|
|
|
def __len__(self): |
|
return self.config.num_train_timesteps |
|
|