Spaces:
Runtime error
Runtime error
Add training
Browse files- __pycache__/cnn.cpython-39.pyc +0 -0
- __pycache__/dataset.cpython-39.pyc +0 -0
- cnn.py +1 -1
- dataset.py +4 -1
- models/void.pth +0 -0
- notebooks/playground.ipynb +107 -16
- train.py +82 -0
__pycache__/cnn.cpython-39.pyc
CHANGED
Binary files a/__pycache__/cnn.cpython-39.pyc and b/__pycache__/cnn.cpython-39.pyc differ
|
|
__pycache__/dataset.cpython-39.pyc
CHANGED
Binary files a/__pycache__/dataset.cpython-39.pyc and b/__pycache__/dataset.cpython-39.pyc differ
|
|
cnn.py
CHANGED
@@ -52,7 +52,7 @@ class CNNetwork(nn.Module):
|
|
52 |
nn.MaxPool2d(kernel_size=2)
|
53 |
)
|
54 |
self.flatten = nn.Flatten()
|
55 |
-
self.linear = nn.Linear(128 * 5 *
|
56 |
self.softmax = nn.Softmax(dim=1)
|
57 |
|
58 |
def forward(self, input_data):
|
|
|
52 |
nn.MaxPool2d(kernel_size=2)
|
53 |
)
|
54 |
self.flatten = nn.Flatten()
|
55 |
+
self.linear = nn.Linear(128 * 5 * 11, 10)
|
56 |
self.softmax = nn.Softmax(dim=1)
|
57 |
|
58 |
def forward(self, input_data):
|
dataset.py
CHANGED
@@ -18,6 +18,7 @@ class VoiceDataset(Dataset):
|
|
18 |
# file processing
|
19 |
self._data_path = os.path.join(data_directory)
|
20 |
self._labels = os.listdir(self._data_path)
|
|
|
21 |
self.audio_files_labels = self._join_audio_files()
|
22 |
|
23 |
self.device = device
|
@@ -50,7 +51,9 @@ class VoiceDataset(Dataset):
|
|
50 |
|
51 |
# apply transformation
|
52 |
wav = self.transformation(wav)
|
53 |
-
|
|
|
|
|
54 |
|
55 |
|
56 |
def _join_audio_files(self):
|
|
|
18 |
# file processing
|
19 |
self._data_path = os.path.join(data_directory)
|
20 |
self._labels = os.listdir(self._data_path)
|
21 |
+
self.label_mapping = {label: i for i, label in enumerate(self._labels)}
|
22 |
self.audio_files_labels = self._join_audio_files()
|
23 |
|
24 |
self.device = device
|
|
|
51 |
|
52 |
# apply transformation
|
53 |
wav = self.transformation(wav)
|
54 |
+
|
55 |
+
# return wav and integer representation of the label
|
56 |
+
return wav, self.label_mapping[label]
|
57 |
|
58 |
|
59 |
def _join_audio_files(self):
|
models/void.pth
ADDED
Binary file (673 kB). View file
|
|
notebooks/playground.ipynb
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
"execution_count": 8,
|
6 |
-
"id": "
|
7 |
"metadata": {},
|
8 |
"outputs": [],
|
9 |
"source": [
|
@@ -14,7 +14,7 @@
|
|
14 |
{
|
15 |
"cell_type": "code",
|
16 |
"execution_count": 10,
|
17 |
-
"id": "
|
18 |
"metadata": {},
|
19 |
"outputs": [],
|
20 |
"source": [
|
@@ -25,7 +25,7 @@
|
|
25 |
{
|
26 |
"cell_type": "code",
|
27 |
"execution_count": 86,
|
28 |
-
"id": "
|
29 |
"metadata": {},
|
30 |
"outputs": [],
|
31 |
"source": [
|
@@ -38,7 +38,7 @@
|
|
38 |
{
|
39 |
"cell_type": "code",
|
40 |
"execution_count": 85,
|
41 |
-
"id": "
|
42 |
"metadata": {},
|
43 |
"outputs": [],
|
44 |
"source": [
|
@@ -49,7 +49,7 @@
|
|
49 |
{
|
50 |
"cell_type": "code",
|
51 |
"execution_count": 78,
|
52 |
-
"id": "
|
53 |
"metadata": {},
|
54 |
"outputs": [
|
55 |
{
|
@@ -70,8 +70,8 @@
|
|
70 |
},
|
71 |
{
|
72 |
"cell_type": "code",
|
73 |
-
"execution_count":
|
74 |
-
"id": "
|
75 |
"metadata": {},
|
76 |
"outputs": [],
|
77 |
"source": [
|
@@ -81,13 +81,13 @@
|
|
81 |
" hop_length=512,\n",
|
82 |
" n_mels=64\n",
|
83 |
" )\n",
|
84 |
-
"dataset = VoiceDataset('../data', mel_spectrogram, 16000, device)"
|
85 |
]
|
86 |
},
|
87 |
{
|
88 |
"cell_type": "code",
|
89 |
-
"execution_count":
|
90 |
-
"id": "
|
91 |
"metadata": {},
|
92 |
"outputs": [
|
93 |
{
|
@@ -96,7 +96,7 @@
|
|
96 |
"5718"
|
97 |
]
|
98 |
},
|
99 |
-
"execution_count":
|
100 |
"metadata": {},
|
101 |
"output_type": "execute_result"
|
102 |
}
|
@@ -108,7 +108,7 @@
|
|
108 |
{
|
109 |
"cell_type": "code",
|
110 |
"execution_count": 82,
|
111 |
-
"id": "
|
112 |
"metadata": {},
|
113 |
"outputs": [
|
114 |
{
|
@@ -136,7 +136,7 @@
|
|
136 |
{
|
137 |
"cell_type": "code",
|
138 |
"execution_count": 83,
|
139 |
-
"id": "
|
140 |
"metadata": {},
|
141 |
"outputs": [
|
142 |
{
|
@@ -157,7 +157,7 @@
|
|
157 |
{
|
158 |
"cell_type": "code",
|
159 |
"execution_count": 87,
|
160 |
-
"id": "
|
161 |
"metadata": {},
|
162 |
"outputs": [
|
163 |
{
|
@@ -200,13 +200,104 @@
|
|
200 |
"summary(cnn, (1, 64, 44))"
|
201 |
]
|
202 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
{
|
204 |
"cell_type": "code",
|
205 |
"execution_count": null,
|
206 |
-
"id": "
|
207 |
"metadata": {},
|
208 |
"outputs": [],
|
209 |
-
"source": [
|
|
|
|
|
210 |
}
|
211 |
],
|
212 |
"metadata": {
|
|
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
"execution_count": 8,
|
6 |
+
"id": "55895dc1",
|
7 |
"metadata": {},
|
8 |
"outputs": [],
|
9 |
"source": [
|
|
|
14 |
{
|
15 |
"cell_type": "code",
|
16 |
"execution_count": 10,
|
17 |
+
"id": "d1e95c40",
|
18 |
"metadata": {},
|
19 |
"outputs": [],
|
20 |
"source": [
|
|
|
25 |
{
|
26 |
"cell_type": "code",
|
27 |
"execution_count": 86,
|
28 |
+
"id": "0ae6ce32",
|
29 |
"metadata": {},
|
30 |
"outputs": [],
|
31 |
"source": [
|
|
|
38 |
{
|
39 |
"cell_type": "code",
|
40 |
"execution_count": 85,
|
41 |
+
"id": "4200acc4",
|
42 |
"metadata": {},
|
43 |
"outputs": [],
|
44 |
"source": [
|
|
|
49 |
{
|
50 |
"cell_type": "code",
|
51 |
"execution_count": 78,
|
52 |
+
"id": "b98d408a",
|
53 |
"metadata": {},
|
54 |
"outputs": [
|
55 |
{
|
|
|
70 |
},
|
71 |
{
|
72 |
"cell_type": "code",
|
73 |
+
"execution_count": 97,
|
74 |
+
"id": "f26723ab",
|
75 |
"metadata": {},
|
76 |
"outputs": [],
|
77 |
"source": [
|
|
|
81 |
" hop_length=512,\n",
|
82 |
" n_mels=64\n",
|
83 |
" )\n",
|
84 |
+
"dataset = VoiceDataset('../data/train', mel_spectrogram, 16000, device)"
|
85 |
]
|
86 |
},
|
87 |
{
|
88 |
"cell_type": "code",
|
89 |
+
"execution_count": 93,
|
90 |
+
"id": "7664a918",
|
91 |
"metadata": {},
|
92 |
"outputs": [
|
93 |
{
|
|
|
96 |
"5718"
|
97 |
]
|
98 |
},
|
99 |
+
"execution_count": 93,
|
100 |
"metadata": {},
|
101 |
"output_type": "execute_result"
|
102 |
}
|
|
|
108 |
{
|
109 |
"cell_type": "code",
|
110 |
"execution_count": 82,
|
111 |
+
"id": "0adfe082",
|
112 |
"metadata": {},
|
113 |
"outputs": [
|
114 |
{
|
|
|
136 |
{
|
137 |
"cell_type": "code",
|
138 |
"execution_count": 83,
|
139 |
+
"id": "6f095274",
|
140 |
"metadata": {},
|
141 |
"outputs": [
|
142 |
{
|
|
|
157 |
{
|
158 |
"cell_type": "code",
|
159 |
"execution_count": 87,
|
160 |
+
"id": "362d6f74",
|
161 |
"metadata": {},
|
162 |
"outputs": [
|
163 |
{
|
|
|
200 |
"summary(cnn, (1, 64, 44))"
|
201 |
]
|
202 |
},
|
203 |
+
{
|
204 |
+
"cell_type": "code",
|
205 |
+
"execution_count": 91,
|
206 |
+
"id": "d2da6515",
|
207 |
+
"metadata": {},
|
208 |
+
"outputs": [
|
209 |
+
{
|
210 |
+
"data": {
|
211 |
+
"text/plain": [
|
212 |
+
"tensor(0)"
|
213 |
+
]
|
214 |
+
},
|
215 |
+
"execution_count": 91,
|
216 |
+
"metadata": {},
|
217 |
+
"output_type": "execute_result"
|
218 |
+
}
|
219 |
+
],
|
220 |
+
"source": [
|
221 |
+
"torch.tensor(0)"
|
222 |
+
]
|
223 |
+
},
|
224 |
+
{
|
225 |
+
"cell_type": "code",
|
226 |
+
"execution_count": 95,
|
227 |
+
"id": "8a10cc8c",
|
228 |
+
"metadata": {},
|
229 |
+
"outputs": [
|
230 |
+
{
|
231 |
+
"data": {
|
232 |
+
"text/plain": [
|
233 |
+
"{'aman': 0, 'imran': 1, 'labib': 2}"
|
234 |
+
]
|
235 |
+
},
|
236 |
+
"execution_count": 95,
|
237 |
+
"metadata": {},
|
238 |
+
"output_type": "execute_result"
|
239 |
+
}
|
240 |
+
],
|
241 |
+
"source": [
|
242 |
+
"dataset.label_mapping"
|
243 |
+
]
|
244 |
+
},
|
245 |
+
{
|
246 |
+
"cell_type": "code",
|
247 |
+
"execution_count": 98,
|
248 |
+
"id": "e65a95c3",
|
249 |
+
"metadata": {},
|
250 |
+
"outputs": [
|
251 |
+
{
|
252 |
+
"ename": "TypeError",
|
253 |
+
"evalue": "join() argument must be str, bytes, or os.PathLike object, not 'int'",
|
254 |
+
"output_type": "error",
|
255 |
+
"traceback": [
|
256 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
257 |
+
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
|
258 |
+
"Cell \u001b[0;32mIn[98], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\n",
|
259 |
+
"File \u001b[0;32m~/ml-sandbox/VoID/notebooks/../dataset.py:41\u001b[0m, in \u001b[0;36mVoiceDataset.__getitem__\u001b[0;34m(self, index)\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mself\u001b[39m, index):\n\u001b[1;32m 39\u001b[0m \u001b[38;5;66;03m# get file\u001b[39;00m\n\u001b[1;32m 40\u001b[0m file, label \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maudio_files_labels[index]\n\u001b[0;32m---> 41\u001b[0m filepath \u001b[38;5;241m=\u001b[39m \u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfile\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 43\u001b[0m \u001b[38;5;66;03m# load wav\u001b[39;00m\n\u001b[1;32m 44\u001b[0m wav, sr \u001b[38;5;241m=\u001b[39m torchaudio\u001b[38;5;241m.\u001b[39mload(filepath, normalize\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
|
260 |
+
"File \u001b[0;32m~/anaconda3/envs/void/lib/python3.9/posixpath.py:90\u001b[0m, in \u001b[0;36mjoin\u001b[0;34m(a, *p)\u001b[0m\n\u001b[1;32m 88\u001b[0m path \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m sep \u001b[38;5;241m+\u001b[39m b\n\u001b[1;32m 89\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mTypeError\u001b[39;00m, \u001b[38;5;167;01mAttributeError\u001b[39;00m, \u001b[38;5;167;01mBytesWarning\u001b[39;00m):\n\u001b[0;32m---> 90\u001b[0m \u001b[43mgenericpath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_check_arg_types\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mjoin\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43ma\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mp\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 91\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m\n\u001b[1;32m 92\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m path\n",
|
261 |
+
"File \u001b[0;32m~/anaconda3/envs/void/lib/python3.9/genericpath.py:152\u001b[0m, in \u001b[0;36m_check_arg_types\u001b[0;34m(funcname, *args)\u001b[0m\n\u001b[1;32m 150\u001b[0m hasbytes \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 151\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 152\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfuncname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m() argument must be str, bytes, or \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 153\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mos.PathLike object, not \u001b[39m\u001b[38;5;132;01m{\u001b[39;00ms\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m!r}\u001b[39;00m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 154\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m hasstr \u001b[38;5;129;01mand\u001b[39;00m hasbytes:\n\u001b[1;32m 155\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCan\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt mix strings and bytes in path components\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
262 |
+
"\u001b[0;31mTypeError\u001b[0m: join() argument must be str, bytes, or os.PathLike object, not 'int'"
|
263 |
+
]
|
264 |
+
}
|
265 |
+
],
|
266 |
+
"source": [
|
267 |
+
"dataset[0]"
|
268 |
+
]
|
269 |
+
},
|
270 |
+
{
|
271 |
+
"cell_type": "code",
|
272 |
+
"execution_count": 104,
|
273 |
+
"id": "a1357e5b",
|
274 |
+
"metadata": {},
|
275 |
+
"outputs": [
|
276 |
+
{
|
277 |
+
"data": {
|
278 |
+
"text/plain": [
|
279 |
+
"'2023-05-12 22:28:06.556207'"
|
280 |
+
]
|
281 |
+
},
|
282 |
+
"execution_count": 104,
|
283 |
+
"metadata": {},
|
284 |
+
"output_type": "execute_result"
|
285 |
+
}
|
286 |
+
],
|
287 |
+
"source": [
|
288 |
+
"from datetime import datetime\n",
|
289 |
+
"now = datetime.now()"
|
290 |
+
]
|
291 |
+
},
|
292 |
{
|
293 |
"cell_type": "code",
|
294 |
"execution_count": null,
|
295 |
+
"id": "190c8d4b",
|
296 |
"metadata": {},
|
297 |
"outputs": [],
|
298 |
+
"source": [
|
299 |
+
"now.strftime(\"%Y%m%d-%H%M%S\")"
|
300 |
+
]
|
301 |
}
|
302 |
],
|
303 |
"metadata": {
|
train.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
from tqdm import tqdm
|
3 |
+
|
4 |
+
# torch
|
5 |
+
import torch
|
6 |
+
import torchaudio
|
7 |
+
from torch import nn
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
|
10 |
+
# internal
|
11 |
+
from dataset import VoiceDataset
|
12 |
+
from cnn import CNNetwork
|
13 |
+
|
14 |
+
BATCH_SIZE = 128
|
15 |
+
EPOCHS = 100
|
16 |
+
LEARNING_RATE = 0.001
|
17 |
+
|
18 |
+
TRAIN_FILE="data/train"
|
19 |
+
SAMPLE_RATE=16000
|
20 |
+
|
21 |
+
def train(model, dataloader, loss_fn, optimizer, device, epochs):
|
22 |
+
for i in tqdm(range(epochs), "Training model..."):
|
23 |
+
print(f"Epoch {i + 1}")
|
24 |
+
|
25 |
+
train_epoch(model, dataloader, loss_fn, optimizer, device)
|
26 |
+
|
27 |
+
print (f"----------------------------------- \n")
|
28 |
+
|
29 |
+
print("---- Finished Training ----")
|
30 |
+
|
31 |
+
|
32 |
+
def train_epoch(model, dataloader, loss_fn, optimizer, device):
|
33 |
+
for x, y in dataloader:
|
34 |
+
x, y = x.to(device), y.to(device)
|
35 |
+
|
36 |
+
# calculate loss
|
37 |
+
pred = model(x)
|
38 |
+
loss = loss_fn(pred, y)
|
39 |
+
|
40 |
+
# backprop and update weights
|
41 |
+
optimizer.zero_grad()
|
42 |
+
loss.backward()
|
43 |
+
optimizer.step()
|
44 |
+
|
45 |
+
print(f"Loss: {loss.item()}")
|
46 |
+
|
47 |
+
if __name__ == "__main__":
|
48 |
+
if torch.cuda.is_available():
|
49 |
+
device = "cuda"
|
50 |
+
else:
|
51 |
+
device = "cpu"
|
52 |
+
print(f"Using {device} device.")
|
53 |
+
|
54 |
+
# instantiating our dataset object and create data loader
|
55 |
+
mel_spectrogram = torchaudio.transforms.MelSpectrogram(
|
56 |
+
sample_rate=SAMPLE_RATE,
|
57 |
+
n_fft=1024,
|
58 |
+
hop_length=512,
|
59 |
+
n_mels=64
|
60 |
+
)
|
61 |
+
|
62 |
+
train_dataset = VoiceDataset(TRAIN_FILE, mel_spectrogram, SAMPLE_RATE, device)
|
63 |
+
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
64 |
+
|
65 |
+
# construct model
|
66 |
+
model = CNNetwork().to(device)
|
67 |
+
print(model)
|
68 |
+
|
69 |
+
# init loss function and optimizer
|
70 |
+
loss_fn = nn.CrossEntropyLoss()
|
71 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
72 |
+
|
73 |
+
|
74 |
+
# train model
|
75 |
+
train(model, train_dataloader, loss_fn, optimizer, device, EPOCHS)
|
76 |
+
|
77 |
+
# save model
|
78 |
+
now = datetime.now()
|
79 |
+
now = now.strftime("%Y%m%d_%H%M%S")
|
80 |
+
model_filename = f"models/void_{now}.pth"
|
81 |
+
torch.save(model.state_dict(), model_filename)
|
82 |
+
print(f"Trained feed forward net saved at {model_filename}")
|