komodel / modules /diffusion /karras /karras_diffusion.py
RMSnow's picture
add backend inference and inferface output
0883aa1
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Based on: https://github.com/crowsonkb/k-diffusion
"""
import random
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
# from piq import LPIPS
from utils.ssim import SSIM
from modules.diffusion.karras.random_utils import get_generator
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
)
return x[(...,) + (None,) * dims_to_append]
def append_zero(x):
return th.cat([x, x.new_zeros([1])])
def get_weightings(weight_schedule, snrs, sigma_data):
if weight_schedule == "snr":
weightings = snrs
elif weight_schedule == "snr+1":
weightings = snrs + 1
elif weight_schedule == "karras":
weightings = snrs + 1.0 / sigma_data**2
elif weight_schedule == "truncated-snr":
weightings = th.clamp(snrs, min=1.0)
elif weight_schedule == "uniform":
weightings = th.ones_like(snrs)
else:
raise NotImplementedError()
return weightings
class KarrasDenoiser:
def __init__(
self,
sigma_data: float = 0.5,
sigma_max=80.0,
sigma_min=0.002,
rho=7.0,
weight_schedule="karras",
distillation=False,
loss_norm="l2",
):
self.sigma_data = sigma_data
self.sigma_max = sigma_max
self.sigma_min = sigma_min
self.weight_schedule = weight_schedule
self.distillation = distillation
self.loss_norm = loss_norm
# if loss_norm == "lpips":
# self.lpips_loss = LPIPS(replace_pooling=True, reduction="none")
if loss_norm == "ssim":
self.ssim_loss = SSIM()
self.rho = rho
self.num_timesteps = 40
def get_snr(self, sigmas):
return sigmas**-2
def get_sigmas(self, sigmas):
return sigmas
def get_scalings(self, sigma):
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
return c_skip, c_out, c_in
def get_scalings_for_boundary_condition(self, sigma):
c_skip = self.sigma_data**2 / (
(sigma - self.sigma_min) ** 2 + self.sigma_data**2
)
c_out = (
(sigma - self.sigma_min)
* self.sigma_data
/ (sigma**2 + self.sigma_data**2) ** 0.5
)
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
return c_skip, c_out, c_in
def training_losses(self, model, x_start, sigmas, condition=None, noise=None):
if noise is None:
noise = th.randn_like(x_start)
terms = {}
dims = x_start.ndim
x_t = x_start + noise * append_dims(sigmas, dims)
model_output, denoised = self.denoise(model, x_t, sigmas, condition)
snrs = self.get_snr(sigmas)
weights = append_dims(
get_weightings(self.weight_schedule, snrs, self.sigma_data), dims
)
# terms["xs_mse"] = mean_flat((denoised - x_start) ** 2)
terms["mse"] = mean_flat(weights * (denoised - x_start) ** 2)
# terms["mae"] = mean_flat(weights * th.abs(denoised - x_start))
# terms["mse"] = nn.MSELoss(reduction="none")(denoised, x_start)
# if "vb" in terms:
# terms["loss"] = terms["mse"] + terms["vb"]
# else:
terms["loss"] = terms["mse"]
return terms
def consistency_losses(
self,
model,
x_start,
num_scales,
# model_kwargs=None,
condition=None,
target_model=None,
teacher_model=None,
teacher_diffusion=None,
noise=None,
):
if noise is None:
noise = th.randn_like(x_start)
dims = x_start.ndim
def denoise_fn(x, t):
return self.denoise(model, x, t, condition)[1]
if target_model:
@th.no_grad()
def target_denoise_fn(x, t):
return self.denoise(target_model, x, t, condition)[1]
else:
raise NotImplementedError("Must have a target model")
if teacher_model:
@th.no_grad()
def teacher_denoise_fn(x, t):
return teacher_diffusion.denoise(teacher_model, x, t, condition)[1]
@th.no_grad()
def heun_solver(samples, t, next_t, x0):
x = samples
if teacher_model is None:
denoiser = x0
else:
denoiser = teacher_denoise_fn(x, t)
d = (x - denoiser) / append_dims(t, dims)
samples = x + d * append_dims(next_t - t, dims)
if teacher_model is None:
denoiser = x0
else:
denoiser = teacher_denoise_fn(samples, next_t)
next_d = (samples - denoiser) / append_dims(next_t, dims)
samples = x + (d + next_d) * append_dims((next_t - t) / 2, dims)
return samples
@th.no_grad()
def euler_solver(samples, t, next_t, x0):
x = samples
if teacher_model is None:
denoiser = x0
else:
denoiser = teacher_denoise_fn(x, t)
d = (x - denoiser) / append_dims(t, dims)
samples = x + d * append_dims(next_t - t, dims)
return samples
indices = th.randint(
0, num_scales - 1, (x_start.shape[0],), device=x_start.device
)
t = self.sigma_max ** (1 / self.rho) + indices / (num_scales - 1) * (
self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho)
)
t = t**self.rho
t2 = self.sigma_max ** (1 / self.rho) + (indices + 1) / (num_scales - 1) * (
self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho)
)
t2 = t2**self.rho
x_t = x_start + noise * append_dims(t, dims)
dropout_state = th.get_rng_state()
distiller = denoise_fn(x_t, t)
if teacher_model is None:
x_t2 = euler_solver(x_t, t, t2, x_start).detach()
else:
x_t2 = heun_solver(x_t, t, t2, x_start).detach()
th.set_rng_state(dropout_state)
distiller_target = target_denoise_fn(x_t2, t2)
distiller_target = distiller_target.detach()
snrs = self.get_snr(t)
weights = get_weightings(self.weight_schedule, snrs, self.sigma_data)
if self.loss_norm == "l1":
diffs = th.abs(distiller - distiller_target)
loss = mean_flat(diffs) * weights
elif self.loss_norm == "l2":
# diffs = (distiller - distiller_target) ** 2
loss = F.mse_loss(distiller, distiller_target)
# loss = mean_flat(diffs) * weights
elif self.loss_norm == "ssim":
loss = self.ssim_loss(distiller, distiller_target) * weights
# elif self.loss_norm == "l2-32":
# distiller = F.interpolate(distiller, size=32, mode="bilinear")
# distiller_target = F.interpolate(
# distiller_target,
# size=32,
# mode="bilinear",
# )
# diffs = (distiller - distiller_target) ** 2
# loss = mean_flat(diffs) * weights
# elif self.loss_norm == "lpips":
# if x_start.shape[-1] < 256:
# distiller = F.interpolate(distiller, size=224, mode="bilinear")
# distiller_target = F.interpolate(
# distiller_target, size=224, mode="bilinear"
# )
# loss = (
# self.lpips_loss(
# (distiller + 1) / 2.0,
# (distiller_target + 1) / 2.0,
# )
# * weights
# )
else:
raise ValueError(f"Unknown loss norm {self.loss_norm}")
terms = {}
terms["loss"] = loss
return terms
# def progdist_losses(
# self,
# model,
# x_start,
# num_scales,
# model_kwargs=None,
# teacher_model=None,
# teacher_diffusion=None,
# noise=None,
# ):
# if model_kwargs is None:
# model_kwargs = {}
# if noise is None:
# noise = th.randn_like(x_start)
# dims = x_start.ndim
# def denoise_fn(x, t):
# return self.denoise(model, x, t, **model_kwargs)[1]
# @th.no_grad()
# def teacher_denoise_fn(x, t):
# return teacher_diffusion.denoise(teacher_model, x, t, **model_kwargs)[1]
# @th.no_grad()
# def euler_solver(samples, t, next_t):
# x = samples
# denoiser = teacher_denoise_fn(x, t)
# d = (x - denoiser) / append_dims(t, dims)
# samples = x + d * append_dims(next_t - t, dims)
# return samples
# @th.no_grad()
# def euler_to_denoiser(x_t, t, x_next_t, next_t):
# denoiser = x_t - append_dims(t, dims) * (x_next_t - x_t) / append_dims(
# next_t - t, dims
# )
# return denoiser
# indices = th.randint(0, num_scales, (x_start.shape[0],), device=x_start.device)
# t = self.sigma_max ** (1 / self.rho) + indices / num_scales * (
# self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho)
# )
# t = t**self.rho
# t2 = self.sigma_max ** (1 / self.rho) + (indices + 0.5) / num_scales * (
# self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho)
# )
# t2 = t2**self.rho
# t3 = self.sigma_max ** (1 / self.rho) + (indices + 1) / num_scales * (
# self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho)
# )
# t3 = t3**self.rho
# x_t = x_start + noise * append_dims(t, dims)
# denoised_x = denoise_fn(x_t, t)
# x_t2 = euler_solver(x_t, t, t2).detach()
# x_t3 = euler_solver(x_t2, t2, t3).detach()
# target_x = euler_to_denoiser(x_t, t, x_t3, t3).detach()
# snrs = self.get_snr(t)
# weights = get_weightings(self.weight_schedule, snrs, self.sigma_data)
# if self.loss_norm == "l1":
# diffs = th.abs(denoised_x - target_x)
# loss = mean_flat(diffs) * weights
# elif self.loss_norm == "l2":
# diffs = (denoised_x - target_x) ** 2
# loss = mean_flat(diffs) * weights
# elif self.loss_norm == "lpips":
# if x_start.shape[-1] < 256:
# denoised_x = F.interpolate(denoised_x, size=224, mode="bilinear")
# target_x = F.interpolate(target_x, size=224, mode="bilinear")
# loss = (
# self.lpips_loss(
# (denoised_x + 1) / 2.0,
# (target_x + 1) / 2.0,
# )
# * weights
# )
# else:
# raise ValueError(f"Unknown loss norm {self.loss_norm}")
# terms = {}
# terms["loss"] = loss
# return terms
def denoise(self, model, x_t, sigmas, condition):
if not self.distillation:
c_skip, c_out, c_in = [
append_dims(x, x_t.ndim) for x in self.get_scalings(sigmas)
]
else:
c_skip, c_out, c_in = [
append_dims(x, x_t.ndim)
for x in self.get_scalings_for_boundary_condition(sigmas)
]
rescaled_t = 1000 * 0.25 * th.log(sigmas + 1e-44)
# rescaled_t = rescaled_t[:, None]
model_output = model(c_in * x_t, rescaled_t, condition)
denoised = c_out * model_output + c_skip * x_t
return model_output, denoised
def karras_sample(
diffusion,
model,
shape,
steps,
clip_denoised=True,
progress=True,
callback=None,
# model_kwargs=None,
condition=None,
device=None,
sigma_min=0.002,
sigma_max=80, # higher for highres?
rho=7.0,
sampler="heun",
s_churn=0.0,
s_tmin=0.0,
s_tmax=float("inf"),
s_noise=1.0,
generator=None,
ts=None,
):
if generator is None:
generator = get_generator("dummy")
if sampler == "progdist":
sigmas = get_sigmas_karras(steps + 1, sigma_min, sigma_max, rho, device=device)
else:
sigmas = get_sigmas_karras(steps, sigma_min, sigma_max, rho, device=device)
th.manual_seed(42)
x_T = generator.randn(*shape, device=device) * sigma_max
sigmas = sigmas.unsqueeze(-1)
sample_fn = {
"heun": sample_heun,
"dpm": sample_dpm,
"ancestral": sample_euler_ancestral,
"onestep": sample_onestep,
"progdist": sample_progdist,
"euler": sample_euler,
"multistep": stochastic_iterative_sampler,
}[sampler]
if sampler in ["heun", "dpm"]:
sampler_args = dict(
s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise
)
elif sampler == "multistep":
sampler_args = dict(
ts=ts, t_min=sigma_min, t_max=sigma_max, rho=diffusion.rho, steps=steps
)
else:
sampler_args = {}
def denoiser(x_t, sigma):
_, denoised = diffusion.denoise(model, x_t, sigma, condition)
if clip_denoised:
denoised = denoised.clamp(-1, 1)
return denoised
x_0 = sample_fn(
denoiser,
x_T,
sigmas,
generator,
progress=progress,
callback=callback,
**sampler_args,
)
return x_0.clamp(-1, 1)
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu"):
"""Constructs the noise schedule of Karras et al. (2022)."""
ramp = th.linspace(0, 1, n)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return append_zero(sigmas).to(device)
def to_d(x, sigma, denoised):
"""Converts a denoiser output to a Karras ODE derivative."""
return (x - denoised) / append_dims(sigma, x.ndim)
def get_ancestral_step(sigma_from, sigma_to):
"""Calculates the noise level (sigma_down) to step down to and the amount
of noise to add (sigma_up) when doing an ancestral sampling step."""
sigma_up = (
sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2
) ** 0.5
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
return sigma_down, sigma_up
@th.no_grad()
def sample_euler_ancestral(model, x, sigmas, generator, progress=False, callback=None):
"""Ancestral sampling with Euler method steps."""
s_in = x.new_ones([x.shape[0]])
indices = range(len(sigmas) - 1)
if progress:
from tqdm.auto import tqdm
indices = tqdm(indices)
for i in indices:
denoised = model(x, sigmas[i] * s_in)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
if callback is not None:
callback(
{
"x": x,
"i": i,
"sigma": sigmas[i],
"sigma_hat": sigmas[i],
"denoised": denoised,
}
)
d = to_d(x, sigmas[i], denoised)
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
x = x + generator.randn_like(x) * sigma_up
return x
@th.no_grad()
def sample_midpoint_ancestral(model, x, ts, generator, progress=False, callback=None):
"""Ancestral sampling with midpoint method steps."""
s_in = x.new_ones([x.shape[0]])
step_size = 1 / len(ts)
if progress:
from tqdm.auto import tqdm
ts = tqdm(ts)
for tn in ts:
dn = model(x, tn * s_in)
dn_2 = model(x + (step_size / 2) * dn, (tn + step_size / 2) * s_in)
x = x + step_size * dn_2
if callback is not None:
callback({"x": x, "tn": tn, "dn": dn, "dn_2": dn_2})
return x
@th.no_grad()
def sample_heun(
denoiser,
x,
sigmas,
generator,
progress=False,
callback=None,
s_churn=0.0,
s_tmin=0.0,
s_tmax=float("inf"),
s_noise=1.0,
):
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
s_in = x.new_ones([x.shape[0]])
indices = range(len(sigmas) - 1)
if progress:
from tqdm.auto import tqdm
indices = tqdm(indices)
for i in indices:
gamma = (
min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
if s_tmin <= sigmas[i] <= s_tmax
else 0.0
)
eps = generator.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
denoised = denoiser(x, sigma_hat * s_in)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback(
{
"x": x,
"i": i,
"sigma": sigmas[i],
"sigma_hat": sigma_hat,
"denoised": denoised,
}
)
dt = sigmas[i + 1] - sigma_hat
if sigmas[i + 1] == 0:
# Euler method
x = x + d * dt
else:
# Heun's method
x_2 = x + d * dt
denoised_2 = denoiser(x_2, sigmas[i + 1] * s_in)
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
d_prime = (d + d_2) / 2
x = x + d_prime * dt
return x
@th.no_grad()
def sample_euler(
denoiser,
x,
sigmas,
generator,
progress=False,
callback=None,
):
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
s_in = x.new_ones([x.shape[0]])
indices = range(len(sigmas) - 1)
if progress:
from tqdm.auto import tqdm
indices = tqdm(indices)
for i in indices:
sigma = sigmas[i]
denoised = denoiser(x, sigma * s_in)
d = to_d(x, sigma, denoised)
if callback is not None:
callback(
{
"x": x,
"i": i,
"sigma": sigmas[i],
"denoised": denoised,
}
)
dt = sigmas[i + 1] - sigma
x = x + d * dt
return x
@th.no_grad()
def sample_dpm(
denoiser,
x,
sigmas,
generator,
progress=False,
callback=None,
s_churn=0.0,
s_tmin=0.0,
s_tmax=float("inf"),
s_noise=1.0,
):
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
s_in = x.new_ones([x.shape[0]])
indices = range(len(sigmas) - 1)
if progress:
from tqdm.auto import tqdm
indices = tqdm(indices)
for i in indices:
gamma = (
min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
if s_tmin <= sigmas[i] <= s_tmax
else 0.0
)
eps = generator.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
denoised = denoiser(x, sigma_hat * s_in)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback(
{
"x": x,
"i": i,
"sigma": sigmas[i],
"sigma_hat": sigma_hat,
"denoised": denoised,
}
)
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3
dt_1 = sigma_mid - sigma_hat
dt_2 = sigmas[i + 1] - sigma_hat
x_2 = x + d * dt_1
denoised_2 = denoiser(x_2, sigma_mid * s_in)
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
return x
@th.no_grad()
def sample_onestep(
distiller,
x,
sigmas,
generator=None,
progress=False,
callback=None,
):
"""Single-step generation from a distilled model."""
s_in = x.new_ones([x.shape[0]])
return distiller(x, sigmas[0] * s_in)
@th.no_grad()
def stochastic_iterative_sampler(
distiller,
x,
sigmas,
generator,
ts,
progress=False,
callback=None,
t_min=0.002,
t_max=80.0,
rho=7.0,
steps=40,
):
t_max_rho = t_max ** (1 / rho)
t_min_rho = t_min ** (1 / rho)
s_in = x.new_ones([x.shape[0]])
for i in range(len(ts) - 1):
t = (t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
x0 = distiller(x, t * s_in)
next_t = (t_max_rho + ts[i + 1] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
next_t = np.clip(next_t, t_min, t_max)
x = x0 + generator.randn_like(x) * np.sqrt(next_t**2 - t_min**2)
return x
@th.no_grad()
def sample_progdist(
denoiser,
x,
sigmas,
generator=None,
progress=False,
callback=None,
):
s_in = x.new_ones([x.shape[0]])
sigmas = sigmas[:-1] # skip the zero sigma
indices = range(len(sigmas) - 1)
if progress:
from tqdm.auto import tqdm
indices = tqdm(indices)
for i in indices:
sigma = sigmas[i]
denoised = denoiser(x, sigma * s_in)
d = to_d(x, sigma, denoised)
if callback is not None:
callback(
{
"x": x,
"i": i,
"sigma": sigma,
"denoised": denoised,
}
)
dt = sigmas[i + 1] - sigma
x = x + d * dt
return x
# @th.no_grad()
# def iterative_colorization(
# distiller,
# images,
# x,
# ts,
# t_min=0.002,
# t_max=80.0,
# rho=7.0,
# steps=40,
# generator=None,
# ):
# def obtain_orthogonal_matrix():
# vector = np.asarray([0.2989, 0.5870, 0.1140])
# vector = vector / np.linalg.norm(vector)
# matrix = np.eye(3)
# matrix[:, 0] = vector
# matrix = np.linalg.qr(matrix)[0]
# if np.sum(matrix[:, 0]) < 0:
# matrix = -matrix
# return matrix
# Q = th.from_numpy(obtain_orthogonal_matrix()).to(dist_util.dev()).to(th.float32)
# mask = th.zeros(*x.shape[1:], device=dist_util.dev())
# mask[0, ...] = 1.0
# def replacement(x0, x1):
# x0 = th.einsum("bchw,cd->bdhw", x0, Q)
# x1 = th.einsum("bchw,cd->bdhw", x1, Q)
# x_mix = x0 * mask + x1 * (1.0 - mask)
# x_mix = th.einsum("bdhw,cd->bchw", x_mix, Q)
# return x_mix
# t_max_rho = t_max ** (1 / rho)
# t_min_rho = t_min ** (1 / rho)
# s_in = x.new_ones([x.shape[0]])
# images = replacement(images, th.zeros_like(images))
# for i in range(len(ts) - 1):
# t = (t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
# x0 = distiller(x, t * s_in)
# x0 = th.clamp(x0, -1.0, 1.0)
# x0 = replacement(images, x0)
# next_t = (t_max_rho + ts[i + 1] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
# next_t = np.clip(next_t, t_min, t_max)
# x = x0 + generator.randn_like(x) * np.sqrt(next_t**2 - t_min**2)
# return x, images
# @th.no_grad()
# def iterative_inpainting(
# distiller,
# images,
# x,
# ts,
# t_min=0.002,
# t_max=80.0,
# rho=7.0,
# steps=40,
# generator=None,
# ):
# from PIL import Image, ImageDraw, ImageFont
# image_size = x.shape[-1]
# # create a blank image with a white background
# img = Image.new("RGB", (image_size, image_size), color="white")
# # get a drawing context for the image
# draw = ImageDraw.Draw(img)
# # load a font
# font = ImageFont.truetype("arial.ttf", 250)
# # draw the letter "C" in black
# draw.text((50, 0), "S", font=font, fill=(0, 0, 0))
# # convert the image to a numpy array
# img_np = np.array(img)
# img_np = img_np.transpose(2, 0, 1)
# img_th = th.from_numpy(img_np).to(dist_util.dev())
# mask = th.zeros(*x.shape, device=dist_util.dev())
# mask = mask.reshape(-1, 7, 3, image_size, image_size)
# mask[::2, :, img_th > 0.5] = 1.0
# mask[1::2, :, img_th < 0.5] = 1.0
# mask = mask.reshape(-1, 3, image_size, image_size)
# def replacement(x0, x1):
# x_mix = x0 * mask + x1 * (1 - mask)
# return x_mix
# t_max_rho = t_max ** (1 / rho)
# t_min_rho = t_min ** (1 / rho)
# s_in = x.new_ones([x.shape[0]])
# images = replacement(images, -th.ones_like(images))
# for i in range(len(ts) - 1):
# t = (t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
# x0 = distiller(x, t * s_in)
# x0 = th.clamp(x0, -1.0, 1.0)
# x0 = replacement(images, x0)
# next_t = (t_max_rho + ts[i + 1] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
# next_t = np.clip(next_t, t_min, t_max)
# x = x0 + generator.randn_like(x) * np.sqrt(next_t**2 - t_min**2)
# return x, images
# @th.no_grad()
# def iterative_superres(
# distiller,
# images,
# x,
# ts,
# t_min=0.002,
# t_max=80.0,
# rho=7.0,
# steps=40,
# generator=None,
# ):
# patch_size = 8
# def obtain_orthogonal_matrix():
# vector = np.asarray([1] * patch_size**2)
# vector = vector / np.linalg.norm(vector)
# matrix = np.eye(patch_size**2)
# matrix[:, 0] = vector
# matrix = np.linalg.qr(matrix)[0]
# if np.sum(matrix[:, 0]) < 0:
# matrix = -matrix
# return matrix
# Q = th.from_numpy(obtain_orthogonal_matrix()).to(dist_util.dev()).to(th.float32)
# image_size = x.shape[-1]
# def replacement(x0, x1):
# x0_flatten = (
# x0.reshape(-1, 3, image_size, image_size)
# .reshape(
# -1,
# 3,
# image_size // patch_size,
# patch_size,
# image_size // patch_size,
# patch_size,
# )
# .permute(0, 1, 2, 4, 3, 5)
# .reshape(-1, 3, image_size**2 // patch_size**2, patch_size**2)
# )
# x1_flatten = (
# x1.reshape(-1, 3, image_size, image_size)
# .reshape(
# -1,
# 3,
# image_size // patch_size,
# patch_size,
# image_size // patch_size,
# patch_size,
# )
# .permute(0, 1, 2, 4, 3, 5)
# .reshape(-1, 3, image_size**2 // patch_size**2, patch_size**2)
# )
# x0 = th.einsum("bcnd,de->bcne", x0_flatten, Q)
# x1 = th.einsum("bcnd,de->bcne", x1_flatten, Q)
# x_mix = x0.new_zeros(x0.shape)
# x_mix[..., 0] = x0[..., 0]
# x_mix[..., 1:] = x1[..., 1:]
# x_mix = th.einsum("bcne,de->bcnd", x_mix, Q)
# x_mix = (
# x_mix.reshape(
# -1,
# 3,
# image_size // patch_size,
# image_size // patch_size,
# patch_size,
# patch_size,
# )
# .permute(0, 1, 2, 4, 3, 5)
# .reshape(-1, 3, image_size, image_size)
# )
# return x_mix
# def average_image_patches(x):
# x_flatten = (
# x.reshape(-1, 3, image_size, image_size)
# .reshape(
# -1,
# 3,
# image_size // patch_size,
# patch_size,
# image_size // patch_size,
# patch_size,
# )
# .permute(0, 1, 2, 4, 3, 5)
# .reshape(-1, 3, image_size**2 // patch_size**2, patch_size**2)
# )
# x_flatten[..., :] = x_flatten.mean(dim=-1, keepdim=True)
# return (
# x_flatten.reshape(
# -1,
# 3,
# image_size // patch_size,
# image_size // patch_size,
# patch_size,
# patch_size,
# )
# .permute(0, 1, 2, 4, 3, 5)
# .reshape(-1, 3, image_size, image_size)
# )
# t_max_rho = t_max ** (1 / rho)
# t_min_rho = t_min ** (1 / rho)
# s_in = x.new_ones([x.shape[0]])
# images = average_image_patches(images)
# for i in range(len(ts) - 1):
# t = (t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
# x0 = distiller(x, t * s_in)
# x0 = th.clamp(x0, -1.0, 1.0)
# x0 = replacement(images, x0)
# next_t = (t_max_rho + ts[i + 1] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
# next_t = np.clip(next_t, t_min, t_max)
# x = x0 + generator.randn_like(x) * np.sqrt(next_t**2 - t_min**2)
# return x, images