Spaces:
Build error
Build error
# Copyright 2023 Adobe Research. All rights reserved. | |
# To view a copy of the license, visit LICENSE.md. | |
import os | |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False" | |
from PIL import Image | |
import torch | |
import gradio as gr | |
from lavis.models import load_model_and_preprocess | |
from diffusers import DDIMScheduler | |
from src.utils.ddim_inv import DDIMInversion | |
from src.utils.edit_directions import construct_direction | |
from src.utils.scheduler import DDIMInverseScheduler | |
from src.utils.edit_pipeline import EditingPipeline | |
def main(): | |
NUM_DDIM_STEPS = 50 | |
TORCH_DTYPE = torch.float16 | |
XA_GUIDANCE = 0.1 | |
DIR_SCALE = 1.0 | |
MODEL_NAME = 'CompVis/stable-diffusion-v1-4' | |
NEGATIVE_GUIDANCE_SCALE = 5.0 | |
DEVICE = "cuda" | |
# if torch.cuda.is_available(): | |
# DEVICE = "cuda" | |
# else: | |
# DEVICE = "cpu" | |
# print(f"Using {DEVICE}") | |
model_blip, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=DEVICE) | |
pipe = EditingPipeline.from_pretrained(MODEL_NAME, torch_dtype=TORCH_DTYPE, safety_checker=None).to(DEVICE) | |
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) | |
inv_pipe = DDIMInversion.from_pretrained(MODEL_NAME, torch_dtype=TORCH_DTYPE, safety_checker=None).to("cuda") | |
inv_pipe.scheduler = DDIMInverseScheduler.from_config(inv_pipe.scheduler.config) | |
TASKS = ["dog2cat","cat2dog","horse2zebra","zebra2horse","horse2llama","dog2capy"] | |
TASK_OPTIONS = ["Dog to Cat", "Cat to Dog", "Horse to Zebra", "Zebra to Horse", "Horse to Llama", "Dog to Capy"] | |
def edit_real_image( | |
og_img, | |
task, | |
seed, | |
xa_guidance, | |
num_ddim_steps, | |
dir_scale | |
): | |
torch.cuda.manual_seed(seed) | |
# do inversion first, get inversion and generated prompt | |
curr_img = og_img.resize((512,512), Image.Resampling.LANCZOS) | |
_image = vis_processors["eval"](curr_img).unsqueeze(0).to(DEVICE) | |
prompt_str = model_blip.generate({"image": _image})[0] | |
x_inv, _, _ = inv_pipe( | |
prompt_str, | |
guidance_scale=1, | |
num_inversion_steps=NUM_DDIM_STEPS, | |
img=curr_img, | |
torch_dtype=TORCH_DTYPE | |
) | |
task_str = TASKS[task] | |
rec_pil, edit_pil = pipe( | |
prompt_str, | |
num_inference_steps=num_ddim_steps, | |
x_in=x_inv[0].unsqueeze(0), | |
edit_dir=construct_direction(task_str)*dir_scale, | |
guidance_amount=xa_guidance, | |
guidance_scale=NEGATIVE_GUIDANCE_SCALE, | |
negative_prompt=prompt_str # use the unedited prompt for the negative prompt | |
) | |
return prompt_str, edit_pil[0] | |
def edit_real_image_example(): | |
test_img = Image.open("./assets/test_images/cats/cat_4.png") | |
seed = 42 | |
task = 1 | |
prompt_str, edited_img = edit_real_image(test_img, task, seed, XA_GUIDANCE, NUM_DDIM_STEPS, DIR_SCALE) | |
return test_img, seed, "Cat to Dog", prompt_str, edited_img, XA_GUIDANCE, NUM_DDIM_STEPS, DIR_SCALE | |
def edit_synthetic_image(seed, task, prompt_str, xa_guidance, num_ddim_steps): | |
torch.cuda.manual_seed(seed) | |
x = torch.randn((1,4,64,64), device="cuda") | |
task_str = TASKS[task] | |
rec_pil, edit_pil = pipe( | |
prompt_str, | |
num_inference_steps=num_ddim_steps, | |
x_in=x, | |
edit_dir=construct_direction(task_str), | |
guidance_amount=xa_guidance, | |
guidance_scale=NEGATIVE_GUIDANCE_SCALE, | |
negative_prompt="" # use the empty string for the negative prompt | |
) | |
return rec_pil[0], edit_pil[0] | |
def edit_synth_image_example(): | |
seed = 42 | |
task = 1 | |
xa_guidance = XA_GUIDANCE | |
num_ddim_steps = NUM_DDIM_STEPS | |
prompt_str = "A cute white cat sitting on top of the fridge" | |
recon_img, edited_img = edit_synthetic_image(seed, task, prompt_str, xa_guidance, num_ddim_steps) | |
return seed, "Cat to Dog", xa_guidance, num_ddim_steps, prompt_str, recon_img, edited_img | |
with gr.Blocks() as demo: | |
gr.Markdown(""" | |
### Zero-shot Image-to-Image Translation (https://github.com/pix2pixzero/pix2pix-zero) | |
Gaurav Parmar, Krishna Kumar Singh, Richard Zhang, Yijun Li, Jingwan Lu, Jun-Yan Zhu <br/> | |
- For real images: | |
- Upload an image of a dog, cat or horse, | |
- Choose one of the task options to turn it into another animal! | |
- Changing Parameters: | |
- Increase direction scale is it is not cat (or another animal) enough. | |
- If the quality is not high enough, increase num ddim steps. | |
- Increase cross attention guidance to preserve original image structures. <br/> | |
- For synthetic images: | |
- Enter a prompt about dogs/cats/horses | |
- Choose a task option | |
""") | |
with gr.Tab("Real Image"): | |
with gr.Row(): | |
seed = gr.Number(value=42, precision=1, label="Seed", interactive=True) | |
real_xa_guidance = gr.Number(value=XA_GUIDANCE, label="Cross Attention Guidance", interactive=True) | |
real_num_ddim_steps = gr.Number(value=NUM_DDIM_STEPS, precision=1, label="Num DDIM steps", interactive=True) | |
real_edit_dir_scale = gr.Number(value=DIR_SCALE, label="Edit Direction Scale", interactive=True) | |
real_generate_button = gr.Button("Generate") | |
real_load_sample_button = gr.Button("Load Example") | |
with gr.Row(): | |
task_name = gr.Radio( | |
label='Task Name', | |
choices=TASK_OPTIONS, | |
value=TASK_OPTIONS[0], | |
type="index", | |
show_label=True, | |
interactive=True, | |
) | |
with gr.Row(): | |
recon_text = gr.Textbox(lines=1, label="Reconstructed Text", interactive=False) | |
with gr.Row(): | |
input_image = gr.Image(label="Input Image", type="pil", interactive=True) | |
output_image = gr.Image(label="Output Image", type="pil", interactive=False) | |
with gr.Tab("Synthetic Images"): | |
with gr.Row(): | |
synth_seed = gr.Number(value=42, precision=1, label="Seed", interactive=True) | |
synth_prompt = gr.Textbox(lines=1, label="Prompt", interactive=True) | |
synth_generate_button = gr.Button("Generate") | |
synth_load_sample_button = gr.Button("Load Example") | |
with gr.Row(): | |
synth_task_name = gr.Radio( | |
label='Task Name', | |
choices=TASK_OPTIONS, | |
value=TASK_OPTIONS[0], | |
type="index", | |
show_label=True, | |
interactive=True, | |
) | |
synth_xa_guidance = gr.Number(value=XA_GUIDANCE, label="Cross Attention Guidance", interactive=True) | |
synth_num_ddim_steps = gr.Number(value=NUM_DDIM_STEPS, precision=1, label="Num DDIM steps", interactive=True) | |
with gr.Row(): | |
synth_input_image = gr.Image(label="Input Image", type="pil", interactive=False) | |
synth_output_image = gr.Image(label="Output Image", type="pil", interactive=False) | |
real_generate_button.click( | |
fn=edit_real_image, | |
inputs=[ | |
input_image, task_name, seed, real_xa_guidance, real_num_ddim_steps, real_edit_dir_scale | |
], | |
outputs=[recon_text, output_image] | |
) | |
real_load_sample_button.click( | |
fn=edit_real_image_example, | |
inputs=[], | |
outputs=[input_image, seed, task_name, recon_text, output_image, real_xa_guidance, real_num_ddim_steps, real_edit_dir_scale] | |
) | |
synth_generate_button.click( | |
fn=edit_synthetic_image, | |
inputs=[synth_seed, synth_task_name, synth_prompt, synth_xa_guidance, synth_num_ddim_steps], | |
outputs=[synth_input_image, synth_output_image] | |
) | |
synth_load_sample_button.click( | |
fn=edit_synth_image_example, | |
inputs=[], | |
outputs=[seed, synth_task_name, synth_xa_guidance, synth_num_ddim_steps, synth_prompt, synth_input_image, synth_output_image] | |
) | |
demo.queue(concurrency_count=1) | |
demo.launch(share=False, server_name="0.0.0.0") | |
if __name__ == "__main__": | |
main() | |