import datetime
import time
import json
import uuid
import gradio as gr
import regex as re
from pathlib import Path
from .utils import *
from .log_utils import build_logger
from .constants import IMAGE_DIR, VIDEO_DIR
import imageio
from diffusers.utils import load_image
import torch

ig_logger = build_logger("gradio_web_server_image_generation", "gr_web_image_generation.log") # ig = image generation, loggers for single model direct chat
igm_logger = build_logger("gradio_web_server_image_generation_multi", "gr_web_image_generation_multi.log") # igm = image generation multi, loggers for side-by-side and battle
ie_logger = build_logger("gradio_web_server_image_editing", "gr_web_image_editing.log") # ie = image editing, loggers for single model direct chat
iem_logger = build_logger("gradio_web_server_image_editing_multi", "gr_web_image_editing_multi.log") # iem = image editing multi, loggers for side-by-side and battle
vg_logger = build_logger("gradio_web_server_video_generation", "gr_web_video_generation.log") # vg = video generation, loggers for single model direct chat
vgm_logger = build_logger("gradio_web_server_video_generation_multi", "gr_web_video_generation_multi.log") # vgm = video generation multi, loggers for side-by-side and battle

def save_any_image(image_file, file_path):
    if isinstance(image_file, str):
        image = load_image(image_file)
        image.save(file_path, 'JPEG')
    else:
        image_file.save(file_path, 'JPEG')

def vote_last_response_ig(state, vote_type, model_selector, request: gr.Request):
    with open(get_conv_log_filename(), "a") as fout:
        data = {
            "tstamp": round(time.time(), 4),
            "type": vote_type,
            "model": model_selector,
            "state": state.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())
    output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
    with open(output_file, 'w') as f:
        save_any_image(state.output, f)
    save_image_file_on_log_server(output_file)
        
def vote_last_response_igm(states, vote_type, model_selectors, request: gr.Request):
    with open(get_conv_log_filename(), "a") as fout:
        data = {
            "tstamp": round(time.time(), 4),
            "type": vote_type,
            "models": [x for x in model_selectors],
            "states": [x.dict() for x in states],
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())
    for state in states:
        output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
        with open(output_file, 'w') as f:
            save_any_image(state.output, f)
        save_image_file_on_log_server(output_file)

def vote_last_response_ie(state, vote_type, model_selector, request: gr.Request):
    with open(get_conv_log_filename(), "a") as fout:
        data = {
            "tstamp": round(time.time(), 4),
            "type": vote_type,
            "model": model_selector,
            "state": state.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())
    output_file = f'{IMAGE_DIR}/edition/{state.conv_id}.jpg'
    source_file = f'{IMAGE_DIR}/edition/{state.conv_id}_source.jpg'
    with open(output_file, 'w') as f:
        save_any_image(state.output, f)
    with open(source_file, 'w') as sf:
        save_any_image(state.source_image, sf)
    save_image_file_on_log_server(output_file)
    save_image_file_on_log_server(source_file)
        
def vote_last_response_iem(states, vote_type, model_selectors, request: gr.Request):
    with open(get_conv_log_filename(), "a") as fout:
        data = {
            "tstamp": round(time.time(), 4),
            "type": vote_type,
            "models": [x for x in model_selectors],
            "states": [x.dict() for x in states],
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())
    for state in states:
        output_file = f'{IMAGE_DIR}/edition/{state.conv_id}.jpg'
        source_file = f'{IMAGE_DIR}/edition/{state.conv_id}_source.jpg'
        with open(output_file, 'w') as f:
            save_any_image(state.output, f)
        with open(source_file, 'w') as sf:
            save_any_image(state.source_image, sf)
        save_image_file_on_log_server(output_file)
        save_image_file_on_log_server(source_file)


def vote_last_response_vg(state, vote_type, model_selector, request: gr.Request):
    with open(get_conv_log_filename(), "a") as fout:
        data = {
            "tstamp": round(time.time(), 4),
            "type": vote_type,
            "model": model_selector,
            "state": state.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())

    output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    if state.model_name.startswith('fal'):
        r = requests.get(state.output)
        with open(output_file, 'wb') as outfile:
            outfile.write(r.content)
    else:
        print("======== video shape: ========")
        print(state.output.shape)
        # Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels]
        if state.output.shape[-1] != 3:
            state.output = state.output.permute(0, 2, 3, 1)
        imageio.mimwrite(output_file, state.output, fps=8, quality=9)
    save_video_file_on_log_server(output_file)



def vote_last_response_vgm(states, vote_type, model_selectors, request: gr.Request):
    with open(get_conv_log_filename(), "a") as fout:
        data = {
            "tstamp": round(time.time(), 4),
            "type": vote_type,
            "models": [x for x in model_selectors],
            "states": [x.dict() for x in states],
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())
    for state in states:
        output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
        if state.model_name.startswith('fal'):
            r = requests.get(state.output)
            with open(output_file, 'wb') as outfile:
                outfile.write(r.content)
        elif isinstance(state.output, torch.Tensor):
            print("======== video shape: ========")
            print(state.output.shape)
            # Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels]
            if state.output.shape[-1] != 3:
                state.output = state.output.permute(0, 2, 3, 1)
            imageio.mimwrite(output_file, state.output, fps=8, quality=9)
        else:
            r = requests.get(state.output)
            with open(output_file, 'wb') as outfile:
                outfile.write(r.content)
        save_video_file_on_log_server(output_file)
            

## Image Generation (IG) Single Model Direct Chat
def upvote_last_response_ig(state, model_selector, request: gr.Request):
    ip = get_ip(request)
    ig_logger.info(f"upvote. ip: {ip}")
    vote_last_response_ig(state, "upvote", model_selector, request)
    return ("",) + (disable_btn,) * 3

def downvote_last_response_ig(state, model_selector, request: gr.Request):
    ip = get_ip(request)
    ig_logger.info(f"downvote. ip: {ip}")
    vote_last_response_ig(state, "downvote", model_selector, request)
    return ("",) + (disable_btn,) * 3


