import logging

from fastapi import HTTPException, Query, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel

from modules.api.Api import APIManager
from modules.api.impl.handler.TTSHandler import TTSHandler
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
from modules.api.impl.model.enhancer_model import EnhancerConfig
from modules.speaker import speaker_mgr

logger = logging.getLogger(__name__)


class XTTS_V2_Settings:
    def __init__(self):
        self.stream_chunk_size = 100
        self.temperature = 0.3
        self.speed = 1

        # TODO: 这两个参数现在用不着...但是其实gpt是可以用的可以考虑增加
        self.length_penalty = 0.5
        self.repetition_penalty = 1.0

        self.top_p = 0.7
        self.top_k = 20
        self.enable_text_splitting = True

        # 下面是额外配置 xtts_v2 中不包含的,但是本系统需要的
        self.batch_size = 4
        self.eos = "[uv_break]"
        self.infer_seed = 42
        self.use_decoder = True
        self.prompt1 = ""
        self.prompt2 = ""
        self.prefix = ""
        self.spliter_threshold = 100
        self.style = ""


class TTSSettingsRequest(BaseModel):
    # 这个 stream_chunk 现在当作 spliter_threshold 用
    stream_chunk_size: int
    temperature: float
    speed: float
    length_penalty: float
    repetition_penalty: float
    top_p: float
    top_k: int
    enable_text_splitting: bool

    batch_size: int = None
    eos: str = None
    infer_seed: int = None
    use_decoder: bool = None
    prompt1: str = None
    prompt2: str = None
    prefix: str = None
    spliter_threshold: int = None
    style: str = None


class SynthesisRequest(BaseModel):
    text: str
    speaker_wav: str
    language: str


