from typing import Union

import gradio as gr
import numpy as np
import torch
import torch.profiler

from modules import refiner
from modules.api.impl.handler.SSMLHandler import SSMLHandler
from modules.api.impl.handler.TTSHandler import TTSHandler
from modules.api.impl.model.audio_model import AdjustConfig
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
from modules.api.impl.model.enhancer_model import EnhancerConfig
from modules.api.utils import calc_spk_style
from modules.data import styles_mgr
from modules.Enhancer.ResembleEnhance import apply_audio_enhance as _apply_audio_enhance
from modules.normalization import text_normalize
from modules.SentenceSplitter import SentenceSplitter
from modules.speaker import Speaker, speaker_mgr
from modules.ssml_parser.SSMLParser import SSMLBreak, SSMLSegment, create_ssml_parser
from modules.utils import audio
from modules.utils.hf import spaces
from modules.webui import webui_config


def get_speakers():
    return speaker_mgr.list_speakers()


def get_speaker_names() -> tuple[list[Speaker], list[str]]:
    speakers = get_speakers()

    def get_speaker_show_name(spk):
        if spk.gender == "*" or spk.gender == "":
            return spk.name
        return f"{spk.gender} : {spk.name}"

    speaker_names = [get_speaker_show_name(speaker) for speaker in speakers]
    speaker_names.sort(key=lambda x: x.startswith("*") and "-1" or x)

    return speakers, speaker_names


def get_styles():
    return styles_mgr.list_items()


def load_spk_info(file):
    if file is None:
        return "empty"
    try:

        spk: Speaker = Speaker.from_file(file)
        infos = spk.to_json()
        return f"""
- name: {infos.name}
- gender: {infos.gender}
- describe: {infos.describe}
    """.strip()
    except:
        return "load failed"


def segments_length_limit(
    segments: list[Union[SSMLBreak, SSMLSegment]], total_max: int
) -> list[Union[SSMLBreak, SSMLSegment]]:
    ret_segments = []
    total_len = 0
    for seg in segments:
        if isinstance(seg, SSMLBreak):
            ret_segments.append(seg)
            continue
        total_len += len(seg["text"])
        if total_len > total_max:
            break
        ret_segments.append(seg)
    return ret_segments


@torch.inference_mode()
@spaces.GPU(duration=120)
def apply_audio_enhance(audio_data, sr, enable_denoise, enable_enhance):
    return _apply_audio_enhance(audio_data, sr, enable_denoise, enable_enhance)


@torch.inference_mode()
@spaces.GPU(duration=120)
def synthesize_ssml(
    ssml: str,
    batch_size=4,
    enable_enhance=False,
    enable_denoise=False,
    eos: str = "[uv_break]",
    spliter_thr: int = 100,
    pitch: float = 0,
    speed_rate: float = 1,
    volume_gain_db: float = 0,
    normalize: bool = True,
    headroom: float = 1,
    progress=gr.Progress(track_tqdm=True),
):
    try:
        batch_size = int(batch_size)
    except Exception:
        batch_size = 8

    ssml = ssml.strip()

    if ssml == "":
        raise gr.Error("SSML is empty, please input some SSML")

    parser = create_ssml_parser()
    segments = parser.parse(ssml)
    max_len = webui_config.ssml_max
    segments = segments_length_limit(segments, max_len)

    if len(segments) == 0:
        raise gr.Error("No valid segments in SSML")

    infer_config = InferConfig(
        batch_size=batch_size,
        spliter_threshold=spliter_thr,
        eos=eos,
        # NOTE: SSML not support `infer_seed` contorl
        # seed=42,
    )
    adjust_config = AdjustConfig(
        pitch=pitch,
        speed_rate=speed_rate,
        volume_gain_db=volume_gain_db,
        normalize=normalize,
        headroom=headroom,
    )
    enhancer_config = EnhancerConfig(
        enabled=enable_denoise or enable_enhance or False,
        lambd=0.9 if enable_denoise else 0.1,
    )

    handler = SSMLHandler(
        ssml_content=ssml,
        infer_config=infer_config,
        adjust_config=adjust_config,
        enhancer_config=enhancer_config,
    )

    audio_data, sr = handler.enqueue()

    # NOTE: 这里必须要加,不然 gradio 没法解析成 mp3 格式
    audio_data = audio.audio_to_int16(audio_data)

    return sr, audio_data


