LuChengTHU commited on
Commit
531ea40
·
1 Parent(s): 69e4d47

add dpm-solver support (much faster than plms)

Browse files

Former-commit-id: 8ee518a5a26d8f57d61985f81dc19d7e17a74a7d

ldm/models/diffusion/dpm_solver/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sampler import DPMSolverSampler
ldm/models/diffusion/dpm_solver/dpm_solver.py ADDED
@@ -0,0 +1,1157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+
5
+
6
+ class NoiseScheduleVP:
7
+ def __init__(
8
+ self,
9
+ schedule='discrete',
10
+ betas=None,
11
+ alphas_cumprod=None,
12
+ continuous_beta_0=0.1,
13
+ continuous_beta_1=20.,
14
+ ):
15
+ """Create a wrapper class for the forward SDE (VP type).
16
+
17
+ ***
18
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
19
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
20
+ ***
21
+
22
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
23
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
24
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
25
+
26
+ log_alpha_t = self.marginal_log_mean_coeff(t)
27
+ sigma_t = self.marginal_std(t)
28
+ lambda_t = self.marginal_lambda(t)
29
+
30
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
31
+
32
+ t = self.inverse_lambda(lambda_t)
33
+
34
+ ===============================================================
35
+
36
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
37
+
38
+ 1. For discrete-time DPMs:
39
+
40
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
41
+ t_i = (i + 1) / N
42
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
43
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
44
+
45
+ Args:
46
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
47
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
48
+
49
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
50
+
51
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
52
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
53
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
54
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
55
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
56
+ and
57
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
58
+
59
+
60
+ 2. For continuous-time DPMs:
61
+
62
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
63
+ schedule are the default settings in DDPM and improved-DDPM:
64
+
65
+ Args:
66
+ beta_min: A `float` number. The smallest beta for the linear schedule.
67
+ beta_max: A `float` number. The largest beta for the linear schedule.
68
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
69
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
70
+ T: A `float` number. The ending time of the forward process.
71
+
72
+ ===============================================================
73
+
74
+ Args:
75
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
76
+ 'linear' or 'cosine' for continuous-time DPMs.
77
+ Returns:
78
+ A wrapper object of the forward SDE (VP type).
79
+
80
+ ===============================================================
81
+
82
+ Example:
83
+
84
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
85
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
86
+
87
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
88
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
89
+
90
+ # For continuous-time DPMs (VPSDE), linear schedule:
91
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
92
+
93
+ """
94
+
95
+ if schedule not in ['discrete', 'linear', 'cosine']:
96
+ raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule))
97
+
98
+ self.schedule = schedule
99
+ if schedule == 'discrete':
100
+ if betas is not None:
101
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
102
+ else:
103
+ assert alphas_cumprod is not None
104
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
105
+ self.total_N = len(log_alphas)
106
+ self.T = 1.
107
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
108
+ self.log_alpha_array = log_alphas.reshape((1, -1,))
109
+ else:
110
+ self.total_N = 1000
111
+ self.beta_0 = continuous_beta_0
112
+ self.beta_1 = continuous_beta_1
113
+ self.cosine_s = 0.008
114
+ self.cosine_beta_max = 999.
115
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
116
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
117
+ self.schedule = schedule
118
+ if schedule == 'cosine':
119
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
120
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
121
+ self.T = 0.9946
122
+ else:
123
+ self.T = 1.
124
+
125
+ def marginal_log_mean_coeff(self, t):
126
+ """
127
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
128
+ """
129
+ if self.schedule == 'discrete':
130
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
131
+ elif self.schedule == 'linear':
132
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
133
+ elif self.schedule == 'cosine':
134
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
135
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
136
+ return log_alpha_t
137
+
138
+ def marginal_alpha(self, t):
139
+ """
140
+ Compute alpha_t of a given continuous-time label t in [0, T].
141
+ """
142
+ return torch.exp(self.marginal_log_mean_coeff(t))
143
+
144
+ def marginal_std(self, t):
145
+ """
146
+ Compute sigma_t of a given continuous-time label t in [0, T].
147
+ """
148
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
149
+
150
+ def marginal_lambda(self, t):
151
+ """
152
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
153
+ """
154
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
155
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
156
+ return log_mean_coeff - log_std
157
+
158
+ def inverse_lambda(self, lamb):
159
+ """
160
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
161
+ """
162
+ if self.schedule == 'linear':
163
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
164
+ Delta = self.beta_0**2 + tmp
165
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
166
+ elif self.schedule == 'discrete':
167
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
168
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
169
+ return t.reshape((-1,))
170
+ else:
171
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
172
+ t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
173
+ t = t_fn(log_alpha)
174
+ return t
175
+
176
+
177
+ def model_wrapper(
178
+ model,
179
+ noise_schedule,
180
+ model_type="noise",
181
+ model_kwargs={},
182
+ guidance_type="uncond",
183
+ condition=None,
184
+ unconditional_condition=None,
185
+ guidance_scale=1.,
186
+ classifier_fn=None,
187
+ classifier_kwargs={},
188
+ ):
189
+ """Create a wrapper function for the noise prediction model.
190
+
191
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
192
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
193
+
194
+ We support four types of the diffusion model by setting `model_type`:
195
+
196
+ 1. "noise": noise prediction model. (Trained by predicting noise).
197
+
198
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
199
+
200
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
201
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
202
+
203
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
204
+ arXiv preprint arXiv:2202.00512 (2022).
205
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
206
+ arXiv preprint arXiv:2210.02303 (2022).
207
+
208
+ 4. "score": marginal score function. (Trained by denoising score matching).
209
+ Note that the score function and the noise prediction model follows a simple relationship:
210
+ ```
211
+ noise(x_t, t) = -sigma_t * score(x_t, t)
212
+ ```
213
+
214
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
215
+ 1. "uncond": unconditional sampling by DPMs.
216
+ The input `model` has the following format:
217
+ ``
218
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
219
+ ``
220
+
221
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
222
+ The input `model` has the following format:
223
+ ``
224
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
225
+ ``
226
+
227
+ The input `classifier_fn` has the following format:
228
+ ``
229
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
230
+ ``
231
+
232
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
233
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
234
+
235
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
236
+ The input `model` has the following format:
237
+ ``
238
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
239
+ ``
240
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
241
+
242
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
243
+ arXiv preprint arXiv:2207.12598 (2022).
244
+
245
+
246
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
247
+ or continuous-time labels (i.e. epsilon to T).
248
+
249
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
250
+ ``
251
+ def model_fn(x, t_continuous) -> noise:
252
+ t_input = get_model_input_time(t_continuous)
253
+ return noise_pred(model, x, t_input, **model_kwargs)
254
+ ``
255
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
256
+
257
+ ===============================================================
258
+
259
+ Args:
260
+ model: A diffusion model with the corresponding format described above.
261
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
262
+ model_type: A `str`. The parameterization type of the diffusion model.
263
+ "noise" or "x_start" or "v" or "score".
264
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
265
+ guidance_type: A `str`. The type of the guidance for sampling.
266
+ "uncond" or "classifier" or "classifier-free".
267
+ condition: A pytorch tensor. The condition for the guided sampling.
268
+ Only used for "classifier" or "classifier-free" guidance type.
269
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
270
+ Only used for "classifier-free" guidance type.
271
+ guidance_scale: A `float`. The scale for the guided sampling.
272
+ classifier_fn: A classifier function. Only used for the classifier guidance.
273
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
274
+ Returns:
275
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
276
+ """
277
+
278
+ def get_model_input_time(t_continuous):
279
+ """
280
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
281
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
282
+ For continuous-time DPMs, we just use `t_continuous`.
283
+ """
284
+ if noise_schedule.schedule == 'discrete':
285
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
286
+ else:
287
+ return t_continuous
288
+
289
+ def noise_pred_fn(x, t_continuous, cond=None):
290
+ if t_continuous.reshape((-1,)).shape[0] == 1:
291
+ t_continuous = t_continuous.expand((x.shape[0]))
292
+ t_input = get_model_input_time(t_continuous)
293
+ if cond is None:
294
+ output = model(x, t_input, **model_kwargs)
295
+ else:
296
+ output = model(x, t_input, cond, **model_kwargs)
297
+ if model_type == "noise":
298
+ return output
299
+ elif model_type == "x_start":
300
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
301
+ dims = x.dim()
302
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
303
+ elif model_type == "v":
304
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
305
+ dims = x.dim()
306
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
307
+ elif model_type == "score":
308
+ sigma_t = noise_schedule.marginal_std(t_continuous)
309
+ dims = x.dim()
310
+ return -expand_dims(sigma_t, dims) * output
311
+
312
+ def cond_grad_fn(x, t_input):
313
+ """
314
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
315
+ """
316
+ with torch.enable_grad():
317
+ x_in = x.detach().requires_grad_(True)
318
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
319
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
320
+
321
+ def model_fn(x, t_continuous):
322
+ """
323
+ The noise predicition model function that is used for DPM-Solver.
324
+ """
325
+ if t_continuous.reshape((-1,)).shape[0] == 1:
326
+ t_continuous = t_continuous.expand((x.shape[0]))
327
+ if guidance_type == "uncond":
328
+ return noise_pred_fn(x, t_continuous)
329
+ elif guidance_type == "classifier":
330
+ assert classifier_fn is not None
331
+ t_input = get_model_input_time(t_continuous)
332
+ cond_grad = cond_grad_fn(x, t_input)
333
+ sigma_t = noise_schedule.marginal_std(t_continuous)
334
+ noise = noise_pred_fn(x, t_continuous)
335
+ return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
336
+ elif guidance_type == "classifier-free":
337
+ if guidance_scale == 1. or unconditional_condition is None:
338
+ return noise_pred_fn(x, t_continuous, cond=condition)
339
+ else:
340
+ x_in = torch.cat([x] * 2)
341
+ t_in = torch.cat([t_continuous] * 2)
342
+ c_in = torch.cat([unconditional_condition, condition])
343
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
344
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
345
+
346
+ assert model_type in ["noise", "x_start", "v"]
347
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
348
+ return model_fn
349
+
350
+
351
+ class DPM_Solver:
352
+ def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
353
+ """Construct a DPM-Solver.
354
+
355
+ We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
356
+ If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
357
+ If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
358
+ In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
359
+ The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
360
+
361
+ Args:
362
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
363
+ ``
364
+ def model_fn(x, t_continuous):
365
+ return noise
366
+ ``
367
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
368
+ predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
369
+ thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
370
+ max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
371
+
372
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
373
+ """
374
+ self.model = model_fn
375
+ self.noise_schedule = noise_schedule
376
+ self.predict_x0 = predict_x0
377
+ self.thresholding = thresholding
378
+ self.max_val = max_val
379
+
380
+ def noise_prediction_fn(self, x, t):
381
+ """
382
+ Return the noise prediction model.
383
+ """
384
+ return self.model(x, t)
385
+
386
+ def data_prediction_fn(self, x, t):
387
+ """
388
+ Return the data prediction model (with thresholding).
389
+ """
390
+ noise = self.noise_prediction_fn(x, t)
391
+ dims = x.dim()
392
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
393
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
394
+ if self.thresholding:
395
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
396
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
397
+ s = expand_dims(torch.maximum(s, torch.ones_like(s).to(s.device)), dims)
398
+ x0 = torch.clamp(x0, -s, s) / (s / self.max_val)
399
+ return x0
400
+
401
+ def model_fn(self, x, t):
402
+ """
403
+ Convert the model to the noise prediction model or the data prediction model.
404
+ """
405
+ if self.predict_x0:
406
+ return self.data_prediction_fn(x, t)
407
+ else:
408
+ return self.noise_prediction_fn(x, t)
409
+
410
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
411
+ """Compute the intermediate time steps for sampling.
412
+
413
+ Args:
414
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
415
+ - 'logSNR': uniform logSNR for the time steps.
416
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
417
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
418
+ t_T: A `float`. The starting time of the sampling (default is T).
419
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
420
+ N: A `int`. The total number of the spacing of the time steps.
421
+ device: A torch device.
422
+ Returns:
423
+ A pytorch tensor of the time steps, with the shape (N + 1,).
424
+ """
425
+ if skip_type == 'logSNR':
426
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
427
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
428
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
429
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
430
+ elif skip_type == 'time_uniform':
431
+ return torch.linspace(t_T, t_0, N + 1).to(device)
432
+ elif skip_type == 'time_quadratic':
433
+ t_order = 2
434
+ t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
435
+ return t
436
+ else:
437
+ raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
438
+
439
+ def get_orders_for_singlestep_solver(self, steps, order):
440
+ """
441
+ Get the order of each step for sampling by the singlestep DPM-Solver.
442
+
443
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
444
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
445
+ - If order == 1:
446
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
447
+ - If order == 2:
448
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
449
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
450
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
451
+ - If order == 3:
452
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
453
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
454
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
455
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
456
+
457
+ ============================================
458
+ Args:
459
+ order: A `int`. The max order for the solver (2 or 3).
460
+ steps: A `int`. The total number of function evaluations (NFE).
461
+ Returns:
462
+ orders: A list of the solver order of each step.
463
+ """
464
+ if order == 3:
465
+ K = steps // 3 + 1
466
+ if steps % 3 == 0:
467
+ orders = [3,] * (K - 2) + [2, 1]
468
+ elif steps % 3 == 1:
469
+ orders = [3,] * (K - 1) + [1]
470
+ else:
471
+ orders = [3,] * (K - 1) + [2]
472
+ return orders
473
+ elif order == 2:
474
+ K = steps // 2
475
+ if steps % 2 == 0:
476
+ orders = [2,] * K
477
+ else:
478
+ orders = [2,] * K + [1]
479
+ return orders
480
+ elif order == 1:
481
+ return [1,] * steps
482
+ else:
483
+ raise ValueError("'order' must be '1' or '2' or '3'.")
484
+
485
+ def denoise_fn(self, x, s):
486
+ """
487
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
488
+ """
489
+ return self.data_prediction_fn(x, s)
490
+
491
+ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
492
+ """
493
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
494
+
495
+ Args:
496
+ x: A pytorch tensor. The initial value at time `s`.
497
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
498
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
499
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
500
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
501
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
502
+ Returns:
503
+ x_t: A pytorch tensor. The approximated solution at time `t`.
504
+ """
505
+ ns = self.noise_schedule
506
+ dims = x.dim()
507
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
508
+ h = lambda_t - lambda_s
509
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
510
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
511
+ alpha_t = torch.exp(log_alpha_t)
512
+
513
+ if self.predict_x0:
514
+ phi_1 = torch.expm1(-h)
515
+ if model_s is None:
516
+ model_s = self.model_fn(x, s)
517
+ x_t = (
518
+ expand_dims(sigma_t / sigma_s, dims) * x
519
+ - expand_dims(alpha_t * phi_1, dims) * model_s
520
+ )
521
+ if return_intermediate:
522
+ return x_t, {'model_s': model_s}
523
+ else:
524
+ return x_t
525
+ else:
526
+ phi_1 = torch.expm1(h)
527
+ if model_s is None:
528
+ model_s = self.model_fn(x, s)
529
+ x_t = (
530
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
531
+ - expand_dims(sigma_t * phi_1, dims) * model_s
532
+ )
533
+ if return_intermediate:
534
+ return x_t, {'model_s': model_s}
535
+ else:
536
+ return x_t
537
+
538
+ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type='dpm_solver'):
539
+ """
540
+ Singlestep solver DPM-Solver-2 from time `s` to time `t`.
541
+
542
+ Args:
543
+ x: A pytorch tensor. The initial value at time `s`.
544
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
545
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
546
+ r1: A `float`. The hyperparameter of the second-order solver.
547
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
548
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
549
+ return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
550
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
551
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
552
+ Returns:
553
+ x_t: A pytorch tensor. The approximated solution at time `t`.
554
+ """
555
+ if solver_type not in ['dpm_solver', 'taylor']:
556
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
557
+ if r1 is None:
558
+ r1 = 0.5
559
+ ns = self.noise_schedule
560
+ dims = x.dim()
561
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
562
+ h = lambda_t - lambda_s
563
+ lambda_s1 = lambda_s + r1 * h
564
+ s1 = ns.inverse_lambda(lambda_s1)
565
+ log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t)
566
+ sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
567
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
568
+
569
+ if self.predict_x0:
570
+ phi_11 = torch.expm1(-r1 * h)
571
+ phi_1 = torch.expm1(-h)
572
+
573
+ if model_s is None:
574
+ model_s = self.model_fn(x, s)
575
+ x_s1 = (
576
+ expand_dims(sigma_s1 / sigma_s, dims) * x
577
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
578
+ )
579
+ model_s1 = self.model_fn(x_s1, s1)
580
+ if solver_type == 'dpm_solver':
581
+ x_t = (
582
+ expand_dims(sigma_t / sigma_s, dims) * x
583
+ - expand_dims(alpha_t * phi_1, dims) * model_s
584
+ - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s)
585
+ )
586
+ elif solver_type == 'taylor':
587
+ x_t = (
588
+ expand_dims(sigma_t / sigma_s, dims) * x
589
+ - expand_dims(alpha_t * phi_1, dims) * model_s
590
+ + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (model_s1 - model_s)
591
+ )
592
+ else:
593
+ phi_11 = torch.expm1(r1 * h)
594
+ phi_1 = torch.expm1(h)
595
+
596
+ if model_s is None:
597
+ model_s = self.model_fn(x, s)
598
+ x_s1 = (
599
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
600
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
601
+ )
602
+ model_s1 = self.model_fn(x_s1, s1)
603
+ if solver_type == 'dpm_solver':
604
+ x_t = (
605
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
606
+ - expand_dims(sigma_t * phi_1, dims) * model_s
607
+ - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s)
608
+ )
609
+ elif solver_type == 'taylor':
610
+ x_t = (
611
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
612
+ - expand_dims(sigma_t * phi_1, dims) * model_s
613
+ - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s)
614
+ )
615
+ if return_intermediate:
616
+ return x_t, {'model_s': model_s, 'model_s1': model_s1}
617
+ else:
618
+ return x_t
619
+
620
+ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1./3., r2=2./3., model_s=None, model_s1=None, return_intermediate=False, solver_type='dpm_solver'):
621
+ """
622
+ Singlestep solver DPM-Solver-3 from time `s` to time `t`.
623
+
624
+ Args:
625
+ x: A pytorch tensor. The initial value at time `s`.
626
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
627
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
628
+ r1: A `float`. The hyperparameter of the third-order solver.
629
+ r2: A `float`. The hyperparameter of the third-order solver.
630
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
631
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
632
+ model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
633
+ If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
634
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
635
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
636
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
637
+ Returns:
638
+ x_t: A pytorch tensor. The approximated solution at time `t`.
639
+ """
640
+ if solver_type not in ['dpm_solver', 'taylor']:
641
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
642
+ if r1 is None:
643
+ r1 = 1. / 3.
644
+ if r2 is None:
645
+ r2 = 2. / 3.
646
+ ns = self.noise_schedule
647
+ dims = x.dim()
648
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
649
+ h = lambda_t - lambda_s
650
+ lambda_s1 = lambda_s + r1 * h
651
+ lambda_s2 = lambda_s + r2 * h
652
+ s1 = ns.inverse_lambda(lambda_s1)
653
+ s2 = ns.inverse_lambda(lambda_s2)
654
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
655
+ sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(s2), ns.marginal_std(t)
656
+ alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
657
+
658
+ if self.predict_x0:
659
+ phi_11 = torch.expm1(-r1 * h)
660
+ phi_12 = torch.expm1(-r2 * h)
661
+ phi_1 = torch.expm1(-h)
662
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
663
+ phi_2 = phi_1 / h + 1.
664
+ phi_3 = phi_2 / h - 0.5
665
+
666
+ if model_s is None:
667
+ model_s = self.model_fn(x, s)
668
+ if model_s1 is None:
669
+ x_s1 = (
670
+ expand_dims(sigma_s1 / sigma_s, dims) * x
671
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
672
+ )
673
+ model_s1 = self.model_fn(x_s1, s1)
674
+ x_s2 = (
675
+ expand_dims(sigma_s2 / sigma_s, dims) * x
676
+ - expand_dims(alpha_s2 * phi_12, dims) * model_s
677
+ + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
678
+ )
679
+ model_s2 = self.model_fn(x_s2, s2)
680
+ if solver_type == 'dpm_solver':
681
+ x_t = (
682
+ expand_dims(sigma_t / sigma_s, dims) * x
683
+ - expand_dims(alpha_t * phi_1, dims) * model_s
684
+ + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s)
685
+ )
686
+ elif solver_type == 'taylor':
687
+ D1_0 = (1. / r1) * (model_s1 - model_s)
688
+ D1_1 = (1. / r2) * (model_s2 - model_s)
689
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
690
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
691
+ x_t = (
692
+ expand_dims(sigma_t / sigma_s, dims) * x
693
+ - expand_dims(alpha_t * phi_1, dims) * model_s
694
+ + expand_dims(alpha_t * phi_2, dims) * D1
695
+ - expand_dims(alpha_t * phi_3, dims) * D2
696
+ )
697
+ else:
698
+ phi_11 = torch.expm1(r1 * h)
699
+ phi_12 = torch.expm1(r2 * h)
700
+ phi_1 = torch.expm1(h)
701
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
702
+ phi_2 = phi_1 / h - 1.
703
+ phi_3 = phi_2 / h - 0.5
704
+
705
+ if model_s is None:
706
+ model_s = self.model_fn(x, s)
707
+ if model_s1 is None:
708
+ x_s1 = (
709
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
710
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
711
+ )
712
+ model_s1 = self.model_fn(x_s1, s1)
713
+ x_s2 = (
714
+ expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
715
+ - expand_dims(sigma_s2 * phi_12, dims) * model_s
716
+ - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
717
+ )
718
+ model_s2 = self.model_fn(x_s2, s2)
719
+ if solver_type == 'dpm_solver':
720
+ x_t = (
721
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
722
+ - expand_dims(sigma_t * phi_1, dims) * model_s
723
+ - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s)
724
+ )
725
+ elif solver_type == 'taylor':
726
+ D1_0 = (1. / r1) * (model_s1 - model_s)
727
+ D1_1 = (1. / r2) * (model_s2 - model_s)
728
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
729
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
730
+ x_t = (
731
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
732
+ - expand_dims(sigma_t * phi_1, dims) * model_s
733
+ - expand_dims(sigma_t * phi_2, dims) * D1
734
+ - expand_dims(sigma_t * phi_3, dims) * D2
735
+ )
736
+
737
+ if return_intermediate:
738
+ return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
739
+ else:
740
+ return x_t
741
+
742
+ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
743
+ """
744
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
745
+
746
+ Args:
747
+ x: A pytorch tensor. The initial value at time `s`.
748
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
749
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
750
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
751
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
752
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
753
+ Returns:
754
+ x_t: A pytorch tensor. The approximated solution at time `t`.
755
+ """
756
+ if solver_type not in ['dpm_solver', 'taylor']:
757
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
758
+ ns = self.noise_schedule
759
+ dims = x.dim()
760
+ model_prev_1, model_prev_0 = model_prev_list
761
+ t_prev_1, t_prev_0 = t_prev_list
762
+ lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
763
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
764
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
765
+ alpha_t = torch.exp(log_alpha_t)
766
+
767
+ h_0 = lambda_prev_0 - lambda_prev_1
768
+ h = lambda_t - lambda_prev_0
769
+ r0 = h_0 / h
770
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
771
+ if self.predict_x0:
772
+ if solver_type == 'dpm_solver':
773
+ x_t = (
774
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
775
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
776
+ - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0
777
+ )
778
+ elif solver_type == 'taylor':
779
+ x_t = (
780
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
781
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
782
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
783
+ )
784
+ else:
785
+ if solver_type == 'dpm_solver':
786
+ x_t = (
787
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
788
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
789
+ - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0
790
+ )
791
+ elif solver_type == 'taylor':
792
+ x_t = (
793
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
794
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
795
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0
796
+ )
797
+ return x_t
798
+
799
+ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'):
800
+ """
801
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
802
+
803
+ Args:
804
+ x: A pytorch tensor. The initial value at time `s`.
805
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
806
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
807
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
808
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
809
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
810
+ Returns:
811
+ x_t: A pytorch tensor. The approximated solution at time `t`.
812
+ """
813
+ ns = self.noise_schedule
814
+ dims = x.dim()
815
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
816
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
817
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
818
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
819
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
820
+ alpha_t = torch.exp(log_alpha_t)
821
+
822
+ h_1 = lambda_prev_1 - lambda_prev_2
823
+ h_0 = lambda_prev_0 - lambda_prev_1
824
+ h = lambda_t - lambda_prev_0
825
+ r0, r1 = h_0 / h, h_1 / h
826
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
827
+ D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2)
828
+ D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
829
+ D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
830
+ if self.predict_x0:
831
+ x_t = (
832
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
833
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
834
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1
835
+ - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h**2 - 0.5), dims) * D2
836
+ )
837
+ else:
838
+ x_t = (
839
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
840
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
841
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1
842
+ - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h**2 - 0.5), dims) * D2
843
+ )
844
+ return x_t
845
+
846
+ def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, r2=None):
847
+ """
848
+ Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
849
+
850
+ Args:
851
+ x: A pytorch tensor. The initial value at time `s`.
852
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
853
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
854
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
855
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
856
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
857
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
858
+ r1: A `float`. The hyperparameter of the second-order or third-order solver.
859
+ r2: A `float`. The hyperparameter of the third-order solver.
860
+ Returns:
861
+ x_t: A pytorch tensor. The approximated solution at time `t`.
862
+ """
863
+ if order == 1:
864
+ return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
865
+ elif order == 2:
866
+ return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1)
867
+ elif order == 3:
868
+ return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2)
869
+ else:
870
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
871
+
872
+ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'):
873
+ """
874
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
875
+
876
+ Args:
877
+ x: A pytorch tensor. The initial value at time `s`.
878
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
879
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
880
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
881
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
882
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
883
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
884
+ Returns:
885
+ x_t: A pytorch tensor. The approximated solution at time `t`.
886
+ """
887
+ if order == 1:
888
+ return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
889
+ elif order == 2:
890
+ return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
891
+ elif order == 3:
892
+ return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
893
+ else:
894
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
895
+
896
+ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type='dpm_solver'):
897
+ """
898
+ The adaptive step size solver based on singlestep DPM-Solver.
899
+
900
+ Args:
901
+ x: A pytorch tensor. The initial value at time `t_T`.
902
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
903
+ t_T: A `float`. The starting time of the sampling (default is T).
904
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
905
+ h_init: A `float`. The initial step size (for logSNR).
906
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
907
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
908
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
909
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
910
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
911
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
912
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
913
+ Returns:
914
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
915
+
916
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
917
+ """
918
+ ns = self.noise_schedule
919
+ s = t_T * torch.ones((x.shape[0],)).to(x)
920
+ lambda_s = ns.marginal_lambda(s)
921
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
922
+ h = h_init * torch.ones_like(s).to(x)
923
+ x_prev = x
924
+ nfe = 0
925
+ if order == 2:
926
+ r1 = 0.5
927
+ lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
928
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs)
929
+ elif order == 3:
930
+ r1, r2 = 1. / 3., 2. / 3.
931
+ lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type)
932
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs)
933
+ else:
934
+ raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
935
+ while torch.abs((s - t_0)).mean() > t_err:
936
+ t = ns.inverse_lambda(lambda_s + h)
937
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
938
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
939
+ delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
940
+ norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
941
+ E = norm_fn((x_higher - x_lower) / delta).max()
942
+ if torch.all(E <= 1.):
943
+ x = x_higher
944
+ s = t
945
+ x_prev = x_lower
946
+ lambda_s = ns.marginal_lambda(s)
947
+ h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
948
+ nfe += order
949
+ print('adaptive solver nfe', nfe)
950
+ return x
951
+
952
+ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
953
+ method='singlestep', denoise=False, solver_type='dpm_solver', atol=0.0078,
954
+ rtol=0.05,
955
+ ):
956
+ """
957
+ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
958
+
959
+ =====================================================
960
+
961
+ We support the following algorithms for both noise prediction model and data prediction model:
962
+ - 'singlestep':
963
+ Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
964
+ We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
965
+ The total number of function evaluations (NFE) == `steps`.
966
+ Given a fixed NFE == `steps`, the sampling procedure is:
967
+ - If `order` == 1:
968
+ - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
969
+ - If `order` == 2:
970
+ - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
971
+ - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
972
+ - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
973
+ - If `order` == 3:
974
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
975
+ - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
976
+ - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
977
+ - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
978
+ - 'multistep':
979
+ Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
980
+ We initialize the first `order` values by lower order multistep solvers.
981
+ Given a fixed NFE == `steps`, the sampling procedure is:
982
+ Denote K = steps.
983
+ - If `order` == 1:
984
+ - We use K steps of DPM-Solver-1 (i.e. DDIM).
985
+ - If `order` == 2:
986
+ - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
987
+ - If `order` == 3:
988
+ - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
989
+ - 'singlestep_fixed':
990
+ Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
991
+ We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
992
+ - 'adaptive':
993
+ Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
994
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
995
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
996
+ (NFE) and the sample quality.
997
+ - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
998
+ - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
999
+
1000
+ =====================================================
1001
+
1002
+ Some advices for choosing the algorithm:
1003
+ - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
1004
+ Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
1005
+ e.g.
1006
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
1007
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
1008
+ skip_type='time_uniform', method='singlestep')
1009
+ - For **guided sampling with large guidance scale** by DPMs:
1010
+ Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
1011
+ e.g.
1012
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
1013
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
1014
+ skip_type='time_uniform', method='multistep')
1015
+
1016
+ We support three types of `skip_type`:
1017
+ - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
1018
+ - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
1019
+ - 'time_quadratic': quadratic time for the time steps.
1020
+
1021
+ =====================================================
1022
+ Args:
1023
+ x: A pytorch tensor. The initial value at time `t_start`
1024
+ e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
1025
+ steps: A `int`. The total number of function evaluations (NFE).
1026
+ t_start: A `float`. The starting time of the sampling.
1027
+ If `T` is None, we use self.noise_schedule.T (default is 1.0).
1028
+ t_end: A `float`. The ending time of the sampling.
1029
+ If `t_end` is None, we use 1. / self.noise_schedule.total_N.
1030
+ e.g. if total_N == 1000, we have `t_end` == 1e-3.
1031
+ For discrete-time DPMs:
1032
+ - We recommend `t_end` == 1. / self.noise_schedule.total_N.
1033
+ For continuous-time DPMs:
1034
+ - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
1035
+ order: A `int`. The order of DPM-Solver.
1036
+ skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
1037
+ method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
1038
+ denoise: A `bool`. Whether to denoise at the final step. Default is False.
1039
+ If `denoise` is True, the total NFE is (`steps` + 1).
1040
+ solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
1041
+ atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1042
+ rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1043
+ Returns:
1044
+ x_end: A pytorch tensor. The approximated solution at time `t_end`.
1045
+
1046
+ """
1047
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
1048
+ t_T = self.noise_schedule.T if t_start is None else t_start
1049
+ device = x.device
1050
+ if method == 'adaptive':
1051
+ with torch.no_grad():
1052
+ x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type)
1053
+ elif method == 'multistep':
1054
+ assert steps >= order
1055
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
1056
+ assert timesteps.shape[0] - 1 == steps
1057
+ with torch.no_grad():
1058
+ vec_t = timesteps[0].expand((x.shape[0]))
1059
+ model_prev_list = [self.model_fn(x, vec_t)]
1060
+ t_prev_list = [vec_t]
1061
+ # Init the first `order` values by lower order multistep DPM-Solver.
1062
+ for init_order in range(1, order):
1063
+ vec_t = timesteps[init_order].expand(x.shape[0])
1064
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, solver_type=solver_type)
1065
+ model_prev_list.append(self.model_fn(x, vec_t))
1066
+ t_prev_list.append(vec_t)
1067
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
1068
+ for step in range(order, steps + 1):
1069
+ vec_t = timesteps[step].expand(x.shape[0])
1070
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, order, solver_type=solver_type)
1071
+ for i in range(order - 1):
1072
+ t_prev_list[i] = t_prev_list[i + 1]
1073
+ model_prev_list[i] = model_prev_list[i + 1]
1074
+ t_prev_list[-1] = vec_t
1075
+ # We do not need to evaluate the final model value.
1076
+ if step < steps:
1077
+ model_prev_list[-1] = self.model_fn(x, vec_t)
1078
+ elif method in ['singlestep', 'singlestep_fixed']:
1079
+ if method == 'singlestep':
1080
+ orders = self.get_orders_for_singlestep_solver(steps=steps, order=order)
1081
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
1082
+ elif method == 'singlestep_fixed':
1083
+ K = steps // order
1084
+ orders = [order,] * K
1085
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=(K * order), device=device)
1086
+ with torch.no_grad():
1087
+ i = 0
1088
+ for order in orders:
1089
+ vec_s, vec_t = timesteps[i].expand(x.shape[0]), timesteps[i + order].expand(x.shape[0])
1090
+ h = self.noise_schedule.marginal_lambda(timesteps[i + order]) - self.noise_schedule.marginal_lambda(timesteps[i])
1091
+ r1 = None if order <= 1 else (self.noise_schedule.marginal_lambda(timesteps[i + 1]) - self.noise_schedule.marginal_lambda(timesteps[i])) / h
1092
+ r2 = None if order <= 2 else (self.noise_schedule.marginal_lambda(timesteps[i + 2]) - self.noise_schedule.marginal_lambda(timesteps[i])) / h
1093
+ x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
1094
+ i += order
1095
+ if denoise:
1096
+ x = self.denoise_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
1097
+ return x
1098
+
1099
+
1100
+
1101
+ #############################################################
1102
+ # other utility functions
1103
+ #############################################################
1104
+
1105
+ def interpolate_fn(x, xp, yp):
1106
+ """
1107
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
1108
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
1109
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
1110
+
1111
+ Args:
1112
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
1113
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
1114
+ yp: PyTorch tensor with shape [C, K].
1115
+ Returns:
1116
+ The function values f(x), with shape [N, C].
1117
+ """
1118
+ N, K = x.shape[0], xp.shape[1]
1119
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
1120
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
1121
+ x_idx = torch.argmin(x_indices, dim=2)
1122
+ cand_start_idx = x_idx - 1
1123
+ start_idx = torch.where(
1124
+ torch.eq(x_idx, 0),
1125
+ torch.tensor(1, device=x.device),
1126
+ torch.where(
1127
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1128
+ ),
1129
+ )
1130
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
1131
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
1132
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
1133
+ start_idx2 = torch.where(
1134
+ torch.eq(x_idx, 0),
1135
+ torch.tensor(0, device=x.device),
1136
+ torch.where(
1137
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1138
+ ),
1139
+ )
1140
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
1141
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
1142
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
1143
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
1144
+ return cand
1145
+
1146
+
1147
+ def expand_dims(v, dims):
1148
+ """
1149
+ Expand the tensor `v` to the dim `dims`.
1150
+
1151
+ Args:
1152
+ `v`: a PyTorch tensor with shape [N].
1153
+ `dim`: a `int`.
1154
+ Returns:
1155
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
1156
+ """
1157
+ return v[(...,) + (None,)*(dims - 1)]
ldm/models/diffusion/dpm_solver/sampler.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+
5
+ from .solver import NoiseScheduleVP, model_wrapper, DPM_Solver
6
+
7
+
8
+ class DPMSolverSampler(object):
9
+ def __init__(self, model, **kwargs):
10
+ super().__init__()
11
+ self.model = model
12
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
13
+ self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
14
+
15
+ def register_buffer(self, name, attr):
16
+ if type(attr) == torch.Tensor:
17
+ if attr.device != torch.device("cuda"):
18
+ attr = attr.to(torch.device("cuda"))
19
+ setattr(self, name, attr)
20
+
21
+ @torch.no_grad()
22
+ def sample(self,
23
+ S,
24
+ batch_size,
25
+ shape,
26
+ conditioning=None,
27
+ callback=None,
28
+ normals_sequence=None,
29
+ img_callback=None,
30
+ quantize_x0=False,
31
+ eta=0.,
32
+ mask=None,
33
+ x0=None,
34
+ temperature=1.,
35
+ noise_dropout=0.,
36
+ score_corrector=None,
37
+ corrector_kwargs=None,
38
+ verbose=True,
39
+ x_T=None,
40
+ log_every_t=100,
41
+ unconditional_guidance_scale=1.,
42
+ unconditional_conditioning=None,
43
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
44
+ **kwargs
45
+ ):
46
+ if conditioning is not None:
47
+ if isinstance(conditioning, dict):
48
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
49
+ if cbs != batch_size:
50
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
51
+ else:
52
+ if conditioning.shape[0] != batch_size:
53
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
54
+
55
+ # sampling
56
+ C, H, W = shape
57
+ size = (batch_size, C, H, W)
58
+
59
+ # print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
60
+
61
+ device = self.model.betas.device
62
+ if x_T is None:
63
+ img = torch.randn(size, device=device)
64
+ else:
65
+ img = x_T
66
+
67
+ ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
68
+
69
+ model_fn = model_wrapper(
70
+ lambda x, t, c: self.model.apply_model(x, t, c),
71
+ ns,
72
+ model_type="noise",
73
+ guidance_type="classifier-free",
74
+ condition=conditioning,
75
+ unconditional_condition=unconditional_conditioning,
76
+ guidance_scale=unconditional_guidance_scale,
77
+ )
78
+
79
+ dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
80
+ x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2)
81
+
82
+ return x.to(device), None
scripts/txt2img.py CHANGED
@@ -17,6 +17,7 @@ from contextlib import contextmanager, nullcontext
17
  from ldm.util import instantiate_from_config
