fargerm's picture
Update app.py
54b1b7d verified
raw
history blame
2.22 kB
import streamlit as st
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from transformers import MarianMTModel, MarianTokenizer
import soundfile as sf
# Device setup
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Load Whisper model
whisper_model_id = "openai/whisper-large-v3"
whisper_model = AutoModelForSpeechSeq2Seq.from_pretrained(
whisper_model_id, torch_dtype=torch_dtype, use_safetensors=True
).to(device)
whisper_processor = AutoProcessor.from_pretrained(whisper_model_id)
whisper_pipe = pipeline(
"automatic-speech-recognition",
model=whisper_model,
tokenizer=whisper_processor.tokenizer,
feature_extractor=whisper_processor.feature_extractor,
torch_dtype=torch_dtype,
device=device,
)
# Load TTS model
tts_pipe = pipeline("text-to-speech", "microsoft/speecht5_tts")
# Load translation model
def load_translation_model(lang_code):
model_name = f"Helsinki-NLP/opus-mt-en-{lang_code}"
model = MarianMTModel.from_pretrained(model_name)
tokenizer = MarianTokenizer.from_pretrained(model_name)
return model, tokenizer
st.title("TextLangAudioGenerator")
# Text input
text_input = st.text_area("Enter text in English")
if text_input:
# Select target language
target_lang = st.selectbox(
"Select target language",
["fr", "zh", "it", "ur", "hi"], # Add more language codes as needed
format_func=lambda x: {"fr": "French", "zh": "Chinese", "it": "Italian", "ur": "Urdu", "hi": "Hindi"}.get(x, x)
)
if target_lang:
# Load translation model
model, tokenizer = load_translation_model(target_lang)
inputs = tokenizer(text_input, return_tensors="pt")
translated = model.generate(**inputs)
translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
st.write(f"Translated text: {translated_text}")
# Generate TTS
speech = tts_pipe(translated_text)
audio_path = "translated_speech.wav"
sf.write(audio_path, speech["audio"], samplerate=speech["sampling_rate"])
st.audio(audio_path, format="audio/wav")