File size: 7,662 Bytes
93f4bab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0684a78
93f4bab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ce0876
93f4bab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import json
import os
import time
from io import BytesIO
from pathlib import Path

import librosa
import numpy as np
import soundfile
import torch

import utils
from infer_tools.f0_static import compare_pitch, static_f0_time
from modules.diff.diffusion import GaussianDiffusion
from modules.diff.net import DiffNet
from modules.vocoders.nsf_hifigan import NsfHifiGAN
from preprocessing.hubertinfer import HubertEncoder
from preprocessing.process_pipeline import File2Batch, get_pitch_parselmouth
from utils.hparams import hparams, set_hparams
from utils.pitch_utils import denorm_f0, norm_interp_f0


def timeit(func):
    def run(*args, **kwargs):
        t = time.time()
        res = func(*args, **kwargs)
        print('executing \'%s\' costed %.3fs' % (func.__name__, time.time() - t))
        return res

    return run


def format_wav(audio_path):
    if Path(audio_path).suffix == '.wav':
        return
    raw_audio, raw_sample_rate = librosa.load(audio_path, mono=True, sr=None)
    soundfile.write(Path(audio_path).with_suffix(".wav"), raw_audio, raw_sample_rate)


def fill_a_to_b(a, b):
    if len(a) < len(b):
        for _ in range(0, len(b) - len(a)):
            a.append(a[0])


def get_end_file(dir_path, end):
    file_lists = []
    for root, dirs, files in os.walk(dir_path):
        files = [f for f in files if f[0] != '.']
        dirs[:] = [d for d in dirs if d[0] != '.']
        for f_file in files:
            if f_file.endswith(end):
                file_lists.append(os.path.join(root, f_file).replace("\\", "/"))
    return file_lists


def mkdir(paths: list):
    for path in paths:
        if not os.path.exists(path):
            os.mkdir(path)