def flag_last_response_ig(state, model_selector, request: gr.Request):
    ip = get_ip(request)
    ig_logger.info(f"flag. ip: {ip}")
    vote_last_response_ig(state, "flag", model_selector, request)
    return ("",) + (disable_btn,) * 3

## Image Generation Multi (IGM) Side-by-Side and Battle

def leftvote_last_response_igm(
    state0, state1, model_selector0, model_selector1, request: gr.Request
):
    igm_logger.info(f"leftvote (named). ip: {get_ip(request)}")
    vote_last_response_igm(
        [state0, state1], "leftvote", [model_selector0, model_selector1], request
    )
    if model_selector0 == "":
        return ("",) + (disable_btn,) * 4 + (
        gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
        gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
    else:
        return ("",) + (disable_btn,) * 4 + (gr.Markdown(state0.model_name, visible=True),
                                             gr.Markdown(state1.model_name, visible=True))

def rightvote_last_response_igm(
    state0, state1, model_selector0, model_selector1, request: gr.Request
):
    igm_logger.info(f"rightvote (named). ip: {get_ip(request)}")
    vote_last_response_igm(
        [state0, state1], "rightvote", [model_selector0, model_selector1], request
    )
    print(model_selector0)
    if model_selector0 == "":
        return ("",) + (disable_btn,) * 4 + (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
    else:
        return ("",) + (disable_btn,) * 4 + (gr.Markdown(state0.model_name, visible=True),
                                             gr.Markdown(state1.model_name, visible=True))


def tievote_last_response_igm(
    state0, state1, model_selector0, model_selector1, request: gr.Request
):
    igm_logger.info(f"tievote (named). ip: {get_ip(request)}")
    vote_last_response_igm(
        [state0, state1], "tievote", [model_selector0, model_selector1], request
    )
    if model_selector0 == "":
        return ("",) + (disable_btn,) * 4 + (
        gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
        gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
    else:
        return ("",) + (disable_btn,) * 4 + (gr.Markdown(state0.model_name, visible=True),
                                             gr.Markdown(state1.model_name, visible=True))


def bothbad_vote_last_response_igm(
    state0, state1, model_selector0, model_selector1, request: gr.Request
):
    igm_logger.info(f"bothbad_vote (named). ip: {get_ip(request)}")
    vote_last_response_igm(
        [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
    )
    if model_selector0 == "":
        return ("",) + (disable_btn,) * 4 + (
            gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
            gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
    else:
        return ("",) + (disable_btn,) * 4 + (gr.Markdown(state0.model_name, visible=True),
                                             gr.Markdown(state1.model_name, visible=True))

## Image Editing (IE) Single Model Direct Chat

def upvote_last_response_ie(state, model_selector, request: gr.Request):
    ip = get_ip(request)
    ie_logger.info(f"upvote. ip: {ip}")
    vote_last_response_ie(state, "upvote", model_selector, request)
    return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3

def downvote_last_response_ie(state, model_selector, request: gr.Request):
    ip = get_ip(request)
    ie_logger.info(f"downvote. ip: {ip}")
    vote_last_response_ie(state, "downvote", model_selector, request)
    return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3

def flag_last_response_ie(state, model_selector, request: gr.Request):
    ip = get_ip(request)
    ie_logger.info(f"flag. ip: {ip}")
    vote_last_response_ie(state, "flag", model_selector, request)
    return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3

## Image Editing Multi (IEM) Side-by-Side and Battle
def leftvote_last_response_iem(
    state0, state1, model_selector0, model_selector1, request: gr.Request
):
    iem_logger.info(f"leftvote (anony). ip: {get_ip(request)}")
    vote_last_response_iem(
        [state0, state1], "leftvote", [model_selector0, model_selector1], request
    )
    # names = (
    #     "### Model A: " + state0.model_name,
    #     "### Model B: " + state1.model_name,
    # )
    # names = (state0.model_name, state1.model_name)
    if model_selector0 == "":
        names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
    else:
        names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False))
    return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4
        
def rightvote_last_response_iem(
    state0, state1, model_selector0, model_selector1, request: gr.Request
):
    iem_logger.info(f"rightvote (anony). ip: {get_ip(request)}")
    vote_last_response_iem(
        [state0, state1], "rightvote", [model_selector0, model_selector1], request
    )
    # names = (
    #     "### Model A: " + state0.model_name,
    #     "### Model B: " + state1.model_name,
    # )
    if model_selector0 == "":
        names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
                 gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
    else:
        names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False))
    return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4

def tievote_last_response_iem(
    state0, state1, model_selector0, model_selector1, request: gr.Request
):
    iem_logger.info(f"tievote (anony). ip: {get_ip(request)}")
    vote_last_response_iem(
        [state0, state1], "tievote", [model_selector0, model_selector1], request
    )
    if model_selector0 == "":
        names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
                 gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
    else:
        names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False))
    return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4

def bothbad_vote_last_response_iem(
    state0, state1, model_selector0, model_selector1, request: gr.Request
):
    iem_logger.info(f"bothbad_vote (anony). ip: {get_ip(request)}")
    vote_last_response_iem(
        [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
    )
    if model_selector0 == "":
        names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
                 gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
    else:
        names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False))
    return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4


## Video Generation (VG) Single Model Direct Chat
def upvote_last_response_vg(state, model_selector, request: gr.Request):
    ip = get_ip(request)
    vg_logger.info(f"upvote. ip: {ip}")
    vote_last_response_vg(state, "upvote", model_selector, request)
    return ("",) + (disable_btn,) * 3

def downvote_last_response_vg(state, model_selector, request: gr.Request):
    ip = get_ip(request)
    vg_logger.info(f"downvote. ip: {ip}")
    vote_last_response_vg(state, "downvote", model_selector, request)
    return ("",) + (disable_btn,) * 3


def flag_last_response_vg(state, model_selector, request: gr.Request):
    ip = get_ip(request)
    vg_logger.info(f"flag. ip: {ip}")
    vote_last_response_vg(state, "flag", model_selector, request)
    return ("",) + (disable_btn,) * 3

