|
import torch |
|
import torchvision.transforms as transforms |
|
import torchvision.models as models |
|
import torch.nn as nn |
|
from joblib import load |
|
from gradio import File |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
import io |
|
import numpy as np |
|
import gradio as gr |
|
|
|
|
|
device = torch.device("cpu") |
|
|
|
|
|
data_transforms = transforms.Compose([ |
|
transforms.Resize(224), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
]) |
|
|
|
|
|
def load_isolation_forest(): |
|
path = 'Models/Anomaly_MSI_MSS_Isolation_Forest_model.joblib' |
|
return load(path) |
|
|
|
|
|
def load_feature_extractor(): |
|
feature_extractor_path = 'Models/feature_extractor.pth' |
|
feature_extractor = models.resnet50(weights=None) |
|
feature_extractor.fc = nn.Sequential() |
|
feature_extractor.load_state_dict(torch.load(feature_extractor_path, map_location=device)) |
|
feature_extractor.to(device) |
|
feature_extractor.eval() |
|
return feature_extractor |
|
|
|
|
|
def is_anomaly(clf, feature_extractor, image): |
|
with torch.no_grad(): |
|
image_features = feature_extractor(image) |
|
return clf.predict(image_features.cpu().numpy().reshape(1, -1))[0] == -1 |
|
|
|
|
|
def classify_image(model, image): |
|
with torch.no_grad(): |
|
outputs = model(image) |
|
probabilities = torch.nn.functional.softmax(outputs, dim=1) |
|
_, predicted = torch.max(outputs, 1) |
|
|
|
class_names = ['abnormal', 'normal'] |
|
predicted_class_index = predicted.item() |
|
predicted_class_name = class_names[predicted_class_index] |
|
predicted_probability = probabilities[0][predicted_class_index].item() * 100 |
|
|
|
return predicted_class_name, predicted_probability |
|
|
|
|
|
def load_classification_model(): |
|
model_path = 'Gastric_Models/the_resnet_50_model.pth' |
|
model = torch.load(model_path, map_location=device) |
|
model.to(device) |
|
model.eval() |
|
return model |
|
|
|
|
|
def process_image(image_path): |
|
|
|
image = Image.open(image_path).convert('RGB') |
|
input_image = data_transforms(image).unsqueeze(0).to(device) |
|
|
|
|
|
clf = load_isolation_forest() |
|
feature_extractor = load_feature_extractor() |
|
classification_model = load_classification_model() |
|
|
|
|
|
if is_anomaly(clf, feature_extractor, input_image): |
|
return "Anomaly detected. Image will not be classified.", None |
|
|
|
|
|
predicted_class, probability = classify_image(classification_model, input_image) |
|
result = f"The predicted class is: {predicted_class} with a probability of {probability:.2f}%" |
|
|
|
|
|
heatmap = generate_heatmap(classification_model, input_image) |
|
heatmap_image = Image.fromarray(np.uint8(plt.cm.hot(heatmap) * 255)) |
|
|
|
return result, heatmap_image |
|
|
|
|
|
def generate_heatmap(model, image): |
|
activation = [] |
|
def hook_fn(module, input, output): |
|
activation.append(output) |
|
for module in model.named_modules(): |
|
if isinstance(module[1], torch.nn.ReLU): |
|
module[1].register_forward_hook(hook_fn) |
|
|
|
|
|
output = model(image) |
|
prediction = output.argmax(1) |
|
|
|
|
|
model.zero_grad() |
|
one_hot_output = torch.FloatTensor(1, output.size()[-1]).zero_().to(device) |
|
one_hot_output[0][prediction] = 1 |
|
output.backward(gradient=one_hot_output) |
|
|
|
|
|
if len(activation) > 0: |
|
gradients = activation[0][0].detach().cpu().numpy() |
|
heatmap = gradients.max(axis=0) |
|
threshold = 0.5 |
|
heatmap[heatmap < threshold] = 0 |
|
return heatmap |
|
else: |
|
return np.zeros((224, 224)) |
|
|
|
|
|
iface = gr.Interface( |
|
fn=process_image, |
|
inputs=File(type="filepath"), |
|
outputs=[gr.Textbox(label="Result"), gr.Image(label="Heatmap"), gr.Image(label="Additional Output")], |
|
title="GastroHub AI Gastric Image Classifier", |
|
description="Upload an image to classify it as normal or abnormal.", |
|
article="Above is a sample image to test the results of the model. Click it to see the results.", |
|
examples=[ |
|
["/Gastric_Images/Ladybug.png"], |
|
], |
|
allow_flagging="never", |
|
) |
|
|
|
iface.launch() |