Spaces:
Runtime error
Runtime error
| from datetime import datetime | |
| from tqdm import tqdm | |
| import wandb | |
| # torch | |
| import torch | |
| import torchaudio | |
| from torch import nn | |
| from torch.utils.data import DataLoader | |
| # internal | |
| from dataset import VoiceDataset | |
| from cnn import CNNetwork | |
| BATCH_SIZE = 128 | |
| EPOCHS = 10 | |
| LEARNING_RATE = 0.001 | |
| TRAIN_FILE="data/train" | |
| AISF_TRAIN_FILE="data/aisf/train" | |
| TEST_FILE="data/test" | |
| SAMPLE_RATE=48000 | |
| def train(model, train_dataloader, loss_fn, optimizer, device, epochs, test_dataloader=None): | |
| training_acc = [] | |
| training_loss = [] | |
| testing_acc = [] | |
| testing_loss = [] | |
| for i in range(epochs): | |
| print(f"Epoch {i + 1}/{epochs}") | |
| # train model | |
| train_epoch_loss, train_epoch_acc = train_epoch(model, train_dataloader, loss_fn, optimizer, device) | |
| # training metrics | |
| training_loss.append(train_epoch_loss/len(train_dataloader)) | |
| training_acc.append(train_epoch_acc/len(train_dataloader)) | |
| print("Training Loss: {:.2f}, Training Accuracy {}".format(training_loss[i], training_acc[i])) | |
| wandb.log({'training_loss': training_loss[i], 'training_acc': training_acc[i]}) | |
| if test_dataloader: | |
| # test model | |
| test_epoch_loss, test_epoch_acc = validate_epoch(model, test_dataloader, loss_fn, device) | |
| # testing metrics | |
| testing_loss.append(test_epoch_loss/len(test_dataloader)) | |
| testing_acc.append(test_epoch_acc/len(test_dataloader)) | |
| print("Testing Loss: {:.2f}, Testing Accuracy {}".format(testing_loss[i], testing_acc[i])) | |
| wandb.log({'testing_loss': testing_loss[i], 'testing_acc': testing_acc[i]}) | |
| print ("-------------------------------------------- \n") | |
| print("---- Finished Training ----") | |
| return training_acc, training_loss, testing_acc, testing_loss | |
| def train_epoch(model, train_dataloader, loss_fn, optimizer, device): | |
| train_loss = 0.0 | |
| train_acc = 0.0 | |
| total = 0.0 | |
| model.train() | |
| for wav, target in tqdm(train_dataloader, "Training batch..."): | |
| wav, target = wav.to(device), target.to(device) | |
| # calculate loss | |
| output = model(wav) | |
| loss = loss_fn(output, target) | |
| # backprop and update weights | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| # metrics | |
| train_loss += loss.item() | |
| prediction = torch.argmax(output, 1) | |
| train_acc += (prediction == target).sum().item()/len(prediction) | |
| total += 1 | |
| return train_loss, train_acc | |
| def validate_epoch(model, test_dataloader, loss_fn, device): | |
| test_loss = 0.0 | |
| test_acc = 0.0 | |
| total = 0.0 | |
| model.eval() | |
| with torch.no_grad(): | |
| for wav, target in tqdm(test_dataloader, "Testing batch..."): | |
| wav, target = wav.to(device), target.to(device) | |
| output = model(wav) | |
| loss = loss_fn(output, target) | |
| test_loss += loss.item() | |
| prediciton = torch.argmax(output, 1) | |
| test_acc += (prediciton == target).sum().item()/len(prediciton) | |
| total += 1 | |
| return test_loss, test_acc | |
| if __name__ == "__main__": | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| else: | |
| device = "cpu" | |
| print(f"Using {device} device.") | |
| # instantiating our dataset object and create data loader | |
| mel_spectrogram = torchaudio.transforms.MelSpectrogram( | |
| sample_rate=SAMPLE_RATE, | |
| n_fft=2048, | |
| hop_length=512, | |
| n_mels=128 | |
| ) | |
| train_dataset = VoiceDataset(AISF_TRAIN_FILE, mel_spectrogram, device, time_limit_in_secs=3) | |
| train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) | |
| # construct model | |
| model = CNNetwork().to(device) | |
| print(model) | |
| print(train_dataset.label_mapping) | |
| # init loss function and optimizer | |
| loss_fn = nn.CrossEntropyLoss() | |
| # optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) | |
| optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9) | |
| wandb.init(project="void-train") | |
| # train model | |
| train(model, train_dataloader, loss_fn, optimizer, device, EPOCHS) | |
| model.label_mapping = train_dataset.label_mapping | |
| # save model | |
| now = datetime.now() | |
| now = now.strftime("%Y%m%d_%H%M%S") | |
| model_filename = f"models/aisf/void_{now}.pth" | |
| torch.save(model.state_dict(), model_filename) | |
| print(f"Trained void model saved at {model_filename}") | |
| wandb.finish() |