fargerm's picture
Update app.py
ca4c4e9 verified
raw
history blame
3.99 kB
import streamlit as st
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from transformers import MarianMTModel, MarianTokenizer
import soundfile as sf
from datasets import load_dataset
# Device setup
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Load Whisper model
whisper_model_id = "openai/whisper-large-v3"
whisper_model = AutoModelForSpeechSeq2Seq.from_pretrained(
whisper_model_id, torch_dtype=torch_dtype, use_safetensors=True
).to(device)
whisper_processor = AutoProcessor.from_pretrained(whisper_model_id)
whisper_pipe = pipeline(
"automatic-speech-recognition",
model=whisper_model,
tokenizer=whisper_processor.tokenizer,
feature_extractor=whisper_processor.feature_extractor,
torch_dtype=torch_dtype,
device=device,
)
# Load TTS model
tts_pipe = pipeline("text-to-speech", "microsoft/speecht5_tts")
# Load speaker embeddings
def get_speaker_embedding():
dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
# Use the first sample's embedding as an example
speaker_embedding = torch.tensor(dataset[7306]["xvector"]).unsqueeze(0)
return speaker_embedding
speaker_embedding = get_speaker_embedding()
# Load translation model
def load_translation_model(lang_code):
model_name = f"Helsinki-NLP/opus-mt-en-{lang_code}"
model = MarianMTModel.from_pretrained(model_name)
tokenizer = MarianTokenizer.from_pretrained(model_name)
return model, tokenizer
st.title("TextLangAudioGenerator")
# Text input
text_input = st.text_area("Enter text in English", key="text_input")
# Select target language
target_lang = st.selectbox(
"Select target language",
["fr", "zh", "it", "ur", "hi"], # Add more language codes as needed
format_func=lambda x: {"fr": "French", "zh": "Chinese", "it": "Italian", "ur": "Urdu", "hi": "Hindi"}.get(x, x),
key="target_lang"
)
# Initialize session state for storing results
if "translated_text" not in st.session_state:
st.session_state.translated_text = ""
if "audio_path" not in st.session_state:
st.session_state.audio_path = ""
# Submit button
if st.button("Submit"):
if text_input and target_lang:
# Load translation model
model, tokenizer = load_translation_model(target_lang)
inputs = tokenizer(text_input, return_tensors="pt")
translated = model.generate(**inputs)
st.session_state.translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
st.session_state.audio_path = "" # Clear previous audio path
st.write(f"Translated text: {st.session_state.translated_text}")
else:
st.error("Please enter text and select a target language.")
# Listen to Translated Audio button
if st.button("Listen to Translated Audio"):
if st.session_state.translated_text:
# Generate TTS
speech = tts_pipe(st.session_state.translated_text, forward_params={"speaker_embeddings": speaker_embedding})
st.session_state.audio_path = "translated_speech.wav"
sf.write(st.session_state.audio_path, speech["audio"], samplerate=speech["sampling_rate"])
st.audio(st.session_state.audio_path, format="audio/wav")
else:
st.error("Please submit the text first.")
# Reset button
if st.button("Reset"):
st.session_state.translated_text = ""
st.session_state.audio_path = ""
st.text_area("Enter text in English", value="", key="text_input")
st.selectbox("Select target language", ["fr", "zh", "it", "ur", "hi"], key="target_lang")
st.experimental_rerun() # Reload the app to reset the inputs
# Display current state of translated text and audio
if st.session_state.translated_text and not st.session_state.audio_path:
st.write(f"Translated text: {st.session_state.translated_text}")
elif st.session_state.audio_path:
st.audio(st.session_state.audio_path, format="audio/wav")