Spaces:
Runtime error
Runtime error
from pathlib import Path | |
from math import sin, pi, sqrt | |
from functools import partial | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from easydict import EasyDict | |
from guided_diffusion.script_util import ( | |
create_model_and_diffusion, | |
model_and_diffusion_defaults, | |
NUM_CLASSES, | |
create_classifier, | |
classifier_defaults, | |
sr_create_model_and_diffusion, | |
sr_model_and_diffusion_defaults, | |
) | |
from adapt import ScoreAdapter | |
from my.registry import Registry | |
PRETRAINED_REGISTRY = Registry("pretrained") | |
device = torch.device("cuda") | |
def load_ckpt(path, **kwargs): | |
# with bf.BlobFile(path, "rb") as f: | |
# data = f.read() | |
return torch.load(path, **kwargs) | |
def pick_out_cfgs(src, target_ks): | |
return {k: src[k] for k in target_ks} | |
def m_imgnet_64(): | |
return dict( | |
attention_resolutions="32,16,8", | |
class_cond=True, | |
diffusion_steps=1000, | |
dropout=0.1, | |
image_size=64, | |
learn_sigma=True, | |
noise_schedule="cosine", | |
num_channels=192, | |
num_head_channels=64, | |
num_res_blocks=3, | |
resblock_updown=True, | |
use_new_attention_order=True, | |
use_fp16=True, | |
use_scale_shift_norm=True, | |
classifier_depth=4, | |
classifier_scale=1.0, | |
model_path="models/64x64_diffusion.pt", | |
classifier_path="models/64x64_classifier.pt", | |
) | |
def m_imgnet_128(): | |
return dict( | |
attention_resolutions="32,16,8", | |
class_cond=True, | |
diffusion_steps=1000, | |
image_size=128, | |
learn_sigma=True, | |
noise_schedule="linear", | |
num_channels=256, | |
num_heads=4, | |
num_res_blocks=2, | |
resblock_updown=True, | |
use_fp16=True, | |
use_scale_shift_norm=True, | |
classifier_scale=0.5, | |
model_path="models/128x128_diffusion.pt", | |
classifier_path="models/128x128_classifier.pt", | |
) | |
def m_imgnet_256(): | |
return dict( | |
attention_resolutions="32,16,8", | |
class_cond=True, | |
diffusion_steps=1000, | |
image_size=256, | |
learn_sigma=True, | |
noise_schedule="linear", | |
num_channels=256, | |
num_head_channels=64, | |
num_res_blocks=2, | |
resblock_updown=True, | |
use_fp16=True, | |
use_scale_shift_norm=True, | |
classifier_scale=1.0, | |
model_path="models/256x256_diffusion.pt", | |
classifier_path="models/256x256_classifier.pt" | |
) | |
def m_imgnet_256_uncond(): | |
return dict( | |
attention_resolutions="32,16,8", | |
class_cond=False, | |
diffusion_steps=1000, | |
image_size=256, | |
learn_sigma=True, | |
noise_schedule="linear", | |
num_channels=256, | |
num_head_channels=64, | |
num_res_blocks=2, | |
resblock_updown=True, | |
use_fp16=True, | |
use_scale_shift_norm=True, | |
classifier_scale=10.0, | |
model_path="models/256x256_diffusion_uncond.pt", | |
classifier_path="models/256x256_classifier.pt", | |
) | |
def m_imgnet_512(): | |
return dict( | |
attention_resolutions="32,16,8", | |
class_cond=True, | |
diffusion_steps=1000, | |
image_size=512, | |
learn_sigma=True, | |
noise_schedule="linear", | |
num_channels=256, | |
num_head_channels=64, | |
num_res_blocks=2, | |
resblock_updown=True, | |
use_fp16=False, | |
use_scale_shift_norm=True, | |
classifier_scale=4.0, | |
model_path="models/512x512_diffusion.pt", | |
classifier_path="models/512x512_classifier.pt" | |
) | |
def m_imgnet_64_256(base_samples="64_samples.npz"): | |
return dict( | |
attention_resolutions="32,16,8", | |
class_cond=True, | |
diffusion_steps=1000, | |
large_size=256, | |
small_size=64, | |
learn_sigma=True, | |
noise_schedule="linear", | |
num_channels=192, | |
num_heads=4, | |
num_res_blocks=2, | |
resblock_updown=True, | |
use_fp16=True, | |
use_scale_shift_norm=True, | |
model_path="models/64_256_upsampler.pt", | |
base_samples=base_samples, | |
) | |
def m_imgnet_128_512(base_samples="128_samples.npz",): | |
return dict( | |
attention_resolutions="32,16", | |
class_cond=True, | |
diffusion_steps=1000, | |
large_size=512, | |
small_size=128, | |
learn_sigma=True, | |
noise_schedule="linear", | |
num_channels=192, | |
num_head_channels=64, | |
num_res_blocks=2, | |
resblock_updown=True, | |
use_fp16=True, | |
use_scale_shift_norm=True, | |
model_path="models/128_512_upsampler.pt", | |
base_samples=base_samples, | |
) | |
def m_lsun_256(category="bedroom"): | |
return dict( | |
attention_resolutions="32,16,8", | |
class_cond=False, | |
diffusion_steps=1000, | |
dropout=0.1, | |
image_size=256, | |
learn_sigma=True, | |
noise_schedule="linear", | |
num_channels=256, | |
num_head_channels=64, | |
num_res_blocks=2, | |
resblock_updown=True, | |
use_fp16=True, | |
use_scale_shift_norm=True, | |
model_path=f"models/lsun_{category}.pt" | |
) | |
def img_gen(specific_cfgs, num_samples=16, batch_size=16, load_only=False, ckpt_root=Path("")): | |
cfgs = EasyDict( | |
clip_denoised=True, | |
num_samples=num_samples, | |
batch_size=batch_size, | |
use_ddim=False, | |
model_path="", | |
classifier_path="", | |
classifier_scale=1.0, | |
) | |
cfgs.update(model_and_diffusion_defaults()) | |
cfgs.update(classifier_defaults()) | |
cfgs.update(specific_cfgs) | |
use_classifier_guidance = bool(cfgs.classifier_path) | |
class_aware = cfgs.class_cond or use_classifier_guidance | |
model, diffusion = create_model_and_diffusion( | |
**pick_out_cfgs(cfgs, model_and_diffusion_defaults().keys()) | |
) | |
model.load_state_dict( | |
load_ckpt(str(ckpt_root / cfgs.model_path), map_location="cpu") | |
) | |
model.to(device) | |
if cfgs.use_fp16: | |
model.convert_to_fp16() | |
model.eval() | |
def model_fn(x, t, y=None): | |
return model(x, t, y if cfgs.class_cond else None) | |
classifier = None | |
cond_fn = None | |
if use_classifier_guidance: | |
classifier = create_classifier( | |
**pick_out_cfgs(cfgs, classifier_defaults().keys()) | |
) | |
classifier.load_state_dict( | |
load_ckpt(str(ckpt_root / cfgs.classifier_path), map_location="cpu") | |
) | |
classifier.to(device) | |
if cfgs.classifier_use_fp16: | |
classifier.convert_to_fp16() | |
classifier.eval() | |
def cond_fn(x, t, y=None): | |
assert y is not None | |
with torch.enable_grad(): | |
x_in = x.detach().requires_grad_(True) | |
logits = classifier(x_in, t) | |
log_probs = F.log_softmax(logits, dim=-1) | |
selected = log_probs[range(len(logits)), y.view(-1)] | |
return torch.autograd.grad(selected.sum(), x_in)[0] * cfgs.classifier_scale | |
if load_only: | |
return model, classifier | |
all_images = [] | |
all_labels = [] | |
while len(all_images) * cfgs.batch_size < cfgs.num_samples: | |
model_kwargs = {} | |
if class_aware: | |
classes = torch.randint( | |
low=0, high=NUM_CLASSES, size=(cfgs.batch_size,), device=device | |
) | |
model_kwargs["y"] = classes | |
sample_fn = ( | |
diffusion.p_sample_loop if not cfgs.use_ddim else diffusion.ddim_sample_loop | |
) | |
sample = sample_fn( | |
model_fn, | |
(cfgs.batch_size, 3, cfgs.image_size, cfgs.image_size), | |
clip_denoised=cfgs.clip_denoised, | |
model_kwargs=model_kwargs, | |
cond_fn=cond_fn, | |
device=device, | |
progress=True | |
) | |
sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8) | |
sample = sample.permute(0, 2, 3, 1) | |
sample = sample.contiguous() | |
all_images.append(sample.cpu().numpy()) | |
if class_aware: | |
all_labels.append(classes.cpu().numpy()) | |
arr = np.concatenate(all_images, axis=0) | |
arr = arr[:cfgs.num_samples] | |
if class_aware: | |
all_labels = np.concatenate(all_labels, axis=0) | |
all_labels = all_labels[:cfgs.num_samples] | |
shape_str = "x".join([str(x) for x in arr.shape]) | |
out_path = Path("./out") / f"samples_{shape_str}.npz" | |
np.savez(out_path, arr, all_labels) | |
def img_upsamp(specific_cfgs, num_samples=16, batch_size=16, load_only=False): | |
"""note that here the ckpt root is not configured properly; will break but easy fix""" | |
cfgs = EasyDict( | |
clip_denoised=True, | |
num_samples=num_samples, | |
batch_size=batch_size, | |
use_ddim=False, | |
base_samples="", | |
model_path="", | |
) | |
cfgs.update(sr_model_and_diffusion_defaults()) | |
cfgs.update(specific_cfgs) | |
model, diffusion = sr_create_model_and_diffusion( | |
**pick_out_cfgs(cfgs, sr_model_and_diffusion_defaults().keys()) | |
) | |
model.load_state_dict(load_ckpt(cfgs.model_path, map_location="cpu")) | |
model.to(device) | |
if cfgs.use_fp16: | |
model.convert_to_fp16() | |
model.eval() | |
if load_only: | |
return model | |
data = load_low_res_samples( | |
cfgs.base_samples, cfgs.batch_size, cfgs.class_cond | |
) | |
all_images = [] | |
while len(all_images) * cfgs.batch_size < cfgs.num_samples: | |
model_kwargs = next(data) | |
model_kwargs = {k: v.to(device) for k, v in model_kwargs.items()} | |
samples = diffusion.p_sample_loop( | |
model, | |
(cfgs.batch_size, 3, cfgs.large_size, cfgs.large_size), | |
clip_denoised=cfgs.clip_denoised, | |
model_kwargs=model_kwargs, | |
progress=True | |
) | |
samples = ((samples + 1) * 127.5).clamp(0, 255).to(torch.uint8) | |
samples = samples.permute(0, 2, 3, 1) | |
samples = samples.contiguous() | |
all_images.append(samples.cpu().numpy()) | |
arr = np.concatenate(all_images, axis=0) | |
arr = arr[: cfgs.num_samples] | |
shape_str = "x".join([str(x) for x in arr.shape]) | |
out_path = Path("./out") / f"samples_{shape_str}.npz" | |
np.savez(out_path, arr) | |
def load_low_res_samples(base_samples, batch_size, class_cond): | |
obj = np.load(base_samples) | |
image_arr = obj["arr_0"] | |
if class_cond: | |
label_arr = obj["arr_1"] | |
buffer = [] | |
label_buffer = [] | |
while True: | |
for i in range(len(image_arr)): | |
buffer.append(image_arr[i]) | |
if class_cond: | |
label_buffer.append(label_arr[i]) | |
if len(buffer) == batch_size: | |
batch = torch.from_numpy(np.stack(buffer)).float() | |
batch = batch / 127.5 - 1.0 | |
batch = batch.permute(0, 3, 1, 2) | |
res = {} | |
res["low_res"] = batch | |
if class_cond: | |
res["y"] = torch.from_numpy(np.stack(label_buffer)) | |
yield res | |
buffer, label_buffer = [], [] | |
def class_cond_info(imgnet_cat): | |
def rand_cond_fn(batch_size): | |
cats = torch.randint( | |
low=0, high=NUM_CLASSES, size=(batch_size,), device=device | |
) | |
return {"y": cats} | |
def class_specific_cond(batch_size): | |
cats = torch.tensor([imgnet_cat, ] * batch_size, device=device) | |
return {"y": cats} | |
if imgnet_cat == -1: | |
return rand_cond_fn | |
else: | |
return class_specific_cond | |
def _sqrt(x): | |
if isinstance(x, float): | |
return sqrt(x) | |
else: | |
assert isinstance(x, torch.Tensor) | |
return torch.sqrt(x) | |
class GuidedDDPM(ScoreAdapter): | |
def __init__(self, model, lsun_cat, imgnet_cat): | |
print(PRETRAINED_REGISTRY) | |
cfgs = PRETRAINED_REGISTRY.get(model)( | |
**({"category": lsun_cat} if model.startswith("m_lsun") else {}) | |
) | |
self.unet, self.classifier = img_gen( | |
cfgs, load_only=True, ckpt_root=self.checkpoint_root() / "guided_ddpm" | |
) | |
H, W = cfgs['image_size'], cfgs['image_size'] | |
self._data_shape = (3, H, W) | |
if cfgs['class_cond'] or (self.classifier is not None): | |
cond_func = class_cond_info(imgnet_cat) | |
else: | |
cond_func = lambda *args, **kwargs: {} | |
self.cond_func = cond_func | |
self._unet_is_cond = bool(cfgs['class_cond']) | |
noise_schedule = cfgs['noise_schedule'] | |
assert noise_schedule in ("linear", "cosine") | |
self.M = 1000 | |
if noise_schedule == "linear": | |
self.us = self.linear_us(self.M) | |
self._σ_min = 0.01 | |
else: | |
self.us = self.cosine_us(self.M) | |
self._σ_min = 0.0064 | |
self.noise_schedule = noise_schedule | |
self._device = next(self.unet.parameters()).device | |
def data_shape(self): | |
return self._data_shape | |
def σ_max(self): | |
return self.us[0] | |
def σ_min(self): | |
return self.us[-1] | |
def denoise(self, xs, σ, **model_kwargs): | |
N = xs.shape[0] | |
cond_t, σ = self.time_cond_vec(N, σ) | |
output = self.unet( | |
xs / _sqrt(1 + σ**2), cond_t, **model_kwargs | |
) | |
# not using the var pred | |
n_hat = torch.split(output, xs.shape[1], dim=1)[0] | |
Ds = xs - σ * n_hat | |
return Ds | |
def cond_info(self, batch_size): | |
return self.cond_func(batch_size) | |
def unet_is_cond(self): | |
return self._unet_is_cond | |
def use_cls_guidance(self): | |
return (self.classifier is not None) | |
def classifier_grad(self, xs, σ, ys): | |
N = xs.shape[0] | |
cond_t, σ = self.time_cond_vec(N, σ) | |
with torch.enable_grad(): | |
x_in = xs.detach().requires_grad_(True) | |
logits = self.classifier(x_in, cond_t) | |
log_probs = F.log_softmax(logits, dim=-1) | |
selected = log_probs[range(len(logits)), ys.view(-1)] | |
grad = torch.autograd.grad(selected.sum(), x_in)[0] | |
grad = grad * (1 / sqrt(1 + σ**2)) | |
return grad | |
def snap_t_to_nearest_tick(self, t): | |
j = np.abs(t - self.us).argmin() | |
return self.us[j], j | |
def time_cond_vec(self, N, σ): | |
if isinstance(σ, float): | |
σ, j = self.snap_t_to_nearest_tick(σ) # σ might change due to snapping | |
cond_t = (self.M - 1) - j | |
cond_t = torch.tensor([cond_t] * N, device=self.device) | |
return cond_t, σ | |
else: | |
assert isinstance(σ, torch.Tensor) | |
σ = σ.reshape(-1).cpu().numpy() | |
σs = [] | |
js = [] | |
for elem in σ: | |
_σ, _j = self.snap_t_to_nearest_tick(elem) | |
σs.append(_σ) | |
js.append((self.M - 1) - _j) | |
cond_t = torch.tensor(js, device=self.device) | |
σs = torch.tensor(σs, device=self.device, dtype=torch.float32).reshape(-1, 1, 1, 1) | |
return cond_t, σs | |
def cosine_us(M=1000): | |
assert M == 1000 | |
def α_bar(j): | |
return sin(pi / 2 * j / (M * (0.008 + 1))) ** 2 | |
us = [0, ] | |
for j in reversed(range(0, M)): # [M-1, 0], inclusive | |
u_j = sqrt(((us[-1] ** 2) + 1) / (max(α_bar(j) / α_bar(j+1), 0.001)) - 1) | |
us.append(u_j) | |
us = np.array(us) | |
us = us[1:] | |
us = us[::-1] | |
return us | |
def linear_us(M=1000): | |
assert M == 1000 | |
β_start = 0.0001 | |
β_end = 0.02 | |
βs = np.linspace(β_start, β_end, M, dtype=np.float64) | |
αs = np.cumprod(1 - βs) | |
us = np.sqrt((1 - αs) / αs) | |
us = us[::-1] | |
return us | |