Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline | |
from transformers import MarianMTModel, MarianTokenizer | |
import soundfile as sf | |
from datasets import load_dataset | |
# Device setup | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
# Load Whisper model | |
whisper_model_id = "openai/whisper-large-v3" | |
whisper_model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
whisper_model_id, torch_dtype=torch_dtype, use_safetensors=True | |
).to(device) | |
whisper_processor = AutoProcessor.from_pretrained(whisper_model_id) | |
whisper_pipe = pipeline( | |
"automatic-speech-recognition", | |
model=whisper_model, | |
tokenizer=whisper_processor.tokenizer, | |
feature_extractor=whisper_processor.feature_extractor, | |
torch_dtype=torch_dtype, | |
device=device, | |
) | |
# Load TTS model | |
tts_pipe = pipeline("text-to-speech", "microsoft/speecht5_tts") | |
# Load speaker embeddings | |
def get_speaker_embedding(): | |
dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") | |
# Use the first sample's embedding as an example | |
speaker_embedding = torch.tensor(dataset[7306]["xvector"]).unsqueeze(0) | |
return speaker_embedding | |
speaker_embedding = get_speaker_embedding() | |
# Load translation model | |
def load_translation_model(lang_code): | |
model_name = f"Helsinki-NLP/opus-mt-en-{lang_code}" | |
model = MarianMTModel.from_pretrained(model_name) | |
tokenizer = MarianTokenizer.from_pretrained(model_name) | |
return model, tokenizer | |
st.title("TextLangAudioGenerator") | |
# Text input | |
text_input = st.text_area("Enter text in English", key="text_input") | |
# Select target language | |
target_lang = st.selectbox( | |
"Select target language", | |
["fr", "zh", "it", "ur", "hi"], # Add more language codes as needed | |
format_func=lambda x: {"fr": "French", "zh": "Chinese", "it": "Italian", "ur": "Urdu", "hi": "Hindi"}.get(x, x), | |
key="target_lang" | |
) | |
# Initialize session state for storing results | |
if "translated_text" not in st.session_state: | |
st.session_state.translated_text = "" | |
if "audio_path" not in st.session_state: | |
st.session_state.audio_path = "" | |
# Submit button | |
if st.button("Submit"): | |
if text_input and target_lang: | |
# Load translation model | |
model, tokenizer = load_translation_model(target_lang) | |
inputs = tokenizer(text_input, return_tensors="pt") | |
translated = model.generate(**inputs) | |
st.session_state.translated_text = tokenizer.decode(translated[0], skip_special_tokens=True) | |
st.session_state.audio_path = "" # Clear previous audio path | |
st.write(f"Translated text: {st.session_state.translated_text}") | |
else: | |
st.error("Please enter text and select a target language.") | |
# Listen to Translated Audio button | |
if st.button("Listen to Translated Audio"): | |
if st.session_state.translated_text: | |
# Generate TTS | |
speech = tts_pipe(st.session_state.translated_text, forward_params={"speaker_embeddings": speaker_embedding}) | |
st.session_state.audio_path = "translated_speech.wav" | |
sf.write(st.session_state.audio_path, speech["audio"], samplerate=speech["sampling_rate"]) | |
st.audio(st.session_state.audio_path, format="audio/wav") | |
else: | |
st.error("Please submit the text first.") | |
# Reset button | |
if st.button("Reset"): | |
st.session_state.translated_text = "" | |
st.session_state.audio_path = "" | |
st.text_area("Enter text in English", value="", key="text_input") | |
st.selectbox("Select target language", ["fr", "zh", "it", "ur", "hi"], key="target_lang") | |
st.experimental_rerun() # Reload the app to reset the inputs | |
# Display current state of translated text and audio | |
if st.session_state.translated_text and not st.session_state.audio_path: | |
st.write(f"Translated text: {st.session_state.translated_text}") | |
elif st.session_state.audio_path: | |
st.audio(st.session_state.audio_path, format="audio/wav") |