|
import torch |
|
import torchvision.transforms as transforms |
|
import torchvision.models as models |
|
import torch.nn as nn |
|
from joblib import load |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
device = torch.device("cpu") |
|
|
|
|
|
data_transforms = transforms.Compose([ |
|
transforms.Resize(224), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
]) |
|
|
|
|
|
clf = load('Models/Anomaly_MSI_MSS_Isolation_Forest_model.joblib') |
|
|
|
|
|
def is_anomaly(clf, feature_extractor, input_image): |
|
feature_extractor.to(device) |
|
with torch.no_grad(): |
|
image_features = feature_extractor(input_image) |
|
|
|
is_outlier = clf.predict(image_features.cpu().numpy().reshape(1, -1)) |
|
return is_outlier[0] == -1 |
|
|
|
|
|
feature_extractor = models.resnet50(weights=None) |
|
feature_extractor.fc = torch.nn.Sequential() |
|
feature_extractor.load_state_dict(torch.load('Models/feature_extractor.pth', map_location=device)) |
|
feature_extractor.to(device) |
|
feature_extractor.eval() |
|
|
|
|
|
model_ft = torch.load('Gastric_Models/the_resnet_50_model.pth', map_location=device) |
|
model_ft = model_ft.to(device) |
|
model_ft.eval() |
|
|
|
def classify_image(clf, feature_extractor, input_image, model_path, class_names): |
|
""" |
|
Detects anomalies and classifies the image. |
|
|
|
Parameters: |
|
clf - Anomaly detection model. |
|
feature_extractor - Feature extractor for the anomaly detection model. |
|
input_image - The image to be classified. |
|
model_path - Path to the classification model. |
|
class_names - List of class names for classification. |
|
|
|
Returns: |
|
A tuple containing the predicted class name and its probability. |
|
""" |
|
|
|
if is_anomaly(clf, feature_extractor, input_image): |
|
print("Anomaly detected. Image will not be classified.") |
|
return None, None |
|
else: |
|
print("No anomaly detected. Proceeding with classification.") |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = torch.load(model_path, map_location=device) |
|
model = model.to(device) |
|
model.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(input_image) |
|
probabilities = torch.nn.functional.softmax(outputs, dim=1) |
|
_, predicted = torch.max(outputs, 1) |
|
|
|
predicted_class_index = predicted.item() |
|
predicted_class_name = class_names[predicted_class_index] |
|
predicted_probability = probabilities[0][predicted_class_index].item() * 100 |
|
|
|
return predicted_class_name, predicted_probability |
|
|
|
def generate_heatmap(model, input_image, device): |
|
""" |
|
Generates a heatmap based on the model's activation. |
|
|
|
Parameters: |
|
model - The classification model. |
|
input_image - The image for which heatmap is to be generated. |
|
device - The device (CPU or GPU) where the model is loaded. |
|
|
|
Returns: |
|
The generated heatmap. |
|
""" |
|
activation = [] |
|
|
|
def hook_fn(module, input, output): |
|
activation.append(output) |
|
|
|
for module in model.named_modules(): |
|
if isinstance(module[1], torch.nn.ReLU): |
|
module[1].register_forward_hook(hook_fn) |
|
|
|
model.zero_grad() |
|
output = model(input_image) |
|
prediction = output.argmax(1) |
|
|
|
one_hot_output = torch.FloatTensor(1, output.size()[-1]).zero_().to(device) |
|
one_hot_output[0][prediction] = 1 |
|
output.backward(gradient=one_hot_output) |
|
|
|
if len(activation) > 0: |
|
gradients = activation[0][0].detach().cpu().numpy() |
|
heatmap = gradients.max(axis=0) |
|
|
|
|
|
threshold = 0.5 |
|
heatmap[heatmap < threshold] = 0 |
|
|
|
return heatmap |
|
|
|
else: |
|
print("No activations recorded.") |
|
return None |
|
|
|
import gradio as gr |
|
import torch |
|
from PIL import Image |
|
import io |
|
|
|
def process_image(image): |
|
""" |
|
Processes the image and returns the classification results and heatmap. |
|
""" |
|
|
|
input_image = data_transforms(image).unsqueeze(0).to(device) |
|
|
|
|
|
predicted_class_name, predicted_probability = classify_image(clf, feature_extractor, input_image, 'Gastric_Models/the_resnet_50_model.pth', ['abnormal', 'normal']) |
|
|
|
if predicted_class_name is None: |
|
return "Anomaly Detected - Image not classified", None |
|
|
|
|
|
heatmap = generate_heatmap(model_ft, input_image, device) |
|
if heatmap is not None: |
|
plt.imshow(heatmap, cmap='hot') |
|
plt.axis('off') |
|
buffer = io.BytesIO() |
|
plt.savefig(buffer, format='png') |
|
buffer.seek(0) |
|
heatmap_image = Image.open(buffer) |
|
plt.close() |
|
description = "\n\nRed Regions - highest importance for the predicted class.\nYellow Regions - moderately high importance for the predicted class." |
|
return f"The predicted class is: {predicted_class_name} at a probability of ({predicted_probability:.2f}%) {description}", heatmap_image |
|
else: |
|
return f"{predicted_class_name} ({predicted_probability:.2f}%)", "No heatmap generated" |
|
|
|
gr.Interface( |
|
fn=process_image, |
|
inputs=gr.Image(type="pil"), |
|
outputs=[gr.Textbox(label="Classification Result"), gr.Image(label="Heatmap")], |
|
title="Gastrohub Cancer Detection", |
|
description="Upload an image to classify it as normal or abnormal.", |
|
article="Above are a few sample images to test the results of the model. Click any to see the results.", |
|
examples=[ |
|
["Gastric_Images/Abnormal-04038-test2.png"], |
|
["Gastric_Images/cancer_108_WSI.png"], |
|
["Gastric_Images/Ladybug.png"], |
|
["Gastric_Images/normal_78_WSI.png"], |
|
["Gastric_Images/Normal-01006-test1.png"], |
|
["Gastric_Images/Normal-07000.png"], |
|
], |
|
allow_flagging="never", |
|
).launch() |