## Image Generation Multi (IGM) Side-by-Side and Battle

def leftvote_last_response_vgm(
    state0, state1, model_selector0, model_selector1, request: gr.Request
):
    vgm_logger.info(f"leftvote (named). ip: {get_ip(request)}")
    vote_last_response_vgm(
        [state0, state1], "leftvote", [model_selector0, model_selector1], request
    )
    if model_selector0 == "":
        return ("",) + (disable_btn,) * 4 + (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
    else:
        return ("",) + (disable_btn,) * 4 + (
        gr.Markdown(state0.model_name, visible=False),
        gr.Markdown(state1.model_name, visible=False))


def rightvote_last_response_vgm(
    state0, state1, model_selector0, model_selector1, request: gr.Request
):
    vgm_logger.info(f"rightvote (named). ip: {get_ip(request)}")
    vote_last_response_vgm(
        [state0, state1], "rightvote", [model_selector0, model_selector1], request
    )
    if model_selector0 == "":
        return ("",) + (disable_btn,) * 4 + (
        gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
        gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
    else:
        return ("",) + (disable_btn,) * 4 + (
            gr.Markdown(state0.model_name, visible=False),
            gr.Markdown(state1.model_name, visible=False))

def tievote_last_response_vgm(
    state0, state1, model_selector0, model_selector1, request: gr.Request
):
    vgm_logger.info(f"tievote (named). ip: {get_ip(request)}")
    vote_last_response_vgm(
        [state0, state1], "tievote", [model_selector0, model_selector1], request
    )
    if model_selector0 == "":
        return ("",) + (disable_btn,) * 4 + (
        gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
        gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
    else:
        return ("",) + (disable_btn,) * 4 + (
            gr.Markdown(state0.model_name, visible=False),
            gr.Markdown(state1.model_name, visible=False))


def bothbad_vote_last_response_vgm(
    state0, state1, model_selector0, model_selector1, request: gr.Request
):
    vgm_logger.info(f"bothbad_vote (named). ip: {get_ip(request)}")
    vote_last_response_vgm(
        [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
    )
    if model_selector0 == "":
        return ("",) + (disable_btn,) * 4 + (
        gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
        gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
    else:
        return ("",) + (disable_btn,) * 4 + (
            gr.Markdown(state0.model_name, visible=False),
            gr.Markdown(state1.model_name, visible=False))

share_js = """
function (a, b, c, d) {
    const captureElement = document.querySelector('#share-region-named');
    html2canvas(captureElement)
        .then(canvas => {
            canvas.style.display = 'none'
            document.body.appendChild(canvas)
            return canvas
        })
        .then(canvas => {
            const image = canvas.toDataURL('image/png')
            const a = document.createElement('a')
            a.setAttribute('download', 'chatbot-arena.png')
            a.setAttribute('href', image)
            a.click()
            canvas.remove()
        });
    return [a, b, c, d];
}
"""
def share_click_igm(state0, state1, model_selector0, model_selector1, request: gr.Request):
    igm_logger.info(f"share (anony). ip: {get_ip(request)}")
    if state0 is not None and state1 is not None:
        vote_last_response_igm(
            [state0, state1], "share", [model_selector0, model_selector1], request
        )

def share_click_iem(state0, state1, model_selector0, model_selector1, request: gr.Request):
    iem_logger.info(f"share (anony). ip: {get_ip(request)}")
    if state0 is not None and state1 is not None:
        vote_last_response_iem(
            [state0, state1], "share", [model_selector0, model_selector1], request
        )
        
## All Generation Gradio Interface

class ImageStateIG:
    def __init__(self, model_name):
        self.conv_id = uuid.uuid4().hex
        self.model_name = model_name
        self.prompt = None
        self.output = None

    def dict(self):
        base = {
            "conv_id": self.conv_id,
            "model_name": self.model_name,
            "prompt": self.prompt
            }
        return base

class ImageStateIE:
    def __init__(self, model_name):
        self.conv_id = uuid.uuid4().hex
        self.model_name = model_name
        self.source_prompt = None
        self.target_prompt = None
        self.instruct_prompt = None
        self.source_image = None
        self.output = None

    def dict(self):
        base = {
            "conv_id": self.conv_id,
            "model_name": self.model_name,
            "source_prompt": self.source_prompt,
            "target_prompt": self.target_prompt,
            "instruct_prompt": self.instruct_prompt
            }
        return base

class VideoStateVG:
    def __init__(self, model_name):
        self.conv_id = uuid.uuid4().hex
        self.model_name = model_name
        self.prompt = None
        self.output = None

    def dict(self):
        base = {
            "conv_id": self.conv_id,
            "model_name": self.model_name,
            "prompt": self.prompt
            }
        return base


def generate_ig(gen_func, state, text, model_name, request: gr.Request):
    if not text:
        raise gr.Warning("Prompt cannot be empty.")
    if not model_name:
        raise gr.Warning("Model name cannot be empty.")
    state = ImageStateIG(model_name)
    ip = get_ip(request)
    ig_logger.info(f"generate. ip: {ip}")
    start_tstamp = time.time()
    generated_image = gen_func(text, model_name)
    if generated_image == '':
        raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
    state.prompt = text
    state.output = generated_image
    state.model_name = model_name
    
    yield state, generated_image
    
    finish_tstamp = time.time()
    # logger.info(f"===output===: {output}")

    with open(get_conv_log_filename(), "a") as fout:
        data = {
            "tstamp": round(finish_tstamp, 4),
            "type": "chat",
            "model": model_name,
            "gen_params": {},
            "start": round(start_tstamp, 4),
            "finish": round(finish_tstamp, 4),
            "state": state.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())
        
    output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    with open(output_file, 'w') as f:
        save_any_image(state.output, f)
    save_image_file_on_log_server(output_file)

def generate_ig_museum(gen_func, state, model_name, request: gr.Request):
    if not model_name:
        raise gr.Warning("Model name cannot be empty.")
    state = ImageStateIG(model_name)
    ip = get_ip(request)
    ig_logger.info(f"generate. ip: {ip}")
    start_tstamp = time.time()
    generated_image, text = gen_func(model_name)
    state.prompt = text
    state.output = generated_image
    state.model_name = model_name
    
    yield state, generated_image, text
    
    finish_tstamp = time.time()
    # logger.info(f"===output===: {output}")

    with open(get_conv_log_filename(), "a") as fout:
        data = {
            "tstamp": round(finish_tstamp, 4),
            "type": "chat",
            "model": model_name,
            "gen_params": {},
            "start": round(start_tstamp, 4),
            "finish": round(finish_tstamp, 4),
            "state": state.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())
        
    output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    with open(output_file, 'w') as f:
        save_any_image(state.output, f)
    save_image_file_on_log_server(output_file)

def generate_igm(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request):
    if not text:
        raise gr.Warning("Prompt cannot be empty.")
    if not model_name0:
        raise gr.Warning("Model name A cannot be empty.")
    if not model_name1:
        raise gr.Warning("Model name B cannot be empty.")
    state0 = ImageStateIG(model_name0)
    state1 = ImageStateIG(model_name1)
    ip = get_ip(request)
    igm_logger.info(f"generate. ip: {ip}")
    start_tstamp = time.time()
    # Remove ### Model (A|B): from model name
    model_name0 = re.sub(r"### Model A: ", "", model_name0)
    model_name1 = re.sub(r"### Model B: ", "", model_name1)
    generated_image0, generated_image1 = gen_func(text, model_name0, model_name1)
    if generated_image0 == '' and generated_image1 == '':
        raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
    state0.prompt = text
    state1.prompt = text
    state0.output = generated_image0
    state1.output = generated_image1
    state0.model_name = model_name0
    state1.model_name = model_name1
    
    yield state0, state1, generated_image0, generated_image1
    
    finish_tstamp = time.time()
    # logger.info(f"===output===: {output}")
    
    with open(get_conv_log_filename(), "a") as fout:
        data = {
            "tstamp": round(finish_tstamp, 4),
            "type": "chat",
            "model": model_name0,
            "gen_params": {},
            "start": round(start_tstamp, 4),
            "finish": round(finish_tstamp, 4),
            "state": state0.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())
        data = {
            "tstamp": round(finish_tstamp, 4),
            "type": "chat",
            "model": model_name1,
            "gen_params": {},
            "start": round(start_tstamp, 4),
            "finish": round(finish_tstamp, 4),
            "state": state1.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())
    
    for i, state in enumerate([state0, state1]):
        output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
        with open(output_file, 'w') as f:
            save_any_image(state.output, f)
        save_image_file_on_log_server(output_file)

def generate_igm_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request):
    if not model_name0:
        raise gr.Warning("Model name A cannot be empty.")
    if not model_name1:
        raise gr.Warning("Model name B cannot be empty.")
    state0 = ImageStateIG(model_name0)
    state1 = ImageStateIG(model_name1)
    ip = get_ip(request)
    igm_logger.info(f"generate. ip: {ip}")
    start_tstamp = time.time()
    # Remove ### Model (A|B): from model name
    model_name0 = re.sub(r"### Model A: ", "", model_name0)
    model_name1 = re.sub(r"### Model B: ", "", model_name1)
    generated_image0, generated_image1, text = gen_func(model_name0, model_name1)
    state0.prompt = text
    state1.prompt = text
    state0.output = generated_image0
    state1.output = generated_image1
    state0.model_name = model_name0
    state1.model_name = model_name1
    
    yield state0, state1, generated_image0, generated_image1, text
    
    finish_tstamp = time.time()
    # logger.info(f"===output===: {output}")
    
    with open(get_conv_log_filename(), "a") as fout:
        data = {
            "tstamp": round(finish_tstamp, 4),
            "type": "chat",
            "model": model_name0,
            "gen_params": {},
            "start": round(start_tstamp, 4),
            "finish": round(finish_tstamp, 4),
            "state": state0.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())
        data = {
            "tstamp": round(finish_tstamp, 4),
            "type": "chat",
            "model": model_name1,
            "gen_params": {},
            "start": round(start_tstamp, 4),
            "finish": round(finish_tstamp, 4),
            "state": state1.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())
    
    for i, state in enumerate([state0, state1]):
        output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
        with open(output_file, 'w') as f:
            save_any_image(state.output, f)
        save_image_file_on_log_server(output_file)


def generate_igm_annoy(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request):
    if not text:
        raise gr.Warning("Prompt cannot be empty.")
    state0 = ImageStateIG(model_name0)
    state1 = ImageStateIG(model_name1)
    ip = get_ip(request)
    igm_logger.info(f"generate. ip: {ip}")
    start_tstamp = time.time()
    model_name0 = ""
    model_name1 = ""
    generated_image0, generated_image1, model_name0, model_name1 = gen_func(text, model_name0, model_name1)
    if generated_image0 == '' and generated_image1 == '':
        raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
    state0.prompt = text
    state1.prompt = text
    state0.output = generated_image0
    state1.output = generated_image1
    state0.model_name = model_name0
    state1.model_name = model_name1
    
    yield state0, state1, generated_image0, generated_image1, \
        gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False)
    
    finish_tstamp = time.time()
    # logger.info(f"===output===: {output}")
    
    with open(get_conv_log_filename(), "a") as fout:
        data = {
            "tstamp": round(finish_tstamp, 4),
            "type": "chat",
            "model": model_name0,
            "gen_params": {},
            "start": round(start_tstamp, 4),
            "finish": round(finish_tstamp, 4),
            "state": state0.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())
        data = {
            "tstamp": round(finish_tstamp, 4),
            "type": "chat",
            "model": model_name1,
            "gen_params": {},
            "start": round(start_tstamp, 4),
            "finish": round(finish_tstamp, 4),
            "state": state1.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())
    
    for i, state in enumerate([state0, state1]):
        output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
        with open(output_file, 'w') as f:
            save_any_image(state.output, f)
        save_image_file_on_log_server(output_file)
            
def generate_igm_annoy_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request):
    state0 = ImageStateIG(model_name0)
    state1 = ImageStateIG(model_name1)
    ip = get_ip(request)
    igm_logger.info(f"generate. ip: {ip}")
    start_tstamp = time.time()
    # model_name0 = re.sub(r"### Model A: ", "", model_name0)
    # model_name1 = re.sub(r"### Model B: ", "", model_name1)
    model_name0 = ""
    model_name1 = ""
    generated_image0, generated_image1, model_name0, model_name1, text = gen_func(model_name0, model_name1)
    state0.prompt = text
    state1.prompt = text
    state0.output = generated_image0
    state1.output = generated_image1
    state0.model_name = model_name0
    state1.model_name = model_name1
    
    yield state0, state1, generated_image0, generated_image1, text,\
        gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False)
    
    finish_tstamp = time.time()
    # logger.info(f"===output===: {output}")
    
    with open(get_conv_log_filename(), "a") as fout:
        data = {
            "tstamp": round(finish_tstamp, 4),
            "type": "chat",
            "model": model_name0,
            "gen_params": {},
            "start": round(start_tstamp, 4),
            "finish": round(finish_tstamp, 4),
            "state": state0.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())
        data = {
            "tstamp": round(finish_tstamp, 4),
            "type": "chat",
            "model": model_name1,
            "gen_params": {},
            "start": round(start_tstamp, 4),
            "finish": round(finish_tstamp, 4),
            "state": state1.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())
    
    for i, state in enumerate([state0, state1]):
        output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
        with open(output_file, 'w') as f:
            save_any_image(state.output, f)
        save_image_file_on_log_server(output_file)
            
def generate_ie(gen_func, state, source_text, target_text, instruct_text, source_image, model_name, request: gr.Request):
    if not source_text:
        raise gr.Warning("Source prompt cannot be empty.")
    if not target_text:
        raise gr.Warning("Target prompt cannot be empty.")
    if not instruct_text:
        raise gr.Warning("Instruction prompt cannot be empty.")
    if not source_image:
        raise gr.Warning("Source image cannot be empty.")
    if not model_name:
        raise gr.Warning("Model name cannot be empty.")
    state = ImageStateIE(model_name)
    ip = get_ip(request)
    ig_logger.info(f"generate. ip: {ip}")
    start_tstamp = time.time()
    generated_image = gen_func(source_text, target_text, instruct_text, source_image, model_name)
    if generated_image == '':
        raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
    state.source_prompt = source_text
    state.target_prompt = target_text
    state.instruct_prompt = instruct_text
    state.source_image = source_image
    state.output = generated_image
    state.model_name = model_name
    
    yield state, generated_image
    
    finish_tstamp = time.time()
    # logger.info(f"===output===: {output}")

    with open(get_conv_log_filename(), "a") as fout:
        data = {
            "tstamp": round(finish_tstamp, 4),
            "type": "chat",
            "model": model_name,
            "gen_params": {},
            "start": round(start_tstamp, 4),
            "finish": round(finish_tstamp, 4),
            "state": state.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())
        
    src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg'
    os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
    with open(src_img_file, 'w') as f:
        save_any_image(state.source_image, f)
    output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg'
    with open(output_file, 'w') as f:
        save_any_image(state.output, f)
    save_image_file_on_log_server(src_img_file)
    save_image_file_on_log_server(output_file)

def generate_ie_museum(gen_func, state, model_name, request: gr.Request):
    if not model_name:
        raise gr.Warning("Model name cannot be empty.")
    state = ImageStateIE(model_name)
    ip = get_ip(request)
    ig_logger.info(f"generate. ip: {ip}")
    start_tstamp = time.time()
    source_image, generated_image, source_text, target_text, instruct_text = gen_func(model_name)
    state.source_prompt = source_text
    state.target_prompt = target_text
    state.instruct_prompt = instruct_text
    state.source_image = source_image
    state.output = generated_image
    state.model_name = model_name
    
    yield state, generated_image, source_image, source_text, target_text, instruct_text
    
    finish_tstamp = time.time()
    # logger.info(f"===output===: {output}")

    with open(get_conv_log_filename(), "a") as fout:
        data = {
            "tstamp": round(finish_tstamp, 4),
            "type": "chat",
            "model": model_name,
            "gen_params": {},
            "start": round(start_tstamp, 4),
            "finish": round(finish_tstamp, 4),
            "state": state.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())
        
    src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg'
    os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
    with open(src_img_file, 'w') as f:
        save_any_image(state.source_image, f)
    output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg'
    with open(output_file, 'w') as f:
        save_any_image(state.output, f)
    save_image_file_on_log_server(src_img_file)
    save_image_file_on_log_server(output_file)


def generate_iem(gen_func, state0, state1, source_text, target_text, instruct_text, source_image, model_name0, model_name1, request: gr.Request):
    if not source_text:
        raise gr.Warning("Source prompt cannot be empty.")
    if not target_text:
        raise gr.Warning("Target prompt cannot be empty.")
    if not instruct_text:
        raise gr.Warning("Instruction prompt cannot be empty.")
    if not source_image:
        raise gr.Warning("Source image cannot be empty.")
    if not model_name0:
        raise gr.Warning("Model name A cannot be empty.")
    if not model_name1:
        raise gr.Warning("Model name B cannot be empty.")
    state0 = ImageStateIE(model_name0)
    state1 = ImageStateIE(model_name1)
    ip = get_ip(request)
    igm_logger.info(f"generate. ip: {ip}")
    start_tstamp = time.time()
    model_name0 = re.sub(r"### Model A: ", "", model_name0)
    model_name1 = re.sub(r"### Model B: ", "", model_name1)
    generated_image0, generated_image1 = gen_func(source_text, target_text, instruct_text, source_image, model_name0, model_name1)
    if generated_image0 == '' and generated_image1 == '':
        raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
    state0.source_prompt = source_text
    state0.target_prompt = target_text
    state0.instruct_prompt = instruct_text
    state0.source_image = source_image
    state0.output = generated_image0
    state0.model_name = model_name0
    state1.source_prompt = source_text
    state1.target_prompt = target_text
    state1.instruct_prompt = instruct_text
    state1.source_image = source_image
    state1.output = generated_image1
    state1.model_name = model_name1
    
    yield state0, state1, generated_image0, generated_image1
    
    finish_tstamp = time.time()
    # logger.info(f"===output===: {output}")
    
    with open(get_conv_log_filename(), "a") as fout:
        data = {
            "tstamp": round(finish_tstamp, 4),
            "type": "chat",
            "model": model_name0,
            "gen_params": {},
            "start": round(start_tstamp, 4),
            "finish": round(finish_tstamp, 4),
            "state": state0.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())
        data = {
            "tstamp": round(finish_tstamp, 4),
            "type": "chat",
            "model": model_name1,
            "gen_params": {},
            "start": round(start_tstamp, 4),
            "finish": round(finish_tstamp, 4),
            "state": state1.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())
        
    for i, state in enumerate([state0, state1]):
        src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg'
        os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
        with open(src_img_file, 'w') as f:
            save_any_image(state.source_image, f)
        output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg'
        with open(output_file, 'w') as f:
            save_any_image(state.output, f)
        save_image_file_on_log_server(src_img_file)
        save_image_file_on_log_server(output_file)

def generate_iem_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request):
    if not model_name0:
        raise gr.Warning("Model name A cannot be empty.")
    if not model_name1:
        raise gr.Warning("Model name B cannot be empty.")
    state0 = ImageStateIE(model_name0)
    state1 = ImageStateIE(model_name1)
    ip = get_ip(request)
    igm_logger.info(f"generate. ip: {ip}")
    start_tstamp = time.time()
    model_name0 = re.sub(r"### Model A: ", "", model_name0)
    model_name1 = re.sub(r"### Model B: ", "", model_name1)
    source_image, generated_image0, generated_image1, source_text, target_text, instruct_text = gen_func(model_name0, model_name1)
    state0.source_prompt = source_text
    state0.target_prompt = target_text
    state0.instruct_prompt = instruct_text
    state0.source_image = source_image
    state0.output = generated_image0
    state0.model_name = model_name0
    state1.source_prompt = source_text
    state1.target_prompt = target_text
    state1.instruct_prompt = instruct_text
    state1.source_image = source_image
    state1.output = generated_image1
    state1.model_name = model_name1
    
    yield state0, state1, generated_image0, generated_image1, source_image, source_text, target_text, instruct_text
    
    finish_tstamp = time.time()
    # logger.info(f"===output===: {output}")
    
    with open(get_conv_log_filename(), "a") as fout:
        data = {
            "tstamp": round(finish_tstamp, 4),
            "type": "chat",
            "model": model_name0,
            "gen_params": {},
            "start": round(start_tstamp, 4),
            "finish": round(finish_tstamp, 4),
            "state": state0.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())
        data = {
            "tstamp": round(finish_tstamp, 4),
            "type": "chat",
            "model": model_name1,
            "gen_params": {},
            "start": round(start_tstamp, 4),
            "finish": round(finish_tstamp, 4),
            "state": state1.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())
        
    for i, state in enumerate([state0, state1]):
        src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg'
        os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
        with open(src_img_file, 'w') as f:
            save_any_image(state.source_image, f)
        output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg'
        with open(output_file, 'w') as f:
            save_any_image(state.output, f)
        save_image_file_on_log_server(src_img_file)
        save_image_file_on_log_server(output_file)
            

