import logging from dataclasses import dataclass from functools import partial from typing import Protocol, Union import matplotlib.pyplot as plt import numpy as np import scipy import torch import torch.nn.functional as F from torch import Tensor, nn from tqdm import trange from .wn import WN logger = logging.getLogger(__name__) class VelocityField(Protocol): def __call__(self, *, t: Tensor, ψt: Tensor, dt: Tensor) -> Tensor: ... class Solver: def __init__( self, method="midpoint", nfe=32, viz_name="solver", viz_every=100, mel_fn=None, time_mapping_divisor=4, verbose=False, ): self.configurate_(nfe=nfe, method=method) self.verbose = verbose self.viz_every = viz_every self.viz_name = viz_name self._camera = None self._mel_fn = mel_fn self._time_mapping = partial( self.exponential_decay_mapping, n=time_mapping_divisor ) def configurate_(self, nfe=None, method=None): if nfe is None: nfe = self.nfe if method is None: method = self.method if nfe == 1 and method in ("midpoint", "rk4"): logger.warning( f"1 NFE is not supported for {method}, using euler method instead." ) method = "euler" self.nfe = nfe self.method = method @property def time_mapping(self): return self._time_mapping @staticmethod def exponential_decay_mapping(t, n=4): """ Args: n: target step """ def h(t, a): return (a**t - 1) / (a - 1) # Solve h(1/n) = 0.5 a = float(scipy.optimize.fsolve(lambda a: h(1 / n, a) - 0.5, x0=0)) t = h(t, a=a) return t @torch.no_grad() def _maybe_camera_snap(self, *, ψt, t): camera = self._camera if camera is not None: if ψt.shape[1] == 1: # Waveform, b 1 t, plot every 100 samples plt.subplot(211) plt.plot(ψt.detach().cpu().numpy()[0, 0, ::100], color="blue") if self._mel_fn is not None: plt.subplot(212) mel = self._mel_fn(ψt.detach().cpu().numpy()[0, 0]) plt.imshow(mel, origin="lower", interpolation="none") elif ψt.shape[1] == 2: # Complex plt.subplot(121) plt.imshow( ψt.detach().cpu().numpy()[0, 0], origin="lower", interpolation="none", ) plt.subplot(122) plt.imshow( ψt.detach().cpu().numpy()[0, 1], origin="lower", interpolation="none", ) else: # Spectrogram, b c t plt.imshow( ψt.detach().cpu().numpy()[0], origin="lower", interpolation="none" ) ax = plt.gca() ax.text(0.5, 1.01, f"t={t:.2f}", transform=ax.transAxes, ha="center") camera.snap() @staticmethod def _euler_step(t, ψt, dt, f: VelocityField): return ψt + dt * f(t=t, ψt=ψt, dt=dt) @staticmethod def _midpoint_step(t, ψt, dt, f: VelocityField): return ψt + dt * f(t=t + dt / 2, ψt=ψt + dt * f(t=t, ψt=ψt, dt=dt) / 2, dt=dt) @staticmethod def _rk4_step(t, ψt, dt, f: VelocityField): k1 = f(t=t, ψt=ψt, dt=dt) k2 = f(t=t + dt / 2, ψt=ψt + dt * k1 / 2, dt=dt) k3 = f(t=t + dt / 2, ψt=ψt + dt * k2 / 2, dt=dt) k4 = f(t=t + dt, ψt=ψt + dt * k3, dt=dt) return ψt + dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6 @property def _step(self): if self.method == "euler": return self._euler_step elif self.method == "midpoint": return self._midpoint_step elif self.method == "rk4": return self._rk4_step else: raise ValueError(f"Unknown method: {self.method}") def get_running_train_loop(self): try: # Lazy import from ...utils.train_loop import TrainLoop return TrainLoop.get_running_loop() except ImportError: return None @property def visualizing(self): loop = self.get_running_train_loop() if loop is None: return out_path = loop.make_current_step_viz_path(self.viz_name, ".gif") return loop.global_step % self.viz_every == 0 and not out_path.exists() def _reset_camera(self): try: from celluloid import Camera self._camera = Camera(plt.figure()) except: pass def _maybe_dump_camera(self): camera = self._camera loop = self.get_running_train_loop() if camera is not None and loop is not None: animation = camera.animate() out_path = loop.make_current_step_viz_path(self.viz_name, ".gif") out_path.parent.mkdir(exist_ok=True, parents=True) animation.save(out_path, writer="pillow", fps=4) plt.close() self._camera = None @property def n_steps(self): n = self.nfe if self.method == "euler": pass elif self.method == "midpoint": n //= 2 elif self.method == "rk4": n //= 4 else: raise ValueError(f"Unknown method: {self.method}") return n def solve(self, f: VelocityField, ψ0: Tensor, t0=0.0, t1=1.0): ts = self._time_mapping(np.linspace(t0, t1, self.n_steps + 1)) if self.visualizing: self._reset_camera() if self.verbose: steps = trange(self.n_steps, desc="CFM inference") else: steps = range(self.n_steps) ψt = ψ0 for i in steps: dt = ts[i + 1] - ts[i] t = ts[i] self._maybe_camera_snap(ψt=ψt, t=t) ψt = self._step(t=t, ψt=ψt, dt=dt, f=f) self._maybe_camera_snap(ψt=ψt, t=ts[-1]) ψ1 = ψt del ψt self._maybe_dump_camera() return ψ1 def __call__(self, f: VelocityField, ψ0: Tensor, t0=0.0, t1=1.0): return self.solve(f=f, ψ0=ψ0, t0=t0, t1=t1) class SinusodialTimeEmbedding(nn.Module): def __init__(self, d_embed): super().__init__() self.d_embed = d_embed assert d_embed % 2 == 0 def forward(self, t): t = t.unsqueeze(-1) # ... 1 p = torch.linspace(0, 4, self.d_embed // 2).to(t) while p.dim() < t.dim(): p = p.unsqueeze(0) # ... d/2 sin = torch.sin(t * 10**p) cos = torch.cos(t * 10**p) return torch.cat([sin, cos], dim=-1) @dataclass(eq=False) class CFM(nn.Module): """ This mixin is for general diffusion models. ψ0 stands for the gaussian noise, and ψ1 is the data point. Here we follow the CFM style: The generation process (reverse process) is from t=0 to t=1. The forward process is from t=1 to t=0. """ cond_dim: int output_dim: int time_emb_dim: int = 128 viz_name: str = "cfm" solver_nfe: int = 32 solver_method: str = "midpoint" time_mapping_divisor: int = 4 def __post_init__(self): super().__init__() self.solver = Solver( viz_name=self.viz_name, viz_every=1, nfe=self.solver_nfe, method=self.solver_method, time_mapping_divisor=self.time_mapping_divisor, ) self.emb = SinusodialTimeEmbedding(self.time_emb_dim) self.net = WN( input_dim=self.output_dim, output_dim=self.output_dim, local_dim=self.cond_dim, global_dim=self.time_emb_dim, ) def _perturb(self, ψ1: Tensor, t: Union[Tensor, None] = None): """ Perturb ψ1 to ψt. """ raise NotImplementedError def _sample_ψ0(self, x: Tensor): """ Args: x: (b c t), which implies the shape of ψ0 """ shape = list(x.shape) shape[1] = self.output_dim if self.training: g = None else: g = torch.Generator(device=x.device) g.manual_seed(0) # deterministic sampling during eval ψ0 = torch.randn(shape, device=x.device, dtype=x.dtype, generator=g) return ψ0 @property def sigma(self): return 1e-4 def _to_ψt(self, *, ψ1: Tensor, ψ0: Tensor, t: Tensor): """ Eq (22) """ while t.dim() < ψ1.dim(): t = t.unsqueeze(-1) μ = t * ψ1 + (1 - t) * ψ0 return μ + torch.randn_like(μ) * self.sigma def _to_u(self, *, ψ1, ψ0: Tensor): """ Eq (21) """ return ψ1 - ψ0 def _to_v(self, *, ψt, x, t: Union[float, Tensor]): """ Args: ψt: (b c t) x: (b c t) t: (b) Returns: v: (b c t) """ if isinstance(t, (float, int)): t = torch.full(ψt.shape[:1], t).to(ψt) t = t.clamp(0, 1) # [0, 1) g = self.emb(t) # (b d) v = self.net(ψt, l=x, g=g) return v def compute_losses(self, x, y, ψ0) -> dict: """ Args: x: (b c t) y: (b c t) Returns: losses: dict """ t = torch.rand(len(x), device=x.device, dtype=x.dtype) t = self.solver.time_mapping(t) if ψ0 is None: ψ0 = self._sample_ψ0(x) ψt = self._to_ψt(ψ1=y, t=t, ψ0=ψ0) v = self._to_v(ψt=ψt, t=t, x=x) u = self._to_u(ψ1=y, ψ0=ψ0) losses = dict(l1=F.l1_loss(v, u)) return losses @torch.inference_mode() def sample(self, x, ψ0=None, t0=0.0): """ Args: x: (b c t) Returns: y: (b ... t) """ if ψ0 is None: ψ0 = self._sample_ψ0(x) f = lambda t, ψt, dt: self._to_v(ψt=ψt, t=t, x=x) ψ1 = self.solver(f=f, ψ0=ψ0, t0=t0) return ψ1 def forward( self, x: Tensor, y: Union[Tensor, None] = None, ψ0: Union[Tensor, None] = None, t0=0.0, ): if y is None: y = self.sample(x, ψ0=ψ0, t0=t0) else: self.losses = self.compute_losses(x, y, ψ0=ψ0) return y