alibabasglab commited on
Commit
0adb99d
·
verified ·
1 Parent(s): b89149b

Upload networks.py

Browse files
Files changed (1) hide show
  1. networks.py +613 -0
networks.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Shengkui Zhao, Zexu Pan)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import soundfile as sf
18
+ import os
19
+ import subprocess
20
+ import librosa
21
+ from tqdm import tqdm
22
+ import numpy as np
23
+ from pydub import AudioSegment
24
+ from utils.decode import decode_one_audio
25
+ from dataloader.dataloader import DataReader
26
+
27
+ MAX_WAV_VALUE = 32768.0
28
+
29
+ class SpeechModel:
30
+ """
31
+ The SpeechModel class is a base class designed to handle speech processing tasks,
32
+ such as loading, processing, and decoding audio data. It initializes the computational
33
+ device (CPU or GPU) and holds model-related attributes. The class is flexible and intended
34
+ to be extended by specific speech models for tasks like speech enhancement, speech separation,
35
+ target speaker extraction etc.
36
+
37
+ Attributes:
38
+ - args: Argument parser object that contains configuration settings.
39
+ - device: The device (CPU or GPU) on which the model will run.
40
+ - model: The actual model used for speech processing tasks (to be loaded by subclasses).
41
+ - name: A placeholder for the model's name.
42
+ - data: A dictionary to store any additional data related to the model, such as audio input.
43
+ """
44
+
45
+ def __init__(self, args):
46
+ """
47
+ Initializes the SpeechModel class by determining the computation device
48
+ (GPU or CPU) to be used for running the model, based on system availability.
49
+
50
+ Args:
51
+ - args: Argument parser object containing settings like whether to use CUDA (GPU) or not.
52
+ """
53
+ # Check if a GPU is available
54
+ if torch.cuda.is_available():
55
+ # Find the GPU with the most free memory using a custom method
56
+ free_gpu_id = self.get_free_gpu()
57
+ if free_gpu_id is not None:
58
+ args.use_cuda = 1
59
+ torch.cuda.set_device(free_gpu_id)
60
+ self.device = torch.device('cuda')
61
+ else:
62
+ # If no GPU is detected, use the CPU
63
+ #print("No GPU found. Using CPU.")
64
+ args.use_cuda = 0
65
+ self.device = torch.device('cpu')
66
+ else:
67
+ # If no GPU is detected, use the CPU
68
+ args.use_cuda = 0
69
+ self.device = torch.device('cpu')
70
+
71
+ self.args = args
72
+ self.model = None
73
+ self.name = None
74
+ self.data = {}
75
+ self.print = False
76
+
77
+ def get_free_gpu(self):
78
+ """
79
+ Identifies the GPU with the most free memory using 'nvidia-smi' and returns its index.
80
+
81
+ This function queries the available GPUs on the system and determines which one has
82
+ the highest amount of free memory. It uses the `nvidia-smi` command-line tool to gather
83
+ GPU memory usage data. If successful, it returns the index of the GPU with the most free memory.
84
+ If the query fails or an error occurs, it returns None.
85
+
86
+ Returns:
87
+ int: Index of the GPU with the most free memory, or None if no GPU is found or an error occurs.
88
+ """
89
+ try:
90
+ # Run nvidia-smi to query GPU memory usage and free memory
91
+ result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used,memory.free', '--format=csv,nounits,noheader'], stdout=subprocess.PIPE)
92
+ gpu_info = result.stdout.decode('utf-8').strip().split('\n')
93
+
94
+ free_gpu = None
95
+ max_free_memory = 0
96
+ for i, info in enumerate(gpu_info):
97
+ used, free = map(int, info.split(','))
98
+ if free > max_free_memory:
99
+ max_free_memory = free
100
+ free_gpu = i
101
+ return free_gpu
102
+ except Exception as e:
103
+ print(f"Error finding free GPU: {e}")
104
+ return None
105
+
106
+ def download_model(self, model_name):
107
+ checkpoint_dir = self.args.checkpoint_dir
108
+ from huggingface_hub import snapshot_download
109
+ if not os.path.exists(checkpoint_dir):
110
+ os.makedirs(checkpoint_dir)
111
+ print(f'downloading checkpoint for {model_name}')
112
+ try:
113
+ snapshot_download(repo_id=f'alibabasglab/{model_name}', local_dir=checkpoint_dir)
114
+ return True
115
+ except:
116
+ return False
117
+
118
+ def load_model(self):
119
+ """
120
+ Loads a pre-trained model checkpoints from a specified directory. It checks for
121
+ the best model ('last_best_checkpoint') in the checkpoint directory. If a model is
122
+ found, it loads the model state into the current model instance.
123
+
124
+ If no checkpoint is found, it will try to download the model from huggingface.
125
+ If the downloading fails, it prints a warning message.
126
+
127
+ Steps:
128
+ - Search for the best model checkpoint or the most recent one.
129
+ - Load the model's state dictionary from the checkpoint file.
130
+
131
+ Raises:
132
+ - FileNotFoundError: If neither 'last_best_checkpoint' nor 'last_checkpoint' files are found.
133
+ """
134
+ # Define paths for the best model and the last checkpoint
135
+ best_name = os.path.join(self.args.checkpoint_dir, 'last_best_checkpoint')
136
+ # Check if the last best checkpoint exists
137
+ if not os.path.isfile(best_name):
138
+ if not self.download_model(self.name):
139
+ # If downloading is unsuccessful
140
+ print(f'Warning: Downloading model {self.name} is not successful. Please try again or manually download from https://huggingface.co/alibabasglab/{self.name}/tree/main !')
141
+ return
142
+
143
+ if isinstance(self.model, nn.ModuleList):
144
+ with open(best_name, 'r') as f:
145
+ model_name = f.readline().strip()
146
+ checkpoint_path = os.path.join(self.args.checkpoint_dir, model_name)
147
+ self._load_model(self.model[0], checkpoint_path, model_key='mossformer')
148
+ model_name = f.readline().strip()
149
+ checkpoint_path = os.path.join(self.args.checkpoint_dir, model_name)
150
+ self._load_model(self.model[1], checkpoint_path, model_key='generator')
151
+ else:
152
+ # Read the model's checkpoint name from the file
153
+ with open(best_name, 'r') as f:
154
+ model_name = f.readline().strip()
155
+ # Form the full path to the model's checkpoint
156
+ checkpoint_path = os.path.join(self.args.checkpoint_dir, model_name)
157
+ self._load_model(self.model, checkpoint_path, model_key='model')
158
+
159
+ def _load_model(self, model, checkpoint_path, model_key=None):
160
+ # Load the checkpoint file into memory (map_location ensures compatibility with different devices)
161
+ checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
162
+ # Load the model's state dictionary (weights and biases) into the current model
163
+ if model_key in checkpoint:
164
+ pretrained_model = checkpoint[model_key]
165
+ else:
166
+ pretrained_model = checkpoint
167
+ state = model.state_dict()
168
+ for key in state.keys():
169
+ if key in pretrained_model and state[key].shape == pretrained_model[key].shape:
170
+ state[key] = pretrained_model[key]
171
+ elif key.replace('module.', '') in pretrained_model and state[key].shape == pretrained_model[key.replace('module.', '')].shape:
172
+ state[key] = pretrained_model[key.replace('module.', '')]
173
+ elif 'module.'+key in pretrained_model and state[key].shape == pretrained_model['module.'+key].shape:
174
+ state[key] = pretrained_model['module.'+key]
175
+ elif self.print: print(f'{key} not loaded')
176
+ model.load_state_dict(state)
177
+
178
+ def decode(self):
179
+ """
180
+ Decodes the input audio data using the loaded model and ensures the output matches the original audio length.
181
+
182
+ This method processes the audio through a speech model (e.g., for enhancement, separation, etc.),
183
+ and truncates the resulting audio to match the original input's length. The method supports multiple speakers
184
+ if the model handles multi-speaker audio.
185
+
186
+ Returns:
187
+ output_audio: The decoded audio after processing, truncated to the input audio length.
188
+ If multi-speaker audio is processed, a list of truncated audio outputs per speaker is returned.
189
+ """
190
+ # Decode the audio using the loaded model on the given device (e.g., CPU or GPU)
191
+ output_audios = []
192
+ for i in range(len(self.data['audio'])):
193
+ output_audio = decode_one_audio(self.model, self.device, self.data['audio'][i], self.args)
194
+ # Ensure the decoded output matches the length of the input audio
195
+ if isinstance(output_audio, list):
196
+ # If multi-speaker audio (a list of outputs), truncate each speaker's audio to input length
197
+ for spk in range(self.args.num_spks):
198
+ output_audio[spk] = output_audio[spk][:self.data['audio_len']]
199
+ else:
200
+ # Single output, truncate to input audio length
201
+ output_audio = output_audio[:self.data['audio_len']]
202
+ output_audios.append(output_audio)
203
+
204
+ if isinstance(output_audios[0], list):
205
+ output_audios_np = []
206
+ for spk in range(self.args.num_spks):
207
+ output_audio_buf = []
208
+ for i in range(len(output_audios)):
209
+ output_audio_buf.append(output_audios[i][spk])
210
+ #output_audio_buf = np.vstack((output_audio_buf, output_audios[i][spk])).T
211
+ output_audios_np.append(np.array(output_audio_buf))
212
+ else:
213
+ output_audios_np = np.array(output_audios)
214
+ return output_audios_np
215
+
216
+ def process(self, input_path, online_write=False, output_path=None):
217
+ """
218
+ Load and process audio files from the specified input path. Optionally,
219
+ write the output audio files to the specified output directory.
220
+
221
+ Args:
222
+ input_path (str): Path to the input audio files or folder.
223
+ online_write (bool): Whether to write the processed audio to disk in real-time.
224
+ output_path (str): Optional path for writing output files. If None, output
225
+ will be stored in self.result.
226
+
227
+ Returns:
228
+ dict or ndarray: Processed audio results either as a dictionary or as a single array,
229
+ depending on the number of audio files processed.
230
+ Returns None if online_write is enabled.
231
+ """
232
+
233
+ self.result = {}
234
+ self.args.input_path = input_path
235
+ data_reader = DataReader(self.args) # Initialize a data reader to load the audio files
236
+
237
+
238
+ # Check if online writing is enabled
239
+ if online_write:
240
+ output_wave_dir = self.args.output_dir # Set the default output directory
241
+ if isinstance(output_path, str): # If a specific output path is provided, use it
242
+ output_wave_dir = os.path.join(output_path, self.name)
243
+ # Create the output directory if it does not exist
244
+ if not os.path.isdir(output_wave_dir):
245
+ os.makedirs(output_wave_dir)
246
+
247
+ num_samples = len(data_reader) # Get the total number of samples to process
248
+ print(f'Running {self.name} ...') # Display the model being used
249
+
250
+ if self.args.task == 'target_speaker_extraction':
251
+ from utils.video_process import process_tse
252
+ assert online_write == True
253
+ process_tse(self.args, self.model, self.device, data_reader, output_wave_dir)
254
+ else:
255
+ # Disable gradient calculation for better efficiency during inference
256
+ with torch.no_grad():
257
+ for idx in tqdm(range(num_samples)): # Loop over all audio samples
258
+ self.data = {}
259
+ # Read the audio, waveform ID, and audio length from the data reader
260
+ input_audio, wav_id, input_len, scalars, audio_info = data_reader[idx]
261
+ # Store the input audio and metadata in self.data
262
+ self.data['audio'] = input_audio
263
+ self.data['id'] = wav_id
264
+ self.data['audio_len'] = input_len
265
+ self.data.update(audio_info)
266
+
267
+ # Perform the audio decoding/processing
268
+ output_audios = self.decode()
269
+
270
+ # Perform audio renormalization
271
+ if not isinstance(output_audios, list):
272
+ if len(scalars) > 1:
273
+ for i in range(len(scalars)):
274
+ output_audios[:,i] = output_audios[:,i] * scalars[i]
275
+ else:
276
+ output_audios = output_audios * scalars[0]
277
+
278
+ if online_write:
279
+ # If online writing is enabled, save the output audio to files
280
+ if isinstance(output_audios, list):
281
+ # In case of multi-speaker output, save each speaker's output separately
282
+ for spk in range(self.args.num_spks):
283
+ output_file = os.path.join(output_wave_dir, wav_id.replace('.'+self.data['ext'], f'_s{spk+1}.'+self.data['ext']))
284
+ self.write_audio(output_file, key=None, spk=spk, audio=output_audios)
285
+ else:
286
+ # Single-speaker or standard output
287
+ output_file = os.path.join(output_wave_dir, wav_id)
288
+ self.write_audio(output_file, key=None, spk=None, audio=output_audios)
289
+ else:
290
+ # If not writing to disk, store the output in the result dictionary
291
+ self.result[wav_id] = output_audios
292
+
293
+ # Return the processed results if not writing to disk
294
+ if not online_write:
295
+ if len(self.result) == 1:
296
+ # If there is only one result, return it directly
297
+ return next(iter(self.result.values()))
298
+ else:
299
+ # Otherwise, return the entire result dictionary
300
+ return self.result
301
+
302
+ def write_audio(self, output_path, key=None, spk=None, audio=None):
303
+ """
304
+ This function writes an audio signal to an output file, applying necessary transformations
305
+ such as resampling, channel handling, and format conversion based on the provided parameters
306
+ and the instance's internal settings.
307
+
308
+ Args:
309
+ output_path (str): The file path where the audio will be saved.
310
+ key (str, optional): The key used to retrieve audio from the internal result dictionary
311
+ if audio is not provided.
312
+ spk (str, optional): A specific speaker identifier, used to extract a particular speaker's
313
+ audio from a multi-speaker dataset or result.
314
+ audio (numpy.ndarray, optional): A numpy array containing the audio data to be written.
315
+ If provided, key and spk are ignored.
316
+ """
317
+
318
+ if audio is not None:
319
+ if spk is not None:
320
+ result_ = audio[spk]
321
+ else:
322
+ result_ = audio
323
+ else:
324
+ if spk is not None:
325
+ result_ = self.result[key][spk]
326
+ else:
327
+ result_ = self.result[key]
328
+
329
+ if self.data['sample_rate'] != self.args.sampling_rate:
330
+ if self.data['channels'] == 2:
331
+ left_channel = librosa.resample(result_[0,:], orig_sr=self.args.sampling_rate, target_sr=self.data['sample_rate'])
332
+ right_channel = librosa.resample(result_[1,:], orig_sr=self.args.sampling_rate, target_sr=self.data['sample_rate'])
333
+ result = np.vstack((left_channel, right_channel)).T
334
+ else:
335
+ result = librosa.resample(result_[0,:], orig_sr=self.args.sampling_rate, target_sr=self.data['sample_rate'])
336
+ else:
337
+ if self.data['channels'] == 2:
338
+ left_channel = result_[0,:]
339
+ right_channel = result_[1,:]
340
+ result = np.vstack((left_channel, right_channel)).T
341
+ else:
342
+ result = result_[0,:]
343
+
344
+ if self.data['sample_width'] == 4: ##32 bit float
345
+ MAX_WAV_VALUE = 2147483648.0
346
+ np_type = np.int32
347
+ elif self.data['sample_width'] == 2: ##16 bit int
348
+ MAX_WAV_VALUE = 32768.0
349
+ np_type = np.int16
350
+ else:
351
+ self.data['sample_width'] = 2 ##16 bit int
352
+ MAX_WAV_VALUE = 32768.0
353
+ np_type = np.int16
354
+
355
+ result = result * MAX_WAV_VALUE
356
+ result = result.astype(np_type)
357
+ audio_segment = AudioSegment(
358
+ result.tobytes(), # Raw audio data as bytes
359
+ frame_rate=self.data['sample_rate'], # Sample rate
360
+ sample_width=self.data['sample_width'], # No. bytes per sample
361
+ channels=self.data['channels'] # No. channels
362
+ )
363
+ audio_format = 'ipod' if self.data['ext'] in ['m4a', 'aac'] else self.data['ext']
364
+ audio_segment.export(output_path, format=audio_format)
365
+
366
+ def write(self, output_path, add_subdir=False, use_key=False):
367
+ """
368
+ Write the processed audio results to the specified output path.
369
+
370
+ Args:
371
+ output_path (str): The directory or file path where processed audio will be saved. If not
372
+ provided, defaults to self.args.output_dir.
373
+ add_subdir (bool): If True, appends the model name as a subdirectory to the output path.
374
+ use_key (bool): If True, uses the result dictionary's keys (audio file IDs) for filenames.
375
+
376
+ Returns:
377
+ None: Outputs are written to disk, no data is returned.
378
+ """
379
+
380
+ # Ensure the output path is a string. If not provided, use the default output directory
381
+ if not isinstance(output_path, str):
382
+ output_path = self.args.output_dir
383
+
384
+ # If add_subdir is enabled, create a subdirectory for the model name
385
+ if add_subdir:
386
+ if os.path.isfile(output_path):
387
+ print(f'File exists: {output_path}, remove it and try again!')
388
+ return
389
+ output_path = os.path.join(output_path, self.name)
390
+ if not os.path.isdir(output_path):
391
+ os.makedirs(output_path)
392
+
393
+ # Ensure proper directory setup when using keys for filenames
394
+ if use_key and not os.path.isdir(output_path):
395
+ if os.path.exists(output_path):
396
+ print(f'File exists: {output_path}, remove it and try again!')
397
+ return
398
+ os.makedirs(output_path)
399
+ # If not using keys and output path is a directory, check for conflicts
400
+ if not use_key and os.path.isdir(output_path):
401
+ print(f'Directory exists: {output_path}, remove it and try again!')
402
+ return
403
+
404
+ # Iterate over the results dictionary to write the processed audio to disk
405
+ for key in self.result:
406
+ if use_key:
407
+ # If using keys, format filenames based on the result dictionary's keys (audio IDs)
408
+ if isinstance(self.result[key], list): # For multi-speaker outputs
409
+ for spk in range(self.args.num_spks):
410
+ output_file = os.path.join(output_path, key.replace('.'+self.data['ext'], f'_s{spk+1}.'+self.data['ext']))
411
+ self.write_audio(output_file, key, spk)
412
+ else:
413
+ output_file = os.path.join(output_path, key)
414
+ self.write_audio(output_path, key)
415
+ else:
416
+ # If not using keys, write audio to the specified output path directly
417
+ if isinstance(self.result[key], list): # For multi-speaker outputs
418
+ for spk in range(self.args.num_spks):
419
+ output_file = output_path.replace('.'+self.data['ext'], f'_s{spk+1}.'+self.data['ext'])
420
+ self.write_audio(output_file, key, spk)
421
+ else:
422
+ self.write_audio(output_path, key)
423
+
424
+ # The model classes for specific sub-tasks
425
+
426
+ class CLS_FRCRN_SE_16K(SpeechModel):
427
+ """
428
+ A subclass of SpeechModel that implements a speech enhancement model using
429
+ the FRCRN architecture for 16 kHz speech enhancement.
430
+
431
+ Args:
432
+ args (Namespace): The argument parser containing model configurations and paths.
433
+ """
434
+
435
+ def __init__(self, args):
436
+ # Initialize the parent SpeechModel class
437
+ super(CLS_FRCRN_SE_16K, self).__init__(args)
438
+
439
+ # Import the FRCRN speech enhancement model for 16 kHz
440
+ from models.frcrn_se.frcrn import FRCRN_SE_16K
441
+
442
+ # Initialize the model
443
+ self.model = FRCRN_SE_16K(args).model
444
+ self.name = 'FRCRN_SE_16K'
445
+
446
+ # Load pre-trained model checkpoint
447
+ self.load_model()
448
+
449
+ # Move model to the appropriate device (GPU/CPU)
450
+ if args.use_cuda == 1:
451
+ self.model.to(self.device)
452
+
453
+ # Set the model to evaluation mode (no gradient calculation)
454
+ self.model.eval()
455
+
456
+ class CLS_MossFormer2_SE_48K(SpeechModel):
457
+ """
458
+ A subclass of SpeechModel that implements the MossFormer2 architecture for
459
+ 48 kHz speech enhancement.
460
+
461
+ Args:
462
+ args (Namespace): The argument parser containing model configurations and paths.
463
+ """
464
+
465
+ def __init__(self, args):
466
+ # Initialize the parent SpeechModel class
467
+ super(CLS_MossFormer2_SE_48K, self).__init__(args)
468
+
469
+ # Import the MossFormer2 speech enhancement model for 48 kHz
470
+ from models.mossformer2_se.mossformer2_se_wrapper import MossFormer2_SE_48K
471
+
472
+ # Initialize the model
473
+ self.model = MossFormer2_SE_48K(args).model
474
+ self.name = 'MossFormer2_SE_48K'
475
+
476
+ # Load pre-trained model checkpoint
477
+ self.load_model()
478
+
479
+ # Move model to the appropriate device (GPU/CPU)
480
+ if args.use_cuda == 1:
481
+ self.model.to(self.device)
482
+
483
+ # Set the model to evaluation mode (no gradient calculation)
484
+ self.model.eval()
485
+
486
+ class CLS_MossFormer2_SR_48K(SpeechModel):
487
+ """
488
+ A subclass of SpeechModel that implements the MossFormer2 architecture for
489
+ 48 kHz speech super-resolution.
490
+
491
+ Args:
492
+ args (Namespace): The argument parser containing model configurations and paths.
493
+ """
494
+
495
+ def __init__(self, args):
496
+ # Initialize the parent SpeechModel class
497
+ super(CLS_MossFormer2_SR_48K, self).__init__(args)
498
+
499
+ # Import the MossFormer2 speech enhancement model for 48 kHz
500
+ from models.mossformer2_sr.mossformer2_sr_wrapper import MossFormer2_SR_48K
501
+
502
+ # Initialize the model
503
+ self.model = nn.ModuleList()
504
+ self.model.append(MossFormer2_SR_48K(args).model_m)
505
+ self.model.append(MossFormer2_SR_48K(args).model_g)
506
+ self.name = 'MossFormer2_SR_48K'
507
+
508
+ # Load pre-trained model checkpoint
509
+ self.load_model()
510
+
511
+ # Move model to the appropriate device (GPU/CPU)
512
+ if args.use_cuda == 1:
513
+ for model in self.model:
514
+ model.to(self.device)
515
+
516
+ # Set the model to evaluation mode (no gradient calculation)
517
+ for model in self.model:
518
+ model.eval()
519
+ self.model[1].remove_weight_norm()
520
+
521
+ class CLS_MossFormerGAN_SE_16K(SpeechModel):
522
+ """
523
+ A subclass of SpeechModel that implements the MossFormerGAN architecture for
524
+ 16 kHz speech enhancement, utilizing GAN-based speech processing.
525
+
526
+ Args:
527
+ args (Namespace): The argument parser containing model configurations and paths.
528
+ """
529
+
530
+ def __init__(self, args):
531
+ # Initialize the parent SpeechModel class
532
+ super(CLS_MossFormerGAN_SE_16K, self).__init__(args)
533
+
534
+ # Import the MossFormerGAN speech enhancement model for 16 kHz
535
+ from models.mossformer_gan_se.generator import MossFormerGAN_SE_16K
536
+
537
+ # Initialize the model
538
+ self.model = MossFormerGAN_SE_16K(args).model
539
+ self.name = 'MossFormerGAN_SE_16K'
540
+
541
+ # Load pre-trained model checkpoint
542
+ self.load_model()
543
+
544
+ # Move model to the appropriate device (GPU/CPU)
545
+ if args.use_cuda == 1:
546
+ self.model.to(self.device)
547
+
548
+ # Set the model to evaluation mode (no gradient calculation)
549
+ self.model.eval()
550
+
551
+ class CLS_MossFormer2_SS_16K(SpeechModel):
552
+ """
553
+ A subclass of SpeechModel that implements the MossFormer2 architecture for
554
+ 16 kHz speech separation.
555
+
556
+ Args:
557
+ args (Namespace): The argument parser containing model configurations and paths.
558
+ """
559
+
560
+ def __init__(self, args):
561
+ # Initialize the parent SpeechModel class
562
+ super(CLS_MossFormer2_SS_16K, self).__init__(args)
563
+
564
+ # Import the MossFormer2 speech separation model for 16 kHz
565
+ from models.mossformer2_ss.mossformer2 import MossFormer2_SS_16K
566
+
567
+ # Initialize the model
568
+ self.model = MossFormer2_SS_16K(args).model
569
+ self.name = 'MossFormer2_SS_16K'
570
+
571
+ # Load pre-trained model checkpoint
572
+ self.load_model()
573
+
574
+ # Move model to the appropriate device (GPU/CPU)
575
+ if args.use_cuda == 1:
576
+ self.model.to(self.device)
577
+
578
+ # Set the model to evaluation mode (no gradient calculation)
579
+ self.model.eval()
580
+
581
+
582
+ class CLS_AV_MossFormer2_TSE_16K(SpeechModel):
583
+ """
584
+ A subclass of SpeechModel that implements an audio-visual (AV) model using
585
+ the AV-MossFormer2 architecture for target speaker extraction (TSE) at 16 kHz.
586
+ This model leverages both audio and visual cues to perform speaker extraction.
587
+
588
+ Args:
589
+ args (Namespace): The argument parser containing model configurations and paths.
590
+ """
591
+
592
+ def __init__(self, args):
593
+ # Initialize the parent SpeechModel class
594
+ super(CLS_AV_MossFormer2_TSE_16K, self).__init__(args)
595
+
596
+ # Import the AV-MossFormer2 model for 16 kHz target speech enhancement
597
+ from models.av_mossformer2_tse.av_mossformer2 import AV_MossFormer2_TSE_16K
598
+
599
+ # Initialize the model
600
+ self.model = AV_MossFormer2_TSE_16K(args).model
601
+ self.name = 'AV_MossFormer2_TSE_16K'
602
+
603
+ # Load pre-trained model checkpoint
604
+ self.load_model()
605
+
606
+ # Move model to the appropriate device (GPU/CPU)
607
+ if args.use_cuda == 1:
608
+ self.model.to(self.device)
609
+
610
+ # Set the model to evaluation mode (no gradient calculation)
611
+ self.model.eval()
612
+
613
+