Emu2 / demo /generation_frontend.py
ryanzhangfan's picture
fix typo
562648a
raw
history blame
9.96 kB
# -*- coding: utf-8 -*-
# ===================================================
#
# Author : Fan Zhang
# Email : zhangfan@baai.ac.cn
# Institute : Beijing Academy of Artificial Intelligence (BAAI)
# Create On : 2023-12-11 15:35
# Last Modified : 2023-12-20 16:39
# File Name : generation_frontend.py
# Description :
#
# ===================================================
import base64
import json
import io
import time
from PIL import Image
import requests
import gradio as gr
from .constants import EVA_IMAGE_SIZE
from .meta import ConvMeta, Role, DataMeta
from .utils import frontend_logger as logging
from .constants import LICENSE, TERM_OF_USE
CONTROLLER_URL = ""
def submit(
meta,
enable_grd,
left,
top,
right,
bottom,
image,
text,
):
if meta is None:
meta = ConvMeta()
meta.pop_error()
if meta.has_gen:
meta.clear()
if enable_grd:
if text == "" and image is None:
logging.info(f"{meta.log_id}: invalid input: no valid data for grounding input")
gr.Error("text or image must be given if enable grounding generation")
return meta.format_chatbot(), meta, False, 0, 0, EVA_IMAGE_SIZE, EVA_IMAGE_SIZE, None, ""
meta.append(Role.USER, DataMeta.build(text=text, image=image, coordinate=[left, top, right, bottom]))
elif image is not None and text != "":
logging.info(f"{meta.log_id}: invalid input: give text and image simultaneously for single modality input")
gr.Error("Do not submit text and image data at the same time!!!")
return meta.format_chatbot(), meta, False, 0, 0, EVA_IMAGE_SIZE, EVA_IMAGE_SIZE, None, ""
elif image is not None:
meta.append(Role.USER, DataMeta.build(image=image))
elif text != "":
meta.append(Role.USER, DataMeta.build(text=text))
return meta.format_chatbot(), meta, False, 0, 0, EVA_IMAGE_SIZE, EVA_IMAGE_SIZE, None, ""
def clear_history(meta):
if meta is None:
meta = ConvMeta()
meta.clear()
return meta.format_chatbot(), meta
def generate(meta, classifier_free_guidance, steps):
if meta is None:
meta = ConvMeta()
meta.pop_error()
meta.pop()
prompt = meta.format_prompt()
prompt_list, image_list = [], {}
for idx, p in enumerate(prompt):
if isinstance(p, Image.Image):
key = f"[<IMAGE{idx}>]"
prompt_list.append(["IMAGE", key])
buf = io.BytesIO()
p.save(buf, format="PNG")
image_list[key] = (key, io.BytesIO(buf.getvalue()), "image/png")
else:
prompt_list.append(["TEXT", p])
if len(image_list) == 0:
image_list = None
logging.info(f"{meta.log_id}: construct generation reqeust with prompt {prompt_list}")
t0 = time.time()
try:
rsp = requests.post(
CONTROLLER_URL + "/v1/mmg",
files=image_list,
data={
"log_id": meta.log_id,
"prompt": json.dumps(prompt_list),
"classifier_free_guidance": classifier_free_guidance,
"steps": steps,
},
)
except Exception as ex:
rsp = requests.Response()
rsp.status_code = 1099
rsp._content = str(ex).encode()
t1 = time.time()
logging.info(f"{meta.log_id}: get response with status code: {rsp.status_code}, time: {(t1-t0)*1000:.3f}ms")
if rsp.status_code == requests.codes.ok:
content = json.loads(rsp.text)
if content["code"] == 0:
image = Image.open(io.BytesIO(base64.b64decode(content["data"])))
meta.append(Role.ASSISTANT, DataMeta.build(image=image, resize=False))
else:
meta.append(Role.ASSISTANT, DataMeta.build(text=f"GENERATE FAILED: {content['data']}", is_error=True))
else:
meta.append(Role.ASSISTANT, DataMeta.build(text=f"GENERATE FAILED: http failed with code {rsp.status_code}, msg: {rsp.text}", is_error=True))
return meta.format_chatbot(), meta
def push_examples(examples, meta):
if meta is None:
meta = ConvMeta()
meta.clear()
if len(examples) == 1:
prompt, = examples
meta.append(Role.USER, DataMeta.build(text=prompt))
elif len(examples) == 2:
image, prompt = examples
meta.append(Role.USER, DataMeta.build(image=Image.open(image)))
meta.append(Role.USER, DataMeta.build(text=prompt))
return meta.format_chatbot(), meta
def build_generation(args):
global CONTROLLER_URL
CONTROLLER_URL = args.controller_url
with gr.Blocks(title="Emu", theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")) as demo:
state = gr.State()
with gr.Row():
with gr.Column(scale=2):
with gr.Row():
imagebox = gr.Image(type="pil")
with gr.Row():
with gr.Accordion("Grounding Parameters", open=True, visible=True) as grounding_row:
enable_grd = gr.Checkbox(label="Enable")
left = gr.Slider(minimum=0, maximum=EVA_IMAGE_SIZE, value=0, step=1, interactive=True, label="left")
top = gr.Slider(minimum=0, maximum=EVA_IMAGE_SIZE, value=0, step=1, interactive=True, label="top")
right = gr.Slider(minimum=0, maximum=EVA_IMAGE_SIZE, value=EVA_IMAGE_SIZE, step=1, interactive=True, label="right")
bottom = gr.Slider(minimum=0, maximum=EVA_IMAGE_SIZE, value=EVA_IMAGE_SIZE, step=1, interactive=True, label="bottom")
with gr.Row():
with gr.Accordion("Diffusion Parameters", open=True, visible=True) as parameters_row:
cfg = gr.Slider(minimum=1, maximum=30, value=3, step=0.5, interactive=True, label="classifier free guidance")
steps = gr.Slider(minimum=1, maximum=100, value=50, step=1, interactive=True, label="steps")
with gr.Column(scale=6):
chatbot = gr.Chatbot(
elem_id="chatbot",
label="Emu Chatbot",
visible=True,
height=720,
)
with gr.Row():
with gr.Column(scale=8):
textbox = gr.Textbox(
show_label=False,
placeholder="Enter text and add to prompt",
visible=True,
container=False,
)
with gr.Column(scale=1, min_width=60):
add_btn = gr.Button(value="Add")
with gr.Row(visible=True) as button_row:
# upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=False)
# downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=False)
# regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=False)
clear_btn = gr.Button(value="πŸ—‘οΈ Clear History")
generate_btn = gr.Button(value="Generate")
with gr.Row():
examples_t2i = gr.Dataset(components=[gr.Textbox(visible=False)],
label="Examples",
samples=[
["impressionist painting of an astronaut in a jungle"],
["A poised woman with short, curly hair and a warm smile, dressed in elegant attire, standing in front of a historic stone bridge in a serene park at sunset."],
],
)
with gr.Row():
examples_it2i = gr.Dataset(components=[gr.Image(type="pil", visible=False), gr.Textbox(visible=False)],
samples=[
["./examples/dog2.jpg", "wearing a red hat on the beach."],
["./examples/balloon.webp", "floats above theforest"],
],
)
gr.Markdown(TERM_OF_USE)
gr.Markdown(LICENSE)
clear_btn.click(clear_history, inputs=state, outputs=[chatbot, state])
textbox.submit(
submit,
inputs=[
state,
enable_grd,
left,
top,
right,
bottom,
imagebox,
textbox,
],
outputs=[
chatbot,
state,
enable_grd,
left,
top,
right,
bottom,
imagebox,
textbox,
],
)
add_btn.click(
submit,
inputs=[
state,
enable_grd,
left,
top,
right,
bottom,
imagebox,
textbox,
],
outputs=[
chatbot,
state,
enable_grd,
left,
top,
right,
bottom,
imagebox,
textbox,
],
)
generate_btn.click(
generate,
inputs=[
state,
cfg,
steps,
],
outputs=[
chatbot,
state,
]
)
examples_t2i.click(
push_examples,
inputs=[
examples_t2i,
state,
],
outputs=[
chatbot,
state,
]
)
examples_it2i.click(
push_examples,
inputs=[
examples_it2i,
state,
],
outputs=[
chatbot,
state,
]
)
return demo