AkitoP commited on
Commit
c3fbe2e
·
verified ·
1 Parent(s): 2cb4f1c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +383 -383
app.py CHANGED
@@ -1,384 +1,384 @@
1
- import os
2
- import sys
3
- # to avoid the modified user.pth file
4
- cnhubert_base_path = "GPT_SoVITS\pretrained_models\chinese-hubert-base"
5
- bert_path = "GPT_SoVITS\pretrained_models\chinese-roberta-wwm-ext-large"
6
- os.environ["version"] = 'v2'
7
- now_dir = os.getcwd()
8
- sys.path.insert(0, now_dir)
9
- import gradio as gr
10
- from transformers import AutoModelForMaskedLM, AutoTokenizer
11
- import numpy as np
12
- from pathlib import Path
13
- import os,librosa,torch, audiosegment
14
- from scipy.io.wavfile import write as wavwrite
15
- from GPT_SoVITS.feature_extractor import cnhubert
16
- cnhubert.cnhubert_base_path=cnhubert_base_path
17
- from GPT_SoVITS.module.models import SynthesizerTrn
18
- from GPT_SoVITS.AR.models.t2s_lightning_module import Text2SemanticLightningModule
19
- from GPT_SoVITS.text import cleaned_text_to_sequence
20
- from GPT_SoVITS.text.cleaner import clean_text
21
- from time import time as ttime
22
- from GPT_SoVITS.module.mel_processing import spectrogram_torch
23
- import tempfile
24
- from tools.my_utils import load_audio
25
- import os
26
- import json
27
-
28
- ################ End strange import and user.pth modification ################
29
-
30
- # import pyopenjtalk
31
- # cwd = os.getcwd()
32
- # if os.path.exists(os.path.join(cwd,'user.dic')):
33
- # pyopenjtalk.update_global_jtalk_with_user_dict(os.path.join(cwd, 'user.dic'))
34
-
35
-
36
- import logging
37
- logging.getLogger('httpx').setLevel(logging.WARNING)
38
- logging.getLogger('httpcore').setLevel(logging.WARNING)
39
- logging.getLogger('multipart').setLevel(logging.WARNING)
40
-
41
- device = "cuda" if torch.cuda.is_available() else "cpu"
42
- #device = "cpu"
43
- is_half = False
44
-
45
- tokenizer = AutoTokenizer.from_pretrained(bert_path)
46
- bert_model=AutoModelForMaskedLM.from_pretrained(bert_path)
47
- if(is_half==True):bert_model=bert_model.half().to(device)
48
- else:bert_model=bert_model.to(device)
49
- # bert_model=bert_model.to(device)
50
- def get_bert_feature(text, word2ph): # Bert(不是HuBERT的特征计算)
51
- with torch.no_grad():
52
- inputs = tokenizer(text, return_tensors="pt")
53
- for i in inputs:
54
- inputs[i] = inputs[i].to(device)#####输入是long不用管精度问题,精度随bert_model
55
- res = bert_model(**inputs, output_hidden_states=True)
56
- res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
57
- assert len(word2ph) == len(text)
58
- phone_level_feature = []
59
- for i in range(len(word2ph)):
60
- repeat_feature = res[i].repeat(word2ph[i], 1)
61
- phone_level_feature.append(repeat_feature)
62
- phone_level_feature = torch.cat(phone_level_feature, dim=0)
63
- # if(is_half==True):phone_level_feature=phone_level_feature.half()
64
- return phone_level_feature.T
65
-
66
- loaded_sovits_model = [] # [(path, dict, model)]
67
- loaded_gpt_model = []
68
- ssl_model = cnhubert.get_model()
69
- if (is_half == True):
70
- ssl_model = ssl_model.half().to(device)
71
- else:
72
- ssl_model = ssl_model.to(device)
73
-
74
-
75
- def load_model(sovits_path, gpt_path):
76
- global ssl_model
77
- global loaded_sovits_model
78
- global loaded_gpt_model
79
- vq_model = None
80
- t2s_model = None
81
- dict_s2 = None
82
- dict_s1 = None
83
- hps = None
84
- for path, dict_s2_, model in loaded_sovits_model:
85
- if path == sovits_path:
86
- vq_model = model
87
- dict_s2 = dict_s2_
88
- break
89
- for path, dict_s1_, model in loaded_gpt_model:
90
- if path == gpt_path:
91
- t2s_model = model
92
- dict_s1 = dict_s1_
93
- break
94
-
95
- if dict_s2 is None:
96
- dict_s2 = torch.load(sovits_path, map_location="cpu")
97
- hps = dict_s2["config"]
98
-
99
- if dict_s1 is None:
100
- dict_s1 = torch.load(gpt_path, map_location="cpu")
101
- config = dict_s1["config"]
102
- class DictToAttrRecursive:
103
- def __init__(self, input_dict):
104
- for key, value in input_dict.items():
105
- if isinstance(value, dict):
106
- # 如果值是字典,递归调用构造函数
107
- setattr(self, key, DictToAttrRecursive(value))
108
- else:
109
- setattr(self, key, value)
110
-
111
- hps = DictToAttrRecursive(hps)
112
- hps.model.semantic_frame_rate = "25hz"
113
-
114
-
115
- if not vq_model:
116
- vq_model = SynthesizerTrn(
117
- hps.data.filter_length // 2 + 1,
118
- hps.train.segment_size // hps.data.hop_length,
119
- n_speakers=hps.data.n_speakers,
120
- **hps.model)
121
- if (is_half == True):
122
- vq_model = vq_model.half().to(device)
123
- else:
124
- vq_model = vq_model.to(device)
125
- vq_model.eval()
126
- vq_model.load_state_dict(dict_s2["weight"], strict=False)
127
- loaded_sovits_model.append((sovits_path, dict_s2, vq_model))
128
- hz = 50
129
- max_sec = config['data']['max_sec']
130
- if not t2s_model:
131
- t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
132
- t2s_model.load_state_dict(dict_s1["weight"])
133
- if (is_half == True): t2s_model = t2s_model.half()
134
- t2s_model = t2s_model.to(device)
135
- t2s_model.eval()
136
- total = sum([param.nelement() for param in t2s_model.parameters()])
137
- loaded_gpt_model.append((gpt_path, dict_s1, t2s_model))
138
- return vq_model, ssl_model, t2s_model, hps, config, hz, max_sec
139
-
140
-
141
- def get_spepc(hps, filename):
142
- audio=load_audio(filename,int(hps.data.sampling_rate))
143
- audio = audio / np.max(np.abs(audio))
144
- audio=torch.FloatTensor(audio)
145
- audio_norm = audio
146
- # audio_norm = audio / torch.max(torch.abs(audio))
147
- audio_norm = audio_norm.unsqueeze(0)
148
- spec = spectrogram_torch(audio_norm, hps.data.filter_length,hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,center=False)
149
- return spec
150
-
151
- def create_tts_fn(vq_model, ssl_model, t2s_model, hps, config, hz, max_sec):
152
- def tts_fn(ref_wav_path, prompt_text, prompt_language, target_phone, text_language, target_text = None):
153
- t0 = ttime()
154
- prompt_text=prompt_text.strip()
155
- prompt_language=prompt_language
156
- with torch.no_grad():
157
- wav16k, sr = librosa.load(ref_wav_path, sr=16000, mono=False)
158
- direction = np.array([1,1])
159
- if wav16k.ndim == 2:
160
- power = np.sum(np.abs(wav16k) ** 2, axis=1)
161
- direction = power / np.sum(power)
162
- wav16k = (wav16k[0] + wav16k[1]) / 2
163
- #
164
- # maxx=0.95
165
- # tmp_max = np.abs(wav16k).max()
166
- # alpha=0.5
167
- # wav16k = (wav16k / tmp_max * (maxx * alpha*32768)) + ((1 - alpha)*32768) * wav16k
168
- #在这里归一化
169
- #print(max(np.abs(wav16k)))
170
- #wav16k = wav16k / np.max(np.abs(wav16k))
171
- #print(max(np.abs(wav16k)))
172
- # 添加0.3s的静音
173
- wav16k = np.concatenate([wav16k, np.zeros(int(hps.data.sampling_rate * 0.3)),])
174
- wav16k = torch.from_numpy(wav16k)
175
- wav16k = wav16k.float()
176
- if(is_half==True):wav16k=wav16k.half().to(device)
177
- else:wav16k=wav16k.to(device)
178
- ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2)#.float()
179
- codes = vq_model.extract_latent(ssl_content)
180
- prompt_semantic = codes[0, 0]
181
- t1 = ttime()
182
- phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
183
- phones1=cleaned_text_to_sequence(phones1)
184
- #texts=text.split("\n")
185
- audio_opt = []
186
- zero_wav=np.zeros((2, int(hps.data.sampling_rate*0.3)),dtype=np.float16 if is_half==True else np.float32)
187
- phones = get_phone_from_str_list(target_phone, text_language)
188
- for phones2 in phones:
189
- if(len(phones2) == 0):
190
- continue
191
- if(len(phones2) == 1 and phones2[0] == ""):
192
- continue
193
- #phones2, word2ph2, norm_text2 = clean_text(text, text_language)
194
- phones2 = cleaned_text_to_sequence(phones2)
195
- #if(prompt_language=="zh"):bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
196
- bert1 = torch.zeros((1024, len(phones1)),dtype=torch.float16 if is_half==True else torch.float32).to(device)
197
- #if(text_language=="zh"):bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
198
- bert2 = torch.zeros((1024, len(phones2))).to(bert1)
199
- bert = torch.cat([bert1, bert2], 1)
200
-
201
- all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
202
- bert = bert.to(device).unsqueeze(0)
203
- all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
204
- prompt = prompt_semantic.unsqueeze(0).to(device)
205
- t2 = ttime()
206
- idx = 0
207
- cnt = 0
208
- while idx == 0 and cnt < 2:
209
- with torch.no_grad():
210
- # pred_semantic = t2s_model.model.infer
211
- pred_semantic,idx = t2s_model.model.infer_panel(
212
- all_phoneme_ids,
213
- all_phoneme_len,
214
- prompt,
215
- bert,
216
- # prompt_phone_len=ph_offset,
217
- top_k=config['inference']['top_k'],
218
- early_stop_num=hz * max_sec)
219
- t3 = ttime()
220
- cnt+=1
221
- if idx == 0:
222
- return "Error: Generation failure: bad zero prediction.", None
223
- pred_semantic = pred_semantic[:,-idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
224
- refer = get_spepc(hps, ref_wav_path)#.to(device)
225
- if(is_half==True):refer=refer.half().to(device)
226
- else:refer=refer.to(device)
227
- # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
228
- audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer).detach().cpu().numpy()[0, 0]###试试重建不带上prompt部分
229
- # direction乘上,变双通道
230
- # 强制0.5
231
- direction = np.array([1, 1])
232
- audio = np.expand_dims(audio, 0) * direction[:, np.newaxis]
233
- audio_opt.append(audio)
234
- audio_opt.append(zero_wav)
235
- t4 = ttime()
236
-
237
- audio = (hps.data.sampling_rate,(np.concatenate(audio_opt, axis=1)*32768).astype(np.int16).T)
238
- prefix_1 = prompt_text[:8].replace(" ", "_").replace("\n", "_").replace("?","_").replace("!","_").replace(",","_")
239
- prefix_2 = target_text[:8].replace(" ", "_").replace("\n", "_").replace("?","_").replace("!","_").replace(",","_")
240
- filename = tempfile.mktemp(suffix=".wav",prefix=f"{prefix_1}_{prefix_2}_")
241
- #audiosegment.from_numpy_array(audio[1].T, framerate=audio[0]).export(filename, format="WAV")
242
- wavwrite(filename, audio[0], audio[1])
243
- return "Success", audio, filename
244
- return tts_fn
245
-
246
-
247
- def get_str_list_from_phone(text, text_language):
248
- # raw文本过g2p得到音素列表,再转成字符串
249
- # 注意,这里的text是一个段落,可能包含多个句子
250
- # 段落间\n分割,音素间空格分割
251
- print(text)
252
- texts=text.split("\n")
253
- phone_list = []
254
- for text in texts:
255
- phones2, word2ph2, norm_text2 = clean_text(text, text_language)
256
- phone_list.append(" ".join(phones2))
257
- return "\n".join(phone_list)
258
-
259
- def get_phone_from_str_list(str_list:str, language:str = 'ja'):
260
- # 从音素字符串中得到音素列表
261
- # 注意,这里的text是一个段落,可能包含多个句子
262
- # 段落间\n分割,音素间空格分割
263
- sentences = str_list.split("\n")
264
- phones = []
265
- for sentence in sentences:
266
- phones.append(sentence.split(" "))
267
- return phones
268
-
269
- splits={",","。","?","!",",",".","?","!","~",":",":","—","…",}#不考虑省略号
270
- def split(todo_text):
271
- todo_text = todo_text.replace("……", "。").replace("——", ",")
272
- if (todo_text[-1] not in splits): todo_text += "。"
273
- i_split_head = i_split_tail = 0
274
- len_text = len(todo_text)
275
- todo_texts = []
276
- while (1):
277
- if (i_split_head >= len_text): break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
278
- if (todo_text[i_split_head] in splits):
279
- i_split_head += 1
280
- todo_texts.append(todo_text[i_split_tail:i_split_head])
281
- i_split_tail = i_split_head
282
- else:
283
- i_split_head += 1
284
- return todo_texts
285
-
286
-
287
- def change_reference_audio(prompt_text, transcripts):
288
- return transcripts[prompt_text]
289
-
290
-
291
- models = []
292
- models_info = json.load(open("./models/models_info.json", "r", encoding="utf-8"))
293
-
294
-
295
-
296
- for i, info in models_info.items():
297
- title = info['title']
298
- cover = info['cover']
299
- gpt_weight = info['gpt_weight']
300
- sovits_weight = info['sovits_weight']
301
- example_reference = info['example_reference']
302
- transcripts = {}
303
- transcript_path = info["transcript_path"]
304
- path = os.path.dirname(transcript_path)
305
- with open(transcript_path, 'r', encoding='utf-8') as file:
306
- for line in file:
307
- line = line.strip().replace("\\", "/")
308
- items = line.split("|")
309
- wav,t = items[0], items[-1]
310
- wav = os.path.basename(wav)
311
- transcripts[t] = os.path.join(os.path.join(path,"reference_audio"), wav)
312
-
313
- vq_model, ssl_model, t2s_model, hps, config, hz, max_sec = load_model(sovits_weight, gpt_weight)
314
-
315
-
316
- models.append(
317
- (
318
- i,
319
- title,
320
- cover,
321
- transcripts,
322
- example_reference,
323
- create_tts_fn(
324
- vq_model, ssl_model, t2s_model, hps, config, hz, max_sec
325
- )
326
- )
327
- )
328
- with gr.Blocks() as app:
329
- gr.Markdown(
330
- "# <center> GPT-SoVITS Demo\n"
331
- )
332
- with gr.Tabs():
333
- for (name, title, cover, transcripts, example_reference, tts_fn) in models:
334
- with gr.TabItem(name):
335
- with gr.Row():
336
- gr.Markdown(
337
- '<div align="center">'
338
- f'<a><strong>{title}</strong></a>'
339
- '</div>')
340
- with gr.Row():
341
- with gr.Column():
342
- prompt_text = gr.Dropdown(
343
- label="Transcript of the Reference Audio",
344
- value=example_reference if example_reference in transcripts else list(transcripts.keys())[0],
345
- choices=list(transcripts.keys())
346
- )
347
- inp_ref_audio = gr.Audio(
348
- label="Reference Audio",
349
- type="filepath",
350
- interactive=False,
351
- value=transcripts[example_reference] if example_reference in transcripts else list(transcripts.values())[0]
352
- )
353
- transcripts_state = gr.State(value=transcripts)
354
- prompt_text.change(
355
- fn=change_reference_audio,
356
- inputs=[prompt_text, transcripts_state],
357
- outputs=[inp_ref_audio]
358
- )
359
- prompt_language = gr.State(value="ja")
360
- with gr.Column():
361
- text = gr.Textbox(label="Input Text", value="私はお兄ちゃんのだいだいだーいすきな妹なんだから、言うことなんでも聞いてくれますよね!")
362
- text_language = gr.Dropdown(
363
- label="Language",
364
- choices=["ja"],
365
- value="ja"
366
- )
367
- clean_button = gr.Button("Clean Text", variant="primary")
368
- inference_button = gr.Button("Generate", variant="primary")
369
- cleaned_text = gr.Textbox(label="Cleaned Text")
370
- output = gr.Audio(label="Output Audio")
371
- output_file = gr.File(label="Output Audio File")
372
- om = gr.Textbox(label="Output Message")
373
- clean_button.click(
374
- fn=get_str_list_from_phone,
375
- inputs=[text, text_language],
376
- outputs=[cleaned_text]
377
- )
378
- inference_button.click(
379
- fn=tts_fn,
380
- inputs=[inp_ref_audio, prompt_text, prompt_language, cleaned_text, text_language, text],
381
- outputs=[om, output, output_file]
382
- )
383
-
384
  app.launch(share=True)
 
