try:
    import spaces
except:

    class NoneSpaces:
        def __init__(self):
            pass

        def GPU(self, fn):
            return fn

    spaces = NoneSpaces()

import os
import logging
import sys

import numpy as np

from modules.devices import devices
from modules.synthesize_audio import synthesize_audio

logging.basicConfig(
    level=os.getenv("LOG_LEVEL", "INFO"),
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)


import gradio as gr

import torch

from modules.ssml import parse_ssml
from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments

from modules.speaker import speaker_mgr
from modules.data import styles_mgr

from modules.api.utils import calc_spk_style
import modules.generate_audio as generate

from modules.normalization import text_normalize
from modules import refiner, config, models

from modules.utils import env, audio
from modules.SentenceSplitter import SentenceSplitter

# fix: If the system proxy is enabled in the Windows system, you need to skip these
os.environ["NO_PROXY"] = "localhost,127.0.0.1,0.0.0.0"

torch._dynamo.config.cache_size_limit = 64
torch._dynamo.config.suppress_errors = True
torch.set_float32_matmul_precision("high")

webui_config = {
    "tts_max": 1000,
    "ssml_max": 5000,
    "spliter_threshold": 100,
    "max_batch_size": 8,
}


def get_speakers():
    return speaker_mgr.list_speakers()


def get_styles():
    return styles_mgr.list_items()


def segments_length_limit(segments, total_max: int):
    ret_segments = []
    total_len = 0
    for seg in segments:
        if "text" not in seg:
            continue
        total_len += len(seg["text"])
        if total_len > total_max:
            break
        ret_segments.append(seg)
    return ret_segments


@torch.inference_mode()
@spaces.GPU
def synthesize_ssml(ssml: str, batch_size=4):
    try:
        batch_size = int(batch_size)
    except Exception:
        batch_size = 8

    ssml = ssml.strip()

    if ssml == "":
        return None

    segments = parse_ssml(ssml)
    max_len = webui_config["ssml_max"]
    segments = segments_length_limit(segments, max_len)

    if len(segments) == 0:
        return None

    models.load_chat_tts()
    synthesize = SynthesizeSegments(batch_size=batch_size)
    audio_segments = synthesize.synthesize_segments(segments)
    combined_audio = combine_audio_segments(audio_segments)

    return audio.pydub_to_np(combined_audio)


@torch.inference_mode()
@spaces.GPU
def tts_generate(
    text,
    temperature,
    top_p,
    top_k,
    spk,
    infer_seed,
    use_decoder,
    prompt1,
    prompt2,
    prefix,
    style,
    disable_normalize=False,
    batch_size=4,
):
    try:
        batch_size = int(batch_size)
    except Exception:
        batch_size = 4

    max_len = webui_config["tts_max"]
    text = text.strip()[0:max_len]

    if text == "":
        return None

    if style == "*auto":
        style = None

    if isinstance(top_k, float):
        top_k = int(top_k)

    params = calc_spk_style(spk=spk, style=style)
    spk = params.get("spk", spk)

    infer_seed = infer_seed or params.get("seed", infer_seed)
    temperature = temperature or params.get("temperature", temperature)
    prefix = prefix or params.get("prefix", prefix)
    prompt1 = prompt1 or params.get("prompt1", "")
    prompt2 = prompt2 or params.get("prompt2", "")

    infer_seed = np.clip(infer_seed, -1, 2**32 - 1, out=None, dtype=np.int64)
    infer_seed = int(infer_seed)

    if not disable_normalize:
        text = text_normalize(text)

    models.load_chat_tts()
    sample_rate, audio_data = synthesize_audio(
        text=text,
        temperature=temperature,
        top_P=top_p,
        top_K=top_k,
        spk=spk,
        infer_seed=infer_seed,
        use_decoder=use_decoder,
        prompt1=prompt1,
        prompt2=prompt2,
        prefix=prefix,
        batch_size=batch_size,
    )

    audio_data = audio.audio_to_int16(audio_data)
    return sample_rate, audio_data


