import os import sys sys.path.insert(1, os.path.join(sys.path[0], '../utils')) import numpy as np import argparse import h5py import math import time import logging import matplotlib.pyplot as plt import torch torch.backends.cudnn.benchmark=True torch.manual_seed(0) import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torch.utils.data from utilities import get_filename from models import * import config class Transfer_Cnn14(nn.Module): def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num, freeze_base): """Classifier for a new task using pretrained Cnn14 as a sub module. """ super(Transfer_Cnn14, self).__init__() audioset_classes_num = 527 self.base = Cnn14(sample_rate, window_size, hop_size, mel_bins, fmin, fmax, audioset_classes_num) # Transfer to another task layer self.fc_transfer = nn.Linear(2048, classes_num, bias=True) if freeze_base: # Freeze AudioSet pretrained layers for param in self.base.parameters(): param.requires_grad = False self.init_weights() def init_weights(self): init_layer(self.fc_transfer) def load_from_pretrain(self, pretrained_checkpoint_path): checkpoint = torch.load(pretrained_checkpoint_path) self.base.load_state_dict(checkpoint['model']) def forward(self, input, mixup_lambda=None): """Input: (batch_size, data_length) """ output_dict = self.base(input, mixup_lambda) embedding = output_dict['embedding'] clipwise_output = torch.log_softmax(self.fc_transfer(embedding), dim=-1) output_dict['clipwise_output'] = clipwise_output return output_dict def train(args): # Arugments & parameters sample_rate = args.sample_rate window_size = args.window_size hop_size = args.hop_size mel_bins = args.mel_bins fmin = args.fmin fmax = args.fmax model_type = args.model_type pretrained_checkpoint_path = args.pretrained_checkpoint_path freeze_base = args.freeze_base device = 'cuda' if (args.cuda and torch.cuda.is_available()) else 'cpu' classes_num = config.classes_num pretrain = True if pretrained_checkpoint_path else False # Model Model = eval(model_type) model = Model(sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num, freeze_base) # Load pretrained model if pretrain: logging.info('Load pretrained model from {}'.format(pretrained_checkpoint_path)) model.load_from_pretrain(pretrained_checkpoint_path) # Parallel print('GPU number: {}'.format(torch.cuda.device_count())) model = torch.nn.DataParallel(model) if 'cuda' in device: model.to(device) print('Load pretrained model successfully!') if __name__ == '__main__': parser = argparse.ArgumentParser(description='Example of parser. ') subparsers = parser.add_subparsers(dest='mode') # Train parser_train = subparsers.add_parser('train') parser_train.add_argument('--sample_rate', type=int, required=True) parser_train.add_argument('--window_size', type=int, required=True) parser_train.add_argument('--hop_size', type=int, required=True) parser_train.add_argument('--mel_bins', type=int, required=True) parser_train.add_argument('--fmin', type=int, required=True) parser_train.add_argument('--fmax', type=int, required=True) parser_train.add_argument('--model_type', type=str, required=True) parser_train.add_argument('--pretrained_checkpoint_path', type=str) parser_train.add_argument('--freeze_base', action='store_true', default=False) parser_train.add_argument('--cuda', action='store_true', default=False) # Parse arguments args = parser.parse_args() args.filename = get_filename(__file__) if args.mode == 'train': train(args) else: raise Exception('Error argument!')