alibabasglab commited on
Commit
126c408
·
verified ·
1 Parent(s): e1ed673

Upload 32 files

Browse files
Files changed (32) hide show
  1. models/mossformer2_sr/__init__.py +0 -0
  2. models/mossformer2_sr/__pycache__/__init__.cpython-312.pyc +0 -0
  3. models/mossformer2_sr/__pycache__/__init__.cpython-38.pyc +0 -0
  4. models/mossformer2_sr/__pycache__/conv_module.cpython-312.pyc +0 -0
  5. models/mossformer2_sr/__pycache__/conv_module.cpython-38.pyc +0 -0
  6. models/mossformer2_sr/__pycache__/env.cpython-312.pyc +0 -0
  7. models/mossformer2_sr/__pycache__/fsmn.cpython-312.pyc +0 -0
  8. models/mossformer2_sr/__pycache__/fsmn.cpython-38.pyc +0 -0
  9. models/mossformer2_sr/__pycache__/generator.cpython-312.pyc +0 -0
  10. models/mossformer2_sr/__pycache__/layer_norm.cpython-312.pyc +0 -0
  11. models/mossformer2_sr/__pycache__/layer_norm.cpython-38.pyc +0 -0
  12. models/mossformer2_sr/__pycache__/mossformer.cpython-38.pyc +0 -0
  13. models/mossformer2_sr/__pycache__/mossformer2.cpython-312.pyc +0 -0
  14. models/mossformer2_sr/__pycache__/mossformer2.cpython-38.pyc +0 -0
  15. models/mossformer2_sr/__pycache__/mossformer2_block.cpython-312.pyc +0 -0
  16. models/mossformer2_sr/__pycache__/mossformer2_block.cpython-38.pyc +0 -0
  17. models/mossformer2_sr/__pycache__/mossformer2_se_wrapper.cpython-312.pyc +0 -0
  18. models/mossformer2_sr/__pycache__/mossformer2_se_wrapper.cpython-38.pyc +0 -0
  19. models/mossformer2_sr/__pycache__/mossformer2_sr_wrapper.cpython-312.pyc +0 -0
  20. models/mossformer2_sr/__pycache__/mossformer_block.cpython-38.pyc +0 -0
  21. models/mossformer2_sr/__pycache__/snake.cpython-312.pyc +0 -0
  22. models/mossformer2_sr/__pycache__/utils.cpython-312.pyc +0 -0
  23. models/mossformer2_sr/conv_module.py +388 -0
  24. models/mossformer2_sr/env.py +15 -0
  25. models/mossformer2_sr/fsmn.py +214 -0
  26. models/mossformer2_sr/generator.py +448 -0
  27. models/mossformer2_sr/layer_norm.py +126 -0
  28. models/mossformer2_sr/mossformer2.py +711 -0
  29. models/mossformer2_sr/mossformer2_block.py +735 -0
  30. models/mossformer2_sr/mossformer2_sr_wrapper.py +52 -0
  31. models/mossformer2_sr/snake.py +33 -0
  32. models/mossformer2_sr/utils.py +37 -0
models/mossformer2_sr/__init__.py ADDED
File without changes
models/mossformer2_sr/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (198 Bytes). View file
 
models/mossformer2_sr/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (193 Bytes). View file
 
models/mossformer2_sr/__pycache__/conv_module.cpython-312.pyc ADDED
Binary file (20.4 kB). View file
 
models/mossformer2_sr/__pycache__/conv_module.cpython-38.pyc ADDED
Binary file (13.9 kB). View file
 
models/mossformer2_sr/__pycache__/env.cpython-312.pyc ADDED
Binary file (1.19 kB). View file
 
models/mossformer2_sr/__pycache__/fsmn.cpython-312.pyc ADDED
Binary file (12.4 kB). View file
 
models/mossformer2_sr/__pycache__/fsmn.cpython-38.pyc ADDED
Binary file (8.51 kB). View file
 
models/mossformer2_sr/__pycache__/generator.cpython-312.pyc ADDED
Binary file (24.6 kB). View file
 
models/mossformer2_sr/__pycache__/layer_norm.cpython-312.pyc ADDED
Binary file (6.59 kB). View file
 
models/mossformer2_sr/__pycache__/layer_norm.cpython-38.pyc ADDED
Binary file (4.24 kB). View file
 
models/mossformer2_sr/__pycache__/mossformer.cpython-38.pyc ADDED
Binary file (16.3 kB). View file
 
models/mossformer2_sr/__pycache__/mossformer2.cpython-312.pyc ADDED
Binary file (22.8 kB). View file
 
models/mossformer2_sr/__pycache__/mossformer2.cpython-38.pyc ADDED
Binary file (15.9 kB). View file
 
models/mossformer2_sr/__pycache__/mossformer2_block.cpython-312.pyc ADDED
Binary file (30.8 kB). View file
 
models/mossformer2_sr/__pycache__/mossformer2_block.cpython-38.pyc ADDED
Binary file (23.5 kB). View file
 
models/mossformer2_sr/__pycache__/mossformer2_se_wrapper.cpython-312.pyc ADDED
Binary file (4.05 kB). View file
 
models/mossformer2_sr/__pycache__/mossformer2_se_wrapper.cpython-38.pyc ADDED
Binary file (3.6 kB). View file
 
models/mossformer2_sr/__pycache__/mossformer2_sr_wrapper.cpython-312.pyc ADDED
Binary file (2.36 kB). View file
 
models/mossformer2_sr/__pycache__/mossformer_block.cpython-38.pyc ADDED
Binary file (21.2 kB). View file
 
models/mossformer2_sr/__pycache__/snake.cpython-312.pyc ADDED
Binary file (2.24 kB). View file
 
models/mossformer2_sr/__pycache__/utils.cpython-312.pyc ADDED
Binary file (2.46 kB). View file
 
