alibabasglab commited on
Commit
b89149b
·
verified ·
1 Parent(s): dc79b90

Delete networks.py

Browse files
Files changed (1) hide show
  1. networks.py +0 -490
networks.py DELETED
@@ -1,490 +0,0 @@
1
- """
2
- Authors: Shengkui Zhao, Zexu Pan
3
- """
4
-
5
- import torch
6
- import soundfile as sf
7
- import os
8
- import subprocess
9
- from tqdm import tqdm
10
- from utils.decode import decode_one_audio
11
- from dataloader.dataloader import DataReader
12
-
13
- class SpeechModel:
14
- """
15
- The SpeechModel class is a base class designed to handle speech processing tasks,
16
- such as loading, processing, and decoding audio data. It initializes the computational
17
- device (CPU or GPU) and holds model-related attributes. The class is flexible and intended
18
- to be extended by specific speech models for tasks like speech enhancement, speech separation,
19
- target speaker extraction etc.
20
-
21
- Attributes:
22
- - args: Argument parser object that contains configuration settings.
23
- - device: The device (CPU or GPU) on which the model will run.
24
- - model: The actual model used for speech processing tasks (to be loaded by subclasses).
25
- - name: A placeholder for the model's name.
26
- - data: A dictionary to store any additional data related to the model, such as audio input.
27
- """
28
-
29
- def __init__(self, args):
30
- """
31
- Initializes the SpeechModel class by determining the computation device
32
- (GPU or CPU) to be used for running the model, based on system availability.
33
-
34
- Args:
35
- - args: Argument parser object containing settings like whether to use CUDA (GPU) or not.
36
- """
37
- # Check if a GPU is available
38
- """
39
- if torch.cuda.is_available():
40
- # Find the GPU with the most free memory using a custom method
41
- free_gpu_id = self.get_free_gpu()
42
- if free_gpu_id is not None:
43
- args.use_cuda = 1
44
- torch.cuda.set_device(free_gpu_id)
45
- print(f'use GPU: {free_gpu_id}')
46
- self.device = torch.device('cuda')
47
- else:
48
- # If no GPU is detected, use the CPU
49
- #print("No GPU found. Using CPU.")
50
- args.use_cuda = 0
51
- self.device = torch.device('cpu')
52
- else:
53
- # If no GPU is detected, use the CPU
54
- args.use_cuda = 0
55
- self.device = torch.device('cpu')
56
- """
57
- if torch.cuda.is_available():
58
- print('GPU is found and used!')
59
- self.device = torch.device('cuda')
60
- else:
61
- # If no GPU is detected, use the CPU
62
- args.use_cuda = 0
63
- self.device = torch.device('cpu')
64
-
65
- self.args = args
66
- self.model = None
67
- self.name = None
68
- self.data = {}
69
-
70
- def get_free_gpu(self):
71
- """
72
- Identifies the GPU with the most free memory using 'nvidia-smi' and returns its index.
73
-
74
- This function queries the available GPUs on the system and determines which one has
75
- the highest amount of free memory. It uses the `nvidia-smi` command-line tool to gather
76
- GPU memory usage data. If successful, it returns the index of the GPU with the most free memory.
77
- If the query fails or an error occurs, it returns None.
78
-
79
- Returns:
80
- int: Index of the GPU with the most free memory, or None if no GPU is found or an error occurs.
81
- """
82
- try:
83
- # Run nvidia-smi to query GPU memory usage and free memory
84
- result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used,memory.free', '--format=csv,nounits,noheader'], stdout=subprocess.PIPE)
85
- gpu_info = result.stdout.decode('utf-8').strip().split('\n')
86
-
87
- free_gpu = None
88
- max_free_memory = 0
89
- for i, info in enumerate(gpu_info):
90
- used, free = map(int, info.split(','))
91
- if free > max_free_memory:
92
- max_free_memory = free
93
- free_gpu = i
94
- return free_gpu
95
- except Exception as e:
96
- print(f"Error finding free GPU: {e}")
97
- return None
98
-
99
- def load_model(self):
100
- """
101
- Loads a pre-trained model checkpoint from a specified directory. It checks for
102
- the best model ('last_best_checkpoint') or the most recent checkpoint ('last_checkpoint')
103
- in the checkpoint directory. If a model is found, it loads the model state into the
104
- current model instance.
105
-
106
- If no checkpoint is found, it prints a warning message.
107
-
108
- Steps:
109
- - Search for the best model checkpoint or the most recent one.
110
- - Load the model's state dictionary from the checkpoint file.
111
-
112
- Raises:
113
- - FileNotFoundError: If neither 'last_best_checkpoint' nor 'last_checkpoint' files are found.
114
- """
115
- # Define paths for the best model and the last checkpoint
116
- best_name = os.path.join(self.args.checkpoint_dir, 'last_best_checkpoint')
117
- ckpt_name = os.path.join(self.args.checkpoint_dir, 'last_checkpoint')
118
-
119
- # Check if the best checkpoint or last checkpoint exists
120
- if os.path.isfile(best_name):
121
- name = best_name # Prioritize loading the best model
122
- elif os.path.isfile(ckpt_name):
123
- name = ckpt_name # Otherwise, load the last saved checkpoint
124
- else:
125
- # If no checkpoint exists, print a warning and exit the function
126
- print('Warning: No existing checkpoint or best model found!')
127
- return
128
-
129
- # Read the model's checkpoint name from the file
130
- with open(name, 'r') as f:
131
- model_name = f.readline().strip()
132
-
133
- # Form the full path to the model's checkpoint
134
- checkpoint_path = os.path.join(self.args.checkpoint_dir, model_name)
135
-
136
- # Load the checkpoint file into memory (map_location ensures compatibility with different devices)
137
- checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
138
-
139
- # Load the model's state dictionary (weights and biases) into the current model
140
- '''
141
- if 'model' in checkpoint:
142
- # If the checkpoint contains a 'model' key, load the corresponding state dictionary
143
- if self.args.task =='target_speaker_extraction':
144
- pretrained_model = checkpoint['model']
145
- state = self.model.state_dict()
146
- for key in state.keys():
147
- pretrain_key = 'module.' + key
148
- state[key] = pretrained_model[pretrain_key]
149
- self.model.load_state_dict(state, strict=True)
150
- else:
151
- self.model.load_state_dict(checkpoint['model'], strict=False)
152
- else:
153
- # If the checkpoint is a plain state dictionary, load it directly
154
- self.model.load_state_dict(checkpoint, strict=False)
155
- '''
156
- if 'model' in checkpoint:
157
- pretrained_model = checkpoint['model']
158
- else:
159
- pretrained_model = checkpoint
160
- state = self.model.state_dict()
161
- for key in state.keys():
162
- if key in pretrained_model and state[key].shape == pretrained_model[key].shape:
163
- state[key] = pretrained_model[key]
164
- elif key.replace('module.', '') in pretrained_model and state[key].shape == pretrained_model[key.replace('module.', '')].shape:
165
- state[key] = pretrained_model[key.replace('module.', '')]
166
- elif 'module.'+key in pretrained_model and state[key].shape == pretrained_model['module.'+key].shape:
167
- state[key] = pretrained_model['module.'+key]
168
- elif self.print: print(f'{key} not loaded')
169
- self.model.load_state_dict(state)
170
- print(f'Successfully loaded {model_name} for decoding')
171
-
172
- def decode(self):
173
- """
174
- Decodes the input audio data using the loaded model and ensures the output matches the original audio length.
175
-
176
- This method processes the audio through a speech model (e.g., for enhancement, separation, etc.),
177
- and truncates the resulting audio to match the original input's length. The method supports multiple speakers
178
- if the model handles multi-speaker audio.
179
-
180
- Returns:
181
- output_audio: The decoded audio after processing, truncated to the input audio length.
182
- If multi-speaker audio is processed, a list of truncated audio outputs per speaker is returned.
183
- """
184
- # Decode the audio using the loaded model on the given device (e.g., CPU or GPU)
185
- output_audio = decode_one_audio(self.model, self.device, self.data['audio'], self.args)
186
-
187
- # Ensure the decoded output matches the length of the input audio
188
- if isinstance(output_audio, list):
189
- # If multi-speaker audio (a list of outputs), truncate each speaker's audio to input length
190
- for spk in range(self.args.num_spks):
191
- output_audio[spk] = output_audio[spk][:self.data['audio_len']]
192
- else:
193
- # Single output, truncate to input audio length
194
- output_audio = output_audio[:self.data['audio_len']]
195
-
196
- return output_audio
197
-
198
- def process(self, input_path, online_write=False, output_path=None):
199
- """
200
- Load and process audio files from the specified input path. Optionally,
201
- write the output audio files to the specified output directory.
202
-
203
- Args:
204
- input_path (str): Path to the input audio files or folder.
205
- online_write (bool): Whether to write the processed audio to disk in real-time.
206
- output_path (str): Optional path for writing output files. If None, output
207
- will be stored in self.result.
208
-
209
- Returns:
210
- dict or ndarray: Processed audio results either as a dictionary or as a single array,
211
- depending on the number of audio files processed.
212
- Returns None if online_write is enabled.
213
- """
214
-
215
- self.result = {}
216
- self.args.input_path = input_path
217
- data_reader = DataReader(self.args) # Initialize a data reader to load the audio files
218
-
219
-
220
- # Check if online writing is enabled
221
- if online_write:
222
- output_wave_dir = self.args.output_dir # Set the default output directory
223
- if isinstance(output_path, str): # If a specific output path is provided, use it
224
- output_wave_dir = os.path.join(output_path, self.name)
225
- # Create the output directory if it does not exist
226
- if not os.path.isdir(output_wave_dir):
227
- os.makedirs(output_wave_dir)
228
-
229
- num_samples = len(data_reader) # Get the total number of samples to process
230
- print(f'Running {self.name} ...') # Display the model being used
231
-
232
- if self.args.task == 'target_speaker_extraction':
233
- from utils.video_process import process_tse
234
- assert online_write == True
235
- process_tse(self.args, self.model, self.device, data_reader, output_wave_dir)
236
- else:
237
- # Disable gradient calculation for better efficiency during inference
238
- with torch.no_grad():
239
- for idx in tqdm(range(num_samples)): # Loop over all audio samples
240
- self.data = {}
241
- # Read the audio, waveform ID, and audio length from the data reader
242
- input_audio, wav_id, input_len, scalar = data_reader[idx]
243
- # Store the input audio and metadata in self.data
244
- self.data['audio'] = input_audio
245
- self.data['id'] = wav_id
246
- self.data['audio_len'] = input_len
247
-
248
- # Perform the audio decoding/processing
249
- output_audio = self.decode()
250
-
251
- #if isinstance(output_audio, list):
252
- # for spk in range(self.args.num_spks):
253
- # output_audio[spk] = output_audio[spk] * scalar
254
- #else:
255
- #if not isinstance(output_audio, list):
256
- if self.args.network == 'FRCRN_SE_16K':
257
- output_audio = output_audio * scalar
258
-
259
- if online_write:
260
- # If online writing is enabled, save the output audio to files
261
- if isinstance(output_audio, list):
262
- # In case of multi-speaker output, save each speaker's output separately
263
- for spk in range(self.args.num_spks):
264
- output_file = os.path.join(output_wave_dir, wav_id.replace('.wav', f'_s{spk+1}.wav'))
265
- sf.write(output_file, output_audio[spk], self.args.sampling_rate)
266
- else:
267
- # Single-speaker or standard output
268
- output_file = os.path.join(output_wave_dir, wav_id)
269
- sf.write(output_file, output_audio, self.args.sampling_rate)
270
- else:
271
- # If not writing to disk, store the output in the result dictionary
272
- self.result[wav_id] = output_audio
273
-
274
- # Return the processed results if not writing to disk
275
- if not online_write:
276
- if len(self.result) == 1:
277
- # If there is only one result, return it directly
278
- return next(iter(self.result.values()))
279
- else:
280
- # Otherwise, return the entire result dictionary
281
- return self.result
282
-
283
-
284
- def write(self, output_path, add_subdir=False, use_key=False):
285
- """
286
- Write the processed audio results to the specified output path.
287
-
288
- Args:
289
- output_path (str): The directory or file path where processed audio will be saved. If not
290
- provided, defaults to self.args.output_dir.
291
- add_subdir (bool): If True, appends the model name as a subdirectory to the output path.
292
- use_key (bool): If True, uses the result dictionary's keys (audio file IDs) for filenames.
293
-
294
- Returns:
295
- None: Outputs are written to disk, no data is returned.
296
- """
297
-
298
- # Ensure the output path is a string. If not provided, use the default output directory
299
- if not isinstance(output_path, str):
300
- output_path = self.args.output_dir
301
-
302
- # If add_subdir is enabled, create a subdirectory for the model name
303
- if add_subdir:
304
- if os.path.isfile(output_path):
305
- print(f'File exists: {output_path}, remove it and try again!')
306
- return
307
- output_path = os.path.join(output_path, self.name)
308
- if not os.path.isdir(output_path):
309
- os.makedirs(output_path)
310
-
311
- # Ensure proper directory setup when using keys for filenames
312
- if use_key and not os.path.isdir(output_path):
313
- if os.path.exists(output_path):
314
- print(f'File exists: {output_path}, remove it and try again!')
315
- return
316
- os.makedirs(output_path)
317
- # If not using keys and output path is a directory, check for conflicts
318
- if not use_key and os.path.isdir(output_path):
319
- print(f'Directory exists: {output_path}, remove it and try again!')
320
- return
321
-
322
- # Iterate over the results dictionary to write the processed audio to disk
323
- for key in self.result:
324
- if use_key:
325
- # If using keys, format filenames based on the result dictionary's keys (audio IDs)
326
- if isinstance(self.result[key], list): # For multi-speaker outputs
327
- for spk in range(self.args.num_spks):
328
- sf.write(os.path.join(output_path, key.replace('.wav', f'_s{spk+1}.wav')),
329
- self.result[key][spk], self.args.sampling_rate)
330
- else:
331
- sf.write(os.path.join(output_path, key), self.result[key], self.args.sampling_rate)
332
- else:
333
- # If not using keys, write audio to the specified output path directly
334
- if isinstance(self.result[key], list): # For multi-speaker outputs
335
- for spk in range(self.args.num_spks):
336
- sf.write(output_path.replace('.wav', f'_s{spk+1}.wav'),
337
- self.result[key][spk], self.args.sampling_rate)
338
- else:
339
- sf.write(output_path, self.result[key], self.args.sampling_rate)
340
-
341
- # The model classes for specific sub-tasks
342
-
343
- class CLS_FRCRN_SE_16K(SpeechModel):
344
- """
345
- A subclass of SpeechModel that implements a speech enhancement model using
346
- the FRCRN architecture for 16 kHz speech enhancement.
347
-
348
- Args:
349
- args (Namespace): The argument parser containing model configurations and paths.
350
- """
351
-
352
- def __init__(self, args):
353
- # Initialize the parent SpeechModel class
354
- super(CLS_FRCRN_SE_16K, self).__init__(args)
355
-
356
- # Import the FRCRN speech enhancement model for 16 kHz
357
- from models.frcrn_se.frcrn import FRCRN_SE_16K
358
-
359
- # Initialize the model
360
- self.model = FRCRN_SE_16K(args).model
361
- self.name = 'FRCRN_SE_16K'
362
-
363
- # Load pre-trained model checkpoint
364
- self.load_model()
365
-
366
- # Move model to the appropriate device (GPU/CPU)
367
- self.model.to(self.device)
368
-
369
- # Set the model to evaluation mode (no gradient calculation)
370
- self.model.eval()
371
-
372
- class CLS_MossFormer2_SE_48K(SpeechModel):
373
- """
374
- A subclass of SpeechModel that implements the MossFormer2 architecture for
375
- 48 kHz speech enhancement.
376
-
377
- Args:
378
- args (Namespace): The argument parser containing model configurations and paths.
379
- """
380
-
381
- def __init__(self, args):
382
- # Initialize the parent SpeechModel class
383
- super(CLS_MossFormer2_SE_48K, self).__init__(args)
384
-
385
- # Import the MossFormer2 speech enhancement model for 48 kHz
386
- from models.mossformer2_se.mossformer2_se_wrapper import MossFormer2_SE_48K
387
-
388
- # Initialize the model
389
- self.model = MossFormer2_SE_48K(args).model
390
- self.name = 'MossFormer2_SE_48K'
391
-
392
- # Load pre-trained model checkpoint
393
- self.load_model()
394
-
395
- # Move model to the appropriate device (GPU/CPU)
396
- self.model.to(self.device)
397
-
398
- # Set the model to evaluation mode (no gradient calculation)
399
- self.model.eval()
400
-
401
- class CLS_MossFormerGAN_SE_16K(SpeechModel):
402
- """
403
- A subclass of SpeechModel that implements the MossFormerGAN architecture for
404
- 16 kHz speech enhancement, utilizing GAN-based speech processing.
405
-
406
- Args:
407
- args (Namespace): The argument parser containing model configurations and paths.
408
- """
409
-
410
- def __init__(self, args):
411
- # Initialize the parent SpeechModel class
412
- super(CLS_MossFormerGAN_SE_16K, self).__init__(args)
413
-
414
- # Import the MossFormerGAN speech enhancement model for 16 kHz
415
- from models.mossformer_gan_se.generator import MossFormerGAN_SE_16K
416
-
417
- # Initialize the model
418
- self.model = MossFormerGAN_SE_16K(args).model
419
- self.name = 'MossFormerGAN_SE_16K'
420
-
421
- # Load pre-trained model checkpoint
422
- self.load_model()
423
-
424
- # Move model to the appropriate device (GPU/CPU)
425
- self.model.to(self.device)
426
-
427
- # Set the model to evaluation mode (no gradient calculation)
428
- self.model.eval()
429
-
430
- class CLS_MossFormer2_SS_16K(SpeechModel):
431
- """
432
- A subclass of SpeechModel that implements the MossFormer2 architecture for
433
- 16 kHz speech separation.
434
-
435
- Args:
436
- args (Namespace): The argument parser containing model configurations and paths.
437
- """
438
-
439
- def __init__(self, args):
440
- # Initialize the parent SpeechModel class
441
- super(CLS_MossFormer2_SS_16K, self).__init__(args)
442
-
443
- # Import the MossFormer2 speech separation model for 16 kHz
444
- from models.mossformer2_ss.mossformer2 import MossFormer2_SS_16K
445
-
446
- # Initialize the model
447
- self.model = MossFormer2_SS_16K(args).model
448
- self.name = 'MossFormer2_SS_16K'
449
-
450
- # Load pre-trained model checkpoint
451
- self.load_model()
452
-
453
- # Move model to the appropriate device (GPU/CPU)
454
- self.model.to(self.device)
455
-
456
- # Set the model to evaluation mode (no gradient calculation)
457
- self.model.eval()
458
-
459
-
460
- class CLS_AV_MossFormer2_TSE_16K(SpeechModel):
461
- """
462
- A subclass of SpeechModel that implements an audio-visual (AV) model using
463
- the AV-MossFormer2 architecture for target speaker extraction (TSE) at 16 kHz.
464
- This model leverages both audio and visual cues to perform speaker extraction.
465
-
466
- Args:
467
- args (Namespace): The argument parser containing model configurations and paths.
468
- """
469
-
470
- def __init__(self, args):
471
- # Initialize the parent SpeechModel class
472
- super(CLS_AV_MossFormer2_TSE_16K, self).__init__(args)
473
-
474
- # Import the AV-MossFormer2 model for 16 kHz target speech enhancement
475
- from models.av_mossformer2_tse.av_mossformer2 import AV_MossFormer2_TSE_16K
476
-
477
- # Initialize the model
478
- self.model = AV_MossFormer2_TSE_16K(args).model
479
- self.name = 'AV_MossFormer2_TSE_16K'
480
-
481
- # Load pre-trained model checkpoint
482
- self.load_model()
483
-
484
- # Move model to the appropriate device (GPU/CPU)
485
- self.model.to(self.device)
486
-
487
- # Set the model to evaluation mode (no gradient calculation)
488
- self.model.eval()
489
-
490
-