from datetime import datetime from tqdm import tqdm import wandb # torch import torch import torchaudio from torch import nn from torch.utils.data import DataLoader # internal from dataset import VoiceDataset from cnn import CNNetwork BATCH_SIZE = 128 EPOCHS = 10 LEARNING_RATE = 0.001 TRAIN_FILE="data/train" AISF_TRAIN_FILE="data/aisf/train" TEST_FILE="data/test" SAMPLE_RATE=48000 def train(model, train_dataloader, loss_fn, optimizer, device, epochs, test_dataloader=None): training_acc = [] training_loss = [] testing_acc = [] testing_loss = [] for i in range(epochs): print(f"Epoch {i + 1}/{epochs}") # train model train_epoch_loss, train_epoch_acc = train_epoch(model, train_dataloader, loss_fn, optimizer, device) # training metrics training_loss.append(train_epoch_loss/len(train_dataloader)) training_acc.append(train_epoch_acc/len(train_dataloader)) print("Training Loss: {:.2f}, Training Accuracy {}".format(training_loss[i], training_acc[i])) wandb.log({'training_loss': training_loss[i], 'training_acc': training_acc[i]}) if test_dataloader: # test model test_epoch_loss, test_epoch_acc = validate_epoch(model, test_dataloader, loss_fn, device) # testing metrics testing_loss.append(test_epoch_loss/len(test_dataloader)) testing_acc.append(test_epoch_acc/len(test_dataloader)) print("Testing Loss: {:.2f}, Testing Accuracy {}".format(testing_loss[i], testing_acc[i])) wandb.log({'testing_loss': testing_loss[i], 'training_acc': training_acc[i]}) print ("-------------------------------------------- \n") print("---- Finished Training ----") return training_acc, training_loss, testing_acc, testing_loss def train_epoch(model, train_dataloader, loss_fn, optimizer, device): train_loss = 0.0 train_acc = 0.0 total = 0.0 model.train() for wav, target in tqdm(train_dataloader, "Training batch..."): wav, target = wav.to(device), target.to(device) # calculate loss output = model(wav) loss = loss_fn(output, target) # backprop and update weights optimizer.zero_grad() loss.backward() optimizer.step() # metrics train_loss += loss.item() prediction = torch.argmax(output, 1) train_acc += (prediction == target).sum().item()/len(prediction) total += 1 return train_loss, train_acc def validate_epoch(model, test_dataloader, loss_fn, device): test_loss = 0.0 test_acc = 0.0 total = 0.0 model.eval() with torch.no_grad(): for wav, target in tqdm(test_dataloader, "Testing batch..."): wav, target = wav.to(device), target.to(device) output = model(wav) loss = loss_fn(output, target) test_loss += loss.item() prediciton = torch.argmax(output, 1) test_acc += (prediciton == target).sum().item()/len(prediciton) total += 1 return test_loss, test_acc if __name__ == "__main__": if torch.cuda.is_available(): device = "cuda" else: device = "cpu" print(f"Using {device} device.") # instantiating our dataset object and create data loader mel_spectrogram = torchaudio.transforms.MelSpectrogram( sample_rate=SAMPLE_RATE, n_fft=2048, hop_length=512, n_mels=128 ) train_dataset = VoiceDataset(AISF_TRAIN_FILE, mel_spectrogram, device, time_limit_in_secs=3) train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) # construct model model = CNNetwork().to(device) print(model) print(train_dataset.label_mapping) # init loss function and optimizer loss_fn = nn.CrossEntropyLoss() # optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9) wandb.init(project="void-train") # train model train(model, train_dataloader, loss_fn, optimizer, device, EPOCHS) model.label_mapping = train_dataset.label_mapping # save model now = datetime.now() now = now.strftime("%Y%m%d_%H%M%S") model_filename = f"models/aisf/void_{now}.pth" torch.save(model.state_dict(), model_filename) print(f"Trained void model saved at {model_filename}") wandb.finish()