Dionyssos commited on
Commit
d9ecbcf
·
1 Parent(s): 70184a3

cleanup diffusion

Browse files
Modules/diffusion/diffusion.py CHANGED
@@ -54,15 +54,6 @@ def get_default_model_kwargs():
54
  def get_default_sampling_kwargs():
55
  return dict(sigma_schedule=LinearSchedule(), sampler=VSampler(), clamp=True)
56
 
57
-
58
- class AudioDiffusionModel(Model1d):
59
- def __init__(self, **kwargs):
60
- super().__init__(**{**get_default_model_kwargs(), **kwargs})
61
-
62
- def sample(self, *args, **kwargs):
63
- return super().sample(*args, **{**get_default_sampling_kwargs(), **kwargs})
64
-
65
-
66
  class AudioDiffusionConditional(Model1d):
67
  def __init__(
68
  self,
 
54
  def get_default_sampling_kwargs():
55
  return dict(sigma_schedule=LinearSchedule(), sampler=VSampler(), clamp=True)
56
 
 
 
 
 
 
 
 
 
 
57
  class AudioDiffusionConditional(Model1d):
58
  def __init__(
59
  self,
Modules/diffusion/sampler.py CHANGED
@@ -1,27 +1,14 @@
1
  from math import atan, cos, pi, sin, sqrt
2
  from typing import Any, Callable, List, Optional, Tuple, Type
3
-
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
- from einops import rearrange, reduce
8
  from torch import Tensor
9
-
10
  from .utils import *
11
 
12
- """
13
- Diffusion Training
14
- """
15
-
16
- """ Distributions """
17
-
18
-
19
- class Distribution:
20
- def __call__(self, num_samples: int, device: torch.device):
21
- raise NotImplementedError()
22
-
23
 
24
- class LogNormalDistribution(Distribution):
25
  def __init__(self, mean: float, std: float):
26
  self.mean = mean
27
  self.std = std
@@ -33,55 +20,11 @@ class LogNormalDistribution(Distribution):
33
  return normal.exp()
34
 
35
 
36
- class UniformDistribution(Distribution):
37
  def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
38
  return torch.rand(num_samples, device=device)
39
 
40
 
41
- class VKDistribution(Distribution):
42
- def __init__(
43
- self,
44
- min_value: float = 0.0,
45
- max_value: float = float("inf"),
46
- sigma_data: float = 1.0,
47
- ):
48
- self.min_value = min_value
49
- self.max_value = max_value
50
- self.sigma_data = sigma_data
51
-
52
- def __call__(
53
- self, num_samples: int, device: torch.device = torch.device("cpu")
54
- ) -> Tensor:
55
- sigma_data = self.sigma_data
56
- min_cdf = atan(self.min_value / sigma_data) * 2 / pi
57
- max_cdf = atan(self.max_value / sigma_data) * 2 / pi
58
- u = (max_cdf - min_cdf) * torch.randn((num_samples,), device=device) + min_cdf
59
- return torch.tan(u * pi / 2) * sigma_data
60
-
61
-
62
- """ Diffusion Classes """
63
-
64
-
65
- def pad_dims(x: Tensor, ndim: int) -> Tensor:
66
- # Pads additional ndims to the right of the tensor
67
- return x.view(*x.shape, *((1,) * ndim))
68
-
69
-
70
- def clip(x: Tensor, dynamic_threshold: float = 0.0):
71
- if dynamic_threshold == 0.0:
72
- return x.clamp(-1.0, 1.0)
73
- else:
74
- # Dynamic thresholding
75
- # Find dynamic threshold quantile for each batch
76
- x_flat = rearrange(x, "b ... -> b (...)")
77
- scale = torch.quantile(x_flat.abs(), dynamic_threshold, dim=-1)
78
- # Clamp to a min of 1.0
79
- scale.clamp_(min=1.0)
80
- # Clamp all values and scale
81
- scale = pad_dims(scale, ndim=x.ndim - scale.ndim)
82
- x = x.clamp(-scale, scale) / scale
83
- return x
84
-
85
 
86
  def to_batch(
87
  batch_size: int,
@@ -96,73 +39,7 @@ def to_batch(
96
  assert exists(xs)
97
  return xs
98
 
99
-
100
- class Diffusion(nn.Module):
101
-
102
- alias: str = ""
103
-
104
- """Base diffusion class"""
105
-
106
- def denoise_fn(
107
- self,
108
- x_noisy: Tensor,
109
- sigmas: Optional[Tensor] = None,
110
- sigma: Optional[float] = None,
111
- **kwargs,
112
- ) -> Tensor:
113
- raise NotImplementedError("Diffusion class missing denoise_fn")
114
-
115
- def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
116
- raise NotImplementedError("Diffusion class missing forward function")
117
-
118
-
119
- class VDiffusion(Diffusion):
120
-
121
- alias = "v"
122
-
123
- def __init__(self, net: nn.Module, *, sigma_distribution: Distribution):
124
- super().__init__()
125
- self.net = net
126
- self.sigma_distribution = sigma_distribution
127
-
128
- def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
129
- angle = sigmas * pi / 2
130
- alpha = torch.cos(angle)
131
- beta = torch.sin(angle)
132
- return alpha, beta
133
-
134
- def denoise_fn(
135
- self,
136
- x_noisy: Tensor,
137
- sigmas: Optional[Tensor] = None,
138
- sigma: Optional[float] = None,
139
- **kwargs,
140
- ) -> Tensor:
141
- batch_size, device = x_noisy.shape[0], x_noisy.device
142
- sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
143
- return self.net(x_noisy, sigmas, **kwargs)
144
-
145
- def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
146
- batch_size, device = x.shape[0], x.device
147
-
148
- # Sample amount of noise to add for each batch element
149
- sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
150
- sigmas_padded = rearrange(sigmas, "b -> b 1 1")
151
-
152
- # Get noise
153
- noise = default(noise, lambda: torch.randn_like(x))
154
-
155
- # Combine input and noise weighted by half-circle
156
- alpha, beta = self.get_alpha_beta(sigmas_padded)
157
- x_noisy = x * alpha + noise * beta
158
- x_target = noise * alpha - x * beta
159
-
160
- # Denoise and return loss
161
- x_denoised = self.denoise_fn(x_noisy, sigmas, **kwargs)
162
- return F.mse_loss(x_denoised, x_target)
163
-
164
-
165
- class KDiffusion(Diffusion):
166
  """Elucidated Diffusion (Karras et al. 2022): https://arxiv.org/abs/2206.00364"""
167
 
168
  alias = "k"
@@ -171,7 +48,7 @@ class KDiffusion(Diffusion):
171
  self,
172
  net: nn.Module,
173
  *,
174
- sigma_distribution: Distribution,
175
  sigma_data: float, # data distribution standard deviation
176
  dynamic_threshold: float = 0.0,
177
  ):
@@ -196,127 +73,32 @@ class KDiffusion(Diffusion):
196
  sigmas: Optional[Tensor] = None,
197
  sigma: Optional[float] = None,
198
  **kwargs,
199
- ) -> Tensor:
 
200
  batch_size, device = x_noisy.shape[0], x_noisy.device
201
  sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
202
 
203
  # Predict network output and add skip connection
 
204
  c_skip, c_out, c_in, c_noise = self.get_scale_weights(sigmas)
205
  x_pred = self.net(c_in * x_noisy, c_noise, **kwargs)
206
  x_denoised = c_skip * x_noisy + c_out * x_pred
207
 
208
  return x_denoised
209
 
210
- def loss_weight(self, sigmas: Tensor) -> Tensor:
211
- # Computes weight depending on data distribution
212
- return (sigmas ** 2 + self.sigma_data ** 2) * (sigmas * self.sigma_data) ** -2
213
-
214
- def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
215
- batch_size, device = x.shape[0], x.device
216
- from einops import rearrange, reduce
217
-
218
- # Sample amount of noise to add for each batch element
219
- sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
220
- sigmas_padded = rearrange(sigmas, "b -> b 1 1")
221
-
222
- # Add noise to input
223
- noise = default(noise, lambda: torch.randn_like(x))
224
- x_noisy = x + sigmas_padded * noise
225
-
226
- # Compute denoised values
227
- x_denoised = self.denoise_fn(x_noisy, sigmas=sigmas, **kwargs)
228
-
229
- # Compute weighted loss
230
- losses = F.mse_loss(x_denoised, x, reduction="none")
231
- losses = reduce(losses, "b ... -> b", "mean")
232
- losses = losses * self.loss_weight(sigmas)
233
- loss = losses.mean()
234
- return loss
235
-
236
-
237
- class VKDiffusion(Diffusion):
238
-
239
- alias = "vk"
240
-
241
- def __init__(self, net: nn.Module, *, sigma_distribution: Distribution):
242
- super().__init__()
243
- self.net = net
244
- self.sigma_distribution = sigma_distribution
245
-
246
- def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
247
- sigma_data = 1.0
248
- sigmas = rearrange(sigmas, "b -> b 1 1")
249
- c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2)
250
- c_out = -sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5
251
- c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5
252
- return c_skip, c_out, c_in
253
-
254
- def sigma_to_t(self, sigmas: Tensor) -> Tensor:
255
- return sigmas.atan() / pi * 2
256
-
257
- def t_to_sigma(self, t: Tensor) -> Tensor:
258
- return (t * pi / 2).tan()
259
-
260
- def denoise_fn(
261
- self,
262
- x_noisy: Tensor,
263
- sigmas: Optional[Tensor] = None,
264
- sigma: Optional[float] = None,
265
- **kwargs,
266
- ) -> Tensor:
267
- batch_size, device = x_noisy.shape[0], x_noisy.device
268
- sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
269
-
270
- # Predict network output and add skip connection
271
- c_skip, c_out, c_in = self.get_scale_weights(sigmas)
272
- x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)
273
- x_denoised = c_skip * x_noisy + c_out * x_pred
274
- return x_denoised
275
-
276
- def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
277
- batch_size, device = x.shape[0], x.device
278
 
