|
from pathlib import Path |
|
import json |
|
from math import sqrt |
|
import numpy as np |
|
import torch |
|
from abc import ABCMeta, abstractmethod |
|
|
|
|
|
class ScoreAdapter(metaclass=ABCMeta): |
|
|
|
@abstractmethod |
|
def denoise(self, xs, σ, **kwargs): |
|
pass |
|
|
|
def score(self, xs, σ, **kwargs): |
|
Ds = self.denoise(xs, σ, **kwargs) |
|
grad_log_p_t = (Ds - xs) / (σ ** 2) |
|
return grad_log_p_t |
|
|
|
@abstractmethod |
|
def data_shape(self): |
|
return (3, 256, 256) |
|
|
|
def samps_centered(self): |
|
|
|
return True |
|
|
|
@property |
|
@abstractmethod |
|
def σ_max(self): |
|
pass |
|
|
|
@property |
|
@abstractmethod |
|
def σ_min(self): |
|
pass |
|
|
|
def cond_info(self, batch_size): |
|
return {} |
|
|
|
@abstractmethod |
|
def unet_is_cond(self): |
|
return False |
|
|
|
@abstractmethod |
|
def use_cls_guidance(self): |
|
return False |
|
|
|
def classifier_grad(self, xs, σ, ys): |
|
raise NotImplementedError() |
|
|
|
@abstractmethod |
|
def snap_t_to_nearest_tick(self, t): |
|
|
|
return t, None |
|
|
|
@property |
|
def device(self): |
|
return self._device |
|
|
|
def checkpoint_root(self): |
|
"""the path at which the pretrained checkpoints are stored""" |
|
with Path(__file__).resolve().with_name("env.json").open("r") as f: |
|
root = json.load(f) |
|
return root |
|
|
|
|
|
def karras_t_schedule(ρ=7, N=10, σ_max=80, σ_min=0.002): |
|
ts = [] |
|
for i in range(N): |
|
|
|
t = ( |
|
σ_max ** (1 / ρ) + (i / (N - 1)) * (σ_min ** (1 / ρ) - σ_max ** (1 / ρ)) |
|
) ** ρ |
|
ts.append(t) |
|
return ts |
|
|
|
|
|
def power_schedule(σ_max, σ_min, num_stages): |
|
σs = np.exp(np.linspace(np.log(σ_max), np.log(σ_min), num_stages)) |
|
return σs |
|
|
|
|
|
class Karras(): |
|
|
|
@classmethod |
|
@torch.no_grad() |
|
def inference( |
|
cls, model, batch_size, num_t, *, |
|
σ_max=80, cls_scaling=1, |
|
init_xs=None, heun=True, |
|
langevin=False, |
|
S_churn=80, S_min=0.05, S_max=50, S_noise=1.003, |
|
): |
|
σ_max = min(σ_max, model.σ_max) |
|
σ_min = model.σ_min |
|
ts = karras_t_schedule(ρ=7, N=num_t, σ_max=σ_max, σ_min=σ_min) |
|
assert len(ts) == num_t |
|
ts = [model.snap_t_to_nearest_tick(t)[0] for t in ts] |
|
ts.append(0) |
|
σ_max = ts[0] |
|
|
|
cond_inputs = model.cond_info(batch_size) |
|
|
|
def compute_step(xs, σ): |
|
grad_log_p_t = model.score( |
|
xs, σ, **(cond_inputs if model.unet_is_cond() else {}) |
|
) |
|
if model.use_cls_guidance(): |
|
grad_cls = model.classifier_grad(xs, σ, cond_inputs["y"]) |
|
grad_cls = grad_cls * cls_scaling |
|
grad_log_p_t += grad_cls |
|
d_i = -1 * σ * grad_log_p_t |
|
return d_i |
|
|
|
if init_xs is not None: |
|
xs = init_xs.to(model.device) |
|
else: |
|
xs = σ_max * torch.randn( |
|
batch_size, *model.data_shape(), device=model.device |
|
) |
|
|
|
yield xs |
|
|
|
for i in range(num_t): |
|
t_i = ts[i] |
|
|
|
if langevin and (S_min < t_i and t_i < S_max): |
|
xs, t_i = cls.noise_backward_in_time( |
|
model, xs, t_i, S_noise, S_churn / num_t |
|
) |
|
|
|
Δt = ts[i+1] - t_i |
|
|
|
d_1 = compute_step(xs, σ=t_i) |
|
xs_1 = xs + Δt * d_1 |
|
|
|
|
|
if (not heun) or (ts[i+1] == 0): |
|
xs = xs_1 |
|
else: |
|
d_2 = compute_step(xs_1, σ=ts[i+1]) |
|
xs = xs + Δt * (d_1 + d_2) / 2 |
|
|
|
yield xs |
|
|
|
@staticmethod |
|
def noise_backward_in_time(model, xs, t_i, S_noise, S_churn_i): |
|
n = S_noise * torch.randn_like(xs) |
|
γ_i = min(sqrt(2)-1, S_churn_i) |
|
t_i_hat = t_i * (1 + γ_i) |
|
t_i_hat = model.snap_t_to_nearest_tick(t_i_hat)[0] |
|
xs = xs + n * sqrt(t_i_hat ** 2 - t_i ** 2) |
|
return xs, t_i_hat |
|
|
|
|
|
def test(): |
|
pass |
|
|
|
|
|
if __name__ == "__main__": |
|
test() |
|
|