None1145 commited on
Commit
88a3a03
·
verified ·
1 Parent(s): 08ba840

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -13
app.py CHANGED
@@ -119,7 +119,6 @@ def load_model(sovits_path, gpt_path):
119
  vq_model.load_state_dict(dict_s2["weight"], strict=False)
120
  hz = 50
121
  max_sec = config['data']['max_sec']
122
- # t2s_model = Text2SemanticLightningModule.load_from_checkpoint(checkpoint_path=gpt_path, config=config, map_location="cpu")#########todo
123
  t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
124
  t2s_model.load_state_dict(dict_s1["weight"])
125
  if (is_half == True): t2s_model = t2s_model.half()
@@ -148,11 +147,11 @@ def create_tts_fn(vq_model, ssl_model, t2s_model, hps, config, hz, max_sec):
148
  if len(text) > 50:
149
  return f"Error: Text is too long, ({len(text)}>50)", None
150
  with torch.no_grad():
151
- wav16k, sr = librosa.load(ref_wav_path, sr=16000) # 派蒙
152
  wav16k = torch.from_numpy(wav16k)
153
  if(is_half==True):wav16k=wav16k.half().to(device)
154
  else:wav16k=wav16k.to(device)
155
- ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2)#.float()
156
  codes = vq_model.extract_latent(ssl_content)
157
  prompt_semantic = codes[0, 0]
158
  t1 = ttime()
@@ -176,23 +175,19 @@ def create_tts_fn(vq_model, ssl_model, t2s_model, hps, config, hz, max_sec):
176
  prompt = prompt_semantic.unsqueeze(0).to(device)
177
  t2 = ttime()
178
  with torch.no_grad():
179
- # pred_semantic = t2s_model.model.infer(
180
  pred_semantic,idx = t2s_model.model.infer_panel(
181
  all_phoneme_ids,
182
  all_phoneme_len,
183
  prompt,
184
  bert,
185
- # prompt_phone_len=ph_offset,
186
  top_k=config['inference']['top_k'],
187
  early_stop_num=hz * max_sec)
188
  t3 = ttime()
189
- # print(pred_semantic.shape,idx)
190
- pred_semantic = pred_semantic[:,-idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
191
  refer = get_spepc(hps, ref_wav_path)#.to(device)
192
  if(is_half==True):refer=refer.half().to(device)
193
  else:refer=refer.to(device)
194
- # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
195
- audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer).detach().cpu().numpy()[0, 0]###试试重建不带上prompt部分
196
  audio_opt.append(audio)
197
  audio_opt.append(zero_wav)
198
  t4 = ttime()
@@ -201,7 +196,7 @@ def create_tts_fn(vq_model, ssl_model, t2s_model, hps, config, hz, max_sec):
201
  return tts_fn
202
 
203
 
204
- splits={",","。","?","!",",",".","?","!","~",":",":","—","…",}#不考虑省略号
205
  def split(todo_text):
206
  todo_text = todo_text.replace("……", "。").replace("——", ",")
207
  if (todo_text[-1] not in splits): todo_text += "。"
@@ -209,7 +204,7 @@ def split(todo_text):
209
  len_text = len(todo_text)
210
  todo_texts = []
211
  while (1):
212
- if (i_split_head >= len_text): break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
213
  if (todo_text[i_split_head] in splits):
214
  i_split_head += 1
215
  todo_texts.append(todo_text[i_split_tail:i_split_head])
@@ -289,8 +284,8 @@ for speaker in models_info:
289
  sovits_weight = speaker_info["sovits_weight"]
290
  gpt_weight = speaker_info["gpt_weight"]
291
  model_id = "None1145/GPT-SoVITS-Base"
292
- # vq_model, ssl_model, t2s_model, hps, config, hz, max_sec = load_model(sovits_weight, gpt_weight)
293
- vq_model, ssl_model, t2s_model, hps, config, hz, max_sec = load_model(sovits_weight, f"./PretrainedModels/{model_id}/GPT.ckpt")
294
  models.append(
295
  (
296
  speaker,
 
119
  vq_model.load_state_dict(dict_s2["weight"], strict=False)
120
  hz = 50
121
  max_sec = config['data']['max_sec']
 
122
  t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
123
  t2s_model.load_state_dict(dict_s1["weight"])
124
  if (is_half == True): t2s_model = t2s_model.half()
 
147
  if len(text) > 50:
148
  return f"Error: Text is too long, ({len(text)}>50)", None
149
  with torch.no_grad():
150
+ wav16k, sr = librosa.load(ref_wav_path, sr=16000)
151
  wav16k = torch.from_numpy(wav16k)
152
  if(is_half==True):wav16k=wav16k.half().to(device)
153
  else:wav16k=wav16k.to(device)
154
+ ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2)
155
  codes = vq_model.extract_latent(ssl_content)
156
  prompt_semantic = codes[0, 0]
157
  t1 = ttime()
 
175
  prompt = prompt_semantic.unsqueeze(0).to(device)
176
  t2 = ttime()
177
  with torch.no_grad():
 
178
  pred_semantic,idx = t2s_model.model.infer_panel(
179
  all_phoneme_ids,
180
  all_phoneme_len,
181
  prompt,
182
  bert,
 
183
  top_k=config['inference']['top_k'],
184
  early_stop_num=hz * max_sec)
185
  t3 = ttime()
186
+ pred_semantic = pred_semantic[:,-idx:].unsqueeze(0)
 
187
  refer = get_spepc(hps, ref_wav_path)#.to(device)
188
  if(is_half==True):refer=refer.half().to(device)
189
  else:refer=refer.to(device)
190
+ audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer).detach().cpu().numpy()[0, 0]
 
191
  audio_opt.append(audio)
192
  audio_opt.append(zero_wav)
193
  t4 = ttime()
 
196
  return tts_fn
197
 
198
 
199
+ splits={",","。","?","!",",",".","?","!","~",":",":","—","…",}
200
  def split(todo_text):
201
  todo_text = todo_text.replace("……", "。").replace("——", ",")
202
  if (todo_text[-1] not in splits): todo_text += "。"
 
204
  len_text = len(todo_text)
205
  todo_texts = []
206
  while (1):
207
+ if (i_split_head >= len_text): break
208
  if (todo_text[i_split_head] in splits):
209
  i_split_head += 1
210
  todo_texts.append(todo_text[i_split_tail:i_split_head])
 
284
  sovits_weight = speaker_info["sovits_weight"]
285
  gpt_weight = speaker_info["gpt_weight"]
286
  model_id = "None1145/GPT-SoVITS-Base"
287
+ vq_model, ssl_model, t2s_model, hps, config, hz, max_sec = load_model(sovits_weight, gpt_weight)
288
+ # vq_model, ssl_model, t2s_model, hps, config, hz, max_sec = load_model(sovits_weight, f"./PretrainedModels/{model_id}/GPT.ckpt")
289
  models.append(
290
  (
291
  speaker,