Plachta commited on
Commit
bc452bd
1 Parent(s): 91da599

Upload 65 files

Browse files
Files changed (32) hide show
  1. .gitattributes +3 -0
  2. examples/reference/azuma_0.wav +0 -0
  3. examples/reference/trump_0.wav +3 -0
  4. examples/source/TECHNOPOLIS - 2085 [vocals]_[cut_14sec].wav +3 -0
  5. examples/source/glados_0.wav +0 -0
  6. examples/source/jay_0.wav +3 -0
  7. modules/alias_free_torch/__pycache__/__init__.cpython-310.pyc +0 -0
  8. modules/alias_free_torch/__pycache__/act.cpython-310.pyc +0 -0
  9. modules/alias_free_torch/__pycache__/filter.cpython-310.pyc +0 -0
  10. modules/alias_free_torch/__pycache__/resample.cpython-310.pyc +0 -0
  11. modules/bigvgan/activations.py +120 -0
  12. modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  13. modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  14. modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  15. modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  16. modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
  17. modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
  18. modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  19. modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
  20. modules/bigvgan/alias_free_activation/torch/act.py +30 -0
  21. modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
  22. modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
  23. modules/bigvgan/bigvgan.py +492 -0
  24. modules/bigvgan/config.json +63 -0
  25. modules/bigvgan/env.py +18 -0
  26. modules/bigvgan/meldataset.py +354 -0
  27. modules/bigvgan/utils.py +99 -0
  28. modules/diffusion_transformer.py +2 -2
  29. modules/flow_matching.py +3 -1
  30. modules/hifigan/generator.py +454 -454
  31. modules/length_regulator.py +118 -102
  32. modules/rmvpe.py +600 -600
.gitattributes CHANGED
@@ -38,3 +38,6 @@ examples/reference/s3p2.wav filter=lfs diff=lfs merge=lfs -text
38
  examples/source/source_s3.wav filter=lfs diff=lfs merge=lfs -text
39
  examples/source/source_s4.wav filter=lfs diff=lfs merge=lfs -text
40
  examples/source/Wiz[[:space:]]Khalifa,Charlie[[:space:]]Puth[[:space:]]-[[:space:]]See[[:space:]]You[[:space:]]Again[[:space:]]\[vocals\]_\[cut_28sec\].wav filter=lfs diff=lfs merge=lfs -text
 
 
 
 
38
  examples/source/source_s3.wav filter=lfs diff=lfs merge=lfs -text
39
  examples/source/source_s4.wav filter=lfs diff=lfs merge=lfs -text
40
  examples/source/Wiz[[:space:]]Khalifa,Charlie[[:space:]]Puth[[:space:]]-[[:space:]]See[[:space:]]You[[:space:]]Again[[:space:]]\[vocals\]_\[cut_28sec\].wav filter=lfs diff=lfs merge=lfs -text
41
+ examples/reference/trump_0.wav filter=lfs diff=lfs merge=lfs -text
42
+ examples/source/jay_0.wav filter=lfs diff=lfs merge=lfs -text
43
+ examples/source/TECHNOPOLIS[[:space:]]-[[:space:]]2085[[:space:]]\[vocals\]_\[cut_14sec\].wav filter=lfs diff=lfs merge=lfs -text
examples/reference/azuma_0.wav ADDED
Binary file (629 kB). View file
 
examples/reference/trump_0.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:716becc9daf00351dfe324398edea9e8378f9453408b27612d92b6721f80ddbc
3
+ size 1379484
examples/source/TECHNOPOLIS - 2085 [vocals]_[cut_14sec].wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:87087ca5260ce96659b01a647edb30bb08527ed7d0c074fb5ae1e8338cc733e5
3
+ size 2796016
examples/source/glados_0.wav ADDED
Binary file (640 kB). View file
 
examples/source/jay_0.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d30f1500acacb597c3b27d7a5937dd088b8029b27e9db8bf5982085f26f4457
3
+ size 1270124
modules/alias_free_torch/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (198 Bytes). View file
 
modules/alias_free_torch/__pycache__/act.cpython-310.pyc ADDED
Binary file (1.03 kB). View file
 
modules/alias_free_torch/__pycache__/filter.cpython-310.pyc ADDED
Binary file (2.61 kB). View file
 
modules/alias_free_torch/__pycache__/resample.cpython-310.pyc ADDED
Binary file (1.89 kB). View file
 
