CindyBSydney's picture
Update app.py
00d4300
import torch
import torchvision.transforms as transforms
import torchvision.models as models
import torch.nn as nn
from joblib import load
from PIL import Image
import matplotlib.pyplot as plt
# Assuming that the model paths and other configurations are correctly set
device = torch.device("cpu")
# Transformation
data_transforms = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Load the Isolation Forest model
clf = load('Models/Anomaly_MSI_MSS_Isolation_Forest_model.joblib')
# Anomaly detection function
def is_anomaly(clf, feature_extractor, input_image):
feature_extractor.to(device)
with torch.no_grad():
image_features = feature_extractor(input_image)
is_outlier = clf.predict(image_features.cpu().numpy().reshape(1, -1))
return is_outlier[0] == -1
# Feature extractor
feature_extractor = models.resnet50(weights=None)
feature_extractor.fc = torch.nn.Sequential()
feature_extractor.load_state_dict(torch.load('Models/feature_extractor.pth', map_location=device))
feature_extractor.to(device)
feature_extractor.eval()
# Load the classification model
model_ft = torch.load('Gastric_Models/the_resnet_50_model.pth', map_location=device)
model_ft = model_ft.to(device)
model_ft.eval()
def classify_image(clf, feature_extractor, input_image, model_path, class_names):
"""
Detects anomalies and classifies the image.
Parameters:
clf - Anomaly detection model.
feature_extractor - Feature extractor for the anomaly detection model.
input_image - The image to be classified.
model_path - Path to the classification model.
class_names - List of class names for classification.
Returns:
A tuple containing the predicted class name and its probability.
"""
# Anomaly detection
if is_anomaly(clf, feature_extractor, input_image):
print("Anomaly detected. Image will not be classified.")
return None, None
else:
print("No anomaly detected. Proceeding with classification.")
# Load classification model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load(model_path, map_location=device)
model = model.to(device)
model.eval()
# Classification
with torch.no_grad():
outputs = model(input_image)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
_, predicted = torch.max(outputs, 1)
predicted_class_index = predicted.item()
predicted_class_name = class_names[predicted_class_index]
predicted_probability = probabilities[0][predicted_class_index].item() * 100
return predicted_class_name, predicted_probability
def generate_heatmap(model, input_image, device):
"""
Generates a heatmap based on the model's activation.
Parameters:
model - The classification model.
input_image - The image for which heatmap is to be generated.
device - The device (CPU or GPU) where the model is loaded.
Returns:
The generated heatmap.
"""
activation = []
def hook_fn(module, input, output):
activation.append(output)
for module in model.named_modules():
if isinstance(module[1], torch.nn.ReLU):
module[1].register_forward_hook(hook_fn)
model.zero_grad()
output = model(input_image)
prediction = output.argmax(1)
one_hot_output = torch.FloatTensor(1, output.size()[-1]).zero_().to(device)
one_hot_output[0][prediction] = 1
output.backward(gradient=one_hot_output)
if len(activation) > 0:
gradients = activation[0][0].detach().cpu().numpy()
heatmap = gradients.max(axis=0)
# Thresholding the heatmap
threshold = 0.5
heatmap[heatmap < threshold] = 0
return heatmap
else:
print("No activations recorded.")
return None
import gradio as gr
import torch
from PIL import Image
import io
def process_image(image):
"""
Processes the image and returns the classification results and heatmap.
"""
# Transform the image
input_image = data_transforms(image).unsqueeze(0).to(device)
# Classify the image
predicted_class_name, predicted_probability = classify_image(clf, feature_extractor, input_image, 'Gastric_Models/the_resnet_50_model.pth', ['abnormal', 'normal'])
if predicted_class_name is None:
return "Anomaly Detected - Image not classified", None
# Generate heatmap
heatmap = generate_heatmap(model_ft, input_image, device)
if heatmap is not None:
plt.imshow(heatmap, cmap='hot')
plt.axis('off')
buffer = io.BytesIO()
plt.savefig(buffer, format='png')
buffer.seek(0)
heatmap_image = Image.open(buffer)
plt.close()
description = "\n\nRed Regions - highest importance for the predicted class.\nYellow Regions - moderately high importance for the predicted class."
return f"The predicted class is: {predicted_class_name} at a probability of ({predicted_probability:.2f}%) {description}", heatmap_image
else:
return f"{predicted_class_name} ({predicted_probability:.2f}%)", "No heatmap generated"
gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil"),
outputs=[gr.Textbox(label="Classification Result"), gr.Image(label="Heatmap")],
title="Gastrohub Cancer Detection",
description="Upload an image to classify it as normal or abnormal.",
article="Above are a few sample images to test the results of the model. Click any to see the results.",
examples=[
["Gastric_Images/Abnormal-04038-test2.png"],
["Gastric_Images/cancer_108_WSI.png"],
["Gastric_Images/Ladybug.png"],
["Gastric_Images/normal_78_WSI.png"],
["Gastric_Images/Normal-01006-test1.png"],
["Gastric_Images/Normal-07000.png"],
],
allow_flagging="never",
).launch()