JackismyShephard commited on
Commit
831b161
·
1 Parent(s): 7cbdcbc

use pipe abstraction for inference

Browse files
Files changed (1) hide show
  1. app.py +12 -14
app.py CHANGED
@@ -2,13 +2,17 @@ import gradio as gr
2
  import numpy as np
3
  import torch
4
 
5
- from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
6
 
7
- checkpoint_base = "microsoft/speecht5_tts"
8
  checkpoint_finetuned = "JackismyShephard/speecht5_tts-finetuned-nst-da"
9
- processor = SpeechT5Processor.from_pretrained(checkpoint_base)
10
- model = SpeechT5ForTextToSpeech.from_pretrained(checkpoint_finetuned)
11
- vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
 
 
 
 
 
12
 
13
  speaker_embeddings = {
14
  "F23": "embeddings/female_23_vestjylland.npy",
@@ -26,12 +30,6 @@ def predict(text, speaker):
26
 
27
  text = replace_danish_letters(text)
28
 
29
- inputs = processor(text=text, return_tensors="pt")
30
-
31
- # limit input length
32
- input_ids = inputs["input_ids"]
33
- # input_ids = input_ids[..., : model.config.max_text_positions]
34
-
35
  speaker_id = speaker[:3]
36
 
37
  speaker_embedding_path = speaker_embeddings[speaker_id]
@@ -40,10 +38,10 @@ def predict(text, speaker):
40
 
41
  speaker_embedding = torch.tensor(speaker_embedding).unsqueeze(0)
42
 
43
- speech = model.generate_speech(input_ids, speaker_embedding, vocoder=vocoder)
 
44
 
45
- speech = speech.numpy()
46
- return (16000, speech)
47
 
48
 
49
  def replace_danish_letters(text):
 
2
  import numpy as np
3
  import torch
4
 
5
+ from transformers import pipeline
6
 
 
7
  checkpoint_finetuned = "JackismyShephard/speecht5_tts-finetuned-nst-da"
8
+
9
+ pipe = pipeline(
10
+ "text-to-speech",
11
+ model=checkpoint_finetuned,
12
+ use_fast=True,
13
+ device=0 if torch.cuda.is_available() else "cpu",
14
+ )
15
+
16
 
17
  speaker_embeddings = {
18
  "F23": "embeddings/female_23_vestjylland.npy",
 
30
 
31
  text = replace_danish_letters(text)
32
 
 
 
 
 
 
 
33
  speaker_id = speaker[:3]
34
 
35
  speaker_embedding_path = speaker_embeddings[speaker_id]
 
38
 
39
  speaker_embedding = torch.tensor(speaker_embedding).unsqueeze(0)
40
 
41
+ forward_params = {"speaker_embeddings": speaker_embedding}
42
+ speech = pipe(text, forward_params=forward_params)
43
 
44
+ return (speech["sampling_rate"], speech["audio"])
 
45
 
46
 
47
  def replace_danish_letters(text):