@torch.inference_mode()
@spaces.GPU
def refine_text(text: str, prompt: str):
    text = text_normalize(text)
    return refiner.refine_text(text, prompt=prompt)


def read_local_readme():
    with open("README.md", "r", encoding="utf-8") as file:
        content = file.read()
        content = content[content.index("# ") :]
        return content


# 演示示例文本
sample_texts = [
    {
        "text": "大🍌,一条大🍌,嘿,你的感觉真的很奇妙  [lbreak]",
    },
    {
        "text": "天气预报显示,今天会有小雨,请大家出门时记得带伞。降温的天气也提醒我们要适时添衣保暖 [lbreak]",
    },
    {
        "text": "公司的年度总结会议将在下周三举行,请各部门提前准备好相关材料,确保会议顺利进行 [lbreak]",
    },
    {
        "text": "今天的午餐菜单包括烤鸡、沙拉和蔬菜汤,大家可以根据自己的口味选择适合的菜品 [lbreak]",
    },
    {
        "text": "请注意,电梯将在下午两点进行例行维护,预计需要一个小时的时间,请大家在此期间使用楼梯 [lbreak]",
    },
    {
        "text": "图书馆新到了一批书籍,涵盖了文学、科学和历史等多个领域,欢迎大家前来借阅 [lbreak]",
    },
    {
        "text": "电影中梁朝伟扮演的陈永仁的编号27149 [lbreak]",
    },
    {
        "text": "这块黄金重达324.75克 [lbreak]",
    },
    {
        "text": "我们班的最高总分为583分 [lbreak]",
    },
    {
        "text": "12~23 [lbreak]",
    },
    {
        "text": "-1.5~2 [lbreak]",
    },
    {
        "text": "她出生于86年8月18日,她弟弟出生于1995年3月1日 [lbreak]",
    },
    {
        "text": "等会请在12:05请通知我 [lbreak]",
    },
    {
        "text": "今天的最低气温达到-10°C [lbreak]",
    },
    {
        "text": "现场有7/12的观众投出了赞成票 [lbreak]",
    },
    {
        "text": "明天有62%的概率降雨 [lbreak]",
    },
    {
        "text": "随便来几个价格12块5,34.5元,20.1万 [lbreak]",
    },
    {
        "text": "这是固话0421-33441122 [lbreak]",
    },
    {
        "text": "这是手机+86 18544139121 [lbreak]",
    },
]

ssml_example1 = """
<speak version="0.1">
    <voice spk="Bob" seed="42" style="narration-relaxed">
        下面是一个 ChatTTS 用于合成多角色多情感的有声书示例[lbreak]
    </voice>
    <voice spk="Bob" seed="42" style="narration-relaxed">
        黛玉冷笑道:[lbreak]
    </voice>
    <voice spk="female2" seed="42" style="angry">
        我说呢 [uv_break] ,亏了绊住,不然,早就飞起来了[lbreak]
    </voice>
    <voice spk="Bob" seed="42" style="narration-relaxed">
        宝玉道:[lbreak]
    </voice>
    <voice spk="Alice" seed="42" style="unfriendly">
        “只许和你玩 [uv_break] ,替你解闷。不过偶然到他那里,就说这些闲话。”[lbreak]
    </voice>
    <voice spk="female2" seed="42" style="angry">
        “好没意思的话![uv_break] 去不去,关我什么事儿? 又没叫你替我解闷儿 [uv_break],还许你不理我呢” [lbreak]
    </voice>
    <voice spk="Bob" seed="42" style="narration-relaxed">
        说着,便赌气回房去了 [lbreak]
    </voice>
</speak>
"""
ssml_example2 = """
<speak version="0.1">
    <voice spk="Bob" seed="42" style="narration-relaxed">
        使用 prosody 控制生成文本的语速语调和音量,示例如下 [lbreak]

        <prosody>
            无任何限制将会继承父级voice配置进行生成 [lbreak]
        </prosody>
        <prosody rate="1.5">
            设置 rate 大于1表示加速,小于1为减速 [lbreak]
        </prosody>
        <prosody pitch="6">
            设置 pitch 调整音调,设置为6表示提高6个半音 [lbreak]
        </prosody>
        <prosody volume="2">
            设置 volume 调整音量,设置为2表示提高2个分贝 [lbreak]
        </prosody>

        在 voice 中无prosody包裹的文本即为默认生成状态下的语音 [lbreak]
    </voice>
</speak>
"""
ssml_example3 = """
<speak version="0.1">
    <voice spk="Bob" seed="42" style="narration-relaxed">
        使用 break 标签将会简单的 [lbreak]
        
        <break time="500" />

        插入一段空白到生成结果中 [lbreak]
    </voice>
</speak>
"""

