File size: 4,215 Bytes
899d4f3
fbb29cc
 
 
5abc527
5403777
d50e22f
242815e
fbb29cc
 
 
242815e
fbb29cc
 
 
dd27423
fbb29cc
 
 
 
 
 
 
 
 
 
5444a62
fbb29cc
 
9bc5a61
2a45f93
 
 
 
 
 
 
 
 
fbb29cc
 
 
 
 
 
c90f796
5abc527
 
5403777
 
5abc527
5403777
b601058
 
 
5403777
 
b601058
5abc527
5403777
 
 
 
 
 
 
 
b601058
54b1b7d
fbb29cc
 
 
5403777
ca4c4e9
5403777
6e7b03e
 
5abc527
5403777
 
 
fbb29cc
5403777
 
 
d50e22f
 
 
 
 
 
b601058
5403777
b601058
5403777
b601058
5403777
 
6e7b03e
 
5403777
 
 
ca4c4e9
5403777
ca4c4e9
d50e22f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import streamlit as st
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from transformers import MarianMTModel, MarianTokenizer
import soundfile as sf
from datasets import load_dataset
import os

# Device setup
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# Load Whisper model
whisper_model_id = "openai/whisper-large-v3"
whisper_model = AutoModelForSpeechSeq2Seq.from_pretrained(
    whisper_model_id, torch_dtype=torch_dtype, use_safetensors=True
).to(device)
whisper_processor = AutoProcessor.from_pretrained(whisper_model_id)
whisper_pipe = pipeline(
    "automatic-speech-recognition",
    model=whisper_model,
    tokenizer=whisper_processor.tokenizer,
    feature_extractor=whisper_processor.feature_extractor,
    torch_dtype=torch_dtype,
    device=device,
)

# Load TTS model
tts_pipe = pipeline("text-to-speech", "microsoft/speecht5_tts")

# Load speaker embeddings
def get_speaker_embedding():
    dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
    # Use the first sample's embedding as an example
    speaker_embedding = torch.tensor(dataset[7306]["xvector"]).unsqueeze(0)
    return speaker_embedding

speaker_embedding = get_speaker_embedding()

# Load translation model
def load_translation_model(lang_code):
    model_name = f"Helsinki-NLP/opus-mt-en-{lang_code}"
    model = MarianMTModel.from_pretrained(model_name)
    tokenizer = MarianTokenizer.from_pretrained(model_name)
    return model, tokenizer

st.title("TextLangAudioGenerator")

# Text input
text_input = st.text_area("Enter text in English", key="text_input")

# Select target language
target_lang = st.selectbox(
    "Select target language",
    ["fr", "zh", "it", "ur", "hi"],  # Add more language codes as needed
    format_func=lambda x: {"fr": "French", "zh": "Chinese", "it": "Italian", "ur": "Urdu", "hi": "Hindi"}.get(x, x),
    key="target_lang"
)

# Initialize session state for storing results
if "translated_text" not in st.session_state:
    st.session_state.translated_text = ""
if "audio_path" not in st.session_state:
    st.session_state.audio_path = ""

# Submit button
if st.button("Submit"):
    if text_input and target_lang:
        # Load translation model
        model, tokenizer = load_translation_model(target_lang)
        inputs = tokenizer(text_input, return_tensors="pt")
        translated = model.generate(**inputs)
        st.session_state.translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
        st.session_state.audio_path = ""  # Clear previous audio path
        st.write(f"Translated text: {st.session_state.translated_text}")
    else:
        st.error("Please enter text and select a target language.")

# Listen to Translated Audio button
if st.button("Listen to Translated Audio"):
    if st.session_state.translated_text:
        # Generate TTS
        speech = tts_pipe(st.session_state.translated_text, forward_params={"speaker_embeddings": speaker_embedding})
        st.session_state.audio_path = "translated_speech.wav"
        sf.write(st.session_state.audio_path, speech["audio"], samplerate=speech["sampling_rate"])
        
        # Ensure the audio file exists and display the player
        if os.path.exists(st.session_state.audio_path):
            st.audio(st.session_state.audio_path, format="audio/wav")
        else:
            st.error("Failed to generate audio. Please try again.")
    else:
        st.error("Please submit the text first.")

# Reset button
if st.button("Reset"):
    st.session_state.translated_text = ""
    st.session_state.audio_path = ""
    st.text_area("Enter text in English", value="", key="text_input")
    st.selectbox("Select target language", ["fr", "zh", "it", "ur", "hi"], key="target_lang")
    st.experimental_rerun()  # Reload the app to reset the inputs

# Display current state of translated text and audio
if st.session_state.translated_text and not st.session_state.audio_path:
    st.write(f"Translated text: {st.session_state.translated_text}")
elif st.session_state.audio_path:
    st.audio(st.session_state.audio_path, format="audio/wav")