# Copyright 2023 Adobe Research. All rights reserved. # To view a copy of the license, visit LICENSE.md. import os os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False" from PIL import Image import torch import gradio as gr from lavis.models import load_model_and_preprocess from diffusers import DDIMScheduler from src.utils.ddim_inv import DDIMInversion from src.utils.edit_directions import construct_direction from src.utils.scheduler import DDIMInverseScheduler from src.utils.edit_pipeline import EditingPipeline def main(): NUM_DDIM_STEPS = 50 TORCH_DTYPE = torch.float16 XA_GUIDANCE = 0.1 DIR_SCALE = 1.0 MODEL_NAME = 'CompVis/stable-diffusion-v1-4' NEGATIVE_GUIDANCE_SCALE = 5.0 DEVICE = "cuda" # if torch.cuda.is_available(): # DEVICE = "cuda" # else: # DEVICE = "cpu" # print(f"Using {DEVICE}") model_blip, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=DEVICE) pipe = EditingPipeline.from_pretrained(MODEL_NAME, torch_dtype=TORCH_DTYPE, safety_checker=None).to(DEVICE) pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) inv_pipe = DDIMInversion.from_pretrained(MODEL_NAME, torch_dtype=TORCH_DTYPE, safety_checker=None).to("cuda") inv_pipe.scheduler = DDIMInverseScheduler.from_config(inv_pipe.scheduler.config) TASKS = ["dog2cat","cat2dog","horse2zebra","zebra2horse","horse2llama","dog2capy"] TASK_OPTIONS = ["Dog to Cat", "Cat to Dog", "Horse to Zebra", "Zebra to Horse", "Horse to Llama", "Dog to Capy"] def edit_real_image( og_img, task, seed, xa_guidance, num_ddim_steps, dir_scale ): torch.cuda.manual_seed(seed) # do inversion first, get inversion and generated prompt curr_img = og_img.resize((512,512), Image.Resampling.LANCZOS) _image = vis_processors["eval"](curr_img).unsqueeze(0).to(DEVICE) prompt_str = model_blip.generate({"image": _image})[0] x_inv, _, _ = inv_pipe( prompt_str, guidance_scale=1, num_inversion_steps=NUM_DDIM_STEPS, img=curr_img, torch_dtype=TORCH_DTYPE ) task_str = TASKS[task] rec_pil, edit_pil = pipe( prompt_str, num_inference_steps=num_ddim_steps, x_in=x_inv[0].unsqueeze(0), edit_dir=construct_direction(task_str)*dir_scale, guidance_amount=xa_guidance, guidance_scale=NEGATIVE_GUIDANCE_SCALE, negative_prompt=prompt_str # use the unedited prompt for the negative prompt ) return prompt_str, edit_pil[0] def edit_real_image_example(): test_img = Image.open("./assets/test_images/cats/cat_4.png") seed = 42 task = 1 prompt_str, edited_img = edit_real_image(test_img, task, seed, XA_GUIDANCE, NUM_DDIM_STEPS, DIR_SCALE) return test_img, seed, "Cat to Dog", prompt_str, edited_img, XA_GUIDANCE, NUM_DDIM_STEPS, DIR_SCALE def edit_synthetic_image(seed, task, prompt_str, xa_guidance, num_ddim_steps): torch.cuda.manual_seed(seed) x = torch.randn((1,4,64,64), device="cuda") task_str = TASKS[task] rec_pil, edit_pil = pipe( prompt_str, num_inference_steps=num_ddim_steps, x_in=x, edit_dir=construct_direction(task_str), guidance_amount=xa_guidance, guidance_scale=NEGATIVE_GUIDANCE_SCALE, negative_prompt="" # use the empty string for the negative prompt ) return rec_pil[0], edit_pil[0] def edit_synth_image_example(): seed = 42 task = 1 xa_guidance = XA_GUIDANCE num_ddim_steps = NUM_DDIM_STEPS prompt_str = "A cute white cat sitting on top of the fridge" recon_img, edited_img = edit_synthetic_image(seed, task, prompt_str, xa_guidance, num_ddim_steps) return seed, "Cat to Dog", xa_guidance, num_ddim_steps, prompt_str, recon_img, edited_img with gr.Blocks() as demo: gr.Markdown(""" ### Zero-shot Image-to-Image Translation (https://github.com/pix2pixzero/pix2pix-zero) Gaurav Parmar, Krishna Kumar Singh, Richard Zhang, Yijun Li, Jingwan Lu, Jun-Yan Zhu
- For real images: - Upload an image of a dog, cat or horse, - Choose one of the task options to turn it into another animal! - Changing Parameters: - Increase direction scale is it is not cat (or another animal) enough. - If the quality is not high enough, increase num ddim steps. - Increase cross attention guidance to preserve original image structures.
- For synthetic images: - Enter a prompt about dogs/cats/horses - Choose a task option """) with gr.Tab("Real Image"): with gr.Row(): seed = gr.Number(value=42, precision=1, label="Seed", interactive=True) real_xa_guidance = gr.Number(value=XA_GUIDANCE, label="Cross Attention Guidance", interactive=True) real_num_ddim_steps = gr.Number(value=NUM_DDIM_STEPS, precision=1, label="Num DDIM steps", interactive=True) real_edit_dir_scale = gr.Number(value=DIR_SCALE, label="Edit Direction Scale", interactive=True) real_generate_button = gr.Button("Generate") real_load_sample_button = gr.Button("Load Example") with gr.Row(): task_name = gr.Radio( label='Task Name', choices=TASK_OPTIONS, value=TASK_OPTIONS[0], type="index", show_label=True, interactive=True, ) with gr.Row(): recon_text = gr.Textbox(lines=1, label="Reconstructed Text", interactive=False) with gr.Row(): input_image = gr.Image(label="Input Image", type="pil", interactive=True) output_image = gr.Image(label="Output Image", type="pil", interactive=False) with gr.Tab("Synthetic Images"): with gr.Row(): synth_seed = gr.Number(value=42, precision=1, label="Seed", interactive=True) synth_prompt = gr.Textbox(lines=1, label="Prompt", interactive=True) synth_generate_button = gr.Button("Generate") synth_load_sample_button = gr.Button("Load Example") with gr.Row(): synth_task_name = gr.Radio( label='Task Name', choices=TASK_OPTIONS, value=TASK_OPTIONS[0], type="index", show_label=True, interactive=True, ) synth_xa_guidance = gr.Number(value=XA_GUIDANCE, label="Cross Attention Guidance", interactive=True) synth_num_ddim_steps = gr.Number(value=NUM_DDIM_STEPS, precision=1, label="Num DDIM steps", interactive=True) with gr.Row(): synth_input_image = gr.Image(label="Input Image", type="pil", interactive=False) synth_output_image = gr.Image(label="Output Image", type="pil", interactive=False) real_generate_button.click( fn=edit_real_image, inputs=[ input_image, task_name, seed, real_xa_guidance, real_num_ddim_steps, real_edit_dir_scale ], outputs=[recon_text, output_image] ) real_load_sample_button.click( fn=edit_real_image_example, inputs=[], outputs=[input_image, seed, task_name, recon_text, output_image, real_xa_guidance, real_num_ddim_steps, real_edit_dir_scale] ) synth_generate_button.click( fn=edit_synthetic_image, inputs=[synth_seed, synth_task_name, synth_prompt, synth_xa_guidance, synth_num_ddim_steps], outputs=[synth_input_image, synth_output_image] ) synth_load_sample_button.click( fn=edit_synth_image_example, inputs=[], outputs=[seed, synth_task_name, synth_xa_guidance, synth_num_ddim_steps, synth_prompt, synth_input_image, synth_output_image] ) demo.queue(concurrency_count=1) demo.launch(share=False, server_name="0.0.0.0") if __name__ == "__main__": main()