279
- # Sample amount of noise to add for each batch element
280
- sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
281
- sigmas_padded = rearrange(sigmas, "b -> b 1 1")
282
 
283
- # Add noise to input
284
- noise = default(noise, lambda: torch.randn_like(x))
285
- x_noisy = x + sigmas_padded * noise
286
 
287
- # Compute model output
288
- c_skip, c_out, c_in = self.get_scale_weights(sigmas)
289
- x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)
290
 
291
- # Compute v-objective target
292
- v_target = (x - c_skip * x_noisy) / (c_out + 1e-7)
293
 
294
- # Compute loss
295
- loss = F.mse_loss(x_pred, v_target)
296
- return loss
297
 
298
 
299
- """
300
- Diffusion Sampling
301
- """
302
 
303
- """ Schedules """
304
 
305
 
306
- class Schedule(nn.Module):
307
- """Interface used by different sampling schedules"""
308
-
309
- def forward(self, num_steps: int, device: torch.device) -> Tensor:
310
- raise NotImplementedError()
311
 
312
 
313
- class LinearSchedule(Schedule):
314
- def forward(self, num_steps: int, device: Any) -> Tensor:
315
- sigmas = torch.linspace(1, 0, num_steps + 1)[:-1]
316
- return sigmas
317
-
318
-
319
- class KarrasSchedule(Schedule):
320
  """https://arxiv.org/abs/2206.00364 equation 5"""