ssml_example4 = """
<speak version="0.1">
    <voice spk="Bob" seed="42" style="excited">
        temperature for sampling (may be overridden by style or speaker) [lbreak]
        <break time="500" />
        温度值用于采样,这个值有可能被 style 或者 speaker 覆盖  [lbreak]
        <break time="500" />
        temperature for sampling ,这个值有可能被 style 或者 speaker 覆盖  [lbreak]
        <break time="500" />
        温度值用于采样,(may be overridden by style or speaker) [lbreak]
    </voice>
</speak>
"""

default_ssml = """
<speak version="0.1">
  <voice spk="Bob" seed="42" style="narration-relaxed">
    这里是一个简单的 SSML 示例 [lbreak] 
  </voice>
</speak>
"""


def create_tts_interface():
    speakers = get_speakers()

    def get_speaker_show_name(spk):
        if spk.gender == "*" or spk.gender == "":
            return spk.name
        return f"{spk.gender} : {spk.name}"

    speaker_names = ["*random"] + [
        get_speaker_show_name(speaker) for speaker in speakers
    ]

    styles = ["*auto"] + [s.get("name") for s in get_styles()]

    history = []

    with gr.Row():
        with gr.Column(scale=1):
            with gr.Group():
                gr.Markdown("🎛️Sampling")
                temperature_input = gr.Slider(
                    0.01, 2.0, value=0.3, step=0.01, label="Temperature"
                )
                top_p_input = gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Top P")
                top_k_input = gr.Slider(1, 50, value=20, step=1, label="Top K")
                batch_size_input = gr.Slider(
                    1,
                    webui_config["max_batch_size"],
                    value=4,
                    step=1,
                    label="Batch Size",
                )

            with gr.Row():
                with gr.Group():
                    gr.Markdown("🎭Style")
                    gr.Markdown("- 后缀为 `_p` 表示带prompt,效果更强但是影响质量")
                    style_input_dropdown = gr.Dropdown(
                        choices=styles,
                        # label="Choose Style",
                        interactive=True,
                        show_label=False,
                        value="*auto",
                    )
            with gr.Row():
                with gr.Group():
                    gr.Markdown("🗣️Speaker (Name or Seed)")
                    spk_input_text = gr.Textbox(
                        label="Speaker (Text or Seed)",
                        value="female2",
                        show_label=False,
                    )
                    spk_input_dropdown = gr.Dropdown(
                        choices=speaker_names,
                        # label="Choose Speaker",
                        interactive=True,
                        value="female : female2",
                        show_label=False,
                    )
                    spk_rand_button = gr.Button(
                        value="🎲",
                        # tooltip="Random Seed",
                        variant="secondary",
                    )
                    spk_input_dropdown.change(
                        fn=lambda x: x.startswith("*")
                        and "-1"
                        or x.split(":")[-1].strip(),
                        inputs=[spk_input_dropdown],
                        outputs=[spk_input_text],
                    )
                    spk_rand_button.click(
                        lambda x: str(torch.randint(0, 2**32 - 1, (1,)).item()),
                        inputs=[spk_input_text],
                        outputs=[spk_input_text],
                    )
            with gr.Group():
                gr.Markdown("💃Inference Seed")
                infer_seed_input = gr.Number(
                    value=42,
                    label="Inference Seed",
                    show_label=False,
                    minimum=-1,
                    maximum=2**32 - 1,
                )
                infer_seed_rand_button = gr.Button(
                    value="🎲",
                    # tooltip="Random Seed",
                    variant="secondary",
                )
            use_decoder_input = gr.Checkbox(
                value=True, label="Use Decoder", visible=False
            )
            with gr.Group():
                gr.Markdown("🔧Prompt engineering")
                prompt1_input = gr.Textbox(label="Prompt 1")
                prompt2_input = gr.Textbox(label="Prompt 2")
                prefix_input = gr.Textbox(label="Prefix")

            infer_seed_rand_button.click(
                lambda x: int(torch.randint(0, 2**32 - 1, (1,)).item()),
                inputs=[infer_seed_input],
                outputs=[infer_seed_input],
            )
        with gr.Column(scale=3):
            with gr.Row():
                with gr.Column(scale=4):
                    with gr.Group():
                        input_title = gr.Markdown(
                            "📝Text Input",
                            elem_id="input-title",
                        )
                        gr.Markdown(
                            f"- 字数限制{webui_config['tts_max']:,}字,超过部分截断"
                        )
                        gr.Markdown("- 如果尾字吞字不读,可以试试结尾加上 `[lbreak]`")
                        gr.Markdown(
                            "- If the input text is all in English, it is recommended to check disable_normalize"
                        )
                        text_input = gr.Textbox(
                            show_label=False,
                            label="Text to Speech",
                            lines=10,
                            placeholder="输入文本或选择示例",
                            elem_id="text-input",
                        )
                        # TODO 字数统计,其实实现很好写,但是就是会触发loading...并且还要和后端交互...
                        # text_input.change(
                        #     fn=lambda x: (
                        #         f"📝Text Input ({len(x)} char)"
                        #         if x
                        #         else (
                        #             "📝Text Input (0 char)"
                        #             if not x
                        #             else "📝Text Input (0 char)"
                        #         )
                        #     ),
                        #     inputs=[text_input],
                        #     outputs=[input_title],
                        # )
                        with gr.Row():
                            contorl_tokens = [
                                "[laugh]",
                                "[uv_break]",
                                "[v_break]",
                                "[lbreak]",
                            ]

                            for tk in contorl_tokens:
                                t_btn = gr.Button(tk)
                                t_btn.click(
                                    lambda text, tk=tk: text + " " + tk,
                                    inputs=[text_input],
                                    outputs=[text_input],
                                )
                with gr.Column(scale=1):
                    with gr.Group():
                        gr.Markdown("🎶Refiner")
                        refine_prompt_input = gr.Textbox(
                            label="Refine Prompt",
                            value="[oral_2][laugh_0][break_6]",
                        )
                        refine_button = gr.Button("✍️Refine Text")
                        # TODO 分割句子,使用当前配置拼接为SSML,然后发送到SSML tab
                        # send_button = gr.Button("📩Split and send to SSML")

                    with gr.Group():
                        gr.Markdown("🔊Generate")
                        disable_normalize_input = gr.Checkbox(
                            value=False, label="Disable Normalize"
                        )
                        tts_button = gr.Button(
                            "🔊Generate Audio",
                            variant="primary",
                            elem_classes="big-button",
                        )

            with gr.Group():
                gr.Markdown("🎄Examples")
                sample_dropdown = gr.Dropdown(
                    choices=[sample["text"] for sample in sample_texts],
                    show_label=False,
                    value=None,
                    interactive=True,
                )
                sample_dropdown.change(
                    fn=lambda x: x,
                    inputs=[sample_dropdown],
                    outputs=[text_input],
                )

            with gr.Group():
                gr.Markdown("🎨Output")
                tts_output = gr.Audio(label="Generated Audio")

    refine_button.click(
        refine_text,
        inputs=[text_input, refine_prompt_input],
        outputs=[text_input],
    )

    tts_button.click(
        tts_generate,
        inputs=[
            text_input,
            temperature_input,
            top_p_input,
            top_k_input,
            spk_input_text,
            infer_seed_input,
            use_decoder_input,
            prompt1_input,
            prompt2_input,
            prefix_input,
            style_input_dropdown,
            disable_normalize_input,
            batch_size_input,
        ],
        outputs=tts_output,
    )


