Spaces:
Running
on
Zero
Running
on
Zero
alibabasglab
commited on
Upload 32 files
Browse files- models/mossformer2_sr/__init__.py +0 -0
- models/mossformer2_sr/__pycache__/__init__.cpython-312.pyc +0 -0
- models/mossformer2_sr/__pycache__/__init__.cpython-38.pyc +0 -0
- models/mossformer2_sr/__pycache__/conv_module.cpython-312.pyc +0 -0
- models/mossformer2_sr/__pycache__/conv_module.cpython-38.pyc +0 -0
- models/mossformer2_sr/__pycache__/env.cpython-312.pyc +0 -0
- models/mossformer2_sr/__pycache__/fsmn.cpython-312.pyc +0 -0
- models/mossformer2_sr/__pycache__/fsmn.cpython-38.pyc +0 -0
- models/mossformer2_sr/__pycache__/generator.cpython-312.pyc +0 -0
- models/mossformer2_sr/__pycache__/layer_norm.cpython-312.pyc +0 -0
- models/mossformer2_sr/__pycache__/layer_norm.cpython-38.pyc +0 -0
- models/mossformer2_sr/__pycache__/mossformer.cpython-38.pyc +0 -0
- models/mossformer2_sr/__pycache__/mossformer2.cpython-312.pyc +0 -0
- models/mossformer2_sr/__pycache__/mossformer2.cpython-38.pyc +0 -0
- models/mossformer2_sr/__pycache__/mossformer2_block.cpython-312.pyc +0 -0
- models/mossformer2_sr/__pycache__/mossformer2_block.cpython-38.pyc +0 -0
- models/mossformer2_sr/__pycache__/mossformer2_se_wrapper.cpython-312.pyc +0 -0
- models/mossformer2_sr/__pycache__/mossformer2_se_wrapper.cpython-38.pyc +0 -0
- models/mossformer2_sr/__pycache__/mossformer2_sr_wrapper.cpython-312.pyc +0 -0
- models/mossformer2_sr/__pycache__/mossformer_block.cpython-38.pyc +0 -0
- models/mossformer2_sr/__pycache__/snake.cpython-312.pyc +0 -0
- models/mossformer2_sr/__pycache__/utils.cpython-312.pyc +0 -0
- models/mossformer2_sr/conv_module.py +388 -0
- models/mossformer2_sr/env.py +15 -0
- models/mossformer2_sr/fsmn.py +214 -0
- models/mossformer2_sr/generator.py +448 -0
- models/mossformer2_sr/layer_norm.py +126 -0
- models/mossformer2_sr/mossformer2.py +711 -0
- models/mossformer2_sr/mossformer2_block.py +735 -0
- models/mossformer2_sr/mossformer2_sr_wrapper.py +52 -0
- models/mossformer2_sr/snake.py +33 -0
- models/mossformer2_sr/utils.py +37 -0
models/mossformer2_sr/__init__.py
ADDED
File without changes
|
models/mossformer2_sr/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (198 Bytes). View file
|
|
models/mossformer2_sr/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (193 Bytes). View file
|
|
models/mossformer2_sr/__pycache__/conv_module.cpython-312.pyc
ADDED
Binary file (20.4 kB). View file
|
|
models/mossformer2_sr/__pycache__/conv_module.cpython-38.pyc
ADDED
Binary file (13.9 kB). View file
|
|
models/mossformer2_sr/__pycache__/env.cpython-312.pyc
ADDED
Binary file (1.19 kB). View file
|
|
models/mossformer2_sr/__pycache__/fsmn.cpython-312.pyc
ADDED
Binary file (12.4 kB). View file
|
|
models/mossformer2_sr/__pycache__/fsmn.cpython-38.pyc
ADDED
Binary file (8.51 kB). View file
|
|
models/mossformer2_sr/__pycache__/generator.cpython-312.pyc
ADDED
Binary file (24.6 kB). View file
|
|
models/mossformer2_sr/__pycache__/layer_norm.cpython-312.pyc
ADDED
Binary file (6.59 kB). View file
|
|
models/mossformer2_sr/__pycache__/layer_norm.cpython-38.pyc
ADDED
Binary file (4.24 kB). View file
|
|
models/mossformer2_sr/__pycache__/mossformer.cpython-38.pyc
ADDED
Binary file (16.3 kB). View file
|
|
models/mossformer2_sr/__pycache__/mossformer2.cpython-312.pyc
ADDED
Binary file (22.8 kB). View file
|
|
models/mossformer2_sr/__pycache__/mossformer2.cpython-38.pyc
ADDED
Binary file (15.9 kB). View file
|
|
models/mossformer2_sr/__pycache__/mossformer2_block.cpython-312.pyc
ADDED
Binary file (30.8 kB). View file
|
|
models/mossformer2_sr/__pycache__/mossformer2_block.cpython-38.pyc
ADDED
Binary file (23.5 kB). View file
|
|
models/mossformer2_sr/__pycache__/mossformer2_se_wrapper.cpython-312.pyc
ADDED
Binary file (4.05 kB). View file
|
|
models/mossformer2_sr/__pycache__/mossformer2_se_wrapper.cpython-38.pyc
ADDED
Binary file (3.6 kB). View file
|
|
models/mossformer2_sr/__pycache__/mossformer2_sr_wrapper.cpython-312.pyc
ADDED
Binary file (2.36 kB). View file
|
|
models/mossformer2_sr/__pycache__/mossformer_block.cpython-38.pyc
ADDED
Binary file (21.2 kB). View file
|
|
models/mossformer2_sr/__pycache__/snake.cpython-312.pyc
ADDED
Binary file (2.24 kB). View file
|
|
models/mossformer2_sr/__pycache__/utils.cpython-312.pyc
ADDED
Binary file (2.46 kB). View file
|
|
models/mossformer2_sr/conv_module.py
ADDED
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch import Tensor
|
4 |
+
import torch.nn.init as init
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
EPS = 1e-8
|
8 |
+
|
9 |
+
class GlobalLayerNorm(nn.Module):
|
10 |
+
"""Calculate Global Layer Normalization.
|
11 |
+
|
12 |
+
Arguments
|
13 |
+
---------
|
14 |
+
dim : (int or list or torch.Size)
|
15 |
+
Input shape from an expected input of size.
|
16 |
+
eps : float
|
17 |
+
A value added to the denominator for numerical stability.
|
18 |
+
elementwise_affine : bool
|
19 |
+
A boolean value that when set to True,
|
20 |
+
this module has learnable per-element affine parameters
|
21 |
+
initialized to ones (for weights) and zeros (for biases).
|
22 |
+
|
23 |
+
Example
|
24 |
+
-------
|
25 |
+
>>> x = torch.randn(5, 10, 20)
|
26 |
+
>>> GLN = GlobalLayerNorm(10, 3)
|
27 |
+
>>> x_norm = GLN(x)
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True):
|
31 |
+
super(GlobalLayerNorm, self).__init__()
|
32 |
+
self.dim = dim
|
33 |
+
self.eps = eps
|
34 |
+
self.elementwise_affine = elementwise_affine
|
35 |
+
|
36 |
+
if self.elementwise_affine:
|
37 |
+
if shape == 3:
|
38 |
+
self.weight = nn.Parameter(torch.ones(self.dim, 1))
|
39 |
+
self.bias = nn.Parameter(torch.zeros(self.dim, 1))
|
40 |
+
if shape == 4:
|
41 |
+
self.weight = nn.Parameter(torch.ones(self.dim, 1, 1))
|
42 |
+
self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1))
|
43 |
+
else:
|
44 |
+
self.register_parameter("weight", None)
|
45 |
+
self.register_parameter("bias", None)
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
"""Returns the normalized tensor.
|
49 |
+
|
50 |
+
Arguments
|
51 |
+
---------
|
52 |
+
x : torch.Tensor
|
53 |
+
Tensor of size [N, C, K, S] or [N, C, L].
|
54 |
+
"""
|
55 |
+
# x = N x C x K x S or N x C x L
|
56 |
+
# N x 1 x 1
|
57 |
+
# cln: mean,var N x 1 x K x S
|
58 |
+
# gln: mean,var N x 1 x 1
|
59 |
+
if x.dim() == 3:
|
60 |
+
mean = torch.mean(x, (1, 2), keepdim=True)
|
61 |
+
var = torch.mean((x - mean) ** 2, (1, 2), keepdim=True)
|
62 |
+
if self.elementwise_affine:
|
63 |
+
x = (
|
64 |
+
self.weight * (x - mean) / torch.sqrt(var + self.eps)
|
65 |
+
+ self.bias
|
66 |
+
)
|
67 |
+
else:
|
68 |
+
x = (x - mean) / torch.sqrt(var + self.eps)
|
69 |
+
|
70 |
+
if x.dim() == 4:
|
71 |
+
mean = torch.mean(x, (1, 2, 3), keepdim=True)
|
72 |
+
var = torch.mean((x - mean) ** 2, (1, 2, 3), keepdim=True)
|
73 |
+
if self.elementwise_affine:
|
74 |
+
x = (
|
75 |
+
self.weight * (x - mean) / torch.sqrt(var + self.eps)
|
76 |
+
+ self.bias
|
77 |
+
)
|
78 |
+
else:
|
79 |
+
x = (x - mean) / torch.sqrt(var + self.eps)
|
80 |
+
return x
|
81 |
+
|
82 |
+
|
83 |
+
class CumulativeLayerNorm(nn.LayerNorm):
|
84 |
+
"""Calculate Cumulative Layer Normalization.
|
85 |
+
|
86 |
+
Arguments
|
87 |
+
---------
|
88 |
+
dim : int
|
89 |
+
Dimension that you want to normalize.
|
90 |
+
elementwise_affine : True
|
91 |
+
Learnable per-element affine parameters.
|
92 |
+
|
93 |
+
Example
|
94 |
+
-------
|
95 |
+
>>> x = torch.randn(5, 10, 20)
|
96 |
+
>>> CLN = CumulativeLayerNorm(10)
|
97 |
+
>>> x_norm = CLN(x)
|
98 |
+
"""
|
99 |
+
|
100 |
+
def __init__(self, dim, elementwise_affine=True):
|
101 |
+
super(CumulativeLayerNorm, self).__init__(
|
102 |
+
dim, elementwise_affine=elementwise_affine, eps=1e-8
|
103 |
+
)
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
"""Returns the normalized tensor.
|
107 |
+
|
108 |
+
Arguments
|
109 |
+
---------
|
110 |
+
x : torch.Tensor
|
111 |
+
Tensor size [N, C, K, S] or [N, C, L]
|
112 |
+
"""
|
113 |
+
# x: N x C x K x S or N x C x L
|
114 |
+
# N x K x S x C
|
115 |
+
if x.dim() == 4:
|
116 |
+
x = x.permute(0, 2, 3, 1).contiguous()
|
117 |
+
# N x K x S x C == only channel norm
|
118 |
+
x = super().forward(x)
|
119 |
+
# N x C x K x S
|
120 |
+
x = x.permute(0, 3, 1, 2).contiguous()
|
121 |
+
if x.dim() == 3:
|
122 |
+
x = torch.transpose(x, 1, 2)
|
123 |
+
# N x L x C == only channel norm
|
124 |
+
x = super().forward(x)
|
125 |
+
# N x C x L
|
126 |
+
x = torch.transpose(x, 1, 2)
|
127 |
+
return x
|
128 |
+
|
129 |
+
|
130 |
+
def select_norm(norm, dim, shape):
|
131 |
+
"""Just a wrapper to select the normalization type.
|
132 |
+
"""
|
133 |
+
|
134 |
+
if norm == "gln":
|
135 |
+
return GlobalLayerNorm(dim, shape, elementwise_affine=True)
|
136 |
+
if norm == "cln":
|
137 |
+
return CumulativeLayerNorm(dim, elementwise_affine=True)
|
138 |
+
if norm == "ln":
|
139 |
+
return nn.GroupNorm(1, dim, eps=1e-8)
|
140 |
+
else:
|
141 |
+
return nn.BatchNorm1d(dim)
|
142 |
+
|
143 |
+
class Swish(nn.Module):
|
144 |
+
"""
|
145 |
+
Swish is a smooth, non-monotonic function that consistently matches or outperforms ReLU on deep networks applied
|
146 |
+
to a variety of challenging domains such as Image classification and Machine translation.
|
147 |
+
"""
|
148 |
+
def __init__(self):
|
149 |
+
super(Swish, self).__init__()
|
150 |
+
|
151 |
+
def forward(self, inputs: Tensor) -> Tensor:
|
152 |
+
return inputs * inputs.sigmoid()
|
153 |
+
|
154 |
+
|
155 |
+
class GLU(nn.Module):
|
156 |
+
"""
|
157 |
+
The gating mechanism is called Gated Linear Units (GLU), which was first introduced for natural language processing
|
158 |
+
in the paper “Language Modeling with Gated Convolutional Networks”
|
159 |
+
"""
|
160 |
+
def __init__(self, dim: int) -> None:
|
161 |
+
super(GLU, self).__init__()
|
162 |
+
self.dim = dim
|
163 |
+
|
164 |
+
def forward(self, inputs: Tensor) -> Tensor:
|
165 |
+
outputs, gate = inputs.chunk(2, dim=self.dim)
|
166 |
+
return outputs * gate.sigmoid()
|
167 |
+
|
168 |
+
class Transpose(nn.Module):
|
169 |
+
""" Wrapper class of torch.transpose() for Sequential module. """
|
170 |
+
def __init__(self, shape: tuple):
|
171 |
+
super(Transpose, self).__init__()
|
172 |
+
self.shape = shape
|
173 |
+
|
174 |
+
def forward(self, x: Tensor) -> Tensor:
|
175 |
+
return x.transpose(*self.shape)
|
176 |
+
|
177 |
+
class Linear(nn.Module):
|
178 |
+
"""
|
179 |
+
Wrapper class of torch.nn.Linear
|
180 |
+
Weight initialize by xavier initialization and bias initialize to zeros.
|
181 |
+
"""
|
182 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
|
183 |
+
super(Linear, self).__init__()
|
184 |
+
self.linear = nn.Linear(in_features, out_features, bias=bias)
|
185 |
+
init.xavier_uniform_(self.linear.weight)
|
186 |
+
if bias:
|
187 |
+
init.zeros_(self.linear.bias)
|
188 |
+
|
189 |
+
def forward(self, x: Tensor) -> Tensor:
|
190 |
+
return self.linear(x)
|
191 |
+
|
192 |
+
class DepthwiseConv1d(nn.Module):
|
193 |
+
"""
|
194 |
+
When groups == in_channels and out_channels == K * in_channels, where K is a positive integer,
|
195 |
+
this operation is termed in literature as depthwise convolution.
|
196 |
+
Args:
|
197 |
+
in_channels (int): Number of channels in the input
|
198 |
+
out_channels (int): Number of channels produced by the convolution
|
199 |
+
kernel_size (int or tuple): Size of the convolving kernel
|
200 |
+
stride (int, optional): Stride of the convolution. Default: 1
|
201 |
+
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
|
202 |
+
bias (bool, optional): If True, adds a learnable bias to the output. Default: True
|
203 |
+
Inputs: inputs
|
204 |
+
- **inputs** (batch, in_channels, time): Tensor containing input vector
|
205 |
+
Returns: outputs
|
206 |
+
- **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution.
|
207 |
+
"""
|
208 |
+
def __init__(
|
209 |
+
self,
|
210 |
+
in_channels: int,
|
211 |
+
out_channels: int,
|
212 |
+
kernel_size: int,
|
213 |
+
stride: int = 1,
|
214 |
+
padding: int = 0,
|
215 |
+
bias: bool = False,
|
216 |
+
) -> None:
|
217 |
+
super(DepthwiseConv1d, self).__init__()
|
218 |
+
assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels"
|
219 |
+
self.conv = nn.Conv1d(
|
220 |
+
in_channels=in_channels,
|
221 |
+
out_channels=out_channels,
|
222 |
+
kernel_size=kernel_size,
|
223 |
+
groups=in_channels,
|
224 |
+
stride=stride,
|
225 |
+
padding=padding,
|
226 |
+
bias=bias,
|
227 |
+
)
|
228 |
+
|
229 |
+
def forward(self, inputs: Tensor) -> Tensor:
|
230 |
+
return self.conv(inputs)
|
231 |
+
|
232 |
+
|
233 |
+
class PointwiseConv1d(nn.Module):
|
234 |
+
"""
|
235 |
+
When kernel size == 1 conv1d, this operation is termed in literature as pointwise convolution.
|
236 |
+
This operation often used to match dimensions.
|
237 |
+
Args:
|
238 |
+
in_channels (int): Number of channels in the input
|
239 |
+
out_channels (int): Number of channels produced by the convolution
|
240 |
+
stride (int, optional): Stride of the convolution. Default: 1
|
241 |
+
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
|
242 |
+
bias (bool, optional): If True, adds a learnable bias to the output. Default: True
|
243 |
+
Inputs: inputs
|
244 |
+
- **inputs** (batch, in_channels, time): Tensor containing input vector
|
245 |
+
Returns: outputs
|
246 |
+
- **outputs** (batch, out_channels, time): Tensor produces by pointwise 1-D convolution.
|
247 |
+
"""
|
248 |
+
def __init__(
|
249 |
+
self,
|
250 |
+
in_channels: int,
|
251 |
+
out_channels: int,
|
252 |
+
stride: int = 1,
|
253 |
+
padding: int = 0,
|
254 |
+
bias: bool = True,
|
255 |
+
) -> None:
|
256 |
+
super(PointwiseConv1d, self).__init__()
|
257 |
+
self.conv = nn.Conv1d(
|
258 |
+
in_channels=in_channels,
|
259 |
+
out_channels=out_channels,
|
260 |
+
kernel_size=1,
|
261 |
+
stride=stride,
|
262 |
+
padding=padding,
|
263 |
+
bias=bias,
|
264 |
+
)
|
265 |
+
|
266 |
+
def forward(self, inputs: Tensor) -> Tensor:
|
267 |
+
return self.conv(inputs)
|
268 |
+
|
269 |
+
|
270 |
+
class ConvModule(nn.Module):
|
271 |
+
"""
|
272 |
+
Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
|
273 |
+
This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution
|
274 |
+
to aid training deep models.
|
275 |
+
Args:
|
276 |
+
in_channels (int): Number of channels in the input
|
277 |
+
kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31
|
278 |
+
dropout_p (float, optional): probability of dropout
|
279 |
+
Inputs: inputs
|
280 |
+
inputs (batch, time, dim): Tensor contains input sequences
|
281 |
+
Outputs: outputs
|
282 |
+
outputs (batch, time, dim): Tensor produces by conformer convolution module.
|
283 |
+
"""
|
284 |
+
def __init__(
|
285 |
+
self,
|
286 |
+
in_channels: int,
|
287 |
+
kernel_size: int = 17,
|
288 |
+
expansion_factor: int = 2,
|
289 |
+
dropout_p: float = 0.1,
|
290 |
+
) -> None:
|
291 |
+
super(ConvModule, self).__init__()
|
292 |
+
assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
|
293 |
+
assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2"
|
294 |
+
|
295 |
+
self.sequential = nn.Sequential(
|
296 |
+
Transpose(shape=(1, 2)),
|
297 |
+
DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2),
|
298 |
+
)
|
299 |
+
|
300 |
+
def forward(self, inputs: Tensor) -> Tensor:
|
301 |
+
return inputs + self.sequential(inputs).transpose(1, 2)
|
302 |
+
|
303 |
+
class ConvModule_Dilated(nn.Module):
|
304 |
+
"""
|
305 |
+
Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
|
306 |
+
This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution
|
307 |
+
to aid training deep models.
|
308 |
+
Args:
|
309 |
+
in_channels (int): Number of channels in the input
|
310 |
+
kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31
|
311 |
+
dropout_p (float, optional): probability of dropout
|
312 |
+
Inputs: inputs
|
313 |
+
inputs (batch, time, dim): Tensor contains input sequences
|
314 |
+
Outputs: outputs
|
315 |
+
outputs (batch, time, dim): Tensor produces by conformer convolution module.
|
316 |
+
"""
|
317 |
+
def __init__(
|
318 |
+
self,
|
319 |
+
in_channels: int,
|
320 |
+
kernel_size: int = 17,
|
321 |
+
expansion_factor: int = 2,
|
322 |
+
dropout_p: float = 0.1,
|
323 |
+
) -> None:
|
324 |
+
super(ConvModule_Gating, self).__init__()
|
325 |
+
assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
|
326 |
+
assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2"
|
327 |
+
self.sequential = nn.Sequential(
|
328 |
+
Transpose(shape=(1, 2)),
|
329 |
+
DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2),
|
330 |
+
)
|
331 |
+
|
332 |
+
def forward(self, inputs: Tensor) -> Tensor:
|
333 |
+
return inputs + self.sequential(inputs).transpose(1, 2)
|
334 |
+
|
335 |
+
class DilatedDenseNet(nn.Module):
|
336 |
+
def __init__(self, depth=4, lorder=20, in_channels=64):
|
337 |
+
super(DilatedDenseNet, self).__init__()
|
338 |
+
self.depth = depth
|
339 |
+
self.in_channels = in_channels
|
340 |
+
self.pad = nn.ConstantPad2d((1, 1, 1, 0), value=0.)
|
341 |
+
self.twidth = lorder*2-1
|
342 |
+
self.kernel_size = (self.twidth, 1)
|
343 |
+
for i in range(self.depth):
|
344 |
+
dil = 2 ** i
|
345 |
+
pad_length = lorder + (dil - 1) * (lorder - 1) - 1
|
346 |
+
setattr(self, 'pad{}'.format(i + 1), nn.ConstantPad2d((0, 0, pad_length, pad_length), value=0.))
|
347 |
+
setattr(self, 'conv{}'.format(i + 1),
|
348 |
+
nn.Conv2d(self.in_channels*(i+1), self.in_channels, kernel_size=self.kernel_size,
|
349 |
+
dilation=(dil, 1), groups=self.in_channels, bias=False))
|
350 |
+
setattr(self, 'norm{}'.format(i + 1), nn.InstanceNorm2d(in_channels, affine=True))
|
351 |
+
setattr(self, 'prelu{}'.format(i + 1), nn.PReLU(self.in_channels))
|
352 |
+
|
353 |
+
def forward(self, x):
|
354 |
+
x = torch.unsqueeze(x, 1)
|
355 |
+
x_per = x.permute(0, 3, 2, 1)
|
356 |
+
skip = x_per
|
357 |
+
for i in range(self.depth):
|
358 |
+
out = getattr(self, 'pad{}'.format(i + 1))(skip)
|
359 |
+
out = getattr(self, 'conv{}'.format(i + 1))(out)
|
360 |
+
out = getattr(self, 'norm{}'.format(i + 1))(out)
|
361 |
+
out = getattr(self, 'prelu{}'.format(i + 1))(out)
|
362 |
+
skip = torch.cat([out, skip], dim=1)
|
363 |
+
out1 = out.permute(0, 3, 2, 1)
|
364 |
+
return out1.squeeze(1)
|
365 |
+
|
366 |
+
class FFConvM_Dilated(nn.Module):
|
367 |
+
def __init__(
|
368 |
+
self,
|
369 |
+
dim_in,
|
370 |
+
dim_out,
|
371 |
+
norm_klass = nn.LayerNorm,
|
372 |
+
dropout = 0.1
|
373 |
+
):
|
374 |
+
super().__init__()
|
375 |
+
self.mdl = nn.Sequential(
|
376 |
+
norm_klass(dim_in),
|
377 |
+
nn.Linear(dim_in, dim_out),
|
378 |
+
nn.SiLU(),
|
379 |
+
DilatedDenseNet(depth=2, lorder=17, in_channels=dim_out),
|
380 |
+
nn.Dropout(dropout)
|
381 |
+
)
|
382 |
+
def forward(
|
383 |
+
self,
|
384 |
+
x,
|
385 |
+
):
|
386 |
+
output = self.mdl(x)
|
387 |
+
return output
|
388 |
+
|
models/mossformer2_sr/env.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
|
4 |
+
|
5 |
+
class AttrDict(dict):
|
6 |
+
def __init__(self, *args, **kwargs):
|
7 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
8 |
+
self.__dict__ = self
|
9 |
+
|
10 |
+
|
11 |
+
def build_env(config, config_name, path):
|
12 |
+
t_path = os.path.join(path, config_name)
|
13 |
+
if config != t_path:
|
14 |
+
os.makedirs(path, exist_ok=True)
|
15 |
+
shutil.copyfile(config, os.path.join(path, config_name))
|
models/mossformer2_sr/fsmn.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch as th
|
4 |
+
from torch.nn.parameter import Parameter
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
|
8 |
+
class UniDeepFsmn(nn.Module):
|
9 |
+
"""
|
10 |
+
UniDeepFsmn is a neural network module that implements a single-deep feedforward sequence memory network (FSMN).
|
11 |
+
|
12 |
+
Attributes:
|
13 |
+
input_dim (int): Dimension of the input features.
|
14 |
+
output_dim (int): Dimension of the output features.
|
15 |
+
lorder (int): Length of the order for the convolution layers.
|
16 |
+
hidden_size (int): Number of hidden units in the linear layer.
|
17 |
+
linear (nn.Linear): Linear layer to project input features to hidden size.
|
18 |
+
project (nn.Linear): Linear layer to project hidden features to output dimensions.
|
19 |
+
conv1 (nn.Conv2d): Convolutional layer for processing the output in a grouped manner.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None):
|
23 |
+
super(UniDeepFsmn, self).__init__()
|
24 |
+
|
25 |
+
self.input_dim = input_dim
|
26 |
+
self.output_dim = output_dim
|
27 |
+
if lorder is None:
|
28 |
+
return
|
29 |
+
self.lorder = lorder
|
30 |
+
self.hidden_size = hidden_size
|
31 |
+
|
32 |
+
# Initialize the layers
|
33 |
+
self.linear = nn.Linear(input_dim, hidden_size) # Linear transformation to hidden size
|
34 |
+
self.project = nn.Linear(hidden_size, output_dim, bias=False) # Project hidden size to output dimension
|
35 |
+
self.conv1 = nn.Conv2d(output_dim, output_dim, [lorder + lorder - 1, 1], [1, 1], groups=output_dim, bias=False) # Convolution layer
|
36 |
+
|
37 |
+
def forward(self, input):
|
38 |
+
"""
|
39 |
+
Forward pass for the UniDeepFsmn model.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
input (torch.Tensor): Input tensor of shape (batch_size, input_dim).
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
torch.Tensor: The output tensor of the same shape as input, enhanced by the network.
|
46 |
+
"""
|
47 |
+
f1 = F.relu(self.linear(input)) # Apply linear layer followed by ReLU activation
|
48 |
+
p1 = self.project(f1) # Project to output dimension
|
49 |
+
x = th.unsqueeze(p1, 1) # Add a dimension for compatibility with Conv2d
|
50 |
+
x_per = x.permute(0, 3, 2, 1) # Permute dimensions for convolution
|
51 |
+
y = F.pad(x_per, [0, 0, self.lorder - 1, self.lorder - 1]) # Pad for causal convolution
|
52 |
+
out = x_per + self.conv1(y) # Add original input to convolution output
|
53 |
+
out1 = out.permute(0, 3, 2, 1) # Permute back to original dimensions
|
54 |
+
return input + out1.squeeze() # Return enhanced input
|
55 |
+
|
56 |
+
|
57 |
+
class UniDeepFsmn_dual(nn.Module):
|
58 |
+
"""
|
59 |
+
UniDeepFsmn_dual is a neural network module that implements a dual-deep feedforward sequence memory network (FSMN).
|
60 |
+
|
61 |
+
This class extends the UniDeepFsmn by adding a second convolution layer for richer feature extraction.
|
62 |
+
|
63 |
+
Attributes:
|
64 |
+
input_dim (int): Dimension of the input features.
|
65 |
+
output_dim (int): Dimension of the output features.
|
66 |
+
lorder (int): Length of the order for the convolution layers.
|
67 |
+
hidden_size (int): Number of hidden units in the linear layer.
|
68 |
+
linear (nn.Linear): Linear layer to project input features to hidden size.
|
69 |
+
project (nn.Linear): Linear layer to project hidden features to output dimensions.
|
70 |
+
conv1 (nn.Conv2d): First convolutional layer for processing the output.
|
71 |
+
conv2 (nn.Conv2d): Second convolutional layer for further processing the features.
|
72 |
+
"""
|
73 |
+
|
74 |
+
def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None):
|
75 |
+
super(UniDeepFsmn_dual, self).__init__()
|
76 |
+
|
77 |
+
self.input_dim = input_dim
|
78 |
+
self.output_dim = output_dim
|
79 |
+
if lorder is None:
|
80 |
+
return
|
81 |
+
self.lorder = lorder
|
82 |
+
self.hidden_size = hidden_size
|
83 |
+
|
84 |
+
# Initialize the layers
|
85 |
+
self.linear = nn.Linear(input_dim, hidden_size) # Linear transformation to hidden size
|
86 |
+
self.project = nn.Linear(hidden_size, output_dim, bias=False) # Project hidden size to output dimension
|
87 |
+
self.conv1 = nn.Conv2d(output_dim, output_dim, [lorder + lorder - 1, 1], [1, 1], groups=output_dim, bias=False) # First convolution layer
|
88 |
+
self.conv2 = nn.Conv2d(output_dim, output_dim, [lorder + lorder - 1, 1], [1, 1], groups=output_dim // 4, bias=False) # Second convolution layer
|
89 |
+
|
90 |
+
def forward(self, input):
|
91 |
+
"""
|
92 |
+
Forward pass for the UniDeepFsmn_dual model.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
input (torch.Tensor): Input tensor of shape (batch_size, input_dim).
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
torch.Tensor: The output tensor of the same shape as input, enhanced by the network.
|
99 |
+
"""
|
100 |
+
f1 = F.relu(self.linear(input)) # Apply linear layer followed by ReLU activation
|
101 |
+
p1 = self.project(f1) # Project to output dimension
|
102 |
+
x = th.unsqueeze(p1, 1) # Add a dimension for compatibility with Conv2d
|
103 |
+
x_per = x.permute(0, 3, 2, 1) # Permute dimensions for convolution
|
104 |
+
y = F.pad(x_per, [0, 0, self.lorder - 1, self.lorder - 1]) # Pad for causal convolution
|
105 |
+
conv1_out = x_per + self.conv1(y) # Add original input to first convolution output
|
106 |
+
z = F.pad(conv1_out, [0, 0, self.lorder - 1, self.lorder - 1]) # Pad for second convolution
|
107 |
+
out = conv1_out + self.conv2(z) # Add output of second convolution
|
108 |
+
out1 = out.permute(0, 3, 2, 1) # Permute back to original dimensions
|
109 |
+
return input + out1.squeeze() # Return enhanced input
|
110 |
+
|
111 |
+
|
112 |
+
class DilatedDenseNet(nn.Module):
|
113 |
+
"""
|
114 |
+
DilatedDenseNet implements a dense network structure with dilated convolutions.
|
115 |
+
|
116 |
+
This architecture enables wider receptive fields while maintaining a lower number of parameters.
|
117 |
+
It consists of multiple convolutional layers with dilation rates that increase at each layer.
|
118 |
+
|
119 |
+
Attributes:
|
120 |
+
depth (int): Number of convolutional layers in the network.
|
121 |
+
in_channels (int): Number of input channels for the first layer.
|
122 |
+
pad (nn.ConstantPad2d): Padding layer to maintain dimensions.
|
123 |
+
twidth (int): Width of the kernel used in convolution.
|
124 |
+
kernel_size (tuple): Kernel size for convolution operations.
|
125 |
+
"""
|
126 |
+
|
127 |
+
def __init__(self, depth=4, lorder=20, in_channels=64):
|
128 |
+
super(DilatedDenseNet, self).__init__()
|
129 |
+
self.depth = depth
|
130 |
+
self.in_channels = in_channels
|
131 |
+
self.pad = nn.ConstantPad2d((1, 1, 1, 0), value=0.) # Padding for the input
|
132 |
+
self.twidth = lorder * 2 - 1 # Width of the kernel
|
133 |
+
self.kernel_size = (self.twidth, 1) # Kernel size for convolutions
|
134 |
+
|
135 |
+
# Initialize layers dynamically based on depth
|
136 |
+
for i in range(self.depth):
|
137 |
+
dil = 2 ** i # Calculate dilation rate
|
138 |
+
pad_length = lorder + (dil - 1) * (lorder - 1) - 1 # Calculate padding length
|
139 |
+
setattr(self, 'pad{}'.format(i + 1), nn.ConstantPad2d((0, 0, pad_length, pad_length), value=0.)) # Padding for dilation
|
140 |
+
setattr(self, 'conv{}'.format(i + 1),
|
141 |
+
nn.Conv2d(self.in_channels * (i + 1), self.in_channels, kernel_size=self.kernel_size,
|
142 |
+
dilation=(dil, 1), groups=self.in_channels, bias=False)) # Convolution layer with dilation
|
143 |
+
setattr(self, 'norm{}'.format(i + 1), nn.InstanceNorm2d(in_channels, affine=True)) # Normalization layer
|
144 |
+
setattr(self, 'prelu{}'.format(i + 1), nn.PReLU(self.in_channels)) # Activation layer
|
145 |
+
|
146 |
+
def forward(self, x):
|
147 |
+
"""
|
148 |
+
Forward pass for the DilatedDenseNet model.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
x (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width).
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
torch.Tensor: Output tensor after applying dense layers.
|
155 |
+
"""
|
156 |
+
skip = x # Initialize skip connection
|
157 |
+
for i in range(self.depth):
|
158 |
+
out = getattr(self, 'pad{}'.format(i + 1))(skip) # Apply padding
|
159 |
+
out = getattr(self, 'conv{}'.format(i + 1))(out) # Apply convolution
|
160 |
+
out = getattr(self, 'norm{}'.format(i + 1))(out) # Apply normalization
|
161 |
+
out = getattr(self, 'prelu{}'.format(i + 1))(out) # Apply PReLU activation
|
162 |
+
skip = th.cat([out, skip], dim=1) # Concatenate the output with the skip connection
|
163 |
+
return out # Return the final output
|
164 |
+
|
165 |
+
class UniDeepFsmn_dilated(nn.Module):
|
166 |
+
"""
|
167 |
+
UniDeepFsmn_dilated combines the UniDeepFsmn architecture with a dilated dense network
|
168 |
+
to enhance feature extraction while maintaining efficient computation.
|
169 |
+
|
170 |
+
Attributes:
|
171 |
+
input_dim (int): Dimension of the input features.
|
172 |
+
output_dim (int): Dimension of the output features.
|
173 |
+
depth (int): Depth of the dilated dense network.
|
174 |
+
lorder (int): Length of the order for the convolution layers.
|
175 |
+
hidden_size (int): Number of hidden units in the linear layer.
|
176 |
+
linear (nn.Linear): Linear layer to project input features to hidden size.
|
177 |
+
project (nn.Linear): Linear layer to project hidden features to output dimensions.
|
178 |
+
conv (DilatedDenseNet): Instance of the DilatedDenseNet for feature extraction.
|
179 |
+
"""
|
180 |
+
|
181 |
+
def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None, depth=2):
|
182 |
+
super(UniDeepFsmn_dilated, self).__init__()
|
183 |
+
|
184 |
+
self.input_dim = input_dim
|
185 |
+
self.output_dim = output_dim
|
186 |
+
self.depth = depth
|
187 |
+
if lorder is None:
|
188 |
+
return
|
189 |
+
self.lorder = lorder
|
190 |
+
self.hidden_size = hidden_size
|
191 |
+
|
192 |
+
# Initialize layers
|
193 |
+
self.linear = nn.Linear(input_dim, hidden_size) # Linear transformation to hidden size
|
194 |
+
self.project = nn.Linear(hidden_size, output_dim, bias=False) # Project hidden size to output dimension
|
195 |
+
self.conv = DilatedDenseNet(depth=self.depth, lorder=lorder, in_channels=output_dim) # Dilated dense network for feature extraction
|
196 |
+
|
197 |
+
def forward(self, input):
|
198 |
+
"""
|
199 |
+
Forward pass for the UniDeepFsmn_dilated model.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
input (torch.Tensor): Input tensor of shape (batch_size, input_dim).
|
203 |
+
|
204 |
+
Returns:
|
205 |
+
torch.Tensor: The output tensor of the same shape as input, enhanced by the network.
|
206 |
+
"""
|
207 |
+
f1 = F.relu(self.linear(input)) # Apply linear layer followed by ReLU activation
|
208 |
+
p1 = self.project(f1) # Project to output dimension
|
209 |
+
x = th.unsqueeze(p1, 1) # Add a dimension for compatibility with Conv2d
|
210 |
+
x_per = x.permute(0, 3, 2, 1) # Permute dimensions for convolution
|
211 |
+
out = self.conv(x_per) # Pass through the dilated dense network
|
212 |
+
out1 = out.permute(0, 3, 2, 1) # Permute back to original dimensions
|
213 |
+
|
214 |
+
return input + out1.squeeze() # Return enhanced input
|
models/mossformer2_sr/generator.py
ADDED
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
5 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
6 |
+
from models.mossformer2_sr.utils import init_weights, get_padding
|
7 |
+
from models.mossformer2_sr.mossformer2 import MossFormer_MaskNet
|
8 |
+
from models.mossformer2_sr.snake import Snake1d
|
9 |
+
from typing import Optional, List, Union, Dict, Tuple
|
10 |
+
from models.mossformer2_sr.env import AttrDict
|
11 |
+
import typing
|
12 |
+
from torchaudio.transforms import Spectrogram, Resample
|
13 |
+
|
14 |
+
LRELU_SLOPE = 0.1
|
15 |
+
|
16 |
+
|
17 |
+
class ResBlock1(torch.nn.Module):
|
18 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
19 |
+
super(ResBlock1, self).__init__()
|
20 |
+
self.h = h
|
21 |
+
self.convs1 = nn.ModuleList([
|
22 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
23 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
24 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
25 |
+
padding=get_padding(kernel_size, dilation[1]))),
|
26 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
27 |
+
padding=get_padding(kernel_size, dilation[2])))
|
28 |
+
#Snake1d(channels)
|
29 |
+
])
|
30 |
+
self.convs1.apply(init_weights)
|
31 |
+
self.convs1_activates = nn.ModuleList([
|
32 |
+
Snake1d(channels),
|
33 |
+
Snake1d(channels),
|
34 |
+
Snake1d(channels)
|
35 |
+
])
|
36 |
+
self.convs2 = nn.ModuleList([
|
37 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
38 |
+
padding=get_padding(kernel_size, 1))),
|
39 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
40 |
+
padding=get_padding(kernel_size, 1))),
|
41 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
42 |
+
padding=get_padding(kernel_size, 1)))
|
43 |
+
#Snake1d(channels)
|
44 |
+
])
|
45 |
+
self.convs2.apply(init_weights)
|
46 |
+
#self.convs2_activate = Snake1d(channels)
|
47 |
+
self.convs2_activates = nn.ModuleList([
|
48 |
+
Snake1d(channels),
|
49 |
+
Snake1d(channels),
|
50 |
+
Snake1d(channels)
|
51 |
+
])
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
for c1, c2, act1, act2 in zip(self.convs1, self.convs2, self.convs1_activates, self.convs2_activates):
|
55 |
+
#xt = F.leaky_relu(x, LRELU_SLOPE)
|
56 |
+
#print(f'xt: {xt.shape}')
|
57 |
+
xt = act1(x)
|
58 |
+
xt = c1(xt)
|
59 |
+
#xt = F.leaky_relu(xt, LRELU_SLOPE)
|
60 |
+
xt = act2(xt)
|
61 |
+
xt = c2(xt)
|
62 |
+
x = xt + x
|
63 |
+
return x
|
64 |
+
|
65 |
+
def remove_weight_norm(self):
|
66 |
+
for l in self.convs1:
|
67 |
+
remove_weight_norm(l)
|
68 |
+
for l in self.convs2:
|
69 |
+
remove_weight_norm(l)
|
70 |
+
|
71 |
+
|
72 |
+
class ResBlock2(torch.nn.Module):
|
73 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
|
74 |
+
super(ResBlock2, self).__init__()
|
75 |
+
self.h = h
|
76 |
+
self.convs = nn.ModuleList([
|
77 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
78 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
79 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
80 |
+
padding=get_padding(kernel_size, dilation[1])))
|
81 |
+
#Snake1d(channels)
|
82 |
+
])
|
83 |
+
self.convs.apply(init_weights)
|
84 |
+
#self.convs_activate = Snake1d(channels)
|
85 |
+
self.convs_activates = nn.ModuleList([
|
86 |
+
Snake1d(channels),
|
87 |
+
Snake1d(channels)
|
88 |
+
])
|
89 |
+
def forward(self, x):
|
90 |
+
for c, act in zip(self.convs, self.convs_activates):
|
91 |
+
#xt = F.leaky_relu(x, LRELU_SLOPE)
|
92 |
+
xt = act(x)
|
93 |
+
xt = c(xt)
|
94 |
+
x = xt + x
|
95 |
+
return x
|
96 |
+
|
97 |
+
def remove_weight_norm(self):
|
98 |
+
for l in self.convs:
|
99 |
+
remove_weight_norm(l)
|
100 |
+
|
101 |
+
|
102 |
+
class Generator(torch.nn.Module):
|
103 |
+
def __init__(self, h):
|
104 |
+
super(Generator, self).__init__()
|
105 |
+
self.h = h
|
106 |
+
self.num_kernels = len(h.resblock_kernel_sizes)
|
107 |
+
self.num_upsamples = len(h.upsample_rates)
|
108 |
+
self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
|
109 |
+
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
|
110 |
+
|
111 |
+
self.ups = nn.ModuleList()
|
112 |
+
self.snakes = nn.ModuleList()
|
113 |
+
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
114 |
+
self.snakes.append(Snake1d(h.upsample_initial_channel//(2**i)))
|
115 |
+
self.ups.append(weight_norm(
|
116 |
+
ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
|
117 |
+
k, u, padding=(k-u)//2)))
|
118 |
+
|
119 |
+
self.resblocks = nn.ModuleList()
|
120 |
+
for i in range(len(self.ups)):
|
121 |
+
ch = h.upsample_initial_channel//(2**(i+1))
|
122 |
+
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
123 |
+
self.resblocks.append(resblock(h, ch, k, d))
|
124 |
+
|
125 |
+
self.snake_post = Snake1d(ch)
|
126 |
+
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
127 |
+
self.ups.apply(init_weights)
|
128 |
+
self.conv_post.apply(init_weights)
|
129 |
+
def forward(self, x):
|
130 |
+
x = self.conv_pre(x)
|
131 |
+
for i in range(self.num_upsamples):
|
132 |
+
#x = F.leaky_relu(x, LRELU_SLOPE)
|
133 |
+
#print(f'x {i}: {x.shape}')
|
134 |
+
x = self.snakes[i](x)
|
135 |
+
x = self.ups[i](x)
|
136 |
+
xs = None
|
137 |
+
for j in range(self.num_kernels):
|
138 |
+
if xs is None:
|
139 |
+
xs = self.resblocks[i*self.num_kernels+j](x)
|
140 |
+
else:
|
141 |
+
xs += self.resblocks[i*self.num_kernels+j](x)
|
142 |
+
x = xs / self.num_kernels
|
143 |
+
#x = F.leaky_relu(x)
|
144 |
+
x = self.snake_post(x)
|
145 |
+
x = self.conv_post(x)
|
146 |
+
x = torch.tanh(x)
|
147 |
+
|
148 |
+
return x
|
149 |
+
|
150 |
+
def remove_weight_norm(self):
|
151 |
+
#print('Removing weight norm...')
|
152 |
+
for l in self.ups:
|
153 |
+
remove_weight_norm(l)
|
154 |
+
for l in self.resblocks:
|
155 |
+
l.remove_weight_norm()
|
156 |
+
remove_weight_norm(self.conv_pre)
|
157 |
+
remove_weight_norm(self.conv_post)
|
158 |
+
|
159 |
+
|
160 |
+
class DiscriminatorP(torch.nn.Module):
|
161 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
162 |
+
super(DiscriminatorP, self).__init__()
|
163 |
+
self.period = period
|
164 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
165 |
+
self.convs = nn.ModuleList([
|
166 |
+
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
167 |
+
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
168 |
+
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
169 |
+
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
170 |
+
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
171 |
+
])
|
172 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
173 |
+
|
174 |
+
def forward(self, x):
|
175 |
+
fmap = []
|
176 |
+
|
177 |
+
# 1d to 2d
|
178 |
+
b, c, t = x.shape
|
179 |
+
if t % self.period != 0: # pad first
|
180 |
+
n_pad = self.period - (t % self.period)
|
181 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
182 |
+
t = t + n_pad
|
183 |
+
x = x.view(b, c, t // self.period, self.period)
|
184 |
+
|
185 |
+
for l in self.convs:
|
186 |
+
x = l(x)
|
187 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
188 |
+
fmap.append(x)
|
189 |
+
x = self.conv_post(x)
|
190 |
+
fmap.append(x)
|
191 |
+
x = torch.flatten(x, 1, -1)
|
192 |
+
|
193 |
+
return x, fmap
|
194 |
+
|
195 |
+
|
196 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
197 |
+
def __init__(self):
|
198 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
199 |
+
self.discriminators = nn.ModuleList([
|
200 |
+
DiscriminatorP(2),
|
201 |
+
DiscriminatorP(3),
|
202 |
+
DiscriminatorP(5),
|
203 |
+
DiscriminatorP(7),
|
204 |
+
DiscriminatorP(11),
|
205 |
+
])
|
206 |
+
|
207 |
+
def forward(self, y, y_hat):
|
208 |
+
y_d_rs = []
|
209 |
+
y_d_gs = []
|
210 |
+
fmap_rs = []
|
211 |
+
fmap_gs = []
|
212 |
+
for i, d in enumerate(self.discriminators):
|
213 |
+
y_d_r, fmap_r = d(y)
|
214 |
+
y_d_g, fmap_g = d(y_hat)
|
215 |
+
y_d_rs.append(y_d_r)
|
216 |
+
fmap_rs.append(fmap_r)
|
217 |
+
y_d_gs.append(y_d_g)
|
218 |
+
fmap_gs.append(fmap_g)
|
219 |
+
|
220 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
221 |
+
|
222 |
+
|
223 |
+
class DiscriminatorS(torch.nn.Module):
|
224 |
+
def __init__(self, use_spectral_norm=False):
|
225 |
+
super(DiscriminatorS, self).__init__()
|
226 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
227 |
+
self.convs = nn.ModuleList([
|
228 |
+
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
|
229 |
+
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
230 |
+
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
231 |
+
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
232 |
+
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
233 |
+
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
234 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
235 |
+
])
|
236 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
237 |
+
|
238 |
+
def forward(self, x):
|
239 |
+
fmap = []
|
240 |
+
for l in self.convs:
|
241 |
+
x = l(x)
|
242 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
243 |
+
fmap.append(x)
|
244 |
+
x = self.conv_post(x)
|
245 |
+
fmap.append(x)
|
246 |
+
x = torch.flatten(x, 1, -1)
|
247 |
+
|
248 |
+
return x, fmap
|
249 |
+
|
250 |
+
|
251 |
+
class MultiScaleDiscriminator(torch.nn.Module):
|
252 |
+
def __init__(self):
|
253 |
+
super(MultiScaleDiscriminator, self).__init__()
|
254 |
+
self.discriminators = nn.ModuleList([
|
255 |
+
DiscriminatorS(use_spectral_norm=True),
|
256 |
+
DiscriminatorS(),
|
257 |
+
DiscriminatorS(),
|
258 |
+
])
|
259 |
+
self.meanpools = nn.ModuleList([
|
260 |
+
AvgPool1d(4, 2, padding=2),
|
261 |
+
AvgPool1d(4, 2, padding=2)
|
262 |
+
])
|
263 |
+
|
264 |
+
def forward(self, y, y_hat):
|
265 |
+
y_d_rs = []
|
266 |
+
y_d_gs = []
|
267 |
+
fmap_rs = []
|
268 |
+
fmap_gs = []
|
269 |
+
for i, d in enumerate(self.discriminators):
|
270 |
+
if i != 0:
|
271 |
+
y = self.meanpools[i-1](y)
|
272 |
+
y_hat = self.meanpools[i-1](y_hat)
|
273 |
+
y_d_r, fmap_r = d(y)
|
274 |
+
y_d_g, fmap_g = d(y_hat)
|
275 |
+
y_d_rs.append(y_d_r)
|
276 |
+
fmap_rs.append(fmap_r)
|
277 |
+
y_d_gs.append(y_d_g)
|
278 |
+
fmap_gs.append(fmap_g)
|
279 |
+
|
280 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
281 |
+
|
282 |
+
|
283 |
+
# Method based on descript-audio-codec: https://github.com/descriptinc/descript-audio-codec
|
284 |
+
# Modified code adapted from https://github.com/gemelo-ai/vocos under the MIT license.
|
285 |
+
# LICENSE is in incl_licenses directory.
|
286 |
+
class DiscriminatorB(nn.Module):
|
287 |
+
def __init__(
|
288 |
+
self,
|
289 |
+
window_length: int,
|
290 |
+
channels: int = 32,
|
291 |
+
hop_factor: float = 0.25,
|
292 |
+
bands: Tuple[Tuple[float, float], ...] = (
|
293 |
+
(0.0, 0.1),
|
294 |
+
(0.1, 0.25),
|
295 |
+
(0.25, 0.5),
|
296 |
+
(0.5, 0.75),
|
297 |
+
(0.75, 1.0),
|
298 |
+
),
|
299 |
+
):
|
300 |
+
super().__init__()
|
301 |
+
self.window_length = window_length
|
302 |
+
self.hop_factor = hop_factor
|
303 |
+
self.spec_fn = Spectrogram(
|
304 |
+
n_fft=window_length,
|
305 |
+
hop_length=int(window_length * hop_factor),
|
306 |
+
win_length=window_length,
|
307 |
+
power=None,
|
308 |
+
)
|
309 |
+
n_fft = window_length // 2 + 1
|
310 |
+
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
|
311 |
+
self.bands = bands
|
312 |
+
convs = lambda: nn.ModuleList(
|
313 |
+
[
|
314 |
+
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
|
315 |
+
weight_norm(
|
316 |
+
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
317 |
+
),
|
318 |
+
weight_norm(
|
319 |
+
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
320 |
+
),
|
321 |
+
weight_norm(
|
322 |
+
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
323 |
+
),
|
324 |
+
weight_norm(
|
325 |
+
nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))
|
326 |
+
),
|
327 |
+
]
|
328 |
+
)
|
329 |
+
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
|
330 |
+
|
331 |
+
self.conv_post = weight_norm(
|
332 |
+
nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))
|
333 |
+
)
|
334 |
+
def spectrogram(self, x: torch.Tensor) -> List[torch.Tensor]:
|
335 |
+
# Remove DC offset
|
336 |
+
x = x - x.mean(dim=-1, keepdims=True)
|
337 |
+
# Peak normalize the volume of input audio
|
338 |
+
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
|
339 |
+
x = self.spec_fn(x)
|
340 |
+
x = torch.view_as_real(x)
|
341 |
+
x = x.permute(0, 3, 2, 1) # [B, F, T, C] -> [B, C, T, F]
|
342 |
+
# Split into bands
|
343 |
+
x_bands = [x[..., b[0] : b[1]] for b in self.bands]
|
344 |
+
return x_bands
|
345 |
+
|
346 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
347 |
+
x_bands = self.spectrogram(x.squeeze(1))
|
348 |
+
fmap = []
|
349 |
+
x = []
|
350 |
+
|
351 |
+
for band, stack in zip(x_bands, self.band_convs):
|
352 |
+
for i, layer in enumerate(stack):
|
353 |
+
band = layer(band)
|
354 |
+
band = torch.nn.functional.leaky_relu(band, 0.1)
|
355 |
+
if i > 0:
|
356 |
+
fmap.append(band)
|
357 |
+
x.append(band)
|
358 |
+
|
359 |
+
x = torch.cat(x, dim=-1)
|
360 |
+
x = self.conv_post(x)
|
361 |
+
fmap.append(x)
|
362 |
+
|
363 |
+
return x, fmap
|
364 |
+
|
365 |
+
# Method based on descript-audio-codec: https://github.com/descriptinc/descript-audio-codec
|
366 |
+
# Modified code adapted from https://github.com/gemelo-ai/vocos under the MIT license.
|
367 |
+
# LICENSE is in incl_licenses directory.
|
368 |
+
class MultiBandDiscriminator(nn.Module):
|
369 |
+
def __init__(
|
370 |
+
self,
|
371 |
+
h,
|
372 |
+
):
|
373 |
+
"""
|
374 |
+
Multi-band multi-scale STFT discriminator, with the architecture based on https://github.com/descriptinc/descript-audio-codec.
|
375 |
+
and the modified code adapted from https://github.com/gemelo-ai/vocos.
|
376 |
+
"""
|
377 |
+
super().__init__()
|
378 |
+
# fft_sizes (list[int]): Tuple of window lengths for FFT. Defaults to [2048, 1024, 512] if not set in h.
|
379 |
+
self.fft_sizes = h.get("mbd_fft_sizes", [2048, 1024, 512])
|
380 |
+
self.discriminators = nn.ModuleList(
|
381 |
+
[DiscriminatorB(window_length=w) for w in self.fft_sizes]
|
382 |
+
)
|
383 |
+
|
384 |
+
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
385 |
+
List[torch.Tensor],
|
386 |
+
List[torch.Tensor],
|
387 |
+
List[List[torch.Tensor]],
|
388 |
+
List[List[torch.Tensor]],
|
389 |
+
]:
|
390 |
+
|
391 |
+
y_d_rs = []
|
392 |
+
y_d_gs = []
|
393 |
+
fmap_rs = []
|
394 |
+
fmap_gs = []
|
395 |
+
|
396 |
+
for d in self.discriminators:
|
397 |
+
y_d_r, fmap_r = d(x=y)
|
398 |
+
y_d_g, fmap_g = d(x=y_hat)
|
399 |
+
y_d_rs.append(y_d_r)
|
400 |
+
fmap_rs.append(fmap_r)
|
401 |
+
y_d_gs.append(y_d_g)
|
402 |
+
fmap_gs.append(fmap_g)
|
403 |
+
|
404 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
405 |
+
|
406 |
+
def feature_loss(fmap_r, fmap_g):
|
407 |
+
loss = 0
|
408 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
409 |
+
for rl, gl in zip(dr, dg):
|
410 |
+
loss += torch.mean(torch.abs(rl - gl))
|
411 |
+
|
412 |
+
return loss*2
|
413 |
+
|
414 |
+
|
415 |
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
416 |
+
loss = 0
|
417 |
+
r_losses = []
|
418 |
+
g_losses = []
|
419 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
420 |
+
r_loss = torch.mean((1-dr)**2)
|
421 |
+
g_loss = torch.mean(dg**2)
|
422 |
+
loss += (r_loss + g_loss)
|
423 |
+
r_losses.append(r_loss.item())
|
424 |
+
g_losses.append(g_loss.item())
|
425 |
+
|
426 |
+
return loss, r_losses, g_losses
|
427 |
+
|
428 |
+
|
429 |
+
def generator_loss(disc_outputs):
|
430 |
+
loss = 0
|
431 |
+
gen_losses = []
|
432 |
+
for dg in disc_outputs:
|
433 |
+
l = torch.mean((1-dg)**2)
|
434 |
+
gen_losses.append(l)
|
435 |
+
loss += l
|
436 |
+
|
437 |
+
return loss, gen_losses
|
438 |
+
|
439 |
+
class Mossformer(nn.Module):
|
440 |
+
|
441 |
+
def __init__(self):
|
442 |
+
super(Mossformer, self).__init__()
|
443 |
+
self.mossformer = MossFormer_MaskNet(in_channels=80, out_channels=512, out_channels_final=80)
|
444 |
+
|
445 |
+
def forward(self, input):
|
446 |
+
out = self.mossformer(input)
|
447 |
+
return out
|
448 |
+
|
models/mossformer2_sr/layer_norm.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python -u
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Copyright 2018 Northwestern Polytechnical University (author: Ke Wang)
|
5 |
+
|
6 |
+
from __future__ import absolute_import
|
7 |
+
from __future__ import division
|
8 |
+
from __future__ import print_function
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
|
13 |
+
|
14 |
+
class CLayerNorm(nn.LayerNorm):
|
15 |
+
"""Channel-wise layer normalization."""
|
16 |
+
|
17 |
+
def __init__(self, *args, **kwargs):
|
18 |
+
super(CLayerNorm, self).__init__(*args, **kwargs)
|
19 |
+
|
20 |
+
def forward(self, sample):
|
21 |
+
"""Forward function.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
sample: [batch_size, channels, length]
|
25 |
+
"""
|
26 |
+
if sample.dim() != 3:
|
27 |
+
raise RuntimeError('{} only accept 3-D tensor as input'.format(
|
28 |
+
self.__name__))
|
29 |
+
# [N, C, T] -> [N, T, C]
|
30 |
+
sample = torch.transpose(sample, 1, 2)
|
31 |
+
# LayerNorm
|
32 |
+
sample = super().forward(sample)
|
33 |
+
# [N, T, C] -> [N, C, T]
|
34 |
+
sample = torch.transpose(sample, 1, 2)
|
35 |
+
return sample
|
36 |
+
|
37 |
+
class ILayerNorm(nn.InstanceNorm1d):
|
38 |
+
"""Channel-wise layer normalization."""
|
39 |
+
|
40 |
+
def __init__(self, *args, **kwargs):
|
41 |
+
super(ILayerNorm, self).__init__(*args, **kwargs)
|
42 |
+
|
43 |
+
def forward(self, sample):
|
44 |
+
"""Forward function.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
sample: [batch_size, channels, length]
|
48 |
+
"""
|
49 |
+
if sample.dim() != 3:
|
50 |
+
raise RuntimeError('{} only accept 3-D tensor as input'.format(
|
51 |
+
self.__name__))
|
52 |
+
# [N, C, T] -> [N, T, C]
|
53 |
+
sample = torch.transpose(sample, 1, 2)
|
54 |
+
# LayerNorm
|
55 |
+
sample = super().forward(sample)
|
56 |
+
# [N, T, C] -> [N, C, T]
|
57 |
+
sample = torch.transpose(sample, 1, 2)
|
58 |
+
return sample
|
59 |
+
|
60 |
+
class GLayerNorm(nn.Module):
|
61 |
+
"""Global Layer Normalization for TasNet."""
|
62 |
+
|
63 |
+
def __init__(self, channels, eps=1e-5):
|
64 |
+
super(GLayerNorm, self).__init__()
|
65 |
+
self.eps = eps
|
66 |
+
self.norm_dim = channels
|
67 |
+
self.gamma = nn.Parameter(torch.Tensor(channels))
|
68 |
+
self.beta = nn.Parameter(torch.Tensor(channels))
|
69 |
+
self.reset_parameters()
|
70 |
+
|
71 |
+
def reset_parameters(self):
|
72 |
+
nn.init.ones_(self.gamma)
|
73 |
+
nn.init.zeros_(self.beta)
|
74 |
+
|
75 |
+
def forward(self, sample):
|
76 |
+
"""Forward function.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
sample: [batch_size, channels, length]
|
80 |
+
"""
|
81 |
+
if sample.dim() != 3:
|
82 |
+
raise RuntimeError('{} only accept 3-D tensor as input'.format(
|
83 |
+
self.__name__))
|
84 |
+
# [N, C, T] -> [N, T, C]
|
85 |
+
sample = torch.transpose(sample, 1, 2)
|
86 |
+
# Mean and variance [N, 1, 1]
|
87 |
+
mean = torch.mean(sample, (1, 2), keepdim=True)
|
88 |
+
var = torch.mean((sample - mean)**2, (1, 2), keepdim=True)
|
89 |
+
sample = (sample - mean) / torch.sqrt(var + self.eps) * \
|
90 |
+
self.gamma + self.beta
|
91 |
+
# [N, T, C] -> [N, C, T]
|
92 |
+
sample = torch.transpose(sample, 1, 2)
|
93 |
+
return sample
|
94 |
+
|
95 |
+
class _LayerNorm(nn.Module):
|
96 |
+
"""Layer Normalization base class."""
|
97 |
+
|
98 |
+
def __init__(self, channel_size):
|
99 |
+
super(_LayerNorm, self).__init__()
|
100 |
+
self.channel_size = channel_size
|
101 |
+
self.gamma = nn.Parameter(torch.ones(channel_size),
|
102 |
+
requires_grad=True)
|
103 |
+
self.beta = nn.Parameter(torch.zeros(channel_size),
|
104 |
+
requires_grad=True)
|
105 |
+
|
106 |
+
def apply_gain_and_bias(self, normed_x):
|
107 |
+
""" Assumes input of size `[batch, chanel, *]`. """
|
108 |
+
return (self.gamma * normed_x.transpose(1, -1) +
|
109 |
+
self.beta).transpose(1, -1)
|
110 |
+
|
111 |
+
|
112 |
+
class GlobLayerNorm(_LayerNorm):
|
113 |
+
"""Global Layer Normalization (globLN)."""
|
114 |
+
|
115 |
+
def forward(self, x):
|
116 |
+
""" Applies forward pass.
|
117 |
+
Works for any input size > 2D.
|
118 |
+
Args:
|
119 |
+
x (:class:`torch.Tensor`): Shape `[batch, chan, *]`
|
120 |
+
Returns:
|
121 |
+
:class:`torch.Tensor`: gLN_x `[batch, chan, *]`
|
122 |
+
"""
|
123 |
+
dims = list(range(1, len(x.shape)))
|
124 |
+
mean = x.mean(dim=dims, keepdim=True)
|
125 |
+
var = torch.pow(x - mean, 2).mean(dim=dims, keepdim=True)
|
126 |
+
return self.apply_gain_and_bias((x - mean) / (var + 1e-8).sqrt())
|
models/mossformer2_sr/mossformer2.py
ADDED
@@ -0,0 +1,711 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
modified from https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/lobes/models/dual_path.py
|
3 |
+
Author: Shengkui Zhao
|
4 |
+
"""
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import copy
|
11 |
+
from models.mossformer2_sr.mossformer2_block import ScaledSinuEmbedding, MossformerBlock_GFSMN, MossformerBlock
|
12 |
+
|
13 |
+
|
14 |
+
EPS = 1e-8
|
15 |
+
|
16 |
+
|
17 |
+
class GlobalLayerNorm(nn.Module):
|
18 |
+
"""Calculate Global Layer Normalization.
|
19 |
+
|
20 |
+
Arguments
|
21 |
+
---------
|
22 |
+
dim : (int or list or torch.Size)
|
23 |
+
Input shape from an expected input of size.
|
24 |
+
eps : float
|
25 |
+
A value added to the denominator for numerical stability.
|
26 |
+
elementwise_affine : bool
|
27 |
+
A boolean value that when set to True,
|
28 |
+
this module has learnable per-element affine parameters
|
29 |
+
initialized to ones (for weights) and zeros (for biases).
|
30 |
+
|
31 |
+
Example
|
32 |
+
-------
|
33 |
+
>>> x = torch.randn(5, 10, 20)
|
34 |
+
>>> GLN = GlobalLayerNorm(10, 3)
|
35 |
+
>>> x_norm = GLN(x)
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True):
|
39 |
+
super(GlobalLayerNorm, self).__init__()
|
40 |
+
self.dim = dim
|
41 |
+
self.eps = eps
|
42 |
+
self.elementwise_affine = elementwise_affine
|
43 |
+
|
44 |
+
if self.elementwise_affine:
|
45 |
+
if shape == 3:
|
46 |
+
self.weight = nn.Parameter(torch.ones(self.dim, 1))
|
47 |
+
self.bias = nn.Parameter(torch.zeros(self.dim, 1))
|
48 |
+
if shape == 4:
|
49 |
+
self.weight = nn.Parameter(torch.ones(self.dim, 1, 1))
|
50 |
+
self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1))
|
51 |
+
else:
|
52 |
+
self.register_parameter("weight", None)
|
53 |
+
self.register_parameter("bias", None)
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
"""Returns the normalized tensor.
|
57 |
+
|
58 |
+
Arguments
|
59 |
+
---------
|
60 |
+
x : torch.Tensor
|
61 |
+
Tensor of size [N, C, K, S] or [N, C, L].
|
62 |
+
"""
|
63 |
+
# x = N x C x K x S or N x C x L
|
64 |
+
# N x 1 x 1
|
65 |
+
# cln: mean,var N x 1 x K x S
|
66 |
+
# gln: mean,var N x 1 x 1
|
67 |
+
if x.dim() == 3:
|
68 |
+
mean = torch.mean(x, (1, 2), keepdim=True)
|
69 |
+
var = torch.mean((x - mean) ** 2, (1, 2), keepdim=True)
|
70 |
+
if self.elementwise_affine:
|
71 |
+
x = (
|
72 |
+
self.weight * (x - mean) / torch.sqrt(var + self.eps)
|
73 |
+
+ self.bias
|
74 |
+
)
|
75 |
+
else:
|
76 |
+
x = (x - mean) / torch.sqrt(var + self.eps)
|
77 |
+
|
78 |
+
if x.dim() == 4:
|
79 |
+
mean = torch.mean(x, (1, 2, 3), keepdim=True)
|
80 |
+
var = torch.mean((x - mean) ** 2, (1, 2, 3), keepdim=True)
|
81 |
+
if self.elementwise_affine:
|
82 |
+
x = (
|
83 |
+
self.weight * (x - mean) / torch.sqrt(var + self.eps)
|
84 |
+
+ self.bias
|
85 |
+
)
|
86 |
+
else:
|
87 |
+
x = (x - mean) / torch.sqrt(var + self.eps)
|
88 |
+
return x
|
89 |
+
|
90 |
+
|
91 |
+
class CumulativeLayerNorm(nn.LayerNorm):
|
92 |
+
"""Calculate Cumulative Layer Normalization.
|
93 |
+
|
94 |
+
Arguments
|
95 |
+
---------
|
96 |
+
dim : int
|
97 |
+
Dimension that you want to normalize.
|
98 |
+
elementwise_affine : True
|
99 |
+
Learnable per-element affine parameters.
|
100 |
+
|
101 |
+
Example
|
102 |
+
-------
|
103 |
+
>>> x = torch.randn(5, 10, 20)
|
104 |
+
>>> CLN = CumulativeLayerNorm(10)
|
105 |
+
>>> x_norm = CLN(x)
|
106 |
+
"""
|
107 |
+
|
108 |
+
def __init__(self, dim, elementwise_affine=True):
|
109 |
+
super(CumulativeLayerNorm, self).__init__(
|
110 |
+
dim, elementwise_affine=elementwise_affine, eps=1e-8
|
111 |
+
)
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
"""Returns the normalized tensor.
|
115 |
+
|
116 |
+
Arguments
|
117 |
+
---------
|
118 |
+
x : torch.Tensor
|
119 |
+
Tensor size [N, C, K, S] or [N, C, L]
|
120 |
+
"""
|
121 |
+
# x: N x C x K x S or N x C x L
|
122 |
+
# N x K x S x C
|
123 |
+
if x.dim() == 4:
|
124 |
+
x = x.permute(0, 2, 3, 1).contiguous()
|
125 |
+
# N x K x S x C == only channel norm
|
126 |
+
x = super().forward(x)
|
127 |
+
# N x C x K x S
|
128 |
+
x = x.permute(0, 3, 1, 2).contiguous()
|
129 |
+
if x.dim() == 3:
|
130 |
+
x = torch.transpose(x, 1, 2)
|
131 |
+
# N x L x C == only channel norm
|
132 |
+
x = super().forward(x)
|
133 |
+
# N x C x L
|
134 |
+
x = torch.transpose(x, 1, 2)
|
135 |
+
return x
|
136 |
+
|
137 |
+
|
138 |
+
def select_norm(norm, dim, shape):
|
139 |
+
"""Just a wrapper to select the normalization type.
|
140 |
+
"""
|
141 |
+
|
142 |
+
if norm == "gln":
|
143 |
+
return GlobalLayerNorm(dim, shape, elementwise_affine=True)
|
144 |
+
if norm == "cln":
|
145 |
+
return CumulativeLayerNorm(dim, elementwise_affine=True)
|
146 |
+
if norm == "ln":
|
147 |
+
return nn.GroupNorm(1, dim, eps=1e-8)
|
148 |
+
else:
|
149 |
+
return nn.BatchNorm1d(dim)
|
150 |
+
|
151 |
+
|
152 |
+
class Encoder(nn.Module):
|
153 |
+
"""Convolutional Encoder Layer.
|
154 |
+
|
155 |
+
Arguments
|
156 |
+
---------
|
157 |
+
kernel_size : int
|
158 |
+
Length of filters.
|
159 |
+
in_channels : int
|
160 |
+
Number of input channels.
|
161 |
+
out_channels : int
|
162 |
+
Number of output channels.
|
163 |
+
|
164 |
+
Example
|
165 |
+
-------
|
166 |
+
>>> x = torch.randn(2, 1000)
|
167 |
+
>>> encoder = Encoder(kernel_size=4, out_channels=64)
|
168 |
+
>>> h = encoder(x)
|
169 |
+
>>> h.shape
|
170 |
+
torch.Size([2, 64, 499])
|
171 |
+
"""
|
172 |
+
|
173 |
+
def __init__(self, kernel_size=2, out_channels=64, in_channels=1):
|
174 |
+
super(Encoder, self).__init__()
|
175 |
+
self.conv1d = nn.Conv1d(
|
176 |
+
in_channels=in_channels,
|
177 |
+
out_channels=out_channels,
|
178 |
+
kernel_size=kernel_size,
|
179 |
+
stride=kernel_size // 2,
|
180 |
+
groups=1,
|
181 |
+
bias=False,
|
182 |
+
)
|
183 |
+
self.in_channels = in_channels
|
184 |
+
|
185 |
+
def forward(self, x):
|
186 |
+
"""Return the encoded output.
|
187 |
+
|
188 |
+
Arguments
|
189 |
+
---------
|
190 |
+
x : torch.Tensor
|
191 |
+
Input tensor with dimensionality [B, L].
|
192 |
+
Return
|
193 |
+
------
|
194 |
+
x : torch.Tensor
|
195 |
+
Encoded tensor with dimensionality [B, N, T_out].
|
196 |
+
|
197 |
+
where B = Batchsize
|
198 |
+
L = Number of timepoints
|
199 |
+
N = Number of filters
|
200 |
+
T_out = Number of timepoints at the output of the encoder
|
201 |
+
"""
|
202 |
+
# B x L -> B x 1 x L
|
203 |
+
if self.in_channels == 1:
|
204 |
+
x = torch.unsqueeze(x, dim=1)
|
205 |
+
# B x 1 x L -> B x N x T_out
|
206 |
+
x = self.conv1d(x)
|
207 |
+
x = F.relu(x)
|
208 |
+
|
209 |
+
return x
|
210 |
+
|
211 |
+
|
212 |
+
class Decoder(nn.ConvTranspose1d):
|
213 |
+
"""A decoder layer that consists of ConvTranspose1d.
|
214 |
+
|
215 |
+
Arguments
|
216 |
+
---------
|
217 |
+
kernel_size : int
|
218 |
+
Length of filters.
|
219 |
+
in_channels : int
|
220 |
+
Number of input channels.
|
221 |
+
out_channels : int
|
222 |
+
Number of output channels.
|
223 |
+
|
224 |
+
|
225 |
+
Example
|
226 |
+
---------
|
227 |
+
>>> x = torch.randn(2, 100, 1000)
|
228 |
+
>>> decoder = Decoder(kernel_size=4, in_channels=100, out_channels=1)
|
229 |
+
>>> h = decoder(x)
|
230 |
+
>>> h.shape
|
231 |
+
torch.Size([2, 1003])
|
232 |
+
"""
|
233 |
+
|
234 |
+
def __init__(self, *args, **kwargs):
|
235 |
+
super(Decoder, self).__init__(*args, **kwargs)
|
236 |
+
|
237 |
+
def forward(self, x):
|
238 |
+
"""Return the decoded output.
|
239 |
+
|
240 |
+
Arguments
|
241 |
+
---------
|
242 |
+
x : torch.Tensor
|
243 |
+
Input tensor with dimensionality [B, N, L].
|
244 |
+
where, B = Batchsize,
|
245 |
+
N = number of filters
|
246 |
+
L = time points
|
247 |
+
"""
|
248 |
+
|
249 |
+
if x.dim() not in [2, 3]:
|
250 |
+
raise RuntimeError(
|
251 |
+
"{} accept 3/4D tensor as input".format(self.__name__)
|
252 |
+
)
|
253 |
+
x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
|
254 |
+
|
255 |
+
if torch.squeeze(x).dim() == 1:
|
256 |
+
x = torch.squeeze(x, dim=1)
|
257 |
+
else:
|
258 |
+
x = torch.squeeze(x)
|
259 |
+
return x
|
260 |
+
|
261 |
+
|
262 |
+
class IdentityBlock:
|
263 |
+
"""This block is used when we want to have identity transformation within the Dual_path block.
|
264 |
+
|
265 |
+
Example
|
266 |
+
-------
|
267 |
+
>>> x = torch.randn(10, 100)
|
268 |
+
>>> IB = IdentityBlock()
|
269 |
+
>>> xhat = IB(x)
|
270 |
+
"""
|
271 |
+
|
272 |
+
def _init__(self, **kwargs):
|
273 |
+
pass
|
274 |
+
|
275 |
+
def __call__(self, x):
|
276 |
+
return x
|
277 |
+
|
278 |
+
|
279 |
+
class MossFormerM(nn.Module):
|
280 |
+
"""This class implements the transformer encoder.
|
281 |
+
|
282 |
+
Arguments
|
283 |
+
---------
|
284 |
+
num_blocks : int
|
285 |
+
Number of mossformer blocks to include.
|
286 |
+
d_model : int
|
287 |
+
The dimension of the input embedding.
|
288 |
+
attn_dropout : float
|
289 |
+
Dropout for the self-attention (Optional).
|
290 |
+
group_size: int
|
291 |
+
the chunk size
|
292 |
+
query_key_dim: int
|
293 |
+
the attention vector dimension
|
294 |
+
expansion_factor: int
|
295 |
+
the expansion factor for the linear projection in conv module
|
296 |
+
causal: bool
|
297 |
+
true for causal / false for non causal
|
298 |
+
|
299 |
+
Example
|
300 |
+
-------
|
301 |
+
>>> import torch
|
302 |
+
>>> x = torch.rand((8, 60, 512))
|
303 |
+
>>> net = TransformerEncoder_MossFormerM(num_blocks=8, d_model=512)
|
304 |
+
>>> output, _ = net(x)
|
305 |
+
>>> output.shape
|
306 |
+
torch.Size([8, 60, 512])
|
307 |
+
"""
|
308 |
+
def __init__(
|
309 |
+
self,
|
310 |
+
num_blocks,
|
311 |
+
d_model=None,
|
312 |
+
causal=False,
|
313 |
+
group_size = 256,
|
314 |
+
query_key_dim = 128,
|
315 |
+
expansion_factor = 4.,
|
316 |
+
attn_dropout = 0.1
|
317 |
+
):
|
318 |
+
super().__init__()
|
319 |
+
|
320 |
+
self.mossformerM = MossformerBlock_GFSMN(
|
321 |
+
dim=d_model,
|
322 |
+
depth=num_blocks,
|
323 |
+
group_size=group_size,
|
324 |
+
query_key_dim=query_key_dim,
|
325 |
+
expansion_factor=expansion_factor,
|
326 |
+
causal=causal,
|
327 |
+
attn_dropout=attn_dropout
|
328 |
+
)
|
329 |
+
self.norm = nn.LayerNorm(d_model, eps=1e-6)
|
330 |
+
def forward(
|
331 |
+
self,
|
332 |
+
src,
|
333 |
+
):
|
334 |
+
"""
|
335 |
+
Arguments
|
336 |
+
----------
|
337 |
+
src : torch.Tensor
|
338 |
+
Tensor shape [B, L, N],
|
339 |
+
where, B = Batchsize,
|
340 |
+
L = time points
|
341 |
+
N = number of filters
|
342 |
+
The sequence to the encoder layer (required).
|
343 |
+
src_mask : tensor
|
344 |
+
The mask for the src sequence (optional).
|
345 |
+
src_key_padding_mask : tensor
|
346 |
+
The mask for the src keys per batch (optional).
|
347 |
+
"""
|
348 |
+
output = self.mossformerM(src)
|
349 |
+
output = self.norm(output)
|
350 |
+
|
351 |
+
return output
|
352 |
+
|
353 |
+
class MossFormerM2(nn.Module):
|
354 |
+
"""This class implements the transformer encoder.
|
355 |
+
|
356 |
+
Arguments
|
357 |
+
---------
|
358 |
+
num_blocks : int
|
359 |
+
Number of mossformer blocks to include.
|
360 |
+
d_model : int
|
361 |
+
The dimension of the input embedding.
|
362 |
+
attn_dropout : float
|
363 |
+
Dropout for the self-attention (Optional).
|
364 |
+
group_size: int
|
365 |
+
the chunk size
|
366 |
+
query_key_dim: int
|
367 |
+
the attention vector dimension
|
368 |
+
expansion_factor: int
|
369 |
+
the expansion factor for the linear projection in conv module
|
370 |
+
causal: bool
|
371 |
+
true for causal / false for non causal
|
372 |
+
|
373 |
+
Example
|
374 |
+
-------
|
375 |
+
>>> import torch
|
376 |
+
>>> x = torch.rand((8, 60, 512))
|
377 |
+
>>> net = TransformerEncoder_MossFormerM(num_blocks=8, d_model=512)
|
378 |
+
>>> output, _ = net(x)
|
379 |
+
>>> output.shape
|
380 |
+
torch.Size([8, 60, 512])
|
381 |
+
"""
|
382 |
+
def __init__(
|
383 |
+
self,
|
384 |
+
num_blocks,
|
385 |
+
d_model=None,
|
386 |
+
causal=False,
|
387 |
+
group_size = 256,
|
388 |
+
query_key_dim = 128,
|
389 |
+
expansion_factor = 4.,
|
390 |
+
attn_dropout = 0.1
|
391 |
+
):
|
392 |
+
super().__init__()
|
393 |
+
|
394 |
+
self.mossformerM = MossformerBlock(
|
395 |
+
dim=d_model,
|
396 |
+
depth=num_blocks,
|
397 |
+
group_size=group_size,
|
398 |
+
query_key_dim=query_key_dim,
|
399 |
+
expansion_factor=expansion_factor,
|
400 |
+
causal=causal,
|
401 |
+
attn_dropout=attn_dropout
|
402 |
+
)
|
403 |
+
self.norm = nn.LayerNorm(d_model, eps=1e-6)
|
404 |
+
|
405 |
+
def forward(
|
406 |
+
self,
|
407 |
+
src,
|
408 |
+
):
|
409 |
+
"""
|
410 |
+
Arguments
|
411 |
+
----------
|
412 |
+
src : torch.Tensor
|
413 |
+
Tensor shape [B, L, N],
|
414 |
+
where, B = Batchsize,
|
415 |
+
L = time points
|
416 |
+
N = number of filters
|
417 |
+
The sequence to the encoder layer (required).
|
418 |
+
src_mask : tensor
|
419 |
+
The mask for the src sequence (optional).
|
420 |
+
src_key_padding_mask : tensor
|
421 |
+
The mask for the src keys per batch (optional).
|
422 |
+
"""
|
423 |
+
output = self.mossformerM(src)
|
424 |
+
output = self.norm(output)
|
425 |
+
|
426 |
+
return output
|
427 |
+
|
428 |
+
class Computation_Block(nn.Module):
|
429 |
+
"""Computation block for dual-path processing.
|
430 |
+
|
431 |
+
Arguments
|
432 |
+
---------
|
433 |
+
intra_mdl : torch.nn.module
|
434 |
+
Model to process within the chunks.
|
435 |
+
inter_mdl : torch.nn.module
|
436 |
+
Model to process across the chunks.
|
437 |
+
out_channels : int
|
438 |
+
Dimensionality of inter/intra model.
|
439 |
+
norm : str
|
440 |
+
Normalization type.
|
441 |
+
skip_around_intra : bool
|
442 |
+
Skip connection around the intra layer.
|
443 |
+
linear_layer_after_inter_intra : bool
|
444 |
+
Linear layer or not after inter or intra.
|
445 |
+
|
446 |
+
Example
|
447 |
+
---------
|
448 |
+
>>> comp_block = Computation_Block(64)
|
449 |
+
>>> x = torch.randn(10, 64, 100)
|
450 |
+
>>> x = comp_block(x)
|
451 |
+
>>> x.shape
|
452 |
+
torch.Size([10, 64, 100])
|
453 |
+
"""
|
454 |
+
|
455 |
+
def __init__(
|
456 |
+
self,
|
457 |
+
num_blocks,
|
458 |
+
out_channels,
|
459 |
+
norm="ln",
|
460 |
+
skip_around_intra=True,
|
461 |
+
):
|
462 |
+
super(Computation_Block, self).__init__()
|
463 |
+
|
464 |
+
##MossFormer+: MossFormer with recurrence
|
465 |
+
self.intra_mdl = MossFormerM(num_blocks=num_blocks, d_model=out_channels)
|
466 |
+
##MossFormerM2: the orignal MossFormer
|
467 |
+
#self.intra_mdl = MossFormerM2(num_blocks=num_blocks, d_model=out_channels)
|
468 |
+
self.skip_around_intra = skip_around_intra
|
469 |
+
|
470 |
+
# Norm
|
471 |
+
self.norm = norm
|
472 |
+
if norm is not None:
|
473 |
+
self.intra_norm = select_norm(norm, out_channels, 3)
|
474 |
+
|
475 |
+
def forward(self, x):
|
476 |
+
"""Returns the output tensor.
|
477 |
+
|
478 |
+
Arguments
|
479 |
+
---------
|
480 |
+
x : torch.Tensor
|
481 |
+
Input tensor of dimension [B, N, S].
|
482 |
+
|
483 |
+
|
484 |
+
Return
|
485 |
+
---------
|
486 |
+
out: torch.Tensor
|
487 |
+
Output tensor of dimension [B, N, S].
|
488 |
+
where, B = Batchsize,
|
489 |
+
N = number of filters
|
490 |
+
S = sequence time index
|
491 |
+
"""
|
492 |
+
B, N, S = x.shape
|
493 |
+
# intra RNN
|
494 |
+
# [B, S, N]
|
495 |
+
intra = x.permute(0, 2, 1).contiguous() #.view(B, S, N)
|
496 |
+
|
497 |
+
intra = self.intra_mdl(intra)
|
498 |
+
|
499 |
+
# [B, N, S]
|
500 |
+
intra = intra.permute(0, 2, 1).contiguous()
|
501 |
+
if self.norm is not None:
|
502 |
+
intra = self.intra_norm(intra)
|
503 |
+
|
504 |
+
# [B, N, S]
|
505 |
+
if self.skip_around_intra:
|
506 |
+
intra = intra + x
|
507 |
+
|
508 |
+
out = intra
|
509 |
+
return out
|
510 |
+
|
511 |
+
|
512 |
+
class MossFormer_MaskNet(nn.Module):
|
513 |
+
"""The dual path model which is the basis for dualpathrnn, sepformer, dptnet.
|
514 |
+
|
515 |
+
Arguments
|
516 |
+
---------
|
517 |
+
in_channels : int
|
518 |
+
Number of channels at the output of the encoder.
|
519 |
+
out_channels : int
|
520 |
+
Number of channels that would be inputted to the intra and inter blocks.
|
521 |
+
intra_model : torch.nn.module
|
522 |
+
Model to process within the chunks.
|
523 |
+
num_layers : int
|
524 |
+
Number of layers of Dual Computation Block.
|
525 |
+
norm : str
|
526 |
+
Normalization type.
|
527 |
+
num_spks : int
|
528 |
+
Number of sources (speakers).
|
529 |
+
skip_around_intra : bool
|
530 |
+
Skip connection around intra.
|
531 |
+
use_global_pos_enc : bool
|
532 |
+
Global positional encodings.
|
533 |
+
max_length : int
|
534 |
+
Maximum sequence length.
|
535 |
+
|
536 |
+
Example
|
537 |
+
---------
|
538 |
+
>>> mossformer_block = MossFormerM(1, 64, 8)
|
539 |
+
>>> mossformer_masknet = MossFormer_MaskNet(64, 64, intra_block, num_spks=2)
|
540 |
+
>>> x = torch.randn(10, 64, 2000)
|
541 |
+
>>> x = mossformer_masknet(x)
|
542 |
+
>>> x.shape
|
543 |
+
torch.Size([2, 10, 64, 2000])
|
544 |
+
"""
|
545 |
+
|
546 |
+
def __init__(
|
547 |
+
self,
|
548 |
+
in_channels,
|
549 |
+
out_channels,
|
550 |
+
out_channels_final,
|
551 |
+
num_blocks=24,
|
552 |
+
norm="ln",
|
553 |
+
num_spks=1,
|
554 |
+
skip_around_intra=True,
|
555 |
+
use_global_pos_enc=True,
|
556 |
+
max_length=20000,
|
557 |
+
):
|
558 |
+
super(MossFormer_MaskNet, self).__init__()
|
559 |
+
self.num_spks = num_spks
|
560 |
+
self.num_blocks = num_blocks
|
561 |
+
self.norm = select_norm(norm, in_channels, 3)
|
562 |
+
self.conv1d_encoder = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
563 |
+
self.use_global_pos_enc = use_global_pos_enc
|
564 |
+
|
565 |
+
if self.use_global_pos_enc:
|
566 |
+
self.pos_enc = ScaledSinuEmbedding(out_channels)
|
567 |
+
|
568 |
+
self.mdl = Computation_Block(
|
569 |
+
num_blocks,
|
570 |
+
out_channels,
|
571 |
+
norm,
|
572 |
+
skip_around_intra=skip_around_intra,
|
573 |
+
)
|
574 |
+
|
575 |
+
self.conv1d_out = nn.Conv1d(
|
576 |
+
out_channels, out_channels * num_spks, kernel_size=1
|
577 |
+
)
|
578 |
+
self.conv1_decoder = nn.Conv1d(out_channels, out_channels_final, 1, bias=False)
|
579 |
+
self.prelu = nn.PReLU()
|
580 |
+
self.activation = nn.ReLU()
|
581 |
+
# gated output layer
|
582 |
+
self.output = nn.Sequential(
|
583 |
+
nn.Conv1d(out_channels, out_channels, 1), nn.Tanh()
|
584 |
+
)
|
585 |
+
self.output_gate = nn.Sequential(
|
586 |
+
nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid()
|
587 |
+
)
|
588 |
+
|
589 |
+
def forward(self, x):
|
590 |
+
"""Returns the output tensor.
|
591 |
+
|
592 |
+
Arguments
|
593 |
+
---------
|
594 |
+
x : torch.Tensor
|
595 |
+
Input tensor of dimension [B, N, S].
|
596 |
+
|
597 |
+
Returns
|
598 |
+
-------
|
599 |
+
out : torch.Tensor
|
600 |
+
Output tensor of dimension [spks, B, N, S]
|
601 |
+
where, spks = Number of speakers
|
602 |
+
B = Batchsize,
|
603 |
+
N = number of filters
|
604 |
+
S = the number of time frames
|
605 |
+
"""
|
606 |
+
|
607 |
+
# before each line we indicate the shape after executing the line
|
608 |
+
|
609 |
+
# [B, N, L]
|
610 |
+
x = self.norm(x)
|
611 |
+
|
612 |
+
# [B, N, L]
|
613 |
+
x = self.conv1d_encoder(x)
|
614 |
+
if self.use_global_pos_enc:
|
615 |
+
#x = self.pos_enc(x.transpose(1, -1)).transpose(1, -1) + x * (
|
616 |
+
# x.size(1) ** 0.5)
|
617 |
+
base = x
|
618 |
+
x = x.transpose(1, -1)
|
619 |
+
emb = self.pos_enc(x)
|
620 |
+
emb = emb.transpose(0, -1)
|
621 |
+
#print('base: {}, emb: {}'.format(base.shape, emb.shape))
|
622 |
+
x = base + emb
|
623 |
+
|
624 |
+
|
625 |
+
# [B, N, S]
|
626 |
+
#for i in range(self.num_modules):
|
627 |
+
# x = self.dual_mdl[i](x)
|
628 |
+
x = self.mdl(x)
|
629 |
+
x = self.prelu(x)
|
630 |
+
|
631 |
+
# [B, N*spks, S]
|
632 |
+
x = self.conv1d_out(x)
|
633 |
+
B, _, S = x.shape
|
634 |
+
|
635 |
+
# [B*spks, N, S]
|
636 |
+
x = x.view(B * self.num_spks, -1, S)
|
637 |
+
|
638 |
+
# [B*spks, N, S]
|
639 |
+
x = self.output(x) * self.output_gate(x)
|
640 |
+
|
641 |
+
# [B*spks, N, S]
|
642 |
+
x = self.conv1_decoder(x)
|
643 |
+
|
644 |
+
# [B, spks, N, S]
|
645 |
+
_, N, L = x.shape
|
646 |
+
x = x.view(B, self.num_spks, N, L)
|
647 |
+
x = self.activation(x)
|
648 |
+
|
649 |
+
# [spks, B, N, S]
|
650 |
+
x = x.transpose(0, 1)
|
651 |
+
|
652 |
+
return x[0]
|
653 |
+
|
654 |
+
class MossFormer(nn.Module):
|
655 |
+
def __init__(
|
656 |
+
self,
|
657 |
+
in_channels=512,
|
658 |
+
out_channels=512,
|
659 |
+
num_blocks=24,
|
660 |
+
kernel_size=16,
|
661 |
+
norm="ln",
|
662 |
+
num_spks=2,
|
663 |
+
skip_around_intra=True,
|
664 |
+
use_global_pos_enc=True,
|
665 |
+
max_length=20000,
|
666 |
+
):
|
667 |
+
super(MossFormer, self).__init__()
|
668 |
+
self.num_spks = num_spks
|
669 |
+
self.enc = Encoder(kernel_size=kernel_size, out_channels=in_channels, in_channels=180)
|
670 |
+
self.mask_net = MossFormer_MaskNet(
|
671 |
+
in_channels=in_channels,
|
672 |
+
out_channels=out_channels,
|
673 |
+
num_blocks=num_blocks,
|
674 |
+
norm=norm,
|
675 |
+
num_spks=num_spks,
|
676 |
+
skip_around_intra=skip_around_intra,
|
677 |
+
use_global_pos_enc=use_global_pos_enc,
|
678 |
+
max_length=max_length,
|
679 |
+
)
|
680 |
+
self.dec = Decoder(
|
681 |
+
in_channels=out_channels,
|
682 |
+
out_channels=1,
|
683 |
+
kernel_size=kernel_size,
|
684 |
+
stride = kernel_size//2,
|
685 |
+
bias=False
|
686 |
+
)
|
687 |
+
def forward(self, input):
|
688 |
+
x = self.enc(input)
|
689 |
+
mask = self.mask_net(x)
|
690 |
+
x = torch.stack([x] * self.num_spks)
|
691 |
+
sep_x = x * mask
|
692 |
+
|
693 |
+
# Decoding
|
694 |
+
est_source = torch.cat(
|
695 |
+
[
|
696 |
+
self.dec(sep_x[i]).unsqueeze(-1)
|
697 |
+
for i in range(self.num_spks)
|
698 |
+
],
|
699 |
+
dim=-1,
|
700 |
+
)
|
701 |
+
T_origin = input.size(1)
|
702 |
+
T_est = est_source.size(1)
|
703 |
+
if T_origin > T_est:
|
704 |
+
est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est))
|
705 |
+
else:
|
706 |
+
est_source = est_source[:, :T_origin, :]
|
707 |
+
|
708 |
+
out = []
|
709 |
+
for spk in range(self.num_spks):
|
710 |
+
out.append(est_source[:,:,spk])
|
711 |
+
return out
|
models/mossformer2_sr/mossformer2_block.py
ADDED
@@ -0,0 +1,735 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This source code is modified by Shengkui Zhao based on https://github.com/lucidrains/FLASH-pytorch
|
3 |
+
"""
|
4 |
+
|
5 |
+
import math
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import nn, einsum
|
9 |
+
from einops import rearrange
|
10 |
+
from rotary_embedding_torch import RotaryEmbedding
|
11 |
+
from models.mossformer2_se.conv_module import ConvModule, GLU, FFConvM_Dilated
|
12 |
+
from models.mossformer2_se.fsmn import UniDeepFsmn, UniDeepFsmn_dilated
|
13 |
+
from torchinfo import summary
|
14 |
+
from models.mossformer2_se.layer_norm import CLayerNorm, GLayerNorm, GlobLayerNorm, ILayerNorm
|
15 |
+
|
16 |
+
# Helper functions
|
17 |
+
|
18 |
+
def identity(t, *args, **kwargs):
|
19 |
+
"""
|
20 |
+
Returns the input tensor unchanged.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
t (torch.Tensor): Input tensor.
|
24 |
+
*args: Additional arguments (ignored).
|
25 |
+
**kwargs: Additional keyword arguments (ignored).
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
torch.Tensor: The input tensor.
|
29 |
+
"""
|
30 |
+
return t
|
31 |
+
|
32 |
+
def append_dims(x, num_dims):
|
33 |
+
"""
|
34 |
+
Adds additional dimensions to the input tensor.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
x (torch.Tensor): Input tensor.
|
38 |
+
num_dims (int): Number of dimensions to append.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
torch.Tensor: Tensor with appended dimensions.
|
42 |
+
"""
|
43 |
+
if num_dims <= 0:
|
44 |
+
return x
|
45 |
+
return x.view(*x.shape, *((1,) * num_dims)) # Reshape to append dimensions
|
46 |
+
|
47 |
+
def exists(val):
|
48 |
+
"""
|
49 |
+
Checks if a value exists (is not None).
|
50 |
+
|
51 |
+
Args:
|
52 |
+
val: The value to check.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
bool: True if value exists, False otherwise.
|
56 |
+
"""
|
57 |
+
return val is not None
|
58 |
+
|
59 |
+
def default(val, d):
|
60 |
+
"""
|
61 |
+
Returns a default value if the given value does not exist.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
val: The value to check.
|
65 |
+
d: Default value to return if val does not exist.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
The original value if it exists, otherwise the default value.
|
69 |
+
"""
|
70 |
+
return val if exists(val) else d
|
71 |
+
|
72 |
+
def padding_to_multiple_of(n, mult):
|
73 |
+
"""
|
74 |
+
Calculates the amount of padding needed to make a number a multiple of another.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
n (int): The number to pad.
|
78 |
+
mult (int): The multiple to match.
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
int: The padding amount required to make n a multiple of mult.
|
82 |
+
"""
|
83 |
+
remainder = n % mult
|
84 |
+
if remainder == 0:
|
85 |
+
return 0
|
86 |
+
return mult - remainder # Return the required padding
|
87 |
+
|
88 |
+
# Scale Normalization class
|
89 |
+
|
90 |
+
class ScaleNorm(nn.Module):
|
91 |
+
"""
|
92 |
+
ScaleNorm implements a scaled normalization technique for neural network layers.
|
93 |
+
|
94 |
+
Attributes:
|
95 |
+
dim (int): Dimension of the input features.
|
96 |
+
eps (float): Small value to prevent division by zero.
|
97 |
+
g (nn.Parameter): Learnable parameter for scaling.
|
98 |
+
"""
|
99 |
+
|
100 |
+
def __init__(self, dim, eps=1e-5):
|
101 |
+
super().__init__()
|
102 |
+
self.scale = dim ** -0.5 # Calculate scale factor
|
103 |
+
self.eps = eps # Set epsilon
|
104 |
+
self.g = nn.Parameter(torch.ones(1)) # Initialize scaling parameter
|
105 |
+
|
106 |
+
def forward(self, x):
|
107 |
+
"""
|
108 |
+
Forward pass for the ScaleNorm layer.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
x (torch.Tensor): Input tensor.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
torch.Tensor: Scaled and normalized output tensor.
|
115 |
+
"""
|
116 |
+
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale # Compute norm
|
117 |
+
return x / norm.clamp(min=self.eps) * self.g # Normalize and scale
|
118 |
+
|
119 |
+
# Absolute positional encodings class
|
120 |
+
|
121 |
+
class ScaledSinuEmbedding(nn.Module):
|
122 |
+
"""
|
123 |
+
ScaledSinuEmbedding provides sinusoidal positional encodings for inputs.
|
124 |
+
|
125 |
+
Attributes:
|
126 |
+
scale (nn.Parameter): Learnable scale factor for the embeddings.
|
127 |
+
inv_freq (torch.Tensor): Inverse frequency used for sine and cosine calculations.
|
128 |
+
"""
|
129 |
+
|
130 |
+
def __init__(self, dim):
|
131 |
+
super().__init__()
|
132 |
+
self.scale = nn.Parameter(torch.ones(1,)) # Initialize scale
|
133 |
+
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) # Calculate inverse frequency
|
134 |
+
self.register_buffer('inv_freq', inv_freq) # Register as a buffer
|
135 |
+
|
136 |
+
def forward(self, x):
|
137 |
+
"""
|
138 |
+
Forward pass for the ScaledSinuEmbedding layer.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
x (torch.Tensor): Input tensor of shape (batch_size, sequence_length).
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
torch.Tensor: Positional encoding tensor of shape (batch_size, sequence_length, dim).
|
145 |
+
"""
|
146 |
+
n, device = x.shape[1], x.device # Extract sequence length and device
|
147 |
+
t = torch.arange(n, device=device).type_as(self.inv_freq) # Create time steps
|
148 |
+
sinu = einsum('i , j -> i j', t, self.inv_freq) # Calculate sine and cosine embeddings
|
149 |
+
emb = torch.cat((sinu.sin(), sinu.cos()), dim=-1) # Concatenate sine and cosine embeddings
|
150 |
+
return emb * self.scale # Scale the embeddings
|
151 |
+
|
152 |
+
class OffsetScale(nn.Module):
|
153 |
+
"""
|
154 |
+
OffsetScale applies learned offsets and scales to the input tensor.
|
155 |
+
|
156 |
+
Attributes:
|
157 |
+
gamma (nn.Parameter): Learnable scale parameter for each head.
|
158 |
+
beta (nn.Parameter): Learnable offset parameter for each head.
|
159 |
+
"""
|
160 |
+
|
161 |
+
def __init__(self, dim, heads=1):
|
162 |
+
super().__init__()
|
163 |
+
self.gamma = nn.Parameter(torch.ones(heads, dim)) # Initialize scale parameters
|
164 |
+
self.beta = nn.Parameter(torch.zeros(heads, dim)) # Initialize offset parameters
|
165 |
+
nn.init.normal_(self.gamma, std=0.02) # Normal initialization for gamma
|
166 |
+
|
167 |
+
def forward(self, x):
|
168 |
+
"""
|
169 |
+
Forward pass for the OffsetScale layer.
|
170 |
+
|
171 |
+
Args:
|
172 |
+
x (torch.Tensor): Input tensor.
|
173 |
+
|
174 |
+
Returns:
|
175 |
+
List[torch.Tensor]: A list of tensors with applied offsets and scales for each head.
|
176 |
+
"""
|
177 |
+
out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta # Apply scaling and offsets
|
178 |
+
return out.unbind(dim=-2) # Unbind heads into a list
|
179 |
+
|
180 |
+
# Feed-Forward Convolutional Module
|
181 |
+
|
182 |
+
class FFConvM(nn.Module):
|
183 |
+
"""
|
184 |
+
FFConvM is a feed-forward convolutional module with normalization and dropout.
|
185 |
+
|
186 |
+
Attributes:
|
187 |
+
dim_in (int): Input dimension of the features.
|
188 |
+
dim_out (int): Output dimension after processing.
|
189 |
+
norm_klass (nn.Module): Normalization class to be used.
|
190 |
+
dropout (float): Dropout probability.
|
191 |
+
"""
|
192 |
+
|
193 |
+
def __init__(
|
194 |
+
self,
|
195 |
+
dim_in,
|
196 |
+
dim_out,
|
197 |
+
norm_klass=nn.LayerNorm,
|
198 |
+
dropout=0.1
|
199 |
+
):
|
200 |
+
super().__init__()
|
201 |
+
self.mdl = nn.Sequential(
|
202 |
+
norm_klass(dim_in), # Normalize input
|
203 |
+
nn.Linear(dim_in, dim_out), # Linear transformation
|
204 |
+
nn.SiLU(), # Activation function
|
205 |
+
ConvModule(dim_out), # Convolution module
|
206 |
+
nn.Dropout(dropout) # Apply dropout
|
207 |
+
)
|
208 |
+
|
209 |
+
def forward(self, x):
|
210 |
+
"""
|
211 |
+
Forward pass for the FFConvM module.
|
212 |
+
|
213 |
+
Args:
|
214 |
+
x (torch.Tensor): Input tensor.
|
215 |
+
|
216 |
+
Returns:
|
217 |
+
torch.Tensor: Output tensor after processing.
|
218 |
+
"""
|
219 |
+
output = self.mdl(x) # Pass through the model
|
220 |
+
return output
|
221 |
+
|
222 |
+
class FFM(nn.Module):
|
223 |
+
"""
|
224 |
+
FFM is a feed-forward module with normalization and dropout.
|
225 |
+
|
226 |
+
Attributes:
|
227 |
+
dim_in (int): Input dimension of the features.
|
228 |
+
dim_out (int): Output dimension after processing.
|
229 |
+
norm_klass (nn.Module): Normalization class to be used.
|
230 |
+
dropout (float): Dropout probability.
|
231 |
+
"""
|
232 |
+
|
233 |
+
def __init__(
|
234 |
+
self,
|
235 |
+
dim_in,
|
236 |
+
dim_out,
|
237 |
+
norm_klass=nn.LayerNorm,
|
238 |
+
dropout=0.1
|
239 |
+
):
|
240 |
+
super().__init__()
|
241 |
+
self.mdl = nn.Sequential(
|
242 |
+
norm_klass(dim_in), # Normalize input
|
243 |
+
nn.Linear(dim_in, dim_out), # Linear transformation
|
244 |
+
nn.SiLU(), # Activation function
|
245 |
+
nn.Dropout(dropout) # Apply dropout
|
246 |
+
)
|
247 |
+
|
248 |
+
def forward(self, x):
|
249 |
+
"""
|
250 |
+
Forward pass for the FFM module.
|
251 |
+
|
252 |
+
Args:
|
253 |
+
x (torch.Tensor): Input tensor.
|
254 |
+
|
255 |
+
Returns:
|
256 |
+
torch.Tensor: Output tensor after processing.
|
257 |
+
"""
|
258 |
+
output = self.mdl(x) # Pass through the model
|
259 |
+
return output
|
260 |
+
|
261 |
+
class FLASH_ShareA_FFConvM(nn.Module):
|
262 |
+
"""
|
263 |
+
Fast Shared Dual Attention Mechanism with feed-forward convolutional blocks.
|
264 |
+
Published in paper: "MossFormer: Pushing the Performance Limit of Monaural Speech Separation
|
265 |
+
using Gated Single-Head Transformer with Convolution-Augmented Joint Self-Attentions", ICASSP 2023.
|
266 |
+
(https://arxiv.org/abs/2302.11824)
|
267 |
+
|
268 |
+
Args:
|
269 |
+
dim (int): Input dimension.
|
270 |
+
group_size (int, optional): Size of groups for processing. Defaults to 256.
|
271 |
+
query_key_dim (int, optional): Dimension of the query and key. Defaults to 128.
|
272 |
+
expansion_factor (float, optional): Factor to expand the hidden dimension. Defaults to 1.
|
273 |
+
causal (bool, optional): Whether to use causal masking. Defaults to False.
|
274 |
+
dropout (float, optional): Dropout rate. Defaults to 0.1.
|
275 |
+
rotary_pos_emb (optional): Rotary positional embeddings for attention. Defaults to None.
|
276 |
+
norm_klass (callable, optional): Normalization class to use. Defaults to nn.LayerNorm.
|
277 |
+
shift_tokens (bool, optional): Whether to shift tokens for attention calculation. Defaults to True.
|
278 |
+
"""
|
279 |
+
|
280 |
+
def __init__(
|
281 |
+
self,
|
282 |
+
*,
|
283 |
+
dim,
|
284 |
+
group_size=256,
|
285 |
+
query_key_dim=128,
|
286 |
+
expansion_factor=1.,
|
287 |
+
causal=False,
|
288 |
+
dropout=0.1,
|
289 |
+
rotary_pos_emb=None,
|
290 |
+
norm_klass=nn.LayerNorm,
|
291 |
+
shift_tokens=True
|
292 |
+
):
|
293 |
+
super().__init__()
|
294 |
+
hidden_dim = int(dim * expansion_factor)
|
295 |
+
self.group_size = group_size
|
296 |
+
self.causal = causal
|
297 |
+
self.shift_tokens = shift_tokens
|
298 |
+
|
299 |
+
# Initialize positional embeddings, dropout, and projections
|
300 |
+
self.rotary_pos_emb = rotary_pos_emb
|
301 |
+
self.dropout = nn.Dropout(dropout)
|
302 |
+
|
303 |
+
# Feed-forward layers
|
304 |
+
self.to_hidden = FFConvM(
|
305 |
+
dim_in=dim,
|
306 |
+
dim_out=hidden_dim,
|
307 |
+
norm_klass=norm_klass,
|
308 |
+
dropout=dropout,
|
309 |
+
)
|
310 |
+
self.to_qk = FFConvM(
|
311 |
+
dim_in=dim,
|
312 |
+
dim_out=query_key_dim,
|
313 |
+
norm_klass=norm_klass,
|
314 |
+
dropout=dropout,
|
315 |
+
)
|
316 |
+
|
317 |
+
# Offset and scale for query and key
|
318 |
+
self.qk_offset_scale = OffsetScale(query_key_dim, heads=4)
|
319 |
+
|
320 |
+
self.to_out = FFConvM(
|
321 |
+
dim_in=dim * 2,
|
322 |
+
dim_out=dim,
|
323 |
+
norm_klass=norm_klass,
|
324 |
+
dropout=dropout,
|
325 |
+
)
|
326 |
+
|
327 |
+
self.gateActivate = nn.Sigmoid()
|
328 |
+
|
329 |
+
def forward(self, x, *, mask=None):
|
330 |
+
"""
|
331 |
+
Forward pass for FLASH layer.
|
332 |
+
|
333 |
+
Args:
|
334 |
+
x (Tensor): Input tensor of shape (batch, seq_len, features).
|
335 |
+
mask (Tensor, optional): Mask for attention. Defaults to None.
|
336 |
+
|
337 |
+
Returns:
|
338 |
+
Tensor: Output tensor after applying attention and projections.
|
339 |
+
"""
|
340 |
+
|
341 |
+
# Pre-normalization step
|
342 |
+
normed_x = x
|
343 |
+
residual = x # Save residual for skip connection
|
344 |
+
|
345 |
+
# Token shifting if enabled
|
346 |
+
if self.shift_tokens:
|
347 |
+
x_shift, x_pass = normed_x.chunk(2, dim=-1)
|
348 |
+
x_shift = F.pad(x_shift, (0, 0, 1, -1), value=0.)
|
349 |
+
normed_x = torch.cat((x_shift, x_pass), dim=-1)
|
350 |
+
|
351 |
+
# Initial projections
|
352 |
+
v, u = self.to_hidden(normed_x).chunk(2, dim=-1)
|
353 |
+
qk = self.to_qk(normed_x)
|
354 |
+
|
355 |
+
# Offset and scale
|
356 |
+
quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk)
|
357 |
+
att_v, att_u = self.cal_attention(x, quad_q, lin_q, quad_k, lin_k, v, u)
|
358 |
+
|
359 |
+
# Output calculation with gating
|
360 |
+
out = (att_u * v) * self.gateActivate(att_v * u)
|
361 |
+
x = x + self.to_out(out) # Residual connection
|
362 |
+
return x
|
363 |
+
|
364 |
+
def cal_attention(self, x, quad_q, lin_q, quad_k, lin_k, v, u, mask=None):
|
365 |
+
"""
|
366 |
+
Calculate attention output using quadratic and linear attention mechanisms.
|
367 |
+
|
368 |
+
Args:
|
369 |
+
x (Tensor): Input tensor of shape (batch, seq_len, features).
|
370 |
+
quad_q (Tensor): Quadratic query representation.
|
371 |
+
lin_q (Tensor): Linear query representation.
|
372 |
+
quad_k (Tensor): Quadratic key representation.
|
373 |
+
lin_k (Tensor): Linear key representation.
|
374 |
+
v (Tensor): Value representation.
|
375 |
+
u (Tensor): Additional value representation.
|
376 |
+
mask (Tensor, optional): Mask for attention. Defaults to None.
|
377 |
+
|
378 |
+
Returns:
|
379 |
+
Tuple[Tensor, Tensor]: Attention outputs for v and u.
|
380 |
+
"""
|
381 |
+
b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size
|
382 |
+
|
383 |
+
# Apply mask to linear keys if provided
|
384 |
+
if exists(mask):
|
385 |
+
lin_mask = rearrange(mask, '... -> ... 1')
|
386 |
+
lin_k = lin_k.masked_fill(~lin_mask, 0.)
|
387 |
+
|
388 |
+
# Rotate queries and keys with rotary positional embeddings
|
389 |
+
if exists(self.rotary_pos_emb):
|
390 |
+
quad_q, lin_q, quad_k, lin_k = map(self.rotary_pos_emb.rotate_queries_or_keys, (quad_q, lin_q, quad_k, lin_k))
|
391 |
+
|
392 |
+
# Padding for group processing
|
393 |
+
padding = padding_to_multiple_of(n, g)
|
394 |
+
if padding > 0:
|
395 |
+
quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: F.pad(t, (0, 0, 0, padding), value=0.), (quad_q, quad_k, lin_q, lin_k, v, u))
|
396 |
+
mask = default(mask, torch.ones((b, n), device=device, dtype=torch.bool))
|
397 |
+
mask = F.pad(mask, (0, padding), value=False)
|
398 |
+
|
399 |
+
# Group along sequence for attention
|
400 |
+
quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: rearrange(t, 'b (g n) d -> b g n d', n=self.group_size), (quad_q, quad_k, lin_q, lin_k, v, u))
|
401 |
+
|
402 |
+
if exists(mask):
|
403 |
+
mask = rearrange(mask, 'b (g j) -> b g 1 j', j=g)
|
404 |
+
|
405 |
+
# Calculate quadratic attention output
|
406 |
+
sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / g
|
407 |
+
attn = F.relu(sim) ** 2 # ReLU activation
|
408 |
+
attn = self.dropout(attn)
|
409 |
+
|
410 |
+
# Apply mask to attention if provided
|
411 |
+
if exists(mask):
|
412 |
+
attn = attn.masked_fill(~mask, 0.)
|
413 |
+
|
414 |
+
# Apply causal mask if needed
|
415 |
+
if self.causal:
|
416 |
+
causal_mask = torch.ones((g, g), dtype=torch.bool, device=device).triu(1)
|
417 |
+
attn = attn.masked_fill(causal_mask, 0.)
|
418 |
+
|
419 |
+
# Calculate output from attention
|
420 |
+
quad_out_v = einsum('... i j, ... j d -> ... i d', attn, v)
|
421 |
+
quad_out_u = einsum('... i j, ... j d -> ... i d', attn, u)
|
422 |
+
|
423 |
+
# Calculate linear attention output
|
424 |
+
if self.causal:
|
425 |
+
lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / g
|
426 |
+
lin_kv = lin_kv.cumsum(dim=1) # Cumulative sum for linear attention
|
427 |
+
lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value=0.)
|
428 |
+
lin_out_v = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q)
|
429 |
+
|
430 |
+
lin_ku = einsum('b g n d, b g n e -> b g d e', lin_k, u) / g
|
431 |
+
lin_ku = lin_ku.cumsum(dim=1) # Cumulative sum for linear attention
|
432 |
+
lin_ku = F.pad(lin_ku, (0, 0, 0, 0, 1, -1), value=0.)
|
433 |
+
lin_out_u = einsum('b g d e, b g n d -> b g n e', lin_ku, lin_q)
|
434 |
+
else:
|
435 |
+
lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n
|
436 |
+
lin_out_v = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv)
|
437 |
+
|
438 |
+
lin_ku = einsum('b g n d, b g n e -> b d e', lin_k, u) / n
|
439 |
+
lin_out_u = einsum('b g n d, b d e -> b g n e', lin_q, lin_ku)
|
440 |
+
|
441 |
+
# Reshape and remove padding from outputs
|
442 |
+
return map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out_v + lin_out_v, quad_out_u + lin_out_u))
|
443 |
+
|
444 |
+
class Gated_FSMN(nn.Module):
|
445 |
+
"""
|
446 |
+
Gated Frequency Selective Memory Network (FSMN) class.
|
447 |
+
|
448 |
+
This class implements a gated FSMN that combines two feedforward
|
449 |
+
convolutional networks with a frequency selective memory module.
|
450 |
+
|
451 |
+
Args:
|
452 |
+
in_channels (int): Number of input channels.
|
453 |
+
out_channels (int): Number of output channels.
|
454 |
+
lorder (int): Order of the filter for FSMN.
|
455 |
+
hidden_size (int): Number of hidden units in the network.
|
456 |
+
"""
|
457 |
+
def __init__(self, in_channels, out_channels, lorder, hidden_size):
|
458 |
+
super().__init__()
|
459 |
+
# Feedforward network for the first branch (u)
|
460 |
+
self.to_u = FFConvM(
|
461 |
+
dim_in=in_channels,
|
462 |
+
dim_out=hidden_size,
|
463 |
+
norm_klass=nn.LayerNorm,
|
464 |
+
dropout=0.1,
|
465 |
+
)
|
466 |
+
# Feedforward network for the second branch (v)
|
467 |
+
self.to_v = FFConvM(
|
468 |
+
dim_in=in_channels,
|
469 |
+
dim_out=hidden_size,
|
470 |
+
norm_klass=nn.LayerNorm,
|
471 |
+
dropout=0.1,
|
472 |
+
)
|
473 |
+
# Frequency selective memory network
|
474 |
+
self.fsmn = UniDeepFsmn(in_channels, out_channels, lorder, hidden_size)
|
475 |
+
|
476 |
+
def forward(self, x):
|
477 |
+
"""
|
478 |
+
Forward pass for the Gated FSMN.
|
479 |
+
|
480 |
+
Args:
|
481 |
+
x (Tensor): Input tensor of shape (batch_size, in_channels, sequence_length).
|
482 |
+
|
483 |
+
Returns:
|
484 |
+
Tensor: Output tensor after applying gated FSMN operations.
|
485 |
+
"""
|
486 |
+
input = x
|
487 |
+
x_u = self.to_u(x) # Process input through the first branch
|
488 |
+
x_v = self.to_v(x) # Process input through the second branch
|
489 |
+
x_u = self.fsmn(x_u) # Apply FSMN to the output of the first branch
|
490 |
+
x = x_v * x_u + input # Combine outputs with the original input
|
491 |
+
return x
|
492 |
+
|
493 |
+
|
494 |
+
class Gated_FSMN_Block(nn.Module):
|
495 |
+
"""
|
496 |
+
A 1-D convolutional block that incorporates a gated FSMN.
|
497 |
+
|
498 |
+
This block consists of two convolutional layers, followed by a
|
499 |
+
gated FSMN and normalization layers.
|
500 |
+
|
501 |
+
Args:
|
502 |
+
dim (int): Dimensionality of the input.
|
503 |
+
inner_channels (int): Number of channels in the inner layers.
|
504 |
+
group_size (int): Size of the groups for normalization.
|
505 |
+
norm_type (str): Type of normalization to use ('scalenorm' or 'layernorm').
|
506 |
+
"""
|
507 |
+
def __init__(self, dim, inner_channels=256, group_size=256, norm_type='scalenorm'):
|
508 |
+
super(Gated_FSMN_Block, self).__init__()
|
509 |
+
# Choose normalization class based on the provided type
|
510 |
+
if norm_type == 'scalenorm':
|
511 |
+
norm_klass = ScaleNorm
|
512 |
+
elif norm_type == 'layernorm':
|
513 |
+
norm_klass = nn.LayerNorm
|
514 |
+
|
515 |
+
self.group_size = group_size
|
516 |
+
|
517 |
+
# First convolutional layer with PReLU activation
|
518 |
+
self.conv1 = nn.Sequential(
|
519 |
+
nn.Conv1d(dim, inner_channels, kernel_size=1),
|
520 |
+
nn.PReLU(),
|
521 |
+
)
|
522 |
+
self.norm1 = CLayerNorm(inner_channels) # Normalization after first convolution
|
523 |
+
self.gated_fsmn = Gated_FSMN(inner_channels, inner_channels, lorder=20, hidden_size=inner_channels) # Gated FSMN layer
|
524 |
+
self.norm2 = CLayerNorm(inner_channels) # Normalization after FSMN
|
525 |
+
self.conv2 = nn.Conv1d(inner_channels, dim, kernel_size=1) # Final convolutional layer
|
526 |
+
|
527 |
+
def forward(self, input):
|
528 |
+
"""
|
529 |
+
Forward pass for the Gated FSMN Block.
|
530 |
+
|
531 |
+
Args:
|
532 |
+
input (Tensor): Input tensor of shape (batch_size, dim, sequence_length).
|
533 |
+
|
534 |
+
Returns:
|
535 |
+
Tensor: Output tensor after processing through the block.
|
536 |
+
"""
|
537 |
+
conv1 = self.conv1(input.transpose(2, 1)) # Apply first convolution
|
538 |
+
norm1 = self.norm1(conv1) # Apply normalization
|
539 |
+
seq_out = self.gated_fsmn(norm1.transpose(2, 1)) # Apply gated FSMN
|
540 |
+
norm2 = self.norm2(seq_out.transpose(2, 1)) # Apply second normalization
|
541 |
+
conv2 = self.conv2(norm2) # Apply final convolution
|
542 |
+
return conv2.transpose(2, 1) + input # Residual connection
|
543 |
+
|
544 |
+
|
545 |
+
class MossformerBlock_GFSMN(nn.Module):
|
546 |
+
"""
|
547 |
+
Mossformer Block with Gated FSMN.
|
548 |
+
|
549 |
+
This block combines attention mechanisms and gated FSMN layers
|
550 |
+
to process input sequences.
|
551 |
+
|
552 |
+
Args:
|
553 |
+
dim (int): Dimensionality of the input.
|
554 |
+
depth (int): Number of layers in the block.
|
555 |
+
group_size (int): Size of the groups for normalization.
|
556 |
+
query_key_dim (int): Dimension of the query and key in attention.
|
557 |
+
expansion_factor (float): Expansion factor for feedforward layers.
|
558 |
+
causal (bool): If True, enables causal attention.
|
559 |
+
attn_dropout (float): Dropout rate for attention layers.
|
560 |
+
norm_type (str): Type of normalization to use ('scalenorm' or 'layernorm').
|
561 |
+
shift_tokens (bool): If True, shifts tokens in the attention layer.
|
562 |
+
"""
|
563 |
+
def __init__(self, *, dim, depth, group_size=256, query_key_dim=128, expansion_factor=4., causal=False, attn_dropout=0.1, norm_type='scalenorm', shift_tokens=True):
|
564 |
+
super().__init__()
|
565 |
+
assert norm_type in ('scalenorm', 'layernorm'), 'norm_type must be one of scalenorm or layernorm'
|
566 |
+
|
567 |
+
if norm_type == 'scalenorm':
|
568 |
+
norm_klass = ScaleNorm
|
569 |
+
elif norm_type == 'layernorm':
|
570 |
+
norm_klass = nn.LayerNorm
|
571 |
+
|
572 |
+
self.group_size = group_size
|
573 |
+
|
574 |
+
# Rotary positional embedding for attention
|
575 |
+
rotary_pos_emb = RotaryEmbedding(dim=min(32, query_key_dim))
|
576 |
+
|
577 |
+
# Create a list of Gated FSMN blocks
|
578 |
+
self.fsmn = nn.ModuleList([Gated_FSMN_Block(dim) for _ in range(depth)])
|
579 |
+
|
580 |
+
# Create a list of attention layers using FLASH_ShareA_FFConvM
|
581 |
+
self.layers = nn.ModuleList([
|
582 |
+
FLASH_ShareA_FFConvM(
|
583 |
+
dim=dim,
|
584 |
+
group_size=group_size,
|
585 |
+
query_key_dim=query_key_dim,
|
586 |
+
expansion_factor=expansion_factor,
|
587 |
+
causal=causal,
|
588 |
+
dropout=attn_dropout,
|
589 |
+
rotary_pos_emb=rotary_pos_emb,
|
590 |
+
norm_klass=norm_klass,
|
591 |
+
shift_tokens=shift_tokens
|
592 |
+
) for _ in range(depth)
|
593 |
+
])
|
594 |
+
|
595 |
+
def _build_repeats(self, in_channels, out_channels, lorder, hidden_size, repeats=1):
|
596 |
+
"""
|
597 |
+
Builds repeated UniDeep FSMN layers.
|
598 |
+
|
599 |
+
Args:
|
600 |
+
in_channels (int): Number of input channels.
|
601 |
+
out_channels (int): Number of output channels.
|
602 |
+
lorder (int): Order of the filter for FSMN.
|
603 |
+
hidden_size (int): Number of hidden units.
|
604 |
+
repeats (int): Number of repetitions.
|
605 |
+
|
606 |
+
Returns:
|
607 |
+
Sequential: A sequential container with repeated layers.
|
608 |
+
"""
|
609 |
+
repeats = [
|
610 |
+
UniDeepFsmn(in_channels, out_channels, lorder, hidden_size)
|
611 |
+
for i in range(repeats)
|
612 |
+
]
|
613 |
+
return nn.Sequential(*repeats)
|
614 |
+
|
615 |
+
def forward(self, x, *, mask=None):
|
616 |
+
"""
|
617 |
+
Forward pass for the Mossformer Block with Gated FSMN.
|
618 |
+
|
619 |
+
Args:
|
620 |
+
x (Tensor): Input tensor of shape (batch_size, dim, sequence_length).
|
621 |
+
mask (Tensor, optional): Mask tensor for attention operations.
|
622 |
+
|
623 |
+
Returns:
|
624 |
+
Tensor: Output tensor after processing through the block.
|
625 |
+
"""
|
626 |
+
ii = 0
|
627 |
+
for flash in self.layers: # Process through each layer
|
628 |
+
x = flash(x, mask=mask)
|
629 |
+
x = self.fsmn[ii](x) # Apply corresponding Gated FSMN block
|
630 |
+
ii += 1
|
631 |
+
|
632 |
+
return x
|
633 |
+
|
634 |
+
|
635 |
+
class MossformerBlock(nn.Module):
|
636 |
+
"""
|
637 |
+
Mossformer Block with attention mechanisms.
|
638 |
+
|
639 |
+
This block is designed to process input sequences using attention
|
640 |
+
layers and incorporates rotary positional embeddings. It allows
|
641 |
+
for configurable normalization types and can handle causal
|
642 |
+
attention.
|
643 |
+
|
644 |
+
Args:
|
645 |
+
dim (int): Dimensionality of the input.
|
646 |
+
depth (int): Number of attention layers in the block.
|
647 |
+
group_size (int, optional): Size of groups for normalization. Default is 256.
|
648 |
+
query_key_dim (int, optional): Dimension of the query and key in attention. Default is 128.
|
649 |
+
expansion_factor (float, optional): Expansion factor for feedforward layers. Default is 4.
|
650 |
+
causal (bool, optional): If True, enables causal attention. Default is False.
|
651 |
+
attn_dropout (float, optional): Dropout rate for attention layers. Default is 0.1.
|
652 |
+
norm_type (str, optional): Type of normalization to use ('scalenorm' or 'layernorm'). Default is 'scalenorm'.
|
653 |
+
shift_tokens (bool, optional): If True, shifts tokens in the attention layer. Default is True.
|
654 |
+
"""
|
655 |
+
def __init__(
|
656 |
+
self,
|
657 |
+
*,
|
658 |
+
dim,
|
659 |
+
depth,
|
660 |
+
group_size=256,
|
661 |
+
query_key_dim=128,
|
662 |
+
expansion_factor=4.0,
|
663 |
+
causal=False,
|
664 |
+
attn_dropout=0.1,
|
665 |
+
norm_type='scalenorm',
|
666 |
+
shift_tokens=True
|
667 |
+
):
|
668 |
+
super().__init__()
|
669 |
+
|
670 |
+
# Ensure normalization type is valid
|
671 |
+
assert norm_type in ('scalenorm', 'layernorm'), 'norm_type must be one of scalenorm or layernorm'
|
672 |
+
|
673 |
+
# Select normalization class based on the provided type
|
674 |
+
if norm_type == 'scalenorm':
|
675 |
+
norm_klass = ScaleNorm
|
676 |
+
elif norm_type == 'layernorm':
|
677 |
+
norm_klass = nn.LayerNorm
|
678 |
+
|
679 |
+
self.group_size = group_size # Group size for normalization
|
680 |
+
|
681 |
+
# Rotary positional embedding for attention
|
682 |
+
rotary_pos_emb = RotaryEmbedding(dim=min(32, query_key_dim))
|
683 |
+
# Max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J
|
684 |
+
|
685 |
+
# Create a list of attention layers using FLASH_ShareA_FFConvM
|
686 |
+
self.layers = nn.ModuleList([
|
687 |
+
FLASH_ShareA_FFConvM(
|
688 |
+
dim=dim,
|
689 |
+
group_size=group_size,
|
690 |
+
query_key_dim=query_key_dim,
|
691 |
+
expansion_factor=expansion_factor,
|
692 |
+
causal=causal,
|
693 |
+
dropout=attn_dropout,
|
694 |
+
rotary_pos_emb=rotary_pos_emb,
|
695 |
+
norm_klass=norm_klass,
|
696 |
+
shift_tokens=shift_tokens
|
697 |
+
) for _ in range(depth)
|
698 |
+
])
|
699 |
+
|
700 |
+
def _build_repeats(self, in_channels, out_channels, lorder, hidden_size, repeats=1):
|
701 |
+
"""
|
702 |
+
Builds repeated UniDeep FSMN layers.
|
703 |
+
|
704 |
+
Args:
|
705 |
+
in_channels (int): Number of input channels.
|
706 |
+
out_channels (int): Number of output channels.
|
707 |
+
lorder (int): Order of the filter for FSMN.
|
708 |
+
hidden_size (int): Number of hidden units.
|
709 |
+
repeats (int, optional): Number of repetitions. Default is 1.
|
710 |
+
|
711 |
+
Returns:
|
712 |
+
Sequential: A sequential container with repeated layers.
|
713 |
+
"""
|
714 |
+
repeats = [
|
715 |
+
UniDeepFsmn(in_channels, out_channels, lorder, hidden_size)
|
716 |
+
for _ in range(repeats)
|
717 |
+
]
|
718 |
+
return nn.Sequential(*repeats)
|
719 |
+
|
720 |
+
def forward(self, x, *, mask=None):
|
721 |
+
"""
|
722 |
+
Forward pass for the Mossformer Block.
|
723 |
+
|
724 |
+
Args:
|
725 |
+
x (Tensor): Input tensor of shape (batch_size, dim, sequence_length).
|
726 |
+
mask (Tensor, optional): Mask tensor for attention operations.
|
727 |
+
|
728 |
+
Returns:
|
729 |
+
Tensor: Output tensor after processing through the block.
|
730 |
+
"""
|
731 |
+
# Process input through each attention layer
|
732 |
+
for flash in self.layers:
|
733 |
+
x = flash(x, mask=mask) # Apply attention layer with optional mask
|
734 |
+
|
735 |
+
return x # Return the final output tensor
|
models/mossformer2_sr/mossformer2_sr_wrapper.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.mossformer2_sr.generator import Mossformer, Generator
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class MossFormer2_SR_48K(nn.Module):
|
5 |
+
"""
|
6 |
+
The MossFormer2_SR_48K model for speech super-resolution.
|
7 |
+
|
8 |
+
This class encapsulates the functionality of the MossFormer2 and HiFi-Gan
|
9 |
+
Generator within a higher-level model. It processes input audio data to produce
|
10 |
+
higher-resolution outputs.
|
11 |
+
|
12 |
+
Arguments
|
13 |
+
---------
|
14 |
+
args : Namespace
|
15 |
+
Configuration arguments that may include hyperparameters
|
16 |
+
and model settings (not utilized in this implementation but
|
17 |
+
can be extended for flexibility).
|
18 |
+
|
19 |
+
Example
|
20 |
+
---------
|
21 |
+
>>> model = MossFormer2_SR_48K(args).model
|
22 |
+
>>> x = torch.randn(10, 180, 2000) # Example input
|
23 |
+
>>> outputs = model(x) # Forward pass
|
24 |
+
>>> outputs.shape, mask.shape # Check output shapes
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, args):
|
28 |
+
super(MossFormer2_SR_48K, self).__init__()
|
29 |
+
# Initialize the TestNet model, which contains the MossFormer MaskNet
|
30 |
+
self.model_m = Mossformer() # Instance of TestNet
|
31 |
+
self.model_g = Generator(args)
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
"""
|
35 |
+
Forward pass through the model.
|
36 |
+
|
37 |
+
Arguments
|
38 |
+
---------
|
39 |
+
x : torch.Tensor
|
40 |
+
Input tensor of dimension [B, N, S], where B is the batch size,
|
41 |
+
N is the number of mel bins (80 in this case), and S is the
|
42 |
+
sequence length (e.g., time frames).
|
43 |
+
|
44 |
+
Returns
|
45 |
+
-------
|
46 |
+
outputs : torch.Tensor
|
47 |
+
Bandwidth expanded audio output tensor from the model.
|
48 |
+
|
49 |
+
"""
|
50 |
+
x = self.model_m(x) # Get outputs and mask from TestNet
|
51 |
+
outpus = self.model_g(x)
|
52 |
+
return outputs # Return the outputs
|
models/mossformer2_sr/snake.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from einops import rearrange
|
6 |
+
from torch.nn.utils import weight_norm
|
7 |
+
|
8 |
+
|
9 |
+
def WNConv1d(*args, **kwargs):
|
10 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
11 |
+
|
12 |
+
|
13 |
+
def WNConvTranspose1d(*args, **kwargs):
|
14 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
15 |
+
|
16 |
+
|
17 |
+
# Scripting this brings model speed up 1.4x
|
18 |
+
@torch.jit.script
|
19 |
+
def snake(x, alpha):
|
20 |
+
shape = x.shape
|
21 |
+
x = x.reshape(shape[0], shape[1], -1)
|
22 |
+
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
23 |
+
x = x.reshape(shape)
|
24 |
+
return x
|
25 |
+
|
26 |
+
|
27 |
+
class Snake1d(nn.Module):
|
28 |
+
def __init__(self, channels):
|
29 |
+
super().__init__()
|
30 |
+
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
return snake(x, self.alpha)
|
models/mossformer2_sr/utils.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from torch.nn.utils import weight_norm
|
5 |
+
|
6 |
+
def init_weights(m, mean=0.0, std=0.01):
|
7 |
+
classname = m.__class__.__name__
|
8 |
+
if classname.find("Conv") != -1:
|
9 |
+
m.weight.data.normal_(mean, std)
|
10 |
+
|
11 |
+
def apply_weight_norm(m):
|
12 |
+
classname = m.__class__.__name__
|
13 |
+
if classname.find("Conv") != -1:
|
14 |
+
weight_norm(m)
|
15 |
+
|
16 |
+
def get_padding(kernel_size, dilation=1):
|
17 |
+
return int((kernel_size*dilation - dilation)/2)
|
18 |
+
|
19 |
+
def load_checkpoint(filepath, device):
|
20 |
+
assert os.path.isfile(filepath)
|
21 |
+
print("Loading '{}'".format(filepath))
|
22 |
+
checkpoint_dict = torch.load(filepath, map_location=device)
|
23 |
+
print("Complete.")
|
24 |
+
return checkpoint_dict
|
25 |
+
|
26 |
+
def save_checkpoint(filepath, obj):
|
27 |
+
print("Saving checkpoint to {}".format(filepath))
|
28 |
+
torch.save(obj, filepath)
|
29 |
+
print("Complete.")
|
30 |
+
|
31 |
+
def scan_checkpoint(cp_dir, prefix):
|
32 |
+
pattern = os.path.join(cp_dir, prefix + '????????')
|
33 |
+
cp_list = glob.glob(pattern)
|
34 |
+
if len(cp_list) == 0:
|
35 |
+
return None
|
36 |
+
return sorted(cp_list)[-1]
|
37 |
+
|