def generate_iem_annoy(gen_func, state0, state1, source_text, target_text, instruct_text, source_image, model_name0, model_name1, request: gr.Request):
    if not source_text:
        raise gr.Warning("Source prompt cannot be empty.")
    if not target_text:
        raise gr.Warning("Target prompt cannot be empty.")
    if not instruct_text:
        raise gr.Warning("Instruction prompt cannot be empty.")
    if not source_image:
        raise gr.Warning("Source image cannot be empty.")
    state0 = ImageStateIE(model_name0)
    state1 = ImageStateIE(model_name1)
    ip = get_ip(request)
    igm_logger.info(f"generate. ip: {ip}")
    start_tstamp = time.time()
    model_name0 = ""
    model_name1 = ""
    generated_image0, generated_image1, model_name0, model_name1 = gen_func(source_text, target_text, instruct_text, source_image, model_name0, model_name1)
    if generated_image0 == '' and generated_image1 == '':
        raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
    state0.source_prompt = source_text
    state0.target_prompt = target_text
    state0.instruct_prompt = instruct_text
    state0.source_image = source_image
    state0.output = generated_image0
    state0.model_name = model_name0
    state1.source_prompt = source_text
    state1.target_prompt = target_text
    state1.instruct_prompt = instruct_text
    state1.source_image = source_image
    state1.output = generated_image1
    state1.model_name = model_name1
    
    yield state0, state1, generated_image0, generated_image1, \
        gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False)
    
    finish_tstamp = time.time()
    # logger.info(f"===output===: {output}")
    
    with open(get_conv_log_filename(), "a") as fout:
        data = {
            "tstamp": round(finish_tstamp, 4),
            "type": "chat",
            "model": model_name0,
            "gen_params": {},
            "start": round(start_tstamp, 4),
            "finish": round(finish_tstamp, 4),
            "state": state0.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())
        data = {
            "tstamp": round(finish_tstamp, 4),
            "type": "chat",
            "model": model_name1,
            "gen_params": {},
            "start": round(start_tstamp, 4),
            "finish": round(finish_tstamp, 4),
            "state": state1.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())

    for i, state in enumerate([state0, state1]):
        src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg'
        os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
        with open(src_img_file, 'w') as f:
            save_any_image(state.source_image, f)
        output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg'
        with open(output_file, 'w') as f:
            save_any_image(state.output, f)
        save_image_file_on_log_server(src_img_file)
        save_image_file_on_log_server(output_file)

