Spaces:
Sleeping
Sleeping
import os | |
import time | |
from PIL import Image, ImageOps | |
from torch import nn | |
import torchvision.transforms as T | |
import torch | |
import cv2 | |
import numpy as np | |
import streamlit as st | |
st.set_page_config(layout="wide", page_title="Digit Recognition") | |
class Network(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=(5, 5), stride=(1, 1), padding=(0, 0)) | |
self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=(5, 5), stride=(1, 1), padding=(0, 0)) | |
self.conv3 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=(4, 4), stride=(1, 1), padding=(0, 0)) | |
self.fully_connected1 = nn.Linear(in_features=120, out_features=84) | |
self.fully_connected2 = nn.Linear(in_features=84, out_features=10) | |
self.pooling_layer = nn.AvgPool2d(kernel_size=(2, 2), stride=(2, 2)) | |
self.relu = nn.ReLU() | |
self.dropout = nn.Dropout(0.25) | |
def forward(self, x): | |
# Convolution Layer 1 | |
x = self.conv1(x) | |
x = self.relu(x) | |
x = self.pooling_layer(x) | |
# Convolution Layer 2 | |
x = self.conv2(x) | |
x = self.relu(x) | |
x = self.pooling_layer(x) | |
x = self.dropout(x) | |
# Convolution Layer 3 | |
x = self.conv3(x) | |
x = self.relu(x) | |
# flatten x | |
x = x.view(-1, 120) | |
# Fully connected layer 1 | |
x = self.fully_connected1(x) | |
x = self.relu(x) | |
# Fully connected layer 2 | |
x = self.fully_connected2(x) | |
return x | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = Network() | |
model.load_state_dict(torch.load('mnist_model.pth', map_location=torch.device(device))) | |
st.title("MNIST Image Classification") | |
st.subheader("This is a simple image classification web application to predict handwritten digits") | |
st.sidebar.write('## Please upload an image file :camera:', unsafe_allow_html=True) | |
file = st.sidebar.file_uploader("## Upload", type=["png"]) | |
if file is None: | |
imagefile = './0.png' | |
else: | |
imagefile = file | |
img = Image.open(imagefile) | |
img_copy = img | |
img = cv2.cvtColor(np.array(img), cv2.COLOR_BGR2GRAY) | |
transform = T.Compose([ | |
T.ToTensor(), | |
T.Resize((28, 28)) | |
]) | |
img = transform(img) | |
st.image(img_copy, width=150) | |
model.eval() | |
results = model(img) | |
category = torch.argmax(results) | |
print(category.numpy()) | |
st.write('<hr font-size: 30px;>The image is digit </hr>', str(category.numpy()), unsafe_allow_html=True) | |