{ "cells": [ { "cell_type": "code", "execution_count": 3, "id": "2d4667c5", "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "markdown", "id": "524483a8", "metadata": {}, "source": [ "# Dependencies" ] }, { "cell_type": "code", "execution_count": 4, "id": "4c51929f", "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append('..')" ] }, { "cell_type": "code", "execution_count": 23, "id": "8f19f1c8", "metadata": {}, "outputs": [], "source": [ "import os\n", "import matplotlib.pyplot as plt\n", "from datetime import datetime\n", "\n", "# torch\n", "import torch\n", "import torchaudio\n", "from torch.utils.data import DataLoader\n", "\n", "# model training\n", "from cnn import CNNetwork\n", "from dataset import VoiceDataset\n", "from train import train" ] }, { "cell_type": "markdown", "id": "4f51b4f8", "metadata": {}, "source": [ "## Globals" ] }, { "cell_type": "code", "execution_count": 37, "id": "6f1716bf", "metadata": {}, "outputs": [], "source": [ "DATA_PATH = os.path.join('..', 'data', 'aisf', 'augmented')" ] }, { "cell_type": "code", "execution_count": 41, "id": "13fa6700", "metadata": {}, "outputs": [], "source": [ "TRAIN_PATH = os.path.join(DATA_PATH, 'train')\n", "TEST_PATH = os.path.join(DATA_PATH, 'test')" ] }, { "cell_type": "code", "execution_count": 26, "id": "152c1fbf", "metadata": {}, "outputs": [], "source": [ "EPOCHS = 25\n", "BATCH_SIZE = 128\n", "LEARNING_RATE = 0.001\n", "SAMPLE_RATE=48000" ] }, { "cell_type": "code", "execution_count": 27, "id": "98cf668e", "metadata": {}, "outputs": [], "source": [ "MEL_SPEC = torchaudio.transforms.MelSpectrogram(\n", " sample_rate=SAMPLE_RATE,\n", " n_fft=2048,\n", " hop_length=512,\n", " n_mels=128,\n", ")" ] }, { "cell_type": "markdown", "id": "063c13a4", "metadata": {}, "source": [ "# `train()` setup" ] }, { "cell_type": "code", "execution_count": 28, "id": "19b58a13", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using cpu device.\n" ] } ], "source": [ "if torch.cuda.is_available():\n", " device = \"cuda\"\n", "else:\n", " device = \"cpu\"\n", "print(f\"Using {device} device.\")" ] }, { "cell_type": "code", "execution_count": 42, "id": "848745e7", "metadata": {}, "outputs": [], "source": [ "# Datasets\n", "train_dataset = VoiceDataset(TRAIN_PATH, MEL_SPEC, device, SAMPLE_RATE, time_limit_in_secs=3)\n", "test_dataset = VoiceDataset(TEST_PATH, MEL_SPEC, device, SAMPLE_RATE, time_limit_in_secs=3)\n", "\n", "# Dataloaders\n", "train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n", "test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)" ] }, { "cell_type": "code", "execution_count": 43, "id": "ec97cc05", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1, 128, 282])\n" ] }, { "data": { "text/plain": [ "4221" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(train_dataset[0][0].shape)\n", "len(train_dataset)" ] }, { "cell_type": "code", "execution_count": 44, "id": "6176760d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1, 128, 282])\n" ] }, { "data": { "text/plain": [ "2121" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(test_dataset[0][0].shape)\n", "len(test_dataset)" ] }, { "cell_type": "code", "execution_count": 45, "id": "67a76338", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CNNetwork(\n", " (conv1): Sequential(\n", " (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))\n", " (1): ReLU()\n", " (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " )\n", " (conv2): Sequential(\n", " (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))\n", " (1): ReLU()\n", " (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " )\n", " (conv3): Sequential(\n", " (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))\n", " (1): ReLU()\n", " (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " )\n", " (conv4): Sequential(\n", " (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))\n", " (1): ReLU()\n", " (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " )\n", " (flatten): Flatten(start_dim=1, end_dim=-1)\n", " (linear): Linear(in_features=21888, out_features=3, bias=True)\n", " (softmax): Softmax(dim=1)\n", ")\n" ] } ], "source": [ "model = CNNetwork().to(device)\n", "print(model)" ] }, { "cell_type": "code", "execution_count": 46, "id": "5c5b08c3", "metadata": {}, "outputs": [], "source": [ "loss_fn = torch.nn.CrossEntropyLoss()\n", "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)" ] }, { "cell_type": "code", "execution_count": 47, "id": "f7b14a2b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'shafqat': 0, 'aman': 1, 'jake': 2}" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_dataset.label_mapping" ] }, { "cell_type": "code", "execution_count": 48, "id": "87ed6df3", "metadata": {}, "outputs": [], "source": [ "model.labels = train_dataset.label_mapping" ] }, { "cell_type": "markdown", "id": "e6eb6722", "metadata": {}, "source": [ "# Train/Test" ] }, { "cell_type": "code", "execution_count": 49, "id": "f4b5c789", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/25\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Training batch...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 33/33 [01:10<00:00, 2.15s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training Loss: 0.79, Training Accuracy 0.7577386363636364\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Testing batch...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:14<00:00, 1.19it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Testing Loss: 0.64, Testing Accuracy 0.9096809528605964\n", "-------------------------------------------- \n", "\n", "Epoch 2/25\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Training batch...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 33/33 [01:02<00:00, 1.91s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training Loss: 0.59, Training Accuracy 0.9618503787878787\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Testing batch...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:13<00:00, 1.22it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Testing Loss: 0.60, Testing Accuracy 0.9536978746978244\n", "-------------------------------------------- \n", "\n", "Epoch 3/25\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Training batch...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 33/33 [01:02<00:00, 1.89s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training Loss: 0.58, Training Accuracy 0.9758295454545456\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Testing batch...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:13<00:00, 1.22it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Testing Loss: 0.58, Testing Accuracy 0.9756370870265916\n", "-------------------------------------------- \n", "\n", "Epoch 4/25\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Training batch...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 33/33 [01:01<00:00, 1.87s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training Loss: 0.57, Training Accuracy 0.9867310606060606\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Testing batch...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:13<00:00, 1.23it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Testing Loss: 0.56, Testing Accuracy 0.9889705882352942\n", "-------------------------------------------- \n", "\n", "Epoch 5/25\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Training batch...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 33/33 [01:19<00:00, 2.40s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training Loss: 0.56, Training Accuracy 0.9933712121212122\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Testing batch...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:21<00:00, 1.26s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Testing Loss: 0.56, Testing Accuracy 0.992300815874295\n", "-------------------------------------------- \n", "\n", "Epoch 6/25\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Training batch...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 33/33 [01:17<00:00, 2.36s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training Loss: 0.56, Training Accuracy 0.9955018939393939\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Testing batch...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:14<00:00, 1.18it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Testing Loss: 0.56, Testing Accuracy 0.9935661764705882\n", "-------------------------------------------- \n", "\n", "Epoch 7/25\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Training batch...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 33/33 [01:05<00:00, 1.97s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training Loss: 0.56, Training Accuracy 0.9964488636363636\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Testing batch...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:13<00:00, 1.23it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Testing Loss: 0.56, Testing Accuracy 0.9881647864625303\n", "-------------------------------------------- \n", "\n", "Epoch 8/25\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Training batch...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 33/33 [01:02<00:00, 1.89s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training Loss: 0.55, Training Accuracy 0.9971590909090909\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Testing batch...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:14<00:00, 1.21it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Testing Loss: 0.56, Testing Accuracy 0.989083904109589\n", "-------------------------------------------- \n", "\n", "Epoch 9/25\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Training batch...: 12%|████████████████▍ | 4/33 [00:08<01:00, 2.10s/it]\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[49], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m history \u001b[38;5;241m=\u001b[39m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_dataloader\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43mloss_fn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[43mEPOCHS\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[43m \u001b[49m\u001b[43mtest_dataloader\u001b[49m\n\u001b[1;32m 9\u001b[0m \u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/ml-sandbox/VoID/notebooks/../train.py:33\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m(model, train_dataloader, loss_fn, optimizer, device, epochs, test_dataloader)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEpoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;250m \u001b[39m\u001b[38;5;241m+\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepochs\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 32\u001b[0m \u001b[38;5;66;03m# train model\u001b[39;00m\n\u001b[0;32m---> 33\u001b[0m train_epoch_loss, train_epoch_acc \u001b[38;5;241m=\u001b[39m \u001b[43mtrain_epoch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mloss_fn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 35\u001b[0m \u001b[38;5;66;03m# training metrics\u001b[39;00m\n\u001b[1;32m 36\u001b[0m training_loss\u001b[38;5;241m.\u001b[39mappend(train_epoch_loss\u001b[38;5;241m/\u001b[39m\u001b[38;5;28mlen\u001b[39m(train_dataloader))\n", "File \u001b[0;32m~/ml-sandbox/VoID/notebooks/../train.py:68\u001b[0m, in \u001b[0;36mtrain_epoch\u001b[0;34m(model, train_dataloader, loss_fn, optimizer, device)\u001b[0m\n\u001b[1;32m 65\u001b[0m wav, target \u001b[38;5;241m=\u001b[39m wav\u001b[38;5;241m.\u001b[39mto(device), target\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 67\u001b[0m \u001b[38;5;66;03m# calculate loss\u001b[39;00m\n\u001b[0;32m---> 68\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mwav\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 69\u001b[0m loss \u001b[38;5;241m=\u001b[39m loss_fn(output, target)\n\u001b[1;32m 71\u001b[0m \u001b[38;5;66;03m# backprop and update weights\u001b[39;00m\n", "File \u001b[0;32m~/anaconda3/envs/void/lib/python3.9/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", "File \u001b[0;32m~/ml-sandbox/VoID/notebooks/../cnn.py:59\u001b[0m, in \u001b[0;36mCNNetwork.forward\u001b[0;34m(self, input_data)\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, input_data):\n\u001b[0;32m---> 59\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv1\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_data\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 60\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconv2(x)\n\u001b[1;32m 61\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconv3(x)\n", "File \u001b[0;32m~/anaconda3/envs/void/lib/python3.9/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", "File \u001b[0;32m~/anaconda3/envs/void/lib/python3.9/site-packages/torch/nn/modules/container.py:217\u001b[0m, in \u001b[0;36mSequential.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 215\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m):\n\u001b[1;32m 216\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[0;32m--> 217\u001b[0m \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 218\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m\n", "File \u001b[0;32m~/anaconda3/envs/void/lib/python3.9/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", "File \u001b[0;32m~/anaconda3/envs/void/lib/python3.9/site-packages/torch/nn/modules/conv.py:463\u001b[0m, in \u001b[0;36mConv2d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 462\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 463\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_conv_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/anaconda3/envs/void/lib/python3.9/site-packages/torch/nn/modules/conv.py:459\u001b[0m, in \u001b[0;36mConv2d._conv_forward\u001b[0;34m(self, input, weight, bias)\u001b[0m\n\u001b[1;32m 455\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mzeros\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m 456\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mconv2d(F\u001b[38;5;241m.\u001b[39mpad(\u001b[38;5;28minput\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reversed_padding_repeated_twice, mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode),\n\u001b[1;32m 457\u001b[0m weight, bias, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstride,\n\u001b[1;32m 458\u001b[0m _pair(\u001b[38;5;241m0\u001b[39m), \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdilation, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgroups)\n\u001b[0;32m--> 459\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv2d\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstride\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 460\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpadding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdilation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgroups\u001b[49m\u001b[43m)\u001b[49m\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "history = train(\n", " model,\n", " train_dataloader,\n", " loss_fn,\n", " optimizer,\n", " device,\n", " EPOCHS,\n", " test_dataloader\n", ")" ] }, { "cell_type": "markdown", "id": "572ae86d", "metadata": {}, "source": [ "# Visualizations" ] }, { "cell_type": "code", "execution_count": 50, "id": "0b986ac4", "metadata": {}, "outputs": [], "source": [ "training_acc = history[0]\n", "training_loss = history[1]\n", "\n", "testing_acc = history[2]\n", "testing_loss = history[3]" ] }, { "cell_type": "code", "execution_count": 51, "id": "5cfbedcc", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig = plt.figure()\n", "fig.suptitle(\"Accuracy\")\n", "ax = fig.subplots(1, 2)\n", "ax[0].plot(training_acc)\n", "ax[0].title.set_text(\"Training\")\n", "ax[1].plot(testing_acc)\n", "ax[1].title.set_text(\"Testing\")" ] }, { "cell_type": "code", "execution_count": 52, "id": "2475aebc", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig = plt.figure()\n", "fig.suptitle(\"Loss\")\n", "ax = fig.subplots(1, 2)\n", "ax[0].plot(training_loss)\n", "ax[0].title.set_text(\"Training\")\n", "ax[1].plot(testing_loss)\n", "ax[1].title.set_text(\"Testing\")" ] }, { "cell_type": "code", "execution_count": 53, "id": "fc0bc3d9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Trained void model saved at ../models/aisf/void_20230517_115313.pth\n" ] } ], "source": [ "# save model\n", "now = datetime.now()\n", "now = now.strftime(\"%Y%m%d_%H%M%S\")\n", "model_filename = f\"../models/aisf/void_{now}.pth\"\n", "torch.save(model.state_dict(), model_filename)\n", "print(f\"Trained void model saved at {model_filename}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "9b5f0551", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 5 }