321
 
322
  def __init__(self, sigma_min: float, sigma_max: float, rho: float = 7.0):
@@ -342,7 +124,7 @@ class KarrasSchedule(Schedule):
342
 
343
  class Sampler(nn.Module):
344
 
345
- diffusion_types: List[Type[Diffusion]] = []
346
 
347
  def forward(
348
  self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
@@ -361,127 +143,10 @@ class Sampler(nn.Module):
361
  raise NotImplementedError("Inpainting not available with current sampler")
362
 
363
 
364
- class VSampler(Sampler):
365
-
366
- diffusion_types = [VDiffusion]
367
-
368
- def get_alpha_beta(self, sigma: float) -> Tuple[float, float]:
369
- angle = sigma * pi / 2
370
- alpha = cos(angle)
371
- beta = sin(angle)
372
- return alpha, beta
373
-
374
- def forward(
375
- self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
376
- ) -> Tensor:
377
- x = sigmas[0] * noise
378
- alpha, beta = self.get_alpha_beta(sigmas[0].item())
379
-
380
- for i in range(num_steps - 1):
381
- is_last = i == num_steps - 1
382
-
383
- x_denoised = fn(x, sigma=sigmas[i])
384
- x_pred = x * alpha - x_denoised * beta
385
- x_eps = x * beta + x_denoised * alpha
386
-
387
- if not is_last:
388
- alpha, beta = self.get_alpha_beta(sigmas[i + 1].item())
389
- x = x_pred * alpha + x_eps * beta
390
-
391
- return x_pred
392
-
393
-
394
- class KarrasSampler(Sampler):
395
- """https://arxiv.org/abs/2206.00364 algorithm 1"""
396
-
397
- diffusion_types = [KDiffusion, VKDiffusion]
398
-
399
- def __init__(
400
- self,
401
- s_tmin: float = 0,
402
- s_tmax: float = float("inf"),
403
- s_churn: float = 0.0,
404
- s_noise: float = 1.0,
405
- ):
406
- super().__init__()
407
- self.s_tmin = s_tmin
408
- self.s_tmax = s_tmax
409
- self.s_noise = s_noise
410
- self.s_churn = s_churn
411
-
412
- def step(
413
- self, x: Tensor, fn: Callable, sigma: float, sigma_next: float, gamma: float
414
- ) -> Tensor:
415
- """Algorithm 2 (step)"""
416
- # Select temporarily increased noise level
417
- sigma_hat = sigma + gamma * sigma
418
- # Add noise to move from sigma to sigma_hat
419
- epsilon = self.s_noise * torch.randn_like(x)
420
- x_hat = x + sqrt(sigma_hat ** 2 - sigma ** 2) * epsilon
421
- # Evaluate ∂x/∂sigma at sigma_hat
422
- d = (x_hat - fn(x_hat, sigma=sigma_hat)) / sigma_hat
423
- # Take euler step from sigma_hat to sigma_next
424
- x_next = x_hat + (sigma_next - sigma_hat) * d
425
- # Second order correction
426
- if sigma_next != 0:
427
- model_out_next = fn(x_next, sigma=sigma_next)
428
- d_prime = (x_next - model_out_next) / sigma_next
429
- x_next = x_hat + 0.5 * (sigma - sigma_hat) * (d + d_prime)
430
- return x_next
431
-
432
- def forward(
433
- self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
434
- ) -> Tensor:
435
- x = sigmas[0] * noise
436
- # Compute gammas
437
- gammas = torch.where(
438
- (sigmas >= self.s_tmin) & (sigmas <= self.s_tmax),
439
- min(self.s_churn / num_steps, sqrt(2) - 1),
440
- 0.0,
441
- )
442
- # Denoise to sample
443
- for i in range(num_steps - 1):
444
- x = self.step(
445
- x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1], gamma=gammas[i] # type: ignore # noqa
446
- )
447
-
448
- return x
449
-
450
-
451
- class AEulerSampler(Sampler):
452
-
453
- diffusion_types = [KDiffusion, VKDiffusion]
454
-
455
- def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float]:
456
- sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
457
- sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
458
- return sigma_up, sigma_down
459
-
460
- def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
461
- # Sigma steps
462
- sigma_up, sigma_down = self.get_sigmas(sigma, sigma_next)
463
- # Derivative at sigma (∂x/∂sigma)
464
- d = (x - fn(x, sigma=sigma)) / sigma
465
- # Euler method
466
- x_next = x + d * (sigma_down - sigma)
467
- # Add randomness
468
- x_next = x_next + torch.randn_like(x) * sigma_up
469
- return x_next
470
-
471
- def forward(
472
- self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
473
- ) -> Tensor:
474
- x = sigmas[0] * noise
475
- # Denoise to sample
476
- for i in range(num_steps - 1):
477
- x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
478
- return x
479
-
480
-
481
  class ADPM2Sampler(Sampler):
