Stable-X commited on
Commit
b353dc0
1 Parent(s): 35c32ba

Update scheduler

Browse files
stablenormal/scheduler/heuristics_ddimsampler.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler
8
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
9
+ from diffusers.configuration_utils import register_to_config, ConfigMixin
10
+ import pdb
11
+
12
+
13
+ class HEURI_DDIMScheduler(DDIMScheduler, SchedulerMixin, ConfigMixin):
14
+
15
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
16
+ """
17
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
18
+
19
+ Args:
20
+ num_inference_steps (`int`):
21
+ The number of diffusion steps used when generating samples with a pre-trained model.
22
+ """
23
+
24
+ if num_inference_steps > self.config.num_train_timesteps:
25
+ raise ValueError(
26
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
27
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
28
+ f" maximal {self.config.num_train_timesteps} timesteps."
29
+ )
30
+
31
+ self.num_inference_steps = num_inference_steps
32
+
33
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
34
+ if self.config.timestep_spacing == "linspace":
35
+ timesteps = (
36
+ np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
37
+ .round()[::-1]
38
+ .copy()
39
+ .astype(np.int64)
40
+ )
41
+ elif self.config.timestep_spacing == "leading":
42
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
43
+ # creates integer timesteps by multiplying by ratio
44
+ # casting to int to avoid issues when num_inference_step is power of 3
45
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
46
+ timesteps += self.config.steps_offset
47
+ elif self.config.timestep_spacing == "trailing":
48
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
49
+ # creates integer timesteps by multiplying by ratio
50
+ # casting to int to avoid issues when num_inference_step is power of 3
51
+ timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
52
+ timesteps -= 1
53
+ else:
54
+ raise ValueError(
55
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
56
+ )
57
+
58
+ timesteps = torch.from_numpy(timesteps).to(device)
59
+ naive_sampling_step = num_inference_steps //2
60
+
61
+ self.naive_sampling_step = naive_sampling_step
62
+
63
+ timesteps[:naive_sampling_step] = timesteps[naive_sampling_step] # refine on step 5 for 5 steps, then backward from step 6
64
+
65
+ timesteps = [timestep + 1 for timestep in timesteps]
66
+
67
+ self.timesteps = timesteps
68
+ self.gap = self.config.num_train_timesteps // self.num_inference_steps
69
+ self.prev_timesteps = [timestep for timestep in self.timesteps[1:]]
70
+ self.prev_timesteps.append(torch.zeros_like(self.prev_timesteps[-1]))
71
+
72
+ def step(
73
+ self,
74
+ model_output: torch.Tensor,
75
+ timestep: int,
76
+ prev_timestep: int,
77
+ sample: torch.Tensor,
78
+ eta: float = 0.0,
79
+ use_clipped_model_output: bool = False,
80
+ generator=None,
81
+ cur_step=None,
82
+ gauss_latent=None,
83
+ variance_noise: Optional[torch.Tensor] = None,
84
+ return_dict: bool = True,
85
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
86
+ """
87
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
88
+ process from the learned model outputs (most often the predicted noise).
89
+
90
+ Args:
91
+ model_output (`torch.Tensor`):
92
+ The direct output from learned diffusion model.
93
+ timestep (`float`):
94
+ The current discrete timestep in the diffusion chain.
95
+ pre_timestep (`float`):
96
+ next_timestep
97
+ sample (`torch.Tensor`):
98
+ A current instance of a sample created by the diffusion process.
99
+ eta (`float`):
100
+ The weight of noise for added noise in diffusion step.
101
+ use_clipped_model_output (`bool`, defaults to `False`):
102
+ If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
103
+ because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
104
+ clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
105
+ `use_clipped_model_output` has no effect.
106
+ generator (`torch.Generator`, *optional*):
107
+ A random number generator.
108
+ variance_noise (`torch.Tensor`):
109
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
110
+ itself. Useful for methods such as [`CycleDiffusion`].
111
+ return_dict (`bool`, *optional*, defaults to `True`):
112
+ Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
113
+
114
+ Returns:
115
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
116
+ If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
117
+ tuple is returned where the first element is the sample tensor.
118
+
119
+ """
120
+ if self.num_inference_steps is None:
121
+ raise ValueError(
122
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
123
+ )
124
+
125
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
126
+ # Ideally, read DDIM paper in-detail understanding
127
+
128
+ # Notation (<variable name> -> <name in paper>
129
+ # - pred_noise_t -> e_theta(x_t, t)
130
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
131
+ # - std_dev_t -> sigma_t
132
+ # - eta -> η
133
+ # - pred_sample_direction -> "direction pointing to x_t"
134
+ # - pred_prev_sample -> "x_t-1"
135
+
136
+ # 1. get previous step value (=t-1)
137
+ # trick from heuri_sampling
138
+ if cur_step == self.naive_sampling_step and timestep == prev_timestep:
139
+ timestep += self.gap
140
+
141
+ prev_timestep = prev_timestep # NOTE naive sampling
142
+
143
+ # 2. compute alphas, betas
144
+ alpha_prod_t = self.alphas_cumprod[timestep]
145
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
146
+
147
+ beta_prod_t = 1 - alpha_prod_t
148
+
149
+ # 3. compute predicted original sample from predicted noise also called
150
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
151
+ if self.config.prediction_type == "epsilon":
152
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
153
+ pred_epsilon = model_output
154
+ elif self.config.prediction_type == "sample":
155
+ pred_original_sample = model_output
156
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
157
+ elif self.config.prediction_type == "v_prediction":
158
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
159
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
160
+ else:
161
+ raise ValueError(
162
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
163
+ " `v_prediction`"
164
+ )
165
+
166
+ # 4. Clip or threshold "predicted x_0"
167
+ if self.config.thresholding:
168
+ pred_original_sample = self._threshold_sample(pred_original_sample)
169
+
170
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
171
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
172
+ variance = self._get_variance(timestep, prev_timestep)
173
+ std_dev_t = eta * variance ** (0.5)
174
+
175
+ if use_clipped_model_output:
176
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
177
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
178
+
179
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
180
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
181
+
182
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
183
+ if gauss_latent == None:
184
+ gauss_latent = torch.randn_like(pred_original_sample)
185
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
186
+
187
+ if eta > 0:
188
+ if variance_noise is not None and generator is not None:
189
+ raise ValueError(
190
+ "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
191
+ " `variance_noise` stays `None`."
192
+ )
193
+
194
+ if variance_noise is None:
195
+ variance_noise = randn_tensor(
196
+ model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
197
+ )
198
+ variance = std_dev_t * variance_noise
199
+
200
+ prev_sample = prev_sample + variance
201
+
202
+ if cur_step < self.naive_sampling_step:
203
+ prev_sample = self.add_noise(pred_original_sample, gauss_latent, timestep)
204
+
205
+ if not return_dict:
206
+ return (prev_sample,)
207
+
208
+ return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
209
+
210
+
211
+
212
+ def add_noise(
213
+ self,
214
+ original_samples: torch.Tensor,
215
+ noise: torch.Tensor,
216
+ timesteps: torch.IntTensor,
217
+ ) -> torch.Tensor:
218
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
219
+ # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
220
+ # for the subsequent add_noise calls
221
+ self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
222
+ alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
223
+ timesteps = timesteps.to(original_samples.device)
224
+
225
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
226
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
227
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
228
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
229
+
230
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
231
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
232
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
233
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
234
+
235
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
236
+ return noisy_samples