{ "cells": [ { "cell_type": "code", "execution_count": 8, "id": "9db7bd27", "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 10, "id": "72b076a5", "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append('..')" ] }, { "cell_type": "code", "execution_count": 86, "id": "391c8ebe", "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "import torch\n", "from torchsummary import summary" ] }, { "cell_type": "code", "execution_count": 85, "id": "0f0b166a", "metadata": {}, "outputs": [], "source": [ "from dataset import *\n", "from cnn import CNNetwork" ] }, { "cell_type": "code", "execution_count": 78, "id": "b690f559", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using device cpu\n" ] } ], "source": [ "if torch.cuda.is_available():\n", " device = \"cuda\"\n", "else:\n", " device = \"cpu\"\n", "print(f\"Using device {device}\")" ] }, { "cell_type": "code", "execution_count": 109, "id": "5b4cac66", "metadata": {}, "outputs": [], "source": [ "mel_spectrogram = torchaudio.transforms.MelSpectrogram(\n", " sample_rate=16000,\n", " n_fft=1024,\n", " hop_length=512,\n", " n_mels=64\n", " )\n", "dataset = VoiceDataset('../data/train', mel_spectrogram, 16000, device)" ] }, { "cell_type": "code", "execution_count": 110, "id": "55928782", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "5717" ] }, "execution_count": 110, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(dataset)" ] }, { "cell_type": "code", "execution_count": 111, "id": "296fc1d0", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[[0.2647, 0.0247, 0.0324, ..., 0.0230, 0.1026, 0.5454],\n", " [0.0812, 0.0178, 0.0890, ..., 0.2376, 0.5061, 0.5292],\n", " [0.0052, 0.0212, 0.1341, ..., 0.9336, 0.2778, 0.1372],\n", " ...,\n", " [0.5154, 0.3950, 0.4497, ..., 0.4916, 0.4505, 0.7709],\n", " [0.1919, 0.4804, 0.5144, ..., 0.5931, 0.4466, 0.4706],\n", " [0.1208, 0.4357, 0.4016, ..., 0.5168, 0.7007, 0.3696]]]),\n", " 0)" ] }, "execution_count": 111, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset[0]" ] }, { "cell_type": "code", "execution_count": 112, "id": "b921ef42", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 64, 157])" ] }, "execution_count": 112, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset[0][0].shape" ] }, { "cell_type": "code", "execution_count": 144, "id": "83671781", "metadata": {}, "outputs": [], "source": [ "cnn = CNNetwork()\n", "# summary(cnn, (1, 64, 44))" ] }, { "cell_type": "code", "execution_count": 114, "id": "5a12b59f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0)" ] }, "execution_count": 114, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.tensor(0)" ] }, { "cell_type": "code", "execution_count": 115, "id": "4845de38", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'aman': 0, 'imran': 1, 'labib': 2}" ] }, "execution_count": 115, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset.label_mapping" ] }, { "cell_type": "code", "execution_count": 116, "id": "51c03aaf", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[[0.2647, 0.0247, 0.0324, ..., 0.0230, 0.1026, 0.5454],\n", " [0.0812, 0.0178, 0.0890, ..., 0.2376, 0.5061, 0.5292],\n", " [0.0052, 0.0212, 0.1341, ..., 0.9336, 0.2778, 0.1372],\n", " ...,\n", " [0.5154, 0.3950, 0.4497, ..., 0.4916, 0.4505, 0.7709],\n", " [0.1919, 0.4804, 0.5144, ..., 0.5931, 0.4466, 0.4706],\n", " [0.1208, 0.4357, 0.4016, ..., 0.5168, 0.7007, 0.3696]]]),\n", " 0)" ] }, "execution_count": 116, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset[0]" ] }, { "cell_type": "code", "execution_count": 117, "id": "ba6b88ee", "metadata": {}, "outputs": [], "source": [ "from datetime import datetime\n", "now = datetime.now()" ] }, { "cell_type": "code", "execution_count": 107, "id": "a6046ccf", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'20230512_222912'" ] }, "execution_count": 107, "metadata": {}, "output_type": "execute_result" } ], "source": [ "now.strftime(\"%Y%m%d_%H%M%S\")" ] }, { "cell_type": "code", "execution_count": 145, "id": "d7789a04", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 145, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cnn.load_state_dict(torch.load(\"../models/void_20230512_225714.pth\"))" ] }, { "cell_type": "code", "execution_count": 151, "id": "a6030b42", "metadata": {}, "outputs": [], "source": [ "x, y = dataset[10]" ] }, { "cell_type": "code", "execution_count": 152, "id": "78352b6b", "metadata": {}, "outputs": [], "source": [ "labels = dataset._labels" ] }, { "cell_type": "code", "execution_count": 153, "id": "b8cc2162", "metadata": {}, "outputs": [], "source": [ "input = x.unsqueeze_(0) " ] }, { "cell_type": "code", "execution_count": 182, "id": "845ecea4", "metadata": {}, "outputs": [], "source": [ "def predict(model, input, target, class_mapping):\n", " model.eval()\n", " with torch.no_grad():\n", " predictions = model(input)\n", " predicted_index = predictions[0].argmax(0)\n", " predicted = class_mapping[predicted_index]\n", " expected = class_mapping[target]\n", " return predictions" ] }, { "cell_type": "code", "execution_count": 155, "id": "eb8d1e55", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[1.0000e+00, 1.3728e-20, 2.8026e-44]])\n" ] }, { "data": { "text/plain": [ "('aman', 'aman')" ] }, "execution_count": 155, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predict(cnn, input, y, labels)" ] }, { "cell_type": "code", "execution_count": 156, "id": "5d58683e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[[[0.0259, 0.1384, 0.0784, ..., 0.0000, 0.0000, 0.0000],\n", " [0.0334, 0.1320, 0.0701, ..., 0.0000, 0.0000, 0.0000],\n", " [0.0481, 0.0324, 0.0545, ..., 0.0000, 0.0000, 0.0000],\n", " ...,\n", " [0.2665, 0.3647, 0.3147, ..., 0.0000, 0.0000, 0.0000],\n", " [0.2710, 0.3796, 0.2160, ..., 0.0000, 0.0000, 0.0000],\n", " [0.1950, 0.2607, 0.1905, ..., 0.0000, 0.0000, 0.0000]]]])" ] }, "execution_count": 156, "metadata": {}, "output_type": "execute_result" } ], "source": [ "input" ] }, { "cell_type": "code", "execution_count": 157, "id": "b0af5b69", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 1, 64, 157])" ] }, "execution_count": 157, "metadata": {}, "output_type": "execute_result" } ], "source": [ "input.shape" ] }, { "cell_type": "code", "execution_count": 158, "id": "28c0768a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 1, 64, 157])" ] }, "execution_count": 158, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x.shape" ] }, { "cell_type": "code", "execution_count": 159, "id": "c5817d01", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'data/aman/aman_1'" ] }, "execution_count": 159, "metadata": {}, "output_type": "execute_result" } ], "source": [ "os.path.join('data', 'aman', 'aman_1')" ] }, { "cell_type": "code", "execution_count": 175, "id": "03bb835e", "metadata": {}, "outputs": [], "source": [ "test_dataset = VoiceDataset('../data/test', mel_spectrogram, 16000, device)" ] }, { "cell_type": "code", "execution_count": 177, "id": "151f8cb9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('aman_test2.m4a', 'aman'),\n", " ('aman_6.m4a', 'aman'),\n", " ('aman_4.m4a', 'aman'),\n", " ('aman_5.m4a', 'aman'),\n", " ('aman_1.m4a', 'aman'),\n", " ('aman_2.m4a', 'aman'),\n", " ('aman_3.m4a', 'aman'),\n", " ('aman_test.m4a', 'aman'),\n", " ('Imran 6.m4a', 'imran'),\n", " ('Imran 5.m4a', 'imran'),\n", " ('Imran 4.m4a', 'imran'),\n", " ('Imran 1.m4a', 'imran'),\n", " ('Imran 3.m4a', 'imran'),\n", " ('Imran 2.m4a', 'imran'),\n", " ('labib_6.m4a', 'labib'),\n", " ('labib_4.m4a', 'labib'),\n", " ('labib_5.m4a', 'labib'),\n", " ('labib_1.m4a', 'labib'),\n", " ('labib_2.m4a', 'labib'),\n", " ('labib_3.m4a', 'labib')]" ] }, "execution_count": 177, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_dataset.audio_files_labels" ] }, { "cell_type": "code", "execution_count": 178, "id": "4c09f2c2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 64, 157])" ] }, "execution_count": 178, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_input, test_output = test_dataset[0]\n", "test_input.shape" ] }, { "cell_type": "code", "execution_count": 179, "id": "1f4254aa", "metadata": {}, "outputs": [], "source": [ "test_input = test_input.unsqueeze_(0) " ] }, { "cell_type": "code", "execution_count": 180, "id": "f1559f5c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 1, 64, 157])" ] }, "execution_count": 180, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_input.shape" ] }, { "cell_type": "code", "execution_count": 185, "id": "5e683c2e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[1., 0., 0.]])\n" ] } ], "source": [ "output = predict(cnn, test_input, test_output, test_dataset._labels)" ] }, { "cell_type": "code", "execution_count": 188, "id": "581911ad", "metadata": {}, "outputs": [], "source": [ "pred = torch.argmax(output, 1)" ] }, { "cell_type": "code", "execution_count": null, "id": "135eecde", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 5 }