482
  """https://www.desmos.com/calculator/jbxjlqd9mb"""
483
 
484
- diffusion_types = [KDiffusion, VKDiffusion]
485
 
486
  def __init__(self, rho: float = 1.0):
487
  super().__init__()
@@ -510,52 +175,23 @@ class ADPM2Sampler(Sampler):
510
  return x_next
511
 
512
  def forward(
513
- self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
514
- ) -> Tensor:
515
  x = sigmas[0] * noise
516
  # Denoise to sample
517
  for i in range(num_steps - 1):
518
  x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
519
  return x
520
 
521
- def inpaint(
522
- self,
523
- source: Tensor,
524
- mask: Tensor,
525
- fn: Callable,
526
- sigmas: Tensor,
527
- num_steps: int,
528
- num_resamples: int,
529
- ) -> Tensor:
530
- x = sigmas[0] * torch.randn_like(source)
531
-
532
- for i in range(num_steps - 1):
533
- # Noise source to current noise level
534
- source_noisy = source + sigmas[i] * torch.randn_like(source)
535
- for r in range(num_resamples):
536
- # Merge noisy source and current then denoise
537
- x = source_noisy * mask + x * ~mask
538
- x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
539
- # Renoise if not last resample step
540
- if r < num_resamples - 1:
541
- sigma = sqrt(sigmas[i] ** 2 - sigmas[i + 1] ** 2)
542
- x = x + sigma * torch.randn_like(x)
543
-
544
- return source * mask + x * ~mask
545
-
546
-
547
- """ Main Classes """
548
-
549
-
550
  class DiffusionSampler(nn.Module):
