carlfeynman commited on
Commit
c32023c
1 Parent(s): fccfd79

full model inference/ui added

Browse files
__pycache__/mnist_classifier.cpython-39.pyc ADDED
Binary file (4.5 kB). View file
 
__pycache__/server.cpython-39.pyc ADDED
Binary file (1.32 kB). View file
 
classifier.pkl DELETED
Binary file (654 kB)
 
classifier.pth ADDED
Binary file (647 kB). View file
 
mnist_classifier.ipynb CHANGED
@@ -16,7 +16,19 @@
16
  "import torchvision.transforms.functional as TF\n",
17
  "from torch.utils.data import default_collate, DataLoader\n",
18
  "import torch.optim as optim\n",
19
- "import pickle\n",
 
 
 
 
 
 
 
 
 
 
 
 
20
  "%matplotlib inline\n",
21
  "plt.rcParams['figure.figsize'] = [2, 2]"
22
  ]
@@ -24,7 +36,11 @@
24
  {
25
  "cell_type": "code",
26
  "execution_count": 101,
27
- "metadata": {},
 
 
 
 
28
  "outputs": [
29
  {
30
  "name": "stderr",
@@ -43,27 +59,25 @@
43
  },
44
  {
45
  "cell_type": "code",
46
- "execution_count": 102,
47
  "metadata": {},
48
- "outputs": [
49
- {
50
- "data": {
51
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAI4AAACOCAYAAADn/TAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAIsUlEQVR4nO3df2yU9R0H8PfHtrQroFJBVrGjHVRAweHWCASCJBuumiXOLAyYWTbjQiYy58Y2fmzZ5oILJgsJMjSRrCsmig7mAjFsZBIlLkNGdeBgrOWnWqnFwkDmUNrrZ3/0bPu59cfTz3P33NPr+5WQu89zd32+MW+/z/eeu+dzoqogGqgrsj0AGpwYHHJhcMiFwSEXBodcGBxyCRUcEakWkXoROSYiK9M1KIo/8Z7HEZE8AA0A5gNoBLAfwGJV/Wf6hkdxlR/itbcCOKaqJwBARJ4FcBeAXoMzTAq1CMND7JKidhH/blHVManbwwRnHIC3u9WNAGb09YIiDMcM+XyIXVLUXtRtb/a0PUxwpIdt/3fcE5ElAJYAQBGKQ+yO4iTM4rgRQFm3+noAp1OfpKpPqmqVqlYVoDDE7ihOwgRnP4BKEakQkWEAFgHYkZ5hUdy5D1Wq2iYiywDsApAHoEZVD6dtZBRrYdY4UNWdAHamaSw0iPDMMbkwOOTC4JALg0MuDA65MDjkwuCQC4NDLgwOuTA45MLgkAuDQy6hPuQcSiTf/qfKGzM68Gvrf1Bu6kRxu6nHTzhj6uKl9jty764bZurXq54zdUviA1PP2Lq88/7E778aeJwDwRmHXBgccmFwyGXIrHHyplSaWgsLTH36tqtNfWmmXTeUXGXrVz5j1xlh/PG/I0396K+rTb1v2jOmPtl6ydRrm+eb+rpXMt/ziDMOuTA45MLgkEvOrnES8z5r6nW1G019Q4E9NxKlVk2Y+qcbvmnq/A/sGmXW1mWmHvlOm6kLW+yap7huX8gR9o8zDrkwOOTC4JBLzq5xCuvtZeyvfVhm6hsKmtO2r+VNM0194j/2c6zaCdtMfaHdrmHGPvbXUPvPRqdqzjjkwuCQC4NDLjm7xmlretfUGx5dYOpHqu1nT3lvjDD1waUb+vz7a1pu7rx/7Au2YVTifJOpvzZrqalPPWj/VgUO9rmvOOKMQy79BkdEakTkjIgc6ratRET+LCJHk7ejMjtMipsgM04tgOqUbSsB7FbVSgC7kzUNIYH6HItIOYAXVHVqsq4HME9Vm0SkFMDLqjqpv79zpZRoXLqO5o2+xtSJs+dMffKZm019eG6NqW/95Xc671+7Mdx5mDh7Ube9pqpVqdu9a5yxqtoEAMnba8MMjgafjL+rYrva3OSdcZqThygkb8/09kS2q81N3hlnB4BvAFibvN2ethFFJNFyts/HW9/v+/s6N93T9csD7z2RZx9sTyDXBXk7vgXAXgCTRKRRRO5DR2Dmi8hRdPwIyNrMDpPipt8ZR1UX9/JQPN4eUVbwzDG55OxnVWFNWdFg6nun2Qn2t+N3d96/bcED5rGRz2Xmeu044YxDLgwOuTA45MI1Ti8S5y+Y+uz9U0z91o6ua5lWrnnKPLbqq3ebWv9+lanLHtlrd+b8XdRs4oxDLgwOufBQFVD7wSOmXvTwDzvvP/2zX5nHDsy0hy7Yq2dw03B7SW/lJvtV07YTp3yDjBBnHHJhcMiFwSGXQF8dTZc4fXU0nXT2dFNfubbR1Fs+vavP109+6VumnvSwPRWQOHrCP7iQ0v3VURriGBxyYXDIhWucDMgbay/6OL1woqn3rVhv6itS/v+95+Ttpr4wp++vuWYS1ziUVgwOuTA45MLPqjIg0WwvMxv7mK0//JFtN1ss9lKcTeUvmPpLdz9kn/+HzLej7Q9nHHJhcMiFwSEXrnHSoH3OdFMfX1Bk6qnTT5k6dU2TasO5W+zzt9e5x5YpnHHIhcEhFwaHXLjGCUiqppq64cGudcqm2ZvNY3OLLg/ob3+kraZ+9VyFfUK7/U5yHHDGIZcg/XHKROQlETkiIodF5LvJ7WxZO4QFmXHaACxX1SnouNDjARG5EWxZO6QFaazUBODjDqMXReQIgHEA7gIwL/m0zQBeBrAiI6OMQH7FeFMfv/c6U/984bOm/sqIFve+Vjfbr7fsWW8vvBq1OeUS4Rga0Bon2e/4FgD7wJa1Q1rg4IjICAC/B/CQqr4/gNctEZE6EalrxUeeMVIMBQqOiBSgIzRPq+rzyc2BWtayXW1u6neNIyIC4DcAjqjqum4PDaqWtfnlnzL1hc+VmnrhL/5k6m9f/Ty8Un9qce/jdk1TUvs3U49qj/+aJlWQE4CzAXwdwD9E5EBy22p0BOZ3yfa1bwFY0PPLKRcFeVf1FwDSy8O5f8kC9YhnjsklZz6ryi/9pKnP1Qw39f0Ve0y9eGS4n49e9s6czvuvPzHdPDZ62yFTl1wcfGuY/nDGIRcGh1wYHHIZVGucy1/sOh9y+Xv2pxBXT9xp6ts/YX8eeqCaE5dMPXfHclNP/sm/Ou+XnLdrmPZQex4cOOOQC4NDLoPqUHXqy105b5i2dUCv3Xh+gqnX77GtRCRhz3FOXnPS1JXN9rLb3P8NvL5xxiEXBodcGBxyYSs36hNbuVFaMTjkwuCQC4NDLgwOuTA45MLgkAuDQy4MDrkwOOTC4JBLpJ9Vich7AN4EMBqAv09IZnFs1nhVHZO6MdLgdO5UpK6nD87igGMLhocqcmFwyCVbwXkyS/sNgmMLICtrHBr8eKgil0iDIyLVIlIvIsdEJKvtbUWkRkTOiMihbtti0bt5MPSWjiw4IpIHYCOAOwDcCGBxsl9yttQCqE7ZFpfezfHvLa2qkfwDMAvArm71KgCrotp/L2MqB3CoW10PoDR5vxRAfTbH121c2wHMj9P4ojxUjQPwdre6MbktTmLXuzmuvaWjDE5PfQT5lq4P3t7SUYgyOI0AyrrV1wM4HeH+gwjUuzkKYXpLRyHK4OwHUCkiFSIyDMAidPRKjpOPezcDWezdHKC3NJDt3tIRL/LuBNAA4DiAH2d5wbkFHT9u0oqO2fA+ANeg493K0eRtSZbGNgcdh/E3ABxI/rszLuNTVZ45Jh+eOSYXBodcGBxyYXDIhcEhFwaHXBgccmFwyOV/atVD7hyCzrEAAAAASUVORK5CYII=",
52
- "text/plain": [
53
- "<Figure size 144x144 with 1 Axes>"
54
- ]
55
- },
56
- "metadata": {
57
- "needs_background": "light"
58
- },
59
- "output_type": "display_data"
60
- }
61
- ],
62
  "source": [
63
  "def transform_ds(b):\n",
64
  " b[x] = [TF.to_tensor(ele) for ele in b[x]]\n",
65
- " return b\n",
66
- "\n",
 
 
 
 
 
 
 
 
 
 
 
67
  "dst = ds.with_transform(transform_ds)\n",
68
  "plt.imshow(dst['train'][0]['image'].permute(1,2,0));"
69
  ]
@@ -93,8 +107,19 @@
93
  "\n",
94
  "def collate_fn(b):\n",
95
  " collate = default_collate(b)\n",
96
- " return (collate[x], collate[y])\n",
97
- "\n",
 
 
 
 
 
 
 
 
 
 
 
98
  "dls = DataLoaders(dst['train'], dst['test'], bs=bs, collate_fn=collate_fn)\n",
99
  "xb,yb = next(iter(dls.train))\n",
100
  "xb.shape, yb.shape"
@@ -178,23 +203,27 @@
178
  },
179
  {
180
  "cell_type": "code",
181
- "execution_count": 109,
182
- "metadata": {},
 
 
 
 
183
  "outputs": [
184
  {
185
  "name": "stdout",
186
  "output_type": "stream",
187
  "text": [
188
- "train, epoch:1, loss: 0.1077, accuracy: 0.9104\n",
189
- "eval, epoch:1, loss: 0.0382, accuracy: 0.9791\n",
190
- "train, epoch:2, loss: 0.0410, accuracy: 0.9832\n",
191
- "eval, epoch:2, loss: 0.0221, accuracy: 0.9866\n",
192
- "train, epoch:3, loss: 0.0538, accuracy: 0.9871\n",
193
- "eval, epoch:3, loss: 0.0141, accuracy: 0.9887\n",
194
- "train, epoch:4, loss: 0.0343, accuracy: 0.9858\n",
195
- "eval, epoch:4, loss: 0.0163, accuracy: 0.9871\n",
196
- "train, epoch:5, loss: 0.0390, accuracy: 0.9865\n",
197
- "eval, epoch:5, loss: 0.0169, accuracy: 0.9871\n"
198
  ]
199
  }
200
  ],
@@ -227,7 +256,7 @@
227
  },
228
  {
229
  "cell_type": "code",
230
- "execution_count": 110,
231
  "metadata": {
232
  "tags": [
233
  "exclude"
@@ -235,8 +264,76 @@
235
  },
236
  "outputs": [],
237
  "source": [
238
- "with open('./classifier.pkl', 'wb') as model_file:\n",
239
- " pickle.dump(model, model_file)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  ]
241
  },
242
  {
@@ -252,7 +349,7 @@
252
  },
253
  {
254
  "cell_type": "code",
255
- "execution_count": 111,
256
  "metadata": {
257
  "tags": [
258
  "exclude"
@@ -264,13 +361,20 @@
264
  "output_type": "stream",
265
  "text": [
266
  "[NbConvertApp] Converting notebook mnist_classifier.ipynb to script\n",
267
- "[NbConvertApp] Writing 3691 bytes to mnist_classifier.py\n"
268
  ]
269
  }
270
  ],
271
  "source": [
272
  "!jupyter nbconvert --to script --TagRemovePreprocessor.remove_cell_tags=\"exclude\" --TemplateExporter.exclude_input_prompt=True mnist_classifier.ipynb\n"
273
  ]
 
 
 
 
 
 
 
274
  }
275
  ],
276
  "metadata": {
 
16
  "import torchvision.transforms.functional as TF\n",
17
  "from torch.utils.data import default_collate, DataLoader\n",
18
  "import torch.optim as optim\n",
19
+ "import pickle\n"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": null,
25
+ "metadata": {
26
+ "tags": [
27
+ "exclude"
28
+ ]
29
+ },
30
+ "outputs": [],
31
+ "source": [
32
  "%matplotlib inline\n",
33
  "plt.rcParams['figure.figsize'] = [2, 2]"
34
  ]
 
36
  {
37
  "cell_type": "code",
38
  "execution_count": 101,
39
+ "metadata": {
40
+ "tags": [
41
+ "exclude"
42
+ ]
43
+ },
44
  "outputs": [
45
  {
46
  "name": "stderr",
 
59
  },
60
  {
61
  "cell_type": "code",
62
+ "execution_count": 112,
63
  "metadata": {},
64
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  "source": [
66
  "def transform_ds(b):\n",
67
  " b[x] = [TF.to_tensor(ele) for ele in b[x]]\n",
68
+ " return b"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": null,
74
+ "metadata": {
75
+ "tags": [
76
+ "exclude"
77
+ ]
78
+ },
79
+ "outputs": [],
80
+ "source": [
81
  "dst = ds.with_transform(transform_ds)\n",
82
  "plt.imshow(dst['train'][0]['image'].permute(1,2,0));"
83
  ]
 
107
  "\n",
108
  "def collate_fn(b):\n",
109
  " collate = default_collate(b)\n",
110
+ " return (collate[x], collate[y])"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "code",
115
+ "execution_count": null,
116
+ "metadata": {
117
+ "tags": [
118
+ "exclude"
119
+ ]
120
+ },
121
+ "outputs": [],
122
+ "source": [
123
  "dls = DataLoaders(dst['train'], dst['test'], bs=bs, collate_fn=collate_fn)\n",
124
  "xb,yb = next(iter(dls.train))\n",
125
  "xb.shape, yb.shape"
 
203
  },
204
  {
205
  "cell_type": "code",
206
+ "execution_count": 195,
207
+ "metadata": {
208
+ "tags": [
209
+ "exclude"
210
+ ]
211
+ },
212
  "outputs": [
213
  {
214
  "name": "stdout",
215
  "output_type": "stream",
216
  "text": [
217
+ "train, epoch:1, loss: 0.0776, accuracy: 0.9172\n",
218
+ "eval, epoch:1, loss: 0.0372, accuracy: 0.9818\n",
219
+ "train, epoch:2, loss: 0.0571, accuracy: 0.9828\n",
220
+ "eval, epoch:2, loss: 0.0287, accuracy: 0.9863\n",
221
+ "train, epoch:3, loss: 0.0425, accuracy: 0.9847\n",
222
+ "eval, epoch:3, loss: 0.0256, accuracy: 0.9865\n",
223
+ "train, epoch:4, loss: 0.0271, accuracy: 0.9868\n",
224
+ "eval, epoch:4, loss: 0.0378, accuracy: 0.9826\n",
225
+ "train, epoch:5, loss: 0.0395, accuracy: 0.9844\n",
226
+ "eval, epoch:5, loss: 0.0307, accuracy: 0.9873\n"
227
  ]
228
  }
229
  ],
 
256
  },
257
  {
258
  "cell_type": "code",
259
+ "execution_count": 196,
260
  "metadata": {
261
  "tags": [
262
  "exclude"
 
264
  },
265
  "outputs": [],
266
  "source": [
267
+ "torch.save(model.state_dict(), 'classifier.pth')"
268
+ ]
269
+ },
270
+ {
271
+ "cell_type": "code",
272
+ "execution_count": 197,
273
+ "metadata": {},
274
+ "outputs": [],
275
+ "source": [
276
+ "loaded_model = cnn_classifier()\n",
277
+ "loaded_model.load_state_dict(torch.load('classifier.pth'))\n",
278
+ "loaded_model.eval();"
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "execution_count": 206,
284
+ "metadata": {},
285
+ "outputs": [],
286
+ "source": [
287
+ "def predict(img):\n",
288
+ " with torch.no_grad():\n",
289
+ " img = img[None,]\n",
290
+ " pred = loaded_model(img)[0]\n",
291
+ " pred_probs = F.softmax(pred, dim=0)\n",
292
+ " pred = [{\"digit\": i, \"prob\": f'{prob*100:.2f}%', 'logits': pred[i]} for i, prob in enumerate(pred_probs)]\n",
293
+ " pred = sorted(pred, key=lambda ele: ele['digit'], reverse=False)\n",
294
+ " return pred"
295
+ ]
296
+ },
297
+ {
298
+ "cell_type": "code",
299
+ "execution_count": 204,
300
+ "metadata": {
301
+ "tags": [
302
+ "exclude"
303
+ ]
304
+ },
305
+ "outputs": [
306
+ {
307
+ "name": "stdout",
308
+ "output_type": "stream",
309
+ "text": [
310
+ "tensor(5)\n"
311
+ ]
312
+ },
313
+ {
314
+ "data": {
315
+ "text/plain": [
316
+ "[{'digit': 0, 'prob': '21.42%', 'logits': tensor(0.0559)},\n",
317
+ " {'digit': 8, 'prob': '19.44%', 'logits': tensor(-0.0408)},\n",
318
+ " {'digit': 4, 'prob': '18.08%', 'logits': tensor(-0.1135)},\n",
319
+ " {'digit': 9, 'prob': '16.41%', 'logits': tensor(-0.2104)},\n",
320
+ " {'digit': 6, 'prob': '12.23%', 'logits': tensor(-0.5049)},\n",
321
+ " {'digit': 1, 'prob': '6.87%', 'logits': tensor(-1.0806)},\n",
322
+ " {'digit': 7, 'prob': '2.33%', 'logits': tensor(-2.1633)},\n",
323
+ " {'digit': 5, 'prob': '1.19%', 'logits': tensor(-2.8386)},\n",
324
+ " {'digit': 2, 'prob': '1.06%', 'logits': tensor(-2.9527)},\n",
325
+ " {'digit': 3, 'prob': '0.97%', 'logits': tensor(-3.0359)}]"
326
+ ]
327
+ },
328
+ "execution_count": 204,
329
+ "metadata": {},
330
+ "output_type": "execute_result"
331
+ }
332
+ ],
333
+ "source": [
334
+ "img = xb[0].reshape(1, 28, 28)\n",
335
+ "print(yb[0])\n",
336
+ "predict(img)"
337
  ]
338
  },
339
  {
 
349
  },
350
  {
351
  "cell_type": "code",
352
+ "execution_count": 205,
353
  "metadata": {
354
  "tags": [
355
  "exclude"
 
361
  "output_type": "stream",
362
  "text": [
363
  "[NbConvertApp] Converting notebook mnist_classifier.ipynb to script\n",
364
+ "[NbConvertApp] Writing 2904 bytes to mnist_classifier.py\n"
365
  ]
366
  }
367
  ],
368
  "source": [
369
  "!jupyter nbconvert --to script --TagRemovePreprocessor.remove_cell_tags=\"exclude\" --TemplateExporter.exclude_input_prompt=True mnist_classifier.ipynb\n"
370
  ]
371
+ },
372
+ {
373
+ "cell_type": "code",
374
+ "execution_count": null,
375
+ "metadata": {},
376
+ "outputs": [],
377
+ "source": []
378
  }
379
  ],
380
  "metadata": {
mnist_classifier.py CHANGED
@@ -12,22 +12,12 @@ import torchvision.transforms.functional as TF
12
  from torch.utils.data import default_collate, DataLoader
13
  import torch.optim as optim
14
  import pickle
15
- get_ipython().run_line_magic('matplotlib', 'inline')
16
- plt.rcParams['figure.figsize'] = [2, 2]
17
-
18
-
19
- dataset_nm = 'mnist'
20
- x,y = 'image', 'label'
21
- ds = load_dataset(dataset_nm)
22
 
23
 
24
  def transform_ds(b):
25
  b[x] = [TF.to_tensor(ele) for ele in b[x]]
26
  return b
27
 
28
- dst = ds.with_transform(transform_ds)
29
- plt.imshow(dst['train'][0]['image'].permute(1,2,0));
30
-
31
 
32
  bs = 1024
33
  class DataLoaders:
@@ -39,10 +29,6 @@ def collate_fn(b):
39
  collate = default_collate(b)
40
  return (collate[x], collate[y])
41
 
42
- dls = DataLoaders(dst['train'], dst['test'], bs=bs, collate_fn=collate_fn)
43
- xb,yb = next(iter(dls.train))
44
- xb.shape, yb.shape
45
-
46
 
47
  class Reshape(nn.Module):
48
  def __init__(self, dim):
@@ -96,28 +82,20 @@ def kaiming_init(m):
96
  nn.init.kaiming_normal_(m.weight)
97
 
98
 
99
- model = cnn_classifier()
100
- model.apply(kaiming_init)
101
- lr = 0.1
102
- max_lr = 0.3
103
- epochs = 5
104
- opt = optim.AdamW(model.parameters(), lr=lr)
105
- sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)
106
- for epoch in range(epochs):
107
- for train in (True, False):
108
- accuracy = 0
109
- dl = dls.train if train else dls.valid
110
- for xb,yb in dl:
111
- preds = model(xb)
112
- loss = F.cross_entropy(preds, yb)
113
- if train:
114
- loss.backward()
115
- opt.step()
116
- opt.zero_grad()
117
- with torch.no_grad():
118
- accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()
119
- if train:
120
- sched.step()
121
- accuracy /= len(dl)
122
- print(f"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}")
123
 
 
12
  from torch.utils.data import default_collate, DataLoader
13
  import torch.optim as optim
14
  import pickle
 
 
 
 
 
 
 
15
 
16
 
17
  def transform_ds(b):
18
  b[x] = [TF.to_tensor(ele) for ele in b[x]]
19
  return b
20
 
 
 
 
21
 
22
  bs = 1024
23
  class DataLoaders:
 
29
  collate = default_collate(b)
30
  return (collate[x], collate[y])
31
 
 
 
 
 
32
 
33
  class Reshape(nn.Module):
34
  def __init__(self, dim):
 
82
  nn.init.kaiming_normal_(m.weight)
83
 
84
 
85
+ loaded_model = cnn_classifier()
86
+ loaded_model.load_state_dict(torch.load('classifier.pth'))
87
+ loaded_model.eval();
88
+
89
+
90
+ def predict(img):
91
+ with torch.no_grad():
92
+ img = img[None,]
93
+ pred = loaded_model(img)[0]
94
+ pred_probs = F.softmax(pred, dim=0)
95
+ pred = [{"digit": i, "prob": f'{prob*100:.2f}%', 'logits': pred[i]} for i, prob in enumerate(pred_probs)]
96
+ pred = sorted(pred, key=lambda ele: ele['digit'], reverse=False)
97
+ return pred
98
+
99
+
100
+
 
 
 
 
 
 
 
 
101
 
requirements.txt CHANGED
@@ -1,3 +1,8 @@
1
  fastapi==0.68.1
2
  uvicorn==0.15.0
3
- aiofiles
 
 
 
 
 
 
1
  fastapi==0.68.1
2
  uvicorn==0.15.0
3
+ aiofiles
4
+ torch
5
+ fastcore
6
+ torchvision
7
+ datasets
8
+
server.py CHANGED
@@ -1,7 +1,11 @@
1
- from fastapi import FastAPI
 
 
2
  from fastapi.responses import FileResponse
3
  from fastapi.staticfiles import StaticFiles
4
  from pathlib import Path
 
 
5
 
6
  app = FastAPI()
7
 
@@ -10,6 +14,20 @@ app.mount("/static", StaticFiles(directory=Path("static")), name="static")
10
  async def root():
11
  return FileResponse("static/index.html")
12
 
 
 
 
 
 
 
 
 
 
 
 
13
  @app.post("/predict")
14
- async def predict():
15
- return {"prediction": "Hello, World!"}
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ import io
3
+ from PIL import Image
4
  from fastapi.responses import FileResponse
5
  from fastapi.staticfiles import StaticFiles
6
  from pathlib import Path
7
+ import torchvision.transforms as transforms
8
+ import mnist_classifier
9
 
10
  app = FastAPI()
11
 
 
14
  async def root():
15
  return FileResponse("static/index.html")
16
 
17
+ def process_image(file: UploadFile):
18
+ image_bytes = file.file.read()
19
+ pil_image = Image.open(io.BytesIO(image_bytes))
20
+ transform = transforms.Compose([
21
+ transforms.Resize((28, 28)),
22
+ transforms.Grayscale(num_output_channels=1),
23
+ transforms.ToTensor(),
24
+ ])
25
+ tensor_image = transform(pil_image)
26
+ return tensor_image
27
+
28
  @app.post("/predict")
29
+ async def predict(image: UploadFile):
30
+ tensor_image = process_image(image)
31
+ prediction = mnist_classifier.predict(tensor_image)
32
+ return {"prediction": prediction}
33
+
static/index.html CHANGED
@@ -4,14 +4,46 @@
4
  <head>
5
  <meta charset="UTF-8">
6
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
7
- <link rel="icon" type="image/png" href="favicon.png">
8
- <link rel="icon" type="image/x-icon" href="favicon.ico">
9
  <title>Draw and Predict Handwritten Digits</title>
10
  <style>
11
  body {
12
  font-family: 'Montserrat', sans-serif;
13
  }
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  #whiteboard {
16
  border: 3px solid #088395;
17
  /* Simple black border */
@@ -52,6 +84,7 @@
52
  flex-direction: column;
53
  align-items: center;
54
  justify-content: center;
 
55
  }
56
 
57
  #btn-container {
@@ -59,19 +92,28 @@
59
  flex-direction: row;
60
  align-items: center;
61
  }
 
 
 
 
 
 
62
  </style>
63
  </head>
64
 
65
  <body>
66
  <h3 style="text-align:center;">Draw and Predict Handwritten Digits</h3>
 
67
  <div id='container'>
68
  <canvas id="whiteboard" width="400" height="200"></canvas>
69
  <div id='btn-container'>
70
  <button id="capture-button">Predict</button>
71
  <button id="clear-button">Clear</button>
72
  </div>
73
- <div id="prediction-result"></div>
74
  </div>
 
 
 
75
  <script>
76
  var canvas = document.getElementById('whiteboard');
77
  var context = canvas.getContext('2d');
@@ -103,27 +145,80 @@
103
  clearButton.addEventListener('click', function () {
104
  context.clearRect(0, 0, canvas.width, canvas.height);
105
  });
 
 
 
 
 
106
 
107
  var predictionResult = document.getElementById('prediction-result');
108
  var captureButton = document.getElementById('capture-button');
109
  captureButton.addEventListener('click', function () {
 
 
 
110
  var imageData = canvas.toDataURL("image/png");
 
 
111
  fetch('/predict', {
112
  method: 'POST',
113
- headers: {
114
- 'Content-Type': 'application/json',
115
- },
116
- body: JSON.stringify({ image: imageData }), // Send the image data as JSON
117
  })
118
  .then(response => response.json())
119
  .then(data => {
120
- predictionResult.textContent = 'Predicted Digit: ' + data.prediction;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  })
122
  .catch(error => {
123
  console.error('Error:', error);
124
- predictionResult.textContent = 'Prediction failed.';
125
- });
 
 
 
126
  });
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  </script>
128
  </body>
129
 
 
4
  <head>
5
  <meta charset="UTF-8">
6
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
7
+ <link rel="icon" type="image/png" href="/static/favicon.png">
8
+ <link rel="icon" type="image/x-icon" href="/static/favicon.ico">
9
  <title>Draw and Predict Handwritten Digits</title>
10
  <style>
11
  body {
12
  font-family: 'Montserrat', sans-serif;
13
  }
14
 
15
+ #prediction-result {
16
+ }
17
+
18
+ .prediction-item {
19
+ margin-bottom: 10px;
20
+ display: flex;
21
+ }
22
+
23
+ .prediction-digit {
24
+ font-size: 14px;
25
+ }
26
+
27
+ .progress-container {
28
+ display: flex;
29
+ align-items: center;
30
+ }
31
+
32
+ .progress-bar {
33
+ width: 100px;
34
+ height: 10px;
35
+ background-color: #ccc;
36
+ border-radius: 6px;
37
+ margin-left: 10px;
38
+ }
39
+
40
+ .progress {
41
+ height: 100%;
42
+ border-radius: 5px;
43
+ background-color: #088395;
44
+ transition: width 0.3s ease-in-out;
45
+ }
46
+
47
  #whiteboard {
48
  border: 3px solid #088395;
49
  /* Simple black border */
 
84
  flex-direction: column;
85
  align-items: center;
86
  justify-content: center;
87
+ margin-right: 10px;
88
  }
89
 
90
  #btn-container {
 
92
  flex-direction: row;
93
  align-items: center;
94
  }
95
+
96
+ #parent {
97
+ display: flex;
98
+ justify-content: center;
99
+ align-items: center;
100
+ }
101
  </style>
102
  </head>
103
 
104
  <body>
105
  <h3 style="text-align:center;">Draw and Predict Handwritten Digits</h3>
106
+ <div id="parent">
107
  <div id='container'>
108
  <canvas id="whiteboard" width="400" height="200"></canvas>
109
  <div id='btn-container'>
110
  <button id="capture-button">Predict</button>
111
  <button id="clear-button">Clear</button>
112
  </div>
 
113
  </div>
114
+ <div id="prediction-result"></div>
115
+ </div>
116
+
117
  <script>
118
  var canvas = document.getElementById('whiteboard');
119
  var context = canvas.getContext('2d');
 
145
  clearButton.addEventListener('click', function () {
146
  context.clearRect(0, 0, canvas.width, canvas.height);
147
  });
148
+ function isCanvasEmpty(canvas) {
149
+ const context = canvas.getContext('2d');
150
+ const pixelBuffer = new Uint32Array(context.getImageData(0, 0, canvas.width, canvas.height).data.buffer);
151
+ return !pixelBuffer.some(color => color !== 0);
152
+ }
153
 
154
  var predictionResult = document.getElementById('prediction-result');
155
  var captureButton = document.getElementById('capture-button');
156
  captureButton.addEventListener('click', function () {
157
+ captureButton.disabled = true
158
+ captureButton.style.opacity = "0.5";
159
+ captureButton.innerText = 'Loading...'
160
  var imageData = canvas.toDataURL("image/png");
161
+ var formData = new FormData();
162
+ formData.append("image", dataURLtoBlob(imageData), "image.png");
163
  fetch('/predict', {
164
  method: 'POST',
165
+ body: formData,
 
 
 
166
  })
167
  .then(response => response.json())
168
  .then(data => {
169
+ predictionResult.innerHTML = '';
170
+ data.prediction.forEach(item => {
171
+ const digitItem = document.createElement('div');
172
+ digitItem.classList.add('prediction-item');
173
+
174
+ const digitText = document.createElement('div');
175
+ digitText.classList.add('prediction-digit');
176
+ digitText.textContent = `Digit ${item.digit}:`;
177
+
178
+ const progressContainer = document.createElement('div');
179
+ progressContainer.classList.add('progress-container');
180
+
181
+ const progressBar = document.createElement('div');
182
+ progressBar.classList.add('progress-bar');
183
+
184
+ const progress = document.createElement('div');
185
+ progress.classList.add('progress');
186
+ progress.style.width = item.prob;
187
+
188
+ progressBar.appendChild(progress);
189
+ progressContainer.appendChild(progressBar);
190
+
191
+ digitItem.appendChild(digitText);
192
+ digitItem.appendChild(progressContainer);
193
+
194
+ predictionResult.appendChild(digitItem);
195
+
196
+ captureButton.disabled = false
197
+ captureButton.style.opacity = "1";
198
+ captureButton.innerText = 'Predict'
199
+ });
200
  })
201
  .catch(error => {
202
  console.error('Error:', error);
203
+ predictionResult.textContent = 'Something wentr wrong, try again.';
204
+ captureButton.disabled = false
205
+ captureButton.style.opacity = "1";
206
+ captureButton.innerText = 'Predict'
207
+ });
208
  });
209
+
210
+ // Function to convert a data URL to a Blob
211
+ function dataURLtoBlob(dataURL) {
212
+ var arr = dataURL.split(',');
213
+ var mime = arr[0].match(/:(.*?);/)[1];
214
+ var bstr = atob(arr[1]);
215
+ var n = bstr.length;
216
+ var u8arr = new Uint8Array(n);
217
+ while (n--) {
218
+ u8arr[n] = bstr.charCodeAt(n);
219
+ }
220
+ return new Blob([u8arr], { type: mime });
221
+ }
222
  </script>
223
  </body>
224