import gradio as gr import numpy as np import torch import torchvision from torch import nn from huggingface_hub import snapshot_download class LeNet(nn.Module): def __init__(self): super(LeNet, self).__init__() self.convs = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=4, kernel_size=(5, 5)), nn.Tanh(), nn.AvgPool2d(2, 2), nn.Conv2d(in_channels=4, out_channels=12, kernel_size=(5, 5)), nn.Tanh(), nn.AvgPool2d(2, 2) ) self.linear = nn.Sequential( nn.Linear(4*4*12,10) ) def forward(self, x): x = self.convs(x) x = torch.flatten(x, 1) return self.linear(x) @torch.no_grad() def predict(self, input): input = input.reshape(1, 1, 28, 28) out = self(input) return nn.functional.softmax(out[0], dim = 0) lenet = LeNet() lenet_pt = snapshot_download('stanimirovb/ibob-lenet-v1') + '/lenet-v1.pth' lenet.load_state_dict(torch.load(lenet_pt, map_location='cpu')) resize = torchvision.transforms.Resize((28, 28), antialias=True) def on_submit(img): with torch.no_grad(): img = img['composite'].astype(np.float32) img = torch.from_numpy(img) img = resize(img.unsqueeze(0)) result = lenet.predict(img) sorted = [[i, e] for i, e in enumerate(result.numpy())] sorted.sort(key = lambda a : -a[1]) return "\n".join(map(str, sorted)) iface = gr.Interface( title = "LeNet", fn = on_submit, inputs=gr.Sketchpad(image_mode='P'), outputs=gr.Text(), ) iface.launch()