551
  def __init__(
552
  self,
553
- diffusion: Diffusion,
554
  *,
555
- sampler: Sampler,
556
- sigma_schedule: Schedule,
557
- num_steps: Optional[int] = None,
558
- clamp: bool = True,
559
  ):
560
  super().__init__()
561
  self.denoise_fn = diffusion.denoise_fn
@@ -571,8 +207,8 @@ class DiffusionSampler(nn.Module):
571
  assert diffusion.alias in [t.alias for t in sampler.diffusion_types], message
572
 
573
  def forward(
574
- self, noise: Tensor, num_steps: Optional[int] = None, **kwargs
575
- ) -> Tensor:
576
  device = noise.device
577
  num_steps = default(num_steps, self.num_steps) # type: ignore
578
  assert exists(num_steps), "Parameter `num_steps` must be provided"
@@ -583,109 +219,4 @@ class DiffusionSampler(nn.Module):
583
  # Sample using sampler
584
  x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps)
585
  x = x.clamp(-1.0, 1.0) if self.clamp else x
586
- return x
587
-
588
-
589
- class DiffusionInpainter(nn.Module):
590
- def __init__(
591
- self,
592
- diffusion: Diffusion,
593
- *,
594
- num_steps: int,
595
- num_resamples: int,
596
- sampler: Sampler,
597
- sigma_schedule: Schedule,
598
- ):
599
- super().__init__()
600
- self.denoise_fn = diffusion.denoise_fn
601
- self.num_steps = num_steps
602
- self.num_resamples = num_resamples
603
- self.inpaint_fn = sampler.inpaint
604
- self.sigma_schedule = sigma_schedule
605
-
606
- @torch.no_grad()
607
- def forward(self, inpaint: Tensor, inpaint_mask: Tensor) -> Tensor:
608
- x = self.inpaint_fn(
609
- source=inpaint,
610
- mask=inpaint_mask,
611
- fn=self.denoise_fn,
612
- sigmas=self.sigma_schedule(self.num_steps, inpaint.device),
613
- num_steps=self.num_steps,
614
- num_resamples=self.num_resamples,
615
- )
616
- return x
617
-
618
-
619
- def sequential_mask(like: Tensor, start: int) -> Tensor:
620
- length, device = like.shape[2], like.device
621
- mask = torch.ones_like(like, dtype=torch.bool)
622
- mask[:, :, start:] = torch.zeros((length - start,), device=device)
623
- return mask
624
-
625
-
626
- class SpanBySpanComposer(nn.Module):
627
- def __init__(
628
- self,
629
- inpainter: DiffusionInpainter,
630
- *,
631
- num_spans: int,
632
- ):
633
- super().__init__()
634
- self.inpainter = inpainter
635
- self.num_spans = num_spans
636
-
637
- def forward(self, start: Tensor, keep_start: bool = False) -> Tensor:
638
- half_length = start.shape[2] // 2
639
-
640
- spans = list(start.chunk(chunks=2, dim=-1)) if keep_start else []
641
- # Inpaint second half from first half
642
- inpaint = torch.zeros_like(start)
643
- inpaint[:, :, :half_length] = start[:, :, half_length:]
644
- inpaint_mask = sequential_mask(like=start, start=half_length)
645
-
646
- for i in range(self.num_spans):
647
- # Inpaint second half
648
- span = self.inpainter(inpaint=inpaint, inpaint_mask=inpaint_mask)
649
- # Replace first half with generated second half
650
- second_half = span[:, :, half_length:]
651
- inpaint[:, :, :half_length] = second_half
652
- # Save generated span
653
- spans.append(second_half)
654
-
655
- return torch.cat(spans, dim=2)
656
-
657
-
658
- class XDiffusion(nn.Module):
659
- def __init__(self, type: str, net: nn.Module, **kwargs):
660
- super().__init__()
661
-
662
- diffusion_classes = [VDiffusion, KDiffusion, VKDiffusion]
663
- aliases = [t.alias for t in diffusion_classes] # type: ignore
664
- message = f"type='{type}' must be one of {*aliases,}"
665
- assert type in aliases, message
666
- self.net = net
667
-
668
- for XDiffusion in diffusion_classes:
669
- if XDiffusion.alias == type: # type: ignore
670
- self.diffusion = XDiffusion(net=net, **kwargs)
671
-
672
- def forward(self, *args, **kwargs) -> Tensor:
673
- return self.diffusion(*args, **kwargs)
674
-
675
- def sample(
676
- self,
677
- noise: Tensor,
678
- num_steps: int,
679
- sigma_schedule: Schedule,
680
- sampler: Sampler,
681
- clamp: bool,
682
- **kwargs,
683
- ) -> Tensor:
684
- diffusion_sampler = DiffusionSampler(
685
- diffusion=self.diffusion,
686
- sampler=sampler,
687
- sigma_schedule=sigma_schedule,
688
- num_steps=num_steps,
689
- clamp=clamp,
690
- )
691
- return diffusion_sampler(noise, **kwargs)
 
