None1145 commited on
Commit
eea5553
·
verified ·
1 Parent(s): ef1f503

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +312 -306
app.py CHANGED
@@ -1,307 +1,313 @@
1
- import os
2
- from huggingface_hub import snapshot_download
3
- snapshot_download(repo_id="None1145/GPT-SoVITS-Lappland-the-Decadenza", cache_dir="./Models")
4
- snapshot_download(repo_id="None1145/GPT-SoVITS-Base", cache_dir="./PretrainedModels")
5
- cnhubert_base_path = "PretrainedModels/chinese-hubert-base"
6
- bert_path = "PretrainedModels/chinese-roberta-wwm-ext-large"
7
-
8
- import gradio as gr
9
- from transformers import AutoModelForMaskedLM, AutoTokenizer
10
- import sys,torch,numpy as np
11
- from pathlib import Path
12
- import os,pdb,utils,librosa,math,traceback,requests,argparse,torch,multiprocessing,pandas as pd,torch.multiprocessing as mp,soundfile
13
- # torch.backends.cuda.sdp_kernel("flash")
14
- # torch.backends.cuda.enable_flash_sdp(True)
15
- # torch.backends.cuda.enable_mem_efficient_sdp(True) # Not avaliable if torch version is lower than 2.0
16
- # torch.backends.cuda.enable_math_sdp(True)
17
- from random import shuffle
18
- from AR.utils import get_newest_ckpt
19
- from glob import glob
20
- from tqdm import tqdm
21
- from feature_extractor import cnhubert
22
- cnhubert.cnhubert_base_path=cnhubert_base_path
23
- from io import BytesIO
24
- from module.models import SynthesizerTrn
25
- from AR.models.t2s_lightning_module import Text2SemanticLightningModule
26
- from AR.utils.io import load_yaml_config
27
- from text import cleaned_text_to_sequence
28
- from text.cleaner import text_to_sequence, clean_text
29
- from time import time as ttime
30
- from module.mel_processing import spectrogram_torch
31
- from my_utils import load_audio
32
- import re
33
-
34
- import logging
35
- logging.getLogger('httpx').setLevel(logging.WARNING)
36
- logging.getLogger('httpcore').setLevel(logging.WARNING)
37
- logging.getLogger('multipart').setLevel(logging.WARNING)
38
-
39
- device = "cpu"
40
- is_half = False
41
-
42
- tokenizer = AutoTokenizer.from_pretrained(bert_path)
43
- bert_model=AutoModelForMaskedLM.from_pretrained(bert_path)
44
- if(is_half==True):bert_model=bert_model.half().to(device)
45
- else:bert_model=bert_model.to(device)
46
- # bert_model=bert_model.to(device)
47
- def get_bert_feature(text, word2ph):
48
- with torch.no_grad():
49
- inputs = tokenizer(text, return_tensors="pt")
50
- for i in inputs:
51
- inputs[i] = inputs[i].to(device)#####输入是long不用管精度问题,精度随bert_model
52
- res = bert_model(**inputs, output_hidden_states=True)
53
- res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
54
- assert len(word2ph) == len(text)
55
- phone_level_feature = []
56
- for i in range(len(word2ph)):
57
- repeat_feature = res[i].repeat(word2ph[i], 1)
58
- phone_level_feature.append(repeat_feature)
59
- phone_level_feature = torch.cat(phone_level_feature, dim=0)
60
- # if(is_half==True):phone_level_feature=phone_level_feature.half()
61
- return phone_level_feature.T
62
-
63
-
64
- def load_model(sovits_path, gpt_path):
65
- n_semantic = 1024
66
- dict_s2 = torch.load(sovits_path, map_location="cpu")
67
- hps = dict_s2["config"]
68
-
69
- class DictToAttrRecursive:
70
- def __init__(self, input_dict):
71
- for key, value in input_dict.items():
72
- if isinstance(value, dict):
73
- # 如果值是字典,递归调用构造函数
74
- setattr(self, key, DictToAttrRecursive(value))
75
- else:
76
- setattr(self, key, value)
77
-
78
- hps = DictToAttrRecursive(hps)
79
- hps.model.semantic_frame_rate = "25hz"
80
- dict_s1 = torch.load(gpt_path, map_location="cpu")
81
- config = dict_s1["config"]
82
- ssl_model = cnhubert.get_model()
83
- if (is_half == True):
84
- ssl_model = ssl_model.half().to(device)
85
- else:
86
- ssl_model = ssl_model.to(device)
87
-
88
- vq_model = SynthesizerTrn(
89
- hps.data.filter_length // 2 + 1,
90
- hps.train.segment_size // hps.data.hop_length,
91
- n_speakers=hps.data.n_speakers,
92
- **hps.model)
93
- if (is_half == True):
94
- vq_model = vq_model.half().to(device)
95
- else:
96
- vq_model = vq_model.to(device)
97
- vq_model.eval()
98
- vq_model.load_state_dict(dict_s2["weight"], strict=False)
99
- hz = 50
100
- max_sec = config['data']['max_sec']
101
- # t2s_model = Text2SemanticLightningModule.load_from_checkpoint(checkpoint_path=gpt_path, config=config, map_location="cpu")#########todo
102
- t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
103
- t2s_model.load_state_dict(dict_s1["weight"])
104
- if (is_half == True): t2s_model = t2s_model.half()
105
- t2s_model = t2s_model.to(device)
106
- t2s_model.eval()
107
- total = sum([param.nelement() for param in t2s_model.parameters()])
108
- print("Number of parameter: %.2fM" % (total / 1e6))
109
- return vq_model, ssl_model, t2s_model, hps, config, hz, max_sec
110
-
111
-
112
- def get_spepc(hps, filename):
113
- audio=load_audio(filename,int(hps.data.sampling_rate))
114
- audio=torch.FloatTensor(audio)
115
- audio_norm = audio
116
- audio_norm = audio_norm.unsqueeze(0)
117
- spec = spectrogram_torch(audio_norm, hps.data.filter_length,hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,center=False)
118
- return spec
119
-
120
-
121
- def create_tts_fn(vq_model, ssl_model, t2s_model, hps, config, hz, max_sec):
122
- def tts_fn(ref_wav_path, prompt_text, prompt_language, text, text_language):
123
- t0 = ttime()
124
- prompt_text=prompt_text.strip("\n")
125
- prompt_language,text=prompt_language,text.strip("\n")
126
- print(text)
127
- if len(text) > 50:
128
- return f"Error: Text is too long, ({len(text)}>50)", None
129
- with torch.no_grad():
130
- wav16k, sr = librosa.load(ref_wav_path, sr=16000) # 派蒙
131
- wav16k = torch.from_numpy(wav16k)
132
- if(is_half==True):wav16k=wav16k.half().to(device)
133
- else:wav16k=wav16k.to(device)
134
- ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2)#.float()
135
- codes = vq_model.extract_latent(ssl_content)
136
- prompt_semantic = codes[0, 0]
137
- t1 = ttime()
138
- phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
139
- phones1=cleaned_text_to_sequence(phones1)
140
- texts=text.split("\n")
141
- audio_opt = []
142
- zero_wav=np.zeros(int(hps.data.sampling_rate*0.3),dtype=np.float16 if is_half==True else np.float32)
143
- for text in texts:
144
- phones2, word2ph2, norm_text2 = clean_text(text, text_language)
145
- phones2 = cleaned_text_to_sequence(phones2)
146
- if(prompt_language=="zh"):bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
147
- else:bert1 = torch.zeros((1024, len(phones1)),dtype=torch.float16 if is_half==True else torch.float32).to(device)
148
- if(text_language=="zh"):bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
149
- else:bert2 = torch.zeros((1024, len(phones2))).to(bert1)
150
- bert = torch.cat([bert1, bert2], 1)
151
-
152
- all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
153
- bert = bert.to(device).unsqueeze(0)
154
- all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
155
- prompt = prompt_semantic.unsqueeze(0).to(device)
156
- t2 = ttime()
157
- with torch.no_grad():
158
- # pred_semantic = t2s_model.model.infer(
159
- pred_semantic,idx = t2s_model.model.infer_panel(
160
- all_phoneme_ids,
161
- all_phoneme_len,
162
- prompt,
163
- bert,
164
- # prompt_phone_len=ph_offset,
165
- top_k=config['inference']['top_k'],
166
- early_stop_num=hz * max_sec)
167
- t3 = ttime()
168
- # print(pred_semantic.shape,idx)
169
- pred_semantic = pred_semantic[:,-idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
170
- refer = get_spepc(hps, ref_wav_path)#.to(device)
171
- if(is_half==True):refer=refer.half().to(device)
172
- else:refer=refer.to(device)
173
- # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
174
- audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer).detach().cpu().numpy()[0, 0]###试试重建不带上prompt部分
175
- audio_opt.append(audio)
176
- audio_opt.append(zero_wav)
177
- t4 = ttime()
178
- print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
179
- return "Success", (hps.data.sampling_rate,(np.concatenate(audio_opt,0)*32768).astype(np.int16))
180
- return tts_fn
181
-
182
-
183
- splits={"","。","?","!",",",".","?","!","~",":",":","—","…",}#不考虑省略号
184
- def split(todo_text):
185
- todo_text = todo_text.replace("……", "。").replace("——", ",")
186
- if (todo_text[-1] not in splits): todo_text += "。"
187
- i_split_head = i_split_tail = 0
188
- len_text = len(todo_text)
189
- todo_texts = []
190
- while (1):
191
- if (i_split_head >= len_text): break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
192
- if (todo_text[i_split_head] in splits):
193
- i_split_head += 1
194
- todo_texts.append(todo_text[i_split_tail:i_split_head])
195
- i_split_tail = i_split_head
196
- else:
197
- i_split_head += 1
198
- return todo_texts
199
-
200
-
201
- def change_reference_audio(prompt_text, transcripts):
202
- return transcripts[prompt_text]
203
-
204
- models = []
205
- models_info = {}
206
- models_folder_path = "./Models/None1145"
207
- folder_names = [name for name in os.listdir(models_folder_path) if os.path.isdir(os.path.join(models_folder_path, name))]
208
- for folder_name in folder_names:
209
- speaker = folder_name[11:]
210
- models_info[speaker] = {}
211
- models_info[speaker]["title"] = speaker
212
- pattern = re.compile(r"s(\d+)\.pth$")
213
- max_value = -1
214
- max_file = None
215
- sovits_path = f"{models_folder_path}/{folder_name}/SoVITS_weights"
216
- for filename in os.listdir(sovits_path):
217
- match = pattern.search(filename)
218
- if match:
219
- value = int(match.group(1))
220
- if value > max_value:
221
- max_value = value
222
- max_file = filename
223
- models_info[speaker]["sovits_weight"] = f"{sovits_path}/{max_file}"
224
- pattern = re.compile(r"e(\d+)\.ckpt$")
225
- max_value = -1
226
- max_file = None
227
- gpt_path = f"{models_folder_path}/{folder_name}/GPT_weights"
228
- for filename in os.listdir(gpt_path):
229
- match = pattern.search(filename)
230
- if match:
231
- value = int(match.group(1))
232
- if value > max_value:
233
- max_value = value
234
- max_file = filename
235
- models_info[speaker]["gpt_weight"] = f"{gpt_path}/{max_file}"
236
- data_path = f"{models_folder_path}/{folder_name}/Data"
237
- models_info[speaker]["transcript"] = {}
238
- with open(f"{data_path}/{speaker}.list", "r", encoding="utf-8") as f:
239
- for line in f.read().split("\n"):
240
- wav = f"{models_folder_path}/{folder_name}/Data/{line.split("|")[0].split("/")[1]}"
241
- text = line.split("|")[3]
242
- models_info[speaker]["transcript"][text] = wav
243
- models_info[speaker]["example_reference"] = text
244
- for speaker in models_info:
245
- speaker_info = models_info[speaker]
246
- title = speaker_info["title"]
247
- sovits_weight = speaker_info["sovits_weight"]
248
- gpt_weight = speaker_info["gpt_weight"]
249
- vq_model, ssl_model, t2s_model, hps, config, hz, max_sec = load_model(sovits_weight, gpt_weight)
250
- models.append(
251
- (
252
- speaker,
253
- title,
254
- speaker_info["transcript"],
255
- speaker_info["example_reference"],
256
- create_tts_fn(
257
- vq_model, ssl_model, t2s_model, hps, config, hz, max_sec
258
- )
259
- )
260
- )
261
-
262
- with gr.Blocks() as app:
263
- with gr.Tabs():
264
- for (name, title, transcript, example_reference, tts_fn) in models:
265
- with gr.TabItem(name):
266
- with gr.Row():
267
- gr.Markdown(
268
- '<div align="center">'
269
- f'<a><strong>{title}</strong></a>'
270
- '</div>')
271
- with gr.Row():
272
- with gr.Column():
273
- prompt_text = gr.Dropdown(
274
- label="Transcript of the Reference Audio",
275
- value=example_reference,
276
- choices=list(transcript.keys())
277
- )
278
- inp_ref_audio = gr.Audio(
279
- label="Reference Audio",
280
- type="filepath",
281
- interactive=False,
282
- value=transcript[example_reference]
283
- )
284
- transcripts_state = gr.State(value=transcript)
285
- prompt_text.change(
286
- fn=change_reference_audio,
287
- inputs=[prompt_text, transcripts_state],
288
- outputs=[inp_ref_audio]
289
- )
290
- prompt_language = gr.State(value="zh")
291
- with gr.Column():
292
- text = gr.Textbox(label="Input Text", value="你好。")
293
- text_language = gr.Dropdown(
294
- label="Language",
295
- choices=["zh", "en", "ja"],
296
- value="ja"
297
- )
298
- inference_button = gr.Button("Generate", variant="primary")
299
- om = gr.Textbox(label="Output Message")
300
- output = gr.Audio(label="Output Audio")
301
- inference_button.click(
302
- fn=tts_fn,
303
- inputs=[inp_ref_audio, prompt_text, prompt_language, text, text_language],
304
- outputs=[om, output]
305
- )
306
-
 
 
 
 
 
 
307
  app.queue().launch()
 
