import streamlit as st
import torch
from TTS.api import TTS

# Set Page Title
st.title("Text to Speech with Overflow")
# Get device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Init TTS
tts = TTS("tts_models/en/ljspeech/overflow").to(device)

# Input form for the user to enter text
user_input = st.text_area("Enter the text:")

if st.button("Synthesize"):
    if user_input:
        # audio = text_to_audio(user_input)
        
        # Text to speech to a file
        tts.tts_to_file(text=user_input, file_path="output.wav")

        # Display the audio
        st.audio("output.wav", format="audio/wav")

        # Create a download link for the audio file
        st.markdown("### Download the audio")
        with open("output.wav", "rb") as audio_file:
            audio_bytes = audio_file.read()

        st.download_button(
            label="Click to download",
            data=audio_bytes,
            key="download_audio",
            file_name="output.wav",
            mime="audio/wav",
        )