MostafaAhmed98's picture
Update app.py
7398a9f verified
import torch
import torchvision
import base64
import torch.nn.functional as F
import gradio as gr
from PIL import Image
from pathlib import Path
from torch import nn
from torchvision import transforms
device = 'cpu' # 'cuda' if torch.cuda.is_available() else
base_path = str(Path(__file__).parent)
path_of_model = base_path + "/cnn_net.pt"
default_img = base_path + "/base64_img.bin"
def load_the_model():
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 39)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
loaded_model = Net()#.to(device)
loaded_model.load_state_dict(torch.load(path_of_model, map_location=torch.device('cpu')))
loaded_model.eval()
return loaded_model
def loads_data_classes() -> list:
"""
Load the class labels for the prediction.
Returns:
list: A list of class labels.
"""
class_labels = ['ا','ب','ت','ث','ج','ح','خ','د','ذ','ر','ز','س','ش','ص','ض','ط','ظ','ع','غ','ف',
'ق','ك','ل','لا','م','ن','ه','و','ي','٠','١','٢','٣','٤','٥','٦','٧','٨','٩']
return class_labels
def base64_to_image(base64_file):
# Decode the Base64 string to binary data
image_data = base64.b64decode(base64_file)
image_path = 'decoded.png'
with open(image_path, 'wb') as output_file:
output_file.write(image_data)
return image_path
def read_base64_file(file_path):
with open(file_path, 'r') as file:
base64_string = file.read()
return base64_string
def predict_on_base64(model, base64_file):
path = base64_to_image(base64_file)
img = Image.open(path)
model.eval()
with torch.inference_mode():
custom_image = torchvision.io.read_image(str(path)).type(torch.float32)
# Divide the image pixel values by 255 to get them between [0, 1]
custom_image = custom_image / 255.
# apply the model transformations
transform_img = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((32,32)),
])
# Transform target image
custom_image_transformed = transform_img(custom_image)
# Add an extra dimension to image (Batch_size)
custom_image_transformed_with_batch_size = custom_image_transformed.unsqueeze(dim=0)
# Make a prediction on image with an extra dimension
custom_image_pred = model(custom_image_transformed_with_batch_size) # .to(device)
# Getting the probs
prob = torch.softmax(custom_image_pred, dim=1)
# Getting the sample prob
sample_prob = round(prob[0][prob.argmax()].item(), 3)
# getting the highest logit
test_pred_labels = custom_image_pred.argmax(dim=1).item()
labels = loads_data_classes()
test_pred_labels = labels[test_pred_labels]
return test_pred_labels, sample_prob, img
model = load_the_model() # load the model
def predict(user_base_64_file):
base64_string = read_base64_file(user_base_64_file.name) # convert the base64 image to string
prediction, probability, img = predict_on_base64(model=model, base64_file=base64_string) # use the model and getting prediction
return prediction, probability, img
demo = gr.Interface(fn=predict,
inputs=gr.File(value = default_img ,label="Upload a Base64 Image File with .txt(utf-8 format) or .bin"),
outputs=[gr.Textbox(label="Predicted Label"), gr.Textbox(label="Probability"), gr.Image(label="Image")],
title="Arabic Letter Recognition", allow_flagging=False
)
demo.launch(share=True)