Spaces:
Runtime error
Runtime error
# Adapted from: | |
# https://github.com/csteinmetz1/micro-tcn/blob/main/microtcn/utils.py | |
import os | |
import csv | |
import torch | |
import fnmatch | |
import numpy as np | |
import random | |
from enum import Enum | |
import pyloudnorm as pyln | |
class DSPMode(Enum): | |
NONE = "none" | |
TRAIN_INFER = "train_infer" | |
INFER = "infer" | |
def __str__(self): | |
return self.value | |
def loudness_normalize(x, sample_rate, target_loudness=-24.0): | |
x = x.view(1, -1) | |
stereo_audio = x.repeat(2, 1).permute(1, 0).numpy() | |
meter = pyln.Meter(sample_rate) | |
loudness = meter.integrated_loudness(stereo_audio) | |
norm_x = pyln.normalize.loudness( | |
stereo_audio, | |
loudness, | |
target_loudness, | |
) | |
x = torch.tensor(norm_x).permute(1, 0) | |
x = x[0, :].view(1, -1) | |
return x | |
def get_random_file_id(keys): | |
# generate a random index into the keys of the input files | |
rand_input_idx = torch.randint(0, len(keys) - 1, [1])[0] | |
# find the key (file_id) correponding to the random index | |
rand_input_file_id = list(keys)[rand_input_idx] | |
return rand_input_file_id | |
def get_random_patch(audio_file, length, check_silence=True): | |
silent = True | |
while silent: | |
start_idx = int(torch.rand(1) * (audio_file.num_frames - length)) | |
stop_idx = start_idx + length | |
patch = audio_file.audio[:, start_idx:stop_idx].clone().detach() | |
if (patch ** 2).mean() > 1e-4 or not check_silence: | |
silent = False | |
return start_idx, stop_idx | |
def seed_worker(worker_id): | |
worker_seed = torch.initial_seed() % 2 ** 32 | |
np.random.seed(worker_seed) | |
random.seed(worker_seed) | |
def getFilesPath(directory, extension): | |
n_path = [] | |
for path, subdirs, files in os.walk(directory): | |
for name in files: | |
if fnmatch.fnmatch(name, extension): | |
n_path.append(os.path.join(path, name)) | |
n_path.sort() | |
return n_path | |
def count_parameters(model, trainable_only=True): | |
if trainable_only: | |
if len(list(model.parameters())) > 0: | |
params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
else: | |
params = 0 | |
else: | |
if len(list(model.parameters())) > 0: | |
params = sum(p.numel() for p in model.parameters()) | |
else: | |
params = 0 | |
return params | |
def system_summary(system): | |
print(f"Encoder: {count_parameters(system.encoder)/1e6:0.2f} M") | |
print(f"Processor: {count_parameters(system.processor)/1e6:0.2f} M") | |
if hasattr(system, "adv_loss_fn"): | |
for idx, disc in enumerate(system.adv_loss_fn.discriminators): | |
print(f"Discriminator {idx+1}: {count_parameters(disc)/1e6:0.2f} M") | |
def center_crop(x, length: int): | |
if x.shape[-1] != length: | |
start = (x.shape[-1] - length) // 2 | |
stop = start + length | |
x = x[..., start:stop] | |
return x | |
def causal_crop(x, length: int): | |
if x.shape[-1] != length: | |
stop = x.shape[-1] - 1 | |
start = stop - length | |
x = x[..., start:stop] | |
return x | |
def denormalize(norm_val, max_val, min_val): | |
return (norm_val * (max_val - min_val)) + min_val | |
def normalize(denorm_val, max_val, min_val): | |
return (denorm_val - min_val) / (max_val - min_val) | |
def get_random_patch(audio_file, length, energy_treshold=1e-4): | |
"""Produce sample indicies for a random patch of size `length`. | |
This function will check the energy of the selected patch to | |
ensure that it is not complete silence. If silence is found, | |
it will continue searching for a non-silent patch. | |
Args: | |
audio_file (AudioFile): Audio file object. | |
length (int): Number of samples in random patch. | |
Returns: | |
start_idx (int): Starting sample index | |
stop_idx (int): Stop sample index | |
""" | |
silent = True | |
while silent: | |
start_idx = int(torch.rand(1) * (audio_file.num_frames - length)) | |
stop_idx = start_idx + length | |
patch = audio_file.audio[:, start_idx:stop_idx] | |
if (patch ** 2).mean() > energy_treshold: | |
silent = False | |
return start_idx, stop_idx | |
def split_dataset(file_list, subset, train_frac): | |
"""Given a list of files, split into train/val/test sets. | |
Args: | |
file_list (list): List of audio files. | |
subset (str): One of "train", "val", or "test". | |
train_frac (float): Fraction of the dataset to use for training. | |
Returns: | |
file_list (list): List of audio files corresponding to subset. | |
""" | |
assert train_frac > 0.1 and train_frac < 1.0 | |
total_num_examples = len(file_list) | |
train_num_examples = int(total_num_examples * train_frac) | |
val_num_examples = int(total_num_examples * (1 - train_frac) / 2) | |
test_num_examples = total_num_examples - (train_num_examples + val_num_examples) | |
if train_num_examples < 0: | |
raise ValueError( | |
f"No examples in training set. Try increasing train_frac: {train_frac}." | |
) | |
elif val_num_examples < 0: | |
raise ValueError( | |
f"No examples in validation set. Try decreasing train_frac: {train_frac}." | |
) | |
elif test_num_examples < 0: | |
raise ValueError( | |
f"No examples in test set. Try decreasing train_frac: {train_frac}." | |
) | |
if subset == "train": | |
start_idx = 0 | |
stop_idx = train_num_examples | |
elif subset == "val": | |
start_idx = train_num_examples | |
stop_idx = start_idx + val_num_examples | |
elif subset == "test": | |
start_idx = train_num_examples + val_num_examples | |
stop_idx = start_idx + test_num_examples + 1 | |
else: | |
raise ValueError("Invalid subset: {subset}.") | |
return file_list[start_idx:stop_idx] | |
def rademacher(size): | |
"""Generates random samples from a Rademacher distribution +-1 | |
Args: | |
size (int): | |
""" | |
m = torch.distributions.binomial.Binomial(1, 0.5) | |
x = m.sample(size) | |
x[x == 0] = -1 | |
return x | |
def get_subset(csv_file): | |
subset_files = [] | |
with open(csv_file) as fp: | |
reader = csv.DictReader(fp) | |
for row in reader: | |
subset_files.append(row["filepath"]) | |
return list(set(subset_files)) | |
def conform_length(x: torch.Tensor, length: int): | |
"""Crop or pad input on last dim to match `length`.""" | |
if x.shape[-1] < length: | |
padsize = length - x.shape[-1] | |
x = torch.nn.functional.pad(x, (0, padsize)) | |
elif x.shape[-1] > length: | |
x = x[..., :length] | |
return x | |
def linear_fade( | |
x: torch.Tensor, | |
fade_ms: float = 50.0, | |
sample_rate: float = 22050, | |
): | |
"""Apply fade in and fade out to last dim.""" | |
fade_samples = int(fade_ms * 1e-3 * 22050) | |
fade_in = torch.linspace(0.0, 1.0, steps=fade_samples) | |
fade_out = torch.linspace(1.0, 0.0, steps=fade_samples) | |
# fade in | |
x[..., :fade_samples] *= fade_in | |
# fade out | |
x[..., -fade_samples:] *= fade_out | |
return x | |
# def get_random_patch(x, sample_rate, length_samples): | |
# length = length_samples | |
# silent = True | |
# while silent: | |
# start_idx = np.random.randint(0, x.shape[-1] - length - 1) | |
# stop_idx = start_idx + length | |
# x_crop = x[0:1, start_idx:stop_idx] | |
# # check for silence | |
# frames = length // sample_rate | |
# silent_frames = [] | |
# for n in range(frames): | |
# start_idx = n * sample_rate | |
# stop_idx = start_idx + sample_rate | |
# x_frame = x_crop[0:1, start_idx:stop_idx] | |
# if (x_frame ** 2).mean() > 3e-4: | |
# silent_frames.append(False) | |
# else: | |
# silent_frames.append(True) | |
# silent = True if any(silent_frames) else False | |
# x_crop /= x_crop.abs().max() | |
# return x_crop | |