1
  from math import atan, cos, pi, sin, sqrt
2
  from typing import Any, Callable, List, Optional, Tuple, Type
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
+ from einops import rearrange
7
  from torch import Tensor
 
8
  from .utils import *
9
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ class LogNormalDistribution():
12
  def __init__(self, mean: float, std: float):
13
  self.mean = mean
14
  self.std = std
 
20
  return normal.exp()
21
 
22
 
23
+ class UniformDistribution():
24
  def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
25
  return torch.rand(num_samples, device=device)
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  def to_batch(
30
  batch_size: int,
 
39
  assert exists(xs)
40
  return xs
41
 
42
+ class KDiffusion(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  """Elucidated Diffusion (Karras et al. 2022): https://arxiv.org/abs/2206.00364"""
44
 
45
  alias = "k"
 
48
  self,
49
  net: nn.Module,
50
  *,
51
+ sigma_distribution,
52
  sigma_data: float, # data distribution standard deviation
53
  dynamic_threshold: float = 0.0,
54
  ):
 
73
  sigmas: Optional[Tensor] = None,
74
  sigma: Optional[float] = None,
75
  **kwargs,
76
+ ):
77
+ # raise ValueError
78
  batch_size, device = x_noisy.shape[0], x_noisy.device
79
  sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
80
 
