alvanlii's picture
Duplicate from alvanlii/pix2pix_zero
7e0bf18
# Copyright 2023 Adobe Research. All rights reserved.
# To view a copy of the license, visit LICENSE.md.
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False"
from PIL import Image
import torch
import gradio as gr
from lavis.models import load_model_and_preprocess
from diffusers import DDIMScheduler
from src.utils.ddim_inv import DDIMInversion
from src.utils.edit_directions import construct_direction
from src.utils.scheduler import DDIMInverseScheduler
from src.utils.edit_pipeline import EditingPipeline
def main():
NUM_DDIM_STEPS = 50
TORCH_DTYPE = torch.float16
XA_GUIDANCE = 0.1
DIR_SCALE = 1.0
MODEL_NAME = 'CompVis/stable-diffusion-v1-4'
NEGATIVE_GUIDANCE_SCALE = 5.0
DEVICE = "cuda"
# if torch.cuda.is_available():
# DEVICE = "cuda"
# else:
# DEVICE = "cpu"
# print(f"Using {DEVICE}")
model_blip, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=DEVICE)
pipe = EditingPipeline.from_pretrained(MODEL_NAME, torch_dtype=TORCH_DTYPE, safety_checker=None).to(DEVICE)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
inv_pipe = DDIMInversion.from_pretrained(MODEL_NAME, torch_dtype=TORCH_DTYPE, safety_checker=None).to("cuda")
inv_pipe.scheduler = DDIMInverseScheduler.from_config(inv_pipe.scheduler.config)
TASKS = ["dog2cat","cat2dog","horse2zebra","zebra2horse","horse2llama","dog2capy"]
TASK_OPTIONS = ["Dog to Cat", "Cat to Dog", "Horse to Zebra", "Zebra to Horse", "Horse to Llama", "Dog to Capy"]
def edit_real_image(
og_img,
task,
seed,
xa_guidance,
num_ddim_steps,
dir_scale
):
torch.cuda.manual_seed(seed)
# do inversion first, get inversion and generated prompt
curr_img = og_img.resize((512,512), Image.Resampling.LANCZOS)
_image = vis_processors["eval"](curr_img).unsqueeze(0).to(DEVICE)
prompt_str = model_blip.generate({"image": _image})[0]
x_inv, _, _ = inv_pipe(
prompt_str,
guidance_scale=1,
num_inversion_steps=NUM_DDIM_STEPS,
img=curr_img,
torch_dtype=TORCH_DTYPE
)
task_str = TASKS[task]
rec_pil, edit_pil = pipe(
prompt_str,
num_inference_steps=num_ddim_steps,
x_in=x_inv[0].unsqueeze(0),
edit_dir=construct_direction(task_str)*dir_scale,
guidance_amount=xa_guidance,
guidance_scale=NEGATIVE_GUIDANCE_SCALE,
negative_prompt=prompt_str # use the unedited prompt for the negative prompt
)
return prompt_str, edit_pil[0]
def edit_real_image_example():
test_img = Image.open("./assets/test_images/cats/cat_4.png")
seed = 42
task = 1
prompt_str, edited_img = edit_real_image(test_img, task, seed, XA_GUIDANCE, NUM_DDIM_STEPS, DIR_SCALE)
return test_img, seed, "Cat to Dog", prompt_str, edited_img, XA_GUIDANCE, NUM_DDIM_STEPS, DIR_SCALE
def edit_synthetic_image(seed, task, prompt_str, xa_guidance, num_ddim_steps):
torch.cuda.manual_seed(seed)
x = torch.randn((1,4,64,64), device="cuda")
task_str = TASKS[task]
rec_pil, edit_pil = pipe(
prompt_str,
num_inference_steps=num_ddim_steps,
x_in=x,
edit_dir=construct_direction(task_str),
guidance_amount=xa_guidance,
guidance_scale=NEGATIVE_GUIDANCE_SCALE,
negative_prompt="" # use the empty string for the negative prompt
)
return rec_pil[0], edit_pil[0]
def edit_synth_image_example():
seed = 42
task = 1
xa_guidance = XA_GUIDANCE
num_ddim_steps = NUM_DDIM_STEPS
prompt_str = "A cute white cat sitting on top of the fridge"
recon_img, edited_img = edit_synthetic_image(seed, task, prompt_str, xa_guidance, num_ddim_steps)
return seed, "Cat to Dog", xa_guidance, num_ddim_steps, prompt_str, recon_img, edited_img
with gr.Blocks() as demo:
gr.Markdown("""
### Zero-shot Image-to-Image Translation (https://github.com/pix2pixzero/pix2pix-zero)
Gaurav Parmar, Krishna Kumar Singh, Richard Zhang, Yijun Li, Jingwan Lu, Jun-Yan Zhu <br/>
- For real images:
- Upload an image of a dog, cat or horse,
- Choose one of the task options to turn it into another animal!
- Changing Parameters:
- Increase direction scale is it is not cat (or another animal) enough.
- If the quality is not high enough, increase num ddim steps.
- Increase cross attention guidance to preserve original image structures. <br/>
- For synthetic images:
- Enter a prompt about dogs/cats/horses
- Choose a task option
""")
with gr.Tab("Real Image"):
with gr.Row():
seed = gr.Number(value=42, precision=1, label="Seed", interactive=True)
real_xa_guidance = gr.Number(value=XA_GUIDANCE, label="Cross Attention Guidance", interactive=True)
real_num_ddim_steps = gr.Number(value=NUM_DDIM_STEPS, precision=1, label="Num DDIM steps", interactive=True)
real_edit_dir_scale = gr.Number(value=DIR_SCALE, label="Edit Direction Scale", interactive=True)
real_generate_button = gr.Button("Generate")
real_load_sample_button = gr.Button("Load Example")
with gr.Row():
task_name = gr.Radio(
label='Task Name',
choices=TASK_OPTIONS,
value=TASK_OPTIONS[0],
type="index",
show_label=True,
interactive=True,
)
with gr.Row():
recon_text = gr.Textbox(lines=1, label="Reconstructed Text", interactive=False)
with gr.Row():
input_image = gr.Image(label="Input Image", type="pil", interactive=True)
output_image = gr.Image(label="Output Image", type="pil", interactive=False)
with gr.Tab("Synthetic Images"):
with gr.Row():
synth_seed = gr.Number(value=42, precision=1, label="Seed", interactive=True)
synth_prompt = gr.Textbox(lines=1, label="Prompt", interactive=True)
synth_generate_button = gr.Button("Generate")
synth_load_sample_button = gr.Button("Load Example")
with gr.Row():
synth_task_name = gr.Radio(
label='Task Name',
choices=TASK_OPTIONS,
value=TASK_OPTIONS[0],
type="index",
show_label=True,
interactive=True,
)
synth_xa_guidance = gr.Number(value=XA_GUIDANCE, label="Cross Attention Guidance", interactive=True)
synth_num_ddim_steps = gr.Number(value=NUM_DDIM_STEPS, precision=1, label="Num DDIM steps", interactive=True)
with gr.Row():
synth_input_image = gr.Image(label="Input Image", type="pil", interactive=False)
synth_output_image = gr.Image(label="Output Image", type="pil", interactive=False)
real_generate_button.click(
fn=edit_real_image,
inputs=[
input_image, task_name, seed, real_xa_guidance, real_num_ddim_steps, real_edit_dir_scale
],
outputs=[recon_text, output_image]
)
real_load_sample_button.click(
fn=edit_real_image_example,
inputs=[],
outputs=[input_image, seed, task_name, recon_text, output_image, real_xa_guidance, real_num_ddim_steps, real_edit_dir_scale]
)
synth_generate_button.click(
fn=edit_synthetic_image,
inputs=[synth_seed, synth_task_name, synth_prompt, synth_xa_guidance, synth_num_ddim_steps],
outputs=[synth_input_image, synth_output_image]
)
synth_load_sample_button.click(
fn=edit_synth_image_example,
inputs=[],
outputs=[seed, synth_task_name, synth_xa_guidance, synth_num_ddim_steps, synth_prompt, synth_input_image, synth_output_image]
)
demo.queue(concurrency_count=1)
demo.launch(share=False, server_name="0.0.0.0")
if __name__ == "__main__":
main()