File size: 6,029 Bytes
ca1f33d
 
 
43aa94f
ca1f33d
 
 
 
5e2b50c
ca1f33d
61a1204
5e2b50c
ca1f33d
 
 
 
 
 
 
 
5e2b50c
d6c3c84
00d4300
 
 
 
 
 
 
 
 
5e2b50c
 
 
 
 
 
61a1204
 
5e2b50c
 
 
61a1204
00d4300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e2b50c
 
00d4300
 
 
 
5e2b50c
61a1204
00d4300
 
5e2b50c
00d4300
 
5e2b50c
00d4300
 
 
 
 
 
 
 
 
 
 
 
 
 
5e2b50c
00d4300
94f7f47
 
00d4300
94f7f47
 
00d4300
94f7f47
00d4300
 
 
 
 
 
94f7f47
00d4300
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import torch
import torchvision.transforms as transforms
import torchvision.models as models
import torch.nn as nn
from joblib import load
from PIL import Image
import matplotlib.pyplot as plt

# Assuming that the model paths and other configurations are correctly set
device = torch.device("cpu")

# Transformation
data_transforms = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load the Isolation Forest model
clf = load('Models/Anomaly_MSI_MSS_Isolation_Forest_model.joblib')

# Anomaly detection function
def is_anomaly(clf, feature_extractor, input_image):
    feature_extractor.to(device)
    with torch.no_grad():
        image_features = feature_extractor(input_image)

    is_outlier = clf.predict(image_features.cpu().numpy().reshape(1, -1))
    return is_outlier[0] == -1

# Feature extractor
feature_extractor = models.resnet50(weights=None)
feature_extractor.fc = torch.nn.Sequential()
feature_extractor.load_state_dict(torch.load('Models/feature_extractor.pth', map_location=device))
feature_extractor.to(device)
feature_extractor.eval()

# Load the classification model
model_ft = torch.load('Gastric_Models/the_resnet_50_model.pth', map_location=device)
model_ft = model_ft.to(device)
model_ft.eval()

def classify_image(clf, feature_extractor, input_image, model_path, class_names):
    """
    Detects anomalies and classifies the image.

    Parameters:
    clf - Anomaly detection model.
    feature_extractor - Feature extractor for the anomaly detection model.
    input_image - The image to be classified.
    model_path - Path to the classification model.
    class_names - List of class names for classification.

    Returns:
    A tuple containing the predicted class name and its probability.
    """
    # Anomaly detection
    if is_anomaly(clf, feature_extractor, input_image):
        print("Anomaly detected. Image will not be classified.")
        return None, None
    else:
        print("No anomaly detected. Proceeding with classification.")

        # Load classification model
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = torch.load(model_path, map_location=device)
        model = model.to(device)
        model.eval()

        # Classification
        with torch.no_grad():
            outputs = model(input_image)
            probabilities = torch.nn.functional.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs, 1)

        predicted_class_index = predicted.item()
        predicted_class_name = class_names[predicted_class_index]
        predicted_probability = probabilities[0][predicted_class_index].item() * 100

        return predicted_class_name, predicted_probability

def generate_heatmap(model, input_image, device):
    """
    Generates a heatmap based on the model's activation.

    Parameters:
    model - The classification model.
    input_image - The image for which heatmap is to be generated.
    device - The device (CPU or GPU) where the model is loaded.

    Returns:
    The generated heatmap.
    """
    activation = []

    def hook_fn(module, input, output):
        activation.append(output)

    for module in model.named_modules():
        if isinstance(module[1], torch.nn.ReLU):
            module[1].register_forward_hook(hook_fn)

    model.zero_grad()
    output = model(input_image)
    prediction = output.argmax(1)

    one_hot_output = torch.FloatTensor(1, output.size()[-1]).zero_().to(device)
    one_hot_output[0][prediction] = 1
    output.backward(gradient=one_hot_output)

    if len(activation) > 0:
        gradients = activation[0][0].detach().cpu().numpy()
        heatmap = gradients.max(axis=0)

        # Thresholding the heatmap
        threshold = 0.5
        heatmap[heatmap < threshold] = 0

        return heatmap

    else:
        print("No activations recorded.")
        return None

import gradio as gr
import torch
from PIL import Image
import io

def process_image(image):
    """
    Processes the image and returns the classification results and heatmap.
    """
    # Transform the image
    input_image = data_transforms(image).unsqueeze(0).to(device)

    # Classify the image
    predicted_class_name, predicted_probability = classify_image(clf, feature_extractor, input_image, 'Gastric_Models/the_resnet_50_model.pth', ['abnormal', 'normal'])

    if predicted_class_name is None:
        return "Anomaly Detected - Image not classified", None

    # Generate heatmap
    heatmap = generate_heatmap(model_ft, input_image, device)
    if heatmap is not None:
        plt.imshow(heatmap, cmap='hot')
        plt.axis('off')
        buffer = io.BytesIO()
        plt.savefig(buffer, format='png')
        buffer.seek(0)
        heatmap_image = Image.open(buffer)
        plt.close()
        description = "\n\nRed Regions -  highest importance for the predicted class.\nYellow Regions - moderately high importance for the predicted class."
        return f"The predicted class is: {predicted_class_name} at a probability of ({predicted_probability:.2f}%) {description}", heatmap_image
    else:
        return f"{predicted_class_name} ({predicted_probability:.2f}%)", "No heatmap generated"

gr.Interface(
    fn=process_image,
    inputs=gr.Image(type="pil"),
    outputs=[gr.Textbox(label="Classification Result"), gr.Image(label="Heatmap")],
    title="Gastrohub Cancer Detection",
    description="Upload an image to classify it as normal or abnormal.",
    article="Above are a few sample images to test the results of the model. Click any to see the results.",
    examples=[
        ["Gastric_Images/Abnormal-04038-test2.png"],
        ["Gastric_Images/cancer_108_WSI.png"],
        ["Gastric_Images/Ladybug.png"],
        ["Gastric_Images/normal_78_WSI.png"],
        ["Gastric_Images/Normal-01006-test1.png"],
        ["Gastric_Images/Normal-07000.png"],
    ],
    allow_flagging="never",
).launch()