ZacLiu commited on
Commit
872be97
·
1 Parent(s): 5943e4a

Add application file

Browse files
Files changed (1) hide show
  1. app.py +318 -0
app.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import re
3
+ import imp
4
+ import time
5
+ import json
6
+ import base64
7
+ import requests
8
+ import gradio as gr
9
+ import ui_functions as uifn
10
+ from css_and_js import js, call_JS
11
+ from PIL import Image, PngImagePlugin, ImageChops
12
+
13
+ url_host = "http://flagstudio.baai.ac.cn"
14
+ token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoiZjAxOGMxMzJiYTUyNDBjMzk5NTMzYTI5YjBmMzZiODMiLCJhcHBfbmFtZSI6IndlYiIsImlkZW50aXR5X3R5cGUiOiIyIiwidXNlcl9yb2xlIjoiMiIsImp0aSI6IjVjMmQzMjdiLWI5Y2MtNDhiZS1hZWQ4LTllMjQ4MDk4NzMxYyIsIm5iZiI6MTY2OTAwNjE5NywiZXhwIjoxOTg0MzY2MTk3LCJpYXQiOjE2NjkwMDYxOTd9.9B3MDk8wA6iWH5puXjcD19tJJ4Ox7mdpRyWZs5Kwt70"
15
+
16
+ def read_content(file_path: str) -> str:
17
+ """read the content of target file
18
+ """
19
+ with open(file_path, 'r', encoding='utf-8') as f:
20
+ content = f.read()
21
+
22
+ return content
23
+
24
+ def filter_content(raw_style: str):
25
+ if "(" in raw_style:
26
+ i = raw_style.index("(")
27
+ else :
28
+ i = -1
29
+
30
+ if i == -1:
31
+ return raw_style
32
+ else :
33
+ return raw_style[:i]
34
+
35
+ def upload_image(img):
36
+ url = url_host + "/api/v1/image/get-upload-link"
37
+ headers = {"token": token}
38
+ r = requests.post(url, json={}, headers=headers)
39
+ if r.status_code != 200:
40
+ raise gr.Error(r.reason)
41
+ head_res = r.json()
42
+ if head_res["code"] != 0:
43
+ raise gr.Error("Unknown error")
44
+ image_id = head_res["data"]["image_id"]
45
+ image_url = head_res["data"]["url"]
46
+ image_headers = head_res["data"]["headers"]
47
+
48
+ imgBytes = io.BytesIO()
49
+ img.save(imgBytes, "PNG")
50
+ imgBytes = imgBytes.getvalue()
51
+
52
+ r = requests.put(image_url, data=imgBytes, headers=image_headers)
53
+ if r.status_code != 200:
54
+ raise gr.Error(r.reason)
55
+ return image_id, image_url
56
+
57
+ def post_reqest(seed, prompt, width, height, image_num, img=None, mask=None):
58
+ data = {
59
+ "type": "gen-image",
60
+ "gen_image_num": image_num,
61
+ "parameters": {
62
+ "width": width, # output height width
63
+ "height": height, # output image height
64
+ "prompts": [prompt],
65
+ }
66
+ }
67
+ data["parameters"]["seed"] = int(seed)
68
+ if img is not None:
69
+ # Upload image
70
+ image_id, image_url = upload_image(img)
71
+ data["parameters"]["init_image"] = {
72
+ "image_id": image_id,
73
+ "url": image_url,
74
+ "width": img.width,
75
+ "height": img.height,
76
+ }
77
+ if mask is not None:
78
+ # Upload mask
79
+ extrama = mask.convert("L").getextrema()
80
+ if extrama[1] > 0:
81
+ mask_id, mask_url = upload_image(mask)
82
+ data["parameters"]["mask_image"] = {
83
+ "image_id": mask_id,
84
+ "url": mask_url,
85
+ "width": mask.width,
86
+ "height": mask.height,
87
+ }
88
+ headers = {"token": token}
89
+ # Send create task request
90
+ # url = "http://flagstudio.baai.ac.cn/api/v1/task/create"
91
+ url = url_host+"/api/v1/task/create"
92
+ r = requests.post(url, json=data, headers=headers)
93
+ if r.status_code != 200:
94
+ raise gr.Error(r.reason)
95
+ create_res = r.json()
96
+ task_id = create_res["data"]["task_id"]
97
+
98
+ # Get result
99
+ url = url_host+"/api/v1/task/status"
100
+ while True:
101
+ r = requests.post(url, json=create_res["data"], headers=headers)
102
+ if r.status_code != 200:
103
+ raise gr.Error(r.reason)
104
+ res = r.json()
105
+ if res["code"] == 6002:
106
+ # Running
107
+ time.sleep(1)
108
+ continue
109
+ elif res["code"] == 0:
110
+ # Finished
111
+ images = []
112
+ for img_info in res["data"]["images"]:
113
+ img_res = requests.get(img_info["url"])
114
+ images.append(Image.open(io.BytesIO(img_res.content)).convert("RGB"))
115
+ return images
116
+ else:
117
+ raise gr.Error(f"Error code: {res['code']}")
118
+
119
+ def request_images(raw_text, class_draw, style_draw, batch_size, w, h, seed):
120
+ if filter_content(class_draw) != "国画":
121
+ if filter_content(class_draw) != "通用":
122
+ raw_text = raw_text + f",{filter_content(class_draw)}"
123
+
124
+ for sty in style_draw:
125
+ raw_text = raw_text + f",{filter_content(sty)}"
126
+ elif filter_content(class_draw) == "国画":
127
+ raw_text = raw_text + ",国画,水墨画,大作,黑白,高清,传统"
128
+ print(f"raw text is {raw_text}")
129
+
130
+ images = post_reqest(seed, raw_text, w, h, int(batch_size))
131
+
132
+ return images
133
+
134
+
135
+ def img2img(prompt, image_and_mask):
136
+ return post_reqest(0, prompt, 512, 512, 1, image_and_mask["image"], image_and_mask["mask"])
137
+
138
+
139
+ examples = [
140
+ '水墨蝴蝶和牡丹花,国画',
141
+ '苍劲有力的墨竹,国画',
142
+ '暴风雨中的灯塔',
143
+ '机械小松鼠,科学幻想',
144
+ '中国水墨山水画,国画',
145
+ "Lighthouse in the storm",
146
+ "A dog",
147
+ "Landscape by 张大千",
148
+ "A tiger 长��兔子耳朵",
149
+ "A baby bird 铅笔素描",
150
+ ]
151
+
152
+ if __name__ == "__main__":
153
+ block = gr.Blocks(css=read_content('style.css'))
154
+
155
+ with block:
156
+ gr.HTML(read_content("header.html"))
157
+ with gr.Tabs(elem_id='tabss') as tabs:
158
+
159
+ with gr.TabItem("文生图(Text-to-img)", id='txt2img_tab'):
160
+
161
+ with gr.Group():
162
+ with gr.Box():
163
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
164
+ text = gr.Textbox(
165
+ label="Prompt",
166
+ show_label=False,
167
+ max_lines=1,
168
+ placeholder="Input text(输入文字)",
169
+ interactive=True,
170
+ ).style(
171
+ border=(True, False, True, True),
172
+ rounded=(True, False, False, True),
173
+ container=False,
174
+ )
175
+
176
+ btn = gr.Button("Generate image").style(
177
+ margin=False,
178
+ rounded=(True, True, True, True),
179
+ )
180
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
181
+ class_draw = gr.Radio(choices=["通用(general)","国画(traditional Chinese painting)",], value="通用(general)", show_label=True, label='生成类型(type)')
182
+ # class_draw = gr.Dropdown(["通用(general)", "国画(traditional Chinese painting)",
183
+ # "照片,摄影(picture photography)", "油画(oil painting)",
184
+ # "铅笔素描(pencil sketch)", "CG",
185
+ # "水彩画(watercolor painting)", "水墨画(ink and wash)",
186
+ # "插画(illustrations)", "3D", "图生图(img2img)"],
187
+ # label="生成类型(type)",
188
+ # show_label=True,
189
+ # value="通用(general)")
190
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
191
+ style_draw = gr.CheckboxGroup(["蒸汽朋克(steampunk)", "电影摄影风格(film photography)",
192
+ "概念艺术(concept art)", "Warming lighting",
193
+ "Dramatic lighting", "Natural lighting",
194
+ "虚幻引擎(unreal engine)", "4k", "8k",
195
+ "充满细节(full details)"],
196
+ label="画面风格(style)",
197
+ show_label=True,
198
+ )
199
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
200
+ # sample_size = gr.Slider(minimum=1,
201
+ # maximum=4,
202
+ # step=1,
203
+ # label="生成数量(number)",
204
+ # show_label=True,
205
+ # interactive=True,
206
+ # )
207
+ sample_size = gr.Radio(choices=["1","2","3","4"], value="1", show_label=True, label='生成数量(number)')
208
+ seed = gr.Number(0, label='seed', interactive=True)
209
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
210
+ w = gr.Slider(512,1024,value=512, step=64, label="width")
211
+ h = gr.Slider(512,1024,value=512, step=64, label="height")
212
+
213
+ gallery = gr.Gallery(
214
+ label="Generated images", show_label=False, elem_id="gallery"
215
+ ).style(grid=[2,2])
216
+ gr.Examples(examples=examples, fn=request_images, inputs=text, outputs=gallery, examples_per_page=100)
217
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
218
+ img_choices = gr.Dropdown(["图片1(img1)"],label='请选择一张图片发送到图生图',show_label=True,value="图片1(img1)")
219
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
220
+ output_txt2img_copy_to_input_btn = gr.Button("发送图片到图生图(Sent the image to img2img)").style(
221
+ margin=False,
222
+ rounded=(True, True, True, True),
223
+ )
224
+
225
+ with gr.Row():
226
+ prompt = gr.Markdown("提示(Prompt):", visible=False)
227
+ with gr.Row():
228
+ move_prompt_zh = gr.Markdown("请移至图生图部分进行编辑(拉到顶部)", visible=False)
229
+ with gr.Row():
230
+ move_prompt_en = gr.Markdown("Please move to the img2img section for editing(Pull to the top)", visible=False)
231
+
232
+
233
+
234
+ text.submit(request_images, inputs=[text, class_draw, style_draw, sample_size, w, h, seed], outputs=gallery)
235
+ btn.click(request_images, inputs=[text, class_draw, style_draw, sample_size, w, h, seed], outputs=gallery)
236
+
237
+ sample_size.change(
238
+ fn=uifn.change_img_choices,
239
+ inputs=[sample_size],
240
+ outputs=[img_choices]
241
+ )
242
+
243
+ with gr.TabItem("图生图(Img-to-Img)", id="img2img_tab"):
244
+ with gr.Row(elem_id="prompt_row"):
245
+ img2img_prompt = gr.Textbox(label="Prompt",
246
+ elem_id='img2img_prompt_input',
247
+ placeholder="神奇的森林,流淌的河流.",
248
+ lines=1,
249
+ max_lines=1,
250
+ value="",
251
+ show_label=False).style()
252
+
253
+ img2img_btn_mask = gr.Button("Generate", variant="primary", visible=False,
254
+ elem_id="img2img_mask_btn")
255
+ img2img_btn_editor = gr.Button("Generate", variant="primary", elem_id="img2img_edit_btn")
256
+ gr.Markdown('#### 输入图像')
257
+ with gr.Row().style(equal_height=False):
258
+ #with gr.Column():
259
+ img2img_image_mask = gr.Image(
260
+ value=None,
261
+ source="upload",
262
+ interactive=True,
263
+ tool="sketch",
264
+ type='pil',
265
+ elem_id="img2img_mask",
266
+ image_mode="RGBA"
267
+ )
268
+ gr.Markdown('#### 编辑后的图片')
269
+ with gr.Row():
270
+ output_img2img_gallery = gr.Gallery(label="Images", elem_id="img2img_gallery_output").style(
271
+ grid=[4,4,4] )
272
+ with gr.Row():
273
+ gr.Markdown('提示(prompt):')
274
+ with gr.Row():
275
+ gr.Markdown('请选择一张图像掩盖掉一部分区域,并输入文本描述')
276
+ with gr.Row():
277
+ gr.Markdown('Please select an image to cover up a part of the area and enter a text description.')
278
+ gr.Markdown('# 编辑设置',visible=False)
279
+
280
+
281
+ output_txt2img_copy_to_input_btn.click(
282
+ uifn.copy_img_to_input,
283
+ [gallery, img_choices],
284
+ [tabs, img2img_image_mask, move_prompt_zh, move_prompt_en, prompt]
285
+ )
286
+
287
+
288
+ img2img_func = img2img
289
+ img2img_inputs = [img2img_prompt, img2img_image_mask]
290
+ img2img_outputs = [output_img2img_gallery]
291
+
292
+ img2img_btn_mask.click(
293
+ img2img_func,
294
+ img2img_inputs,
295
+ img2img_outputs
296
+ )
297
+
298
+ def img2img_submit_params():
299
+ return (img2img_func,
300
+ img2img_inputs,
301
+ img2img_outputs)
302
+
303
+ img2img_btn_editor.click(*img2img_submit_params())
304
+
305
+ # GENERATE ON ENTER
306
+ img2img_prompt.submit(None, None, None,
307
+ _js=call_JS("clickFirstVisibleButton",
308
+ rowId="prompt_row"))
309
+
310
+ gr.HTML(read_content("footer.html"))
311
+ # gr.Image('./contributors.png')
312
+
313
+ block.queue(max_size=50, concurrency_count=20).launch(
314
+ # share=True,
315
+ show_error=True,
316
+ server_name="0.0.0.0",
317
+ server_port=43523,
318
+ )