3v324v23's picture
Add code
a84a65c
raw
history blame
14.9 kB
import torch as th
import numpy as np
import logging
import enum
from . import path
from .utils import EasyDict, log_state, mean_flat
from .integrators import ode, sde
class ModelType(enum.Enum):
"""
Which type of output the model predicts.
"""
NOISE = enum.auto() # the model predicts epsilon
SCORE = enum.auto() # the model predicts \nabla \log p(x)
VELOCITY = enum.auto() # the model predicts v(x)
class PathType(enum.Enum):
"""
Which type of path to use.
"""
LINEAR = enum.auto()
GVP = enum.auto()
VP = enum.auto()
class WeightType(enum.Enum):
"""
Which type of weighting to use.
"""
NONE = enum.auto()
VELOCITY = enum.auto()
LIKELIHOOD = enum.auto()
class SNRType(enum.Enum):
UNIFORM = enum.auto()
LOGNORM = enum.auto()
class Transport:
def __init__(
self,
*,
model_type,
path_type,
loss_type,
train_eps,
sample_eps,
snr_type
):
path_options = {
PathType.LINEAR: path.ICPlan,
PathType.GVP: path.GVPCPlan,
PathType.VP: path.VPCPlan,
}
self.loss_type = loss_type
self.model_type = model_type
self.path_sampler = path_options[path_type]()
self.train_eps = train_eps
self.sample_eps = sample_eps
self.snr_type = snr_type
def prior_logp(self, z):
'''
Standard multivariate normal prior
Assume z is batched
'''
shape = th.tensor(z.size())
N = th.prod(shape[1:])
_fn = lambda x: -N / 2. * np.log(2 * np.pi) - th.sum(x ** 2) / 2.
return th.vmap(_fn)(z)
def check_interval(
self,
train_eps,
sample_eps,
*,
diffusion_form="SBDM",
sde=False,
reverse=False,
eval=False,
last_step_size=0.0,
):
t0 = 0
t1 = 1
eps = train_eps if not eval else sample_eps
if (type(self.path_sampler) in [path.VPCPlan]):
t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) \
and (self.model_type != ModelType.VELOCITY or sde): # avoid numerical issue by taking a first semi-implicit step
t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0
t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
if reverse:
t0, t1 = 1 - t0, 1 - t1
return t0, t1
def sample(self, x1):
"""Sampling x0 & t based on shape of x1 (if needed)
Args:
x1 - data point; [batch, *dim]
"""
if isinstance(x1, (list, tuple)):
x0 = [th.randn_like(img_start) for img_start in x1]
else:
x0 = th.randn_like(x1)
t0, t1 = self.check_interval(self.train_eps, self.sample_eps)
if self.snr_type == SNRType.UNIFORM:
t = th.rand((len(x1),)) * (t1 - t0) + t0
elif self.snr_type == SNRType.LOGNORM:
u = th.normal(mean=0., std=1., size=(len(x1),))
t = 1 / (1 + th.exp(-u)) * (t1 - t0) + t0
else:
raise ValueError(f"Unknown snr type: {self.snr_type}")
t = t.to(x1[0])
return t, x0, x1
def training_losses(
self,
model,
x1,
model_kwargs=None
):
"""Loss for training the score model
Args:
- model: backbone model; could be score, noise, or velocity
- x1: datapoint
- model_kwargs: additional arguments for the model
"""
if model_kwargs == None:
model_kwargs = {}
t, x0, x1 = self.sample(x1)
t, xt, ut = self.path_sampler.plan(t, x0, x1)
model_output = model(xt, t, **model_kwargs)
B = len(x0)
terms = {}
# terms['pred'] = model_output
if self.model_type == ModelType.VELOCITY:
if isinstance(x1, (list, tuple)):
assert len(model_output) == len(ut) == len(x1)
for i in range(B):
assert model_output[i].shape == ut[i].shape == x1[i].shape, (
f"{model_output[i].shape} {ut[i].shape} {x1[i].shape}"
)
terms["task_loss"] = th.stack(
[((ut[i] - model_output[i]) ** 2).mean() for i in range(B)],
dim=0,
)
else:
terms['task_loss'] = mean_flat(((model_output - ut) ** 2))
else:
raise NotImplementedError
terms['loss'] = terms['task_loss']
terms['task_loss'] = terms['task_loss'].clone().detach()
return terms
def get_drift(
self
):
"""member function for obtaining the drift of the probability flow ODE"""
def score_ode(x, t, model, **model_kwargs):
drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
model_output = model(x, t, **model_kwargs)
return (-drift_mean + drift_var * model_output) # by change of variable
def noise_ode(x, t, model, **model_kwargs):
drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))
model_output = model(x, t, **model_kwargs)
score = model_output / -sigma_t
return (-drift_mean + drift_var * score)
def velocity_ode(x, t, model, **model_kwargs):
model_output = model(x, t, **model_kwargs)
return model_output
if self.model_type == ModelType.NOISE:
drift_fn = noise_ode
elif self.model_type == ModelType.SCORE:
drift_fn = score_ode
else:
drift_fn = velocity_ode
def body_fn(x, t, model, **model_kwargs):
model_output = drift_fn(x, t, model, **model_kwargs)
assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape"
return model_output
return body_fn
def get_score(
self,
):
"""member function for obtaining score of
x_t = alpha_t * x + sigma_t * eps"""
if self.model_type == ModelType.NOISE:
score_fn = lambda x, t, model, **kwargs: model(x, t, **kwargs) / -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0]
elif self.model_type == ModelType.SCORE:
score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs)
elif self.model_type == ModelType.VELOCITY:
score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(model(x, t, **kwargs), x, t)
else:
raise NotImplementedError()
return score_fn
class Sampler:
"""Sampler class for the transport model"""
def __init__(
self,
transport,
):
"""Constructor for a general sampler; supporting different sampling methods
Args:
- transport: an tranport object specify model prediction & interpolant type
"""
self.transport = transport
self.drift = self.transport.get_drift()
self.score = self.transport.get_score()
def __get_sde_diffusion_and_drift(
self,
*,
diffusion_form="SBDM",
diffusion_norm=1.0,
):
def diffusion_fn(x, t):
diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm)
return diffusion
sde_drift = \
lambda x, t, model, **kwargs: \
self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs)
sde_diffusion = diffusion_fn
return sde_drift, sde_diffusion
def __get_last_step(
self,
sde_drift,
*,
last_step,
last_step_size,
):
"""Get the last step function of the SDE solver"""
if last_step is None:
last_step_fn = \
lambda x, t, model, **model_kwargs: \
x
elif last_step == "Mean":
last_step_fn = \
lambda x, t, model, **model_kwargs: \
x + sde_drift(x, t, model, **model_kwargs) * last_step_size
elif last_step == "Tweedie":
alpha = self.transport.path_sampler.compute_alpha_t # simple aliasing; the original name was too long
sigma = self.transport.path_sampler.compute_sigma_t
last_step_fn = \
lambda x, t, model, **model_kwargs: \
x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][0] * self.score(x, t, model, **model_kwargs)
elif last_step == "Euler":
last_step_fn = \
lambda x, t, model, **model_kwargs: \
x + self.drift(x, t, model, **model_kwargs) * last_step_size
else:
raise NotImplementedError()
return last_step_fn
def sample_sde(
self,
*,
sampling_method="Euler",
diffusion_form="SBDM",
diffusion_norm=1.0,
last_step="Mean",
last_step_size=0.04,
num_steps=250,
):
"""returns a sampling function with given SDE settings
Args:
- sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama
- diffusion_form: function form of diffusion coefficient; default to be matching SBDM
- diffusion_norm: function magnitude of diffusion coefficient; default to 1
- last_step: type of the last step; default to identity
- last_step_size: size of the last step; default to match the stride of 250 steps over [0,1]
- num_steps: total integration step of SDE
"""
if last_step is None:
last_step_size = 0.0
sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift(
diffusion_form=diffusion_form,
diffusion_norm=diffusion_norm,
)
t0, t1 = self.transport.check_interval(
self.transport.train_eps,
self.transport.sample_eps,
diffusion_form=diffusion_form,
sde=True,
eval=True,
reverse=False,
last_step_size=last_step_size,
)
_sde = sde(
sde_drift,
sde_diffusion,
t0=t0,
t1=t1,
num_steps=num_steps,
sampler_type=sampling_method
)
last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size)
def _sample(init, model, **model_kwargs):
xs = _sde.sample(init, model, **model_kwargs)
ts = th.ones(init.size(0), device=init.device) * t1
x = last_step_fn(xs[-1], ts, model, **model_kwargs)
xs.append(x)
assert len(xs) == num_steps, "Samples does not match the number of steps"
return xs
return _sample
def sample_ode(
self,
*,
sampling_method="dopri5",
num_steps=50,
atol=1e-6,
rtol=1e-3,
reverse=False,
time_shifting_factor=None,
):
"""returns a sampling function with given ODE settings
Args:
- sampling_method: type of sampler used in solving the ODE; default to be Dopri5
- num_steps:
- fixed solver (Euler, Heun): the actual number of integration steps performed
- adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
- atol: absolute error tolerance for the solver
- rtol: relative error tolerance for the solver
- reverse: whether solving the ODE in reverse (data to noise); default to False
"""
if reverse:
drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs)
else:
drift = self.drift
t0, t1 = self.transport.check_interval(
self.transport.train_eps,
self.transport.sample_eps,
sde=False,
eval=True,
reverse=reverse,
last_step_size=0.0,
)
_ode = ode(
drift=drift,
t0=t0,
t1=t1,
sampler_type=sampling_method,
num_steps=num_steps,
atol=atol,
rtol=rtol,
time_shifting_factor=time_shifting_factor,
)
return _ode.sample
def sample_ode_likelihood(
self,
*,
sampling_method="dopri5",
num_steps=50,
atol=1e-6,
rtol=1e-3,
):
"""returns a sampling function for calculating likelihood with given ODE settings
Args:
- sampling_method: type of sampler used in solving the ODE; default to be Dopri5
- num_steps:
- fixed solver (Euler, Heun): the actual number of integration steps performed
- adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
- atol: absolute error tolerance for the solver
- rtol: relative error tolerance for the solver
"""
def _likelihood_drift(x, t, model, **model_kwargs):
x, _ = x
eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1
t = th.ones_like(t) * (1 - t)
with th.enable_grad():
x.requires_grad = True
grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0]
logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size()))))
drift = self.drift(x, t, model, **model_kwargs)
return (-drift, logp_grad)
t0, t1 = self.transport.check_interval(
self.transport.train_eps,
self.transport.sample_eps,
sde=False,
eval=True,
reverse=False,
last_step_size=0.0,
)
_ode = ode(
drift=_likelihood_drift,
t0=t0,
t1=t1,
sampler_type=sampling_method,
num_steps=num_steps,
atol=atol,
rtol=rtol,
)
def _sample_fn(x, model, **model_kwargs):
init_logp = th.zeros(x.size(0)).to(x)
input = (x, init_logp)
drift, delta_logp = _ode.sample(input, model, **model_kwargs)
drift, delta_logp = drift[-1], delta_logp[-1]
prior_logp = self.transport.prior_logp(drift)
logp = prior_logp - delta_logp
return logp, drift
return _sample_fn