Spaces:
Runtime error
Runtime error
import torch | |
import math | |
import numpy as np | |
import enum | |
class GaussingDistribution: | |
def __init__(self, parameters: torch.Tensor) -> None: | |
self.mean, log_variance = torch.chunk(parameters, 2, dim=1) | |
self.log_variance = torch.clamp(log_variance, -30.0, 20.0) | |
self.std = torch.exp(0.5 * self.log_variance) | |
def sample(self): | |
return self.mean + self.std * torch.rand_like(self.std) | |
def normal_kl(mean1, logvar1, mean2, logvar2): | |
tensor = None | |
for obj in (mean1, logvar1, mean2, logvar2): | |
if isinstance(obj, torch.Tensor): | |
tensor = obj | |
break | |
assert tensor is not None, "at least one argument must be a Tensor" | |
logvar1, logvar2 = [ | |
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) | |
for x in (logvar1, logvar2) | |
] | |
return 0.5 * ( | |
-1.0 | |
+ logvar2 | |
- logvar1 | |
+ torch.exp(logvar1 - logvar2) | |
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2) | |
) | |
def approx_standard_normal_cdf(x): | |
return 0.5 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3)))) | |
def continuous_gaussian_log_likelihood(x, *, means, log_scales): | |
centered_x = x - means | |
inv_stdv = torch.exp(-log_scales) | |
normalized_x = centered_x * inv_stdv | |
log_probs = torch.distributions.Normal(torch.zeros_like(x), torch.ones_like(x)).log_prob(normalized_x) | |
return log_probs | |
def discretized_gaussian_log_likelihood(x, *, means, log_scales): | |
assert x.shape == means.shape == log_scales.shape | |
centered_x = x - means | |
inv_stdv = torch.exp(-log_scales) | |
plus_in = inv_stdv * (centered_x + 1.0 / 255.0) | |
cdf_plus = approx_standard_normal_cdf(plus_in) | |
min_in = inv_stdv * (centered_x - 1.0 / 255.0) | |
cdf_min = approx_standard_normal_cdf(min_in) | |
log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) | |
log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) | |
cdf_delta = cdf_plus - cdf_min | |
log_probs = torch.where( | |
x < -0.999, | |
log_cdf_plus, | |
torch.where(x > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))), | |
) | |
assert log_probs.shape == x.shape | |
return log_probs | |
################# Gaussing #################### | |
def mean_flat(tensor): | |
return tensor.mean(dim=list(range(1, len(tensor.shape)))) | |
class ModelMeanType(enum.Enum): | |
PREVIOUS_X = enum.auto() # the model predicts x_{t-1} | |
START_X = enum.auto() # the model predicts x_0 | |
EPSILON = enum.auto() # the model predicts epsilon | |
class ModelVarType(enum.Enum): | |
LEARNED = enum.auto() | |
FIXED_SMALL = enum.auto() | |
FIXED_LARGE = enum.auto() | |
LEARNED_RANGE = enum.auto() | |
class LossType(enum.Enum): | |
MSE = enum.auto() # use raw MSE loss (and KL when learning variances) | |
RESCALED_MSE = ( | |
enum.auto() | |
) # use raw MSE loss (with RESCALED_KL when learning variances) | |
KL = enum.auto() # use the variational lower-bound | |
RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB | |
def is_vb(self): | |
return self == LossType.KL or self == LossType.RESCALED_KL | |
def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): | |
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) | |
warmup_time = int(num_diffusion_timesteps * warmup_frac) | |
betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) | |
return betas | |
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): | |
if beta_schedule == "quad": | |
betas = ( | |
np.linspace( | |
beta_start ** 0.5, | |
beta_end ** 0.5, | |
num_diffusion_timesteps, | |
dtype=np.float64, | |
) | |
** 2 | |
) | |
elif beta_schedule == "linear": | |
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) | |
elif beta_schedule == "warmup10": | |
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) | |
elif beta_schedule == "warmup50": | |
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) | |
elif beta_schedule == "const": | |
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) | |
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 | |
betas = 1.0 / np.linspace( | |
num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 | |
) | |
else: | |
raise NotImplementedError(beta_schedule) | |
assert betas.shape == (num_diffusion_timesteps,) | |
return betas | |
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): | |
if schedule_name == "linear": | |
scale = 1000 / num_diffusion_timesteps | |
return get_beta_schedule( | |
"linear", | |
beta_start=scale * 0.0001, | |
beta_end=scale * 0.02, | |
num_diffusion_timesteps=num_diffusion_timesteps, | |
) | |
elif schedule_name == "squaredcos_cap_v2": | |
return betas_for_alpha_bar( | |
num_diffusion_timesteps, | |
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, | |
) | |
else: | |
raise NotImplementedError(f"unknown beta schedule: {schedule_name}") | |
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): | |
betas = [] | |
for i in range(num_diffusion_timesteps): | |
t1 = i / num_diffusion_timesteps | |
t2 = (i + 1) / num_diffusion_timesteps | |
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) | |
return np.array(betas) | |
class GaussianDiffusion: | |
def __init__( | |
self, | |
*, | |
betas, | |
model_mean_type, | |
model_var_type, | |
loss_type | |
): | |
self.model_mean_type = model_mean_type | |
self.model_var_type = model_var_type | |
self.loss_type = loss_type | |
# Use float64 for accuracy. | |
betas = np.array(betas, dtype=np.float64) | |
self.betas = betas | |
assert len(betas.shape) == 1, "betas must be 1-D" | |
assert (betas > 0).all() and (betas <= 1).all() | |
self.num_timesteps = int(betas.shape[0]) | |
alphas = 1.0 - betas | |
self.alphas_cumprod = np.cumprod(alphas, axis=0) | |
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) | |
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) | |
assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) | |
# calculations for diffusion q(x_t | x_{t-1}) and others | |
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) | |
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) | |
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) | |
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) | |
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) | |
# calculations for posterior q(x_{t-1} | x_t, x_0) | |
self.posterior_variance = ( | |
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) | |
) | |
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain | |
self.posterior_log_variance_clipped = np.log( | |
np.append(self.posterior_variance[1], self.posterior_variance[1:]) | |
) if len(self.posterior_variance) > 1 else np.array([]) | |
self.posterior_mean_coef1 = ( | |
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) | |
) | |
self.posterior_mean_coef2 = ( | |
(1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) | |
) | |
def q_mean_variance(self, x_start, t): | |
mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start | |
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) | |
log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) | |
return mean, variance, log_variance | |
def q_sample(self, x_start, t, noise=None): | |
if noise is None: | |
noise = torch.randn_like(x_start) | |
assert noise.shape == x_start.shape | |
return ( | |
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start | |
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise | |
) | |
def q_posterior_mean_variance(self, x_start, x_t, t): | |
assert x_start.shape == x_t.shape | |
posterior_mean = ( | |
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start | |
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t | |
) | |
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) | |
posterior_log_variance_clipped = _extract_into_tensor( | |
self.posterior_log_variance_clipped, t, x_t.shape | |
) | |
assert ( | |
posterior_mean.shape[0] | |
== posterior_variance.shape[0] | |
== posterior_log_variance_clipped.shape[0] | |
== x_start.shape[0] | |
) | |
return posterior_mean, posterior_variance, posterior_log_variance_clipped | |
def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): | |
if model_kwargs is None: | |
model_kwargs = {} | |
B, C = x.shape[:2] | |
assert t.shape == (B,) | |
model_output = model(x, t, **model_kwargs) | |
if isinstance(model_output, tuple): | |
model_output, extra = model_output | |
else: | |
extra = None | |
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: | |
assert model_output.shape == (B, C * 2, *x.shape[2:]) | |
model_output, model_var_values = torch.split(model_output, C, dim=1) | |
min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) | |
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) | |
frac = (model_var_values + 1) / 2 | |
model_log_variance = frac * max_log + (1 - frac) * min_log | |
model_variance = torch.exp(model_log_variance) | |
else: | |
model_variance, model_log_variance = { | |
ModelVarType.FIXED_LARGE: ( | |
np.append(self.posterior_variance[1], self.betas[1:]), | |
np.log(np.append(self.posterior_variance[1], self.betas[1:])), | |
), | |
ModelVarType.FIXED_SMALL: ( | |
self.posterior_variance, | |
self.posterior_log_variance_clipped, | |
), | |
}[self.model_var_type] | |
model_variance = _extract_into_tensor(model_variance, t, x.shape) | |
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) | |
def process_xstart(x): | |
if denoised_fn is not None: | |
x = denoised_fn(x) | |
if clip_denoised: | |
return x.clamp(-1, 1) | |
return x | |
if self.model_mean_type == ModelMeanType.START_X: | |
pred_xstart = process_xstart(model_output) | |
else: | |
pred_xstart = process_xstart( | |
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) | |
) | |
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) | |
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape | |
return { | |
"mean": model_mean, | |
"variance": model_variance, | |
"log_variance": model_log_variance, | |
"pred_xstart": pred_xstart, | |
"extra": extra, | |
} | |
def _predict_xstart_from_eps(self, x_t, t, eps): | |
assert x_t.shape == eps.shape | |
return ( | |
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t | |
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps | |
) | |
def _predict_eps_from_xstart(self, x_t, t, pred_xstart): | |
return ( | |
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart | |
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) | |
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): | |
gradient = cond_fn(x, t, **model_kwargs) | |
new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() | |
return new_mean | |
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): | |
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) | |
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) | |
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) | |
out = p_mean_var.copy() | |
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) | |
out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) | |
return out | |
def p_sample( | |
self, | |
model, | |
x, | |
t, | |
clip_denoised=True, | |
denoised_fn=None, | |
cond_fn=None, | |
model_kwargs=None, | |
): | |
out = self.p_mean_variance( | |
model, | |
x, | |
t, | |
clip_denoised=clip_denoised, | |
denoised_fn=denoised_fn, | |
model_kwargs=model_kwargs, | |
) | |
noise = torch.randn_like(x) | |
nonzero_mask = ( | |
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) | |
) # no noise when t == 0 | |
if cond_fn is not None: | |
out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) | |
sample = out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * noise | |
return {"sample": sample, "pred_xstart": out["pred_xstart"]} | |
def p_sample_loop( | |
self, | |
model, | |
shape, | |
noise=None, | |
clip_denoised=True, | |
denoised_fn=None, | |
cond_fn=None, | |
model_kwargs=None, | |
device=None, | |
progress=False, | |
): | |
final = None | |
for sample in self.p_sample_loop_progressive( | |
model, | |
shape, | |
noise=noise, | |
clip_denoised=clip_denoised, | |
denoised_fn=denoised_fn, | |
cond_fn=cond_fn, | |
model_kwargs=model_kwargs, | |
device=device, | |
progress=progress, | |
): | |
final = sample | |
return final["sample"] | |
def p_sample_loop_progressive( | |
self, | |
model, | |
shape, | |
noise=None, | |
clip_denoised=True, | |
denoised_fn=None, | |
cond_fn=None, | |
model_kwargs=None, | |
device=None, | |
progress=False, | |
): | |
if device is None: | |
device = next(model.parameters()).device | |
assert isinstance(shape, (tuple, list)) | |
if noise is not None: | |
img = noise | |
else: | |
img = torch.randn(*shape, device=device) | |
indices = list(range(self.num_timesteps))[::-1] | |
if progress: | |
# Lazy import so that we don't depend on tqdm. | |
from tqdm.auto import tqdm | |
indices = tqdm(indices) | |
for i in indices: | |
t = torch.tensor([i] * shape[0], device=device) | |
with torch.no_grad(): | |
out = self.p_sample( | |
model, | |
img, | |
t, | |
clip_denoised=clip_denoised, | |
denoised_fn=denoised_fn, | |
cond_fn=cond_fn, | |
model_kwargs=model_kwargs, | |
) | |
yield out | |
img = out["sample"] | |
def ddim_sample( | |
self, | |
model, | |
x, | |
t, | |
clip_denoised=True, | |
denoised_fn=None, | |
cond_fn=None, | |
model_kwargs=None, | |
eta=0.0, | |
): | |
out = self.p_mean_variance( | |
model, | |
x, | |
t, | |
clip_denoised=clip_denoised, | |
denoised_fn=denoised_fn, | |
model_kwargs=model_kwargs, | |
) | |
if cond_fn is not None: | |
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) | |
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) | |
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) | |
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) | |
sigma = ( | |
eta | |
* torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) | |
* torch.sqrt(1 - alpha_bar / alpha_bar_prev) | |
) | |
# Equation 12. | |
noise = torch.randn_like(x) | |
mean_pred = ( | |
out["pred_xstart"] * torch.sqrt(alpha_bar_prev) | |
+ torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps | |
) | |
nonzero_mask = ( | |
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) | |
) # no noise when t == 0 | |
sample = mean_pred + nonzero_mask * sigma * noise | |
return {"sample": sample, "pred_xstart": out["pred_xstart"]} | |
def ddim_reverse_sample( | |
self, | |
model, | |
x, | |
t, | |
clip_denoised=True, | |
denoised_fn=None, | |
cond_fn=None, | |
model_kwargs=None, | |
eta=0.0, | |
): | |
assert eta == 0.0, "Reverse ODE only for deterministic path" | |
out = self.p_mean_variance( | |
model, | |
x, | |
t, | |
clip_denoised=clip_denoised, | |
denoised_fn=denoised_fn, | |
model_kwargs=model_kwargs, | |
) | |
if cond_fn is not None: | |
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) | |
eps = ( | |
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x | |
- out["pred_xstart"] | |
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) | |
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) | |
# Equation 12. reversed | |
mean_pred = out["pred_xstart"] * torch.sqrt(alpha_bar_next) + torch.sqrt(1 - alpha_bar_next) * eps | |
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} | |
def ddim_sample_loop( | |
self, | |
model, | |
shape, | |
noise=None, | |
clip_denoised=True, | |
denoised_fn=None, | |
cond_fn=None, | |
model_kwargs=None, | |
device=None, | |
progress=False, | |
eta=0.0, | |
): | |
final = None | |
for sample in self.ddim_sample_loop_progressive( | |
model, | |
shape, | |
noise=noise, | |
clip_denoised=clip_denoised, | |
denoised_fn=denoised_fn, | |
cond_fn=cond_fn, | |
model_kwargs=model_kwargs, | |
device=device, | |
progress=progress, | |
eta=eta, | |
): | |
final = sample | |
return final["sample"] | |
def ddim_sample_loop_progressive( | |
self, | |
model, | |
shape, | |
noise=None, | |
clip_denoised=True, | |
denoised_fn=None, | |
cond_fn=None, | |
model_kwargs=None, | |
device=None, | |
progress=False, | |
eta=0.0, | |
): | |
if device is None: | |
device = next(model.parameters()).device | |
assert isinstance(shape, (tuple, list)) | |
if noise is not None: | |
img = noise | |
else: | |
img = torch.randn(*shape, device=device) | |
indices = list(range(self.num_timesteps))[::-1] | |
if progress: | |
# Lazy import so that we don't depend on tqdm. | |
from tqdm.auto import tqdm | |
indices = tqdm(indices) | |
for i in indices: | |
t = torch.tensor([i] * shape[0], device=device) | |
with torch.no_grad(): | |
out = self.ddim_sample( | |
model, | |
img, | |
t, | |
clip_denoised=clip_denoised, | |
denoised_fn=denoised_fn, | |
cond_fn=cond_fn, | |
model_kwargs=model_kwargs, | |
eta=eta, | |
) | |
yield out | |
img = out["sample"] | |
def _vb_terms_bpd( | |
self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None | |
): | |
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( | |
x_start=x_start, x_t=x_t, t=t | |
) | |
out = self.p_mean_variance( | |
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs | |
) | |
kl = normal_kl( | |
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] | |
) | |
kl = mean_flat(kl) / np.log(2.0) | |
decoder_nll = -discretized_gaussian_log_likelihood( | |
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] | |
) | |
assert decoder_nll.shape == x_start.shape | |
decoder_nll = mean_flat(decoder_nll) / np.log(2.0) | |
# At the first timestep return the decoder NLL, | |
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) | |
output = torch.where((t == 0), decoder_nll, kl) | |
return {"output": output, "pred_xstart": out["pred_xstart"]} | |
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): | |
if model_kwargs is None: | |
model_kwargs = {} | |
if noise is None: | |
noise = torch.randn_like(x_start) | |
x_t = self.q_sample(x_start, t, noise=noise) | |
terms = {} | |
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: | |
terms["loss"] = self._vb_terms_bpd( | |
model=model, | |
x_start=x_start, | |
x_t=x_t, | |
t=t, | |
clip_denoised=False, | |
model_kwargs=model_kwargs, | |
)["output"] | |
if self.loss_type == LossType.RESCALED_KL: | |
terms["loss"] *= self.num_timesteps | |
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: | |
model_output = model(x_t, t, **model_kwargs) | |
if self.model_var_type in [ | |
ModelVarType.LEARNED, | |
ModelVarType.LEARNED_RANGE, | |
]: | |
B, C = x_t.shape[:2] | |
assert model_output.shape == (B, C * 2, *x_t.shape[2:]) | |
model_output, model_var_values = torch.split(model_output, C, dim=1) | |
frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1) | |
terms["vb"] = self._vb_terms_bpd( | |
model=lambda *args, r=frozen_out: r, | |
x_start=x_start, | |
x_t=x_t, | |
t=t, | |
clip_denoised=False, | |
)["output"] | |
if self.loss_type == LossType.RESCALED_MSE: | |
terms["vb"] *= self.num_timesteps / 1000.0 | |
target = { | |
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( | |
x_start=x_start, x_t=x_t, t=t | |
)[0], | |
ModelMeanType.START_X: x_start, | |
ModelMeanType.EPSILON: noise, | |
}[self.model_mean_type] | |
assert model_output.shape == target.shape == x_start.shape | |
terms["mse"] = mean_flat((target - model_output) ** 2) | |
if "vb" in terms: | |
terms["loss"] = terms["mse"] + terms["vb"] | |
else: | |
terms["loss"] = terms["mse"] | |
else: | |
raise NotImplementedError(self.loss_type) | |
return terms | |
def _prior_bpd(self, x_start): | |
batch_size = x_start.shape[0] | |
t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) | |
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) | |
kl_prior = normal_kl( | |
mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 | |
) | |
return mean_flat(kl_prior) / np.log(2.0) | |
def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): | |
device = x_start.device | |
batch_size = x_start.shape[0] | |
vb = [] | |
xstart_mse = [] | |
mse = [] | |
for t in list(range(self.num_timesteps))[::-1]: | |
t_batch = torch.tensor([t] * batch_size, device=device) | |
noise = torch.randn_like(x_start) | |
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) | |
# Calculate VLB term at the current timestep | |
with torch.no_grad(): | |
out = self._vb_terms_bpd( | |
model, | |
x_start=x_start, | |
x_t=x_t, | |
t=t_batch, | |
clip_denoised=clip_denoised, | |
model_kwargs=model_kwargs, | |
) | |
vb.append(out["output"]) | |
xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) | |
eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) | |
mse.append(mean_flat((eps - noise) ** 2)) | |
vb = torch.stack(vb, dim=1) | |
xstart_mse = torch.stack(xstart_mse, dim=1) | |
mse = torch.stack(mse, dim=1) | |
prior_bpd = self._prior_bpd(x_start) | |
total_bpd = vb.sum(dim=1) + prior_bpd | |
return { | |
"total_bpd": total_bpd, | |
"prior_bpd": prior_bpd, | |
"vb": vb, | |
"xstart_mse": xstart_mse, | |
"mse": mse, | |
} | |
def _extract_into_tensor(arr, timesteps, broadcast_shape): | |
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float() | |
while len(res.shape) < len(broadcast_shape): | |
res = res[..., None] | |
return res + torch.zeros(broadcast_shape, device=timesteps.device) | |
############################### Denoising Diffusion Probabilistic Model################################### | |
class DDPMSampler: | |
def __init__(self, generator: torch.Generator, num_training_steps=1000, beta_start: float = 0.00085, beta_end: float = 0.0120): | |
self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_training_steps, dtype=torch.float32) ** 2 | |
self.alphas = 1.0 - self.betas | |
self.alphas_cumprod = torch.cumprod(self.alphas, d_model=0) | |
self.one = torch.tensor(1.0) | |
self.generator = generator | |
self.num_train_timesteps = num_training_steps | |
self.timesteps = torch.from_numpy(np.arange(0, num_training_steps)[::-1].copy()) | |
def set_inference_timesteps(self, num_inference_steps=50): | |
self.num_inference_steps = num_inference_steps | |
step_ratio = self.num_train_timesteps // self.num_inference_steps | |
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) | |
self.timesteps = torch.from_numpy(timesteps) | |
def _get_previous_timestep(self, timestep: int) -> int: | |
prev_t = timestep - self.num_train_timesteps // self.num_inference_steps | |
return prev_t | |
def _get_variance(self, timestep: int) -> torch.Tensor: | |
prev_t = self._get_previous_timestep(timestep) | |
alpha_prod_t = self.alphas_cumprod[timestep] | |
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one | |
current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev | |
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t | |
variance = torch.clamp(variance, min=1e-20) | |
return variance | |
def set_strength(self, strength=1): | |
""" | |
Set how much noise to add to the input image. | |
More noise (strength ~ 1) means that the output will be further from the input image. | |
Less noise (strength ~ 0) means that the output will be closer to the input image. | |
""" | |
start_step = self.num_inference_steps - int(self.num_inference_steps * strength) | |
self.timesteps = self.timesteps[start_step:] | |
self.start_step = start_step | |
def step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor): | |
t = timestep | |
prev_t = self._get_previous_timestep(t) | |
alpha_prod_t = self.alphas_cumprod[t] | |
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one | |
beta_prod_t = 1 - alpha_prod_t | |
beta_prod_t_prev = 1 - alpha_prod_t_prev | |
current_alpha_t = alpha_prod_t / alpha_prod_t_prev | |
current_beta_t = 1 - current_alpha_t | |
pred_original_sample = (latents - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) | |
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t | |
current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t | |
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents | |
variance = 0 | |
if t > 0: | |
device = model_output.device | |
noise = torch.randn(model_output.shape, generator=self.generator, device=device, dtype=model_output.dtype) | |
variance = (self._get_variance(t) ** 0.5) * noise | |
pred_prev_sample = pred_prev_sample + variance | |
return pred_prev_sample | |
def add_noise( | |
self, | |
original_samples: torch.FloatTensor, | |
timesteps: torch.IntTensor, | |
) -> torch.FloatTensor: | |
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) | |
timesteps = timesteps.to(original_samples.device) | |
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 | |
sqrt_alpha_prod = sqrt_alpha_prod.flatten() | |
while len(sqrt_alpha_prod.shape) < len(original_samples.shape): | |
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) | |
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 | |
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() | |
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): | |
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) | |
noise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype) | |
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise | |
return noisy_samples |