Spaces:
Build error
Build error
import os | |
import sys | |
import csv | |
import glob | |
import torch | |
import random | |
from tqdm import tqdm | |
from typing import List, Any | |
from deepafx_st.data.audio import AudioFile | |
import deepafx_st.utils as utils | |
import deepafx_st.data.augmentations as augmentations | |
class AudioDataset(torch.utils.data.Dataset): | |
"""Audio dataset which returns an input and target file. | |
Args: | |
audio_dir (str): Path to the top level of the audio dataset. | |
input_dir (List[str], optional): List of paths to the directories containing input audio files. Default: ["clean"] | |
subset (str, optional): Dataset subset. One of ["train", "val", "test"]. Default: "train" | |
length (int, optional): Number of samples to load for each example. Default: 65536 | |
train_frac (float, optional): Fraction of the files to use for training subset. Default: 0.8 | |
val_frac (float, optional): Fraction of the files to use for validation subset. Default: 0.1 | |
buffer_size_gb (float, optional): Size of audio to read into RAM in GB at any given time. Default: 10.0 | |
Note: This is the buffer size PER DataLoader worker. So total RAM = buffer_size_gb * num_workers | |
buffer_reload_rate (int, optional): Number of items to generate before loading next chunk of dataset. Default: 10000 | |
half (bool, optional): Sotre audio samples as float 16. Default: False | |
num_examples_per_epoch (int, optional): Define an epoch as certain number of audio examples. Default: 10000 | |
random_scale_input (bool, optional): Apply random gain scaling to input utterances. Default: False | |
random_scale_target (bool, optional): Apply same random gain scaling to target utterances. Default: False | |
augmentations (dict, optional): List of augmentation types to apply to inputs. Default: [] | |
freq_corrupt (bool, optional): Apply bad EQ filters. Default: False | |
drc_corrupt (bool, optional): Apply an expander to corrupt dynamic range. Default: False | |
ext (str, optional): Expected audio file extension. Default: "wav" | |
""" | |
def __init__( | |
self, | |
audio_dir, | |
input_dirs: List[str] = ["cleanraw"], | |
subset: str = "train", | |
length: int = 65536, | |
train_frac: float = 0.8, | |
val_per: float = 0.1, | |
buffer_size_gb: float = 1.0, | |
buffer_reload_rate: float = 1000, | |
half: bool = False, | |
num_examples_per_epoch: int = 10000, | |
random_scale_input: bool = False, | |
random_scale_target: bool = False, | |
augmentations: dict = {}, | |
freq_corrupt: bool = False, | |
drc_corrupt: bool = False, | |
ext: str = "wav", | |
): | |
super().__init__() | |
self.audio_dir = audio_dir | |
self.dataset_name = os.path.basename(audio_dir) | |
self.input_dirs = input_dirs | |
self.subset = subset | |
self.length = length | |
self.train_frac = train_frac | |
self.val_per = val_per | |
self.buffer_size_gb = buffer_size_gb | |
self.buffer_reload_rate = buffer_reload_rate | |
self.half = half | |
self.num_examples_per_epoch = num_examples_per_epoch | |
self.random_scale_input = random_scale_input | |
self.random_scale_target = random_scale_target | |
self.augmentations = augmentations | |
self.freq_corrupt = freq_corrupt | |
self.drc_corrupt = drc_corrupt | |
self.ext = ext | |
self.input_filepaths = [] | |
for input_dir in input_dirs: | |
search_path = os.path.join(audio_dir, input_dir, f"*.{ext}") | |
self.input_filepaths += glob.glob(search_path) | |
self.input_filepaths = sorted(self.input_filepaths) | |
# create dataset split based on subset | |
self.input_filepaths = utils.split_dataset( | |
self.input_filepaths, | |
subset, | |
train_frac, | |
) | |
# get details about input audio files | |
input_files = {} | |
input_dur_frames = 0 | |
for input_filepath in tqdm(self.input_filepaths, ncols=80): | |
file_id = os.path.basename(input_filepath) | |
audio_file = AudioFile( | |
input_filepath, | |
preload=False, | |
half=half, | |
) | |
if audio_file.num_frames < (self.length * 2): | |
continue | |
input_files[file_id] = audio_file | |
input_dur_frames += input_files[file_id].num_frames | |
if len(list(input_files.items())) < 1: | |
raise RuntimeError(f"No files found in {search_path}.") | |
input_dur_hr = (input_dur_frames / input_files[file_id].sample_rate) / 3600 | |
print( | |
f"\nLoaded {len(input_files)} files for {subset} = {input_dur_hr:0.2f} hours." | |
) | |
self.sample_rate = input_files[file_id].sample_rate | |
# save a csv file with details about the train and test split | |
splits_dir = os.path.join("configs", "splits") | |
if not os.path.isdir(splits_dir): | |
os.makedirs(splits_dir) | |
csv_filepath = os.path.join(splits_dir, f"{self.dataset_name}_{self.subset}_set.csv") | |
with open(csv_filepath, "w") as fp: | |
dw = csv.DictWriter(fp, ["file_id", "filepath", "type", "subset"]) | |
dw.writeheader() | |
for input_filepath in self.input_filepaths: | |
dw.writerow( | |
{ | |
"file_id": self.get_file_id(input_filepath), | |
"filepath": input_filepath, | |
"type": "input", | |
"subset": self.subset, | |
} | |
) | |
# some setup for iteratble loading of the dataset into RAM | |
self.items_since_load = self.buffer_reload_rate | |
def __len__(self): | |
return self.num_examples_per_epoch | |
def load_audio_buffer(self): | |
self.input_files_loaded = {} # clear audio buffer | |
self.items_since_load = 0 # reset iteration counter | |
nbytes_loaded = 0 # counter for data in RAM | |
# different subset in each | |
random.shuffle(self.input_filepaths) | |
# load files into RAM | |
for input_filepath in self.input_filepaths: | |
file_id = os.path.basename(input_filepath) | |
audio_file = AudioFile( | |
input_filepath, | |
preload=True, | |
half=self.half, | |
) | |
if audio_file.num_frames < (self.length * 2): | |
continue | |
self.input_files_loaded[file_id] = audio_file | |
nbytes = audio_file.audio.element_size() * audio_file.audio.nelement() | |
nbytes_loaded += nbytes | |
# check the size of loaded data | |
if nbytes_loaded > self.buffer_size_gb * 1e9: | |
break | |
def generate_pair(self): | |
# ------------------------ Input audio ---------------------- | |
rand_input_file_id = None | |
input_file = None | |
start_idx = None | |
stop_idx = None | |
while True: | |
rand_input_file_id = self.get_random_file_id(self.input_files_loaded.keys()) | |
# use this random key to retrieve an input file | |
input_file = self.input_files_loaded[rand_input_file_id] | |
# load the audio data if needed | |
if not input_file.loaded: | |
raise RuntimeError("Audio not loaded.") | |
# get a random patch of size `self.length` x 2 | |
start_idx, stop_idx = self.get_random_patch( | |
input_file, int(self.length * 2) | |
) | |
if start_idx >= 0: | |
break | |
input_audio = input_file.audio[:, start_idx:stop_idx].clone().detach() | |
input_audio = input_audio.view(1, -1) | |
if self.half: | |
input_audio = input_audio.float() | |
# peak normalize to -12 dBFS | |
input_audio /= input_audio.abs().max() | |
input_audio *= 10 ** (-12.0 / 20) # with min 3 dBFS headroom | |
if len(list(self.augmentations.items())) > 0: | |
if torch.rand(1).sum() < 0.5: | |
input_audio_aug = augmentations.apply( | |
[input_audio], | |
self.sample_rate, | |
self.augmentations, | |
)[0] | |
else: | |
input_audio_aug = input_audio.clone() | |
else: | |
input_audio_aug = input_audio.clone() | |
input_audio_corrupt = input_audio_aug.clone() | |
# apply frequency and dynamic range corrpution (expander) | |
if self.freq_corrupt and torch.rand(1).sum() < 0.75: | |
input_audio_corrupt = augmentations.frequency_corruption( | |
[input_audio_corrupt], self.sample_rate | |
)[0] | |
# peak normalize again before passing through dynamic range expander | |
input_audio_corrupt /= input_audio_corrupt.abs().max() | |
input_audio_corrupt *= 10 ** (-12.0 / 20) # with min 3 dBFS headroom | |
if self.drc_corrupt and torch.rand(1).sum() < 0.10: | |
input_audio_corrupt = augmentations.dynamic_range_corruption( | |
[input_audio_corrupt], self.sample_rate | |
)[0] | |
# ------------------------ Target audio ---------------------- | |
# use the same augmented audio clip, add different random EQ and compressor | |
target_audio_corrupt = input_audio_aug.clone() | |
# apply frequency and dynamic range corrpution (expander) | |
if self.freq_corrupt and torch.rand(1).sum() < 0.75: | |
target_audio_corrupt = augmentations.frequency_corruption( | |
[target_audio_corrupt], self.sample_rate | |
)[0] | |
# peak normalize again before passing through dynamic range compressor | |
input_audio_corrupt /= input_audio_corrupt.abs().max() | |
input_audio_corrupt *= 10 ** (-12.0 / 20) # with min 3 dBFS headroom | |
if self.drc_corrupt and torch.rand(1).sum() < 0.75: | |
target_audio_corrupt = augmentations.dynamic_range_compression( | |
[target_audio_corrupt], self.sample_rate | |
)[0] | |
return input_audio_corrupt, target_audio_corrupt | |
def __getitem__(self, _): | |
""" """ | |
# increment counter | |
self.items_since_load += 1 | |
# load next chunk into buffer if needed | |
if self.items_since_load > self.buffer_reload_rate: | |
self.load_audio_buffer() | |
# generate pairs for style training | |
input_audio, target_audio = self.generate_pair() | |
# ------------------------ Conform length of files ------------------- | |
input_audio = utils.conform_length(input_audio, int(self.length * 2)) | |
target_audio = utils.conform_length(target_audio, int(self.length * 2)) | |
# ------------------------ Apply fade in and fade out ------------------- | |
input_audio = utils.linear_fade(input_audio, sample_rate=self.sample_rate) | |
target_audio = utils.linear_fade(target_audio, sample_rate=self.sample_rate) | |
# ------------------------ Final normalizeation ---------------------- | |
# always peak normalize final input to -12 dBFS | |
input_audio /= input_audio.abs().max() | |
input_audio *= 10 ** (-12.0 / 20.0) | |
# always peak normalize the target to -12 dBFS | |
target_audio /= target_audio.abs().max() | |
target_audio *= 10 ** (-12.0 / 20.0) | |
return input_audio, target_audio | |
def get_random_file_id(keys): | |
# generate a random index into the keys of the input files | |
rand_input_idx = torch.randint(0, len(keys) - 1, [1])[0] | |
# find the key (file_id) correponding to the random index | |
rand_input_file_id = list(keys)[rand_input_idx] | |
return rand_input_file_id | |
def get_random_patch(audio_file, length, check_silence=True): | |
silent = True | |
count = 0 | |
while silent: | |
count += 1 | |
start_idx = torch.randint(0, audio_file.num_frames - length - 1, [1])[0] | |
# int(torch.rand(1) * (audio_file.num_frames - length)) | |
stop_idx = start_idx + length | |
patch = audio_file.audio[:, start_idx:stop_idx].clone().detach() | |
length = patch.shape[-1] | |
first_patch = patch[..., : length // 2] | |
second_patch = patch[..., length // 2 :] | |
if ( | |
(first_patch**2).mean() > 1e-5 and (second_patch**2).mean() > 1e-5 | |
) or not check_silence: | |
silent = False | |
if count > 100: | |
print("get_random_patch count", count) | |
return -1, -1 | |
# break | |
return start_idx, stop_idx | |
def get_file_id(self, filepath): | |
"""Given a filepath extract the DAPS file id. | |
Args: | |
filepath (str): Path to an audio files in the DAPS dataset. | |
Returns: | |
file_id (str): DAPS file id of the form <participant_id>_<script_id> | |
file_set (str): The DAPS set to which the file belongs. | |
""" | |
file_id = os.path.basename(filepath).split("_")[:2] | |
file_id = "_".join(file_id) | |
return file_id | |
def get_file_set(self, filepath): | |
"""Given a filepath extract the DAPS file set name. | |
Args: | |
filepath (str): Path to an audio files in the DAPS dataset. | |
Returns: | |
file_set (str): The DAPS set to which the file belongs. | |
""" | |
file_set = os.path.basename(filepath).split("_")[2:] | |
file_set = "_".join(file_set) | |
file_set = file_set.replace(f".{self.ext}", "") | |
return file_set | |