Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import json | |
from pathlib import Path | |
import k_diffusion | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from einops import rearrange, repeat | |
from omegaconf import OmegaConf | |
from PIL import Image | |
from pytorch_lightning import seed_everything | |
from tqdm import tqdm | |
from stable_diffusion.ldm.modules.attention import CrossAttention | |
from stable_diffusion.ldm.util import instantiate_from_config | |
from metrics.clip_similarity import ClipSimilarity | |
################################################################################ | |
# Modified K-diffusion Euler ancestral sampler with prompt-to-prompt. | |
# https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py | |
def append_dims(x, target_dims): | |
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.""" | |
dims_to_append = target_dims - x.ndim | |
if dims_to_append < 0: | |
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") | |
return x[(...,) + (None,) * dims_to_append] | |
def to_d(x, sigma, denoised): | |
"""Converts a denoiser output to a Karras ODE derivative.""" | |
return (x - denoised) / append_dims(sigma, x.ndim) | |
def get_ancestral_step(sigma_from, sigma_to): | |
"""Calculates the noise level (sigma_down) to step down to and the amount | |
of noise to add (sigma_up) when doing an ancestral sampling step.""" | |
sigma_up = min(sigma_to, (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5) | |
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 | |
return sigma_down, sigma_up | |
def sample_euler_ancestral(model, x, sigmas, prompt2prompt_threshold=0.0, **extra_args): | |
"""Ancestral sampling with Euler method steps.""" | |
s_in = x.new_ones([x.shape[0]]) | |
for i in range(len(sigmas) - 1): | |
prompt_to_prompt = prompt2prompt_threshold > i / (len(sigmas) - 2) | |
for m in model.modules(): | |
if isinstance(m, CrossAttention): | |
m.prompt_to_prompt = prompt_to_prompt | |
denoised = model(x, sigmas[i] * s_in, **extra_args) | |
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) | |
d = to_d(x, sigmas[i], denoised) | |
# Euler method | |
dt = sigma_down - sigmas[i] | |
x = x + d * dt | |
if sigmas[i + 1] > 0: | |
# Make noise the same across all samples in batch. | |
x = x + torch.randn_like(x[:1]) * sigma_up | |
return x | |
################################################################################ | |
def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False): | |
print(f"Loading model from {ckpt}") | |
pl_sd = torch.load(ckpt, map_location="cpu") | |
if "global_step" in pl_sd: | |
print(f"Global Step: {pl_sd['global_step']}") | |
sd = pl_sd["state_dict"] | |
if vae_ckpt is not None: | |
print(f"Loading VAE from {vae_ckpt}") | |
vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"] | |
sd = { | |
k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v | |
for k, v in sd.items() | |
} | |
model = instantiate_from_config(config.model) | |
m, u = model.load_state_dict(sd, strict=False) | |
if len(m) > 0 and verbose: | |
print("missing keys:") | |
print(m) | |
if len(u) > 0 and verbose: | |
print("unexpected keys:") | |
print(u) | |
return model | |
class CFGDenoiser(nn.Module): | |
def __init__(self, model): | |
super().__init__() | |
self.inner_model = model | |
def forward(self, x, sigma, uncond, cond, cfg_scale): | |
x_in = torch.cat([x] * 2) | |
sigma_in = torch.cat([sigma] * 2) | |
cond_in = torch.cat([uncond, cond]) | |
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) | |
return uncond + (cond - uncond) * cfg_scale | |
def to_pil(image: torch.Tensor) -> Image.Image: | |
image = 255.0 * rearrange(image.cpu().numpy(), "c h w -> h w c") | |
image = Image.fromarray(image.astype(np.uint8)) | |
return image | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"out_dir", | |
type=str, | |
help="Path to output dataset directory.", | |
) | |
parser.add_argument( | |
"prompts_file", | |
type=str, | |
help="Path to prompts .jsonl file.", | |
) | |
parser.add_argument( | |
"--steps", | |
type=int, | |
default=100, | |
help="Number of sampling steps.", | |
) | |
parser.add_argument( | |
"--n-samples", | |
type=int, | |
default=100, | |
help="Number of samples to generate per prompt (before CLIP filtering).", | |
) | |
parser.add_argument( | |
"--max-out-samples", | |
type=int, | |
default=4, | |
help="Max number of output samples to save per prompt (after CLIP filtering).", | |
) | |
parser.add_argument( | |
"--n-partitions", | |
type=int, | |
default=1, | |
help="Number of total partitions.", | |
) | |
parser.add_argument( | |
"--partition", | |
type=int, | |
default=0, | |
help="Partition index.", | |
) | |
parser.add_argument( | |
"--min-p2p", | |
type=float, | |
default=0.1, | |
help="Min prompt2prompt threshold (portion of denoising for which to fix self attention maps).", | |
) | |
parser.add_argument( | |
"--max-p2p", | |
type=float, | |
default=0.9, | |
help="Max prompt2prompt threshold (portion of denoising for which to fix self attention maps).", | |
) | |
parser.add_argument( | |
"--min-cfg", | |
type=float, | |
default=7.5, | |
help="Min classifier free guidance scale.", | |
) | |
parser.add_argument( | |
"--max-cfg", | |
type=float, | |
default=15, | |
help="Max classifier free guidance scale.", | |
) | |
parser.add_argument( | |
"--clip-threshold", | |
type=float, | |
default=0.2, | |
help="CLIP threshold for text-image similarity of each image.", | |
) | |
parser.add_argument( | |
"--clip-dir-threshold", | |
type=float, | |
default=0.2, | |
help="Directional CLIP threshold for similarity of change between pairs of text and pairs of images.", | |
) | |
parser.add_argument( | |
"--clip-img-threshold", | |
type=float, | |
default=0.7, | |
help="CLIP threshold for image-image similarity.", | |
) | |
opt = parser.parse_args() | |
global_seed = torch.randint(1 << 32, ()).item() | |
print(f"Global seed: {global_seed}") | |
seed_everything(global_seed) | |
model = load_model_from_config( | |
OmegaConf.load("configs/stable-diffusion/v1-inference.yaml"), | |
ckpt="models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt", | |
vae_ckpt="models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt", | |
) | |
model.cuda().eval() | |
model_wrap = k_diffusion.external.CompVisDenoiser(model) | |
clip_similarity = ClipSimilarity().cuda() | |
out_dir = Path(opt.out_dir) | |
out_dir.mkdir(exist_ok=True, parents=True) | |
with open(opt.prompts_file) as fp: | |
prompts = [json.loads(line) for line in fp] | |
print(f"Partition index {opt.partition} ({opt.partition + 1} / {opt.n_partitions})") | |
prompts = np.array_split(list(enumerate(prompts)), opt.n_partitions)[opt.partition] | |
with torch.no_grad(), torch.autocast("cuda"), model.ema_scope(): | |
uncond = model.get_learned_conditioning(2 * [""]) | |
sigmas = model_wrap.get_sigmas(opt.steps) | |
for i, prompt in tqdm(prompts, desc="Prompts"): | |
prompt_dir = out_dir.joinpath(f"{i:07d}") | |
prompt_dir.mkdir(exist_ok=True) | |
with open(prompt_dir.joinpath("prompt.json"), "w") as fp: | |
json.dump(prompt, fp) | |
cond = model.get_learned_conditioning([prompt["input"], prompt["output"]]) | |
results = {} | |
with tqdm(total=opt.n_samples, desc="Samples") as progress_bar: | |
while len(results) < opt.n_samples: | |
seed = torch.randint(1 << 32, ()).item() | |
if seed in results: | |
continue | |
torch.manual_seed(seed) | |
x = torch.randn(1, 4, 512 // 8, 512 // 8, device="cuda") * sigmas[0] | |
x = repeat(x, "1 ... -> n ...", n=2) | |
model_wrap_cfg = CFGDenoiser(model_wrap) | |
p2p_threshold = opt.min_p2p + torch.rand(()).item() * (opt.max_p2p - opt.min_p2p) | |
cfg_scale = opt.min_cfg + torch.rand(()).item() * (opt.max_cfg - opt.min_cfg) | |
extra_args = {"cond": cond, "uncond": uncond, "cfg_scale": cfg_scale} | |
samples_ddim = sample_euler_ancestral(model_wrap_cfg, x, sigmas, p2p_threshold, **extra_args) | |
x_samples_ddim = model.decode_first_stage(samples_ddim) | |
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) | |
x0 = x_samples_ddim[0] | |
x1 = x_samples_ddim[1] | |
clip_sim_0, clip_sim_1, clip_sim_dir, clip_sim_image = clip_similarity( | |
x0[None], x1[None], [prompt["input"]], [prompt["output"]] | |
) | |
results[seed] = dict( | |
image_0=to_pil(x0), | |
image_1=to_pil(x1), | |
p2p_threshold=p2p_threshold, | |
cfg_scale=cfg_scale, | |
clip_sim_0=clip_sim_0[0].item(), | |
clip_sim_1=clip_sim_1[0].item(), | |
clip_sim_dir=clip_sim_dir[0].item(), | |
clip_sim_image=clip_sim_image[0].item(), | |
) | |
progress_bar.update() | |
# CLIP filter to get best samples for each prompt. | |
metadata = [ | |
(result["clip_sim_dir"], seed) | |
for seed, result in results.items() | |
if result["clip_sim_image"] >= opt.clip_img_threshold | |
and result["clip_sim_dir"] >= opt.clip_dir_threshold | |
and result["clip_sim_0"] >= opt.clip_threshold | |
and result["clip_sim_1"] >= opt.clip_threshold | |
] | |
metadata.sort(reverse=True) | |
for _, seed in metadata[: opt.max_out_samples]: | |
result = results[seed] | |
image_0 = result.pop("image_0") | |
image_1 = result.pop("image_1") | |
image_0.save(prompt_dir.joinpath(f"{seed}_0.jpg"), quality=100) | |
image_1.save(prompt_dir.joinpath(f"{seed}_1.jpg"), quality=100) | |
with open(prompt_dir.joinpath(f"metadata.jsonl"), "a") as fp: | |
fp.write(f"{json.dumps(dict(seed=seed, **result))}\n") | |
print("Done.") | |
if __name__ == "__main__": | |
main() | |