File size: 6,029 Bytes
ca1f33d 43aa94f ca1f33d 5e2b50c ca1f33d 61a1204 5e2b50c ca1f33d 5e2b50c d6c3c84 00d4300 5e2b50c 61a1204 5e2b50c 61a1204 00d4300 5e2b50c 00d4300 5e2b50c 61a1204 00d4300 5e2b50c 00d4300 5e2b50c 00d4300 5e2b50c 00d4300 94f7f47 00d4300 94f7f47 00d4300 94f7f47 00d4300 94f7f47 00d4300 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import torch
import torchvision.transforms as transforms
import torchvision.models as models
import torch.nn as nn
from joblib import load
from PIL import Image
import matplotlib.pyplot as plt
# Assuming that the model paths and other configurations are correctly set
device = torch.device("cpu")
# Transformation
data_transforms = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Load the Isolation Forest model
clf = load('Models/Anomaly_MSI_MSS_Isolation_Forest_model.joblib')
# Anomaly detection function
def is_anomaly(clf, feature_extractor, input_image):
feature_extractor.to(device)
with torch.no_grad():
image_features = feature_extractor(input_image)
is_outlier = clf.predict(image_features.cpu().numpy().reshape(1, -1))
return is_outlier[0] == -1
# Feature extractor
feature_extractor = models.resnet50(weights=None)
feature_extractor.fc = torch.nn.Sequential()
feature_extractor.load_state_dict(torch.load('Models/feature_extractor.pth', map_location=device))
feature_extractor.to(device)
feature_extractor.eval()
# Load the classification model
model_ft = torch.load('Gastric_Models/the_resnet_50_model.pth', map_location=device)
model_ft = model_ft.to(device)
model_ft.eval()
def classify_image(clf, feature_extractor, input_image, model_path, class_names):
"""
Detects anomalies and classifies the image.
Parameters:
clf - Anomaly detection model.
feature_extractor - Feature extractor for the anomaly detection model.
input_image - The image to be classified.
model_path - Path to the classification model.
class_names - List of class names for classification.
Returns:
A tuple containing the predicted class name and its probability.
"""
# Anomaly detection
if is_anomaly(clf, feature_extractor, input_image):
print("Anomaly detected. Image will not be classified.")
return None, None
else:
print("No anomaly detected. Proceeding with classification.")
# Load classification model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load(model_path, map_location=device)
model = model.to(device)
model.eval()
# Classification
with torch.no_grad():
outputs = model(input_image)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
_, predicted = torch.max(outputs, 1)
predicted_class_index = predicted.item()
predicted_class_name = class_names[predicted_class_index]
predicted_probability = probabilities[0][predicted_class_index].item() * 100
return predicted_class_name, predicted_probability
def generate_heatmap(model, input_image, device):
"""
Generates a heatmap based on the model's activation.
Parameters:
model - The classification model.
input_image - The image for which heatmap is to be generated.
device - The device (CPU or GPU) where the model is loaded.
Returns:
The generated heatmap.
"""
activation = []
def hook_fn(module, input, output):
activation.append(output)
for module in model.named_modules():
if isinstance(module[1], torch.nn.ReLU):
module[1].register_forward_hook(hook_fn)
model.zero_grad()
output = model(input_image)
prediction = output.argmax(1)
one_hot_output = torch.FloatTensor(1, output.size()[-1]).zero_().to(device)
one_hot_output[0][prediction] = 1
output.backward(gradient=one_hot_output)
if len(activation) > 0:
gradients = activation[0][0].detach().cpu().numpy()
heatmap = gradients.max(axis=0)
# Thresholding the heatmap
threshold = 0.5
heatmap[heatmap < threshold] = 0
return heatmap
else:
print("No activations recorded.")
return None
import gradio as gr
import torch
from PIL import Image
import io
def process_image(image):
"""
Processes the image and returns the classification results and heatmap.
"""
# Transform the image
input_image = data_transforms(image).unsqueeze(0).to(device)
# Classify the image
predicted_class_name, predicted_probability = classify_image(clf, feature_extractor, input_image, 'Gastric_Models/the_resnet_50_model.pth', ['abnormal', 'normal'])
if predicted_class_name is None:
return "Anomaly Detected - Image not classified", None
# Generate heatmap
heatmap = generate_heatmap(model_ft, input_image, device)
if heatmap is not None:
plt.imshow(heatmap, cmap='hot')
plt.axis('off')
buffer = io.BytesIO()
plt.savefig(buffer, format='png')
buffer.seek(0)
heatmap_image = Image.open(buffer)
plt.close()
description = "\n\nRed Regions - highest importance for the predicted class.\nYellow Regions - moderately high importance for the predicted class."
return f"The predicted class is: {predicted_class_name} at a probability of ({predicted_probability:.2f}%) {description}", heatmap_image
else:
return f"{predicted_class_name} ({predicted_probability:.2f}%)", "No heatmap generated"
gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil"),
outputs=[gr.Textbox(label="Classification Result"), gr.Image(label="Heatmap")],
title="Gastrohub Cancer Detection",
description="Upload an image to classify it as normal or abnormal.",
article="Above are a few sample images to test the results of the model. Click any to see the results.",
examples=[
["Gastric_Images/Abnormal-04038-test2.png"],
["Gastric_Images/cancer_108_WSI.png"],
["Gastric_Images/Ladybug.png"],
["Gastric_Images/normal_78_WSI.png"],
["Gastric_Images/Normal-01006-test1.png"],
["Gastric_Images/Normal-07000.png"],
],
allow_flagging="never",
).launch() |