fargerm's picture
Update app.py
d50e22f verified
import streamlit as st
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from transformers import MarianMTModel, MarianTokenizer
import soundfile as sf
from datasets import load_dataset
import os
# Device setup
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Load Whisper model
whisper_model_id = "openai/whisper-large-v3"
whisper_model = AutoModelForSpeechSeq2Seq.from_pretrained(
whisper_model_id, torch_dtype=torch_dtype, use_safetensors=True
).to(device)
whisper_processor = AutoProcessor.from_pretrained(whisper_model_id)
whisper_pipe = pipeline(
"automatic-speech-recognition",
model=whisper_model,
tokenizer=whisper_processor.tokenizer,
feature_extractor=whisper_processor.feature_extractor,
torch_dtype=torch_dtype,
device=device,
)
# Load TTS model
tts_pipe = pipeline("text-to-speech", "microsoft/speecht5_tts")
# Load speaker embeddings
def get_speaker_embedding():
dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
# Use the first sample's embedding as an example
speaker_embedding = torch.tensor(dataset[7306]["xvector"]).unsqueeze(0)
return speaker_embedding
speaker_embedding = get_speaker_embedding()
# Load translation model
def load_translation_model(lang_code):
model_name = f"Helsinki-NLP/opus-mt-en-{lang_code}"
model = MarianMTModel.from_pretrained(model_name)
tokenizer = MarianTokenizer.from_pretrained(model_name)
return model, tokenizer
st.title("TextLangAudioGenerator")
# Text input
text_input = st.text_area("Enter text in English", key="text_input")
# Select target language
target_lang = st.selectbox(
"Select target language",
["fr", "zh", "it", "ur", "hi"], # Add more language codes as needed
format_func=lambda x: {"fr": "French", "zh": "Chinese", "it": "Italian", "ur": "Urdu", "hi": "Hindi"}.get(x, x),
key="target_lang"
)
# Initialize session state for storing results
if "translated_text" not in st.session_state:
st.session_state.translated_text = ""
if "audio_path" not in st.session_state:
st.session_state.audio_path = ""
# Submit button
if st.button("Submit"):
if text_input and target_lang:
# Load translation model
model, tokenizer = load_translation_model(target_lang)
inputs = tokenizer(text_input, return_tensors="pt")
translated = model.generate(**inputs)
st.session_state.translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
st.session_state.audio_path = "" # Clear previous audio path
st.write(f"Translated text: {st.session_state.translated_text}")
else:
st.error("Please enter text and select a target language.")
# Listen to Translated Audio button
if st.button("Listen to Translated Audio"):
if st.session_state.translated_text:
# Generate TTS
speech = tts_pipe(st.session_state.translated_text, forward_params={"speaker_embeddings": speaker_embedding})
st.session_state.audio_path = "translated_speech.wav"
sf.write(st.session_state.audio_path, speech["audio"], samplerate=speech["sampling_rate"])
# Ensure the audio file exists and display the player
if os.path.exists(st.session_state.audio_path):
st.audio(st.session_state.audio_path, format="audio/wav")
else:
st.error("Failed to generate audio. Please try again.")
else:
st.error("Please submit the text first.")
# Reset button
if st.button("Reset"):
st.session_state.translated_text = ""
st.session_state.audio_path = ""
st.text_area("Enter text in English", value="", key="text_input")
st.selectbox("Select target language", ["fr", "zh", "it", "ur", "hi"], key="target_lang")
st.experimental_rerun() # Reload the app to reset the inputs
# Display current state of translated text and audio
if st.session_state.translated_text and not st.session_state.audio_path:
st.write(f"Translated text: {st.session_state.translated_text}")
elif st.session_state.audio_path:
st.audio(st.session_state.audio_path, format="audio/wav")