File size: 53,461 Bytes
a1ece17
 
 
 
338bbe8
a1ece17
5fe67f2
a1ece17
 
 
d64050c
 
55abd01
 
 
 
 
338bbe8
c32023c
 
 
 
338bbe8
c32023c
 
 
 
 
 
 
55abd01
 
a1ece17
 
55abd01
 
338bbe8
c32023c
 
 
 
 
0bbec58
 
 
 
 
 
9c06ef3
0bbec58
 
 
55abd01
 
 
 
 
 
 
 
338bbe8
55abd01
c32023c
55abd01
 
 
c32023c
 
 
 
 
338bbe8
c32023c
 
 
 
 
0bbec58
 
 
338bbe8
0bbec58
338bbe8
0bbec58
 
338bbe8
 
 
 
0bbec58
 
338bbe8
 
 
 
 
 
 
 
 
 
55abd01
 
 
 
 
338bbe8
55abd01
 
 
c32023c
 
 
 
 
338bbe8
c32023c
 
 
 
 
338bbe8
 
 
 
 
 
 
 
 
 
 
 
c32023c
55abd01
 
 
 
 
 
 
9c06ef3
55abd01
 
 
 
 
 
 
 
 
 
 
 
 
 
9c06ef3
8e35bc7
 
 
 
 
 
 
 
 
 
 
 
 
9c06ef3
8e35bc7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c06ef3
55abd01
 
 
 
 
9c06ef3
 
 
 
 
8e35bc7
55abd01
6b67354
55abd01
 
 
 
9c06ef3
55abd01
 
 
5993d2f
 
056ab4f
55abd01
 
 
 
9c06ef3
c32023c
 
 
 
 
b85505a
 
 
 
 
9c06ef3
 
 
 
 
 
 
 
 
 
b85505a
 
 
55abd01
5993d2f
 
55abd01
5993d2f
55abd01
 
 
 
 
 
338bbe8
55abd01
 
 
 
 
 
 
 
 
 
338bbe8
55abd01
 
 
338bbe8
 
55abd01
 
 
 
9c06ef3
338bbe8
 
 
 
 
 
 
 
 
 
9c06ef3
 
 
 
 
338bbe8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c06ef3
338bbe8
 
 
 
 
 
 
 
9c06ef3
338bbe8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c06ef3
53075d2
 
 
 
 
c004d97
55abd01
c32023c
 
 
 
 
9c06ef3
c32023c
 
 
 
338bbe8
c32023c
 
 
 
 
9c06ef3
338bbe8
 
 
 
 
 
 
 
9c06ef3
338bbe8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c06ef3
338bbe8
 
 
 
 
 
 
 
9c06ef3
338bbe8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c06ef3
c32023c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c06ef3
c32023c
 
 
 
 
 
 
 
 
 
9c06ef3
 
 
 
 
 
 
 
c32023c
 
 
 
 
9c06ef3
 
 
 
 
 
 
 
 
 
c32023c
 
9c06ef3
c32023c
 
 
 
 
338bbe8
 
c32023c
55abd01
 
 
0bbec58
53075d2
 
 
 
 
48f5984
0bbec58
48f5984
 
 
0bbec58
87d9e93
53075d2
 
 
 
 
0bbec58
 
 
 
 
9490409
87d9e93
0bbec58
 
 
 
8e35bc7
0bbec58
c32023c
9490409
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2e6f5e
9490409
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1ece17
 
 
 
48f5984
a1ece17
d64050c
a1ece17
 
d64050c
 
 
 
 
 
a1ece17
d64050c
 
