File size: 5,584 Bytes
7a986e7
f2c28c8
 
7a986e7
 
 
f2c28c8
 
 
 
7a986e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2c28c8
7a986e7
f2c28c8
 
 
 
7a986e7
 
dfbd972
f2c28c8
dfbd972
7a986e7
f2c28c8
7a986e7
f2c28c8
 
 
7a986e7
f2c28c8
7a986e7
f2c28c8
 
7a986e7
f2c28c8
 
 
 
 
 
 
 
 
7a986e7
 
f2c28c8
 
 
7a986e7
 
f2c28c8
7a986e7
dfbd972
f2c28c8
 
 
7a986e7
 
dfbd972
 
 
7a986e7
 
dfbd972
a7e5224
dfbd972
a7e5224
7a986e7
 
 
dfbd972
 
f2c28c8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import cv2
import torch
import numpy as np
import torch.nn.functional as F
from torch import nn
from transformers import AutoImageProcessor, Swinv2ForImageClassification, SegformerForSemanticSegmentation
import matplotlib.pyplot as plt
import streamlit as st
from PIL import Image

# --- GlaucomaModel Class ---
class GlaucomaModel(object):
    def __init__(self, 
                 cls_model_path="pamixsun/swinv2_tiny_for_glaucoma_classification", 
                 seg_model_path='pamixsun/segformer_for_optic_disc_cup_segmentation',
                 device=torch.device('cpu')):
        self.device = device
        # Classification model for glaucoma
        self.cls_extractor = AutoImageProcessor.from_pretrained(cls_model_path)
        self.cls_model = Swinv2ForImageClassification.from_pretrained(cls_model_path).to(device).eval()
        # Segmentation model for optic disc and cup
        self.seg_extractor = AutoImageProcessor.from_pretrained(seg_model_path)
        self.seg_model = SegformerForSemanticSegmentation.from_pretrained(seg_model_path).to(device).eval()

        # Class activation map
        self.cls_id2label = self.cls_model.config.id2label
        self.seg_id2label = self.seg_model.config.id2label

    def glaucoma_pred(self, image):
        inputs = self.cls_extractor(images=image.copy(), return_tensors="pt")
        with torch.no_grad():
            inputs.to(self.device)
            outputs = self.cls_model(**inputs).logits
            # Softmax for probabilities
            probs = F.softmax(outputs, dim=-1)
            disease_idx = probs.cpu()[0, :].numpy().argmax()
            confidence = probs.cpu()[0, disease_idx].item()
        return disease_idx, confidence

    def optic_disc_cup_pred(self, image):
        inputs = self.seg_extractor(images=image.copy(), return_tensors="pt")
        with torch.no_grad():
            inputs.to(self.device)
            outputs = self.seg_model(**inputs)
        logits = outputs.logits.cpu()
        upsampled_logits = nn.functional.interpolate(
            logits, size=image.shape[:2], mode="bilinear", align_corners=False
        )
        # Softmax for segmentation confidence
        seg_probs = F.softmax(upsampled_logits, dim=1)
        pred_disc_cup = upsampled_logits.argmax(dim=1)[0]
        cup_confidence = seg_probs[0, 2, :, :].mean().item()
        disc_confidence = seg_probs[0, 1, :, :].mean().item()
        return pred_disc_cup.numpy().astype(np.uint8), cup_confidence, disc_confidence

    def process(self, image):
        image_shape = image.shape[:2]
        disease_idx, cls_confidence = self.glaucoma_pred(image)
        disc_cup, cup_confidence, disc_confidence = self.optic_disc_cup_pred(image)
        try:
            vcdr = simple_vcdr(disc_cup)  # Assuming simple_vcdr() is defined elsewhere
        except:
            vcdr = np.nan
        _, disc_cup_image = add_mask(image, disc_cup, [1, 2], [[0, 255, 0], [255, 0, 0]], 0.2)
        return disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence

# --- Streamlit Interface ---
def main():
    # Wide mode in Streamlit
    st.set_page_config(layout="wide")

    st.title("Glaucoma Screening from Retinal Fundus Images")
    st.write('Developed by X. Sun. Find more info about me: https://pamixsun.github.io')

    # Set columns for the interface
    cols = st.beta_columns((1, 1, 1))
    cols[0].subheader("Input image")
    cols[1].subheader("Optic disc and optic cup")
    cols[2].subheader("Classification Map")

    # File uploader
    st.sidebar.title("Image selection")
    st.set_option('deprecation.showfileUploaderEncoding', False)
    uploaded_file = st.sidebar.file_uploader("Upload image", type=['png', 'jpeg', 'jpg'])

    if uploaded_file is not None:
        # Read and display uploaded image
        image = Image.open(uploaded_file).convert('RGB')
        image = np.array(image).astype(np.uint8)
        fig, ax = plt.subplots()
        ax.imshow(image)
        ax.axis('off')
        cols[0].pyplot(fig)

    if st.sidebar.button("Analyze image"):
        if uploaded_file is None:
            st.sidebar.write("Please upload an image")
        else:
            with st.spinner('Loading model...'):
                # Load the model on available device
                run_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
                model = GlaucomaModel(device=run_device)

            with st.spinner('Analyzing...'):
                # Get predictions from the model
                disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence = model.process(image)

                # Display optic disc and cup image
                ax.imshow(disc_cup_image)
                ax.axis('off')
                cols[1].pyplot(fig)

                # Display classification map
                ax.imshow(image)
                ax.axis('off')
                cols[2].pyplot(fig)

                # Display results with confidence
                st.subheader("Screening results:")
                final_results_as_table = f"""
                |Parameters|Outcomes|
                |---|---|
                |Vertical cup-to-disc ratio|{vcdr:.04f}|
                |Category|{model.cls_id2label[disease_idx]} ({cls_confidence*100:.02f}% confidence)|
                |Optic Cup Segmentation Confidence|{cup_confidence*100:.02f}%|
                |Optic Disc Segmentation Confidence|{disc_confidence*100:.02f}%|
                """
                st.markdown(final_results_as_table)

if __name__ == '__main__':
    main()