GenAI-Arena / serve /vote_utils.py
tianleliphoebe's picture
add NSFW gaurd
45c2aa5
raw
history blame
64.6 kB
import datetime
import time
import json
import uuid
import gradio as gr
import regex as re
from pathlib import Path
from .utils import *
from .log_utils import build_logger
from .constants import IMAGE_DIR, VIDEO_DIR
import imageio
from diffusers.utils import load_image
import torch
ig_logger = build_logger("gradio_web_server_image_generation", "gr_web_image_generation.log") # ig = image generation, loggers for single model direct chat
igm_logger = build_logger("gradio_web_server_image_generation_multi", "gr_web_image_generation_multi.log") # igm = image generation multi, loggers for side-by-side and battle
ie_logger = build_logger("gradio_web_server_image_editing", "gr_web_image_editing.log") # ie = image editing, loggers for single model direct chat
iem_logger = build_logger("gradio_web_server_image_editing_multi", "gr_web_image_editing_multi.log") # iem = image editing multi, loggers for side-by-side and battle
vg_logger = build_logger("gradio_web_server_video_generation", "gr_web_video_generation.log") # vg = video generation, loggers for single model direct chat
vgm_logger = build_logger("gradio_web_server_video_generation_multi", "gr_web_video_generation_multi.log") # vgm = video generation multi, loggers for side-by-side and battle
def save_any_image(image_file, file_path):
if isinstance(image_file, str):
image = load_image(image_file)
image.save(file_path, 'JPEG')
else:
image_file.save(file_path, 'JPEG')
def vote_last_response_ig(state, vote_type, model_selector, request: gr.Request):
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(time.time(), 4),
"type": vote_type,
"model": model_selector,
"state": state.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
with open(output_file, 'w') as f:
save_any_image(state.output, f)
save_image_file_on_log_server(output_file)
def vote_last_response_igm(states, vote_type, model_selectors, request: gr.Request):
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(time.time(), 4),
"type": vote_type,
"models": [x for x in model_selectors],
"states": [x.dict() for x in states],
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for state in states:
output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
with open(output_file, 'w') as f:
save_any_image(state.output, f)
save_image_file_on_log_server(output_file)
def vote_last_response_ie(state, vote_type, model_selector, request: gr.Request):
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(time.time(), 4),
"type": vote_type,
"model": model_selector,
"state": state.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
output_file = f'{IMAGE_DIR}/edition/{state.conv_id}.jpg'
source_file = f'{IMAGE_DIR}/edition/{state.conv_id}_source.jpg'
with open(output_file, 'w') as f:
save_any_image(state.output, f)
with open(source_file, 'w') as sf:
save_any_image(state.source_image, sf)
save_image_file_on_log_server(output_file)
save_image_file_on_log_server(source_file)
def vote_last_response_iem(states, vote_type, model_selectors, request: gr.Request):
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(time.time(), 4),
"type": vote_type,
"models": [x for x in model_selectors],
"states": [x.dict() for x in states],
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for state in states:
output_file = f'{IMAGE_DIR}/edition/{state.conv_id}.jpg'
source_file = f'{IMAGE_DIR}/edition/{state.conv_id}_source.jpg'
with open(output_file, 'w') as f:
save_any_image(state.output, f)
with open(source_file, 'w') as sf:
save_any_image(state.source_image, sf)
save_image_file_on_log_server(output_file)
save_image_file_on_log_server(source_file)
def vote_last_response_vg(state, vote_type, model_selector, request: gr.Request):
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(time.time(), 4),
"type": vote_type,
"model": model_selector,
"state": state.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
os.makedirs(os.path.dirname(output_file), exist_ok=True)
if state.model_name.startswith('fal'):
r = requests.get(state.output)
with open(output_file, 'wb') as outfile:
outfile.write(r.content)
else:
print("======== video shape: ========")
print(state.output.shape)
# Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels]
if state.output.shape[-1] != 3:
state.output = state.output.permute(0, 2, 3, 1)
imageio.mimwrite(output_file, state.output, fps=8, quality=9)
save_video_file_on_log_server(output_file)
def vote_last_response_vgm(states, vote_type, model_selectors, request: gr.Request):
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(time.time(), 4),
"type": vote_type,
"models": [x for x in model_selectors],
"states": [x.dict() for x in states],
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for state in states:
output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
os.makedirs(os.path.dirname(output_file), exist_ok=True)
if state.model_name.startswith('fal'):
r = requests.get(state.output)
with open(output_file, 'wb') as outfile:
outfile.write(r.content)
elif isinstance(state.output, torch.Tensor):
print("======== video shape: ========")
print(state.output.shape)
# Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels]
if state.output.shape[-1] != 3:
state.output = state.output.permute(0, 2, 3, 1)
imageio.mimwrite(output_file, state.output, fps=8, quality=9)
else:
r = requests.get(state.output)
with open(output_file, 'wb') as outfile:
outfile.write(r.content)
save_video_file_on_log_server(output_file)
## Image Generation (IG) Single Model Direct Chat
def upvote_last_response_ig(state, model_selector, request: gr.Request):
ip = get_ip(request)
ig_logger.info(f"upvote. ip: {ip}")
vote_last_response_ig(state, "upvote", model_selector, request)
return ("",) + (disable_btn,) * 3
def downvote_last_response_ig(state, model_selector, request: gr.Request):
ip = get_ip(request)
ig_logger.info(f"downvote. ip: {ip}")
vote_last_response_ig(state, "downvote", model_selector, request)
return ("",) + (disable_btn,) * 3
def flag_last_response_ig(state, model_selector, request: gr.Request):
ip = get_ip(request)
ig_logger.info(f"flag. ip: {ip}")
vote_last_response_ig(state, "flag", model_selector, request)
return ("",) + (disable_btn,) * 3
## Image Generation Multi (IGM) Side-by-Side and Battle
def leftvote_last_response_igm(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
igm_logger.info(f"leftvote (named). ip: {get_ip(request)}")
vote_last_response_igm(
[state0, state1], "leftvote", [model_selector0, model_selector1], request
)
if model_selector0 == "":
return ("",) + (disable_btn,) * 4 + (
gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
return ("",) + (disable_btn,) * 4 + (gr.Markdown(state0.model_name, visible=True),
gr.Markdown(state1.model_name, visible=True))
def rightvote_last_response_igm(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
igm_logger.info(f"rightvote (named). ip: {get_ip(request)}")
vote_last_response_igm(
[state0, state1], "rightvote", [model_selector0, model_selector1], request
)
print(model_selector0)
if model_selector0 == "":
return ("",) + (disable_btn,) * 4 + (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
return ("",) + (disable_btn,) * 4 + (gr.Markdown(state0.model_name, visible=True),
gr.Markdown(state1.model_name, visible=True))
def tievote_last_response_igm(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
igm_logger.info(f"tievote (named). ip: {get_ip(request)}")
vote_last_response_igm(
[state0, state1], "tievote", [model_selector0, model_selector1], request
)
if model_selector0 == "":
return ("",) + (disable_btn,) * 4 + (
gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
return ("",) + (disable_btn,) * 4 + (gr.Markdown(state0.model_name, visible=True),
gr.Markdown(state1.model_name, visible=True))
def bothbad_vote_last_response_igm(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
igm_logger.info(f"bothbad_vote (named). ip: {get_ip(request)}")
vote_last_response_igm(
[state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
)
if model_selector0 == "":
return ("",) + (disable_btn,) * 4 + (
gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
return ("",) + (disable_btn,) * 4 + (gr.Markdown(state0.model_name, visible=True),
gr.Markdown(state1.model_name, visible=True))
## Image Editing (IE) Single Model Direct Chat
def upvote_last_response_ie(state, model_selector, request: gr.Request):
ip = get_ip(request)
ie_logger.info(f"upvote. ip: {ip}")
vote_last_response_ie(state, "upvote", model_selector, request)
return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3
def downvote_last_response_ie(state, model_selector, request: gr.Request):
ip = get_ip(request)
ie_logger.info(f"downvote. ip: {ip}")
vote_last_response_ie(state, "downvote", model_selector, request)
return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3
def flag_last_response_ie(state, model_selector, request: gr.Request):
ip = get_ip(request)
ie_logger.info(f"flag. ip: {ip}")
vote_last_response_ie(state, "flag", model_selector, request)
return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3
## Image Editing Multi (IEM) Side-by-Side and Battle
def leftvote_last_response_iem(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
iem_logger.info(f"leftvote (anony). ip: {get_ip(request)}")
vote_last_response_iem(
[state0, state1], "leftvote", [model_selector0, model_selector1], request
)
# names = (
# "### Model A: " + state0.model_name,
# "### Model B: " + state1.model_name,
# )
# names = (state0.model_name, state1.model_name)
if model_selector0 == "":
names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False))
return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4
def rightvote_last_response_iem(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
iem_logger.info(f"rightvote (anony). ip: {get_ip(request)}")
vote_last_response_iem(
[state0, state1], "rightvote", [model_selector0, model_selector1], request
)
# names = (
# "### Model A: " + state0.model_name,
# "### Model B: " + state1.model_name,
# )
if model_selector0 == "":
names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False))
return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4
def tievote_last_response_iem(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
iem_logger.info(f"tievote (anony). ip: {get_ip(request)}")
vote_last_response_iem(
[state0, state1], "tievote", [model_selector0, model_selector1], request
)
if model_selector0 == "":
names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False))
return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4
def bothbad_vote_last_response_iem(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
iem_logger.info(f"bothbad_vote (anony). ip: {get_ip(request)}")
vote_last_response_iem(
[state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
)
if model_selector0 == "":
names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False))
return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4
## Video Generation (VG) Single Model Direct Chat
def upvote_last_response_vg(state, model_selector, request: gr.Request):
ip = get_ip(request)
vg_logger.info(f"upvote. ip: {ip}")
vote_last_response_vg(state, "upvote", model_selector, request)
return ("",) + (disable_btn,) * 3
def downvote_last_response_vg(state, model_selector, request: gr.Request):
ip = get_ip(request)
vg_logger.info(f"downvote. ip: {ip}")
vote_last_response_vg(state, "downvote", model_selector, request)
return ("",) + (disable_btn,) * 3
def flag_last_response_vg(state, model_selector, request: gr.Request):
ip = get_ip(request)
vg_logger.info(f"flag. ip: {ip}")
vote_last_response_vg(state, "flag", model_selector, request)
return ("",) + (disable_btn,) * 3
## Image Generation Multi (IGM) Side-by-Side and Battle
def leftvote_last_response_vgm(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
vgm_logger.info(f"leftvote (named). ip: {get_ip(request)}")
vote_last_response_vgm(
[state0, state1], "leftvote", [model_selector0, model_selector1], request
)
if model_selector0 == "":
return ("",) + (disable_btn,) * 4 + (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
return ("",) + (disable_btn,) * 4 + (
gr.Markdown(state0.model_name, visible=False),
gr.Markdown(state1.model_name, visible=False))
def rightvote_last_response_vgm(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
vgm_logger.info(f"rightvote (named). ip: {get_ip(request)}")
vote_last_response_vgm(
[state0, state1], "rightvote", [model_selector0, model_selector1], request
)
if model_selector0 == "":
return ("",) + (disable_btn,) * 4 + (
gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
return ("",) + (disable_btn,) * 4 + (
gr.Markdown(state0.model_name, visible=False),
gr.Markdown(state1.model_name, visible=False))
def tievote_last_response_vgm(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
vgm_logger.info(f"tievote (named). ip: {get_ip(request)}")
vote_last_response_vgm(
[state0, state1], "tievote", [model_selector0, model_selector1], request
)
if model_selector0 == "":
return ("",) + (disable_btn,) * 4 + (
gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
return ("",) + (disable_btn,) * 4 + (
gr.Markdown(state0.model_name, visible=False),
gr.Markdown(state1.model_name, visible=False))
def bothbad_vote_last_response_vgm(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
vgm_logger.info(f"bothbad_vote (named). ip: {get_ip(request)}")
vote_last_response_vgm(
[state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
)
if model_selector0 == "":
return ("",) + (disable_btn,) * 4 + (
gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
return ("",) + (disable_btn,) * 4 + (
gr.Markdown(state0.model_name, visible=False),
gr.Markdown(state1.model_name, visible=False))
share_js = """
function (a, b, c, d) {
const captureElement = document.querySelector('#share-region-named');
html2canvas(captureElement)
.then(canvas => {
canvas.style.display = 'none'
document.body.appendChild(canvas)
return canvas
})
.then(canvas => {
const image = canvas.toDataURL('image/png')
const a = document.createElement('a')
a.setAttribute('download', 'chatbot-arena.png')
a.setAttribute('href', image)
a.click()
canvas.remove()
});
return [a, b, c, d];
}
"""
def share_click_igm(state0, state1, model_selector0, model_selector1, request: gr.Request):
igm_logger.info(f"share (anony). ip: {get_ip(request)}")
if state0 is not None and state1 is not None:
vote_last_response_igm(
[state0, state1], "share", [model_selector0, model_selector1], request
)
def share_click_iem(state0, state1, model_selector0, model_selector1, request: gr.Request):
iem_logger.info(f"share (anony). ip: {get_ip(request)}")
if state0 is not None and state1 is not None:
vote_last_response_iem(
[state0, state1], "share", [model_selector0, model_selector1], request
)
## All Generation Gradio Interface
class ImageStateIG:
def __init__(self, model_name):
self.conv_id = uuid.uuid4().hex
self.model_name = model_name
self.prompt = None
self.output = None
def dict(self):
base = {
"conv_id": self.conv_id,
"model_name": self.model_name,
"prompt": self.prompt
}
return base
class ImageStateIE:
def __init__(self, model_name):
self.conv_id = uuid.uuid4().hex
self.model_name = model_name
self.source_prompt = None
self.target_prompt = None
self.instruct_prompt = None
self.source_image = None
self.output = None
def dict(self):
base = {
"conv_id": self.conv_id,
"model_name": self.model_name,
"source_prompt": self.source_prompt,
"target_prompt": self.target_prompt,
"instruct_prompt": self.instruct_prompt
}
return base
class VideoStateVG:
def __init__(self, model_name):
self.conv_id = uuid.uuid4().hex
self.model_name = model_name
self.prompt = None
self.output = None
def dict(self):
base = {
"conv_id": self.conv_id,
"model_name": self.model_name,
"prompt": self.prompt
}
return base
def generate_ig(gen_func, state, text, model_name, request: gr.Request):
if not text:
raise gr.Warning("Prompt cannot be empty.")
if not model_name:
raise gr.Warning("Model name cannot be empty.")
state = ImageStateIG(model_name)
ip = get_ip(request)
ig_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
generated_image = gen_func(text, model_name)
if generated_image == '':
raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
state.prompt = text
state.output = generated_image
state.model_name = model_name
yield state, generated_image
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, 'w') as f:
save_any_image(state.output, f)
save_image_file_on_log_server(output_file)
def generate_ig_museum(gen_func, state, model_name, request: gr.Request):
if not model_name:
raise gr.Warning("Model name cannot be empty.")
state = ImageStateIG(model_name)
ip = get_ip(request)
ig_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
generated_image, text = gen_func(model_name)
state.prompt = text
state.output = generated_image
state.model_name = model_name
yield state, generated_image, text
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, 'w') as f:
save_any_image(state.output, f)
save_image_file_on_log_server(output_file)
def generate_igm(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request):
if not text:
raise gr.Warning("Prompt cannot be empty.")
if not model_name0:
raise gr.Warning("Model name A cannot be empty.")
if not model_name1:
raise gr.Warning("Model name B cannot be empty.")
state0 = ImageStateIG(model_name0)
state1 = ImageStateIG(model_name1)
ip = get_ip(request)
igm_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
# Remove ### Model (A|B): from model name
model_name0 = re.sub(r"### Model A: ", "", model_name0)
model_name1 = re.sub(r"### Model B: ", "", model_name1)
generated_image0, generated_image1 = gen_func(text, model_name0, model_name1)
if generated_image0 == '' and generated_image1 == '':
raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
state0.prompt = text
state1.prompt = text
state0.output = generated_image0
state1.output = generated_image1
state0.model_name = model_name0
state1.model_name = model_name1
yield state0, state1, generated_image0, generated_image1
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name0,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state0.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name1,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state1.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for i, state in enumerate([state0, state1]):
output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, 'w') as f:
save_any_image(state.output, f)
save_image_file_on_log_server(output_file)
def generate_igm_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request):
if not model_name0:
raise gr.Warning("Model name A cannot be empty.")
if not model_name1:
raise gr.Warning("Model name B cannot be empty.")
state0 = ImageStateIG(model_name0)
state1 = ImageStateIG(model_name1)
ip = get_ip(request)
igm_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
# Remove ### Model (A|B): from model name
model_name0 = re.sub(r"### Model A: ", "", model_name0)
model_name1 = re.sub(r"### Model B: ", "", model_name1)
generated_image0, generated_image1, text = gen_func(model_name0, model_name1)
state0.prompt = text
state1.prompt = text
state0.output = generated_image0
state1.output = generated_image1
state0.model_name = model_name0
state1.model_name = model_name1
yield state0, state1, generated_image0, generated_image1, text
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name0,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state0.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name1,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state1.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for i, state in enumerate([state0, state1]):
output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, 'w') as f:
save_any_image(state.output, f)
save_image_file_on_log_server(output_file)
def generate_igm_annoy(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request):
if not text:
raise gr.Warning("Prompt cannot be empty.")
state0 = ImageStateIG(model_name0)
state1 = ImageStateIG(model_name1)
ip = get_ip(request)
igm_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
model_name0 = ""
model_name1 = ""
generated_image0, generated_image1, model_name0, model_name1 = gen_func(text, model_name0, model_name1)
if generated_image0 == '' and generated_image1 == '':
raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
state0.prompt = text
state1.prompt = text
state0.output = generated_image0
state1.output = generated_image1
state0.model_name = model_name0
state1.model_name = model_name1
yield state0, state1, generated_image0, generated_image1, \
gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False)
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name0,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state0.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name1,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state1.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for i, state in enumerate([state0, state1]):
output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, 'w') as f:
save_any_image(state.output, f)
save_image_file_on_log_server(output_file)
def generate_igm_annoy_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request):
state0 = ImageStateIG(model_name0)
state1 = ImageStateIG(model_name1)
ip = get_ip(request)
igm_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
# model_name0 = re.sub(r"### Model A: ", "", model_name0)
# model_name1 = re.sub(r"### Model B: ", "", model_name1)
model_name0 = ""
model_name1 = ""
generated_image0, generated_image1, model_name0, model_name1, text = gen_func(model_name0, model_name1)
state0.prompt = text
state1.prompt = text
state0.output = generated_image0
state1.output = generated_image1
state0.model_name = model_name0
state1.model_name = model_name1
yield state0, state1, generated_image0, generated_image1, text,\
gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False)
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name0,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state0.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name1,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state1.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for i, state in enumerate([state0, state1]):
output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, 'w') as f:
save_any_image(state.output, f)
save_image_file_on_log_server(output_file)
def generate_ie(gen_func, state, source_text, target_text, instruct_text, source_image, model_name, request: gr.Request):
if not source_text:
raise gr.Warning("Source prompt cannot be empty.")
if not target_text:
raise gr.Warning("Target prompt cannot be empty.")
if not instruct_text:
raise gr.Warning("Instruction prompt cannot be empty.")
if not source_image:
raise gr.Warning("Source image cannot be empty.")
if not model_name:
raise gr.Warning("Model name cannot be empty.")
state = ImageStateIE(model_name)
ip = get_ip(request)
ig_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
generated_image = gen_func(source_text, target_text, instruct_text, source_image, model_name)
if generated_image == '':
raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
state.source_prompt = source_text
state.target_prompt = target_text
state.instruct_prompt = instruct_text
state.source_image = source_image
state.output = generated_image
state.model_name = model_name
yield state, generated_image
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg'
os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
with open(src_img_file, 'w') as f:
save_any_image(state.source_image, f)
output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg'
with open(output_file, 'w') as f:
save_any_image(state.output, f)
save_image_file_on_log_server(src_img_file)
save_image_file_on_log_server(output_file)
def generate_ie_museum(gen_func, state, model_name, request: gr.Request):
if not model_name:
raise gr.Warning("Model name cannot be empty.")
state = ImageStateIE(model_name)
ip = get_ip(request)
ig_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
source_image, generated_image, source_text, target_text, instruct_text = gen_func(model_name)
state.source_prompt = source_text
state.target_prompt = target_text
state.instruct_prompt = instruct_text
state.source_image = source_image
state.output = generated_image
state.model_name = model_name
yield state, generated_image, source_image, source_text, target_text, instruct_text
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg'
os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
with open(src_img_file, 'w') as f:
save_any_image(state.source_image, f)
output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg'
with open(output_file, 'w') as f:
save_any_image(state.output, f)
save_image_file_on_log_server(src_img_file)
save_image_file_on_log_server(output_file)
def generate_iem(gen_func, state0, state1, source_text, target_text, instruct_text, source_image, model_name0, model_name1, request: gr.Request):
if not source_text:
raise gr.Warning("Source prompt cannot be empty.")
if not target_text:
raise gr.Warning("Target prompt cannot be empty.")
if not instruct_text:
raise gr.Warning("Instruction prompt cannot be empty.")
if not source_image:
raise gr.Warning("Source image cannot be empty.")
if not model_name0:
raise gr.Warning("Model name A cannot be empty.")
if not model_name1:
raise gr.Warning("Model name B cannot be empty.")
state0 = ImageStateIE(model_name0)
state1 = ImageStateIE(model_name1)
ip = get_ip(request)
igm_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
model_name0 = re.sub(r"### Model A: ", "", model_name0)
model_name1 = re.sub(r"### Model B: ", "", model_name1)
generated_image0, generated_image1 = gen_func(source_text, target_text, instruct_text, source_image, model_name0, model_name1)
if generated_image0 == '' and generated_image1 == '':
raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
state0.source_prompt = source_text
state0.target_prompt = target_text
state0.instruct_prompt = instruct_text
state0.source_image = source_image
state0.output = generated_image0
state0.model_name = model_name0
state1.source_prompt = source_text
state1.target_prompt = target_text
state1.instruct_prompt = instruct_text
state1.source_image = source_image
state1.output = generated_image1
state1.model_name = model_name1
yield state0, state1, generated_image0, generated_image1
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name0,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state0.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name1,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state1.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for i, state in enumerate([state0, state1]):
src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg'
os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
with open(src_img_file, 'w') as f:
save_any_image(state.source_image, f)
output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg'
with open(output_file, 'w') as f:
save_any_image(state.output, f)
save_image_file_on_log_server(src_img_file)
save_image_file_on_log_server(output_file)
def generate_iem_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request):
if not model_name0:
raise gr.Warning("Model name A cannot be empty.")
if not model_name1:
raise gr.Warning("Model name B cannot be empty.")
state0 = ImageStateIE(model_name0)
state1 = ImageStateIE(model_name1)
ip = get_ip(request)
igm_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
model_name0 = re.sub(r"### Model A: ", "", model_name0)
model_name1 = re.sub(r"### Model B: ", "", model_name1)
source_image, generated_image0, generated_image1, source_text, target_text, instruct_text = gen_func(model_name0, model_name1)
state0.source_prompt = source_text
state0.target_prompt = target_text
state0.instruct_prompt = instruct_text
state0.source_image = source_image
state0.output = generated_image0
state0.model_name = model_name0
state1.source_prompt = source_text
state1.target_prompt = target_text
state1.instruct_prompt = instruct_text
state1.source_image = source_image
state1.output = generated_image1
state1.model_name = model_name1
yield state0, state1, generated_image0, generated_image1, source_image, source_text, target_text, instruct_text
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name0,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state0.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name1,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state1.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for i, state in enumerate([state0, state1]):
src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg'
os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
with open(src_img_file, 'w') as f:
save_any_image(state.source_image, f)
output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg'
with open(output_file, 'w') as f:
save_any_image(state.output, f)
save_image_file_on_log_server(src_img_file)
save_image_file_on_log_server(output_file)
def generate_iem_annoy(gen_func, state0, state1, source_text, target_text, instruct_text, source_image, model_name0, model_name1, request: gr.Request):
if not source_text:
raise gr.Warning("Source prompt cannot be empty.")
if not target_text:
raise gr.Warning("Target prompt cannot be empty.")
if not instruct_text:
raise gr.Warning("Instruction prompt cannot be empty.")
if not source_image:
raise gr.Warning("Source image cannot be empty.")
state0 = ImageStateIE(model_name0)
state1 = ImageStateIE(model_name1)
ip = get_ip(request)
igm_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
model_name0 = ""
model_name1 = ""
generated_image0, generated_image1, model_name0, model_name1 = gen_func(source_text, target_text, instruct_text, source_image, model_name0, model_name1)
if generated_image0 == '' and generated_image1 == '':
raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
state0.source_prompt = source_text
state0.target_prompt = target_text
state0.instruct_prompt = instruct_text
state0.source_image = source_image
state0.output = generated_image0
state0.model_name = model_name0
state1.source_prompt = source_text
state1.target_prompt = target_text
state1.instruct_prompt = instruct_text
state1.source_image = source_image
state1.output = generated_image1
state1.model_name = model_name1
yield state0, state1, generated_image0, generated_image1, \
gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False)
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name0,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state0.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name1,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state1.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for i, state in enumerate([state0, state1]):
src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg'
os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
with open(src_img_file, 'w') as f:
save_any_image(state.source_image, f)
output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg'
with open(output_file, 'w') as f:
save_any_image(state.output, f)
save_image_file_on_log_server(src_img_file)
save_image_file_on_log_server(output_file)
def generate_iem_annoy_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request):
state0 = ImageStateIE(model_name0)
state1 = ImageStateIE(model_name1)
ip = get_ip(request)
igm_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
model_name0 = ""
model_name1 = ""
source_image, generated_image0, generated_image1, source_text, target_text, instruct_text, model_name0, model_name1 = gen_func(model_name0, model_name1)
state0.source_prompt = source_text
state0.target_prompt = target_text
state0.instruct_prompt = instruct_text
state0.source_image = source_image
state0.output = generated_image0
state0.model_name = model_name0
state1.source_prompt = source_text
state1.target_prompt = target_text
state1.instruct_prompt = instruct_text
state1.source_image = source_image
state1.output = generated_image1
state1.model_name = model_name1
yield state0, state1, generated_image0, generated_image1, source_image, source_text, target_text, instruct_text, \
gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False)
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name0,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state0.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name1,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state1.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for i, state in enumerate([state0, state1]):
src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg'
os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
with open(src_img_file, 'w') as f:
save_any_image(state.source_image, f)
output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg'
with open(output_file, 'w') as f:
save_any_image(state.output, f)
save_image_file_on_log_server(src_img_file)
save_image_file_on_log_server(output_file)
def generate_vg(gen_func, state, text, model_name, request: gr.Request):
if not text:
raise gr.Warning("Prompt cannot be empty.")
if not model_name:
raise gr.Warning("Model name cannot be empty.")
state = VideoStateVG(model_name)
ip = get_ip(request)
vg_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
generated_video = gen_func(text, model_name)
if generated_video == '':
raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
state.prompt = text
state.output = generated_video
state.model_name = model_name
# yield state, generated_video
finish_tstamp = time.time()
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
os.makedirs(os.path.dirname(output_file), exist_ok=True)
if model_name.startswith('fal'):
r = requests.get(state.output)
with open(output_file, 'wb') as outfile:
outfile.write(r.content)
else:
print("======== video shape: ========")
print(state.output.shape)
# Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels]
if state.output.shape[-1] != 3:
state.output = state.output.permute(0, 2, 3, 1)
imageio.mimwrite(output_file, state.output, fps=8, quality=9)
save_video_file_on_log_server(output_file)
yield state, output_file
def generate_vg_museum(gen_func, state, model_name, request: gr.Request):
state = VideoStateVG(model_name)
ip = get_ip(request)
vg_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
generated_video, text = gen_func(model_name)
state.prompt = text
state.output = generated_video
state.model_name = model_name
# yield state, generated_video
finish_tstamp = time.time()
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
os.makedirs(os.path.dirname(output_file), exist_ok=True)
r = requests.get(state.output)
with open(output_file, 'wb') as outfile:
outfile.write(r.content)
save_video_file_on_log_server(output_file)
yield state, output_file, text
def generate_vgm(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request):
if not text:
raise gr.Warning("Prompt cannot be empty.")
if not model_name0:
raise gr.Warning("Model name A cannot be empty.")
if not model_name1:
raise gr.Warning("Model name B cannot be empty.")
state0 = VideoStateVG(model_name0)
state1 = VideoStateVG(model_name1)
ip = get_ip(request)
igm_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
# Remove ### Model (A|B): from model name
model_name0 = re.sub(r"### Model A: ", "", model_name0)
model_name1 = re.sub(r"### Model B: ", "", model_name1)
generated_video0, generated_video1 = gen_func(text, model_name0, model_name1)
if generated_video0 == '' and generated_video1 == '':
raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
state0.prompt = text
state1.prompt = text
state0.output = generated_video0
state1.output = generated_video1
state0.model_name = model_name0
state1.model_name = model_name1
# yield state0, state1, generated_video0, generated_video1
print("====== model name =========")
print(state0.model_name)
print(state1.model_name)
finish_tstamp = time.time()
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name0,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state0.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name1,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state1.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for i, state in enumerate([state0, state1]):
output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
os.makedirs(os.path.dirname(output_file), exist_ok=True)
print(state.model_name)
if state.model_name.startswith('fal'):
r = requests.get(state.output)
with open(output_file, 'wb') as outfile:
outfile.write(r.content)
else:
print("======== video shape: ========")
print(state.output)
print(state.output.shape)
# Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels]
if state.output.shape[-1] != 3:
state.output = state.output.permute(0, 2, 3, 1)
imageio.mimwrite(output_file, state.output, fps=8, quality=9)
save_video_file_on_log_server(output_file)
yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4'
def generate_vgm_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request):
state0 = VideoStateVG(model_name0)
state1 = VideoStateVG(model_name1)
ip = get_ip(request)
igm_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
# Remove ### Model (A|B): from model name
model_name0 = re.sub(r"### Model A: ", "", model_name0)
model_name1 = re.sub(r"### Model B: ", "", model_name1)
generated_video0, generated_video1, text = gen_func(model_name0, model_name1)
state0.prompt = text
state1.prompt = text
state0.output = generated_video0
state1.output = generated_video1
state0.model_name = model_name0
state1.model_name = model_name1
# yield state0, state1, generated_video0, generated_video1
print("====== model name =========")
print(state0.model_name)
print(state1.model_name)
finish_tstamp = time.time()
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name0,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state0.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name1,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state1.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for i, state in enumerate([state0, state1]):
output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
os.makedirs(os.path.dirname(output_file), exist_ok=True)
print(state.model_name)
r = requests.get(state.output)
with open(output_file, 'wb') as outfile:
outfile.write(r.content)
save_video_file_on_log_server(output_file)
yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4', text
def generate_vgm_annoy(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request):
if not text:
raise gr.Warning("Prompt cannot be empty.")
state0 = VideoStateVG(model_name0)
state1 = VideoStateVG(model_name1)
ip = get_ip(request)
vgm_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
model_name0 = ""
model_name1 = ""
generated_video0, generated_video1, model_name0, model_name1 = gen_func(text, model_name0, model_name1)
if generated_video0 == '' and generated_video1 == '':
raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
state0.prompt = text
state1.prompt = text
state0.output = generated_video0
state1.output = generated_video1
state0.model_name = model_name0
state1.model_name = model_name1
# yield state0, state1, generated_video0, generated_video1, \
# gr.Markdown(f"### Model A: {model_name0}"), gr.Markdown(f"### Model B: {model_name1}")
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name0,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state0.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name1,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state1.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for i, state in enumerate([state0, state1]):
output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
os.makedirs(os.path.dirname(output_file), exist_ok=True)
if state.model_name.startswith('fal'):
r = requests.get(state.output)
with open(output_file, 'wb') as outfile:
outfile.write(r.content)
else:
print("======== video shape: ========")
print(state.output.shape)
# Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels]
if state.output.shape[-1] != 3:
state.output = state.output.permute(0, 2, 3, 1)
imageio.mimwrite(output_file, state.output, fps=8, quality=9)
save_video_file_on_log_server(output_file)
yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4', \
gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False)
def generate_vgm_annoy_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request):
state0 = VideoStateVG(model_name0)
state1 = VideoStateVG(model_name1)
ip = get_ip(request)
vgm_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
model_name0 = ""
model_name1 = ""
generated_video0, generated_video1, model_name0, model_name1, text = gen_func(model_name0, model_name1)
state0.prompt = text
state1.prompt = text
state0.output = generated_video0
state1.output = generated_video1
state0.model_name = model_name0
state1.model_name = model_name1
# yield state0, state1, generated_video0, generated_video1, \
# gr.Markdown(f"### Model A: {model_name0}"), gr.Markdown(f"### Model B: {model_name1}")
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name0,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state0.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name1,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state1.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for i, state in enumerate([state0, state1]):
output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
os.makedirs(os.path.dirname(output_file), exist_ok=True)
r = requests.get(state.output)
with open(output_file, 'wb') as outfile:
outfile.write(r.content)
save_video_file_on_log_server(output_file)
yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4', text,\
gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False)