import streamlit as st import torch from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline from transformers import MarianMTModel, MarianTokenizer import soundfile as sf from datasets import load_dataset # Device setup device = "cuda:0" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # Load Whisper model whisper_model_id = "openai/whisper-large-v3" whisper_model = AutoModelForSpeechSeq2Seq.from_pretrained( whisper_model_id, torch_dtype=torch_dtype, use_safetensors=True ).to(device) whisper_processor = AutoProcessor.from_pretrained(whisper_model_id) whisper_pipe = pipeline( "automatic-speech-recognition", model=whisper_model, tokenizer=whisper_processor.tokenizer, feature_extractor=whisper_processor.feature_extractor, torch_dtype=torch_dtype, device=device, ) # Load TTS model tts_pipe = pipeline("text-to-speech", "microsoft/speecht5_tts") # Load speaker embeddings def get_speaker_embedding(): dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") # Use the first sample's embedding as an example speaker_embedding = torch.tensor(dataset[7306]["xvector"]).unsqueeze(0) return speaker_embedding speaker_embedding = get_speaker_embedding() # Load translation model def load_translation_model(lang_code): model_name = f"Helsinki-NLP/opus-mt-en-{lang_code}" model = MarianMTModel.from_pretrained(model_name) tokenizer = MarianTokenizer.from_pretrained(model_name) return model, tokenizer st.title("TextLangAudioGenerator") # Text input text_input = st.text_area("Enter text in English", key="text_input") # Select target language target_lang = st.selectbox( "Select target language", ["fr", "zh", "it", "ur", "hi"], # Add more language codes as needed format_func=lambda x: {"fr": "French", "zh": "Chinese", "it": "Italian", "ur": "Urdu", "hi": "Hindi"}.get(x, x), key="target_lang" ) # Initialize session state for storing results if "translated_text" not in st.session_state: st.session_state.translated_text = "" if "audio_path" not in st.session_state: st.session_state.audio_path = "" # Submit button if st.button("Submit"): if text_input and target_lang: # Load translation model model, tokenizer = load_translation_model(target_lang) inputs = tokenizer(text_input, return_tensors="pt") translated = model.generate(**inputs) st.session_state.translated_text = tokenizer.decode(translated[0], skip_special_tokens=True) st.session_state.audio_path = "" # Clear previous audio path st.write(f"Translated text: {st.session_state.translated_text}") else: st.error("Please enter text and select a target language.") # Listen to Translated Audio button if st.button("Listen to Translated Audio"): if st.session_state.translated_text: # Generate TTS speech = tts_pipe(st.session_state.translated_text, forward_params={"speaker_embeddings": speaker_embedding}) st.session_state.audio_path = "translated_speech.wav" sf.write(st.session_state.audio_path, speech["audio"], samplerate=speech["sampling_rate"]) st.audio(st.session_state.audio_path, format="audio/wav") else: st.error("Please submit the text first.") # Reset button if st.button("Reset"): st.session_state.translated_text = "" st.session_state.audio_path = "" st.text_area("Enter text in English", value="", key="text_input") st.selectbox("Select target language", ["fr", "zh", "it", "ur", "hi"], key="target_lang") st.experimental_rerun() # Reload the app to reset the inputs # Display current state of translated text and audio if st.session_state.translated_text and not st.session_state.audio_path: st.write(f"Translated text: {st.session_state.translated_text}") elif st.session_state.audio_path: st.audio(st.session_state.audio_path, format="audio/wav")