ChrisPreston commited on
Commit
c16ad86
·
1 Parent(s): 2332176

Delete infer_tools/infer_tool_beta.py

Browse files
Files changed (1) hide show
  1. infer_tools/infer_tool_beta.py +0 -229
infer_tools/infer_tool_beta.py DELETED
@@ -1,229 +0,0 @@
1
- import json
2
- import os
3
- import time
4
- from io import BytesIO
5
- from pathlib import Path
6
-
7
- import librosa
8
- import numpy as np
9
- import soundfile
10
- import torch
11
-
12
- import utils
13
- from infer_tools.f0_static import compare_pitch, static_f0_time
14
- from modules.diff.diffusion import GaussianDiffusion
15
- from modules.diff.net import DiffNet
16
- from modules.vocoders.nsf_hifigan import NsfHifiGAN
17
- from preprocessing.hubertinfer import HubertEncoder
18
- from preprocessing.process_pipeline import File2Batch, get_pitch_parselmouth
19
- from utils.hparams import hparams, set_hparams
20
- from utils.pitch_utils import denorm_f0, norm_interp_f0
21
-
22
-
23
- def timeit(func):
24
- def run(*args, **kwargs):
25
- t = time.time()
26
- res = func(*args, **kwargs)
27
- print('executing \'%s\' costed %.3fs' % (func.__name__, time.time() - t))
28
- return res
29
-
30
- return run
31
-
32
-
33
- def format_wav(audio_path):
34
- if Path(audio_path).suffix == '.wav':
35
- return
36
- raw_audio, raw_sample_rate = librosa.load(audio_path, mono=True, sr=None)
37
- soundfile.write(Path(audio_path).with_suffix(".wav"), raw_audio, raw_sample_rate)
38
-
39
-
40
- def fill_a_to_b(a, b):
41
- if len(a) < len(b):
42
- for _ in range(0, len(b) - len(a)):
43
- a.append(a[0])
44
-
45
-
46
- def get_end_file(dir_path, end):
47
- file_lists = []
48
- for root, dirs, files in os.walk(dir_path):
49
- files = [f for f in files if f[0] != '.']
50
- dirs[:] = [d for d in dirs if d[0] != '.']
51
- for f_file in files:
52
- if f_file.endswith(end):
53
- file_lists.append(os.path.join(root, f_file).replace("\\", "/"))
54
- return file_lists
55
-
56
-
57
- def mkdir(paths: list):
58
- for path in paths:
59
- if not os.path.exists(path):
60
- os.mkdir(path)
61
-
62
-
63
- class Svcb:
64
- def __init__(self, project_name, config_name, hubert_gpu, model_path, onnx=False):
65
- self.project_name = project_name
66
- self.DIFF_DECODERS = {
67
- 'wavenet': lambda hp: DiffNet(hp['audio_num_mel_bins']),
68
- }
69
-
70
- self.model_path = model_path
71
- self.dev = torch.device("cuda")
72
-
73
- self._ = set_hparams(config=config_name, exp_name=self.project_name, infer=True,
74
- reset=True, hparams_str='', print_hparams=False)
75
-
76
- self.mel_bins = hparams['audio_num_mel_bins']
77
- hparams['hubert_gpu'] = hubert_gpu
78
- self.hubert = HubertEncoder(hparams['hubert_path'], onnx=onnx)
79
- self.model = GaussianDiffusion(
80
- phone_encoder=self.hubert,
81
- out_dims=self.mel_bins, denoise_fn=self.DIFF_DECODERS[hparams['diff_decoder_type']](hparams),
82
- timesteps=hparams['timesteps'],
83
- K_step=hparams['K_step'],
84
- loss_type=hparams['diff_loss_type'],
85
- spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
86
- )
87
- utils.load_ckpt(self.model, self.model_path, 'model', force=True, strict=True)
88
- self.model.cuda()
89
- self.vocoder = NsfHifiGAN()
90
-
91
- # def process_batch_f0(batch_f0, hparams):
92
- # pitch_num = collect_f0(batch_f0)
93
- # pitch_time = {}
94
- # sort_key = sorted(pitch_num.keys())
95
- # for key in sort_key:
96
- # pitch_time[key] = round(pitch_num[key] * hparams['hop_size'] / hparams['audio_sample_rate'], 2)
97
- # return pitch_time
98
-
99
- def infer_autokey(self, in_path, key, acc, spk_id=0, use_crepe=False):
100
- batch, temp_dict = self.pre(in_path, acc, spk_id, use_crepe)
101
- input_f0 = temp_dict['f0']
102
- if "f0_static" in hparams.keys():
103
- f0_static = json.loads(hparams['f0_static'])
104
- pitch_time_temp = static_f0_time(input_f0)
105
- eval_dict = {}
106
- for trans_key in range(-12, 12):
107
- eval_dict[trans_key] = compare_pitch(f0_static, pitch_time_temp, trans_key=trans_key)
108
- sort_key = sorted(eval_dict, key=eval_dict.get, reverse=True)[:5]
109
- print(f"推荐移调:{sort_key}")
110
- print(f"自动变调已启用,您的输入key被{sort_key[0]}key覆盖,控制参数为auto_key")
111
- if sort_key[0] > 6:
112
- key = sort_key[0] + 6
113
- else:
114
- key = sort_key[0]
115
- return key, in_path, batch
116
-
117
- # def infer(self, in_path, key, acc, spk_id=0, use_crepe=True, singer=False):
118
- # batch = self.pre(in_path, acc, spk_id, use_crepe)
119
-
120
- def infer(self, in_path, key, batch, singer=False):
121
- batch['f0'] = batch['f0'] + (key / 12)
122
- batch['f0'][batch['f0'] > np.log2(hparams['f0_max'])] = 0
123
-
124
- @timeit
125
- def diff_infer():
126
- spk_embed = batch.get('spk_embed') if not hparams['use_spk_id'] else batch.get('spk_ids')
127
- energy = batch.get('energy').cuda() if batch.get('energy') else None
128
- if spk_embed is None:
129
- spk_embed = torch.LongTensor([0])
130
- diff_outputs = self.model(
131
- hubert=batch['hubert'].cuda(), spk_embed_id=spk_embed.cuda(), mel2ph=batch['mel2ph'].cuda(),
132
- f0=batch['f0'].cuda(), energy=energy, ref_mels=batch["mels"].cuda(), infer=True)
133
- return diff_outputs
134
-
135
- outputs = diff_infer()
136
- batch['outputs'] = outputs['mel_out']
137
- batch['mel2ph_pred'] = outputs['mel2ph']
138
- batch['f0_gt'] = denorm_f0(batch['f0'], batch['uv'], hparams)
139
- batch['f0_pred'] = outputs.get('f0_denorm')
140
- return self.after_infer(batch, singer, in_path)
141
-
142
- @timeit
143
- def after_infer(self, prediction, singer, in_path):
144
- for k, v in prediction.items():
145
- if type(v) is torch.Tensor:
146
- prediction[k] = v.cpu().numpy()
147
-
148
- # remove paddings
149
- mel_gt = prediction["mels"]
150
- mel_gt_mask = np.abs(mel_gt).sum(-1) > 0
151
-
152
- mel_pred = prediction["outputs"]
153
- mel_pred_mask = np.abs(mel_pred).sum(-1) > 0
154
- mel_pred = mel_pred[mel_pred_mask]
155
- mel_pred = np.clip(mel_pred, hparams['mel_vmin'], hparams['mel_vmax'])
156
-
157
- f0_gt = prediction.get("f0_gt")
158
- f0_pred = prediction.get("f0_pred")
159
- if f0_pred is not None:
160
- f0_gt = f0_gt[mel_gt_mask]
161
- if len(f0_pred) > len(mel_pred_mask):
162
- f0_pred = f0_pred[:len(mel_pred_mask)]
163
- f0_pred = f0_pred[mel_pred_mask]
164
- torch.cuda.is_available() and torch.cuda.empty_cache()
165
-
166
- if singer:
167
- data_path = in_path.replace("batch", "singer_data")
168
- mel_path = data_path[:-4] + "_mel.npy"
169
- f0_path = data_path[:-4] + "_f0.npy"
170
- np.save(mel_path, mel_pred)
171
- np.save(f0_path, f0_pred)
172
- wav_pred = self.vocoder.spec2wav(mel_pred, f0=f0_pred)
173
- return f0_gt, f0_pred, wav_pred
174
-
175
- def pre(self, wav_fn, accelerate, spk_id=0, use_crepe=True):
176
- if isinstance(wav_fn, BytesIO):
177
- item_name = self.project_name
178
- else:
179
- song_info = wav_fn.split('/')
180
- item_name = song_info[-1].split('.')[-2]
181
- temp_dict = {'wav_fn': wav_fn, 'spk_id': spk_id, 'id': 0}
182
-
183
- temp_dict = File2Batch.temporary_dict2processed_input(item_name, temp_dict, self.hubert, infer=True,
184
- use_crepe=use_crepe)
185
- hparams['pndm_speedup'] = accelerate
186
- batch = File2Batch.processed_input2batch([getitem(temp_dict)])
187
- return batch, temp_dict
188
-
189
- def evaluate_key(self, wav_path, key, auto_key):
190
- if "f0_static" in hparams.keys():
191
- f0_static = json.loads(hparams['f0_static'])
192
- wav, mel = self.vocoder.wav2spec(wav_path)
193
- input_f0 = get_pitch_parselmouth(wav, mel, hparams)[0]
194
- pitch_time_temp = static_f0_time(input_f0)
195
- eval_dict = {}
196
- for trans_key in range(-12, 12):
197
- eval_dict[trans_key] = compare_pitch(f0_static, pitch_time_temp, trans_key=trans_key)
198
- sort_key = sorted(eval_dict, key=eval_dict.get, reverse=True)[:5]
199
- print(f"推荐移调:{sort_key}")
200
- if auto_key:
201
- print(f"自动变调已启用,您的输入key被{sort_key[0]}key覆盖,控制参数为auto_key")
202
- return sort_key[0]
203
- else:
204
- print("config缺少f0_staic,无法使用自动变调,可通过infer_tools/data_static添加")
205
- return key
206
-
207
-
208
- def getitem(item):
209
- max_frames = hparams['max_frames']
210
- spec = torch.Tensor(item['mel'])[:max_frames]
211
- mel2ph = torch.LongTensor(item['mel2ph'])[:max_frames] if 'mel2ph' in item else None
212
- f0, uv = norm_interp_f0(item["f0"][:max_frames], hparams)
213
- hubert = torch.Tensor(item['hubert'][:hparams['max_input_tokens']])
214
- pitch = torch.LongTensor(item.get("pitch"))[:max_frames]
215
- sample = {
216
- "id": item['id'],
217
- "spk_id": item['spk_id'],
218
- "item_name": item['item_name'],
219
- "hubert": hubert,
220
- "mel": spec,
221
- "pitch": pitch,
222
- "f0": f0,
223
- "uv": uv,
224
- "mel2ph": mel2ph,
225
- "mel_nonpadding": spec.abs().sum(-1) > 0,
226
- }
227
- if hparams['use_energy_embed']:
228
- sample['energy'] = item['energy']
229
- return sample