Spaces:
Build error
Build error
ChrisPreston
commited on
Commit
·
c16ad86
1
Parent(s):
2332176
Delete infer_tools/infer_tool_beta.py
Browse files- infer_tools/infer_tool_beta.py +0 -229
infer_tools/infer_tool_beta.py
DELETED
@@ -1,229 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
import os
|
3 |
-
import time
|
4 |
-
from io import BytesIO
|
5 |
-
from pathlib import Path
|
6 |
-
|
7 |
-
import librosa
|
8 |
-
import numpy as np
|
9 |
-
import soundfile
|
10 |
-
import torch
|
11 |
-
|
12 |
-
import utils
|
13 |
-
from infer_tools.f0_static import compare_pitch, static_f0_time
|
14 |
-
from modules.diff.diffusion import GaussianDiffusion
|
15 |
-
from modules.diff.net import DiffNet
|
16 |
-
from modules.vocoders.nsf_hifigan import NsfHifiGAN
|
17 |
-
from preprocessing.hubertinfer import HubertEncoder
|
18 |
-
from preprocessing.process_pipeline import File2Batch, get_pitch_parselmouth
|
19 |
-
from utils.hparams import hparams, set_hparams
|
20 |
-
from utils.pitch_utils import denorm_f0, norm_interp_f0
|
21 |
-
|
22 |
-
|
23 |
-
def timeit(func):
|
24 |
-
def run(*args, **kwargs):
|
25 |
-
t = time.time()
|
26 |
-
res = func(*args, **kwargs)
|
27 |
-
print('executing \'%s\' costed %.3fs' % (func.__name__, time.time() - t))
|
28 |
-
return res
|
29 |
-
|
30 |
-
return run
|
31 |
-
|
32 |
-
|
33 |
-
def format_wav(audio_path):
|
34 |
-
if Path(audio_path).suffix == '.wav':
|
35 |
-
return
|
36 |
-
raw_audio, raw_sample_rate = librosa.load(audio_path, mono=True, sr=None)
|
37 |
-
soundfile.write(Path(audio_path).with_suffix(".wav"), raw_audio, raw_sample_rate)
|
38 |
-
|
39 |
-
|
40 |
-
def fill_a_to_b(a, b):
|
41 |
-
if len(a) < len(b):
|
42 |
-
for _ in range(0, len(b) - len(a)):
|
43 |
-
a.append(a[0])
|
44 |
-
|
45 |
-
|
46 |
-
def get_end_file(dir_path, end):
|
47 |
-
file_lists = []
|
48 |
-
for root, dirs, files in os.walk(dir_path):
|
49 |
-
files = [f for f in files if f[0] != '.']
|
50 |
-
dirs[:] = [d for d in dirs if d[0] != '.']
|
51 |
-
for f_file in files:
|
52 |
-
if f_file.endswith(end):
|
53 |
-
file_lists.append(os.path.join(root, f_file).replace("\\", "/"))
|
54 |
-
return file_lists
|
55 |
-
|
56 |
-
|
57 |
-
def mkdir(paths: list):
|
58 |
-
for path in paths:
|
59 |
-
if not os.path.exists(path):
|
60 |
-
os.mkdir(path)
|
61 |
-
|
62 |
-
|
63 |
-
class Svcb:
|
64 |
-
def __init__(self, project_name, config_name, hubert_gpu, model_path, onnx=False):
|
65 |
-
self.project_name = project_name
|
66 |
-
self.DIFF_DECODERS = {
|
67 |
-
'wavenet': lambda hp: DiffNet(hp['audio_num_mel_bins']),
|
68 |
-
}
|
69 |
-
|
70 |
-
self.model_path = model_path
|
71 |
-
self.dev = torch.device("cuda")
|
72 |
-
|
73 |
-
self._ = set_hparams(config=config_name, exp_name=self.project_name, infer=True,
|
74 |
-
reset=True, hparams_str='', print_hparams=False)
|
75 |
-
|
76 |
-
self.mel_bins = hparams['audio_num_mel_bins']
|
77 |
-
hparams['hubert_gpu'] = hubert_gpu
|
78 |
-
self.hubert = HubertEncoder(hparams['hubert_path'], onnx=onnx)
|
79 |
-
self.model = GaussianDiffusion(
|
80 |
-
phone_encoder=self.hubert,
|
81 |
-
out_dims=self.mel_bins, denoise_fn=self.DIFF_DECODERS[hparams['diff_decoder_type']](hparams),
|
82 |
-
timesteps=hparams['timesteps'],
|
83 |
-
K_step=hparams['K_step'],
|
84 |
-
loss_type=hparams['diff_loss_type'],
|
85 |
-
spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
|
86 |
-
)
|
87 |
-
utils.load_ckpt(self.model, self.model_path, 'model', force=True, strict=True)
|
88 |
-
self.model.cuda()
|
89 |
-
self.vocoder = NsfHifiGAN()
|
90 |
-
|
91 |
-
# def process_batch_f0(batch_f0, hparams):
|
92 |
-
# pitch_num = collect_f0(batch_f0)
|
93 |
-
# pitch_time = {}
|
94 |
-
# sort_key = sorted(pitch_num.keys())
|
95 |
-
# for key in sort_key:
|
96 |
-
# pitch_time[key] = round(pitch_num[key] * hparams['hop_size'] / hparams['audio_sample_rate'], 2)
|
97 |
-
# return pitch_time
|
98 |
-
|
99 |
-
def infer_autokey(self, in_path, key, acc, spk_id=0, use_crepe=False):
|
100 |
-
batch, temp_dict = self.pre(in_path, acc, spk_id, use_crepe)
|
101 |
-
input_f0 = temp_dict['f0']
|
102 |
-
if "f0_static" in hparams.keys():
|
103 |
-
f0_static = json.loads(hparams['f0_static'])
|
104 |
-
pitch_time_temp = static_f0_time(input_f0)
|
105 |
-
eval_dict = {}
|
106 |
-
for trans_key in range(-12, 12):
|
107 |
-
eval_dict[trans_key] = compare_pitch(f0_static, pitch_time_temp, trans_key=trans_key)
|
108 |
-
sort_key = sorted(eval_dict, key=eval_dict.get, reverse=True)[:5]
|
109 |
-
print(f"推荐移调:{sort_key}")
|
110 |
-
print(f"自动变调已启用,您的输入key被{sort_key[0]}key覆盖,控制参数为auto_key")
|
111 |
-
if sort_key[0] > 6:
|
112 |
-
key = sort_key[0] + 6
|
113 |
-
else:
|
114 |
-
key = sort_key[0]
|
115 |
-
return key, in_path, batch
|
116 |
-
|
117 |
-
# def infer(self, in_path, key, acc, spk_id=0, use_crepe=True, singer=False):
|
118 |
-
# batch = self.pre(in_path, acc, spk_id, use_crepe)
|
119 |
-
|
120 |
-
def infer(self, in_path, key, batch, singer=False):
|
121 |
-
batch['f0'] = batch['f0'] + (key / 12)
|
122 |
-
batch['f0'][batch['f0'] > np.log2(hparams['f0_max'])] = 0
|
123 |
-
|
124 |
-
@timeit
|
125 |
-
def diff_infer():
|
126 |
-
spk_embed = batch.get('spk_embed') if not hparams['use_spk_id'] else batch.get('spk_ids')
|
127 |
-
energy = batch.get('energy').cuda() if batch.get('energy') else None
|
128 |
-
if spk_embed is None:
|
129 |
-
spk_embed = torch.LongTensor([0])
|
130 |
-
diff_outputs = self.model(
|
131 |
-
hubert=batch['hubert'].cuda(), spk_embed_id=spk_embed.cuda(), mel2ph=batch['mel2ph'].cuda(),
|
132 |
-
f0=batch['f0'].cuda(), energy=energy, ref_mels=batch["mels"].cuda(), infer=True)
|
133 |
-
return diff_outputs
|
134 |
-
|
135 |
-
outputs = diff_infer()
|
136 |
-
batch['outputs'] = outputs['mel_out']
|
137 |
-
batch['mel2ph_pred'] = outputs['mel2ph']
|
138 |
-
batch['f0_gt'] = denorm_f0(batch['f0'], batch['uv'], hparams)
|
139 |
-
batch['f0_pred'] = outputs.get('f0_denorm')
|
140 |
-
return self.after_infer(batch, singer, in_path)
|
141 |
-
|
142 |
-
@timeit
|
143 |
-
def after_infer(self, prediction, singer, in_path):
|
144 |
-
for k, v in prediction.items():
|
145 |
-
if type(v) is torch.Tensor:
|
146 |
-
prediction[k] = v.cpu().numpy()
|
147 |
-
|
148 |
-
# remove paddings
|
149 |
-
mel_gt = prediction["mels"]
|
150 |
-
mel_gt_mask = np.abs(mel_gt).sum(-1) > 0
|
151 |
-
|
152 |
-
mel_pred = prediction["outputs"]
|
153 |
-
mel_pred_mask = np.abs(mel_pred).sum(-1) > 0
|
154 |
-
mel_pred = mel_pred[mel_pred_mask]
|
155 |
-
mel_pred = np.clip(mel_pred, hparams['mel_vmin'], hparams['mel_vmax'])
|
156 |
-
|
157 |
-
f0_gt = prediction.get("f0_gt")
|
158 |
-
f0_pred = prediction.get("f0_pred")
|
159 |
-
if f0_pred is not None:
|
160 |
-
f0_gt = f0_gt[mel_gt_mask]
|
161 |
-
if len(f0_pred) > len(mel_pred_mask):
|
162 |
-
f0_pred = f0_pred[:len(mel_pred_mask)]
|
163 |
-
f0_pred = f0_pred[mel_pred_mask]
|
164 |
-
torch.cuda.is_available() and torch.cuda.empty_cache()
|
165 |
-
|
166 |
-
if singer:
|
167 |
-
data_path = in_path.replace("batch", "singer_data")
|
168 |
-
mel_path = data_path[:-4] + "_mel.npy"
|
169 |
-
f0_path = data_path[:-4] + "_f0.npy"
|
170 |
-
np.save(mel_path, mel_pred)
|
171 |
-
np.save(f0_path, f0_pred)
|
172 |
-
wav_pred = self.vocoder.spec2wav(mel_pred, f0=f0_pred)
|
173 |
-
return f0_gt, f0_pred, wav_pred
|
174 |
-
|
175 |
-
def pre(self, wav_fn, accelerate, spk_id=0, use_crepe=True):
|
176 |
-
if isinstance(wav_fn, BytesIO):
|
177 |
-
item_name = self.project_name
|
178 |
-
else:
|
179 |
-
song_info = wav_fn.split('/')
|
180 |
-
item_name = song_info[-1].split('.')[-2]
|
181 |
-
temp_dict = {'wav_fn': wav_fn, 'spk_id': spk_id, 'id': 0}
|
182 |
-
|
183 |
-
temp_dict = File2Batch.temporary_dict2processed_input(item_name, temp_dict, self.hubert, infer=True,
|
184 |
-
use_crepe=use_crepe)
|
185 |
-
hparams['pndm_speedup'] = accelerate
|
186 |
-
batch = File2Batch.processed_input2batch([getitem(temp_dict)])
|
187 |
-
return batch, temp_dict
|
188 |
-
|
189 |
-
def evaluate_key(self, wav_path, key, auto_key):
|
190 |
-
if "f0_static" in hparams.keys():
|
191 |
-
f0_static = json.loads(hparams['f0_static'])
|
192 |
-
wav, mel = self.vocoder.wav2spec(wav_path)
|
193 |
-
input_f0 = get_pitch_parselmouth(wav, mel, hparams)[0]
|
194 |
-
pitch_time_temp = static_f0_time(input_f0)
|
195 |
-
eval_dict = {}
|
196 |
-
for trans_key in range(-12, 12):
|
197 |
-
eval_dict[trans_key] = compare_pitch(f0_static, pitch_time_temp, trans_key=trans_key)
|
198 |
-
sort_key = sorted(eval_dict, key=eval_dict.get, reverse=True)[:5]
|
199 |
-
print(f"推荐移调:{sort_key}")
|
200 |
-
if auto_key:
|
201 |
-
print(f"自动变调已启用,您的输入key被{sort_key[0]}key覆盖,控制参数为auto_key")
|
202 |
-
return sort_key[0]
|
203 |
-
else:
|
204 |
-
print("config缺少f0_staic,无法使用自动变调,可通过infer_tools/data_static添加")
|
205 |
-
return key
|
206 |
-
|
207 |
-
|
208 |
-
def getitem(item):
|
209 |
-
max_frames = hparams['max_frames']
|
210 |
-
spec = torch.Tensor(item['mel'])[:max_frames]
|
211 |
-
mel2ph = torch.LongTensor(item['mel2ph'])[:max_frames] if 'mel2ph' in item else None
|
212 |
-
f0, uv = norm_interp_f0(item["f0"][:max_frames], hparams)
|
213 |
-
hubert = torch.Tensor(item['hubert'][:hparams['max_input_tokens']])
|
214 |
-
pitch = torch.LongTensor(item.get("pitch"))[:max_frames]
|
215 |
-
sample = {
|
216 |
-
"id": item['id'],
|
217 |
-
"spk_id": item['spk_id'],
|
218 |
-
"item_name": item['item_name'],
|
219 |
-
"hubert": hubert,
|
220 |
-
"mel": spec,
|
221 |
-
"pitch": pitch,
|
222 |
-
"f0": f0,
|
223 |
-
"uv": uv,
|
224 |
-
"mel2ph": mel2ph,
|
225 |
-
"mel_nonpadding": spec.abs().sum(-1) > 0,
|
226 |
-
}
|
227 |
-
if hparams['use_energy_embed']:
|
228 |
-
sample['energy'] = item['energy']
|
229 |
-
return sample
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|