81
  # Predict network output and add skip connection
82
+ # print('\n\n\n\n', kwargs, '\nKWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWAr\n\n\n\n') 'embedding tensor'
83
  c_skip, c_out, c_in, c_noise = self.get_scale_weights(sigmas)
84
  x_pred = self.net(c_in * x_noisy, c_noise, **kwargs)
85
  x_denoised = c_skip * x_noisy + c_out * x_pred
86
 
87
  return x_denoised
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
 
 
 
90
 
 
 
 
91
 
 
 
 
92
 
 
 
93
 
 
 
 
94
 
95
 
 
 
 
96
 
 
97
 
98
 
 
 
 
 
 
99
 
100
 
101
+ class KarrasSchedule(nn.Module):
 
 
 
 
 
 
102
  """https://arxiv.org/abs/2206.00364 equation 5"""
103
 
104
  def __init__(self, sigma_min: float, sigma_max: float, rho: float = 7.0):
 
124
 
125
  class Sampler(nn.Module):
126
 
127
+
128
 
129
  def forward(
130
  self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
 
143
  raise NotImplementedError("Inpainting not available with current sampler")
144
 
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  class ADPM2Sampler(Sampler):
147
  """https://www.desmos.com/calculator/jbxjlqd9mb"""
148
 
149
+ diffusion_types = [KDiffusion,] # VKDiffusion]
150
 
151
  def __init__(self, rho: float = 1.0):
152
  super().__init__()
 
175
  return x_next
176
 
177
  def forward(
178
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int):
179
+ # raise ValueError
180
  x = sigmas[0] * noise
181
  # Denoise to sample
182
  for i in range(num_steps - 1):
183
  x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
184
  return x
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  class DiffusionSampler(nn.Module):
187
  def __init__(
188
  self,
189
+ diffusion,
190
  *,
191
+ sampler,
192
+ sigma_schedule,
193
+ num_steps=None,
194
+ clamp=True,
195
  ):
196
  super().__init__()
197
  self.denoise_fn = diffusion.denoise_fn
 
207
  assert diffusion.alias in [t.alias for t in sampler.diffusion_types], message
208
 
209
  def forward(
210
+ self, noise, num_steps=None, **kwargs):
211
+ # raise ValueError
212
  device = noise.device
213
  num_steps = default(num_steps, self.num_steps) # type: ignore
214
  assert exists(num_steps), "Parameter `num_steps` must be provided"
 
219
  # Sample using sampler
220
  x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps)
221
  x = x.clamp(-1.0, 1.0) if self.clamp else x
222
+ return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Utils/text_utils.py CHANGED
@@ -84,7 +84,8 @@ def split_into_sentences(text):
84
  sentences = [s.strip() for s in sentences]
85
 
86
  # Split Very long sentences >500 phoneme - StyleTTS2 crashes
87
- sentences = [sub_sent+' ' for s in sentences for sub_sent in textwrap.wrap(s, 400, break_long_words=0)]
 
88
 
89
  if sentences and not sentences[-1]: sentences = sentences[:-1]
90
  return sentences
 
84
  sentences = [s.strip() for s in sentences]
85
 
86
  # Split Very long sentences >500 phoneme - StyleTTS2 crashes
87
+ # -- even 400 phonemes sometimes OOM in cuda:4
88
+ sentences = [sub_sent+' ' for s in sentences for sub_sent in textwrap.wrap(s, 300, break_long_words=0)]
89
 
90
  if sentences and not sentences[-1]: sentences = sentences[:-1]
91
  return sentences