yourusername's picture
:beers: cheers
66a6dc0
# Adapted from:
# https://github.com/csteinmetz1/micro-tcn/blob/main/microtcn/utils.py
import os
import csv
import torch
import fnmatch
import numpy as np
import random
from enum import Enum
import pyloudnorm as pyln
class DSPMode(Enum):
NONE = "none"
TRAIN_INFER = "train_infer"
INFER = "infer"
def __str__(self):
return self.value
def loudness_normalize(x, sample_rate, target_loudness=-24.0):
x = x.view(1, -1)
stereo_audio = x.repeat(2, 1).permute(1, 0).numpy()
meter = pyln.Meter(sample_rate)
loudness = meter.integrated_loudness(stereo_audio)
norm_x = pyln.normalize.loudness(
stereo_audio,
loudness,
target_loudness,
)
x = torch.tensor(norm_x).permute(1, 0)
x = x[0, :].view(1, -1)
return x
def get_random_file_id(keys):
# generate a random index into the keys of the input files
rand_input_idx = torch.randint(0, len(keys) - 1, [1])[0]
# find the key (file_id) correponding to the random index
rand_input_file_id = list(keys)[rand_input_idx]
return rand_input_file_id
def get_random_patch(audio_file, length, check_silence=True):
silent = True
while silent:
start_idx = int(torch.rand(1) * (audio_file.num_frames - length))
stop_idx = start_idx + length
patch = audio_file.audio[:, start_idx:stop_idx].clone().detach()
if (patch ** 2).mean() > 1e-4 or not check_silence:
silent = False
return start_idx, stop_idx
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2 ** 32
np.random.seed(worker_seed)
random.seed(worker_seed)
def getFilesPath(directory, extension):
n_path = []
for path, subdirs, files in os.walk(directory):
for name in files:
if fnmatch.fnmatch(name, extension):
n_path.append(os.path.join(path, name))
n_path.sort()
return n_path
def count_parameters(model, trainable_only=True):
if trainable_only:
if len(list(model.parameters())) > 0:
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
else:
params = 0
else:
if len(list(model.parameters())) > 0:
params = sum(p.numel() for p in model.parameters())
else:
params = 0
return params
def system_summary(system):
print(f"Encoder: {count_parameters(system.encoder)/1e6:0.2f} M")
print(f"Processor: {count_parameters(system.processor)/1e6:0.2f} M")
if hasattr(system, "adv_loss_fn"):
for idx, disc in enumerate(system.adv_loss_fn.discriminators):
print(f"Discriminator {idx+1}: {count_parameters(disc)/1e6:0.2f} M")
def center_crop(x, length: int):
if x.shape[-1] != length:
start = (x.shape[-1] - length) // 2
stop = start + length
x = x[..., start:stop]
return x
def causal_crop(x, length: int):
if x.shape[-1] != length:
stop = x.shape[-1] - 1
start = stop - length
x = x[..., start:stop]
return x
def denormalize(norm_val, max_val, min_val):
return (norm_val * (max_val - min_val)) + min_val
def normalize(denorm_val, max_val, min_val):
return (denorm_val - min_val) / (max_val - min_val)
def get_random_patch(audio_file, length, energy_treshold=1e-4):
"""Produce sample indicies for a random patch of size `length`.
This function will check the energy of the selected patch to
ensure that it is not complete silence. If silence is found,
it will continue searching for a non-silent patch.
Args:
audio_file (AudioFile): Audio file object.
length (int): Number of samples in random patch.
Returns:
start_idx (int): Starting sample index
stop_idx (int): Stop sample index
"""
silent = True
while silent:
start_idx = int(torch.rand(1) * (audio_file.num_frames - length))
stop_idx = start_idx + length
patch = audio_file.audio[:, start_idx:stop_idx]
if (patch ** 2).mean() > energy_treshold:
silent = False
return start_idx, stop_idx
def split_dataset(file_list, subset, train_frac):
"""Given a list of files, split into train/val/test sets.
Args:
file_list (list): List of audio files.
subset (str): One of "train", "val", or "test".
train_frac (float): Fraction of the dataset to use for training.
Returns:
file_list (list): List of audio files corresponding to subset.
"""
assert train_frac > 0.1 and train_frac < 1.0
total_num_examples = len(file_list)
train_num_examples = int(total_num_examples * train_frac)
val_num_examples = int(total_num_examples * (1 - train_frac) / 2)
test_num_examples = total_num_examples - (train_num_examples + val_num_examples)
if train_num_examples < 0:
raise ValueError(
f"No examples in training set. Try increasing train_frac: {train_frac}."
)
elif val_num_examples < 0:
raise ValueError(
f"No examples in validation set. Try decreasing train_frac: {train_frac}."
)
elif test_num_examples < 0:
raise ValueError(
f"No examples in test set. Try decreasing train_frac: {train_frac}."
)
if subset == "train":
start_idx = 0
stop_idx = train_num_examples
elif subset == "val":
start_idx = train_num_examples
stop_idx = start_idx + val_num_examples
elif subset == "test":
start_idx = train_num_examples + val_num_examples
stop_idx = start_idx + test_num_examples + 1
else:
raise ValueError("Invalid subset: {subset}.")
return file_list[start_idx:stop_idx]
def rademacher(size):
"""Generates random samples from a Rademacher distribution +-1
Args:
size (int):
"""
m = torch.distributions.binomial.Binomial(1, 0.5)
x = m.sample(size)
x[x == 0] = -1
return x
def get_subset(csv_file):
subset_files = []
with open(csv_file) as fp:
reader = csv.DictReader(fp)
for row in reader:
subset_files.append(row["filepath"])
return list(set(subset_files))
def conform_length(x: torch.Tensor, length: int):
"""Crop or pad input on last dim to match `length`."""
if x.shape[-1] < length:
padsize = length - x.shape[-1]
x = torch.nn.functional.pad(x, (0, padsize))
elif x.shape[-1] > length:
x = x[..., :length]
return x
def linear_fade(
x: torch.Tensor,
fade_ms: float = 50.0,
sample_rate: float = 22050,
):
"""Apply fade in and fade out to last dim."""
fade_samples = int(fade_ms * 1e-3 * 22050)
fade_in = torch.linspace(0.0, 1.0, steps=fade_samples)
fade_out = torch.linspace(1.0, 0.0, steps=fade_samples)
# fade in
x[..., :fade_samples] *= fade_in
# fade out
x[..., -fade_samples:] *= fade_out
return x
# def get_random_patch(x, sample_rate, length_samples):
# length = length_samples
# silent = True
# while silent:
# start_idx = np.random.randint(0, x.shape[-1] - length - 1)
# stop_idx = start_idx + length
# x_crop = x[0:1, start_idx:stop_idx]
# # check for silence
# frames = length // sample_rate
# silent_frames = []
# for n in range(frames):
# start_idx = n * sample_rate
# stop_idx = start_idx + sample_rate
# x_frame = x_crop[0:1, start_idx:stop_idx]
# if (x_frame ** 2).mean() > 3e-4:
# silent_frames.append(False)
# else:
# silent_frames.append(True)
# silent = True if any(silent_frames) else False
# x_crop /= x_crop.abs().max()
# return x_crop