fargerm commited on
Commit
5403777
1 Parent(s): b601058

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -18
app.py CHANGED
@@ -3,6 +3,7 @@ import torch
3
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
4
  from transformers import MarianMTModel, MarianTokenizer
5
  import soundfile as sf
 
6
 
7
  # Device setup
8
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
@@ -28,7 +29,6 @@ tts_pipe = pipeline("text-to-speech", "microsoft/speecht5_tts")
28
 
29
  # Load speaker embeddings
30
  def get_speaker_embedding():
31
- from datasets import load_dataset
32
  dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
33
  # Use the first sample's embedding as an example
34
  speaker_embedding = torch.tensor(dataset[7306]["xvector"]).unsqueeze(0)
@@ -45,38 +45,54 @@ def load_translation_model(lang_code):
45
 
46
  st.title("TextLangAudioGenerator")
47
 
48
- # Text input box
49
- text_input = st.text_area("Enter text in English")
50
 
51
- # Language selection dropdown
52
  target_lang = st.selectbox(
53
  "Select target language",
54
  ["fr", "zh", "it", "ur", "hi"], # Add more language codes as needed
55
- format_func=lambda x: {"fr": "French", "zh": "Chinese", "it": "Italian", "ur": "Urdu", "hi": "Hindi"}.get(x, x)
 
56
  )
57
 
58
- # Buttons for actions
59
- if st.button("Translate and Generate Audio"):
 
 
 
 
 
 
60
  if text_input and target_lang:
61
  # Load translation model
62
  model, tokenizer = load_translation_model(target_lang)
63
  inputs = tokenizer(text_input, return_tensors="pt")
64
  translated = model.generate(**inputs)
65
- translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
66
- st.write(f"Translated text: {translated_text}")
67
 
 
 
 
68
  # Generate TTS
69
- speech = tts_pipe(translated_text, forward_params={"speaker_embeddings": speaker_embedding})
70
- audio_path = "translated_speech.wav"
71
- sf.write(audio_path, speech["audio"], samplerate=speech["sampling_rate"])
72
-
73
- st.audio(audio_path, format="audio/wav")
74
  else:
75
- st.error("Please enter text and select a target language.")
76
 
77
- # Optional: Add a button for additional functionality if needed
78
  if st.button("Reset"):
79
- st.text_area("Enter text in English", value="")
80
- st.selectbox("Select target language", ["fr", "zh", "it", "ur", "hi"])
 
 
 
 
 
 
 
81
 
82
 
 
3
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
4
  from transformers import MarianMTModel, MarianTokenizer
5
  import soundfile as sf
6
+ from datasets import load_dataset
7
 
8
  # Device setup
9
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
 
29
 
30
  # Load speaker embeddings
31
  def get_speaker_embedding():
 
32
  dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
33
  # Use the first sample's embedding as an example
34
  speaker_embedding = torch.tensor(dataset[7306]["xvector"]).unsqueeze(0)
 
45
 
46
  st.title("TextLangAudioGenerator")
47
 
48
+ # Text input
49
+ text_input = st.text_area("Enter text in English", key="text_input")
50
 
51
+ # Select target language
52
  target_lang = st.selectbox(
53
  "Select target language",
54
  ["fr", "zh", "it", "ur", "hi"], # Add more language codes as needed
55
+ format_func=lambda x: {"fr": "French", "zh": "Chinese", "it": "Italian", "ur": "Urdu", "hi": "Hindi"}.get(x, x),
56
+ key="target_lang"
57
  )
58
 
59
+ # Initialize session state for storing results
60
+ if "translated_text" not in st.session_state:
61
+ st.session_state.translated_text = ""
62
+ if "audio_path" not in st.session_state:
63
+ st.session_state.audio_path = ""
64
+
65
+ # Submit button
66
+ if st.button("Submit"):
67
  if text_input and target_lang:
68
  # Load translation model
69
  model, tokenizer = load_translation_model(target_lang)
70
  inputs = tokenizer(text_input, return_tensors="pt")
71
  translated = model.generate(**inputs)
72
+ st.session_state.translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
73
+ st.write(f"Translated text: {st.session_state.translated_text}")
74
 
75
+ # Listen to Translated Audio button
76
+ if st.button("Listen to Translated Audio"):
77
+ if st.session_state.translated_text:
78
  # Generate TTS
79
+ speech = tts_pipe(st.session_state.translated_text, forward_params={"speaker_embeddings": speaker_embedding})
80
+ st.session_state.audio_path = "translated_speech.wav"
81
+ sf.write(st.session_state.audio_path, speech["audio"], samplerate=speech["sampling_rate"])
82
+ st.audio(st.session_state.audio_path, format="audio/wav")
 
83
  else:
84
+ st.error("Please submit the text first.")
85
 
86
+ # Reset button
87
  if st.button("Reset"):
88
+ st.session_state.translated_text = ""
89
+ st.session_state.audio_path = ""
90
+ st.experimental_rerun() # Reload the app to reset the inputs
91
+
92
+ # Display current state of translated text and audio
93
+ if st.session_state.translated_text:
94
+ st.write(f"Translated text: {st.session_state.translated_text}")
95
+ if st.session_state.audio_path:
96
+ st.audio(st.session_state.audio_path, format="audio/wav")
97
 
98