instruct-pix2pix / dataset_creation /generate_img_dataset.py
timbrooks's picture
Add InstructPix2Pix
2afcb7e
raw
history blame
10.7 kB
import argparse
import json
from pathlib import Path
import k_diffusion
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange, repeat
from omegaconf import OmegaConf
from PIL import Image
from pytorch_lightning import seed_everything
from tqdm import tqdm
from stable_diffusion.ldm.modules.attention import CrossAttention
from stable_diffusion.ldm.util import instantiate_from_config
from metrics.clip_similarity import ClipSimilarity
################################################################################
# Modified K-diffusion Euler ancestral sampler with prompt-to-prompt.
# https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
return x[(...,) + (None,) * dims_to_append]
def to_d(x, sigma, denoised):
"""Converts a denoiser output to a Karras ODE derivative."""
return (x - denoised) / append_dims(sigma, x.ndim)
def get_ancestral_step(sigma_from, sigma_to):
"""Calculates the noise level (sigma_down) to step down to and the amount
of noise to add (sigma_up) when doing an ancestral sampling step."""
sigma_up = min(sigma_to, (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5)
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
return sigma_down, sigma_up
def sample_euler_ancestral(model, x, sigmas, prompt2prompt_threshold=0.0, **extra_args):
"""Ancestral sampling with Euler method steps."""
s_in = x.new_ones([x.shape[0]])
for i in range(len(sigmas) - 1):
prompt_to_prompt = prompt2prompt_threshold > i / (len(sigmas) - 2)
for m in model.modules():
if isinstance(m, CrossAttention):
m.prompt_to_prompt = prompt_to_prompt
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
d = to_d(x, sigmas[i], denoised)
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
if sigmas[i + 1] > 0:
# Make noise the same across all samples in batch.
x = x + torch.randn_like(x[:1]) * sigma_up
return x
################################################################################
def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
if vae_ckpt is not None:
print(f"Loading VAE from {vae_ckpt}")
vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"]
sd = {
k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v
for k, v in sd.items()
}
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
return model
class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cfg_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cfg_scale
def to_pil(image: torch.Tensor) -> Image.Image:
image = 255.0 * rearrange(image.cpu().numpy(), "c h w -> h w c")
image = Image.fromarray(image.astype(np.uint8))
return image
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"out_dir",
type=str,
help="Path to output dataset directory.",
)
parser.add_argument(
"prompts_file",
type=str,
help="Path to prompts .jsonl file.",
)
parser.add_argument(
"--steps",
type=int,
default=100,
help="Number of sampling steps.",
)
parser.add_argument(
"--n-samples",
type=int,
default=100,
help="Number of samples to generate per prompt (before CLIP filtering).",
)
parser.add_argument(
"--max-out-samples",
type=int,
default=4,
help="Max number of output samples to save per prompt (after CLIP filtering).",
)
parser.add_argument(
"--n-partitions",
type=int,
default=1,
help="Number of total partitions.",
)
parser.add_argument(
"--partition",
type=int,
default=0,
help="Partition index.",
)
parser.add_argument(
"--min-p2p",
type=float,
default=0.1,
help="Min prompt2prompt threshold (portion of denoising for which to fix self attention maps).",
)
parser.add_argument(
"--max-p2p",
type=float,
default=0.9,
help="Max prompt2prompt threshold (portion of denoising for which to fix self attention maps).",
)
parser.add_argument(
"--min-cfg",
type=float,
default=7.5,
help="Min classifier free guidance scale.",
)
parser.add_argument(
"--max-cfg",
type=float,
default=15,
help="Max classifier free guidance scale.",
)
parser.add_argument(
"--clip-threshold",
type=float,
default=0.2,
help="CLIP threshold for text-image similarity of each image.",
)
parser.add_argument(
"--clip-dir-threshold",
type=float,
default=0.2,
help="Directional CLIP threshold for similarity of change between pairs of text and pairs of images.",
)
parser.add_argument(
"--clip-img-threshold",
type=float,
default=0.7,
help="CLIP threshold for image-image similarity.",
)
opt = parser.parse_args()
global_seed = torch.randint(1 << 32, ()).item()
print(f"Global seed: {global_seed}")
seed_everything(global_seed)
model = load_model_from_config(
OmegaConf.load("configs/stable-diffusion/v1-inference.yaml"),
ckpt="models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt",
vae_ckpt="models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt",
)
model.cuda().eval()
model_wrap = k_diffusion.external.CompVisDenoiser(model)
clip_similarity = ClipSimilarity().cuda()
out_dir = Path(opt.out_dir)
out_dir.mkdir(exist_ok=True, parents=True)
with open(opt.prompts_file) as fp:
prompts = [json.loads(line) for line in fp]
print(f"Partition index {opt.partition} ({opt.partition + 1} / {opt.n_partitions})")
prompts = np.array_split(list(enumerate(prompts)), opt.n_partitions)[opt.partition]
with torch.no_grad(), torch.autocast("cuda"), model.ema_scope():
uncond = model.get_learned_conditioning(2 * [""])
sigmas = model_wrap.get_sigmas(opt.steps)
for i, prompt in tqdm(prompts, desc="Prompts"):
prompt_dir = out_dir.joinpath(f"{i:07d}")
prompt_dir.mkdir(exist_ok=True)
with open(prompt_dir.joinpath("prompt.json"), "w") as fp:
json.dump(prompt, fp)
cond = model.get_learned_conditioning([prompt["input"], prompt["output"]])
results = {}
with tqdm(total=opt.n_samples, desc="Samples") as progress_bar:
while len(results) < opt.n_samples:
seed = torch.randint(1 << 32, ()).item()
if seed in results:
continue
torch.manual_seed(seed)
x = torch.randn(1, 4, 512 // 8, 512 // 8, device="cuda") * sigmas[0]
x = repeat(x, "1 ... -> n ...", n=2)
model_wrap_cfg = CFGDenoiser(model_wrap)
p2p_threshold = opt.min_p2p + torch.rand(()).item() * (opt.max_p2p - opt.min_p2p)
cfg_scale = opt.min_cfg + torch.rand(()).item() * (opt.max_cfg - opt.min_cfg)
extra_args = {"cond": cond, "uncond": uncond, "cfg_scale": cfg_scale}
samples_ddim = sample_euler_ancestral(model_wrap_cfg, x, sigmas, p2p_threshold, **extra_args)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x0 = x_samples_ddim[0]
x1 = x_samples_ddim[1]
clip_sim_0, clip_sim_1, clip_sim_dir, clip_sim_image = clip_similarity(
x0[None], x1[None], [prompt["input"]], [prompt["output"]]
)
results[seed] = dict(
image_0=to_pil(x0),
image_1=to_pil(x1),
p2p_threshold=p2p_threshold,
cfg_scale=cfg_scale,
clip_sim_0=clip_sim_0[0].item(),
clip_sim_1=clip_sim_1[0].item(),
clip_sim_dir=clip_sim_dir[0].item(),
clip_sim_image=clip_sim_image[0].item(),
)
progress_bar.update()
# CLIP filter to get best samples for each prompt.
metadata = [
(result["clip_sim_dir"], seed)
for seed, result in results.items()
if result["clip_sim_image"] >= opt.clip_img_threshold
and result["clip_sim_dir"] >= opt.clip_dir_threshold
and result["clip_sim_0"] >= opt.clip_threshold
and result["clip_sim_1"] >= opt.clip_threshold
]
metadata.sort(reverse=True)
for _, seed in metadata[: opt.max_out_samples]:
result = results[seed]
image_0 = result.pop("image_0")
image_1 = result.pop("image_1")
image_0.save(prompt_dir.joinpath(f"{seed}_0.jpg"), quality=100)
image_1.save(prompt_dir.joinpath(f"{seed}_1.jpg"), quality=100)
with open(prompt_dir.joinpath(f"metadata.jsonl"), "a") as fp:
fp.write(f"{json.dumps(dict(seed=seed, **result))}\n")
print("Done.")
if __name__ == "__main__":
main()