def create_ssml_interface():
    examples = [
        ssml_example1,
        ssml_example2,
        ssml_example3,
        ssml_example4,
    ]

    with gr.Row():
        with gr.Column(scale=3):
            with gr.Group():
                gr.Markdown("📝SSML Input")
                gr.Markdown(f"- 最长{webui_config['ssml_max']:,}字符,超过会被截断")
                gr.Markdown("- 尽量保证使用相同的 seed")
                gr.Markdown(
                    "- 关于SSML可以看这个 [文档](https://github.com/lenML/ChatTTS-Forge/blob/main/docs/SSML.md)"
                )
                ssml_input = gr.Textbox(
                    label="SSML Input",
                    lines=10,
                    value=default_ssml,
                    placeholder="输入 SSML 或选择示例",
                    elem_id="ssml_input",
                    show_label=False,
                )
                ssml_button = gr.Button("🔊Synthesize SSML", variant="primary")
        with gr.Column(scale=1):
            with gr.Group():
                # 参数
                gr.Markdown("🎛️Parameters")
                # batch size
                batch_size_input = gr.Slider(
                    label="Batch Size",
                    value=4,
                    minimum=1,
                    maximum=webui_config["max_batch_size"],
                    step=1,
                )
            with gr.Group():
                gr.Markdown("🎄Examples")
                gr.Examples(
                    examples=examples,
                    inputs=[ssml_input],
                )

    ssml_output = gr.Audio(label="Generated Audio")

    ssml_button.click(
        synthesize_ssml,
        inputs=[ssml_input, batch_size_input],
        outputs=ssml_output,
    )

    return ssml_input