# @torch.inference_mode()
@spaces.GPU(duration=120)
def tts_generate(
    text,
    temperature=0.3,
    top_p=0.7,
    top_k=20,
    spk=-1,
    infer_seed=-1,
    use_decoder=True,
    prompt1="",
    prompt2="",
    prefix="",
    style="",
    disable_normalize=False,
    batch_size=4,
    enable_enhance=False,
    enable_denoise=False,
    spk_file=None,
    spliter_thr: int = 100,
    eos: str = "[uv_break]",
    pitch: float = 0,
    speed_rate: float = 1,
    volume_gain_db: float = 0,
    normalize: bool = True,
    headroom: float = 1,
    progress=gr.Progress(track_tqdm=True),
):
    try:
        batch_size = int(batch_size)
    except Exception:
        batch_size = 4

    max_len = webui_config.tts_max
    text = text.strip()[0:max_len]

    if text == "":
        raise gr.Error("Text is empty, please input some text")

    if style == "*auto":
        style = ""

    if isinstance(top_k, float):
        top_k = int(top_k)

    params = calc_spk_style(spk=spk, style=style)
    spk = params.get("spk", spk)

    infer_seed = infer_seed or params.get("seed", infer_seed)
    temperature = temperature or params.get("temperature", temperature)
    prefix = prefix or params.get("prefix", prefix)
    prompt1 = prompt1 or params.get("prompt1", "")
    prompt2 = prompt2 or params.get("prompt2", "")

    infer_seed = np.clip(infer_seed, -1, 2**32 - 1, out=None, dtype=np.float64)
    infer_seed = int(infer_seed)

    if isinstance(spk, int):
        spk = Speaker.from_seed(spk)

    if spk_file:
        try:
            spk: Speaker = Speaker.from_file(spk_file)
        except Exception:
            raise gr.Error("Failed to load speaker file")

        if not isinstance(spk.emb, torch.Tensor):
            raise gr.Error("Speaker file is not supported")

    tts_config = ChatTTSConfig(
        style=style,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        prefix=prefix,
        prompt1=prompt1,
        prompt2=prompt2,
    )
    infer_config = InferConfig(
        batch_size=batch_size,
        spliter_threshold=spliter_thr,
        eos=eos,
        seed=infer_seed,
    )
    adjust_config = AdjustConfig(
        pitch=pitch,
        speed_rate=speed_rate,
        volume_gain_db=volume_gain_db,
        normalize=normalize,
        headroom=headroom,
    )
    enhancer_config = EnhancerConfig(
        enabled=enable_denoise or enable_enhance or False,
        lambd=0.9 if enable_denoise else 0.1,
    )

    handler = TTSHandler(
        text_content=text,
        spk=spk,
        tts_config=tts_config,
        infer_config=infer_config,
        adjust_config=adjust_config,
        enhancer_config=enhancer_config,
    )

    audio_data, sample_rate = handler.enqueue()

    # NOTE: 这里必须要加,不然 gradio 没法解析成 mp3 格式
    audio_data = audio.audio_to_int16(audio_data)
    return sample_rate, audio_data


@torch.inference_mode()
@spaces.GPU(duration=120)
def refine_text(
    text: str,
    prompt: str,
    progress=gr.Progress(track_tqdm=True),
):
    text = text_normalize(text)
    return refiner.refine_text(text, prompt=prompt)


@torch.inference_mode()
@spaces.GPU(duration=120)
def split_long_text(long_text_input, spliter_threshold=100, eos=""):
    spliter = SentenceSplitter(threshold=spliter_threshold)
    sentences = spliter.parse(long_text_input)
    sentences = [text_normalize(s) + eos for s in sentences]
    data = []
    for i, text in enumerate(sentences):
        token_length = spliter.count_tokens(text)
        data.append([i, text, token_length])
    return data