File size: 4,582 Bytes
ca1f33d
 
 
43aa94f
ca1f33d
c8cc359
ca1f33d
 
 
61a1204
 
ca1f33d
61a1204
ca1f33d
61a1204
 
ca1f33d
 
 
 
 
 
 
 
61a1204
 
 
ca1f33d
61a1204
 
 
 
 
 
d6c3c84
61a1204
 
d6c3c84
61a1204
 
 
 
 
ca1f33d
61a1204
 
ca1f33d
61a1204
ca1f33d
 
 
 
61a1204
ca1f33d
 
 
61a1204
 
 
 
 
 
 
 
 
 
 
 
 
6499566
61a1204
 
 
 
 
 
 
 
 
6605792
61a1204
 
 
 
 
6605792
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca1f33d
61a1204
ca1f33d
61a1204
ca1f33d
61a1204
ccca058
 
 
 
6605792
ccca058
 
ca1f33d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
import torchvision.transforms as transforms
import torchvision.models as models
import torch.nn as nn
from joblib import load
from gradio import File
from PIL import Image
import matplotlib.pyplot as plt
import io
import numpy as np
import gradio as gr

# Device configuration
device = torch.device("cpu")

# Transformation for the input images
data_transforms = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load the Isolation Forest model
def load_isolation_forest():
    path = 'Models/Anomaly_MSI_MSS_Isolation_Forest_model.joblib'
    return load(path)

# Load the feature extractor
def load_feature_extractor():
    feature_extractor_path = 'Models/feature_extractor.pth'
    feature_extractor = models.resnet50(weights=None)
    feature_extractor.fc = nn.Sequential()
    feature_extractor.load_state_dict(torch.load(feature_extractor_path, map_location=device))
    feature_extractor.to(device)
    feature_extractor.eval()
    return feature_extractor

# Anomaly detection function
def is_anomaly(clf, feature_extractor, image):
    with torch.no_grad():
        image_features = feature_extractor(image)
    return clf.predict(image_features.cpu().numpy().reshape(1, -1))[0] == -1

# Classification function
def classify_image(model, image):
    with torch.no_grad():
        outputs = model(image)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        _, predicted = torch.max(outputs, 1)

    class_names = ['abnormal', 'normal']
    predicted_class_index = predicted.item()
    predicted_class_name = class_names[predicted_class_index]
    predicted_probability = probabilities[0][predicted_class_index].item() * 100

    return predicted_class_name, predicted_probability

# Load the classification model
def load_classification_model():
    model_path = 'Gastric_Models/the_resnet_50_model.pth'
    model = torch.load(model_path, map_location=device)
    model.to(device)
    model.eval()
    return model

# Function to process the image and get results
def process_image(image_path):
    # Convert to PIL and apply transforms
    image = Image.open(image_path).convert('RGB')
    input_image = data_transforms(image).unsqueeze(0).to(device)

    # Load models
    clf = load_isolation_forest()
    feature_extractor = load_feature_extractor()
    classification_model = load_classification_model()

    # Check for anomaly
    if is_anomaly(clf, feature_extractor, input_image):
        return "Anomaly detected. Image will not be classified.", None

    # Classify image
    predicted_class, probability = classify_image(classification_model, input_image)
    result = f"The predicted class is: {predicted_class} with a probability of {probability:.2f}%"

    # Generate heatmap
    heatmap = generate_heatmap(classification_model, input_image)
    heatmap_image = Image.fromarray(np.uint8(plt.cm.hot(heatmap) * 255))

    return result, heatmap_image

# Function to generate heatmap
def generate_heatmap(model, image):
    activation = []
    def hook_fn(module, input, output):
        activation.append(output)
    for module in model.named_modules():
        if isinstance(module[1], torch.nn.ReLU):
            module[1].register_forward_hook(hook_fn)

    # Forward pass
    output = model(image)
    prediction = output.argmax(1)

    # Backpropagation to compute gradients
    model.zero_grad()
    one_hot_output = torch.FloatTensor(1, output.size()[-1]).zero_().to(device)
    one_hot_output[0][prediction] = 1
    output.backward(gradient=one_hot_output)

    # Compute the heatmap
    if len(activation) > 0:
        gradients = activation[0][0].detach().cpu().numpy()
        heatmap = gradients.max(axis=0)
        threshold = 0.5  # Adjust this threshold value as needed
        heatmap[heatmap < threshold] = 0
        return heatmap
    else:
        return np.zeros((224, 224))  # Return an empty heatmap if no activation is recorded

# Gradio interface
iface = gr.Interface(
    fn=process_image,
    inputs=File(type="filepath"),
    outputs=[gr.Textbox(label="Result"), gr.Image(label="Heatmap"), gr.Image(label="Additional Output")],
    title="GastroHub AI Gastric Image Classifier",
    description="Upload an image to classify it as normal or abnormal.",
    article="Above is a sample image to test the results of the model. Click it to see the results.",
    examples=[
        ["/Gastric_Images/Ladybug.png"],
    ],
    allow_flagging="never",
)

iface.launch()