File size: 7,950 Bytes
7a986e7
f2c28c8
 
7a986e7
 
 
f2c28c8
 
 
b75380d
f2c28c8
7a986e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b75380d
7a986e7
 
 
 
 
 
 
 
 
 
 
 
 
 
b75380d
 
7a986e7
 
 
 
 
 
0eb72c7
7a986e7
0eb72c7
7a986e7
 
0eb72c7
 
 
 
b75380d
0eb72c7
b75380d
0eb72c7
 
 
 
 
 
b75380d
0eb72c7
 
7a986e7
0eb72c7
 
 
 
 
 
 
 
 
 
 
 
b75380d
0eb72c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a986e7
 
f2c28c8
 
 
 
7a986e7
 
0eb72c7
f2c28c8
dfbd972
0eb72c7
 
f2c28c8
7a986e7
f2c28c8
 
 
7a986e7
f2c28c8
7a986e7
f2c28c8
 
7a986e7
f2c28c8
 
 
 
 
 
 
 
 
7a986e7
f2c28c8
 
 
a4f839e
0eb72c7
f2c28c8
7a986e7
dfbd972
f2c28c8
 
 
7a986e7
 
dfbd972
 
 
0eb72c7
 
 
 
 
b75380d
 
 
 
 
 
 
 
 
 
7a986e7
 
dfbd972
a7e5224
dfbd972
a7e5224
b75380d
 
 
dfbd972
 
f2c28c8
 
a4f839e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import cv2
import torch
import numpy as np
import torch.nn.functional as F
from torch import nn
from transformers import AutoImageProcessor, Swinv2ForImageClassification, SegformerForSemanticSegmentation
import matplotlib.pyplot as plt
import streamlit as st
from PIL import Image
import io

# --- GlaucomaModel Class ---
class GlaucomaModel(object):
    def __init__(self, 
                 cls_model_path="pamixsun/swinv2_tiny_for_glaucoma_classification", 
                 seg_model_path='pamixsun/segformer_for_optic_disc_cup_segmentation',
                 device=torch.device('cpu')):
        self.device = device
        # Classification model for glaucoma
        self.cls_extractor = AutoImageProcessor.from_pretrained(cls_model_path)
        self.cls_model = Swinv2ForImageClassification.from_pretrained(cls_model_path).to(device).eval()
        # Segmentation model for optic disc and cup
        self.seg_extractor = AutoImageProcessor.from_pretrained(seg_model_path)
        self.seg_model = SegformerForSemanticSegmentation.from_pretrained(seg_model_path).to(device).eval()

        # Class activation map
        self.cls_id2label = self.cls_model.config.id2label
        self.seg_id2label = self.seg_model.config.id2label

    def glaucoma_pred(self, image):
        inputs = self.cls_extractor(images=image.copy(), return_tensors="pt")
        with torch.no_grad():
            inputs.to(self.device)
            outputs = self.cls_model(**inputs).logits
            # Softmax for probabilities
            probs = F.softmax(outputs, dim=-1)
            disease_idx = probs.cpu()[0, :].numpy().argmax()
            confidence = probs.cpu()[0, disease_idx].item() * 100  # Scale to percentage
        return disease_idx, confidence

    def optic_disc_cup_pred(self, image):
        inputs = self.seg_extractor(images=image.copy(), return_tensors="pt")
        with torch.no_grad():
            inputs.to(self.device)
            outputs = self.seg_model(**inputs)
        logits = outputs.logits.cpu()
        upsampled_logits = nn.functional.interpolate(
            logits, size=image.shape[:2], mode="bilinear", align_corners=False
        )
        # Softmax for segmentation confidence
        seg_probs = F.softmax(upsampled_logits, dim=1)
        pred_disc_cup = upsampled_logits.argmax(dim=1)[0]
        cup_confidence = seg_probs[0, 2, :, :].mean().item() * 100  # Scale to percentage
        disc_confidence = seg_probs[0, 1, :, :].mean().item() * 100  # Scale to percentage
        return pred_disc_cup.numpy().astype(np.uint8), cup_confidence, disc_confidence

    def process(self, image):
        image_shape = image.shape[:2]
        disease_idx, cls_confidence = self.glaucoma_pred(image)
        disc_cup, cup_confidence, disc_confidence = self.optic_disc_cup_pred(image)

        try:
            vcdr = simple_vcdr(disc_cup)  # Calculate vertical cup-to-disc ratio
        except:
            vcdr = np.nan

        # Mask for optic disc and cup
        mask = (disc_cup > 0).astype(np.uint8)
        
        # Get bounding box of the optic cup + disc and add dynamic padding
        x, y, w, h = cv2.boundingRect(mask)
        padding = max(50, int(0.2 * max(w, h)))  # Dynamic padding (20% of width or height)
        x = max(x - padding, 0)
        y = max(y - padding, 0)
        w = min(w + 2 * padding, image.shape[1] - x)
        h = min(h + 2 * padding, image.shape[0] - y)

        # Ensure that the bounding box is large enough to avoid cropping errors
        cropped_image = image[y:y+h, x:x+w] if w >= 50 and h >= 50 else image.copy()

        # Generate disc and cup visualization
        _, disc_cup_image = add_mask(image, disc_cup, [1, 2], [[0, 255, 0], [255, 0, 0]], 0.2)

        return disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence, cropped_image

