import os
import gradio as gr
import numpy as np
import whisper
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from gtts import gTTS

# Load Whisper STT model
whisper_model = whisper.load_model("base")

# Load translation models
tokenizer = AutoTokenizer.from_pretrained("alirezamsh/small100")
model = AutoModelForSeq2SeqLM.from_pretrained("alirezamsh/small100")

def translate_speech(audio, target_lang):
    audio = audio[0].astype("float32")  # Extract audio from tuple and convert to float32
    sample_rate = whisper.sample_rate  # Get sample rate from whisper_model
    audio = whisper.pad_or_trim(audio, sample_rate)
    mel = whisper.log_mel_spectrogram(audio).to(whisper_model.device)
    _, probs = whisper_model.detect_language(mel)
    options = whisper.DecodingOptions(fp16=False)
    result = whisper.decode(whisper_model, mel, options)
    text = result.text

    # Translate text
    tokenizer.src_lang = target_lang 
    encoded_text = tokenizer(text, return_tensors="pt")
    generated_tokens = model.generate(**encoded_text)
    translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]

    # Text-to-speech (TTS)
    tts = gTTS(text=translated_text, lang=target_lang)
    audio_path = "translated_audio.mp3"
    tts.save(audio_path)

    return audio_path


def translate_speech_interface(audio, target_lang):
    translated_audio = translate_speech(audio, target_lang)
    translated_audio_bytes = open(translated_audio, "rb").read()

    return translated_audio_bytes

audio_recording = gr.inputs.Audio(source="microphone", type="numpy", label="Record your speech")
lang_choices = ["ru", "fr", "en", "de"]
lang_dropdown = gr.inputs.Dropdown(lang_choices, label="Select Language to Translate")
output_audio = gr.outputs.Audio(type="numpy", label="Translated Audio")

iface = gr.Interface(fn=translate_speech_interface, inputs=[audio_recording, lang_dropdown], outputs=output_audio, title="Speech Translator")
iface.launch()