modules/bigvgan/activations.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ from torch import nn, sin, pow
6
+ from torch.nn import Parameter
7
+
8
+
9
+ class Snake(nn.Module):
10
+ '''
11
+ Implementation of a sine-based periodic activation function
12
+ Shape:
13
+ - Input: (B, C, T)
14
+ - Output: (B, C, T), same shape as the input
15
+ Parameters:
16
+ - alpha - trainable parameter
17
+ References:
18
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
19
+ https://arxiv.org/abs/2006.08195
20
+ Examples:
21
+ >>> a1 = snake(256)
22
+ >>> x = torch.randn(256)
23
+ >>> x = a1(x)
24
+ '''
25
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
26
+ '''
27
+ Initialization.
28
+ INPUT:
29
+ - in_features: shape of the input
30
+ - alpha: trainable parameter
31
+ alpha is initialized to 1 by default, higher values = higher-frequency.
32
+ alpha will be trained along with the rest of your model.
33
+ '''
34
+ super(Snake, self).__init__()
35
+ self.in_features = in_features
36
+
37
+ # initialize alpha
38
+ self.alpha_logscale = alpha_logscale
39
+ if self.alpha_logscale: # log scale alphas initialized to zeros
40
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
41
+ else: # linear scale alphas initialized to ones
42
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
43
+
44
+ self.alpha.requires_grad = alpha_trainable
45
+
46
+ self.no_div_by_zero = 0.000000001
47
+
48
+ def forward(self, x):
49
+ '''
50
+ Forward pass of the function.
51
+ Applies the function to the input elementwise.
52
+ Snake ∶= x + 1/a * sin^2 (xa)
53
+ '''
54
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
55
+ if self.alpha_logscale:
56
+ alpha = torch.exp(alpha)
57
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
58
+
59
+ return x
60
+
61
+
62
+ class SnakeBeta(nn.Module):
63
+ '''
64
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
65
+ Shape:
66
+ - Input: (B, C, T)
67
+ - Output: (B, C, T), same shape as the input
68
+ Parameters:
69
+ - alpha - trainable parameter that controls frequency
70
+ - beta - trainable parameter that controls magnitude
71
+ References:
72
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
73
+ https://arxiv.org/abs/2006.08195
74
+ Examples:
75
+ >>> a1 = snakebeta(256)
76
+ >>> x = torch.randn(256)
77
+ >>> x = a1(x)
78
+ '''
79
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
80
+ '''
81
+ Initialization.
82
+ INPUT:
83
+ - in_features: shape of the input
84
+ - alpha - trainable parameter that controls frequency
85
+ - beta - trainable parameter that controls magnitude
86
+ alpha is initialized to 1 by default, higher values = higher-frequency.
87
+ beta is initialized to 1 by default, higher values = higher-magnitude.
88
+ alpha will be trained along with the rest of your model.
89
+ '''
90
+ super(SnakeBeta, self).__init__()
91
+ self.in_features = in_features
92
+
93
+ # initialize alpha
94
+ self.alpha_logscale = alpha_logscale
95
+ if self.alpha_logscale: # log scale alphas initialized to zeros
96
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
97
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
98
+ else: # linear scale alphas initialized to ones
99
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
100
+ self.beta = Parameter(torch.ones(in_features) * alpha)
101
+
102
+ self.alpha.requires_grad = alpha_trainable
103
+ self.beta.requires_grad = alpha_trainable
104
+
105
+ self.no_div_by_zero = 0.000000001
106
+
107
+ def forward(self, x):
108
+ '''
109
+ Forward pass of the function.
110
+ Applies the function to the input elementwise.
111
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
112
+ '''
113
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
114
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
115
+ if self.alpha_logscale:
116
+ alpha = torch.exp(alpha)
117
+ beta = torch.exp(beta)
118
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
119
+
120
+ return x
modules/bigvgan/alias_free_activation/cuda/__init__.py ADDED
File without changes
modules/bigvgan/alias_free_activation/cuda/activation1d.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from alias_free_activation.torch.resample import UpSample1d, DownSample1d
7
+
8
+ # load fused CUDA kernel: this enables importing anti_alias_activation_cuda
9
+ from alias_free_activation.cuda import load
10
+
11
+ anti_alias_activation_cuda = load.load()
12
+
13
+
14
+ class FusedAntiAliasActivation(torch.autograd.Function):
15
+ """
16
+ Assumes filter size 12, replication padding on upsampling/downsampling, and logscale alpha/beta parameters as inputs.
17
+ The hyperparameters are hard-coded in the kernel to maximize speed.
18
+ NOTE: The fused kenrel is incorrect for Activation1d with different hyperparameters.
19
+ """
20
+
21
+ @staticmethod
22
+ def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
23
+ activation_results = anti_alias_activation_cuda.forward(
24
+ inputs, up_ftr, down_ftr, alpha, beta
25
+ )
26
+
27
+ return activation_results
28
+
29
+ @staticmethod
30
+ def backward(ctx, output_grads):
31
+ raise NotImplementedError
32
+ return output_grads, None, None
33
+
34
+
35
+ class Activation1d(nn.Module):
36
+ def __init__(
37
+ self,
38
+ activation,
39
+ up_ratio: int = 2,
40
+ down_ratio: int = 2,
41
+ up_kernel_size: int = 12,
42
+ down_kernel_size: int = 12,
43
+ fused: bool = True,
44
+ ):
45
+ super().__init__()
46
+ self.up_ratio = up_ratio
47
+ self.down_ratio = down_ratio
48
+ self.act = activation
49
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
50
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
51
+
52
+ self.fused = fused # Whether to use fused CUDA kernel or not
53
+
54
+ def forward(self, x):
55
+ if not self.fused:
56
+ x = self.upsample(x)
57
+ x = self.act(x)
58
+ x = self.downsample(x)
59
+ return x
60
+ else:
61
+ if self.act.__class__.__name__ == "Snake":
62
+ beta = self.act.alpha.data # Snake uses same params for alpha and beta
63
+ else:
64
+ beta = (
65
+ self.act.beta.data
66
+ ) # Snakebeta uses different params for alpha and beta
67
+ alpha = self.act.alpha.data
68
+ if (
69
+ not self.act.alpha_logscale
70
+ ): # Exp baked into cuda kernel, cancel it out with a log
71
+ alpha = torch.log(alpha)
72
+ beta = torch.log(beta)
73
+
74
+ x = FusedAntiAliasActivation.apply(
75
+ x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta
76
+ )
77
+ return x
modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <torch/extension.h>
18
+
19
+ extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta);
20
+
21
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22
+ m.def("forward", &fwd_cuda, "Anti-Alias Activation forward (CUDA)");
23
+ }
modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <ATen/ATen.h>
18
+ #include <cuda.h>
19
+ #include <cuda_runtime.h>
20
+ #include <cuda_fp16.h>
21
+ #include <cuda_profiler_api.h>
22
+ #include <ATen/cuda/CUDAContext.h>
23
+ #include <torch/extension.h>
24
+ #include "type_shim.h"
25
+ #include <assert.h>
26
+ #include <cfloat>
27
+ #include <limits>
28
+ #include <stdint.h>
29
+ #include <c10/macros/Macros.h>
30
+
31
+ namespace
32
+ {
33
+ // Hard-coded hyperparameters
34
+ // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
35
+ constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;
36
+ constexpr int BUFFER_SIZE = 32;
37
+ constexpr int FILTER_SIZE = 12;
38
+ constexpr int HALF_FILTER_SIZE = 6;
39
+ constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl
40
+ constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl
41
+ constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl
42
+
43
+ template <typename input_t, typename output_t, typename acc_t>
44
+ __global__ void anti_alias_activation_forward(
45
+ output_t *dst,
46
+ const input_t *src,
47
+ const input_t *up_ftr,
48
+ const input_t *down_ftr,
49
+ const input_t *alpha,
50
+ const input_t *beta,
51
+ int batch_size,
52
+ int channels,
53
+ int seq_len)
54
+ {
55
+ // Up and downsample filters
56
+ input_t up_filter[FILTER_SIZE];
57
+ input_t down_filter[FILTER_SIZE];
58
+
59
+ // Load data from global memory including extra indices reserved for replication paddings
60
+ input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0};
61
+ input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0};
62
+
63
+ // Output stores downsampled output before writing to dst
64
+ output_t output[BUFFER_SIZE];
65
+
66
+ // blockDim/threadIdx = (128, 1, 1)
67
+ // gridDim/blockIdx = (seq_blocks, channels, batches)
68
+ int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
69
+ int local_offset = threadIdx.x * BUFFER_SIZE;
70
+ int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset;
71
+
72
+ // intermediate have double the seq_len
73
+ int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2;
74
+ int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset;
75
+
76
+ // Get values needed for replication padding before moving pointer
77
+ const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
78
+ input_t seq_left_most_value = right_most_pntr[0];
79
+ input_t seq_right_most_value = right_most_pntr[seq_len - 1];
80
+
81
+ // Move src and dst pointers
82
+ src += block_offset + local_offset;
83
+ dst += block_offset + local_offset;
84
+
85
+ // Alpha and beta values for snake activatons. Applies exp by default
86
+ alpha = alpha + blockIdx.y;
87
+ input_t alpha_val = expf(alpha[0]);
88
+ beta = beta + blockIdx.y;
89
+ input_t beta_val = expf(beta[0]);
90
+
91
+ #pragma unroll
92
+ for (int it = 0; it < FILTER_SIZE; it += 1)
93
+ {
94
+ up_filter[it] = up_ftr[it];
95
+ down_filter[it] = down_ftr[it];
96
+ }
97
+
98
+ // Apply replication padding for upsampling, matching torch impl
99
+ #pragma unroll
100
+ for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE; it += 1)
101
+ {
102
+ int element_index = seq_offset + it; // index for element
103
+ if ((element_index < 0) && (element_index >= -UPSAMPLE_REPLICATION_PAD))
104
+ {
105
+ elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_left_most_value;
106
+ }
107
+ if ((element_index >= seq_len) && (element_index < seq_len + UPSAMPLE_REPLICATION_PAD))
108
+ {
109
+ elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_right_most_value;
110
+ }
111
+ if ((element_index >= 0) && (element_index < seq_len))
112
+ {
113
+ elements[2 * (HALF_FILTER_SIZE + it)] = 2 * src[it];
114
+ }
115
+ }
116
+
117
+ // Apply upsampling strided convolution and write to intermediates. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT for replication padding of the downsampilng conv later
118
+ #pragma unroll
119
+ for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1)
120
+ {
121
+ input_t acc = 0.0;
122
+ int element_index = intermediate_seq_offset + it; // index for intermediate
123
+ #pragma unroll
124
+ for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
125
+ {
126
+ if ((element_index + f_idx) >= 0)
127
+ {
128
+ acc += up_filter[f_idx] * elements[it + f_idx];
129
+ }
130
+ }
131
+ intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] = acc;
132
+ }
133
+
134
+ // Apply activation function. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT and DOWNSAMPLE_REPLICATION_PAD_RIGHT for replication padding of the downsampilng conv later
135
+ double no_div_by_zero = 0.000000001;
136
+ #pragma unroll
137
+ for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1)
138
+ {
139
+ intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val);
140
+ }
141
+
142
+ // Apply replication padding before downsampling conv from intermediates
143
+ #pragma unroll
144
+ for (int it = 0; it < DOWNSAMPLE_REPLICATION_PAD_LEFT; it += 1)
145
+ {
146
+ intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT];
147
+ }
148
+ #pragma unroll
149
+ for (int it = DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it < DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE + DOWNSAMPLE_REPLICATION_PAD_RIGHT; it += 1)
150
+ {
151
+ intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE - 1];
152
+ }
153
+
154
+ // Apply downsample strided convolution (assuming stride=2) from intermediates
155
+ #pragma unroll
156
+ for (int it = 0; it < BUFFER_SIZE; it += 1)
157
+ {
158
+ input_t acc = 0.0;
159
+ #pragma unroll
160
+ for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
161
+ {
162
+ // Add constant DOWNSAMPLE_REPLICATION_PAD_RIGHT to match torch implementation
163
+ acc += down_filter[f_idx] * intermediates[it * 2 + f_idx + DOWNSAMPLE_REPLICATION_PAD_RIGHT];
164
+ }
165
+ output[it] = acc;
166
+ }
167
+
168
+ // Write output to dst
169
+ #pragma unroll
170
+ for (int it = 0; it < BUFFER_SIZE; it += ELEMENTS_PER_LDG_STG)
171
+ {
172
+ int element_index = seq_offset + it;
173
+ if (element_index < seq_len)
174
+ {
175
+ dst[it] = output[it];
176
+ }
177
+ }
178
+
179
+ }
180
+
181
+ template <typename input_t, typename output_t, typename acc_t>
182
+ void dispatch_anti_alias_activation_forward(
183
+ output_t *dst,
184
+ const input_t *src,
185
+ const input_t *up_ftr,
186
+ const input_t *down_ftr,
187
+ const input_t *alpha,
188
+ const input_t *beta,
189
+ int batch_size,
190
+ int channels,
191
+ int seq_len)
192
+ {
193
+ if (seq_len == 0)
194
+ {
195
+ return;
196
+ }
197
+ else
198
+ {
199
+ // Use 128 threads per block to maximimize gpu utilization
200
+ constexpr int threads_per_block = 128;
201
+ constexpr int seq_len_per_block = 4096;
202
+ int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block;
203
+ dim3 blocks(blocks_per_seq_len, channels, batch_size);
204
+ dim3 threads(threads_per_block, 1, 1);
205
+
206
+ anti_alias_activation_forward<input_t, output_t, acc_t>
207
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, up_ftr, down_ftr, alpha, beta, batch_size, channels, seq_len);
208
+ }
209
+ }
210
+ }
211
+
212
+ extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta)
213
+ {
214
+ // Input is a 3d tensor with dimensions [batches, channels, seq_len]
215
+ const int batches = input.size(0);
216
+ const int channels = input.size(1);
217
+ const int seq_len = input.size(2);
218
+
219
+ // Output
220
+ auto act_options = input.options().requires_grad(false);
221
+
222
+ torch::Tensor anti_alias_activation_results =
223
+ torch::empty({batches, channels, seq_len}, act_options);
224
+
225
+ void *input_ptr = static_cast<void *>(input.data_ptr());
226
+ void *up_filter_ptr = static_cast<void *>(up_filter.data_ptr());
227
+ void *down_filter_ptr = static_cast<void *>(down_filter.data_ptr());
228
+ void *alpha_ptr = static_cast<void *>(alpha.data_ptr());
229
+ void *beta_ptr = static_cast<void *>(beta.data_ptr());
230
+ void *anti_alias_activation_results_ptr = static_cast<void *>(anti_alias_activation_results.data_ptr());
231
+
232
+ DISPATCH_FLOAT_HALF_AND_BFLOAT(
233
+ input.scalar_type(),
234
+ "dispatch anti alias activation_forward",
235
+ dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float>(
236
+ reinterpret_cast<scalar_t *>(anti_alias_activation_results_ptr),
237
+ reinterpret_cast<const scalar_t *>(input_ptr),
238
+ reinterpret_cast<const scalar_t *>(up_filter_ptr),
239
+ reinterpret_cast<const scalar_t *>(down_filter_ptr),
240
+ reinterpret_cast<const scalar_t *>(alpha_ptr),
241
+ reinterpret_cast<const scalar_t *>(beta_ptr),
242
+ batches,
243
+ channels,
244
+ seq_len););
245
+ return anti_alias_activation_results;
246
+ }
modules/bigvgan/alias_free_activation/cuda/compat.h ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ /*This code is copied fron NVIDIA apex:
18
+ * https://github.com/NVIDIA/apex
19
+ * with minor changes. */
20
+
21
+ #ifndef TORCH_CHECK
22
+ #define TORCH_CHECK AT_CHECK
23
+ #endif
24
+
25
+ #ifdef VERSION_GE_1_3
26
+ #define DATA_PTR data_ptr
27
+ #else
28
+ #define DATA_PTR data
29
+ #endif
modules/bigvgan/alias_free_activation/cuda/load.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ import os
5
+ import pathlib
6
+ import subprocess
7
+
8
+ from torch.utils import cpp_extension
9
+
10
+ """
11
+ Setting this param to a list has a problem of generating different compilation commands (with diferent order of architectures) and leading to recompilation of fused kernels.
12
+ Set it to empty stringo avoid recompilation and assign arch flags explicity in extra_cuda_cflags below
13
+ """
14
+ os.environ["TORCH_CUDA_ARCH_LIST"] = ""
15
+
16
+
17
+ def load():
18
+ # Check if cuda 11 is installed for compute capability 8.0
19
+ cc_flag = []
20
+ _, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
21
+ if int(bare_metal_major) >= 11:
22
+ cc_flag.append("-gencode")
23
+ cc_flag.append("arch=compute_80,code=sm_80")
24
+
25
+ # Build path
26
+ srcpath = pathlib.Path(__file__).parent.absolute()
27
+ buildpath = srcpath / "build"
28
+ _create_build_dir(buildpath)
29
+
30
+ # Helper function to build the kernels.
31
+ def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
32
+ return cpp_extension.load(
33
+ name=name,
34
+ sources=sources,
35
+ build_directory=buildpath,
36
+ extra_cflags=[
37
+ "-O3",
38
+ ],
39
+ extra_cuda_cflags=[
40
+ "-O3",
41
+ "-gencode",
42
+ "arch=compute_70,code=sm_70",
43
+ "--use_fast_math",
44
+ ]
45
+ + extra_cuda_flags
46
+ + cc_flag,
47
+ verbose=True,
48
+ )
49
+
50
+ extra_cuda_flags = [
51
+ "-U__CUDA_NO_HALF_OPERATORS__",
52
+ "-U__CUDA_NO_HALF_CONVERSIONS__",
53
+ "--expt-relaxed-constexpr",
54
+ "--expt-extended-lambda",
55
+ ]
56
+
57
+ sources = [
58
+ srcpath / "anti_alias_activation.cpp",
59
+ srcpath / "anti_alias_activation_cuda.cu",
60
+ ]
61
+ anti_alias_activation_cuda = _cpp_extention_load_helper(
62
+ "anti_alias_activation_cuda", sources, extra_cuda_flags
63
+ )
64
+
65
+ return anti_alias_activation_cuda
66
+
67
+
68
+ def _get_cuda_bare_metal_version(cuda_dir):
69
+ raw_output = subprocess.check_output(
70
+ [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
71
+ )
72
+ output = raw_output.split()
73
+ release_idx = output.index("release") + 1
74
+ release = output[release_idx].split(".")
75
+ bare_metal_major = release[0]
76
+ bare_metal_minor = release[1][0]
77
+
78
+ return raw_output, bare_metal_major, bare_metal_minor
79
+
80
+
81
+ def _create_build_dir(buildpath):
82
+ try:
83
+ os.mkdir(buildpath)
84
+ except OSError:
85
+ if not os.path.isdir(buildpath):
86
+ print(f"Creation of the build directory {buildpath} failed")
modules/bigvgan/alias_free_activation/cuda/type_shim.h ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <ATen/ATen.h>
18
+ #include "compat.h"
19
+
20
+ #define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
21
+ switch (TYPE) \
22
+ { \
23
+ case at::ScalarType::Float: \
24
+ { \
25
+ using scalar_t = float; \
26
+ __VA_ARGS__; \
27
+ break; \
28
+ } \
29
+ case at::ScalarType::Half: \
30
+ { \
31
+ using scalar_t = at::Half; \
32
+ __VA_ARGS__; \
33
+ break; \
34
+ } \
35
+ case at::ScalarType::BFloat16: \
36
+ { \
37
+ using scalar_t = at::BFloat16; \
38
+ __VA_ARGS__; \
39
+ break; \
40
+ } \
41
+ default: \
42
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
43
+ }
44
+
45
+ #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
46
+ switch (TYPEIN) \
47
+ { \
48
+ case at::ScalarType::Float: \
49
+ { \
50
+ using scalar_t_in = float; \
51
+ switch (TYPEOUT) \
52
+ { \
53
+ case at::ScalarType::Float: \
54
+ { \
55
+ using scalar_t_out = float; \
56
+ __VA_ARGS__; \
57
+ break; \
58
+ } \
59
+ case at::ScalarType::Half: \
60
+ { \
61
+ using scalar_t_out = at::Half; \
62
+ __VA_ARGS__; \
63
+ break; \
64
+ } \
65
+ case at::ScalarType::BFloat16: \
66
+ { \
67
+ using scalar_t_out = at::BFloat16; \
68
+ __VA_ARGS__; \
69
+ break; \
70
+ } \
71
+ default: \
72
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
73
+ } \
74
+ break; \
75
+ } \
76
+ case at::ScalarType::Half: \
77
+ { \
78
+ using scalar_t_in = at::Half; \
79
+ using scalar_t_out = at::Half; \
80
+ __VA_ARGS__; \
81
+ break; \
82
+ } \
83
+ case at::ScalarType::BFloat16: \
84
+ { \
85
+ using scalar_t_in = at::BFloat16; \
86
+ using scalar_t_out = at::BFloat16; \
87
+ __VA_ARGS__; \
88
+ break; \
89
+ } \
90
+ default: \
91
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
92
+ }
modules/bigvgan/alias_free_activation/torch/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ from .filter import *
5
+ from .resample import *
6
+ from .act import *
modules/bigvgan/alias_free_activation/torch/act.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from .resample import UpSample1d, DownSample1d
6
+
7
+
8
+ class Activation1d(nn.Module):
9
+ def __init__(
10
+ self,
11
+ activation,
12
+ up_ratio: int = 2,
13
+ down_ratio: int = 2,
14
+ up_kernel_size: int = 12,
15
+ down_kernel_size: int = 12,
16
+ ):
17
+ super().__init__()
18
+ self.up_ratio = up_ratio
19
+ self.down_ratio = down_ratio
20
+ self.act = activation
21
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
22
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
23
+
24
+ # x: [B,C,T]
25
+ def forward(self, x):
26
+ x = self.upsample(x)
27
+ x = self.act(x)
28
+ x = self.downsample(x)
29
+
30
+ return x
modules/bigvgan/alias_free_activation/torch/filter.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import math
8
+
9
+ if "sinc" in dir(torch):
10
+ sinc = torch.sinc
11
+ else:
12
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
13
+ # https://adefossez.github.io/julius/julius/core.html
14
+ # LICENSE is in incl_licenses directory.
15
+ def sinc(x: torch.Tensor):
16
+ """
17
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
18
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
19
+ """
20
+ return torch.where(
21
+ x == 0,
22
+ torch.tensor(1.0, device=x.device, dtype=x.dtype),
23
+ torch.sin(math.pi * x) / math.pi / x,
24
+ )
25
+
26
+
27
+ # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
28
+ # https://adefossez.github.io/julius/julius/lowpass.html
29
+ # LICENSE is in incl_licenses directory.
30
+ def kaiser_sinc_filter1d(
31
+ cutoff, half_width, kernel_size
32
+ ): # return filter [1,1,kernel_size]
33
+ even = kernel_size % 2 == 0
34
+ half_size = kernel_size // 2
35
+
36
+ # For kaiser window
37
+ delta_f = 4 * half_width
38
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
39
+ if A > 50.0:
40
+ beta = 0.1102 * (A - 8.7)
41
+ elif A >= 21.0:
42
+ beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
43
+ else:
44
+ beta = 0.0
45
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
46
+
47
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
48
+ if even:
49
+ time = torch.arange(-half_size, half_size) + 0.5
50
+ else:
51
+ time = torch.arange(kernel_size) - half_size
52
+ if cutoff == 0:
53
+ filter_ = torch.zeros_like(time)
54
+ else:
55
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
56
+ """
57
+ Normalize filter to have sum = 1, otherwise we will have a small leakage of the constant component in the input signal.
58
+ """
59
+ filter_ /= filter_.sum()
60
+ filter = filter_.view(1, 1, kernel_size)
61
+
62
+ return filter
63
+
64
+
65
+ class LowPassFilter1d(nn.Module):
66
+ def __init__(
67
+ self,
68
+ cutoff=0.5,
69
+ half_width=0.6,
70
+ stride: int = 1,
71
+ padding: bool = True,
72
+ padding_mode: str = "replicate",
73
+ kernel_size: int = 12,
74
+ ):
75
+ """
76
+ kernel_size should be even number for stylegan3 setup, in this implementation, odd number is also possible.
77
+ """
78
+ super().__init__()
79
+ if cutoff < -0.0:
80
+ raise ValueError("Minimum cutoff must be larger than zero.")
81
+ if cutoff > 0.5:
82
+ raise ValueError("A cutoff above 0.5 does not make sense.")
83
+ self.kernel_size = kernel_size
84
+ self.even = kernel_size % 2 == 0
85
+ self.pad_left = kernel_size // 2 - int(self.even)
86
+ self.pad_right = kernel_size // 2
87
+ self.stride = stride
88
+ self.padding = padding
89
+ self.padding_mode = padding_mode
90
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
91
+ self.register_buffer("filter", filter)
92
+
93
+ # Input [B, C, T]
94
+ def forward(self, x):
95
+ _, C, _ = x.shape
96
+
97
+ if self.padding:
98
+ x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
99
+ out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
100
+
101
+ return out
modules/bigvgan/alias_free_activation/torch/resample.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from .filter import LowPassFilter1d
7
+ from .filter import kaiser_sinc_filter1d
8
+
9
+
10
+ class UpSample1d(nn.Module):
11
+ def __init__(self, ratio=2, kernel_size=None):
12
+ super().__init__()
13
+ self.ratio = ratio
14
+ self.kernel_size = (
15
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
16
+ )
17
+ self.stride = ratio
18
+ self.pad = self.kernel_size // ratio - 1
19
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
20
+ self.pad_right = (
21
+ self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
22
+ )
23
+ filter = kaiser_sinc_filter1d(
24
+ cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
25
+ )
26
+ self.register_buffer("filter", filter)
27
+
28
+ # x: [B, C, T]
29
+ def forward(self, x):
30
+ _, C, _ = x.shape
31
+
32
+ x = F.pad(x, (self.pad, self.pad), mode="replicate")
33
+ x = self.ratio * F.conv_transpose1d(
34
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
35
+ )
36
+ x = x[..., self.pad_left : -self.pad_right]
37
+
38
+ return x
39
+
40
+
41
+ class DownSample1d(nn.Module):
42
+ def __init__(self, ratio=2, kernel_size=None):
43
+ super().__init__()
44
+ self.ratio = ratio
45
+ self.kernel_size = (
46
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
47
+ )
48
+ self.lowpass = LowPassFilter1d(
49
+ cutoff=0.5 / ratio,
50
+ half_width=0.6 / ratio,
51
+ stride=ratio,
52
+ kernel_size=self.kernel_size,
53
+ )
54
+
55
+ def forward(self, x):
56
+ xx = self.lowpass(x)
57
+
58
+ return xx
modules/bigvgan/bigvgan.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import os
8
+ import json
9
+ from pathlib import Path
10
+ from typing import Optional, Union, Dict
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.nn import Conv1d, ConvTranspose1d
15
+ from torch.nn.utils import weight_norm, remove_weight_norm
16
+
17
+ from . import activations
18
+ from .utils import init_weights, get_padding
19
+ from .alias_free_activation.torch.act import Activation1d as TorchActivation1d
20
+ from .env import AttrDict
21
+
22
+ from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
23
+
24
+
25
+ def load_hparams_from_json(path) -> AttrDict:
26
+ with open(path) as f:
27
+ data = f.read()
28
+ return AttrDict(json.loads(data))
29
+
30
+
31
+ class AMPBlock1(torch.nn.Module):
32
+ """
33
+ AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
34
+ AMPBlock1 has additional self.convs2 that contains additional Conv1d layers with a fixed dilation=1 followed by each layer in self.convs1
35
+
36
+ Args:
37
+ h (AttrDict): Hyperparameters.
38
+ channels (int): Number of convolution channels.
39
+ kernel_size (int): Size of the convolution kernel. Default is 3.
40
+ dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
41
+ activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ h: AttrDict,
47
+ channels: int,
48
+ kernel_size: int = 3,
49
+ dilation: tuple = (1, 3, 5),
50
+ activation: str = None,
51
+ ):
52
+ super().__init__()
53
+
54
+ self.h = h
55
+
56
+ self.convs1 = nn.ModuleList(
57
+ [
58
+ weight_norm(
59
+ Conv1d(
60
+ channels,
61
+ channels,
62
+ kernel_size,
63
+ stride=1,
64
+ dilation=d,
65
+ padding=get_padding(kernel_size, d),
66
+ )
67
+ )
68
+ for d in dilation
69
+ ]
70
+ )
71
+ self.convs1.apply(init_weights)
72
+
73
+ self.convs2 = nn.ModuleList(
74
+ [
75
+ weight_norm(
76
+ Conv1d(
77
+ channels,
78
+ channels,
79
+ kernel_size,
80
+ stride=1,
81
+ dilation=1,
82
+ padding=get_padding(kernel_size, 1),
83
+ )
84
+ )
85
+ for _ in range(len(dilation))
86
+ ]
87
+ )
88
+ self.convs2.apply(init_weights)
89
+
90
+ self.num_layers = len(self.convs1) + len(
91
+ self.convs2
92
+ ) # Total number of conv layers
93
+
94
+ # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
95
+ if self.h.get("use_cuda_kernel", False):
96
+ from alias_free_activation.cuda.activation1d import (
97
+ Activation1d as CudaActivation1d,
98
+ )
99
+
100
+ Activation1d = CudaActivation1d
101
+ else:
102
+ Activation1d = TorchActivation1d
103
+
104
+ # Activation functions
105
+ if activation == "snake":
106
+ self.activations = nn.ModuleList(
107
+ [
108
+ Activation1d(
109
+ activation=activations.Snake(
110
+ channels, alpha_logscale=h.snake_logscale
111
+ )
112
+ )
113
+ for _ in range(self.num_layers)
114
+ ]
115
+ )
116
+ elif activation == "snakebeta":
117
+ self.activations = nn.ModuleList(
118
+ [
119
+ Activation1d(
120
+ activation=activations.SnakeBeta(
121
+ channels, alpha_logscale=h.snake_logscale
122
+ )
123
+ )
124
+ for _ in range(self.num_layers)
125
+ ]
126
+ )
127
+ else:
128
+ raise NotImplementedError(
129
+ "activation incorrectly specified. check the config file and look for 'activation'."
130
+ )
131
+
132
+ def forward(self, x):
133
+ acts1, acts2 = self.activations[::2], self.activations[1::2]
134
+ for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
135
+ xt = a1(x)
136
+ xt = c1(xt)
137
+ xt = a2(xt)
138
+ xt = c2(xt)
139
+ x = xt + x
140
+
141
+ return x
142
+
143
+ def remove_weight_norm(self):
144
+ for l in self.convs1:
145
+ remove_weight_norm(l)
146
+ for l in self.convs2:
147
+ remove_weight_norm(l)
148
+
149
+
150
+ class AMPBlock2(torch.nn.Module):
151
+ """
152
+ AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
153
+ Unlike AMPBlock1, AMPBlock2 does not contain extra Conv1d layers with fixed dilation=1
154
+
155
+ Args:
156
+ h (AttrDict): Hyperparameters.
157
+ channels (int): Number of convolution channels.
158
+ kernel_size (int): Size of the convolution kernel. Default is 3.
159
+ dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
160
+ activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
161
+ """
162
+
163
+ def __init__(
164
+ self,
165
+ h: AttrDict,
166
+ channels: int,
167
+ kernel_size: int = 3,
168
+ dilation: tuple = (1, 3, 5),
169
+ activation: str = None,
170
+ ):
171
+ super().__init__()
172
+
173
+ self.h = h
174
+
175
+ self.convs = nn.ModuleList(
176
+ [
177
+ weight_norm(
178
+ Conv1d(
179
+ channels,
180
+ channels,
181
+ kernel_size,
182
+ stride=1,
183
+ dilation=d,
184
+ padding=get_padding(kernel_size, d),
185
+ )
186
+ )
187
+ for d in dilation
188
+ ]
189
+ )
190
+ self.convs.apply(init_weights)
191
+
192
+ self.num_layers = len(self.convs) # Total number of conv layers
193
+
194
+ # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
195
+ if self.h.get("use_cuda_kernel", False):
196
+ from alias_free_activation.cuda.activation1d import (
197
+ Activation1d as CudaActivation1d,
198
+ )
199
+
200
+ Activation1d = CudaActivation1d
201
+ else:
202
+ Activation1d = TorchActivation1d
203
+
204
+ # Activation functions
205
+ if activation == "snake":
206
+ self.activations = nn.ModuleList(
207
+ [
208
+ Activation1d(
209
+ activation=activations.Snake(
210
+ channels, alpha_logscale=h.snake_logscale
211
+ )
212
+ )
213
+ for _ in range(self.num_layers)
214
+ ]
215
+ )
216
+ elif activation == "snakebeta":
217
+ self.activations = nn.ModuleList(
218
+ [
219
+ Activation1d(
220
+ activation=activations.SnakeBeta(
221
+ channels, alpha_logscale=h.snake_logscale
222
+ )
223
+ )
224
+ for _ in range(self.num_layers)
225
+ ]
226
+ )
227
+ else:
228
+ raise NotImplementedError(
229
+ "activation incorrectly specified. check the config file and look for 'activation'."
230
+ )
231
+
232
+ def forward(self, x):
233
+ for c, a in zip(self.convs, self.activations):
234
+ xt = a(x)
235
+ xt = c(xt)
236
+ x = xt + x
237
+
238
+ def remove_weight_norm(self):
239
+ for l in self.convs:
240
+ remove_weight_norm(l)
241
+
242
+
243
+ class BigVGAN(
244
+ torch.nn.Module,
245
+ PyTorchModelHubMixin,
246
+ library_name="bigvgan",
247
+ repo_url="https://github.com/NVIDIA/BigVGAN",
248
+ docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md",
249
+ pipeline_tag="audio-to-audio",
250
+ license="mit",
251
+ tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"],
252
+ ):
253
+ """
254
+ BigVGAN is a neural vocoder model that applies anti-aliased periodic activation for residual blocks (resblocks).
255
+ New in BigVGAN-v2: it can optionally use optimized CUDA kernels for AMP (anti-aliased multi-periodicity) blocks.
256
+
257
+ Args:
258
+ h (AttrDict): Hyperparameters.
259
+ use_cuda_kernel (bool): If set to True, loads optimized CUDA kernels for AMP. This should be used for inference only, as training is not supported with CUDA kernels.
260
+
261
+ Note:
262
+ - The `use_cuda_kernel` parameter should be used for inference only, as training with CUDA kernels is not supported.
263
+ - Ensure that the activation function is correctly specified in the hyperparameters (h.activation).
264
+ """
265
+
266
+ def __init__(self, h: AttrDict, use_cuda_kernel: bool = False):
267
+ super().__init__()
268
+ self.h = h
269
+ self.h["use_cuda_kernel"] = use_cuda_kernel
270
+
271
+ # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
272
+ if self.h.get("use_cuda_kernel", False):
273
+ from alias_free_activation.cuda.activation1d import (
274
+ Activation1d as CudaActivation1d,
275
+ )
276
+
277
+ Activation1d = CudaActivation1d
278
+ else:
279
+ Activation1d = TorchActivation1d
280
+
281
+ self.num_kernels = len(h.resblock_kernel_sizes)
282
+ self.num_upsamples = len(h.upsample_rates)
283
+
284
+ # Pre-conv
285
+ self.conv_pre = weight_norm(
286
+ Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
287
+ )
288
+
289
+ # Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
290
+ if h.resblock == "1":
291
+ resblock_class = AMPBlock1
292
+ elif h.resblock == "2":
293
+ resblock_class = AMPBlock2
294
+ else:
295
+ raise ValueError(
296
+ f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}"
297
+ )
298
+
299
+ # Transposed conv-based upsamplers. does not apply anti-aliasing
300
+ self.ups = nn.ModuleList()
301
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
302
+ self.ups.append(
303
+ nn.ModuleList(
304
+ [
305
+ weight_norm(
306
+ ConvTranspose1d(
307
+ h.upsample_initial_channel // (2**i),
308
+ h.upsample_initial_channel // (2 ** (i + 1)),
309
+ k,
310
+ u,
311
+ padding=(k - u) // 2,
312
+ )
313
+ )
314
+ ]
315
+ )
316
+ )
317
+
318
+ # Residual blocks using anti-aliased multi-periodicity composition modules (AMP)
319
+ self.resblocks = nn.ModuleList()
320
+ for i in range(len(self.ups)):
321
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
322
+ for j, (k, d) in enumerate(
323
+ zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
324
+ ):
325
+ self.resblocks.append(
326
+ resblock_class(h, ch, k, d, activation=h.activation)
327
+ )
328
+
329
+ # Post-conv
330
+ activation_post = (
331
+ activations.Snake(ch, alpha_logscale=h.snake_logscale)
332
+ if h.activation == "snake"
333
+ else (
334
+ activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
335
+ if h.activation == "snakebeta"
336
+ else None
337
+ )
338
+ )
339
+ if activation_post is None:
340
+ raise NotImplementedError(
341
+ "activation incorrectly specified. check the config file and look for 'activation'."
342
+ )
343
+
344
+ self.activation_post = Activation1d(activation=activation_post)
345
+
346
+ # Whether to use bias for the final conv_post. Default to True for backward compatibility
347
+ self.use_bias_at_final = h.get("use_bias_at_final", True)
348
+ self.conv_post = weight_norm(
349
+ Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)
350
+ )
351
+
352
+ # Weight initialization
353
+ for i in range(len(self.ups)):
354
+ self.ups[i].apply(init_weights)
355
+ self.conv_post.apply(init_weights)
356
+
357
+ # Final tanh activation. Defaults to True for backward compatibility
358
+ self.use_tanh_at_final = h.get("use_tanh_at_final", True)
359
+
360
+ def forward(self, x):
361
+ # Pre-conv
362
+ x = self.conv_pre(x)
363
+
364
+ for i in range(self.num_upsamples):
365
+ # Upsampling
366
+ for i_up in range(len(self.ups[i])):
367
+ x = self.ups[i][i_up](x)
368
+ # AMP blocks
369
+ xs = None
370
+ for j in range(self.num_kernels):
371
+ if xs is None:
372
+ xs = self.resblocks[i * self.num_kernels + j](x)
373
+ else:
374
+ xs += self.resblocks[i * self.num_kernels + j](x)
375
+ x = xs / self.num_kernels
376
+
377
+ # Post-conv
378
+ x = self.activation_post(x)
379
+ x = self.conv_post(x)
380
+ # Final tanh activation
381
+ if self.use_tanh_at_final:
382
+ x = torch.tanh(x)
383
+ else:
384
+ x = torch.clamp(x, min=-1.0, max=1.0) # Bound the output to [-1, 1]
385
+
386
+ return x
387
+
388
+ def remove_weight_norm(self):
389
+ try:
390
+ print("Removing weight norm...")
391
+ for l in self.ups:
392
+ for l_i in l:
393
+ remove_weight_norm(l_i)
394
+ for l in self.resblocks:
395
+ l.remove_weight_norm()
396
+ remove_weight_norm(self.conv_pre)
397
+ remove_weight_norm(self.conv_post)
398
+ except ValueError:
399
+ print("[INFO] Model already removed weight norm. Skipping!")
400
+ pass
401
+
402
+ # Additional methods for huggingface_hub support
403
+ def _save_pretrained(self, save_directory: Path) -> None:
404
+ """Save weights and config.json from a Pytorch model to a local directory."""
405
+
406
+ model_path = save_directory / "bigvgan_generator.pt"
407
+ torch.save({"generator": self.state_dict()}, model_path)
408
+
409
+ config_path = save_directory / "config.json"
410
+ with open(config_path, "w") as config_file:
411
+ json.dump(self.h, config_file, indent=4)
412
+
413
+ @classmethod
414
+ def _from_pretrained(
415
+ cls,
416
+ *,
417
+ model_id: str,
418
+ revision: str,
419
+ cache_dir: str,
420
+ force_download: bool,
421
+ proxies: Optional[Dict],
422
+ resume_download: bool,
423
+ local_files_only: bool,
424
+ token: Union[str, bool, None],
425
+ map_location: str = "cpu", # Additional argument
426
+ strict: bool = False, # Additional argument
427
+ use_cuda_kernel: bool = False,
428
+ **model_kwargs,
429
+ ):
430
+ """Load Pytorch pretrained weights and return the loaded model."""
431
+
432
+ # Download and load hyperparameters (h) used by BigVGAN
433
+ if os.path.isdir(model_id):
434
+ print("Loading config.json from local directory")
435
+ config_file = os.path.join(model_id, "config.json")
436
+ else:
437
+ config_file = hf_hub_download(
438
+ repo_id=model_id,
439
+ filename="config.json",
440
+ revision=revision,
441
+ cache_dir=cache_dir,
442
+ force_download=force_download,
443
+ proxies=proxies,
444
+ resume_download=resume_download,
445
+ token=token,
446
+ local_files_only=local_files_only,
447
+ )
448
+ h = load_hparams_from_json(config_file)
449
+
450
+ # instantiate BigVGAN using h
451
+ if use_cuda_kernel:
452
+ print(
453
+ f"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
454
+ )
455
+ print(
456
+ f"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
457
+ )
458
+ print(
459
+ f"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
460
+ )
461
+ model = cls(h, use_cuda_kernel=use_cuda_kernel)
462
+
463
+ # Download and load pretrained generator weight
464
+ if os.path.isdir(model_id):
465
+ print("Loading weights from local directory")
466
+ model_file = os.path.join(model_id, "bigvgan_generator.pt")
467
+ else:
468
+ print(f"Loading weights from {model_id}")
469
+ model_file = hf_hub_download(
470
+ repo_id=model_id,
471
+ filename="bigvgan_generator.pt",
472
+ revision=revision,
473
+ cache_dir=cache_dir,
474
+ force_download=force_download,
475
+ proxies=proxies,
476
+ resume_download=resume_download,
477
+ token=token,
478
+ local_files_only=local_files_only,
479
+ )
480
+
481
+ checkpoint_dict = torch.load(model_file, map_location=map_location)
482
+
483
+ try:
484
+ model.load_state_dict(checkpoint_dict["generator"])
485
+ except RuntimeError:
486
+ print(
487
+ f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
488
+ )
489
+ model.remove_weight_norm()
490
+ model.load_state_dict(checkpoint_dict["generator"])
491
+
492
+ return model
modules/bigvgan/config.json ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 32,
5
+ "learning_rate": 0.0001,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.9999996,
9
+ "seed": 1234,
10
+
11
+ "upsample_rates": [4,4,2,2,2,2],
12
+ "upsample_kernel_sizes": [8,8,4,4,4,4],
13
+ "upsample_initial_channel": 1536,
14
+ "resblock_kernel_sizes": [3,7,11],
15
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16
+
17
+ "use_tanh_at_final": false,
18
+ "use_bias_at_final": false,
19
+
20
+ "activation": "snakebeta",
21
+ "snake_logscale": true,
22
+
23
+ "use_cqtd_instead_of_mrd": true,
24
+ "cqtd_filters": 128,
25
+ "cqtd_max_filters": 1024,
26
+ "cqtd_filters_scale": 1,
27
+ "cqtd_dilations": [1, 2, 4],
28
+ "cqtd_hop_lengths": [512, 256, 256],
29
+ "cqtd_n_octaves": [9, 9, 9],
30
+ "cqtd_bins_per_octaves": [24, 36, 48],
31
+
32
+ "mpd_reshapes": [2, 3, 5, 7, 11],
33
+ "use_spectral_norm": false,
34
+ "discriminator_channel_mult": 1,
35
+
36
+ "use_multiscale_melloss": true,
37
+ "lambda_melloss": 15,
38
+
39
+ "clip_grad_norm": 500,
40
+
41
+ "segment_size": 65536,
42
+ "num_mels": 80,
43
+ "num_freq": 1025,
44
+ "n_fft": 1024,
45
+ "hop_size": 256,
46
+ "win_size": 1024,
47
+
48
+ "sampling_rate": 22050,
49
+
50
+ "fmin": 0,
51
+ "fmax": null,
52
+ "fmax_for_loss": null,
53
+
54
+ "normalize_volume": true,
55
+
56
+ "num_workers": 4,
57
+
58
+ "dist_config": {
59
+ "dist_backend": "nccl",
60
+ "dist_url": "tcp://localhost:54321",
61
+ "world_size": 1
62
+ }
63
+ }
modules/bigvgan/env.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import os
5
+ import shutil
6
+
7
+
8
+ class AttrDict(dict):
9
+ def __init__(self, *args, **kwargs):
10
+ super(AttrDict, self).__init__(*args, **kwargs)
11
+ self.__dict__ = self
12
+
13
+
14
+ def build_env(config, config_name, path):
15
+ t_path = os.path.join(path, config_name)
16
+ if config != t_path:
17
+ os.makedirs(path, exist_ok=True)
18
+ shutil.copyfile(config, os.path.join(path, config_name))
modules/bigvgan/meldataset.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import math
8
+ import os
9
+ import random
10
+ import torch
11
+ import torch.utils.data
12
+ import numpy as np
13
+ from librosa.util import normalize
14
+ from scipy.io.wavfile import read
15
+ from librosa.filters import mel as librosa_mel_fn
16
+ import pathlib
17
+ from tqdm import tqdm
18
+
19
+ MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)
20
+
21
+
22
+ def load_wav(full_path, sr_target):
23
+ sampling_rate, data = read(full_path)
24
+ if sampling_rate != sr_target:
25
+ raise RuntimeError(
26
+ f"Sampling rate of the file {full_path} is {sampling_rate} Hz, but the model requires {sr_target} Hz"
27
+ )
28
+ return data, sampling_rate
29
+
30
+
31
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
32
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
33
+
34
+
35
+ def dynamic_range_decompression(x, C=1):
36
+ return np.exp(x) / C
37
+
38
+
39
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
40
+ return torch.log(torch.clamp(x, min=clip_val) * C)
41
+
42
+
43
+ def dynamic_range_decompression_torch(x, C=1):
44
+ return torch.exp(x) / C
45
+
46
+
47
+ def spectral_normalize_torch(magnitudes):
48
+ return dynamic_range_compression_torch(magnitudes)
49
+
50
+
51
+ def spectral_de_normalize_torch(magnitudes):
52
+ return dynamic_range_decompression_torch(magnitudes)
53
+
54
+
55
+ mel_basis_cache = {}
56
+ hann_window_cache = {}
57
+
58
+
59
+ def mel_spectrogram(
60
+ y: torch.Tensor,
61
+ n_fft: int,
62
+ num_mels: int,
63
+ sampling_rate: int,
64
+ hop_size: int,
65
+ win_size: int,
66
+ fmin: int,
67
+ fmax: int = None,
68
+ center: bool = False,
69
+ ) -> torch.Tensor:
70
+ """
71
+ Calculate the mel spectrogram of an input signal.
72
+ This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft).
73
+
74
+ Args:
75
+ y (torch.Tensor): Input signal.
76
+ n_fft (int): FFT size.
77
+ num_mels (int): Number of mel bins.
78
+ sampling_rate (int): Sampling rate of the input signal.
79
+ hop_size (int): Hop size for STFT.
80
+ win_size (int): Window size for STFT.
81
+ fmin (int): Minimum frequency for mel filterbank.
82
+ fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn
83
+ center (bool): Whether to pad the input to center the frames. Default is False.
84
+
85
+ Returns:
86
+ torch.Tensor: Mel spectrogram.
87
+ """
88
+ if torch.min(y) < -1.0:
89
+ print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}")
90
+ if torch.max(y) > 1.0:
91
+ print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}")
92
+
93
+ device = y.device
94
+ key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
95
+
96
+ if key not in mel_basis_cache:
97
+ mel = librosa_mel_fn(
98
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
99
+ )
100
+ mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
101
+ hann_window_cache[key] = torch.hann_window(win_size).to(device)
102
+
103
+ mel_basis = mel_basis_cache[key]
104
+ hann_window = hann_window_cache[key]
105
+
106
+ padding = (n_fft - hop_size) // 2
107
+ y = torch.nn.functional.pad(
108
+ y.unsqueeze(1), (padding, padding), mode="reflect"
109
+ ).squeeze(1)
110
+
111
+ spec = torch.stft(
112
+ y,
113
+ n_fft,
114
+ hop_length=hop_size,
115
+ win_length=win_size,
116
+ window=hann_window,
117
+ center=center,
118
+ pad_mode="reflect",
119
+ normalized=False,
120
+ onesided=True,
121
+ return_complex=True,
122
+ )
123
+ spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
124
+
125
+ mel_spec = torch.matmul(mel_basis, spec)
126
+ mel_spec = spectral_normalize_torch(mel_spec)
127
+
128
+ return mel_spec
129
+
130
+
131
+ def get_mel_spectrogram(wav, h):
132
+ """
133
+ Generate mel spectrogram from a waveform using given hyperparameters.
134
+
135
+ Args:
136
+ wav (torch.Tensor): Input waveform.
137
+ h: Hyperparameters object with attributes n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax.
138
+
139
+ Returns:
140
+ torch.Tensor: Mel spectrogram.
141
+ """
142
+ return mel_spectrogram(
143
+ wav,
144
+ h.n_fft,
145
+ h.num_mels,
146
+ h.sampling_rate,
147
+ h.hop_size,
148
+ h.win_size,
149
+ h.fmin,
150
+ h.fmax,
151
+ )
152
+
153
+
154
+ def get_dataset_filelist(a):
155
+ training_files = []
156
+ validation_files = []
157
+ list_unseen_validation_files = []
158
+
159
+ with open(a.input_training_file, "r", encoding="utf-8") as fi:
160
+ training_files = [
161
+ os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
162
+ for x in fi.read().split("\n")
163
+ if len(x) > 0
164
+ ]
165
+ print(f"first training file: {training_files[0]}")
166
+
167
+ with open(a.input_validation_file, "r", encoding="utf-8") as fi:
168
+ validation_files = [
169
+ os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
170
+ for x in fi.read().split("\n")
171
+ if len(x) > 0
172
+ ]
173
+ print(f"first validation file: {validation_files[0]}")
174
+
175
+ for i in range(len(a.list_input_unseen_validation_file)):
176
+ with open(a.list_input_unseen_validation_file[i], "r", encoding="utf-8") as fi:
177
+ unseen_validation_files = [
178
+ os.path.join(a.list_input_unseen_wavs_dir[i], x.split("|")[0] + ".wav")
179
+ for x in fi.read().split("\n")
180
+ if len(x) > 0
181
+ ]
182
+ print(
183
+ f"first unseen {i}th validation fileset: {unseen_validation_files[0]}"
184
+ )
185
+ list_unseen_validation_files.append(unseen_validation_files)
186
+
187
+ return training_files, validation_files, list_unseen_validation_files
188
+
189
+
190
+ class MelDataset(torch.utils.data.Dataset):
191
+ def __init__(
192
+ self,
193
+ training_files,
194
+ hparams,
195
+ segment_size,
196
+ n_fft,
197
+ num_mels,
198
+ hop_size,
199
+ win_size,
200
+ sampling_rate,
201
+ fmin,
202
+ fmax,
203
+ split=True,
204
+ shuffle=True,
205
+ n_cache_reuse=1,
206
+ device=None,
207
+ fmax_loss=None,
208
+ fine_tuning=False,
209
+ base_mels_path=None,
210
+ is_seen=True,
211
+ ):
212
+ self.audio_files = training_files
213
+ random.seed(1234)
214
+ if shuffle:
215
+ random.shuffle(self.audio_files)
216
+ self.hparams = hparams
217
+ self.is_seen = is_seen
218
+ if self.is_seen:
219
+ self.name = pathlib.Path(self.audio_files[0]).parts[0]
220
+ else:
221
+ self.name = "-".join(pathlib.Path(self.audio_files[0]).parts[:2]).strip("/")
222
+
223
+ self.segment_size = segment_size
224
+ self.sampling_rate = sampling_rate
225
+ self.split = split
226
+ self.n_fft = n_fft
227
+ self.num_mels = num_mels
228
+ self.hop_size = hop_size
229
+ self.win_size = win_size
230
+ self.fmin = fmin
231
+ self.fmax = fmax
232
+ self.fmax_loss = fmax_loss
233
+ self.cached_wav = None
234
+ self.n_cache_reuse = n_cache_reuse
235
+ self._cache_ref_count = 0
236
+ self.device = device
237
+ self.fine_tuning = fine_tuning
238
+ self.base_mels_path = base_mels_path
239
+
240
+ print("[INFO] checking dataset integrity...")
241
+ for i in tqdm(range(len(self.audio_files))):
242
+ assert os.path.exists(
243
+ self.audio_files[i]
244
+ ), f"{self.audio_files[i]} not found"
245
+
246
+ def __getitem__(self, index):
247
+ filename = self.audio_files[index]
248
+ if self._cache_ref_count == 0:
249
+ audio, sampling_rate = load_wav(filename, self.sampling_rate)
250
+ audio = audio / MAX_WAV_VALUE
251
+ if not self.fine_tuning:
252
+ audio = normalize(audio) * 0.95
253
+ self.cached_wav = audio
254
+ if sampling_rate != self.sampling_rate:
255
+ raise ValueError(
256
+ f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR"
257
+ )
258
+ self._cache_ref_count = self.n_cache_reuse
259
+ else:
260
+ audio = self.cached_wav
261
+ self._cache_ref_count -= 1
262
+
263
+ audio = torch.FloatTensor(audio)
264
+ audio = audio.unsqueeze(0)
265
+
266
+ if not self.fine_tuning:
267
+ if self.split:
268
+ if audio.size(1) >= self.segment_size:
269
+ max_audio_start = audio.size(1) - self.segment_size
270
+ audio_start = random.randint(0, max_audio_start)
271
+ audio = audio[:, audio_start : audio_start + self.segment_size]
272
+ else:
273
+ audio = torch.nn.functional.pad(
274
+ audio, (0, self.segment_size - audio.size(1)), "constant"
275
+ )
276
+
277
+ mel = mel_spectrogram(
278
+ audio,
279
+ self.n_fft,
280
+ self.num_mels,
281
+ self.sampling_rate,
282
+ self.hop_size,
283
+ self.win_size,
284
+ self.fmin,
285
+ self.fmax,
286
+ center=False,
287
+ )
288
+ else: # Validation step
289
+ # Match audio length to self.hop_size * n for evaluation
290
+ if (audio.size(1) % self.hop_size) != 0:
291
+ audio = audio[:, : -(audio.size(1) % self.hop_size)]
292
+ mel = mel_spectrogram(
293
+ audio,
294
+ self.n_fft,
295
+ self.num_mels,
296
+ self.sampling_rate,
297
+ self.hop_size,
298
+ self.win_size,
299
+ self.fmin,
300
+ self.fmax,
301
+ center=False,
302
+ )
303
+ assert (
304
+ audio.shape[1] == mel.shape[2] * self.hop_size
305
+ ), f"audio shape {audio.shape} mel shape {mel.shape}"
306
+
307
+ else:
308
+ mel = np.load(
309
+ os.path.join(
310
+ self.base_mels_path,
311
+ os.path.splitext(os.path.split(filename)[-1])[0] + ".npy",
312
+ )
313
+ )
314
+ mel = torch.from_numpy(mel)
315
+
316
+ if len(mel.shape) < 3:
317
+ mel = mel.unsqueeze(0)
318
+
319
+ if self.split:
320
+ frames_per_seg = math.ceil(self.segment_size / self.hop_size)
321
+
322
+ if audio.size(1) >= self.segment_size:
323
+ mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
324
+ mel = mel[:, :, mel_start : mel_start + frames_per_seg]
325
+ audio = audio[
326
+ :,
327
+ mel_start
328
+ * self.hop_size : (mel_start + frames_per_seg)
329
+ * self.hop_size,
330
+ ]
331
+ else:
332
+ mel = torch.nn.functional.pad(
333
+ mel, (0, frames_per_seg - mel.size(2)), "constant"
334
+ )
335
+ audio = torch.nn.functional.pad(
336
+ audio, (0, self.segment_size - audio.size(1)), "constant"
337
+ )
338
+
339
+ mel_loss = mel_spectrogram(
340
+ audio,
341
+ self.n_fft,
342
+ self.num_mels,
343
+ self.sampling_rate,
344
+ self.hop_size,
345
+ self.win_size,
346
+ self.fmin,
347
+ self.fmax_loss,
348
+ center=False,
349
+ )
350
+
351
+ return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
352
+
353
+ def __len__(self):
354
+ return len(self.audio_files)
modules/bigvgan/utils.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import glob
5
+ import os
6
+ import matplotlib
7
+ import torch
8
+ from torch.nn.utils import weight_norm
9
+
10
+ matplotlib.use("Agg")
11
+ import matplotlib.pylab as plt
12
+ from .meldataset import MAX_WAV_VALUE
13
+ from scipy.io.wavfile import write
14
+
15
+
16
+ def plot_spectrogram(spectrogram):
17
+ fig, ax = plt.subplots(figsize=(10, 2))
18
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
19
+ plt.colorbar(im, ax=ax)
20
+
21
+ fig.canvas.draw()
22
+ plt.close()
23
+
24
+ return fig
25
+
26
+
27
+ def plot_spectrogram_clipped(spectrogram, clip_max=2.0):
28
+ fig, ax = plt.subplots(figsize=(10, 2))
29
+ im = ax.imshow(
30
+ spectrogram,
31
+ aspect="auto",
32
+ origin="lower",
33
+ interpolation="none",
34
+ vmin=1e-6,
35
+ vmax=clip_max,
36
+ )
37
+ plt.colorbar(im, ax=ax)
38
+
39
+ fig.canvas.draw()
40
+ plt.close()
41
+
42
+ return fig
43
+
44
+
45
+ def init_weights(m, mean=0.0, std=0.01):
46
+ classname = m.__class__.__name__
47
+ if classname.find("Conv") != -1:
48
+ m.weight.data.normal_(mean, std)
49
+
50
+
51
+ def apply_weight_norm(m):
52
+ classname = m.__class__.__name__
53
+ if classname.find("Conv") != -1:
54
+ weight_norm(m)
55
+
56
+
57
+ def get_padding(kernel_size, dilation=1):
58
+ return int((kernel_size * dilation - dilation) / 2)
59
+
60
+
61
+ def load_checkpoint(filepath, device):
62
+ assert os.path.isfile(filepath)
63
+ print(f"Loading '{filepath}'")
64
+ checkpoint_dict = torch.load(filepath, map_location=device)
65
+ print("Complete.")
66
+ return checkpoint_dict
67
+
68
+
69
+ def save_checkpoint(filepath, obj):
70
+ print(f"Saving checkpoint to {filepath}")
71
+ torch.save(obj, filepath)
72
+ print("Complete.")
73
+
74
+
75
+ def scan_checkpoint(cp_dir, prefix, renamed_file=None):
76
+ # Fallback to original scanning logic first
77
+ pattern = os.path.join(cp_dir, prefix + "????????")
78
+ cp_list = glob.glob(pattern)
79
+
80
+ if len(cp_list) > 0:
81
+ last_checkpoint_path = sorted(cp_list)[-1]
82
+ print(f"[INFO] Resuming from checkpoint: '{last_checkpoint_path}'")
83
+ return last_checkpoint_path
84
+
85
+ # If no pattern-based checkpoints are found, check for renamed file
86
+ if renamed_file:
87
+ renamed_path = os.path.join(cp_dir, renamed_file)
88
+ if os.path.isfile(renamed_path):
89
+ print(f"[INFO] Resuming from renamed checkpoint: '{renamed_file}'")
90
+ return renamed_path
91
+
92
+ return None
93
+
94
+
95
+ def save_audio(audio, path, sr):
96
+ # wav: torch with 1d shape
97
+ audio = audio * MAX_WAV_VALUE
98
+ audio = audio.cpu().numpy().astype("int16")
99
+ write(path, sr, audio)
modules/diffusion_transformer.py CHANGED
@@ -106,7 +106,7 @@ class DiT(torch.nn.Module):
106
  self.style_as_token = args.DiT.style_as_token if hasattr(args.DiT, 'style_as_token') else False
107
  self.uvit_skip_connection = args.DiT.uvit_skip_connection if hasattr(args.DiT, 'uvit_skip_connection') else False
108
  model_args = ModelArgs(
109
- block_size=8192,#args.DiT.block_size,
110
  n_layer=args.DiT.depth,
111
  n_head=args.DiT.num_heads,
112
  dim=args.DiT.hidden_dim,
@@ -139,7 +139,7 @@ class DiT(torch.nn.Module):
139
  # self.style_embedder1 = weight_norm(nn.Linear(1024, args.DiT.hidden_dim, bias=True))
140
  # self.style_embedder2 = weight_norm(nn.Linear(1024, args.style_encoder.dim, bias=True))
141
 
142
- input_pos = torch.arange(8192)
143
  self.register_buffer("input_pos", input_pos)
144
 
145
  self.conv1 = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim)
 
106
  self.style_as_token = args.DiT.style_as_token if hasattr(args.DiT, 'style_as_token') else False
107
  self.uvit_skip_connection = args.DiT.uvit_skip_connection if hasattr(args.DiT, 'uvit_skip_connection') else False
108
  model_args = ModelArgs(
109
+ block_size=16384,#args.DiT.block_size,
110
  n_layer=args.DiT.depth,
111
  n_head=args.DiT.num_heads,
112
  dim=args.DiT.hidden_dim,
 
139
  # self.style_embedder1 = weight_norm(nn.Linear(1024, args.DiT.hidden_dim, bias=True))
140
  # self.style_embedder2 = weight_norm(nn.Linear(1024, args.style_encoder.dim, bias=True))
141
 
142
+ input_pos = torch.arange(16384)
143
  self.register_buffer("input_pos", input_pos)
144
 
145
  self.conv1 = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim)
modules/flow_matching.py CHANGED
@@ -6,6 +6,8 @@ import torch.nn.functional as F
6
  from modules.diffusion_transformer import DiT
7
  from modules.commons import sequence_mask
8
 
 
 
9
  class BASECFM(torch.nn.Module, ABC):
10
  def __init__(
11
  self,
@@ -76,7 +78,7 @@ class BASECFM(torch.nn.Module, ABC):
76
  x[..., :prompt_len] = 0
77
  if self.zero_prompt_speech_token:
78
  mu[..., :prompt_len] = 0
79
- for step in range(1, len(t_span)):
80
  dphi_dt = self.estimator(x, prompt_x, x_lens, t.unsqueeze(0), style, mu, f0)
81
  # Classifier-Free Guidance inference introduced in VoiceBox
82
  if inference_cfg_rate > 0:
 
6
  from modules.diffusion_transformer import DiT
7
  from modules.commons import sequence_mask
8
 
9
+ from tqdm import tqdm
10
+
11
  class BASECFM(torch.nn.Module, ABC):
12
  def __init__(
13
  self,
 
78
  x[..., :prompt_len] = 0
79
  if self.zero_prompt_speech_token:
80
  mu[..., :prompt_len] = 0
81
+ for step in tqdm(range(1, len(t_span))):
82
  dphi_dt = self.estimator(x, prompt_x, x_lens, t.unsqueeze(0), style, mu, f0)
83
  # Classifier-Free Guidance inference introduced in VoiceBox
84
  if inference_cfg_rate > 0:
modules/hifigan/generator.py CHANGED
@@ -1,454 +1,454 @@
1
- # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """HIFI-GAN"""
16
-
17
- import typing as tp
18
- import numpy as np
19
- from scipy.signal import get_window
20
- import torch
21
- import torch.nn as nn
22
- import torch.nn.functional as F
23
- from torch.nn import Conv1d
24
- from torch.nn import ConvTranspose1d
25
- from torch.nn.utils import remove_weight_norm
26
- from torch.nn.utils import weight_norm
27
- from torch.distributions.uniform import Uniform
28
-
29
- from torch import sin
30
- from torch.nn.parameter import Parameter
31
-
32
-
33
- """hifigan based generator implementation.
34
-
35
- This code is modified from https://github.com/jik876/hifi-gan
36
- ,https://github.com/kan-bayashi/ParallelWaveGAN and
37
- https://github.com/NVIDIA/BigVGAN
38
-
39
- """
40
- class Snake(nn.Module):
41
- '''
42
- Implementation of a sine-based periodic activation function
43
- Shape:
44
- - Input: (B, C, T)
45
- - Output: (B, C, T), same shape as the input
46
- Parameters:
47
- - alpha - trainable parameter
48
- References:
49
- - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
50
- https://arxiv.org/abs/2006.08195
51
- Examples:
52
- >>> a1 = snake(256)
53
- >>> x = torch.randn(256)
54
- >>> x = a1(x)
55
- '''
56
- def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
57
- '''
58
- Initialization.
59
- INPUT:
60
- - in_features: shape of the input
61
- - alpha: trainable parameter
62
- alpha is initialized to 1 by default, higher values = higher-frequency.
63
- alpha will be trained along with the rest of your model.
64
- '''
65
- super(Snake, self).__init__()
66
- self.in_features = in_features
67
-
68
- # initialize alpha
69
- self.alpha_logscale = alpha_logscale
70
- if self.alpha_logscale: # log scale alphas initialized to zeros
71
- self.alpha = Parameter(torch.zeros(in_features) * alpha)
72
- else: # linear scale alphas initialized to ones
73
- self.alpha = Parameter(torch.ones(in_features) * alpha)
74
-
75
- self.alpha.requires_grad = alpha_trainable
76
-
77
- self.no_div_by_zero = 0.000000001
78
-
79
- def forward(self, x):
80
- '''
81
- Forward pass of the function.
82
- Applies the function to the input elementwise.
83
- Snake ∶= x + 1/a * sin^2 (xa)
84
- '''
85
- alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
86
- if self.alpha_logscale:
87
- alpha = torch.exp(alpha)
88
- x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
89
-
90
- return x
91
-
92
- def get_padding(kernel_size, dilation=1):
93
- return int((kernel_size * dilation - dilation) / 2)
94
-
95
-
96
- def init_weights(m, mean=0.0, std=0.01):
97
- classname = m.__class__.__name__
98
- if classname.find("Conv") != -1:
99
- m.weight.data.normal_(mean, std)
100
-
101
-
102
-
103
- class ResBlock(torch.nn.Module):
104
- """Residual block module in HiFiGAN/BigVGAN."""
105
- def __init__(
106
- self,
107
- channels: int = 512,
108
- kernel_size: int = 3,
109
- dilations: tp.List[int] = [1, 3, 5],
110
- ):
111
- super(ResBlock, self).__init__()
112
- self.convs1 = nn.ModuleList()
113
- self.convs2 = nn.ModuleList()
114
-
115
- for dilation in dilations:
116
- self.convs1.append(
117
- weight_norm(
118
- Conv1d(
119
- channels,
120
- channels,
121
- kernel_size,
122
- 1,
123
- dilation=dilation,
124
- padding=get_padding(kernel_size, dilation)
125
- )
126
- )
127
- )
128
- self.convs2.append(
129
- weight_norm(
130
- Conv1d(
131
- channels,
132
- channels,
133
- kernel_size,
134
- 1,
135
- dilation=1,
136
- padding=get_padding(kernel_size, 1)
137
- )
138
- )
139
- )
140
- self.convs1.apply(init_weights)
141
- self.convs2.apply(init_weights)
142
- self.activations1 = nn.ModuleList([
143
- Snake(channels, alpha_logscale=False)
144
- for _ in range(len(self.convs1))
145
- ])
146
- self.activations2 = nn.ModuleList([
147
- Snake(channels, alpha_logscale=False)
148
- for _ in range(len(self.convs2))
149
- ])
150
-
151
- def forward(self, x: torch.Tensor) -> torch.Tensor:
152
- for idx in range(len(self.convs1)):
153
- xt = self.activations1[idx](x)
154
- xt = self.convs1[idx](xt)
155
- xt = self.activations2[idx](xt)
156
- xt = self.convs2[idx](xt)
157
- x = xt + x
158
- return x
159
-
160
- def remove_weight_norm(self):
161
- for idx in range(len(self.convs1)):
162
- remove_weight_norm(self.convs1[idx])
163
- remove_weight_norm(self.convs2[idx])
164
-
165
- class SineGen(torch.nn.Module):
166
- """ Definition of sine generator
167
- SineGen(samp_rate, harmonic_num = 0,
168
- sine_amp = 0.1, noise_std = 0.003,
169
- voiced_threshold = 0,
170
- flag_for_pulse=False)
171
- samp_rate: sampling rate in Hz
172
- harmonic_num: number of harmonic overtones (default 0)
173
- sine_amp: amplitude of sine-wavefrom (default 0.1)
174
- noise_std: std of Gaussian noise (default 0.003)
175
- voiced_thoreshold: F0 threshold for U/V classification (default 0)
176
- flag_for_pulse: this SinGen is used inside PulseGen (default False)
177
- Note: when flag_for_pulse is True, the first time step of a voiced
178
- segment is always sin(np.pi) or cos(0)
179
- """
180
-
181
- def __init__(self, samp_rate, harmonic_num=0,
182
- sine_amp=0.1, noise_std=0.003,
183
- voiced_threshold=0):
184
- super(SineGen, self).__init__()
185
- self.sine_amp = sine_amp
186
- self.noise_std = noise_std
187
- self.harmonic_num = harmonic_num
188
- self.sampling_rate = samp_rate
189
- self.voiced_threshold = voiced_threshold
190
-
191
- def _f02uv(self, f0):
192
- # generate uv signal
193
- uv = (f0 > self.voiced_threshold).type(torch.float32)
194
- return uv
195
-
196
- @torch.no_grad()
197
- def forward(self, f0):
198
- """
199
- :param f0: [B, 1, sample_len], Hz
200
- :return: [B, 1, sample_len]
201
- """
202
-
203
- F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
204
- for i in range(self.harmonic_num + 1):
205
- F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
206
-
207
- theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
208
- u_dist = Uniform(low=-np.pi, high=np.pi)
209
- phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
210
- phase_vec[:, 0, :] = 0
211
-
212
- # generate sine waveforms
213
- sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
214
-
215
- # generate uv signal
216
- uv = self._f02uv(f0)
217
-
218
- # noise: for unvoiced should be similar to sine_amp
219
- # std = self.sine_amp/3 -> max value ~ self.sine_amp
220
- # . for voiced regions is self.noise_std
221
- noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
222
- noise = noise_amp * torch.randn_like(sine_waves)
223
-
224
- # first: set the unvoiced part to 0 by uv
225
- # then: additive noise
226
- sine_waves = sine_waves * uv + noise
227
- return sine_waves, uv, noise
228
-
229
-
230
- class SourceModuleHnNSF(torch.nn.Module):
231
- """ SourceModule for hn-nsf
232
- SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
233
- add_noise_std=0.003, voiced_threshod=0)
234
- sampling_rate: sampling_rate in Hz
235
- harmonic_num: number of harmonic above F0 (default: 0)
236
- sine_amp: amplitude of sine source signal (default: 0.1)
237
- add_noise_std: std of additive Gaussian noise (default: 0.003)
238
- note that amplitude of noise in unvoiced is decided
239
- by sine_amp
240
- voiced_threshold: threhold to set U/V given F0 (default: 0)
241
- Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
242
- F0_sampled (batchsize, length, 1)
243
- Sine_source (batchsize, length, 1)
244
- noise_source (batchsize, length 1)
245
- uv (batchsize, length, 1)
246
- """
247
-
248
- def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
249
- add_noise_std=0.003, voiced_threshod=0):
250
- super(SourceModuleHnNSF, self).__init__()
251
-
252
- self.sine_amp = sine_amp
253
- self.noise_std = add_noise_std
254
-
255
- # to produce sine waveforms
256
- self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
257
- sine_amp, add_noise_std, voiced_threshod)
258
-
259
- # to merge source harmonics into a single excitation
260
- self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
261
- self.l_tanh = torch.nn.Tanh()
262
-
263
- def forward(self, x):
264
- """
265
- Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
266
- F0_sampled (batchsize, length, 1)
267
- Sine_source (batchsize, length, 1)
268
- noise_source (batchsize, length 1)
269
- """
270
- # source for harmonic branch
271
- with torch.no_grad():
272
- sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
273
- sine_wavs = sine_wavs.transpose(1, 2)
274
- uv = uv.transpose(1, 2)
275
- sine_merge = self.l_tanh(self.l_linear(sine_wavs))
276
-
277
- # source for noise branch, in the same shape as uv
278
- noise = torch.randn_like(uv) * self.sine_amp / 3
279
- return sine_merge, noise, uv
280
-
281
-
282
- class HiFTGenerator(nn.Module):
283
- """
284
- HiFTNet Generator: Neural Source Filter + ISTFTNet
285
- https://arxiv.org/abs/2309.09493
286
- """
287
- def __init__(
288
- self,
289
- in_channels: int = 80,
290
- base_channels: int = 512,
291
- nb_harmonics: int = 8,
292
- sampling_rate: int = 22050,
293
- nsf_alpha: float = 0.1,
294
- nsf_sigma: float = 0.003,
295
- nsf_voiced_threshold: float = 10,
296
- upsample_rates: tp.List[int] = [8, 8],
297
- upsample_kernel_sizes: tp.List[int] = [16, 16],
298
- istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
299
- resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
300
- resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
301
- source_resblock_kernel_sizes: tp.List[int] = [7, 11],
302
- source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
303
- lrelu_slope: float = 0.1,
304
- audio_limit: float = 0.99,
305
- f0_predictor: torch.nn.Module = None,
306
- ):
307
- super(HiFTGenerator, self).__init__()
308
-
309
- self.out_channels = 1
310
- self.nb_harmonics = nb_harmonics
311
- self.sampling_rate = sampling_rate
312
- self.istft_params = istft_params
313
- self.lrelu_slope = lrelu_slope
314
- self.audio_limit = audio_limit
315
-
316
- self.num_kernels = len(resblock_kernel_sizes)
317
- self.num_upsamples = len(upsample_rates)
318
- self.m_source = SourceModuleHnNSF(
319
- sampling_rate=sampling_rate,
320
- upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
321
- harmonic_num=nb_harmonics,
322
- sine_amp=nsf_alpha,
323
- add_noise_std=nsf_sigma,
324
- voiced_threshod=nsf_voiced_threshold)
325
- self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
326
-
327
- self.conv_pre = weight_norm(
328
- Conv1d(in_channels, base_channels, 7, 1, padding=3)
329
- )
330
-
331
- # Up
332
- self.ups = nn.ModuleList()
333
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
334
- self.ups.append(
335
- weight_norm(
336
- ConvTranspose1d(
337
- base_channels // (2**i),
338
- base_channels // (2**(i + 1)),
339
- k,
340
- u,
341
- padding=(k - u) // 2,
342
- )
343
- )
344
- )
345
-
346
- # Down
347
- self.source_downs = nn.ModuleList()
348
- self.source_resblocks = nn.ModuleList()
349
- downsample_rates = [1] + upsample_rates[::-1][:-1]
350
- downsample_cum_rates = np.cumprod(downsample_rates)
351
- for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes,
352
- source_resblock_dilation_sizes)):
353
- if u == 1:
354
- self.source_downs.append(
355
- Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
356
- )
357
- else:
358
- self.source_downs.append(
359
- Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
360
- )
361
-
362
- self.source_resblocks.append(
363
- ResBlock(base_channels // (2 ** (i + 1)), k, d)
364
- )
365
-
366
- self.resblocks = nn.ModuleList()
367
- for i in range(len(self.ups)):
368
- ch = base_channels // (2**(i + 1))
369
- for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
370
- self.resblocks.append(ResBlock(ch, k, d))
371
-
372
- self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
373
- self.ups.apply(init_weights)
374
- self.conv_post.apply(init_weights)
375
- self.reflection_pad = nn.ReflectionPad1d((1, 0))
376
- self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
377
- self.f0_predictor = f0_predictor
378
-
379
- def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
380
- f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
381
-
382
- har_source, _, _ = self.m_source(f0)
383
- return har_source.transpose(1, 2)
384
-
385
- def _stft(self, x):
386
- spec = torch.stft(
387
- x,
388
- self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
389
- return_complex=True)
390
- spec = torch.view_as_real(spec) # [B, F, TT, 2]
391
- return spec[..., 0], spec[..., 1]
392
-
393
- def _istft(self, magnitude, phase):
394
- magnitude = torch.clip(magnitude, max=1e2)
395
- real = magnitude * torch.cos(phase)
396
- img = magnitude * torch.sin(phase)
397
- inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
398
- return inverse_transform
399
-
400
- def forward(self, x: torch.Tensor, f0=None) -> torch.Tensor:
401
- if f0 is None:
402
- f0 = self.f0_predictor(x)
403
- s = self._f02source(f0)
404
-
405
- s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
406
- s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
407
-
408
- x = self.conv_pre(x)
409
- for i in range(self.num_upsamples):
410
- x = F.leaky_relu(x, self.lrelu_slope)
411
- x = self.ups[i](x)
412
-
413
- if i == self.num_upsamples - 1:
414
- x = self.reflection_pad(x)
415
-
416
- # fusion
417
- si = self.source_downs[i](s_stft)
418
- si = self.source_resblocks[i](si)
419
- x = x + si
420
-
421
- xs = None
422
- for j in range(self.num_kernels):
423
- if xs is None:
424
- xs = self.resblocks[i * self.num_kernels + j](x)
425
- else:
426
- xs += self.resblocks[i * self.num_kernels + j](x)
427
- x = xs / self.num_kernels
428
-
429
- x = F.leaky_relu(x)
430
- x = self.conv_post(x)
431
- magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
432
- phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
433
-
434
- x = self._istft(magnitude, phase)
435
- x = torch.clamp(x, -self.audio_limit, self.audio_limit)
436
- return x
437
-
438
- def remove_weight_norm(self):
439
- print('Removing weight norm...')
440
- for l in self.ups:
441
- remove_weight_norm(l)
442
- for l in self.resblocks:
443
- l.remove_weight_norm()
444
- remove_weight_norm(self.conv_pre)
445
- remove_weight_norm(self.conv_post)
446
- self.source_module.remove_weight_norm()
447
- for l in self.source_downs:
448
- remove_weight_norm(l)
449
- for l in self.source_resblocks:
450
- l.remove_weight_norm()
451
-
452
- @torch.inference_mode()
453
- def inference(self, mel: torch.Tensor, f0=None) -> torch.Tensor:
454
- return self.forward(x=mel, f0=f0)
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """HIFI-GAN"""
16
+
17
+ import typing as tp
18
+ import numpy as np
19
+ from scipy.signal import get_window
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from torch.nn import Conv1d
24
+ from torch.nn import ConvTranspose1d
25
+ from torch.nn.utils import remove_weight_norm
26
+ from torch.nn.utils import weight_norm
27
+ from torch.distributions.uniform import Uniform
28
+
29
+ from torch import sin
30
+ from torch.nn.parameter import Parameter
31
+
32
+
33
+ """hifigan based generator implementation.
34
+
35
+ This code is modified from https://github.com/jik876/hifi-gan
36
+ ,https://github.com/kan-bayashi/ParallelWaveGAN and
37
+ https://github.com/NVIDIA/BigVGAN
38
+
39
+ """
40
+ class Snake(nn.Module):
41
+ '''
42
+ Implementation of a sine-based periodic activation function
43
+ Shape:
44
+ - Input: (B, C, T)
45
+ - Output: (B, C, T), same shape as the input
46
+ Parameters:
47
+ - alpha - trainable parameter
48
+ References:
49
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
50
+ https://arxiv.org/abs/2006.08195
51
+ Examples:
52
+ >>> a1 = snake(256)
53
+ >>> x = torch.randn(256)
54
+ >>> x = a1(x)
55
+ '''
56
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
57
+ '''
58
+ Initialization.
59
+ INPUT:
60
+ - in_features: shape of the input
61
+ - alpha: trainable parameter
62
+ alpha is initialized to 1 by default, higher values = higher-frequency.
63
+ alpha will be trained along with the rest of your model.
64
+ '''
65
+ super(Snake, self).__init__()
66
+ self.in_features = in_features
67
+
68
+ # initialize alpha
69
+ self.alpha_logscale = alpha_logscale
70
+ if self.alpha_logscale: # log scale alphas initialized to zeros
71
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
72
+ else: # linear scale alphas initialized to ones
73
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
74
+
75
+ self.alpha.requires_grad = alpha_trainable
76
+
77
+ self.no_div_by_zero = 0.000000001
78
+
79
+ def forward(self, x):
80
+ '''
81
+ Forward pass of the function.
82
+ Applies the function to the input elementwise.
83
+ Snake ∶= x + 1/a * sin^2 (xa)
84
+ '''
85
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
86
+ if self.alpha_logscale:
87
+ alpha = torch.exp(alpha)
88
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
89
+
90
+ return x
91
+
92
+ def get_padding(kernel_size, dilation=1):
93
+ return int((kernel_size * dilation - dilation) / 2)
94
+
95
+
96
+ def init_weights(m, mean=0.0, std=0.01):
97
+ classname = m.__class__.__name__
98
+ if classname.find("Conv") != -1:
99
+ m.weight.data.normal_(mean, std)
100
+
101
+
102
+
103
+ class ResBlock(torch.nn.Module):
104
+ """Residual block module in HiFiGAN/BigVGAN."""
105
+ def __init__(
106
+ self,
107
+ channels: int = 512,
108
+ kernel_size: int = 3,
109
+ dilations: tp.List[int] = [1, 3, 5],
110
+ ):
111
+ super(ResBlock, self).__init__()
112
+ self.convs1 = nn.ModuleList()
113
+ self.convs2 = nn.ModuleList()
114
+
115
+ for dilation in dilations:
116
+ self.convs1.append(
117
+ weight_norm(
118
+ Conv1d(
119
+ channels,
120
+ channels,
121
+ kernel_size,
122
+ 1,
123
+ dilation=dilation,
124
+ padding=get_padding(kernel_size, dilation)
125
+ )
126
+ )
127
+ )
128
+ self.convs2.append(
129
+ weight_norm(
130
+ Conv1d(
131
+ channels,
132
+ channels,
133
+ kernel_size,
134
+ 1,
135
+ dilation=1,
136
+ padding=get_padding(kernel_size, 1)
137
+ )
138
+ )
139
+ )
140
+ self.convs1.apply(init_weights)
141
+ self.convs2.apply(init_weights)
142
+ self.activations1 = nn.ModuleList([
143
+ Snake(channels, alpha_logscale=False)
144
+ for _ in range(len(self.convs1))
145
+ ])
146
+ self.activations2 = nn.ModuleList([
147
+ Snake(channels, alpha_logscale=False)
148
+ for _ in range(len(self.convs2))
149
+ ])
150
+
151
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
152
+ for idx in range(len(self.convs1)):
153
+ xt = self.activations1[idx](x)
154
+ xt = self.convs1[idx](xt)
155
+ xt = self.activations2[idx](xt)
156
+ xt = self.convs2[idx](xt)
157
+ x = xt + x
158
+ return x
159
+
160
+ def remove_weight_norm(self):
161
+ for idx in range(len(self.convs1)):
162
+ remove_weight_norm(self.convs1[idx])
163
+ remove_weight_norm(self.convs2[idx])
164
+
165
+ class SineGen(torch.nn.Module):
166
+ """ Definition of sine generator
167
+ SineGen(samp_rate, harmonic_num = 0,
168
+ sine_amp = 0.1, noise_std = 0.003,
169
+ voiced_threshold = 0,
170
+ flag_for_pulse=False)
171
+ samp_rate: sampling rate in Hz
172
+ harmonic_num: number of harmonic overtones (default 0)
173
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
174
+ noise_std: std of Gaussian noise (default 0.003)
175
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
176
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
177
+ Note: when flag_for_pulse is True, the first time step of a voiced
178
+ segment is always sin(np.pi) or cos(0)
179
+ """
180
+
181
+ def __init__(self, samp_rate, harmonic_num=0,
182
+ sine_amp=0.1, noise_std=0.003,
183
+ voiced_threshold=0):
184
+ super(SineGen, self).__init__()
185
+ self.sine_amp = sine_amp
186
+ self.noise_std = noise_std
187
+ self.harmonic_num = harmonic_num
188
+ self.sampling_rate = samp_rate
189
+ self.voiced_threshold = voiced_threshold
190
+
191
+ def _f02uv(self, f0):
192
+ # generate uv signal
193
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
194
+ return uv
195
+
196
+ @torch.no_grad()
197
+ def forward(self, f0):
198
+ """
199
+ :param f0: [B, 1, sample_len], Hz
200
+ :return: [B, 1, sample_len]
201
+ """
202
+
203
+ F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
204
+ for i in range(self.harmonic_num + 1):
205
+ F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
206
+
207
+ theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
208
+ u_dist = Uniform(low=-np.pi, high=np.pi)
209
+ phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
210
+ phase_vec[:, 0, :] = 0
211
+
212
+ # generate sine waveforms
213
+ sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
214
+
215
+ # generate uv signal
216
+ uv = self._f02uv(f0)
217
+
218
+ # noise: for unvoiced should be similar to sine_amp
219
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
220
+ # . for voiced regions is self.noise_std
221
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
222
+ noise = noise_amp * torch.randn_like(sine_waves)
223
+
224
+ # first: set the unvoiced part to 0 by uv
225
+ # then: additive noise
226
+ sine_waves = sine_waves * uv + noise
227
+ return sine_waves, uv, noise
228
+
229
+
230
+ class SourceModuleHnNSF(torch.nn.Module):
231
+ """ SourceModule for hn-nsf
232
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
233
+ add_noise_std=0.003, voiced_threshod=0)
234
+ sampling_rate: sampling_rate in Hz
235
+ harmonic_num: number of harmonic above F0 (default: 0)
236
+ sine_amp: amplitude of sine source signal (default: 0.1)
237
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
238
+ note that amplitude of noise in unvoiced is decided
239
+ by sine_amp
240
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
241
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
242
+ F0_sampled (batchsize, length, 1)
243
+ Sine_source (batchsize, length, 1)
244
+ noise_source (batchsize, length 1)
245
+ uv (batchsize, length, 1)
246
+ """
247
+
248
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
249
+ add_noise_std=0.003, voiced_threshod=0):
250
+ super(SourceModuleHnNSF, self).__init__()
251
+
252
+ self.sine_amp = sine_amp
253
+ self.noise_std = add_noise_std
254
+
255
+ # to produce sine waveforms
256
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
257
+ sine_amp, add_noise_std, voiced_threshod)
258
+
259
+ # to merge source harmonics into a single excitation
260
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
261
+ self.l_tanh = torch.nn.Tanh()
262
+
263
+ def forward(self, x):
264
+ """
265
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
266
+ F0_sampled (batchsize, length, 1)
267
+ Sine_source (batchsize, length, 1)
268
+ noise_source (batchsize, length 1)
269
+ """
270
+ # source for harmonic branch
271
+ with torch.no_grad():
272
+ sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
273
+ sine_wavs = sine_wavs.transpose(1, 2)
274
+ uv = uv.transpose(1, 2)
275
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
276
+
277
+ # source for noise branch, in the same shape as uv
278
+ noise = torch.randn_like(uv) * self.sine_amp / 3
279
+ return sine_merge, noise, uv
280
+
281
+
282
+ class HiFTGenerator(nn.Module):
283
+ """
284
+ HiFTNet Generator: Neural Source Filter + ISTFTNet
285
+ https://arxiv.org/abs/2309.09493
286
+ """
287
+ def __init__(
288
+ self,
289
+ in_channels: int = 80,
290
+ base_channels: int = 512,
291
+ nb_harmonics: int = 8,
292
+ sampling_rate: int = 22050,
293
+ nsf_alpha: float = 0.1,
294
+ nsf_sigma: float = 0.003,
295
+ nsf_voiced_threshold: float = 10,
296
+ upsample_rates: tp.List[int] = [8, 8],
297
+ upsample_kernel_sizes: tp.List[int] = [16, 16],
298
+ istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
299
+ resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
300
+ resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
301
+ source_resblock_kernel_sizes: tp.List[int] = [7, 11],
302
+ source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
303
+ lrelu_slope: float = 0.1,
304
+ audio_limit: float = 0.99,
305
+ f0_predictor: torch.nn.Module = None,
306
+ ):
307
+ super(HiFTGenerator, self).__init__()
308
+
309
+ self.out_channels = 1
310
+ self.nb_harmonics = nb_harmonics
311
+ self.sampling_rate = sampling_rate
312
+ self.istft_params = istft_params
313
+ self.lrelu_slope = lrelu_slope
314
+ self.audio_limit = audio_limit
315
+
316
+ self.num_kernels = len(resblock_kernel_sizes)
317
+ self.num_upsamples = len(upsample_rates)
318
+ self.m_source = SourceModuleHnNSF(
319
+ sampling_rate=sampling_rate,
320
+ upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
321
+ harmonic_num=nb_harmonics,
322
+ sine_amp=nsf_alpha,
323
+ add_noise_std=nsf_sigma,
324
+ voiced_threshod=nsf_voiced_threshold)
325
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
326
+
327
+ self.conv_pre = weight_norm(
328
+ Conv1d(in_channels, base_channels, 7, 1, padding=3)
329
+ )
330
+
331
+ # Up
332
+ self.ups = nn.ModuleList()
333
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
334
+ self.ups.append(
335
+ weight_norm(
336
+ ConvTranspose1d(
337
+ base_channels // (2**i),
338
+ base_channels // (2**(i + 1)),
339
+ k,
340
+ u,
341
+ padding=(k - u) // 2,
342
+ )
343
+ )
344
+ )
345
+
346
+ # Down
347
+ self.source_downs = nn.ModuleList()
348
+ self.source_resblocks = nn.ModuleList()
349
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
350
+ downsample_cum_rates = np.cumprod(downsample_rates)
351
+ for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes,
352
+ source_resblock_dilation_sizes)):
353
+ if u == 1:
354
+ self.source_downs.append(
355
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
356
+ )
357
+ else:
358
+ self.source_downs.append(
359
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
360
+ )
361
+
362
+ self.source_resblocks.append(
363
+ ResBlock(base_channels // (2 ** (i + 1)), k, d)
364
+ )
365
+
366
+ self.resblocks = nn.ModuleList()
367
+ for i in range(len(self.ups)):
368
+ ch = base_channels // (2**(i + 1))
369
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
370
+ self.resblocks.append(ResBlock(ch, k, d))
371
+
372
+ self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
373
+ self.ups.apply(init_weights)
374
+ self.conv_post.apply(init_weights)
375
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
376
+ self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
377
+ self.f0_predictor = f0_predictor
378
+
379
+ def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
380
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
381
+
382
+ har_source, _, _ = self.m_source(f0)
383
+ return har_source.transpose(1, 2)
384
+
385
+ def _stft(self, x):
386
+ spec = torch.stft(
387
+ x,
388
+ self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
389
+ return_complex=True)
390
+ spec = torch.view_as_real(spec) # [B, F, TT, 2]
391
+ return spec[..., 0], spec[..., 1]
392
+
393
+ def _istft(self, magnitude, phase):
394
+ magnitude = torch.clip(magnitude, max=1e2)
395
+ real = magnitude * torch.cos(phase)
396
+ img = magnitude * torch.sin(phase)
397
+ inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
398
+ return inverse_transform
399
+
400
+ def forward(self, x: torch.Tensor, f0=None) -> torch.Tensor:
401
+ if f0 is None:
402
+ f0 = self.f0_predictor(x)
403
+ s = self._f02source(f0)
404
+
405
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
406
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
407
+
408
+ x = self.conv_pre(x)
409
+ for i in range(self.num_upsamples):
410
+ x = F.leaky_relu(x, self.lrelu_slope)
411
+ x = self.ups[i](x)
412
+
413
+ if i == self.num_upsamples - 1:
414
+ x = self.reflection_pad(x)
415
+
416
+ # fusion
417
+ si = self.source_downs[i](s_stft)
418
+ si = self.source_resblocks[i](si)
419
+ x = x + si
420
+
421
+ xs = None
422
+ for j in range(self.num_kernels):
423
+ if xs is None:
424
+ xs = self.resblocks[i * self.num_kernels + j](x)
425
+ else:
426
+ xs += self.resblocks[i * self.num_kernels + j](x)
427
+ x = xs / self.num_kernels
428
+
429
+ x = F.leaky_relu(x)
430
+ x = self.conv_post(x)
431
+ magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
432
+ phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
433
+
434
+ x = self._istft(magnitude, phase)
435
+ x = torch.clamp(x, -self.audio_limit, self.audio_limit)
436
+ return x
437
+
438
+ def remove_weight_norm(self):
439
+ print('Removing weight norm...')
440
+ for l in self.ups:
441
+ remove_weight_norm(l)
442
+ for l in self.resblocks:
443
+ l.remove_weight_norm()
444
+ remove_weight_norm(self.conv_pre)
445
+ remove_weight_norm(self.conv_post)
446
+ self.source_module.remove_weight_norm()
447
+ for l in self.source_downs:
448
+ remove_weight_norm(l)
449
+ for l in self.source_resblocks:
450
+ l.remove_weight_norm()
451
+
452
+ @torch.inference_mode()
453
+ def inference(self, mel: torch.Tensor, f0=None) -> torch.Tensor:
454
+ return self.forward(x=mel, f0=f0)
modules/length_regulator.py CHANGED
@@ -1,102 +1,118 @@
1
- from typing import Tuple
2
- import torch
3
- import torch.nn as nn
4
- from torch.nn import functional as F
5
- from modules.commons import sequence_mask
6
-
7
-
8
- class InterpolateRegulator(nn.Module):
9
- def __init__(
10
- self,
11
- channels: int,
12
- sampling_ratios: Tuple,
13
- is_discrete: bool = False,
14
- codebook_size: int = 1024, # for discrete only
15
- out_channels: int = None,
16
- groups: int = 1,
17
- token_dropout_prob: float = 0.5, # randomly drop out input tokens
18
- token_dropout_range: float = 0.5, # randomly drop out input tokens
19
- n_codebooks: int = 1, # number of codebooks
20
- quantizer_dropout: float = 0.0, # dropout for quantizer
21
- f0_condition: bool = False,
22
- n_f0_bins: int = 512,
23
- ):
24
- super().__init__()
25
- self.sampling_ratios = sampling_ratios
26
- out_channels = out_channels or channels
27
- model = nn.ModuleList([])
28
- if len(sampling_ratios) > 0:
29
- for _ in sampling_ratios:
30
- module = nn.Conv1d(channels, channels, 3, 1, 1)
31
- norm = nn.GroupNorm(groups, channels)
32
- act = nn.Mish()
33
- model.extend([module, norm, act])
34
- model.append(
35
- nn.Conv1d(channels, out_channels, 1, 1)
36
- )
37
- self.model = nn.Sequential(*model)
38
- self.embedding = nn.Embedding(codebook_size, channels)
39
- self.is_discrete = is_discrete
40
-
41
- self.mask_token = nn.Parameter(torch.zeros(1, channels))
42
-
43
- self.n_codebooks = n_codebooks
44
- if n_codebooks > 1:
45
- self.extra_codebooks = nn.ModuleList([
46
- nn.Embedding(codebook_size, channels) for _ in range(n_codebooks - 1)
47
- ])
48
- self.token_dropout_prob = token_dropout_prob
49
- self.token_dropout_range = token_dropout_range
50
- self.quantizer_dropout = quantizer_dropout
51
-
52
- if f0_condition:
53
- self.f0_embedding = nn.Embedding(n_f0_bins, channels)
54
- self.f0_condition = f0_condition
55
- self.n_f0_bins = n_f0_bins
56
- self.f0_bins = torch.arange(2, 1024, 1024 // n_f0_bins)
57
- self.f0_mask = nn.Parameter(torch.zeros(1, channels))
58
- else:
59
- self.f0_condition = False
60
-
61
- def forward(self, x, ylens=None, n_quantizers=None, f0=None):
62
- # apply token drop
63
- if self.training:
64
- n_quantizers = torch.ones((x.shape[0],)) * self.n_codebooks
65
- dropout = torch.randint(1, self.n_codebooks + 1, (x.shape[0],))
66
- n_dropout = int(x.shape[0] * self.quantizer_dropout)
67
- n_quantizers[:n_dropout] = dropout[:n_dropout]
68
- n_quantizers = n_quantizers.to(x.device)
69
- # decide whether to drop for each sample in batch
70
- else:
71
- n_quantizers = torch.ones((x.shape[0],), device=x.device) * (self.n_codebooks if n_quantizers is None else n_quantizers)
72
- if self.is_discrete:
73
- if self.n_codebooks > 1:
74
- assert len(x.size()) == 3
75
- x_emb = self.embedding(x[:, 0])
76
- for i, emb in enumerate(self.extra_codebooks):
77
- x_emb = x_emb + (n_quantizers > i+1)[..., None, None] * emb(x[:, i+1])
78
- x = x_emb
79
- elif self.n_codebooks == 1:
80
- if len(x.size()) == 2:
81
- x = self.embedding(x)
82
- else:
83
- x = self.embedding(x[:, 0])
84
- # x in (B, T, D)
85
- mask = sequence_mask(ylens).unsqueeze(-1)
86
- x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
87
- if self.f0_condition:
88
- if f0 is None:
89
- x = x + self.f0_mask.unsqueeze(-1)
90
- else:
91
- quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T)
92
- if self.training:
93
- drop_f0 = torch.rand(quantized_f0.size(0)).to(f0.device) < self.quantizer_dropout
94
- else:
95
- drop_f0 = torch.zeros(quantized_f0.size(0)).to(f0.device).bool()
96
- f0_emb = self.f0_embedding(quantized_f0)
97
- f0_emb[drop_f0] = self.f0_mask
98
- f0_emb = F.interpolate(f0_emb.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
99
- x = x + f0_emb
100
- out = self.model(x).transpose(1, 2).contiguous()
101
- olens = ylens
102
- return out * mask, olens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ from modules.commons import sequence_mask
6
+ import numpy as np
7
+
8
+ # f0_bin = 256
9
+ f0_max = 1100.0
10
+ f0_min = 50.0
11
+ f0_mel_min = 1127 * np.log(1 + f0_min / 700)
12
+ f0_mel_max = 1127 * np.log(1 + f0_max / 700)
13
+
14
+ def f0_to_coarse(f0, f0_bin):
15
+ f0_mel = 1127 * (1 + f0 / 700).log()
16
+ a = (f0_bin - 2) / (f0_mel_max - f0_mel_min)
17
+ b = f0_mel_min * a - 1.
18
+ f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel)
19
+ # torch.clip_(f0_mel, min=1., max=float(f0_bin - 1))
20
+ f0_coarse = torch.round(f0_mel).long()
21
+ f0_coarse = f0_coarse * (f0_coarse > 0)
22
+ f0_coarse = f0_coarse + ((f0_coarse < 1) * 1)
23
+ f0_coarse = f0_coarse * (f0_coarse < f0_bin)
24
+ f0_coarse = f0_coarse + ((f0_coarse >= f0_bin) * (f0_bin - 1))
25
+ return f0_coarse
26
+
27
+ class InterpolateRegulator(nn.Module):
28
+ def __init__(
29
+ self,
30
+ channels: int,
31
+ sampling_ratios: Tuple,
32
+ is_discrete: bool = False,
33
+ codebook_size: int = 1024, # for discrete only
34
+ out_channels: int = None,
35
+ groups: int = 1,
36
+ token_dropout_prob: float = 0.5, # randomly drop out input tokens
37
+ token_dropout_range: float = 0.5, # randomly drop out input tokens
38
+ n_codebooks: int = 1, # number of codebooks
39
+ quantizer_dropout: float = 0.0, # dropout for quantizer
40
+ f0_condition: bool = False,
41
+ n_f0_bins: int = 512,
42
+ ):
43
+ super().__init__()
44
+ self.sampling_ratios = sampling_ratios
45
+ out_channels = out_channels or channels
46
+ model = nn.ModuleList([])
47
+ if len(sampling_ratios) > 0:
48
+ for _ in sampling_ratios:
49
+ module = nn.Conv1d(channels, channels, 3, 1, 1)
50
+ norm = nn.GroupNorm(groups, channels)
51
+ act = nn.Mish()
52
+ model.extend([module, norm, act])
53
+ model.append(
54
+ nn.Conv1d(channels, out_channels, 1, 1)
55
+ )
56
+ self.model = nn.Sequential(*model)
57
+ self.embedding = nn.Embedding(codebook_size, channels)
58
+ self.is_discrete = is_discrete
59
+
60
+ self.mask_token = nn.Parameter(torch.zeros(1, channels))
61
+
62
+ self.n_codebooks = n_codebooks
63
+ if n_codebooks > 1:
64
+ self.extra_codebooks = nn.ModuleList([
65
+ nn.Embedding(codebook_size, channels) for _ in range(n_codebooks - 1)
66
+ ])
67
+ self.token_dropout_prob = token_dropout_prob
68
+ self.token_dropout_range = token_dropout_range
69
+ self.quantizer_dropout = quantizer_dropout
70
+
71
+ if f0_condition:
72
+ self.f0_embedding = nn.Embedding(n_f0_bins, channels)
73
+ self.f0_condition = f0_condition
74
+ self.n_f0_bins = n_f0_bins
75
+ self.f0_bins = torch.arange(2, 1024, 1024 // n_f0_bins)
76
+ self.f0_mask = nn.Parameter(torch.zeros(1, channels))
77
+ else:
78
+ self.f0_condition = False
79
+
80
+ def forward(self, x, ylens=None, n_quantizers=None, f0=None):
81
+ # apply token drop
82
+ if self.training:
83
+ n_quantizers = torch.ones((x.shape[0],)) * self.n_codebooks
84
+ dropout = torch.randint(1, self.n_codebooks + 1, (x.shape[0],))
85
+ n_dropout = int(x.shape[0] * self.quantizer_dropout)
86
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
87
+ n_quantizers = n_quantizers.to(x.device)
88
+ # decide whether to drop for each sample in batch
89
+ else:
90
+ n_quantizers = torch.ones((x.shape[0],), device=x.device) * (self.n_codebooks if n_quantizers is None else n_quantizers)
91
+ if self.is_discrete:
92
+ if self.n_codebooks > 1:
93
+ assert len(x.size()) == 3
94
+ x_emb = self.embedding(x[:, 0])
95
+ for i, emb in enumerate(self.extra_codebooks):
96
+ x_emb = x_emb + (n_quantizers > i+1)[..., None, None] * emb(x[:, i+1])
97
+ x = x_emb
98
+ elif self.n_codebooks == 1:
99
+ if len(x.size()) == 2:
100
+ x = self.embedding(x)
101
+ else:
102
+ x = self.embedding(x[:, 0])
103
+ # x in (B, T, D)
104
+ mask = sequence_mask(ylens).unsqueeze(-1)
105
+ x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
106
+ if self.f0_condition:
107
+ if f0 is None:
108
+ x = x + self.f0_mask.unsqueeze(-1)
109
+ else:
110
+ # quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T)
111
+ quantized_f0 = f0_to_coarse(f0, self.n_f0_bins)
112
+ quantized_f0 = quantized_f0.clamp(0, self.n_f0_bins - 1).long()
113
+ f0_emb = self.f0_embedding(quantized_f0)
114
+ f0_emb = F.interpolate(f0_emb.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
115
+ x = x + f0_emb
116
+ out = self.model(x).transpose(1, 2).contiguous()
117
+ olens = ylens
118
+ return out * mask, olens
modules/rmvpe.py CHANGED
@@ -1,600 +1,600 @@
1
- from io import BytesIO
2
- import os
3
- from typing import List, Optional, Tuple
4
- import numpy as np
5
- import torch
6
-
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
- from librosa.util import normalize, pad_center, tiny
10
- from scipy.signal import get_window
11
-
12
- import logging
13
-
14
- logger = logging.getLogger(__name__)
15
-
16
-
17
- class STFT(torch.nn.Module):
18
- def __init__(
19
- self, filter_length=1024, hop_length=512, win_length=None, window="hann"
20
- ):
21
- """
22
- This module implements an STFT using 1D convolution and 1D transpose convolutions.
23
- This is a bit tricky so there are some cases that probably won't work as working
24
- out the same sizes before and after in all overlap add setups is tough. Right now,
25
- this code should work with hop lengths that are half the filter length (50% overlap
26
- between frames).
27
-
28
- Keyword Arguments:
29
- filter_length {int} -- Length of filters used (default: {1024})
30
- hop_length {int} -- Hop length of STFT (restrict to 50% overlap between frames) (default: {512})
31
- win_length {[type]} -- Length of the window function applied to each frame (if not specified, it
32
- equals the filter length). (default: {None})
33
- window {str} -- Type of window to use (options are bartlett, hann, hamming, blackman, blackmanharris)
34
- (default: {'hann'})
35
- """
36
- super(STFT, self).__init__()
37
- self.filter_length = filter_length
38
- self.hop_length = hop_length
39
- self.win_length = win_length if win_length else filter_length
40
- self.window = window
41
- self.forward_transform = None
42
- self.pad_amount = int(self.filter_length / 2)
43
- fourier_basis = np.fft.fft(np.eye(self.filter_length))
44
-
45
- cutoff = int((self.filter_length / 2 + 1))
46
- fourier_basis = np.vstack(
47
- [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
48
- )
49
- forward_basis = torch.FloatTensor(fourier_basis)
50
- inverse_basis = torch.FloatTensor(np.linalg.pinv(fourier_basis))
51
-
52
- assert filter_length >= self.win_length
53
- # get window and zero center pad it to filter_length
54
- fft_window = get_window(window, self.win_length, fftbins=True)
55
- fft_window = pad_center(fft_window, size=filter_length)
56
- fft_window = torch.from_numpy(fft_window).float()
57
-
58
- # window the bases
59
- forward_basis *= fft_window
60
- inverse_basis = (inverse_basis.T * fft_window).T
61
-
62
- self.register_buffer("forward_basis", forward_basis.float())
63
- self.register_buffer("inverse_basis", inverse_basis.float())
64
- self.register_buffer("fft_window", fft_window.float())
65
-
66
- def transform(self, input_data, return_phase=False):
67
- """Take input data (audio) to STFT domain.
68
-
69
- Arguments:
70
- input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
71
-
72
- Returns:
73
- magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
74
- num_frequencies, num_frames)
75
- phase {tensor} -- Phase of STFT with shape (num_batch,
76
- num_frequencies, num_frames)
77
- """
78
- input_data = F.pad(
79
- input_data,
80
- (self.pad_amount, self.pad_amount),
81
- mode="reflect",
82
- )
83
- forward_transform = input_data.unfold(
84
- 1, self.filter_length, self.hop_length
85
- ).permute(0, 2, 1)
86
- forward_transform = torch.matmul(self.forward_basis, forward_transform)
87
- cutoff = int((self.filter_length / 2) + 1)
88
- real_part = forward_transform[:, :cutoff, :]
89
- imag_part = forward_transform[:, cutoff:, :]
90
- magnitude = torch.sqrt(real_part**2 + imag_part**2)
91
- if return_phase:
92
- phase = torch.atan2(imag_part.data, real_part.data)
93
- return magnitude, phase
94
- else:
95
- return magnitude
96
-
97
- def inverse(self, magnitude, phase):
98
- """Call the inverse STFT (iSTFT), given magnitude and phase tensors produced
99
- by the ```transform``` function.
100
-
101
- Arguments:
102
- magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
103
- num_frequencies, num_frames)
104
- phase {tensor} -- Phase of STFT with shape (num_batch,
105
- num_frequencies, num_frames)
106
-
107
- Returns:
108
- inverse_transform {tensor} -- Reconstructed audio given magnitude and phase. Of
109
- shape (num_batch, num_samples)
110
- """
111
- cat = torch.cat(
112
- [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
113
- )
114
- fold = torch.nn.Fold(
115
- output_size=(1, (cat.size(-1) - 1) * self.hop_length + self.filter_length),
116
- kernel_size=(1, self.filter_length),
117
- stride=(1, self.hop_length),
118
- )
119
- inverse_transform = torch.matmul(self.inverse_basis, cat)
120
- inverse_transform = fold(inverse_transform)[
121
- :, 0, 0, self.pad_amount : -self.pad_amount
122
- ]
123
- window_square_sum = (
124
- self.fft_window.pow(2).repeat(cat.size(-1), 1).T.unsqueeze(0)
125
- )
126
- window_square_sum = fold(window_square_sum)[
127
- :, 0, 0, self.pad_amount : -self.pad_amount
128
- ]
129
- inverse_transform /= window_square_sum
130
- return inverse_transform
131
-
132
- def forward(self, input_data):
133
- """Take input data (audio) to STFT domain and then back to audio.
134
-
135
- Arguments:
136
- input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
137
-
138
- Returns:
139
- reconstruction {tensor} -- Reconstructed audio given magnitude and phase. Of
140
- shape (num_batch, num_samples)
141
- """
142
- self.magnitude, self.phase = self.transform(input_data, return_phase=True)
143
- reconstruction = self.inverse(self.magnitude, self.phase)
144
- return reconstruction
145
-
146
-
147
- from time import time as ttime
148
-
149
-
150
- class BiGRU(nn.Module):
151
- def __init__(self, input_features, hidden_features, num_layers):
152
- super(BiGRU, self).__init__()
153
- self.gru = nn.GRU(
154
- input_features,
155
- hidden_features,
156
- num_layers=num_layers,
157
- batch_first=True,
158
- bidirectional=True,
159
- )
160
-
161
- def forward(self, x):
162
- return self.gru(x)[0]
163
-
164
-
165
- class ConvBlockRes(nn.Module):
166
- def __init__(self, in_channels, out_channels, momentum=0.01):
167
- super(ConvBlockRes, self).__init__()
168
- self.conv = nn.Sequential(
169
- nn.Conv2d(
170
- in_channels=in_channels,
171
- out_channels=out_channels,
172
- kernel_size=(3, 3),
173
- stride=(1, 1),
174
- padding=(1, 1),
175
- bias=False,
176
- ),
177
- nn.BatchNorm2d(out_channels, momentum=momentum),
178
- nn.ReLU(),
179
- nn.Conv2d(
180
- in_channels=out_channels,
181
- out_channels=out_channels,
182
- kernel_size=(3, 3),
183
- stride=(1, 1),
184
- padding=(1, 1),
185
- bias=False,
186
- ),
187
- nn.BatchNorm2d(out_channels, momentum=momentum),
188
- nn.ReLU(),
189
- )
190
- # self.shortcut:Optional[nn.Module] = None
191
- if in_channels != out_channels:
192
- self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
193
-
194
- def forward(self, x: torch.Tensor):
195
- if not hasattr(self, "shortcut"):
196
- return self.conv(x) + x
197
- else:
198
- return self.conv(x) + self.shortcut(x)
199
-
200
-
201
- class Encoder(nn.Module):
202
- def __init__(
203
- self,
204
- in_channels,
205
- in_size,
206
- n_encoders,
207
- kernel_size,
208
- n_blocks,
209
- out_channels=16,
210
- momentum=0.01,
211
- ):
212
- super(Encoder, self).__init__()
213
- self.n_encoders = n_encoders
214
- self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
215
- self.layers = nn.ModuleList()
216
- self.latent_channels = []
217
- for i in range(self.n_encoders):
218
- self.layers.append(
219
- ResEncoderBlock(
220
- in_channels, out_channels, kernel_size, n_blocks, momentum=momentum
221
- )
222
- )
223
- self.latent_channels.append([out_channels, in_size])
224
- in_channels = out_channels
225
- out_channels *= 2
226
- in_size //= 2
227
- self.out_size = in_size
228
- self.out_channel = out_channels
229
-
230
- def forward(self, x: torch.Tensor):
231
- concat_tensors: List[torch.Tensor] = []
232
- x = self.bn(x)
233
- for i, layer in enumerate(self.layers):
234
- t, x = layer(x)
235
- concat_tensors.append(t)
236
- return x, concat_tensors
237
-
238
-
239
- class ResEncoderBlock(nn.Module):
240
- def __init__(
241
- self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01
242
- ):
243
- super(ResEncoderBlock, self).__init__()
244
- self.n_blocks = n_blocks
245
- self.conv = nn.ModuleList()
246
- self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
247
- for i in range(n_blocks - 1):
248
- self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
249
- self.kernel_size = kernel_size
250
- if self.kernel_size is not None:
251
- self.pool = nn.AvgPool2d(kernel_size=kernel_size)
252
-
253
- def forward(self, x):
254
- for i, conv in enumerate(self.conv):
255
- x = conv(x)
256
- if self.kernel_size is not None:
257
- return x, self.pool(x)
258
- else:
259
- return x
260
-
261
-
262
- class Intermediate(nn.Module): #
263
- def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
264
- super(Intermediate, self).__init__()
265
- self.n_inters = n_inters
266
- self.layers = nn.ModuleList()
267
- self.layers.append(
268
- ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)
269
- )
270
- for i in range(self.n_inters - 1):
271
- self.layers.append(
272
- ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum)
273
- )
274
-
275
- def forward(self, x):
276
- for i, layer in enumerate(self.layers):
277
- x = layer(x)
278
- return x
279
-
280
-
281
- class ResDecoderBlock(nn.Module):
282
- def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
283
- super(ResDecoderBlock, self).__init__()
284
- out_padding = (0, 1) if stride == (1, 2) else (1, 1)
285
- self.n_blocks = n_blocks
286
- self.conv1 = nn.Sequential(
287
- nn.ConvTranspose2d(
288
- in_channels=in_channels,
289
- out_channels=out_channels,
290
- kernel_size=(3, 3),
291
- stride=stride,
292
- padding=(1, 1),
293
- output_padding=out_padding,
294
- bias=False,
295
- ),
296
- nn.BatchNorm2d(out_channels, momentum=momentum),
297
- nn.ReLU(),
298
- )
299
- self.conv2 = nn.ModuleList()
300
- self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
301
- for i in range(n_blocks - 1):
302
- self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
303
-
304
- def forward(self, x, concat_tensor):
305
- x = self.conv1(x)
306
- x = torch.cat((x, concat_tensor), dim=1)
307
- for i, conv2 in enumerate(self.conv2):
308
- x = conv2(x)
309
- return x
310
-
311
-
312
- class Decoder(nn.Module):
313
- def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
314
- super(Decoder, self).__init__()
315
- self.layers = nn.ModuleList()
316
- self.n_decoders = n_decoders
317
- for i in range(self.n_decoders):
318
- out_channels = in_channels // 2
319
- self.layers.append(
320
- ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)
321
- )
322
- in_channels = out_channels
323
-
324
- def forward(self, x: torch.Tensor, concat_tensors: List[torch.Tensor]):
325
- for i, layer in enumerate(self.layers):
326
- x = layer(x, concat_tensors[-1 - i])
327
- return x
328
-
329
-
330
- class DeepUnet(nn.Module):
331
- def __init__(
332
- self,
333
- kernel_size,
334
- n_blocks,
335
- en_de_layers=5,
336
- inter_layers=4,
337
- in_channels=1,
338
- en_out_channels=16,
339
- ):
340
- super(DeepUnet, self).__init__()
341
- self.encoder = Encoder(
342
- in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels
343
- )
344
- self.intermediate = Intermediate(
345
- self.encoder.out_channel // 2,
346
- self.encoder.out_channel,
347
- inter_layers,
348
- n_blocks,
349
- )
350
- self.decoder = Decoder(
351
- self.encoder.out_channel, en_de_layers, kernel_size, n_blocks
352
- )
353
-
354
- def forward(self, x: torch.Tensor) -> torch.Tensor:
355
- x, concat_tensors = self.encoder(x)
356
- x = self.intermediate(x)
357
- x = self.decoder(x, concat_tensors)
358
- return x
359
-
360
-
361
- class E2E(nn.Module):
362
- def __init__(
363
- self,
364
- n_blocks,
365
- n_gru,
366
- kernel_size,
367
- en_de_layers=5,
368
- inter_layers=4,
369
- in_channels=1,
370
- en_out_channels=16,
371
- ):
372
- super(E2E, self).__init__()
373
- self.unet = DeepUnet(
374
- kernel_size,
375
- n_blocks,
376
- en_de_layers,
377
- inter_layers,
378
- in_channels,
379
- en_out_channels,
380
- )
381
- self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
382
- if n_gru:
383
- self.fc = nn.Sequential(
384
- BiGRU(3 * 128, 256, n_gru),
385
- nn.Linear(512, 360),
386
- nn.Dropout(0.25),
387
- nn.Sigmoid(),
388
- )
389
- else:
390
- self.fc = nn.Sequential(
391
- nn.Linear(3 * nn.N_MELS, nn.N_CLASS), nn.Dropout(0.25), nn.Sigmoid()
392
- )
393
-
394
- def forward(self, mel):
395
- # print(mel.shape)
396
- mel = mel.transpose(-1, -2).unsqueeze(1)
397
- x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
398
- x = self.fc(x)
399
- # print(x.shape)
400
- return x
401
-
402
-
403
- from librosa.filters import mel
404
-
405
-
406
- class MelSpectrogram(torch.nn.Module):
407
- def __init__(
408
- self,
409
- is_half,
410
- n_mel_channels,
411
- sampling_rate,
412
- win_length,
413
- hop_length,
414
- n_fft=None,
415
- mel_fmin=0,
416
- mel_fmax=None,
417
- clamp=1e-5,
418
- ):
419
- super().__init__()
420
- n_fft = win_length if n_fft is None else n_fft
421
- self.hann_window = {}
422
- mel_basis = mel(
423
- sr=sampling_rate,
424
- n_fft=n_fft,
425
- n_mels=n_mel_channels,
426
- fmin=mel_fmin,
427
- fmax=mel_fmax,
428
- htk=True,
429
- )
430
- mel_basis = torch.from_numpy(mel_basis).float()
431
- self.register_buffer("mel_basis", mel_basis)
432
- self.n_fft = win_length if n_fft is None else n_fft
433
- self.hop_length = hop_length
434
- self.win_length = win_length
435
- self.sampling_rate = sampling_rate
436
- self.n_mel_channels = n_mel_channels
437
- self.clamp = clamp
438
- self.is_half = is_half
439
-
440
- def forward(self, audio, keyshift=0, speed=1, center=True):
441
- factor = 2 ** (keyshift / 12)
442
- n_fft_new = int(np.round(self.n_fft * factor))
443
- win_length_new = int(np.round(self.win_length * factor))
444
- hop_length_new = int(np.round(self.hop_length * speed))
445
- keyshift_key = str(keyshift) + "_" + str(audio.device)
446
- if keyshift_key not in self.hann_window:
447
- self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(
448
- audio.device
449
- )
450
- if "privateuseone" in str(audio.device):
451
- if not hasattr(self, "stft"):
452
- self.stft = STFT(
453
- filter_length=n_fft_new,
454
- hop_length=hop_length_new,
455
- win_length=win_length_new,
456
- window="hann",
457
- ).to(audio.device)
458
- magnitude = self.stft.transform(audio)
459
- else:
460
- fft = torch.stft(
461
- audio,
462
- n_fft=n_fft_new,
463
- hop_length=hop_length_new,
464
- win_length=win_length_new,
465
- window=self.hann_window[keyshift_key],
466
- center=center,
467
- return_complex=True,
468
- )
469
- magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
470
- if keyshift != 0:
471
- size = self.n_fft // 2 + 1
472
- resize = magnitude.size(1)
473
- if resize < size:
474
- magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
475
- magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
476
- mel_output = torch.matmul(self.mel_basis, magnitude)
477
- if self.is_half == True:
478
- mel_output = mel_output.half()
479
- log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
480
- return log_mel_spec
481
-
482
-
483
- class RMVPE:
484
- def __init__(self, model_path: str, is_half, device=None, use_jit=False):
485
- self.resample_kernel = {}
486
- self.resample_kernel = {}
487
- self.is_half = is_half
488
- if device is None:
489
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
490
- self.device = device
491
- self.mel_extractor = MelSpectrogram(
492
- is_half, 128, 16000, 1024, 160, None, 30, 8000
493
- ).to(device)
494
- if "privateuseone" in str(device):
495
- import onnxruntime as ort
496
-
497
- ort_session = ort.InferenceSession(
498
- "%s/rmvpe.onnx" % os.environ["rmvpe_root"],
499
- providers=["DmlExecutionProvider"],
500
- )
501
- self.model = ort_session
502
- else:
503
- if str(self.device) == "cuda":
504
- self.device = torch.device("cuda:0")
505
-
506
- def get_default_model():
507
- model = E2E(4, 1, (2, 2))
508
- ckpt = torch.load(model_path, map_location="cpu")
509
- model.load_state_dict(ckpt)
510
- model.eval()
511
- if is_half:
512
- model = model.half()
513
- else:
514
- model = model.float()
515
- return model
516
-
517
- self.model = get_default_model()
518
-
519
- self.model = self.model.to(device)
520
- cents_mapping = 20 * np.arange(360) + 1997.3794084376191
521
- self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
522
-
523
- def mel2hidden(self, mel):
524
- with torch.no_grad():
525
- n_frames = mel.shape[-1]
526
- n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames
527
- if n_pad > 0:
528
- mel = F.pad(mel, (0, n_pad), mode="constant")
529
- if "privateuseone" in str(self.device):
530
- onnx_input_name = self.model.get_inputs()[0].name
531
- onnx_outputs_names = self.model.get_outputs()[0].name
532
- hidden = self.model.run(
533
- [onnx_outputs_names],
534
- input_feed={onnx_input_name: mel.cpu().numpy()},
535
- )[0]
536
- else:
537
- mel = mel.half() if self.is_half else mel.float()
538
- hidden = self.model(mel)
539
- return hidden[:, :n_frames]
540
-
541
- def decode(self, hidden, thred=0.03):
542
- cents_pred = self.to_local_average_cents(hidden, thred=thred)
543
- f0 = 10 * (2 ** (cents_pred / 1200))
544
- f0[f0 == 10] = 0
545
- # f0 = np.array([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred])
546
- return f0
547
-
548
- def infer_from_audio(self, audio, thred=0.03):
549
- # torch.cuda.synchronize()
550
- # t0 = ttime()
551
- if not torch.is_tensor(audio):
552
- audio = torch.from_numpy(audio)
553
- mel = self.mel_extractor(
554
- audio.float().to(self.device).unsqueeze(0), center=True
555
- )
556
- # print(123123123,mel.device.type)
557
- # torch.cuda.synchronize()
558
- # t1 = ttime()
559
- hidden = self.mel2hidden(mel)
560
- # torch.cuda.synchronize()
561
- # t2 = ttime()
562
- # print(234234,hidden.device.type)
563
- if "privateuseone" not in str(self.device):
564
- hidden = hidden.squeeze(0).cpu().numpy()
565
- else:
566
- hidden = hidden[0]
567
- if self.is_half == True:
568
- hidden = hidden.astype("float32")
569
-
570
- f0 = self.decode(hidden, thred=thred)
571
- # torch.cuda.synchronize()
572
- # t3 = ttime()
573
- # print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0))
574
- return f0
575
-
576
- def to_local_average_cents(self, salience, thred=0.05):
577
- # t0 = ttime()
578
- center = np.argmax(salience, axis=1) # 帧长#index
579
- salience = np.pad(salience, ((0, 0), (4, 4))) # 帧长,368
580
- # t1 = ttime()
581
- center += 4
582
- todo_salience = []
583
- todo_cents_mapping = []
584
- starts = center - 4
585
- ends = center + 5
586
- for idx in range(salience.shape[0]):
587
- todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
588
- todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
589
- # t2 = ttime()
590
- todo_salience = np.array(todo_salience) # 帧长,9
591
- todo_cents_mapping = np.array(todo_cents_mapping) # 帧长,9
592
- product_sum = np.sum(todo_salience * todo_cents_mapping, 1)
593
- weight_sum = np.sum(todo_salience, 1) # 帧长
594
- devided = product_sum / weight_sum # 帧长
595
- # t3 = ttime()
596
- maxx = np.max(salience, axis=1) # 帧长
597
- devided[maxx <= thred] = 0
598
- # t4 = ttime()
599
- # print("decode:%s\t%s\t%s\t%s" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
600
- return devided
 
1
+ from io import BytesIO
2
+ import os
3
+ from typing import List, Optional, Tuple
4
+ import numpy as np
5
+ import torch
6
+
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from librosa.util import normalize, pad_center, tiny
10
+ from scipy.signal import get_window
11
+
12
+ import logging
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class STFT(torch.nn.Module):
18
+ def __init__(
19
+ self, filter_length=1024, hop_length=512, win_length=None, window="hann"
20
+ ):
21
+ """
22
+ This module implements an STFT using 1D convolution and 1D transpose convolutions.
23
+ This is a bit tricky so there are some cases that probably won't work as working
24
+ out the same sizes before and after in all overlap add setups is tough. Right now,
25
+ this code should work with hop lengths that are half the filter length (50% overlap
26
+ between frames).
27
+
28
+ Keyword Arguments:
29
+ filter_length {int} -- Length of filters used (default: {1024})
30
+ hop_length {int} -- Hop length of STFT (restrict to 50% overlap between frames) (default: {512})
31
+ win_length {[type]} -- Length of the window function applied to each frame (if not specified, it
32
+ equals the filter length). (default: {None})
33
+ window {str} -- Type of window to use (options are bartlett, hann, hamming, blackman, blackmanharris)
34
+ (default: {'hann'})
35
+ """
36
+ super(STFT, self).__init__()
37
+ self.filter_length = filter_length
38
+ self.hop_length = hop_length
39
+ self.win_length = win_length if win_length else filter_length
40
+ self.window = window
41
+ self.forward_transform = None
42
+ self.pad_amount = int(self.filter_length / 2)
43
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
44
+
45
+ cutoff = int((self.filter_length / 2 + 1))
46
+ fourier_basis = np.vstack(
47
+ [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
48
+ )
49
+ forward_basis = torch.FloatTensor(fourier_basis)
50
+ inverse_basis = torch.FloatTensor(np.linalg.pinv(fourier_basis))
51
+
52
+ assert filter_length >= self.win_length
53
+ # get window and zero center pad it to filter_length
54
+ fft_window = get_window(window, self.win_length, fftbins=True)
55
+ fft_window = pad_center(fft_window, size=filter_length)
56
+ fft_window = torch.from_numpy(fft_window).float()
57
+
58
+ # window the bases
59
+ forward_basis *= fft_window
60
+ inverse_basis = (inverse_basis.T * fft_window).T
61
+
62
+ self.register_buffer("forward_basis", forward_basis.float())
63
+ self.register_buffer("inverse_basis", inverse_basis.float())
64
+ self.register_buffer("fft_window", fft_window.float())
65
+
66
+ def transform(self, input_data, return_phase=False):
67
+ """Take input data (audio) to STFT domain.
68
+
69
+ Arguments:
70
+ input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
71
+
72
+ Returns:
73
+ magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
74
+ num_frequencies, num_frames)
75
+ phase {tensor} -- Phase of STFT with shape (num_batch,
76
+ num_frequencies, num_frames)
77
+ """
78
+ input_data = F.pad(
79
+ input_data,
80
+ (self.pad_amount, self.pad_amount),
81
+ mode="reflect",
82
+ )
83
+ forward_transform = input_data.unfold(
84
+ 1, self.filter_length, self.hop_length
85
+ ).permute(0, 2, 1)
86
+ forward_transform = torch.matmul(self.forward_basis, forward_transform)
87
+ cutoff = int((self.filter_length / 2) + 1)
88
+ real_part = forward_transform[:, :cutoff, :]
89
+ imag_part = forward_transform[:, cutoff:, :]
90
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
91
+ if return_phase:
92
+ phase = torch.atan2(imag_part.data, real_part.data)
93
+ return magnitude, phase
94
+ else:
95
+ return magnitude
96
+
97
+ def inverse(self, magnitude, phase):
98
+ """Call the inverse STFT (iSTFT), given magnitude and phase tensors produced
99
+ by the ```transform``` function.
100
+
101
+ Arguments:
102
+ magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
103
+ num_frequencies, num_frames)
104
+ phase {tensor} -- Phase of STFT with shape (num_batch,
105
+ num_frequencies, num_frames)
106
+
107
+ Returns:
108
+ inverse_transform {tensor} -- Reconstructed audio given magnitude and phase. Of
109
+ shape (num_batch, num_samples)
110
+ """
111
+ cat = torch.cat(
112
+ [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
113
+ )
114
+ fold = torch.nn.Fold(
115
+ output_size=(1, (cat.size(-1) - 1) * self.hop_length + self.filter_length),
116
+ kernel_size=(1, self.filter_length),
117
+ stride=(1, self.hop_length),
118
+ )
119
+ inverse_transform = torch.matmul(self.inverse_basis, cat)
120
+ inverse_transform = fold(inverse_transform)[
121
+ :, 0, 0, self.pad_amount : -self.pad_amount
122
+ ]
123
+ window_square_sum = (
124
+ self.fft_window.pow(2).repeat(cat.size(-1), 1).T.unsqueeze(0)
125
+ )
126
+ window_square_sum = fold(window_square_sum)[
127
+ :, 0, 0, self.pad_amount : -self.pad_amount
128
+ ]
129
+ inverse_transform /= window_square_sum
130
+ return inverse_transform
131
+
132
+ def forward(self, input_data):
133
+ """Take input data (audio) to STFT domain and then back to audio.
134
+
135
+ Arguments:
136
+ input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
137
+
138
+ Returns:
139
+ reconstruction {tensor} -- Reconstructed audio given magnitude and phase. Of
140
+ shape (num_batch, num_samples)
141
+ """
142
+ self.magnitude, self.phase = self.transform(input_data, return_phase=True)
143
+ reconstruction = self.inverse(self.magnitude, self.phase)
144
+ return reconstruction
145
+
146
+
147
+ from time import time as ttime
148
+
149
+
150
+ class BiGRU(nn.Module):
151
+ def __init__(self, input_features, hidden_features, num_layers):
152
+ super(BiGRU, self).__init__()
153
+ self.gru = nn.GRU(
154
+ input_features,
155
+ hidden_features,
156
+ num_layers=num_layers,
157
+ batch_first=True,
158
+ bidirectional=True,
159
+ )
160
+
161
+ def forward(self, x):
162
+ return self.gru(x)[0]
163
+
164
+
165
+ class ConvBlockRes(nn.Module):
166
+ def __init__(self, in_channels, out_channels, momentum=0.01):
167
+ super(ConvBlockRes, self).__init__()
168
+ self.conv = nn.Sequential(
169
+ nn.Conv2d(
170
+ in_channels=in_channels,
171
+ out_channels=out_channels,
172
+ kernel_size=(3, 3),
173
+ stride=(1, 1),
174
+ padding=(1, 1),
175
+ bias=False,
176
+ ),
177
+ nn.BatchNorm2d(out_channels, momentum=momentum),
178
+ nn.ReLU(),
179
+ nn.Conv2d(
180
+ in_channels=out_channels,
181
+ out_channels=out_channels,
182
+ kernel_size=(3, 3),
183
+ stride=(1, 1),
184
+ padding=(1, 1),
185
+ bias=False,
186
+ ),
187
+ nn.BatchNorm2d(out_channels, momentum=momentum),
188
+ nn.ReLU(),
189
+ )
190
+ # self.shortcut:Optional[nn.Module] = None
191
+ if in_channels != out_channels:
192
+ self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
193
+
194
+ def forward(self, x: torch.Tensor):
195
+ if not hasattr(self, "shortcut"):
196
+ return self.conv(x) + x
197
+ else:
198
+ return self.conv(x) + self.shortcut(x)
199
+
200
+
201
+ class Encoder(nn.Module):
202
+ def __init__(
203
+ self,
204
+ in_channels,
205
+ in_size,
206
+ n_encoders,
207
+ kernel_size,
208
+ n_blocks,
209
+ out_channels=16,
210
+ momentum=0.01,
211
+ ):
212
+ super(Encoder, self).__init__()
213
+ self.n_encoders = n_encoders
214
+ self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
215
+ self.layers = nn.ModuleList()
216
+ self.latent_channels = []
217
+ for i in range(self.n_encoders):
218
+ self.layers.append(
219
+ ResEncoderBlock(
220
+ in_channels, out_channels, kernel_size, n_blocks, momentum=momentum
221
+ )
222
+ )
223
+ self.latent_channels.append([out_channels, in_size])
224
+ in_channels = out_channels
225
+ out_channels *= 2
226
+ in_size //= 2
227
+ self.out_size = in_size
228
+ self.out_channel = out_channels
229
+
230
+ def forward(self, x: torch.Tensor):
231
+ concat_tensors: List[torch.Tensor] = []
232
+ x = self.bn(x)
233
+ for i, layer in enumerate(self.layers):
234
+ t, x = layer(x)
235
+ concat_tensors.append(t)
236
+ return x, concat_tensors
237
+
238
+
239
+ class ResEncoderBlock(nn.Module):
240
+ def __init__(
241
+ self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01
242
+ ):
243
+ super(ResEncoderBlock, self).__init__()
244
+ self.n_blocks = n_blocks
245
+ self.conv = nn.ModuleList()
246
+ self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
247
+ for i in range(n_blocks - 1):
248
+ self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
249
+ self.kernel_size = kernel_size
250
+ if self.kernel_size is not None:
251
+ self.pool = nn.AvgPool2d(kernel_size=kernel_size)
252
+
253
+ def forward(self, x):
254
+ for i, conv in enumerate(self.conv):
255
+ x = conv(x)
256
+ if self.kernel_size is not None:
257
+ return x, self.pool(x)
258
+ else:
259
+ return x
260
+
261
+
262
+ class Intermediate(nn.Module): #
263
+ def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
264
+ super(Intermediate, self).__init__()
265
+ self.n_inters = n_inters
266
+ self.layers = nn.ModuleList()
267
+ self.layers.append(
268
+ ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)
269
+ )
270
+ for i in range(self.n_inters - 1):
271
+ self.layers.append(
272
+ ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum)
273
+ )
274
+
275
+ def forward(self, x):
276
+ for i, layer in enumerate(self.layers):
277
+ x = layer(x)
278
+ return x
279
+
280
+
281
+ class ResDecoderBlock(nn.Module):
282
+ def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
283
+ super(ResDecoderBlock, self).__init__()
284
+ out_padding = (0, 1) if stride == (1, 2) else (1, 1)
285
+ self.n_blocks = n_blocks
286
+ self.conv1 = nn.Sequential(
287
+ nn.ConvTranspose2d(
288
+ in_channels=in_channels,
289
+ out_channels=out_channels,
290
+ kernel_size=(3, 3),
291
+ stride=stride,
292
+ padding=(1, 1),
293
+ output_padding=out_padding,
294
+ bias=False,
295
+ ),
296
+ nn.BatchNorm2d(out_channels, momentum=momentum),
297
+ nn.ReLU(),
298
+ )
299
+ self.conv2 = nn.ModuleList()
300
+ self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
301
+ for i in range(n_blocks - 1):
302
+ self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
303
+
304
+ def forward(self, x, concat_tensor):
305
+ x = self.conv1(x)
306
+ x = torch.cat((x, concat_tensor), dim=1)
307
+ for i, conv2 in enumerate(self.conv2):
308
+ x = conv2(x)
309
+ return x
310
+
311
+
312
+ class Decoder(nn.Module):
313
+ def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
314
+ super(Decoder, self).__init__()
315
+ self.layers = nn.ModuleList()
316
+ self.n_decoders = n_decoders
317
+ for i in range(self.n_decoders):
318
+ out_channels = in_channels // 2
319
+ self.layers.append(
320
+ ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)
321
+ )
322
+ in_channels = out_channels
323
+
324
+ def forward(self, x: torch.Tensor, concat_tensors: List[torch.Tensor]):
325
+ for i, layer in enumerate(self.layers):
326
+ x = layer(x, concat_tensors[-1 - i])
327
+ return x
328
+
329
+
330
+ class DeepUnet(nn.Module):
331
+ def __init__(
332
+ self,
333
+ kernel_size,
334
+ n_blocks,
335
+ en_de_layers=5,
336
+ inter_layers=4,
337
+ in_channels=1,
338
+ en_out_channels=16,
339
+ ):
340
+ super(DeepUnet, self).__init__()
341
+ self.encoder = Encoder(
342
+ in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels
343
+ )
344
+ self.intermediate = Intermediate(
345
+ self.encoder.out_channel // 2,
346
+ self.encoder.out_channel,
347
+ inter_layers,
348
+ n_blocks,
349
+ )
350
+ self.decoder = Decoder(
351
+ self.encoder.out_channel, en_de_layers, kernel_size, n_blocks
352
+ )
353
+
354
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
355
+ x, concat_tensors = self.encoder(x)
356
+ x = self.intermediate(x)
357
+ x = self.decoder(x, concat_tensors)
358
+ return x
359
+
360
+
361
+ class E2E(nn.Module):
362
+ def __init__(
363
+ self,
364
+ n_blocks,
365
+ n_gru,
366
+ kernel_size,
367
+ en_de_layers=5,
368
+ inter_layers=4,
369
+ in_channels=1,
370
+ en_out_channels=16,
371
+ ):
372
+ super(E2E, self).__init__()
373
+ self.unet = DeepUnet(
374
+ kernel_size,
375
+ n_blocks,
376
+ en_de_layers,
377
+ inter_layers,
378
+ in_channels,
379
+ en_out_channels,
380
+ )
381
+ self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
382
+ if n_gru:
383
+ self.fc = nn.Sequential(
384
+ BiGRU(3 * 128, 256, n_gru),
385
+ nn.Linear(512, 360),
386
+ nn.Dropout(0.25),
387
+ nn.Sigmoid(),
388
+ )
389
+ else:
390
+ self.fc = nn.Sequential(
391
+ nn.Linear(3 * nn.N_MELS, nn.N_CLASS), nn.Dropout(0.25), nn.Sigmoid()
392
+ )
393
+
394
+ def forward(self, mel):
395
+ # print(mel.shape)
396
+ mel = mel.transpose(-1, -2).unsqueeze(1)
397
+ x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
398
+ x = self.fc(x)
399
+ # print(x.shape)
400
+ return x
401
+
402
+
403
+ from librosa.filters import mel
404
+
405
+
406
+ class MelSpectrogram(torch.nn.Module):
407
+ def __init__(
408
+ self,
409
+ is_half,
410
+ n_mel_channels,
411
+ sampling_rate,
412
+ win_length,
413
+ hop_length,
414
+ n_fft=None,
415
+ mel_fmin=0,
416
+ mel_fmax=None,
417
+ clamp=1e-5,
418
+ ):
419
+ super().__init__()
420
+ n_fft = win_length if n_fft is None else n_fft
421
+ self.hann_window = {}
422
+ mel_basis = mel(
423
+ sr=sampling_rate,
424
+ n_fft=n_fft,
425
+ n_mels=n_mel_channels,
426
+ fmin=mel_fmin,
427
+ fmax=mel_fmax,
428
+ htk=True,
429
+ )
430
+ mel_basis = torch.from_numpy(mel_basis).float()
431
+ self.register_buffer("mel_basis", mel_basis)
432
+ self.n_fft = win_length if n_fft is None else n_fft
433
+ self.hop_length = hop_length
434
+ self.win_length = win_length
435
+ self.sampling_rate = sampling_rate
436
+ self.n_mel_channels = n_mel_channels
437
+ self.clamp = clamp
438
+ self.is_half = is_half
439
+
440
+ def forward(self, audio, keyshift=0, speed=1, center=True):
441
+ factor = 2 ** (keyshift / 12)
442
+ n_fft_new = int(np.round(self.n_fft * factor))
443
+ win_length_new = int(np.round(self.win_length * factor))
444
+ hop_length_new = int(np.round(self.hop_length * speed))
445
+ keyshift_key = str(keyshift) + "_" + str(audio.device)
446
+ if keyshift_key not in self.hann_window:
447
+ self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(
448
+ audio.device
449
+ )
450
+ if "privateuseone" in str(audio.device):
451
+ if not hasattr(self, "stft"):
452
+ self.stft = STFT(
453
+ filter_length=n_fft_new,
454
+ hop_length=hop_length_new,
455
+ win_length=win_length_new,
456
+ window="hann",
457
+ ).to(audio.device)
458
+ magnitude = self.stft.transform(audio)
459
+ else:
460
+ fft = torch.stft(
461
+ audio,
462
+ n_fft=n_fft_new,
463
+ hop_length=hop_length_new,
464
+ win_length=win_length_new,
465
+ window=self.hann_window[keyshift_key],
466
+ center=center,
467
+ return_complex=True,
468
+ )
469
+ magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
470
+ if keyshift != 0:
471
+ size = self.n_fft // 2 + 1
472
+ resize = magnitude.size(1)
473
+ if resize < size:
474
+ magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
475
+ magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
476
+ mel_output = torch.matmul(self.mel_basis, magnitude)
477
+ if self.is_half == True:
478
+ mel_output = mel_output.half()
479
+ log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
480
+ return log_mel_spec
481
+
482
+
483
+ class RMVPE:
484
+ def __init__(self, model_path: str, is_half, device=None, use_jit=False):
485
+ self.resample_kernel = {}
486
+ self.resample_kernel = {}
487
+ self.is_half = is_half
488
+ if device is None:
489
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
490
+ self.device = device
491
+ self.mel_extractor = MelSpectrogram(
492
+ is_half, 128, 16000, 1024, 160, None, 30, 8000
493
+ ).to(device)
494
+ if "privateuseone" in str(device):
495
+ import onnxruntime as ort
496
+
497
+ ort_session = ort.InferenceSession(
498
+ "%s/rmvpe.onnx" % os.environ["rmvpe_root"],
499
+ providers=["DmlExecutionProvider"],
500
+ )
501
+ self.model = ort_session
502
+ else:
503
+ if str(self.device) == "cuda":
504
+ self.device = torch.device("cuda:0")
505
+
506
+ def get_default_model():
507
+ model = E2E(4, 1, (2, 2))
508
+ ckpt = torch.load(model_path, map_location="cpu")
509
+ model.load_state_dict(ckpt)
510
+ model.eval()
511
+ if is_half:
512
+ model = model.half()
513
+ else:
514
+ model = model.float()
515
+ return model
516
+
517
+ self.model = get_default_model()
518
+
519
+ self.model = self.model.to(device)
520
+ cents_mapping = 20 * np.arange(360) + 1997.3794084376191
521
+ self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
522
+
523
+ def mel2hidden(self, mel):
524
+ with torch.no_grad():
525
+ n_frames = mel.shape[-1]
526
+ n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames
527
+ if n_pad > 0:
528
+ mel = F.pad(mel, (0, n_pad), mode="constant")
529
+ if "privateuseone" in str(self.device):
530
+ onnx_input_name = self.model.get_inputs()[0].name
531
+ onnx_outputs_names = self.model.get_outputs()[0].name
532
+ hidden = self.model.run(
533
+ [onnx_outputs_names],
534
+ input_feed={onnx_input_name: mel.cpu().numpy()},
535
+ )[0]
536
+ else:
537
+ mel = mel.half() if self.is_half else mel.float()
538
+ hidden = self.model(mel)
539
+ return hidden[:, :n_frames]
540
+
541
+ def decode(self, hidden, thred=0.03):
542
+ cents_pred = self.to_local_average_cents(hidden, thred=thred)
543
+ f0 = 10 * (2 ** (cents_pred / 1200))
544
+ f0[f0 == 10] = 0
545
+ # f0 = np.array([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred])
546
+ return f0
547
+
548
+ def infer_from_audio(self, audio, thred=0.03):
549
+ # torch.cuda.synchronize()
550
+ # t0 = ttime()
551
+ if not torch.is_tensor(audio):
552
+ audio = torch.from_numpy(audio)
553
+ mel = self.mel_extractor(
554
+ audio.float().to(self.device).unsqueeze(0), center=True
555
+ )
556
+ # print(123123123,mel.device.type)
557
+ # torch.cuda.synchronize()
558
+ # t1 = ttime()
559
+ hidden = self.mel2hidden(mel)
560
+ # torch.cuda.synchronize()
561
+ # t2 = ttime()
562
+ # print(234234,hidden.device.type)
563
+ if "privateuseone" not in str(self.device):
564
+ hidden = hidden.squeeze(0).cpu().numpy()
565
+ else:
566
+ hidden = hidden[0]
567
+ if self.is_half == True:
568
+ hidden = hidden.astype("float32")
569
+
570
+ f0 = self.decode(hidden, thred=thred)
571
+ # torch.cuda.synchronize()
572
+ # t3 = ttime()
573
+ # print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0))
574
+ return f0
575
+
576
+ def to_local_average_cents(self, salience, thred=0.05):
577
+ # t0 = ttime()
578
+ center = np.argmax(salience, axis=1) # 帧长#index
579
+ salience = np.pad(salience, ((0, 0), (4, 4))) # 帧长,368
580
+ # t1 = ttime()
581
+ center += 4
582
+ todo_salience = []
583
+ todo_cents_mapping = []
584
+ starts = center - 4
585
+ ends = center + 5
586
+ for idx in range(salience.shape[0]):
587
+ todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
588
+ todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
589
+ # t2 = ttime()
590
+ todo_salience = np.array(todo_salience) # 帧长,9
591
+ todo_cents_mapping = np.array(todo_cents_mapping) # 帧长,9
592
+ product_sum = np.sum(todo_salience * todo_cents_mapping, 1)
593
+ weight_sum = np.sum(todo_salience, 1) # 帧长
594
+ devided = product_sum / weight_sum # 帧长
595
+ # t3 = ttime()
596
+ maxx = np.max(salience, axis=1) # 帧长
597
+ devided[maxx <= thred] = 0
598
+ # t4 = ttime()
599
+ # print("decode:%s\t%s\t%s\t%s" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
600
+ return devided