# ************************************************************************* # Copyright (2023) Bytedance Inc. # # Copyright (2023) DragDiffusion Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ************************************************************************* import os import gradio as gr from utils.ui_utils import get_points, undo_points from utils.ui_utils import clear_all, store_img, train_lora_interface, run_drag from utils.ui_utils import clear_all_gen, store_img_gen, gen_img, run_drag_gen LENGTH=480 # length of the square area displaying/editing images with gr.Blocks() as demo: # layout definition with gr.Row(): gr.Markdown(""" # Official Implementation of [DragDiffusion](https://arxiv.org/abs/2306.14435) """) # UI components for editing real images with gr.Tab(label="Editing Real Image"): mask = gr.State(value=None) # store mask selected_points = gr.State([]) # store points original_image = gr.State(value=None) # store original input image with gr.Row(): with gr.Column(): gr.Markdown("""
Draw Mask
""") canvas = gr.Image(type="numpy", tool="sketch", label="Draw Mask", show_label=True, height=LENGTH, width=LENGTH) # for mask painting train_lora_button = gr.Button("Train LoRA") with gr.Column(): gr.Markdown("""Click Points
""") input_image = gr.Image(type="numpy", label="Click Points", show_label=True, height=LENGTH, width=LENGTH) # for points clicking undo_button = gr.Button("Undo point") with gr.Column(): gr.Markdown("""Editing Results
""") output_image = gr.Image(type="numpy", label="Editing Results", show_label=True, height=LENGTH, width=LENGTH) with gr.Row(): run_button = gr.Button("Run") clear_all_button = gr.Button("Clear All") # general parameters with gr.Row(): prompt = gr.Textbox(label="Prompt") lora_path = gr.Textbox(value="./lora_tmp", label="LoRA path") lora_status_bar = gr.Textbox(label="display LoRA training status") # algorithm specific parameters with gr.Tab("Drag Config"): with gr.Row(): n_pix_step = gr.Number( value=40, label="number of pixel steps", info="Number of gradient descent (motion supervision) steps on latent.", precision=0) lam = gr.Number(value=0.1, label="lam", info="regularization strength on unmasked areas") # n_actual_inference_step = gr.Number(value=40, label="optimize latent step", precision=0) inversion_strength = gr.Slider(0, 1.0, value=0.75, label="inversion strength", info="The latent at [inversion-strength * total-sampling-steps] is optimized for dragging.") latent_lr = gr.Number(value=0.01, label="latent lr") start_step = gr.Number(value=0, label="start_step", precision=0, visible=False) start_layer = gr.Number(value=10, label="start_layer", precision=0, visible=False) with gr.Tab("Base Model Config"): with gr.Row(): local_models_dir = 'local_pretrained_models' local_models_choice = \ [os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))] model_path = gr.Dropdown(value="runwayml/stable-diffusion-v1-5", label="Diffusion Model Path", choices=[ "runwayml/stable-diffusion-v1-5", ] + local_models_choice ) vae_path = gr.Dropdown(value="default", label="VAE choice", choices=["default", "stabilityai/sd-vae-ft-mse"] + local_models_choice ) with gr.Tab("LoRA Parameters"): with gr.Row(): lora_step = gr.Number(value=200, label="LoRA training steps", precision=0) lora_lr = gr.Number(value=0.0002, label="LoRA learning rate") lora_rank = gr.Number(value=16, label="LoRA rank", precision=0) # UI components for editing generated images with gr.Tab(label="Editing Generated Image"): mask_gen = gr.State(value=None) # store mask selected_points_gen = gr.State([]) # store points original_image_gen = gr.State(value=None) # store the diffusion-generated image intermediate_latents_gen = gr.State(value=None) # store the intermediate diffusion latent during generation with gr.Row(): with gr.Column(): gr.Markdown("""Draw Mask
""") canvas_gen = gr.Image(type="numpy", tool="sketch", label="Draw Mask", show_label=True, height=LENGTH, width=LENGTH) # for mask painting gen_img_button = gr.Button("Generate Image") with gr.Column(): gr.Markdown("""Click Points
""") input_image_gen = gr.Image(type="numpy", label="Click Points", show_label=True, height=LENGTH, width=LENGTH) # for points clicking undo_button_gen = gr.Button("Undo point") with gr.Column(): gr.Markdown("""Editing Results
""") output_image_gen = gr.Image(type="numpy", label="Editing Results", show_label=True, height=LENGTH, width=LENGTH) with gr.Row(): run_button_gen = gr.Button("Run") clear_all_button_gen = gr.Button("Clear All") # general parameters with gr.Row(): pos_prompt_gen = gr.Textbox(label="Positive Prompt") neg_prompt_gen = gr.Textbox(label="Negative Prompt") with gr.Tab("Generation Config"): with gr.Row(): local_models_dir = 'local_pretrained_models' local_models_choice = \ [os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))] model_path_gen = gr.Dropdown(value="runwayml/stable-diffusion-v1-5", label="Diffusion Model Path", choices=[ "runwayml/stable-diffusion-v1-5", "gsdf/Counterfeit-V2.5", "emilianJR/majicMIX_realistic", "SG161222/Realistic_Vision_V2.0", "stablediffusionapi/landscapesupermix", "huangzhe0803/ArchitectureRealMix", "stablediffusionapi/interiordesignsuperm" ] + local_models_choice ) vae_path_gen = gr.Dropdown(value="default", label="VAE choice", choices=["default", "stabilityai/sd-vae-ft-mse"] + local_models_choice ) lora_path_gen = gr.Textbox(value="", label="LoRA path") gen_seed = gr.Number(value=65536, label="Generation Seed", precision=0) height = gr.Number(value=512, label="Height", precision=0) width = gr.Number(value=512, label="Width", precision=0) guidance_scale = gr.Number(value=7.5, label="CFG Scale") scheduler_name_gen = gr.Dropdown( value="DDIM", label="Scheduler", choices=[ "DDIM", "DPM++2M", "DPM++2M_karras" ] ) n_inference_step_gen = gr.Number(value=50, label="Total Sampling Steps", precision=0) with gr.Tab(label="Drag Config"): with gr.Row(): n_pix_step_gen = gr.Number( value=40, label="Number of Pixel Steps", info="Number of gradient descent (motion supervision) steps on latent.", precision=0) lam_gen = gr.Number(value=0.1, label="lam", info="regularization strength on unmasked areas") # n_actual_inference_step_gen = gr.Number(value=40, label="optimize latent step", precision=0) inversion_strength_gen = gr.Slider(0, 1.0, value=0.75, label="Inversion Strength", info="The latent at [inversion-strength * total-sampling-steps] is optimized for dragging.") latent_lr_gen = gr.Number(value=0.01, label="latent lr") start_step_gen = gr.Number(value=0, label="start_step", precision=0, visible=False) start_layer_gen = gr.Number(value=10, label="start_layer", precision=0, visible=False) # Add a checkbox for users to select if they want a GIF of the process with gr.Row(): create_gif_checkbox = gr.Checkbox(label="create_GIF", value=False) create_tracking_point_checkbox = gr.Checkbox(label="create_tracking_point", value=False) gif_interval = gr.Number(value=10, label="interval_GIF", precision=0, info="The interval of the GIF, i.e. the number of steps between each frame of the GIF.") gif_fps = gr.Number(value=1, label="fps_GIF", precision=0, info="The fps of the GIF, i.e. the number of frames per second of the GIF.") # event definition # event for dragging user-input real image canvas.edit( store_img, [canvas], [original_image, selected_points, input_image, mask] ) input_image.select( get_points, [input_image, selected_points], [input_image], ) undo_button.click( undo_points, [original_image, mask], [input_image, selected_points] ) train_lora_button.click( train_lora_interface, [original_image, prompt, model_path, vae_path, lora_path, lora_step, lora_lr, lora_rank], [lora_status_bar] ) run_button.click( run_drag, [original_image, input_image, mask, prompt, selected_points, inversion_strength, lam, latent_lr, n_pix_step, model_path, vae_path, lora_path, start_step, start_layer, create_gif_checkbox, gif_interval, ], [output_image] ) clear_all_button.click( clear_all, [gr.Number(value=LENGTH, visible=False, precision=0)], [canvas, input_image, output_image, selected_points, original_image, mask] ) # event for dragging generated image canvas_gen.edit( store_img_gen, [canvas_gen], [original_image_gen, selected_points_gen, input_image_gen, mask_gen] ) input_image_gen.select( get_points, [input_image_gen, selected_points_gen], [input_image_gen], ) gen_img_button.click( gen_img, [ gr.Number(value=LENGTH, visible=False, precision=0), height, width, n_inference_step_gen, scheduler_name_gen, gen_seed, guidance_scale, pos_prompt_gen, neg_prompt_gen, model_path_gen, vae_path_gen, lora_path_gen, ], [canvas_gen, input_image_gen, output_image_gen, mask_gen, intermediate_latents_gen] ) undo_button_gen.click( undo_points, [original_image_gen, mask_gen], [input_image_gen, selected_points_gen] ) run_button_gen.click( run_drag_gen, [ n_inference_step_gen, scheduler_name_gen, original_image_gen, # the original image generated by the diffusion model input_image_gen, # image with clicking, masking, etc. intermediate_latents_gen, guidance_scale, mask_gen, pos_prompt_gen, neg_prompt_gen, selected_points_gen, inversion_strength_gen, lam_gen, latent_lr_gen, n_pix_step_gen, model_path_gen, vae_path_gen, lora_path_gen, start_step_gen, start_layer_gen, create_gif_checkbox, create_tracking_point_checkbox, gif_interval, gif_fps ], [output_image_gen] ) clear_all_button_gen.click( clear_all_gen, [gr.Number(value=LENGTH, visible=False, precision=0)], [canvas_gen, input_image_gen, output_image_gen, selected_points_gen, original_image_gen, mask_gen, intermediate_latents_gen, ] ) demo.queue().launch(share=True, debug=True)