File size: 4,582 Bytes
ca1f33d 43aa94f ca1f33d c8cc359 ca1f33d 61a1204 ca1f33d 61a1204 ca1f33d 61a1204 ca1f33d 61a1204 ca1f33d 61a1204 d6c3c84 61a1204 d6c3c84 61a1204 ca1f33d 61a1204 ca1f33d 61a1204 ca1f33d 61a1204 ca1f33d 61a1204 6499566 61a1204 6605792 61a1204 6605792 ca1f33d 61a1204 ca1f33d 61a1204 ca1f33d 61a1204 ccca058 6605792 ccca058 ca1f33d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import torch
import torchvision.transforms as transforms
import torchvision.models as models
import torch.nn as nn
from joblib import load
from gradio import File
from PIL import Image
import matplotlib.pyplot as plt
import io
import numpy as np
import gradio as gr
# Device configuration
device = torch.device("cpu")
# Transformation for the input images
data_transforms = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Load the Isolation Forest model
def load_isolation_forest():
path = 'Models/Anomaly_MSI_MSS_Isolation_Forest_model.joblib'
return load(path)
# Load the feature extractor
def load_feature_extractor():
feature_extractor_path = 'Models/feature_extractor.pth'
feature_extractor = models.resnet50(weights=None)
feature_extractor.fc = nn.Sequential()
feature_extractor.load_state_dict(torch.load(feature_extractor_path, map_location=device))
feature_extractor.to(device)
feature_extractor.eval()
return feature_extractor
# Anomaly detection function
def is_anomaly(clf, feature_extractor, image):
with torch.no_grad():
image_features = feature_extractor(image)
return clf.predict(image_features.cpu().numpy().reshape(1, -1))[0] == -1
# Classification function
def classify_image(model, image):
with torch.no_grad():
outputs = model(image)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
_, predicted = torch.max(outputs, 1)
class_names = ['abnormal', 'normal']
predicted_class_index = predicted.item()
predicted_class_name = class_names[predicted_class_index]
predicted_probability = probabilities[0][predicted_class_index].item() * 100
return predicted_class_name, predicted_probability
# Load the classification model
def load_classification_model():
model_path = 'Gastric_Models/the_resnet_50_model.pth'
model = torch.load(model_path, map_location=device)
model.to(device)
model.eval()
return model
# Function to process the image and get results
def process_image(image_path):
# Convert to PIL and apply transforms
image = Image.open(image_path).convert('RGB')
input_image = data_transforms(image).unsqueeze(0).to(device)
# Load models
clf = load_isolation_forest()
feature_extractor = load_feature_extractor()
classification_model = load_classification_model()
# Check for anomaly
if is_anomaly(clf, feature_extractor, input_image):
return "Anomaly detected. Image will not be classified.", None
# Classify image
predicted_class, probability = classify_image(classification_model, input_image)
result = f"The predicted class is: {predicted_class} with a probability of {probability:.2f}%"
# Generate heatmap
heatmap = generate_heatmap(classification_model, input_image)
heatmap_image = Image.fromarray(np.uint8(plt.cm.hot(heatmap) * 255))
return result, heatmap_image
# Function to generate heatmap
def generate_heatmap(model, image):
activation = []
def hook_fn(module, input, output):
activation.append(output)
for module in model.named_modules():
if isinstance(module[1], torch.nn.ReLU):
module[1].register_forward_hook(hook_fn)
# Forward pass
output = model(image)
prediction = output.argmax(1)
# Backpropagation to compute gradients
model.zero_grad()
one_hot_output = torch.FloatTensor(1, output.size()[-1]).zero_().to(device)
one_hot_output[0][prediction] = 1
output.backward(gradient=one_hot_output)
# Compute the heatmap
if len(activation) > 0:
gradients = activation[0][0].detach().cpu().numpy()
heatmap = gradients.max(axis=0)
threshold = 0.5 # Adjust this threshold value as needed
heatmap[heatmap < threshold] = 0
return heatmap
else:
return np.zeros((224, 224)) # Return an empty heatmap if no activation is recorded
# Gradio interface
iface = gr.Interface(
fn=process_image,
inputs=File(type="filepath"),
outputs=[gr.Textbox(label="Result"), gr.Image(label="Heatmap"), gr.Image(label="Additional Output")],
title="GastroHub AI Gastric Image Classifier",
description="Upload an image to classify it as normal or abnormal.",
article="Above is a sample image to test the results of the model. Click it to see the results.",
examples=[
["/Gastric_Images/Ladybug.png"],
],
allow_flagging="never",
)
iface.launch() |