|
import torch |
|
import torchvision.transforms as transforms |
|
import torchvision.models as models |
|
import torch.nn as nn |
|
from joblib import load |
|
from gradio import File |
|
from PIL import Image |
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
import io |
|
|
|
|
|
device = torch.device("cpu") |
|
data_transforms = transforms.Compose([ |
|
transforms.Resize(224), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
]) |
|
|
|
|
|
clf = load('Models/Anomaly_MSI_MSS_Isolation_Forest_model.joblib') |
|
|
|
|
|
feature_extractor_path = 'Models/feature_extractor.pth' |
|
feature_extractor = models.resnet50(weights=None) |
|
feature_extractor.fc = nn.Sequential() |
|
feature_extractor.load_state_dict(torch.load(feature_extractor_path, map_location=device)) |
|
feature_extractor.to(device) |
|
feature_extractor.eval() |
|
|
|
|
|
GASTRIC_MODEL_PATH = 'Gastric_Models/the_resnet_50_model.pth' |
|
model_ft = torch.load(GASTRIC_MODEL_PATH, map_location=device) |
|
model_ft.to(device) |
|
model_ft.eval() |
|
|
|
|
|
def is_anomaly(clf, feature_extractor, input_image): |
|
feature_extractor.to(device) |
|
with torch.no_grad(): |
|
image_features = feature_extractor(input_image) |
|
|
|
is_outlier = clf.predict(image_features.cpu().numpy().reshape(1, -1)) |
|
return is_outlier[0] == -1 |
|
|
|
|
|
def classify_image(uploaded_image): |
|
image = Image.open(uploaded_image).convert('RGB') |
|
input_image = data_transforms(image).unsqueeze(0).to(device) |
|
|
|
|
|
if is_anomaly(clf, feature_extractor): |
|
return "Anomaly detected. Image will not be classified.", None |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model_ft(input_image) |
|
probabilities = torch.nn.functional.softmax(outputs, dim=1) |
|
_, predicted = torch.max(outputs, 1) |
|
|
|
predicted_class_index = predicted.item() |
|
class_names = ['abnormal', 'normal'] |
|
predicted_class_name = class_names[predicted_class_index] |
|
predicted_probability = probabilities[0][predicted_class_index].item() * 100 |
|
|
|
return f"Class: {predicted_class_name}, Probability: {predicted_probability:.2f}%", None |
|
|
|
iface = gr.Interface( |
|
fn=classify_image, |
|
inputs=File(type="filepath"), |
|
outputs=gr.Image(), |
|
title="GastroHub AI Gastric Image Classifier", |
|
description="Upload an image to classify it as normal or abnormal.", |
|
article="Above is a sample image to test the results of the model. Click it to see the results.", |
|
examples=[ |
|
["Gastric_Images/Ladybug.png"], |
|
], |
|
allow_flagging="never", |
|
) |
|
|
|
|
|
iface.launch() |