Spaces:
Runtime error
Runtime error
import io | |
import re | |
import imp | |
import time | |
import json | |
import base64 | |
import requests | |
import gradio as gr | |
import ui_functions as uifn | |
from css_and_js import js, call_JS | |
from PIL import Image, PngImagePlugin, ImageChops | |
url_host = "https://flagstudio.baai.ac.cn" | |
token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoiZjAxOGMxMzJiYTUyNDBjMzk5NTMzYTI5YjBmMzZiODMiLCJhcHBfbmFtZSI6IndlYiIsImlkZW50aXR5X3R5cGUiOiIyIiwidXNlcl9yb2xlIjoiMiIsImp0aSI6IjVjMmQzMjdiLWI5Y2MtNDhiZS1hZWQ4LTllMjQ4MDk4NzMxYyIsIm5iZiI6MTY2OTAwNjE5NywiZXhwIjoxOTg0MzY2MTk3LCJpYXQiOjE2NjkwMDYxOTd9.9B3MDk8wA6iWH5puXjcD19tJJ4Ox7mdpRyWZs5Kwt70" | |
def read_content(file_path: str) -> str: | |
"""read the content of target file | |
""" | |
with open(file_path, 'r', encoding='utf-8') as f: | |
content = f.read() | |
return content | |
def filter_content(raw_style: str): | |
if "(" in raw_style: | |
i = raw_style.index("(") | |
else : | |
i = -1 | |
if i == -1: | |
return raw_style | |
else : | |
return raw_style[:i] | |
def upload_image(img): | |
url = url_host + "/api/v1/image/get-upload-link" | |
headers = {"token": token} | |
r = requests.post(url, json={}, headers=headers) | |
if r.status_code != 200: | |
raise gr.Error(r.reason) | |
head_res = r.json() | |
if head_res["code"] != 0: | |
raise gr.Error("Unknown error") | |
image_id = head_res["data"]["image_id"] | |
image_url = head_res["data"]["url"] | |
image_headers = head_res["data"]["headers"] | |
imgBytes = io.BytesIO() | |
img.save(imgBytes, "PNG") | |
imgBytes = imgBytes.getvalue() | |
r = requests.put(image_url, data=imgBytes, headers=image_headers) | |
if r.status_code != 200: | |
raise gr.Error(r.reason) | |
return image_id, image_url | |
def post_reqest(seed, prompt, width, height, image_num, img=None, mask=None): | |
data = { | |
"type": "gen-image", | |
"parameters": { | |
"width": width, # output height width | |
"height": height, # output image height | |
"prompts": [prompt], | |
} | |
} | |
data["parameters"]["seed"] = int(seed) | |
if img is not None: | |
# Upload image | |
image_id, image_url = upload_image(img) | |
data["parameters"]["init_image"] = { | |
"image_id": image_id, | |
"url": image_url, | |
"width": img.width, | |
"height": img.height, | |
} | |
if mask is not None: | |
# Upload mask | |
extrama = mask.convert("L").getextrema() | |
if extrama[1] > 0: | |
mask_id, mask_url = upload_image(mask) | |
data["parameters"]["mask_image"] = { | |
"image_id": mask_id, | |
"url": mask_url, | |
"width": mask.width, | |
"height": mask.height, | |
} | |
headers = {"token": token} | |
# Send create task request | |
all_task_data = [] | |
url = url_host+"/api/v1/task/create" | |
for _ in range(image_num): | |
r = requests.post(url, json=data, headers=headers) | |
if r.status_code != 200: | |
raise gr.Error(r.reason) | |
create_res = r.json() | |
if create_res['code'] == 3002: | |
raise gr.Error("Inappropriate prompt detected.") | |
elif create_res['code'] != 0: | |
raise gr.Error("Unknown error") | |
all_task_data.append(create_res["data"]) | |
# Get result | |
url = url_host+"/api/v1/task/status" | |
images = [] | |
while True: | |
if len(all_task_data) <= 0: | |
return images | |
for i in range(len(all_task_data)-1, -1, -1): | |
data = all_task_data[i] | |
r = requests.post(url, json=data, headers=headers) | |
if r.status_code != 200: | |
raise gr.Error(r.reason) | |
res = r.json() | |
if res["code"] == 6002: | |
# Running | |
continue | |
if res["code"] == 6005: | |
raise gr.Error("NSFW image detected.") | |
elif res["code"] == 0: | |
# Finished | |
for img_info in res["data"]["images"]: | |
img_res = requests.get(img_info["url"]) | |
images.append(Image.open(io.BytesIO(img_res.content)).convert("RGB")) | |
del all_task_data[i] | |
else: | |
raise gr.Error(f"Error code: {res['code']}") | |
time.sleep(1) | |
def request_images(raw_text, class_draw, style_draw, batch_size, w, h, seed): | |
if filter_content(class_draw) != "国画": | |
if filter_content(class_draw) != "通用": | |
raw_text = raw_text + f",{filter_content(class_draw)}" | |
for sty in style_draw: | |
raw_text = raw_text + f",{filter_content(sty)}" | |
elif filter_content(class_draw) == "国画": | |
raw_text = raw_text + ",国画,水墨画,大作,黑白,高清,传统" | |
print(f"raw text is {raw_text}") | |
images = post_reqest(seed, raw_text, w, h, int(batch_size)) | |
return images | |
def img2img(prompt, image_and_mask): | |
if image_and_mask["image"].width <= image_and_mask["image"].height: | |
width = 512 | |
height = int((width/image_and_mask["image"].width)*image_and_mask["image"].height) | |
else: | |
height = 512 | |
width = int((height/image_and_mask["image"].height)*image_and_mask["image"].width) | |
return post_reqest(0, prompt, width, height, 1, image_and_mask["image"], image_and_mask["mask"]) | |
examples = [ | |
'水墨蝴蝶和牡丹花,国画', | |
'苍劲有力的墨竹,国画', | |
'暴风雨中的灯塔', | |
'机械小松鼠,科学幻想', | |
'中国水墨山水画,国画', | |
"Lighthouse in the storm", | |
"A dog", | |
"Landscape by 张大千", | |
"A tiger 长了兔子耳朵", | |
"A baby bird 铅笔素描", | |
] | |
if __name__ == "__main__": | |
block = gr.Blocks(css=read_content('style.css')) | |
with block: | |
gr.HTML(read_content("header.html")) | |
with gr.Tabs(elem_id='tabss') as tabs: | |
with gr.TabItem("文生图(Text-to-img)", id='txt2img_tab'): | |
with gr.Group(): | |
with gr.Box(): | |
with gr.Row().style(mobile_collapse=False, equal_height=True): | |
text = gr.Textbox( | |
label="Prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Input text(输入文字)", | |
interactive=True, | |
).style( | |
border=(True, False, True, True), | |
rounded=(True, False, False, True), | |
container=False, | |
) | |
btn = gr.Button("Generate image").style( | |
margin=False, | |
rounded=(True, True, True, True), | |
) | |
with gr.Row().style(mobile_collapse=False, equal_height=True): | |
class_draw = gr.Radio(choices=["通用(general)","国画(traditional Chinese painting)",], value="通用(general)", show_label=True, label='生成类型(type)') | |
# class_draw = gr.Dropdown(["通用(general)", "国画(traditional Chinese painting)", | |
# "照片,摄影(picture photography)", "油画(oil painting)", | |
# "铅笔素描(pencil sketch)", "CG", | |
# "水彩画(watercolor painting)", "水墨画(ink and wash)", | |
# "插画(illustrations)", "3D", "图生图(img2img)"], | |
# label="生成类型(type)", | |
# show_label=True, | |
# value="通用(general)") | |
with gr.Row().style(mobile_collapse=False, equal_height=True): | |
style_draw = gr.CheckboxGroup(["蒸汽朋克(steampunk)", "电影摄影风格(film photography)", | |
"概念艺术(concept art)", "Warming lighting", | |
"Dramatic lighting", "Natural lighting", | |
"虚幻引擎(unreal engine)", "4k", "8k", | |
"充满细节(full details)"], | |
label="画面风格(style)", | |
show_label=True, | |
) | |
with gr.Row().style(mobile_collapse=False, equal_height=True): | |
# sample_size = gr.Slider(minimum=1, | |
# maximum=4, | |
# step=1, | |
# label="生成数量(number)", | |
# show_label=True, | |
# interactive=True, | |
# ) | |
sample_size = gr.Radio(choices=["1","2","3","4"], value="1", show_label=True, label='生成数量(number)') | |
seed = gr.Number(0, label='seed', interactive=True) | |
with gr.Row().style(mobile_collapse=False, equal_height=True): | |
w = gr.Slider(512,1024,value=512, step=64, label="width") | |
h = gr.Slider(512,1024,value=512, step=64, label="height") | |
gallery = gr.Gallery( | |
label="Generated images", show_label=False, elem_id="gallery" | |
).style(grid=[2,2]) | |
gr.Examples(examples=examples, fn=request_images, inputs=text, outputs=gallery, examples_per_page=100) | |
with gr.Row().style(mobile_collapse=False, equal_height=True): | |
img_choices = gr.Dropdown(["图片1(img1)"],label='请选择一张图片发送到图生图',show_label=True,value="图片1(img1)") | |
with gr.Row().style(mobile_collapse=False, equal_height=True): | |
output_txt2img_copy_to_input_btn = gr.Button("发送图片到图生图(Sent the image to img2img)").style( | |
margin=False, | |
rounded=(True, True, True, True), | |
) | |
with gr.Row(): | |
prompt = gr.Markdown("提示(Prompt):", visible=False) | |
with gr.Row(): | |
move_prompt_zh = gr.Markdown("请移至图生图部分进行编辑(拉到顶部)", visible=False) | |
with gr.Row(): | |
move_prompt_en = gr.Markdown("Please move to the img2img section for editing(Pull to the top)", visible=False) | |
text.submit(request_images, inputs=[text, class_draw, style_draw, sample_size, w, h, seed], outputs=gallery) | |
btn.click(request_images, inputs=[text, class_draw, style_draw, sample_size, w, h, seed], outputs=gallery) | |
sample_size.change( | |
fn=uifn.change_img_choices, | |
inputs=[sample_size], | |
outputs=[img_choices] | |
) | |
with gr.TabItem("图生图(Img-to-Img)", id="img2img_tab"): | |
with gr.Row(elem_id="prompt_row"): | |
img2img_prompt = gr.Textbox(label="Prompt", | |
elem_id='img2img_prompt_input', | |
placeholder="神奇的森林,流淌的河流.", | |
lines=1, | |
max_lines=1, | |
value="", | |
show_label=False).style() | |
img2img_btn_mask = gr.Button("Generate", variant="primary", visible=False, | |
elem_id="img2img_mask_btn") | |
img2img_btn_editor = gr.Button("Generate", variant="primary", elem_id="img2img_edit_btn") | |
gr.Markdown('#### 输入图像') | |
with gr.Row().style(equal_height=False): | |
#with gr.Column(): | |
img2img_image_mask = gr.Image( | |
value=None, | |
source="upload", | |
interactive=True, | |
tool="sketch", | |
type='pil', | |
elem_id="img2img_mask", | |
image_mode="RGBA" | |
) | |
gr.Markdown('#### 编辑后的图片') | |
with gr.Row(): | |
output_img2img_gallery = gr.Gallery(label="Images", elem_id="img2img_gallery_output").style( | |
grid=[4,4,4] ) | |
with gr.Row(): | |
gr.Markdown('提示(prompt):') | |
with gr.Row(): | |
gr.Markdown('请选择一张图像掩盖掉一部分区域,并输入文本描述') | |
with gr.Row(): | |
gr.Markdown('Please select an image to cover up a part of the area and enter a text description.') | |
gr.Markdown('# 编辑设置',visible=False) | |
output_txt2img_copy_to_input_btn.click( | |
uifn.copy_img_to_input, | |
[gallery, img_choices], | |
[tabs, img2img_image_mask, move_prompt_zh, move_prompt_en, prompt] | |
) | |
img2img_func = img2img | |
img2img_inputs = [img2img_prompt, img2img_image_mask] | |
img2img_outputs = [output_img2img_gallery] | |
img2img_btn_mask.click( | |
img2img_func, | |
img2img_inputs, | |
img2img_outputs | |
) | |
def img2img_submit_params(): | |
return (img2img_func, | |
img2img_inputs, | |
img2img_outputs) | |
img2img_btn_editor.click(*img2img_submit_params()) | |
# GENERATE ON ENTER | |
img2img_prompt.submit(None, None, None, | |
_js=call_JS("clickFirstVisibleButton", | |
rowId="prompt_row")) | |
gr.HTML(read_content("footer.html")) | |
# gr.Image('./contributors.png') | |
block.queue(max_size=512, concurrency_count=256).launch() | |