Spaces:
Build error
Build error
import os | |
import logging | |
import h5py | |
import soundfile | |
import librosa | |
import numpy as np | |
import pandas as pd | |
from scipy import stats | |
import datetime | |
import pickle | |
def create_folder(fd): | |
if not os.path.exists(fd): | |
os.makedirs(fd) | |
def get_filename(path): | |
path = os.path.realpath(path) | |
na_ext = path.split('/')[-1] | |
na = os.path.splitext(na_ext)[0] | |
return na | |
def get_sub_filepaths(folder): | |
paths = [] | |
for root, dirs, files in os.walk(folder): | |
for name in files: | |
path = os.path.join(root, name) | |
paths.append(path) | |
return paths | |
def create_logging(log_dir, filemode): | |
create_folder(log_dir) | |
i1 = 0 | |
while os.path.isfile(os.path.join(log_dir, '{:04d}.log'.format(i1))): | |
i1 += 1 | |
log_path = os.path.join(log_dir, '{:04d}.log'.format(i1)) | |
logging.basicConfig( | |
level=logging.DEBUG, | |
format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s', | |
datefmt='%a, %d %b %Y %H:%M:%S', | |
filename=log_path, | |
filemode=filemode) | |
# Print to console | |
console = logging.StreamHandler() | |
console.setLevel(logging.INFO) | |
formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s') | |
console.setFormatter(formatter) | |
logging.getLogger('').addHandler(console) | |
return logging | |
def read_metadata(csv_path, classes_num, id_to_ix): | |
"""Read metadata of AudioSet from a csv file. | |
Args: | |
csv_path: str | |
Returns: | |
meta_dict: {'audio_name': (audios_num,), 'target': (audios_num, classes_num)} | |
""" | |
with open(csv_path, 'r') as fr: | |
lines = fr.readlines() | |
lines = lines[3:] # Remove heads | |
audios_num = len(lines) | |
targets = np.zeros((audios_num, classes_num), dtype=np.bool) | |
audio_names = [] | |
for n, line in enumerate(lines): | |
items = line.split(', ') | |
"""items: ['--4gqARaEJE', '0.000', '10.000', '"/m/068hy,/m/07q6cd_,/m/0bt9lr,/m/0jbk"\n']""" | |
audio_name = 'Y{}.wav'.format(items[0]) # Audios are started with an extra 'Y' when downloading | |
label_ids = items[3].split('"')[1].split(',') | |
audio_names.append(audio_name) | |
# Target | |
for id in label_ids: | |
ix = id_to_ix[id] | |
targets[n, ix] = 1 | |
meta_dict = {'audio_name': np.array(audio_names), 'target': targets} | |
return meta_dict | |
def float32_to_int16(x): | |
assert np.max(np.abs(x)) <= 1.2 | |
x = np.clip(x, -1, 1) | |
return (x * 32767.).astype(np.int16) | |
def int16_to_float32(x): | |
return (x / 32767.).astype(np.float32) | |
def pad_or_truncate(x, audio_length): | |
"""Pad all audio to specific length.""" | |
if len(x) <= audio_length: | |
return np.concatenate((x, np.zeros(audio_length - len(x))), axis=0) | |
else: | |
return x[0 : audio_length] | |
def d_prime(auc): | |
d_prime = stats.norm().ppf(auc) * np.sqrt(2.0) | |
return d_prime | |
class Mixup(object): | |
def __init__(self, mixup_alpha, random_seed=1234): | |
"""Mixup coefficient generator. | |
""" | |
self.mixup_alpha = mixup_alpha | |
self.random_state = np.random.RandomState(random_seed) | |
def get_lambda(self, batch_size): | |
"""Get mixup random coefficients. | |
Args: | |
batch_size: int | |
Returns: | |
mixup_lambdas: (batch_size,) | |
""" | |
mixup_lambdas = [] | |
for n in range(0, batch_size, 2): | |
lam = self.random_state.beta(self.mixup_alpha, self.mixup_alpha, 1)[0] | |
mixup_lambdas.append(lam) | |
mixup_lambdas.append(1. - lam) | |
return np.array(mixup_lambdas) | |
class StatisticsContainer(object): | |
def __init__(self, statistics_path): | |
"""Contain statistics of different training iterations. | |
""" | |
self.statistics_path = statistics_path | |
self.backup_statistics_path = '{}_{}.pkl'.format( | |
os.path.splitext(self.statistics_path)[0], | |
datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')) | |
self.statistics_dict = {'bal': [], 'test': []} | |
def append(self, iteration, statistics, data_type): | |
statistics['iteration'] = iteration | |
self.statistics_dict[data_type].append(statistics) | |
def dump(self): | |
pickle.dump(self.statistics_dict, open(self.statistics_path, 'wb')) | |
pickle.dump(self.statistics_dict, open(self.backup_statistics_path, 'wb')) | |
logging.info(' Dump statistics to {}'.format(self.statistics_path)) | |
logging.info(' Dump statistics to {}'.format(self.backup_statistics_path)) | |
def load_state_dict(self, resume_iteration): | |
self.statistics_dict = pickle.load(open(self.statistics_path, 'rb')) | |
resume_statistics_dict = {'bal': [], 'test': []} | |
for key in self.statistics_dict.keys(): | |
for statistics in self.statistics_dict[key]: | |
if statistics['iteration'] <= resume_iteration: | |
resume_statistics_dict[key].append(statistics) | |
self.statistics_dict = resume_statistics_dict |