Spaces:
ruimei
/
Runtime error

Eddycrack864 commited on
Commit
240c2c0
1 Parent(s): 62998a2

Upload separate.py

Browse files
Files changed (1) hide show
  1. separate.py +1459 -0
separate.py ADDED
@@ -0,0 +1,1459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import TYPE_CHECKING
3
+ from demucs.apply import apply_model, demucs_segments
4
+ from demucs.hdemucs import HDemucs
5
+ from demucs.model_v2 import auto_load_demucs_model_v2
6
+ from demucs.pretrained import get_model as _gm
7
+ from demucs.utils import apply_model_v1
8
+ from demucs.utils import apply_model_v2
9
+ from lib_v5.tfc_tdf_v3 import TFC_TDF_net, STFT
10
+ from lib_v5 import spec_utils
11
+ from lib_v5.vr_network import nets
12
+ from lib_v5.vr_network import nets_new
13
+ from lib_v5.vr_network.model_param_init import ModelParameters
14
+ from pathlib import Path
15
+ from gui_data.constants import *
16
+ from gui_data.error_handling import *
17
+ from scipy import signal
18
+ import audioread
19
+ import gzip
20
+ import librosa
21
+ import math
22
+ import numpy as np
23
+ import onnxruntime as ort
24
+ import os
25
+ import torch
26
+ import warnings
27
+ import pydub
28
+ import soundfile as sf
29
+ import lib_v5.mdxnet as MdxnetSet
30
+ import math
31
+ #import random
32
+ from onnx import load
33
+ from onnx2pytorch import ConvertModel
34
+ import gc
35
+
36
+ if TYPE_CHECKING:
37
+ from UVR import ModelData
38
+
39
+ # if not is_macos:
40
+ # import torch_directml
41
+
42
+ mps_available = torch.backends.mps.is_available() if is_macos else False
43
+ cuda_available = torch.cuda.is_available()
44
+
45
+ # def get_gpu_info():
46
+ # directml_device, directml_available = DIRECTML_DEVICE, False
47
+
48
+ # if not is_macos:
49
+ # directml_available = torch_directml.is_available()
50
+
51
+ # if directml_available:
52
+ # directml_device = str(torch_directml.device()).partition(":")[0]
53
+
54
+ # return directml_device, directml_available
55
+
56
+ # DIRECTML_DEVICE, directml_available = get_gpu_info()
57
+
58
+ def clear_gpu_cache():
59
+ gc.collect()
60
+ if is_macos:
61
+ torch.mps.empty_cache()
62
+ else:
63
+ torch.cuda.empty_cache()
64
+
65
+ warnings.filterwarnings("ignore")
66
+ cpu = torch.device('cpu')
67
+
68
+ class SeperateAttributes:
69
+ def __init__(self, model_data: ModelData,
70
+ process_data: dict,
71
+ main_model_primary_stem_4_stem=None,
72
+ main_process_method=None,
73
+ is_return_dual=True,
74
+ main_model_primary=None,
75
+ vocal_stem_path=None,
76
+ master_inst_source=None,
77
+ master_vocal_source=None):
78
+
79
+ self.list_all_models: list
80
+ self.process_data = process_data
81
+ self.progress_value = 0
82
+ self.set_progress_bar = process_data['set_progress_bar']
83
+ self.write_to_console = process_data['write_to_console']
84
+ if vocal_stem_path:
85
+ self.audio_file, self.audio_file_base = vocal_stem_path
86
+ self.audio_file_base_voc_split = lambda stem, split:os.path.join(self.export_path, f'{self.audio_file_base.replace("_(Vocals)", "")}_({stem}_{split}).wav')
87
+ else:
88
+ self.audio_file = process_data['audio_file']
89
+ self.audio_file_base = process_data['audio_file_base']
90
+ self.audio_file_base_voc_split = None
91
+ self.export_path = process_data['export_path']
92
+ self.cached_source_callback = process_data['cached_source_callback']
93
+ self.cached_model_source_holder = process_data['cached_model_source_holder']
94
+ self.is_4_stem_ensemble = process_data['is_4_stem_ensemble']
95
+ self.list_all_models = process_data['list_all_models']
96
+ self.process_iteration = process_data['process_iteration']
97
+ self.is_return_dual = is_return_dual
98
+ self.is_pitch_change = model_data.is_pitch_change
99
+ self.semitone_shift = model_data.semitone_shift
100
+ self.is_match_frequency_pitch = model_data.is_match_frequency_pitch
101
+ self.overlap = model_data.overlap
102
+ self.overlap_mdx = model_data.overlap_mdx
103
+ self.overlap_mdx23 = model_data.overlap_mdx23
104
+ self.is_mdx_combine_stems = model_data.is_mdx_combine_stems
105
+ self.is_mdx_c = model_data.is_mdx_c
106
+ self.mdx_c_configs = model_data.mdx_c_configs
107
+ self.mdxnet_stem_select = model_data.mdxnet_stem_select
108
+ self.mixer_path = model_data.mixer_path
109
+ self.model_samplerate = model_data.model_samplerate
110
+ self.model_capacity = model_data.model_capacity
111
+ self.is_vr_51_model = model_data.is_vr_51_model
112
+ self.is_pre_proc_model = model_data.is_pre_proc_model
113
+ self.is_secondary_model_activated = model_data.is_secondary_model_activated if not self.is_pre_proc_model else False
114
+ self.is_secondary_model = model_data.is_secondary_model if not self.is_pre_proc_model else True
115
+ self.process_method = model_data.process_method
116
+ self.model_path = model_data.model_path
117
+ self.model_name = model_data.model_name
118
+ self.model_basename = model_data.model_basename
119
+ self.wav_type_set = model_data.wav_type_set
120
+ self.mp3_bit_set = model_data.mp3_bit_set
121
+ self.save_format = model_data.save_format
122
+ self.is_gpu_conversion = model_data.is_gpu_conversion
123
+ self.is_normalization = model_data.is_normalization
124
+ self.is_primary_stem_only = model_data.is_primary_stem_only if not self.is_secondary_model else model_data.is_primary_model_primary_stem_only
125
+ self.is_secondary_stem_only = model_data.is_secondary_stem_only if not self.is_secondary_model else model_data.is_primary_model_secondary_stem_only
126
+ self.is_ensemble_mode = model_data.is_ensemble_mode
127
+ self.secondary_model = model_data.secondary_model #
128
+ self.primary_model_primary_stem = model_data.primary_model_primary_stem
129
+ self.primary_stem_native = model_data.primary_stem_native
130
+ self.primary_stem = model_data.primary_stem #
131
+ self.secondary_stem = model_data.secondary_stem #
132
+ self.is_invert_spec = model_data.is_invert_spec #
133
+ self.is_deverb_vocals = model_data.is_deverb_vocals
134
+ self.is_mixer_mode = model_data.is_mixer_mode #
135
+ self.secondary_model_scale = model_data.secondary_model_scale #
136
+ self.is_demucs_pre_proc_model_inst_mix = model_data.is_demucs_pre_proc_model_inst_mix #
137
+ self.primary_source_map = {}
138
+ self.secondary_source_map = {}
139
+ self.primary_source = None
140
+ self.secondary_source = None
141
+ self.secondary_source_primary = None
142
+ self.secondary_source_secondary = None
143
+ self.main_model_primary_stem_4_stem = main_model_primary_stem_4_stem
144
+ self.main_model_primary = main_model_primary
145
+ self.ensemble_primary_stem = model_data.ensemble_primary_stem
146
+ self.is_multi_stem_ensemble = model_data.is_multi_stem_ensemble
147
+ self.is_other_gpu = False
148
+ self.is_deverb = True
149
+ self.DENOISER_MODEL = model_data.DENOISER_MODEL
150
+ self.DEVERBER_MODEL = model_data.DEVERBER_MODEL
151
+ self.is_source_swap = False
152
+ self.vocal_split_model = model_data.vocal_split_model
153
+ self.is_vocal_split_model = model_data.is_vocal_split_model
154
+ self.master_vocal_path = None
155
+ self.set_master_inst_source = None
156
+ self.master_inst_source = master_inst_source
157
+ self.master_vocal_source = master_vocal_source
158
+ self.is_save_inst_vocal_splitter = isinstance(master_inst_source, np.ndarray) and model_data.is_save_inst_vocal_splitter
159
+ self.is_inst_only_voc_splitter = model_data.is_inst_only_voc_splitter
160
+ self.is_karaoke = model_data.is_karaoke
161
+ self.is_bv_model = model_data.is_bv_model
162
+ self.is_bv_model_rebalenced = model_data.bv_model_rebalance and self.is_vocal_split_model
163
+ self.is_sec_bv_rebalance = model_data.is_sec_bv_rebalance
164
+ self.stem_path_init = os.path.join(self.export_path, f'{self.audio_file_base}_({self.secondary_stem}).wav')
165
+ self.deverb_vocal_opt = model_data.deverb_vocal_opt
166
+ self.is_save_vocal_only = model_data.is_save_vocal_only
167
+ self.device = cpu
168
+ self.run_type = ['CPUExecutionProvider']
169
+ self.is_opencl = False
170
+ self.device_set = model_data.device_set
171
+ self.is_use_opencl = model_data.is_use_opencl
172
+
173
+ if self.is_inst_only_voc_splitter or self.is_sec_bv_rebalance:
174
+ self.is_primary_stem_only = False
175
+ self.is_secondary_stem_only = False
176
+
177
+ if main_model_primary and self.is_multi_stem_ensemble:
178
+ self.primary_stem, self.secondary_stem = main_model_primary, secondary_stem(main_model_primary)
179
+
180
+ if self.is_gpu_conversion >= 0:
181
+ if mps_available:
182
+ self.device, self.is_other_gpu = 'mps', True
183
+ else:
184
+ device_prefix = None
185
+ if self.device_set != DEFAULT:
186
+ device_prefix = CUDA_DEVICE#DIRECTML_DEVICE if self.is_use_opencl and directml_available else CUDA_DEVICE
187
+
188
+ # if directml_available and self.is_use_opencl:
189
+ # self.device = torch_directml.device() if not device_prefix else f'{device_prefix}:{self.device_set}'
190
+ # self.is_other_gpu = True
191
+ if cuda_available:# and not self.is_use_opencl:
192
+ self.device = CUDA_DEVICE if not device_prefix else f'{device_prefix}:{self.device_set}'
193
+ self.run_type = ['CUDAExecutionProvider']
194
+
195
+ if model_data.process_method == MDX_ARCH_TYPE:
196
+ self.is_mdx_ckpt = model_data.is_mdx_ckpt
197
+ self.primary_model_name, self.primary_sources = self.cached_source_callback(MDX_ARCH_TYPE, model_name=self.model_basename)
198
+ self.is_denoise = model_data.is_denoise#
199
+ self.is_denoise_model = model_data.is_denoise_model#
200
+ self.is_mdx_c_seg_def = model_data.is_mdx_c_seg_def#
201
+ self.mdx_batch_size = model_data.mdx_batch_size
202
+ self.compensate = model_data.compensate
203
+ self.mdx_segment_size = model_data.mdx_segment_size
204
+
205
+ if self.is_mdx_c:
206
+ if not self.is_4_stem_ensemble:
207
+ self.primary_stem = model_data.ensemble_primary_stem if process_data['is_ensemble_master'] else model_data.primary_stem
208
+ self.secondary_stem = model_data.ensemble_secondary_stem if process_data['is_ensemble_master'] else model_data.secondary_stem
209
+ else:
210
+ self.dim_f, self.dim_t = model_data.mdx_dim_f_set, 2**model_data.mdx_dim_t_set
211
+
212
+ self.check_label_secondary_stem_runs()
213
+ self.n_fft = model_data.mdx_n_fft_scale_set
214
+ self.chunks = model_data.chunks
215
+ self.margin = model_data.margin
216
+ self.adjust = 1
217
+ self.dim_c = 4
218
+ self.hop = 1024
219
+
220
+ if model_data.process_method == DEMUCS_ARCH_TYPE:
221
+ self.demucs_stems = model_data.demucs_stems if not main_process_method in [MDX_ARCH_TYPE, VR_ARCH_TYPE] else None
222
+ self.secondary_model_4_stem = model_data.secondary_model_4_stem
223
+ self.secondary_model_4_stem_scale = model_data.secondary_model_4_stem_scale
224
+ self.is_chunk_demucs = model_data.is_chunk_demucs
225
+ self.segment = model_data.segment
226
+ self.demucs_version = model_data.demucs_version
227
+ self.demucs_source_list = model_data.demucs_source_list
228
+ self.demucs_source_map = model_data.demucs_source_map
229
+ self.is_demucs_combine_stems = model_data.is_demucs_combine_stems
230
+ self.demucs_stem_count = model_data.demucs_stem_count
231
+ self.pre_proc_model = model_data.pre_proc_model
232
+ self.device = cpu if self.is_other_gpu and not self.demucs_version in [DEMUCS_V3, DEMUCS_V4] else self.device
233
+
234
+ self.primary_stem = model_data.ensemble_primary_stem if process_data['is_ensemble_master'] else model_data.primary_stem
235
+ self.secondary_stem = model_data.ensemble_secondary_stem if process_data['is_ensemble_master'] else model_data.secondary_stem
236
+
237
+ if (self.is_multi_stem_ensemble or self.is_4_stem_ensemble) and not self.is_secondary_model:
238
+ self.is_return_dual = False
239
+
240
+ if self.is_multi_stem_ensemble and main_model_primary:
241
+ self.is_4_stem_ensemble = False
242
+ if main_model_primary in self.demucs_source_map.keys():
243
+ self.primary_stem = main_model_primary
244
+ self.secondary_stem = secondary_stem(main_model_primary)
245
+ elif secondary_stem(main_model_primary) in self.demucs_source_map.keys():
246
+ self.primary_stem = secondary_stem(main_model_primary)
247
+ self.secondary_stem = main_model_primary
248
+
249
+ if self.is_secondary_model and not process_data['is_ensemble_master']:
250
+ if not self.demucs_stem_count == 2 and model_data.primary_model_primary_stem == INST_STEM:
251
+ self.primary_stem = VOCAL_STEM
252
+ self.secondary_stem = INST_STEM
253
+ else:
254
+ self.primary_stem = model_data.primary_model_primary_stem
255
+ self.secondary_stem = secondary_stem(self.primary_stem)
256
+
257
+ self.shifts = model_data.shifts
258
+ self.is_split_mode = model_data.is_split_mode if not self.demucs_version == DEMUCS_V4 else True
259
+ self.primary_model_name, self.primary_sources = self.cached_source_callback(DEMUCS_ARCH_TYPE, model_name=self.model_basename)
260
+
261
+ if model_data.process_method == VR_ARCH_TYPE:
262
+ self.check_label_secondary_stem_runs()
263
+ self.primary_model_name, self.primary_sources = self.cached_source_callback(VR_ARCH_TYPE, model_name=self.model_basename)
264
+ self.mp = model_data.vr_model_param
265
+ self.high_end_process = model_data.is_high_end_process
266
+ self.is_tta = model_data.is_tta
267
+ self.is_post_process = model_data.is_post_process
268
+ self.is_gpu_conversion = model_data.is_gpu_conversion
269
+ self.batch_size = model_data.batch_size
270
+ self.window_size = model_data.window_size
271
+ self.input_high_end_h = None
272
+ self.input_high_end = None
273
+ self.post_process_threshold = model_data.post_process_threshold
274
+ self.aggressiveness = {'value': model_data.aggression_setting,
275
+ 'split_bin': self.mp.param['band'][1]['crop_stop'],
276
+ 'aggr_correction': self.mp.param.get('aggr_correction')}
277
+
278
+ def check_label_secondary_stem_runs(self):
279
+
280
+ # For ensemble master that's not a 4-stem ensemble, and not mdx_c
281
+ if self.process_data['is_ensemble_master'] and not self.is_4_stem_ensemble and not self.is_mdx_c:
282
+ if self.ensemble_primary_stem != self.primary_stem:
283
+ self.is_primary_stem_only, self.is_secondary_stem_only = self.is_secondary_stem_only, self.is_primary_stem_only
284
+
285
+ # For secondary models
286
+ if self.is_pre_proc_model or self.is_secondary_model:
287
+ self.is_primary_stem_only = False
288
+ self.is_secondary_stem_only = False
289
+
290
+ def start_inference_console_write(self):
291
+ if self.is_secondary_model and not self.is_pre_proc_model and not self.is_vocal_split_model:
292
+ self.write_to_console(INFERENCE_STEP_2_SEC(self.process_method, self.model_basename))
293
+
294
+ if self.is_pre_proc_model:
295
+ self.write_to_console(INFERENCE_STEP_2_PRE(self.process_method, self.model_basename))
296
+
297
+ if self.is_vocal_split_model:
298
+ self.write_to_console(INFERENCE_STEP_2_VOC_S(self.process_method, self.model_basename))
299
+
300
+ def running_inference_console_write(self, is_no_write=False):
301
+ self.write_to_console(DONE, base_text='') if not is_no_write else None
302
+ self.set_progress_bar(0.05) if not is_no_write else None
303
+
304
+ if self.is_secondary_model and not self.is_pre_proc_model and not self.is_vocal_split_model:
305
+ self.write_to_console(INFERENCE_STEP_1_SEC)
306
+ elif self.is_pre_proc_model:
307
+ self.write_to_console(INFERENCE_STEP_1_PRE)
308
+ elif self.is_vocal_split_model:
309
+ self.write_to_console(INFERENCE_STEP_1_VOC_S)
310
+ else:
311
+ self.write_to_console(INFERENCE_STEP_1)
312
+
313
+ def running_inference_progress_bar(self, length, is_match_mix=False):
314
+ if not is_match_mix:
315
+ self.progress_value += 1
316
+
317
+ if (0.8/length*self.progress_value) >= 0.8:
318
+ length = self.progress_value + 1
319
+
320
+ self.set_progress_bar(0.1, (0.8/length*self.progress_value))
321
+
322
+ def load_cached_sources(self):
323
+
324
+ if self.is_secondary_model and not self.is_pre_proc_model:
325
+ self.write_to_console(INFERENCE_STEP_2_SEC_CACHED_MODOEL(self.process_method, self.model_basename))
326
+ elif self.is_pre_proc_model:
327
+ self.write_to_console(INFERENCE_STEP_2_PRE_CACHED_MODOEL(self.process_method, self.model_basename))
328
+ else:
329
+ self.write_to_console(INFERENCE_STEP_2_PRIMARY_CACHED, "")
330
+
331
+ def cache_source(self, secondary_sources):
332
+
333
+ model_occurrences = self.list_all_models.count(self.model_basename)
334
+
335
+ if not model_occurrences <= 1:
336
+ if self.process_method == MDX_ARCH_TYPE:
337
+ self.cached_model_source_holder(MDX_ARCH_TYPE, secondary_sources, self.model_basename)
338
+
339
+ if self.process_method == VR_ARCH_TYPE:
340
+ self.cached_model_source_holder(VR_ARCH_TYPE, secondary_sources, self.model_basename)
341
+
342
+ if self.process_method == DEMUCS_ARCH_TYPE:
343
+ self.cached_model_source_holder(DEMUCS_ARCH_TYPE, secondary_sources, self.model_basename)
344
+
345
+ def process_vocal_split_chain(self, sources: dict):
346
+
347
+ def is_valid_vocal_split_condition(master_vocal_source):
348
+ """Checks if conditions for vocal split processing are met."""
349
+ conditions = [
350
+ isinstance(master_vocal_source, np.ndarray),
351
+ self.vocal_split_model,
352
+ not self.is_ensemble_mode,
353
+ not self.is_karaoke,
354
+ not self.is_bv_model
355
+ ]
356
+ return all(conditions)
357
+
358
+ # Retrieve sources from the dictionary with default fallbacks
359
+ master_inst_source = sources.get(INST_STEM, None)
360
+ master_vocal_source = sources.get(VOCAL_STEM, None)
361
+
362
+ # Process the vocal split chain if conditions are met
363
+ if is_valid_vocal_split_condition(master_vocal_source):
364
+ process_chain_model(
365
+ self.vocal_split_model,
366
+ self.process_data,
367
+ vocal_stem_path=self.master_vocal_path,
368
+ master_vocal_source=master_vocal_source,
369
+ master_inst_source=master_inst_source
370
+ )
371
+
372
+ def process_secondary_stem(self, stem_source, secondary_model_source=None, model_scale=None):
373
+ if not self.is_secondary_model:
374
+ if self.is_secondary_model_activated and isinstance(secondary_model_source, np.ndarray):
375
+ secondary_model_scale = model_scale if model_scale else self.secondary_model_scale
376
+ stem_source = spec_utils.average_dual_sources(stem_source, secondary_model_source, secondary_model_scale)
377
+
378
+ return stem_source
379
+
380
+ def final_process(self, stem_path, source, secondary_source, stem_name, samplerate):
381
+ source = self.process_secondary_stem(source, secondary_source)
382
+ self.write_audio(stem_path, source, samplerate, stem_name=stem_name)
383
+
384
+ return {stem_name: source}
385
+
386
+ def write_audio(self, stem_path: str, stem_source, samplerate, stem_name=None):
387
+
388
+ def save_audio_file(path, source):
389
+ source = spec_utils.normalize(source, self.is_normalization)
390
+ sf.write(path, source, samplerate, subtype=self.wav_type_set)
391
+
392
+ if is_not_ensemble:
393
+ save_format(path, self.save_format, self.mp3_bit_set)
394
+
395
+ def save_voc_split_instrumental(stem_name, stem_source, is_inst_invert=False):
396
+ inst_stem_name = "Instrumental (With Lead Vocals)" if stem_name == LEAD_VOCAL_STEM else "Instrumental (With Backing Vocals)"
397
+ inst_stem_path_name = LEAD_VOCAL_STEM_I if stem_name == LEAD_VOCAL_STEM else BV_VOCAL_STEM_I
398
+ inst_stem_path = self.audio_file_base_voc_split(INST_STEM, inst_stem_path_name)
399
+ stem_source = -stem_source if is_inst_invert else stem_source
400
+ inst_stem_source = spec_utils.combine_arrarys([self.master_inst_source, stem_source], is_swap=True)
401
+ save_with_message(inst_stem_path, inst_stem_name, inst_stem_source)
402
+
403
+ def save_voc_split_vocal(stem_name, stem_source):
404
+ voc_split_stem_name = LEAD_VOCAL_STEM_LABEL if stem_name == LEAD_VOCAL_STEM else BV_VOCAL_STEM_LABEL
405
+ voc_split_stem_path = self.audio_file_base_voc_split(VOCAL_STEM, stem_name)
406
+ save_with_message(voc_split_stem_path, voc_split_stem_name, stem_source)
407
+
408
+ def save_with_message(stem_path, stem_name, stem_source):
409
+ is_deverb = self.is_deverb_vocals and (
410
+ self.deverb_vocal_opt == stem_name or
411
+ (self.deverb_vocal_opt == 'ALL' and
412
+ (stem_name == VOCAL_STEM or stem_name == LEAD_VOCAL_STEM_LABEL or stem_name == BV_VOCAL_STEM_LABEL)))
413
+
414
+ self.write_to_console(f'{SAVING_STEM[0]}{stem_name}{SAVING_STEM[1]}')
415
+
416
+ if is_deverb and is_not_ensemble:
417
+ deverb_vocals(stem_path, stem_source)
418
+
419
+ save_audio_file(stem_path, stem_source)
420
+ self.write_to_console(DONE, base_text='')
421
+
422
+ def deverb_vocals(stem_path:str, stem_source):
423
+ self.write_to_console(INFERENCE_STEP_DEVERBING, base_text='')
424
+ stem_source_deverbed, stem_source_2 = vr_denoiser(stem_source, self.device, is_deverber=True, model_path=self.DEVERBER_MODEL)
425
+ save_audio_file(stem_path.replace(".wav", "_deverbed.wav"), stem_source_deverbed)
426
+ save_audio_file(stem_path.replace(".wav", "_reverb_only.wav"), stem_source_2)
427
+
428
+ is_bv_model_lead = (self.is_bv_model_rebalenced and self.is_vocal_split_model and stem_name == LEAD_VOCAL_STEM)
429
+ is_bv_rebalance_lead = (self.is_bv_model_rebalenced and self.is_vocal_split_model and stem_name == BV_VOCAL_STEM)
430
+ is_no_vocal_save = self.is_inst_only_voc_splitter and (stem_name == VOCAL_STEM or stem_name == BV_VOCAL_STEM or stem_name == LEAD_VOCAL_STEM) or is_bv_model_lead
431
+ is_not_ensemble = (not self.is_ensemble_mode or self.is_vocal_split_model)
432
+ is_do_not_save_inst = (self.is_save_vocal_only and self.is_sec_bv_rebalance and stem_name == INST_STEM)
433
+
434
+ if is_bv_rebalance_lead:
435
+ master_voc_source = spec_utils.match_array_shapes(self.master_vocal_source, stem_source, is_swap=True)
436
+ bv_rebalance_lead_source = stem_source-master_voc_source
437
+
438
+ if not is_bv_model_lead and not is_do_not_save_inst:
439
+ if self.is_vocal_split_model or not self.is_secondary_model:
440
+ if self.is_vocal_split_model and not self.is_inst_only_voc_splitter:
441
+ save_voc_split_vocal(stem_name, stem_source)
442
+ if is_bv_rebalance_lead:
443
+ save_voc_split_vocal(LEAD_VOCAL_STEM, bv_rebalance_lead_source)
444
+ else:
445
+ if not is_no_vocal_save:
446
+ save_with_message(stem_path, stem_name, stem_source)
447
+
448
+ if self.is_save_inst_vocal_splitter and not self.is_save_vocal_only:
449
+ save_voc_split_instrumental(stem_name, stem_source)
450
+ if is_bv_rebalance_lead:
451
+ save_voc_split_instrumental(LEAD_VOCAL_STEM, bv_rebalance_lead_source, is_inst_invert=True)
452
+
453
+ self.set_progress_bar(0.95)
454
+
455
+ if stem_name == VOCAL_STEM:
456
+ self.master_vocal_path = stem_path
457
+
458
+ def pitch_fix(self, source, sr_pitched, org_mix):
459
+ semitone_shift = self.semitone_shift
460
+ source = spec_utils.change_pitch_semitones(source, sr_pitched, semitone_shift=semitone_shift)[0]
461
+ source = spec_utils.match_array_shapes(source, org_mix)
462
+ return source
463
+
464
+ def match_frequency_pitch(self, mix):
465
+ source = mix
466
+ if self.is_match_frequency_pitch and self.is_pitch_change:
467
+ source, sr_pitched = spec_utils.change_pitch_semitones(mix, 44100, semitone_shift=-self.semitone_shift)
468
+ source = self.pitch_fix(source, sr_pitched, mix)
469
+
470
+ return source
471
+
472
+ class SeperateMDX(SeperateAttributes):
473
+
474
+ def seperate(self):
475
+ samplerate = 44100
476
+
477
+ if self.primary_model_name == self.model_basename and isinstance(self.primary_sources, tuple):
478
+ mix, source = self.primary_sources
479
+ self.load_cached_sources()
480
+ else:
481
+ self.start_inference_console_write()
482
+
483
+ if self.is_mdx_ckpt:
484
+ model_params = torch.load(self.model_path, map_location=lambda storage, loc: storage)['hyper_parameters']
485
+ self.dim_c, self.hop = model_params['dim_c'], model_params['hop_length']
486
+ separator = MdxnetSet.ConvTDFNet(**model_params)
487
+ self.model_run = separator.load_from_checkpoint(self.model_path).to(self.device).eval()
488
+ else:
489
+ if self.mdx_segment_size == self.dim_t and not self.is_other_gpu:
490
+ ort_ = ort.InferenceSession(self.model_path, providers=self.run_type)
491
+ self.model_run = lambda spek:ort_.run(None, {'input': spek.cpu().numpy()})[0]
492
+ else:
493
+ self.model_run = ConvertModel(load(self.model_path))
494
+ self.model_run.to(self.device).eval()
495
+
496
+ self.running_inference_console_write()
497
+ mix = prepare_mix(self.audio_file)
498
+
499
+ source = self.demix(mix)
500
+
501
+ if not self.is_vocal_split_model:
502
+ self.cache_source((mix, source))
503
+ self.write_to_console(DONE, base_text='')
504
+
505
+ mdx_net_cut = True if self.primary_stem in MDX_NET_FREQ_CUT and self.is_match_frequency_pitch else False
506
+
507
+ if self.is_secondary_model_activated and self.secondary_model:
508
+ self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(self.secondary_model, self.process_data, main_process_method=self.process_method, main_model_primary=self.primary_stem)
509
+
510
+ if not self.is_primary_stem_only:
511
+ secondary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.secondary_stem}).wav')
512
+ if not isinstance(self.secondary_source, np.ndarray):
513
+ raw_mix = self.demix(self.match_frequency_pitch(mix), is_match_mix=True) if mdx_net_cut else self.match_frequency_pitch(mix)
514
+ self.secondary_source = spec_utils.invert_stem(raw_mix, source) if self.is_invert_spec else mix.T-source.T
515
+
516
+ self.secondary_source_map = self.final_process(secondary_stem_path, self.secondary_source, self.secondary_source_secondary, self.secondary_stem, samplerate)
517
+
518
+ if not self.is_secondary_stem_only:
519
+ primary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.primary_stem}).wav')
520
+
521
+ if not isinstance(self.primary_source, np.ndarray):
522
+ self.primary_source = source.T
523
+
524
+ self.primary_source_map = self.final_process(primary_stem_path, self.primary_source, self.secondary_source_primary, self.primary_stem, samplerate)
525
+
526
+ clear_gpu_cache()
527
+
528
+ secondary_sources = {**self.primary_source_map, **self.secondary_source_map}
529
+
530
+ self.process_vocal_split_chain(secondary_sources)
531
+
532
+ if self.is_secondary_model or self.is_pre_proc_model:
533
+ return secondary_sources
534
+
535
+ def initialize_model_settings(self):
536
+ self.n_bins = self.n_fft//2+1
537
+ self.trim = self.n_fft//2
538
+ self.chunk_size = self.hop * (self.mdx_segment_size-1)
539
+ self.gen_size = self.chunk_size-2*self.trim
540
+ self.stft = STFT(self.n_fft, self.hop, self.dim_f, self.device)
541
+
542
+ def demix(self, mix, is_match_mix=False):
543
+ self.initialize_model_settings()
544
+
545
+ org_mix = mix
546
+ tar_waves_ = []
547
+
548
+ if is_match_mix:
549
+ chunk_size = self.hop * (256-1)
550
+ overlap = 0.02
551
+ else:
552
+ chunk_size = self.chunk_size
553
+ overlap = self.overlap_mdx
554
+
555
+ if self.is_pitch_change:
556
+ mix, sr_pitched = spec_utils.change_pitch_semitones(mix, 44100, semitone_shift=-self.semitone_shift)
557
+
558
+ gen_size = chunk_size-2*self.trim
559
+
560
+ pad = gen_size + self.trim - ((mix.shape[-1]) % gen_size)
561
+ mixture = np.concatenate((np.zeros((2, self.trim), dtype='float32'), mix, np.zeros((2, pad), dtype='float32')), 1)
562
+
563
+ step = self.chunk_size - self.n_fft if overlap == DEFAULT else int((1 - overlap) * chunk_size)
564
+ result = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
565
+ divider = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
566
+ total = 0
567
+ total_chunks = (mixture.shape[-1] + step - 1) // step
568
+
569
+ for i in range(0, mixture.shape[-1], step):
570
+ total += 1
571
+ start = i
572
+ end = min(i + chunk_size, mixture.shape[-1])
573
+
574
+ chunk_size_actual = end - start
575
+
576
+ if overlap == 0:
577
+ window = None
578
+ else:
579
+ window = np.hanning(chunk_size_actual)
580
+ window = np.tile(window[None, None, :], (1, 2, 1))
581
+
582
+ mix_part_ = mixture[:, start:end]
583
+ if end != i + chunk_size:
584
+ pad_size = (i + chunk_size) - end
585
+ mix_part_ = np.concatenate((mix_part_, np.zeros((2, pad_size), dtype='float32')), axis=-1)
586
+
587
+ mix_part = torch.tensor([mix_part_], dtype=torch.float32).to(self.device)
588
+ mix_waves = mix_part.split(self.mdx_batch_size)
589
+
590
+ with torch.no_grad():
591
+ for mix_wave in mix_waves:
592
+ self.running_inference_progress_bar(total_chunks, is_match_mix=is_match_mix)
593
+
594
+ tar_waves = self.run_model(mix_wave, is_match_mix=is_match_mix)
595
+
596
+ if window is not None:
597
+ tar_waves[..., :chunk_size_actual] *= window
598
+ divider[..., start:end] += window
599
+ else:
600
+ divider[..., start:end] += 1
601
+
602
+ result[..., start:end] += tar_waves[..., :end-start]
603
+
604
+ tar_waves = result / divider
605
+ tar_waves_.append(tar_waves)
606
+
607
+ tar_waves_ = np.vstack(tar_waves_)[:, :, self.trim:-self.trim]
608
+ tar_waves = np.concatenate(tar_waves_, axis=-1)[:, :mix.shape[-1]]
609
+
610
+ source = tar_waves[:,0:None]
611
+
612
+ if self.is_pitch_change and not is_match_mix:
613
+ source = self.pitch_fix(source, sr_pitched, org_mix)
614
+
615
+ source = source if is_match_mix else source*self.compensate
616
+
617
+ if self.is_denoise_model and not is_match_mix:
618
+ if NO_STEM in self.primary_stem_native or self.primary_stem_native == INST_STEM:
619
+ if org_mix.shape[1] != source.shape[1]:
620
+ source = spec_utils.match_array_shapes(source, org_mix)
621
+ source = org_mix - vr_denoiser(org_mix-source, self.device, model_path=self.DENOISER_MODEL)
622
+ else:
623
+ source = vr_denoiser(source, self.device, model_path=self.DENOISER_MODEL)
624
+
625
+ return source
626
+
627
+ def run_model(self, mix, is_match_mix=False):
628
+
629
+ spek = self.stft(mix.to(self.device))*self.adjust
630
+ spek[:, :, :3, :] *= 0
631
+
632
+ if is_match_mix:
633
+ spec_pred = spek.cpu().numpy()
634
+ else:
635
+ spec_pred = -self.model_run(-spek)*0.5+self.model_run(spek)*0.5 if self.is_denoise else self.model_run(spek)
636
+
637
+ return self.stft.inverse(torch.tensor(spec_pred).to(self.device)).cpu().detach().numpy()
638
+
639
+ class SeperateMDXC(SeperateAttributes):
640
+
641
+ def seperate(self):
642
+ samplerate = 44100
643
+ sources = None
644
+
645
+ if self.primary_model_name == self.model_basename and isinstance(self.primary_sources, tuple):
646
+ mix, sources = self.primary_sources
647
+ self.load_cached_sources()
648
+ else:
649
+ self.start_inference_console_write()
650
+ self.running_inference_console_write()
651
+ mix = prepare_mix(self.audio_file)
652
+ sources = self.demix(mix)
653
+ if not self.is_vocal_split_model:
654
+ self.cache_source((mix, sources))
655
+ self.write_to_console(DONE, base_text='')
656
+
657
+ stem_list = [self.mdx_c_configs.training.target_instrument] if self.mdx_c_configs.training.target_instrument else [i for i in self.mdx_c_configs.training.instruments]
658
+
659
+ if self.is_secondary_model:
660
+ if self.is_pre_proc_model:
661
+ self.mdxnet_stem_select = stem_list[0]
662
+ else:
663
+ self.mdxnet_stem_select = self.main_model_primary_stem_4_stem if self.main_model_primary_stem_4_stem else self.primary_model_primary_stem
664
+ self.primary_stem = self.mdxnet_stem_select
665
+ self.secondary_stem = secondary_stem(self.mdxnet_stem_select)
666
+ self.is_primary_stem_only, self.is_secondary_stem_only = False, False
667
+
668
+ is_all_stems = self.mdxnet_stem_select == ALL_STEMS
669
+ is_not_ensemble_master = not self.process_data['is_ensemble_master']
670
+ is_not_single_stem = not len(stem_list) <= 2
671
+ is_not_secondary_model = not self.is_secondary_model
672
+ is_ensemble_4_stem = self.is_4_stem_ensemble and is_not_single_stem
673
+
674
+ if (is_all_stems and is_not_ensemble_master and is_not_single_stem and is_not_secondary_model) or is_ensemble_4_stem and not self.is_pre_proc_model:
675
+ for stem in stem_list:
676
+ primary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({stem}).wav')
677
+ self.primary_source = sources[stem].T
678
+ self.write_audio(primary_stem_path, self.primary_source, samplerate, stem_name=stem)
679
+
680
+ if stem == VOCAL_STEM and not self.is_sec_bv_rebalance:
681
+ self.process_vocal_split_chain({VOCAL_STEM:stem})
682
+ else:
683
+ if len(stem_list) == 1:
684
+ source_primary = sources
685
+ else:
686
+ source_primary = sources[stem_list[0]] if self.is_multi_stem_ensemble and len(stem_list) == 2 else sources[self.mdxnet_stem_select]
687
+ if self.is_secondary_model_activated and self.secondary_model:
688
+ self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(self.secondary_model,
689
+ self.process_data,
690
+ main_process_method=self.process_method,
691
+ main_model_primary=self.primary_stem)
692
+
693
+ if not self.is_primary_stem_only:
694
+ secondary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.secondary_stem}).wav')
695
+ if not isinstance(self.secondary_source, np.ndarray):
696
+
697
+ if self.is_mdx_combine_stems and len(stem_list) >= 2:
698
+ if len(stem_list) == 2:
699
+ secondary_source = sources[self.secondary_stem]
700
+ else:
701
+ sources.pop(self.primary_stem)
702
+ next_stem = next(iter(sources))
703
+ secondary_source = np.zeros_like(sources[next_stem])
704
+ for v in sources.values():
705
+ secondary_source += v
706
+
707
+ self.secondary_source = secondary_source.T
708
+ else:
709
+ self.secondary_source, raw_mix = source_primary, self.match_frequency_pitch(mix)
710
+ self.secondary_source = spec_utils.to_shape(self.secondary_source, raw_mix.shape)
711
+
712
+ if self.is_invert_spec:
713
+ self.secondary_source = spec_utils.invert_stem(raw_mix, self.secondary_source)
714
+ else:
715
+ self.secondary_source = (-self.secondary_source.T+raw_mix.T)
716
+
717
+ self.secondary_source_map = self.final_process(secondary_stem_path, self.secondary_source, self.secondary_source_secondary, self.secondary_stem, samplerate)
718
+
719
+ if not self.is_secondary_stem_only:
720
+ primary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.primary_stem}).wav')
721
+ if not isinstance(self.primary_source, np.ndarray):
722
+ self.primary_source = source_primary.T
723
+
724
+ self.primary_source_map = self.final_process(primary_stem_path, self.primary_source, self.secondary_source_primary, self.primary_stem, samplerate)
725
+
726
+ clear_gpu_cache()
727
+
728
+ secondary_sources = {**self.primary_source_map, **self.secondary_source_map}
729
+ self.process_vocal_split_chain(secondary_sources)
730
+
731
+ if self.is_secondary_model or self.is_pre_proc_model:
732
+ return secondary_sources
733
+
734
+ def demix(self, mix):
735
+ sr_pitched = 441000
736
+ org_mix = mix
737
+ if self.is_pitch_change:
738
+ mix, sr_pitched = spec_utils.change_pitch_semitones(mix, 44100, semitone_shift=-self.semitone_shift)
739
+
740
+ model = TFC_TDF_net(self.mdx_c_configs, device=self.device)
741
+ model.load_state_dict(torch.load(self.model_path, map_location=cpu))
742
+ model.to(self.device).eval()
743
+ mix = torch.tensor(mix, dtype=torch.float32)
744
+
745
+ try:
746
+ S = model.num_target_instruments
747
+ except Exception as e:
748
+ S = model.module.num_target_instruments
749
+
750
+ mdx_segment_size = self.mdx_c_configs.inference.dim_t if self.is_mdx_c_seg_def else self.mdx_segment_size
751
+
752
+ batch_size = self.mdx_batch_size
753
+ chunk_size = self.mdx_c_configs.audio.hop_length * (mdx_segment_size - 1)
754
+ overlap = self.overlap_mdx23
755
+
756
+ hop_size = chunk_size // overlap
757
+ mix_shape = mix.shape[1]
758
+ pad_size = hop_size - (mix_shape - chunk_size) % hop_size
759
+ mix = torch.cat([torch.zeros(2, chunk_size - hop_size), mix, torch.zeros(2, pad_size + chunk_size - hop_size)], 1)
760
+
761
+ chunks = mix.unfold(1, chunk_size, hop_size).transpose(0, 1)
762
+ batches = [chunks[i : i + batch_size] for i in range(0, len(chunks), batch_size)]
763
+
764
+ X = torch.zeros(S, *mix.shape) if S > 1 else torch.zeros_like(mix)
765
+ X = X.to(self.device)
766
+
767
+ with torch.no_grad():
768
+ cnt = 0
769
+ for batch in batches:
770
+ self.running_inference_progress_bar(len(batches))
771
+ x = model(batch.to(self.device))
772
+
773
+ for w in x:
774
+ X[..., cnt * hop_size : cnt * hop_size + chunk_size] += w
775
+ cnt += 1
776
+
777
+ estimated_sources = X[..., chunk_size - hop_size:-(pad_size + chunk_size - hop_size)] / overlap
778
+ del X
779
+ pitch_fix = lambda s:self.pitch_fix(s, sr_pitched, org_mix)
780
+
781
+ if S > 1:
782
+ sources = {k: pitch_fix(v) if self.is_pitch_change else v for k, v in zip(self.mdx_c_configs.training.instruments, estimated_sources.cpu().detach().numpy())}
783
+ del estimated_sources
784
+ if self.is_denoise_model:
785
+ if VOCAL_STEM in sources.keys() and INST_STEM in sources.keys():
786
+ sources[VOCAL_STEM] = vr_denoiser(sources[VOCAL_STEM], self.device, model_path=self.DENOISER_MODEL)
787
+ if sources[VOCAL_STEM].shape[1] != org_mix.shape[1]:
788
+ sources[VOCAL_STEM] = spec_utils.match_array_shapes(sources[VOCAL_STEM], org_mix)
789
+ sources[INST_STEM] = org_mix - sources[VOCAL_STEM]
790
+
791
+ return sources
792
+ else:
793
+ est_s = estimated_sources.cpu().detach().numpy()
794
+ del estimated_sources
795
+ return pitch_fix(est_s) if self.is_pitch_change else est_s
796
+
797
+ class SeperateDemucs(SeperateAttributes):
798
+ def seperate(self):
799
+ samplerate = 44100
800
+ source = None
801
+ model_scale = None
802
+ stem_source = None
803
+ stem_source_secondary = None
804
+ inst_mix = None
805
+ inst_source = None
806
+ is_no_write = False
807
+ is_no_piano_guitar = False
808
+ is_no_cache = False
809
+
810
+ if self.primary_model_name == self.model_basename and isinstance(self.primary_sources, np.ndarray) and not self.pre_proc_model:
811
+ source = self.primary_sources
812
+ self.load_cached_sources()
813
+ else:
814
+ self.start_inference_console_write()
815
+ is_no_cache = True
816
+
817
+ mix = prepare_mix(self.audio_file)
818
+
819
+ if is_no_cache:
820
+ if self.demucs_version == DEMUCS_V1:
821
+ if str(self.model_path).endswith(".gz"):
822
+ self.model_path = gzip.open(self.model_path, "rb")
823
+ klass, args, kwargs, state = torch.load(self.model_path)
824
+ self.demucs = klass(*args, **kwargs)
825
+ self.demucs.to(self.device)
826
+ self.demucs.load_state_dict(state)
827
+ elif self.demucs_version == DEMUCS_V2:
828
+ self.demucs = auto_load_demucs_model_v2(self.demucs_source_list, self.model_path)
829
+ self.demucs.to(self.device)
830
+ self.demucs.load_state_dict(torch.load(self.model_path))
831
+ self.demucs.eval()
832
+ else:
833
+ self.demucs = HDemucs(sources=self.demucs_source_list)
834
+ self.demucs = _gm(name=os.path.splitext(os.path.basename(self.model_path))[0],
835
+ repo=Path(os.path.dirname(self.model_path)))
836
+ self.demucs = demucs_segments(self.segment, self.demucs)
837
+ self.demucs.to(self.device)
838
+ self.demucs.eval()
839
+
840
+ if self.pre_proc_model:
841
+ if self.primary_stem not in [VOCAL_STEM, INST_STEM]:
842
+ is_no_write = True
843
+ self.write_to_console(DONE, base_text='')
844
+ mix_no_voc = process_secondary_model(self.pre_proc_model, self.process_data, is_pre_proc_model=True)
845
+ inst_mix = prepare_mix(mix_no_voc[INST_STEM])
846
+ self.process_iteration()
847
+ self.running_inference_console_write(is_no_write=is_no_write)
848
+ inst_source = self.demix_demucs(inst_mix)
849
+ self.process_iteration()
850
+
851
+ self.running_inference_console_write(is_no_write=is_no_write) if not self.pre_proc_model else None
852
+
853
+ if self.primary_model_name == self.model_basename and isinstance(self.primary_sources, np.ndarray) and self.pre_proc_model:
854
+ source = self.primary_sources
855
+ else:
856
+ source = self.demix_demucs(mix)
857
+
858
+ self.write_to_console(DONE, base_text='')
859
+
860
+ del self.demucs
861
+ clear_gpu_cache()
862
+
863
+ if isinstance(inst_source, np.ndarray):
864
+ source_reshape = spec_utils.reshape_sources(inst_source[self.demucs_source_map[VOCAL_STEM]], source[self.demucs_source_map[VOCAL_STEM]])
865
+ inst_source[self.demucs_source_map[VOCAL_STEM]] = source_reshape
866
+ source = inst_source
867
+
868
+ if isinstance(source, np.ndarray):
869
+
870
+ if len(source) == 2:
871
+ self.demucs_source_map = DEMUCS_2_SOURCE_MAPPER
872
+ else:
873
+ self.demucs_source_map = DEMUCS_6_SOURCE_MAPPER if len(source) == 6 else DEMUCS_4_SOURCE_MAPPER
874
+
875
+ if len(source) == 6 and self.process_data['is_ensemble_master'] or len(source) == 6 and self.is_secondary_model:
876
+ is_no_piano_guitar = True
877
+ six_stem_other_source = list(source)
878
+ six_stem_other_source = [i for n, i in enumerate(source) if n in [self.demucs_source_map[OTHER_STEM], self.demucs_source_map[GUITAR_STEM], self.demucs_source_map[PIANO_STEM]]]
879
+ other_source = np.zeros_like(six_stem_other_source[0])
880
+ for i in six_stem_other_source:
881
+ other_source += i
882
+ source_reshape = spec_utils.reshape_sources(source[self.demucs_source_map[OTHER_STEM]], other_source)
883
+ source[self.demucs_source_map[OTHER_STEM]] = source_reshape
884
+
885
+ if not self.is_vocal_split_model:
886
+ self.cache_source(source)
887
+
888
+ if (self.demucs_stems == ALL_STEMS and not self.process_data['is_ensemble_master']) or self.is_4_stem_ensemble and not self.is_return_dual:
889
+ for stem_name, stem_value in self.demucs_source_map.items():
890
+ if self.is_secondary_model_activated and not self.is_secondary_model and not stem_value >= 4:
891
+ if self.secondary_model_4_stem[stem_value]:
892
+ model_scale = self.secondary_model_4_stem_scale[stem_value]
893
+ stem_source_secondary = process_secondary_model(self.secondary_model_4_stem[stem_value], self.process_data, main_model_primary_stem_4_stem=stem_name, is_source_load=True, is_return_dual=False)
894
+ if isinstance(stem_source_secondary, np.ndarray):
895
+ stem_source_secondary = stem_source_secondary[1 if self.secondary_model_4_stem[stem_value].demucs_stem_count == 2 else stem_value].T
896
+ elif type(stem_source_secondary) is dict:
897
+ stem_source_secondary = stem_source_secondary[stem_name]
898
+
899
+ stem_source_secondary = None if stem_value >= 4 else stem_source_secondary
900
+ stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({stem_name}).wav')
901
+ stem_source = source[stem_value].T
902
+
903
+ stem_source = self.process_secondary_stem(stem_source, secondary_model_source=stem_source_secondary, model_scale=model_scale)
904
+ self.write_audio(stem_path, stem_source, samplerate, stem_name=stem_name)
905
+
906
+ if stem_name == VOCAL_STEM and not self.is_sec_bv_rebalance:
907
+ self.process_vocal_split_chain({VOCAL_STEM:stem_source})
908
+
909
+ if self.is_secondary_model:
910
+ return source
911
+ else:
912
+ if self.is_secondary_model_activated and self.secondary_model:
913
+ self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(self.secondary_model, self.process_data, main_process_method=self.process_method)
914
+
915
+ if not self.is_primary_stem_only:
916
+ def secondary_save(sec_stem_name, source, raw_mixture=None, is_inst_mixture=False):
917
+ secondary_source = self.secondary_source if not is_inst_mixture else None
918
+ secondary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({sec_stem_name}).wav')
919
+ secondary_source_secondary = None
920
+
921
+ if not isinstance(secondary_source, np.ndarray):
922
+ if self.is_demucs_combine_stems:
923
+ source = list(source)
924
+ if is_inst_mixture:
925
+ source = [i for n, i in enumerate(source) if not n in [self.demucs_source_map[self.primary_stem], self.demucs_source_map[VOCAL_STEM]]]
926
+ else:
927
+ source.pop(self.demucs_source_map[self.primary_stem])
928
+
929
+ source = source[:len(source) - 2] if is_no_piano_guitar else source
930
+ secondary_source = np.zeros_like(source[0])
931
+ for i in source:
932
+ secondary_source += i
933
+ secondary_source = secondary_source.T
934
+ else:
935
+ if not isinstance(raw_mixture, np.ndarray):
936
+ raw_mixture = prepare_mix(self.audio_file)
937
+
938
+ secondary_source = source[self.demucs_source_map[self.primary_stem]]
939
+
940
+ if self.is_invert_spec:
941
+ secondary_source = spec_utils.invert_stem(raw_mixture, secondary_source)
942
+ else:
943
+ raw_mixture = spec_utils.reshape_sources(secondary_source, raw_mixture)
944
+ secondary_source = (-secondary_source.T+raw_mixture.T)
945
+
946
+ if not is_inst_mixture:
947
+ self.secondary_source = secondary_source
948
+ secondary_source_secondary = self.secondary_source_secondary
949
+ self.secondary_source = self.process_secondary_stem(secondary_source, secondary_source_secondary)
950
+ self.secondary_source_map = {self.secondary_stem: self.secondary_source}
951
+
952
+ self.write_audio(secondary_stem_path, secondary_source, samplerate, stem_name=sec_stem_name)
953
+
954
+ secondary_save(self.secondary_stem, source, raw_mixture=mix)
955
+
956
+ if self.is_demucs_pre_proc_model_inst_mix and self.pre_proc_model and not self.is_4_stem_ensemble:
957
+ secondary_save(f"{self.secondary_stem} {INST_STEM}", source, raw_mixture=inst_mix, is_inst_mixture=True)
958
+
959
+ if not self.is_secondary_stem_only:
960
+ primary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.primary_stem}).wav')
961
+ if not isinstance(self.primary_source, np.ndarray):
962
+ self.primary_source = source[self.demucs_source_map[self.primary_stem]].T
963
+
964
+ self.primary_source_map = self.final_process(primary_stem_path, self.primary_source, self.secondary_source_primary, self.primary_stem, samplerate)
965
+
966
+ secondary_sources = {**self.primary_source_map, **self.secondary_source_map}
967
+
968
+ self.process_vocal_split_chain(secondary_sources)
969
+
970
+ if self.is_secondary_model:
971
+ return secondary_sources
972
+
973
+ def demix_demucs(self, mix):
974
+
975
+ org_mix = mix
976
+
977
+ if self.is_pitch_change:
978
+ mix, sr_pitched = spec_utils.change_pitch_semitones(mix, 44100, semitone_shift=-self.semitone_shift)
979
+
980
+ processed = {}
981
+ mix = torch.tensor(mix, dtype=torch.float32)
982
+ ref = mix.mean(0)
983
+ mix = (mix - ref.mean()) / ref.std()
984
+ mix_infer = mix
985
+
986
+ with torch.no_grad():
987
+ if self.demucs_version == DEMUCS_V1:
988
+ sources = apply_model_v1(self.demucs,
989
+ mix_infer.to(self.device),
990
+ self.shifts,
991
+ self.is_split_mode,
992
+ set_progress_bar=self.set_progress_bar)
993
+ elif self.demucs_version == DEMUCS_V2:
994
+ sources = apply_model_v2(self.demucs,
995
+ mix_infer.to(self.device),
996
+ self.shifts,
997
+ self.is_split_mode,
998
+ self.overlap,
999
+ set_progress_bar=self.set_progress_bar)
1000
+ else:
1001
+ sources = apply_model(self.demucs,
1002
+ mix_infer[None],
1003
+ self.shifts,
1004
+ self.is_split_mode,
1005
+ self.overlap,
1006
+ static_shifts=1 if self.shifts == 0 else self.shifts,
1007
+ set_progress_bar=self.set_progress_bar,
1008
+ device=self.device)[0]
1009
+
1010
+ sources = (sources * ref.std() + ref.mean()).cpu().numpy()
1011
+ sources[[0,1]] = sources[[1,0]]
1012
+ processed[mix] = sources[:,:,0:None].copy()
1013
+ sources = list(processed.values())
1014
+ sources = [s[:,:,0:None] for s in sources]
1015
+ #sources = [self.pitch_fix(s[:,:,0:None], sr_pitched, org_mix) if self.is_pitch_change else s[:,:,0:None] for s in sources]
1016
+ sources = np.concatenate(sources, axis=-1)
1017
+
1018
+ if self.is_pitch_change:
1019
+ sources = np.stack([self.pitch_fix(stem, sr_pitched, org_mix) for stem in sources])
1020
+
1021
+ return sources
1022
+
1023
+ class SeperateVR(SeperateAttributes):
1024
+
1025
+ def seperate(self):
1026
+ if self.primary_model_name == self.model_basename and isinstance(self.primary_sources, tuple):
1027
+ y_spec, v_spec = self.primary_sources
1028
+ self.load_cached_sources()
1029
+ else:
1030
+ self.start_inference_console_write()
1031
+
1032
+ device = self.device
1033
+
1034
+ nn_arch_sizes = [
1035
+ 31191, # default
1036
+ 33966, 56817, 123821, 123812, 129605, 218409, 537238, 537227]
1037
+ vr_5_1_models = [56817, 218409]
1038
+ model_size = math.ceil(os.stat(self.model_path).st_size / 1024)
1039
+ nn_arch_size = min(nn_arch_sizes, key=lambda x:abs(x-model_size))
1040
+
1041
+ if nn_arch_size in vr_5_1_models or self.is_vr_51_model:
1042
+ self.model_run = nets_new.CascadedNet(self.mp.param['bins'] * 2,
1043
+ nn_arch_size,
1044
+ nout=self.model_capacity[0],
1045
+ nout_lstm=self.model_capacity[1])
1046
+ self.is_vr_51_model = True
1047
+ else:
1048
+ self.model_run = nets.determine_model_capacity(self.mp.param['bins'] * 2, nn_arch_size)
1049
+
1050
+ self.model_run.load_state_dict(torch.load(self.model_path, map_location=cpu))
1051
+ self.model_run.to(device)
1052
+
1053
+ self.running_inference_console_write()
1054
+
1055
+ y_spec, v_spec = self.inference_vr(self.loading_mix(), device, self.aggressiveness)
1056
+ if not self.is_vocal_split_model:
1057
+ self.cache_source((y_spec, v_spec))
1058
+ self.write_to_console(DONE, base_text='')
1059
+
1060
+ if self.is_secondary_model_activated and self.secondary_model:
1061
+ self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(self.secondary_model, self.process_data, main_process_method=self.process_method, main_model_primary=self.primary_stem)
1062
+
1063
+ if not self.is_secondary_stem_only:
1064
+ primary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.primary_stem}).wav')
1065
+ if not isinstance(self.primary_source, np.ndarray):
1066
+ self.primary_source = self.spec_to_wav(y_spec).T
1067
+ if not self.model_samplerate == 44100:
1068
+ self.primary_source = librosa.resample(self.primary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T
1069
+
1070
+ self.primary_source_map = self.final_process(primary_stem_path, self.primary_source, self.secondary_source_primary, self.primary_stem, 44100)
1071
+
1072
+ if not self.is_primary_stem_only:
1073
+ secondary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.secondary_stem}).wav')
1074
+ if not isinstance(self.secondary_source, np.ndarray):
1075
+ self.secondary_source = self.spec_to_wav(v_spec).T
1076
+ if not self.model_samplerate == 44100:
1077
+ self.secondary_source = librosa.resample(self.secondary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T
1078
+
1079
+ self.secondary_source_map = self.final_process(secondary_stem_path, self.secondary_source, self.secondary_source_secondary, self.secondary_stem, 44100)
1080
+
1081
+ clear_gpu_cache()
1082
+ secondary_sources = {**self.primary_source_map, **self.secondary_source_map}
1083
+
1084
+ self.process_vocal_split_chain(secondary_sources)
1085
+
1086
+ if self.is_secondary_model:
1087
+ return secondary_sources
1088
+
1089
+ def loading_mix(self):
1090
+
1091
+ X_wave, X_spec_s = {}, {}
1092
+
1093
+ bands_n = len(self.mp.param['band'])
1094
+
1095
+ audio_file = spec_utils.write_array_to_mem(self.audio_file, subtype=self.wav_type_set)
1096
+ is_mp3 = audio_file.endswith('.mp3') if isinstance(audio_file, str) else False
1097
+
1098
+ for d in range(bands_n, 0, -1):
1099
+ bp = self.mp.param['band'][d]
1100
+
1101
+ if OPERATING_SYSTEM == 'Darwin':
1102
+ wav_resolution = 'polyphase' if SYSTEM_PROC == ARM or ARM in SYSTEM_ARCH else bp['res_type']
1103
+ else:
1104
+ wav_resolution = bp['res_type']
1105
+
1106
+ if d == bands_n: # high-end band
1107
+ X_wave[d], _ = librosa.load(audio_file, bp['sr'], False, dtype=np.float32, res_type=wav_resolution)
1108
+ X_spec_s[d] = spec_utils.wave_to_spectrogram(X_wave[d], bp['hl'], bp['n_fft'], self.mp, band=d, is_v51_model=self.is_vr_51_model)
1109
+
1110
+ if not np.any(X_wave[d]) and is_mp3:
1111
+ X_wave[d] = rerun_mp3(audio_file, bp['sr'])
1112
+
1113
+ if X_wave[d].ndim == 1:
1114
+ X_wave[d] = np.asarray([X_wave[d], X_wave[d]])
1115
+ else: # lower bands
1116
+ X_wave[d] = librosa.resample(X_wave[d+1], self.mp.param['band'][d+1]['sr'], bp['sr'], res_type=wav_resolution)
1117
+ X_spec_s[d] = spec_utils.wave_to_spectrogram(X_wave[d], bp['hl'], bp['n_fft'], self.mp, band=d, is_v51_model=self.is_vr_51_model)
1118
+
1119
+ if d == bands_n and self.high_end_process != 'none':
1120
+ self.input_high_end_h = (bp['n_fft']//2 - bp['crop_stop']) + (self.mp.param['pre_filter_stop'] - self.mp.param['pre_filter_start'])
1121
+ self.input_high_end = X_spec_s[d][:, bp['n_fft']//2-self.input_high_end_h:bp['n_fft']//2, :]
1122
+
1123
+ X_spec = spec_utils.combine_spectrograms(X_spec_s, self.mp, is_v51_model=self.is_vr_51_model)
1124
+
1125
+ del X_wave, X_spec_s, audio_file
1126
+
1127
+ return X_spec
1128
+
1129
+ def inference_vr(self, X_spec, device, aggressiveness):
1130
+ def _execute(X_mag_pad, roi_size):
1131
+ X_dataset = []
1132
+ patches = (X_mag_pad.shape[2] - 2 * self.model_run.offset) // roi_size
1133
+ total_iterations = patches//self.batch_size if not self.is_tta else (patches//self.batch_size)*2
1134
+ for i in range(patches):
1135
+ start = i * roi_size
1136
+ X_mag_window = X_mag_pad[:, :, start:start + self.window_size]
1137
+ X_dataset.append(X_mag_window)
1138
+
1139
+ X_dataset = np.asarray(X_dataset)
1140
+ self.model_run.eval()
1141
+ with torch.no_grad():
1142
+ mask = []
1143
+ for i in range(0, patches, self.batch_size):
1144
+ self.progress_value += 1
1145
+ if self.progress_value >= total_iterations:
1146
+ self.progress_value = total_iterations
1147
+ self.set_progress_bar(0.1, 0.8/total_iterations*self.progress_value)
1148
+ X_batch = X_dataset[i: i + self.batch_size]
1149
+ X_batch = torch.from_numpy(X_batch).to(device)
1150
+ pred = self.model_run.predict_mask(X_batch)
1151
+ if not pred.size()[3] > 0:
1152
+ raise Exception(ERROR_MAPPER[WINDOW_SIZE_ERROR])
1153
+ pred = pred.detach().cpu().numpy()
1154
+ pred = np.concatenate(pred, axis=2)
1155
+ mask.append(pred)
1156
+ if len(mask) == 0:
1157
+ raise Exception(ERROR_MAPPER[WINDOW_SIZE_ERROR])
1158
+
1159
+ mask = np.concatenate(mask, axis=2)
1160
+ return mask
1161
+
1162
+ def postprocess(mask, X_mag, X_phase):
1163
+ is_non_accom_stem = False
1164
+ for stem in NON_ACCOM_STEMS:
1165
+ if stem == self.primary_stem:
1166
+ is_non_accom_stem = True
1167
+
1168
+ mask = spec_utils.adjust_aggr(mask, is_non_accom_stem, aggressiveness)
1169
+
1170
+ if self.is_post_process:
1171
+ mask = spec_utils.merge_artifacts(mask, thres=self.post_process_threshold)
1172
+
1173
+ y_spec = mask * X_mag * np.exp(1.j * X_phase)
1174
+ v_spec = (1 - mask) * X_mag * np.exp(1.j * X_phase)
1175
+
1176
+ return y_spec, v_spec
1177
+
1178
+ X_mag, X_phase = spec_utils.preprocess(X_spec)
1179
+ n_frame = X_mag.shape[2]
1180
+ pad_l, pad_r, roi_size = spec_utils.make_padding(n_frame, self.window_size, self.model_run.offset)
1181
+ X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant')
1182
+ X_mag_pad /= X_mag_pad.max()
1183
+ mask = _execute(X_mag_pad, roi_size)
1184
+
1185
+ if self.is_tta:
1186
+ pad_l += roi_size // 2
1187
+ pad_r += roi_size // 2
1188
+ X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant')
1189
+ X_mag_pad /= X_mag_pad.max()
1190
+ mask_tta = _execute(X_mag_pad, roi_size)
1191
+ mask_tta = mask_tta[:, :, roi_size // 2:]
1192
+ mask = (mask[:, :, :n_frame] + mask_tta[:, :, :n_frame]) * 0.5
1193
+ else:
1194
+ mask = mask[:, :, :n_frame]
1195
+
1196
+ y_spec, v_spec = postprocess(mask, X_mag, X_phase)
1197
+
1198
+ return y_spec, v_spec
1199
+
1200
+ def spec_to_wav(self, spec):
1201
+ if self.high_end_process.startswith('mirroring') and isinstance(self.input_high_end, np.ndarray) and self.input_high_end_h:
1202
+ input_high_end_ = spec_utils.mirroring(self.high_end_process, spec, self.input_high_end, self.mp)
1203
+ wav = spec_utils.cmb_spectrogram_to_wave(spec, self.mp, self.input_high_end_h, input_high_end_, is_v51_model=self.is_vr_51_model)
1204
+ else:
1205
+ wav = spec_utils.cmb_spectrogram_to_wave(spec, self.mp, is_v51_model=self.is_vr_51_model)
1206
+
1207
+ return wav
1208
+
1209
+ def process_secondary_model(secondary_model: ModelData,
1210
+ process_data,
1211
+ main_model_primary_stem_4_stem=None,
1212
+ is_source_load=False,
1213
+ main_process_method=None,
1214
+ is_pre_proc_model=False,
1215
+ is_return_dual=True,
1216
+ main_model_primary=None):
1217
+
1218
+ if not is_pre_proc_model:
1219
+ process_iteration = process_data['process_iteration']
1220
+ process_iteration()
1221
+
1222
+ if secondary_model.process_method == VR_ARCH_TYPE:
1223
+ seperator = SeperateVR(secondary_model, process_data, main_model_primary_stem_4_stem=main_model_primary_stem_4_stem, main_process_method=main_process_method, main_model_primary=main_model_primary)
1224
+ if secondary_model.process_method == MDX_ARCH_TYPE:
1225
+ if secondary_model.is_mdx_c:
1226
+ seperator = SeperateMDXC(secondary_model, process_data, main_model_primary_stem_4_stem=main_model_primary_stem_4_stem, main_process_method=main_process_method, is_return_dual=is_return_dual, main_model_primary=main_model_primary)
1227
+ else:
1228
+ seperator = SeperateMDX(secondary_model, process_data, main_model_primary_stem_4_stem=main_model_primary_stem_4_stem, main_process_method=main_process_method, main_model_primary=main_model_primary)
1229
+ if secondary_model.process_method == DEMUCS_ARCH_TYPE:
1230
+ seperator = SeperateDemucs(secondary_model, process_data, main_model_primary_stem_4_stem=main_model_primary_stem_4_stem, main_process_method=main_process_method, is_return_dual=is_return_dual, main_model_primary=main_model_primary)
1231
+
1232
+ secondary_sources = seperator.seperate()
1233
+
1234
+ if type(secondary_sources) is dict and not is_source_load and not is_pre_proc_model:
1235
+ return gather_sources(secondary_model.primary_model_primary_stem, secondary_stem(secondary_model.primary_model_primary_stem), secondary_sources)
1236
+ else:
1237
+ return secondary_sources
1238
+
1239
+ def process_chain_model(secondary_model: ModelData,
1240
+ process_data,
1241
+ vocal_stem_path,
1242
+ master_vocal_source,
1243
+ master_inst_source=None):
1244
+
1245
+ process_iteration = process_data['process_iteration']
1246
+ process_iteration()
1247
+
1248
+ if secondary_model.bv_model_rebalance:
1249
+ vocal_source = spec_utils.reduce_mix_bv(master_inst_source, master_vocal_source, reduction_rate=secondary_model.bv_model_rebalance)
1250
+ else:
1251
+ vocal_source = master_vocal_source
1252
+
1253
+ vocal_stem_path = [vocal_source, os.path.splitext(os.path.basename(vocal_stem_path))[0]]
1254
+
1255
+ if secondary_model.process_method == VR_ARCH_TYPE:
1256
+ seperator = SeperateVR(secondary_model, process_data, vocal_stem_path=vocal_stem_path, master_inst_source=master_inst_source, master_vocal_source=master_vocal_source)
1257
+ if secondary_model.process_method == MDX_ARCH_TYPE:
1258
+ if secondary_model.is_mdx_c:
1259
+ seperator = SeperateMDXC(secondary_model, process_data, vocal_stem_path=vocal_stem_path, master_inst_source=master_inst_source, master_vocal_source=master_vocal_source)
1260
+ else:
1261
+ seperator = SeperateMDX(secondary_model, process_data, vocal_stem_path=vocal_stem_path, master_inst_source=master_inst_source, master_vocal_source=master_vocal_source)
1262
+ if secondary_model.process_method == DEMUCS_ARCH_TYPE:
1263
+ seperator = SeperateDemucs(secondary_model, process_data, vocal_stem_path=vocal_stem_path, master_inst_source=master_inst_source, master_vocal_source=master_vocal_source)
1264
+
1265
+ secondary_sources = seperator.seperate()
1266
+
1267
+ if type(secondary_sources) is dict:
1268
+ return secondary_sources
1269
+ else:
1270
+ return None
1271
+
1272
+ def gather_sources(primary_stem_name, secondary_stem_name, secondary_sources: dict):
1273
+
1274
+ source_primary = False
1275
+ source_secondary = False
1276
+
1277
+ for key, value in secondary_sources.items():
1278
+ if key in primary_stem_name:
1279
+ source_primary = value
1280
+ if key in secondary_stem_name:
1281
+ source_secondary = value
1282
+
1283
+ return source_primary, source_secondary
1284
+
1285
+ def prepare_mix(mix):
1286
+
1287
+ audio_path = mix
1288
+
1289
+ if not isinstance(mix, np.ndarray):
1290
+ mix, sr = librosa.load(mix, mono=False, sr=44100)
1291
+ else:
1292
+ mix = mix.T
1293
+
1294
+ if isinstance(audio_path, str):
1295
+ if not np.any(mix) and audio_path.endswith('.mp3'):
1296
+ mix = rerun_mp3(audio_path)
1297
+
1298
+ if mix.ndim == 1:
1299
+ mix = np.asfortranarray([mix,mix])
1300
+
1301
+ return mix
1302
+
1303
+ def rerun_mp3(audio_file, sample_rate=44100):
1304
+
1305
+ with audioread.audio_open(audio_file) as f:
1306
+ track_length = int(f.duration)
1307
+
1308
+ return librosa.load(audio_file, duration=track_length, mono=False, sr=sample_rate)[0]
1309
+
1310
+ def save_format(audio_path, save_format, mp3_bit_set):
1311
+
1312
+ if not save_format == WAV:
1313
+
1314
+ if OPERATING_SYSTEM == 'Darwin':
1315
+ FFMPEG_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'ffmpeg')
1316
+ pydub.AudioSegment.converter = FFMPEG_PATH
1317
+
1318
+ musfile = pydub.AudioSegment.from_wav(audio_path)
1319
+
1320
+ if save_format == FLAC:
1321
+ audio_path_flac = audio_path.replace(".wav", ".flac")
1322
+ musfile.export(audio_path_flac, format="flac")
1323
+
1324
+ if save_format == MP3:
1325
+ audio_path_mp3 = audio_path.replace(".wav", ".mp3")
1326
+ try:
1327
+ musfile.export(audio_path_mp3, format="mp3", bitrate=mp3_bit_set, codec="libmp3lame")
1328
+ except Exception as e:
1329
+ print(e)
1330
+ musfile.export(audio_path_mp3, format="mp3", bitrate=mp3_bit_set)
1331
+
1332
+ try:
1333
+ os.remove(audio_path)
1334
+ except Exception as e:
1335
+ print(e)
1336
+
1337
+ def pitch_shift(mix):
1338
+ new_sr = 31183
1339
+
1340
+ # Resample audio file
1341
+ resampled_audio = signal.resample_poly(mix, new_sr, 44100)
1342
+
1343
+ return resampled_audio
1344
+
1345
+ def list_to_dictionary(lst):
1346
+ dictionary = {item: index for index, item in enumerate(lst)}
1347
+ return dictionary
1348
+
1349
+ def vr_denoiser(X, device, hop_length=1024, n_fft=2048, cropsize=256, is_deverber=False, model_path=None):
1350
+ batchsize = 4
1351
+
1352
+ if is_deverber:
1353
+ nout, nout_lstm = 64, 128
1354
+ mp = ModelParameters(os.path.join('lib_v5', 'vr_network', 'modelparams', '4band_v3.json'))
1355
+ n_fft = mp.param['bins'] * 2
1356
+ else:
1357
+ mp = None
1358
+ hop_length=1024
1359
+ nout, nout_lstm = 16, 128
1360
+
1361
+ model = nets_new.CascadedNet(n_fft, nout=nout, nout_lstm=nout_lstm)
1362
+ model.load_state_dict(torch.load(model_path, map_location=cpu))
1363
+ model.to(device)
1364
+
1365
+ if mp is None:
1366
+ X_spec = spec_utils.wave_to_spectrogram_old(X, hop_length, n_fft)
1367
+ else:
1368
+ X_spec = loading_mix(X.T, mp)
1369
+
1370
+ #PreProcess
1371
+ X_mag = np.abs(X_spec)
1372
+ X_phase = np.angle(X_spec)
1373
+
1374
+ #Sep
1375
+ n_frame = X_mag.shape[2]
1376
+ pad_l, pad_r, roi_size = spec_utils.make_padding(n_frame, cropsize, model.offset)
1377
+ X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant')
1378
+ X_mag_pad /= X_mag_pad.max()
1379
+
1380
+ X_dataset = []
1381
+ patches = (X_mag_pad.shape[2] - 2 * model.offset) // roi_size
1382
+ for i in range(patches):
1383
+ start = i * roi_size
1384
+ X_mag_crop = X_mag_pad[:, :, start:start + cropsize]
1385
+ X_dataset.append(X_mag_crop)
1386
+
1387
+ X_dataset = np.asarray(X_dataset)
1388
+
1389
+ model.eval()
1390
+
1391
+ with torch.no_grad():
1392
+ mask = []
1393
+ # To reduce the overhead, dataloader is not used.
1394
+ for i in range(0, patches, batchsize):
1395
+ X_batch = X_dataset[i: i + batchsize]
1396
+ X_batch = torch.from_numpy(X_batch).to(device)
1397
+
1398
+ pred = model.predict_mask(X_batch)
1399
+
1400
+ pred = pred.detach().cpu().numpy()
1401
+ pred = np.concatenate(pred, axis=2)
1402
+ mask.append(pred)
1403
+
1404
+ mask = np.concatenate(mask, axis=2)
1405
+
1406
+ mask = mask[:, :, :n_frame]
1407
+
1408
+ #Post Proc
1409
+ if is_deverber:
1410
+ v_spec = mask * X_mag * np.exp(1.j * X_phase)
1411
+ y_spec = (1 - mask) * X_mag * np.exp(1.j * X_phase)
1412
+ else:
1413
+ v_spec = (1 - mask) * X_mag * np.exp(1.j * X_phase)
1414
+
1415
+ if mp is None:
1416
+ wave = spec_utils.spectrogram_to_wave_old(v_spec, hop_length=1024)
1417
+ else:
1418
+ wave = spec_utils.cmb_spectrogram_to_wave(v_spec, mp, is_v51_model=True).T
1419
+
1420
+ wave = spec_utils.match_array_shapes(wave, X)
1421
+
1422
+ if is_deverber:
1423
+ wave_2 = spec_utils.cmb_spectrogram_to_wave(y_spec, mp, is_v51_model=True).T
1424
+ wave_2 = spec_utils.match_array_shapes(wave_2, X)
1425
+ return wave, wave_2
1426
+ else:
1427
+ return wave
1428
+
1429
+ def loading_mix(X, mp):
1430
+
1431
+ X_wave, X_spec_s = {}, {}
1432
+
1433
+ bands_n = len(mp.param['band'])
1434
+
1435
+ for d in range(bands_n, 0, -1):
1436
+ bp = mp.param['band'][d]
1437
+
1438
+ if OPERATING_SYSTEM == 'Darwin':
1439
+ wav_resolution = 'polyphase' if SYSTEM_PROC == ARM or ARM in SYSTEM_ARCH else bp['res_type']
1440
+ else:
1441
+ wav_resolution = 'polyphase'#bp['res_type']
1442
+
1443
+ if d == bands_n: # high-end band
1444
+ X_wave[d] = X
1445
+
1446
+ else: # lower bands
1447
+ X_wave[d] = librosa.resample(X_wave[d+1], mp.param['band'][d+1]['sr'], bp['sr'], res_type=wav_resolution)
1448
+
1449
+ X_spec_s[d] = spec_utils.wave_to_spectrogram(X_wave[d], bp['hl'], bp['n_fft'], mp, band=d, is_v51_model=True)
1450
+
1451
+ # if d == bands_n and is_high_end_process:
1452
+ # input_high_end_h = (bp['n_fft']//2 - bp['crop_stop']) + (mp.param['pre_filter_stop'] - mp.param['pre_filter_start'])
1453
+ # input_high_end = X_spec_s[d][:, bp['n_fft']//2-input_high_end_h:bp['n_fft']//2, :]
1454
+
1455
+ X_spec = spec_utils.combine_spectrograms(X_spec_s, mp)
1456
+
1457
+ del X_wave, X_spec_s
1458
+
1459
+ return X_spec