18
  from ldm.models.diffusion.ddim import DDIMSampler
19
  from ldm.models.diffusion.plms import PLMSSampler
 
20
 
21
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
22
  from transformers import AutoFeatureExtractor
@@ -132,6 +133,11 @@ def main():
132
  action='store_true',
133
  help="use plms sampling",
134
  )
 
 
 
 
 
135
  parser.add_argument(
136
  "--laion400m",
137
  action='store_true',
@@ -242,7 +248,9 @@ def main():
242
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
243
  model = model.to(device)
244
 
245
- if opt.plms:
 
 
246
  sampler = PLMSSampler(model)
247
  else:
248
  sampler = DDIMSampler(model)
 
17
  from ldm.util import instantiate_from_config
18
  from ldm.models.diffusion.ddim import DDIMSampler
19
  from ldm.models.diffusion.plms import PLMSSampler
20
+ from ldm.models.diffusion.dpm_solver import DPMSolverSampler
21
 
22
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
23
  from transformers import AutoFeatureExtractor
 
133
  action='store_true',
134
  help="use plms sampling",
135
  )
136
+ parser.add_argument(
137
+ "--dpm_solver",
138
+ action='store_true',
139
+ help="use dpm_solver sampling",
140
+ )
141
  parser.add_argument(
142
  "--laion400m",
143
  action='store_true',
 
248
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
249
  model = model.to(device)
250
 
251
+ if opt.dpm_solver:
252
+ sampler = DPMSolverSampler(model)
253
+ elif opt.plms:
254
  sampler = PLMSSampler(model)
255
  else:
256
  sampler = DDIMSampler(model)