Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -119,7 +119,6 @@ def load_model(sovits_path, gpt_path):
|
|
119 |
vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
120 |
hz = 50
|
121 |
max_sec = config['data']['max_sec']
|
122 |
-
# t2s_model = Text2SemanticLightningModule.load_from_checkpoint(checkpoint_path=gpt_path, config=config, map_location="cpu")#########todo
|
123 |
t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
|
124 |
t2s_model.load_state_dict(dict_s1["weight"])
|
125 |
if (is_half == True): t2s_model = t2s_model.half()
|
@@ -148,11 +147,11 @@ def create_tts_fn(vq_model, ssl_model, t2s_model, hps, config, hz, max_sec):
|
|
148 |
if len(text) > 50:
|
149 |
return f"Error: Text is too long, ({len(text)}>50)", None
|
150 |
with torch.no_grad():
|
151 |
-
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
152 |
wav16k = torch.from_numpy(wav16k)
|
153 |
if(is_half==True):wav16k=wav16k.half().to(device)
|
154 |
else:wav16k=wav16k.to(device)
|
155 |
-
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2)
|
156 |
codes = vq_model.extract_latent(ssl_content)
|
157 |
prompt_semantic = codes[0, 0]
|
158 |
t1 = ttime()
|
@@ -176,23 +175,19 @@ def create_tts_fn(vq_model, ssl_model, t2s_model, hps, config, hz, max_sec):
|
|
176 |
prompt = prompt_semantic.unsqueeze(0).to(device)
|
177 |
t2 = ttime()
|
178 |
with torch.no_grad():
|
179 |
-
# pred_semantic = t2s_model.model.infer(
|
180 |
pred_semantic,idx = t2s_model.model.infer_panel(
|
181 |
all_phoneme_ids,
|
182 |
all_phoneme_len,
|
183 |
prompt,
|
184 |
bert,
|
185 |
-
# prompt_phone_len=ph_offset,
|
186 |
top_k=config['inference']['top_k'],
|
187 |
early_stop_num=hz * max_sec)
|
188 |
t3 = ttime()
|
189 |
-
|
190 |
-
pred_semantic = pred_semantic[:,-idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
|
191 |
refer = get_spepc(hps, ref_wav_path)#.to(device)
|
192 |
if(is_half==True):refer=refer.half().to(device)
|
193 |
else:refer=refer.to(device)
|
194 |
-
|
195 |
-
audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer).detach().cpu().numpy()[0, 0]###试试重建不带上prompt部分
|
196 |
audio_opt.append(audio)
|
197 |
audio_opt.append(zero_wav)
|
198 |
t4 = ttime()
|
@@ -201,7 +196,7 @@ def create_tts_fn(vq_model, ssl_model, t2s_model, hps, config, hz, max_sec):
|
|
201 |
return tts_fn
|
202 |
|
203 |
|
204 |
-
splits={",","。","?","!",",",".","?","!","~",":",":","—","…",}
|
205 |
def split(todo_text):
|
206 |
todo_text = todo_text.replace("……", "。").replace("——", ",")
|
207 |
if (todo_text[-1] not in splits): todo_text += "。"
|
@@ -209,7 +204,7 @@ def split(todo_text):
|
|
209 |
len_text = len(todo_text)
|
210 |
todo_texts = []
|
211 |
while (1):
|
212 |
-
if (i_split_head >= len_text): break
|
213 |
if (todo_text[i_split_head] in splits):
|
214 |
i_split_head += 1
|
215 |
todo_texts.append(todo_text[i_split_tail:i_split_head])
|
@@ -289,8 +284,8 @@ for speaker in models_info:
|
|
289 |
sovits_weight = speaker_info["sovits_weight"]
|
290 |
gpt_weight = speaker_info["gpt_weight"]
|
291 |
model_id = "None1145/GPT-SoVITS-Base"
|
292 |
-
|
293 |
-
vq_model, ssl_model, t2s_model, hps, config, hz, max_sec = load_model(sovits_weight, f"./PretrainedModels/{model_id}/GPT.ckpt")
|
294 |
models.append(
|
295 |
(
|
296 |
speaker,
|
|
|
119 |
vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
120 |
hz = 50
|
121 |
max_sec = config['data']['max_sec']
|
|
|
122 |
t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
|
123 |
t2s_model.load_state_dict(dict_s1["weight"])
|
124 |
if (is_half == True): t2s_model = t2s_model.half()
|
|
|
147 |
if len(text) > 50:
|
148 |
return f"Error: Text is too long, ({len(text)}>50)", None
|
149 |
with torch.no_grad():
|
150 |
+
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
151 |
wav16k = torch.from_numpy(wav16k)
|
152 |
if(is_half==True):wav16k=wav16k.half().to(device)
|
153 |
else:wav16k=wav16k.to(device)
|
154 |
+
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2)
|
155 |
codes = vq_model.extract_latent(ssl_content)
|
156 |
prompt_semantic = codes[0, 0]
|
157 |
t1 = ttime()
|
|
|
175 |
prompt = prompt_semantic.unsqueeze(0).to(device)
|
176 |
t2 = ttime()
|
177 |
with torch.no_grad():
|
|
|
178 |
pred_semantic,idx = t2s_model.model.infer_panel(
|
179 |
all_phoneme_ids,
|
180 |
all_phoneme_len,
|
181 |
prompt,
|
182 |
bert,
|
|
|
183 |
top_k=config['inference']['top_k'],
|
184 |
early_stop_num=hz * max_sec)
|
185 |
t3 = ttime()
|
186 |
+
pred_semantic = pred_semantic[:,-idx:].unsqueeze(0)
|
|
|
187 |
refer = get_spepc(hps, ref_wav_path)#.to(device)
|
188 |
if(is_half==True):refer=refer.half().to(device)
|
189 |
else:refer=refer.to(device)
|
190 |
+
audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer).detach().cpu().numpy()[0, 0]
|
|
|
191 |
audio_opt.append(audio)
|
192 |
audio_opt.append(zero_wav)
|
193 |
t4 = ttime()
|
|
|
196 |
return tts_fn
|
197 |
|
198 |
|
199 |
+
splits={",","。","?","!",",",".","?","!","~",":",":","—","…",}
|
200 |
def split(todo_text):
|
201 |
todo_text = todo_text.replace("……", "。").replace("——", ",")
|
202 |
if (todo_text[-1] not in splits): todo_text += "。"
|
|
|
204 |
len_text = len(todo_text)
|
205 |
todo_texts = []
|
206 |
while (1):
|
207 |
+
if (i_split_head >= len_text): break
|
208 |
if (todo_text[i_split_head] in splits):
|
209 |
i_split_head += 1
|
210 |
todo_texts.append(todo_text[i_split_tail:i_split_head])
|
|
|
284 |
sovits_weight = speaker_info["sovits_weight"]
|
285 |
gpt_weight = speaker_info["gpt_weight"]
|
286 |
model_id = "None1145/GPT-SoVITS-Base"
|
287 |
+
vq_model, ssl_model, t2s_model, hps, config, hz, max_sec = load_model(sovits_weight, gpt_weight)
|
288 |
+
# vq_model, ssl_model, t2s_model, hps, config, hz, max_sec = load_model(sovits_weight, f"./PretrainedModels/{model_id}/GPT.ckpt")
|
289 |
models.append(
|
290 |
(
|
291 |
speaker,
|