Spaces:
Build error
Build error
import os | |
import sys | |
sys.path.insert(1, os.path.join(sys.path[0], '../utils')) | |
import numpy as np | |
import argparse | |
import time | |
import logging | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
import torch.utils.data | |
from utilities import (create_folder, get_filename, create_logging, Mixup, | |
StatisticsContainer) | |
from models import (PVT, PVT2, PVT_lr, PVT_nopretrain, PVT_2layer, Cnn14, Cnn14_no_specaug, Cnn14_no_dropout, | |
Cnn6, Cnn10, ResNet22, ResNet38, ResNet54, Cnn14_emb512, Cnn14_emb128, | |
Cnn14_emb32, MobileNetV1, MobileNetV2, LeeNet11, LeeNet24, DaiNet19, | |
Res1dNet31, Res1dNet51, Wavegram_Cnn14, Wavegram_Logmel_Cnn14, | |
Wavegram_Logmel128_Cnn14, Cnn14_16k, Cnn14_8k, Cnn14_mel32, Cnn14_mel128, | |
Cnn14_mixup_time_domain, Cnn14_DecisionLevelMax, Cnn14_DecisionLevelAtt, Cnn6_Transformer, GLAM, GLAM2, GLAM3, Cnn4, EAT) | |
#from models_test import (PVT_test) | |
#from models1 import (PVT1) | |
#from models_vig import (VIG, VIG2) | |
#from models_vvt import (VVT) | |
#from models2 import (MPVIT, MPVIT2) | |
#from models_reshape import (PVT_reshape, PVT_tscam) | |
#from models_swin import (Swin, Swin_nopretrain) | |
#from models_swin2 import (Swin2) | |
#from models_van import (Van, Van_tiny) | |
#from models_focal import (Focal) | |
#from models_cross import (Cross) | |
#from models_cov import (Cov) | |
#from models_cnn import (Cnn_light) | |
#from models_twins import (Twins) | |
#from models_cmt import (Cmt, Cmt1) | |
#from models_shunted import (Shunted) | |
#from models_quadtree import (Quadtree, Quadtree2, Quadtree_nopretrain) | |
#from models_davit import (Davit_tscam, Davit, Davit_nopretrain) | |
from pytorch_utils import (move_data_to_device, count_parameters, count_flops, | |
do_mixup) | |
from data_generator import (AudioSetDataset, TrainSampler, BalancedTrainSampler, | |
AlternateTrainSampler, EvaluateSampler, collate_fn) | |
from evaluate import Evaluator | |
import config | |
from losses import get_loss_func | |
def train(args): | |
"""Train AudioSet tagging model. | |
Args: | |
dataset_dir: str | |
workspace: str | |
data_type: 'balanced_train' | 'full_train' | |
window_size: int | |
hop_size: int | |
mel_bins: int | |
model_type: str | |
loss_type: 'clip_bce' | |
balanced: 'none' | 'balanced' | 'alternate' | |
augmentation: 'none' | 'mixup' | |
batch_size: int | |
learning_rate: float | |
resume_iteration: int | |
early_stop: int | |
accumulation_steps: int | |
cuda: bool | |
""" | |
# Arugments & parameters | |
workspace = args.workspace | |
data_type = args.data_type | |
sample_rate = args.sample_rate | |
window_size = args.window_size | |
hop_size = args.hop_size | |
mel_bins = args.mel_bins | |
fmin = args.fmin | |
fmax = args.fmax | |
model_type = args.model_type | |
loss_type = args.loss_type | |
balanced = args.balanced | |
augmentation = args.augmentation | |
batch_size = args.batch_size | |
learning_rate = args.learning_rate | |
resume_iteration = args.resume_iteration | |
early_stop = args.early_stop | |
device = torch.device('cuda') if args.cuda and torch.cuda.is_available() else torch.device('cpu') | |
filename = args.filename | |
num_workers = 8 | |
clip_samples = config.clip_samples | |
classes_num = config.classes_num | |
loss_func = get_loss_func(loss_type) | |
# Paths | |
black_list_csv = None | |
train_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes', | |
'{}.h5'.format(data_type)) | |
eval_bal_indexes_hdf5_path = os.path.join(workspace, | |
'hdf5s', 'indexes', 'balanced_train.h5') | |
eval_test_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes', | |
'eval.h5') | |
checkpoints_dir = os.path.join(workspace, 'checkpoints', filename, | |
'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'.format( | |
sample_rate, window_size, hop_size, mel_bins, fmin, fmax), | |
'data_type={}'.format(data_type), model_type, | |
'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced), | |
'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size)) | |
create_folder(checkpoints_dir) | |
statistics_path = os.path.join(workspace, 'statistics', filename, | |
'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'.format( | |
sample_rate, window_size, hop_size, mel_bins, fmin, fmax), | |
'data_type={}'.format(data_type), model_type, | |
'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced), | |
'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size), | |
'statistics.pkl') | |
create_folder(os.path.dirname(statistics_path)) | |
logs_dir = os.path.join(workspace, 'logs', filename, | |
'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'.format( | |
sample_rate, window_size, hop_size, mel_bins, fmin, fmax), | |
'data_type={}'.format(data_type), model_type, | |
'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced), | |
'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size)) | |
create_logging(logs_dir, filemode='w') | |
logging.info(args) | |
if 'cuda' in str(device): | |
logging.info('Using GPU.') | |
device = 'cuda' | |
else: | |
logging.info('Using CPU. Set --cuda flag to use GPU.') | |
device = 'cpu' | |
# Model | |
Model = eval(model_type) | |
model = Model(sample_rate=sample_rate, window_size=window_size, | |
hop_size=hop_size, mel_bins=mel_bins, fmin=fmin, fmax=fmax, | |
classes_num=classes_num) | |
total = sum(p.numel() for p in model.parameters()) | |
print("Total params: %.2fM" % (total/1e6)) | |
logging.info("Total params: %.2fM" % (total/1e6)) | |
#params_num = count_parameters(model) | |
# flops_num = count_flops(model, clip_samples) | |
#logging.info('Parameters num: {}'.format(params_num)) | |
# logging.info('Flops num: {:.3f} G'.format(flops_num / 1e9)) | |
# Dataset will be used by DataLoader later. Dataset takes a meta as input | |
# and return a waveform and a target. | |
dataset = AudioSetDataset(sample_rate=sample_rate) | |
# Train sampler | |
if balanced == 'none': | |
Sampler = TrainSampler | |
elif balanced == 'balanced': | |
Sampler = BalancedTrainSampler | |
elif balanced == 'alternate': | |
Sampler = AlternateTrainSampler | |
train_sampler = Sampler( | |
indexes_hdf5_path=train_indexes_hdf5_path, | |
batch_size=batch_size * 2 if 'mixup' in augmentation else batch_size, | |
black_list_csv=black_list_csv) | |
# Evaluate sampler | |
eval_bal_sampler = EvaluateSampler( | |
indexes_hdf5_path=eval_bal_indexes_hdf5_path, batch_size=batch_size) | |
eval_test_sampler = EvaluateSampler( | |
indexes_hdf5_path=eval_test_indexes_hdf5_path, batch_size=batch_size) | |
# Data loader | |
train_loader = torch.utils.data.DataLoader(dataset=dataset, | |
batch_sampler=train_sampler, collate_fn=collate_fn, | |
num_workers=num_workers, pin_memory=True) | |
eval_bal_loader = torch.utils.data.DataLoader(dataset=dataset, | |
batch_sampler=eval_bal_sampler, collate_fn=collate_fn, | |
num_workers=num_workers, pin_memory=True) | |
eval_test_loader = torch.utils.data.DataLoader(dataset=dataset, | |
batch_sampler=eval_test_sampler, collate_fn=collate_fn, | |
num_workers=num_workers, pin_memory=True) | |
mix=0.5 | |
if 'mixup' in augmentation: | |
mixup_augmenter = Mixup(mixup_alpha=mix) | |
print(mix) | |
logging.info(mix) | |
# Evaluator | |
evaluator = Evaluator(model=model) | |
# Statistics | |
statistics_container = StatisticsContainer(statistics_path) | |
# Optimizer | |
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.05, amsgrad=True) | |
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=4, min_lr=1e-06, verbose=True) | |
train_bgn_time = time.time() | |
# Resume training | |
if resume_iteration > 0: | |
resume_checkpoint_path = os.path.join(workspace, 'checkpoints', filename, | |
'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'.format( | |
sample_rate, window_size, hop_size, mel_bins, fmin, fmax), | |
'data_type={}'.format(data_type), model_type, | |
'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced), | |
'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size), | |
'{}_iterations.pth'.format(resume_iteration)) | |
logging.info('Loading checkpoint {}'.format(resume_checkpoint_path)) | |
checkpoint = torch.load(resume_checkpoint_path) | |
model.load_state_dict(checkpoint['model']) | |
train_sampler.load_state_dict(checkpoint['sampler']) | |
statistics_container.load_state_dict(resume_iteration) | |
iteration = checkpoint['iteration'] | |
else: | |
iteration = 0 | |
# Parallel | |
print('GPU number: {}'.format(torch.cuda.device_count())) | |
model = torch.nn.DataParallel(model) | |
if 'cuda' in str(device): | |
model.to(device) | |
if resume_iteration: | |
optimizer.load_state_dict(checkpoint['optimizer']) | |
scheduler.load_state_dict(checkpoint['scheduler']) | |
print(optimizer.state_dict()['param_groups'][0]['lr']) | |
time1 = time.time() | |
for batch_data_dict in train_loader: | |
"""batch_data_dict: { | |
'audio_name': (batch_size [*2 if mixup],), | |
'waveform': (batch_size [*2 if mixup], clip_samples), | |
'target': (batch_size [*2 if mixup], classes_num), | |
(ifexist) 'mixup_lambda': (batch_size * 2,)} | |
""" | |
# Evaluate | |
if (iteration % 2000 == 0 and iteration >= resume_iteration) or (iteration == 0): | |
train_fin_time = time.time() | |
bal_statistics = evaluator.evaluate(eval_bal_loader) | |
test_statistics = evaluator.evaluate(eval_test_loader) | |
logging.info('Validate bal mAP: {:.3f}'.format( | |
np.mean(bal_statistics['average_precision']))) | |
logging.info('Validate test mAP: {:.3f}'.format( | |
np.mean(test_statistics['average_precision']))) | |
statistics_container.append(iteration, bal_statistics, data_type='bal') | |
statistics_container.append(iteration, test_statistics, data_type='test') | |
statistics_container.dump() | |
train_time = train_fin_time - train_bgn_time | |
validate_time = time.time() - train_fin_time | |
logging.info( | |
'iteration: {}, train time: {:.3f} s, validate time: {:.3f} s' | |
''.format(iteration, train_time, validate_time)) | |
logging.info('------------------------------------') | |
train_bgn_time = time.time() | |
# Save model | |
if iteration % 2000 == 0: | |
checkpoint = { | |
'iteration': iteration, | |
'model': model.module.state_dict(), | |
'sampler': train_sampler.state_dict(), | |
'optimizer': optimizer.state_dict(), | |
'scheduler': scheduler.state_dict()} | |
checkpoint_path = os.path.join( | |
checkpoints_dir, '{}_iterations.pth'.format(iteration)) | |
torch.save(checkpoint, checkpoint_path) | |
logging.info('Model saved to {}'.format(checkpoint_path)) | |
# Mixup lambda | |
if 'mixup' in augmentation: | |
batch_data_dict['mixup_lambda'] = mixup_augmenter.get_lambda( | |
batch_size=len(batch_data_dict['waveform'])) | |
# Move data to device | |
for key in batch_data_dict.keys(): | |
batch_data_dict[key] = move_data_to_device(batch_data_dict[key], device) | |
# Forward | |
model.train() | |
if 'mixup' in augmentation: | |
batch_output_dict = model(batch_data_dict['waveform'], | |
batch_data_dict['mixup_lambda']) | |
"""{'clipwise_output': (batch_size, classes_num), ...}""" | |
batch_target_dict = {'target': do_mixup(batch_data_dict['target'], | |
batch_data_dict['mixup_lambda'])} | |
"""{'target': (batch_size, classes_num)}""" | |
else: | |
batch_output_dict = model(batch_data_dict['waveform'], None) | |
"""{'clipwise_output': (batch_size, classes_num), ...}""" | |
batch_target_dict = {'target': batch_data_dict['target']} | |
"""{'target': (batch_size, classes_num)}""" | |
# Loss | |
loss = loss_func(batch_output_dict, batch_target_dict) | |
# Backward | |
loss.backward() | |
optimizer.step() | |
optimizer.zero_grad() | |
if iteration % 10 == 0: | |
print(iteration, loss) | |
#print('--- Iteration: {}, train time: {:.3f} s / 10 iterations ---'\ | |
# .format(iteration, time.time() - time1)) | |
#time1 = time.time() | |
if iteration % 2000 == 0: | |
scheduler.step(np.mean(test_statistics['average_precision'])) | |
print(optimizer.state_dict()['param_groups'][0]['lr']) | |
logging.info(optimizer.state_dict()['param_groups'][0]['lr']) | |
# Stop learning | |
if iteration == early_stop: | |
break | |
iteration += 1 | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='Example of parser. ') | |
subparsers = parser.add_subparsers(dest='mode') | |
parser_train = subparsers.add_parser('train') | |
parser_train.add_argument('--workspace', type=str, required=True) | |
parser_train.add_argument('--data_type', type=str, default='full_train', choices=['balanced_train', 'full_train']) | |
parser_train.add_argument('--sample_rate', type=int, default=32000) | |
parser_train.add_argument('--window_size', type=int, default=1024) | |
parser_train.add_argument('--hop_size', type=int, default=320) | |
parser_train.add_argument('--mel_bins', type=int, default=64) | |
parser_train.add_argument('--fmin', type=int, default=50) | |
parser_train.add_argument('--fmax', type=int, default=14000) | |
parser_train.add_argument('--model_type', type=str, required=True) | |
parser_train.add_argument('--loss_type', type=str, default='clip_bce', choices=['clip_bce']) | |
parser_train.add_argument('--balanced', type=str, default='balanced', choices=['none', 'balanced', 'alternate']) | |
parser_train.add_argument('--augmentation', type=str, default='mixup', choices=['none', 'mixup']) | |
parser_train.add_argument('--batch_size', type=int, default=32) | |
parser_train.add_argument('--learning_rate', type=float, default=1e-3) | |
parser_train.add_argument('--resume_iteration', type=int, default=0) | |
parser_train.add_argument('--early_stop', type=int, default=1000000) | |
parser_train.add_argument('--cuda', action='store_true', default=False) | |
args = parser.parse_args() | |
args.filename = get_filename(__file__) | |
if args.mode == 'train': | |
train(args) | |
else: | |
raise Exception('Error argument!') |