# NOTE: 这个其实是需要GPU的...但是spaces会自动卸载,所以不太好使,具体处理在text_normalize中兼容
# @spaces.GPU
def split_long_text(long_text_input):
    spliter = SentenceSplitter(webui_config["spliter_threshold"])
    sentences = spliter.parse(long_text_input)
    sentences = [text_normalize(s) for s in sentences]
    data = []
    for i, text in enumerate(sentences):
        data.append([i, text, len(text)])
    return data


def merge_dataframe_to_ssml(dataframe, spk, style, seed):
    if style == "*auto":
        style = None
    if spk == "-1" or spk == -1:
        spk = None
    if seed == -1 or seed == "-1":
        seed = None

    ssml = ""
    indent = " " * 2

    for i, row in dataframe.iterrows():
        ssml += f"{indent}<voice"
        if spk:
            ssml += f' spk="{spk}"'
        if style:
            ssml += f' style="{style}"'
        if seed:
            ssml += f' seed="{seed}"'
        ssml += ">\n"
        ssml += f"{indent}{indent}{text_normalize(row[1])}\n"
        ssml += f"{indent}</voice>\n"
    return f"<speak version='0.1'>\n{ssml}</speak>"


# 长文本处理
# 可以输入长文本,并选择切割方法,切割之后可以将拼接的SSML发送到SSML tab
# 根据 。 句号切割,切割之后显示到 data table
def create_long_content_tab(ssml_input, tabs):
    speakers = get_speakers()

    def get_speaker_show_name(spk):
        if spk.gender == "*" or spk.gender == "":
            return spk.name
        return f"{spk.gender} : {spk.name}"

    speaker_names = ["*random"] + [
        get_speaker_show_name(speaker) for speaker in speakers
    ]

    styles = ["*auto"] + [s.get("name") for s in get_styles()]

    with gr.Row():
        with gr.Column(scale=1):
            # 选择说话人 选择风格 选择seed
            with gr.Group():
                gr.Markdown("🗣️Speaker")
                spk_input_text = gr.Textbox(
                    label="Speaker (Text or Seed)",
                    value="female2",
                    show_label=False,
                )
                spk_input_dropdown = gr.Dropdown(
                    choices=speaker_names,
                    interactive=True,
                    value="female : female2",
                    show_label=False,
                )
                spk_rand_button = gr.Button(
                    value="🎲",
                    variant="secondary",
                )
            with gr.Group():
                gr.Markdown("🎭Style")
                style_input_dropdown = gr.Dropdown(
                    choices=styles,
                    interactive=True,
                    show_label=False,
                    value="*auto",
                )
            with gr.Group():
                gr.Markdown("🗣️Seed")
                infer_seed_input = gr.Number(
                    value=42,
                    label="Inference Seed",
                    show_label=False,
                    minimum=-1,
                    maximum=2**32 - 1,
                )
                infer_seed_rand_button = gr.Button(
                    value="🎲",
                    variant="secondary",
                )

            send_btn = gr.Button("📩Send to SSML", variant="primary")

        with gr.Column(scale=3):
            with gr.Group():
                gr.Markdown("📝Long Text Input")
                gr.Markdown("- 此页面用于处理超长文本")
                gr.Markdown("- 切割后,可以选择说话人、风格、seed,然后发送到SSML")
                long_text_input = gr.Textbox(
                    label="Long Text Input",
                    lines=10,
                    placeholder="输入长文本",
                    elem_id="long-text-input",
                    show_label=False,
                )
                long_text_split_button = gr.Button("🔪Split Text")

    with gr.Row():
        with gr.Column(scale=3):
            with gr.Group():
                gr.Markdown("🎨Output")
                long_text_output = gr.DataFrame(
                    headers=["index", "text", "length"],
                    datatype=["number", "str", "number"],
                    elem_id="long-text-output",
                    interactive=False,
                    wrap=True,
                    value=[],
                )

    spk_input_dropdown.change(
        fn=lambda x: x.startswith("*") and "-1" or x.split(":")[-1].strip(),
        inputs=[spk_input_dropdown],
        outputs=[spk_input_text],
    )
    spk_rand_button.click(
        lambda x: int(torch.randint(0, 2**32 - 1, (1,)).item()),
        inputs=[spk_input_text],
        outputs=[spk_input_text],
    )
    infer_seed_rand_button.click(
        lambda x: int(torch.randint(0, 2**32 - 1, (1,)).item()),
        inputs=[infer_seed_input],
        outputs=[infer_seed_input],
    )
    long_text_split_button.click(
        split_long_text,
        inputs=[long_text_input],
        outputs=[long_text_output],
    )

    infer_seed_rand_button.click(
        lambda x: int(torch.randint(0, 2**32 - 1, (1,)).item()),
        inputs=[infer_seed_input],
        outputs=[infer_seed_input],
    )

    send_btn.click(
        merge_dataframe_to_ssml,
        inputs=[
            long_text_output,
            spk_input_text,
            style_input_dropdown,
            infer_seed_input,
        ],
        outputs=[ssml_input],
    )

    def change_tab():
        return gr.Tabs(selected="ssml")

    send_btn.click(change_tab, inputs=[], outputs=[tabs])


