pix2pix-zero-01 / src /inversion.py
ysharma's picture
ysharma HF staff
upload git code base
d950775
raw
history blame
2.49 kB
import os, pdb
import argparse
import numpy as np
import torch
import requests
from PIL import Image
from lavis.models import load_model_and_preprocess
from utils.ddim_inv import DDIMInversion
from utils.scheduler import DDIMInverseScheduler
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input_image', type=str, default='assets/test_images/cat_a.png')
parser.add_argument('--results_folder', type=str, default='output/test_cat')
parser.add_argument('--num_ddim_steps', type=int, default=50)
parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
parser.add_argument('--use_float_16', action='store_true')
args = parser.parse_args()
# make the output folders
os.makedirs(os.path.join(args.results_folder, "inversion"), exist_ok=True)
os.makedirs(os.path.join(args.results_folder, "prompt"), exist_ok=True)
if args.use_float_16:
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
# load the BLIP model
model_blip, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=torch.device("cuda"))
# make the DDIM inversion pipeline
pipe = DDIMInversion.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda")
pipe.scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
# if the input is a folder, collect all the images as a list
if os.path.isdir(args.input_image):
l_img_paths = sorted(glob(os.path.join(args.input_image, "*.png")))
else:
l_img_paths = [args.input_image]
for img_path in l_img_paths:
bname = os.path.basename(args.input_image).split(".")[0]
img = Image.open(args.input_image).resize((512,512), Image.Resampling.LANCZOS)
# generate the caption
_image = vis_processors["eval"](img).unsqueeze(0).cuda()
prompt_str = model_blip.generate({"image": _image})[0]
x_inv, x_inv_image, x_dec_img = pipe(
prompt_str,
guidance_scale=1,
num_inversion_steps=args.num_ddim_steps,
img=img,
torch_dtype=torch_dtype
)
# save the inversion
torch.save(x_inv[0], os.path.join(args.results_folder, f"inversion/{bname}.pt"))
# save the prompt string
with open(os.path.join(args.results_folder, f"prompt/{bname}.txt"), "w") as f:
f.write(prompt_str)