def generate_iem_annoy_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request):
    state0 = ImageStateIE(model_name0)
    state1 = ImageStateIE(model_name1)
    ip = get_ip(request)
    igm_logger.info(f"generate. ip: {ip}")
    start_tstamp = time.time()
    model_name0 = ""
    model_name1 = ""
    source_image, generated_image0, generated_image1, source_text, target_text, instruct_text, model_name0, model_name1 = gen_func(model_name0, model_name1)
    state0.source_prompt = source_text
    state0.target_prompt = target_text
    state0.instruct_prompt = instruct_text
    state0.source_image = source_image
    state0.output = generated_image0
    state0.model_name = model_name0
    state1.source_prompt = source_text
    state1.target_prompt = target_text
    state1.instruct_prompt = instruct_text
    state1.source_image = source_image
    state1.output = generated_image1
    state1.model_name = model_name1
    
    yield state0, state1, generated_image0, generated_image1, source_image, source_text, target_text, instruct_text, \
        gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False)
    
    finish_tstamp = time.time()
    # logger.info(f"===output===: {output}")
    
    with open(get_conv_log_filename(), "a") as fout:
        data = {
            "tstamp": round(finish_tstamp, 4),
            "type": "chat",
            "model": model_name0,
            "gen_params": {},
            "start": round(start_tstamp, 4),
            "finish": round(finish_tstamp, 4),
            "state": state0.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())
        data = {
            "tstamp": round(finish_tstamp, 4),
            "type": "chat",
            "model": model_name1,
            "gen_params": {},
            "start": round(start_tstamp, 4),
            "finish": round(finish_tstamp, 4),
            "state": state1.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())

    for i, state in enumerate([state0, state1]):
        src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg'
        os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
        with open(src_img_file, 'w') as f:
            save_any_image(state.source_image, f)
        output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg'
        with open(output_file, 'w') as f:
            save_any_image(state.output, f)
        save_image_file_on_log_server(src_img_file)
        save_image_file_on_log_server(output_file)