def create_readme_tab():
    readme_content = read_local_readme()
    gr.Markdown(readme_content)


def create_app_footer():
    gradio_version = gr.__version__
    git_tag = config.versions.git_tag
    git_commit = config.versions.git_commit
    git_branch = config.versions.git_branch
    python_version = config.versions.python_version
    torch_version = config.versions.torch_version

    config.versions.gradio_version = gradio_version

    gr.Markdown(
        f"""
🍦 [ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge)
version: [{git_tag}](https://github.com/lenML/ChatTTS-Forge/commit/{git_commit}) | branch: `{git_branch}` | python: `{python_version}` | torch: `{torch_version}`
        """
    )


def create_interface():

    js_func = """
    function refresh() {
        const url = new URL(window.location);

        if (url.searchParams.get('__theme') !== 'dark') {
            url.searchParams.set('__theme', 'dark');
            window.location.href = url.href;
        }
    }
    """

    head_js = """
    <script>
    </script>
    """

    with gr.Blocks(js=js_func, head=head_js, title="ChatTTS Forge WebUI") as demo:
        css = """
        <style>
        .big-button {
            height: 80px;
        }
        #input_title div.eta-bar {
            display: none !important; transform: none !important;
        }
        footer {
            display: none !important;
        }
        </style>
        """

        gr.HTML(css)
        with gr.Tabs() as tabs:
            with gr.TabItem("TTS"):
                create_tts_interface()

            with gr.TabItem("SSML", id="ssml"):
                ssml_input = create_ssml_interface()

            with gr.TabItem("Long Text"):
                create_long_content_tab(ssml_input, tabs=tabs)

            with gr.TabItem("README"):
                create_readme_tab()

        create_app_footer()
    return demo


