|
|
|
import torch |
|
import torchvision |
|
import base64 |
|
import torch.nn.functional as F |
|
import gradio as gr |
|
from PIL import Image |
|
from pathlib import Path |
|
from torch import nn |
|
from torchvision import transforms |
|
|
|
device = 'cpu' |
|
base_path = str(Path(__file__).parent) |
|
path_of_model = base_path + "/cnn_net.pt" |
|
default_img = base_path + "/base64_img.bin" |
|
|
|
def load_the_model(): |
|
class Net(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.conv1 = nn.Conv2d(1, 6, 5) |
|
self.pool = nn.MaxPool2d(2, 2) |
|
self.conv2 = nn.Conv2d(6, 16, 5) |
|
self.fc1 = nn.Linear(16 * 5 * 5, 120) |
|
self.fc2 = nn.Linear(120, 84) |
|
self.fc3 = nn.Linear(84, 39) |
|
|
|
def forward(self, x): |
|
x = self.pool(F.relu(self.conv1(x))) |
|
x = self.pool(F.relu(self.conv2(x))) |
|
x = torch.flatten(x, 1) |
|
x = F.relu(self.fc1(x)) |
|
x = F.relu(self.fc2(x)) |
|
x = self.fc3(x) |
|
return x |
|
|
|
loaded_model = Net() |
|
loaded_model.load_state_dict(torch.load(path_of_model, map_location=torch.device('cpu'))) |
|
loaded_model.eval() |
|
return loaded_model |
|
|
|
def loads_data_classes() -> list: |
|
""" |
|
Load the class labels for the prediction. |
|
|
|
Returns: |
|
list: A list of class labels. |
|
""" |
|
class_labels = ['ا','ب','ت','ث','ج','ح','خ','د','ذ','ر','ز','س','ش','ص','ض','ط','ظ','ع','غ','ف', |
|
'ق','ك','ل','لا','م','ن','ه','و','ي','٠','١','٢','٣','٤','٥','٦','٧','٨','٩'] |
|
return class_labels |
|
|
|
def base64_to_image(base64_file): |
|
|
|
|
|
image_data = base64.b64decode(base64_file) |
|
|
|
image_path = 'decoded.png' |
|
with open(image_path, 'wb') as output_file: |
|
output_file.write(image_data) |
|
|
|
return image_path |
|
|
|
def read_base64_file(file_path): |
|
with open(file_path, 'r') as file: |
|
base64_string = file.read() |
|
return base64_string |
|
|
|
def predict_on_base64(model, base64_file): |
|
path = base64_to_image(base64_file) |
|
img = Image.open(path) |
|
model.eval() |
|
with torch.inference_mode(): |
|
custom_image = torchvision.io.read_image(str(path)).type(torch.float32) |
|
|
|
custom_image = custom_image / 255. |
|
|
|
transform_img = transforms.Compose([ |
|
transforms.Grayscale(), |
|
transforms.Resize((32,32)), |
|
]) |
|
|
|
|
|
custom_image_transformed = transform_img(custom_image) |
|
|
|
custom_image_transformed_with_batch_size = custom_image_transformed.unsqueeze(dim=0) |
|
|
|
|
|
custom_image_pred = model(custom_image_transformed_with_batch_size) |
|
|
|
prob = torch.softmax(custom_image_pred, dim=1) |
|
|
|
sample_prob = round(prob[0][prob.argmax()].item(), 3) |
|
|
|
test_pred_labels = custom_image_pred.argmax(dim=1).item() |
|
labels = loads_data_classes() |
|
test_pred_labels = labels[test_pred_labels] |
|
|
|
return test_pred_labels, sample_prob, img |
|
|
|
model = load_the_model() |
|
|
|
def predict(user_base_64_file): |
|
base64_string = read_base64_file(user_base_64_file.name) |
|
prediction, probability, img = predict_on_base64(model=model, base64_file=base64_string) |
|
return prediction, probability, img |
|
|
|
|
|
demo = gr.Interface(fn=predict, |
|
inputs=gr.File(value = default_img ,label="Upload a Base64 Image File with .txt(utf-8 format) or .bin"), |
|
outputs=[gr.Textbox(label="Predicted Label"), gr.Textbox(label="Probability"), gr.Image(label="Image")], |
|
title="Arabic Letter Recognition", allow_flagging=False |
|
) |
|
demo.launch(share=True) |