Yusen commited on
Commit
c0bbdd5
·
1 Parent(s): 536d40c

Upload 9 files

Browse files
vdecoder/__init__.py ADDED
File without changes
vdecoder/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (164 Bytes). View file
 
vdecoder/hifigan/__pycache__/env.cpython-310.pyc ADDED
Binary file (837 Bytes). View file
 
vdecoder/hifigan/__pycache__/models.cpython-310.pyc ADDED
Binary file (15.1 kB). View file
 
vdecoder/hifigan/__pycache__/utils.cpython-310.pyc ADDED
Binary file (2.37 kB). View file
 
vdecoder/hifigan/env.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+
4
+
5
+ class AttrDict(dict):
6
+ def __init__(self, *args, **kwargs):
7
+ super(AttrDict, self).__init__(*args, **kwargs)
8
+ self.__dict__ = self
9
+
10
+
11
+ def build_env(config, config_name, path):
12
+ t_path = os.path.join(path, config_name)
13
+ if config != t_path:
14
+ os.makedirs(path, exist_ok=True)
15
+ shutil.copyfile(config, os.path.join(path, config_name))
vdecoder/hifigan/models.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from .env import AttrDict
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torch.nn as nn
8
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
9
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
10
+ from .utils import init_weights, get_padding
11
+
12
+ LRELU_SLOPE = 0.1
13
+
14
+
15
+ def load_model(model_path, device='cuda'):
16
+ config_file = os.path.join(os.path.split(model_path)[0], 'config.json')
17
+ with open(config_file) as f:
18
+ data = f.read()
19
+
20
+ global h
21
+ json_config = json.loads(data)
22
+ h = AttrDict(json_config)
23
+
24
+ generator = Generator(h).to(device)
25
+
26
+ cp_dict = torch.load(model_path)
27
+ generator.load_state_dict(cp_dict['generator'])
28
+ generator.eval()
29
+ generator.remove_weight_norm()
30
+ del cp_dict
31
+ return generator, h
32
+
33
+
34
+ class ResBlock1(torch.nn.Module):
35
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
36
+ super(ResBlock1, self).__init__()
37
+ self.h = h
38
+ self.convs1 = nn.ModuleList([
39
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
40
+ padding=get_padding(kernel_size, dilation[0]))),
41
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
42
+ padding=get_padding(kernel_size, dilation[1]))),
43
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
44
+ padding=get_padding(kernel_size, dilation[2])))
45
+ ])
46
+ self.convs1.apply(init_weights)
47
+
48
+ self.convs2 = nn.ModuleList([
49
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
50
+ padding=get_padding(kernel_size, 1))),
51
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
52
+ padding=get_padding(kernel_size, 1))),
53
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
54
+ padding=get_padding(kernel_size, 1)))
55
+ ])
56
+ self.convs2.apply(init_weights)
57
+
58
+ def forward(self, x):
59
+ for c1, c2 in zip(self.convs1, self.convs2):
60
+ xt = F.leaky_relu(x, LRELU_SLOPE)
61
+ xt = c1(xt)
62
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
63
+ xt = c2(xt)
64
+ x = xt + x
65
+ return x
66
+
67
+ def remove_weight_norm(self):
68
+ for l in self.convs1:
69
+ remove_weight_norm(l)
70
+ for l in self.convs2:
71
+ remove_weight_norm(l)
72
+
73
+
74
+ class ResBlock2(torch.nn.Module):
75
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
76
+ super(ResBlock2, self).__init__()
77
+ self.h = h
78
+ self.convs = nn.ModuleList([
79
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
80
+ padding=get_padding(kernel_size, dilation[0]))),
81
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
82
+ padding=get_padding(kernel_size, dilation[1])))
83
+ ])
84
+ self.convs.apply(init_weights)
85
+
86
+ def forward(self, x):
87
+ for c in self.convs:
88
+ xt = F.leaky_relu(x, LRELU_SLOPE)
89
+ xt = c(xt)
90
+ x = xt + x
91
+ return x
92
+
93
+ def remove_weight_norm(self):
94
+ for l in self.convs:
95
+ remove_weight_norm(l)
96
+
97
+
98
+ def padDiff(x):
99
+ return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0)
100
+
101
+ class SineGen(torch.nn.Module):
102
+ """ Definition of sine generator
103
+ SineGen(samp_rate, harmonic_num = 0,
104
+ sine_amp = 0.1, noise_std = 0.003,
105
+ voiced_threshold = 0,
106
+ flag_for_pulse=False)
107
+ samp_rate: sampling rate in Hz
108
+ harmonic_num: number of harmonic overtones (default 0)
109
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
110
+ noise_std: std of Gaussian noise (default 0.003)
111
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
112
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
113
+ Note: when flag_for_pulse is True, the first time step of a voiced
114
+ segment is always sin(np.pi) or cos(0)
115
+ """
116
+
117
+ def __init__(self, samp_rate, harmonic_num=0,
118
+ sine_amp=0.1, noise_std=0.003,
119
+ voiced_threshold=0,
120
+ flag_for_pulse=False):
121
+ super(SineGen, self).__init__()
122
+ self.sine_amp = sine_amp
123
+ self.noise_std = noise_std
124
+ self.harmonic_num = harmonic_num
125
+ self.dim = self.harmonic_num + 1
126
+ self.sampling_rate = samp_rate
127
+ self.voiced_threshold = voiced_threshold
128
+ self.flag_for_pulse = flag_for_pulse
129
+
130
+ def _f02uv(self, f0):
131
+ # generate uv signal
132
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
133
+ return uv
134
+
135
+ def _f02sine(self, f0_values):
136
+ """ f0_values: (batchsize, length, dim)
137
+ where dim indicates fundamental tone and overtones
138
+ """
139
+ # convert to F0 in rad. The interger part n can be ignored
140
+ # because 2 * np.pi * n doesn't affect phase
141
+ rad_values = (f0_values / self.sampling_rate) % 1
142
+
143
+ # initial phase noise (no noise for fundamental component)
144
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
145
+ device=f0_values.device)
146
+ rand_ini[:, 0] = 0
147
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
148
+
149
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
150
+ if not self.flag_for_pulse:
151
+ # for normal case
152
+
153
+ # To prevent torch.cumsum numerical overflow,
154
+ # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
155
+ # Buffer tmp_over_one_idx indicates the time step to add -1.
156
+ # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
157
+ tmp_over_one = torch.cumsum(rad_values, 1) % 1
158
+ tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
159
+ cumsum_shift = torch.zeros_like(rad_values)
160
+ cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
161
+
162
+ sines = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1)
163
+ * 2 * np.pi)
164
+ else:
165
+ # If necessary, make sure that the first time step of every
166
+ # voiced segments is sin(pi) or cos(0)
167
+ # This is used for pulse-train generation
168
+
169
+ # identify the last time step in unvoiced segments
170
+ uv = self._f02uv(f0_values)
171
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
172
+ uv_1[:, -1, :] = 1
173
+ u_loc = (uv < 1) * (uv_1 > 0)
174
+
175
+ # get the instantanouse phase
176
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
177
+ # different batch needs to be processed differently
178
+ for idx in range(f0_values.shape[0]):
179
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
180
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
181
+ # stores the accumulation of i.phase within
182
+ # each voiced segments
183
+ tmp_cumsum[idx, :, :] = 0
184
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
185
+
186
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
187
+ # within the previous voiced segment.
188
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
189
+
190
+ # get the sines
191
+ sines = torch.cos(i_phase * 2 * np.pi)
192
+ return sines
193
+
194
+ def forward(self, f0):
195
+ """ sine_tensor, uv = forward(f0)
196
+ input F0: tensor(batchsize=1, length, dim=1)
197
+ f0 for unvoiced steps should be 0
198
+ output sine_tensor: tensor(batchsize=1, length, dim)
199
+ output uv: tensor(batchsize=1, length, 1)
200
+ """
201
+ with torch.no_grad():
202
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
203
+ device=f0.device)
204
+ # fundamental component
205
+ fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
206
+
207
+ # generate sine waveforms
208
+ sine_waves = self._f02sine(fn) * self.sine_amp
209
+
210
+ # generate uv signal
211
+ # uv = torch.ones(f0.shape)
212
+ # uv = uv * (f0 > self.voiced_threshold)
213
+ uv = self._f02uv(f0)
214
+
215
+ # noise: for unvoiced should be similar to sine_amp
216
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
217
+ # . for voiced regions is self.noise_std
218
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
219
+ noise = noise_amp * torch.randn_like(sine_waves)
220
+
221
+ # first: set the unvoiced part to 0 by uv
222
+ # then: additive noise
223
+ sine_waves = sine_waves * uv + noise
224
+ return sine_waves, uv, noise
225
+
226
+
227
+ class SourceModuleHnNSF(torch.nn.Module):
228
+ """ SourceModule for hn-nsf
229
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
230
+ add_noise_std=0.003, voiced_threshod=0)
231
+ sampling_rate: sampling_rate in Hz
232
+ harmonic_num: number of harmonic above F0 (default: 0)
233
+ sine_amp: amplitude of sine source signal (default: 0.1)
234
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
235
+ note that amplitude of noise in unvoiced is decided
236
+ by sine_amp
237
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
238
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
239
+ F0_sampled (batchsize, length, 1)
240
+ Sine_source (batchsize, length, 1)
241
+ noise_source (batchsize, length 1)
242
+ uv (batchsize, length, 1)
243
+ """
244
+
245
+ def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1,
246
+ add_noise_std=0.003, voiced_threshod=0):
247
+ super(SourceModuleHnNSF, self).__init__()
248
+
249
+ self.sine_amp = sine_amp
250
+ self.noise_std = add_noise_std
251
+
252
+ # to produce sine waveforms
253
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
254
+ sine_amp, add_noise_std, voiced_threshod)
255
+
256
+ # to merge source harmonics into a single excitation
257
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
258
+ self.l_tanh = torch.nn.Tanh()
259
+
260
+ def forward(self, x):
261
+ """
262
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
263
+ F0_sampled (batchsize, length, 1)
264
+ Sine_source (batchsize, length, 1)
265
+ noise_source (batchsize, length 1)
266
+ """
267
+ # source for harmonic branch
268
+ sine_wavs, uv, _ = self.l_sin_gen(x)
269
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
270
+
271
+ # source for noise branch, in the same shape as uv
272
+ noise = torch.randn_like(uv) * self.sine_amp / 3
273
+ return sine_merge, noise, uv
274
+
275
+
276
+ class Generator(torch.nn.Module):
277
+ def __init__(self, h):
278
+ super(Generator, self).__init__()
279
+ self.h = h
280
+
281
+ self.num_kernels = len(h["resblock_kernel_sizes"])
282
+ self.num_upsamples = len(h["upsample_rates"])
283
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h["upsample_rates"]))
284
+ self.m_source = SourceModuleHnNSF(
285
+ sampling_rate=h["sampling_rate"],
286
+ harmonic_num=8)
287
+ self.noise_convs = nn.ModuleList()
288
+ self.conv_pre = weight_norm(Conv1d(h["inter_channels"], h["upsample_initial_channel"], 7, 1, padding=3))
289
+ resblock = ResBlock1 if h["resblock"] == '1' else ResBlock2
290
+ self.ups = nn.ModuleList()
291
+ for i, (u, k) in enumerate(zip(h["upsample_rates"], h["upsample_kernel_sizes"])):
292
+ c_cur = h["upsample_initial_channel"] // (2 ** (i + 1))
293
+ self.ups.append(weight_norm(
294
+ ConvTranspose1d(h["upsample_initial_channel"] // (2 ** i), h["upsample_initial_channel"] // (2 ** (i + 1)),
295
+ k, u, padding=(k - u) // 2)))
296
+ if i + 1 < len(h["upsample_rates"]): #
297
+ stride_f0 = np.prod(h["upsample_rates"][i + 1:])
298
+ self.noise_convs.append(Conv1d(
299
+ 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
300
+ else:
301
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
302
+ self.resblocks = nn.ModuleList()
303
+ for i in range(len(self.ups)):
304
+ ch = h["upsample_initial_channel"] // (2 ** (i + 1))
305
+ for j, (k, d) in enumerate(zip(h["resblock_kernel_sizes"], h["resblock_dilation_sizes"])):
306
+ self.resblocks.append(resblock(h, ch, k, d))
307
+
308
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
309
+ self.ups.apply(init_weights)
310
+ self.conv_post.apply(init_weights)
311
+ self.cond = nn.Conv1d(h['gin_channels'], h['upsample_initial_channel'], 1)
312
+
313
+ def forward(self, x, f0, g=None):
314
+ # print(1,x.shape,f0.shape,f0[:, None].shape)
315
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
316
+ # print(2,f0.shape)
317
+ har_source, noi_source, uv = self.m_source(f0)
318
+ har_source = har_source.transpose(1, 2)
319
+ x = self.conv_pre(x)
320
+ x = x + self.cond(g)
321
+ # print(124,x.shape,har_source.shape)
322
+ for i in range(self.num_upsamples):
323
+ x = F.leaky_relu(x, LRELU_SLOPE)
324
+ # print(3,x.shape)
325
+ x = self.ups[i](x)
326
+ x_source = self.noise_convs[i](har_source)
327
+ # print(4,x_source.shape,har_source.shape,x.shape)
328
+ x = x + x_source
329
+ xs = None
330
+ for j in range(self.num_kernels):
331
+ if xs is None:
332
+ xs = self.resblocks[i * self.num_kernels + j](x)
333
+ else:
334
+ xs += self.resblocks[i * self.num_kernels + j](x)
335
+ x = xs / self.num_kernels
336
+ x = F.leaky_relu(x)
337
+ x = self.conv_post(x)
338
+ x = torch.tanh(x)
339
+
340
+ return x
341
+
342
+ def remove_weight_norm(self):
343
+ print('Removing weight norm...')
344
+ for l in self.ups:
345
+ remove_weight_norm(l)
346
+ for l in self.resblocks:
347
+ l.remove_weight_norm()
348
+ remove_weight_norm(self.conv_pre)
349
+ remove_weight_norm(self.conv_post)
350
+
351
+
352
+ class DiscriminatorP(torch.nn.Module):
353
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
354
+ super(DiscriminatorP, self).__init__()
355
+ self.period = period
356
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
357
+ self.convs = nn.ModuleList([
358
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
359
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
360
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
361
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
362
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
363
+ ])
364
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
365
+
366
+ def forward(self, x):
367
+ fmap = []
368
+
369
+ # 1d to 2d
370
+ b, c, t = x.shape
371
+ if t % self.period != 0: # pad first
372
+ n_pad = self.period - (t % self.period)
373
+ x = F.pad(x, (0, n_pad), "reflect")
374
+ t = t + n_pad
375
+ x = x.view(b, c, t // self.period, self.period)
376
+
377
+ for l in self.convs:
378
+ x = l(x)
379
+ x = F.leaky_relu(x, LRELU_SLOPE)
380
+ fmap.append(x)
381
+ x = self.conv_post(x)
382
+ fmap.append(x)
383
+ x = torch.flatten(x, 1, -1)
384
+
385
+ return x, fmap
386
+
387
+
388
+ class MultiPeriodDiscriminator(torch.nn.Module):
389
+ def __init__(self, periods=None):
390
+ super(MultiPeriodDiscriminator, self).__init__()
391
+ self.periods = periods if periods is not None else [2, 3, 5, 7, 11]
392
+ self.discriminators = nn.ModuleList()
393
+ for period in self.periods:
394
+ self.discriminators.append(DiscriminatorP(period))
395
+
396
+ def forward(self, y, y_hat):
397
+ y_d_rs = []
398
+ y_d_gs = []
399
+ fmap_rs = []
400
+ fmap_gs = []
401
+ for i, d in enumerate(self.discriminators):
402
+ y_d_r, fmap_r = d(y)
403
+ y_d_g, fmap_g = d(y_hat)
404
+ y_d_rs.append(y_d_r)
405
+ fmap_rs.append(fmap_r)
406
+ y_d_gs.append(y_d_g)
407
+ fmap_gs.append(fmap_g)
408
+
409
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
410
+
411
+
412
+ class DiscriminatorS(torch.nn.Module):
413
+ def __init__(self, use_spectral_norm=False):
414
+ super(DiscriminatorS, self).__init__()
415
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
416
+ self.convs = nn.ModuleList([
417
+ norm_f(Conv1d(1, 128, 15, 1, padding=7)),
418
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
419
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
420
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
421
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
422
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
423
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
424
+ ])
425
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
426
+
427
+ def forward(self, x):
428
+ fmap = []
429
+ for l in self.convs:
430
+ x = l(x)
431
+ x = F.leaky_relu(x, LRELU_SLOPE)
432
+ fmap.append(x)
433
+ x = self.conv_post(x)
434
+ fmap.append(x)
435
+ x = torch.flatten(x, 1, -1)
436
+
437
+ return x, fmap
438
+
439
+
440
+ class MultiScaleDiscriminator(torch.nn.Module):
441
+ def __init__(self):
442
+ super(MultiScaleDiscriminator, self).__init__()
443
+ self.discriminators = nn.ModuleList([
444
+ DiscriminatorS(use_spectral_norm=True),
445
+ DiscriminatorS(),
446
+ DiscriminatorS(),
447
+ ])
448
+ self.meanpools = nn.ModuleList([
449
+ AvgPool1d(4, 2, padding=2),
450
+ AvgPool1d(4, 2, padding=2)
451
+ ])
452
+
453
+ def forward(self, y, y_hat):
454
+ y_d_rs = []
455
+ y_d_gs = []
456
+ fmap_rs = []
457
+ fmap_gs = []
458
+ for i, d in enumerate(self.discriminators):
459
+ if i != 0:
460
+ y = self.meanpools[i - 1](y)
461
+ y_hat = self.meanpools[i - 1](y_hat)
462
+ y_d_r, fmap_r = d(y)
463
+ y_d_g, fmap_g = d(y_hat)
464
+ y_d_rs.append(y_d_r)
465
+ fmap_rs.append(fmap_r)
466
+ y_d_gs.append(y_d_g)
467
+ fmap_gs.append(fmap_g)
468
+
469
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
470
+
471
+
472
+ def feature_loss(fmap_r, fmap_g):
473
+ loss = 0
474
+ for dr, dg in zip(fmap_r, fmap_g):
475
+ for rl, gl in zip(dr, dg):
476
+ loss += torch.mean(torch.abs(rl - gl))
477
+
478
+ return loss * 2
479
+
480
+
481
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
482
+ loss = 0
483
+ r_losses = []
484
+ g_losses = []
485
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
486
+ r_loss = torch.mean((1 - dr) ** 2)
487
+ g_loss = torch.mean(dg ** 2)
488
+ loss += (r_loss + g_loss)
489
+ r_losses.append(r_loss.item())
490
+ g_losses.append(g_loss.item())
491
+
492
+ return loss, r_losses, g_losses
493
+
494
+
495
+ def generator_loss(disc_outputs):
496
+ loss = 0
497
+ gen_losses = []
498
+ for dg in disc_outputs:
499
+ l = torch.mean((1 - dg) ** 2)
500
+ gen_losses.append(l)
501
+ loss += l
502
+
503
+ return loss, gen_losses
vdecoder/hifigan/nvSTFT.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ os.environ["LRU_CACHE_CAPACITY"] = "3"
4
+ import random
5
+ import torch
6
+ import torch.utils.data
7
+ import numpy as np
8
+ import librosa
9
+ from librosa.util import normalize
10
+ from librosa.filters import mel as librosa_mel_fn
11
+ from scipy.io.wavfile import read
12
+ import soundfile as sf
13
+
14
+ def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False):
15
+ sampling_rate = None
16
+ try:
17
+ data, sampling_rate = sf.read(full_path, always_2d=True)# than soundfile.
18
+ except Exception as ex:
19
+ print(f"'{full_path}' failed to load.\nException:")
20
+ print(ex)
21
+ if return_empty_on_exception:
22
+ return [], sampling_rate or target_sr or 32000
23
+ else:
24
+ raise Exception(ex)
25
+
26
+ if len(data.shape) > 1:
27
+ data = data[:, 0]
28
+ assert len(data) > 2# check duration of audio file is > 2 samples (because otherwise the slice operation was on the wrong dimension)
29
+
30
+ if np.issubdtype(data.dtype, np.integer): # if audio data is type int
31
+ max_mag = -np.iinfo(data.dtype).min # maximum magnitude = min possible value of intXX
32
+ else: # if audio data is type fp32
33
+ max_mag = max(np.amax(data), -np.amin(data))
34
+ max_mag = (2**31)+1 if max_mag > (2**15) else ((2**15)+1 if max_mag > 1.01 else 1.0) # data should be either 16-bit INT, 32-bit INT or [-1 to 1] float32
35
+
36
+ data = torch.FloatTensor(data.astype(np.float32))/max_mag
37
+
38
+ if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception:# resample will crash with inf/NaN inputs. return_empty_on_exception will return empty arr instead of except
39
+ return [], sampling_rate or target_sr or 32000
40
+ if target_sr is not None and sampling_rate != target_sr:
41
+ data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sampling_rate, target_sr=target_sr))
42
+ sampling_rate = target_sr
43
+
44
+ return data, sampling_rate
45
+
46
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
47
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
48
+
49
+ def dynamic_range_decompression(x, C=1):
50
+ return np.exp(x) / C
51
+
52
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
53
+ return torch.log(torch.clamp(x, min=clip_val) * C)
54
+
55
+ def dynamic_range_decompression_torch(x, C=1):
56
+ return torch.exp(x) / C
57
+
58
+ class STFT():
59
+ def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5):
60
+ self.target_sr = sr
61
+
62
+ self.n_mels = n_mels
63
+ self.n_fft = n_fft
64
+ self.win_size = win_size
65
+ self.hop_length = hop_length
66
+ self.fmin = fmin
67
+ self.fmax = fmax
68
+ self.clip_val = clip_val
69
+ self.mel_basis = {}
70
+ self.hann_window = {}
71
+
72
+ def get_mel(self, y, center=False):
73
+ sampling_rate = self.target_sr
74
+ n_mels = self.n_mels
75
+ n_fft = self.n_fft
76
+ win_size = self.win_size
77
+ hop_length = self.hop_length
78
+ fmin = self.fmin
79
+ fmax = self.fmax
80
+ clip_val = self.clip_val
81
+
82
+ if torch.min(y) < -1.:
83
+ print('min value is ', torch.min(y))
84
+ if torch.max(y) > 1.:
85
+ print('max value is ', torch.max(y))
86
+
87
+ if fmax not in self.mel_basis:
88
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
89
+ self.mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
90
+ self.hann_window[str(y.device)] = torch.hann_window(self.win_size).to(y.device)
91
+
92
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_length)/2), int((n_fft-hop_length)/2)), mode='reflect')
93
+ y = y.squeeze(1)
94
+
95
+ spec = torch.stft(y, n_fft, hop_length=hop_length, win_length=win_size, window=self.hann_window[str(y.device)],
96
+ center=center, pad_mode='reflect', normalized=False, onesided=True)
97
+ # print(111,spec)
98
+ spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
99
+ # print(222,spec)
100
+ spec = torch.matmul(self.mel_basis[str(fmax)+'_'+str(y.device)], spec)
101
+ # print(333,spec)
102
+ spec = dynamic_range_compression_torch(spec, clip_val=clip_val)
103
+ # print(444,spec)
104
+ return spec
105
+
106
+ def __call__(self, audiopath):
107
+ audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr)
108
+ spect = self.get_mel(audio.unsqueeze(0)).squeeze(0)
109
+ return spect
110
+
111
+ stft = STFT()
vdecoder/hifigan/utils.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import matplotlib
4
+ import torch
5
+ from torch.nn.utils import weight_norm
6
+ matplotlib.use("Agg")
7
+ import matplotlib.pylab as plt
8
+
9
+
10
+ def plot_spectrogram(spectrogram):
11
+ fig, ax = plt.subplots(figsize=(10, 2))
12
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
13
+ interpolation='none')
14
+ plt.colorbar(im, ax=ax)
15
+
16
+ fig.canvas.draw()
17
+ plt.close()
18
+
19
+ return fig
20
+
21
+
22
+ def init_weights(m, mean=0.0, std=0.01):
23
+ classname = m.__class__.__name__
24
+ if classname.find("Conv") != -1:
25
+ m.weight.data.normal_(mean, std)
26
+
27
+
28
+ def apply_weight_norm(m):
29
+ classname = m.__class__.__name__
30
+ if classname.find("Conv") != -1:
31
+ weight_norm(m)
32
+
33
+
34
+ def get_padding(kernel_size, dilation=1):
35
+ return int((kernel_size*dilation - dilation)/2)
36
+
37
+
38
+ def load_checkpoint(filepath, device):
39
+ assert os.path.isfile(filepath)
40
+ print("Loading '{}'".format(filepath))
41
+ checkpoint_dict = torch.load(filepath, map_location=device)
42
+ print("Complete.")
43
+ return checkpoint_dict
44
+
45
+
46
+ def save_checkpoint(filepath, obj):
47
+ print("Saving checkpoint to {}".format(filepath))
48
+ torch.save(obj, filepath)
49
+ print("Complete.")
50
+
51
+
52
+ def del_old_checkpoints(cp_dir, prefix, n_models=2):
53
+ pattern = os.path.join(cp_dir, prefix + '????????')
54
+ cp_list = glob.glob(pattern) # get checkpoint paths
55
+ cp_list = sorted(cp_list)# sort by iter
56
+ if len(cp_list) > n_models: # if more than n_models models are found
57
+ for cp in cp_list[:-n_models]:# delete the oldest models other than lastest n_models
58
+ open(cp, 'w').close()# empty file contents
59
+ os.unlink(cp)# delete file (move to trash when using Colab)
60
+
61
+
62
+ def scan_checkpoint(cp_dir, prefix):
63
+ pattern = os.path.join(cp_dir, prefix + '????????')
64
+ cp_list = glob.glob(pattern)
65
+ if len(cp_list) == 0:
66
+ return None
67
+ return sorted(cp_list)[-1]
68
+