Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2024 Alibaba Inc (authors: Shengkui Zhao, Zexu Pan) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
import torch.nn as nn | |
import soundfile as sf | |
import os | |
import subprocess | |
import librosa | |
from tqdm import tqdm | |
import numpy as np | |
from pydub import AudioSegment | |
from utils.decode import decode_one_audio | |
from dataloader.dataloader import DataReader | |
MAX_WAV_VALUE = 32768.0 | |
class SpeechModel: | |
""" | |
The SpeechModel class is a base class designed to handle speech processing tasks, | |
such as loading, processing, and decoding audio data. It initializes the computational | |
device (CPU or GPU) and holds model-related attributes. The class is flexible and intended | |
to be extended by specific speech models for tasks like speech enhancement, speech separation, | |
target speaker extraction etc. | |
Attributes: | |
- args: Argument parser object that contains configuration settings. | |
- device: The device (CPU or GPU) on which the model will run. | |
- model: The actual model used for speech processing tasks (to be loaded by subclasses). | |
- name: A placeholder for the model's name. | |
- data: A dictionary to store any additional data related to the model, such as audio input. | |
""" | |
def __init__(self, args): | |
""" | |
Initializes the SpeechModel class by determining the computation device | |
(GPU or CPU) to be used for running the model, based on system availability. | |
Args: | |
- args: Argument parser object containing settings like whether to use CUDA (GPU) or not. | |
""" | |
# Check if a GPU is available | |
if torch.cuda.is_available(): | |
# Find the GPU with the most free memory using a custom method | |
free_gpu_id = self.get_free_gpu() | |
if free_gpu_id is not None: | |
args.use_cuda = 1 | |
torch.cuda.set_device(free_gpu_id) | |
self.device = torch.device('cuda') | |
else: | |
# If no GPU is detected, use the CPU | |
#print("No GPU found. Using CPU.") | |
args.use_cuda = 0 | |
self.device = torch.device('cpu') | |
else: | |
# If no GPU is detected, use the CPU | |
args.use_cuda = 0 | |
self.device = torch.device('cpu') | |
self.args = args | |
self.model = None | |
self.name = None | |
self.data = {} | |
self.print = False | |
def get_free_gpu(self): | |
""" | |
Identifies the GPU with the most free memory using 'nvidia-smi' and returns its index. | |
This function queries the available GPUs on the system and determines which one has | |
the highest amount of free memory. It uses the `nvidia-smi` command-line tool to gather | |
GPU memory usage data. If successful, it returns the index of the GPU with the most free memory. | |
If the query fails or an error occurs, it returns None. | |
Returns: | |
int: Index of the GPU with the most free memory, or None if no GPU is found or an error occurs. | |
""" | |
try: | |
# Run nvidia-smi to query GPU memory usage and free memory | |
result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used,memory.free', '--format=csv,nounits,noheader'], stdout=subprocess.PIPE) | |
gpu_info = result.stdout.decode('utf-8').strip().split('\n') | |
free_gpu = None | |
max_free_memory = 0 | |
for i, info in enumerate(gpu_info): | |
used, free = map(int, info.split(',')) | |
if free > max_free_memory: | |
max_free_memory = free | |
free_gpu = i | |
return free_gpu | |
except Exception as e: | |
print(f"Error finding free GPU: {e}") | |
return None | |
def download_model(self, model_name): | |
checkpoint_dir = self.args.checkpoint_dir | |
from huggingface_hub import snapshot_download | |
if not os.path.exists(checkpoint_dir): | |
os.makedirs(checkpoint_dir) | |
print(f'downloading checkpoint for {model_name}') | |
try: | |
snapshot_download(repo_id=f'alibabasglab/{model_name}', local_dir=checkpoint_dir) | |
return True | |
except: | |
return False | |
def load_model(self): | |
""" | |
Loads a pre-trained model checkpoints from a specified directory. It checks for | |
the best model ('last_best_checkpoint') in the checkpoint directory. If a model is | |
found, it loads the model state into the current model instance. | |
If no checkpoint is found, it will try to download the model from huggingface. | |
If the downloading fails, it prints a warning message. | |
Steps: | |
- Search for the best model checkpoint or the most recent one. | |
- Load the model's state dictionary from the checkpoint file. | |
Raises: | |
- FileNotFoundError: If neither 'last_best_checkpoint' nor 'last_checkpoint' files are found. | |
""" | |
# Define paths for the best model and the last checkpoint | |
best_name = os.path.join(self.args.checkpoint_dir, 'last_best_checkpoint') | |
# Check if the last best checkpoint exists | |
if not os.path.isfile(best_name): | |
if not self.download_model(self.name): | |
# If downloading is unsuccessful | |
print(f'Warning: Downloading model {self.name} is not successful. Please try again or manually download from https://huggingface.co/alibabasglab/{self.name}/tree/main !') | |
return | |
if isinstance(self.model, nn.ModuleList): | |
with open(best_name, 'r') as f: | |
model_name = f.readline().strip() | |
checkpoint_path = os.path.join(self.args.checkpoint_dir, model_name) | |
self._load_model(self.model[0], checkpoint_path, model_key='mossformer') | |
model_name = f.readline().strip() | |
checkpoint_path = os.path.join(self.args.checkpoint_dir, model_name) | |
self._load_model(self.model[1], checkpoint_path, model_key='generator') | |
else: | |
# Read the model's checkpoint name from the file | |
with open(best_name, 'r') as f: | |
model_name = f.readline().strip() | |
# Form the full path to the model's checkpoint | |
checkpoint_path = os.path.join(self.args.checkpoint_dir, model_name) | |
self._load_model(self.model, checkpoint_path, model_key='model') | |
def _load_model(self, model, checkpoint_path, model_key=None): | |
# Load the checkpoint file into memory (map_location ensures compatibility with different devices) | |
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) | |
# Load the model's state dictionary (weights and biases) into the current model | |
if model_key in checkpoint: | |
pretrained_model = checkpoint[model_key] | |
else: | |
pretrained_model = checkpoint | |
state = model.state_dict() | |
for key in state.keys(): | |
if key in pretrained_model and state[key].shape == pretrained_model[key].shape: | |
state[key] = pretrained_model[key] | |
elif key.replace('module.', '') in pretrained_model and state[key].shape == pretrained_model[key.replace('module.', '')].shape: | |
state[key] = pretrained_model[key.replace('module.', '')] | |
elif 'module.'+key in pretrained_model and state[key].shape == pretrained_model['module.'+key].shape: | |
state[key] = pretrained_model['module.'+key] | |
elif self.print: print(f'{key} not loaded') | |
model.load_state_dict(state) | |
def decode(self): | |
""" | |
Decodes the input audio data using the loaded model and ensures the output matches the original audio length. | |
This method processes the audio through a speech model (e.g., for enhancement, separation, etc.), | |
and truncates the resulting audio to match the original input's length. The method supports multiple speakers | |
if the model handles multi-speaker audio. | |
Returns: | |
output_audio: The decoded audio after processing, truncated to the input audio length. | |
If multi-speaker audio is processed, a list of truncated audio outputs per speaker is returned. | |
""" | |
# Decode the audio using the loaded model on the given device (e.g., CPU or GPU) | |
output_audios = [] | |
for i in range(len(self.data['audio'])): | |
output_audio = decode_one_audio(self.model, self.device, self.data['audio'][i], self.args) | |
# Ensure the decoded output matches the length of the input audio | |
if isinstance(output_audio, list): | |
# If multi-speaker audio (a list of outputs), truncate each speaker's audio to input length | |
for spk in range(self.args.num_spks): | |
output_audio[spk] = output_audio[spk][:self.data['audio_len']] | |
else: | |
# Single output, truncate to input audio length | |
output_audio = output_audio[:self.data['audio_len']] | |
output_audios.append(output_audio) | |
if isinstance(output_audios[0], list): | |
output_audios_np = [] | |
for spk in range(self.args.num_spks): | |
output_audio_buf = [] | |
for i in range(len(output_audios)): | |
output_audio_buf.append(output_audios[i][spk]) | |
#output_audio_buf = np.vstack((output_audio_buf, output_audios[i][spk])).T | |
output_audios_np.append(np.array(output_audio_buf)) | |
else: | |
output_audios_np = np.array(output_audios) | |
return output_audios_np | |
def process(self, input_path, online_write=False, output_path=None): | |
""" | |
Load and process audio files from the specified input path. Optionally, | |
write the output audio files to the specified output directory. | |
Args: | |
input_path (str): Path to the input audio files or folder. | |
online_write (bool): Whether to write the processed audio to disk in real-time. | |
output_path (str): Optional path for writing output files. If None, output | |
will be stored in self.result. | |
Returns: | |
dict or ndarray: Processed audio results either as a dictionary or as a single array, | |
depending on the number of audio files processed. | |
Returns None if online_write is enabled. | |
""" | |
self.result = {} | |
self.args.input_path = input_path | |
data_reader = DataReader(self.args) # Initialize a data reader to load the audio files | |
# Check if online writing is enabled | |
if online_write: | |
output_wave_dir = self.args.output_dir # Set the default output directory | |
if isinstance(output_path, str): # If a specific output path is provided, use it | |
output_wave_dir = os.path.join(output_path, self.name) | |
# Create the output directory if it does not exist | |
if not os.path.isdir(output_wave_dir): | |
os.makedirs(output_wave_dir) | |
num_samples = len(data_reader) # Get the total number of samples to process | |
print(f'Running {self.name} ...') # Display the model being used | |
if self.args.task == 'target_speaker_extraction': | |
from utils.video_process import process_tse | |
assert online_write == True | |
process_tse(self.args, self.model, self.device, data_reader, output_wave_dir) | |
else: | |
# Disable gradient calculation for better efficiency during inference | |
with torch.no_grad(): | |
for idx in tqdm(range(num_samples)): # Loop over all audio samples | |
self.data = {} | |
# Read the audio, waveform ID, and audio length from the data reader | |
input_audio, wav_id, input_len, scalars, audio_info = data_reader[idx] | |
# Store the input audio and metadata in self.data | |
self.data['audio'] = input_audio | |
self.data['id'] = wav_id | |
self.data['audio_len'] = input_len | |
self.data.update(audio_info) | |
# Perform the audio decoding/processing | |
output_audios = self.decode() | |
# Perform audio renormalization | |
if not isinstance(output_audios, list): | |
if len(scalars) > 1: | |
for i in range(len(scalars)): | |
output_audios[:,i] = output_audios[:,i] * scalars[i] | |
else: | |
output_audios = output_audios * scalars[0] | |
if online_write: | |
# If online writing is enabled, save the output audio to files | |
if isinstance(output_audios, list): | |
# In case of multi-speaker output, save each speaker's output separately | |
for spk in range(self.args.num_spks): | |
output_file = os.path.join(output_wave_dir, wav_id.replace('.'+self.data['ext'], f'_s{spk+1}.'+self.data['ext'])) | |
self.write_audio(output_file, key=None, spk=spk, audio=output_audios) | |
else: | |
# Single-speaker or standard output | |
output_file = os.path.join(output_wave_dir, wav_id) | |
self.write_audio(output_file, key=None, spk=None, audio=output_audios) | |
else: | |
# If not writing to disk, store the output in the result dictionary | |
self.result[wav_id] = output_audios | |
# Return the processed results if not writing to disk | |
if not online_write: | |
if len(self.result) == 1: | |
# If there is only one result, return it directly | |
return next(iter(self.result.values())) | |
else: | |
# Otherwise, return the entire result dictionary | |
return self.result | |
def write_audio(self, output_path, key=None, spk=None, audio=None): | |
""" | |
This function writes an audio signal to an output file, applying necessary transformations | |
such as resampling, channel handling, and format conversion based on the provided parameters | |
and the instance's internal settings. | |
Args: | |
output_path (str): The file path where the audio will be saved. | |
key (str, optional): The key used to retrieve audio from the internal result dictionary | |
if audio is not provided. | |
spk (str, optional): A specific speaker identifier, used to extract a particular speaker's | |
audio from a multi-speaker dataset or result. | |
audio (numpy.ndarray, optional): A numpy array containing the audio data to be written. | |
If provided, key and spk are ignored. | |
""" | |
if audio is not None: | |
if spk is not None: | |
result_ = audio[spk] | |
else: | |
result_ = audio | |
else: | |
if spk is not None: | |
result_ = self.result[key][spk] | |
else: | |
result_ = self.result[key] | |
if self.data['sample_rate'] != self.args.sampling_rate: | |
if self.data['channels'] == 2: | |
left_channel = librosa.resample(result_[0,:], orig_sr=self.args.sampling_rate, target_sr=self.data['sample_rate']) | |
right_channel = librosa.resample(result_[1,:], orig_sr=self.args.sampling_rate, target_sr=self.data['sample_rate']) | |
result = np.vstack((left_channel, right_channel)).T | |
else: | |
result = librosa.resample(result_[0,:], orig_sr=self.args.sampling_rate, target_sr=self.data['sample_rate']) | |
else: | |
if self.data['channels'] == 2: | |
left_channel = result_[0,:] | |
right_channel = result_[1,:] | |
result = np.vstack((left_channel, right_channel)).T | |
else: | |
result = result_[0,:] | |
if self.data['sample_width'] == 4: ##32 bit float | |
MAX_WAV_VALUE = 2147483648.0 | |
np_type = np.int32 | |
elif self.data['sample_width'] == 2: ##16 bit int | |
MAX_WAV_VALUE = 32768.0 | |
np_type = np.int16 | |
else: | |
self.data['sample_width'] = 2 ##16 bit int | |
MAX_WAV_VALUE = 32768.0 | |
np_type = np.int16 | |
result = result * MAX_WAV_VALUE | |
result = result.astype(np_type) | |
audio_segment = AudioSegment( | |
result.tobytes(), # Raw audio data as bytes | |
frame_rate=self.data['sample_rate'], # Sample rate | |
sample_width=self.data['sample_width'], # No. bytes per sample | |
channels=self.data['channels'] # No. channels | |
) | |
audio_format = 'ipod' if self.data['ext'] in ['m4a', 'aac'] else self.data['ext'] | |
audio_segment.export(output_path, format=audio_format) | |
def write(self, output_path, add_subdir=False, use_key=False): | |
""" | |
Write the processed audio results to the specified output path. | |
Args: | |
output_path (str): The directory or file path where processed audio will be saved. If not | |
provided, defaults to self.args.output_dir. | |
add_subdir (bool): If True, appends the model name as a subdirectory to the output path. | |
use_key (bool): If True, uses the result dictionary's keys (audio file IDs) for filenames. | |
Returns: | |
None: Outputs are written to disk, no data is returned. | |
""" | |
# Ensure the output path is a string. If not provided, use the default output directory | |
if not isinstance(output_path, str): | |
output_path = self.args.output_dir | |
# If add_subdir is enabled, create a subdirectory for the model name | |
if add_subdir: | |
if os.path.isfile(output_path): | |
print(f'File exists: {output_path}, remove it and try again!') | |
return | |
output_path = os.path.join(output_path, self.name) | |
if not os.path.isdir(output_path): | |
os.makedirs(output_path) | |
# Ensure proper directory setup when using keys for filenames | |
if use_key and not os.path.isdir(output_path): | |
if os.path.exists(output_path): | |
print(f'File exists: {output_path}, remove it and try again!') | |
return | |
os.makedirs(output_path) | |
# If not using keys and output path is a directory, check for conflicts | |
if not use_key and os.path.isdir(output_path): | |
print(f'Directory exists: {output_path}, remove it and try again!') | |
return | |
# Iterate over the results dictionary to write the processed audio to disk | |
for key in self.result: | |
if use_key: | |
# If using keys, format filenames based on the result dictionary's keys (audio IDs) | |
if isinstance(self.result[key], list): # For multi-speaker outputs | |
for spk in range(self.args.num_spks): | |
output_file = os.path.join(output_path, key.replace('.'+self.data['ext'], f'_s{spk+1}.'+self.data['ext'])) | |
self.write_audio(output_file, key, spk) | |
else: | |
output_file = os.path.join(output_path, key) | |
self.write_audio(output_path, key) | |
else: | |
# If not using keys, write audio to the specified output path directly | |
if isinstance(self.result[key], list): # For multi-speaker outputs | |
for spk in range(self.args.num_spks): | |
output_file = output_path.replace('.'+self.data['ext'], f'_s{spk+1}.'+self.data['ext']) | |
self.write_audio(output_file, key, spk) | |
else: | |
self.write_audio(output_path, key) | |
# The model classes for specific sub-tasks | |
class CLS_FRCRN_SE_16K(SpeechModel): | |
""" | |
A subclass of SpeechModel that implements a speech enhancement model using | |
the FRCRN architecture for 16 kHz speech enhancement. | |
Args: | |
args (Namespace): The argument parser containing model configurations and paths. | |
""" | |
def __init__(self, args): | |
# Initialize the parent SpeechModel class | |
super(CLS_FRCRN_SE_16K, self).__init__(args) | |
# Import the FRCRN speech enhancement model for 16 kHz | |
from models.frcrn_se.frcrn import FRCRN_SE_16K | |
# Initialize the model | |
self.model = FRCRN_SE_16K(args).model | |
self.name = 'FRCRN_SE_16K' | |
# Load pre-trained model checkpoint | |
self.load_model() | |
# Move model to the appropriate device (GPU/CPU) | |
if args.use_cuda == 1: | |
self.model.to(self.device) | |
# Set the model to evaluation mode (no gradient calculation) | |
self.model.eval() | |
class CLS_MossFormer2_SE_48K(SpeechModel): | |
""" | |
A subclass of SpeechModel that implements the MossFormer2 architecture for | |
48 kHz speech enhancement. | |
Args: | |
args (Namespace): The argument parser containing model configurations and paths. | |
""" | |
def __init__(self, args): | |
# Initialize the parent SpeechModel class | |
super(CLS_MossFormer2_SE_48K, self).__init__(args) | |
# Import the MossFormer2 speech enhancement model for 48 kHz | |
from models.mossformer2_se.mossformer2_se_wrapper import MossFormer2_SE_48K | |
# Initialize the model | |
self.model = MossFormer2_SE_48K(args).model | |
self.name = 'MossFormer2_SE_48K' | |
# Load pre-trained model checkpoint | |
self.load_model() | |
# Move model to the appropriate device (GPU/CPU) | |
if args.use_cuda == 1: | |
self.model.to(self.device) | |
# Set the model to evaluation mode (no gradient calculation) | |
self.model.eval() | |
class CLS_MossFormer2_SR_48K(SpeechModel): | |
""" | |
A subclass of SpeechModel that implements the MossFormer2 architecture for | |
48 kHz speech super-resolution. | |
Args: | |
args (Namespace): The argument parser containing model configurations and paths. | |
""" | |
def __init__(self, args): | |
# Initialize the parent SpeechModel class | |
super(CLS_MossFormer2_SR_48K, self).__init__(args) | |
# Import the MossFormer2 speech enhancement model for 48 kHz | |
from models.mossformer2_sr.mossformer2_sr_wrapper import MossFormer2_SR_48K | |
# Initialize the model | |
self.model = nn.ModuleList() | |
self.model.append(MossFormer2_SR_48K(args).model_m) | |
self.model.append(MossFormer2_SR_48K(args).model_g) | |
self.name = 'MossFormer2_SR_48K' | |
# Load pre-trained model checkpoint | |
self.load_model() | |
# Move model to the appropriate device (GPU/CPU) | |
if args.use_cuda == 1: | |
for model in self.model: | |
model.to(self.device) | |
# Set the model to evaluation mode (no gradient calculation) | |
for model in self.model: | |
model.eval() | |
self.model[1].remove_weight_norm() | |
class CLS_MossFormerGAN_SE_16K(SpeechModel): | |
""" | |
A subclass of SpeechModel that implements the MossFormerGAN architecture for | |
16 kHz speech enhancement, utilizing GAN-based speech processing. | |
Args: | |
args (Namespace): The argument parser containing model configurations and paths. | |
""" | |
def __init__(self, args): | |
# Initialize the parent SpeechModel class | |
super(CLS_MossFormerGAN_SE_16K, self).__init__(args) | |
# Import the MossFormerGAN speech enhancement model for 16 kHz | |
from models.mossformer_gan_se.generator import MossFormerGAN_SE_16K | |
# Initialize the model | |
self.model = MossFormerGAN_SE_16K(args).model | |
self.name = 'MossFormerGAN_SE_16K' | |
# Load pre-trained model checkpoint | |
self.load_model() | |
# Move model to the appropriate device (GPU/CPU) | |
if args.use_cuda == 1: | |
self.model.to(self.device) | |
# Set the model to evaluation mode (no gradient calculation) | |
self.model.eval() | |
class CLS_MossFormer2_SS_16K(SpeechModel): | |
""" | |
A subclass of SpeechModel that implements the MossFormer2 architecture for | |
16 kHz speech separation. | |
Args: | |
args (Namespace): The argument parser containing model configurations and paths. | |
""" | |
def __init__(self, args): | |
# Initialize the parent SpeechModel class | |
super(CLS_MossFormer2_SS_16K, self).__init__(args) | |
# Import the MossFormer2 speech separation model for 16 kHz | |
from models.mossformer2_ss.mossformer2 import MossFormer2_SS_16K | |
# Initialize the model | |
self.model = MossFormer2_SS_16K(args).model | |
self.name = 'MossFormer2_SS_16K' | |
# Load pre-trained model checkpoint | |
self.load_model() | |
# Move model to the appropriate device (GPU/CPU) | |
if args.use_cuda == 1: | |
self.model.to(self.device) | |
# Set the model to evaluation mode (no gradient calculation) | |
self.model.eval() | |
class CLS_AV_MossFormer2_TSE_16K(SpeechModel): | |
""" | |
A subclass of SpeechModel that implements an audio-visual (AV) model using | |
the AV-MossFormer2 architecture for target speaker extraction (TSE) at 16 kHz. | |
This model leverages both audio and visual cues to perform speaker extraction. | |
Args: | |
args (Namespace): The argument parser containing model configurations and paths. | |
""" | |
def __init__(self, args): | |
# Initialize the parent SpeechModel class | |
super(CLS_AV_MossFormer2_TSE_16K, self).__init__(args) | |
# Import the AV-MossFormer2 model for 16 kHz target speech enhancement | |
from models.av_mossformer2_tse.av_mossformer2 import AV_MossFormer2_TSE_16K | |
# Initialize the model | |
self.model = AV_MossFormer2_TSE_16K(args).model | |
self.name = 'AV_MossFormer2_TSE_16K' | |
# Load pre-trained model checkpoint | |
self.load_model() | |
# Move model to the appropriate device (GPU/CPU) | |
if args.use_cuda == 1: | |
self.model.to(self.device) | |
# Set the model to evaluation mode (no gradient calculation) | |
self.model.eval() | |