tianleliphoebe commited on
Commit
42acef2
1 Parent(s): 97b8fe0

update NSFW

Browse files
Files changed (1) hide show
  1. serve/vote_utils.py +0 -1575
serve/vote_utils.py CHANGED
@@ -1,1575 +0,0 @@
1
- import datetime
2
- import time
3
- import json
4
- import uuid
5
- import gradio as gr
6
- import regex as re
7
- from pathlib import Path
8
- from .utils import *
9
- from .log_utils import build_logger
10
- from .constants import IMAGE_DIR, VIDEO_DIR
11
- import imageio
12
- from diffusers.utils import load_image
13
- import torch
14
-
15
- ig_logger = build_logger("gradio_web_server_image_generation", "gr_web_image_generation.log") # ig = image generation, loggers for single model direct chat
16
- igm_logger = build_logger("gradio_web_server_image_generation_multi", "gr_web_image_generation_multi.log") # igm = image generation multi, loggers for side-by-side and battle
17
- ie_logger = build_logger("gradio_web_server_image_editing", "gr_web_image_editing.log") # ie = image editing, loggers for single model direct chat
18
- iem_logger = build_logger("gradio_web_server_image_editing_multi", "gr_web_image_editing_multi.log") # iem = image editing multi, loggers for side-by-side and battle
19
- vg_logger = build_logger("gradio_web_server_video_generation", "gr_web_video_generation.log") # vg = video generation, loggers for single model direct chat
20
- vgm_logger = build_logger("gradio_web_server_video_generation_multi", "gr_web_video_generation_multi.log") # vgm = video generation multi, loggers for side-by-side and battle
21
-
22
- def save_any_image(image_file, file_path):
23
- if isinstance(image_file, str):
24
- image = load_image(image_file)
25
- image.save(file_path, 'JPEG')
26
- else:
27
- image_file.save(file_path, 'JPEG')
28
-
29
- def vote_last_response_ig(state, vote_type, model_selector, request: gr.Request):
30
- with open(get_conv_log_filename(), "a") as fout:
31
- data = {
32
- "tstamp": round(time.time(), 4),
33
- "type": vote_type,
34
- "model": model_selector,
35
- "state": state.dict(),
36
- "ip": get_ip(request),
37
- }
38
- fout.write(json.dumps(data) + "\n")
39
- append_json_item_on_log_server(data, get_conv_log_filename())
40
- output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
41
- with open(output_file, 'w') as f:
42
- save_any_image(state.output, f)
43
- save_image_file_on_log_server(output_file)
44
-
45
- def vote_last_response_igm(states, vote_type, model_selectors, request: gr.Request):
46
- with open(get_conv_log_filename(), "a") as fout:
47
- data = {
48
- "tstamp": round(time.time(), 4),
49
- "type": vote_type,
50
- "models": [x for x in model_selectors],
51
- "states": [x.dict() for x in states],
52
- "ip": get_ip(request),
53
- }
54
- fout.write(json.dumps(data) + "\n")
55
- append_json_item_on_log_server(data, get_conv_log_filename())
56
- for state in states:
57
- output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
58
- with open(output_file, 'w') as f:
59
- save_any_image(state.output, f)
60
- save_image_file_on_log_server(output_file)
61
-
62
- def vote_last_response_ie(state, vote_type, model_selector, request: gr.Request):
63
- with open(get_conv_log_filename(), "a") as fout:
64
- data = {
65
- "tstamp": round(time.time(), 4),
66
- "type": vote_type,
67
- "model": model_selector,
68
- "state": state.dict(),
69
- "ip": get_ip(request),
70
- }
71
- fout.write(json.dumps(data) + "\n")
72
- append_json_item_on_log_server(data, get_conv_log_filename())
73
- output_file = f'{IMAGE_DIR}/edition/{state.conv_id}.jpg'
74
- source_file = f'{IMAGE_DIR}/edition/{state.conv_id}_source.jpg'
75
- with open(output_file, 'w') as f:
76
- save_any_image(state.output, f)
77
- with open(source_file, 'w') as sf:
78
- save_any_image(state.source_image, sf)
79
- save_image_file_on_log_server(output_file)
80
- save_image_file_on_log_server(source_file)
81
-
82
- def vote_last_response_iem(states, vote_type, model_selectors, request: gr.Request):
83
- with open(get_conv_log_filename(), "a") as fout:
84
- data = {
85
- "tstamp": round(time.time(), 4),
86
- "type": vote_type,
87
- "models": [x for x in model_selectors],
88
- "states": [x.dict() for x in states],
89
- "ip": get_ip(request),
90
- }
91
- fout.write(json.dumps(data) + "\n")
92
- append_json_item_on_log_server(data, get_conv_log_filename())
93
- for state in states:
94
- output_file = f'{IMAGE_DIR}/edition/{state.conv_id}.jpg'
95
- source_file = f'{IMAGE_DIR}/edition/{state.conv_id}_source.jpg'
96
- with open(output_file, 'w') as f:
97
- save_any_image(state.output, f)
98
- with open(source_file, 'w') as sf:
99
- save_any_image(state.source_image, sf)
100
- save_image_file_on_log_server(output_file)
101
- save_image_file_on_log_server(source_file)
102
-
103
-
104
- def vote_last_response_vg(state, vote_type, model_selector, request: gr.Request):
105
- with open(get_conv_log_filename(), "a") as fout:
106
- data = {
107
- "tstamp": round(time.time(), 4),
108
- "type": vote_type,
109
- "model": model_selector,
110
- "state": state.dict(),
111
- "ip": get_ip(request),
112
- }
113
- fout.write(json.dumps(data) + "\n")
114
- append_json_item_on_log_server(data, get_conv_log_filename())
115
-
116
- output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
117
- os.makedirs(os.path.dirname(output_file), exist_ok=True)
118
- if state.model_name.startswith('fal'):
119
- r = requests.get(state.output)
120
- with open(output_file, 'wb') as outfile:
121
- outfile.write(r.content)
122
- else:
123
- print("======== video shape: ========")
124
- print(state.output.shape)
125
- # Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels]
126
- if state.output.shape[-1] != 3:
127
- state.output = state.output.permute(0, 2, 3, 1)
128
- imageio.mimwrite(output_file, state.output, fps=8, quality=9)
129
- save_video_file_on_log_server(output_file)
130
-
131
-
132
-
133
- def vote_last_response_vgm(states, vote_type, model_selectors, request: gr.Request):
134
- with open(get_conv_log_filename(), "a") as fout:
135
- data = {
136
- "tstamp": round(time.time(), 4),
137
- "type": vote_type,
138
- "models": [x for x in model_selectors],
139
- "states": [x.dict() for x in states],
140
- "ip": get_ip(request),
141
- }
142
- fout.write(json.dumps(data) + "\n")
143
- append_json_item_on_log_server(data, get_conv_log_filename())
144
- for state in states:
145
- output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
146
- os.makedirs(os.path.dirname(output_file), exist_ok=True)
147
- if state.model_name.startswith('fal'):
148
- r = requests.get(state.output)
149
- with open(output_file, 'wb') as outfile:
150
- outfile.write(r.content)
151
- elif isinstance(state.output, torch.Tensor):
152
- print("======== video shape: ========")
153
- print(state.output.shape)
154
- # Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels]
155
- if state.output.shape[-1] != 3:
156
- state.output = state.output.permute(0, 2, 3, 1)
157
- imageio.mimwrite(output_file, state.output, fps=8, quality=9)
158
- else:
159
- r = requests.get(state.output)
160
- with open(output_file, 'wb') as outfile:
161
- outfile.write(r.content)
162
- save_video_file_on_log_server(output_file)
163
-
164
-
165
- ## Image Generation (IG) Single Model Direct Chat
166
- def upvote_last_response_ig(state, model_selector, request: gr.Request):
167
- ip = get_ip(request)
168
- ig_logger.info(f"upvote. ip: {ip}")
169
- vote_last_response_ig(state, "upvote", model_selector, request)
170
- return ("",) + (disable_btn,) * 3
171
-
172
- def downvote_last_response_ig(state, model_selector, request: gr.Request):
173
- ip = get_ip(request)
174
- ig_logger.info(f"downvote. ip: {ip}")
175
- vote_last_response_ig(state, "downvote", model_selector, request)
176
- return ("",) + (disable_btn,) * 3
177
-
178
-
179
- def flag_last_response_ig(state, model_selector, request: gr.Request):
180
- ip = get_ip(request)
181
- ig_logger.info(f"flag. ip: {ip}")
182
- vote_last_response_ig(state, "flag", model_selector, request)
183
- return ("",) + (disable_btn,) * 3
184
-
185
- ## Image Generation Multi (IGM) Side-by-Side and Battle
186
-
187
- def leftvote_last_response_igm(
188
- state0, state1, model_selector0, model_selector1, request: gr.Request
189
- ):
190
- igm_logger.info(f"leftvote (named). ip: {get_ip(request)}")
191
- vote_last_response_igm(
192
- [state0, state1], "leftvote", [model_selector0, model_selector1], request
193
- )
194
- if model_selector0 == "":
195
- return ("",) + (disable_btn,) * 4 + (
196
- gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
197
- gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
198
- else:
199
- return ("",) + (disable_btn,) * 4 + (gr.Markdown(state0.model_name, visible=True),
200
- gr.Markdown(state1.model_name, visible=True))
201
-
202
- def rightvote_last_response_igm(
203
- state0, state1, model_selector0, model_selector1, request: gr.Request
204
- ):
205
- igm_logger.info(f"rightvote (named). ip: {get_ip(request)}")
206
- vote_last_response_igm(
207
- [state0, state1], "rightvote", [model_selector0, model_selector1], request
208
- )
209
- print(model_selector0)
210
- if model_selector0 == "":
211
- return ("",) + (disable_btn,) * 4 + (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
212
- else:
213
- return ("",) + (disable_btn,) * 4 + (gr.Markdown(state0.model_name, visible=True),
214
- gr.Markdown(state1.model_name, visible=True))
215
-
216
-
217
- def tievote_last_response_igm(
218
- state0, state1, model_selector0, model_selector1, request: gr.Request
219
- ):
220
- igm_logger.info(f"tievote (named). ip: {get_ip(request)}")
221
- vote_last_response_igm(
222
- [state0, state1], "tievote", [model_selector0, model_selector1], request
223
- )
224
- if model_selector0 == "":
225
- return ("",) + (disable_btn,) * 4 + (
226
- gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
227
- gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
228
- else:
229
- return ("",) + (disable_btn,) * 4 + (gr.Markdown(state0.model_name, visible=True),
230
- gr.Markdown(state1.model_name, visible=True))
231
-
232
-
233
- def bothbad_vote_last_response_igm(
234
- state0, state1, model_selector0, model_selector1, request: gr.Request
235
- ):
236
- igm_logger.info(f"bothbad_vote (named). ip: {get_ip(request)}")
237
- vote_last_response_igm(
238
- [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
239
- )
240
- if model_selector0 == "":
241
- return ("",) + (disable_btn,) * 4 + (
242
- gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
243
- gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
244
- else:
245
- return ("",) + (disable_btn,) * 4 + (gr.Markdown(state0.model_name, visible=True),
246
- gr.Markdown(state1.model_name, visible=True))
247
-
248
- ## Image Editing (IE) Single Model Direct Chat
249
-
250
- def upvote_last_response_ie(state, model_selector, request: gr.Request):
251
- ip = get_ip(request)
252
- ie_logger.info(f"upvote. ip: {ip}")
253
- vote_last_response_ie(state, "upvote", model_selector, request)
254
- return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3
255
-
256
- def downvote_last_response_ie(state, model_selector, request: gr.Request):
257
- ip = get_ip(request)
258
- ie_logger.info(f"downvote. ip: {ip}")
259
- vote_last_response_ie(state, "downvote", model_selector, request)
260
- return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3
261
-
262
- def flag_last_response_ie(state, model_selector, request: gr.Request):
263
- ip = get_ip(request)
264
- ie_logger.info(f"flag. ip: {ip}")
265
- vote_last_response_ie(state, "flag", model_selector, request)
266
- return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3
267
-
268
- ## Image Editing Multi (IEM) Side-by-Side and Battle
269
- def leftvote_last_response_iem(
270
- state0, state1, model_selector0, model_selector1, request: gr.Request
271
- ):
272
- iem_logger.info(f"leftvote (anony). ip: {get_ip(request)}")
273
- vote_last_response_iem(
274
- [state0, state1], "leftvote", [model_selector0, model_selector1], request
275
- )
276
- # names = (
277
- # "### Model A: " + state0.model_name,
278
- # "### Model B: " + state1.model_name,
279
- # )
280
- # names = (state0.model_name, state1.model_name)
281
- if model_selector0 == "":
282
- names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
283
- else:
284
- names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False))
285
- return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4
286
-
287
- def rightvote_last_response_iem(
288
- state0, state1, model_selector0, model_selector1, request: gr.Request
289
- ):
290
- iem_logger.info(f"rightvote (anony). ip: {get_ip(request)}")
291
- vote_last_response_iem(
292
- [state0, state1], "rightvote", [model_selector0, model_selector1], request
293
- )
294
- # names = (
295
- # "### Model A: " + state0.model_name,
296
- # "### Model B: " + state1.model_name,
297
- # )
298
- if model_selector0 == "":
299
- names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
300
- gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
301
- else:
302
- names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False))
303
- return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4
304
-
305
- def tievote_last_response_iem(
306
- state0, state1, model_selector0, model_selector1, request: gr.Request
307
- ):
308
- iem_logger.info(f"tievote (anony). ip: {get_ip(request)}")
309
- vote_last_response_iem(
310
- [state0, state1], "tievote", [model_selector0, model_selector1], request
311
- )
312
- if model_selector0 == "":
313
- names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
314
- gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
315
- else:
316
- names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False))
317
- return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4
318
-
319
- def bothbad_vote_last_response_iem(
320
- state0, state1, model_selector0, model_selector1, request: gr.Request
321
- ):
322
- iem_logger.info(f"bothbad_vote (anony). ip: {get_ip(request)}")
323
- vote_last_response_iem(
324
- [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
325
- )
326
- if model_selector0 == "":
327
- names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
328
- gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
329
- else:
330
- names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False))
331
- return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4
332
-
333
-
334
- ## Video Generation (VG) Single Model Direct Chat
335
- def upvote_last_response_vg(state, model_selector, request: gr.Request):
336
- ip = get_ip(request)
337
- vg_logger.info(f"upvote. ip: {ip}")
338
- vote_last_response_vg(state, "upvote", model_selector, request)
339
- return ("",) + (disable_btn,) * 3
340
-
341
- def downvote_last_response_vg(state, model_selector, request: gr.Request):
342
- ip = get_ip(request)
343
- vg_logger.info(f"downvote. ip: {ip}")
344
- vote_last_response_vg(state, "downvote", model_selector, request)
345
- return ("",) + (disable_btn,) * 3
346
-
347
-
348
- def flag_last_response_vg(state, model_selector, request: gr.Request):
349
- ip = get_ip(request)
350
- vg_logger.info(f"flag. ip: {ip}")
351
- vote_last_response_vg(state, "flag", model_selector, request)
352
- return ("",) + (disable_btn,) * 3
353
-
354
- ## Image Generation Multi (IGM) Side-by-Side and Battle
355
-
356
- def leftvote_last_response_vgm(
357
- state0, state1, model_selector0, model_selector1, request: gr.Request
358
- ):
359
- vgm_logger.info(f"leftvote (named). ip: {get_ip(request)}")
360
- vote_last_response_vgm(
361
- [state0, state1], "leftvote", [model_selector0, model_selector1], request
362
- )
363
- if model_selector0 == "":
364
- return ("",) + (disable_btn,) * 4 + (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
365
- else:
366
- return ("",) + (disable_btn,) * 4 + (
367
- gr.Markdown(state0.model_name, visible=False),
368
- gr.Markdown(state1.model_name, visible=False))
369
-
370
-
371
- def rightvote_last_response_vgm(
372
- state0, state1, model_selector0, model_selector1, request: gr.Request
373
- ):
374
- vgm_logger.info(f"rightvote (named). ip: {get_ip(request)}")
375
- vote_last_response_vgm(
376
- [state0, state1], "rightvote", [model_selector0, model_selector1], request
377
- )
378
- if model_selector0 == "":
379
- return ("",) + (disable_btn,) * 4 + (
380
- gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
381
- gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
382
- else:
383
- return ("",) + (disable_btn,) * 4 + (
384
- gr.Markdown(state0.model_name, visible=False),
385
- gr.Markdown(state1.model_name, visible=False))
386
-
387
- def tievote_last_response_vgm(
388
- state0, state1, model_selector0, model_selector1, request: gr.Request
389
- ):
390
- vgm_logger.info(f"tievote (named). ip: {get_ip(request)}")
391
- vote_last_response_vgm(
392
- [state0, state1], "tievote", [model_selector0, model_selector1], request
393
- )
394
- if model_selector0 == "":
395
- return ("",) + (disable_btn,) * 4 + (
396
- gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
397
- gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
398
- else:
399
- return ("",) + (disable_btn,) * 4 + (
400
- gr.Markdown(state0.model_name, visible=False),
401
- gr.Markdown(state1.model_name, visible=False))
402
-
403
-
404
- def bothbad_vote_last_response_vgm(
405
- state0, state1, model_selector0, model_selector1, request: gr.Request
406
- ):
407
- vgm_logger.info(f"bothbad_vote (named). ip: {get_ip(request)}")
408
- vote_last_response_vgm(
409
- [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
410
- )
411
- if model_selector0 == "":
412
- return ("",) + (disable_btn,) * 4 + (
413
- gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
414
- gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
415
- else:
416
- return ("",) + (disable_btn,) * 4 + (
417
- gr.Markdown(state0.model_name, visible=False),
418
- gr.Markdown(state1.model_name, visible=False))
419
-
420
- share_js = """
421
- function (a, b, c, d) {
422
- const captureElement = document.querySelector('#share-region-named');
423
- html2canvas(captureElement)
424
- .then(canvas => {
425
- canvas.style.display = 'none'
426
- document.body.appendChild(canvas)
427
- return canvas
428
- })
429
- .then(canvas => {
430
- const image = canvas.toDataURL('image/png')
431
- const a = document.createElement('a')
432
- a.setAttribute('download', 'chatbot-arena.png')
433
- a.setAttribute('href', image)
434
- a.click()
435
- canvas.remove()
436
- });
437
- return [a, b, c, d];
438
- }
439
- """
440
- def share_click_igm(state0, state1, model_selector0, model_selector1, request: gr.Request):
441
- igm_logger.info(f"share (anony). ip: {get_ip(request)}")
442
- if state0 is not None and state1 is not None:
443
- vote_last_response_igm(
444
- [state0, state1], "share", [model_selector0, model_selector1], request
445
- )
446
-
447
- def share_click_iem(state0, state1, model_selector0, model_selector1, request: gr.Request):
448
- iem_logger.info(f"share (anony). ip: {get_ip(request)}")
449
- if state0 is not None and state1 is not None:
450
- vote_last_response_iem(
451
- [state0, state1], "share", [model_selector0, model_selector1], request
452
- )
453
-
454
- ## All Generation Gradio Interface
455
-
456
- class ImageStateIG:
457
- def __init__(self, model_name):
458
- self.conv_id = uuid.uuid4().hex
459
- self.model_name = model_name
460
- self.prompt = None
461
- self.output = None
462
-
463
- def dict(self):
464
- base = {
465
- "conv_id": self.conv_id,
466
- "model_name": self.model_name,
467
- "prompt": self.prompt
468
- }
469
- return base
470
-
471
- class ImageStateIE:
472
- def __init__(self, model_name):
473
- self.conv_id = uuid.uuid4().hex
474
- self.model_name = model_name
475
- self.source_prompt = None
476
- self.target_prompt = None
477
- self.instruct_prompt = None
478
- self.source_image = None
479
- self.output = None
480
-
481
- def dict(self):
482
- base = {
483
- "conv_id": self.conv_id,
484
- "model_name": self.model_name,
485
- "source_prompt": self.source_prompt,
486
- "target_prompt": self.target_prompt,
487
- "instruct_prompt": self.instruct_prompt
488
- }
489
- return base
490
-
491
- class VideoStateVG:
492
- def __init__(self, model_name):
493
- self.conv_id = uuid.uuid4().hex
494
- self.model_name = model_name
495
- self.prompt = None
496
- self.output = None
497
-
498
- def dict(self):
499
- base = {
500
- "conv_id": self.conv_id,
501
- "model_name": self.model_name,
502
- "prompt": self.prompt
503
- }
504
- return base
505
-
506
-
507
- def generate_ig(gen_func, state, text, model_name, request: gr.Request):
508
- if not text:
509
- raise gr.Warning("Prompt cannot be empty.")
510
- if not model_name:
511
- raise gr.Warning("Model name cannot be empty.")
512
- state = ImageStateIG(model_name)
513
- ip = get_ip(request)
514
- ig_logger.info(f"generate. ip: {ip}")
515
- start_tstamp = time.time()
516
- generated_image = gen_func(text, model_name)
517
- if generated_image == '':
518
- raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
519
- state.prompt = text
520
- state.output = generated_image
521
- state.model_name = model_name
522
-
523
- yield state, generated_image
524
-
525
- finish_tstamp = time.time()
526
- # logger.info(f"===output===: {output}")
527
-
528
- with open(get_conv_log_filename(), "a") as fout:
529
- data = {
530
- "tstamp": round(finish_tstamp, 4),
531
- "type": "chat",
532
- "model": model_name,
533
- "gen_params": {},
534
- "start": round(start_tstamp, 4),
535
- "finish": round(finish_tstamp, 4),
536
- "state": state.dict(),
537
- "ip": get_ip(request),
538
- }
539
- fout.write(json.dumps(data) + "\n")
540
- append_json_item_on_log_server(data, get_conv_log_filename())
541
-
542
- output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
543
- os.makedirs(os.path.dirname(output_file), exist_ok=True)
544
- with open(output_file, 'w') as f:
545
- save_any_image(state.output, f)
546
- save_image_file_on_log_server(output_file)
547
-
548
- def generate_ig_museum(gen_func, state, model_name, request: gr.Request):
549
- if not model_name:
550
- raise gr.Warning("Model name cannot be empty.")
551
- state = ImageStateIG(model_name)
552
- ip = get_ip(request)
553
- ig_logger.info(f"generate. ip: {ip}")
554
- start_tstamp = time.time()
555
- generated_image, text = gen_func(model_name)
556
- state.prompt = text
557
- state.output = generated_image
558
- state.model_name = model_name
559
-
560
- yield state, generated_image, text
561
-
562
- finish_tstamp = time.time()
563
- # logger.info(f"===output===: {output}")
564
-
565
- with open(get_conv_log_filename(), "a") as fout:
566
- data = {
567
- "tstamp": round(finish_tstamp, 4),
568
- "type": "chat",
569
- "model": model_name,
570
- "gen_params": {},
571
- "start": round(start_tstamp, 4),
572
- "finish": round(finish_tstamp, 4),
573
- "state": state.dict(),
574
- "ip": get_ip(request),
575
- }
576
- fout.write(json.dumps(data) + "\n")
577
- append_json_item_on_log_server(data, get_conv_log_filename())
578
-
579
- output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
580
- os.makedirs(os.path.dirname(output_file), exist_ok=True)
581
- with open(output_file, 'w') as f:
582
- save_any_image(state.output, f)
583
- save_image_file_on_log_server(output_file)
584
-
585
- def generate_igm(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request):
586
- if not text:
587
- raise gr.Warning("Prompt cannot be empty.")
588
- if not model_name0:
589
- raise gr.Warning("Model name A cannot be empty.")
590
- if not model_name1:
591
- raise gr.Warning("Model name B cannot be empty.")
592
- state0 = ImageStateIG(model_name0)
593
- state1 = ImageStateIG(model_name1)
594
- ip = get_ip(request)
595
- igm_logger.info(f"generate. ip: {ip}")
596
- start_tstamp = time.time()
597
- # Remove ### Model (A|B): from model name
598
- model_name0 = re.sub(r"### Model A: ", "", model_name0)
599
- model_name1 = re.sub(r"### Model B: ", "", model_name1)
600
- generated_image0, generated_image1 = gen_func(text, model_name0, model_name1)
601
- if generated_image0 == '' and generated_image1 == '':
602
- raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
603
- state0.prompt = text
604
- state1.prompt = text
605
- state0.output = generated_image0
606
- state1.output = generated_image1
607
- state0.model_name = model_name0
608
- state1.model_name = model_name1
609
-
610
- yield state0, state1, generated_image0, generated_image1
611
-
612
- finish_tstamp = time.time()
613
- # logger.info(f"===output===: {output}")
614
-
615
- with open(get_conv_log_filename(), "a") as fout:
616
- data = {
617
- "tstamp": round(finish_tstamp, 4),
618
- "type": "chat",
619
- "model": model_name0,
620
- "gen_params": {},
621
- "start": round(start_tstamp, 4),
622
- "finish": round(finish_tstamp, 4),
623
- "state": state0.dict(),
624
- "ip": get_ip(request),
625
- }
626
- fout.write(json.dumps(data) + "\n")
627
- append_json_item_on_log_server(data, get_conv_log_filename())
628
- data = {
629
- "tstamp": round(finish_tstamp, 4),
630
- "type": "chat",
631
- "model": model_name1,
632
- "gen_params": {},
633
- "start": round(start_tstamp, 4),
634
- "finish": round(finish_tstamp, 4),
635
- "state": state1.dict(),
636
- "ip": get_ip(request),
637
- }
638
- fout.write(json.dumps(data) + "\n")
639
- append_json_item_on_log_server(data, get_conv_log_filename())
640
-
641
- for i, state in enumerate([state0, state1]):
642
- output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
643
- os.makedirs(os.path.dirname(output_file), exist_ok=True)
644
- with open(output_file, 'w') as f:
645
- save_any_image(state.output, f)
646
- save_image_file_on_log_server(output_file)
647
-
648
- def generate_igm_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request):
649
- if not model_name0:
650
- raise gr.Warning("Model name A cannot be empty.")
651
- if not model_name1:
652
- raise gr.Warning("Model name B cannot be empty.")
653
- state0 = ImageStateIG(model_name0)
654
- state1 = ImageStateIG(model_name1)
655
- ip = get_ip(request)
656
- igm_logger.info(f"generate. ip: {ip}")
657
- start_tstamp = time.time()
658
- # Remove ### Model (A|B): from model name
659
- model_name0 = re.sub(r"### Model A: ", "", model_name0)
660
- model_name1 = re.sub(r"### Model B: ", "", model_name1)
661
- generated_image0, generated_image1, text = gen_func(model_name0, model_name1)
662
- state0.prompt = text
663
- state1.prompt = text
664
- state0.output = generated_image0
665
- state1.output = generated_image1
666
- state0.model_name = model_name0
667
- state1.model_name = model_name1
668
-
669
- yield state0, state1, generated_image0, generated_image1, text
670
-
671
- finish_tstamp = time.time()
672
- # logger.info(f"===output===: {output}")
673
-
674
- with open(get_conv_log_filename(), "a") as fout:
675
- data = {
676
- "tstamp": round(finish_tstamp, 4),
677
- "type": "chat",
678
- "model": model_name0,
679
- "gen_params": {},
680
- "start": round(start_tstamp, 4),
681
- "finish": round(finish_tstamp, 4),
682
- "state": state0.dict(),
683
- "ip": get_ip(request),
684
- }
685
- fout.write(json.dumps(data) + "\n")
686
- append_json_item_on_log_server(data, get_conv_log_filename())
687
- data = {
688
- "tstamp": round(finish_tstamp, 4),
689
- "type": "chat",
690
- "model": model_name1,
691
- "gen_params": {},
692
- "start": round(start_tstamp, 4),
693
- "finish": round(finish_tstamp, 4),
694
- "state": state1.dict(),
695
- "ip": get_ip(request),
696
- }
697
- fout.write(json.dumps(data) + "\n")
698
- append_json_item_on_log_server(data, get_conv_log_filename())
699
-
700
- for i, state in enumerate([state0, state1]):
701
- output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
702
- os.makedirs(os.path.dirname(output_file), exist_ok=True)
703
- with open(output_file, 'w') as f:
704
- save_any_image(state.output, f)
705
- save_image_file_on_log_server(output_file)
706
-
707
-
708
- def generate_igm_annoy(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request):
709
- if not text:
710
- raise gr.Warning("Prompt cannot be empty.")
711
- state0 = ImageStateIG(model_name0)
712
- state1 = ImageStateIG(model_name1)
713
- ip = get_ip(request)
714
- igm_logger.info(f"generate. ip: {ip}")
715
- start_tstamp = time.time()
716
- model_name0 = ""
717
- model_name1 = ""
718
- generated_image0, generated_image1, model_name0, model_name1 = gen_func(text, model_name0, model_name1)
719
- if generated_image0 == '' and generated_image1 == '':
720
- raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
721
- state0.prompt = text
722
- state1.prompt = text
723
- state0.output = generated_image0
724
- state1.output = generated_image1
725
- state0.model_name = model_name0
726
- state1.model_name = model_name1
727
-
728
- yield state0, state1, generated_image0, generated_image1, \
729
- gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False)
730
-
731
- finish_tstamp = time.time()
732
- # logger.info(f"===output===: {output}")
733
-
734
- with open(get_conv_log_filename(), "a") as fout:
735
- data = {
736
- "tstamp": round(finish_tstamp, 4),
737
- "type": "chat",
738
- "model": model_name0,
739
- "gen_params": {},
740
- "start": round(start_tstamp, 4),
741
- "finish": round(finish_tstamp, 4),
742
- "state": state0.dict(),
743
- "ip": get_ip(request),
744
- }
745
- fout.write(json.dumps(data) + "\n")
746
- append_json_item_on_log_server(data, get_conv_log_filename())
747
- data = {
748
- "tstamp": round(finish_tstamp, 4),
749
- "type": "chat",
750
- "model": model_name1,
751
- "gen_params": {},
752
- "start": round(start_tstamp, 4),
753
- "finish": round(finish_tstamp, 4),
754
- "state": state1.dict(),
755
- "ip": get_ip(request),
756
- }
757
- fout.write(json.dumps(data) + "\n")
758
- append_json_item_on_log_server(data, get_conv_log_filename())
759
-
760
- for i, state in enumerate([state0, state1]):
761
- output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
762
- os.makedirs(os.path.dirname(output_file), exist_ok=True)
763
- with open(output_file, 'w') as f:
764
- save_any_image(state.output, f)
765
- save_image_file_on_log_server(output_file)
766
-
767
- def generate_igm_annoy_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request):
768
- state0 = ImageStateIG(model_name0)
769
- state1 = ImageStateIG(model_name1)
770
- ip = get_ip(request)
771
- igm_logger.info(f"generate. ip: {ip}")
772
- start_tstamp = time.time()
773
- # model_name0 = re.sub(r"### Model A: ", "", model_name0)
774
- # model_name1 = re.sub(r"### Model B: ", "", model_name1)
775
- model_name0 = ""
776
- model_name1 = ""
777
- generated_image0, generated_image1, model_name0, model_name1, text = gen_func(model_name0, model_name1)
778
- state0.prompt = text
779
- state1.prompt = text
780
- state0.output = generated_image0
781
- state1.output = generated_image1
782
- state0.model_name = model_name0
783
- state1.model_name = model_name1
784
-
785
- yield state0, state1, generated_image0, generated_image1, text,\
786
- gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False)
787
-
788
- finish_tstamp = time.time()
789
- # logger.info(f"===output===: {output}")
790
-
791
- with open(get_conv_log_filename(), "a") as fout:
792
- data = {
793
- "tstamp": round(finish_tstamp, 4),
794
- "type": "chat",
795
- "model": model_name0,
796
- "gen_params": {},
797
- "start": round(start_tstamp, 4),
798
- "finish": round(finish_tstamp, 4),
799
- "state": state0.dict(),
800
- "ip": get_ip(request),
801
- }
802
- fout.write(json.dumps(data) + "\n")
803
- append_json_item_on_log_server(data, get_conv_log_filename())
804
- data = {
805
- "tstamp": round(finish_tstamp, 4),
806
- "type": "chat",
807
- "model": model_name1,
808
- "gen_params": {},
809
- "start": round(start_tstamp, 4),
810
- "finish": round(finish_tstamp, 4),
811
- "state": state1.dict(),
812
- "ip": get_ip(request),
813
- }
814
- fout.write(json.dumps(data) + "\n")
815
- append_json_item_on_log_server(data, get_conv_log_filename())
816
-
817
- for i, state in enumerate([state0, state1]):
818
- output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
819
- os.makedirs(os.path.dirname(output_file), exist_ok=True)
820
- with open(output_file, 'w') as f:
821
- save_any_image(state.output, f)
822
- save_image_file_on_log_server(output_file)
823
-
824
- def generate_ie(gen_func, state, source_text, target_text, instruct_text, source_image, model_name, request: gr.Request):
825
- if not source_text:
826
- raise gr.Warning("Source prompt cannot be empty.")
827
- if not target_text:
828
- raise gr.Warning("Target prompt cannot be empty.")
829
- if not instruct_text:
830
- raise gr.Warning("Instruction prompt cannot be empty.")
831
- if not source_image:
832
- raise gr.Warning("Source image cannot be empty.")
833
- if not model_name:
834
- raise gr.Warning("Model name cannot be empty.")
835
- state = ImageStateIE(model_name)
836
- ip = get_ip(request)
837
- ig_logger.info(f"generate. ip: {ip}")
838
- start_tstamp = time.time()
839
- generated_image = gen_func(source_text, target_text, instruct_text, source_image, model_name)
840
- if generated_image == '':
841
- raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
842
- state.source_prompt = source_text
843
- state.target_prompt = target_text
844
- state.instruct_prompt = instruct_text
845
- state.source_image = source_image
846
- state.output = generated_image
847
- state.model_name = model_name
848
-
849
- yield state, generated_image
850
-
851
- finish_tstamp = time.time()
852
- # logger.info(f"===output===: {output}")
853
-
854
- with open(get_conv_log_filename(), "a") as fout:
855
- data = {
856
- "tstamp": round(finish_tstamp, 4),
857
- "type": "chat",
858
- "model": model_name,
859
- "gen_params": {},
860
- "start": round(start_tstamp, 4),
861
- "finish": round(finish_tstamp, 4),
862
- "state": state.dict(),
863
- "ip": get_ip(request),
864
- }
865
- fout.write(json.dumps(data) + "\n")
866
- append_json_item_on_log_server(data, get_conv_log_filename())
867
-
868
- src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg'
869
- os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
870
- with open(src_img_file, 'w') as f:
871
- save_any_image(state.source_image, f)
872
- output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg'
873
- with open(output_file, 'w') as f:
874
- save_any_image(state.output, f)
875
- save_image_file_on_log_server(src_img_file)
876
- save_image_file_on_log_server(output_file)
877
-
878
- def generate_ie_museum(gen_func, state, model_name, request: gr.Request):
879
- if not model_name:
880
- raise gr.Warning("Model name cannot be empty.")
881
- state = ImageStateIE(model_name)
882
- ip = get_ip(request)
883
- ig_logger.info(f"generate. ip: {ip}")
884
- start_tstamp = time.time()
885
- source_image, generated_image, source_text, target_text, instruct_text = gen_func(model_name)
886
- state.source_prompt = source_text
887
- state.target_prompt = target_text
888
- state.instruct_prompt = instruct_text
889
- state.source_image = source_image
890
- state.output = generated_image
891
- state.model_name = model_name
892
-
893
- yield state, generated_image, source_image, source_text, target_text, instruct_text
894
-
895
- finish_tstamp = time.time()
896
- # logger.info(f"===output===: {output}")
897
-
898
- with open(get_conv_log_filename(), "a") as fout:
899
- data = {
900
- "tstamp": round(finish_tstamp, 4),
901
- "type": "chat",
902
- "model": model_name,
903
- "gen_params": {},
904
- "start": round(start_tstamp, 4),
905
- "finish": round(finish_tstamp, 4),
906
- "state": state.dict(),
907
- "ip": get_ip(request),
908
- }
909
- fout.write(json.dumps(data) + "\n")
910
- append_json_item_on_log_server(data, get_conv_log_filename())
911
-
912
- src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg'
913
- os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
914
- with open(src_img_file, 'w') as f:
915
- save_any_image(state.source_image, f)
916
- output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg'
917
- with open(output_file, 'w') as f:
918
- save_any_image(state.output, f)
919
- save_image_file_on_log_server(src_img_file)
920
- save_image_file_on_log_server(output_file)
921
-
922
-
923
- def generate_iem(gen_func, state0, state1, source_text, target_text, instruct_text, source_image, model_name0, model_name1, request: gr.Request):
924
- if not source_text:
925
- raise gr.Warning("Source prompt cannot be empty.")
926
- if not target_text:
927
- raise gr.Warning("Target prompt cannot be empty.")
928
- if not instruct_text:
929
- raise gr.Warning("Instruction prompt cannot be empty.")
930
- if not source_image:
931
- raise gr.Warning("Source image cannot be empty.")
932
- if not model_name0:
933
- raise gr.Warning("Model name A cannot be empty.")
934
- if not model_name1:
935
- raise gr.Warning("Model name B cannot be empty.")
936
- state0 = ImageStateIE(model_name0)
937
- state1 = ImageStateIE(model_name1)
938
- ip = get_ip(request)
939
- igm_logger.info(f"generate. ip: {ip}")
940
- start_tstamp = time.time()
941
- model_name0 = re.sub(r"### Model A: ", "", model_name0)
942
- model_name1 = re.sub(r"### Model B: ", "", model_name1)
943
- generated_image0, generated_image1 = gen_func(source_text, target_text, instruct_text, source_image, model_name0, model_name1)
944
- if generated_image0 == '' and generated_image1 == '':
945
- raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
946
- state0.source_prompt = source_text
947
- state0.target_prompt = target_text
948
- state0.instruct_prompt = instruct_text
949
- state0.source_image = source_image
950
- state0.output = generated_image0
951
- state0.model_name = model_name0
952
- state1.source_prompt = source_text
953
- state1.target_prompt = target_text
954
- state1.instruct_prompt = instruct_text
955
- state1.source_image = source_image
956
- state1.output = generated_image1
957
- state1.model_name = model_name1
958
-
959
- yield state0, state1, generated_image0, generated_image1
960
-
961
- finish_tstamp = time.time()
962
- # logger.info(f"===output===: {output}")
963
-
964
- with open(get_conv_log_filename(), "a") as fout:
965
- data = {
966
- "tstamp": round(finish_tstamp, 4),
967
- "type": "chat",
968
- "model": model_name0,
969
- "gen_params": {},
970
- "start": round(start_tstamp, 4),
971
- "finish": round(finish_tstamp, 4),
972
- "state": state0.dict(),
973
- "ip": get_ip(request),
974
- }
975
- fout.write(json.dumps(data) + "\n")
976
- append_json_item_on_log_server(data, get_conv_log_filename())
977
- data = {
978
- "tstamp": round(finish_tstamp, 4),
979
- "type": "chat",
980
- "model": model_name1,
981
- "gen_params": {},
982
- "start": round(start_tstamp, 4),
983
- "finish": round(finish_tstamp, 4),
984
- "state": state1.dict(),
985
- "ip": get_ip(request),
986
- }
987
- fout.write(json.dumps(data) + "\n")
988
- append_json_item_on_log_server(data, get_conv_log_filename())
989
-
990
- for i, state in enumerate([state0, state1]):
991
- src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg'
992
- os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
993
- with open(src_img_file, 'w') as f:
994
- save_any_image(state.source_image, f)
995
- output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg'
996
- with open(output_file, 'w') as f:
997
- save_any_image(state.output, f)
998
- save_image_file_on_log_server(src_img_file)
999
- save_image_file_on_log_server(output_file)
1000
-
1001
- def generate_iem_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request):
1002
- if not model_name0:
1003
- raise gr.Warning("Model name A cannot be empty.")
1004
- if not model_name1:
1005
- raise gr.Warning("Model name B cannot be empty.")
1006
- state0 = ImageStateIE(model_name0)
1007
- state1 = ImageStateIE(model_name1)
1008
- ip = get_ip(request)
1009
- igm_logger.info(f"generate. ip: {ip}")
1010
- start_tstamp = time.time()
1011
- model_name0 = re.sub(r"### Model A: ", "", model_name0)
1012
- model_name1 = re.sub(r"### Model B: ", "", model_name1)
1013
- source_image, generated_image0, generated_image1, source_text, target_text, instruct_text = gen_func(model_name0, model_name1)
1014
- state0.source_prompt = source_text
1015
- state0.target_prompt = target_text
1016
- state0.instruct_prompt = instruct_text
1017
- state0.source_image = source_image
1018
- state0.output = generated_image0
1019
- state0.model_name = model_name0
1020
- state1.source_prompt = source_text
1021
- state1.target_prompt = target_text
1022
- state1.instruct_prompt = instruct_text
1023
- state1.source_image = source_image
1024
- state1.output = generated_image1
1025
- state1.model_name = model_name1
1026
-
1027
- yield state0, state1, generated_image0, generated_image1, source_image, source_text, target_text, instruct_text
1028
-
1029
- finish_tstamp = time.time()
1030
- # logger.info(f"===output===: {output}")
1031
-
1032
- with open(get_conv_log_filename(), "a") as fout:
1033
- data = {
1034
- "tstamp": round(finish_tstamp, 4),
1035
- "type": "chat",
1036
- "model": model_name0,
1037
- "gen_params": {},
1038
- "start": round(start_tstamp, 4),
1039
- "finish": round(finish_tstamp, 4),
1040
- "state": state0.dict(),
1041
- "ip": get_ip(request),
1042
- }
1043
- fout.write(json.dumps(data) + "\n")
1044
- append_json_item_on_log_server(data, get_conv_log_filename())
1045
- data = {
1046
- "tstamp": round(finish_tstamp, 4),
1047
- "type": "chat",
1048
- "model": model_name1,
1049
- "gen_params": {},
1050
- "start": round(start_tstamp, 4),
1051
- "finish": round(finish_tstamp, 4),
1052
- "state": state1.dict(),
1053
- "ip": get_ip(request),
1054
- }
1055
- fout.write(json.dumps(data) + "\n")
1056
- append_json_item_on_log_server(data, get_conv_log_filename())
1057
-
1058
- for i, state in enumerate([state0, state1]):
1059
- src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg'
1060
- os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
1061
- with open(src_img_file, 'w') as f:
1062
- save_any_image(state.source_image, f)
1063
- output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg'
1064
- with open(output_file, 'w') as f:
1065
- save_any_image(state.output, f)
1066
- save_image_file_on_log_server(src_img_file)
1067
- save_image_file_on_log_server(output_file)
1068
-
1069
-
1070
- def generate_iem_annoy(gen_func, state0, state1, source_text, target_text, instruct_text, source_image, model_name0, model_name1, request: gr.Request):
1071
- if not source_text:
1072
- raise gr.Warning("Source prompt cannot be empty.")
1073
- if not target_text:
1074
- raise gr.Warning("Target prompt cannot be empty.")
1075
- if not instruct_text:
1076
- raise gr.Warning("Instruction prompt cannot be empty.")
1077
- if not source_image:
1078
- raise gr.Warning("Source image cannot be empty.")
1079
- state0 = ImageStateIE(model_name0)
1080
- state1 = ImageStateIE(model_name1)
1081
- ip = get_ip(request)
1082
- igm_logger.info(f"generate. ip: {ip}")
1083
- start_tstamp = time.time()
1084
- model_name0 = ""
1085
- model_name1 = ""
1086
- generated_image0, generated_image1, model_name0, model_name1 = gen_func(source_text, target_text, instruct_text, source_image, model_name0, model_name1)
1087
- if generated_image0 == '' and generated_image1 == '':
1088
- raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
1089
- state0.source_prompt = source_text
1090
- state0.target_prompt = target_text
1091
- state0.instruct_prompt = instruct_text
1092
- state0.source_image = source_image
1093
- state0.output = generated_image0
1094
- state0.model_name = model_name0
1095
- state1.source_prompt = source_text
1096
- state1.target_prompt = target_text
1097
- state1.instruct_prompt = instruct_text
1098
- state1.source_image = source_image
1099
- state1.output = generated_image1
1100
- state1.model_name = model_name1
1101
-
1102
- yield state0, state1, generated_image0, generated_image1, \
1103
- gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False)
1104
-
1105
- finish_tstamp = time.time()
1106
- # logger.info(f"===output===: {output}")
1107
-
1108
- with open(get_conv_log_filename(), "a") as fout:
1109
- data = {
1110
- "tstamp": round(finish_tstamp, 4),
1111
- "type": "chat",
1112
- "model": model_name0,
1113
- "gen_params": {},
1114
- "start": round(start_tstamp, 4),
1115
- "finish": round(finish_tstamp, 4),
1116
- "state": state0.dict(),
1117
- "ip": get_ip(request),
1118
- }
1119
- fout.write(json.dumps(data) + "\n")
1120
- append_json_item_on_log_server(data, get_conv_log_filename())
1121
- data = {
1122
- "tstamp": round(finish_tstamp, 4),
1123
- "type": "chat",
1124
- "model": model_name1,
1125
- "gen_params": {},
1126
- "start": round(start_tstamp, 4),
1127
- "finish": round(finish_tstamp, 4),
1128
- "state": state1.dict(),
1129
- "ip": get_ip(request),
1130
- }
1131
- fout.write(json.dumps(data) + "\n")
1132
- append_json_item_on_log_server(data, get_conv_log_filename())
1133
-
1134
- for i, state in enumerate([state0, state1]):
1135
- src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg'
1136
- os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
1137
- with open(src_img_file, 'w') as f:
1138
- save_any_image(state.source_image, f)
1139
- output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg'
1140
- with open(output_file, 'w') as f:
1141
- save_any_image(state.output, f)
1142
- save_image_file_on_log_server(src_img_file)
1143
- save_image_file_on_log_server(output_file)
1144
-
1145
- def generate_iem_annoy_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request):
1146
- state0 = ImageStateIE(model_name0)
1147
- state1 = ImageStateIE(model_name1)
1148
- ip = get_ip(request)
1149
- igm_logger.info(f"generate. ip: {ip}")
1150
- start_tstamp = time.time()
1151
- model_name0 = ""
1152
- model_name1 = ""
1153
- source_image, generated_image0, generated_image1, source_text, target_text, instruct_text, model_name0, model_name1 = gen_func(model_name0, model_name1)
1154
- state0.source_prompt = source_text
1155
- state0.target_prompt = target_text
1156
- state0.instruct_prompt = instruct_text
1157
- state0.source_image = source_image
1158
- state0.output = generated_image0
1159
- state0.model_name = model_name0
1160
- state1.source_prompt = source_text
1161
- state1.target_prompt = target_text
1162
- state1.instruct_prompt = instruct_text
1163
- state1.source_image = source_image
1164
- state1.output = generated_image1
1165
- state1.model_name = model_name1
1166
-
1167
- yield state0, state1, generated_image0, generated_image1, source_image, source_text, target_text, instruct_text, \
1168
- gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False)
1169
-
1170
- finish_tstamp = time.time()
1171
- # logger.info(f"===output===: {output}")
1172
-
1173
- with open(get_conv_log_filename(), "a") as fout:
1174
- data = {
1175
- "tstamp": round(finish_tstamp, 4),
1176
- "type": "chat",
1177
- "model": model_name0,
1178
- "gen_params": {},
1179
- "start": round(start_tstamp, 4),
1180
- "finish": round(finish_tstamp, 4),
1181
- "state": state0.dict(),
1182
- "ip": get_ip(request),
1183
- }
1184
- fout.write(json.dumps(data) + "\n")
1185
- append_json_item_on_log_server(data, get_conv_log_filename())
1186
- data = {
1187
- "tstamp": round(finish_tstamp, 4),
1188
- "type": "chat",
1189
- "model": model_name1,
1190
- "gen_params": {},
1191
- "start": round(start_tstamp, 4),
1192
- "finish": round(finish_tstamp, 4),
1193
- "state": state1.dict(),
1194
- "ip": get_ip(request),
1195
- }
1196
- fout.write(json.dumps(data) + "\n")
1197
- append_json_item_on_log_server(data, get_conv_log_filename())
1198
-
1199
- for i, state in enumerate([state0, state1]):
1200
- src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg'
1201
- os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
1202
- with open(src_img_file, 'w') as f:
1203
- save_any_image(state.source_image, f)
1204
- output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg'
1205
- with open(output_file, 'w') as f:
1206
- save_any_image(state.output, f)
1207
- save_image_file_on_log_server(src_img_file)
1208
- save_image_file_on_log_server(output_file)
1209
-
1210
- def generate_vg(gen_func, state, text, model_name, request: gr.Request):
1211
- if not text:
1212
- raise gr.Warning("Prompt cannot be empty.")
1213
- if not model_name:
1214
- raise gr.Warning("Model name cannot be empty.")
1215
- state = VideoStateVG(model_name)
1216
- ip = get_ip(request)
1217
- vg_logger.info(f"generate. ip: {ip}")
1218
- start_tstamp = time.time()
1219
- generated_video = gen_func(text, model_name)
1220
- if generated_video == '':
1221
- raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
1222
- state.prompt = text
1223
- state.output = generated_video
1224
- state.model_name = model_name
1225
-
1226
- # yield state, generated_video
1227
-
1228
- finish_tstamp = time.time()
1229
-
1230
- with open(get_conv_log_filename(), "a") as fout:
1231
- data = {
1232
- "tstamp": round(finish_tstamp, 4),
1233
- "type": "chat",
1234
- "model": model_name,
1235
- "gen_params": {},
1236
- "start": round(start_tstamp, 4),
1237
- "finish": round(finish_tstamp, 4),
1238
- "state": state.dict(),
1239
- "ip": get_ip(request),
1240
- }
1241
- fout.write(json.dumps(data) + "\n")
1242
- append_json_item_on_log_server(data, get_conv_log_filename())
1243
-
1244
- output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
1245
- os.makedirs(os.path.dirname(output_file), exist_ok=True)
1246
- if model_name.startswith('fal'):
1247
- r = requests.get(state.output)
1248
- with open(output_file, 'wb') as outfile:
1249
- outfile.write(r.content)
1250
- else:
1251
- print("======== video shape: ========")
1252
- print(state.output.shape)
1253
- # Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels]
1254
- if state.output.shape[-1] != 3:
1255
- state.output = state.output.permute(0, 2, 3, 1)
1256
- imageio.mimwrite(output_file, state.output, fps=8, quality=9)
1257
-
1258
- save_video_file_on_log_server(output_file)
1259
- yield state, output_file
1260
-
1261
- def generate_vg_museum(gen_func, state, model_name, request: gr.Request):
1262
- state = VideoStateVG(model_name)
1263
- ip = get_ip(request)
1264
- vg_logger.info(f"generate. ip: {ip}")
1265
- start_tstamp = time.time()
1266
- generated_video, text = gen_func(model_name)
1267
- state.prompt = text
1268
- state.output = generated_video
1269
- state.model_name = model_name
1270
-
1271
- # yield state, generated_video
1272
-
1273
- finish_tstamp = time.time()
1274
-
1275
- with open(get_conv_log_filename(), "a") as fout:
1276
- data = {
1277
- "tstamp": round(finish_tstamp, 4),
1278
- "type": "chat",
1279
- "model": model_name,
1280
- "gen_params": {},
1281
- "start": round(start_tstamp, 4),
1282
- "finish": round(finish_tstamp, 4),
1283
- "state": state.dict(),
1284
- "ip": get_ip(request),
1285
- }
1286
- fout.write(json.dumps(data) + "\n")
1287
- append_json_item_on_log_server(data, get_conv_log_filename())
1288
-
1289
- output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
1290
- os.makedirs(os.path.dirname(output_file), exist_ok=True)
1291
-
1292
- r = requests.get(state.output)
1293
- with open(output_file, 'wb') as outfile:
1294
- outfile.write(r.content)
1295
-
1296
- save_video_file_on_log_server(output_file)
1297
- yield state, output_file, text
1298
-
1299
-
1300
- def generate_vgm(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request):
1301
- if not text:
1302
- raise gr.Warning("Prompt cannot be empty.")
1303
- if not model_name0:
1304
- raise gr.Warning("Model name A cannot be empty.")
1305
- if not model_name1:
1306
- raise gr.Warning("Model name B cannot be empty.")
1307
- state0 = VideoStateVG(model_name0)
1308
- state1 = VideoStateVG(model_name1)
1309
- ip = get_ip(request)
1310
- igm_logger.info(f"generate. ip: {ip}")
1311
- start_tstamp = time.time()
1312
- # Remove ### Model (A|B): from model name
1313
- model_name0 = re.sub(r"### Model A: ", "", model_name0)
1314
- model_name1 = re.sub(r"### Model B: ", "", model_name1)
1315
- generated_video0, generated_video1 = gen_func(text, model_name0, model_name1)
1316
- if generated_video0 == '' and generated_video1 == '':
1317
- raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
1318
- state0.prompt = text
1319
- state1.prompt = text
1320
- state0.output = generated_video0
1321
- state1.output = generated_video1
1322
- state0.model_name = model_name0
1323
- state1.model_name = model_name1
1324
-
1325
- # yield state0, state1, generated_video0, generated_video1
1326
- print("====== model name =========")
1327
- print(state0.model_name)
1328
- print(state1.model_name)
1329
-
1330
-
1331
- finish_tstamp = time.time()
1332
-
1333
-
1334
- with open(get_conv_log_filename(), "a") as fout:
1335
- data = {
1336
- "tstamp": round(finish_tstamp, 4),
1337
- "type": "chat",
1338
- "model": model_name0,
1339
- "gen_params": {},
1340
- "start": round(start_tstamp, 4),
1341
- "finish": round(finish_tstamp, 4),
1342
- "state": state0.dict(),
1343
- "ip": get_ip(request),
1344
- }
1345
- fout.write(json.dumps(data) + "\n")
1346
- append_json_item_on_log_server(data, get_conv_log_filename())
1347
- data = {
1348
- "tstamp": round(finish_tstamp, 4),
1349
- "type": "chat",
1350
- "model": model_name1,
1351
- "gen_params": {},
1352
- "start": round(start_tstamp, 4),
1353
- "finish": round(finish_tstamp, 4),
1354
- "state": state1.dict(),
1355
- "ip": get_ip(request),
1356
- }
1357
- fout.write(json.dumps(data) + "\n")
1358
- append_json_item_on_log_server(data, get_conv_log_filename())
1359
-
1360
- for i, state in enumerate([state0, state1]):
1361
- output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
1362
- os.makedirs(os.path.dirname(output_file), exist_ok=True)
1363
- print(state.model_name)
1364
-
1365
- if state.model_name.startswith('fal'):
1366
- r = requests.get(state.output)
1367
- with open(output_file, 'wb') as outfile:
1368
- outfile.write(r.content)
1369
- else:
1370
- print("======== video shape: ========")
1371
- print(state.output)
1372
- print(state.output.shape)
1373
- # Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels]
1374
- if state.output.shape[-1] != 3:
1375
- state.output = state.output.permute(0, 2, 3, 1)
1376
- imageio.mimwrite(output_file, state.output, fps=8, quality=9)
1377
- save_video_file_on_log_server(output_file)
1378
- yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4'
1379
-
1380
- def generate_vgm_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request):
1381
- state0 = VideoStateVG(model_name0)
1382
- state1 = VideoStateVG(model_name1)
1383
- ip = get_ip(request)
1384
- igm_logger.info(f"generate. ip: {ip}")
1385
- start_tstamp = time.time()
1386
- # Remove ### Model (A|B): from model name
1387
- model_name0 = re.sub(r"### Model A: ", "", model_name0)
1388
- model_name1 = re.sub(r"### Model B: ", "", model_name1)
1389
- generated_video0, generated_video1, text = gen_func(model_name0, model_name1)
1390
- state0.prompt = text
1391
- state1.prompt = text
1392
- state0.output = generated_video0
1393
- state1.output = generated_video1
1394
- state0.model_name = model_name0
1395
- state1.model_name = model_name1
1396
-
1397
- # yield state0, state1, generated_video0, generated_video1
1398
- print("====== model name =========")
1399
- print(state0.model_name)
1400
- print(state1.model_name)
1401
-
1402
-
1403
- finish_tstamp = time.time()
1404
-
1405
-
1406
- with open(get_conv_log_filename(), "a") as fout:
1407
- data = {
1408
- "tstamp": round(finish_tstamp, 4),
1409
- "type": "chat",
1410
- "model": model_name0,
1411
- "gen_params": {},
1412
- "start": round(start_tstamp, 4),
1413
- "finish": round(finish_tstamp, 4),
1414
- "state": state0.dict(),
1415
- "ip": get_ip(request),
1416
- }
1417
- fout.write(json.dumps(data) + "\n")
1418
- append_json_item_on_log_server(data, get_conv_log_filename())
1419
- data = {
1420
- "tstamp": round(finish_tstamp, 4),
1421
- "type": "chat",
1422
- "model": model_name1,
1423
- "gen_params": {},
1424
- "start": round(start_tstamp, 4),
1425
- "finish": round(finish_tstamp, 4),
1426
- "state": state1.dict(),
1427
- "ip": get_ip(request),
1428
- }
1429
- fout.write(json.dumps(data) + "\n")
1430
- append_json_item_on_log_server(data, get_conv_log_filename())
1431
-
1432
- for i, state in enumerate([state0, state1]):
1433
- output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
1434
- os.makedirs(os.path.dirname(output_file), exist_ok=True)
1435
- print(state.model_name)
1436
-
1437
- r = requests.get(state.output)
1438
- with open(output_file, 'wb') as outfile:
1439
- outfile.write(r.content)
1440
-
1441
- save_video_file_on_log_server(output_file)
1442
- yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4', text
1443
-
1444
-
1445
- def generate_vgm_annoy(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request):
1446
- if not text:
1447
- raise gr.Warning("Prompt cannot be empty.")
1448
- state0 = VideoStateVG(model_name0)
1449
- state1 = VideoStateVG(model_name1)
1450
- ip = get_ip(request)
1451
- vgm_logger.info(f"generate. ip: {ip}")
1452
- start_tstamp = time.time()
1453
- model_name0 = ""
1454
- model_name1 = ""
1455
- generated_video0, generated_video1, model_name0, model_name1 = gen_func(text, model_name0, model_name1)
1456
- if generated_video0 == '' and generated_video1 == '':
1457
- raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
1458
- state0.prompt = text
1459
- state1.prompt = text
1460
- state0.output = generated_video0
1461
- state1.output = generated_video1
1462
- state0.model_name = model_name0
1463
- state1.model_name = model_name1
1464
-
1465
- # yield state0, state1, generated_video0, generated_video1, \
1466
- # gr.Markdown(f"### Model A: {model_name0}"), gr.Markdown(f"### Model B: {model_name1}")
1467
-
1468
- finish_tstamp = time.time()
1469
- # logger.info(f"===output===: {output}")
1470
-
1471
- with open(get_conv_log_filename(), "a") as fout:
1472
- data = {
1473
- "tstamp": round(finish_tstamp, 4),
1474
- "type": "chat",
1475
- "model": model_name0,
1476
- "gen_params": {},
1477
- "start": round(start_tstamp, 4),
1478
- "finish": round(finish_tstamp, 4),
1479
- "state": state0.dict(),
1480
- "ip": get_ip(request),
1481
- }
1482
- fout.write(json.dumps(data) + "\n")
1483
- append_json_item_on_log_server(data, get_conv_log_filename())
1484
- data = {
1485
- "tstamp": round(finish_tstamp, 4),
1486
- "type": "chat",
1487
- "model": model_name1,
1488
- "gen_params": {},
1489
- "start": round(start_tstamp, 4),
1490
- "finish": round(finish_tstamp, 4),
1491
- "state": state1.dict(),
1492
- "ip": get_ip(request),
1493
- }
1494
- fout.write(json.dumps(data) + "\n")
1495
- append_json_item_on_log_server(data, get_conv_log_filename())
1496
-
1497
- for i, state in enumerate([state0, state1]):
1498
- output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
1499
- os.makedirs(os.path.dirname(output_file), exist_ok=True)
1500
- if state.model_name.startswith('fal'):
1501
- r = requests.get(state.output)
1502
- with open(output_file, 'wb') as outfile:
1503
- outfile.write(r.content)
1504
- else:
1505
- print("======== video shape: ========")
1506
- print(state.output.shape)
1507
- # Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels]
1508
- if state.output.shape[-1] != 3:
1509
- state.output = state.output.permute(0, 2, 3, 1)
1510
- imageio.mimwrite(output_file, state.output, fps=8, quality=9)
1511
- save_video_file_on_log_server(output_file)
1512
-
1513
- yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4', \
1514
- gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False)
1515
-
1516
- def generate_vgm_annoy_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request):
1517
- state0 = VideoStateVG(model_name0)
1518
- state1 = VideoStateVG(model_name1)
1519
- ip = get_ip(request)
1520
- vgm_logger.info(f"generate. ip: {ip}")
1521
- start_tstamp = time.time()
1522
- model_name0 = ""
1523
- model_name1 = ""
1524
- generated_video0, generated_video1, model_name0, model_name1, text = gen_func(model_name0, model_name1)
1525
- state0.prompt = text
1526
- state1.prompt = text
1527
- state0.output = generated_video0
1528
- state1.output = generated_video1
1529
- state0.model_name = model_name0
1530
- state1.model_name = model_name1
1531
-
1532
- # yield state0, state1, generated_video0, generated_video1, \
1533
- # gr.Markdown(f"### Model A: {model_name0}"), gr.Markdown(f"### Model B: {model_name1}")
1534
-
1535
- finish_tstamp = time.time()
1536
- # logger.info(f"===output===: {output}")
1537
-
1538
- with open(get_conv_log_filename(), "a") as fout:
1539
- data = {
1540
- "tstamp": round(finish_tstamp, 4),
1541
- "type": "chat",
1542
- "model": model_name0,
1543
- "gen_params": {},
1544
- "start": round(start_tstamp, 4),
1545
- "finish": round(finish_tstamp, 4),
1546
- "state": state0.dict(),
1547
- "ip": get_ip(request),
1548
- }
1549
- fout.write(json.dumps(data) + "\n")
1550
- append_json_item_on_log_server(data, get_conv_log_filename())
1551
- data = {
1552
- "tstamp": round(finish_tstamp, 4),
1553
- "type": "chat",
1554
- "model": model_name1,
1555
- "gen_params": {},
1556
- "start": round(start_tstamp, 4),
1557
- "finish": round(finish_tstamp, 4),
1558
- "state": state1.dict(),
1559
- "ip": get_ip(request),
1560
- }
1561
- fout.write(json.dumps(data) + "\n")
1562
- append_json_item_on_log_server(data, get_conv_log_filename())
1563
-
1564
- for i, state in enumerate([state0, state1]):
1565
- output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
1566
- os.makedirs(os.path.dirname(output_file), exist_ok=True)
1567
-
1568
- r = requests.get(state.output)
1569
- with open(output_file, 'wb') as outfile:
1570
- outfile.write(r.content)
1571
-
1572
- save_video_file_on_log_server(output_file)
1573
-
1574
- yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4', text,\
1575
- gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False)