Spaces:
Sleeping
Sleeping
File size: 4,215 Bytes
899d4f3 fbb29cc 5abc527 5403777 d50e22f 242815e fbb29cc 242815e fbb29cc dd27423 fbb29cc 5444a62 fbb29cc 9bc5a61 2a45f93 fbb29cc c90f796 5abc527 5403777 5abc527 5403777 b601058 5403777 b601058 5abc527 5403777 b601058 54b1b7d fbb29cc 5403777 ca4c4e9 5403777 6e7b03e 5abc527 5403777 fbb29cc 5403777 d50e22f b601058 5403777 b601058 5403777 b601058 5403777 6e7b03e 5403777 ca4c4e9 5403777 ca4c4e9 d50e22f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import streamlit as st
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from transformers import MarianMTModel, MarianTokenizer
import soundfile as sf
from datasets import load_dataset
import os
# Device setup
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Load Whisper model
whisper_model_id = "openai/whisper-large-v3"
whisper_model = AutoModelForSpeechSeq2Seq.from_pretrained(
whisper_model_id, torch_dtype=torch_dtype, use_safetensors=True
).to(device)
whisper_processor = AutoProcessor.from_pretrained(whisper_model_id)
whisper_pipe = pipeline(
"automatic-speech-recognition",
model=whisper_model,
tokenizer=whisper_processor.tokenizer,
feature_extractor=whisper_processor.feature_extractor,
torch_dtype=torch_dtype,
device=device,
)
# Load TTS model
tts_pipe = pipeline("text-to-speech", "microsoft/speecht5_tts")
# Load speaker embeddings
def get_speaker_embedding():
dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
# Use the first sample's embedding as an example
speaker_embedding = torch.tensor(dataset[7306]["xvector"]).unsqueeze(0)
return speaker_embedding
speaker_embedding = get_speaker_embedding()
# Load translation model
def load_translation_model(lang_code):
model_name = f"Helsinki-NLP/opus-mt-en-{lang_code}"
model = MarianMTModel.from_pretrained(model_name)
tokenizer = MarianTokenizer.from_pretrained(model_name)
return model, tokenizer
st.title("TextLangAudioGenerator")
# Text input
text_input = st.text_area("Enter text in English", key="text_input")
# Select target language
target_lang = st.selectbox(
"Select target language",
["fr", "zh", "it", "ur", "hi"], # Add more language codes as needed
format_func=lambda x: {"fr": "French", "zh": "Chinese", "it": "Italian", "ur": "Urdu", "hi": "Hindi"}.get(x, x),
key="target_lang"
)
# Initialize session state for storing results
if "translated_text" not in st.session_state:
st.session_state.translated_text = ""
if "audio_path" not in st.session_state:
st.session_state.audio_path = ""
# Submit button
if st.button("Submit"):
if text_input and target_lang:
# Load translation model
model, tokenizer = load_translation_model(target_lang)
inputs = tokenizer(text_input, return_tensors="pt")
translated = model.generate(**inputs)
st.session_state.translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
st.session_state.audio_path = "" # Clear previous audio path
st.write(f"Translated text: {st.session_state.translated_text}")
else:
st.error("Please enter text and select a target language.")
# Listen to Translated Audio button
if st.button("Listen to Translated Audio"):
if st.session_state.translated_text:
# Generate TTS
speech = tts_pipe(st.session_state.translated_text, forward_params={"speaker_embeddings": speaker_embedding})
st.session_state.audio_path = "translated_speech.wav"
sf.write(st.session_state.audio_path, speech["audio"], samplerate=speech["sampling_rate"])
# Ensure the audio file exists and display the player
if os.path.exists(st.session_state.audio_path):
st.audio(st.session_state.audio_path, format="audio/wav")
else:
st.error("Failed to generate audio. Please try again.")
else:
st.error("Please submit the text first.")
# Reset button
if st.button("Reset"):
st.session_state.translated_text = ""
st.session_state.audio_path = ""
st.text_area("Enter text in English", value="", key="text_input")
st.selectbox("Select target language", ["fr", "zh", "it", "ur", "hi"], key="target_lang")
st.experimental_rerun() # Reload the app to reset the inputs
# Display current state of translated text and audio
if st.session_state.translated_text and not st.session_state.audio_path:
st.write(f"Translated text: {st.session_state.translated_text}")
elif st.session_state.audio_path:
st.audio(st.session_state.audio_path, format="audio/wav")
|