Hecheng0625 commited on
Commit
451c794
β€’
1 Parent(s): 2c533c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -2
app.py CHANGED
@@ -19,9 +19,23 @@ from models.tts.maskgct.g2p.g2p_generation import g2p, chn_eng_g2p
19
 
20
  from transformers import SeamlessM4TFeatureExtractor
21
 
 
 
22
  processor = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
23
 
24
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
 
27
  def g2p_(text, language):
@@ -281,6 +295,9 @@ def maskgct_inference(
281
  speech_16k = librosa.load(prompt_speech_path, sr=16000)[0]
282
  speech = librosa.load(prompt_speech_path, sr=24000)[0]
283
 
 
 
 
284
  combine_semantic_code, _ = text2semantic(
285
  device,
286
  speech_16k,
@@ -352,7 +369,7 @@ iface = gr.Interface(
352
  fn=inference,
353
  inputs=[
354
  gr.Audio(label="Upload Prompt Wav", type="filepath"),
355
- gr.Textbox(label="Prompt Text"),
356
  gr.Textbox(label="Target Text"),
357
  gr.Number(
358
  label="Target Duration (in seconds), if the target duration is less than 0, the system will estimate a duration.", value=-1
 
19
 
20
  from transformers import SeamlessM4TFeatureExtractor
21
 
22
+ import whisperx
23
+
24
  processor = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
25
 
26
+ device = torch.device("cuda" if torch.cuda.is_available() else "CPU")
27
+
28
+ whisper_model = whisperx.load_model("small", "cuda", compute_type="int8")
29
+
30
+ @torch.no_grad()
31
+ def get_prompt_text(speech_16k):
32
+ asr_result = whisper_model.transcribe(speech_16k)
33
+ print("asr_result:", asr_result)
34
+ language = asr_result["language"]
35
+ #text = asr_result["text"] # whisper asr result
36
+ text = asr_result["segments"][0]["text"]
37
+ print("prompt text:", text)
38
+ return text, language
39
 
40
 
41
  def g2p_(text, language):
 
295
  speech_16k = librosa.load(prompt_speech_path, sr=16000)[0]
296
  speech = librosa.load(prompt_speech_path, sr=24000)[0]
297
 
298
+ if prompt_text is None:
299
+ prompt_text, language = get_prompt_text(prompt_speech_path)
300
+
301
  combine_semantic_code, _ = text2semantic(
302
  device,
303
  speech_16k,
 
369
  fn=inference,
370
  inputs=[
371
  gr.Audio(label="Upload Prompt Wav", type="filepath"),
372
+ gr.Textbox(label="Prompt Text, if None, the system will use an ASR model to detect prompt text and prompt language.", value=None),
373
  gr.Textbox(label="Target Text"),
374
  gr.Number(
375
  label="Target Duration (in seconds), if the target duration is less than 0, the system will estimate a duration.", value=-1