amanmibra commited on
Commit
a38e25f
1 Parent(s): e5db3e9

Add training

Browse files
__pycache__/cnn.cpython-39.pyc CHANGED
Binary files a/__pycache__/cnn.cpython-39.pyc and b/__pycache__/cnn.cpython-39.pyc differ
 
__pycache__/dataset.cpython-39.pyc CHANGED
Binary files a/__pycache__/dataset.cpython-39.pyc and b/__pycache__/dataset.cpython-39.pyc differ
 
cnn.py CHANGED
@@ -52,7 +52,7 @@ class CNNetwork(nn.Module):
52
  nn.MaxPool2d(kernel_size=2)
53
  )
54
  self.flatten = nn.Flatten()
55
- self.linear = nn.Linear(128 * 5 * 4, 10)
56
  self.softmax = nn.Softmax(dim=1)
57
 
58
  def forward(self, input_data):
 
52
  nn.MaxPool2d(kernel_size=2)
53
  )
54
  self.flatten = nn.Flatten()
55
+ self.linear = nn.Linear(128 * 5 * 11, 10)
56
  self.softmax = nn.Softmax(dim=1)
57
 
58
  def forward(self, input_data):
dataset.py CHANGED
@@ -18,6 +18,7 @@ class VoiceDataset(Dataset):
18
  # file processing
19
  self._data_path = os.path.join(data_directory)
20
  self._labels = os.listdir(self._data_path)
 
21
  self.audio_files_labels = self._join_audio_files()
22
 
23
  self.device = device
@@ -50,7 +51,9 @@ class VoiceDataset(Dataset):
50
 
51
  # apply transformation
52
  wav = self.transformation(wav)
53
- return wav, label
 
 
54
 
55
 
56
  def _join_audio_files(self):
 
18
  # file processing
19
  self._data_path = os.path.join(data_directory)
20
  self._labels = os.listdir(self._data_path)
21
+ self.label_mapping = {label: i for i, label in enumerate(self._labels)}
22
  self.audio_files_labels = self._join_audio_files()
23
 
24
  self.device = device
 
51
 
52
  # apply transformation
53
  wav = self.transformation(wav)
54
+
55
+ # return wav and integer representation of the label
56
+ return wav, self.label_mapping[label]
57
 
58
 
59
  def _join_audio_files(self):
models/void.pth ADDED
Binary file (673 kB). View file
 