models/mossformer2_sr/conv_module.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch import Tensor
4
+ import torch.nn.init as init
5
+ import torch.nn.functional as F
6
+
7
+ EPS = 1e-8
8
+
9
+ class GlobalLayerNorm(nn.Module):
10
+ """Calculate Global Layer Normalization.
11
+
12
+ Arguments
13
+ ---------
14
+ dim : (int or list or torch.Size)
15
+ Input shape from an expected input of size.
16
+ eps : float
17
+ A value added to the denominator for numerical stability.
18
+ elementwise_affine : bool
19
+ A boolean value that when set to True,
20
+ this module has learnable per-element affine parameters
21
+ initialized to ones (for weights) and zeros (for biases).
22
+
23
+ Example
24
+ -------
25
+ >>> x = torch.randn(5, 10, 20)
26
+ >>> GLN = GlobalLayerNorm(10, 3)
27
+ >>> x_norm = GLN(x)
28
+ """
29
+
30
+ def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True):
31
+ super(GlobalLayerNorm, self).__init__()
32
+ self.dim = dim
33
+ self.eps = eps
34
+ self.elementwise_affine = elementwise_affine
35
+
36
+ if self.elementwise_affine:
37
+ if shape == 3:
38
+ self.weight = nn.Parameter(torch.ones(self.dim, 1))
39
+ self.bias = nn.Parameter(torch.zeros(self.dim, 1))
40
+ if shape == 4:
41
+ self.weight = nn.Parameter(torch.ones(self.dim, 1, 1))
42
+ self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1))
43
+ else:
44
+ self.register_parameter("weight", None)
45
+ self.register_parameter("bias", None)
46
+
47
+ def forward(self, x):
48
+ """Returns the normalized tensor.
49
+
50
+ Arguments
51
+ ---------
52
+ x : torch.Tensor
53
+ Tensor of size [N, C, K, S] or [N, C, L].
54
+ """
55
+ # x = N x C x K x S or N x C x L
56
+ # N x 1 x 1
57
+ # cln: mean,var N x 1 x K x S
58
+ # gln: mean,var N x 1 x 1
59
+ if x.dim() == 3:
60
+ mean = torch.mean(x, (1, 2), keepdim=True)
61
+ var = torch.mean((x - mean) ** 2, (1, 2), keepdim=True)
62
+ if self.elementwise_affine:
63
+ x = (
64
+ self.weight * (x - mean) / torch.sqrt(var + self.eps)
65
+ + self.bias
66
+ )
67
+ else:
68
+ x = (x - mean) / torch.sqrt(var + self.eps)
69
+
70
+ if x.dim() == 4:
71
+ mean = torch.mean(x, (1, 2, 3), keepdim=True)
72
+ var = torch.mean((x - mean) ** 2, (1, 2, 3), keepdim=True)
73
+ if self.elementwise_affine:
74
+ x = (
75
+ self.weight * (x - mean) / torch.sqrt(var + self.eps)
76
+ + self.bias
77
+ )
78
+ else:
79
+ x = (x - mean) / torch.sqrt(var + self.eps)
80
+ return x
81
+
82
+
83
+ class CumulativeLayerNorm(nn.LayerNorm):
84
+ """Calculate Cumulative Layer Normalization.
85
+
86
+ Arguments
87
+ ---------
88
+ dim : int
89
+ Dimension that you want to normalize.
90
+ elementwise_affine : True
91
+ Learnable per-element affine parameters.
92
+
93
+ Example
94
+ -------
95
+ >>> x = torch.randn(5, 10, 20)
96
+ >>> CLN = CumulativeLayerNorm(10)
97
+ >>> x_norm = CLN(x)
98
+ """
99
+
100
+ def __init__(self, dim, elementwise_affine=True):
101
+ super(CumulativeLayerNorm, self).__init__(
102
+ dim, elementwise_affine=elementwise_affine, eps=1e-8
103
+ )
104
+
105
+ def forward(self, x):
106
+ """Returns the normalized tensor.
107
+
108
+ Arguments
109
+ ---------
110
+ x : torch.Tensor
111
+ Tensor size [N, C, K, S] or [N, C, L]
112
+ """
113
+ # x: N x C x K x S or N x C x L
114
+ # N x K x S x C
115
+ if x.dim() == 4:
116
+ x = x.permute(0, 2, 3, 1).contiguous()
117
+ # N x K x S x C == only channel norm
118
+ x = super().forward(x)
119
+ # N x C x K x S
120
+ x = x.permute(0, 3, 1, 2).contiguous()
121
+ if x.dim() == 3:
122
+ x = torch.transpose(x, 1, 2)
123
+ # N x L x C == only channel norm
124
+ x = super().forward(x)
125
+ # N x C x L
126
+ x = torch.transpose(x, 1, 2)
127
+ return x
128
+
129
+
130
+ def select_norm(norm, dim, shape):
131
+ """Just a wrapper to select the normalization type.
132
+ """
133
+
134
+ if norm == "gln":
135
+ return GlobalLayerNorm(dim, shape, elementwise_affine=True)
136
+ if norm == "cln":
137
+ return CumulativeLayerNorm(dim, elementwise_affine=True)
138
+ if norm == "ln":
139
+ return nn.GroupNorm(1, dim, eps=1e-8)
140
+ else:
141
+ return nn.BatchNorm1d(dim)
142
+
143
+ class Swish(nn.Module):
144
+ """
145
+ Swish is a smooth, non-monotonic function that consistently matches or outperforms ReLU on deep networks applied
146
+ to a variety of challenging domains such as Image classification and Machine translation.
147
+ """
148
+ def __init__(self):
149
+ super(Swish, self).__init__()
150
+
151
+ def forward(self, inputs: Tensor) -> Tensor:
152
+ return inputs * inputs.sigmoid()
153
+
154
+
155
+ class GLU(nn.Module):
156
+ """
157
+ The gating mechanism is called Gated Linear Units (GLU), which was first introduced for natural language processing
158
+ in the paper “Language Modeling with Gated Convolutional Networks”
159
+ """
160
+ def __init__(self, dim: int) -> None:
161
+ super(GLU, self).__init__()
162
+ self.dim = dim
163
+
164
+ def forward(self, inputs: Tensor) -> Tensor:
165
+ outputs, gate = inputs.chunk(2, dim=self.dim)
166
+ return outputs * gate.sigmoid()
167
+
168
+ class Transpose(nn.Module):
169
+ """ Wrapper class of torch.transpose() for Sequential module. """
170
+ def __init__(self, shape: tuple):
171
+ super(Transpose, self).__init__()
172
+ self.shape = shape
173
+
174
+ def forward(self, x: Tensor) -> Tensor:
175
+ return x.transpose(*self.shape)
176
+
177
+ class Linear(nn.Module):
178
+ """
179
+ Wrapper class of torch.nn.Linear
180
+ Weight initialize by xavier initialization and bias initialize to zeros.
181
+ """
182
+ def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
183
+ super(Linear, self).__init__()
184
+ self.linear = nn.Linear(in_features, out_features, bias=bias)
185
+ init.xavier_uniform_(self.linear.weight)
186
+ if bias:
187
+ init.zeros_(self.linear.bias)
188
+
189
+ def forward(self, x: Tensor) -> Tensor:
190
+ return self.linear(x)
191
+
192
+ class DepthwiseConv1d(nn.Module):
193
+ """
194
+ When groups == in_channels and out_channels == K * in_channels, where K is a positive integer,
195
+ this operation is termed in literature as depthwise convolution.
196
+ Args:
197
+ in_channels (int): Number of channels in the input
198
+ out_channels (int): Number of channels produced by the convolution
199
+ kernel_size (int or tuple): Size of the convolving kernel
200
+ stride (int, optional): Stride of the convolution. Default: 1
201
+ padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
202
+ bias (bool, optional): If True, adds a learnable bias to the output. Default: True
203
+ Inputs: inputs
204
+ - **inputs** (batch, in_channels, time): Tensor containing input vector
205
+ Returns: outputs
206
+ - **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution.
207
+ """
208
+ def __init__(
209
+ self,
210
+ in_channels: int,
211
+ out_channels: int,
212
+ kernel_size: int,
213
+ stride: int = 1,
214
+ padding: int = 0,
215
+ bias: bool = False,
216
+ ) -> None:
217
+ super(DepthwiseConv1d, self).__init__()
218
+ assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels"
219
+ self.conv = nn.Conv1d(
220
+ in_channels=in_channels,
221
+ out_channels=out_channels,
222
+ kernel_size=kernel_size,
223
+ groups=in_channels,
224
+ stride=stride,
225
+ padding=padding,
226
+ bias=bias,
227
+ )
228
+
229
+ def forward(self, inputs: Tensor) -> Tensor:
230
+ return self.conv(inputs)
231
+
232
+
233
+ class PointwiseConv1d(nn.Module):
234
+ """
235
+ When kernel size == 1 conv1d, this operation is termed in literature as pointwise convolution.
236
+ This operation often used to match dimensions.
237
+ Args:
238
+ in_channels (int): Number of channels in the input
239
+ out_channels (int): Number of channels produced by the convolution
240
+ stride (int, optional): Stride of the convolution. Default: 1
241
+ padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
242
+ bias (bool, optional): If True, adds a learnable bias to the output. Default: True
243
+ Inputs: inputs
244
+ - **inputs** (batch, in_channels, time): Tensor containing input vector
245
+ Returns: outputs
246
+ - **outputs** (batch, out_channels, time): Tensor produces by pointwise 1-D convolution.
247
+ """
248
+ def __init__(
249
+ self,
250
+ in_channels: int,
251
+ out_channels: int,
252
+ stride: int = 1,
253
+ padding: int = 0,
254
+ bias: bool = True,
255
+ ) -> None:
256
+ super(PointwiseConv1d, self).__init__()
257
+ self.conv = nn.Conv1d(
258
+ in_channels=in_channels,
259
+ out_channels=out_channels,
260
+ kernel_size=1,
261
+ stride=stride,
262
+ padding=padding,
263
+ bias=bias,
264
+ )
265
+
266
+ def forward(self, inputs: Tensor) -> Tensor:
267
+ return self.conv(inputs)
268
+
269
+
270
+ class ConvModule(nn.Module):
271
+ """
272
+ Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
273
+ This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution
274
+ to aid training deep models.
275
+ Args:
276
+ in_channels (int): Number of channels in the input
277
+ kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31
278
+ dropout_p (float, optional): probability of dropout
279
+ Inputs: inputs
280
+ inputs (batch, time, dim): Tensor contains input sequences
281
+ Outputs: outputs
282
+ outputs (batch, time, dim): Tensor produces by conformer convolution module.
283
+ """
284
+ def __init__(
285
+ self,
286
+ in_channels: int,
287
+ kernel_size: int = 17,
288
+ expansion_factor: int = 2,
289
+ dropout_p: float = 0.1,
290
+ ) -> None:
291
+ super(ConvModule, self).__init__()
292
+ assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
293
+ assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2"
294
+
295
+ self.sequential = nn.Sequential(
296
+ Transpose(shape=(1, 2)),
297
+ DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2),
298
+ )
299
+
300
+ def forward(self, inputs: Tensor) -> Tensor:
301
+ return inputs + self.sequential(inputs).transpose(1, 2)
302
+
303
+ class ConvModule_Dilated(nn.Module):
304
+ """
305
+ Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
306
+ This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution
307
+ to aid training deep models.
308
+ Args:
309
+ in_channels (int): Number of channels in the input
310
+ kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31
311
+ dropout_p (float, optional): probability of dropout
312
+ Inputs: inputs
313
+ inputs (batch, time, dim): Tensor contains input sequences
314
+ Outputs: outputs
315
+ outputs (batch, time, dim): Tensor produces by conformer convolution module.
316
+ """
317
+ def __init__(
318
+ self,
319
+ in_channels: int,
320
+ kernel_size: int = 17,
321
+ expansion_factor: int = 2,
322
+ dropout_p: float = 0.1,
323
+ ) -> None:
324
+ super(ConvModule_Gating, self).__init__()
325
+ assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
326
+ assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2"
327
+ self.sequential = nn.Sequential(
328
+ Transpose(shape=(1, 2)),
329
+ DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2),
330
+ )
331
+
332
+ def forward(self, inputs: Tensor) -> Tensor:
333
+ return inputs + self.sequential(inputs).transpose(1, 2)
334
+
335
+ class DilatedDenseNet(nn.Module):
336
+ def __init__(self, depth=4, lorder=20, in_channels=64):
337
+ super(DilatedDenseNet, self).__init__()
338
+ self.depth = depth
339
+ self.in_channels = in_channels
340
+ self.pad = nn.ConstantPad2d((1, 1, 1, 0), value=0.)
341
+ self.twidth = lorder*2-1
342
+ self.kernel_size = (self.twidth, 1)
343
+ for i in range(self.depth):
344
+ dil = 2 ** i
345
+ pad_length = lorder + (dil - 1) * (lorder - 1) - 1
346
+ setattr(self, 'pad{}'.format(i + 1), nn.ConstantPad2d((0, 0, pad_length, pad_length), value=0.))
347
+ setattr(self, 'conv{}'.format(i + 1),
348
+ nn.Conv2d(self.in_channels*(i+1), self.in_channels, kernel_size=self.kernel_size,
349
+ dilation=(dil, 1), groups=self.in_channels, bias=False))
350
+ setattr(self, 'norm{}'.format(i + 1), nn.InstanceNorm2d(in_channels, affine=True))
351
+ setattr(self, 'prelu{}'.format(i + 1), nn.PReLU(self.in_channels))
352
+
353
+ def forward(self, x):
354
+ x = torch.unsqueeze(x, 1)
355
+ x_per = x.permute(0, 3, 2, 1)
356
+ skip = x_per
357
+ for i in range(self.depth):
358
+ out = getattr(self, 'pad{}'.format(i + 1))(skip)
359
+ out = getattr(self, 'conv{}'.format(i + 1))(out)
360
+ out = getattr(self, 'norm{}'.format(i + 1))(out)
361
+ out = getattr(self, 'prelu{}'.format(i + 1))(out)
362
+ skip = torch.cat([out, skip], dim=1)
363
+ out1 = out.permute(0, 3, 2, 1)
364
+ return out1.squeeze(1)
365
+
366
+ class FFConvM_Dilated(nn.Module):
367
+ def __init__(
368
+ self,
369
+ dim_in,
370
+ dim_out,
371
+ norm_klass = nn.LayerNorm,
372
+ dropout = 0.1
373
+ ):
374
+ super().__init__()
375
+ self.mdl = nn.Sequential(
376
+ norm_klass(dim_in),
377
+ nn.Linear(dim_in, dim_out),
378
+ nn.SiLU(),
379
+ DilatedDenseNet(depth=2, lorder=17, in_channels=dim_out),
380
+ nn.Dropout(dropout)
381
+ )
382
+ def forward(
383
+ self,
384
+ x,
385
+ ):
386
+ output = self.mdl(x)
387
+ return output
388
+
models/mossformer2_sr/env.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+
4
+
5
+ class AttrDict(dict):
6
+ def __init__(self, *args, **kwargs):
7
+ super(AttrDict, self).__init__(*args, **kwargs)
8
+ self.__dict__ = self
9
+
10
+
11
+ def build_env(config, config_name, path):
12
+ t_path = os.path.join(path, config_name)
13
+ if config != t_path:
14
+ os.makedirs(path, exist_ok=True)
15
+ shutil.copyfile(config, os.path.join(path, config_name))
models/mossformer2_sr/fsmn.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ import torch as th
4
+ from torch.nn.parameter import Parameter
5
+ import numpy as np
6
+ import os
7
+
8
+ class UniDeepFsmn(nn.Module):
9
+ """
10
+ UniDeepFsmn is a neural network module that implements a single-deep feedforward sequence memory network (FSMN).
11
+
12
+ Attributes:
13
+ input_dim (int): Dimension of the input features.
14
+ output_dim (int): Dimension of the output features.
15
+ lorder (int): Length of the order for the convolution layers.
16
+ hidden_size (int): Number of hidden units in the linear layer.
17
+ linear (nn.Linear): Linear layer to project input features to hidden size.
18
+ project (nn.Linear): Linear layer to project hidden features to output dimensions.
19
+ conv1 (nn.Conv2d): Convolutional layer for processing the output in a grouped manner.
20
+ """
21
+
22
+ def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None):
23
+ super(UniDeepFsmn, self).__init__()
24
+
25
+ self.input_dim = input_dim
26
+ self.output_dim = output_dim
27
+ if lorder is None:
28
+ return
29
+ self.lorder = lorder
30
+ self.hidden_size = hidden_size
31
+
32
+ # Initialize the layers
33
+ self.linear = nn.Linear(input_dim, hidden_size) # Linear transformation to hidden size
34
+ self.project = nn.Linear(hidden_size, output_dim, bias=False) # Project hidden size to output dimension
35
+ self.conv1 = nn.Conv2d(output_dim, output_dim, [lorder + lorder - 1, 1], [1, 1], groups=output_dim, bias=False) # Convolution layer
36
+
37
+ def forward(self, input):
38
+ """
39
+ Forward pass for the UniDeepFsmn model.
40
+
41
+ Args:
42
+ input (torch.Tensor): Input tensor of shape (batch_size, input_dim).
43
+
44
+ Returns:
45
+ torch.Tensor: The output tensor of the same shape as input, enhanced by the network.
46
+ """
47
+ f1 = F.relu(self.linear(input)) # Apply linear layer followed by ReLU activation
48
+ p1 = self.project(f1) # Project to output dimension
49
+ x = th.unsqueeze(p1, 1) # Add a dimension for compatibility with Conv2d
50
+ x_per = x.permute(0, 3, 2, 1) # Permute dimensions for convolution
51
+ y = F.pad(x_per, [0, 0, self.lorder - 1, self.lorder - 1]) # Pad for causal convolution
52
+ out = x_per + self.conv1(y) # Add original input to convolution output
53
+ out1 = out.permute(0, 3, 2, 1) # Permute back to original dimensions
54
+ return input + out1.squeeze() # Return enhanced input
55
+
56
+
57
+ class UniDeepFsmn_dual(nn.Module):
58
+ """
59
+ UniDeepFsmn_dual is a neural network module that implements a dual-deep feedforward sequence memory network (FSMN).
60
+
61
+ This class extends the UniDeepFsmn by adding a second convolution layer for richer feature extraction.
62
+
63
+ Attributes:
64
+ input_dim (int): Dimension of the input features.
65
+ output_dim (int): Dimension of the output features.
66
+ lorder (int): Length of the order for the convolution layers.
67
+ hidden_size (int): Number of hidden units in the linear layer.
68
+ linear (nn.Linear): Linear layer to project input features to hidden size.
69
+ project (nn.Linear): Linear layer to project hidden features to output dimensions.
70
+ conv1 (nn.Conv2d): First convolutional layer for processing the output.
71
+ conv2 (nn.Conv2d): Second convolutional layer for further processing the features.
72
+ """
73
+
74
+ def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None):
75
+ super(UniDeepFsmn_dual, self).__init__()
76
+
77
+ self.input_dim = input_dim
78
+ self.output_dim = output_dim
79
+ if lorder is None:
80
+ return
81
+ self.lorder = lorder
82
+ self.hidden_size = hidden_size
83
+
84
+ # Initialize the layers
85
+ self.linear = nn.Linear(input_dim, hidden_size) # Linear transformation to hidden size
86
+ self.project = nn.Linear(hidden_size, output_dim, bias=False) # Project hidden size to output dimension
87
+ self.conv1 = nn.Conv2d(output_dim, output_dim, [lorder + lorder - 1, 1], [1, 1], groups=output_dim, bias=False) # First convolution layer
88
+ self.conv2 = nn.Conv2d(output_dim, output_dim, [lorder + lorder - 1, 1], [1, 1], groups=output_dim // 4, bias=False) # Second convolution layer
89
+
90
+ def forward(self, input):
91
+ """
92
+ Forward pass for the UniDeepFsmn_dual model.
93
+
94
+ Args:
95
+ input (torch.Tensor): Input tensor of shape (batch_size, input_dim).
96
+
97
+ Returns:
98
+ torch.Tensor: The output tensor of the same shape as input, enhanced by the network.
99
+ """
100
+ f1 = F.relu(self.linear(input)) # Apply linear layer followed by ReLU activation
101
+ p1 = self.project(f1) # Project to output dimension
102
+ x = th.unsqueeze(p1, 1) # Add a dimension for compatibility with Conv2d
103
+ x_per = x.permute(0, 3, 2, 1) # Permute dimensions for convolution
104
+ y = F.pad(x_per, [0, 0, self.lorder - 1, self.lorder - 1]) # Pad for causal convolution
105
+ conv1_out = x_per + self.conv1(y) # Add original input to first convolution output
106
+ z = F.pad(conv1_out, [0, 0, self.lorder - 1, self.lorder - 1]) # Pad for second convolution
107
+ out = conv1_out + self.conv2(z) # Add output of second convolution
108
+ out1 = out.permute(0, 3, 2, 1) # Permute back to original dimensions
109
+ return input + out1.squeeze() # Return enhanced input
110
+
111
+
112
+ class DilatedDenseNet(nn.Module):
113
+ """
114
+ DilatedDenseNet implements a dense network structure with dilated convolutions.
115
+
116
+ This architecture enables wider receptive fields while maintaining a lower number of parameters.
117
+ It consists of multiple convolutional layers with dilation rates that increase at each layer.
118
+
119
+ Attributes:
120
+ depth (int): Number of convolutional layers in the network.
121
+ in_channels (int): Number of input channels for the first layer.
122
+ pad (nn.ConstantPad2d): Padding layer to maintain dimensions.
123
+ twidth (int): Width of the kernel used in convolution.
124
+ kernel_size (tuple): Kernel size for convolution operations.
125
+ """
126
+
127
+ def __init__(self, depth=4, lorder=20, in_channels=64):
128
+ super(DilatedDenseNet, self).__init__()
129
+ self.depth = depth
130
+ self.in_channels = in_channels
131
+ self.pad = nn.ConstantPad2d((1, 1, 1, 0), value=0.) # Padding for the input
132
+ self.twidth = lorder * 2 - 1 # Width of the kernel
133
+ self.kernel_size = (self.twidth, 1) # Kernel size for convolutions
134
+
135
+ # Initialize layers dynamically based on depth
136
+ for i in range(self.depth):
137
+ dil = 2 ** i # Calculate dilation rate
138
+ pad_length = lorder + (dil - 1) * (lorder - 1) - 1 # Calculate padding length
139
+ setattr(self, 'pad{}'.format(i + 1), nn.ConstantPad2d((0, 0, pad_length, pad_length), value=0.)) # Padding for dilation
140
+ setattr(self, 'conv{}'.format(i + 1),
141
+ nn.Conv2d(self.in_channels * (i + 1), self.in_channels, kernel_size=self.kernel_size,
142
+ dilation=(dil, 1), groups=self.in_channels, bias=False)) # Convolution layer with dilation
143
+ setattr(self, 'norm{}'.format(i + 1), nn.InstanceNorm2d(in_channels, affine=True)) # Normalization layer
144
+ setattr(self, 'prelu{}'.format(i + 1), nn.PReLU(self.in_channels)) # Activation layer
145
+
146
+ def forward(self, x):
147
+ """
148
+ Forward pass for the DilatedDenseNet model.
149
+
150
+ Args:
151
+ x (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width).
152
+
153
+ Returns:
154
+ torch.Tensor: Output tensor after applying dense layers.
155
+ """
156
+ skip = x # Initialize skip connection
157
+ for i in range(self.depth):
158
+ out = getattr(self, 'pad{}'.format(i + 1))(skip) # Apply padding
159
+ out = getattr(self, 'conv{}'.format(i + 1))(out) # Apply convolution
160
+ out = getattr(self, 'norm{}'.format(i + 1))(out) # Apply normalization
161
+ out = getattr(self, 'prelu{}'.format(i + 1))(out) # Apply PReLU activation
162
+ skip = th.cat([out, skip], dim=1) # Concatenate the output with the skip connection
163
+ return out # Return the final output
164
+
165
+ class UniDeepFsmn_dilated(nn.Module):
166
+ """
167
+ UniDeepFsmn_dilated combines the UniDeepFsmn architecture with a dilated dense network
168
+ to enhance feature extraction while maintaining efficient computation.
169
+
170
+ Attributes:
171
+ input_dim (int): Dimension of the input features.
172
+ output_dim (int): Dimension of the output features.
173
+ depth (int): Depth of the dilated dense network.
174
+ lorder (int): Length of the order for the convolution layers.
175
+ hidden_size (int): Number of hidden units in the linear layer.
176
+ linear (nn.Linear): Linear layer to project input features to hidden size.
177
+ project (nn.Linear): Linear layer to project hidden features to output dimensions.
178
+ conv (DilatedDenseNet): Instance of the DilatedDenseNet for feature extraction.
179
+ """
180
+
181
+ def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None, depth=2):
182
+ super(UniDeepFsmn_dilated, self).__init__()
183
+
184
+ self.input_dim = input_dim
185
+ self.output_dim = output_dim
186
+ self.depth = depth
187
+ if lorder is None:
188
+ return
189
+ self.lorder = lorder
190
+ self.hidden_size = hidden_size
191
+
192
+ # Initialize layers
193
+ self.linear = nn.Linear(input_dim, hidden_size) # Linear transformation to hidden size
194
+ self.project = nn.Linear(hidden_size, output_dim, bias=False) # Project hidden size to output dimension
195
+ self.conv = DilatedDenseNet(depth=self.depth, lorder=lorder, in_channels=output_dim) # Dilated dense network for feature extraction
196
+
197
+ def forward(self, input):
198
+ """
199
+ Forward pass for the UniDeepFsmn_dilated model.
200
+
201
+ Args:
202
+ input (torch.Tensor): Input tensor of shape (batch_size, input_dim).
203
+
204
+ Returns:
205
+ torch.Tensor: The output tensor of the same shape as input, enhanced by the network.
206
+ """
207
+ f1 = F.relu(self.linear(input)) # Apply linear layer followed by ReLU activation
208
+ p1 = self.project(f1) # Project to output dimension
209
+ x = th.unsqueeze(p1, 1) # Add a dimension for compatibility with Conv2d
210
+ x_per = x.permute(0, 3, 2, 1) # Permute dimensions for convolution
211
+ out = self.conv(x_per) # Pass through the dilated dense network
212
+ out1 = out.permute(0, 3, 2, 1) # Permute back to original dimensions
213
+
214
+ return input + out1.squeeze() # Return enhanced input
models/mossformer2_sr/generator.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
6
+ from models.mossformer2_sr.utils import init_weights, get_padding
7
+ from models.mossformer2_sr.mossformer2 import MossFormer_MaskNet
8
+ from models.mossformer2_sr.snake import Snake1d
9
+ from typing import Optional, List, Union, Dict, Tuple
10
+ from models.mossformer2_sr.env import AttrDict
11
+ import typing
12
+ from torchaudio.transforms import Spectrogram, Resample
13
+
14
+ LRELU_SLOPE = 0.1
15
+
16
+
17
+ class ResBlock1(torch.nn.Module):
18
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
19
+ super(ResBlock1, self).__init__()
20
+ self.h = h
21
+ self.convs1 = nn.ModuleList([
22
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
23
+ padding=get_padding(kernel_size, dilation[0]))),
24
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
25
+ padding=get_padding(kernel_size, dilation[1]))),
26
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
27
+ padding=get_padding(kernel_size, dilation[2])))
28
+ #Snake1d(channels)
29
+ ])
30
+ self.convs1.apply(init_weights)
31
+ self.convs1_activates = nn.ModuleList([
32
+ Snake1d(channels),
33
+ Snake1d(channels),
34
+ Snake1d(channels)
35
+ ])
36
+ self.convs2 = nn.ModuleList([
37
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
38
+ padding=get_padding(kernel_size, 1))),
39
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
40
+ padding=get_padding(kernel_size, 1))),
41
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
42
+ padding=get_padding(kernel_size, 1)))
43
+ #Snake1d(channels)
44
+ ])
45
+ self.convs2.apply(init_weights)
46
+ #self.convs2_activate = Snake1d(channels)
47
+ self.convs2_activates = nn.ModuleList([
48
+ Snake1d(channels),
49
+ Snake1d(channels),
50
+ Snake1d(channels)
51
+ ])
52
+
53
+ def forward(self, x):
54
+ for c1, c2, act1, act2 in zip(self.convs1, self.convs2, self.convs1_activates, self.convs2_activates):
55
+ #xt = F.leaky_relu(x, LRELU_SLOPE)
56
+ #print(f'xt: {xt.shape}')
57
+ xt = act1(x)
58
+ xt = c1(xt)
59
+ #xt = F.leaky_relu(xt, LRELU_SLOPE)
60
+ xt = act2(xt)
61
+ xt = c2(xt)
62
+ x = xt + x
63
+ return x
64
+
65
+ def remove_weight_norm(self):
66
+ for l in self.convs1:
67
+ remove_weight_norm(l)
68
+ for l in self.convs2:
69
+ remove_weight_norm(l)
70
+
71
+
72
+ class ResBlock2(torch.nn.Module):
73
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
74
+ super(ResBlock2, self).__init__()
75
+ self.h = h
76
+ self.convs = nn.ModuleList([
77
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
78
+ padding=get_padding(kernel_size, dilation[0]))),
79
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
80
+ padding=get_padding(kernel_size, dilation[1])))
81
+ #Snake1d(channels)
82
+ ])
83
+ self.convs.apply(init_weights)
84
+ #self.convs_activate = Snake1d(channels)
85
+ self.convs_activates = nn.ModuleList([
86
+ Snake1d(channels),
87
+ Snake1d(channels)
88
+ ])
89
+ def forward(self, x):
90
+ for c, act in zip(self.convs, self.convs_activates):
91
+ #xt = F.leaky_relu(x, LRELU_SLOPE)
92
+ xt = act(x)
93
+ xt = c(xt)
94
+ x = xt + x
95
+ return x
96
+
97
+ def remove_weight_norm(self):
98
+ for l in self.convs:
99
+ remove_weight_norm(l)
100
+
101
+
102
+ class Generator(torch.nn.Module):
103
+ def __init__(self, h):
104
+ super(Generator, self).__init__()
105
+ self.h = h
106
+ self.num_kernels = len(h.resblock_kernel_sizes)
107
+ self.num_upsamples = len(h.upsample_rates)
108
+ self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
109
+ resblock = ResBlock1 if h.resblock == '1' else ResBlock2
110
+
111
+ self.ups = nn.ModuleList()
112
+ self.snakes = nn.ModuleList()
113
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
114
+ self.snakes.append(Snake1d(h.upsample_initial_channel//(2**i)))
115
+ self.ups.append(weight_norm(
116
+ ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
117
+ k, u, padding=(k-u)//2)))
118
+
119
+ self.resblocks = nn.ModuleList()
120
+ for i in range(len(self.ups)):
121
+ ch = h.upsample_initial_channel//(2**(i+1))
122
+ for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
123
+ self.resblocks.append(resblock(h, ch, k, d))
124
+
125
+ self.snake_post = Snake1d(ch)
126
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
127
+ self.ups.apply(init_weights)
128
+ self.conv_post.apply(init_weights)
129
+ def forward(self, x):
130
+ x = self.conv_pre(x)
131
+ for i in range(self.num_upsamples):
132
+ #x = F.leaky_relu(x, LRELU_SLOPE)
133
+ #print(f'x {i}: {x.shape}')
134
+ x = self.snakes[i](x)
135
+ x = self.ups[i](x)
136
+ xs = None
137
+ for j in range(self.num_kernels):
138
+ if xs is None:
139
+ xs = self.resblocks[i*self.num_kernels+j](x)
140
+ else:
141
+ xs += self.resblocks[i*self.num_kernels+j](x)
142
+ x = xs / self.num_kernels
143
+ #x = F.leaky_relu(x)
144
+ x = self.snake_post(x)
145
+ x = self.conv_post(x)
146
+ x = torch.tanh(x)
147
+
148
+ return x
149
+
150
+ def remove_weight_norm(self):
151
+ #print('Removing weight norm...')
152
+ for l in self.ups:
153
+ remove_weight_norm(l)
154
+ for l in self.resblocks:
155
+ l.remove_weight_norm()
156
+ remove_weight_norm(self.conv_pre)
157
+ remove_weight_norm(self.conv_post)
158
+
159
+
160
+ class DiscriminatorP(torch.nn.Module):
161
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
162
+ super(DiscriminatorP, self).__init__()
163
+ self.period = period
164
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
165
+ self.convs = nn.ModuleList([
166
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
167
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
168
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
169
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
170
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
171
+ ])
172
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
173
+
174
+ def forward(self, x):
175
+ fmap = []
176
+
177
+ # 1d to 2d
178
+ b, c, t = x.shape
179
+ if t % self.period != 0: # pad first
180
+ n_pad = self.period - (t % self.period)
181
+ x = F.pad(x, (0, n_pad), "reflect")
182
+ t = t + n_pad
183
+ x = x.view(b, c, t // self.period, self.period)
184
+
185
+ for l in self.convs:
186
+ x = l(x)
187
+ x = F.leaky_relu(x, LRELU_SLOPE)
188
+ fmap.append(x)
189
+ x = self.conv_post(x)
190
+ fmap.append(x)
191
+ x = torch.flatten(x, 1, -1)
192
+
193
+ return x, fmap
194
+
195
+
196
+ class MultiPeriodDiscriminator(torch.nn.Module):
197
+ def __init__(self):
198
+ super(MultiPeriodDiscriminator, self).__init__()
199
+ self.discriminators = nn.ModuleList([
200
+ DiscriminatorP(2),
201
+ DiscriminatorP(3),
202
+ DiscriminatorP(5),
203
+ DiscriminatorP(7),
204
+ DiscriminatorP(11),
205
+ ])
206
+
207
+ def forward(self, y, y_hat):
208
+ y_d_rs = []
209
+ y_d_gs = []
210
+ fmap_rs = []
211
+ fmap_gs = []
212
+ for i, d in enumerate(self.discriminators):
213
+ y_d_r, fmap_r = d(y)
214
+ y_d_g, fmap_g = d(y_hat)
215
+ y_d_rs.append(y_d_r)
216
+ fmap_rs.append(fmap_r)
217
+ y_d_gs.append(y_d_g)
218
+ fmap_gs.append(fmap_g)
219
+
220
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
221
+
222
+
223
+ class DiscriminatorS(torch.nn.Module):
224
+ def __init__(self, use_spectral_norm=False):
225
+ super(DiscriminatorS, self).__init__()
226
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
227
+ self.convs = nn.ModuleList([
228
+ norm_f(Conv1d(1, 128, 15, 1, padding=7)),
229
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
230
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
231
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
232
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
233
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
234
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
235
+ ])
236
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
237
+
238
+ def forward(self, x):
239
+ fmap = []
240
+ for l in self.convs:
241
+ x = l(x)
242
+ x = F.leaky_relu(x, LRELU_SLOPE)
243
+ fmap.append(x)
244
+ x = self.conv_post(x)
245
+ fmap.append(x)
246
+ x = torch.flatten(x, 1, -1)
247
+
248
+ return x, fmap
249
+
250
+
251
+ class MultiScaleDiscriminator(torch.nn.Module):
252
+ def __init__(self):
253
+ super(MultiScaleDiscriminator, self).__init__()
254
+ self.discriminators = nn.ModuleList([
255
+ DiscriminatorS(use_spectral_norm=True),
256
+ DiscriminatorS(),
257
+ DiscriminatorS(),
258
+ ])
259
+ self.meanpools = nn.ModuleList([
260
+ AvgPool1d(4, 2, padding=2),
261
+ AvgPool1d(4, 2, padding=2)
262
+ ])
263
+
264
+ def forward(self, y, y_hat):
265
+ y_d_rs = []
266
+ y_d_gs = []
267
+ fmap_rs = []
268
+ fmap_gs = []
269
+ for i, d in enumerate(self.discriminators):
270
+ if i != 0:
271
+ y = self.meanpools[i-1](y)
272
+ y_hat = self.meanpools[i-1](y_hat)
273
+ y_d_r, fmap_r = d(y)
274
+ y_d_g, fmap_g = d(y_hat)
275
+ y_d_rs.append(y_d_r)
276
+ fmap_rs.append(fmap_r)
277
+ y_d_gs.append(y_d_g)
278
+ fmap_gs.append(fmap_g)
279
+
280
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
281
+
282
+
283
+ # Method based on descript-audio-codec: https://github.com/descriptinc/descript-audio-codec
284
+ # Modified code adapted from https://github.com/gemelo-ai/vocos under the MIT license.
285
+ # LICENSE is in incl_licenses directory.
286
+ class DiscriminatorB(nn.Module):
287
+ def __init__(
288
+ self,
289
+ window_length: int,
290
+ channels: int = 32,
291
+ hop_factor: float = 0.25,
292
+ bands: Tuple[Tuple[float, float], ...] = (
293
+ (0.0, 0.1),
294
+ (0.1, 0.25),
295
+ (0.25, 0.5),
296
+ (0.5, 0.75),
297
+ (0.75, 1.0),
298
+ ),
299
+ ):
300
+ super().__init__()
301
+ self.window_length = window_length
302
+ self.hop_factor = hop_factor
303
+ self.spec_fn = Spectrogram(
304
+ n_fft=window_length,
305
+ hop_length=int(window_length * hop_factor),
306
+ win_length=window_length,
307
+ power=None,
308
+ )
309
+ n_fft = window_length // 2 + 1
310
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
311
+ self.bands = bands
312
+ convs = lambda: nn.ModuleList(
313
+ [
314
+ weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
315
+ weight_norm(
316
+ nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
317
+ ),
318
+ weight_norm(
319
+ nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
320
+ ),
321
+ weight_norm(
322
+ nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
323
+ ),
324
+ weight_norm(
325
+ nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))
326
+ ),
327
+ ]
328
+ )
329
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
330
+
331
+ self.conv_post = weight_norm(
332
+ nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))
333
+ )
334
+ def spectrogram(self, x: torch.Tensor) -> List[torch.Tensor]:
335
+ # Remove DC offset
336
+ x = x - x.mean(dim=-1, keepdims=True)
337
+ # Peak normalize the volume of input audio
338
+ x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
339
+ x = self.spec_fn(x)
340
+ x = torch.view_as_real(x)
341
+ x = x.permute(0, 3, 2, 1) # [B, F, T, C] -> [B, C, T, F]
342
+ # Split into bands
343
+ x_bands = [x[..., b[0] : b[1]] for b in self.bands]
344
+ return x_bands
345
+
346
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
347
+ x_bands = self.spectrogram(x.squeeze(1))
348
+ fmap = []
349
+ x = []
350
+
351
+ for band, stack in zip(x_bands, self.band_convs):
352
+ for i, layer in enumerate(stack):
353
+ band = layer(band)
354
+ band = torch.nn.functional.leaky_relu(band, 0.1)
355
+ if i > 0:
356
+ fmap.append(band)
357
+ x.append(band)
358
+
359
+ x = torch.cat(x, dim=-1)
360
+ x = self.conv_post(x)
361
+ fmap.append(x)
362
+
363
+ return x, fmap
364
+
365
+ # Method based on descript-audio-codec: https://github.com/descriptinc/descript-audio-codec
366
+ # Modified code adapted from https://github.com/gemelo-ai/vocos under the MIT license.
367
+ # LICENSE is in incl_licenses directory.
368
+ class MultiBandDiscriminator(nn.Module):
369
+ def __init__(
370
+ self,
371
+ h,
372
+ ):
373
+ """
374
+ Multi-band multi-scale STFT discriminator, with the architecture based on https://github.com/descriptinc/descript-audio-codec.
375
+ and the modified code adapted from https://github.com/gemelo-ai/vocos.
376
+ """
377
+ super().__init__()
378
+ # fft_sizes (list[int]): Tuple of window lengths for FFT. Defaults to [2048, 1024, 512] if not set in h.
379
+ self.fft_sizes = h.get("mbd_fft_sizes", [2048, 1024, 512])
380
+ self.discriminators = nn.ModuleList(
381
+ [DiscriminatorB(window_length=w) for w in self.fft_sizes]
382
+ )
383
+
384
+ def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
385
+ List[torch.Tensor],
386
+ List[torch.Tensor],
387
+ List[List[torch.Tensor]],
388
+ List[List[torch.Tensor]],
389
+ ]:
390
+
391
+ y_d_rs = []
392
+ y_d_gs = []
393
+ fmap_rs = []
394
+ fmap_gs = []
395
+
396
+ for d in self.discriminators:
397
+ y_d_r, fmap_r = d(x=y)
398
+ y_d_g, fmap_g = d(x=y_hat)
399
+ y_d_rs.append(y_d_r)
400
+ fmap_rs.append(fmap_r)
401
+ y_d_gs.append(y_d_g)
402
+ fmap_gs.append(fmap_g)
403
+
404
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
405
+
406
+ def feature_loss(fmap_r, fmap_g):
407
+ loss = 0
408
+ for dr, dg in zip(fmap_r, fmap_g):
409
+ for rl, gl in zip(dr, dg):
410
+ loss += torch.mean(torch.abs(rl - gl))
411
+
412
+ return loss*2
413
+
414
+
415
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
416
+ loss = 0
417
+ r_losses = []
418
+ g_losses = []
419
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
420
+ r_loss = torch.mean((1-dr)**2)
421
+ g_loss = torch.mean(dg**2)
422
+ loss += (r_loss + g_loss)
423
+ r_losses.append(r_loss.item())
424
+ g_losses.append(g_loss.item())
425
+
426
+ return loss, r_losses, g_losses
427
+
428
+
429
+ def generator_loss(disc_outputs):
430
+ loss = 0
431
+ gen_losses = []
432
+ for dg in disc_outputs:
433
+ l = torch.mean((1-dg)**2)
434
+ gen_losses.append(l)
435
+ loss += l
436
+
437
+ return loss, gen_losses
438
+
439
+ class Mossformer(nn.Module):
440
+
441
+ def __init__(self):
442
+ super(Mossformer, self).__init__()
443
+ self.mossformer = MossFormer_MaskNet(in_channels=80, out_channels=512, out_channels_final=80)
444
+
445
+ def forward(self, input):
446
+ out = self.mossformer(input)
447
+ return out
448
+
models/mossformer2_sr/layer_norm.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python -u
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2018 Northwestern Polytechnical University (author: Ke Wang)
5
+
6
+ from __future__ import absolute_import
7
+ from __future__ import division
8
+ from __future__ import print_function
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+
14
+ class CLayerNorm(nn.LayerNorm):
15
+ """Channel-wise layer normalization."""
16
+
17
+ def __init__(self, *args, **kwargs):
18
+ super(CLayerNorm, self).__init__(*args, **kwargs)
19
+
20
+ def forward(self, sample):
21
+ """Forward function.
22
+
23
+ Args:
24
+ sample: [batch_size, channels, length]
25
+ """
26
+ if sample.dim() != 3:
27
+ raise RuntimeError('{} only accept 3-D tensor as input'.format(
28
+ self.__name__))
29
+ # [N, C, T] -> [N, T, C]
30
+ sample = torch.transpose(sample, 1, 2)
31
+ # LayerNorm
32
+ sample = super().forward(sample)
33
+ # [N, T, C] -> [N, C, T]
34
+ sample = torch.transpose(sample, 1, 2)
35
+ return sample
36
+
37
+ class ILayerNorm(nn.InstanceNorm1d):
38
+ """Channel-wise layer normalization."""
39
+
40
+ def __init__(self, *args, **kwargs):
41
+ super(ILayerNorm, self).__init__(*args, **kwargs)
42
+
43
+ def forward(self, sample):
44
+ """Forward function.
45
+
46
+ Args:
47
+ sample: [batch_size, channels, length]
48
+ """
49
+ if sample.dim() != 3:
50
+ raise RuntimeError('{} only accept 3-D tensor as input'.format(
51
+ self.__name__))
52
+ # [N, C, T] -> [N, T, C]
53
+ sample = torch.transpose(sample, 1, 2)
54
+ # LayerNorm
55
+ sample = super().forward(sample)
56
+ # [N, T, C] -> [N, C, T]
57
+ sample = torch.transpose(sample, 1, 2)
58
+ return sample
59
+
60
+ class GLayerNorm(nn.Module):
61
+ """Global Layer Normalization for TasNet."""
62
+
63
+ def __init__(self, channels, eps=1e-5):
64
+ super(GLayerNorm, self).__init__()
65
+ self.eps = eps
66
+ self.norm_dim = channels
67
+ self.gamma = nn.Parameter(torch.Tensor(channels))
68
+ self.beta = nn.Parameter(torch.Tensor(channels))
69
+ self.reset_parameters()
70
+
71
+ def reset_parameters(self):
72
+ nn.init.ones_(self.gamma)
73
+ nn.init.zeros_(self.beta)
74
+
75
+ def forward(self, sample):
76
+ """Forward function.
77
+
78
+ Args:
79
+ sample: [batch_size, channels, length]
80
+ """
81
+ if sample.dim() != 3:
82
+ raise RuntimeError('{} only accept 3-D tensor as input'.format(
83
+ self.__name__))
84
+ # [N, C, T] -> [N, T, C]
85
+ sample = torch.transpose(sample, 1, 2)
86
+ # Mean and variance [N, 1, 1]
87
+ mean = torch.mean(sample, (1, 2), keepdim=True)
88
+ var = torch.mean((sample - mean)**2, (1, 2), keepdim=True)
89
+ sample = (sample - mean) / torch.sqrt(var + self.eps) * \
90
+ self.gamma + self.beta
91
+ # [N, T, C] -> [N, C, T]
92
+ sample = torch.transpose(sample, 1, 2)
93
+ return sample
94
+
95
+ class _LayerNorm(nn.Module):
96
+ """Layer Normalization base class."""
97
+
98
+ def __init__(self, channel_size):
99
+ super(_LayerNorm, self).__init__()
100
+ self.channel_size = channel_size
101
+ self.gamma = nn.Parameter(torch.ones(channel_size),
102
+ requires_grad=True)
103
+ self.beta = nn.Parameter(torch.zeros(channel_size),
104
+ requires_grad=True)
105
+
106
+ def apply_gain_and_bias(self, normed_x):
107
+ """ Assumes input of size `[batch, chanel, *]`. """
108
+ return (self.gamma * normed_x.transpose(1, -1) +
109
+ self.beta).transpose(1, -1)
110
+
111
+
112
+ class GlobLayerNorm(_LayerNorm):
113
+ """Global Layer Normalization (globLN)."""
114
+
115
+ def forward(self, x):
116
+ """ Applies forward pass.
117
+ Works for any input size > 2D.
118
+ Args:
119
+ x (:class:`torch.Tensor`): Shape `[batch, chan, *]`
120
+ Returns:
121
+ :class:`torch.Tensor`: gLN_x `[batch, chan, *]`
122
+ """
123
+ dims = list(range(1, len(x.shape)))
124
+ mean = x.mean(dim=dims, keepdim=True)
125
+ var = torch.pow(x - mean, 2).mean(dim=dims, keepdim=True)
126
+ return self.apply_gain_and_bias((x - mean) / (var + 1e-8).sqrt())
models/mossformer2_sr/mossformer2.py ADDED
@@ -0,0 +1,711 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ modified from https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/lobes/models/dual_path.py
3
+ Author: Shengkui Zhao
4
+ """
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import copy
11
+ from models.mossformer2_sr.mossformer2_block import ScaledSinuEmbedding, MossformerBlock_GFSMN, MossformerBlock
12
+
13
+
14
+ EPS = 1e-8
15
+
16
+
17
+ class GlobalLayerNorm(nn.Module):
18
+ """Calculate Global Layer Normalization.
19
+
20
+ Arguments
21
+ ---------
22
+ dim : (int or list or torch.Size)
23
+ Input shape from an expected input of size.
24
+ eps : float
25
+ A value added to the denominator for numerical stability.
26
+ elementwise_affine : bool
27
+ A boolean value that when set to True,
28
+ this module has learnable per-element affine parameters
29
+ initialized to ones (for weights) and zeros (for biases).
30
+
31
+ Example
32
+ -------
33
+ >>> x = torch.randn(5, 10, 20)
34
+ >>> GLN = GlobalLayerNorm(10, 3)
35
+ >>> x_norm = GLN(x)
36
+ """
37
+
38
+ def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True):
39
+ super(GlobalLayerNorm, self).__init__()
40
+ self.dim = dim
41
+ self.eps = eps
42
+ self.elementwise_affine = elementwise_affine
43
+
44
+ if self.elementwise_affine:
45
+ if shape == 3:
46
+ self.weight = nn.Parameter(torch.ones(self.dim, 1))
47
+ self.bias = nn.Parameter(torch.zeros(self.dim, 1))
48
+ if shape == 4:
49
+ self.weight = nn.Parameter(torch.ones(self.dim, 1, 1))
50
+ self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1))
51
+ else:
52
+ self.register_parameter("weight", None)
53
+ self.register_parameter("bias", None)
54
+
55
+ def forward(self, x):
56
+ """Returns the normalized tensor.
57
+
58
+ Arguments
59
+ ---------
60
+ x : torch.Tensor
61
+ Tensor of size [N, C, K, S] or [N, C, L].
62
+ """
63
+ # x = N x C x K x S or N x C x L
64
+ # N x 1 x 1
65
+ # cln: mean,var N x 1 x K x S
66
+ # gln: mean,var N x 1 x 1
67
+ if x.dim() == 3:
68
+ mean = torch.mean(x, (1, 2), keepdim=True)
69
+ var = torch.mean((x - mean) ** 2, (1, 2), keepdim=True)
70
+ if self.elementwise_affine:
71
+ x = (
72
+ self.weight * (x - mean) / torch.sqrt(var + self.eps)
73
+ + self.bias
74
+ )
75
+ else:
76
+ x = (x - mean) / torch.sqrt(var + self.eps)
77
+
78
+ if x.dim() == 4:
79
+ mean = torch.mean(x, (1, 2, 3), keepdim=True)
80
+ var = torch.mean((x - mean) ** 2, (1, 2, 3), keepdim=True)
81
+ if self.elementwise_affine:
82
+ x = (
83
+ self.weight * (x - mean) / torch.sqrt(var + self.eps)
84
+ + self.bias
85
+ )
86
+ else:
87
+ x = (x - mean) / torch.sqrt(var + self.eps)
88
+ return x
89
+
90
+
91
+ class CumulativeLayerNorm(nn.LayerNorm):
92
+ """Calculate Cumulative Layer Normalization.
93
+
94
+ Arguments
95
+ ---------
96
+ dim : int
97
+ Dimension that you want to normalize.
98
+ elementwise_affine : True
99
+ Learnable per-element affine parameters.
100
+
101
+ Example
102
+ -------
103
+ >>> x = torch.randn(5, 10, 20)
104
+ >>> CLN = CumulativeLayerNorm(10)
105
+ >>> x_norm = CLN(x)
106
+ """
107
+
108
+ def __init__(self, dim, elementwise_affine=True):
109
+ super(CumulativeLayerNorm, self).__init__(
110
+ dim, elementwise_affine=elementwise_affine, eps=1e-8
111
+ )
112
+
113
+ def forward(self, x):
114
+ """Returns the normalized tensor.
115
+
116
+ Arguments
117
+ ---------
118
+ x : torch.Tensor
119
+ Tensor size [N, C, K, S] or [N, C, L]
120
+ """
121
+ # x: N x C x K x S or N x C x L
122
+ # N x K x S x C
123
+ if x.dim() == 4:
124
+ x = x.permute(0, 2, 3, 1).contiguous()
125
+ # N x K x S x C == only channel norm
126
+ x = super().forward(x)
127
+ # N x C x K x S
128
+ x = x.permute(0, 3, 1, 2).contiguous()
129
+ if x.dim() == 3:
130
+ x = torch.transpose(x, 1, 2)
131
+ # N x L x C == only channel norm
132
+ x = super().forward(x)
133
+ # N x C x L
134
+ x = torch.transpose(x, 1, 2)
135
+ return x
136
+
137
+
138
+ def select_norm(norm, dim, shape):
139
+ """Just a wrapper to select the normalization type.
140
+ """
141
+
142
+ if norm == "gln":
143
+ return GlobalLayerNorm(dim, shape, elementwise_affine=True)
144
+ if norm == "cln":
145
+ return CumulativeLayerNorm(dim, elementwise_affine=True)
146
+ if norm == "ln":
147
+ return nn.GroupNorm(1, dim, eps=1e-8)
148
+ else:
149
+ return nn.BatchNorm1d(dim)
150
+
151
+
152
+ class Encoder(nn.Module):
153
+ """Convolutional Encoder Layer.
154
+
155
+ Arguments
156
+ ---------
157
+ kernel_size : int
158
+ Length of filters.
159
+ in_channels : int
160
+ Number of input channels.
161
+ out_channels : int
162
+ Number of output channels.
163
+
164
+ Example
165
+ -------
166
+ >>> x = torch.randn(2, 1000)
167
+ >>> encoder = Encoder(kernel_size=4, out_channels=64)
168
+ >>> h = encoder(x)
169
+ >>> h.shape
170
+ torch.Size([2, 64, 499])
171
+ """
172
+
173
+ def __init__(self, kernel_size=2, out_channels=64, in_channels=1):
174
+ super(Encoder, self).__init__()
175
+ self.conv1d = nn.Conv1d(
176
+ in_channels=in_channels,
177
+ out_channels=out_channels,
178
+ kernel_size=kernel_size,
179
+ stride=kernel_size // 2,
180
+ groups=1,
181
+ bias=False,
182
+ )
183
+ self.in_channels = in_channels
184
+
185
+ def forward(self, x):
186
+ """Return the encoded output.
187
+
188
+ Arguments
189
+ ---------
190
+ x : torch.Tensor
191
+ Input tensor with dimensionality [B, L].
192
+ Return
193
+ ------
194
+ x : torch.Tensor
195
+ Encoded tensor with dimensionality [B, N, T_out].
196
+
197
+ where B = Batchsize
198
+ L = Number of timepoints
199
+ N = Number of filters
200
+ T_out = Number of timepoints at the output of the encoder
201
+ """
202
+ # B x L -> B x 1 x L
203
+ if self.in_channels == 1:
204
+ x = torch.unsqueeze(x, dim=1)
205
+ # B x 1 x L -> B x N x T_out
206
+ x = self.conv1d(x)
207
+ x = F.relu(x)
208
+
209
+ return x
210
+
211
+
212
+ class Decoder(nn.ConvTranspose1d):
213
+ """A decoder layer that consists of ConvTranspose1d.
214
+
215
+ Arguments
216
+ ---------
217
+ kernel_size : int
218
+ Length of filters.
219
+ in_channels : int
220
+ Number of input channels.
221
+ out_channels : int
222
+ Number of output channels.
223
+
224
+
225
+ Example
226
+ ---------
227
+ >>> x = torch.randn(2, 100, 1000)
228
+ >>> decoder = Decoder(kernel_size=4, in_channels=100, out_channels=1)
229
+ >>> h = decoder(x)
230
+ >>> h.shape
231
+ torch.Size([2, 1003])
232
+ """
233
+
234
+ def __init__(self, *args, **kwargs):
235
+ super(Decoder, self).__init__(*args, **kwargs)
236
+
237
+ def forward(self, x):
238
+ """Return the decoded output.
239
+
240
+ Arguments
241
+ ---------
242
+ x : torch.Tensor
243
+ Input tensor with dimensionality [B, N, L].
244
+ where, B = Batchsize,
245
+ N = number of filters
246
+ L = time points
247
+ """
248
+
249
+ if x.dim() not in [2, 3]:
250
+ raise RuntimeError(
251
+ "{} accept 3/4D tensor as input".format(self.__name__)
252
+ )
253
+ x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
254
+
255
+ if torch.squeeze(x).dim() == 1:
256
+ x = torch.squeeze(x, dim=1)
257
+ else:
258
+ x = torch.squeeze(x)
259
+ return x
260
+
261
+
262
+ class IdentityBlock:
263
+ """This block is used when we want to have identity transformation within the Dual_path block.
264
+
265
+ Example
266
+ -------
267
+ >>> x = torch.randn(10, 100)
268
+ >>> IB = IdentityBlock()
269
+ >>> xhat = IB(x)
270
+ """
271
+
272
+ def _init__(self, **kwargs):
273
+ pass
274
+
275
+ def __call__(self, x):
276
+ return x
277
+
278
+
279
+ class MossFormerM(nn.Module):
280
+ """This class implements the transformer encoder.
281
+
282
+ Arguments
283
+ ---------
284
+ num_blocks : int
285
+ Number of mossformer blocks to include.
286
+ d_model : int
287
+ The dimension of the input embedding.
288
+ attn_dropout : float
289
+ Dropout for the self-attention (Optional).
290
+ group_size: int
291
+ the chunk size
292
+ query_key_dim: int
293
+ the attention vector dimension
294
+ expansion_factor: int
295
+ the expansion factor for the linear projection in conv module
296
+ causal: bool
297
+ true for causal / false for non causal
298
+
299
+ Example
300
+ -------
301
+ >>> import torch
302
+ >>> x = torch.rand((8, 60, 512))
303
+ >>> net = TransformerEncoder_MossFormerM(num_blocks=8, d_model=512)
304
+ >>> output, _ = net(x)
305
+ >>> output.shape
306
+ torch.Size([8, 60, 512])
307
+ """
308
+ def __init__(
309
+ self,
310
+ num_blocks,
311
+ d_model=None,
312
+ causal=False,
313
+ group_size = 256,
314
+ query_key_dim = 128,
315
+ expansion_factor = 4.,
316
+ attn_dropout = 0.1
317
+ ):
318
+ super().__init__()
319
+
320
+ self.mossformerM = MossformerBlock_GFSMN(
321
+ dim=d_model,
322
+ depth=num_blocks,
323
+ group_size=group_size,
324
+ query_key_dim=query_key_dim,
325
+ expansion_factor=expansion_factor,
326
+ causal=causal,
327
+ attn_dropout=attn_dropout
328
+ )
329
+ self.norm = nn.LayerNorm(d_model, eps=1e-6)
330
+ def forward(
331
+ self,
332
+ src,
333
+ ):
334
+ """
335
+ Arguments
336
+ ----------
337
+ src : torch.Tensor
338
+ Tensor shape [B, L, N],
339
+ where, B = Batchsize,
340
+ L = time points
341
+ N = number of filters
342
+ The sequence to the encoder layer (required).
343
+ src_mask : tensor
344
+ The mask for the src sequence (optional).
345
+ src_key_padding_mask : tensor
346
+ The mask for the src keys per batch (optional).
347
+ """
348
+ output = self.mossformerM(src)
349
+ output = self.norm(output)
350
+
351
+ return output
352
+
353
+ class MossFormerM2(nn.Module):
354
+ """This class implements the transformer encoder.
355
+
356
+ Arguments
357
+ ---------
358
+ num_blocks : int
359
+ Number of mossformer blocks to include.
360
+ d_model : int
361
+ The dimension of the input embedding.
362
+ attn_dropout : float
363
+ Dropout for the self-attention (Optional).
364
+ group_size: int
365
+ the chunk size
366
+ query_key_dim: int
367
+ the attention vector dimension
368
+ expansion_factor: int
369
+ the expansion factor for the linear projection in conv module
370
+ causal: bool
371
+ true for causal / false for non causal
372
+
373
+ Example
374
+ -------
375
+ >>> import torch
376
+ >>> x = torch.rand((8, 60, 512))
377
+ >>> net = TransformerEncoder_MossFormerM(num_blocks=8, d_model=512)
378
+ >>> output, _ = net(x)
379
+ >>> output.shape
380
+ torch.Size([8, 60, 512])
381
+ """
382
+ def __init__(
383
+ self,
384
+ num_blocks,
385
+ d_model=None,
386
+ causal=False,
387
+ group_size = 256,
388
+ query_key_dim = 128,
389
+ expansion_factor = 4.,
390
+ attn_dropout = 0.1
391
+ ):
392
+ super().__init__()
393
+
394
+ self.mossformerM = MossformerBlock(
395
+ dim=d_model,
396
+ depth=num_blocks,
397
+ group_size=group_size,
398
+ query_key_dim=query_key_dim,
399
+ expansion_factor=expansion_factor,
400
+ causal=causal,
401
+ attn_dropout=attn_dropout
402
+ )
403
+ self.norm = nn.LayerNorm(d_model, eps=1e-6)
404
+
405
+ def forward(
406
+ self,
407
+ src,
408
+ ):
409
+ """
410
+ Arguments
411
+ ----------
412
+ src : torch.Tensor
413
+ Tensor shape [B, L, N],
414
+ where, B = Batchsize,
415
+ L = time points
416
+ N = number of filters
417
+ The sequence to the encoder layer (required).
418
+ src_mask : tensor
419
+ The mask for the src sequence (optional).
420
+ src_key_padding_mask : tensor
421
+ The mask for the src keys per batch (optional).
422
+ """
423
+ output = self.mossformerM(src)
424
+ output = self.norm(output)
425
+
426
+ return output
427
+
428
+ class Computation_Block(nn.Module):
429
+ """Computation block for dual-path processing.
430
+
431
+ Arguments
432
+ ---------
433
+ intra_mdl : torch.nn.module
434
+ Model to process within the chunks.
435
+ inter_mdl : torch.nn.module
436
+ Model to process across the chunks.
437
+ out_channels : int
438
+ Dimensionality of inter/intra model.
439
+ norm : str
440
+ Normalization type.
441
+ skip_around_intra : bool
442
+ Skip connection around the intra layer.
443
+ linear_layer_after_inter_intra : bool
444
+ Linear layer or not after inter or intra.
445
+
446
+ Example
447
+ ---------
448
+ >>> comp_block = Computation_Block(64)
449
+ >>> x = torch.randn(10, 64, 100)
450
+ >>> x = comp_block(x)
451
+ >>> x.shape
452
+ torch.Size([10, 64, 100])
453
+ """
454
+
455
+ def __init__(
456
+ self,
457
+ num_blocks,
458
+ out_channels,
459
+ norm="ln",
460
+ skip_around_intra=True,
461
+ ):
462
+ super(Computation_Block, self).__init__()
463
+
464
+ ##MossFormer+: MossFormer with recurrence
465
+ self.intra_mdl = MossFormerM(num_blocks=num_blocks, d_model=out_channels)
466
+ ##MossFormerM2: the orignal MossFormer
467
+ #self.intra_mdl = MossFormerM2(num_blocks=num_blocks, d_model=out_channels)
468
+ self.skip_around_intra = skip_around_intra
469
+
470
+ # Norm
471
+ self.norm = norm
472
+ if norm is not None:
473
+ self.intra_norm = select_norm(norm, out_channels, 3)
474
+
475
+ def forward(self, x):
476
+ """Returns the output tensor.
477
+
478
+ Arguments
479
+ ---------
480
+ x : torch.Tensor
481
+ Input tensor of dimension [B, N, S].
482
+
483
+
484
+ Return
485
+ ---------
486
+ out: torch.Tensor
487
+ Output tensor of dimension [B, N, S].
488
+ where, B = Batchsize,
489
+ N = number of filters
490
+ S = sequence time index
491
+ """
492
+ B, N, S = x.shape
493
+ # intra RNN
494
+ # [B, S, N]
495
+ intra = x.permute(0, 2, 1).contiguous() #.view(B, S, N)
496
+
497
+ intra = self.intra_mdl(intra)
498
+
499
+ # [B, N, S]
500
+ intra = intra.permute(0, 2, 1).contiguous()
501
+ if self.norm is not None:
502
+ intra = self.intra_norm(intra)
503
+
504
+ # [B, N, S]
505
+ if self.skip_around_intra:
506
+ intra = intra + x
507
+
508
+ out = intra
509
+ return out
510
+
511
+
512
+ class MossFormer_MaskNet(nn.Module):
513
+ """The dual path model which is the basis for dualpathrnn, sepformer, dptnet.
514
+
515
+ Arguments
516
+ ---------
517
+ in_channels : int
518
+ Number of channels at the output of the encoder.
519
+ out_channels : int
520
+ Number of channels that would be inputted to the intra and inter blocks.
521
+ intra_model : torch.nn.module
522
+ Model to process within the chunks.
523
+ num_layers : int
524
+ Number of layers of Dual Computation Block.
525
+ norm : str
526
+ Normalization type.
527
+ num_spks : int
528
+ Number of sources (speakers).
529
+ skip_around_intra : bool
530
+ Skip connection around intra.
531
+ use_global_pos_enc : bool
532
+ Global positional encodings.
533
+ max_length : int
534
+ Maximum sequence length.
535
+
536
+ Example
537
+ ---------
538
+ >>> mossformer_block = MossFormerM(1, 64, 8)
539
+ >>> mossformer_masknet = MossFormer_MaskNet(64, 64, intra_block, num_spks=2)
540
+ >>> x = torch.randn(10, 64, 2000)
541
+ >>> x = mossformer_masknet(x)
542
+ >>> x.shape
543
+ torch.Size([2, 10, 64, 2000])
544
+ """
545
+
546
+ def __init__(
547
+ self,
548
+ in_channels,
549
+ out_channels,
550
+ out_channels_final,
551
+ num_blocks=24,
552
+ norm="ln",
553
+ num_spks=1,
554
+ skip_around_intra=True,
555
+ use_global_pos_enc=True,
556
+ max_length=20000,
557
+ ):
558
+ super(MossFormer_MaskNet, self).__init__()
559
+ self.num_spks = num_spks
560
+ self.num_blocks = num_blocks
561
+ self.norm = select_norm(norm, in_channels, 3)
562
+ self.conv1d_encoder = nn.Conv1d(in_channels, out_channels, 1, bias=False)
563
+ self.use_global_pos_enc = use_global_pos_enc
564
+
565
+ if self.use_global_pos_enc:
566
+ self.pos_enc = ScaledSinuEmbedding(out_channels)
567
+
568
+ self.mdl = Computation_Block(
569
+ num_blocks,
570
+ out_channels,
571
+ norm,
572
+ skip_around_intra=skip_around_intra,
573
+ )
574
+
575
+ self.conv1d_out = nn.Conv1d(
576
+ out_channels, out_channels * num_spks, kernel_size=1
577
+ )
578
+ self.conv1_decoder = nn.Conv1d(out_channels, out_channels_final, 1, bias=False)
579
+ self.prelu = nn.PReLU()
580
+ self.activation = nn.ReLU()
581
+ # gated output layer
582
+ self.output = nn.Sequential(
583
+ nn.Conv1d(out_channels, out_channels, 1), nn.Tanh()
584
+ )
585
+ self.output_gate = nn.Sequential(
586
+ nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid()
587
+ )
588
+
589
+ def forward(self, x):
590
+ """Returns the output tensor.
591
+
592
+ Arguments
593
+ ---------
594
+ x : torch.Tensor
595
+ Input tensor of dimension [B, N, S].
596
+
597
+ Returns
598
+ -------
599
+ out : torch.Tensor
600
+ Output tensor of dimension [spks, B, N, S]
601
+ where, spks = Number of speakers
602
+ B = Batchsize,
603
+ N = number of filters
604
+ S = the number of time frames
605
+ """
606
+
607
+ # before each line we indicate the shape after executing the line
608
+
609
+ # [B, N, L]
610
+ x = self.norm(x)
611
+
612
+ # [B, N, L]
613
+ x = self.conv1d_encoder(x)
614
+ if self.use_global_pos_enc:
615
+ #x = self.pos_enc(x.transpose(1, -1)).transpose(1, -1) + x * (
616
+ # x.size(1) ** 0.5)
617
+ base = x
618
+ x = x.transpose(1, -1)
619
+ emb = self.pos_enc(x)
620
+ emb = emb.transpose(0, -1)
621
+ #print('base: {}, emb: {}'.format(base.shape, emb.shape))
622
+ x = base + emb
623
+
624
+
625
+ # [B, N, S]
626
+ #for i in range(self.num_modules):
627
+ # x = self.dual_mdl[i](x)
628
+ x = self.mdl(x)
629
+ x = self.prelu(x)
630
+
631
+ # [B, N*spks, S]
632
+ x = self.conv1d_out(x)
633
+ B, _, S = x.shape
634
+
635
+ # [B*spks, N, S]
636
+ x = x.view(B * self.num_spks, -1, S)
637
+
638
+ # [B*spks, N, S]
639
+ x = self.output(x) * self.output_gate(x)
640
+
641
+ # [B*spks, N, S]
642
+ x = self.conv1_decoder(x)
643
+
644
+ # [B, spks, N, S]
645
+ _, N, L = x.shape
646
+ x = x.view(B, self.num_spks, N, L)
647
+ x = self.activation(x)
648
+
649
+ # [spks, B, N, S]
650
+ x = x.transpose(0, 1)
651
+
652
+ return x[0]
653
+
654
+ class MossFormer(nn.Module):
655
+ def __init__(
656
+ self,
657
+ in_channels=512,
658
+ out_channels=512,
659
+ num_blocks=24,
660
+ kernel_size=16,
661
+ norm="ln",
662
+ num_spks=2,
663
+ skip_around_intra=True,
664
+ use_global_pos_enc=True,
665
+ max_length=20000,
666
+ ):
667
+ super(MossFormer, self).__init__()
668
+ self.num_spks = num_spks
669
+ self.enc = Encoder(kernel_size=kernel_size, out_channels=in_channels, in_channels=180)
670
+ self.mask_net = MossFormer_MaskNet(
671
+ in_channels=in_channels,
672
+ out_channels=out_channels,
673
+ num_blocks=num_blocks,
674
+ norm=norm,
675
+ num_spks=num_spks,
676
+ skip_around_intra=skip_around_intra,
677
+ use_global_pos_enc=use_global_pos_enc,
678
+ max_length=max_length,
679
+ )
680
+ self.dec = Decoder(
681
+ in_channels=out_channels,
682
+ out_channels=1,
683
+ kernel_size=kernel_size,
684
+ stride = kernel_size//2,
685
+ bias=False
686
+ )
687
+ def forward(self, input):
688
+ x = self.enc(input)
689
+ mask = self.mask_net(x)
690
+ x = torch.stack([x] * self.num_spks)
691
+ sep_x = x * mask
692
+
693
+ # Decoding
694
+ est_source = torch.cat(
695
+ [
696
+ self.dec(sep_x[i]).unsqueeze(-1)
697
+ for i in range(self.num_spks)
698
+ ],
699
+ dim=-1,
700
+ )
701
+ T_origin = input.size(1)
702
+ T_est = est_source.size(1)
703
+ if T_origin > T_est:
704
+ est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est))
705
+ else:
706
+ est_source = est_source[:, :T_origin, :]
707
+
708
+ out = []
709
+ for spk in range(self.num_spks):
710
+ out.append(est_source[:,:,spk])
711
+ return out
models/mossformer2_sr/mossformer2_block.py ADDED
@@ -0,0 +1,735 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This source code is modified by Shengkui Zhao based on https://github.com/lucidrains/FLASH-pytorch
3
+ """
4
+
5
+ import math
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn, einsum
9
+ from einops import rearrange
10
+ from rotary_embedding_torch import RotaryEmbedding
11
+ from models.mossformer2_se.conv_module import ConvModule, GLU, FFConvM_Dilated
12
+ from models.mossformer2_se.fsmn import UniDeepFsmn, UniDeepFsmn_dilated
13
+ from torchinfo import summary
14
+ from models.mossformer2_se.layer_norm import CLayerNorm, GLayerNorm, GlobLayerNorm, ILayerNorm
15
+
16
+ # Helper functions
17
+
18
+ def identity(t, *args, **kwargs):
19
+ """
20
+ Returns the input tensor unchanged.
21
+
22
+ Args:
23
+ t (torch.Tensor): Input tensor.
24
+ *args: Additional arguments (ignored).
25
+ **kwargs: Additional keyword arguments (ignored).
26
+
27
+ Returns:
28
+ torch.Tensor: The input tensor.
29
+ """
30
+ return t
31
+
32
+ def append_dims(x, num_dims):
33
+ """
34
+ Adds additional dimensions to the input tensor.
35
+
36
+ Args:
37
+ x (torch.Tensor): Input tensor.
38
+ num_dims (int): Number of dimensions to append.
39
+
40
+ Returns:
41
+ torch.Tensor: Tensor with appended dimensions.
42
+ """
43
+ if num_dims <= 0:
44
+ return x
45
+ return x.view(*x.shape, *((1,) * num_dims)) # Reshape to append dimensions
46
+
47
+ def exists(val):
48
+ """
49
+ Checks if a value exists (is not None).
50
+
51
+ Args:
52
+ val: The value to check.
53
+
54
+ Returns:
55
+ bool: True if value exists, False otherwise.
56
+ """
57
+ return val is not None
58
+
59
+ def default(val, d):
60
+ """
61
+ Returns a default value if the given value does not exist.
62
+
63
+ Args:
64
+ val: The value to check.
65
+ d: Default value to return if val does not exist.
66
+
67
+ Returns:
68
+ The original value if it exists, otherwise the default value.
69
+ """
70
+ return val if exists(val) else d
71
+
72
+ def padding_to_multiple_of(n, mult):
73
+ """
74
+ Calculates the amount of padding needed to make a number a multiple of another.
75
+
76
+ Args:
77
+ n (int): The number to pad.
78
+ mult (int): The multiple to match.
79
+
80
+ Returns:
81
+ int: The padding amount required to make n a multiple of mult.
82
+ """
83
+ remainder = n % mult
84
+ if remainder == 0:
85
+ return 0
86
+ return mult - remainder # Return the required padding
87
+
88
+ # Scale Normalization class
89
+
90
+ class ScaleNorm(nn.Module):
91
+ """
92
+ ScaleNorm implements a scaled normalization technique for neural network layers.
93
+
94
+ Attributes:
95
+ dim (int): Dimension of the input features.
96
+ eps (float): Small value to prevent division by zero.
97
+ g (nn.Parameter): Learnable parameter for scaling.
98
+ """
99
+
100
+ def __init__(self, dim, eps=1e-5):
101
+ super().__init__()
102
+ self.scale = dim ** -0.5 # Calculate scale factor
103
+ self.eps = eps # Set epsilon
104
+ self.g = nn.Parameter(torch.ones(1)) # Initialize scaling parameter
105
+
106
+ def forward(self, x):
107
+ """
108
+ Forward pass for the ScaleNorm layer.
109
+
110
+ Args:
111
+ x (torch.Tensor): Input tensor.
112
+
113
+ Returns:
114
+ torch.Tensor: Scaled and normalized output tensor.
115
+ """
116
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale # Compute norm
117
+ return x / norm.clamp(min=self.eps) * self.g # Normalize and scale
118
+
119
+ # Absolute positional encodings class
120
+
121
+ class ScaledSinuEmbedding(nn.Module):
122
+ """
123
+ ScaledSinuEmbedding provides sinusoidal positional encodings for inputs.
124
+
125
+ Attributes:
126
+ scale (nn.Parameter): Learnable scale factor for the embeddings.
127
+ inv_freq (torch.Tensor): Inverse frequency used for sine and cosine calculations.
128
+ """
129
+
130
+ def __init__(self, dim):
131
+ super().__init__()
132
+ self.scale = nn.Parameter(torch.ones(1,)) # Initialize scale
133
+ inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) # Calculate inverse frequency
134
+ self.register_buffer('inv_freq', inv_freq) # Register as a buffer
135
+
136
+ def forward(self, x):
137
+ """
138
+ Forward pass for the ScaledSinuEmbedding layer.
139
+
140
+ Args:
141
+ x (torch.Tensor): Input tensor of shape (batch_size, sequence_length).
142
+
143
+ Returns:
144
+ torch.Tensor: Positional encoding tensor of shape (batch_size, sequence_length, dim).
145
+ """
146
+ n, device = x.shape[1], x.device # Extract sequence length and device
147
+ t = torch.arange(n, device=device).type_as(self.inv_freq) # Create time steps
148
+ sinu = einsum('i , j -> i j', t, self.inv_freq) # Calculate sine and cosine embeddings
149
+ emb = torch.cat((sinu.sin(), sinu.cos()), dim=-1) # Concatenate sine and cosine embeddings
150
+ return emb * self.scale # Scale the embeddings
151
+
152
+ class OffsetScale(nn.Module):
153
+ """
154
+ OffsetScale applies learned offsets and scales to the input tensor.
155
+
156
+ Attributes:
157
+ gamma (nn.Parameter): Learnable scale parameter for each head.
158
+ beta (nn.Parameter): Learnable offset parameter for each head.
159
+ """
160
+
161
+ def __init__(self, dim, heads=1):
162
+ super().__init__()
163
+ self.gamma = nn.Parameter(torch.ones(heads, dim)) # Initialize scale parameters
164
+ self.beta = nn.Parameter(torch.zeros(heads, dim)) # Initialize offset parameters
165
+ nn.init.normal_(self.gamma, std=0.02) # Normal initialization for gamma
166
+
167
+ def forward(self, x):
168
+ """
169
+ Forward pass for the OffsetScale layer.
170
+
171
+ Args:
172
+ x (torch.Tensor): Input tensor.
173
+
174
+ Returns:
175
+ List[torch.Tensor]: A list of tensors with applied offsets and scales for each head.
176
+ """
177
+ out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta # Apply scaling and offsets
178
+ return out.unbind(dim=-2) # Unbind heads into a list
179
+
180
+ # Feed-Forward Convolutional Module
181
+
182
+ class FFConvM(nn.Module):
183
+ """
184
+ FFConvM is a feed-forward convolutional module with normalization and dropout.
185
+
186
+ Attributes:
187
+ dim_in (int): Input dimension of the features.
188
+ dim_out (int): Output dimension after processing.
189
+ norm_klass (nn.Module): Normalization class to be used.
190
+ dropout (float): Dropout probability.
191
+ """
192
+
193
+ def __init__(
194
+ self,
195
+ dim_in,
196
+ dim_out,
197
+ norm_klass=nn.LayerNorm,
198
+ dropout=0.1
199
+ ):
200
+ super().__init__()
201
+ self.mdl = nn.Sequential(
202
+ norm_klass(dim_in), # Normalize input
203
+ nn.Linear(dim_in, dim_out), # Linear transformation
204
+ nn.SiLU(), # Activation function
205
+ ConvModule(dim_out), # Convolution module
206
+ nn.Dropout(dropout) # Apply dropout
207
+ )
208
+
209
+ def forward(self, x):
210
+ """
211
+ Forward pass for the FFConvM module.
212
+
213
+ Args:
214
+ x (torch.Tensor): Input tensor.
215
+
216
+ Returns:
217
+ torch.Tensor: Output tensor after processing.
218
+ """
219
+ output = self.mdl(x) # Pass through the model
220
+ return output
221
+
222
+ class FFM(nn.Module):
223
+ """
224
+ FFM is a feed-forward module with normalization and dropout.
225
+
226
+ Attributes:
227
+ dim_in (int): Input dimension of the features.
228
+ dim_out (int): Output dimension after processing.
229
+ norm_klass (nn.Module): Normalization class to be used.
230
+ dropout (float): Dropout probability.
231
+ """
232
+
233
+ def __init__(
234
+ self,
235
+ dim_in,
236
+ dim_out,
237
+ norm_klass=nn.LayerNorm,
238
+ dropout=0.1
239
+ ):
240
+ super().__init__()
241
+ self.mdl = nn.Sequential(
242
+ norm_klass(dim_in), # Normalize input
243
+ nn.Linear(dim_in, dim_out), # Linear transformation
244
+ nn.SiLU(), # Activation function
245
+ nn.Dropout(dropout) # Apply dropout
246
+ )
247
+
248
+ def forward(self, x):
249
+ """
250
+ Forward pass for the FFM module.
251
+
252
+ Args:
253
+ x (torch.Tensor): Input tensor.
254
+
255
+ Returns:
256
+ torch.Tensor: Output tensor after processing.
257
+ """
258
+ output = self.mdl(x) # Pass through the model
259
+ return output
260
+
261
+ class FLASH_ShareA_FFConvM(nn.Module):
262
+ """
263
+ Fast Shared Dual Attention Mechanism with feed-forward convolutional blocks.
264
+ Published in paper: "MossFormer: Pushing the Performance Limit of Monaural Speech Separation
265
+ using Gated Single-Head Transformer with Convolution-Augmented Joint Self-Attentions", ICASSP 2023.
266
+ (https://arxiv.org/abs/2302.11824)
267
+
268
+ Args:
269
+ dim (int): Input dimension.
270
+ group_size (int, optional): Size of groups for processing. Defaults to 256.
271
+ query_key_dim (int, optional): Dimension of the query and key. Defaults to 128.
272
+ expansion_factor (float, optional): Factor to expand the hidden dimension. Defaults to 1.
273
+ causal (bool, optional): Whether to use causal masking. Defaults to False.
274
+ dropout (float, optional): Dropout rate. Defaults to 0.1.
275
+ rotary_pos_emb (optional): Rotary positional embeddings for attention. Defaults to None.
276
+ norm_klass (callable, optional): Normalization class to use. Defaults to nn.LayerNorm.
277
+ shift_tokens (bool, optional): Whether to shift tokens for attention calculation. Defaults to True.
278
+ """
279
+
280
+ def __init__(
281
+ self,
282
+ *,
283
+ dim,
284
+ group_size=256,
285
+ query_key_dim=128,
286
+ expansion_factor=1.,
287
+ causal=False,
288
+ dropout=0.1,
289
+ rotary_pos_emb=None,
290
+ norm_klass=nn.LayerNorm,
291
+ shift_tokens=True
292
+ ):
293
+ super().__init__()
294
+ hidden_dim = int(dim * expansion_factor)
295
+ self.group_size = group_size
296
+ self.causal = causal
297
+ self.shift_tokens = shift_tokens
298
+
299
+ # Initialize positional embeddings, dropout, and projections
300
+ self.rotary_pos_emb = rotary_pos_emb
301
+ self.dropout = nn.Dropout(dropout)
302
+
303
+ # Feed-forward layers
304
+ self.to_hidden = FFConvM(
305
+ dim_in=dim,
306
+ dim_out=hidden_dim,
307
+ norm_klass=norm_klass,
308
+ dropout=dropout,
309
+ )
310
+ self.to_qk = FFConvM(
311
+ dim_in=dim,
312
+ dim_out=query_key_dim,
313
+ norm_klass=norm_klass,
314
+ dropout=dropout,
315
+ )
316
+
317
+ # Offset and scale for query and key
318
+ self.qk_offset_scale = OffsetScale(query_key_dim, heads=4)
319
+
320
+ self.to_out = FFConvM(
321
+ dim_in=dim * 2,
322
+ dim_out=dim,
323
+ norm_klass=norm_klass,
324
+ dropout=dropout,
325
+ )
326
+
327
+ self.gateActivate = nn.Sigmoid()
328
+
329
+ def forward(self, x, *, mask=None):
330
+ """
331
+ Forward pass for FLASH layer.
332
+
333
+ Args:
334
+ x (Tensor): Input tensor of shape (batch, seq_len, features).
335
+ mask (Tensor, optional): Mask for attention. Defaults to None.
336
+
337
+ Returns:
338
+ Tensor: Output tensor after applying attention and projections.
339
+ """
340
+
341
+ # Pre-normalization step
342
+ normed_x = x
343
+ residual = x # Save residual for skip connection
344
+
345
+ # Token shifting if enabled
346
+ if self.shift_tokens:
347
+ x_shift, x_pass = normed_x.chunk(2, dim=-1)
348
+ x_shift = F.pad(x_shift, (0, 0, 1, -1), value=0.)
349
+ normed_x = torch.cat((x_shift, x_pass), dim=-1)
350
+
351
+ # Initial projections
352
+ v, u = self.to_hidden(normed_x).chunk(2, dim=-1)
353
+ qk = self.to_qk(normed_x)
354
+
355
+ # Offset and scale
356
+ quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk)
357
+ att_v, att_u = self.cal_attention(x, quad_q, lin_q, quad_k, lin_k, v, u)
358
+
359
+ # Output calculation with gating
360
+ out = (att_u * v) * self.gateActivate(att_v * u)
361
+ x = x + self.to_out(out) # Residual connection
362
+ return x
363
+
364
+ def cal_attention(self, x, quad_q, lin_q, quad_k, lin_k, v, u, mask=None):
365
+ """
366
+ Calculate attention output using quadratic and linear attention mechanisms.
367
+
368
+ Args:
369
+ x (Tensor): Input tensor of shape (batch, seq_len, features).
370
+ quad_q (Tensor): Quadratic query representation.
371
+ lin_q (Tensor): Linear query representation.
372
+ quad_k (Tensor): Quadratic key representation.
373
+ lin_k (Tensor): Linear key representation.
374
+ v (Tensor): Value representation.
375
+ u (Tensor): Additional value representation.
376
+ mask (Tensor, optional): Mask for attention. Defaults to None.
377
+
378
+ Returns:
379
+ Tuple[Tensor, Tensor]: Attention outputs for v and u.
380
+ """
381
+ b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size
382
+
383
+ # Apply mask to linear keys if provided
384
+ if exists(mask):
385
+ lin_mask = rearrange(mask, '... -> ... 1')
386
+ lin_k = lin_k.masked_fill(~lin_mask, 0.)
387
+
388
+ # Rotate queries and keys with rotary positional embeddings
389
+ if exists(self.rotary_pos_emb):
390
+ quad_q, lin_q, quad_k, lin_k = map(self.rotary_pos_emb.rotate_queries_or_keys, (quad_q, lin_q, quad_k, lin_k))
391
+
392
+ # Padding for group processing
393
+ padding = padding_to_multiple_of(n, g)
394
+ if padding > 0:
395
+ quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: F.pad(t, (0, 0, 0, padding), value=0.), (quad_q, quad_k, lin_q, lin_k, v, u))
396
+ mask = default(mask, torch.ones((b, n), device=device, dtype=torch.bool))
397
+ mask = F.pad(mask, (0, padding), value=False)
398
+
399
+ # Group along sequence for attention
400
+ quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: rearrange(t, 'b (g n) d -> b g n d', n=self.group_size), (quad_q, quad_k, lin_q, lin_k, v, u))
401
+
402
+ if exists(mask):
403
+ mask = rearrange(mask, 'b (g j) -> b g 1 j', j=g)
404
+
405
+ # Calculate quadratic attention output
406
+ sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / g
407
+ attn = F.relu(sim) ** 2 # ReLU activation
408
+ attn = self.dropout(attn)
409
+
410
+ # Apply mask to attention if provided
411
+ if exists(mask):
412
+ attn = attn.masked_fill(~mask, 0.)
413
+
414
+ # Apply causal mask if needed
415
+ if self.causal:
416
+ causal_mask = torch.ones((g, g), dtype=torch.bool, device=device).triu(1)
417
+ attn = attn.masked_fill(causal_mask, 0.)
418
+
419
+ # Calculate output from attention
420
+ quad_out_v = einsum('... i j, ... j d -> ... i d', attn, v)
421
+ quad_out_u = einsum('... i j, ... j d -> ... i d', attn, u)
422
+
423
+ # Calculate linear attention output
424
+ if self.causal:
425
+ lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / g
426
+ lin_kv = lin_kv.cumsum(dim=1) # Cumulative sum for linear attention
427
+ lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value=0.)
428
+ lin_out_v = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q)
429
+
430
+ lin_ku = einsum('b g n d, b g n e -> b g d e', lin_k, u) / g
431
+ lin_ku = lin_ku.cumsum(dim=1) # Cumulative sum for linear attention
432
+ lin_ku = F.pad(lin_ku, (0, 0, 0, 0, 1, -1), value=0.)
433
+ lin_out_u = einsum('b g d e, b g n d -> b g n e', lin_ku, lin_q)
434
+ else:
435
+ lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n
436
+ lin_out_v = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv)
437
+
438
+ lin_ku = einsum('b g n d, b g n e -> b d e', lin_k, u) / n
439
+ lin_out_u = einsum('b g n d, b d e -> b g n e', lin_q, lin_ku)
440
+
441
+ # Reshape and remove padding from outputs
442
+ return map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out_v + lin_out_v, quad_out_u + lin_out_u))
443
+
444
+ class Gated_FSMN(nn.Module):
445
+ """
446
+ Gated Frequency Selective Memory Network (FSMN) class.
447
+
448
+ This class implements a gated FSMN that combines two feedforward
449
+ convolutional networks with a frequency selective memory module.
450
+
451
+ Args:
452
+ in_channels (int): Number of input channels.
453
+ out_channels (int): Number of output channels.
454
+ lorder (int): Order of the filter for FSMN.
455
+ hidden_size (int): Number of hidden units in the network.
456
+ """
457
+ def __init__(self, in_channels, out_channels, lorder, hidden_size):
458
+ super().__init__()
459
+ # Feedforward network for the first branch (u)
460
+ self.to_u = FFConvM(
461
+ dim_in=in_channels,
462
+ dim_out=hidden_size,
463
+ norm_klass=nn.LayerNorm,
464
+ dropout=0.1,
465
+ )
466
+ # Feedforward network for the second branch (v)
467
+ self.to_v = FFConvM(
468
+ dim_in=in_channels,
469
+ dim_out=hidden_size,
470
+ norm_klass=nn.LayerNorm,
471
+ dropout=0.1,
472
+ )
473
+ # Frequency selective memory network
474
+ self.fsmn = UniDeepFsmn(in_channels, out_channels, lorder, hidden_size)
475
+
476
+ def forward(self, x):
477
+ """
478
+ Forward pass for the Gated FSMN.
479
+
480
+ Args:
481
+ x (Tensor): Input tensor of shape (batch_size, in_channels, sequence_length).
482
+
483
+ Returns:
484
+ Tensor: Output tensor after applying gated FSMN operations.
485
+ """
486
+ input = x
487
+ x_u = self.to_u(x) # Process input through the first branch
488
+ x_v = self.to_v(x) # Process input through the second branch
489
+ x_u = self.fsmn(x_u) # Apply FSMN to the output of the first branch
490
+ x = x_v * x_u + input # Combine outputs with the original input
491
+ return x
492
+
493
+
494
+ class Gated_FSMN_Block(nn.Module):
495
+ """
496
+ A 1-D convolutional block that incorporates a gated FSMN.
497
+
498
+ This block consists of two convolutional layers, followed by a
499
+ gated FSMN and normalization layers.
500
+
501
+ Args:
502
+ dim (int): Dimensionality of the input.
503
+ inner_channels (int): Number of channels in the inner layers.
504
+ group_size (int): Size of the groups for normalization.
505
+ norm_type (str): Type of normalization to use ('scalenorm' or 'layernorm').
506
+ """
507
+ def __init__(self, dim, inner_channels=256, group_size=256, norm_type='scalenorm'):
508
+ super(Gated_FSMN_Block, self).__init__()
509
+ # Choose normalization class based on the provided type
510
+ if norm_type == 'scalenorm':
511
+ norm_klass = ScaleNorm
512
+ elif norm_type == 'layernorm':
513
+ norm_klass = nn.LayerNorm
514
+
515
+ self.group_size = group_size
516
+
517
+ # First convolutional layer with PReLU activation
518
+ self.conv1 = nn.Sequential(
519
+ nn.Conv1d(dim, inner_channels, kernel_size=1),
520
+ nn.PReLU(),
521
+ )
522
+ self.norm1 = CLayerNorm(inner_channels) # Normalization after first convolution
523
+ self.gated_fsmn = Gated_FSMN(inner_channels, inner_channels, lorder=20, hidden_size=inner_channels) # Gated FSMN layer
524
+ self.norm2 = CLayerNorm(inner_channels) # Normalization after FSMN
525
+ self.conv2 = nn.Conv1d(inner_channels, dim, kernel_size=1) # Final convolutional layer
526
+
527
+ def forward(self, input):
528
+ """
529
+ Forward pass for the Gated FSMN Block.
530
+
531
+ Args:
532
+ input (Tensor): Input tensor of shape (batch_size, dim, sequence_length).
533
+
534
+ Returns:
535
+ Tensor: Output tensor after processing through the block.
536
+ """
537
+ conv1 = self.conv1(input.transpose(2, 1)) # Apply first convolution
538
+ norm1 = self.norm1(conv1) # Apply normalization
539
+ seq_out = self.gated_fsmn(norm1.transpose(2, 1)) # Apply gated FSMN
540
+ norm2 = self.norm2(seq_out.transpose(2, 1)) # Apply second normalization
541
+ conv2 = self.conv2(norm2) # Apply final convolution
542
+ return conv2.transpose(2, 1) + input # Residual connection
543
+
544
+
545
+ class MossformerBlock_GFSMN(nn.Module):
546
+ """
547
+ Mossformer Block with Gated FSMN.
548
+
549
+ This block combines attention mechanisms and gated FSMN layers
550
+ to process input sequences.
551
+
552
+ Args:
553
+ dim (int): Dimensionality of the input.
554
+ depth (int): Number of layers in the block.
555
+ group_size (int): Size of the groups for normalization.
556
+ query_key_dim (int): Dimension of the query and key in attention.
557
+ expansion_factor (float): Expansion factor for feedforward layers.
558
+ causal (bool): If True, enables causal attention.
559
+ attn_dropout (float): Dropout rate for attention layers.
560
+ norm_type (str): Type of normalization to use ('scalenorm' or 'layernorm').
561
+ shift_tokens (bool): If True, shifts tokens in the attention layer.
562
+ """
563
+ def __init__(self, *, dim, depth, group_size=256, query_key_dim=128, expansion_factor=4., causal=False, attn_dropout=0.1, norm_type='scalenorm', shift_tokens=True):
564
+ super().__init__()
565
+ assert norm_type in ('scalenorm', 'layernorm'), 'norm_type must be one of scalenorm or layernorm'
566
+
567
+ if norm_type == 'scalenorm':
568
+ norm_klass = ScaleNorm
569
+ elif norm_type == 'layernorm':
570
+ norm_klass = nn.LayerNorm
571
+
572
+ self.group_size = group_size
573
+
574
+ # Rotary positional embedding for attention
575
+ rotary_pos_emb = RotaryEmbedding(dim=min(32, query_key_dim))
576
+
577
+ # Create a list of Gated FSMN blocks
578
+ self.fsmn = nn.ModuleList([Gated_FSMN_Block(dim) for _ in range(depth)])
579
+
580
+ # Create a list of attention layers using FLASH_ShareA_FFConvM
581
+ self.layers = nn.ModuleList([
582
+ FLASH_ShareA_FFConvM(
583
+ dim=dim,
584
+ group_size=group_size,
585
+ query_key_dim=query_key_dim,
586
+ expansion_factor=expansion_factor,
587
+ causal=causal,
588
+ dropout=attn_dropout,
589
+ rotary_pos_emb=rotary_pos_emb,
590
+ norm_klass=norm_klass,
591
+ shift_tokens=shift_tokens
592
+ ) for _ in range(depth)
593
+ ])
594
+
595
+ def _build_repeats(self, in_channels, out_channels, lorder, hidden_size, repeats=1):
596
+ """
597
+ Builds repeated UniDeep FSMN layers.
598
+
599
+ Args:
600
+ in_channels (int): Number of input channels.
601
+ out_channels (int): Number of output channels.
602
+ lorder (int): Order of the filter for FSMN.
603
+ hidden_size (int): Number of hidden units.
604
+ repeats (int): Number of repetitions.
605
+
606
+ Returns:
607
+ Sequential: A sequential container with repeated layers.
608
+ """
609
+ repeats = [
610
+ UniDeepFsmn(in_channels, out_channels, lorder, hidden_size)
611
+ for i in range(repeats)
612
+ ]
613
+ return nn.Sequential(*repeats)
614
+
615
+ def forward(self, x, *, mask=None):
616
+ """
617
+ Forward pass for the Mossformer Block with Gated FSMN.
618
+
619
+ Args:
620
+ x (Tensor): Input tensor of shape (batch_size, dim, sequence_length).
621
+ mask (Tensor, optional): Mask tensor for attention operations.
622
+
623
+ Returns:
624
+ Tensor: Output tensor after processing through the block.
625
+ """
626
+ ii = 0
627
+ for flash in self.layers: # Process through each layer
628
+ x = flash(x, mask=mask)
629
+ x = self.fsmn[ii](x) # Apply corresponding Gated FSMN block
630
+ ii += 1
631
+
632
+ return x
633
+
634
+
635
+ class MossformerBlock(nn.Module):
636
+ """
637
+ Mossformer Block with attention mechanisms.
638
+
639
+ This block is designed to process input sequences using attention
640
+ layers and incorporates rotary positional embeddings. It allows
641
+ for configurable normalization types and can handle causal
642
+ attention.
643
+
644
+ Args:
645
+ dim (int): Dimensionality of the input.
646
+ depth (int): Number of attention layers in the block.
647
+ group_size (int, optional): Size of groups for normalization. Default is 256.
648
+ query_key_dim (int, optional): Dimension of the query and key in attention. Default is 128.
649
+ expansion_factor (float, optional): Expansion factor for feedforward layers. Default is 4.
650
+ causal (bool, optional): If True, enables causal attention. Default is False.
651
+ attn_dropout (float, optional): Dropout rate for attention layers. Default is 0.1.
652
+ norm_type (str, optional): Type of normalization to use ('scalenorm' or 'layernorm'). Default is 'scalenorm'.
653
+ shift_tokens (bool, optional): If True, shifts tokens in the attention layer. Default is True.
654
+ """
655
+ def __init__(
656
+ self,
657
+ *,
658
+ dim,
659
+ depth,
660
+ group_size=256,
661
+ query_key_dim=128,
662
+ expansion_factor=4.0,
663
+ causal=False,
664
+ attn_dropout=0.1,
665
+ norm_type='scalenorm',
666
+ shift_tokens=True
667
+ ):
668
+ super().__init__()
669
+
670
+ # Ensure normalization type is valid
671
+ assert norm_type in ('scalenorm', 'layernorm'), 'norm_type must be one of scalenorm or layernorm'
672
+
673
+ # Select normalization class based on the provided type
674
+ if norm_type == 'scalenorm':
675
+ norm_klass = ScaleNorm
676
+ elif norm_type == 'layernorm':
677
+ norm_klass = nn.LayerNorm
678
+
679
+ self.group_size = group_size # Group size for normalization
680
+
681
+ # Rotary positional embedding for attention
682
+ rotary_pos_emb = RotaryEmbedding(dim=min(32, query_key_dim))
683
+ # Max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J
684
+
685
+ # Create a list of attention layers using FLASH_ShareA_FFConvM
686
+ self.layers = nn.ModuleList([
687
+ FLASH_ShareA_FFConvM(
688
+ dim=dim,
689
+ group_size=group_size,
690
+ query_key_dim=query_key_dim,
691
+ expansion_factor=expansion_factor,
692
+ causal=causal,
693
+ dropout=attn_dropout,
694
+ rotary_pos_emb=rotary_pos_emb,
695
+ norm_klass=norm_klass,
696
+ shift_tokens=shift_tokens
697
+ ) for _ in range(depth)
698
+ ])
699
+
700
+ def _build_repeats(self, in_channels, out_channels, lorder, hidden_size, repeats=1):
701
+ """
702
+ Builds repeated UniDeep FSMN layers.
703
+
704
+ Args:
705
+ in_channels (int): Number of input channels.
706
+ out_channels (int): Number of output channels.
707
+ lorder (int): Order of the filter for FSMN.
708
+ hidden_size (int): Number of hidden units.
709
+ repeats (int, optional): Number of repetitions. Default is 1.
710
+
711
+ Returns:
712
+ Sequential: A sequential container with repeated layers.
713
+ """
714
+ repeats = [
715
+ UniDeepFsmn(in_channels, out_channels, lorder, hidden_size)
716
+ for _ in range(repeats)
717
+ ]
718
+ return nn.Sequential(*repeats)
719
+
720
+ def forward(self, x, *, mask=None):
721
+ """
722
+ Forward pass for the Mossformer Block.
723
+
724
+ Args:
725
+ x (Tensor): Input tensor of shape (batch_size, dim, sequence_length).
726
+ mask (Tensor, optional): Mask tensor for attention operations.
727
+
728
+ Returns:
729
+ Tensor: Output tensor after processing through the block.
730
+ """
731
+ # Process input through each attention layer
732
+ for flash in self.layers:
733
+ x = flash(x, mask=mask) # Apply attention layer with optional mask
734
+
735
+ return x # Return the final output tensor
models/mossformer2_sr/mossformer2_sr_wrapper.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.mossformer2_sr.generator import Mossformer, Generator
2
+ import torch.nn as nn
3
+
4
+ class MossFormer2_SR_48K(nn.Module):
5
+ """
6
+ The MossFormer2_SR_48K model for speech super-resolution.
7
+
8
+ This class encapsulates the functionality of the MossFormer2 and HiFi-Gan
9
+ Generator within a higher-level model. It processes input audio data to produce
10
+ higher-resolution outputs.
11
+
12
+ Arguments
13
+ ---------
14
+ args : Namespace
15
+ Configuration arguments that may include hyperparameters
16
+ and model settings (not utilized in this implementation but
17
+ can be extended for flexibility).
18
+
19
+ Example
20
+ ---------
21
+ >>> model = MossFormer2_SR_48K(args).model
22
+ >>> x = torch.randn(10, 180, 2000) # Example input
23
+ >>> outputs = model(x) # Forward pass
24
+ >>> outputs.shape, mask.shape # Check output shapes
25
+ """
26
+
27
+ def __init__(self, args):
28
+ super(MossFormer2_SR_48K, self).__init__()
29
+ # Initialize the TestNet model, which contains the MossFormer MaskNet
30
+ self.model_m = Mossformer() # Instance of TestNet
31
+ self.model_g = Generator(args)
32
+
33
+ def forward(self, x):
34
+ """
35
+ Forward pass through the model.
36
+
37
+ Arguments
38
+ ---------
39
+ x : torch.Tensor
40
+ Input tensor of dimension [B, N, S], where B is the batch size,
41
+ N is the number of mel bins (80 in this case), and S is the
42
+ sequence length (e.g., time frames).
43
+
44
+ Returns
45
+ -------
46
+ outputs : torch.Tensor
47
+ Bandwidth expanded audio output tensor from the model.
48
+
49
+ """
50
+ x = self.model_m(x) # Get outputs and mask from TestNet
51
+ outpus = self.model_g(x)
52
+ return outputs # Return the outputs
models/mossformer2_sr/snake.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from torch.nn.utils import weight_norm
7
+
8
+
9
+ def WNConv1d(*args, **kwargs):
10
+ return weight_norm(nn.Conv1d(*args, **kwargs))
11
+
12
+
13
+ def WNConvTranspose1d(*args, **kwargs):
14
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
15
+
16
+
17
+ # Scripting this brings model speed up 1.4x
18
+ @torch.jit.script
19
+ def snake(x, alpha):
20
+ shape = x.shape
21
+ x = x.reshape(shape[0], shape[1], -1)
22
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
23
+ x = x.reshape(shape)
24
+ return x
25
+
26
+
27
+ class Snake1d(nn.Module):
28
+ def __init__(self, channels):
29
+ super().__init__()
30
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
31
+
32
+ def forward(self, x):
33
+ return snake(x, self.alpha)
models/mossformer2_sr/utils.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import torch
4
+ from torch.nn.utils import weight_norm
5
+
6
+ def init_weights(m, mean=0.0, std=0.01):
7
+ classname = m.__class__.__name__
8
+ if classname.find("Conv") != -1:
9
+ m.weight.data.normal_(mean, std)
10
+
11
+ def apply_weight_norm(m):
12
+ classname = m.__class__.__name__
13
+ if classname.find("Conv") != -1:
14
+ weight_norm(m)
15
+
16
+ def get_padding(kernel_size, dilation=1):
17
+ return int((kernel_size*dilation - dilation)/2)
18
+
19
+ def load_checkpoint(filepath, device):
20
+ assert os.path.isfile(filepath)
21
+ print("Loading '{}'".format(filepath))
22
+ checkpoint_dict = torch.load(filepath, map_location=device)
23
+ print("Complete.")
24
+ return checkpoint_dict
25
+
26
+ def save_checkpoint(filepath, obj):
27
+ print("Saving checkpoint to {}".format(filepath))
28
+ torch.save(obj, filepath)
29
+ print("Complete.")
30
+
31
+ def scan_checkpoint(cp_dir, prefix):
32
+ pattern = os.path.join(cp_dir, prefix + '????????')
33
+ cp_list = glob.glob(pattern)
34
+ if len(cp_list) == 0:
35
+ return None
36
+ return sorted(cp_list)[-1]
37
+