5fe67f2
a1ece17
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "from datasets import load_dataset\n",
    "import fastcore.all as fc\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib as mpl\n",
    "import torchvision.transforms.functional as TF\n",
    "from torch.utils.data import default_collate, DataLoader\n",
    "import torch.optim as optim\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": [
     "exclude"
    ]
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "plt.rcParams['figure.figsize'] = [2, 2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "tags": [
     "exclude"
    ]
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found cached dataset mnist (/Users/arun/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332)\n",
      "100%|██████████| 2/2 [00:00<00:00, 75.69it/s]\n"
     ]
    }
   ],
   "source": [
    "dataset_nm = 'mnist'\n",
    "x,y = 'image', 'label'\n",
    "ds = load_dataset(dataset_nm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def transform_ds(b):\n",
    "    b[x] = [TF.to_tensor(ele) for ele in b[x]]\n",
    "    return b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "tags": [
     "exclude"
    ]
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAI4AAACOCAYAAADn/TAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAIsUlEQVR4nO3df2yU9R0H8PfHtrQroFJBVrGjHVRAweHWCASCJBuumiXOLAyYWTbjQiYy58Y2fmzZ5oILJgsJMjSRrCsmig7mAjFsZBIlLkNGdeBgrOWnWqnFwkDmUNrrZ3/0bPu59cfTz3P33NPr+5WQu89zd32+MW+/z/eeu+dzoqogGqgrsj0AGpwYHHJhcMiFwSEXBodcGBxyCRUcEakWkXoROSYiK9M1KIo/8Z7HEZE8AA0A5gNoBLAfwGJV/Wf6hkdxlR/itbcCOKaqJwBARJ4FcBeAXoMzTAq1CMND7JKidhH/blHVManbwwRnHIC3u9WNAGb09YIiDMcM+XyIXVLUXtRtb/a0PUxwpIdt/3fcE5ElAJYAQBGKQ+yO4iTM4rgRQFm3+noAp1OfpKpPqmqVqlYVoDDE7ihOwgRnP4BKEakQkWEAFgHYkZ5hUdy5D1Wq2iYiywDsApAHoEZVD6dtZBRrYdY4UNWdAHamaSw0iPDMMbkwOOTC4JALg0MuDA65MDjkwuCQC4NDLgwOuTA45MLgkAuDQy6hPuQcSiTf/qfKGzM68Gvrf1Bu6kRxu6nHTzhj6uKl9jty764bZurXq54zdUviA1PP2Lq88/7E778aeJwDwRmHXBgccmFwyGXIrHHyplSaWgsLTH36tqtNfWmmXTeUXGXrVz5j1xlh/PG/I0396K+rTb1v2jOmPtl6ydRrm+eb+rpXMt/ziDMOuTA45MLgkEvOrnES8z5r6nW1G019Q4E9NxKlVk2Y+qcbvmnq/A/sGmXW1mWmHvlOm6kLW+yap7huX8gR9o8zDrkwOOTC4JBLzq5xCuvtZeyvfVhm6hsKmtO2r+VNM0194j/2c6zaCdtMfaHdrmHGPvbXUPvPRqdqzjjkwuCQC4NDLjm7xmlretfUGx5dYOpHqu1nT3lvjDD1waUb+vz7a1pu7rx/7Au2YVTifJOpvzZrqalPPWj/VgUO9rmvOOKMQy79BkdEakTkjIgc6ratRET+LCJHk7ejMjtMipsgM04tgOqUbSsB7FbVSgC7kzUNIYH6HItIOYAXVHVqsq4HME9Vm0SkFMDLqjqpv79zpZRoXLqO5o2+xtSJs+dMffKZm019eG6NqW/95Xc671+7Mdx5mDh7Ube9pqpVqdu9a5yxqtoEAMnba8MMjgafjL+rYrva3OSdcZqThygkb8/09kS2q81N3hlnB4BvAFibvN2ethFFJNFyts/HW9/v+/s6N93T9csD7z2RZx9sTyDXBXk7vgXAXgCTRKRRRO5DR2Dmi8hRdPwIyNrMDpPipt8ZR1UX9/JQPN4eUVbwzDG55OxnVWFNWdFg6nun2Qn2t+N3d96/bcED5rGRz2Xmeu044YxDLgwOuTA45MI1Ti8S5y+Y+uz9U0z91o6ua5lWrnnKPLbqq3ebWv9+lanLHtlrd+b8XdRs4oxDLgwOufBQFVD7wSOmXvTwDzvvP/2zX5nHDsy0hy7Yq2dw03B7SW/lJvtV07YTp3yDjBBnHHJhcMiFwSGXQF8dTZc4fXU0nXT2dFNfubbR1Fs+vavP109+6VumnvSwPRWQOHrCP7iQ0v3VURriGBxyYXDIhWucDMgbay/6OL1woqn3rVhv6itS/v+95+Ttpr4wp++vuWYS1ziUVgwOuTA45MLPqjIg0WwvMxv7mK0//JFtN1ss9lKcTeUvmPpLdz9kn/+HzLej7Q9nHHJhcMiFwSEXrnHSoH3OdFMfX1Bk6qnTT5k6dU2TasO5W+zzt9e5x5YpnHHIhcEhFwaHXLjGCUiqppq64cGudcqm2ZvNY3OLLg/ob3+kraZ+9VyFfUK7/U5yHHDGIZcg/XHKROQlETkiIodF5LvJ7WxZO4QFmXHaACxX1SnouNDjARG5EWxZO6QFaazUBODjDqMXReQIgHEA7gIwL/m0zQBeBrAiI6OMQH7FeFMfv/c6U/984bOm/sqIFve+Vjfbr7fsWW8vvBq1OeUS4Rga0Bon2e/4FgD7wJa1Q1rg4IjICAC/B/CQqr4/gNctEZE6EalrxUeeMVIMBQqOiBSgIzRPq+rzyc2BWtayXW1u6neNIyIC4DcAjqjqum4PDaqWtfnlnzL1hc+VmnrhL/5k6m9f/Ty8Un9qce/jdk1TUvs3U49qj/+aJlWQE4CzAXwdwD9E5EBy22p0BOZ3yfa1bwFY0PPLKRcFeVf1FwDSy8O5f8kC9YhnjsklZz6ryi/9pKnP1Qw39f0Ve0y9eGS4n49e9s6czvuvPzHdPDZ62yFTl1wcfGuY/nDGIRcGh1wYHHIZVGucy1/sOh9y+Xv2pxBXT9xp6ts/YX8eeqCaE5dMPXfHclNP/sm/Ou+XnLdrmPZQex4cOOOQC4NDLoPqUHXqy105b5i2dUCv3Xh+gqnX77GtRCRhz3FOXnPS1JXN9rLb3P8NvL5xxiEXBodcGBxyYSs36hNbuVFaMTjkwuCQC4NDLgwOuTA45MLgkAuDQy4MDrkwOOTC4JBLpJ9Vich7AN4EMBqAv09IZnFs1nhVHZO6MdLgdO5UpK6nD87igGMLhocqcmFwyCVbwXkyS/sNgmMLICtrHBr8eKgil0iDIyLVIlIvIsdEJKvtbUWkRkTOiMihbtti0bt5MPSWjiw4IpIHYCOAOwDcCGBxsl9yttQCqE7ZFpfezfHvLa2qkfwDMAvArm71KgCrotp/L2MqB3CoW10PoDR5vxRAfTbH121c2wHMj9P4ojxUjQPwdre6MbktTmLXuzmuvaWjDE5PfQT5lq4P3t7SUYgyOI0AyrrV1wM4HeH+gwjUuzkKYXpLRyHK4OwHUCkiFSIyDMAidPRKjpOPezcDWezdHKC3NJDt3tIRL/LuBNAA4DiAH2d5wbkFHT9u0oqO2fA+ANeg493K0eRtSZbGNgcdh/E3ABxI/rszLuNTVZ45Jh+eOSYXBodcGBxyYXDIhcEhFwaHXBgccmFwyOV/atVD7hyCzrEAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 144x144 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "dst = ds.with_transform(transform_ds)\n",
    "plt.imshow(dst['train'][0]['image'].permute(1,2,0));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 1024\n",
    "class DataLoaders:\n",
    "    def __init__(self, train_ds, valid_ds, bs, collate_fn, **kwargs):\n",
    "        self.train = DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate_fn, **kwargs)\n",
    "        self.valid = DataLoader(valid_ds, batch_size=bs, shuffle=False, collate_fn=collate_fn, **kwargs)\n",
    "\n",
    "def collate_fn(b):\n",
    "    collate = default_collate(b)\n",
    "    return (collate[x], collate[y])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "tags": [
     "exclude"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([1024, 1, 28, 28]), torch.Size([1024]))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dls = DataLoaders(dst['train'], dst['test'], bs=bs, collate_fn=collate_fn)\n",
    "xb,yb = next(iter(dls.train))\n",
    "xb.shape, yb.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Reshape(nn.Module):\n",
    "    def __init__(self, dim):\n",
    "        super().__init__()\n",
    "        self.dim = dim\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return x.reshape(self.dim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "def conv(ni, nf, ks=3, s=2, act=nn.ReLU, norm=None):\n",
    "    layers = [nn.Conv2d(ni, nf, kernel_size=ks, stride=s, padding=ks//2)]\n",
    "    if norm:\n",
    "        layers.append(norm)\n",
    "    if act:\n",
    "        layers.append(act())\n",
    "    return nn.Sequential(*layers)\n",
    "\n",
    "def _conv_block(ni, nf, ks=3, s=2, act=nn.ReLU, norm=None):\n",
    "    return nn.Sequential(\n",
    "        conv(ni, nf, ks=ks, s=1, norm=None, act=act),\n",
    "        conv(nf, nf, ks=ks, s=s, norm=norm, act=act),\n",
    "    )\n",
    "\n",
    "class ResBlock(nn.Module):\n",
    "    def __init__(self, ni, nf, s=2, ks=3, act=nn.ReLU, norm=None):\n",
    "        super().__init__()\n",
    "        self.convs = _conv_block(ni, nf, s=s, ks=ks, act=act, norm=norm)\n",
    "        self.idconv = fc.noop if ni==nf else conv(ni, nf, ks=1, s=1, act=None)\n",
    "        self.pool = fc.noop if s==1 else nn.AvgPool2d(2, ceil_mode=True)\n",
    "        self.act = act()\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return self.act(self.convs(x) + self.idconv(self.pool(x)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cnn_classifier():\n",
    "    return nn.Sequential(\n",
    "        ResBlock(1, 8, norm=nn.LayerNorm([8, 14, 14])),\n",
    "        ResBlock(8, 16, norm=nn.LayerNorm([16, 7, 7])),\n",
    "        ResBlock(16, 32, norm=nn.LayerNorm([32, 4, 4])),\n",
    "        ResBlock(32, 64, norm=nn.LayerNorm([64, 2, 2])),\n",
    "        ResBlock(64, 64,  norm=nn.LayerNorm([64, 1, 1])),\n",
    "        conv(64, 10, act=False),\n",
    "        nn.Flatten(),\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "def kaiming_init(m):\n",
    "    if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):\n",
    "        nn.init.kaiming_normal_(m.weight)        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {
    "tags": [
     "exclude"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train, epoch:1, loss: 1.8902, accuracy: 0.3183\n",
      "eval, epoch:1, loss: 1.0976, accuracy: 0.6274\n",
      "train, epoch:2, loss: 0.5929, accuracy: 0.8003\n",
      "eval, epoch:2, loss: 0.2895, accuracy: 0.9102\n",
      "train, epoch:3, loss: 0.2396, accuracy: 0.9264\n",
      "eval, epoch:3, loss: 0.1343, accuracy: 0.9597\n",
      "train, epoch:4, loss: 0.1139, accuracy: 0.9651\n",
      "eval, epoch:4, loss: 0.0801, accuracy: 0.9763\n",
      "train, epoch:5, loss: 0.1368, accuracy: 0.9582\n",
      "eval, epoch:5, loss: 0.0882, accuracy: 0.9722\n"
     ]
    }
   ],
   "source": [
    "model = cnn_classifier()\n",
    "model.apply(kaiming_init)\n",
    "lr = 0.1\n",
    "max_lr = 0.3\n",
    "epochs = 5\n",
    "opt = optim.AdamW(model.parameters(), lr=lr)\n",
    "sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)\n",
    "for epoch in range(epochs):\n",
    "    for train in (True, False):\n",
    "        accuracy = 0\n",
    "        total_loss = 0\n",
    "        dl = dls.train if train else dls.valid\n",
    "        for xb,yb in dl:\n",
    "            preds = model(xb)\n",
    "            loss = F.cross_entropy(preds, yb)\n",
    "            if train:\n",
    "                loss.backward()\n",
    "                opt.step()\n",
    "                opt.zero_grad()\n",
    "            with torch.no_grad():\n",
    "                accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()\n",
    "                total_loss  += loss.item()\n",
    "        if train:\n",
    "            sched.step()\n",
    "        accuracy /= len(dl)\n",
    "        total_loss /= len(dl)\n",
    "        print(f\"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {total_loss:.4f}, accuracy: {accuracy:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {
    "tags": [
     "exclude"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "eval, epoch:1, loss: 0.0882, accuracy: 0.9722\n",
      "eval, epoch:2, loss: 0.0882, accuracy: 0.9722\n",
      "eval, epoch:3, loss: 0.0882, accuracy: 0.9722\n",
      "eval, epoch:4, loss: 0.0882, accuracy: 0.9722\n",
      "eval, epoch:5, loss: 0.0882, accuracy: 0.9722\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(epochs):\n",
    "        train = False\n",
    "        accuracy = 0\n",
    "        total_loss = 0\n",
    "        dl = dls.valid\n",
    "        for xb,yb in dl:\n",
    "            preds = model(xb)\n",
    "            loss = F.cross_entropy(preds, yb)\n",
    "            with torch.no_grad():\n",
    "                accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()\n",
    "                total_loss  += loss.item()\n",
    "        accuracy /= len(dl)\n",
    "        total_loss /= len(dl)\n",
    "        print(f\"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {total_loss:.4f}, accuracy: {accuracy:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {
    "tags": [
     "exclude"
    ]
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAj8AAAB+CAYAAADLN3DXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAc5klEQVR4nO3dd3yV1RkH8N+ThL0CGCDISBDCFEQUQaVOtG7UOhARq1itiijWWm21rlZxIVKUqihacWCduGitG1GGIgioTBkyIwHCkiSnf7yXc97nmje8SW6Sm7y/7+fDh+fkvOvek5ucnCnGGBARERFFRUpVPwARERFRZWLlh4iIiCKFlR8iIiKKFFZ+iIiIKFJY+SEiIqJIYeWHiIiIIqXKKz8iskJEjq/q56gsInKbiDxb1c9RUVieNQfLsmZhedYcLMvyq/LKT2nE3oA9IpLv+9ehgu95sYh8WpH3iLufEZHtvtf3RGXdu7KJSLqIPC0iG2L/bquEe1Z2eZ4mIt/EyvIzEelWWfeuTCJyrYgsE5GtIvKjiIwRkbQKvmellaWI7Cci00UkV0TyRGSGiBxRGfeuCuIZHXu9uSJyr4hIBd+zsj+bB4nIHBHZEfv/oMq6d2Xi783iJbTyU9E/7GJeNMY09P1bVgn3rGy9fK9veFU9RCWU5xgA9QFkAegLYKiI/LaC71lpRKQTgMkArgCQDmAqgDcq6XMS/ywVfc+pAA42xjQG0ANALwDXVPA9K1M+gEsAZABoCmA0gKlVUZZApZTn7wAMgleOPQGcCuDyCr5npRGR2gBeB/AsvPJ8GsDrsa9X9rPw92ZilOr35j4rP7HmtZtEZKGIbBaRp0SkbizvaBFZLSI3isg6AE+JSIqI/ElElsb+YpgiIs181xsqIj/E8v5crpdaCiJygIi8H7vvJhGZLCLpvvy2IvKKiGyMHfMPEekKYAKA/rHaZF7s2A9FZLjvXFXLFZGxIrIq9lfwHBEZUFmvc1+SrDxPA3CvMWaHMWYFgInwfsGEeR3VoTxPBPCJMeZTY0wBvF+Y+wM4KuT5JUqmsjTGLDXG5O29FIAiAB1Dvo6kL0tjzC5jzHfGmKLY6yuE90uzWclnhpdM5QlgGIAHjDGrjTFrADwA4OKQryPpyxPA0QDSADxkjNltjHkYXrkeG/L8EiVZWZbndVSHsiyTsC0/Q+D9ID8AQA6Av/jyWsH7AdAe3l8L18D7i+EoAK0BbAYwHgDEa/J/FMDQWF5zAG32XkhEjtz7RpXgNBH5SUQWiMjvQz4/4H1j3x27b1cAbQHcFrtvKoA3AfwArxVifwAvGGMWwfurfUasNpke8l6zABwE7315DsBLe7/xf/FQIvNE5IK4L38sIuti31RZIe9ZGslUnhIX9wj5GqpDeQp++fpK8xrDSJqyFJELRGQrgE3wWgz+GfI1VIeytF8DsAvAGwCeMMZsCHnfsJKlPLsD+NqX/jr2tTCqQ3l2BzDP6P2d5iH8awwjWcoS4O/NXzLGlPgPwAoAV/jSJwNYGouPBvAzgLq+/EUAjvOlMwHsgVfLvhXem7M3r0Hs/OP39Ryx47vBK4RUAIcDWAtgcJhzi7nWIABfxeL+ADYCSCvmuIsBfBr3tQ8BDC/pmLjjN8NrkgO8b5xnSzj2VwBqw+sm+QeAb4p7rrL+S7LyfBbAKwAawWslWApgd00pTwBdAGyPva+1AdwCr0XkpppWlnHP1QnAnQBa1ZSyjDunLoDBAIYlohyTsTzhtWx1iStTA0BqQnnGPosvxH1tMoDbamBZ8vdmMf/C9jWu8sU/xN7IvTYaY3b50u0BvCoiRb6vFQJoGTvPXssYs11EckM+A4wxC33Jz0RkLIDfAHh+X+eKSAsADwMYAO+XbQq8NxfwarM/GK9rotxE5HoAw+G9XgOgMYD9wpxrjPk4Fv4sIiMBbIVX456fiGeLSYryhPfXzjgAiwHkwivHwWFOrA7laYz5VkSGwfswZsKr7C0EsDoRzxWTLGVpGWMWi8gCAI8AOGtfx1eHsvSLvafPi8giEZlrjPl6nyeFlyzlmQ/vvdmrMYB8E/tNU5JqUp7xrw+x9LZEPFdMUpQlf28WL2y3V1tf3A7Aj/77xh27CsBJxph037+6xus3Xuu/lojUh9eEV1YGuluhJHfHju9pvEGZF/rOXQWgnRQ/8Ky4D/t2eAN192q1N4j1U94I4FwATY3X5LelFM9Z3P0TPcsiKcrTGPOTMWaIMaaVMaY7vO/HmSFPrxblaYz5tzGmhzGmOYC/wvshNyvMuSElRVkWIw1ec38Y1aIsi1ELQKJnzSRLeS6A13W5V6/Y18KoDuW5AEBPETWDrSfCv8YwkqUs4/H3JsJXfq4SkTaxAVg3A3ixhGMnAPibiLQHABHJEJEzYnn/BnBqrI+yNoA7SvEMEJEzRKSpePrCazl43Zf/oQRPl24Er7afJyL7A7jBlzcT3jfYPSLSQETqipvGuh5AG9GzAOYCOEtE6otIRwCXxt2nALHmQBG5Fb/8CyPo9XUXb/plqog0hDfIcA28JtFESpbyPEBEmsde70nw+r7v8uVX6/KMvYY+sdeXAW8MzFRjzLdhzw8hWcpyeOyvxL1jFG4C8D9ffrUuSxHpt/e9EZF6InIjvL/KvwhzfikkRXkCeAbAKBHZX0RaA7gewKS9mdW9POF1wRQCuEZE6ojI1bGvvx/y/DCSoiz5e7N4Yd/A5wD8B8Cy2L+7Sjh2LLzBgP8RkW0APgdwGAAYYxYAuCp2vbXwms9sF4CIDBCR/BKufT6AJfCaJp8BMNoY87Qvvy2A6QHn3g7gYHi1ybfgjTVB7LkK4c086ghgZeyZzotlvw/vr4F1IrIp9rUx8Ppc18ObIjnZd59pAN4B8D28ps5d0M2fingD0IbEki3hfUC2wnufswCcaozZE3R+GSVLefaB1yy5Dd5fGENi19yrupcn4L1/eQC+i/1/WdC5ZZQsZXkEgPkish3A27F/N/vyq3tZ1oE3ADUX3g/WkwGcYoz5Mej8MkqW8vwnvOUL5sMbP/EW9AD2al2expif4Y1fuQje5/ISAINiX0+UZClL/t4s7hr76sIVkRXwBim9V+KBVUxE2gB4yRjTv6qfJZmxPGsOlmXNwvKsOViWya9KFuiqCMaY1fBGn1MNwPKsOViWNQvLs+aIcllWq+0tiIiIiMprn91eRERERDUJW36IiIgoUlj5ISIiokgJNeB5YMo57BurYv8teikhCx2yLKteosoSYHkmA342aw5+NmuWksqTLT9EREQUKaz8EBERUaSw8kNERESRwsoPERERRQorP0RERBQprPwQERFRpLDyQ0RERJHCyg8RERFFCis/REREFCms/BAREVGksPJDREREkcLKDxEREUVKqI1NiYgSpeDYPja+e+IEldendqqNU0X/bVZoigKvecuGg2z80d2Hq7yGUz4vy2MSRU5aZisbrzmng8pLHbjJxqe2W6DyVu5sZuN5GzMDr/9+76dVemJedxu/d96hKq9wwXchnrjs2PJDREREkcLKDxEREUUKu72IKOHSstur9MIbW9r42RNcV1fv2vrvryIYF5vC0Pe7vcVXNh48ooXK2zYl9GWoGJsu76/SM28db+M+o6+2ccuZ+eq45YMa2PiLIQ+ovKap9W18wItXqLxOf3JlaXbvLsMTU0kKjuuj0k88OdbGDVLExi1SG6BM2pWUWU+lRjVbZuO3M45WeamoWGz5ISIiokhh5YeIiIgihZUfIiIiipSkHPOTtn9rG+cerTsQN5++I/C80zvNt/Hy7c1tvOjtHHVcxrw9Nq7z1qwyPydR1Eit2i7Ro5PKK6rrfpwc9/gnKu+1pq+Euv6Wol02PvSt61RenQ1uFMCI37yp8n7XZIWNj2mup8i+ldHFxoUbN4Z6jqhLbdzYxnXPXK/y/OOyZt04zsb35XZTxy3+6GgbbywyKq9hihvP9e2541VeTgM3BqjlR3rkR5PJXLagvGpt0eOoJuUdZuMFW9009fw9ddRx36924/ZqrdR5vm8JZM4oUFn1F+faeMX5rVReu3e22jhtvp4+r79jEo8tP0RERBQprPwQERFRpCRFt9ePN+gVWe+5/Ekbn1Bve+B5KRCVLgpqKLt6WuA1Jm7R3WqFvvrgfzfqZtyF0/WKl36tP3VNfQ3m/Rh4XLzCjW7VTE7rdFLTm9h445muHHafkaeOm9f3eRvHrwDce9YQGzea3FjlNXzpi0Q8ZuSsvv4QG381YlwJR4bzcv5+Kn3TJ2fbOOeKmSrPP32+39ClcVdy3SMf5HZWOezqKoMWbtjAhwe+FOqUG5ov1Omz/Om6oW/9/SluKYQPjtXn3bnzEhvXf1V/f8BUdEdJzWBmf6PSs884wMZ5E1yVYFh73cU47t1BNm5972eh7+dfsKLtnctUXlWWGFt+iIiIKFJY+SEiIqJIYeWHiIiIIiUpxvzUytc9fyWN8/Hr9tRVKt1ksYvz27rxQN1O/F4d93wHNwbosiarVJ5/3JB/+iwAoGPws6QMc/cLHHtUjEPuH2HjVmPC96PWNKnd9TiNonFuqfx72j9m48veGa6OO36sGwOQlqfHTBWc4sYNHX7TDJVX62Y3PuiDe/WYs8bPczptkB2ZwTurh3XM/HNs3OgPtVVezjezA89bN9AtgdGzdvDi95t2NlTpOtgUcCTtZfr3UumLJr2R0Ouft/TXKn1I+kobx48V8jum3i6dHveIjY81V6q8+q9yHF8YRUcepNJ3/Mv9fO1cy32++z46Sh2XNcGNFQq/8UzyYssPERERRQorP0RERBQpSdHt1fKzvNDH9vyn6ybKuiO4m6ipL86/R7/MQS1PC3Wvtafpnam3tw11WolqbdXT89s8MsfGUZuo6Z/OfsVrU1XemBUDbXzf+RfYuNOs4Kbt+PevrdscGgv/pQtv2f3pNj77xukqb87CHjYu+npR4P2iqPFS9/fSXzbo3aFvznBdi4MXn63yFn3XxsaZH7hrFH2juxi3Du7nEvqjgqEj3wl8Lv/K0BibEZe7IvA88uxqqVfsTU8NXkm/TNe/tJFKf9DS7RS/Z7zuwrx5v/kI48Cbv1bppa+W8eEiIO8i934/cvtYlefvQj7kvmtt3PYh/fu1JnR1+bHlh4iIiCKFlR8iIiKKFFZ+iIiIKFKqbMxPWpbbVqL/v74KPC5+imT2I27H5rB9kKZA7zJbsCbc9hMZE/Rx8SMJEiFK43xS6tdX6bTX69n4lgVnqLzWQ9xU2KLtP5T73gU/6CUN2p3j0p+c1l/lbb3D7TSceUkzlVeY+1O5n6U6a/mwGwcw7yk9jmNw66EusV5vKVH/9+5Hzabe7rt+wJ/09e9qoXf4DuvQqW4H+Jy3ZpZwJBVn/aF63E2P2rm+VD2EsblIT0sfMOkPNs5e+aXKS1nstjl4Z43eRijsmB8K74DLv7Vxnzpxy0t8NMzG2Q9FZ7kVtvwQERFRpLDyQ0RERJFSZd1eBRPdSpI3/WKFTzfHde2EA1TO9otcfa3Zd3sS/lz1PnbNg0XbtiX8+lG29Em9RPaijpNsfPohJ6u8gu3hVvkOK35V05Unuab8rD/r1Z/r+mbd17TpnYkU//lIqeeWExg+U3dln97g/Qp9lm73b7BxQQnHkZPaMdvGWf10t3DL1HBdXWsLd9r4pAl/VHlZf3ddKFHq3k8GRUf1Vunnsp+y8Q3rdF6HoQtsHKVyYssPERERRQorP0RERBQprPwQERFRpFTZmJ+3u7hdg0vaBf3j+4KnvqbErX9fmt3Ug67xQr6b0H7L53r6ddcb19i4YN36Ut8ritL2dztxLxowSeX1fvBqG2euTfwUS+nd3cb1/rZO5dXZ3DT+cCql1O6dVXrIi9NsfHqDzeW+/vg8Pd5v1S637MA9rWapvCW/zbRx1i0ryn3vmmj53XpJh4793BISb+ZMjT88lC93t7Jxm7+H/wzvHNTXxld2eKVM9z63mV7S4I/DLrdx06dnxB9e46U0cktPFNySG3jc6+/2U+nsgui9VwBbfoiIiChiWPkhIiKiSEmKXd23xK0MetbCC2289utWKq/pAgTa7Ho5kNlrXeBxx7T63sa3xq0mem5DN2X23OMfV3nj3utk4yde0CtPt70rOitjlsb3I9vbeNZu3S2Z+UBi37Nt5+vm3HNvedfGV6UvVXk9HxuR0HtH0eZeuuvQ/9kJa85unR7y2lU27jhK7/ie2t19/vAf3e314bD7bHzhByNVXtr7c0r9XDXRrwfOVukHMj8PODK8B2+4wMb1EH5l7dyu7lfP4EZlG0Lw3rbuKp3xnuvGi+JyBxvP72Hj2d0eVXnTd7mlZTreE/dLtLNbgmRrj+Y2Lqyth4SsG1CEIJkfuXaU9Pf1z9rCjRvjD08KbPkhIiKiSGHlh4iIiCKlyrq9Tu97qksU6ea0emuX27gDliOssPN3ZjZId8+RfmrgcUuubK/Sc4aNsfHg381TeYdnX2vjnEt183KU+Gd3AcBDZ7mVRS/7eqjKa434lb2LJ2m+b9MD9QyjXaPdStBrV+rG7of/d6KNs06arPKyx7qmX67iXDaNftDd1VevOdLGHyzrpPIyXnYrBtfNdeVUZ32+Oq7jgrJ1xeznW5F45a/1xo0dKnZx6aRWeMzBNj644ZsJuebIH4+wcaO5bnhBSV1NqRl6W2jTZ2vAkeH9tKeBSofdsLqmyusSPNs5PcX1L58wY6XKO7GBGx7QtbbefDq0QS58a0ddlTVqzrk2PuByvUl1Yd6Wst0vAdjyQ0RERJHCyg8RERFFCis/REREFClVt6t7FfbPFvl2DC8qYffwrD/rZ+y7Y5SNH7lkgsqbdvxYG1/b9WKVV7hocVkes1oqbKFHXp1Qz72/t9XSowLixwf5Lbs0y8a/OtXtEJ5eS0/T/GS0m97e+eUvVd6BX+yx8Z+fukjltcnj0gTlJdPnqvSKfqk2zi6ahzBKNd5q/SYbxq/+7F/KoFd//XnTe89Hy7Kzatl4SKO1ZbrGDesO09e8wr33ZkUJa4/4zxmvlyyZ339SmZ6FyqZ77Xq+eIXKW13gxtxO3OLK6W+zTg68Xq2VdVS6oL0b//fBUeNU3ncDnrFxh9uvUHmdRpZ/uYWyYssPERERRQorP0RERBQpSbHCc3XR9m+uq2TEbt18N3/UIzZedI3u+sn5fcU+VzJJWb5apc9c7JYS+Kz38/rYmW4F0fhNaQcuONvG3955oI3rTtWryDaCazZddo/euLHtHtf10n7iEpXH6e0VoKiC39UC122aX1i3hAOja9epfVX63dMf9KXK9p5NX9tBpZvNCdfVlZbtlgqZ0GdyCUeWzaz17VS6Gb4PODIast9w09l7ZA9RebuWu01PW+jF0ZE+262wXbjELS3TCXoYQVi/z7pApV/8dIqNz/mV7uaaW6Y7JAZbfoiIiChSWPkhIiKiSGHlh4iIiCKFY37KSOJWEi80brrgiAHvqbxpaFwZj5QU4pcrNye76eZ9L9E7qRvfpsGZE+eqvDo7VvhSKxAkrW0bG78z+D6VN2j8H23cej2ntlel1E6+cSO1awUfWIIdWe5zdGPz4D0rvvyyo0p3wqaAI2uewrp6J+7stLKN89lUuNPGuz7aLy433NiaFfc3tPERdfeUcGTZ7HfeGpUO3nM8GlI+ckuC7P9R+PMSPVJvR+cWKl1HyvZ5r2hs+SEiIqJIYeWHiIiIIoXdXhVg3MfHq3QOZgYcWfMV7dhh4xb/CO56KmuT9Y4n3arCJ3yiu9U6Peje9+D9jqk05FC37MBP3RuqvNxe7l3u2lvv3jwu263y2iatHhJtdG53d++H9ErGJe02XtM0maVXpf/VPLej9sc9p8QfHmhVoVvBt/V9wZ/btDb72zj3MV2uf+2YmF3k/Tq/cqWNc3bPSfj1qWwkzVUlhj/8qsqrJe5n9Etf91F5nVB1ZciWHyIiIooUVn6IiIgoUlj5ISIiokjhmJ9SSKlf38YF/bcGHtf8y9TAPCqf3accqtLPdB5j46FjR6k8UxCl0R4Vo/CYg1X6romP27hPnfijgy342U13PW/u+TZOTdGjvbo03WDjx9p+GPr6OXXdOJ9Plkd364uCH1ap9Jbph9t4dbedKq+ksVfDx420ccujdgYed9x4N6d6RNPFoZ+zJOt90+yPef4GlZdzq9tygZ/vqpOa3kSl97zslqEY0mi2ylu6J9/GXe/IVXlVWYJs+SEiIqJIYeWHiIiIIoXdXqWw7KZeNv6m/z9U3ujcbjbOePEblRf1lUfLKzUjw8bjxz+s8k58wq3i3O5NruKcaM3u1FPW/V1dX/2sv7MvG+u6Smrl68UFWrzmVgVuusl1j6iVnwHMPsNNpcd1H4Z+zmPruSne916md5Vu/viM0Nepadre5T4TYwYdo/IeyPw8/nBrzvXjKuyZ4p2w8CyVXjujtY073KrLjktWOCkNGqh0nu+zkz51gcor2rat3Nf/uV8XG2/9Q57Km9H15cDrXHzd9Tauv+yLUj9HRWHLDxEREUUKKz9EREQUKaz8EBERUaRwzE8c/9Te5afVVnkzzrnfxinQ00Qf//QoG+dsi+52FhVh6Ui3S/fc3W1UXvZzbopzoncnjqrU7p1t/GC7iXG57vv+vPeuVDldJ7lxBkXZupyWjswp9l6XnzlNpUuaLl3kGz23vnC3ymsk7u+4VkNWqLw9j4OS2LYprVW6fYTHaJXGkr/2VOnFFz5q44GXnqbyNj/Xw8YtPl4feM1vr3HjK/90/FSV97sm0wPPW13gprMP9I3DBIB2ryXPOB8/tvwQERFRpLDyQ0RERJFSY7u9/Ksxo1N7lbf0vHQbjxr0hsr7TSM3xbNJil4pdnWBm2h5xOSrVF7XB5bYmN0v5edfQfS6s1wZ/X3Seeq4Nks4vT3h1m+y4ZStumnd3y31/ckT9HknJ/Yx4qfSX3HvNTbOeFR3jRQdeZCNa+VuT+yD1BBLTmmqv/Bl8cdVhL9s0Lt5v/pufxt3eGqWyuN09nA6PrtZpU8/7Nc2npLzospremd9lNYeo3+TXbv2MBu/Mb+XyssZ/7ON282qHj+T2fJDREREkcLKDxEREUUKKz9EREQUKVU25ufn/7pxOPm79fbQ8u/moa6xsZ/uk2zfwe0IfVDz1Ta+r9WzgddIgaj0Zt8wg5xpl6u8Lg/vsHGHuXrMAcf5JNaie93U6EvTvrJx+9c3qeP4vide4Sa38/LLtw9UeYMfmGfj/VKDdwWP9+BPbmn8x+YeGXicyXPLS3R55CeVl7EoeAp0yqdzbczvieIV5W1R6cPmuG1AvujzXMLvN3DB2Tauf4HeXiF7kytLjvEpm6J536r0brfaCi7oeYnKW3FmMxuPGaqXr6gre2z8249/a+O2r6eq4+q95pZw6YQ5Kq86liFbfoiIiChSWPkhIiKiSKmybq/3ur1q46L4RrPe4a4R32Xlv87ygl02vn7t0eq4adMOsXHmpwUqr8HcVTbOWTs77vpUUeJ3EB599BQb3/bYhTZuvbB6TKOsKRq+pFdnvfil4C6rsDriq30fBHZfJZrZrVfFzrzSrcp7bF+9Wvdlf3e7dA9uFLwi8KH3jLBxozW6xBp/tc7GBb6uVKp48V1i7VxvNcbc3jXwvJy47qyajC0/REREFCms/BAREVGksPJDREREkVJlY36Ou8xNIy+op+tgG87ZaeNmjXeovIuz3BTJB+cdp/JSFzW0cfu33dRKM2u+Oi4LwVNmCwJzqCKtvlIvl96l9v9s3PZtN72d40CIEqNg9Rob1/fFADD5lTYuRhsEaYngMXj8WUrJjC0/REREFCms/BAREVGkVFm3V5233U6+deLysl9GoFeR4Y7DvMDjquOKk1E26cqHVPo3L1xn4+yFwd2UREREpcWWHyIiIooUVn6IiIgoUlj5ISIiokipsjE/RH43Z/dV6ewSliMgIiIqD7b8EBERUaSw8kNERESRIsZwUjgRERFFB1t+iIiIKFJY+SEiIqJIYeWHiIiIIoWVHyIiIooUVn6IiIgoUlj5ISIiokj5P0TtsIzRFepPAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 720x720 with 5 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "xbv,ybv = next(iter(dls.train))\n",
    "logits = model(xbv)\n",
    "probs = F.softmax(logits, dim=1)\n",
    "idx = 5\n",
    "_,axs = plt.subplots(1, idx, figsize=(10, 10))\n",
    "for actual, pred, im, ax in zip(ybv[:idx], probs[:idx],xbv.permute(0,2,3,1)[:idx], axs.flat):\n",
    "    ax.imshow(im)\n",
    "    ax.set_axis_off()\n",
    "    ax.set_title(f'pred: {pred.argmax(0).item()}, actual:{actual.item()}')\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {
    "tags": [
     "exclude"
    ]
   },
   "outputs": [],
   "source": [
    "torch.save(model.state_dict(), 'classifier.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_model = cnn_classifier()\n",
    "loaded_model.load_state_dict(torch.load('classifier.pth'));\n",
    "loaded_model.eval();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {
    "tags": [
     "exclude"
    ]
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAj8AAAB+CAYAAADLN3DXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAaPklEQVR4nO3dd3hVRfoH8O+bRopEOpEiCCSAiiKIgg1RdBXrsmtDEVQUxMK6Cuy6P8W26C6yrqhrxQa6CIrYXRFsKAKiSBMQpKj0SO9J5vfHOZk5c5/ccJLcfr6f58nzvJM55565d3Jv5k47opQCERERUVCkxbsARERERLHExg8REREFChs/REREFChs/BAREVGgsPFDREREgcLGDxEREQVK3Bs/IrJKRHrGuxyxIiL3iMj4eJcjWlifqYN1mVpYn6mDdVlzcW/8VIWI/ElEfhKR7SKyVkQeEZGMKF+zv4jMiOY1PNdqICJfikixiGwVkZkicnIsrh0PItJDRD4RkW0isipG14xZfbrXe0ZElopImYj0j9V140VEskRkiYj8EoNrxbQuPdftJyJKRAbE+tqxIiIfiMhOz89+EVkQ5WvG8rP21JDnt9Ot0z/E4vqxJiKdRORz93luEJEhUb5erD9n00XkAbddsENEvhOROpWdE9HGT7QbIgDeAdBJKZUP4GgAxwK4NcrXjKWdAK4F0BBAXQD/APBODF7XCsXgursAPA9gaJSvE0/fAxgM4Nt4FiKGf0NDAWyM0bViTkTqAvgrgEVxLkdU61Mpda5S6pDyHwBfAZgUzWvGklLqi5Dndz6cz98PY12WGHyBbwDneT0NoD6ANgA+iuY14+BeACcB6AYgH0BfAHsrO+GgjR+3e+2vIrJYRLaIyAsiku3mnS4iv4jIcBFZD+AFEUkTkb+IyAq3B2OiiNTzPF5fEVnt5v2tKs9OKbVCKbW1/KEAlMGpyIMSkdYiMt297mYRecXbMhSR5iIyWUQ2ucc8LiLtATwFoJvbYt7qHvup91tfaCtXRB4VkZ/dHqq5InKqz+e3Vym1VClV5j6/UjiNoHqVn+lfgtXnbKXUOAA/VeN5JHx9us/xCaXUNBzkjVgdiVSX7vlHALgKwINVPC8p6tL1IIAxADZX8byDSrT69DxOSwCnAhjn8/hkqs9y/QC8rpTaVc3zLQlWl38G8D+l1CtKqX1KqR1KqR98Po+Er0txvpD8CcD1SqnVyrFQKVWzxo/rSgC/A9AaQBGA//PkFcD559wCwA1wemIuBtAdQBMAWwA84RbySABPwmmVNYHTCm3meRKnlL9Q4YhIHxHZDufD51g4rVk/BM4HVxMA7QE0B3CP+5jpAN4FsBpASwBNAUxw/0AGAZjpfkOo4/NacwB0hPO6vApgUvkffgXPZ76I9An9HZx/lm8DeE4pFelv0glTnzWQNPUZZYlUl48BuBPAnio+h6SoSxE5AcDxcD7YoyWR6rPc1QC+UEqt9Hl8UtSn5/e5AP4I4CWf1/QrUeqyK4DfROQrEdkoIu+IyOE+n0My1GUHACUA/igi60VkmYjcdNCrKaUq/QGwCsAgT7oXgBVufDqA/QCyPfk/ADjTkz4MwAEAGQDudl+c8rw89/yeBytHBeUqBHA/gIKqnuuefzGA79y4G4BNADIqOK4/gBkhv/sUwIDKjgk5fguAY934HgDjfZQvG8AVAPpV5/klU30C6AlgVQ2fV6LX5wwA/VO1LgH8HsCHnmv/kkp1CSAdwDcAulV0nVSrz5ByLa/J324i1mfIOX0BrAQgqViXAJYB2AqgC5z/K2MAfJkqdQmgDwAFYCyAHADHuOU6q7Ln4nes8WdPvBpOK7DcJmV3L7UA8KaIlHl+VwqgsXuefiyl1C4RKfZZBotS6kcRWQTgPwB6H+x4EWkEp9JPBVAbTq/XFje7OYDVSqmS6pSlgmvdDmAAnOer4IxBNqjKY7iv6X9F5AcRmaeU+j4SZXMlXH1WVbLVZxTFvS5FJA/AP+F8wFdZktTlYADzlVIzI1GOSsS9Pr1E5BQ4vRSvV+GcZKhPr34AXlbuf9IISpS63APgTaXUHAAQkXsBbBaRQ5VS2yo7MUnqsryn+T6l1B4A80VkApzPo6nhTvI77NXcEx8OYK0nHfoH8zOAc5VSdTw/2UqpXwGs8z6W291Y32cZKpIBp0vRjwfdsh6jnAnTV8Hp0isv8+FS8cSzit4QuwDketIF5YE7TjkcwKUA6iqny2+b51pVlQmgVTXPDSdR67MqkrU+Iy0R6rIQTrf3F+4chskADnO7oFv6OD8Z6vJMAL93n9N6OJMrR4vI4z7OrYpEqE+vfgAmK6V2VuGcZKjP8sdoDqcn5mW/51RBotTl/JDrlcd+XqdkqMv5lVwzLL+Nn5tEpJk4E7DuBPBaJcc+BeDvItICAESkoYhc5Oa9DuB8d4wyC8B9VSgDRGSA2xItHwf9K4BpnvxPReSeMKfXhjObf6uINIW9wmg2nD+wh0QkT0SyxSwx3wCgmVvecvMA9BaRXBFpA+C6kOuUwO0OFJG74bRg/Ty/ruWvjYjkiMhwOC3/WX7Or4JEqc80d0w300lKtvd1Tvb6dJ9DlvscBUCmW5ZIrrJMhLpcCOfDuaP7MwDO69wR7jfWFKjL/nDmPJQ/x2/grDCp9kTiMBKhPuE+Xg6ASwC8WEFestdnub4AvlJKrajieX4kSl2+AKfh3lFEMgHcBWe4aat7raSuS7fuvgDwNxGpJc6E68vgzEcKy+8L+CqcpXE/uT8PVHLso3Am6n4kIjsAfA3gRLeQiwDc5D7eOjjdZ3o/EHH3XqjksU8GsEBEdgF43/2505PfHMCXYc69F0AnOK3J9+B8O4VbrlIAF8BZObbGLdNlbvZ0OMta14tI+QqPR+CMuW6AM0nuFc91/gfgAzjjrKvhTFz2dn9aRGSRiFzpJmvBmeRWDOBXON125yml1oY7v5oSpT5Pg9Nl+T6cb0Z7YC/BTPb6hPt89sDpKXjGjU8Ld341xL0ulVIlSqn15T8AfgNQ5qZL3cOSui6VUltDnuN+ANsPNmxQDXGvT4+L4dTJJxXkJXV9elyNyE90LpcQdamUmg7n/+R7cLahaANnnky5VKjLK+AMHRa75bxLOatsw5KDDXOKs/ncAKXUx5UeGGci0gzAJKVUt3iXJZGxPlMH6zK1sD5TB+sy8cVl87xoUEr9Amf2OaUA1mfqYF2mFtZn6ghyXSbV7S2IiIiIauqgw15EREREqYQ9P0RERBQobPwQERFRoPia8HxW2iUcG4uzqWWTIrKpHusy/iJVlwDrMxHwvZk6+N5MLZXVJ3t+iIiIKFBSZqk7ERERxV7GES10PGK6fRu4y9++RceFQ76OWZkOhj0/REREFChs/BAREVGgsPFDREREgcI5P0RERFRtW59M1/FxWXafSq3fErOPJTFLRURERBQlbPwQERFRoHDYi4iIiPw7oYOVHNN2rI43lO638lq+tVXHZVEtVNWw54eIiIgChY0fIiIiChQ2foiIiChQUmvOj2ccctk1OTp+85zHrMM6ZGXqOF3s9l/v5WfpeMddzay8tM++i0gxiYiIklXOqI1W+pgss9S9y0PDrLzG876KSZmqij0/REREFChs/BAREVGgJPWw17rbT7LS9w16Wcfn5W7T8du7GlrHzduXreO0kMV3r7X+UMfH3tLPymv+WfXLStF3+/JFVvrMnH06Dh3ePK/zOTouWbc+ugUjIkpym2/opuOPW4228qbtqavjJhOXW3ml0S1WtbHnh4iIiAKFjR8iIiIKFDZ+iIiIKFCSbs5PRkFjHY8YON7K887zOXL8zTouHLXMOq50c7GOJTPLyps4tYuOJx3/rJU3oM9tOs5/9euqFJti4ADSrXQZlIlVyMizUqDgKb6um53uWqLjomf32gfPXhCLIhElJO8cHwCYe8+TOl60X6y8Mef00nHphpXRLViEsOeHiIiIAoWNHyIiIgqUpBv2WnFjax1fmPe+lXf24t46LnzYLLfzDnOFUgfsO9AuXnOYjosK7SGx4gt36zj/VZ8FpqhKyzbbFqQj/FDWUeNuttKtiudGrUxBlV6/no53n9Daysv5zGxDULZ7N2Jp37lmKHtLD3toa0mPZ3T8SY9DrLxH2rSPbsESmLcuMybbn4OFh5jdfRfcbHbV//X0POu4Wt3M5+6szpH/wDzlL/Z7us64mRG/RtDIcUfp+OO77OXspcp81t74pyFWXs7y2dEtWBSw54eIiIgChY0fIiIiChQ2foiIiChQkm7OT+YuE/93R2MrL+eKnTqubJ6PV8mZna30hFOf0vEPB+zl0S0eZ1sx3tLbtrHS+c//puMzc8LPJclfYadD53pRzf3at52OZw991Mq7d6N5n809LrrvI+l8lJU+aeQsHY9oZM/1+mC32ZZ/xONXW3kFSMy7UUdDxmEFVrp4rJn/9Fmb/4Y97+Nxi3XcM2eHlZfm+W5dFnIboUjY3Xubla4zLuKXCISMpk10fM74GTo+JK2WdVzhmzfquOhd+32UjBuH8L85ERERBQobP0RERBQoSTfs1fRfZknda5NOtvJKN6/29Rhl3Y/T8SNjn7Dy2mdm6vjoGddZeS1nzPNbTIqg9MaNdJzxzE4rb1zLqbEuDoWxp3H4zm/vcFOvMwbpOGN6ZLYc8A51DZ7wppV3bq4ZjgkdfBk212yP0Wb8EisvUe9GXW1i78qb3qCBjre9kGvlfdYh/FCXV+hQl19rS/bpONMuFhqm1wJFT+hdDZY8ZKaPTKnzjo7/UWwPHxfebIaPk3GYKxR7foiIiChQ2PghIiKiQGHjh4iIiAIl6eb8qBJzF+aSleHn+EgtM27849gjrbxZ3R8Pe17PQbfq+IgP59nX9ltIiqw6+Tp8o81rvk/r+u0VOm48ZbmVl3LzOeLhhA5W8o6L39JxWsj3qk/2mK3xIzHPR53c0UoPe9mscz4t297G4IAytd1r8aVW3hGXz9dxKv5NeLeGWDqogZW3+NLHavz4M/eaz9lr37/ezvTO5Qn58Gw90dTR6vOyrbwFfcf4unb69Dq+jiPbxuvs7V2WnGH+H64sMbd/mXmeva0I8Es0ixVz7PkhIiKiQGHjh4iIiAIlIYe90vPNMIfUr2vlbTzd7EZZd6m9o++P15pl6k+f/pKOe+SE7tRqulmvXHm2lZP3mVnuWspdgBPC2rMbHfwg2F22ALD/c9PNX7ppWUTLFFQlZ5gu8xNHz7Hy+uWbYei5++zvVSPuMttG5OPral17x2VddfyXB1628k7JNnUfupzdO9RV6+xV1bp2stpVaO7OXpVhruHru+l4wZYmVl7m0No6Tt9qtp4oXDkLfpWe3sk8RuudlRxpaz/Z3Mm98An/1wu67X3Me2fOXfb2Lr+W7tFx/+F36Lj2z9V7nyYL9vwQERFRoLDxQ0RERIHCxg8REREFSkLO+Vl1y9E6/n5wzZdjVuaVIz6y0kM/OVHHU1fZ23vXG2/udJz7JsebY+WWwZN9Hdf7yaFWuumo4NyVO1aGPG1ue+C9bQRgz7X5YMcxVl7dj8ycq+ouKd/TZ6uOf5e7Lexxz21rZaWDNs/Hr6UHTE3cvuISKy/rWhNnrF5j5XlXrZfAH9XtWCt923Ov6riyW2R4l9IDQNvntuu4rCwVNyeIjuxr1um4LGTfge7Thui4aEJqz/PxYs8PERERBQobP0RERBQoCTns1XJKsY779epp5eVl+Ft+/tUU083asMdaK2/1yoY67tlxsZV3eQPT7Teqqz209V0n07F/u9xs5eVO5jBYvB3+xnorzU7xmts/tYWV7p4z05Oy7w5998YuOl54vr08urTYfg/6sXbYSVZ6Ssd/elL2cIh3qOu9i08IeaQVVb52qshbZj5LO7x0q5VXMNu8Q3KmzLby/A5nVWb/OebvYfPAXVae37vBX/PJtVa66Ptval6wANhwi/3emXuk2cX5gc32zuxtBy3Usd+7GGQ0b2alSxscqmMptT95y+YvQSJizw8REREFChs/REREFCgJOexVumipjotPtvOK4U8zeFb6PGjnFWGVju11DMDDx5hVDzcMy7HylvR4TsejR9s3R72j7CYdh3YhU9XtO9d0mbfOGhvHkqQ+ybSHr9SHZkftj9qFrrQzxz6zraWVM+84byr8MJf3Zpt7W9Sx84Zv0PG37eyVnpliVlt6b1YKAP8Zf4GOmy3jKr9ypcvMkN8Rd8Z2+G/jcWbH/W9PeKmSI21HTrxFx22HfWvl8ebS4aXl5em4zaXhd7Qf/0F3K91qnxnKTm9opoRsvNi+sam60Pz3HX3k61beydkHdLxb2VNTRmw4VceLhhxt5aXNmBe2nNHGnh8iIiIKFDZ+iIiIKFDY+CEiIqJAScg5P/HkXZbX9u+FVt60rrk6PjPHvqP8ukvNOGerKdEpWyo7cPbxVvry0e/r2DueHKr9pwN0XLQhuEuaa+LnO+zX/tt2j+o49A7pXltK8qz0+iGe5bViH1vrrE06/rLjBM/jh79CaM4Bz4SP0PlGLV9apeNILNOmqksvtHfWHt5voq/zQndxbjfG7EZccsDf1iYErP6z2d7l+1ahd0Ywb8jWr223cpY+abaGePdc894vCpkLmOZ5jDn77NlXnWb31fGzx46z8kYVmG1g5r5sbwlzX5ezdFxa/BtiiT0/REREFChs/BAREVGgcNirEqU//GilB7/fX8dL//AfK49LMGtmW8tMK33doaGbEDiWH9hnpetNzdZx6fbtoYeTD2MGPF2t84bXX2Slhw5b4PPM6n3nGunZmfbdR+3luvV+nRl6OMWAd6ir+2S7/q+obbYtCB3C7PC5Ga5u+qL93s9ayV2cq0N5hprTQsad08W859571x6W8tpeZv6TtZ12vZXX5C1TT3lv2MNXTWDulHDN3bdYeQsHmm1humbb733JsofWYok9P0RERBQobPwQERFRoLDxQ0RERIES3Tk/aek6TD/EXhabjPMzCr7yjKP+IX7lCJo1JXt0fPkjw6y8ghd5K4OaGjSrr5Ve2P3ZqF5vxEZzH4wRjeb6Pm/aiFN0XG8K5/gkglWXFuj4rXqTrLxMMZ//H+3OtvKajTX/ejI/5hyfSGg8x2wJUjYwZBaqMrOuykJmqM7eZ/6vDRt+m44LJ9nzevw6qdd8K+293oA1p9p523dU6xqRwJ4fIiIiChQ2foiIiChQojrsVdrd7Dj5t7HPW3kjL7lSx2quvWQ2UaQXtbbSw+4fH/ZYCZtDNbVgv+laL3iUw1yR1qrPPCt9Wt9bdbytjf2XnXXsFh0PLJph5T026QKE02qM2Tn9h4fM++reXt+FPafXkoutdM6U2WGPpdhZc7fZyfv96/6p4zLYOzV7h7pGDulv5dX6eE50ChdguV8u1fF1a3pYeSOamB3zD8/IsfKOzNyr4y1tzVBl7eOO8n3tH6+ureMXmzxs5a0sMZ8hv9xh/09N2zXP9zUijT0/REREFChs/BAREVGgsPFDREREgRLVOT8/9TbbYXerVWrlbeySr+OG/le7Rl1athmnXn9GIyvvvNxtYc+r90FO2DyqmHjuGryzRRwLQpY648wy8jqVHPc26lvpFgg/H8v77q/buKGOQ+/q/t0+830s61r7MXi39vhIy7O3KZl8rZnT0SSjVujh2sDp/XVc9B7n+ESbd/uYTde3s/KmTSrS8VW1V1l5h6SZOvz+Rs/d4G+0H997y4zQ5fJeB5R9y4rz7x+i4wYzEmeLCvb8EBERUaCw8UNERESBEtVhr7ZjPbs39rbznho+Rse3b7jZyst9s3o7S/qV3raNjpcObGDlnXaSWXb/dvPHEU7nOVdZ6WZvmfNKQw+mCqU3Mq/9wv7hX2tKftLZLJud1fllHYfe7bvPOzfpuHD119EuFoWx+YZuOm7bb4mV1yozM/TwChXdwKGueClbaNfZG+3NFI6nBl1k5eX3XqfjqUe94evx5+6z01d+OUDHh71lD3s1eD1xhrq82PNDREREgcLGDxEREQUKGz9EREQUKFGd85O2bZeOp+3JtfLOzNmt43//6zErb8JdJ4Z9zNe/7qLj+t+lhz2uuLOZeXNUu5+tvAFNP9RxZcvXlx3Yb6WvHvlnHTebuNjKS8a71BPFyu7meRX+/pM9h1jpdk//pmPOnYudjJaHW+nTB5p5lyMLws/BXH7AbEBw6/LL7MfEmgiVjiKp4VMhc3CeMuH56Fytx2yD8LepSVTs+SEiIqJAYeOHiIiIAiWqw14lK1fr+LGzzrHybh7cRMev/fFRK29k42/CPubIizx5F4U9zLejPre3kc2fbobnGn220cprsMx0F7JLPnae7OvdJ2F+3MpB/qXn51vpEaPH6jhTzHD14P/1t44rXBzdbS7ISG9odtpOf8ke4q9sqOuXErPOefiqS3Sc0ZPDXJQ82PNDREREgcLGDxEREQUKGz9EREQUKFGd8+Plnf8DAK2HmvQd026y8lZdaO4eO+uCR6y8UZtO0XFlc4M6ze6r453r7OW07f9drOMjfgyZQ6LM3Wo5rycxpO01y2lDb4dACaqgoZWsnbZXx5tLzfySptNiViIKsfx2c5ufBW3GVHKk7bIHhuq4/rOJeesCooNhzw8REREFChs/REREFCgxG/aqTNaH9t1/i8wGzOg7+OSQo82wVGW7UTbB4rB5HM5KDKWbzfBju4n20GdZnqml9j8tjVmZKDJKl62w0utLDtXxw7+aO4bnTubS9ljZeYm9c/70PqM8qVphz7t/Uycr3XDCQh1zGJqSFXt+iIiIKFDY+CEiIqJAYeOHiIiIAiUh5vxQMKl9Zpv8Nrd9HfY4ztFKfk//7mwdq23b41iSYCk7paOOH/zH01Zew/Tw83yOef5WHbceZ9/mp2zHitDDiZIOe36IiIgoUNj4ISIiokDhsBcRRV3JT6viXYRAylq7VceTtxxv5Z1YYHZnHlXcwcpr/epmHYduW0CUCtjzQ0RERIHCxg8REREFChs/REREFCic80NElKK8c61+CLkb0IXoUsmZP0alPESJgj0/REREFChs/BAREVGgiFLq4EcRERERpQj2/BAREVGgsPFDREREgcLGDxEREQUKGz9EREQUKGz8EBERUaCw8UNERESB8v9D4PCrWQgE7QAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 720x720 with 5 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    xbv,ybv = next(iter(dls.train))\n",
    "    logits = loaded_model(xbv)\n",
    "    probs = F.softmax(logits, dim=1)\n",
    "    idx = 5\n",
    "    _,axs = plt.subplots(1, idx, figsize=(10, 10))\n",
    "    for actual, pred, im, ax in zip(ybv[:idx], probs[:idx],xbv.permute(0,2,3,1)[:idx], axs.flat):\n",
    "        ax.imshow(im)\n",
    "        ax.set_axis_off()\n",
    "        ax.set_title(f'pred: {pred.argmax(0).item()}, actual:{actual.item()}')\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {
    "tags": [
     "exclude"
    ]
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAj8AAAB+CAYAAADLN3DXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAaPklEQVR4nO3dd3hVRfoH8O+bRopEOpEiCCSAiiKIgg1RdBXrsmtDEVQUxMK6Cuy6P8W26C6yrqhrxQa6CIrYXRFsKAKiSBMQpKj0SO9J5vfHOZk5c5/ccJLcfr6f58nzvJM55565d3Jv5k47opQCERERUVCkxbsARERERLHExg8REREFChs/REREFChs/BAREVGgsPFDREREgcLGDxEREQVK3Bs/IrJKRHrGuxyxIiL3iMj4eJcjWlifqYN1mVpYn6mDdVlzcW/8VIWI/ElEfhKR7SKyVkQeEZGMKF+zv4jMiOY1PNdqICJfikixiGwVkZkicnIsrh0PItJDRD4RkW0isipG14xZfbrXe0ZElopImYj0j9V140VEskRkiYj8EoNrxbQuPdftJyJKRAbE+tqxIiIfiMhOz89+EVkQ5WvG8rP21JDnt9Ot0z/E4vqxJiKdRORz93luEJEhUb5erD9n00XkAbddsENEvhOROpWdE9HGT7QbIgDeAdBJKZUP4GgAxwK4NcrXjKWdAK4F0BBAXQD/APBODF7XCsXgursAPA9gaJSvE0/fAxgM4Nt4FiKGf0NDAWyM0bViTkTqAvgrgEVxLkdU61Mpda5S6pDyHwBfAZgUzWvGklLqi5Dndz6cz98PY12WGHyBbwDneT0NoD6ANgA+iuY14+BeACcB6AYgH0BfAHsrO+GgjR+3e+2vIrJYRLaIyAsiku3mnS4iv4jIcBFZD+AFEUkTkb+IyAq3B2OiiNTzPF5fEVnt5v2tKs9OKbVCKbW1/KEAlMGpyIMSkdYiMt297mYRecXbMhSR5iIyWUQ2ucc8LiLtATwFoJvbYt7qHvup91tfaCtXRB4VkZ/dHqq5InKqz+e3Vym1VClV5j6/UjiNoHqVn+lfgtXnbKXUOAA/VeN5JHx9us/xCaXUNBzkjVgdiVSX7vlHALgKwINVPC8p6tL1IIAxADZX8byDSrT69DxOSwCnAhjn8/hkqs9y/QC8rpTaVc3zLQlWl38G8D+l1CtKqX1KqR1KqR98Po+Er0txvpD8CcD1SqnVyrFQKVWzxo/rSgC/A9AaQBGA//PkFcD559wCwA1wemIuBtAdQBMAWwA84RbySABPwmmVNYHTCm3meRKnlL9Q4YhIHxHZDufD51g4rVk/BM4HVxMA7QE0B3CP+5jpAN4FsBpASwBNAUxw/0AGAZjpfkOo4/NacwB0hPO6vApgUvkffgXPZ76I9An9HZx/lm8DeE4pFelv0glTnzWQNPUZZYlUl48BuBPAnio+h6SoSxE5AcDxcD7YoyWR6rPc1QC+UEqt9Hl8UtSn5/e5AP4I4CWf1/QrUeqyK4DfROQrEdkoIu+IyOE+n0My1GUHACUA/igi60VkmYjcdNCrKaUq/QGwCsAgT7oXgBVufDqA/QCyPfk/ADjTkz4MwAEAGQDudl+c8rw89/yeBytHBeUqBHA/gIKqnuuefzGA79y4G4BNADIqOK4/gBkhv/sUwIDKjgk5fguAY934HgDjfZQvG8AVAPpV5/klU30C6AlgVQ2fV6LX5wwA/VO1LgH8HsCHnmv/kkp1CSAdwDcAulV0nVSrz5ByLa/J324i1mfIOX0BrAQgqViXAJYB2AqgC5z/K2MAfJkqdQmgDwAFYCyAHADHuOU6q7Ln4nes8WdPvBpOK7DcJmV3L7UA8KaIlHl+VwqgsXuefiyl1C4RKfZZBotS6kcRWQTgPwB6H+x4EWkEp9JPBVAbTq/XFje7OYDVSqmS6pSlgmvdDmAAnOer4IxBNqjKY7iv6X9F5AcRmaeU+j4SZXMlXH1WVbLVZxTFvS5FJA/AP+F8wFdZktTlYADzlVIzI1GOSsS9Pr1E5BQ4vRSvV+GcZKhPr34AXlbuf9IISpS63APgTaXUHAAQkXsBbBaRQ5VS2yo7MUnqsryn+T6l1B4A80VkApzPo6nhTvI77NXcEx8OYK0nHfoH8zOAc5VSdTw/2UqpXwGs8z6W291Y32cZKpIBp0vRjwfdsh6jnAnTV8Hp0isv8+FS8cSzit4QuwDketIF5YE7TjkcwKUA6iqny2+b51pVlQmgVTXPDSdR67MqkrU+Iy0R6rIQTrf3F+4chskADnO7oFv6OD8Z6vJMAL93n9N6OJMrR4vI4z7OrYpEqE+vfgAmK6V2VuGcZKjP8sdoDqcn5mW/51RBotTl/JDrlcd+XqdkqMv5lVwzLL+Nn5tEpJk4E7DuBPBaJcc+BeDvItICAESkoYhc5Oa9DuB8d4wyC8B9VSgDRGSA2xItHwf9K4BpnvxPReSeMKfXhjObf6uINIW9wmg2nD+wh0QkT0SyxSwx3wCgmVvecvMA9BaRXBFpA+C6kOuUwO0OFJG74bRg/Ty/ruWvjYjkiMhwOC3/WX7Or4JEqc80d0w300lKtvd1Tvb6dJ9DlvscBUCmW5ZIrrJMhLpcCOfDuaP7MwDO69wR7jfWFKjL/nDmPJQ/x2/grDCp9kTiMBKhPuE+Xg6ASwC8WEFestdnub4AvlJKrajieX4kSl2+AKfh3lFEMgHcBWe4aat7raSuS7fuvgDwNxGpJc6E68vgzEcKy+8L+CqcpXE/uT8PVHLso3Am6n4kIjsAfA3gRLeQiwDc5D7eOjjdZ3o/EHH3XqjksU8GsEBEdgF43/2505PfHMCXYc69F0AnOK3J9+B8O4VbrlIAF8BZObbGLdNlbvZ0OMta14tI+QqPR+CMuW6AM0nuFc91/gfgAzjjrKvhTFz2dn9aRGSRiFzpJmvBmeRWDOBXON125yml1oY7v5oSpT5Pg9Nl+T6cb0Z7YC/BTPb6hPt89sDpKXjGjU8Ld341xL0ulVIlSqn15T8AfgNQ5qZL3cOSui6VUltDnuN+ANsPNmxQDXGvT4+L4dTJJxXkJXV9elyNyE90LpcQdamUmg7n/+R7cLahaANnnky5VKjLK+AMHRa75bxLOatsw5KDDXOKs/ncAKXUx5UeGGci0gzAJKVUt3iXJZGxPlMH6zK1sD5TB+sy8cVl87xoUEr9Amf2OaUA1mfqYF2mFtZn6ghyXSbV7S2IiIiIauqgw15EREREqYQ9P0RERBQobPwQERFRoPia8HxW2iUcG4uzqWWTIrKpHusy/iJVlwDrMxHwvZk6+N5MLZXVJ3t+iIiIKFBSZqk7ERERxV7GES10PGK6fRu4y9++RceFQ76OWZkOhj0/REREFChs/BAREVGgsPFDREREgcI5P0RERFRtW59M1/FxWXafSq3fErOPJTFLRURERBQlbPwQERFRoHDYi4iIiPw7oYOVHNN2rI43lO638lq+tVXHZVEtVNWw54eIiIgChY0fIiIiChQ2foiIiChQUmvOj2ccctk1OTp+85zHrMM6ZGXqOF3s9l/v5WfpeMddzay8tM++i0gxiYiIklXOqI1W+pgss9S9y0PDrLzG876KSZmqij0/REREFChs/BAREVGgJPWw17rbT7LS9w16Wcfn5W7T8du7GlrHzduXreO0kMV3r7X+UMfH3tLPymv+WfXLStF3+/JFVvrMnH06Dh3ePK/zOTouWbc+ugUjIkpym2/opuOPW4228qbtqavjJhOXW3ml0S1WtbHnh4iIiAKFjR8iIiIKFDZ+iIiIKFCSbs5PRkFjHY8YON7K887zOXL8zTouHLXMOq50c7GOJTPLyps4tYuOJx3/rJU3oM9tOs5/9euqFJti4ADSrXQZlIlVyMizUqDgKb6um53uWqLjomf32gfPXhCLIhElJO8cHwCYe8+TOl60X6y8Mef00nHphpXRLViEsOeHiIiIAoWNHyIiIgqUpBv2WnFjax1fmPe+lXf24t46LnzYLLfzDnOFUgfsO9AuXnOYjosK7SGx4gt36zj/VZ8FpqhKyzbbFqQj/FDWUeNuttKtiudGrUxBlV6/no53n9Daysv5zGxDULZ7N2Jp37lmKHtLD3toa0mPZ3T8SY9DrLxH2rSPbsESmLcuMybbn4OFh5jdfRfcbHbV//X0POu4Wt3M5+6szpH/wDzlL/Z7us64mRG/RtDIcUfp+OO77OXspcp81t74pyFWXs7y2dEtWBSw54eIiIgChY0fIiIiChQ2foiIiChQkm7OT+YuE/93R2MrL+eKnTqubJ6PV8mZna30hFOf0vEPB+zl0S0eZ1sx3tLbtrHS+c//puMzc8LPJclfYadD53pRzf3at52OZw991Mq7d6N5n809LrrvI+l8lJU+aeQsHY9oZM/1+mC32ZZ/xONXW3kFSMy7UUdDxmEFVrp4rJn/9Fmb/4Y97+Nxi3XcM2eHlZfm+W5dFnIboUjY3Xubla4zLuKXCISMpk10fM74GTo+JK2WdVzhmzfquOhd+32UjBuH8L85ERERBQobP0RERBQoSTfs1fRfZknda5NOtvJKN6/29Rhl3Y/T8SNjn7Dy2mdm6vjoGddZeS1nzPNbTIqg9MaNdJzxzE4rb1zLqbEuDoWxp3H4zm/vcFOvMwbpOGN6ZLYc8A51DZ7wppV3bq4ZjgkdfBk212yP0Wb8EisvUe9GXW1i78qb3qCBjre9kGvlfdYh/FCXV+hQl19rS/bpONMuFhqm1wJFT+hdDZY8ZKaPTKnzjo7/UWwPHxfebIaPk3GYKxR7foiIiChQ2PghIiKiQGHjh4iIiAIl6eb8qBJzF+aSleHn+EgtM27849gjrbxZ3R8Pe17PQbfq+IgP59nX9ltIiqw6+Tp8o81rvk/r+u0VOm48ZbmVl3LzOeLhhA5W8o6L39JxWsj3qk/2mK3xIzHPR53c0UoPe9mscz4t297G4IAytd1r8aVW3hGXz9dxKv5NeLeGWDqogZW3+NLHavz4M/eaz9lr37/ezvTO5Qn58Gw90dTR6vOyrbwFfcf4unb69Dq+jiPbxuvs7V2WnGH+H64sMbd/mXmeva0I8Es0ixVz7PkhIiKiQGHjh4iIiAIlIYe90vPNMIfUr2vlbTzd7EZZd6m9o++P15pl6k+f/pKOe+SE7tRqulmvXHm2lZP3mVnuWspdgBPC2rMbHfwg2F22ALD/c9PNX7ppWUTLFFQlZ5gu8xNHz7Hy+uWbYei5++zvVSPuMttG5OPral17x2VddfyXB1628k7JNnUfupzdO9RV6+xV1bp2stpVaO7OXpVhruHru+l4wZYmVl7m0No6Tt9qtp4oXDkLfpWe3sk8RuudlRxpaz/Z3Mm98An/1wu67X3Me2fOXfb2Lr+W7tFx/+F36Lj2z9V7nyYL9vwQERFRoLDxQ0RERIHCxg8REREFSkLO+Vl1y9E6/n5wzZdjVuaVIz6y0kM/OVHHU1fZ23vXG2/udJz7JsebY+WWwZN9Hdf7yaFWuumo4NyVO1aGPG1ue+C9bQRgz7X5YMcxVl7dj8ycq+ouKd/TZ6uOf5e7Lexxz21rZaWDNs/Hr6UHTE3cvuISKy/rWhNnrF5j5XlXrZfAH9XtWCt923Ov6riyW2R4l9IDQNvntuu4rCwVNyeIjuxr1um4LGTfge7Thui4aEJqz/PxYs8PERERBQobP0RERBQoCTns1XJKsY779epp5eVl+Ft+/tUU083asMdaK2/1yoY67tlxsZV3eQPT7Teqqz209V0n07F/u9xs5eVO5jBYvB3+xnorzU7xmts/tYWV7p4z05Oy7w5998YuOl54vr08urTYfg/6sXbYSVZ6Ssd/elL2cIh3qOu9i08IeaQVVb52qshbZj5LO7x0q5VXMNu8Q3KmzLby/A5nVWb/OebvYfPAXVae37vBX/PJtVa66Ptval6wANhwi/3emXuk2cX5gc32zuxtBy3Usd+7GGQ0b2alSxscqmMptT95y+YvQSJizw8REREFChs/REREFCgJOexVumipjotPtvOK4U8zeFb6PGjnFWGVju11DMDDx5hVDzcMy7HylvR4TsejR9s3R72j7CYdh3YhU9XtO9d0mbfOGhvHkqQ+ybSHr9SHZkftj9qFrrQzxz6zraWVM+84byr8MJf3Zpt7W9Sx84Zv0PG37eyVnpliVlt6b1YKAP8Zf4GOmy3jKr9ypcvMkN8Rd8Z2+G/jcWbH/W9PeKmSI21HTrxFx22HfWvl8ebS4aXl5em4zaXhd7Qf/0F3K91qnxnKTm9opoRsvNi+sam60Pz3HX3k61beydkHdLxb2VNTRmw4VceLhhxt5aXNmBe2nNHGnh8iIiIKFDZ+iIiIKFDY+CEiIqJAScg5P/HkXZbX9u+FVt60rrk6PjPHvqP8ukvNOGerKdEpWyo7cPbxVvry0e/r2DueHKr9pwN0XLQhuEuaa+LnO+zX/tt2j+o49A7pXltK8qz0+iGe5bViH1vrrE06/rLjBM/jh79CaM4Bz4SP0PlGLV9apeNILNOmqksvtHfWHt5voq/zQndxbjfG7EZccsDf1iYErP6z2d7l+1ahd0Ywb8jWr223cpY+abaGePdc894vCpkLmOZ5jDn77NlXnWb31fGzx46z8kYVmG1g5r5sbwlzX5ezdFxa/BtiiT0/REREFChs/BAREVGgcNirEqU//GilB7/fX8dL//AfK49LMGtmW8tMK33doaGbEDiWH9hnpetNzdZx6fbtoYeTD2MGPF2t84bXX2Slhw5b4PPM6n3nGunZmfbdR+3luvV+nRl6OMWAd6ir+2S7/q+obbYtCB3C7PC5Ga5u+qL93s9ayV2cq0N5hprTQsad08W859571x6W8tpeZv6TtZ12vZXX5C1TT3lv2MNXTWDulHDN3bdYeQsHmm1humbb733JsofWYok9P0RERBQobPwQERFRoLDxQ0RERIES3Tk/aek6TD/EXhabjPMzCr7yjKP+IX7lCJo1JXt0fPkjw6y8ghd5K4OaGjSrr5Ve2P3ZqF5vxEZzH4wRjeb6Pm/aiFN0XG8K5/gkglWXFuj4rXqTrLxMMZ//H+3OtvKajTX/ejI/5hyfSGg8x2wJUjYwZBaqMrOuykJmqM7eZ/6vDRt+m44LJ9nzevw6qdd8K+293oA1p9p523dU6xqRwJ4fIiIiChQ2foiIiChQojrsVdrd7Dj5t7HPW3kjL7lSx2quvWQ2UaQXtbbSw+4fH/ZYCZtDNbVgv+laL3iUw1yR1qrPPCt9Wt9bdbytjf2XnXXsFh0PLJph5T026QKE02qM2Tn9h4fM++reXt+FPafXkoutdM6U2WGPpdhZc7fZyfv96/6p4zLYOzV7h7pGDulv5dX6eE50ChdguV8u1fF1a3pYeSOamB3zD8/IsfKOzNyr4y1tzVBl7eOO8n3tH6+ureMXmzxs5a0sMZ8hv9xh/09N2zXP9zUijT0/REREFChs/BAREVGgsPFDREREgRLVOT8/9TbbYXerVWrlbeySr+OG/le7Rl1athmnXn9GIyvvvNxtYc+r90FO2DyqmHjuGryzRRwLQpY648wy8jqVHPc26lvpFgg/H8v77q/buKGOQ+/q/t0+830s61r7MXi39vhIy7O3KZl8rZnT0SSjVujh2sDp/XVc9B7n+ESbd/uYTde3s/KmTSrS8VW1V1l5h6SZOvz+Rs/d4G+0H997y4zQ5fJeB5R9y4rz7x+i4wYzEmeLCvb8EBERUaCw8UNERESBEtVhr7ZjPbs39rbznho+Rse3b7jZyst9s3o7S/qV3raNjpcObGDlnXaSWXb/dvPHEU7nOVdZ6WZvmfNKQw+mCqU3Mq/9wv7hX2tKftLZLJud1fllHYfe7bvPOzfpuHD119EuFoWx+YZuOm7bb4mV1yozM/TwChXdwKGueClbaNfZG+3NFI6nBl1k5eX3XqfjqUe94evx5+6z01d+OUDHh71lD3s1eD1xhrq82PNDREREgcLGDxEREQUKGz9EREQUKFGd85O2bZeOp+3JtfLOzNmt43//6zErb8JdJ4Z9zNe/7qLj+t+lhz2uuLOZeXNUu5+tvAFNP9RxZcvXlx3Yb6WvHvlnHTebuNjKS8a71BPFyu7meRX+/pM9h1jpdk//pmPOnYudjJaHW+nTB5p5lyMLws/BXH7AbEBw6/LL7MfEmgiVjiKp4VMhc3CeMuH56Fytx2yD8LepSVTs+SEiIqJAYeOHiIiIAiWqw14lK1fr+LGzzrHybh7cRMev/fFRK29k42/CPubIizx5F4U9zLejPre3kc2fbobnGn220cprsMx0F7JLPnae7OvdJ2F+3MpB/qXn51vpEaPH6jhTzHD14P/1t44rXBzdbS7ISG9odtpOf8ke4q9sqOuXErPOefiqS3Sc0ZPDXJQ82PNDREREgcLGDxEREQUKGz9EREQUKFGd8+Plnf8DAK2HmvQd026y8lZdaO4eO+uCR6y8UZtO0XFlc4M6ze6r453r7OW07f9drOMjfgyZQ6LM3Wo5rycxpO01y2lDb4dACaqgoZWsnbZXx5tLzfySptNiViIKsfx2c5ufBW3GVHKk7bIHhuq4/rOJeesCooNhzw8REREFChs/REREFCgxG/aqTNaH9t1/i8wGzOg7+OSQo82wVGW7UTbB4rB5HM5KDKWbzfBju4n20GdZnqml9j8tjVmZKDJKl62w0utLDtXxw7+aO4bnTubS9ljZeYm9c/70PqM8qVphz7t/Uycr3XDCQh1zGJqSFXt+iIiIKFDY+CEiIqJAYeOHiIiIAiUh5vxQMKl9Zpv8Nrd9HfY4ztFKfk//7mwdq23b41iSYCk7paOOH/zH01Zew/Tw83yOef5WHbceZ9/mp2zHitDDiZIOe36IiIgoUNj4ISIiokDhsBcRRV3JT6viXYRAylq7VceTtxxv5Z1YYHZnHlXcwcpr/epmHYduW0CUCtjzQ0RERIHCxg8REREFChs/REREFCic80NElKK8c61+CLkb0IXoUsmZP0alPESJgj0/REREFChs/BAREVGgiFLq4EcRERERpQj2/BAREVGgsPFDREREgcLGDxEREQUKGz9EREQUKGz8EBERUaCw8UNERESB8v9D4PCrWQgE7QAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 720x720 with 5 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "logits = model(xbv)\n",
    "probs = F.softmax(logits, dim=1)\n",
    "idx = 5\n",
    "_,axs = plt.subplots(1, idx, figsize=(10, 10))\n",
    "for actual, pred, im, ax in zip(ybv[:idx], probs[:idx],xbv.permute(0,2,3,1)[:idx], axs.flat):\n",
    "    ax.imshow(im)\n",
    "    ax.set_axis_off()\n",
    "    ax.set_title(f'pred: {pred.argmax(0).item()}, actual:{actual.item()}')\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict(img):\n",
    "    with torch.no_grad():\n",
    "        img = img[None,]\n",
    "        pred = loaded_model(img)[0]\n",
    "        pred_probs = F.softmax(pred, dim=0)\n",
    "        pred = [{\"digit\": i, \"prob\": f'{prob*100:.2f}%', 'logits': pred[i]} for i, prob in enumerate(pred_probs)]\n",
    "        pred = sorted(pred, key=lambda ele: ele['digit'], reverse=False)\n",
    "    return pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {
    "tags": [
     "exclude"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(3)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[W NNPACK.cpp:64] Could not initialize NNPACK! Reason: Unsupported hardware.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'digit': 0, 'prob': '0.00%', 'logits': tensor(-5.5980)},\n",
       " {'digit': 1, 'prob': '0.00%', 'logits': tensor(-0.4972)},\n",
       " {'digit': 2, 'prob': '0.02%', 'logits': tensor(1.2516)},\n",
       " {'digit': 3, 'prob': '99.95%', 'logits': tensor(9.9263)},\n",
       " {'digit': 4, 'prob': '0.00%', 'logits': tensor(-5.5094)},\n",
       " {'digit': 5, 'prob': '0.01%', 'logits': tensor(0.2367)},\n",
       " {'digit': 6, 'prob': '0.00%', 'logits': tensor(-9.4633)},\n",
       " {'digit': 7, 'prob': '0.00%', 'logits': tensor(-2.4315)},\n",
       " {'digit': 8, 'prob': '0.02%', 'logits': tensor(1.4733)},\n",
       " {'digit': 9, 'prob': '0.00%', 'logits': tensor(-0.0205)}]"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "img = xb[1].reshape(1, 28, 28)\n",
    "print(yb[1])\n",
    "predict(img)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": [
     "exclude"
    ]
   },
   "source": [
    "#### commit to .py file for deployment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {
    "tags": [
     "exclude"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[NbConvertApp] Converting notebook mnist_classifier.ipynb to script\n",
      "[NbConvertApp] Writing 2920 bytes to mnist_classifier.py\n"
     ]
    }
   ],
   "source": [
    "!jupyter nbconvert --to script --TagRemovePreprocessor.remove_cell_tags=\"exclude\" --TemplateExporter.exclude_input_prompt=True mnist_classifier.ipynb\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {
    "tags": [
     "exclude"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.1046) tensor(1.) tensor(0.) tensor(0.3062)\n",
      "tensor(0.1435) tensor(1.) tensor(0.) tensor(0.3220)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAI4AAACOCAYAAADn/TAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAFjklEQVR4nO3dz6tUdRjH8fcn8wfZJsvCVNKFRK4KJItaBCHd3Ngm0CBcCG4KClqk9Q+4atdGSHQRRlRgC0FKiggilJDS5KoF5k3JhCBXofC4mENM0713zjxn5syZOZ8XXGbOd+5wvosPz/me78x9riICs0HdNe4J2GRycCzFwbEUB8dSHBxLcXAspVJwJM1ImpV0SdK+YU3Kmk/ZfRxJS4ALwDZgDjgF7IqIn4c3PWuquyu890ngUkT8CiDpI2AHsGBwlml5rGBlhVNa3W7y142IWN07XiU4a4ErXcdzwNbF3rCClWzV8xVOaXX7Mj65PN94leBonrH/Xfck7QX2AqzgngqnsyapsjieA9Z3Ha8Drvb+UkQcjIgtEbFlKcsrnM6apErFOQVskrQR+B3YCbwylFlNmRNXzyz6+gsPP17LPIYpHZyIuC3pdeAEsAQ4FBHnhjYza7QqFYeIOA4cH9JcbIJ459hSKlUcm1+/Nc00cMWxFAfHUhwcS/EaZwwmcd+mlyuOpTg4luLgWIrXOEMwjZ9F9eOKYykOjqX4UpXUho8VFuOKYykOjqU4OJbiNU5Jg6xppvH2u5crjqU4OJbi4FiK1zhD0IY1TS9XHEtxcCzFwbEUr3EW0MavSgzCFcdSHBxLcXAsxWucQtu/XzMoVxxL6RscSYckXZd0tmtslaQvJF0sHu8b7TStacpUnMPATM/YPuBkRGwCThbH1iJ91zgR8Y2kDT3DO4DniudHgK+Bt4c5saZp+75Nr+wa56GIuAZQPD44vCnZJBj5XZXb1U6nbMX5Q9IagOLx+kK/6Ha10ykbnM+B3cXz3cCx4UzHJkWZ2/GjwHfAo5LmJO0BDgDbJF2k809ADox2mtY0Ze6qdi3wkv8pQ4t559hS/FlVUtv/zsoVx1IcHEtxcCzFa5ySqnxfp/e907DmccWxFAfHUnypGoLeS08bvobqimMpDo6lODiW4jVO0mK31P3WPNNwe+6KYykOjqU4OJbi4FiKg2MpDo6lODiW4n2ckiZxr2WUXHEsxcGxFAfHUhwcS3FwLMXBsRQHx1K8j1PwPs1gXHEspUx/nPWSvpJ0XtI5SW8U425Z22JlKs5t4K2IeAx4CnhN0mbcsrbV+gYnIq5FxA/F85vAeWAtnZa1R4pfOwK8NKI5WgMNtMYp+h0/AXyPW9a2WungSLoX+BR4MyL+HuB9eyWdlnT6Fv9k5mgNVCo4kpbSCc2HEfFZMVyqZa3b1U6nMndVAj4AzkfEe10vuWVti5XZAHwGeBX4SdKZYuwdOi1qPy7a1/4GvDySGVojlWlX+y2gBV52y9qW8s6xpTg4luLgWIqDYykOjqU4OJbi4FiKg2Mp/uroCLhdrdkCHBxLcXAsxWucMZiGP8VxxbEUB8dSHBxL8RpnCPrt20zDmqaXK46lODiW4uBYitc4SW34PGoxrjiW4uBYioNjKYqI+k4m/QlcBh4AbtR24sF4bv/1SESs7h2sNTj/nlQ6HRFbaj9xCZ5bOb5UWYqDYynjCs7BMZ23DM+thLGscWzy+VJlKbUGR9KMpFlJlySNtb2tpEOSrks62zXWiN7Nk9BburbgSFoCvA+8CGwGdhX9ksflMDDTM9aU3s3N7y0dEbX8AE8DJ7qO9wP76zr/AnPaAJztOp4F1hTP1wCz45xf17yOAduaNL86L1VrgStdx3PFWJM0rndzU3tL1xmc+foI+pZuEdne0nWoMzhzwPqu43XA1RrPX0ap3s11qNJbug51BucUsEnSRknLgJ10eiU3SSN6N09Eb+maF3nbgQvAL8C7Y15wHgWuAbfoVMM9wP107lYuFo+rxjS3Z+lcxn8EzhQ/25syv4jwzrHleOfYUhwcS3FwLMXBsRQHx1IcHEtxcCzFwbGUO+5QVpdhOxaLAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 144x144 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "from PIL import Image\n",
    "import numpy as np\n",
    "import glob\n",
    "import os\n",
    "png_files = glob.glob(os.path.join('./uploads', '*png'))\n",
    "img = png_files[1]\n",
    "image = Image.open(img)\n",
    "image = TF.to_tensor(image).permute(1, 2, 0)\n",
    "print(image.mean(), image.max(), image.min(), image.std())\n",
    "print(xb[0].mean(), xb[0].max(), xb[0].min(), xb[0].std())\n",
    "plt.imshow(image);\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {
    "tags": [
     "exclude"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7ff734b20fd0>"
      ]
     },
     "execution_count": 75,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAI4AAACOCAYAAADn/TAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAI70lEQVR4nO3df2yU9R0H8PenpaURcR2MKhRSmrQguoUxGwpO0f1gqYyMzc1Ig8RtZGwB58RlQZmzJiabmWZsjE3jImEowS2gwzmMkc7JfhiBTTYQrDAIcqWzdmOuwlZ77Wd/9Gjv89Ber5/79dz1/UrI3ee55+l9Y95+73vPPfc5UVUQjVRRrgdA+YnBIRcGh1wYHHJhcMiFwSGXlIIjIg0i0iIix0TkrnQNisJPvOdxRKQYwBsAFgKIANgHoFFVD6dveBRWY1I4di6AY6p6HABE5EkASwAMGZxSGatlGJfCU1K2deJMh6pOCm5PJTiVAE7F1REA9YkOKMM41MsnUnhKyrbduv3kYNtTCY4Msu2C1z0RWQlgJQCU4aIUno7CJJXFcQTAtLh6KoDTwZ1U9VFVrVPVuhKMTeHpKExSCc4+ALUiUi0ipQCWAngmPcOisHO/VKlqVERuA/A8gGIAm1T1tbSNjEItlTUOVHUXgF1pGgvlEZ45JhcGh1wYHHJhcMiFwSEXBodcGBxyYXDIhcEhFwaHXBgccmFwyIXBIRcGh1wYHHJhcMglpQu5aHAdX51v6s6PnTX1znmPmHpWqb2If853V5m6YuOf0ji69OCMQy4MDrkwOOTCNU6Szt1ov6Ta8aHi/vvLP99sHvtS+YOmrii2a5jewPfLurXH1Gu/vs3UW5o/aQfT+pb9e+fOmVqjUWQaZxxyYXDIhcEhF65xYsZMrTR1ZOMlpt5T90NTXywD65TeQK+F585VmLpp/RcTPvfZBe+aekPdL+xYbrBdRn67ZrOp5+25zdQ1tx7sv5+p9Q5nHHJhcMiFwSEXdw9Aj0tkgoalI5dePdvU7933jqmbr3wq4fGvd3f137/xiTvNYzWPvGnqaKQ14d8qKisz9fjdtt3dtuoXEh7/z97/mnrZLQNrnqKXXk147HB26/Y/q2pdcDtnHHIZNjgisklE2kXkUNy2CSLygogcjd2+P7PDpLBJZsbZDKAhsO0uAM2qWgugOVbTKDLseRxV3SMi0wOblwC4Pnb/5wB+B2BtOgeWbsE1zdrHnzD1grL3AkfY3phz/7LU1Jd9beDcy/TWl81jw505kbH2s6rWJ6tN/Wr1VlMHzxO19dg1zU33fMvU5S/Z8WSCd41zqaq2AUDstmKY/anAZPzMMdvVFibvjPOWiEwGgNht+1A7sl1tYfLOOM8AuBXAA7HbnWkbUYa019lzIxeuaaw1bfb6m4qbI6aOBq6BiVc8cYKpOxbPNPVDTQ+bev7Y4DXFdn21t8vWX95q1zTTt2R+TROUzNvxbQBeBjBTRCIisgJ9gVkoIkfR9yMgD2R2mBQ2ybyrahzioXCcAqac4JljcinY63HGVE0z9YY7fjqi4/e2V5m6e/kFv7zT71/13aa+s95+trSqfLepg+dlgtp77PqpacXtpp7+YvbXNEGccciFwSEXBodcCnaNoxfZa1zGSfC8TTES+f1se90v7EddKIo71zLcmmXw34Qb2nWB8zTVIVjTBHHGIRcGh1wK9qWq58hRUzedXGLqnbW/SenvP955Wf/9//WWmMe+8r5Twd2Nd3u7TH3V02tMPfP+v5q61zPADOOMQy4MDrkwOORSsGucoO6P20uGPl08L+H+ZxqvMvXEX9nfqe25fOAjiXfus5dyrphtvx4TidrHm04vMnXt7a+YOoxrmiDOOOTC4JALg0Muo2aNg17bLk0DdVB54HJMqZxi6mVbnuu/3zjetlYL+k7rYlO/ffW/E+6fDzjjkAuDQy4MDrmMnjVOij747GlTLxsff17IXjZx8/FPmbrz2o5MDStnOOOQC4NDLgwOuXCNc16RvZT0+PfmmvrpST8KHDDwn65LbWOTll/PMPUUcI1DBIDBIScGh1y4xonpaviIqQ/fsjGwh10DxX8l5spnbUv8Gd8P308hphtnHHJJpj/ONBF5UUSOiMhrIvKN2Ha2rB3FkplxogC+qaqzAMwDsFpErgBb1o5qyTRWagNwvsNop4gcAVCJPGxZG+/sF2yrtp89tN7URbBfIQ6qv3d1//1ZT71uHkt8pU9hGNEaJ9bveA6AV8CWtaNa0sERkYsB7ABwh6r+ZwTHrRSR/SKyvxtdwx9AeSGp4IhICfpCs1VVz/+sSlIta9mutjANu8YREQHwGIAjqvqDuIfyqmVt7zUfNvX8dXtNXVNiQx1sXbIqssDUl+460X8/euZMGkaYX5I5AfhRAMsBHBSRA7Ft69AXmF/G2te+CeCmjIyQQimZd1V/wNCdgdiydpTimWNyKdjPqnqut5893fPYZlNfWxb8cSA7qV6+dbWpa+4/ZOrezn+kNL58xxmHXBgccmFwyKVg1jhSUmrqGQ8eNvWFaxrrj132/6EZ60+YOtrZmcLoCg9nHHJhcMilYF6qOj87x9Q/nvLwEHv22XHWXne2YZ39ld9xbba9GlmccciFwSEXBodcCmaNU77ffgTwuWO2JeyOGtuC/94DnzF11XauaUaCMw65MDjkwuCQS8GscaInTtr6Ovv4YtgW+1U4mOkhFTTOOOTC4JALg0MuDA65MDjkwuCQC4NDLqKqw++VricTeRvASQAfAELbw5Vjs6pUdVJwY1aD0/+kIvtVtS7rT5wEji05fKkiFwaHXHIVnEdz9LzJ4NiSkJM1DuU/vlSRS1aDIyINItIiIsdEJKftbUVkk4i0i8ihuG2h6N2cD72lsxYcESkG8BMANwC4AkBjrF9yrmwG0BDYFpbezeHvLa2qWfkHYD6A5+PquwHcna3nH2JM0wEciqtbAEyO3Z8MoCWX44sb104AC8M0vmy+VFUCOBVXR2LbwiR0vZvD2ls6m8EZrI8g39Il4O0tnQ3ZDE4EwLS4eiqA00PsmytJ9W7OhlR6S2dDNoOzD0CtiFSLSCmApejrlRwm53s3Azns3ZxEb2kg172ls7zIWwTgDQB/B/DtHC84t6Hvx0260TcbrgAwEX3vVo7GbifkaGzXoO9l/G8ADsT+LQrL+FSVZ47Jh2eOyYXBIRcGh1wYHHJhcMiFwSEXBodcGBxy+T+hXTkLJRIRbwAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 144x144 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(xb[0].permute(1,2,0))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python_main",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}