Spaces:
Runtime error
Runtime error
File size: 2,863 Bytes
0bbec58 338bbe8 0bbec58 8e35bc7 9490409 8e35bc7 0bbec58 9490409 8e35bc7 0bbec58 5993d2f 056ab4f 0bbec58 c32023c 338bbe8 c32023c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
#!/usr/bin/env python
# coding: utf-8
import torch
from torch import nn
import torch.nn.functional as F
from datasets import load_dataset
import fastcore.all as fc
import torchvision.transforms.functional as TF
from torch.utils.data import default_collate, DataLoader
import torch.optim as optim
def transform_ds(b):
b[x] = [TF.to_tensor(ele) for ele in b[x]]
return b
bs = 1024
class DataLoaders:
def __init__(self, train_ds, valid_ds, bs, collate_fn, **kwargs):
self.train = DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate_fn, **kwargs)
self.valid = DataLoader(valid_ds, batch_size=bs, shuffle=False, collate_fn=collate_fn, **kwargs)
def collate_fn(b):
collate = default_collate(b)
return (collate[x], collate[y])
class Reshape(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
return x.reshape(self.dim)
def conv(ni, nf, ks=3, s=2, act=nn.ReLU, norm=None):
layers = [nn.Conv2d(ni, nf, kernel_size=ks, stride=s, padding=ks//2)]
if norm:
layers.append(norm)
if act:
layers.append(act())
return nn.Sequential(*layers)
def _conv_block(ni, nf, ks=3, s=2, act=nn.ReLU, norm=None):
return nn.Sequential(
conv(ni, nf, ks=ks, s=1, norm=None, act=act),
conv(nf, nf, ks=ks, s=s, norm=norm, act=act),
)
class ResBlock(nn.Module):
def __init__(self, ni, nf, s=2, ks=3, act=nn.ReLU, norm=None):
super().__init__()
self.convs = _conv_block(ni, nf, s=s, ks=ks, act=act, norm=norm)
self.idconv = fc.noop if ni==nf else conv(ni, nf, ks=1, s=1, act=None)
self.pool = fc.noop if s==1 else nn.AvgPool2d(2, ceil_mode=True)
self.act = act()
def forward(self, x):
return self.act(self.convs(x) + self.idconv(self.pool(x)))
def cnn_classifier():
return nn.Sequential(
ResBlock(1, 8, norm=nn.LayerNorm([8, 14, 14])),
ResBlock(8, 16, norm=nn.LayerNorm([16, 7, 7])),
ResBlock(16, 32, norm=nn.LayerNorm([32, 4, 4])),
ResBlock(32, 64, norm=nn.LayerNorm([64, 2, 2])),
ResBlock(64, 64, norm=nn.LayerNorm([64, 1, 1])),
conv(64, 10, act=False),
nn.Flatten(),
)
def kaiming_init(m):
if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
nn.init.kaiming_normal_(m.weight)
loaded_model = cnn_classifier()
loaded_model.load_state_dict(torch.load('classifier.pth'));
loaded_model.eval();
def predict(img):
with torch.no_grad():
img = img[None,]
pred = loaded_model(img)[0]
pred_probs = F.softmax(pred, dim=0)
pred = [{"digit": i, "prob": f'{prob*100:.2f}%', 'logits': pred[i]} for i, prob in enumerate(pred_probs)]
pred = sorted(pred, key=lambda ele: ele['digit'], reverse=False)
return pred
|