|
import torch |
|
import torchvision.transforms as transforms |
|
import torchvision.models as models |
|
import torch.nn as nn |
|
from joblib import load |
|
from gradio import File |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
import io |
|
import numpy as np |
|
import gradio as gr |
|
|
|
|
|
device = torch.device("cpu") |
|
|
|
|
|
data_transforms = transforms.Compose([ |
|
transforms.Resize(224), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
]) |
|
|
|
|
|
def load_isolation_forest(): |
|
path = 'Models/Anomaly_MSI_MSS_Isolation_Forest_model.joblib' |
|
return load(path) |
|
|
|
|
|
def load_feature_extractor(): |
|
feature_extractor_path = 'Models/feature_extractor.pth' |
|
feature_extractor = models.resnet50(weights=None) |
|
feature_extractor.fc = nn.Sequential() |
|
feature_extractor.load_state_dict(torch.load(feature_extractor_path, map_location=device)) |
|
feature_extractor.to(device) |
|
feature_extractor.eval() |
|
return feature_extractor |
|
|
|
|
|
def is_anomaly(clf, feature_extractor, image): |
|
with torch.no_grad(): |
|
image_features = feature_extractor(image) |
|
return clf.predict(image_features.cpu().numpy().reshape(1, -1))[0] == -1 |
|
|
|
|
|
def classify_image(model, image): |
|
with torch.no_grad(): |
|
outputs = model(image) |
|
probabilities = torch.nn.functional.softmax(outputs, dim=1) |
|
_, predicted = torch.max(outputs, 1) |
|
|
|
class_names = ['abnormal', 'normal'] |
|
predicted_class_index = predicted.item() |
|
predicted_class_name = class_names[predicted_class_index] |
|
predicted_probability = probabilities[0][predicted_class_index].item() * 100 |
|
|
|
return predicted_class_name, predicted_probability |
|
|
|
|
|
def load_classification_model(): |
|
model_path = 'Gastric_Models/the_resnet_50_model.pth' |
|
model = torch.load(model_path, map_location=device) |
|
model.to(device) |
|
model.eval() |
|
return model |
|
|
|
|
|
def process_image(image_path): |
|
|
|
image = Image.open(io.BytesIO(image_path.read())).convert('RGB') |
|
input_image = data_transforms(image).unsqueeze(0).to(device) |
|
|
|
|
|
clf = load_isolation_forest() |
|
feature_extractor = load_feature_extractor() |
|
classification_model = load_classification_model() |
|
|
|
|
|
if is_anomaly(clf, feature_extractor, input_image): |
|
return "Anomaly detected. Image will not be classified.", None, None |
|
|
|
|
|
predicted_class, probability = classify_image(classification_model, input_image) |
|
result = f"The predicted class is: {predicted_class} with a probability of {probability:.2f}%" |
|
|
|
|
|
|
|
return result, None, None |
|
|
|
|
|
|
|
iface = gr.Interface( |
|
fn=process_image, |
|
inputs=File(type="filepath"), |
|
outputs=[gr.Textbox(label="Result"), gr.Image(label="Heatmap"), gr.Image(label="Additional Output")], |
|
title="GastroHub AI Gastric Image Classifier", |
|
description="Upload an image to classify it as normal or abnormal.", |
|
article="Above is a sample image to test the results of the model. Click it to see the results.", |
|
examples=[ |
|
["Gastric_Images/Ladybug.png"], |
|
], |
|
allow_flagging="never", |
|
) |
|
|
|
iface.launch() |