from typing import Union

from fastapi import HTTPException
from pydantic import BaseModel

from modules.api import utils as api_utils
from modules.api.Api import APIManager
from modules.api.impl.handler.SSMLHandler import SSMLHandler
from modules.api.impl.handler.TTSHandler import TTSHandler
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
from modules.api.impl.model.enhancer_model import EnhancerConfig
from modules.speaker import Speaker, speaker_mgr


class SynthesisInput(BaseModel):
    text: Union[str, None] = None
    ssml: Union[str, None] = None


class VoiceSelectionParams(BaseModel):
    languageCode: str = "ZH-CN"

    name: str = "female2"
    style: str = ""
    temperature: float = 0.3
    topP: float = 0.7
    topK: int = 20
    seed: int = 42

    # end_of_sentence
    eos: str = "[uv_break]"


class AudioConfig(BaseModel):
    audioEncoding: AudioFormat = AudioFormat.mp3
    speakingRate: float = 1
    pitch: float = 0
    volumeGainDb: float = 0
    sampleRateHertz: int = 24000
    batchSize: int = 4
    spliterThreshold: int = 100


class GoogleTextSynthesizeRequest(BaseModel):
    input: SynthesisInput
    voice: VoiceSelectionParams
    audioConfig: AudioConfig
    enhancerConfig: EnhancerConfig = None


class GoogleTextSynthesizeResponse(BaseModel):
    audioContent: str


async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
    input = request.input
    voice = request.voice
    audioConfig = request.audioConfig
    enhancerConfig = request.enhancerConfig

    # 提取参数

    # TODO 这个也许应该传给 normalizer
    language_code = voice.languageCode
    voice_name = voice.name
    infer_seed = voice.seed or 42
    eos = voice.eos or "[uv_break]"
    audio_format = audioConfig.audioEncoding

    if not isinstance(audio_format, AudioFormat) and isinstance(audio_format, str):
        audio_format = AudioFormat(audio_format)

    speaking_rate = audioConfig.speakingRate or 1
    pitch = audioConfig.pitch or 0
    volume_gain_db = audioConfig.volumeGainDb or 0

    batch_size = audioConfig.batchSize or 1

    spliter_threshold = audioConfig.spliterThreshold or 100

    # TODO
    sample_rate = audioConfig.sampleRateHertz or 24000

    params = api_utils.calc_spk_style(spk=voice.name, style=voice.style)

    # 虽然 calc_spk_style 可以解析 seed 形式,但是这个接口只准备支持 speakers list 中存在的 speaker
    if speaker_mgr.get_speaker(voice_name) is None:
        raise HTTPException(
            status_code=422, detail="The specified voice name is not supported."
        )

    if not isinstance(params.get("spk"), Speaker):
        raise HTTPException(
            status_code=422, detail="The specified voice name is not supported."
        )

    speaker = params.get("spk")
    tts_config = ChatTTSConfig(
        style=params.get("style", ""),
        temperature=voice.temperature,
        top_k=voice.topK,
        top_p=voice.topP,
    )
    infer_config = InferConfig(
        batch_size=batch_size,
        spliter_threshold=spliter_threshold,
        eos=eos,
        seed=infer_seed,
    )
    adjust_config = AdjustConfig(
        speaking_rate=speaking_rate,
        pitch=pitch,
        volume_gain_db=volume_gain_db,
    )
    enhancer_config = enhancerConfig

    mime_type = f"audio/{audio_format.value}"
    if audio_format == AudioFormat.mp3:
        mime_type = "audio/mpeg"
    try:
        if input.text:
            text_content = input.text

            handler = TTSHandler(
                text_content=text_content,
                spk=speaker,
                tts_config=tts_config,
                infer_config=infer_config,
                adjust_config=adjust_config,
                enhancer_config=enhancer_config,
            )

            base64_string = handler.enqueue_to_base64(format=audio_format)
            return {"audioContent": f"data:{mime_type};base64,{base64_string}"}

        elif input.ssml:
            ssml_content = input.ssml

            handler = SSMLHandler(
                ssml_content=ssml_content,
                infer_config=infer_config,
                adjust_config=adjust_config,
                enhancer_config=enhancer_config,
            )

            base64_string = handler.enqueue_to_base64(format=audio_format)

            return {"audioContent": f"data:{mime_type};base64,{base64_string}"}

        else:
            raise HTTPException(
                status_code=422, detail="Invalid input text or ssml specified."
            )

    except Exception as e:
        import logging

        logging.exception(e)

        if isinstance(e, HTTPException):
            raise e
        else:
            raise HTTPException(status_code=500, detail=str(e))


def setup(app: APIManager):
    app.post(
        "/v1/text:synthesize",
        response_model=GoogleTextSynthesizeResponse,
        description="""
google api document: <br/>
[https://cloud.google.com/text-to-speech/docs/reference/rest/v1/text/synthesize](https://cloud.google.com/text-to-speech/docs/reference/rest/v1/text/synthesize)

- 多个属性在本系统中无用仅仅是为了兼容google api
- voice 中的 topP, topK, temperature 为本系统中的参数
- voice.name 即 speaker name (或者speaker seed)
- voice.seed 为 infer seed (可在webui中测试具体作用)

- 编码格式影响的是 audioContent 的二进制格式,所以所有format都是返回带有base64数据的json
        """,
    )(google_text_synthesize)