import gradio as gr
import numpy as np
import cv2
from PIL import Image
import torch
import base64
import requests
from io import BytesIO
from region_control import MultiDiffusion, get_views, preprocess_mask
from sketch_helper import get_high_freq_colors, color_quantization, create_binary_matrix
MAX_COLORS = 12
sd = MultiDiffusion("cuda", "2.0")
canvas_html = "
"
load_js = """
async () => {
const url = "https://huggingface.co/datasets/radames/gradio-components/raw/main/sketch-canvas.js"
fetch(url)
.then(res => res.text())
.then(text => {
const script = document.createElement('script');
script.type = "module"
script.src = URL.createObjectURL(new Blob([text], { type: 'application/javascript' }));
document.head.appendChild(script);
});
}
"""
get_js_colors = """
async (canvasData) => {
const canvasEl = document.getElementById("canvas-root");
return [canvasEl._data]
}
"""
set_canvas_size ="""
async (aspect) => {
if(aspect ==='square'){
_updateCanvas(512,512)
}
if(aspect ==='horizontal'){
_updateCanvas(768,512)
}
if(aspect ==='vertical'){
_updateCanvas(512,768)
}
}
"""
def process_sketch(canvas_data, binary_matrixes):
base64_img = canvas_data['image']
image_data = base64.b64decode(base64_img.split(',')[1])
image = Image.open(BytesIO(image_data))
im2arr = np.array(image)
colors = [tuple(int(color.lstrip('#')[i:i+2], 16) for i in (0, 2, 4)) for color in canvas_data['colors']]
colors_fixed = []
for color in colors:
r, g, b = color
if any(c != 255 for c in (r, g, b)):
binary_matrix = create_binary_matrix(im2arr, (r,g,b))
binary_matrixes.append(binary_matrix)
colors_fixed.append(gr.update(value=f''))
visibilities = []
colors = []
for n in range(MAX_COLORS):
visibilities.append(gr.update(visible=False))
colors.append(gr.update(value=f''))
for n in range(len(colors)-1):
visibilities[n] = gr.update(visible=True)
colors[n] = colors_fixed[n]
return [gr.update(visible=True), binary_matrixes, *visibilities, *colors]
def process_generation(binary_matrixes, master_prompt, *prompts):
clipped_prompts = prompts[:len(binary_matrixes)]
prompts = [master_prompt] + list(clipped_prompts)
neg_prompts = [""] * len(prompts)
fg_masks = torch.cat([preprocess_mask(mask_path, 512 // 8, 512 // 8, "cuda") for mask_path in binary_matrixes])
bg_mask = 1 - torch.sum(fg_masks, dim=0, keepdim=True)
bg_mask[bg_mask < 0] = 0
masks = torch.cat([bg_mask, fg_masks])
print(masks.size())
image = sd.generate(masks, prompts, neg_prompts, 512, 512, 50, bootstrapping=20)
return(image)
css = '''
#color-bg{display:flex;justify-content: center;align-items: center;}
.color-bg-item{width: 100%; height: 32px}
#main_button{width:100%}
'''
def update_css(aspect):
if(aspect=='Square'):
return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)]
elif(aspect == 'Horizontal'):
return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)]
elif(aspect=='Vertical'):
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
with gr.Blocks(css=css) as demo:
binary_matrixes = gr.State([])
gr.Markdown('''## Control your Stable Diffusion generation with Sketches
This Space demonstrates MultiDiffusion region-based generation using Stable Diffusion model. To get started, draw your masks and type your prompts. More details in the [project page](https://multidiffusion.github.io).
''')
with gr.Row():
with gr.Box(elem_id="main-image"):
#with gr.Row():
canvas_data = gr.JSON(value={}, visible=False)
canvas = gr.HTML(canvas_html)
#image = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(512,512), brush_radius=45)
#image_horizontal = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(768,512), visible=False, brush_radius=45)
#image_vertical = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(512, 768), visible=False, brush_radius=45)
#with gr.Row():
# aspect = gr.Radio(["Square", "Horizontal", "Vertical"], value="Square", label="Aspect Ratio")
button_run = gr.Button("I've finished my sketch",elem_id="main_button", interactive=True, _js=get_js_colors)
prompts = []
colors = []
color_row = [None] * MAX_COLORS
with gr.Column(visible=False) as post_sketch:
general_prompt = gr.Textbox(label="General Prompt")
for n in range(MAX_COLORS):
with gr.Row(visible=False) as color_row[n]:
with gr.Box(elem_id="color-bg"):
colors.append(gr.HTML(''))
prompts.append(gr.Textbox(label="Prompt for this mask"))
final_run_btn = gr.Button("Generate!")
out_image = gr.Image(label="Result")
gr.Markdown('''
![Examples](https://multidiffusion.github.io/pics/tight.jpg)
''')
#css_height = gr.HTML("")
#aspect.change(update_css, inputs=aspect, outputs=[image, image_horizontal, image_vertical])
button_run.click(process_sketch, inputs=[canvas_data, binary_matrixes], outputs=[post_sketch, binary_matrixes, *color_row, *colors])
final_run_btn.click(process_generation, inputs=[binary_matrixes, general_prompt, *prompts], outputs=out_image)
demo.load(None, None, None, _js=load_js)
demo.launch(debug=True)