def generate_vg(gen_func, state, text, model_name, request: gr.Request):
    if not text:
        raise gr.Warning("Prompt cannot be empty.")
    if not model_name:
        raise gr.Warning("Model name cannot be empty.")
    state = VideoStateVG(model_name)
    ip = get_ip(request)
    vg_logger.info(f"generate. ip: {ip}")
    start_tstamp = time.time()
    generated_video = gen_func(text, model_name)
    if generated_video == '':
        raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
    state.prompt = text
    state.output = generated_video
    state.model_name = model_name

    # yield state, generated_video

    finish_tstamp = time.time()

    with open(get_conv_log_filename(), "a") as fout:
        data = {
            "tstamp": round(finish_tstamp, 4),
            "type": "chat",
            "model": model_name,
            "gen_params": {},
            "start": round(start_tstamp, 4),
            "finish": round(finish_tstamp, 4),
            "state": state.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())

    output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    if model_name.startswith('fal'):
        r = requests.get(state.output)
        with open(output_file, 'wb') as outfile:
            outfile.write(r.content)
    else:
        print("======== video shape: ========")
        print(state.output.shape)
        # Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels]
        if state.output.shape[-1] != 3:
            state.output = state.output.permute(0, 2, 3, 1)
        imageio.mimwrite(output_file, state.output, fps=8, quality=9)

    save_video_file_on_log_server(output_file)
    yield state, output_file

