import logging
from time import perf_counter

from baseHandler import BaseHandler
from funasr import AutoModel
import numpy as np
from rich.console import Console
import torch

logging.basicConfig(
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)

console = Console()


class ParaformerSTTHandler(BaseHandler):
    """
    Handles the Speech To Text generation using a Paraformer model.
    The default for this model is set to Chinese.
    This model was contributed by @wuhongsheng.
    """

    def setup(
        self,
        model_name="paraformer-zh",
        device="cuda",
        gen_kwargs={},
    ):
        print(model_name)
        if len(model_name.split("/")) > 1:
            model_name = model_name.split("/")[-1]
        self.device = device
        self.model = AutoModel(model=model_name, device=device)
        self.warmup()

    def warmup(self):
        logger.info(f"Warming up {self.__class__.__name__}")

        # 2 warmup steps for no compile or compile mode with CUDA graphs capture
        n_steps = 1
        dummy_input = np.array([0] * 512, dtype=np.float32)
        for _ in range(n_steps):
            _ = self.model.generate(dummy_input)[0]["text"].strip().replace(" ", "")

    def process(self, spoken_prompt):
        logger.debug("infering paraformer...")

        global pipeline_start
        pipeline_start = perf_counter()

        pred_text = (
            self.model.generate(spoken_prompt)[0]["text"].strip().replace(" ", "")
        )
        torch.mps.empty_cache()

        logger.debug("finished paraformer inference")
        console.print(f"[yellow]USER: {pred_text}")

        yield pred_text