notebooks/playground.ipynb CHANGED
@@ -3,7 +3,7 @@
3
  {
4
  "cell_type": "code",
5
  "execution_count": 8,
6
- "id": "46dbbffd",
7
  "metadata": {},
8
  "outputs": [],
9
  "source": [
@@ -14,7 +14,7 @@
14
  {
15
  "cell_type": "code",
16
  "execution_count": 10,
17
- "id": "56056453",
18
  "metadata": {},
19
  "outputs": [],
20
  "source": [
@@ -25,7 +25,7 @@
25
  {
26
  "cell_type": "code",
27
  "execution_count": 86,
28
- "id": "b5cbac9d",
29
  "metadata": {},
30
  "outputs": [],
31
  "source": [
@@ -38,7 +38,7 @@
38
  {
39
  "cell_type": "code",
40
  "execution_count": 85,
41
- "id": "1970ad63",
42
  "metadata": {},
43
  "outputs": [],
44
  "source": [
@@ -49,7 +49,7 @@
49
  {
50
  "cell_type": "code",
51
  "execution_count": 78,
52
- "id": "c28b0c5e",
53
  "metadata": {},
54
  "outputs": [
55
  {
@@ -70,8 +70,8 @@
70
  },
71
  {
72
  "cell_type": "code",
73
- "execution_count": 80,
74
- "id": "69025839",
75
  "metadata": {},
76
  "outputs": [],
77
  "source": [
@@ -81,13 +81,13 @@
81
  " hop_length=512,\n",
82
  " n_mels=64\n",
83
  " )\n",
84
- "dataset = VoiceDataset('../data', mel_spectrogram, 16000, device)"
85
  ]
86
  },
87
  {
88
  "cell_type": "code",
89
- "execution_count": 81,
90
- "id": "8dfbb1b4",
91
  "metadata": {},
92
  "outputs": [
93
  {
@@ -96,7 +96,7 @@
96
  "5718"
97
  ]
98
  },
99
- "execution_count": 81,
100
  "metadata": {},
101
  "output_type": "execute_result"
102
  }
@@ -108,7 +108,7 @@
108
  {
109
  "cell_type": "code",
110
  "execution_count": 82,
111
- "id": "1071e53d",
112
  "metadata": {},
113
  "outputs": [
114
  {
@@ -136,7 +136,7 @@
136
  {
137
  "cell_type": "code",
138
  "execution_count": 83,
139
- "id": "7a6d8133",
140
  "metadata": {},
141
  "outputs": [
142
  {
@@ -157,7 +157,7 @@
157
  {
158
  "cell_type": "code",
159
  "execution_count": 87,
160
- "id": "4b8f75a0",
161
  "metadata": {},
162
  "outputs": [
163
  {
@@ -200,13 +200,104 @@
200
  "summary(cnn, (1, 64, 44))"
201
  ]
202
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  {
204
  "cell_type": "code",
205
  "execution_count": null,
206
- "id": "888726ed",
207
  "metadata": {},
208
  "outputs": [],
209
- "source": []
 
 
210
  }
211
  ],
212
  "metadata": {
 
3
  {
4
  "cell_type": "code",
5
  "execution_count": 8,
6
+ "id": "55895dc1",
7
  "metadata": {},
8
  "outputs": [],
9
  "source": [
 
14
  {
15
  "cell_type": "code",
16
  "execution_count": 10,
17
+ "id": "d1e95c40",
18
  "metadata": {},
19
  "outputs": [],
20
  "source": [
 
25
  {
26
  "cell_type": "code",
27
  "execution_count": 86,
28
+ "id": "0ae6ce32",
29
  "metadata": {},
30
  "outputs": [],
31
  "source": [
 
38
  {
39
  "cell_type": "code",
40
  "execution_count": 85,
41
+ "id": "4200acc4",
42
  "metadata": {},
43
  "outputs": [],
44
  "source": [
 
49
  {
50
  "cell_type": "code",
51
  "execution_count": 78,
52
+ "id": "b98d408a",
53
  "metadata": {},
54
  "outputs": [
55
  {
 
70
  },
71
  {
72
  "cell_type": "code",
73
+ "execution_count": 97,
74
+ "id": "f26723ab",
75
  "metadata": {},
76
  "outputs": [],
77
  "source": [
 
81
  " hop_length=512,\n",
82
  " n_mels=64\n",
83
  " )\n",
84
+ "dataset = VoiceDataset('../data/train', mel_spectrogram, 16000, device)"
85
  ]
86
  },
87
  {
88
  "cell_type": "code",
89
+ "execution_count": 93,
90
+ "id": "7664a918",
91
  "metadata": {},
92
  "outputs": [
93
  {
 
96
  "5718"
97
  ]
98
  },
99
+ "execution_count": 93,
100
  "metadata": {},
101
  "output_type": "execute_result"
102
  }
 
108
  {
109
  "cell_type": "code",
110
  "execution_count": 82,
111
+ "id": "0adfe082",
112
  "metadata": {},
113
  "outputs": [
114
  {
 
136
  {
137
  "cell_type": "code",
138
  "execution_count": 83,
139
+ "id": "6f095274",
140
  "metadata": {},
141
  "outputs": [
142
  {
 
157
  {
158
  "cell_type": "code",
159
  "execution_count": 87,
160
+ "id": "362d6f74",
161
  "metadata": {},
162
  "outputs": [
163
  {
 
200
  "summary(cnn, (1, 64, 44))"
201
  ]
202
  },
203
+ {
204
+ "cell_type": "code",
205
+ "execution_count": 91,
206
+ "id": "d2da6515",
207
+ "metadata": {},
208
+ "outputs": [
209
+ {
210
+ "data": {
211
+ "text/plain": [
212
+ "tensor(0)"
213
+ ]
214
+ },
215
+ "execution_count": 91,
216
+ "metadata": {},
217
+ "output_type": "execute_result"
218
+ }
219
+ ],
220
+ "source": [
221
+ "torch.tensor(0)"
222
+ ]
223
+ },
224
+ {
225
+ "cell_type": "code",
226
+ "execution_count": 95,
227
+ "id": "8a10cc8c",
228
+ "metadata": {},
229
+ "outputs": [
230
+ {
231
+ "data": {
232
+ "text/plain": [
233
+ "{'aman': 0, 'imran': 1, 'labib': 2}"
234
+ ]
235
+ },
236
+ "execution_count": 95,
237
+ "metadata": {},
238
+ "output_type": "execute_result"
239
+ }
240
+ ],
241
+ "source": [
242
+ "dataset.label_mapping"
243
+ ]
244
+ },
245
+ {
246
+ "cell_type": "code",
247
+ "execution_count": 98,
248
+ "id": "e65a95c3",
249
+ "metadata": {},
250
+ "outputs": [
251
+ {
252
+ "ename": "TypeError",
253
+ "evalue": "join() argument must be str, bytes, or os.PathLike object, not 'int'",
254
+ "output_type": "error",
255
+ "traceback": [
256
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
257
+ "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
258
+ "Cell \u001b[0;32mIn[98], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\n",
259
+ "File \u001b[0;32m~/ml-sandbox/VoID/notebooks/../dataset.py:41\u001b[0m, in \u001b[0;36mVoiceDataset.__getitem__\u001b[0;34m(self, index)\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mself\u001b[39m, index):\n\u001b[1;32m 39\u001b[0m \u001b[38;5;66;03m# get file\u001b[39;00m\n\u001b[1;32m 40\u001b[0m file, label \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maudio_files_labels[index]\n\u001b[0;32m---> 41\u001b[0m filepath \u001b[38;5;241m=\u001b[39m \u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfile\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 43\u001b[0m \u001b[38;5;66;03m# load wav\u001b[39;00m\n\u001b[1;32m 44\u001b[0m wav, sr \u001b[38;5;241m=\u001b[39m torchaudio\u001b[38;5;241m.\u001b[39mload(filepath, normalize\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
260
+ "File \u001b[0;32m~/anaconda3/envs/void/lib/python3.9/posixpath.py:90\u001b[0m, in \u001b[0;36mjoin\u001b[0;34m(a, *p)\u001b[0m\n\u001b[1;32m 88\u001b[0m path \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m sep \u001b[38;5;241m+\u001b[39m b\n\u001b[1;32m 89\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mTypeError\u001b[39;00m, \u001b[38;5;167;01mAttributeError\u001b[39;00m, \u001b[38;5;167;01mBytesWarning\u001b[39;00m):\n\u001b[0;32m---> 90\u001b[0m \u001b[43mgenericpath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_check_arg_types\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mjoin\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43ma\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mp\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 91\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m\n\u001b[1;32m 92\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m path\n",
261
+ "File \u001b[0;32m~/anaconda3/envs/void/lib/python3.9/genericpath.py:152\u001b[0m, in \u001b[0;36m_check_arg_types\u001b[0;34m(funcname, *args)\u001b[0m\n\u001b[1;32m 150\u001b[0m hasbytes \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 151\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 152\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfuncname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m() argument must be str, bytes, or \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 153\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mos.PathLike object, not \u001b[39m\u001b[38;5;132;01m{\u001b[39;00ms\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m!r}\u001b[39;00m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 154\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m hasstr \u001b[38;5;129;01mand\u001b[39;00m hasbytes:\n\u001b[1;32m 155\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCan\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt mix strings and bytes in path components\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n",
262
+ "\u001b[0;31mTypeError\u001b[0m: join() argument must be str, bytes, or os.PathLike object, not 'int'"
263
+ ]
264
+ }
265
+ ],
266
+ "source": [
267
+ "dataset[0]"
268
+ ]
269
+ },
270
+ {
271
+ "cell_type": "code",
272
+ "execution_count": 104,
273
+ "id": "a1357e5b",
274
+ "metadata": {},
275
+ "outputs": [
276
+ {
277
+ "data": {
278
+ "text/plain": [
279
+ "'2023-05-12 22:28:06.556207'"
280
+ ]
281
+ },
282
+ "execution_count": 104,
283
+ "metadata": {},
284
+ "output_type": "execute_result"
285
+ }
286
+ ],
287
+ "source": [
288
+ "from datetime import datetime\n",
289
+ "now = datetime.now()"
290
+ ]
291
+ },
292
  {
293
  "cell_type": "code",
294
  "execution_count": null,
295
+ "id": "190c8d4b",
296
  "metadata": {},
297
  "outputs": [],
298
+ "source": [
299
+ "now.strftime(\"%Y%m%d-%H%M%S\")"
300
+ ]
301
  }
302
  ],
303
  "metadata": {
train.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from tqdm import tqdm
3
+
4
+ # torch
5
+ import torch
6
+ import torchaudio
7
+ from torch import nn
8
+ from torch.utils.data import DataLoader
9
+
10
+ # internal
11
+ from dataset import VoiceDataset
12
+ from cnn import CNNetwork
13
+
14
+ BATCH_SIZE = 128
15
+ EPOCHS = 100
16
+ LEARNING_RATE = 0.001
17
+
18
+ TRAIN_FILE="data/train"
19
+ SAMPLE_RATE=16000
20
+
21
+ def train(model, dataloader, loss_fn, optimizer, device, epochs):
22
+ for i in tqdm(range(epochs), "Training model..."):
23
+ print(f"Epoch {i + 1}")
24
+
25
+ train_epoch(model, dataloader, loss_fn, optimizer, device)
26
+
27
+ print (f"----------------------------------- \n")
28
+
29
+ print("---- Finished Training ----")
30
+
31
+
32
+ def train_epoch(model, dataloader, loss_fn, optimizer, device):
33
+ for x, y in dataloader:
34
+ x, y = x.to(device), y.to(device)
35
+
36
+ # calculate loss
37
+ pred = model(x)
38
+ loss = loss_fn(pred, y)
39
+
40
+ # backprop and update weights
41
+ optimizer.zero_grad()
42
+ loss.backward()
43
+ optimizer.step()
44
+
45
+ print(f"Loss: {loss.item()}")
46
+
47
+ if __name__ == "__main__":
48
+ if torch.cuda.is_available():
49
+ device = "cuda"
50
+ else:
51
+ device = "cpu"
52
+ print(f"Using {device} device.")
53
+
54
+ # instantiating our dataset object and create data loader
55
+ mel_spectrogram = torchaudio.transforms.MelSpectrogram(
56
+ sample_rate=SAMPLE_RATE,
57
+ n_fft=1024,
58
+ hop_length=512,
59
+ n_mels=64
60
+ )
61
+
62
+ train_dataset = VoiceDataset(TRAIN_FILE, mel_spectrogram, SAMPLE_RATE, device)
63
+ train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
64
+
65
+ # construct model
66
+ model = CNNetwork().to(device)
67
+ print(model)
68
+
69
+ # init loss function and optimizer
70
+ loss_fn = nn.CrossEntropyLoss()
71
+ optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
72
+
73
+
74
+ # train model
75
+ train(model, train_dataloader, loss_fn, optimizer, device, EPOCHS)
76
+
77
+ # save model
78
+ now = datetime.now()
79
+ now = now.strftime("%Y%m%d_%H%M%S")
80
+ model_filename = f"models/void_{now}.pth"
81
+ torch.save(model.state_dict(), model_filename)
82
+ print(f"Trained feed forward net saved at {model_filename}")