|
import time |
|
import argparse |
|
import numpy as np |
|
import torch |
|
import tqdm |
|
from torch import optim |
|
from torch.utils.data import DataLoader |
|
|
|
from data_proc.cross_entropy_dataset import FBanksCrossEntropyDataset |
|
from models.cross_entropy_model import FBankCrossEntropyNetV2 |
|
from utils.pt_util import restore_objects, save_model, save_objects, restore_model |
|
from trainer.cross_entropy_train import train, test |
|
|
|
|
|
def main(args): |
|
model_path = f"saved_models_cross_entropy/{args.num_layers}/" |
|
use_cuda = True |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print('using device', device) |
|
|
|
import multiprocessing |
|
print('num cpus:', multiprocessing.cpu_count()) |
|
|
|
kwargs = {'num_workers': multiprocessing.cpu_count(), |
|
'pin_memory': True} if use_cuda else {} |
|
|
|
train_dataset = FBanksCrossEntropyDataset(args.train_folder) |
|
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs) |
|
|
|
test_dataset = FBanksCrossEntropyDataset(args.test_folder) |
|
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True, **kwargs) |
|
|
|
model = FBankCrossEntropyNetV2(num_layers=args.num_layers, reduction='mean').to(device) |
|
model = restore_model(model, model_path) |
|
last_epoch, max_accuracy, train_losses, test_losses, train_accuracies, test_accuracies = restore_objects(model_path, (0, 0, [], [], [], [])) |
|
start = last_epoch + 1 if max_accuracy > 0 else 0 |
|
|
|
optimizer = optim.Adam(model.parameters(), lr=args.lr) |
|
|
|
for epoch in range(start, args.epochs): |
|
train_loss, train_accuracy = train(model, device, train_loader, optimizer, epoch, 500) |
|
test_loss, test_accuracy = test(model, device, test_loader) |
|
print('After epoch: {}, train_loss: {}, test loss is: {}, train_accuracy: {}, ' |
|
'test_accuracy: {}'.format(epoch, train_loss, test_loss, train_accuracy, test_accuracy)) |
|
|
|
train_losses.append(train_loss) |
|
test_losses.append(test_loss) |
|
train_accuracies.append(train_accuracy) |
|
test_accuracies.append(test_accuracy) |
|
if test_accuracy > max_accuracy: |
|
max_accuracy = test_accuracy |
|
save_model(model, epoch, model_path) |
|
save_objects((epoch, max_accuracy, train_losses, test_losses, train_accuracies, test_accuracies), epoch, model_path) |
|
print('saved epoch: {} as checkpoint'.format(epoch)) |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser(description='FBank Cross Entropy Training Script') |
|
|
|
parser.add_argument('--num_layers', type=int, default=2, help='Number of layers in the model') |
|
parser.add_argument('--train_folder', type=str, default='fbanks_train', help='Training dataset folder') |
|
parser.add_argument('--test_folder', type=str, default='fbanks_test', help='Testing dataset folder') |
|
parser.add_argument('--epochs', type=int, default=20, help='Number of epochs to train') |
|
parser.add_argument('--batch_size', type=int, default=64, help='Batch size for training') |
|
parser.add_argument('--lr', type=float, default=0.0005, help='Learning rate for the optimizer') |
|
|
|
args = parser.parse_args() |
|
|
|
main(args) |
|
|