def setup(app: APIManager):
    XTTSV2 = XTTS_V2_Settings()

    @app.get("/v1/xtts_v2/speakers")
    async def speakers():
        spks = speaker_mgr.list_speakers()
        return [
            {
                "name": spk.name,
                "voice_id": spk.id,
                # TODO: 也许可以放一个 "/v1/tts" 接口地址在这里
                "preview_url": "",
            }
            for spk in spks
        ]

    @app.post("/v1/xtts_v2/tts_to_audio", response_class=StreamingResponse)
    async def tts_to_audio(request: SynthesisRequest):
        text = request.text
        # speaker_wav 就是 speaker id 。。。
        voice_id = request.speaker_wav
        language = request.language

        spk = speaker_mgr.get_speaker_by_id(voice_id) or speaker_mgr.get_speaker(
            voice_id
        )
        if spk is None:
            raise HTTPException(status_code=400, detail="Invalid speaker id")

        tts_config = ChatTTSConfig(
            style=XTTSV2.style,
            temperature=XTTSV2.temperature,
            top_k=XTTSV2.top_k,
            top_p=XTTSV2.top_p,
            prefix=XTTSV2.prefix,
            prompt1=XTTSV2.prompt1,
            prompt2=XTTSV2.prompt2,
        )
        infer_config = InferConfig(
            batch_size=XTTSV2.batch_size,
            spliter_threshold=XTTSV2.spliter_threshold,
            eos=XTTSV2.eos,
            seed=XTTSV2.infer_seed,
        )
        adjust_config = AdjustConfig(
            speed_rate=XTTSV2.speed,
        )
        # TODO: support enhancer
        enhancer_config = EnhancerConfig(
            # enabled=params.enhance or params.denoise or False,
            # lambd=0.9 if params.denoise else 0.1,
        )

        handler = TTSHandler(
            text_content=text,
            spk=spk,
            tts_config=tts_config,
            infer_config=infer_config,
            adjust_config=adjust_config,
            enhancer_config=enhancer_config,
        )

        buffer = handler.enqueue_to_buffer(AudioFormat.mp3)

        return StreamingResponse(buffer, media_type="audio/mpeg")

    @app.get("/v1/xtts_v2/tts_stream")
    async def tts_stream(
        request: Request,
        text: str = Query(),
        speaker_wav: str = Query(),
        language: str = Query(),
    ):
        # speaker_wav 就是 speaker id 。。。
        voice_id = speaker_wav

        spk = speaker_mgr.get_speaker_by_id(voice_id) or speaker_mgr.get_speaker(
            voice_id
        )
        if spk is None:
            raise HTTPException(status_code=400, detail="Invalid speaker id")

        tts_config = ChatTTSConfig(
            style=XTTSV2.style,
            temperature=XTTSV2.temperature,
            top_k=XTTSV2.top_k,
            top_p=XTTSV2.top_p,
            prefix=XTTSV2.prefix,
            prompt1=XTTSV2.prompt1,
            prompt2=XTTSV2.prompt2,
        )
        infer_config = InferConfig(
            batch_size=XTTSV2.batch_size,
            spliter_threshold=XTTSV2.spliter_threshold,
            eos=XTTSV2.eos,
            seed=XTTSV2.infer_seed,
        )
        adjust_config = AdjustConfig(
            speed_rate=XTTSV2.speed,
        )
        # TODO: support enhancer
        enhancer_config = EnhancerConfig(
            # enabled=params.enhance or params.denoise or False,
            # lambd=0.9 if params.denoise else 0.1,
        )

        handler = TTSHandler(
            text_content=text,
            spk=spk,
            tts_config=tts_config,
            infer_config=infer_config,
            adjust_config=adjust_config,
            enhancer_config=enhancer_config,
        )

        async def generator():
            for chunk in handler.enqueue_to_stream(AudioFormat.mp3):
                disconnected = await request.is_disconnected()
                if disconnected:
                    break

                yield chunk

        return StreamingResponse(generator(), media_type="audio/mpeg")

    @app.post("/v1/xtts_v2/set_tts_settings")
    async def set_tts_settings(request: TTSSettingsRequest):
        try:
            if request.stream_chunk_size < 50:
                raise HTTPException(
                    status_code=400, detail="stream_chunk_size must be greater than 0"
                )
            if request.temperature < 0:
                raise HTTPException(
                    status_code=400, detail="temperature must be greater than 0"
                )
            if request.speed < 0:
                raise HTTPException(
                    status_code=400, detail="speed must be greater than 0"
                )
            if request.length_penalty < 0:
                raise HTTPException(
                    status_code=400, detail="length_penalty must be greater than 0"
                )
            if request.repetition_penalty < 0:
                raise HTTPException(
                    status_code=400, detail="repetition_penalty must be greater than 0"
                )
            if request.top_p < 0:
                raise HTTPException(
                    status_code=400, detail="top_p must be greater than 0"
                )
            if request.top_k < 0:
                raise HTTPException(
                    status_code=400, detail="top_k must be greater than 0"
                )

            XTTSV2.stream_chunk_size = request.stream_chunk_size
            XTTSV2.spliter_threshold = request.stream_chunk_size

            XTTSV2.temperature = request.temperature
            XTTSV2.speed = request.speed
            XTTSV2.length_penalty = request.length_penalty
            XTTSV2.repetition_penalty = request.repetition_penalty
            XTTSV2.top_p = request.top_p
            XTTSV2.top_k = request.top_k
            XTTSV2.enable_text_splitting = request.enable_text_splitting

            # TODO: checker
            if request.batch_size:
                XTTSV2.batch_size = request.batch_size
            if request.eos:
                XTTSV2.eos = request.eos
            if request.infer_seed:
                XTTSV2.infer_seed = request.infer_seed
            if request.use_decoder:
                XTTSV2.use_decoder = request.use_decoder
            if request.prompt1:
                XTTSV2.prompt1 = request.prompt1
            if request.prompt2:
                XTTSV2.prompt2 = request.prompt2
            if request.prefix:
                XTTSV2.prefix = request.prefix
            if request.spliter_threshold:
                XTTSV2.spliter_threshold = request.spliter_threshold
            if request.style:
                XTTSV2.style = request.style

            return {"message": "Settings successfully applied"}
        except Exception as e:
            if isinstance(e, HTTPException):
                raise e
            logger.error(e)
            raise HTTPException(status_code=500, detail=str(e))