Spaces:
Runtime error
Runtime error
ikechan8370
commited on
Commit
•
b772f7c
1
Parent(s):
7e90749
feat: add support for gpu
Browse files
app.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
# coding=utf-8
|
2 |
import time
|
3 |
import gradio as gr
|
4 |
import utils
|
@@ -6,14 +5,16 @@ import commons
|
|
6 |
from models import SynthesizerTrn
|
7 |
from text import text_to_sequence
|
8 |
from torch import no_grad, LongTensor
|
|
|
9 |
|
10 |
hps_ms = utils.get_hparams_from_file(r'./model/config.json')
|
|
|
11 |
net_g_ms = SynthesizerTrn(
|
12 |
len(hps_ms.symbols),
|
13 |
hps_ms.data.filter_length // 2 + 1,
|
14 |
hps_ms.train.segment_size // hps_ms.data.hop_length,
|
15 |
n_speakers=hps_ms.data.n_speakers,
|
16 |
-
**hps_ms.model)
|
17 |
_ = net_g_ms.eval()
|
18 |
speakers = hps_ms.speakers
|
19 |
model, optimizer, learning_rate, epochs = utils.load_checkpoint(r'./model/G_953000.pth', net_g_ms, None)
|
@@ -30,7 +31,7 @@ def vits(text, language, speaker_id, noise_scale, noise_scale_w, length_scale):
|
|
30 |
if not len(text):
|
31 |
return "输入文本不能为空!", None, None
|
32 |
text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
|
33 |
-
if len(text) >
|
34 |
return f"输入文字过长!{len(text)}>100", None, None
|
35 |
if language == 0:
|
36 |
text = f"[ZH]{text}[ZH]"
|
@@ -44,7 +45,7 @@ def vits(text, language, speaker_id, noise_scale, noise_scale_w, length_scale):
|
|
44 |
x_tst_lengths = LongTensor([stn_tst.size(0)])
|
45 |
speaker_id = LongTensor([speaker_id])
|
46 |
audio = net_g_ms.infer(x_tst, x_tst_lengths, sid=speaker_id, noise_scale=noise_scale, noise_scale_w=noise_scale_w,
|
47 |
-
length_scale=length_scale)[0][0, 0].data.float().numpy()
|
48 |
|
49 |
return "生成成功!", (22050, audio), f"生成耗时 {round(time.perf_counter()-start, 2)} s"
|
50 |
|
@@ -116,8 +117,8 @@ if __name__ == '__main__':
|
|
116 |
download = gr.Button("Download Audio")
|
117 |
btn.click(vits, inputs=[input_text, lang, sid, ns, nsw, ls], outputs=[o1, o2, o3], api_name="generate")
|
118 |
download.click(None, [], [], _js=download_audio_js.format())
|
119 |
-
btn2.click(search_speaker, inputs=[search], outputs=[sid]
|
120 |
-
lang.change(change_lang, inputs=[lang], outputs=[ns, nsw, ls]
|
121 |
with gr.TabItem("可用人物一览"):
|
122 |
gr.Radio(label="Speaker", choices=speakers, interactive=False, type="index")
|
123 |
-
app.queue(concurrency_count=1).launch()
|
|
|
|
|
1 |
import time
|
2 |
import gradio as gr
|
3 |
import utils
|
|
|
5 |
from models import SynthesizerTrn
|
6 |
from text import text_to_sequence
|
7 |
from torch import no_grad, LongTensor
|
8 |
+
import torch
|
9 |
|
10 |
hps_ms = utils.get_hparams_from_file(r'./model/config.json')
|
11 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
net_g_ms = SynthesizerTrn(
|
13 |
len(hps_ms.symbols),
|
14 |
hps_ms.data.filter_length // 2 + 1,
|
15 |
hps_ms.train.segment_size // hps_ms.data.hop_length,
|
16 |
n_speakers=hps_ms.data.n_speakers,
|
17 |
+
**hps_ms.model).to(device)
|
18 |
_ = net_g_ms.eval()
|
19 |
speakers = hps_ms.speakers
|
20 |
model, optimizer, learning_rate, epochs = utils.load_checkpoint(r'./model/G_953000.pth', net_g_ms, None)
|
|
|
31 |
if not len(text):
|
32 |
return "输入文本不能为空!", None, None
|
33 |
text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
|
34 |
+
if len(text) > 500:
|
35 |
return f"输入文字过长!{len(text)}>100", None, None
|
36 |
if language == 0:
|
37 |
text = f"[ZH]{text}[ZH]"
|
|
|
45 |
x_tst_lengths = LongTensor([stn_tst.size(0)])
|
46 |
speaker_id = LongTensor([speaker_id])
|
47 |
audio = net_g_ms.infer(x_tst, x_tst_lengths, sid=speaker_id, noise_scale=noise_scale, noise_scale_w=noise_scale_w,
|
48 |
+
length_scale=length_scale)[0][0, 0].data.cpu().float().numpy()
|
49 |
|
50 |
return "生成成功!", (22050, audio), f"生成耗时 {round(time.perf_counter()-start, 2)} s"
|
51 |
|
|
|
117 |
download = gr.Button("Download Audio")
|
118 |
btn.click(vits, inputs=[input_text, lang, sid, ns, nsw, ls], outputs=[o1, o2, o3], api_name="generate")
|
119 |
download.click(None, [], [], _js=download_audio_js.format())
|
120 |
+
btn2.click(search_speaker, inputs=[search], outputs=[sid])
|
121 |
+
lang.change(change_lang, inputs=[lang], outputs=[ns, nsw, ls])
|
122 |
with gr.TabItem("可用人物一览"):
|
123 |
gr.Radio(label="Speaker", choices=speakers, interactive=False, type="index")
|
124 |
+
app.queue(concurrency_count=1).launch()
|