File size: 11,216 Bytes
05b4fca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 |
"""
Abstract SDE classes, Reverse SDE, and VE/VP SDEs.
Taken and adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sde_lib.py
"""
import abc
import warnings
import numpy as np
from sgmse.util.tensors import batch_broadcast
import torch
from sgmse.util.registry import Registry
SDERegistry = Registry("SDE")
class SDE(abc.ABC):
"""SDE abstract class. Functions are designed for a mini-batch of inputs."""
def __init__(self, N):
"""Construct an SDE.
Args:
N: number of discretization time steps.
"""
super().__init__()
self.N = N
@property
@abc.abstractmethod
def T(self):
"""End time of the SDE."""
pass
@abc.abstractmethod
def sde(self, x, t, *args):
pass
@abc.abstractmethod
def marginal_prob(self, x, t, *args):
"""Parameters to determine the marginal distribution of the SDE, $p_t(x|args)$."""
pass
@abc.abstractmethod
def prior_sampling(self, shape, *args):
"""Generate one sample from the prior distribution, $p_T(x|args)$ with shape `shape`."""
pass
@abc.abstractmethod
def prior_logp(self, z):
"""Compute log-density of the prior distribution.
Useful for computing the log-likelihood via probability flow ODE.
Args:
z: latent code
Returns:
log probability density
"""
pass
@staticmethod
@abc.abstractmethod
def add_argparse_args(parent_parser):
"""
Add the necessary arguments for instantiation of this SDE class to an argparse ArgumentParser.
"""
pass
def discretize(self, x, t, y, stepsize):
"""Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.
Useful for reverse diffusion sampling and probabiliy flow sampling.
Defaults to Euler-Maruyama discretization.
Args:
x: a torch tensor
t: a torch float representing the time step (from 0 to `self.T`)
Returns:
f, G
"""
dt = stepsize
drift, diffusion = self.sde(x, t, y)
f = drift * dt
G = diffusion * torch.sqrt(dt)
return f, G
def reverse(oself, score_model, probability_flow=False):
"""Create the reverse-time SDE/ODE.
Args:
score_model: A function that takes x, t and y and returns the score.
probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
"""
N = oself.N
T = oself.T
sde_fn = oself.sde
discretize_fn = oself.discretize
# Build the class for reverse-time SDE.
class RSDE(oself.__class__):
def __init__(self):
self.N = N
self.probability_flow = probability_flow
@property
def T(self):
return T
def sde(self, x, t, *args):
"""Create the drift and diffusion functions for the reverse SDE/ODE."""
rsde_parts = self.rsde_parts(x, t, *args)
total_drift, diffusion = rsde_parts["total_drift"], rsde_parts["diffusion"]
return total_drift, diffusion
def rsde_parts(self, x, t, *args):
sde_drift, sde_diffusion = sde_fn(x, t, *args)
score = score_model(x, t, *args)
score_drift = -sde_diffusion[:, None, None, None]**2 * score * (0.5 if self.probability_flow else 1.)
diffusion = torch.zeros_like(sde_diffusion) if self.probability_flow else sde_diffusion
total_drift = sde_drift + score_drift
return {
'total_drift': total_drift, 'diffusion': diffusion, 'sde_drift': sde_drift,
'sde_diffusion': sde_diffusion, 'score_drift': score_drift, 'score': score,
}
def discretize(self, x, t, y, stepsize):
"""Create discretized iteration rules for the reverse diffusion sampler."""
f, G = discretize_fn(x, t, y, stepsize)
rev_f = f - G[:, None, None, None] ** 2 * score_model(x, t, y) * (0.5 if self.probability_flow else 1.)
rev_G = torch.zeros_like(G) if self.probability_flow else G
return rev_f, rev_G
return RSDE()
@abc.abstractmethod
def copy(self):
pass
@SDERegistry.register("ouve")
class OUVESDE(SDE):
@staticmethod
def add_argparse_args(parser):
parser.add_argument("--sde-n", type=int, default=1000, help="The number of timesteps in the SDE discretization. 30 by default")
parser.add_argument("--theta", type=float, default=1.5, help="The constant stiffness of the Ornstein-Uhlenbeck process. 1.5 by default.")
parser.add_argument("--sigma-min", type=float, default=0.05, help="The minimum sigma to use. 0.05 by default.")
parser.add_argument("--sigma-max", type=float, default=0.5, help="The maximum sigma to use. 0.5 by default.")
return parser
def __init__(self, theta, sigma_min, sigma_max, N=1000, **ignored_kwargs):
"""Construct an Ornstein-Uhlenbeck Variance Exploding SDE.
Note that the "steady-state mean" `y` is not provided at construction, but must rather be given as an argument
to the methods which require it (e.g., `sde` or `marginal_prob`).
dx = -theta (y-x) dt + sigma(t) dw
with
sigma(t) = sigma_min (sigma_max/sigma_min)^t * sqrt(2 log(sigma_max/sigma_min))
Args:
theta: stiffness parameter.
sigma_min: smallest sigma.
sigma_max: largest sigma.
N: number of discretization steps
"""
super().__init__(N)
self.theta = theta
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.logsig = np.log(self.sigma_max / self.sigma_min)
self.N = N
def copy(self):
return OUVESDE(self.theta, self.sigma_min, self.sigma_max, N=self.N)
@property
def T(self):
return 1
def sde(self, x, t, y):
drift = self.theta * (y - x)
# the sqrt(2*logsig) factor is required here so that logsig does not in the end affect the perturbation kernel
# standard deviation. this can be understood from solving the integral of [exp(2s) * g(s)^2] from s=0 to t
# with g(t) = sigma(t) as defined here, and seeing that `logsig` remains in the integral solution
# unless this sqrt(2*logsig) factor is included.
sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
diffusion = sigma * np.sqrt(2 * self.logsig)
return drift, diffusion
def _mean(self, x0, t, y):
theta = self.theta
exp_interp = torch.exp(-theta * t)[:, None, None, None]
return exp_interp * x0 + (1 - exp_interp) * y
def alpha(self, t):
return torch.exp(-self.theta * t)
def _std(self, t):
# This is a full solution to the ODE for P(t) in our derivations, after choosing g(s) as in self.sde()
sigma_min, theta, logsig = self.sigma_min, self.theta, self.logsig
# could maybe replace the two torch.exp(... * t) terms here by cached values **t
return torch.sqrt(
(
sigma_min**2
* torch.exp(-2 * theta * t)
* (torch.exp(2 * (theta + logsig) * t) - 1)
* logsig
)
/
(theta + logsig)
)
def marginal_prob(self, x0, t, y):
return self._mean(x0, t, y), self._std(t)
def prior_sampling(self, shape, y):
if shape != y.shape:
warnings.warn(f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape.")
std = self._std(torch.ones((y.shape[0],), device=y.device))
x_T = y + torch.randn_like(y) * std[:, None, None, None]
return x_T
def prior_logp(self, z):
raise NotImplementedError("prior_logp for OU SDE not yet implemented!")
@SDERegistry.register("ouvp")
class OUVPSDE(SDE):
# !!! We do not utilize this SDE in our works due to observed instabilities around t=0.2. !!!
@staticmethod
def add_argparse_args(parser):
parser.add_argument("--sde-n", type=int, default=1000,
help="The number of timesteps in the SDE discretization. 1000 by default")
parser.add_argument("--beta-min", type=float, required=True,
help="The minimum beta to use.")
parser.add_argument("--beta-max", type=float, required=True,
help="The maximum beta to use.")
parser.add_argument("--stiffness", type=float, default=1,
help="The stiffness factor for the drift, to be multiplied by 0.5*beta(t). 1 by default.")
return parser
def __init__(self, beta_min, beta_max, stiffness=1, N=1000, **ignored_kwargs):
"""
!!! We do not utilize this SDE in our works due to observed instabilities around t=0.2. !!!
Construct an Ornstein-Uhlenbeck Variance Preserving SDE:
dx = -1/2 * beta(t) * stiffness * (y-x) dt + sqrt(beta(t)) * dw
with
beta(t) = beta_min + t(beta_max - beta_min)
Note that the "steady-state mean" `y` is not provided at construction, but must rather be given as an argument
to the methods which require it (e.g., `sde` or `marginal_prob`).
Args:
beta_min: smallest sigma.
beta_max: largest sigma.
stiffness: stiffness factor of the drift. 1 by default.
N: number of discretization steps
"""
super().__init__(N)
self.beta_min = beta_min
self.beta_max = beta_max
self.stiffness = stiffness
self.N = N
def copy(self):
return OUVPSDE(self.beta_min, self.beta_max, self.stiffness, N=self.N)
@property
def T(self):
return 1
def _beta(self, t):
return self.beta_min + t * (self.beta_max - self.beta_min)
def sde(self, x, t, y):
drift = 0.5 * self.stiffness * batch_broadcast(self._beta(t), y) * (y - x)
diffusion = torch.sqrt(self._beta(t))
return drift, diffusion
def _mean(self, x0, t, y):
b0, b1, s = self.beta_min, self.beta_max, self.stiffness
x0y_fac = torch.exp(-0.25 * s * t * (t * (b1-b0) + 2 * b0))[:, None, None, None]
return y + x0y_fac * (x0 - y)
def _std(self, t):
b0, b1, s = self.beta_min, self.beta_max, self.stiffness
return (1 - torch.exp(-0.5 * s * t * (t * (b1-b0) + 2 * b0))) / s
def marginal_prob(self, x0, t, y):
return self._mean(x0, t, y), self._std(t)
def prior_sampling(self, shape, y):
if shape != y.shape:
warnings.warn(f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape.")
std = self._std(torch.ones((y.shape[0],), device=y.device))
x_T = y + torch.randn_like(y) * std[:, None, None, None]
return x_T
def prior_logp(self, z):
raise NotImplementedError("prior_logp for OU SDE not yet implemented!")
|