import os
import torch
import gradio as gr
import torchaudio
import time
from datetime import datetime
from tortoise.api import TextToSpeech
from tortoise.utils.audio import load_voice, load_voices

VOICE_OPTIONS = [
    "angie",
"applejack",
"atkins",
"barack_obama",
"daniel",
"daws",
"deniro",
"dortice",
"dreams",
"emma",
"empire",
"freeman",
"geralt",
"grace",
"halle",
"jane_eyre",
"jlaw",
"kennard",
"lescault",
"lj",
"mol",
"mouse",
"myself",
"pat",
"pat2",
"rainbow",
"sanjita",
"snakes",
"tim_reynolds",
"tom",
"weaver",
"william",
"random", 
]

def inference(
    text,
    voice,
    Emotion,
    Preset,
):
    
    texts = [text]

    Angry_tone = "[I am so angry]"
    Sad_tone = "[I am so sad]"
    Happy_tone = "[I am so happy]"
    Scared_tone = "[I am so scared]"

    if Emotion == "Angry":
        text = Angry_tone + text
    if Emotion == "Sad":
        text = Sad_tone + text
    if Emotion == "Happy":
        text = Happy_tone + text
    if Emotion == "Scared":
        text = Scared_tone + text

    voices = [voice]

    if len(voices) == 1:
        voice_samples, conditioning_latents = load_voice(voice)
    else:
        voice_samples, conditioning_latents = load_voices(voices)

    audio_frames = []

    for j, text in enumerate(texts):
        for audio_frame in tts.tts_with_preset(
            text,
            voice_samples=voice_samples,
            conditioning_latents=conditioning_latents,
            preset=Preset,
            k=1
        ):
            audio_frames.append(torch.from_numpy(audio_frame.cpu().detach().numpy())) 

    complete_audio = torch.cat(audio_frames, dim=0)

    yield (24000, complete_audio.numpy()) 

def main():
    title = "TTS "
    
    text = gr.Textbox(
        lines=4,
        label="Text:",
    )

    voice = gr.Dropdown(
        VOICE_OPTIONS, value="jane_eyre", label="Select voice:", type="value"
    )

    Emotion = gr.Radio(
        ["Angry", "Sad", "Happy", "Scared"],
        type="value",
    )

    Preset = gr.Radio(
        ["ultra_fast", "fast", "standard", "high_quality"],
        type="value",
        value="ultra_fast",
    )

    output_audio = gr.Audio(label="streaming audio:", streaming=True, autoplay=True)
    interface = gr.Interface(
        fn=inference,
        inputs=[
            text,
            voice,
            Emotion,
            Preset,
        ],
        title=title,
        outputs=[output_audio],
    )
    interface.queue().launch()

if __name__ == "__main__":
    tts = TextToSpeech(kv_cache=True, use_deepspeed=True, half=True)

    with open("Tortoise_TTS_Runs_Scripts.log", "a") as f:
        f.write(
            f"\n\n-------------------------Tortoise TTS Scripts Logs, {datetime.now()}-------------------------\n"
        )

    main()