import os
import logging

# logging.basicConfig(
#     level=os.getenv("LOG_LEVEL", "INFO"),
#     format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
# )

from modules.devices import devices
from modules.utils import env
from modules.webui import webui_config
from modules.webui.app import webui_init, create_interface
from modules import generate_audio
from modules import config

if __name__ == "__main__":
    import argparse
    import dotenv

    dotenv.load_dotenv(
        dotenv_path=os.getenv("ENV_FILE", ".env.webui"),
    )

    parser = argparse.ArgumentParser(description="Gradio App")
    parser.add_argument("--server_name", type=str, help="server name")
    parser.add_argument("--server_port", type=int, help="server port")
    parser.add_argument(
        "--share", action="store_true", help="share the gradio interface"
    )
    parser.add_argument("--debug", action="store_true", help="enable debug mode")
    parser.add_argument("--auth", type=str, help="username:password for authentication")
    parser.add_argument(
        "--half",
        action="store_true",
        help="Enable half precision for model inference",
    )
    parser.add_argument(
        "--off_tqdm",
        action="store_true",
        help="Disable tqdm progress bar",
    )
    parser.add_argument(
        "--tts_max_len",
        type=int,
        help="Max length of text for TTS",
    )
    parser.add_argument(
        "--ssml_max_len",
        type=int,
        help="Max length of text for SSML",
    )
    parser.add_argument(
        "--max_batch_size",
        type=int,
        help="Max batch size for TTS",
    )
    parser.add_argument(
        "--lru_size",
        type=int,
        default=64,
        help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
    )
    parser.add_argument(
        "--device_id",
        type=str,
        help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)",
        default=None,
    )
    parser.add_argument(
        "--use_cpu",
        nargs="+",
        help="use CPU as torch device for specified modules",
        default=[],
        type=str.lower,
    )
    parser.add_argument("--compile", action="store_true", help="Enable model compile")
    # webui_Experimental
    parser.add_argument(
        "--webui_experimental",
        action="store_true",
        help="Enable webui_experimental features",
    )

    parser.add_argument(
        "--language",
        type=str,
        help="Set the default language for the webui",
    )
    args = parser.parse_args()

    def get_and_update_env(*args):
        val = env.get_env_or_arg(*args)
        key = args[1]
        config.runtime_env_vars[key] = val
        return val

    server_name = get_and_update_env(args, "server_name", "0.0.0.0", str)
    server_port = get_and_update_env(args, "server_port", 7860, int)
    share = get_and_update_env(args, "share", False, bool)
    debug = get_and_update_env(args, "debug", False, bool)
    auth = get_and_update_env(args, "auth", None, str)
    half = get_and_update_env(args, "half", False, bool)
    off_tqdm = get_and_update_env(args, "off_tqdm", False, bool)
    lru_size = get_and_update_env(args, "lru_size", 64, int)
    device_id = get_and_update_env(args, "device_id", None, str)
    use_cpu = get_and_update_env(args, "use_cpu", [], list)
    compile = get_and_update_env(args, "compile", False, bool)
    language = get_and_update_env(args, "language", "zh-CN", str)

    webui_config.experimental = get_and_update_env(
        args, "webui_experimental", False, bool
    )
    webui_config.tts_max = get_and_update_env(args, "tts_max_len", 1000, int)
    webui_config.ssml_max = get_and_update_env(args, "ssml_max_len", 5000, int)
    webui_config.max_batch_size = get_and_update_env(args, "max_batch_size", 8, int)

    webui_init()
    demo = create_interface()

    if auth:
        auth = tuple(auth.split(":"))

    generate_audio.setup_lru_cache()
    devices.reset_device()
    devices.first_time_calculation()

    demo.queue().launch(
        server_name=server_name,
        server_port=server_port,
        share=share,
        debug=debug,
        auth=auth,
        show_api=False,
    )