MagicQuill / app.py
LIU, Zichen
add gradio footer
78fe60c
raw
history blame
13.3 kB
import subprocess
import shlex
subprocess.run(
shlex.split(
"pip install ./gradio_magicquill-0.0.1-py3-none-any.whl"
)
)
import gradio as gr
from gradio_magicquill import MagicQuill
import random
import torch
import numpy as np
from PIL import Image, ImageOps
import base64
import io
from fastapi import FastAPI, Request
import uvicorn
from MagicQuill import folder_paths
from MagicQuill.scribble_color_edit import ScribbleColorEditModel
from gradio_client import Client, handle_file
from huggingface_hub import snapshot_download
import tempfile
import cv2
import os
import requests
snapshot_download(repo_id="LiuZichen/MagicQuill-models", repo_type="model", local_dir="models")
HF_TOKEN = os.environ.get("HF_TOKEN")
client = Client("LiuZichen/DrawNGuess", hf_token=HF_TOKEN)
scribbleColorEditModel = ScribbleColorEditModel()
def tensor_to_numpy(tensor):
if isinstance(tensor, torch.Tensor):
return (tensor.detach().cpu().numpy() * 255).astype(np.uint8)
return tensor
def tensor_to_base64(tensor):
tensor = tensor.squeeze(0) * 255.
pil_image = Image.fromarray(tensor.cpu().byte().numpy())
buffered = io.BytesIO()
pil_image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_str
def read_base64_image(base64_image):
if base64_image.startswith("data:image/png;base64,"):
base64_image = base64_image.split(",")[1]
elif base64_image.startswith("data:image/jpeg;base64,"):
base64_image = base64_image.split(",")[1]
elif base64_image.startswith("data:image/webp;base64,"):
base64_image = base64_image.split(",")[1]
else:
raise ValueError("Unsupported image format.")
image_data = base64.b64decode(base64_image)
image = Image.open(io.BytesIO(image_data))
image = ImageOps.exif_transpose(image)
return image
def create_alpha_mask(base64_image):
"""Create an alpha mask from the alpha channel of an image."""
image = read_base64_image(base64_image)
mask = torch.zeros((1, image.height, image.width), dtype=torch.float32, device="cpu")
if 'A' in image.getbands():
alpha_channel = np.array(image.getchannel('A')).astype(np.float32) / 255.0
mask[0] = 1.0 - torch.from_numpy(alpha_channel)
return mask
def load_and_preprocess_image(base64_image, convert_to='RGB', has_alpha=False):
"""Load and preprocess a base64 image."""
image = read_base64_image(base64_image)
image = image.convert(convert_to)
image_array = np.array(image).astype(np.float32) / 255.0
image_tensor = torch.from_numpy(image_array)[None,]
return image_tensor
def load_and_resize_image(base64_image, convert_to='RGB', max_size=512):
"""Load and preprocess a base64 image, resize if necessary."""
image = read_base64_image(base64_image)
image = image.convert(convert_to)
width, height = image.size
# if min(width, height) > max_size:
scaling_factor = max_size / min(width, height)
new_size = (int(width * scaling_factor), int(height * scaling_factor))
image = image.resize(new_size, Image.LANCZOS)
image_array = np.array(image).astype(np.float32) / 255.0
image_tensor = torch.from_numpy(image_array)[None,]
return image_tensor
def prepare_images_and_masks(total_mask, original_image, add_color_image, add_edge_image, remove_edge_image):
total_mask = create_alpha_mask(total_mask)
original_image_tensor = load_and_preprocess_image(original_image)
if add_color_image:
add_color_image_tensor = load_and_preprocess_image(add_color_image)
else:
add_color_image_tensor = original_image_tensor
add_edge_mask = create_alpha_mask(add_edge_image) if add_edge_image else torch.zeros_like(total_mask)
remove_edge_mask = create_alpha_mask(remove_edge_image) if remove_edge_image else torch.zeros_like(total_mask)
return add_color_image_tensor, original_image_tensor, total_mask, add_edge_mask, remove_edge_mask
def guess_prompt_handler(original_image, add_color_image, add_edge_image):
# 将张量转换为 NumPy 数组
original_image_tensor = load_and_preprocess_image(original_image)
if add_color_image:
add_color_image_tensor = load_and_preprocess_image(add_color_image)
else:
add_color_image_tensor = original_image_tensor
width, height = original_image_tensor.shape[1], original_image_tensor.shape[2]
add_edge_mask = create_alpha_mask(add_edge_image) if add_edge_image else torch.zeros((1, height, width), dtype=torch.float32, device="cpu")
original_image_numpy = tensor_to_numpy(original_image_tensor.squeeze(0))
add_color_image_numpy = tensor_to_numpy(add_color_image_tensor.squeeze(0))
add_edge_mask_numpy = tensor_to_numpy(add_edge_mask.squeeze(0).unsqueeze(-1))
original_image_numpy = cv2.cvtColor(original_image_numpy, cv2.COLOR_RGB2BGR)
add_color_image_numpy = cv2.cvtColor(add_color_image_numpy, cv2.COLOR_RGB2BGR)
# 创建临时文件,保存 NumPy 数组为图像文件
original_image_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png", mode='w+b')
add_color_image_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png", mode='w+b')
add_edge_mask_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png", mode='w+b')
# 保存 NumPy 数组为 PNG 图像文件
cv2.imwrite(original_image_file.name, original_image_numpy)
cv2.imwrite(add_color_image_file.name, add_color_image_numpy)
cv2.imwrite(add_edge_mask_file.name, add_edge_mask_numpy)
# 确保文件关闭以保证内容写入磁盘
original_image_file.close()
add_color_image_file.close()
add_edge_mask_file.close()
# 调用 API,传递临时文件的路径
res = client.predict(
handle_file(original_image_file.name),
handle_file(add_color_image_file.name),
handle_file(add_edge_mask_file.name)
)
# 删除临时文件,确保它们存在且已被创建
if original_image_file and os.path.exists(original_image_file.name):
os.remove(original_image_file.name)
if add_color_image_file and os.path.exists(add_color_image_file.name):
os.remove(add_color_image_file.name)
if add_edge_mask_file and os.path.exists(add_edge_mask_file.name):
os.remove(add_edge_mask_file.name)
return res
def generate(ckpt_name, total_mask, original_image, add_color_image, add_edge_image, remove_edge_image, positive_prompt, negative_prompt, grow_size, stroke_as_edge, fine_edge, edge_strength, color_strength, inpaint_strength, seed, steps, cfg, sampler_name, scheduler):
add_color_image, original_image, total_mask, add_edge_mask, remove_edge_mask = prepare_images_and_masks(total_mask, original_image, add_color_image, add_edge_image, remove_edge_image)
progress = None
if fine_edge == 'disable':
if torch.sum(remove_edge_mask).item() > 0 and torch.sum(add_edge_mask).item() == 0:
if positive_prompt == "":
positive_prompt = "empty scene"
edge_strength /= 3.
latent_samples, final_image, lineart_output, color_output = scribbleColorEditModel.process(
ckpt_name,
original_image,
add_color_image,
positive_prompt,
negative_prompt,
total_mask,
add_edge_mask,
remove_edge_mask,
grow_size,
stroke_as_edge,
fine_edge,
edge_strength,
color_strength,
inpaint_strength,
seed,
steps,
cfg,
sampler_name,
scheduler,
progress
)
final_image_base64 = tensor_to_base64(final_image)
return final_image_base64
def generate_image_handler(x, ckpt_name, negative_prompt, fine_edge, grow_size, edge_strength, color_strength, inpaint_strength, seed, steps, cfg, sampler_name, scheduler):
if seed == -1:
seed = random.randint(0, 2**32 - 1)
ms_data = x['from_frontend']
positive_prompt = x['from_backend']['prompt']
stroke_as_edge = "enable"
res = generate(ckpt_name, ms_data['total_mask'], ms_data['original_image'], ms_data['add_color_image'], ms_data['add_edge_image'], ms_data['remove_edge_image'], positive_prompt, negative_prompt, grow_size, stroke_as_edge, fine_edge, edge_strength, color_strength, inpaint_strength, seed, steps, cfg, sampler_name, scheduler)
x["from_backend"]["generated_image"] = res
return x
css = '''
.row {
width: 90%;
margin: auto;
}
'''
with gr.Blocks(css=css) as demo:
with gr.Row(elem_classes="row"):
ms = MagicQuill()
with gr.Row(elem_classes="row"):
with gr.Column():
btn = gr.Button("Run", variant="primary")
with gr.Column():
with gr.Accordion("parameters", open=False):
ckpt_name = gr.Dropdown(
label="Base Model Name",
choices=folder_paths.get_filename_list("checkpoints"),
value='SD1.5/realisticVisionV60B1_v51VAE.safetensors',
interactive=True
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value="",
interactive=True
)
# stroke_as_edge = gr.Radio(
# label="Stroke as Edge",
# choices=['enable', 'disable'],
# value='enable',
# interactive=True
# )
fine_edge = gr.Radio(
label="Fine Edge",
choices=['enable', 'disable'],
value='disable',
interactive=True
)
grow_size = gr.Slider(
label="Grow Size",
minimum=0,
maximum=100,
value=15,
step=1,
interactive=True
)
edge_strength = gr.Slider(
label="Edge Strength",
minimum=0.0,
maximum=5.0,
value=0.6,
step=0.01,
interactive=True
)
color_strength = gr.Slider(
label="Color Strength",
minimum=0.0,
maximum=5.0,
value=0.6,
step=0.01,
interactive=True
)
inpaint_strength = gr.Slider(
label="Inpaint Strength",
minimum=0.0,
maximum=5.0,
value=1.0,
step=0.01,
interactive=True
)
seed = gr.Number(
label="Seed",
value=-1,
precision=0,
interactive=True
)
steps = gr.Slider(
label="Steps",
minimum=1,
maximum=50,
value=20,
step=1,
interactive=True
)
cfg = gr.Slider(
label="CFG",
minimum=0.0,
maximum=20.0,
value=5.0,
step=0.1,
interactive=True
)
sampler_name = gr.Dropdown(
label="Sampler Name",
choices=["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "ddim", "uni_pc", "uni_pc_bh2"],
value='euler_ancestral',
interactive=True
)
scheduler = gr.Dropdown(
label="Scheduler",
choices=["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"],
value='karras',
interactive=True
)
btn.click(generate_image_handler, inputs=[ms, ckpt_name, negative_prompt, fine_edge, grow_size, edge_strength, color_strength, inpaint_strength, seed, steps, cfg, sampler_name, scheduler], outputs=ms, concurrency_limit=1)
demo.queue(max_size=20, status_update_rate=0.1)
app = FastAPI()
@app.post("/magic_quill/guess_prompt")
async def guess_prompt(request: Request):
data = await request.json()
res = guess_prompt_handler(data['original_image'], data['add_color_image'], data['add_edge_image'])
return res
@app.post("/magic_quill/process_background_img")
async def process_background_img(request: Request):
img = await request.json()
resized_img_tensor = load_and_resize_image(img)
resized_img_base64 = "data:image/png;base64," + tensor_to_base64(resized_img_tensor)
# add more processing here
return resized_img_base64
app = gr.mount_gradio_app(app, demo, "/")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)
# demo.launch()