def generate_vg_museum(gen_func, state, model_name, request: gr.Request):
    state = VideoStateVG(model_name)
    ip = get_ip(request)
    vg_logger.info(f"generate. ip: {ip}")
    start_tstamp = time.time()
    generated_video, text = gen_func(model_name)
    state.prompt = text
    state.output = generated_video
    state.model_name = model_name

    # yield state, generated_video

    finish_tstamp = time.time()

    with open(get_conv_log_filename(), "a") as fout:
        data = {
            "tstamp": round(finish_tstamp, 4),
            "type": "chat",
            "model": model_name,
            "gen_params": {},
            "start": round(start_tstamp, 4),
            "finish": round(finish_tstamp, 4),
            "state": state.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())

    output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
    os.makedirs(os.path.dirname(output_file), exist_ok=True)

    r = requests.get(state.output)
    with open(output_file, 'wb') as outfile:
        outfile.write(r.content)

    save_video_file_on_log_server(output_file)
    yield state, output_file, text


def generate_vgm(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request):
    if not text:
        raise gr.Warning("Prompt cannot be empty.")
    if not model_name0:
        raise gr.Warning("Model name A cannot be empty.")
    if not model_name1:
        raise gr.Warning("Model name B cannot be empty.")
    state0 = VideoStateVG(model_name0)
    state1 = VideoStateVG(model_name1)
    ip = get_ip(request)
    igm_logger.info(f"generate. ip: {ip}")
    start_tstamp = time.time()
    # Remove ### Model (A|B): from model name
    model_name0 = re.sub(r"### Model A: ", "", model_name0)
    model_name1 = re.sub(r"### Model B: ", "", model_name1)
    generated_video0, generated_video1 = gen_func(text, model_name0, model_name1)
    if generated_video0 == '' and generated_video1 == '':
        raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
    state0.prompt = text
    state1.prompt = text
    state0.output = generated_video0
    state1.output = generated_video1
    state0.model_name = model_name0
    state1.model_name = model_name1

    # yield state0, state1, generated_video0, generated_video1
    print("====== model name =========")
    print(state0.model_name)
    print(state1.model_name)


    finish_tstamp = time.time()


    with open(get_conv_log_filename(), "a") as fout:
        data = {
            "tstamp": round(finish_tstamp, 4),
            "type": "chat",
            "model": model_name0,
            "gen_params": {},
            "start": round(start_tstamp, 4),
            "finish": round(finish_tstamp, 4),
            "state": state0.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())
        data = {
            "tstamp": round(finish_tstamp, 4),
            "type": "chat",
            "model": model_name1,
            "gen_params": {},
            "start": round(start_tstamp, 4),
            "finish": round(finish_tstamp, 4),
            "state": state1.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())

    for i, state in enumerate([state0, state1]):
        output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
        print(state.model_name)

        if state.model_name.startswith('fal'):
            r = requests.get(state.output)
            with open(output_file, 'wb') as outfile:
                outfile.write(r.content)
        else:
            print("======== video shape: ========")
            print(state.output)
            print(state.output.shape)
            # Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels]
            if state.output.shape[-1] != 3:
                state.output = state.output.permute(0, 2, 3, 1)
            imageio.mimwrite(output_file, state.output, fps=8, quality=9)
        save_video_file_on_log_server(output_file)
    yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4'

