import torchaudio

from whisperspeech.pipeline import Pipeline
import argparse

def parse_args():
    parser = argparse.ArgumentParser(description="Convert text to audio.")
    parser.add_argument(
        "--text",
        type=str,
        required=True,
        help="The text to convert to audio.",
    )
    return parser.parse_args()

def convert_text_to_audio(pipe: Pipeline, text: str):
    """Convert text to audio.

    Args:
        pipe (Pipeline): The pipeline to use for text-to-speech.
        text (str): The text to convert to audio.

    Returns:
        torch.Tensor: The generated audio.
    """
    return pipe.generate(text)


def convert_text_to_audio_file(pipe: Pipeline, text: str, output_path: str):
    """Convert text to audio and save it to a file.

    Args:
        pipe (Pipeline): The pipeline to use for text-to-speech.
        text (str): The text to convert to audio.
        output_path (str): The path to save the audio file.
    """
    pipe.generate_to_file(output_path, text)


class TTSProcessor:
    def __init__(self, device: str):
        """Initialize the TTS Processor with a specified device."""
        self.pipe = Pipeline(
            s2a_ref="collabora/whisperspeech:s2a-q4-tiny-en+pl.model", device=device
        )

    def get_reference_voice_embedding(self, path: str):
        """Get the reference voice embedding from the given audio file.

        Args:
            path (str): The path to the audio file.
        Returns:
            torch.Tensor: The reference voice embedding."""
        return self.pipe.extract_spk_emb(path).cpu()

    def convert_text_to_audio(self, text: str, speaker=None):
        """Convert text to audio.

        Args:
            text (str): The text to convert to audio.

        Returns:
            torch.Tensor: The generated audio.
        """
        return self.pipe.generate(text, speaker=speaker)

    def convert_text_to_audio_file(self, text: str, output_path: str, speaker=None):
        """Convert text to audio and save it to a file.

        Args:
            text (str): The text to convert to audio.
            output_path (str): The path to save the audio file.
        """
        self.pipe.generate_to_file(output_path, text, speaker=speaker)
if __name__ == "__main__":
    args = parse_args()
    processor = TTSProcessor("cuda")
    text = args.text
    text = text.lower()
    text_split = "_".join(text.lower().split(" "))  
    # remove the last character if it is a period
    if text_split[-1] == ".":
        text_split = text_split[:-1]
    print(text_split)
    path = f"./examples/{text_split}.wav"
    processor.convert_text_to_audio_file(text, path)