Hecheng0625 commited on
Commit
36f9ba6
β€’
1 Parent(s): fc42837

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -37
app.py CHANGED
@@ -19,23 +19,41 @@ from models.tts.maskgct.g2p.g2p_generation import g2p, chn_eng_g2p
19
 
20
  from transformers import SeamlessM4TFeatureExtractor
21
 
22
- # import whisperx
 
23
 
24
  processor = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
25
 
26
  device = torch.device("cuda" if torch.cuda.is_available() else "CPU")
27
 
28
- # whisper_model = whisperx.load_model("small", "cuda", compute_type="int8")
29
 
30
- # @torch.no_grad()
31
- # def get_prompt_text(speech_16k):
32
- # asr_result = whisper_model.transcribe(speech_16k)
33
- # print("asr_result:", asr_result)
34
- # language = asr_result["language"]
35
- # #text = asr_result["text"] # whisper asr result
36
- # text = asr_result["segments"][0]["text"]
37
- # print("prompt text:", text)
38
- # return text, language
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  def g2p_(text, language):
@@ -279,10 +297,7 @@ def load_models():
279
  @torch.no_grad()
280
  def maskgct_inference(
281
  prompt_speech_path,
282
- prompt_text,
283
  target_text,
284
- language="en",
285
- target_language="en",
286
  target_len=None,
287
  n_timesteps=25,
288
  cfg=2.5,
@@ -295,13 +310,17 @@ def maskgct_inference(
295
  speech_16k = librosa.load(prompt_speech_path, sr=16000)[0]
296
  speech = librosa.load(prompt_speech_path, sr=24000)[0]
297
 
298
- # if prompt_text is None:
299
- # prompt_text, language = get_prompt_text(prompt_speech_path)
300
-
 
 
 
 
301
  combine_semantic_code, _ = text2semantic(
302
  device,
303
  speech_16k,
304
- prompt_text,
305
  language,
306
  target_text,
307
  target_language,
@@ -326,21 +345,15 @@ def maskgct_inference(
326
  @spaces.GPU
327
  def inference(
328
  prompt_wav,
329
- prompt_text,
330
  target_text,
331
  target_len,
332
  n_timesteps,
333
- language,
334
- target_language,
335
  ):
336
  save_path = "./output/output.wav"
337
  os.makedirs("./output", exist_ok=True)
338
  recovered_audio = maskgct_inference(
339
  prompt_wav,
340
- prompt_text,
341
  target_text,
342
- language,
343
- target_language,
344
  target_len=target_len,
345
  n_timesteps=int(n_timesteps),
346
  device=device,
@@ -369,7 +382,6 @@ iface = gr.Interface(
369
  fn=inference,
370
  inputs=[
371
  gr.Audio(label="Upload Prompt Wav", type="filepath"),
372
- gr.Textbox(label="Prompt Text"),
373
  gr.Textbox(label="Target Text"),
374
  gr.Number(
375
  label="Target Duration (in seconds), if the target duration is less than 0, the system will estimate a duration.", value=-1
@@ -377,23 +389,13 @@ iface = gr.Interface(
377
  gr.Slider(
378
  label="Number of Timesteps", minimum=15, maximum=100, value=25, step=1
379
  ),
380
- gr.Dropdown(label="Language", choices=language_list, value="en"),
381
- gr.Dropdown(label="Target Language", choices=language_list, value="en"),
382
  ],
383
  outputs=gr.Audio(label="Generated Audio"),
384
  title="MaskGCT TTS Demo",
385
  description="""
386
- ## MaskGCT: Zero-Shot Text-to-Speech with Masked Generative Codec Transformer
387
-
388
- [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2409.00750)
389
-
390
- [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-model-yellow)](https://huggingface.co/amphion/maskgct)
391
-
392
- [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-demo-pink)](https://huggingface.co/spaces/amphion/maskgct)
393
-
394
- [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](https://github.com/open-mmlab/Amphion/tree/main/models/tts/maskgct)
395
  """
396
  )
397
 
398
  # Launch the interface
399
- iface.launch(allowed_paths=["./output"])
 
19
 
20
  from transformers import SeamlessM4TFeatureExtractor
21
 
22
+ import whisper
23
+ import langid
24
 
25
  processor = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
26
 
27
  device = torch.device("cuda" if torch.cuda.is_available() else "CPU")
28
 
29
+ whisper_model = whisper.load_model("turbo")
30
 
31
+ def detect_speech_language(speech_16k):
32
+ return whisper_model.detect_language(speech_16k)
33
+
34
+ def detect_text_language(text):
35
+ langid.classify(text)[0]
36
+
37
+ @torch.no_grad()
38
+ def get_prompt_text(speech_16k, language):
39
+ full_prompt_text = ""
40
+ shot_prompt_text = ""
41
+ short_prompt_end_ts = 0.0
42
+
43
+ asr_result = whisper_model.transcribe(speech_16k, language=language)
44
+ print("asr_result:", asr_result)
45
+ full_prompt_text = asr_result["text"] # whisper asr result
46
+ #text = asr_result["segments"][0]["text"] # whisperx asr result
47
+ shot_prompt_text = ""
48
+ short_prompt_end_ts = 0.0
49
+ for segment in asr_result["segments"]:
50
+ shot_prompt_text = shot_prompt_text + segment['text']
51
+ short_prompt_end_ts = segment['end']
52
+ if short_prompt_end_ts >= 4:
53
+ break
54
+ print("full prompt text:", full_prompt_text, " shot_prompt_text:", shot_prompt_text,
55
+ "short_prompt_end_ts:", short_prompt_end_ts)
56
+ return full_prompt_text, shot_prompt_text, short_prompt_end_ts
57
 
58
 
59
  def g2p_(text, language):
 
297
  @torch.no_grad()
298
  def maskgct_inference(
299
  prompt_speech_path,
 
300
  target_text,
 
 
301
  target_len=None,
302
  n_timesteps=25,
303
  cfg=2.5,
 
310
  speech_16k = librosa.load(prompt_speech_path, sr=16000)[0]
311
  speech = librosa.load(prompt_speech_path, sr=24000)[0]
312
 
313
+ prompt_language = detect_speech_language(speech_16k)
314
+ full_prompt_text, short_prompt_text, shot_prompt_end_ts = get_prompt_text(prompt_speech_path,
315
+ prompt_language)
316
+ # use the first 4+ seconds wav as the prompt in case the prompt wav is too long
317
+ speech = speech[0: int(shot_prompt_end_ts * 24000)]
318
+ speech_16k = speech_16k[0: int(shot_prompt_end_ts*16000)]
319
+ target_language = detect_text_language(target_text)
320
  combine_semantic_code, _ = text2semantic(
321
  device,
322
  speech_16k,
323
+ short_prompt_text,
324
  language,
325
  target_text,
326
  target_language,
 
345
  @spaces.GPU
346
  def inference(
347
  prompt_wav,
 
348
  target_text,
349
  target_len,
350
  n_timesteps,
 
 
351
  ):
352
  save_path = "./output/output.wav"
353
  os.makedirs("./output", exist_ok=True)
354
  recovered_audio = maskgct_inference(
355
  prompt_wav,
 
356
  target_text,
 
 
357
  target_len=target_len,
358
  n_timesteps=int(n_timesteps),
359
  device=device,
 
382
  fn=inference,
383
  inputs=[
384
  gr.Audio(label="Upload Prompt Wav", type="filepath"),
 
385
  gr.Textbox(label="Target Text"),
386
  gr.Number(
387
  label="Target Duration (in seconds), if the target duration is less than 0, the system will estimate a duration.", value=-1
 
389
  gr.Slider(
390
  label="Number of Timesteps", minimum=15, maximum=100, value=25, step=1
391
  ),
 
 
392
  ],
393
  outputs=gr.Audio(label="Generated Audio"),
394
  title="MaskGCT TTS Demo",
395
  description="""
396
+ [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2409.00750) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-model-yellow)](https://huggingface.co/amphion/maskgct) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-demo-pink)](https://huggingface.co/spaces/amphion/maskgct) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](https://github.com/open-mmlab/Amphion/tree/main/models/tts/maskgct)
 
 
 
 
 
 
 
 
397
  """
398
  )
399
 
400
  # Launch the interface
401
+ iface.launch(allowed_paths=["./output"])