ToDo / app.py
aningineer's picture
Upload folder using huggingface_hub
542ea3b verified
import time
import spaces
import gradio as gr
import torch
import diffusers
from utils import patch_attention_proc, remove_patch
import math
import numpy as np
from PIL import Image
from threading import Semaphore
# Globals
css = """
h1 {
text-align: center;
display: block;
}
"""
# Pipeline
pipe = diffusers.StableDiffusionPipeline.from_pretrained("Lykon/DreamShaper").to("cuda", torch.float16)
pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.safety_checker = None
semaphore = Semaphore() # for preventing collisions of two simultaneous button presses
@spaces.GPU
def generate_baseline(prompt, seed, steps, height_width, negative_prompt, guidance_scale, method):
semaphore.acquire()
downsample_factor = 2
ratio = 0.38
merge_method = "downsample" if method == "todo" else "similarity"
merge_tokens = "keys/values" if method == "todo" else "all"
if height_width == 1024:
downsample_factor = 2
ratio = 0.75
downsample_factor_level_2 = 1
ratio_level_2 = 0.0
elif height_width == 1536:
downsample_factor = 3
ratio = 0.89
downsample_factor_level_2 = 1
ratio_level_2 = 0.0
elif height_width == 2048:
downsample_factor = 4
ratio = 0.9375
downsample_factor_level_2 = 1
ratio_level_2 = 0.0
token_merge_args = {"ratio": ratio,
"merge_tokens": merge_tokens,
"merge_method": merge_method,
"downsample_method": "nearest",
"downsample_factor": downsample_factor,
"timestep_threshold_switch": 0.0,
"timestep_threshold_stop": 0.0,
"downsample_factor_level_2": downsample_factor_level_2,
"ratio_level_2": ratio_level_2
}
torch.manual_seed(seed)
start_time_base = time.time()
remove_patch(pipe)
base_img = pipe(prompt,
num_inference_steps=steps, height=height_width, width=height_width,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale).images[0]
end_time_base = time.time()
result = f"Baseline Runtime: {end_time_base-start_time_base:.2f} sec"
semaphore.release()
return base_img, result
@spaces.GPU
def generate_merged(prompt, seed, steps, height_width, negative_prompt, guidance_scale, method):
semaphore.acquire()
downsample_factor = 2
ratio = 0.38
merge_method = "downsample" if method == "todo" else "similarity"
merge_tokens = "keys/values" if method == "todo" else "all"
if height_width == 1024:
downsample_factor = 2
ratio = 0.75
downsample_factor_level_2 = 1
ratio_level_2 = 0.0
elif height_width == 1536:
downsample_factor = 3
ratio = 0.89
downsample_factor_level_2 = 1
ratio_level_2 = 0.0
elif height_width == 2048:
downsample_factor = 4
ratio = 0.9375
downsample_factor_level_2 = 1
ratio_level_2 = 0.0
token_merge_args = {"ratio": ratio,
"merge_tokens": merge_tokens,
"merge_method": merge_method,
"downsample_method": "nearest",
"downsample_factor": downsample_factor,
"timestep_threshold_switch": 0.0,
"timestep_threshold_stop": 0.0,
"downsample_factor_level_2": downsample_factor_level_2,
"ratio_level_2": ratio_level_2
}
patch_attention_proc(pipe.unet, token_merge_args=token_merge_args)
torch.manual_seed(seed)
start_time_merge = time.time()
merged_img = pipe(prompt,
num_inference_steps=steps, height=height_width, width=height_width,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale).images[0]
end_time_merge = time.time()
result = f"{'ToDo' if method == 'todo' else 'ToMe'} Runtime: {end_time_merge-start_time_merge:.2f} sec"
semaphore.release()
return merged_img, result
with gr.Blocks(css=css) as demo:
gr.Markdown("# ToDo: Token Downsampling for Efficient Generation of High-Resolution Images")
prompt = gr.Textbox(interactive=True, label="prompt")
negative_prompt = gr.Textbox(interactive=True, label="negative_prompt")
with gr.Row():
method = gr.Dropdown(["todo", "tome"], value="todo", label="method", info="Choose Your Desired Method (Default: todo)")
height_width = gr.Dropdown([1024, 1536, 2048], value=1024, label="height/width", info="Choose Your Desired Height/Width (Default: 1024)")
with gr.Row():
guidance_scale = gr.Number(label="guidance_scale", value=7.5, precision=1)
steps = gr.Number(label="steps", value=20, precision=0)
seed = gr.Number(label="seed", value=1, precision=0)
with gr.Row():
with gr.Column():
base_result = gr.Textbox(label="Baseline Runtime")
base_image = gr.Image(label=f"baseline_image", type="pil", interactive=False)
gen = gr.Button("Generate Baseline")
gen.click(generate_baseline, inputs=[prompt, seed, steps, height_width, negative_prompt,
guidance_scale, method], outputs=[base_image, base_result])
with gr.Column():
output_result = gr.Textbox(label="Runtime")
output_image = gr.Image(label=f"image", type="pil", interactive=False)
gen = gr.Button("Generate")
gen.click(generate_merged, inputs=[prompt, seed, steps, height_width, negative_prompt,
guidance_scale, method], outputs=[output_image, output_result])
demo.launch(share=True)