if __name__ == "__main__":
    import argparse
    import dotenv

    dotenv.load_dotenv(
        dotenv_path=os.getenv("ENV_FILE", ".env.webui"),
    )

    parser = argparse.ArgumentParser(description="Gradio App")
    parser.add_argument("--server_name", type=str, help="server name")
    parser.add_argument("--server_port", type=int, help="server port")
    parser.add_argument(
        "--share", action="store_true", help="share the gradio interface"
    )
    parser.add_argument("--debug", action="store_true", help="enable debug mode")
    parser.add_argument("--auth", type=str, help="username:password for authentication")
    parser.add_argument(
        "--half",
        action="store_true",
        help="Enable half precision for model inference",
    )
    parser.add_argument(
        "--off_tqdm",
        action="store_true",
        help="Disable tqdm progress bar",
    )
    parser.add_argument(
        "--tts_max_len",
        type=int,
        help="Max length of text for TTS",
    )
    parser.add_argument(
        "--ssml_max_len",
        type=int,
        help="Max length of text for SSML",
    )
    parser.add_argument(
        "--max_batch_size",
        type=int,
        help="Max batch size for TTS",
    )
    parser.add_argument(
        "--lru_size",
        type=int,
        default=64,
        help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
    )
    parser.add_argument(
        "--device_id",
        type=str,
        help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)",
        default=None,
    )
    parser.add_argument(
        "--use_cpu",
        nargs="+",
        help="use CPU as torch device for specified modules",
        default=[],
        type=str.lower,
    )
    parser.add_argument("--compile", action="store_true", help="Enable model compile")

    args = parser.parse_args()

    def get_and_update_env(*args):
        val = env.get_env_or_arg(*args)
        key = args[1]
        config.runtime_env_vars[key] = val
        return val

    server_name = get_and_update_env(args, "server_name", "0.0.0.0", str)
    server_port = get_and_update_env(args, "server_port", 7860, int)
    share = get_and_update_env(args, "share", False, bool)
    debug = get_and_update_env(args, "debug", False, bool)
    auth = get_and_update_env(args, "auth", None, str)
    half = get_and_update_env(args, "half", False, bool)
    off_tqdm = get_and_update_env(args, "off_tqdm", False, bool)
    lru_size = get_and_update_env(args, "lru_size", 64, int)
    device_id = get_and_update_env(args, "device_id", None, str)
    use_cpu = get_and_update_env(args, "use_cpu", [], list)
    compile = get_and_update_env(args, "compile", False, bool)

    webui_config["tts_max"] = get_and_update_env(args, "tts_max_len", 1000, int)
    webui_config["ssml_max"] = get_and_update_env(args, "ssml_max_len", 5000, int)
    webui_config["max_batch_size"] = get_and_update_env(args, "max_batch_size", 8, int)

    demo = create_interface()

    if auth:
        auth = tuple(auth.split(":"))

    generate.setup_lru_cache()
    devices.reset_device()
    devices.first_time_calculation()

    demo.queue().launch(
        server_name=server_name,
        server_port=server_port,
        share=share,
        debug=debug,
        auth=auth,
        show_api=False,
    )