def generate_vgm_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request):
    state0 = VideoStateVG(model_name0)
    state1 = VideoStateVG(model_name1)
    ip = get_ip(request)
    igm_logger.info(f"generate. ip: {ip}")
    start_tstamp = time.time()
    # Remove ### Model (A|B): from model name
    model_name0 = re.sub(r"### Model A: ", "", model_name0)
    model_name1 = re.sub(r"### Model B: ", "", model_name1)
    generated_video0, generated_video1, text = gen_func(model_name0, model_name1)
    state0.prompt = text
    state1.prompt = text
    state0.output = generated_video0
    state1.output = generated_video1
    state0.model_name = model_name0
    state1.model_name = model_name1

    # yield state0, state1, generated_video0, generated_video1
    print("====== model name =========")
    print(state0.model_name)
    print(state1.model_name)


    finish_tstamp = time.time()


    with open(get_conv_log_filename(), "a") as fout:
        data = {
            "tstamp": round(finish_tstamp, 4),
            "type": "chat",
            "model": model_name0,
            "gen_params": {},
            "start": round(start_tstamp, 4),
            "finish": round(finish_tstamp, 4),
            "state": state0.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())
        data = {
            "tstamp": round(finish_tstamp, 4),
            "type": "chat",
            "model": model_name1,
            "gen_params": {},
            "start": round(start_tstamp, 4),
            "finish": round(finish_tstamp, 4),
            "state": state1.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())

    for i, state in enumerate([state0, state1]):
        output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
        print(state.model_name)

        r = requests.get(state.output)
        with open(output_file, 'wb') as outfile:
            outfile.write(r.content)

        save_video_file_on_log_server(output_file)
    yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4', text


def generate_vgm_annoy(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request):
    if not text:
        raise gr.Warning("Prompt cannot be empty.")
    state0 = VideoStateVG(model_name0)
    state1 = VideoStateVG(model_name1)
    ip = get_ip(request)
    vgm_logger.info(f"generate. ip: {ip}")
    start_tstamp = time.time()
    model_name0 = ""
    model_name1 = ""
    generated_video0, generated_video1, model_name0, model_name1 = gen_func(text, model_name0, model_name1)
    if generated_video0 == '' and generated_video1 == '':
        raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
    state0.prompt = text
    state1.prompt = text
    state0.output = generated_video0
    state1.output = generated_video1
    state0.model_name = model_name0
    state1.model_name = model_name1
    
    # yield state0, state1, generated_video0, generated_video1, \
    #     gr.Markdown(f"### Model A: {model_name0}"), gr.Markdown(f"### Model B: {model_name1}")

    finish_tstamp = time.time()
    # logger.info(f"===output===: {output}")

    with open(get_conv_log_filename(), "a") as fout:
        data = {
            "tstamp": round(finish_tstamp, 4),
            "type": "chat",
            "model": model_name0,
            "gen_params": {},
            "start": round(start_tstamp, 4),
            "finish": round(finish_tstamp, 4),
            "state": state0.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())
        data = {
            "tstamp": round(finish_tstamp, 4),
            "type": "chat",
            "model": model_name1,
            "gen_params": {},
            "start": round(start_tstamp, 4),
            "finish": round(finish_tstamp, 4),
            "state": state1.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())

    for i, state in enumerate([state0, state1]):
        output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
        if state.model_name.startswith('fal'):
            r = requests.get(state.output)
            with open(output_file, 'wb') as outfile:
                outfile.write(r.content)
        else:
            print("======== video shape: ========")
            print(state.output.shape)
            # Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels]
            if state.output.shape[-1] != 3:
                state.output = state.output.permute(0, 2, 3, 1)
            imageio.mimwrite(output_file, state.output, fps=8, quality=9)
        save_video_file_on_log_server(output_file)

    yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4', \
        gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False)

def generate_vgm_annoy_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request):
    state0 = VideoStateVG(model_name0)
    state1 = VideoStateVG(model_name1)
    ip = get_ip(request)
    vgm_logger.info(f"generate. ip: {ip}")
    start_tstamp = time.time()
    model_name0 = ""
    model_name1 = ""
    generated_video0, generated_video1, model_name0, model_name1, text = gen_func(model_name0, model_name1)
    state0.prompt = text
    state1.prompt = text
    state0.output = generated_video0
    state1.output = generated_video1
    state0.model_name = model_name0
    state1.model_name = model_name1
    
    # yield state0, state1, generated_video0, generated_video1, \
    #     gr.Markdown(f"### Model A: {model_name0}"), gr.Markdown(f"### Model B: {model_name1}")

    finish_tstamp = time.time()
    # logger.info(f"===output===: {output}")

    with open(get_conv_log_filename(), "a") as fout:
        data = {
            "tstamp": round(finish_tstamp, 4),
            "type": "chat",
            "model": model_name0,
            "gen_params": {},
            "start": round(start_tstamp, 4),
            "finish": round(finish_tstamp, 4),
            "state": state0.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())
        data = {
            "tstamp": round(finish_tstamp, 4),
            "type": "chat",
            "model": model_name1,
            "gen_params": {},
            "start": round(start_tstamp, 4),
            "finish": round(finish_tstamp, 4),
            "state": state1.dict(),
            "ip": get_ip(request),
        }
        fout.write(json.dumps(data) + "\n")
        append_json_item_on_log_server(data, get_conv_log_filename())

    for i, state in enumerate([state0, state1]):
        output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
        os.makedirs(os.path.dirname(output_file), exist_ok=True)

        r = requests.get(state.output)
        with open(output_file, 'wb') as outfile:
            outfile.write(r.content)

        save_video_file_on_log_server(output_file)

    yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4', text,\
        gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False)