1
+ import os
2
+ from huggingface_hub import snapshot_download
3
+ print("Models...")
4
+ snapshot_download(repo_id="None1145/GPT-SoVITS-Lappland-the-Decadenza", cache_dir="./Models")
5
+ print("Models!!!")
6
+ print("PretrainedModels...")
7
+ snapshot_download(repo_id="None1145/GPT-SoVITS-Base", cache_dir="./PretrainedModels")
8
+ print("PretrainedModels!!!")
9
+ cnhubert_base_path = "PretrainedModels/chinese-hubert-base"
10
+ bert_path = "PretrainedModels/chinese-roberta-wwm-ext-large"
11
+
12
+ import gradio as gr
13
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
14
+ import sys,torch,numpy as np
15
+ from pathlib import Path
16
+ import os,pdb,utils,librosa,math,traceback,requests,argparse,torch,multiprocessing,pandas as pd,torch.multiprocessing as mp,soundfile
17
+ # torch.backends.cuda.sdp_kernel("flash")
18
+ # torch.backends.cuda.enable_flash_sdp(True)
19
+ # torch.backends.cuda.enable_mem_efficient_sdp(True) # Not avaliable if torch version is lower than 2.0
20
+ # torch.backends.cuda.enable_math_sdp(True)
21
+ from random import shuffle
22
+ from AR.utils import get_newest_ckpt
23
+ from glob import glob
24
+ from tqdm import tqdm
25
+ from feature_extractor import cnhubert
26
+ cnhubert.cnhubert_base_path=cnhubert_base_path
27
+ from io import BytesIO
28
+ from module.models import SynthesizerTrn
29
+ from AR.models.t2s_lightning_module import Text2SemanticLightningModule
30
+ from AR.utils.io import load_yaml_config
31
+ from text import cleaned_text_to_sequence
32
+ from text.cleaner import text_to_sequence, clean_text
33
+ from time import time as ttime
34
+ from module.mel_processing import spectrogram_torch
35
+ from my_utils import load_audio
36
+ import re
37
+
38
+ import logging
39
+ logging.getLogger('httpx').setLevel(logging.WARNING)
40
+ logging.getLogger('httpcore').setLevel(logging.WARNING)
41
+ logging.getLogger('multipart').setLevel(logging.WARNING)
42
+
43
+ device = "cpu"
44
+ is_half = False
45
+
46
+ tokenizer = AutoTokenizer.from_pretrained(bert_path)
47
+ bert_model=AutoModelForMaskedLM.from_pretrained(bert_path)
48
+ if(is_half==True):bert_model=bert_model.half().to(device)
49
+ else:bert_model=bert_model.to(device)
50
+ # bert_model=bert_model.to(device)
51
+ def get_bert_feature(text, word2ph):
52
+ with torch.no_grad():
53
+ inputs = tokenizer(text, return_tensors="pt")
54
+ for i in inputs:
55
+ inputs[i] = inputs[i].to(device)#####输入是long不用管精度问题,精度随bert_model
56
+ res = bert_model(**inputs, output_hidden_states=True)
57
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
58
+ assert len(word2ph) == len(text)
59
+ phone_level_feature = []
60
+ for i in range(len(word2ph)):
61
+ repeat_feature = res[i].repeat(word2ph[i], 1)
62
+ phone_level_feature.append(repeat_feature)
63
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
64
+ # if(is_half==True):phone_level_feature=phone_level_feature.half()
65
+ return phone_level_feature.T
66
+
67
+
68
+ def load_model(sovits_path, gpt_path):
69
+ n_semantic = 1024
70
+ dict_s2 = torch.load(sovits_path, map_location="cpu")
71
+ hps = dict_s2["config"]
72
+
73
+ class DictToAttrRecursive:
74
+ def __init__(self, input_dict):
75
+ for key, value in input_dict.items():
76
+ if isinstance(value, dict):
77
+ # 如果值是字典,递归调用构造函数
78
+ setattr(self, key, DictToAttrRecursive(value))
79
+ else:
80
+ setattr(self, key, value)
81
+
82
+ hps = DictToAttrRecursive(hps)
83
+ hps.model.semantic_frame_rate = "25hz"
84
+ dict_s1 = torch.load(gpt_path, map_location="cpu")
85
+ config = dict_s1["config"]
86
+ ssl_model = cnhubert.get_model()
87
+ if (is_half == True):
88
+ ssl_model = ssl_model.half().to(device)
89
+ else:
90
+ ssl_model = ssl_model.to(device)
91
+
92
+ vq_model = SynthesizerTrn(
93
+ hps.data.filter_length // 2 + 1,
94
+ hps.train.segment_size // hps.data.hop_length,
95
+ n_speakers=hps.data.n_speakers,
96
+ **hps.model)
97
+ if (is_half == True):
98
+ vq_model = vq_model.half().to(device)
99
+ else:
100
+ vq_model = vq_model.to(device)
101
+ vq_model.eval()
102
+ vq_model.load_state_dict(dict_s2["weight"], strict=False)
103
+ hz = 50
104
+ max_sec = config['data']['max_sec']
105
+ # t2s_model = Text2SemanticLightningModule.load_from_checkpoint(checkpoint_path=gpt_path, config=config, map_location="cpu")#########todo
106
+ t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
107
+ t2s_model.load_state_dict(dict_s1["weight"])
108
+ if (is_half == True): t2s_model = t2s_model.half()
109
+ t2s_model = t2s_model.to(device)
110
+ t2s_model.eval()
111
+ total = sum([param.nelement() for param in t2s_model.parameters()])
112
+ print("Number of parameter: %.2fM" % (total / 1e6))
113
+ return vq_model, ssl_model, t2s_model, hps, config, hz, max_sec
114
+
115
+
116
+ def get_spepc(hps, filename):
117
+ audio=load_audio(filename,int(hps.data.sampling_rate))
118
+ audio=torch.FloatTensor(audio)
119
+ audio_norm = audio
120
+ audio_norm = audio_norm.unsqueeze(0)
121
+ spec = spectrogram_torch(audio_norm, hps.data.filter_length,hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,center=False)
122
+ return spec
123
+
124
+
125
+ def create_tts_fn(vq_model, ssl_model, t2s_model, hps, config, hz, max_sec):
126
+ def tts_fn(ref_wav_path, prompt_text, prompt_language, text, text_language):
127
+ t0 = ttime()
128
+ prompt_text=prompt_text.strip("\n")
129
+ prompt_language,text=prompt_language,text.strip("\n")
130
+ print(text)
131
+ if len(text) > 50:
132
+ return f"Error: Text is too long, ({len(text)}>50)", None
133
+ with torch.no_grad():
134
+ wav16k, sr = librosa.load(ref_wav_path, sr=16000) # 派蒙
135
+ wav16k = torch.from_numpy(wav16k)
136
+ if(is_half==True):wav16k=wav16k.half().to(device)
137
+ else:wav16k=wav16k.to(device)
138
+ ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2)#.float()
139
+ codes = vq_model.extract_latent(ssl_content)
140
+ prompt_semantic = codes[0, 0]
141
+ t1 = ttime()
142
+ phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
143
+ phones1=cleaned_text_to_sequence(phones1)
144
+ texts=text.split("\n")
145
+ audio_opt = []
146
+ zero_wav=np.zeros(int(hps.data.sampling_rate*0.3),dtype=np.float16 if is_half==True else np.float32)
147
+ for text in texts:
148
+ phones2, word2ph2, norm_text2 = clean_text(text, text_language)
149
+ phones2 = cleaned_text_to_sequence(phones2)
150
+ if(prompt_language=="zh"):bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
151
+ else:bert1 = torch.zeros((1024, len(phones1)),dtype=torch.float16 if is_half==True else torch.float32).to(device)
152
+ if(text_language=="zh"):bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
153
+ else:bert2 = torch.zeros((1024, len(phones2))).to(bert1)
154
+ bert = torch.cat([bert1, bert2], 1)
155
+
156
+ all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
157
+ bert = bert.to(device).unsqueeze(0)
158
+ all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
159
+ prompt = prompt_semantic.unsqueeze(0).to(device)
160
+ t2 = ttime()
161
+ with torch.no_grad():
162
+ # pred_semantic = t2s_model.model.infer(
163
+ pred_semantic,idx = t2s_model.model.infer_panel(
164
+ all_phoneme_ids,
165
+ all_phoneme_len,
166
+ prompt,
167
+ bert,
168
+ # prompt_phone_len=ph_offset,
169
+ top_k=config['inference']['top_k'],
170
+ early_stop_num=hz * max_sec)
171
+ t3 = ttime()
172
+ # print(pred_semantic.shape,idx)
173
+ pred_semantic = pred_semantic[:,-idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
174
+ refer = get_spepc(hps, ref_wav_path)#.to(device)
175
+ if(is_half==True):refer=refer.half().to(device)
176
+ else:refer=refer.to(device)
177
+ # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
178
+ audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer).detach().cpu().numpy()[0, 0]###试试重建不带上prompt部分
179
+ audio_opt.append(audio)
180
+ audio_opt.append(zero_wav)
181
+ t4 = ttime()
182
+ print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
183
+ return "Success", (hps.data.sampling_rate,(np.concatenate(audio_opt,0)*32768).astype(np.int16))
184
+ return tts_fn
185
+
186
+
187
+ splits={",","。","?","!",",",".","?","!","~",":",":","—","…",}#不考虑省略号
188
+ def split(todo_text):
189
+ todo_text = todo_text.replace("……", "。").replace("——", ",")
190
+ if (todo_text[-1] not in splits): todo_text += "。"
191
+ i_split_head = i_split_tail = 0
192
+ len_text = len(todo_text)
193
+ todo_texts = []
194
+ while (1):
195
+ if (i_split_head >= len_text): break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
196
+ if (todo_text[i_split_head] in splits):
197
+ i_split_head += 1
198
+ todo_texts.append(todo_text[i_split_tail:i_split_head])
199
+ i_split_tail = i_split_head
200
+ else:
201
+ i_split_head += 1
202
+ return todo_texts
203
+
204
+
205
+ def change_reference_audio(prompt_text, transcripts):
206
+ return transcripts[prompt_text]
207
+
208
+ models = []
209
+ models_info = {}
210
+ models_folder_path = "./Models/None1145"
211
+ folder_names = [name for name in os.listdir(models_folder_path) if os.path.isdir(os.path.join(models_folder_path, name))]
212
+ for folder_name in folder_names:
213
+ speaker = folder_name[11:]
214
+ models_info[speaker] = {}
215
+ models_info[speaker]["title"] = speaker
216
+ pattern = re.compile(r"s(\d+)\.pth$")
217
+ max_value = -1
218
+ max_file = None
219
+ sovits_path = f"{models_folder_path}/{folder_name}/SoVITS_weights"
220
+ for filename in os.listdir(sovits_path):
221
+ match = pattern.search(filename)
222
+ if match:
223
+ value = int(match.group(1))
224
+ if value > max_value:
225
+ max_value = value
226
+ max_file = filename
227
+ models_info[speaker]["sovits_weight"] = f"{sovits_path}/{max_file}"
228
+ pattern = re.compile(r"e(\d+)\.ckpt$")
229
+ max_value = -1
230
+ max_file = None
231
+ gpt_path = f"{models_folder_path}/{folder_name}/GPT_weights"
232
+ for filename in os.listdir(gpt_path):
233
+ match = pattern.search(filename)
234
+ if match:
235
+ value = int(match.group(1))
236
+ if value > max_value:
237
+ max_value = value
238
+ max_file = filename
239
+ models_info[speaker]["gpt_weight"] = f"{gpt_path}/{max_file}"
240
+ data_path = f"{models_folder_path}/{folder_name}/Data"
241
+ models_info[speaker]["transcript"] = {}
242
+ with open(f"{data_path}/{speaker}.list", "r", encoding="utf-8") as f:
243
+ for line in f.read().split("\n"):
244
+ wav = f"{models_folder_path}/{folder_name}/Data/{line.split('|')[0].split('/')[1]}"
245
+ text = line.split("|")[3]
246
+ models_info[speaker]["transcript"][text] = wav
247
+ models_info[speaker]["example_reference"] = text
248
+ print(models_info)
249
+ for speaker in models_info:
250
+ speaker_info = models_info[speaker]
251
+ title = speaker_info["title"]
252
+ sovits_weight = speaker_info["sovits_weight"]
253
+ gpt_weight = speaker_info["gpt_weight"]
254
+ vq_model, ssl_model, t2s_model, hps, config, hz, max_sec = load_model(sovits_weight, gpt_weight)
255
+ models.append(
256
+ (
257
+ speaker,
258
+ title,
259
+ speaker_info["transcript"],
260
+ speaker_info["example_reference"],
261
+ create_tts_fn(
262
+ vq_model, ssl_model, t2s_model, hps, config, hz, max_sec
263
+ )
264
+ )
265
+ )
266
+ print(models)
267
+
268
+ with gr.Blocks() as app:
269
+ with gr.Tabs():
270
+ for (name, title, transcript, example_reference, tts_fn) in models:
271
+ with gr.TabItem(name):
272
+ with gr.Row():
273
+ gr.Markdown(
274
+ '<div align="center">'
275
+ f'<a><strong>{title}</strong></a>'
276
+ '</div>')
277
+ with gr.Row():
278
+ with gr.Column():
279
+ prompt_text = gr.Dropdown(
280
+ label="Transcript of the Reference Audio",
281
+ value=example_reference,
282
+ choices=list(transcript.keys())
283
+ )
284
+ inp_ref_audio = gr.Audio(
285
+ label="Reference Audio",
286
+ type="filepath",
287
+ interactive=False,
288
+ value=transcript[example_reference]
289
+ )
290
+ transcripts_state = gr.State(value=transcript)
291
+ prompt_text.change(
292
+ fn=change_reference_audio,
293
+ inputs=[prompt_text, transcripts_state],
294
+ outputs=[inp_ref_audio]
295
+ )
296
+ prompt_language = gr.State(value="zh")
297
+ with gr.Column():
298
+ text = gr.Textbox(label="Input Text", value="你好。")
299
+ text_language = gr.Dropdown(
300
+ label="Language",
301
+ choices=["zh", "en", "ja"],
302
+ value="ja"
303
+ )
304
+ inference_button = gr.Button("Generate", variant="primary")
305
+ om = gr.Textbox(label="Output Message")
306
+ output = gr.Audio(label="Output Audio")
307
+ inference_button.click(
308
+ fn=tts_fn,
309
+ inputs=[inp_ref_audio, prompt_text, prompt_language, text, text_language],
310
+ outputs=[om, output]
311
+ )
312
+
313
  app.queue().launch()