# --- Utility Functions ---
def simple_vcdr(mask):
    """
    Simple function to calculate the vertical cup-to-disc ratio (VCDR).
    Assumes:
    - mask contains class 1 for optic disc and class 2 for optic cup.
    """
    disc_area = np.sum(mask == 1)
    cup_area = np.sum(mask == 2)
    if disc_area == 0:  # Avoid division by zero
        return np.nan
    vcdr = cup_area / disc_area
    return vcdr

def add_mask(image, mask, classes, colors, alpha=0.5):
    """
    Adds a transparent mask to the original image.
    Args:
    - image: the original RGB image
    - mask: the predicted segmentation mask
    - classes: a list of class indices to apply masks for (e.g., [1, 2])
    - colors: a list of colors for each class (e.g., [[0, 255, 0], [255, 0, 0]] for green and red)
    - alpha: transparency level (default = 0.5)
    """
    overlay = image.copy()
    for class_id, color in zip(classes, colors):
        overlay[mask == class_id] = color
    output = cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0)
    return output, overlay

# --- Streamlit Interface ---
def main():
    st.set_page_config(layout="wide")
    st.title("Glaucoma Screening from Retinal Fundus Images")
    st.write('Developed by X. Sun. Find more info about me: https://pamixsun.github.io')

    # Set columns for the interface
    cols = st.beta_columns((1, 1, 1, 1))
    cols[0].subheader("Input image")
    cols[1].subheader("Optic disc and optic cup")
    cols[2].subheader("Class activation map")
    cols[3].subheader("Cropped Image")

    # File uploader
    st.sidebar.title("Image selection")
    st.set_option('deprecation.showfileUploaderEncoding', False)
    uploaded_file = st.sidebar.file_uploader("Upload image", type=['png', 'jpeg', 'jpg'])

    if uploaded_file is not None:
        # Read and display uploaded image
        image = Image.open(uploaded_file).convert('RGB')
        image = np.array(image).astype(np.uint8)
        fig, ax = plt.subplots()
        ax.imshow(image)
        ax.axis('off')
        cols[0].pyplot(fig)

    if st.sidebar.button("Analyze image"):
        if uploaded_file is None:
            st.sidebar.write("Please upload an image")
        else:
            with st.spinner('Loading model...'):
                run_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
                model = GlaucomaModel(device=run_device)

            with st.spinner('Analyzing...'):
                # Get predictions from the model
                disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence, cropped_image = model.process(image)

                # Display optic disc and cup image
                ax.imshow(disc_cup_image)
                ax.axis('off')
                cols[1].pyplot(fig)

                # Display classification map
                ax.imshow(image)
                ax.axis('off')
                cols[2].pyplot(fig)

                # Display the cropped image
                ax.imshow(cropped_image)
                ax.axis('off')
                cols[3].pyplot(fig)

                # Make cropped image downloadable
                buf = io.BytesIO()
                Image.fromarray(cropped_image).save(buf, format="PNG")
                st.sidebar.download_button(
                    label="Download Cropped Image",
                    data=buf.getvalue(),
                    file_name="cropped_image.png",
                    mime="image/png"
                )

                # Display results with confidence
                st.subheader("Screening results:")
                final_results_as_table = f"""
                |Parameters|Outcomes|
                |---|---|
                |Vertical cup-to-disc ratio|{vcdr:.04f}|
                |Category|{model.cls_id2label[disease_idx]} ({cls_confidence:.02f}% confidence)|
                |Optic Cup Segmentation Confidence|{cup_confidence:.02f}%|
                |Optic Disc Segmentation Confidence|{disc_confidence:.02f}%|
                """
                st.markdown(final_results_as_table)

if __name__ == '__main__':
    main()