Spaces:
Runtime error
Runtime error
carlfeynman
commited on
Commit
•
c32023c
1
Parent(s):
fccfd79
full model inference/ui added
Browse files- __pycache__/mnist_classifier.cpython-39.pyc +0 -0
- __pycache__/server.cpython-39.pyc +0 -0
- classifier.pkl +0 -0
- classifier.pth +0 -0
- mnist_classifier.ipynb +142 -38
- mnist_classifier.py +16 -38
- requirements.txt +6 -1
- server.py +21 -3
- static/index.html +105 -10
__pycache__/mnist_classifier.cpython-39.pyc
ADDED
Binary file (4.5 kB). View file
|
|
__pycache__/server.cpython-39.pyc
ADDED
Binary file (1.32 kB). View file
|
|
classifier.pkl
DELETED
Binary file (654 kB)
|
|
classifier.pth
ADDED
Binary file (647 kB). View file
|
|
mnist_classifier.ipynb
CHANGED
@@ -16,7 +16,19 @@
|
|
16 |
"import torchvision.transforms.functional as TF\n",
|
17 |
"from torch.utils.data import default_collate, DataLoader\n",
|
18 |
"import torch.optim as optim\n",
|
19 |
-
"import pickle\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
"%matplotlib inline\n",
|
21 |
"plt.rcParams['figure.figsize'] = [2, 2]"
|
22 |
]
|
@@ -24,7 +36,11 @@
|
|
24 |
{
|
25 |
"cell_type": "code",
|
26 |
"execution_count": 101,
|
27 |
-
"metadata": {
|
|
|
|
|
|
|
|
|
28 |
"outputs": [
|
29 |
{
|
30 |
"name": "stderr",
|
@@ -43,27 +59,25 @@
|
|
43 |
},
|
44 |
{
|
45 |
"cell_type": "code",
|
46 |
-
"execution_count":
|
47 |
"metadata": {},
|
48 |
-
"outputs": [
|
49 |
-
{
|
50 |
-
"data": {
|
51 |
-
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAI4AAACOCAYAAADn/TAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAIsUlEQVR4nO3df2yU9R0H8PfHtrQroFJBVrGjHVRAweHWCASCJBuumiXOLAyYWTbjQiYy58Y2fmzZ5oILJgsJMjSRrCsmig7mAjFsZBIlLkNGdeBgrOWnWqnFwkDmUNrrZ3/0bPu59cfTz3P33NPr+5WQu89zd32+MW+/z/eeu+dzoqogGqgrsj0AGpwYHHJhcMiFwSEXBodcGBxyCRUcEakWkXoROSYiK9M1KIo/8Z7HEZE8AA0A5gNoBLAfwGJV/Wf6hkdxlR/itbcCOKaqJwBARJ4FcBeAXoMzTAq1CMND7JKidhH/blHVManbwwRnHIC3u9WNAGb09YIiDMcM+XyIXVLUXtRtb/a0PUxwpIdt/3fcE5ElAJYAQBGKQ+yO4iTM4rgRQFm3+noAp1OfpKpPqmqVqlYVoDDE7ihOwgRnP4BKEakQkWEAFgHYkZ5hUdy5D1Wq2iYiywDsApAHoEZVD6dtZBRrYdY4UNWdAHamaSw0iPDMMbkwOOTC4JALg0MuDA65MDjkwuCQC4NDLgwOuTA45MLgkAuDQy6hPuQcSiTf/qfKGzM68Gvrf1Bu6kRxu6nHTzhj6uKl9jty764bZurXq54zdUviA1PP2Lq88/7E778aeJwDwRmHXBgccmFwyGXIrHHyplSaWgsLTH36tqtNfWmmXTeUXGXrVz5j1xlh/PG/I0396K+rTb1v2jOmPtl6ydRrm+eb+rpXMt/ziDMOuTA45MLgkEvOrnES8z5r6nW1G019Q4E9NxKlVk2Y+qcbvmnq/A/sGmXW1mWmHvlOm6kLW+yap7huX8gR9o8zDrkwOOTC4JBLzq5xCuvtZeyvfVhm6hsKmtO2r+VNM0194j/2c6zaCdtMfaHdrmHGPvbXUPvPRqdqzjjkwuCQC4NDLjm7xmlretfUGx5dYOpHqu1nT3lvjDD1waUb+vz7a1pu7rx/7Au2YVTifJOpvzZrqalPPWj/VgUO9rmvOOKMQy79BkdEakTkjIgc6ratRET+LCJHk7ejMjtMipsgM04tgOqUbSsB7FbVSgC7kzUNIYH6HItIOYAXVHVqsq4HME9Vm0SkFMDLqjqpv79zpZRoXLqO5o2+xtSJs+dMffKZm019eG6NqW/95Xc671+7Mdx5mDh7Ube9pqpVqdu9a5yxqtoEAMnba8MMjgafjL+rYrva3OSdcZqThygkb8/09kS2q81N3hlnB4BvAFibvN2ethFFJNFyts/HW9/v+/s6N93T9csD7z2RZx9sTyDXBXk7vgXAXgCTRKRRRO5DR2Dmi8hRdPwIyNrMDpPipt8ZR1UX9/JQPN4eUVbwzDG55OxnVWFNWdFg6nun2Qn2t+N3d96/bcED5rGRz2Xmeu044YxDLgwOuTA45MI1Ti8S5y+Y+uz9U0z91o6ua5lWrnnKPLbqq3ebWv9+lanLHtlrd+b8XdRs4oxDLgwOufBQFVD7wSOmXvTwDzvvP/2zX5nHDsy0hy7Yq2dw03B7SW/lJvtV07YTp3yDjBBnHHJhcMiFwSGXQF8dTZc4fXU0nXT2dFNfubbR1Fs+vavP109+6VumnvSwPRWQOHrCP7iQ0v3VURriGBxyYXDIhWucDMgbay/6OL1woqn3rVhv6itS/v+95+Ttpr4wp++vuWYS1ziUVgwOuTA45MLPqjIg0WwvMxv7mK0//JFtN1ss9lKcTeUvmPpLdz9kn/+HzLej7Q9nHHJhcMiFwSEXrnHSoH3OdFMfX1Bk6qnTT5k6dU2TasO5W+zzt9e5x5YpnHHIhcEhFwaHXLjGCUiqppq64cGudcqm2ZvNY3OLLg/ob3+kraZ+9VyFfUK7/U5yHHDGIZcg/XHKROQlETkiIodF5LvJ7WxZO4QFmXHaACxX1SnouNDjARG5EWxZO6QFaazUBODjDqMXReQIgHEA7gIwL/m0zQBeBrAiI6OMQH7FeFMfv/c6U/984bOm/sqIFve+Vjfbr7fsWW8vvBq1OeUS4Rga0Bon2e/4FgD7wJa1Q1rg4IjICAC/B/CQqr4/gNctEZE6EalrxUeeMVIMBQqOiBSgIzRPq+rzyc2BWtayXW1u6neNIyIC4DcAjqjqum4PDaqWtfnlnzL1hc+VmnrhL/5k6m9f/Ty8Un9qce/jdk1TUvs3U49qj/+aJlWQE4CzAXwdwD9E5EBy22p0BOZ3yfa1bwFY0PPLKRcFeVf1FwDSy8O5f8kC9YhnjsklZz6ryi/9pKnP1Qw39f0Ve0y9eGS4n49e9s6czvuvPzHdPDZ62yFTl1wcfGuY/nDGIRcGh1wYHHIZVGucy1/sOh9y+Xv2pxBXT9xp6ts/YX8eeqCaE5dMPXfHclNP/sm/Ou+XnLdrmPZQex4cOOOQC4NDLoPqUHXqy105b5i2dUCv3Xh+gqnX77GtRCRhz3FOXnPS1JXN9rLb3P8NvL5xxiEXBodcGBxyYSs36hNbuVFaMTjkwuCQC4NDLgwOuTA45MLgkAuDQy4MDrkwOOTC4JBLpJ9Vich7AN4EMBqAv09IZnFs1nhVHZO6MdLgdO5UpK6nD87igGMLhocqcmFwyCVbwXkyS/sNgmMLICtrHBr8eKgil0iDIyLVIlIvIsdEJKvtbUWkRkTOiMihbtti0bt5MPSWjiw4IpIHYCOAOwDcCGBxsl9yttQCqE7ZFpfezfHvLa2qkfwDMAvArm71KgCrotp/L2MqB3CoW10PoDR5vxRAfTbH121c2wHMj9P4ojxUjQPwdre6MbktTmLXuzmuvaWjDE5PfQT5lq4P3t7SUYgyOI0AyrrV1wM4HeH+gwjUuzkKYXpLRyHK4OwHUCkiFSIyDMAidPRKjpOPezcDWezdHKC3NJDt3tIRL/LuBNAA4DiAH2d5wbkFHT9u0oqO2fA+ANeg493K0eRtSZbGNgcdh/E3ABxI/rszLuNTVZ45Jh+eOSYXBodcGBxyYXDIhcEhFwaHXBgccmFwyOV/atVD7hyCzrEAAAAASUVORK5CYII=",
|
52 |
-
"text/plain": [
|
53 |
-
"<Figure size 144x144 with 1 Axes>"
|
54 |
-
]
|
55 |
-
},
|
56 |
-
"metadata": {
|
57 |
-
"needs_background": "light"
|
58 |
-
},
|
59 |
-
"output_type": "display_data"
|
60 |
-
}
|
61 |
-
],
|
62 |
"source": [
|
63 |
"def transform_ds(b):\n",
|
64 |
" b[x] = [TF.to_tensor(ele) for ele in b[x]]\n",
|
65 |
-
" return b
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
"dst = ds.with_transform(transform_ds)\n",
|
68 |
"plt.imshow(dst['train'][0]['image'].permute(1,2,0));"
|
69 |
]
|
@@ -93,8 +107,19 @@
|
|
93 |
"\n",
|
94 |
"def collate_fn(b):\n",
|
95 |
" collate = default_collate(b)\n",
|
96 |
-
" return (collate[x], collate[y])
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
"dls = DataLoaders(dst['train'], dst['test'], bs=bs, collate_fn=collate_fn)\n",
|
99 |
"xb,yb = next(iter(dls.train))\n",
|
100 |
"xb.shape, yb.shape"
|
@@ -178,23 +203,27 @@
|
|
178 |
},
|
179 |
{
|
180 |
"cell_type": "code",
|
181 |
-
"execution_count":
|
182 |
-
"metadata": {
|
|
|
|
|
|
|
|
|
183 |
"outputs": [
|
184 |
{
|
185 |
"name": "stdout",
|
186 |
"output_type": "stream",
|
187 |
"text": [
|
188 |
-
"train, epoch:1, loss: 0.
|
189 |
-
"eval, epoch:1, loss: 0.
|
190 |
-
"train, epoch:2, loss: 0.
|
191 |
-
"eval, epoch:2, loss: 0.
|
192 |
-
"train, epoch:3, loss: 0.
|
193 |
-
"eval, epoch:3, loss: 0.
|
194 |
-
"train, epoch:4, loss: 0.
|
195 |
-
"eval, epoch:4, loss: 0.
|
196 |
-
"train, epoch:5, loss: 0.
|
197 |
-
"eval, epoch:5, loss: 0.
|
198 |
]
|
199 |
}
|
200 |
],
|
@@ -227,7 +256,7 @@
|
|
227 |
},
|
228 |
{
|
229 |
"cell_type": "code",
|
230 |
-
"execution_count":
|
231 |
"metadata": {
|
232 |
"tags": [
|
233 |
"exclude"
|
@@ -235,8 +264,76 @@
|
|
235 |
},
|
236 |
"outputs": [],
|
237 |
"source": [
|
238 |
-
"
|
239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
]
|
241 |
},
|
242 |
{
|
@@ -252,7 +349,7 @@
|
|
252 |
},
|
253 |
{
|
254 |
"cell_type": "code",
|
255 |
-
"execution_count":
|
256 |
"metadata": {
|
257 |
"tags": [
|
258 |
"exclude"
|
@@ -264,13 +361,20 @@
|
|
264 |
"output_type": "stream",
|
265 |
"text": [
|
266 |
"[NbConvertApp] Converting notebook mnist_classifier.ipynb to script\n",
|
267 |
-
"[NbConvertApp] Writing
|
268 |
]
|
269 |
}
|
270 |
],
|
271 |
"source": [
|
272 |
"!jupyter nbconvert --to script --TagRemovePreprocessor.remove_cell_tags=\"exclude\" --TemplateExporter.exclude_input_prompt=True mnist_classifier.ipynb\n"
|
273 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
}
|
275 |
],
|
276 |
"metadata": {
|
|
|
16 |
"import torchvision.transforms.functional as TF\n",
|
17 |
"from torch.utils.data import default_collate, DataLoader\n",
|
18 |
"import torch.optim as optim\n",
|
19 |
+
"import pickle\n"
|
20 |
+
]
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"cell_type": "code",
|
24 |
+
"execution_count": null,
|
25 |
+
"metadata": {
|
26 |
+
"tags": [
|
27 |
+
"exclude"
|
28 |
+
]
|
29 |
+
},
|
30 |
+
"outputs": [],
|
31 |
+
"source": [
|
32 |
"%matplotlib inline\n",
|
33 |
"plt.rcParams['figure.figsize'] = [2, 2]"
|
34 |
]
|
|
|
36 |
{
|
37 |
"cell_type": "code",
|
38 |
"execution_count": 101,
|
39 |
+
"metadata": {
|
40 |
+
"tags": [
|
41 |
+
"exclude"
|
42 |
+
]
|
43 |
+
},
|
44 |
"outputs": [
|
45 |
{
|
46 |
"name": "stderr",
|
|
|
59 |
},
|
60 |
{
|
61 |
"cell_type": "code",
|
62 |
+
"execution_count": 112,
|
63 |
"metadata": {},
|
64 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
"source": [
|
66 |
"def transform_ds(b):\n",
|
67 |
" b[x] = [TF.to_tensor(ele) for ele in b[x]]\n",
|
68 |
+
" return b"
|
69 |
+
]
|
70 |
+
},
|
71 |
+
{
|
72 |
+
"cell_type": "code",
|
73 |
+
"execution_count": null,
|
74 |
+
"metadata": {
|
75 |
+
"tags": [
|
76 |
+
"exclude"
|
77 |
+
]
|
78 |
+
},
|
79 |
+
"outputs": [],
|
80 |
+
"source": [
|
81 |
"dst = ds.with_transform(transform_ds)\n",
|
82 |
"plt.imshow(dst['train'][0]['image'].permute(1,2,0));"
|
83 |
]
|
|
|
107 |
"\n",
|
108 |
"def collate_fn(b):\n",
|
109 |
" collate = default_collate(b)\n",
|
110 |
+
" return (collate[x], collate[y])"
|
111 |
+
]
|
112 |
+
},
|
113 |
+
{
|
114 |
+
"cell_type": "code",
|
115 |
+
"execution_count": null,
|
116 |
+
"metadata": {
|
117 |
+
"tags": [
|
118 |
+
"exclude"
|
119 |
+
]
|
120 |
+
},
|
121 |
+
"outputs": [],
|
122 |
+
"source": [
|
123 |
"dls = DataLoaders(dst['train'], dst['test'], bs=bs, collate_fn=collate_fn)\n",
|
124 |
"xb,yb = next(iter(dls.train))\n",
|
125 |
"xb.shape, yb.shape"
|
|
|
203 |
},
|
204 |
{
|
205 |
"cell_type": "code",
|
206 |
+
"execution_count": 195,
|
207 |
+
"metadata": {
|
208 |
+
"tags": [
|
209 |
+
"exclude"
|
210 |
+
]
|
211 |
+
},
|
212 |
"outputs": [
|
213 |
{
|
214 |
"name": "stdout",
|
215 |
"output_type": "stream",
|
216 |
"text": [
|
217 |
+
"train, epoch:1, loss: 0.0776, accuracy: 0.9172\n",
|
218 |
+
"eval, epoch:1, loss: 0.0372, accuracy: 0.9818\n",
|
219 |
+
"train, epoch:2, loss: 0.0571, accuracy: 0.9828\n",
|
220 |
+
"eval, epoch:2, loss: 0.0287, accuracy: 0.9863\n",
|
221 |
+
"train, epoch:3, loss: 0.0425, accuracy: 0.9847\n",
|
222 |
+
"eval, epoch:3, loss: 0.0256, accuracy: 0.9865\n",
|
223 |
+
"train, epoch:4, loss: 0.0271, accuracy: 0.9868\n",
|
224 |
+
"eval, epoch:4, loss: 0.0378, accuracy: 0.9826\n",
|
225 |
+
"train, epoch:5, loss: 0.0395, accuracy: 0.9844\n",
|
226 |
+
"eval, epoch:5, loss: 0.0307, accuracy: 0.9873\n"
|
227 |
]
|
228 |
}
|
229 |
],
|
|
|
256 |
},
|
257 |
{
|
258 |
"cell_type": "code",
|
259 |
+
"execution_count": 196,
|
260 |
"metadata": {
|
261 |
"tags": [
|
262 |
"exclude"
|
|
|
264 |
},
|
265 |
"outputs": [],
|
266 |
"source": [
|
267 |
+
"torch.save(model.state_dict(), 'classifier.pth')"
|
268 |
+
]
|
269 |
+
},
|
270 |
+
{
|
271 |
+
"cell_type": "code",
|
272 |
+
"execution_count": 197,
|
273 |
+
"metadata": {},
|
274 |
+
"outputs": [],
|
275 |
+
"source": [
|
276 |
+
"loaded_model = cnn_classifier()\n",
|
277 |
+
"loaded_model.load_state_dict(torch.load('classifier.pth'))\n",
|
278 |
+
"loaded_model.eval();"
|
279 |
+
]
|
280 |
+
},
|
281 |
+
{
|
282 |
+
"cell_type": "code",
|
283 |
+
"execution_count": 206,
|
284 |
+
"metadata": {},
|
285 |
+
"outputs": [],
|
286 |
+
"source": [
|
287 |
+
"def predict(img):\n",
|
288 |
+
" with torch.no_grad():\n",
|
289 |
+
" img = img[None,]\n",
|
290 |
+
" pred = loaded_model(img)[0]\n",
|
291 |
+
" pred_probs = F.softmax(pred, dim=0)\n",
|
292 |
+
" pred = [{\"digit\": i, \"prob\": f'{prob*100:.2f}%', 'logits': pred[i]} for i, prob in enumerate(pred_probs)]\n",
|
293 |
+
" pred = sorted(pred, key=lambda ele: ele['digit'], reverse=False)\n",
|
294 |
+
" return pred"
|
295 |
+
]
|
296 |
+
},
|
297 |
+
{
|
298 |
+
"cell_type": "code",
|
299 |
+
"execution_count": 204,
|
300 |
+
"metadata": {
|
301 |
+
"tags": [
|
302 |
+
"exclude"
|
303 |
+
]
|
304 |
+
},
|
305 |
+
"outputs": [
|
306 |
+
{
|
307 |
+
"name": "stdout",
|
308 |
+
"output_type": "stream",
|
309 |
+
"text": [
|
310 |
+
"tensor(5)\n"
|
311 |
+
]
|
312 |
+
},
|
313 |
+
{
|
314 |
+
"data": {
|
315 |
+
"text/plain": [
|
316 |
+
"[{'digit': 0, 'prob': '21.42%', 'logits': tensor(0.0559)},\n",
|
317 |
+
" {'digit': 8, 'prob': '19.44%', 'logits': tensor(-0.0408)},\n",
|
318 |
+
" {'digit': 4, 'prob': '18.08%', 'logits': tensor(-0.1135)},\n",
|
319 |
+
" {'digit': 9, 'prob': '16.41%', 'logits': tensor(-0.2104)},\n",
|
320 |
+
" {'digit': 6, 'prob': '12.23%', 'logits': tensor(-0.5049)},\n",
|
321 |
+
" {'digit': 1, 'prob': '6.87%', 'logits': tensor(-1.0806)},\n",
|
322 |
+
" {'digit': 7, 'prob': '2.33%', 'logits': tensor(-2.1633)},\n",
|
323 |
+
" {'digit': 5, 'prob': '1.19%', 'logits': tensor(-2.8386)},\n",
|
324 |
+
" {'digit': 2, 'prob': '1.06%', 'logits': tensor(-2.9527)},\n",
|
325 |
+
" {'digit': 3, 'prob': '0.97%', 'logits': tensor(-3.0359)}]"
|
326 |
+
]
|
327 |
+
},
|
328 |
+
"execution_count": 204,
|
329 |
+
"metadata": {},
|
330 |
+
"output_type": "execute_result"
|
331 |
+
}
|
332 |
+
],
|
333 |
+
"source": [
|
334 |
+
"img = xb[0].reshape(1, 28, 28)\n",
|
335 |
+
"print(yb[0])\n",
|
336 |
+
"predict(img)"
|
337 |
]
|
338 |
},
|
339 |
{
|
|
|
349 |
},
|
350 |
{
|
351 |
"cell_type": "code",
|
352 |
+
"execution_count": 205,
|
353 |
"metadata": {
|
354 |
"tags": [
|
355 |
"exclude"
|
|
|
361 |
"output_type": "stream",
|
362 |
"text": [
|
363 |
"[NbConvertApp] Converting notebook mnist_classifier.ipynb to script\n",
|
364 |
+
"[NbConvertApp] Writing 2904 bytes to mnist_classifier.py\n"
|
365 |
]
|
366 |
}
|
367 |
],
|
368 |
"source": [
|
369 |
"!jupyter nbconvert --to script --TagRemovePreprocessor.remove_cell_tags=\"exclude\" --TemplateExporter.exclude_input_prompt=True mnist_classifier.ipynb\n"
|
370 |
]
|
371 |
+
},
|
372 |
+
{
|
373 |
+
"cell_type": "code",
|
374 |
+
"execution_count": null,
|
375 |
+
"metadata": {},
|
376 |
+
"outputs": [],
|
377 |
+
"source": []
|
378 |
}
|
379 |
],
|
380 |
"metadata": {
|
mnist_classifier.py
CHANGED
@@ -12,22 +12,12 @@ import torchvision.transforms.functional as TF
|
|
12 |
from torch.utils.data import default_collate, DataLoader
|
13 |
import torch.optim as optim
|
14 |
import pickle
|
15 |
-
get_ipython().run_line_magic('matplotlib', 'inline')
|
16 |
-
plt.rcParams['figure.figsize'] = [2, 2]
|
17 |
-
|
18 |
-
|
19 |
-
dataset_nm = 'mnist'
|
20 |
-
x,y = 'image', 'label'
|
21 |
-
ds = load_dataset(dataset_nm)
|
22 |
|
23 |
|
24 |
def transform_ds(b):
|
25 |
b[x] = [TF.to_tensor(ele) for ele in b[x]]
|
26 |
return b
|
27 |
|
28 |
-
dst = ds.with_transform(transform_ds)
|
29 |
-
plt.imshow(dst['train'][0]['image'].permute(1,2,0));
|
30 |
-
|
31 |
|
32 |
bs = 1024
|
33 |
class DataLoaders:
|
@@ -39,10 +29,6 @@ def collate_fn(b):
|
|
39 |
collate = default_collate(b)
|
40 |
return (collate[x], collate[y])
|
41 |
|
42 |
-
dls = DataLoaders(dst['train'], dst['test'], bs=bs, collate_fn=collate_fn)
|
43 |
-
xb,yb = next(iter(dls.train))
|
44 |
-
xb.shape, yb.shape
|
45 |
-
|
46 |
|
47 |
class Reshape(nn.Module):
|
48 |
def __init__(self, dim):
|
@@ -96,28 +82,20 @@ def kaiming_init(m):
|
|
96 |
nn.init.kaiming_normal_(m.weight)
|
97 |
|
98 |
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
opt.step()
|
116 |
-
opt.zero_grad()
|
117 |
-
with torch.no_grad():
|
118 |
-
accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()
|
119 |
-
if train:
|
120 |
-
sched.step()
|
121 |
-
accuracy /= len(dl)
|
122 |
-
print(f"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}")
|
123 |
|
|
|
12 |
from torch.utils.data import default_collate, DataLoader
|
13 |
import torch.optim as optim
|
14 |
import pickle
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
|
17 |
def transform_ds(b):
|
18 |
b[x] = [TF.to_tensor(ele) for ele in b[x]]
|
19 |
return b
|
20 |
|
|
|
|
|
|
|
21 |
|
22 |
bs = 1024
|
23 |
class DataLoaders:
|
|
|
29 |
collate = default_collate(b)
|
30 |
return (collate[x], collate[y])
|
31 |
|
|
|
|
|
|
|
|
|
32 |
|
33 |
class Reshape(nn.Module):
|
34 |
def __init__(self, dim):
|
|
|
82 |
nn.init.kaiming_normal_(m.weight)
|
83 |
|
84 |
|
85 |
+
loaded_model = cnn_classifier()
|
86 |
+
loaded_model.load_state_dict(torch.load('classifier.pth'))
|
87 |
+
loaded_model.eval();
|
88 |
+
|
89 |
+
|
90 |
+
def predict(img):
|
91 |
+
with torch.no_grad():
|
92 |
+
img = img[None,]
|
93 |
+
pred = loaded_model(img)[0]
|
94 |
+
pred_probs = F.softmax(pred, dim=0)
|
95 |
+
pred = [{"digit": i, "prob": f'{prob*100:.2f}%', 'logits': pred[i]} for i, prob in enumerate(pred_probs)]
|
96 |
+
pred = sorted(pred, key=lambda ele: ele['digit'], reverse=False)
|
97 |
+
return pred
|
98 |
+
|
99 |
+
|
100 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
requirements.txt
CHANGED
@@ -1,3 +1,8 @@
|
|
1 |
fastapi==0.68.1
|
2 |
uvicorn==0.15.0
|
3 |
-
aiofiles
|
|
|
|
|
|
|
|
|
|
|
|
1 |
fastapi==0.68.1
|
2 |
uvicorn==0.15.0
|
3 |
+
aiofiles
|
4 |
+
torch
|
5 |
+
fastcore
|
6 |
+
torchvision
|
7 |
+
datasets
|
8 |
+
|
server.py
CHANGED
@@ -1,7 +1,11 @@
|
|
1 |
-
from fastapi import FastAPI
|
|
|
|
|
2 |
from fastapi.responses import FileResponse
|
3 |
from fastapi.staticfiles import StaticFiles
|
4 |
from pathlib import Path
|
|
|
|
|
5 |
|
6 |
app = FastAPI()
|
7 |
|
@@ -10,6 +14,20 @@ app.mount("/static", StaticFiles(directory=Path("static")), name="static")
|
|
10 |
async def root():
|
11 |
return FileResponse("static/index.html")
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
@app.post("/predict")
|
14 |
-
async def predict():
|
15 |
-
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, File, UploadFile
|
2 |
+
import io
|
3 |
+
from PIL import Image
|
4 |
from fastapi.responses import FileResponse
|
5 |
from fastapi.staticfiles import StaticFiles
|
6 |
from pathlib import Path
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
import mnist_classifier
|
9 |
|
10 |
app = FastAPI()
|
11 |
|
|
|
14 |
async def root():
|
15 |
return FileResponse("static/index.html")
|
16 |
|
17 |
+
def process_image(file: UploadFile):
|
18 |
+
image_bytes = file.file.read()
|
19 |
+
pil_image = Image.open(io.BytesIO(image_bytes))
|
20 |
+
transform = transforms.Compose([
|
21 |
+
transforms.Resize((28, 28)),
|
22 |
+
transforms.Grayscale(num_output_channels=1),
|
23 |
+
transforms.ToTensor(),
|
24 |
+
])
|
25 |
+
tensor_image = transform(pil_image)
|
26 |
+
return tensor_image
|
27 |
+
|
28 |
@app.post("/predict")
|
29 |
+
async def predict(image: UploadFile):
|
30 |
+
tensor_image = process_image(image)
|
31 |
+
prediction = mnist_classifier.predict(tensor_image)
|
32 |
+
return {"prediction": prediction}
|
33 |
+
|
static/index.html
CHANGED
@@ -4,14 +4,46 @@
|
|
4 |
<head>
|
5 |
<meta charset="UTF-8">
|
6 |
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
7 |
-
<link rel="icon" type="image/png" href="favicon.png">
|
8 |
-
<link rel="icon" type="image/x-icon" href="favicon.ico">
|
9 |
<title>Draw and Predict Handwritten Digits</title>
|
10 |
<style>
|
11 |
body {
|
12 |
font-family: 'Montserrat', sans-serif;
|
13 |
}
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
#whiteboard {
|
16 |
border: 3px solid #088395;
|
17 |
/* Simple black border */
|
@@ -52,6 +84,7 @@
|
|
52 |
flex-direction: column;
|
53 |
align-items: center;
|
54 |
justify-content: center;
|
|
|
55 |
}
|
56 |
|
57 |
#btn-container {
|
@@ -59,19 +92,28 @@
|
|
59 |
flex-direction: row;
|
60 |
align-items: center;
|
61 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
</style>
|
63 |
</head>
|
64 |
|
65 |
<body>
|
66 |
<h3 style="text-align:center;">Draw and Predict Handwritten Digits</h3>
|
|
|
67 |
<div id='container'>
|
68 |
<canvas id="whiteboard" width="400" height="200"></canvas>
|
69 |
<div id='btn-container'>
|
70 |
<button id="capture-button">Predict</button>
|
71 |
<button id="clear-button">Clear</button>
|
72 |
</div>
|
73 |
-
<div id="prediction-result"></div>
|
74 |
</div>
|
|
|
|
|
|
|
75 |
<script>
|
76 |
var canvas = document.getElementById('whiteboard');
|
77 |
var context = canvas.getContext('2d');
|
@@ -103,27 +145,80 @@
|
|
103 |
clearButton.addEventListener('click', function () {
|
104 |
context.clearRect(0, 0, canvas.width, canvas.height);
|
105 |
});
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
var predictionResult = document.getElementById('prediction-result');
|
108 |
var captureButton = document.getElementById('capture-button');
|
109 |
captureButton.addEventListener('click', function () {
|
|
|
|
|
|
|
110 |
var imageData = canvas.toDataURL("image/png");
|
|
|
|
|
111 |
fetch('/predict', {
|
112 |
method: 'POST',
|
113 |
-
|
114 |
-
'Content-Type': 'application/json',
|
115 |
-
},
|
116 |
-
body: JSON.stringify({ image: imageData }), // Send the image data as JSON
|
117 |
})
|
118 |
.then(response => response.json())
|
119 |
.then(data => {
|
120 |
-
predictionResult.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
})
|
122 |
.catch(error => {
|
123 |
console.error('Error:', error);
|
124 |
-
predictionResult.textContent = '
|
125 |
-
|
|
|
|
|
|
|
126 |
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
</script>
|
128 |
</body>
|
129 |
|
|
|
4 |
<head>
|
5 |
<meta charset="UTF-8">
|
6 |
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
7 |
+
<link rel="icon" type="image/png" href="/static/favicon.png">
|
8 |
+
<link rel="icon" type="image/x-icon" href="/static/favicon.ico">
|
9 |
<title>Draw and Predict Handwritten Digits</title>
|
10 |
<style>
|
11 |
body {
|
12 |
font-family: 'Montserrat', sans-serif;
|
13 |
}
|
14 |
|
15 |
+
#prediction-result {
|
16 |
+
}
|
17 |
+
|
18 |
+
.prediction-item {
|
19 |
+
margin-bottom: 10px;
|
20 |
+
display: flex;
|
21 |
+
}
|
22 |
+
|
23 |
+
.prediction-digit {
|
24 |
+
font-size: 14px;
|
25 |
+
}
|
26 |
+
|
27 |
+
.progress-container {
|
28 |
+
display: flex;
|
29 |
+
align-items: center;
|
30 |
+
}
|
31 |
+
|
32 |
+
.progress-bar {
|
33 |
+
width: 100px;
|
34 |
+
height: 10px;
|
35 |
+
background-color: #ccc;
|
36 |
+
border-radius: 6px;
|
37 |
+
margin-left: 10px;
|
38 |
+
}
|
39 |
+
|
40 |
+
.progress {
|
41 |
+
height: 100%;
|
42 |
+
border-radius: 5px;
|
43 |
+
background-color: #088395;
|
44 |
+
transition: width 0.3s ease-in-out;
|
45 |
+
}
|
46 |
+
|
47 |
#whiteboard {
|
48 |
border: 3px solid #088395;
|
49 |
/* Simple black border */
|
|
|
84 |
flex-direction: column;
|
85 |
align-items: center;
|
86 |
justify-content: center;
|
87 |
+
margin-right: 10px;
|
88 |
}
|
89 |
|
90 |
#btn-container {
|
|
|
92 |
flex-direction: row;
|
93 |
align-items: center;
|
94 |
}
|
95 |
+
|
96 |
+
#parent {
|
97 |
+
display: flex;
|
98 |
+
justify-content: center;
|
99 |
+
align-items: center;
|
100 |
+
}
|
101 |
</style>
|
102 |
</head>
|
103 |
|
104 |
<body>
|
105 |
<h3 style="text-align:center;">Draw and Predict Handwritten Digits</h3>
|
106 |
+
<div id="parent">
|
107 |
<div id='container'>
|
108 |
<canvas id="whiteboard" width="400" height="200"></canvas>
|
109 |
<div id='btn-container'>
|
110 |
<button id="capture-button">Predict</button>
|
111 |
<button id="clear-button">Clear</button>
|
112 |
</div>
|
|
|
113 |
</div>
|
114 |
+
<div id="prediction-result"></div>
|
115 |
+
</div>
|
116 |
+
|
117 |
<script>
|
118 |
var canvas = document.getElementById('whiteboard');
|
119 |
var context = canvas.getContext('2d');
|
|
|
145 |
clearButton.addEventListener('click', function () {
|
146 |
context.clearRect(0, 0, canvas.width, canvas.height);
|
147 |
});
|
148 |
+
function isCanvasEmpty(canvas) {
|
149 |
+
const context = canvas.getContext('2d');
|
150 |
+
const pixelBuffer = new Uint32Array(context.getImageData(0, 0, canvas.width, canvas.height).data.buffer);
|
151 |
+
return !pixelBuffer.some(color => color !== 0);
|
152 |
+
}
|
153 |
|
154 |
var predictionResult = document.getElementById('prediction-result');
|
155 |
var captureButton = document.getElementById('capture-button');
|
156 |
captureButton.addEventListener('click', function () {
|
157 |
+
captureButton.disabled = true
|
158 |
+
captureButton.style.opacity = "0.5";
|
159 |
+
captureButton.innerText = 'Loading...'
|
160 |
var imageData = canvas.toDataURL("image/png");
|
161 |
+
var formData = new FormData();
|
162 |
+
formData.append("image", dataURLtoBlob(imageData), "image.png");
|
163 |
fetch('/predict', {
|
164 |
method: 'POST',
|
165 |
+
body: formData,
|
|
|
|
|
|
|
166 |
})
|
167 |
.then(response => response.json())
|
168 |
.then(data => {
|
169 |
+
predictionResult.innerHTML = '';
|
170 |
+
data.prediction.forEach(item => {
|
171 |
+
const digitItem = document.createElement('div');
|
172 |
+
digitItem.classList.add('prediction-item');
|
173 |
+
|
174 |
+
const digitText = document.createElement('div');
|
175 |
+
digitText.classList.add('prediction-digit');
|
176 |
+
digitText.textContent = `Digit ${item.digit}:`;
|
177 |
+
|
178 |
+
const progressContainer = document.createElement('div');
|
179 |
+
progressContainer.classList.add('progress-container');
|
180 |
+
|
181 |
+
const progressBar = document.createElement('div');
|
182 |
+
progressBar.classList.add('progress-bar');
|
183 |
+
|
184 |
+
const progress = document.createElement('div');
|
185 |
+
progress.classList.add('progress');
|
186 |
+
progress.style.width = item.prob;
|
187 |
+
|
188 |
+
progressBar.appendChild(progress);
|
189 |
+
progressContainer.appendChild(progressBar);
|
190 |
+
|
191 |
+
digitItem.appendChild(digitText);
|
192 |
+
digitItem.appendChild(progressContainer);
|
193 |
+
|
194 |
+
predictionResult.appendChild(digitItem);
|
195 |
+
|
196 |
+
captureButton.disabled = false
|
197 |
+
captureButton.style.opacity = "1";
|
198 |
+
captureButton.innerText = 'Predict'
|
199 |
+
});
|
200 |
})
|
201 |
.catch(error => {
|
202 |
console.error('Error:', error);
|
203 |
+
predictionResult.textContent = 'Something wentr wrong, try again.';
|
204 |
+
captureButton.disabled = false
|
205 |
+
captureButton.style.opacity = "1";
|
206 |
+
captureButton.innerText = 'Predict'
|
207 |
+
});
|
208 |
});
|
209 |
+
|
210 |
+
// Function to convert a data URL to a Blob
|
211 |
+
function dataURLtoBlob(dataURL) {
|
212 |
+
var arr = dataURL.split(',');
|
213 |
+
var mime = arr[0].match(/:(.*?);/)[1];
|
214 |
+
var bstr = atob(arr[1]);
|
215 |
+
var n = bstr.length;
|
216 |
+
var u8arr = new Uint8Array(n);
|
217 |
+
while (n--) {
|
218 |
+
u8arr[n] = bstr.charCodeAt(n);
|
219 |
+
}
|
220 |
+
return new Blob([u8arr], { type: mime });
|
221 |
+
}
|
222 |
</script>
|
223 |
</body>
|
224 |
|