class Svc:
    def __init__(self, project_name, config_name, hubert_gpu, model_path, onnx=False):
        self.project_name = project_name
        self.DIFF_DECODERS = {
            'wavenet': lambda hp: DiffNet(hp['audio_num_mel_bins']),
        }

        self.model_path = model_path
        self.dev = torch.device("cpu")

        self._ = set_hparams(config=config_name, exp_name=self.project_name, infer=True,
                             reset=True, hparams_str='', print_hparams=False)

        hparams['hubert_gpu'] = hubert_gpu
        self.hubert = HubertEncoder(hparams['hubert_path'], onnx=onnx)
        self.model = GaussianDiffusion(
            phone_encoder=self.hubert,
            out_dims=hparams['audio_num_mel_bins'],
            denoise_fn=self.DIFF_DECODERS[hparams['diff_decoder_type']](hparams),
            timesteps=hparams['timesteps'],
            K_step=hparams['K_step'],
            loss_type=hparams['diff_loss_type'],
            spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
        )
        utils.load_ckpt(self.model, self.model_path, 'model', force=True, strict=True)
        self.model.to(self.dev)
        self.vocoder = NsfHifiGAN()

    def infer(self, in_path, key, acc, spk_id=0, use_crepe=True, singer=False):
        batch = self.pre(in_path, acc, spk_id, use_crepe)
        batch['f0'] = batch['f0'] + (key / 12)
        batch['f0'][batch['f0'] > np.log2(hparams['f0_max'])] = 0

        @timeit
        def diff_infer():
            spk_embed = batch.get('spk_embed') if not hparams['use_spk_id'] else batch.get('spk_ids')
            energy = batch.get('energy').cuda() if batch.get('energy') else None
            if spk_embed is None:
                spk_embed = torch.LongTensor([0])
            diff_outputs = self.model(
                hubert=batch['hubert'].cuda(), spk_embed_id=spk_embed.cuda(), mel2ph=batch['mel2ph'].cuda(),
                f0=batch['f0'].cuda(), energy=energy, ref_mels=batch["mels"].cuda(), infer=True)
            return diff_outputs

        outputs = diff_infer()
        batch['outputs'] = outputs['mel_out']
        batch['mel2ph_pred'] = outputs['mel2ph']
        batch['f0_gt'] = denorm_f0(batch['f0'], batch['uv'], hparams)
        batch['f0_pred'] = outputs.get('f0_denorm')
        return self.after_infer(batch, singer, in_path)

    @timeit
    def after_infer(self, prediction, singer, in_path):
        for k, v in prediction.items():
            if type(v) is torch.Tensor:
                prediction[k] = v.cpu().numpy()

        # remove paddings
        mel_gt = prediction["mels"]
        mel_gt_mask = np.abs(mel_gt).sum(-1) > 0

        mel_pred = prediction["outputs"]
        mel_pred_mask = np.abs(mel_pred).sum(-1) > 0
        mel_pred = mel_pred[mel_pred_mask]
        mel_pred = np.clip(mel_pred, hparams['mel_vmin'], hparams['mel_vmax'])

        f0_gt = prediction.get("f0_gt")
        f0_pred = prediction.get("f0_pred")
        if f0_pred is not None:
            f0_gt = f0_gt[mel_gt_mask]
        if len(f0_pred) > len(mel_pred_mask):
            f0_pred = f0_pred[:len(mel_pred_mask)]
        f0_pred = f0_pred[mel_pred_mask]
        torch.cuda.is_available() and torch.cuda.empty_cache()

        if singer:
            data_path = in_path.replace("batch", "singer_data")
            mel_path = data_path[:-4] + "_mel.npy"
            f0_path = data_path[:-4] + "_f0.npy"
            np.save(mel_path, mel_pred)
            np.save(f0_path, f0_pred)
        wav_pred = self.vocoder.spec2wav(mel_pred, f0=f0_pred)
        return f0_gt, f0_pred, wav_pred

    def pre(self, wav_fn, accelerate, spk_id=0, use_crepe=True):
        if isinstance(wav_fn, BytesIO):
            item_name = self.project_name
        else:
            song_info = wav_fn.split('/')
            item_name = song_info[-1].split('.')[-2]
        temp_dict = {'wav_fn': wav_fn, 'spk_id': spk_id, 'id': 0}

        temp_dict = File2Batch.temporary_dict2processed_input(item_name, temp_dict, self.hubert, infer=True,
                                                              use_crepe=use_crepe)
        hparams['pndm_speedup'] = accelerate
        batch = File2Batch.processed_input2batch([getitem(temp_dict)])
        return batch

    def evaluate_key(self, wav_path, key, auto_key):
        if "f0_static" in hparams.keys():
            f0_static = json.loads(hparams['f0_static'])
            wav, mel = self.vocoder.wav2spec(wav_path)
            input_f0 = get_pitch_parselmouth(wav, mel, hparams)[0]
            pitch_time_temp = static_f0_time(input_f0)
            eval_dict = {}
            for trans_key in range(-12, 12):
                eval_dict[trans_key] = compare_pitch(f0_static, pitch_time_temp, trans_key=trans_key)
            sort_key = sorted(eval_dict, key=eval_dict.get, reverse=True)[:5]
            print(f"推荐移调:{sort_key}")
            if auto_key:
                print(f"自动变调已启用,您的输入key被{sort_key[0]}key覆盖,控制参数为auto_key")
                return sort_key[0]
        else:
            print("config缺少f0_staic,无法使用自动变调,可通过infer_tools/data_static添加")
        return key


def getitem(item):
    max_frames = hparams['max_frames']
    spec = torch.Tensor(item['mel'])[:max_frames]
    mel2ph = torch.LongTensor(item['mel2ph'])[:max_frames] if 'mel2ph' in item else None
    f0, uv = norm_interp_f0(item["f0"][:max_frames], hparams)
    hubert = torch.Tensor(item['hubert'][:hparams['max_input_tokens']])
    pitch = torch.LongTensor(item.get("pitch"))[:max_frames]
    sample = {
        "id": item['id'],
        "spk_id": item['spk_id'],
        "item_name": item['item_name'],
        "hubert": hubert,
        "mel": spec,
        "pitch": pitch,
        "f0": f0,
        "uv": uv,
        "mel2ph": mel2ph,
        "mel_nonpadding": spec.abs().sum(-1) > 0,
    }
    if hparams['use_energy_embed']:
        sample['energy'] = item['energy']
    return sample