import torch
import soundfile as sf
from .config import model_mms_tts_eng, tokenizer_mms_tts_eng

SAMPLING_RATE = 16000

class T2A:
    def __init__(self, input_text: str):
        self.inputs = tokenizer_mms_tts_eng(input_text, return_tensors="pt")

    def __call__(self):
        if self.inputs is not None:   
            with torch.no_grad():
                output_model = model_mms_tts_eng(**self.inputs)
                    
            audio = output_model["audio"][0]

            with BytesIO() as buffer:
                sf.write(buffer, audio, SAMPLING_RATE, format='wav')
                output = buffer.getvalue() # bytes

            return output
        else:
            raise Exception("Input text is None. Please provide text")