Spaces:
Runtime error
Runtime error
File size: 49,730 Bytes
a1ece17 338bbe8 a1ece17 5fe67f2 a1ece17 d64050c 55abd01 c772c06 c32023c 338bbe8 c32023c 55abd01 a1ece17 55abd01 338bbe8 c32023c 0bbec58 9c06ef3 0bbec58 55abd01 338bbe8 55abd01 c32023c 55abd01 c32023c 338bbe8 c32023c 0bbec58 338bbe8 0bbec58 338bbe8 0bbec58 338bbe8 0bbec58 338bbe8 55abd01 338bbe8 55abd01 c32023c 338bbe8 c32023c 338bbe8 c32023c 55abd01 9c06ef3 55abd01 9c06ef3 8e35bc7 9c06ef3 8e35bc7 9c06ef3 55abd01 9c06ef3 8e35bc7 55abd01 6b67354 55abd01 9c06ef3 55abd01 5993d2f 056ab4f 55abd01 9c06ef3 c32023c b85505a 9c06ef3 b85505a 55abd01 5993d2f 55abd01 5993d2f 55abd01 338bbe8 55abd01 338bbe8 55abd01 338bbe8 55abd01 9c06ef3 338bbe8 9c06ef3 338bbe8 9c06ef3 338bbe8 9c06ef3 338bbe8 9c06ef3 53075d2 c004d97 55abd01 c32023c 9c06ef3 c32023c 338bbe8 c32023c 9c06ef3 338bbe8 9c06ef3 338bbe8 9c06ef3 338bbe8 9c06ef3 338bbe8 9c06ef3 c32023c 9c06ef3 c32023c 9c06ef3 c32023c 9c06ef3 c32023c 9c06ef3 c32023c 338bbe8 c32023c 55abd01 0bbec58 53075d2 48f5984 0bbec58 48f5984 0bbec58 87d9e93 53075d2 0bbec58 9490409 87d9e93 0bbec58 8e35bc7 0bbec58 c32023c 9490409 28478b9 9490409 a1ece17 48f5984 a1ece17 d64050c a1ece17 d64050c a1ece17 d64050c 5fe67f2 a1ece17 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 |
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"import torch.nn.functional as F\n",
"from datasets import load_dataset\n",
"import fastcore.all as fc\n",
"import torchvision.transforms.functional as TF\n",
"from torch.utils.data import default_collate, DataLoader\n",
"import torch.optim as optim"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"exclude"
]
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import matplotlib as mpl"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"tags": [
"exclude"
]
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"plt.rcParams['figure.figsize'] = [2, 2]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"tags": [
"exclude"
]
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Found cached dataset mnist (/Users/arun/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332)\n",
"100%|██████████| 2/2 [00:00<00:00, 75.69it/s]\n"
]
}
],
"source": [
"dataset_nm = 'mnist'\n",
"x,y = 'image', 'label'\n",
"ds = load_dataset(dataset_nm)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def transform_ds(b):\n",
" b[x] = [TF.to_tensor(ele) for ele in b[x]]\n",
" return b"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"tags": [
"exclude"
]
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAI4AAACOCAYAAADn/TAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAIsUlEQVR4nO3df2yU9R0H8PfHtrQroFJBVrGjHVRAweHWCASCJBuumiXOLAyYWTbjQiYy58Y2fmzZ5oILJgsJMjSRrCsmig7mAjFsZBIlLkNGdeBgrOWnWqnFwkDmUNrrZ3/0bPu59cfTz3P33NPr+5WQu89zd32+MW+/z/eeu+dzoqogGqgrsj0AGpwYHHJhcMiFwSEXBodcGBxyCRUcEakWkXoROSYiK9M1KIo/8Z7HEZE8AA0A5gNoBLAfwGJV/Wf6hkdxlR/itbcCOKaqJwBARJ4FcBeAXoMzTAq1CMND7JKidhH/blHVManbwwRnHIC3u9WNAGb09YIiDMcM+XyIXVLUXtRtb/a0PUxwpIdt/3fcE5ElAJYAQBGKQ+yO4iTM4rgRQFm3+noAp1OfpKpPqmqVqlYVoDDE7ihOwgRnP4BKEakQkWEAFgHYkZ5hUdy5D1Wq2iYiywDsApAHoEZVD6dtZBRrYdY4UNWdAHamaSw0iPDMMbkwOOTC4JALg0MuDA65MDjkwuCQC4NDLgwOuTA45MLgkAuDQy6hPuQcSiTf/qfKGzM68Gvrf1Bu6kRxu6nHTzhj6uKl9jty764bZurXq54zdUviA1PP2Lq88/7E778aeJwDwRmHXBgccmFwyGXIrHHyplSaWgsLTH36tqtNfWmmXTeUXGXrVz5j1xlh/PG/I0396K+rTb1v2jOmPtl6ydRrm+eb+rpXMt/ziDMOuTA45MLgkEvOrnES8z5r6nW1G019Q4E9NxKlVk2Y+qcbvmnq/A/sGmXW1mWmHvlOm6kLW+yap7huX8gR9o8zDrkwOOTC4JBLzq5xCuvtZeyvfVhm6hsKmtO2r+VNM0194j/2c6zaCdtMfaHdrmHGPvbXUPvPRqdqzjjkwuCQC4NDLjm7xmlretfUGx5dYOpHqu1nT3lvjDD1waUb+vz7a1pu7rx/7Au2YVTifJOpvzZrqalPPWj/VgUO9rmvOOKMQy79BkdEakTkjIgc6ratRET+LCJHk7ejMjtMipsgM04tgOqUbSsB7FbVSgC7kzUNIYH6HItIOYAXVHVqsq4HME9Vm0SkFMDLqjqpv79zpZRoXLqO5o2+xtSJs+dMffKZm019eG6NqW/95Xc671+7Mdx5mDh7Ube9pqpVqdu9a5yxqtoEAMnba8MMjgafjL+rYrva3OSdcZqThygkb8/09kS2q81N3hlnB4BvAFibvN2ethFFJNFyts/HW9/v+/s6N93T9csD7z2RZx9sTyDXBXk7vgXAXgCTRKRRRO5DR2Dmi8hRdPwIyNrMDpPipt8ZR1UX9/JQPN4eUVbwzDG55OxnVWFNWdFg6nun2Qn2t+N3d96/bcED5rGRz2Xmeu044YxDLgwOuTA45MI1Ti8S5y+Y+uz9U0z91o6ua5lWrnnKPLbqq3ebWv9+lanLHtlrd+b8XdRs4oxDLgwOufBQFVD7wSOmXvTwDzvvP/2zX5nHDsy0hy7Yq2dw03B7SW/lJvtV07YTp3yDjBBnHHJhcMiFwSGXQF8dTZc4fXU0nXT2dFNfubbR1Fs+vavP109+6VumnvSwPRWQOHrCP7iQ0v3VURriGBxyYXDIhWucDMgbay/6OL1woqn3rVhv6itS/v+95+Ttpr4wp++vuWYS1ziUVgwOuTA45MLPqjIg0WwvMxv7mK0//JFtN1ss9lKcTeUvmPpLdz9kn/+HzLej7Q9nHHJhcMiFwSEXrnHSoH3OdFMfX1Bk6qnTT5k6dU2TasO5W+zzt9e5x5YpnHHIhcEhFwaHXLjGCUiqppq64cGudcqm2ZvNY3OLLg/ob3+kraZ+9VyFfUK7/U5yHHDGIZcg/XHKROQlETkiIodF5LvJ7WxZO4QFmXHaACxX1SnouNDjARG5EWxZO6QFaazUBODjDqMXReQIgHEA7gIwL/m0zQBeBrAiI6OMQH7FeFMfv/c6U/984bOm/sqIFve+Vjfbr7fsWW8vvBq1OeUS4Rga0Bon2e/4FgD7wJa1Q1rg4IjICAC/B/CQqr4/gNctEZE6EalrxUeeMVIMBQqOiBSgIzRPq+rzyc2BWtayXW1u6neNIyIC4DcAjqjqum4PDaqWtfnlnzL1hc+VmnrhL/5k6m9f/Ty8Un9qce/jdk1TUvs3U49qj/+aJlWQE4CzAXwdwD9E5EBy22p0BOZ3yfa1bwFY0PPLKRcFeVf1FwDSy8O5f8kC9YhnjsklZz6ryi/9pKnP1Qw39f0Ve0y9eGS4n49e9s6czvuvPzHdPDZ62yFTl1wcfGuY/nDGIRcGh1wYHHIZVGucy1/sOh9y+Xv2pxBXT9xp6ts/YX8eeqCaE5dMPXfHclNP/sm/Ou+XnLdrmPZQex4cOOOQC4NDLoPqUHXqy105b5i2dUCv3Xh+gqnX77GtRCRhz3FOXnPS1JXN9rLb3P8NvL5xxiEXBodcGBxyYSs36hNbuVFaMTjkwuCQC4NDLgwOuTA45MLgkAuDQy4MDrkwOOTC4JBLpJ9Vich7AN4EMBqAv09IZnFs1nhVHZO6MdLgdO5UpK6nD87igGMLhocqcmFwyCVbwXkyS/sNgmMLICtrHBr8eKgil0iDIyLVIlIvIsdEJKvtbUWkRkTOiMihbtti0bt5MPSWjiw4IpIHYCOAOwDcCGBxsl9yttQCqE7ZFpfezfHvLa2qkfwDMAvArm71KgCrotp/L2MqB3CoW10PoDR5vxRAfTbH121c2wHMj9P4ojxUjQPwdre6MbktTmLXuzmuvaWjDE5PfQT5lq4P3t7SUYgyOI0AyrrV1wM4HeH+gwjUuzkKYXpLRyHK4OwHUCkiFSIyDMAidPRKjpOPezcDWezdHKC3NJDt3tIRL/LuBNAA4DiAH2d5wbkFHT9u0oqO2fA+ANeg493K0eRtSZbGNgcdh/E3ABxI/rszLuNTVZ45Jh+eOSYXBodcGBxyYXDIhcEhFwaHXBgccmFwyOV/atVD7hyCzrEAAAAASUVORK5CYII=",
"text/plain": [
"<Figure size 144x144 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"dst = ds.with_transform(transform_ds)\n",
"plt.imshow(dst['train'][0]['image'].permute(1,2,0));"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"bs = 1024\n",
"class DataLoaders:\n",
" def __init__(self, train_ds, valid_ds, bs, collate_fn, **kwargs):\n",
" self.train = DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate_fn, **kwargs)\n",
" self.valid = DataLoader(valid_ds, batch_size=bs, shuffle=False, collate_fn=collate_fn, **kwargs)\n",
"\n",
"def collate_fn(b):\n",
" collate = default_collate(b)\n",
" return (collate[x], collate[y])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"tags": [
"exclude"
]
},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([1024, 1, 28, 28]), torch.Size([1024]))"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dls = DataLoaders(dst['train'], dst['test'], bs=bs, collate_fn=collate_fn)\n",
"xb,yb = next(iter(dls.train))\n",
"xb.shape, yb.shape"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"class Reshape(nn.Module):\n",
" def __init__(self, dim):\n",
" super().__init__()\n",
" self.dim = dim\n",
" \n",
" def forward(self, x):\n",
" return x.reshape(self.dim)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
"def conv(ni, nf, ks=3, s=2, act=nn.ReLU, norm=None):\n",
" layers = [nn.Conv2d(ni, nf, kernel_size=ks, stride=s, padding=ks//2)]\n",
" if norm:\n",
" layers.append(norm)\n",
" if act:\n",
" layers.append(act())\n",
" return nn.Sequential(*layers)\n",
"\n",
"def _conv_block(ni, nf, ks=3, s=2, act=nn.ReLU, norm=None):\n",
" return nn.Sequential(\n",
" conv(ni, nf, ks=ks, s=1, norm=None, act=act),\n",
" conv(nf, nf, ks=ks, s=s, norm=norm, act=act),\n",
" )\n",
"\n",
"class ResBlock(nn.Module):\n",
" def __init__(self, ni, nf, s=2, ks=3, act=nn.ReLU, norm=None):\n",
" super().__init__()\n",
" self.convs = _conv_block(ni, nf, s=s, ks=ks, act=act, norm=norm)\n",
" self.idconv = fc.noop if ni==nf else conv(ni, nf, ks=1, s=1, act=None)\n",
" self.pool = fc.noop if s==1 else nn.AvgPool2d(2, ceil_mode=True)\n",
" self.act = act()\n",
" \n",
" def forward(self, x):\n",
" return self.act(self.convs(x) + self.idconv(self.pool(x)))"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
"def cnn_classifier():\n",
" return nn.Sequential(\n",
" ResBlock(1, 8, norm=nn.LayerNorm([8, 14, 14])),\n",
" ResBlock(8, 16, norm=nn.LayerNorm([16, 7, 7])),\n",
" ResBlock(16, 32, norm=nn.LayerNorm([32, 4, 4])),\n",
" ResBlock(32, 64, norm=nn.LayerNorm([64, 2, 2])),\n",
" ResBlock(64, 64, norm=nn.LayerNorm([64, 1, 1])),\n",
" conv(64, 10, act=False),\n",
" nn.Flatten(),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
"def kaiming_init(m):\n",
" if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):\n",
" nn.init.kaiming_normal_(m.weight) "
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {
"tags": [
"exclude"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train, epoch:1, loss: 1.8902, accuracy: 0.3183\n",
"eval, epoch:1, loss: 1.0976, accuracy: 0.6274\n",
"train, epoch:2, loss: 0.5929, accuracy: 0.8003\n",
"eval, epoch:2, loss: 0.2895, accuracy: 0.9102\n",
"train, epoch:3, loss: 0.2396, accuracy: 0.9264\n",
"eval, epoch:3, loss: 0.1343, accuracy: 0.9597\n",
"train, epoch:4, loss: 0.1139, accuracy: 0.9651\n",
"eval, epoch:4, loss: 0.0801, accuracy: 0.9763\n",
"train, epoch:5, loss: 0.1368, accuracy: 0.9582\n",
"eval, epoch:5, loss: 0.0882, accuracy: 0.9722\n"
]
}
],
"source": [
"model = cnn_classifier()\n",
"model.apply(kaiming_init)\n",
"lr = 0.1\n",
"max_lr = 0.3\n",
"epochs = 5\n",
"opt = optim.AdamW(model.parameters(), lr=lr)\n",
"sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)\n",
"for epoch in range(epochs):\n",
" for train in (True, False):\n",
" accuracy = 0\n",
" total_loss = 0\n",
" dl = dls.train if train else dls.valid\n",
" for xb,yb in dl:\n",
" preds = model(xb)\n",
" loss = F.cross_entropy(preds, yb)\n",
" if train:\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
" with torch.no_grad():\n",
" accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()\n",
" total_loss += loss.item()\n",
" if train:\n",
" sched.step()\n",
" accuracy /= len(dl)\n",
" total_loss /= len(dl)\n",
" print(f\"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {total_loss:.4f}, accuracy: {accuracy:.4f}\")"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {
"tags": [
"exclude"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"eval, epoch:1, loss: 0.0882, accuracy: 0.9722\n",
"eval, epoch:2, loss: 0.0882, accuracy: 0.9722\n",
"eval, epoch:3, loss: 0.0882, accuracy: 0.9722\n",
"eval, epoch:4, loss: 0.0882, accuracy: 0.9722\n",
"eval, epoch:5, loss: 0.0882, accuracy: 0.9722\n"
]
}
],
"source": [
"for epoch in range(epochs):\n",
" train = False\n",
" accuracy = 0\n",
" total_loss = 0\n",
" dl = dls.valid\n",
" for xb,yb in dl:\n",
" preds = model(xb)\n",
" loss = F.cross_entropy(preds, yb)\n",
" with torch.no_grad():\n",
" accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()\n",
" total_loss += loss.item()\n",
" accuracy /= len(dl)\n",
" total_loss /= len(dl)\n",
" print(f\"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {total_loss:.4f}, accuracy: {accuracy:.4f}\")"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {
"tags": [
"exclude"
]
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAj8AAAB+CAYAAADLN3DXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAc5klEQVR4nO3dd3yV1RkH8N+ThL0CGCDISBDCFEQUQaVOtG7UOhARq1itiijWWm21rlZxIVKUqihacWCduGitG1GGIgioTBkyIwHCkiSnf7yXc97nmje8SW6Sm7y/7+fDh+fkvOvek5ucnCnGGBARERFFRUpVPwARERFRZWLlh4iIiCKFlR8iIiKKFFZ+iIiIKFJY+SEiIqJIYeWHiIiIIqXKKz8iskJEjq/q56gsInKbiDxb1c9RUVieNQfLsmZhedYcLMvyq/LKT2nE3oA9IpLv+9ehgu95sYh8WpH3iLufEZHtvtf3RGXdu7KJSLqIPC0iG2L/bquEe1Z2eZ4mIt/EyvIzEelWWfeuTCJyrYgsE5GtIvKjiIwRkbQKvmellaWI7Cci00UkV0TyRGSGiBxRGfeuCuIZHXu9uSJyr4hIBd+zsj+bB4nIHBHZEfv/oMq6d2Xi783iJbTyU9E/7GJeNMY09P1bVgn3rGy9fK9veFU9RCWU5xgA9QFkAegLYKiI/LaC71lpRKQTgMkArgCQDmAqgDcq6XMS/ywVfc+pAA42xjQG0ANALwDXVPA9K1M+gEsAZABoCmA0gKlVUZZApZTn7wAMgleOPQGcCuDyCr5npRGR2gBeB/AsvPJ8GsDrsa9X9rPw92ZilOr35j4rP7HmtZtEZKGIbBaRp0SkbizvaBFZLSI3isg6AE+JSIqI/ElElsb+YpgiIs181xsqIj/E8v5crpdaCiJygIi8H7vvJhGZLCLpvvy2IvKKiGyMHfMPEekKYAKA/rHaZF7s2A9FZLjvXFXLFZGxIrIq9lfwHBEZUFmvc1+SrDxPA3CvMWaHMWYFgInwfsGEeR3VoTxPBPCJMeZTY0wBvF+Y+wM4KuT5JUqmsjTGLDXG5O29FIAiAB1Dvo6kL0tjzC5jzHfGmKLY6yuE90uzWclnhpdM5QlgGIAHjDGrjTFrADwA4OKQryPpyxPA0QDSADxkjNltjHkYXrkeG/L8EiVZWZbndVSHsiyTsC0/Q+D9ID8AQA6Av/jyWsH7AdAe3l8L18D7i+EoAK0BbAYwHgDEa/J/FMDQWF5zAG32XkhEjtz7RpXgNBH5SUQWiMjvQz4/4H1j3x27b1cAbQHcFrtvKoA3AfwArxVifwAvGGMWwfurfUasNpke8l6zABwE7315DsBLe7/xf/FQIvNE5IK4L38sIuti31RZIe9ZGslUnhIX9wj5GqpDeQp++fpK8xrDSJqyFJELRGQrgE3wWgz+GfI1VIeytF8DsAvAGwCeMMZsCHnfsJKlPLsD+NqX/jr2tTCqQ3l2BzDP6P2d5iH8awwjWcoS4O/NXzLGlPgPwAoAV/jSJwNYGouPBvAzgLq+/EUAjvOlMwHsgVfLvhXem7M3r0Hs/OP39Ryx47vBK4RUAIcDWAtgcJhzi7nWIABfxeL+ADYCSCvmuIsBfBr3tQ8BDC/pmLjjN8NrkgO8b5xnSzj2VwBqw+sm+QeAb4p7rrL+S7LyfBbAKwAawWslWApgd00pTwBdAGyPva+1AdwCr0XkpppWlnHP1QnAnQBa1ZSyjDunLoDBAIYlohyTsTzhtWx1iStTA0BqQnnGPosvxH1tMoDbamBZ8vdmMf/C9jWu8sU/xN7IvTYaY3b50u0BvCoiRb6vFQJoGTvPXssYs11EckM+A4wxC33Jz0RkLIDfAHh+X+eKSAsADwMYAO+XbQq8NxfwarM/GK9rotxE5HoAw+G9XgOgMYD9wpxrjPk4Fv4sIiMBbIVX456fiGeLSYryhPfXzjgAiwHkwivHwWFOrA7laYz5VkSGwfswZsKr7C0EsDoRzxWTLGVpGWMWi8gCAI8AOGtfx1eHsvSLvafPi8giEZlrjPl6nyeFlyzlmQ/vvdmrMYB8E/tNU5JqUp7xrw+x9LZEPFdMUpQlf28WL2y3V1tf3A7Aj/77xh27CsBJxph037+6xus3Xuu/lojUh9eEV1YGuluhJHfHju9pvEGZF/rOXQWgnRQ/8Ky4D/t2eAN192q1N4j1U94I4FwATY3X5LelFM9Z3P0TPcsiKcrTGPOTMWaIMaaVMaY7vO/HmSFPrxblaYz5tzGmhzGmOYC/wvshNyvMuSElRVkWIw1ec38Y1aIsi1ELQKJnzSRLeS6A13W5V6/Y18KoDuW5AEBPETWDrSfCv8YwkqUs4/H3JsJXfq4SkTaxAVg3A3ixhGMnAPibiLQHABHJEJEzYnn/BnBqrI+yNoA7SvEMEJEzRKSpePrCazl43Zf/oQRPl24Er7afJyL7A7jBlzcT3jfYPSLSQETqipvGuh5AG9GzAOYCOEtE6otIRwCXxt2nALHmQBG5Fb/8CyPo9XUXb/plqog0hDfIcA28JtFESpbyPEBEmsde70nw+r7v8uVX6/KMvYY+sdeXAW8MzFRjzLdhzw8hWcpyeOyvxL1jFG4C8D9ffrUuSxHpt/e9EZF6InIjvL/KvwhzfikkRXkCeAbAKBHZX0RaA7gewKS9mdW9POF1wRQCuEZE6ojI1bGvvx/y/DCSoiz5e7N4Yd/A5wD8B8Cy2L+7Sjh2LLzBgP8RkW0APgdwGAAYYxYAuCp2vbXwms9sF4CIDBCR/BKufT6AJfCaJp8BMNoY87Qvvy2A6QHn3g7gYHi1ybfgjTVB7LkK4c086ghgZeyZzotlvw/vr4F1IrIp9rUx8Ppc18ObIjnZd59pAN4B8D28ps5d0M2fingD0IbEki3hfUC2wnufswCcaozZE3R+GSVLefaB1yy5Dd5fGENi19yrupcn4L1/eQC+i/1/WdC5ZZQsZXkEgPkish3A27F/N/vyq3tZ1oE3ADUX3g/WkwGcYoz5Mej8MkqW8vwnvOUL5sMbP/EW9AD2al2expif4Y1fuQje5/ISAINiX0+UZClL/t4s7hr76sIVkRXwBim9V+KBVUxE2gB4yRjTv6qfJZmxPGsOlmXNwvKsOViWya9KFuiqCMaY1fBGn1MNwPKsOViWNQvLs+aIcllWq+0tiIiIiMprn91eRERERDUJW36IiIgoUlj5ISIiokgJNeB5YMo57BurYv8teikhCx2yLKteosoSYHkmA342aw5+NmuWksqTLT9EREQUKaz8EBERUaSw8kNERESRwsoPERERRQorP0RERBQprPwQERFRpLDyQ0RERJHCyg8RERFFCis/REREFCms/BAREVGksPJDREREkcLKDxEREUVKqI1NiYgSpeDYPja+e+IEldendqqNU0X/bVZoigKvecuGg2z80d2Hq7yGUz4vy2MSRU5aZisbrzmng8pLHbjJxqe2W6DyVu5sZuN5GzMDr/9+76dVemJedxu/d96hKq9wwXchnrjs2PJDREREkcLKDxEREUUKu72IKOHSstur9MIbW9r42RNcV1fv2vrvryIYF5vC0Pe7vcVXNh48ooXK2zYl9GWoGJsu76/SM28db+M+o6+2ccuZ+eq45YMa2PiLIQ+ovKap9W18wItXqLxOf3JlaXbvLsMTU0kKjuuj0k88OdbGDVLExi1SG6BM2pWUWU+lRjVbZuO3M45WeamoWGz5ISIiokhh5YeIiIgihZUfIiIiipSkHPOTtn9rG+cerTsQN5++I/C80zvNt/Hy7c1tvOjtHHVcxrw9Nq7z1qwyPydR1Eit2i7Ro5PKK6rrfpwc9/gnKu+1pq+Euv6Wol02PvSt61RenQ1uFMCI37yp8n7XZIWNj2mup8i+ldHFxoUbN4Z6jqhLbdzYxnXPXK/y/OOyZt04zsb35XZTxy3+6GgbbywyKq9hihvP9e2541VeTgM3BqjlR3rkR5PJXLagvGpt0eOoJuUdZuMFW9009fw9ddRx36924/ZqrdR5vm8JZM4oUFn1F+faeMX5rVReu3e22jhtvp4+r79jEo8tP0RERBQprPwQERFRpCRFt9ePN+gVWe+5/Ekbn1Bve+B5KRCVLgpqKLt6WuA1Jm7R3WqFvvrgfzfqZtyF0/WKl36tP3VNfQ3m/Rh4XLzCjW7VTE7rdFLTm9h445muHHafkaeOm9f3eRvHrwDce9YQGzea3FjlNXzpi0Q8ZuSsvv4QG381YlwJR4bzcv5+Kn3TJ2fbOOeKmSrPP32+39ClcVdy3SMf5HZWOezqKoMWbtjAhwe+FOqUG5ov1Omz/Om6oW/9/SluKYQPjtXn3bnzEhvXf1V/f8BUdEdJzWBmf6PSs884wMZ5E1yVYFh73cU47t1BNm5972eh7+dfsKLtnctUXlWWGFt+iIiIKFJY+SEiIqJIYeWHiIiIIiUpxvzUytc9fyWN8/Hr9tRVKt1ksYvz27rxQN1O/F4d93wHNwbosiarVJ5/3JB/+iwAoGPws6QMc/cLHHtUjEPuH2HjVmPC96PWNKnd9TiNonFuqfx72j9m48veGa6OO36sGwOQlqfHTBWc4sYNHX7TDJVX62Y3PuiDe/WYs8bPczptkB2ZwTurh3XM/HNs3OgPtVVezjezA89bN9AtgdGzdvDi95t2NlTpOtgUcCTtZfr3UumLJr2R0Ouft/TXKn1I+kobx48V8jum3i6dHveIjY81V6q8+q9yHF8YRUcepNJ3/Mv9fO1cy32++z46Sh2XNcGNFQq/8UzyYssPERERRQorP0RERBQpSdHt1fKzvNDH9vyn6ybKuiO4m6ipL86/R7/MQS1PC3Wvtafpnam3tw11WolqbdXT89s8MsfGUZuo6Z/OfsVrU1XemBUDbXzf+RfYuNOs4Kbt+PevrdscGgv/pQtv2f3pNj77xukqb87CHjYu+npR4P2iqPFS9/fSXzbo3aFvznBdi4MXn63yFn3XxsaZH7hrFH2juxi3Du7nEvqjgqEj3wl8Lv/K0BibEZe7IvA88uxqqVfsTU8NXkm/TNe/tJFKf9DS7RS/Z7zuwrx5v/kI48Cbv1bppa+W8eEiIO8i934/cvtYlefvQj7kvmtt3PYh/fu1JnR1+bHlh4iIiCKFlR8iIiKKFFZ+iIiIKFKqbMxPWpbbVqL/v74KPC5+imT2I27H5rB9kKZA7zJbsCbc9hMZE/Rx8SMJEiFK43xS6tdX6bTX69n4lgVnqLzWQ9xU2KLtP5T73gU/6CUN2p3j0p+c1l/lbb3D7TSceUkzlVeY+1O5n6U6a/mwGwcw7yk9jmNw66EusV5vKVH/9+5Hzabe7rt+wJ/09e9qoXf4DuvQqW4H+Jy3ZpZwJBVn/aF63E2P2rm+VD2EsblIT0sfMOkPNs5e+aXKS1nstjl4Z43eRijsmB8K74DLv7Vxnzpxy0t8NMzG2Q9FZ7kVtvwQERFRpLDyQ0RERJFSZd1eBRPdSpI3/WKFTzfHde2EA1TO9otcfa3Zd3sS/lz1PnbNg0XbtiX8+lG29Em9RPaijpNsfPohJ6u8gu3hVvkOK35V05Unuab8rD/r1Z/r+mbd17TpnYkU//lIqeeWExg+U3dln97g/Qp9lm73b7BxQQnHkZPaMdvGWf10t3DL1HBdXWsLd9r4pAl/VHlZf3ddKFHq3k8GRUf1Vunnsp+y8Q3rdF6HoQtsHKVyYssPERERRQorP0RERBQprPwQERFRpFTZmJ+3u7hdg0vaBf3j+4KnvqbErX9fmt3Ug67xQr6b0H7L53r6ddcb19i4YN36Ut8ritL2dztxLxowSeX1fvBqG2euTfwUS+nd3cb1/rZO5dXZ3DT+cCql1O6dVXrIi9NsfHqDzeW+/vg8Pd5v1S637MA9rWapvCW/zbRx1i0ryn3vmmj53XpJh4793BISb+ZMjT88lC93t7Jxm7+H/wzvHNTXxld2eKVM9z63mV7S4I/DLrdx06dnxB9e46U0cktPFNySG3jc6+/2U+nsgui9VwBbfoiIiChiWPkhIiKiSEmKXd23xK0MetbCC2289utWKq/pAgTa7Ho5kNlrXeBxx7T63sa3xq0mem5DN2X23OMfV3nj3utk4yde0CtPt70rOitjlsb3I9vbeNZu3S2Z+UBi37Nt5+vm3HNvedfGV6UvVXk9HxuR0HtH0eZeuuvQ/9kJa85unR7y2lU27jhK7/ie2t19/vAf3e314bD7bHzhByNVXtr7c0r9XDXRrwfOVukHMj8PODK8B2+4wMb1EH5l7dyu7lfP4EZlG0Lw3rbuKp3xnuvGi+JyBxvP72Hj2d0eVXnTd7mlZTreE/dLtLNbgmRrj+Y2Lqyth4SsG1CEIJkfuXaU9Pf1z9rCjRvjD08KbPkhIiKiSGHlh4iIiCKlyrq9Tu97qksU6ea0emuX27gDliOssPN3ZjZId8+RfmrgcUuubK/Sc4aNsfHg381TeYdnX2vjnEt183KU+Gd3AcBDZ7mVRS/7eqjKa434lb2LJ2m+b9MD9QyjXaPdStBrV+rG7of/d6KNs06arPKyx7qmX67iXDaNftDd1VevOdLGHyzrpPIyXnYrBtfNdeVUZ32+Oq7jgrJ1xeznW5F45a/1xo0dKnZx6aRWeMzBNj644ZsJuebIH4+wcaO5bnhBSV1NqRl6W2jTZ2vAkeH9tKeBSofdsLqmyusSPNs5PcX1L58wY6XKO7GBGx7QtbbefDq0QS58a0ddlTVqzrk2PuByvUl1Yd6Wst0vAdjyQ0RERJHCyg8RERFFCis/REREFClVt6t7FfbPFvl2DC8qYffwrD/rZ+y7Y5SNH7lkgsqbdvxYG1/b9WKVV7hocVkes1oqbKFHXp1Qz72/t9XSowLixwf5Lbs0y8a/OtXtEJ5eS0/T/GS0m97e+eUvVd6BX+yx8Z+fukjltcnj0gTlJdPnqvSKfqk2zi6ahzBKNd5q/SYbxq/+7F/KoFd//XnTe89Hy7Kzatl4SKO1ZbrGDesO09e8wr33ZkUJa4/4zxmvlyyZ339SmZ6FyqZ77Xq+eIXKW13gxtxO3OLK6W+zTg68Xq2VdVS6oL0b//fBUeNU3ncDnrFxh9uvUHmdRpZ/uYWyYssPERERRQorP0RERBQpSbHCc3XR9m+uq2TEbt18N3/UIzZedI3u+sn5fcU+VzJJWb5apc9c7JYS+Kz38/rYmW4F0fhNaQcuONvG3955oI3rTtWryDaCazZddo/euLHtHtf10n7iEpXH6e0VoKiC39UC122aX1i3hAOja9epfVX63dMf9KXK9p5NX9tBpZvNCdfVlZbtlgqZ0GdyCUeWzaz17VS6Gb4PODIast9w09l7ZA9RebuWu01PW+jF0ZE+262wXbjELS3TCXoYQVi/z7pApV/8dIqNz/mV7uaaW6Y7JAZbfoiIiChSWPkhIiKiSGHlh4iIiCKFY37KSOJWEi80brrgiAHvqbxpaFwZj5QU4pcrNye76eZ9L9E7qRvfpsGZE+eqvDo7VvhSKxAkrW0bG78z+D6VN2j8H23cej2ntlel1E6+cSO1awUfWIIdWe5zdGPz4D0rvvyyo0p3wqaAI2uewrp6J+7stLKN89lUuNPGuz7aLy433NiaFfc3tPERdfeUcGTZ7HfeGpUO3nM8GlI+ckuC7P9R+PMSPVJvR+cWKl1HyvZ5r2hs+SEiIqJIYeWHiIiIIoXdXhVg3MfHq3QOZgYcWfMV7dhh4xb/CO56KmuT9Y4n3arCJ3yiu9U6Peje9+D9jqk05FC37MBP3RuqvNxe7l3u2lvv3jwu263y2iatHhJtdG53d++H9ErGJe02XtM0maVXpf/VPLej9sc9p8QfHmhVoVvBt/V9wZ/btDb72zj3MV2uf+2YmF3k/Tq/cqWNc3bPSfj1qWwkzVUlhj/8qsqrJe5n9Etf91F5nVB1ZciWHyIiIooUVn6IiIgoUlj5ISIiokjhmJ9SSKlf38YF/bcGHtf8y9TAPCqf3accqtLPdB5j46FjR6k8UxCl0R4Vo/CYg1X6romP27hPnfijgy342U13PW/u+TZOTdGjvbo03WDjx9p+GPr6OXXdOJ9Plkd364uCH1ap9Jbph9t4dbedKq+ksVfDx420ccujdgYed9x4N6d6RNPFoZ+zJOt90+yPef4GlZdzq9tygZ/vqpOa3kSl97zslqEY0mi2ylu6J9/GXe/IVXlVWYJs+SEiIqJIYeWHiIiIIoXdXqWw7KZeNv6m/z9U3ujcbjbOePEblRf1lUfLKzUjw8bjxz+s8k58wq3i3O5NruKcaM3u1FPW/V1dX/2sv7MvG+u6Smrl68UFWrzmVgVuusl1j6iVnwHMPsNNpcd1H4Z+zmPruSne916md5Vu/viM0Nepadre5T4TYwYdo/IeyPw8/nBrzvXjKuyZ4p2w8CyVXjujtY073KrLjktWOCkNGqh0nu+zkz51gcor2rat3Nf/uV8XG2/9Q57Km9H15cDrXHzd9Tauv+yLUj9HRWHLDxEREUUKKz9EREQUKaz8EBERUaRwzE8c/9Te5afVVnkzzrnfxinQ00Qf//QoG+dsi+52FhVh6Ui3S/fc3W1UXvZzbopzoncnjqrU7p1t/GC7iXG57vv+vPeuVDldJ7lxBkXZupyWjswp9l6XnzlNpUuaLl3kGz23vnC3ymsk7u+4VkNWqLw9j4OS2LYprVW6fYTHaJXGkr/2VOnFFz5q44GXnqbyNj/Xw8YtPl4feM1vr3HjK/90/FSV97sm0wPPW13gprMP9I3DBIB2ryXPOB8/tvwQERFRpLDyQ0RERJFSY7u9/Ksxo1N7lbf0vHQbjxr0hsr7TSM3xbNJil4pdnWBm2h5xOSrVF7XB5bYmN0v5edfQfS6s1wZ/X3Seeq4Nks4vT3h1m+y4ZStumnd3y31/ckT9HknJ/Yx4qfSX3HvNTbOeFR3jRQdeZCNa+VuT+yD1BBLTmmqv/Bl8cdVhL9s0Lt5v/pufxt3eGqWyuN09nA6PrtZpU8/7Nc2npLzospremd9lNYeo3+TXbv2MBu/Mb+XyssZ/7ON282qHj+T2fJDREREkcLKDxEREUUKKz9EREQUKVU25ufn/7pxOPm79fbQ8u/moa6xsZ/uk2zfwe0IfVDz1Ta+r9WzgddIgaj0Zt8wg5xpl6u8Lg/vsHGHuXrMAcf5JNaie93U6EvTvrJx+9c3qeP4vide4Sa38/LLtw9UeYMfmGfj/VKDdwWP9+BPbmn8x+YeGXicyXPLS3R55CeVl7EoeAp0yqdzbczvieIV5W1R6cPmuG1AvujzXMLvN3DB2Tauf4HeXiF7kytLjvEpm6J536r0brfaCi7oeYnKW3FmMxuPGaqXr6gre2z8249/a+O2r6eq4+q95pZw6YQ5Kq86liFbfoiIiChSWPkhIiKiSKmybq/3ur1q46L4RrPe4a4R32Xlv87ygl02vn7t0eq4adMOsXHmpwUqr8HcVTbOWTs77vpUUeJ3EB599BQb3/bYhTZuvbB6TKOsKRq+pFdnvfil4C6rsDriq30fBHZfJZrZrVfFzrzSrcp7bF+9Wvdlf3e7dA9uFLwi8KH3jLBxozW6xBp/tc7GBb6uVKp48V1i7VxvNcbc3jXwvJy47qyajC0/REREFCms/BAREVGksPJDREREkVJlY36Ou8xNIy+op+tgG87ZaeNmjXeovIuz3BTJB+cdp/JSFzW0cfu33dRKM2u+Oi4LwVNmCwJzqCKtvlIvl96l9v9s3PZtN72d40CIEqNg9Rob1/fFADD5lTYuRhsEaYngMXj8WUrJjC0/REREFCms/BAREVGkVFm3V5233U6+deLysl9GoFeR4Y7DvMDjquOKk1E26cqHVPo3L1xn4+yFwd2UREREpcWWHyIiIooUVn6IiIgoUlj5ISIiokipsjE/RH43Z/dV6ewSliMgIiIqD7b8EBERUaSw8kNERESRIsZwUjgRERFFB1t+iIiIKFJY+SEiIqJIYeWHiIiIIoWVHyIiIooUVn6IiIgoUlj5ISIiokj5P0TtsIzRFepPAAAAAElFTkSuQmCC",
"text/plain": [
"<Figure size 720x720 with 5 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"xbv,ybv = next(iter(dls.train))\n",
"logits = model(xbv)\n",
"probs = F.softmax(logits, dim=1)\n",
"idx = 5\n",
"_,axs = plt.subplots(1, idx, figsize=(10, 10))\n",
"for actual, pred, im, ax in zip(ybv[:idx], probs[:idx],xbv.permute(0,2,3,1)[:idx], axs.flat):\n",
" ax.imshow(im)\n",
" ax.set_axis_off()\n",
" ax.set_title(f'pred: {pred.argmax(0).item()}, actual:{actual.item()}')\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {
"tags": [
"exclude"
]
},
"outputs": [],
"source": [
"torch.save(model.state_dict(), 'classifier.pth')"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [],
"source": [
"loaded_model = cnn_classifier()\n",
"loaded_model.load_state_dict(torch.load('classifier.pth'));\n",
"loaded_model.eval();"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {
"tags": [
"exclude"
]
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAj8AAAB+CAYAAADLN3DXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAaPklEQVR4nO3dd3hVRfoH8O+bRopEOpEiCCSAiiKIgg1RdBXrsmtDEVQUxMK6Cuy6P8W26C6yrqhrxQa6CIrYXRFsKAKiSBMQpKj0SO9J5vfHOZk5c5/ccJLcfr6f58nzvJM55565d3Jv5k47opQCERERUVCkxbsARERERLHExg8REREFChs/REREFChs/BAREVGgsPFDREREgcLGDxEREQVK3Bs/IrJKRHrGuxyxIiL3iMj4eJcjWlifqYN1mVpYn6mDdVlzcW/8VIWI/ElEfhKR7SKyVkQeEZGMKF+zv4jMiOY1PNdqICJfikixiGwVkZkicnIsrh0PItJDRD4RkW0isipG14xZfbrXe0ZElopImYj0j9V140VEskRkiYj8EoNrxbQuPdftJyJKRAbE+tqxIiIfiMhOz89+EVkQ5WvG8rP21JDnt9Ot0z/E4vqxJiKdRORz93luEJEhUb5erD9n00XkAbddsENEvhOROpWdE9HGT7QbIgDeAdBJKZUP4GgAxwK4NcrXjKWdAK4F0BBAXQD/APBODF7XCsXgursAPA9gaJSvE0/fAxgM4Nt4FiKGf0NDAWyM0bViTkTqAvgrgEVxLkdU61Mpda5S6pDyHwBfAZgUzWvGklLqi5Dndz6cz98PY12WGHyBbwDneT0NoD6ANgA+iuY14+BeACcB6AYgH0BfAHsrO+GgjR+3e+2vIrJYRLaIyAsiku3mnS4iv4jIcBFZD+AFEUkTkb+IyAq3B2OiiNTzPF5fEVnt5v2tKs9OKbVCKbW1/KEAlMGpyIMSkdYiMt297mYRecXbMhSR5iIyWUQ2ucc8LiLtATwFoJvbYt7qHvup91tfaCtXRB4VkZ/dHqq5InKqz+e3Vym1VClV5j6/UjiNoHqVn+lfgtXnbKXUOAA/VeN5JHx9us/xCaXUNBzkjVgdiVSX7vlHALgKwINVPC8p6tL1IIAxADZX8byDSrT69DxOSwCnAhjn8/hkqs9y/QC8rpTaVc3zLQlWl38G8D+l1CtKqX1KqR1KqR98Po+Er0txvpD8CcD1SqnVyrFQKVWzxo/rSgC/A9AaQBGA//PkFcD559wCwA1wemIuBtAdQBMAWwA84RbySABPwmmVNYHTCm3meRKnlL9Q4YhIHxHZDufD51g4rVk/BM4HVxMA7QE0B3CP+5jpAN4FsBpASwBNAUxw/0AGAZjpfkOo4/NacwB0hPO6vApgUvkffgXPZ76I9An9HZx/lm8DeE4pFelv0glTnzWQNPUZZYlUl48BuBPAnio+h6SoSxE5AcDxcD7YoyWR6rPc1QC+UEqt9Hl8UtSn5/e5AP4I4CWf1/QrUeqyK4DfROQrEdkoIu+IyOE+n0My1GUHACUA/igi60VkmYjcdNCrKaUq/QGwCsAgT7oXgBVufDqA/QCyPfk/ADjTkz4MwAEAGQDudl+c8rw89/yeBytHBeUqBHA/gIKqnuuefzGA79y4G4BNADIqOK4/gBkhv/sUwIDKjgk5fguAY934HgDjfZQvG8AVAPpV5/klU30C6AlgVQ2fV6LX5wwA/VO1LgH8HsCHnmv/kkp1CSAdwDcAulV0nVSrz5ByLa/J324i1mfIOX0BrAQgqViXAJYB2AqgC5z/K2MAfJkqdQmgDwAFYCyAHADHuOU6q7Ln4nes8WdPvBpOK7DcJmV3L7UA8KaIlHl+VwqgsXuefiyl1C4RKfZZBotS6kcRWQTgPwB6H+x4EWkEp9JPBVAbTq/XFje7OYDVSqmS6pSlgmvdDmAAnOer4IxBNqjKY7iv6X9F5AcRmaeU+j4SZXMlXH1WVbLVZxTFvS5FJA/AP+F8wFdZktTlYADzlVIzI1GOSsS9Pr1E5BQ4vRSvV+GcZKhPr34AXlbuf9IISpS63APgTaXUHAAQkXsBbBaRQ5VS2yo7MUnqsryn+T6l1B4A80VkApzPo6nhTvI77NXcEx8OYK0nHfoH8zOAc5VSdTw/2UqpXwGs8z6W291Y32cZKpIBp0vRjwfdsh6jnAnTV8Hp0isv8+FS8cSzit4QuwDketIF5YE7TjkcwKUA6iqny2+b51pVlQmgVTXPDSdR67MqkrU+Iy0R6rIQTrf3F+4chskADnO7oFv6OD8Z6vJMAL93n9N6OJMrR4vI4z7OrYpEqE+vfgAmK6V2VuGcZKjP8sdoDqcn5mW/51RBotTl/JDrlcd+XqdkqMv5lVwzLL+Nn5tEpJk4E7DuBPBaJcc+BeDvItICAESkoYhc5Oa9DuB8d4wyC8B9VSgDRGSA2xItHwf9K4BpnvxPReSeMKfXhjObf6uINIW9wmg2nD+wh0QkT0SyxSwx3wCgmVvecvMA9BaRXBFpA+C6kOuUwO0OFJG74bRg/Ty/ruWvjYjkiMhwOC3/WX7Or4JEqc80d0w300lKtvd1Tvb6dJ9DlvscBUCmW5ZIrrJMhLpcCOfDuaP7MwDO69wR7jfWFKjL/nDmPJQ/x2/grDCp9kTiMBKhPuE+Xg6ASwC8WEFestdnub4AvlJKrajieX4kSl2+AKfh3lFEMgHcBWe4aat7raSuS7fuvgDwNxGpJc6E68vgzEcKy+8L+CqcpXE/uT8PVHLso3Am6n4kIjsAfA3gRLeQiwDc5D7eOjjdZ3o/EHH3XqjksU8GsEBEdgF43/2505PfHMCXYc69F0AnOK3J9+B8O4VbrlIAF8BZObbGLdNlbvZ0OMta14tI+QqPR+CMuW6AM0nuFc91/gfgAzjjrKvhTFz2dn9aRGSRiFzpJmvBmeRWDOBXON125yml1oY7v5oSpT5Pg9Nl+T6cb0Z7YC/BTPb6hPt89sDpKXjGjU8Ld341xL0ulVIlSqn15T8AfgNQ5qZL3cOSui6VUltDnuN+ANsPNmxQDXGvT4+L4dTJJxXkJXV9elyNyE90LpcQdamUmg7n/+R7cLahaANnnky5VKjLK+AMHRa75bxLOatsw5KDDXOKs/ncAKXUx5UeGGci0gzAJKVUt3iXJZGxPlMH6zK1sD5TB+sy8cVl87xoUEr9Amf2OaUA1mfqYF2mFtZn6ghyXSbV7S2IiIiIauqgw15EREREqYQ9P0RERBQobPwQERFRoPia8HxW2iUcG4uzqWWTIrKpHusy/iJVlwDrMxHwvZk6+N5MLZXVJ3t+iIiIKFBSZqk7ERERxV7GES10PGK6fRu4y9++RceFQ76OWZkOhj0/REREFChs/BAREVGgsPFDREREgcI5P0RERFRtW59M1/FxWXafSq3fErOPJTFLRURERBQlbPwQERFRoHDYi4iIiPw7oYOVHNN2rI43lO638lq+tVXHZVEtVNWw54eIiIgChY0fIiIiChQ2foiIiChQUmvOj2ccctk1OTp+85zHrMM6ZGXqOF3s9l/v5WfpeMddzay8tM++i0gxiYiIklXOqI1W+pgss9S9y0PDrLzG876KSZmqij0/REREFChs/BAREVGgJPWw17rbT7LS9w16Wcfn5W7T8du7GlrHzduXreO0kMV3r7X+UMfH3tLPymv+WfXLStF3+/JFVvrMnH06Dh3ePK/zOTouWbc+ugUjIkpym2/opuOPW4228qbtqavjJhOXW3ml0S1WtbHnh4iIiAKFjR8iIiIKFDZ+iIiIKFCSbs5PRkFjHY8YON7K887zOXL8zTouHLXMOq50c7GOJTPLyps4tYuOJx3/rJU3oM9tOs5/9euqFJti4ADSrXQZlIlVyMizUqDgKb6um53uWqLjomf32gfPXhCLIhElJO8cHwCYe8+TOl60X6y8Mef00nHphpXRLViEsOeHiIiIAoWNHyIiIgqUpBv2WnFjax1fmPe+lXf24t46LnzYLLfzDnOFUgfsO9AuXnOYjosK7SGx4gt36zj/VZ8FpqhKyzbbFqQj/FDWUeNuttKtiudGrUxBlV6/no53n9Daysv5zGxDULZ7N2Jp37lmKHtLD3toa0mPZ3T8SY9DrLxH2rSPbsESmLcuMybbn4OFh5jdfRfcbHbV//X0POu4Wt3M5+6szpH/wDzlL/Z7us64mRG/RtDIcUfp+OO77OXspcp81t74pyFWXs7y2dEtWBSw54eIiIgChY0fIiIiChQ2foiIiChQkm7OT+YuE/93R2MrL+eKnTqubJ6PV8mZna30hFOf0vEPB+zl0S0eZ1sx3tLbtrHS+c//puMzc8LPJclfYadD53pRzf3at52OZw991Mq7d6N5n809LrrvI+l8lJU+aeQsHY9oZM/1+mC32ZZ/xONXW3kFSMy7UUdDxmEFVrp4rJn/9Fmb/4Y97+Nxi3XcM2eHlZfm+W5dFnIboUjY3Xubla4zLuKXCISMpk10fM74GTo+JK2WdVzhmzfquOhd+32UjBuH8L85ERERBQobP0RERBQoSTfs1fRfZknda5NOtvJKN6/29Rhl3Y/T8SNjn7Dy2mdm6vjoGddZeS1nzPNbTIqg9MaNdJzxzE4rb1zLqbEuDoWxp3H4zm/vcFOvMwbpOGN6ZLYc8A51DZ7wppV3bq4ZjgkdfBk212yP0Wb8EisvUe9GXW1i78qb3qCBjre9kGvlfdYh/FCXV+hQl19rS/bpONMuFhqm1wJFT+hdDZY8ZKaPTKnzjo7/UWwPHxfebIaPk3GYKxR7foiIiChQ2PghIiKiQGHjh4iIiAIl6eb8qBJzF+aSleHn+EgtM27849gjrbxZ3R8Pe17PQbfq+IgP59nX9ltIiqw6+Tp8o81rvk/r+u0VOm48ZbmVl3LzOeLhhA5W8o6L39JxWsj3qk/2mK3xIzHPR53c0UoPe9mscz4t297G4IAytd1r8aVW3hGXz9dxKv5NeLeGWDqogZW3+NLHavz4M/eaz9lr37/ezvTO5Qn58Gw90dTR6vOyrbwFfcf4unb69Dq+jiPbxuvs7V2WnGH+H64sMbd/mXmeva0I8Es0ixVz7PkhIiKiQGHjh4iIiAIlIYe90vPNMIfUr2vlbTzd7EZZd6m9o++P15pl6k+f/pKOe+SE7tRqulmvXHm2lZP3mVnuWspdgBPC2rMbHfwg2F22ALD/c9PNX7ppWUTLFFQlZ5gu8xNHz7Hy+uWbYei5++zvVSPuMttG5OPral17x2VddfyXB1628k7JNnUfupzdO9RV6+xV1bp2stpVaO7OXpVhruHru+l4wZYmVl7m0No6Tt9qtp4oXDkLfpWe3sk8RuudlRxpaz/Z3Mm98An/1wu67X3Me2fOXfb2Lr+W7tFx/+F36Lj2z9V7nyYL9vwQERFRoLDxQ0RERIHCxg8REREFSkLO+Vl1y9E6/n5wzZdjVuaVIz6y0kM/OVHHU1fZ23vXG2/udJz7JsebY+WWwZN9Hdf7yaFWuumo4NyVO1aGPG1ue+C9bQRgz7X5YMcxVl7dj8ycq+ouKd/TZ6uOf5e7Lexxz21rZaWDNs/Hr6UHTE3cvuISKy/rWhNnrF5j5XlXrZfAH9XtWCt923Ov6riyW2R4l9IDQNvntuu4rCwVNyeIjuxr1um4LGTfge7Thui4aEJqz/PxYs8PERERBQobP0RERBQoCTns1XJKsY779epp5eVl+Ft+/tUU083asMdaK2/1yoY67tlxsZV3eQPT7Teqqz209V0n07F/u9xs5eVO5jBYvB3+xnorzU7xmts/tYWV7p4z05Oy7w5998YuOl54vr08urTYfg/6sXbYSVZ6Ssd/elL2cIh3qOu9i08IeaQVVb52qshbZj5LO7x0q5VXMNu8Q3KmzLby/A5nVWb/OebvYfPAXVae37vBX/PJtVa66Ptval6wANhwi/3emXuk2cX5gc32zuxtBy3Usd+7GGQ0b2alSxscqmMptT95y+YvQSJizw8REREFChs/REREFCgJOexVumipjotPtvOK4U8zeFb6PGjnFWGVju11DMDDx5hVDzcMy7HylvR4TsejR9s3R72j7CYdh3YhU9XtO9d0mbfOGhvHkqQ+ybSHr9SHZkftj9qFrrQzxz6zraWVM+84byr8MJf3Zpt7W9Sx84Zv0PG37eyVnpliVlt6b1YKAP8Zf4GOmy3jKr9ypcvMkN8Rd8Z2+G/jcWbH/W9PeKmSI21HTrxFx22HfWvl8ebS4aXl5em4zaXhd7Qf/0F3K91qnxnKTm9opoRsvNi+sam60Pz3HX3k61beydkHdLxb2VNTRmw4VceLhhxt5aXNmBe2nNHGnh8iIiIKFDZ+iIiIKFDY+CEiIqJAScg5P/HkXZbX9u+FVt60rrk6PjPHvqP8ukvNOGerKdEpWyo7cPbxVvry0e/r2DueHKr9pwN0XLQhuEuaa+LnO+zX/tt2j+o49A7pXltK8qz0+iGe5bViH1vrrE06/rLjBM/jh79CaM4Bz4SP0PlGLV9apeNILNOmqksvtHfWHt5voq/zQndxbjfG7EZccsDf1iYErP6z2d7l+1ahd0Ywb8jWr223cpY+abaGePdc894vCpkLmOZ5jDn77NlXnWb31fGzx46z8kYVmG1g5r5sbwlzX5ezdFxa/BtiiT0/REREFChs/BAREVGgcNirEqU//GilB7/fX8dL//AfK49LMGtmW8tMK33doaGbEDiWH9hnpetNzdZx6fbtoYeTD2MGPF2t84bXX2Slhw5b4PPM6n3nGunZmfbdR+3luvV+nRl6OMWAd6ir+2S7/q+obbYtCB3C7PC5Ga5u+qL93s9ayV2cq0N5hprTQsad08W859571x6W8tpeZv6TtZ12vZXX5C1TT3lv2MNXTWDulHDN3bdYeQsHmm1humbb733JsofWYok9P0RERBQobPwQERFRoLDxQ0RERIES3Tk/aek6TD/EXhabjPMzCr7yjKP+IX7lCJo1JXt0fPkjw6y8ghd5K4OaGjSrr5Ve2P3ZqF5vxEZzH4wRjeb6Pm/aiFN0XG8K5/gkglWXFuj4rXqTrLxMMZ//H+3OtvKajTX/ejI/5hyfSGg8x2wJUjYwZBaqMrOuykJmqM7eZ/6vDRt+m44LJ9nzevw6qdd8K+293oA1p9p523dU6xqRwJ4fIiIiChQ2foiIiChQojrsVdrd7Dj5t7HPW3kjL7lSx2quvWQ2UaQXtbbSw+4fH/ZYCZtDNbVgv+laL3iUw1yR1qrPPCt9Wt9bdbytjf2XnXXsFh0PLJph5T026QKE02qM2Tn9h4fM++reXt+FPafXkoutdM6U2WGPpdhZc7fZyfv96/6p4zLYOzV7h7pGDulv5dX6eE50ChdguV8u1fF1a3pYeSOamB3zD8/IsfKOzNyr4y1tzVBl7eOO8n3tH6+ureMXmzxs5a0sMZ8hv9xh/09N2zXP9zUijT0/REREFChs/BAREVGgsPFDREREgRLVOT8/9TbbYXerVWrlbeySr+OG/le7Rl1athmnXn9GIyvvvNxtYc+r90FO2DyqmHjuGryzRRwLQpY648wy8jqVHPc26lvpFgg/H8v77q/buKGOQ+/q/t0+830s61r7MXi39vhIy7O3KZl8rZnT0SSjVujh2sDp/XVc9B7n+ESbd/uYTde3s/KmTSrS8VW1V1l5h6SZOvz+Rs/d4G+0H997y4zQ5fJeB5R9y4rz7x+i4wYzEmeLCvb8EBERUaCw8UNERESBEtVhr7ZjPbs39rbznho+Rse3b7jZyst9s3o7S/qV3raNjpcObGDlnXaSWXb/dvPHEU7nOVdZ6WZvmfNKQw+mCqU3Mq/9wv7hX2tKftLZLJud1fllHYfe7bvPOzfpuHD119EuFoWx+YZuOm7bb4mV1yozM/TwChXdwKGueClbaNfZG+3NFI6nBl1k5eX3XqfjqUe94evx5+6z01d+OUDHh71lD3s1eD1xhrq82PNDREREgcLGDxEREQUKGz9EREQUKFGd85O2bZeOp+3JtfLOzNmt43//6zErb8JdJ4Z9zNe/7qLj+t+lhz2uuLOZeXNUu5+tvAFNP9RxZcvXlx3Yb6WvHvlnHTebuNjKS8a71BPFyu7meRX+/pM9h1jpdk//pmPOnYudjJaHW+nTB5p5lyMLws/BXH7AbEBw6/LL7MfEmgiVjiKp4VMhc3CeMuH56Fytx2yD8LepSVTs+SEiIqJAYeOHiIiIAiWqw14lK1fr+LGzzrHybh7cRMev/fFRK29k42/CPubIizx5F4U9zLejPre3kc2fbobnGn220cprsMx0F7JLPnae7OvdJ2F+3MpB/qXn51vpEaPH6jhTzHD14P/1t44rXBzdbS7ISG9odtpOf8ke4q9sqOuXErPOefiqS3Sc0ZPDXJQ82PNDREREgcLGDxEREQUKGz9EREQUKFGd8+Plnf8DAK2HmvQd026y8lZdaO4eO+uCR6y8UZtO0XFlc4M6ze6r453r7OW07f9drOMjfgyZQ6LM3Wo5rycxpO01y2lDb4dACaqgoZWsnbZXx5tLzfySptNiViIKsfx2c5ufBW3GVHKk7bIHhuq4/rOJeesCooNhzw8REREFChs/REREFCgxG/aqTNaH9t1/i8wGzOg7+OSQo82wVGW7UTbB4rB5HM5KDKWbzfBju4n20GdZnqml9j8tjVmZKDJKl62w0utLDtXxw7+aO4bnTubS9ljZeYm9c/70PqM8qVphz7t/Uycr3XDCQh1zGJqSFXt+iIiIKFDY+CEiIqJAYeOHiIiIAiUh5vxQMKl9Zpv8Nrd9HfY4ztFKfk//7mwdq23b41iSYCk7paOOH/zH01Zew/Tw83yOef5WHbceZ9/mp2zHitDDiZIOe36IiIgoUNj4ISIiokDhsBcRRV3JT6viXYRAylq7VceTtxxv5Z1YYHZnHlXcwcpr/epmHYduW0CUCtjzQ0RERIHCxg8REREFChs/REREFCic80NElKK8c61+CLkb0IXoUsmZP0alPESJgj0/REREFChs/BAREVGgiFLq4EcRERERpQj2/BAREVGgsPFDREREgcLGDxEREQUKGz9EREQUKGz8EBERUaCw8UNERESB8v9D4PCrWQgE7QAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 720x720 with 5 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"with torch.no_grad():\n",
" xbv,ybv = next(iter(dls.train))\n",
" logits = loaded_model(xbv)\n",
" probs = F.softmax(logits, dim=1)\n",
" idx = 5\n",
" _,axs = plt.subplots(1, idx, figsize=(10, 10))\n",
" for actual, pred, im, ax in zip(ybv[:idx], probs[:idx],xbv.permute(0,2,3,1)[:idx], axs.flat):\n",
" ax.imshow(im)\n",
" ax.set_axis_off()\n",
" ax.set_title(f'pred: {pred.argmax(0).item()}, actual:{actual.item()}')\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {
"tags": [
"exclude"
]
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAj8AAAB+CAYAAADLN3DXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAaPklEQVR4nO3dd3hVRfoH8O+bRopEOpEiCCSAiiKIgg1RdBXrsmtDEVQUxMK6Cuy6P8W26C6yrqhrxQa6CIrYXRFsKAKiSBMQpKj0SO9J5vfHOZk5c5/ccJLcfr6f58nzvJM55565d3Jv5k47opQCERERUVCkxbsARERERLHExg8REREFChs/REREFChs/BAREVGgsPFDREREgcLGDxEREQVK3Bs/IrJKRHrGuxyxIiL3iMj4eJcjWlifqYN1mVpYn6mDdVlzcW/8VIWI/ElEfhKR7SKyVkQeEZGMKF+zv4jMiOY1PNdqICJfikixiGwVkZkicnIsrh0PItJDRD4RkW0isipG14xZfbrXe0ZElopImYj0j9V140VEskRkiYj8EoNrxbQuPdftJyJKRAbE+tqxIiIfiMhOz89+EVkQ5WvG8rP21JDnt9Ot0z/E4vqxJiKdRORz93luEJEhUb5erD9n00XkAbddsENEvhOROpWdE9HGT7QbIgDeAdBJKZUP4GgAxwK4NcrXjKWdAK4F0BBAXQD/APBODF7XCsXgursAPA9gaJSvE0/fAxgM4Nt4FiKGf0NDAWyM0bViTkTqAvgrgEVxLkdU61Mpda5S6pDyHwBfAZgUzWvGklLqi5Dndz6cz98PY12WGHyBbwDneT0NoD6ANgA+iuY14+BeACcB6AYgH0BfAHsrO+GgjR+3e+2vIrJYRLaIyAsiku3mnS4iv4jIcBFZD+AFEUkTkb+IyAq3B2OiiNTzPF5fEVnt5v2tKs9OKbVCKbW1/KEAlMGpyIMSkdYiMt297mYRecXbMhSR5iIyWUQ2ucc8LiLtATwFoJvbYt7qHvup91tfaCtXRB4VkZ/dHqq5InKqz+e3Vym1VClV5j6/UjiNoHqVn+lfgtXnbKXUOAA/VeN5JHx9us/xCaXUNBzkjVgdiVSX7vlHALgKwINVPC8p6tL1IIAxADZX8byDSrT69DxOSwCnAhjn8/hkqs9y/QC8rpTaVc3zLQlWl38G8D+l1CtKqX1KqR1KqR98Po+Er0txvpD8CcD1SqnVyrFQKVWzxo/rSgC/A9AaQBGA//PkFcD559wCwA1wemIuBtAdQBMAWwA84RbySABPwmmVNYHTCm3meRKnlL9Q4YhIHxHZDufD51g4rVk/BM4HVxMA7QE0B3CP+5jpAN4FsBpASwBNAUxw/0AGAZjpfkOo4/NacwB0hPO6vApgUvkffgXPZ76I9An9HZx/lm8DeE4pFelv0glTnzWQNPUZZYlUl48BuBPAnio+h6SoSxE5AcDxcD7YoyWR6rPc1QC+UEqt9Hl8UtSn5/e5AP4I4CWf1/QrUeqyK4DfROQrEdkoIu+IyOE+n0My1GUHACUA/igi60VkmYjcdNCrKaUq/QGwCsAgT7oXgBVufDqA/QCyPfk/ADjTkz4MwAEAGQDudl+c8rw89/yeBytHBeUqBHA/gIKqnuuefzGA79y4G4BNADIqOK4/gBkhv/sUwIDKjgk5fguAY934HgDjfZQvG8AVAPpV5/klU30C6AlgVQ2fV6LX5wwA/VO1LgH8HsCHnmv/kkp1CSAdwDcAulV0nVSrz5ByLa/J324i1mfIOX0BrAQgqViXAJYB2AqgC5z/K2MAfJkqdQmgDwAFYCyAHADHuOU6q7Ln4nes8WdPvBpOK7DcJmV3L7UA8KaIlHl+VwqgsXuefiyl1C4RKfZZBotS6kcRWQTgPwB6H+x4EWkEp9JPBVAbTq/XFje7OYDVSqmS6pSlgmvdDmAAnOer4IxBNqjKY7iv6X9F5AcRmaeU+j4SZXMlXH1WVbLVZxTFvS5FJA/AP+F8wFdZktTlYADzlVIzI1GOSsS9Pr1E5BQ4vRSvV+GcZKhPr34AXlbuf9IISpS63APgTaXUHAAQkXsBbBaRQ5VS2yo7MUnqsryn+T6l1B4A80VkApzPo6nhTvI77NXcEx8OYK0nHfoH8zOAc5VSdTw/2UqpXwGs8z6W291Y32cZKpIBp0vRjwfdsh6jnAnTV8Hp0isv8+FS8cSzit4QuwDketIF5YE7TjkcwKUA6iqny2+b51pVlQmgVTXPDSdR67MqkrU+Iy0R6rIQTrf3F+4chskADnO7oFv6OD8Z6vJMAL93n9N6OJMrR4vI4z7OrYpEqE+vfgAmK6V2VuGcZKjP8sdoDqcn5mW/51RBotTl/JDrlcd+XqdkqMv5lVwzLL+Nn5tEpJk4E7DuBPBaJcc+BeDvItICAESkoYhc5Oa9DuB8d4wyC8B9VSgDRGSA2xItHwf9K4BpnvxPReSeMKfXhjObf6uINIW9wmg2nD+wh0QkT0SyxSwx3wCgmVvecvMA9BaRXBFpA+C6kOuUwO0OFJG74bRg/Ty/ruWvjYjkiMhwOC3/WX7Or4JEqc80d0w300lKtvd1Tvb6dJ9DlvscBUCmW5ZIrrJMhLpcCOfDuaP7MwDO69wR7jfWFKjL/nDmPJQ/x2/grDCp9kTiMBKhPuE+Xg6ASwC8WEFestdnub4AvlJKrajieX4kSl2+AKfh3lFEMgHcBWe4aat7raSuS7fuvgDwNxGpJc6E68vgzEcKy+8L+CqcpXE/uT8PVHLso3Am6n4kIjsAfA3gRLeQiwDc5D7eOjjdZ3o/EHH3XqjksU8GsEBEdgF43/2505PfHMCXYc69F0AnOK3J9+B8O4VbrlIAF8BZObbGLdNlbvZ0OMta14tI+QqPR+CMuW6AM0nuFc91/gfgAzjjrKvhTFz2dn9aRGSRiFzpJmvBmeRWDOBXON125yml1oY7v5oSpT5Pg9Nl+T6cb0Z7YC/BTPb6hPt89sDpKXjGjU8Ld341xL0ulVIlSqn15T8AfgNQ5qZL3cOSui6VUltDnuN+ANsPNmxQDXGvT4+L4dTJJxXkJXV9elyNyE90LpcQdamUmg7n/+R7cLahaANnnky5VKjLK+AMHRa75bxLOatsw5KDDXOKs/ncAKXUx5UeGGci0gzAJKVUt3iXJZGxPlMH6zK1sD5TB+sy8cVl87xoUEr9Amf2OaUA1mfqYF2mFtZn6ghyXSbV7S2IiIiIauqgw15EREREqYQ9P0RERBQobPwQERFRoPia8HxW2iUcG4uzqWWTIrKpHusy/iJVlwDrMxHwvZk6+N5MLZXVJ3t+iIiIKFBSZqk7ERERxV7GES10PGK6fRu4y9++RceFQ76OWZkOhj0/REREFChs/BAREVGgsPFDREREgcI5P0RERFRtW59M1/FxWXafSq3fErOPJTFLRURERBQlbPwQERFRoHDYi4iIiPw7oYOVHNN2rI43lO638lq+tVXHZVEtVNWw54eIiIgChY0fIiIiChQ2foiIiChQUmvOj2ccctk1OTp+85zHrMM6ZGXqOF3s9l/v5WfpeMddzay8tM++i0gxiYiIklXOqI1W+pgss9S9y0PDrLzG876KSZmqij0/REREFChs/BAREVGgJPWw17rbT7LS9w16Wcfn5W7T8du7GlrHzduXreO0kMV3r7X+UMfH3tLPymv+WfXLStF3+/JFVvrMnH06Dh3ePK/zOTouWbc+ugUjIkpym2/opuOPW4228qbtqavjJhOXW3ml0S1WtbHnh4iIiAKFjR8iIiIKFDZ+iIiIKFCSbs5PRkFjHY8YON7K887zOXL8zTouHLXMOq50c7GOJTPLyps4tYuOJx3/rJU3oM9tOs5/9euqFJti4ADSrXQZlIlVyMizUqDgKb6um53uWqLjomf32gfPXhCLIhElJO8cHwCYe8+TOl60X6y8Mef00nHphpXRLViEsOeHiIiIAoWNHyIiIgqUpBv2WnFjax1fmPe+lXf24t46LnzYLLfzDnOFUgfsO9AuXnOYjosK7SGx4gt36zj/VZ8FpqhKyzbbFqQj/FDWUeNuttKtiudGrUxBlV6/no53n9Daysv5zGxDULZ7N2Jp37lmKHtLD3toa0mPZ3T8SY9DrLxH2rSPbsESmLcuMybbn4OFh5jdfRfcbHbV//X0POu4Wt3M5+6szpH/wDzlL/Z7us64mRG/RtDIcUfp+OO77OXspcp81t74pyFWXs7y2dEtWBSw54eIiIgChY0fIiIiChQ2foiIiChQkm7OT+YuE/93R2MrL+eKnTqubJ6PV8mZna30hFOf0vEPB+zl0S0eZ1sx3tLbtrHS+c//puMzc8LPJclfYadD53pRzf3at52OZw991Mq7d6N5n809LrrvI+l8lJU+aeQsHY9oZM/1+mC32ZZ/xONXW3kFSMy7UUdDxmEFVrp4rJn/9Fmb/4Y97+Nxi3XcM2eHlZfm+W5dFnIboUjY3Xubla4zLuKXCISMpk10fM74GTo+JK2WdVzhmzfquOhd+32UjBuH8L85ERERBQobP0RERBQoSTfs1fRfZknda5NOtvJKN6/29Rhl3Y/T8SNjn7Dy2mdm6vjoGddZeS1nzPNbTIqg9MaNdJzxzE4rb1zLqbEuDoWxp3H4zm/vcFOvMwbpOGN6ZLYc8A51DZ7wppV3bq4ZjgkdfBk212yP0Wb8EisvUe9GXW1i78qb3qCBjre9kGvlfdYh/FCXV+hQl19rS/bpONMuFhqm1wJFT+hdDZY8ZKaPTKnzjo7/UWwPHxfebIaPk3GYKxR7foiIiChQ2PghIiKiQGHjh4iIiAIl6eb8qBJzF+aSleHn+EgtM27849gjrbxZ3R8Pe17PQbfq+IgP59nX9ltIiqw6+Tp8o81rvk/r+u0VOm48ZbmVl3LzOeLhhA5W8o6L39JxWsj3qk/2mK3xIzHPR53c0UoPe9mscz4t297G4IAytd1r8aVW3hGXz9dxKv5NeLeGWDqogZW3+NLHavz4M/eaz9lr37/ezvTO5Qn58Gw90dTR6vOyrbwFfcf4unb69Dq+jiPbxuvs7V2WnGH+H64sMbd/mXmeva0I8Es0ixVz7PkhIiKiQGHjh4iIiAIlIYe90vPNMIfUr2vlbTzd7EZZd6m9o++P15pl6k+f/pKOe+SE7tRqulmvXHm2lZP3mVnuWspdgBPC2rMbHfwg2F22ALD/c9PNX7ppWUTLFFQlZ5gu8xNHz7Hy+uWbYei5++zvVSPuMttG5OPral17x2VddfyXB1628k7JNnUfupzdO9RV6+xV1bp2stpVaO7OXpVhruHru+l4wZYmVl7m0No6Tt9qtp4oXDkLfpWe3sk8RuudlRxpaz/Z3Mm98An/1wu67X3Me2fOXfb2Lr+W7tFx/+F36Lj2z9V7nyYL9vwQERFRoLDxQ0RERIHCxg8REREFSkLO+Vl1y9E6/n5wzZdjVuaVIz6y0kM/OVHHU1fZ23vXG2/udJz7JsebY+WWwZN9Hdf7yaFWuumo4NyVO1aGPG1ue+C9bQRgz7X5YMcxVl7dj8ycq+ouKd/TZ6uOf5e7Lexxz21rZaWDNs/Hr6UHTE3cvuISKy/rWhNnrF5j5XlXrZfAH9XtWCt923Ov6riyW2R4l9IDQNvntuu4rCwVNyeIjuxr1um4LGTfge7Thui4aEJqz/PxYs8PERERBQobP0RERBQoCTns1XJKsY779epp5eVl+Ft+/tUU083asMdaK2/1yoY67tlxsZV3eQPT7Teqqz209V0n07F/u9xs5eVO5jBYvB3+xnorzU7xmts/tYWV7p4z05Oy7w5998YuOl54vr08urTYfg/6sXbYSVZ6Ssd/elL2cIh3qOu9i08IeaQVVb52qshbZj5LO7x0q5VXMNu8Q3KmzLby/A5nVWb/OebvYfPAXVae37vBX/PJtVa66Ptval6wANhwi/3emXuk2cX5gc32zuxtBy3Usd+7GGQ0b2alSxscqmMptT95y+YvQSJizw8REREFChs/REREFCgJOexVumipjotPtvOK4U8zeFb6PGjnFWGVju11DMDDx5hVDzcMy7HylvR4TsejR9s3R72j7CYdh3YhU9XtO9d0mbfOGhvHkqQ+ybSHr9SHZkftj9qFrrQzxz6zraWVM+84byr8MJf3Zpt7W9Sx84Zv0PG37eyVnpliVlt6b1YKAP8Zf4GOmy3jKr9ypcvMkN8Rd8Z2+G/jcWbH/W9PeKmSI21HTrxFx22HfWvl8ebS4aXl5em4zaXhd7Qf/0F3K91qnxnKTm9opoRsvNi+sam60Pz3HX3k61beydkHdLxb2VNTRmw4VceLhhxt5aXNmBe2nNHGnh8iIiIKFDZ+iIiIKFDY+CEiIqJAScg5P/HkXZbX9u+FVt60rrk6PjPHvqP8ukvNOGerKdEpWyo7cPbxVvry0e/r2DueHKr9pwN0XLQhuEuaa+LnO+zX/tt2j+o49A7pXltK8qz0+iGe5bViH1vrrE06/rLjBM/jh79CaM4Bz4SP0PlGLV9apeNILNOmqksvtHfWHt5voq/zQndxbjfG7EZccsDf1iYErP6z2d7l+1ahd0Ywb8jWr223cpY+abaGePdc894vCpkLmOZ5jDn77NlXnWb31fGzx46z8kYVmG1g5r5sbwlzX5ezdFxa/BtiiT0/REREFChs/BAREVGgcNirEqU//GilB7/fX8dL//AfK49LMGtmW8tMK33doaGbEDiWH9hnpetNzdZx6fbtoYeTD2MGPF2t84bXX2Slhw5b4PPM6n3nGunZmfbdR+3luvV+nRl6OMWAd6ir+2S7/q+obbYtCB3C7PC5Ga5u+qL93s9ayV2cq0N5hprTQsad08W859571x6W8tpeZv6TtZ12vZXX5C1TT3lv2MNXTWDulHDN3bdYeQsHmm1humbb733JsofWYok9P0RERBQobPwQERFRoLDxQ0RERIES3Tk/aek6TD/EXhabjPMzCr7yjKP+IX7lCJo1JXt0fPkjw6y8ghd5K4OaGjSrr5Ve2P3ZqF5vxEZzH4wRjeb6Pm/aiFN0XG8K5/gkglWXFuj4rXqTrLxMMZ//H+3OtvKajTX/ejI/5hyfSGg8x2wJUjYwZBaqMrOuykJmqM7eZ/6vDRt+m44LJ9nzevw6qdd8K+293oA1p9p523dU6xqRwJ4fIiIiChQ2foiIiChQojrsVdrd7Dj5t7HPW3kjL7lSx2quvWQ2UaQXtbbSw+4fH/ZYCZtDNbVgv+laL3iUw1yR1qrPPCt9Wt9bdbytjf2XnXXsFh0PLJph5T026QKE02qM2Tn9h4fM++reXt+FPafXkoutdM6U2WGPpdhZc7fZyfv96/6p4zLYOzV7h7pGDulv5dX6eE50ChdguV8u1fF1a3pYeSOamB3zD8/IsfKOzNyr4y1tzVBl7eOO8n3tH6+ureMXmzxs5a0sMZ8hv9xh/09N2zXP9zUijT0/REREFChs/BAREVGgsPFDREREgRLVOT8/9TbbYXerVWrlbeySr+OG/le7Rl1athmnXn9GIyvvvNxtYc+r90FO2DyqmHjuGryzRRwLQpY648wy8jqVHPc26lvpFgg/H8v77q/buKGOQ+/q/t0+830s61r7MXi39vhIy7O3KZl8rZnT0SSjVujh2sDp/XVc9B7n+ESbd/uYTde3s/KmTSrS8VW1V1l5h6SZOvz+Rs/d4G+0H997y4zQ5fJeB5R9y4rz7x+i4wYzEmeLCvb8EBERUaCw8UNERESBEtVhr7ZjPbs39rbznho+Rse3b7jZyst9s3o7S/qV3raNjpcObGDlnXaSWXb/dvPHEU7nOVdZ6WZvmfNKQw+mCqU3Mq/9wv7hX2tKftLZLJud1fllHYfe7bvPOzfpuHD119EuFoWx+YZuOm7bb4mV1yozM/TwChXdwKGueClbaNfZG+3NFI6nBl1k5eX3XqfjqUe94evx5+6z01d+OUDHh71lD3s1eD1xhrq82PNDREREgcLGDxEREQUKGz9EREQUKFGd85O2bZeOp+3JtfLOzNmt43//6zErb8JdJ4Z9zNe/7qLj+t+lhz2uuLOZeXNUu5+tvAFNP9RxZcvXlx3Yb6WvHvlnHTebuNjKS8a71BPFyu7meRX+/pM9h1jpdk//pmPOnYudjJaHW+nTB5p5lyMLws/BXH7AbEBw6/LL7MfEmgiVjiKp4VMhc3CeMuH56Fytx2yD8LepSVTs+SEiIqJAYeOHiIiIAiWqw14lK1fr+LGzzrHybh7cRMev/fFRK29k42/CPubIizx5F4U9zLejPre3kc2fbobnGn220cprsMx0F7JLPnae7OvdJ2F+3MpB/qXn51vpEaPH6jhTzHD14P/1t44rXBzdbS7ISG9odtpOf8ke4q9sqOuXErPOefiqS3Sc0ZPDXJQ82PNDREREgcLGDxEREQUKGz9EREQUKFGd8+Plnf8DAK2HmvQd026y8lZdaO4eO+uCR6y8UZtO0XFlc4M6ze6r453r7OW07f9drOMjfgyZQ6LM3Wo5rycxpO01y2lDb4dACaqgoZWsnbZXx5tLzfySptNiViIKsfx2c5ufBW3GVHKk7bIHhuq4/rOJeesCooNhzw8REREFChs/REREFCgxG/aqTNaH9t1/i8wGzOg7+OSQo82wVGW7UTbB4rB5HM5KDKWbzfBju4n20GdZnqml9j8tjVmZKDJKl62w0utLDtXxw7+aO4bnTubS9ljZeYm9c/70PqM8qVphz7t/Uycr3XDCQh1zGJqSFXt+iIiIKFDY+CEiIqJAYeOHiIiIAiUh5vxQMKl9Zpv8Nrd9HfY4ztFKfk//7mwdq23b41iSYCk7paOOH/zH01Zew/Tw83yOef5WHbceZ9/mp2zHitDDiZIOe36IiIgoUNj4ISIiokDhsBcRRV3JT6viXYRAylq7VceTtxxv5Z1YYHZnHlXcwcpr/epmHYduW0CUCtjzQ0RERIHCxg8REREFChs/REREFCic80NElKK8c61+CLkb0IXoUsmZP0alPESJgj0/REREFChs/BAREVGgiFLq4EcRERERpQj2/BAREVGgsPFDREREgcLGDxEREQUKGz9EREQUKGz8EBERUaCw8UNERESB8v9D4PCrWQgE7QAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 720x720 with 5 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"logits = model(xbv)\n",
"probs = F.softmax(logits, dim=1)\n",
"idx = 5\n",
"_,axs = plt.subplots(1, idx, figsize=(10, 10))\n",
"for actual, pred, im, ax in zip(ybv[:idx], probs[:idx],xbv.permute(0,2,3,1)[:idx], axs.flat):\n",
" ax.imshow(im)\n",
" ax.set_axis_off()\n",
" ax.set_title(f'pred: {pred.argmax(0).item()}, actual:{actual.item()}')\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [],
"source": [
"def predict(img):\n",
" with torch.no_grad():\n",
" img = img[None,]\n",
" pred = loaded_model(img)[0]\n",
" pred_probs = F.softmax(pred, dim=0)\n",
" pred = [{\"digit\": i, \"prob\": f'{prob*100:.2f}%', 'logits': pred[i]} for i, prob in enumerate(pred_probs)]\n",
" pred = sorted(pred, key=lambda ele: ele['digit'], reverse=False)\n",
" return pred"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {
"tags": [
"exclude"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(3)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[W NNPACK.cpp:64] Could not initialize NNPACK! Reason: Unsupported hardware.\n"
]
},
{
"data": {
"text/plain": [
"[{'digit': 0, 'prob': '0.00%', 'logits': tensor(-5.5980)},\n",
" {'digit': 1, 'prob': '0.00%', 'logits': tensor(-0.4972)},\n",
" {'digit': 2, 'prob': '0.02%', 'logits': tensor(1.2516)},\n",
" {'digit': 3, 'prob': '99.95%', 'logits': tensor(9.9263)},\n",
" {'digit': 4, 'prob': '0.00%', 'logits': tensor(-5.5094)},\n",
" {'digit': 5, 'prob': '0.01%', 'logits': tensor(0.2367)},\n",
" {'digit': 6, 'prob': '0.00%', 'logits': tensor(-9.4633)},\n",
" {'digit': 7, 'prob': '0.00%', 'logits': tensor(-2.4315)},\n",
" {'digit': 8, 'prob': '0.02%', 'logits': tensor(1.4733)},\n",
" {'digit': 9, 'prob': '0.00%', 'logits': tensor(-0.0205)}]"
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"img = xb[1].reshape(1, 28, 28)\n",
"print(yb[1])\n",
"predict(img)"
]
},
{
"cell_type": "markdown",
"metadata": {
"tags": [
"exclude"
]
},
"source": [
"#### commit to .py file for deployment"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {
"tags": [
"exclude"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[NbConvertApp] Converting notebook mnist_classifier.ipynb to script\n",
"[NbConvertApp] Writing 2920 bytes to mnist_classifier.py\n"
]
}
],
"source": [
"!jupyter nbconvert --to script --TagRemovePreprocessor.remove_cell_tags=\"exclude\" --TemplateExporter.exclude_input_prompt=True mnist_classifier.ipynb\n"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {
"tags": [
"exclude"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(0.1046) tensor(1.) tensor(0.) tensor(0.3062)\n",
"tensor(0.1435) tensor(1.) tensor(0.) tensor(0.3220)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAI4AAACOCAYAAADn/TAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAFjklEQVR4nO3dz6tUdRjH8fcn8wfZJsvCVNKFRK4KJItaBCHd3Ngm0CBcCG4KClqk9Q+4atdGSHQRRlRgC0FKiggilJDS5KoF5k3JhCBXofC4mENM0713zjxn5syZOZ8XXGbOd+5wvosPz/me78x9riICs0HdNe4J2GRycCzFwbEUB8dSHBxLcXAspVJwJM1ImpV0SdK+YU3Kmk/ZfRxJS4ALwDZgDjgF7IqIn4c3PWuquyu890ngUkT8CiDpI2AHsGBwlml5rGBlhVNa3W7y142IWN07XiU4a4ErXcdzwNbF3rCClWzV8xVOaXX7Mj65PN94leBonrH/Xfck7QX2AqzgngqnsyapsjieA9Z3Ha8Drvb+UkQcjIgtEbFlKcsrnM6apErFOQVskrQR+B3YCbwylFlNmRNXzyz6+gsPP17LPIYpHZyIuC3pdeAEsAQ4FBHnhjYza7QqFYeIOA4cH9JcbIJ459hSKlUcm1+/Nc00cMWxFAfHUhwcS/EaZwwmcd+mlyuOpTg4luLgWIrXOEMwjZ9F9eOKYykOjqX4UpXUho8VFuOKYykOjqU4OJbiNU5Jg6xppvH2u5crjqU4OJbi4FiK1zhD0IY1TS9XHEtxcCzFwbEUr3EW0MavSgzCFcdSHBxLcXAsxWucQtu/XzMoVxxL6RscSYckXZd0tmtslaQvJF0sHu8b7TStacpUnMPATM/YPuBkRGwCThbH1iJ91zgR8Y2kDT3DO4DniudHgK+Bt4c5saZp+75Nr+wa56GIuAZQPD44vCnZJBj5XZXb1U6nbMX5Q9IagOLx+kK/6Ha10ykbnM+B3cXz3cCx4UzHJkWZ2/GjwHfAo5LmJO0BDgDbJF2k809ADox2mtY0Ze6qdi3wkv8pQ4t559hS/FlVUtv/zsoVx1IcHEtxcCzFa5ySqnxfp/e907DmccWxFAfHUnypGoLeS08bvobqimMpDo6lODiW4jVO0mK31P3WPNNwe+6KYykOjqU4OJbi4FiKg2MpDo6lODiW4n2ckiZxr2WUXHEsxcGxFAfHUhwcS3FwLMXBsRQHx1K8j1PwPs1gXHEspUx/nPWSvpJ0XtI5SW8U425Z22JlKs5t4K2IeAx4CnhN0mbcsrbV+gYnIq5FxA/F85vAeWAtnZa1R4pfOwK8NKI5WgMNtMYp+h0/AXyPW9a2WungSLoX+BR4MyL+HuB9eyWdlnT6Fv9k5mgNVCo4kpbSCc2HEfFZMVyqZa3b1U6nMndVAj4AzkfEe10vuWVti5XZAHwGeBX4SdKZYuwdOi1qPy7a1/4GvDySGVojlWlX+y2gBV52y9qW8s6xpTg4luLgWIqDYykOjqU4OJbi4FiKg2Mp/uroCLhdrdkCHBxLcXAsxWucMZiGP8VxxbEUB8dSHBxL8RpnCPrt20zDmqaXK46lODiW4uBYitc4SW34PGoxrjiW4uBYioNjKYqI+k4m/QlcBh4AbtR24sF4bv/1SESs7h2sNTj/nlQ6HRFbaj9xCZ5bOb5UWYqDYynjCs7BMZ23DM+thLGscWzy+VJlKbUGR9KMpFlJlySNtb2tpEOSrks62zXWiN7Nk9BburbgSFoCvA+8CGwGdhX9ksflMDDTM9aU3s3N7y0dEbX8AE8DJ7qO9wP76zr/AnPaAJztOp4F1hTP1wCz45xf17yOAduaNL86L1VrgStdx3PFWJM0rndzU3tL1xmc+foI+pZuEdne0nWoMzhzwPqu43XA1RrPX0ap3s11qNJbug51BucUsEnSRknLgJ10eiU3SSN6N09Eb+maF3nbgQvAL8C7Y15wHgWuAbfoVMM9wP107lYuFo+rxjS3Z+lcxn8EzhQ/25syv4jwzrHleOfYUhwcS3FwLMXBsRQHx1IcHEtxcCzFwbGUO+5QVpdhOxaLAAAAAElFTkSuQmCC",
"text/plain": [
"<Figure size 144x144 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from PIL import Image\n",
"import numpy as np\n",
"import glob\n",
"import os\n",
"png_files = glob.glob(os.path.join('./uploads', '*png'))\n",
"img = png_files[1]\n",
"image = Image.open(img)\n",
"image = TF.to_tensor(image).permute(1, 2, 0)\n",
"print(image.mean(), image.max(), image.min(), image.std())\n",
"print(xb[0].mean(), xb[0].max(), xb[0].min(), xb[0].std())\n",
"plt.imshow(image);"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "python_main",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
|