add new vocoder model and denoiser
Browse files- infer_onnx.py +37 -5
- mel_spec_22khz_v2.onnx +3 -0
infer_onnx.py
CHANGED
@@ -31,7 +31,7 @@ def process_text(i: int, text: str, device: torch.device):
|
|
31 |
|
32 |
MODEL_PATH_MATCHA_MEL="matcha_multispeaker_cat_opset_15_10_steps.onnx"
|
33 |
MODEL_PATH_MATCHA="matcha_hifigan_multispeaker_cat.onnx"
|
34 |
-
MODEL_PATH_VOCOS="
|
35 |
CONFIG_PATH="config_22khz.yaml"
|
36 |
|
37 |
sess_options = onnxruntime.SessionOptions()
|
@@ -40,7 +40,7 @@ model_vocos = onnxruntime.InferenceSession(str(MODEL_PATH_VOCOS), sess_options=s
|
|
40 |
model_matcha = onnxruntime.InferenceSession(str(MODEL_PATH_MATCHA), sess_options=sess_options, providers=["CPUExecutionProvider"])
|
41 |
|
42 |
|
43 |
-
def vocos_inference(mel):
|
44 |
|
45 |
with open(CONFIG_PATH, "r") as f:
|
46 |
config = yaml.safe_load(f)
|
@@ -63,6 +63,37 @@ def vocos_inference(mel):
|
|
63 |
spectrogram = mag * (x + 1j * y)
|
64 |
window = torch.hann_window(win_length)
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
# Inverse stft
|
67 |
pad = (win_length - hop_length) // 2
|
68 |
spectrogram = torch.tensor(spectrogram)
|
@@ -92,7 +123,7 @@ def vocos_inference(mel):
|
|
92 |
return y
|
93 |
|
94 |
|
95 |
-
def tts(text:str, spk_id:int, temperature:float, length_scale:float):
|
96 |
sid = np.array([int(spk_id)]) if spk_id is not None else None
|
97 |
text_matcha , text_lengths = process_text(0,text,"cpu")
|
98 |
|
@@ -111,7 +142,7 @@ def tts(text:str, spk_id:int, temperature:float, length_scale:float):
|
|
111 |
|
112 |
vocos_t0 = perf_counter()
|
113 |
# vocos inference
|
114 |
-
wavs_vocos = vocos_inference(mel)
|
115 |
vocos_infer_secs = perf_counter() - vocos_t0
|
116 |
print("Vocos inference time", vocos_infer_secs)
|
117 |
|
@@ -193,7 +224,8 @@ vits2_inference = gr.Interface(
|
|
193 |
step=0.01,
|
194 |
label="Length scale",
|
195 |
info=f"Controls speech pace, larger values for slower pace and smaller values for faster pace",
|
196 |
-
)
|
|
|
197 |
],
|
198 |
outputs=[gr.Audio(label="Matcha vocos", interactive=False, type="filepath"),
|
199 |
gr.Audio(label="Matcha hifigan", interactive=False, type="filepath")]
|
|
|
31 |
|
32 |
MODEL_PATH_MATCHA_MEL="matcha_multispeaker_cat_opset_15_10_steps.onnx"
|
33 |
MODEL_PATH_MATCHA="matcha_hifigan_multispeaker_cat.onnx"
|
34 |
+
MODEL_PATH_VOCOS="mel_spec_22khz_v2.onnx"
|
35 |
CONFIG_PATH="config_22khz.yaml"
|
36 |
|
37 |
sess_options = onnxruntime.SessionOptions()
|
|
|
40 |
model_matcha = onnxruntime.InferenceSession(str(MODEL_PATH_MATCHA), sess_options=sess_options, providers=["CPUExecutionProvider"])
|
41 |
|
42 |
|
43 |
+
def vocos_inference(mel,denoise):
|
44 |
|
45 |
with open(CONFIG_PATH, "r") as f:
|
46 |
config = yaml.safe_load(f)
|
|
|
63 |
spectrogram = mag * (x + 1j * y)
|
64 |
window = torch.hann_window(win_length)
|
65 |
|
66 |
+
if denoise:
|
67 |
+
# Vocoder bias
|
68 |
+
mel_rand = torch.zeros_like(torch.tensor(mel))
|
69 |
+
mag_bias, x_bias, y_bias = model_vocos.run(
|
70 |
+
None,
|
71 |
+
{
|
72 |
+
"mels": mel_rand.float().numpy()
|
73 |
+
},
|
74 |
+
)
|
75 |
+
|
76 |
+
# complex spectrogram from vocos output
|
77 |
+
spectrogram_bias = mag_bias * (x_bias + 1j * y_bias)
|
78 |
+
|
79 |
+
# Denoising
|
80 |
+
spec = torch.view_as_real(torch.tensor(spectrogram))
|
81 |
+
# get magnitude of vocos spectrogram
|
82 |
+
mag_spec = torch.sqrt(spec.pow(2).sum(-1))
|
83 |
+
|
84 |
+
# get magnitude of bias spectrogram
|
85 |
+
spec_bias = torch.view_as_real(torch.tensor(spectrogram_bias))
|
86 |
+
mag_spec_bias = torch.sqrt(spec_bias.pow(2).sum(-1))
|
87 |
+
|
88 |
+
# substract
|
89 |
+
strength = 0.0005
|
90 |
+
mag_spec_denoised = mag_spec - mag_spec_bias * strength
|
91 |
+
mag_spec_denoised = torch.clamp(mag_spec_denoised, 0.0)
|
92 |
+
|
93 |
+
# return to complex spectrogram from magnitude
|
94 |
+
angle = torch.atan2(spec[..., -1], spec[..., 0] )
|
95 |
+
spectrogram = torch.complex(mag_spec_denoised * torch.cos(angle), mag_spec_denoised * torch.sin(angle))
|
96 |
+
|
97 |
# Inverse stft
|
98 |
pad = (win_length - hop_length) // 2
|
99 |
spectrogram = torch.tensor(spectrogram)
|
|
|
123 |
return y
|
124 |
|
125 |
|
126 |
+
def tts(text:str, spk_id:int, temperature:float, length_scale:float, denoise:bool):
|
127 |
sid = np.array([int(spk_id)]) if spk_id is not None else None
|
128 |
text_matcha , text_lengths = process_text(0,text,"cpu")
|
129 |
|
|
|
142 |
|
143 |
vocos_t0 = perf_counter()
|
144 |
# vocos inference
|
145 |
+
wavs_vocos = vocos_inference(mel,denoise)
|
146 |
vocos_infer_secs = perf_counter() - vocos_t0
|
147 |
print("Vocos inference time", vocos_infer_secs)
|
148 |
|
|
|
224 |
step=0.01,
|
225 |
label="Length scale",
|
226 |
info=f"Controls speech pace, larger values for slower pace and smaller values for faster pace",
|
227 |
+
),
|
228 |
+
gr.Checkbox(label="Denoise", info="Removes model bias from vocos"),
|
229 |
],
|
230 |
outputs=[gr.Audio(label="Matcha vocos", interactive=False, type="filepath"),
|
231 |
gr.Audio(label="Matcha hifigan", interactive=False, type="filepath")]
|
mel_spec_22khz_v2.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2b02c479881f89a8320024436e986f64b11e82b1fd48046d4b695c5fd9fb84e7
|
3 |
+
size 53883652
|