1
+ import os
2
+ import sys
3
+ # to avoid the modified user.pth file
4
+ cnhubert_base_path = "GPT_SoVITS\pretrained_models\chinese-hubert-base"
5
+ bert_path = "GPT_SoVITS\pretrained_models\chinese-roberta-wwm-ext-large"
6
+ os.environ["version"] = 'v2'
7
+ now_dir = os.getcwd()
8
+ sys.path.insert(0, now_dir)
9
+ import gradio as gr
10
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
11
+ import numpy as np
12
+ from pathlib import Path
13
+ import os,librosa,torch
14
+ from scipy.io.wavfile import write as wavwrite
15
+ from GPT_SoVITS.feature_extractor import cnhubert
16
+ cnhubert.cnhubert_base_path=cnhubert_base_path
17
+ from GPT_SoVITS.module.models import SynthesizerTrn
18
+ from GPT_SoVITS.AR.models.t2s_lightning_module import Text2SemanticLightningModule
19
+ from GPT_SoVITS.text import cleaned_text_to_sequence
20
+ from GPT_SoVITS.text.cleaner import clean_text
21
+ from time import time as ttime
22
+ from GPT_SoVITS.module.mel_processing import spectrogram_torch
23
+ import tempfile
24
+ from tools.my_utils import load_audio
25
+ import os
26
+ import json
27
+
28
+ ################ End strange import and user.pth modification ################
29
+
30
+ # import pyopenjtalk
31
+ # cwd = os.getcwd()
32
+ # if os.path.exists(os.path.join(cwd,'user.dic')):
33
+ # pyopenjtalk.update_global_jtalk_with_user_dict(os.path.join(cwd, 'user.dic'))
34
+
35
+
36
+ import logging
37
+ logging.getLogger('httpx').setLevel(logging.WARNING)
38
+ logging.getLogger('httpcore').setLevel(logging.WARNING)
39
+ logging.getLogger('multipart').setLevel(logging.WARNING)
40
+
41
+ device = "cuda" if torch.cuda.is_available() else "cpu"
42
+ #device = "cpu"
43
+ is_half = False
44
+
45
+ tokenizer = AutoTokenizer.from_pretrained(bert_path)
46
+ bert_model=AutoModelForMaskedLM.from_pretrained(bert_path)
47
+ if(is_half==True):bert_model=bert_model.half().to(device)
48
+ else:bert_model=bert_model.to(device)
49
+ # bert_model=bert_model.to(device)
50
+ def get_bert_feature(text, word2ph): # Bert(不是HuBERT的特征计算)
51
+ with torch.no_grad():
52
+ inputs = tokenizer(text, return_tensors="pt")
53
+ for i in inputs:
54
+ inputs[i] = inputs[i].to(device)#####输入是long不用管精度问题,精度随bert_model
55
+ res = bert_model(**inputs, output_hidden_states=True)
56
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
57
+ assert len(word2ph) == len(text)
58
+ phone_level_feature = []
59
+ for i in range(len(word2ph)):
60
+ repeat_feature = res[i].repeat(word2ph[i], 1)
61
+ phone_level_feature.append(repeat_feature)
62
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
63
+ # if(is_half==True):phone_level_feature=phone_level_feature.half()
64
+ return phone_level_feature.T
65
+
66
+ loaded_sovits_model = [] # [(path, dict, model)]
67
+ loaded_gpt_model = []
68
+ ssl_model = cnhubert.get_model()
69
+ if (is_half == True):
70
+ ssl_model = ssl_model.half().to(device)
71
+ else:
72
+ ssl_model = ssl_model.to(device)
73
+
74
+
75
+ def load_model(sovits_path, gpt_path):
76
+ global ssl_model
77
+ global loaded_sovits_model
78
+ global loaded_gpt_model
79
+ vq_model = None
80
+ t2s_model = None
81
+ dict_s2 = None
82
+ dict_s1 = None
83
+ hps = None
84
+ for path, dict_s2_, model in loaded_sovits_model:
85
+ if path == sovits_path:
86
+ vq_model = model
87
+ dict_s2 = dict_s2_
88
+ break
89
+ for path, dict_s1_, model in loaded_gpt_model:
90
+ if path == gpt_path:
91
+ t2s_model = model
92
+ dict_s1 = dict_s1_
93
+ break
94
+
95
+ if dict_s2 is None:
96
+ dict_s2 = torch.load(sovits_path, map_location="cpu")
97
+ hps = dict_s2["config"]
98
+
99
+ if dict_s1 is None:
100
+ dict_s1 = torch.load(gpt_path, map_location="cpu")
101
+ config = dict_s1["config"]
102
+ class DictToAttrRecursive:
103
+ def __init__(self, input_dict):
104
+ for key, value in input_dict.items():
105
+ if isinstance(value, dict):
106
+ # 如果值是字典,递归调用构造函数
107
+ setattr(self, key, DictToAttrRecursive(value))
108
+ else:
109
+ setattr(self, key, value)
110
+
111
+ hps = DictToAttrRecursive(hps)
112
+ hps.model.semantic_frame_rate = "25hz"
113
+
114
+
115
+ if not vq_model:
116
+ vq_model = SynthesizerTrn(
117
+ hps.data.filter_length // 2 + 1,
118
+ hps.train.segment_size // hps.data.hop_length,
119
+ n_speakers=hps.data.n_speakers,
120
+ **hps.model)
121
+ if (is_half == True):
122
+ vq_model = vq_model.half().to(device)
123
+ else:
124
+ vq_model = vq_model.to(device)
125
+ vq_model.eval()
126
+ vq_model.load_state_dict(dict_s2["weight"], strict=False)
127
+ loaded_sovits_model.append((sovits_path, dict_s2, vq_model))
128
+ hz = 50
129
+ max_sec = config['data']['max_sec']
130
+ if not t2s_model:
131
+ t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
132
+ t2s_model.load_state_dict(dict_s1["weight"])
133
+ if (is_half == True): t2s_model = t2s_model.half()
134
+ t2s_model = t2s_model.to(device)
135
+ t2s_model.eval()
136
+ total = sum([param.nelement() for param in t2s_model.parameters()])
137
+ loaded_gpt_model.append((gpt_path, dict_s1, t2s_model))
138
+ return vq_model, ssl_model, t2s_model, hps, config, hz, max_sec
139
+
140
+
141
+ def get_spepc(hps, filename):
142
+ audio=load_audio(filename,int(hps.data.sampling_rate))
143
+ audio = audio / np.max(np.abs(audio))
144
+ audio=torch.FloatTensor(audio)
145
+ audio_norm = audio
146
+ # audio_norm = audio / torch.max(torch.abs(audio))
147
+ audio_norm = audio_norm.unsqueeze(0)
148
+ spec = spectrogram_torch(audio_norm, hps.data.filter_length,hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,center=False)
149
+ return spec
150
+
151
+ def create_tts_fn(vq_model, ssl_model, t2s_model, hps, config, hz, max_sec):
152
+ def tts_fn(ref_wav_path, prompt_text, prompt_language, target_phone, text_language, target_text = None):
153
+ t0 = ttime()
154
+ prompt_text=prompt_text.strip()
155
+ prompt_language=prompt_language
156
+ with torch.no_grad():
157
+ wav16k, sr = librosa.load(ref_wav_path, sr=16000, mono=False)
158
+ direction = np.array([1,1])
159
+ if wav16k.ndim == 2:
160
+ power = np.sum(np.abs(wav16k) ** 2, axis=1)
161
+ direction = power / np.sum(power)
162
+ wav16k = (wav16k[0] + wav16k[1]) / 2
163
+ #
164
+ # maxx=0.95
165
+ # tmp_max = np.abs(wav16k).max()
166
+ # alpha=0.5
167
+ # wav16k = (wav16k / tmp_max * (maxx * alpha*32768)) + ((1 - alpha)*32768) * wav16k
168
+ #在这里归一化
169
+ #print(max(np.abs(wav16k)))
170
+ #wav16k = wav16k / np.max(np.abs(wav16k))
171
+ #print(max(np.abs(wav16k)))
172
+ # 添加0.3s的静音
173
+ wav16k = np.concatenate([wav16k, np.zeros(int(hps.data.sampling_rate * 0.3)),])
174
+ wav16k = torch.from_numpy(wav16k)
175
+ wav16k = wav16k.float()
176
+ if(is_half==True):wav16k=wav16k.half().to(device)
177
+ else:wav16k=wav16k.to(device)
178
+ ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2)#.float()
179
+ codes = vq_model.extract_latent(ssl_content)
180
+ prompt_semantic = codes[0, 0]
181
+ t1 = ttime()
182
+ phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
183
+ phones1=cleaned_text_to_sequence(phones1)
184
+ #texts=text.split("\n")
185
+ audio_opt = []
186
+ zero_wav=np.zeros((2, int(hps.data.sampling_rate*0.3)),dtype=np.float16 if is_half==True else np.float32)
187
+ phones = get_phone_from_str_list(target_phone, text_language)
188
+ for phones2 in phones:
189
+ if(len(phones2) == 0):
190
+ continue
191
+ if(len(phones2) == 1 and phones2[0] == ""):
192
+ continue
193
+ #phones2, word2ph2, norm_text2 = clean_text(text, text_language)
194
+ phones2 = cleaned_text_to_sequence(phones2)
195
+ #if(prompt_language=="zh"):bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
196
+ bert1 = torch.zeros((1024, len(phones1)),dtype=torch.float16 if is_half==True else torch.float32).to(device)
197
+ #if(text_language=="zh"):bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
198
+ bert2 = torch.zeros((1024, len(phones2))).to(bert1)
199
+ bert = torch.cat([bert1, bert2], 1)
200
+
201
+ all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
202
+ bert = bert.to(device).unsqueeze(0)
203
+ all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
204
+ prompt = prompt_semantic.unsqueeze(0).to(device)
205
+ t2 = ttime()
206
+ idx = 0
207
+ cnt = 0
208
+ while idx == 0 and cnt < 2:
209
+ with torch.no_grad():
210
+ # pred_semantic = t2s_model.model.infer
211
+ pred_semantic,idx = t2s_model.model.infer_panel(
212
+ all_phoneme_ids,
213
+ all_phoneme_len,
214
+ prompt,
215
+ bert,
216
+ # prompt_phone_len=ph_offset,
217
+ top_k=config['inference']['top_k'],
218
+ early_stop_num=hz * max_sec)
219
+ t3 = ttime()
220
+ cnt+=1
221
+ if idx == 0:
222
+ return "Error: Generation failure: bad zero prediction.", None
223
+ pred_semantic = pred_semantic[:,-idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
224
+ refer = get_spepc(hps, ref_wav_path)#.to(device)
225
+ if(is_half==True):refer=refer.half().to(device)
226
+ else:refer=refer.to(device)
227
+ # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
228
+ audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer).detach().cpu().numpy()[0, 0]###试试重建不带上prompt部分
229
+ # direction乘上,变双通道
230
+ # 强制0.5
231
+ direction = np.array([1, 1])
232
+ audio = np.expand_dims(audio, 0) * direction[:, np.newaxis]
233
+ audio_opt.append(audio)
234
+ audio_opt.append(zero_wav)
235
+ t4 = ttime()
236
+
237
+ audio = (hps.data.sampling_rate,(np.concatenate(audio_opt, axis=1)*32768).astype(np.int16).T)
238
+ prefix_1 = prompt_text[:8].replace(" ", "_").replace("\n", "_").replace("?","_").replace("!","_").replace(",","_")
239
+ prefix_2 = target_text[:8].replace(" ", "_").replace("\n", "_").replace("?","_").replace("!","_").replace(",","_")
240
+ filename = tempfile.mktemp(suffix=".wav",prefix=f"{prefix_1}_{prefix_2}_")
241
+ #audiosegment.from_numpy_array(audio[1].T, framerate=audio[0]).export(filename, format="WAV")
242
+ wavwrite(filename, audio[0], audio[1])
243
+ return "Success", audio, filename
244
+ return tts_fn
245
+
246
+
247
+ def get_str_list_from_phone(text, text_language):
248
+ # raw文本过g2p得到音素列表,再转成字符串
249
+ # 注意,这里的text是一个段落,可能包含多个句子
250
+ # 段落间\n分割,音素间空格分割
251
+ print(text)
252
+ texts=text.split("\n")
253
+ phone_list = []
254
+ for text in texts:
255
+ phones2, word2ph2, norm_text2 = clean_text(text, text_language)
256
+ phone_list.append(" ".join(phones2))
257
+ return "\n".join(phone_list)
258
+
259
+ def get_phone_from_str_list(str_list:str, language:str = 'ja'):
260
+ # 从音素字符串中得到音素列表
261
+ # 注意,这里的text是一个段落,可能包含多个句子
262
+ # 段落间\n分割,音素间空格分割
263
+ sentences = str_list.split("\n")
264
+ phones = []
265
+ for sentence in sentences:
266
+ phones.append(sentence.split(" "))
267
+ return phones
268
+
269
+ splits={",","。","?","!",",",".","?","!","~",":",":","—","…",}#不考虑省略号
270
+ def split(todo_text):
271
+ todo_text = todo_text.replace("……", "。").replace("——", ",")
272
+ if (todo_text[-1] not in splits): todo_text += "。"
273
+ i_split_head = i_split_tail = 0
274
+ len_text = len(todo_text)
275
+ todo_texts = []
276
+ while (1):
277
+ if (i_split_head >= len_text): break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
278
+ if (todo_text[i_split_head] in splits):
279
+ i_split_head += 1
280
+ todo_texts.append(todo_text[i_split_tail:i_split_head])
281
+ i_split_tail = i_split_head
282
+ else:
283
+ i_split_head += 1
284
+ return todo_texts
285
+
286
+
287
+ def change_reference_audio(prompt_text, transcripts):
288
+ return transcripts[prompt_text]
289
+
290
+
291
+ models = []
292
+ models_info = json.load(open("./models/models_info.json", "r", encoding="utf-8"))
293
+
294
+
295
+
296
+ for i, info in models_info.items():
297
+ title = info['title']
298
+ cover = info['cover']
299
+ gpt_weight = info['gpt_weight']
300
+ sovits_weight = info['sovits_weight']
301
+ example_reference = info['example_reference']
302
+ transcripts = {}
303
+ transcript_path = info["transcript_path"]
304
+ path = os.path.dirname(transcript_path)
305
+ with open(transcript_path, 'r', encoding='utf-8') as file:
306
+ for line in file:
307
+ line = line.strip().replace("\\", "/")
308
+ items = line.split("|")
309
+ wav,t = items[0], items[-1]
310
+ wav = os.path.basename(wav)
311
+ transcripts[t] = os.path.join(os.path.join(path,"reference_audio"), wav)
312
+
313
+ vq_model, ssl_model, t2s_model, hps, config, hz, max_sec = load_model(sovits_weight, gpt_weight)
314
+
315
+
316
+ models.append(
317
+ (
318
+ i,
319
+ title,
320
+ cover,
321
+ transcripts,
322
+ example_reference,
323
+ create_tts_fn(
324
+ vq_model, ssl_model, t2s_model, hps, config, hz, max_sec
325
+ )
326
+ )
327
+ )
328
+ with gr.Blocks() as app:
329
+ gr.Markdown(
330
+ "# <center> GPT-SoVITS Demo\n"
331
+ )
332
+ with gr.Tabs():
333
+ for (name, title, cover, transcripts, example_reference, tts_fn) in models:
334
+ with gr.TabItem(name):
335
+ with gr.Row():
336
+ gr.Markdown(
337
+ '<div align="center">'
338
+ f'<a><strong>{title}</strong></a>'
339
+ '</div>')
340
+ with gr.Row():
341
+ with gr.Column():
342
+ prompt_text = gr.Dropdown(
343
+ label="Transcript of the Reference Audio",
344
+ value=example_reference if example_reference in transcripts else list(transcripts.keys())[0],
345
+ choices=list(transcripts.keys())
346
+ )
347
+ inp_ref_audio = gr.Audio(
348
+ label="Reference Audio",
349
+ type="filepath",
350
+ interactive=False,
351
+ value=transcripts[example_reference] if example_reference in transcripts else list(transcripts.values())[0]
352
+ )
353
+ transcripts_state = gr.State(value=transcripts)
354
+ prompt_text.change(
355
+ fn=change_reference_audio,
356
+ inputs=[prompt_text, transcripts_state],
357
+ outputs=[inp_ref_audio]
358
+ )
359
+ prompt_language = gr.State(value="ja")
360
+ with gr.Column():
361
+ text = gr.Textbox(label="Input Text", value="私はお兄ちゃんのだいだいだーいすきな妹なんだから、言うことなんでも聞いてくれますよね!")
362
+ text_language = gr.Dropdown(
363
+ label="Language",
364
+ choices=["ja"],
365
+ value="ja"
366
+ )
367
+ clean_button = gr.Button("Clean Text", variant="primary")
368
+ inference_button = gr.Button("Generate", variant="primary")
369
+ cleaned_text = gr.Textbox(label="Cleaned Text")
370
+ output = gr.Audio(label="Output Audio")
371
+ output_file = gr.File(label="Output Audio File")
372
+ om = gr.Textbox(label="Output Message")
373
+ clean_button.click(
374
+ fn=get_str_list_from_phone,
375
+ inputs=[text, text_language],
376
+ outputs=[cleaned_text]
377
+ )
378
+ inference_button.click(
379
+ fn=tts_fn,
380
+ inputs=[inp_ref_audio, prompt_text, prompt_language, cleaned_text, text_language, text],
381
+ outputs=[om, output, output_file]
382
+ )
383
+
384
  app.launch(share=True)