luigi12345
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -8,6 +8,8 @@ import matplotlib.pyplot as plt
|
|
8 |
import streamlit as st
|
9 |
from PIL import Image
|
10 |
import io
|
|
|
|
|
11 |
|
12 |
# --- GlaucomaModel Class ---
|
13 |
class GlaucomaModel(object):
|
@@ -22,20 +24,17 @@ class GlaucomaModel(object):
|
|
22 |
# Segmentation model for optic disc and cup
|
23 |
self.seg_extractor = AutoImageProcessor.from_pretrained(seg_model_path)
|
24 |
self.seg_model = SegformerForSemanticSegmentation.from_pretrained(seg_model_path).to(device).eval()
|
25 |
-
|
26 |
-
# Class activation map
|
27 |
self.cls_id2label = self.cls_model.config.id2label
|
28 |
-
self.seg_id2label = self.seg_model.config.id2label
|
29 |
|
30 |
def glaucoma_pred(self, image):
|
31 |
inputs = self.cls_extractor(images=image.copy(), return_tensors="pt")
|
32 |
with torch.no_grad():
|
33 |
inputs.to(self.device)
|
34 |
outputs = self.cls_model(**inputs).logits
|
35 |
-
# Softmax for probabilities
|
36 |
probs = F.softmax(outputs, dim=-1)
|
37 |
disease_idx = probs.cpu()[0, :].numpy().argmax()
|
38 |
-
confidence = probs.cpu()[0, disease_idx].item() * 100
|
39 |
return disease_idx, confidence
|
40 |
|
41 |
def optic_disc_cup_pred(self, image):
|
@@ -47,66 +46,44 @@ class GlaucomaModel(object):
|
|
47 |
upsampled_logits = nn.functional.interpolate(
|
48 |
logits, size=image.shape[:2], mode="bilinear", align_corners=False
|
49 |
)
|
50 |
-
# Softmax for segmentation confidence
|
51 |
seg_probs = F.softmax(upsampled_logits, dim=1)
|
52 |
pred_disc_cup = upsampled_logits.argmax(dim=1)[0]
|
53 |
-
cup_confidence = seg_probs[0, 2, :, :].mean().item() * 100
|
54 |
-
disc_confidence = seg_probs[0, 1, :, :].mean().item() * 100
|
55 |
return pred_disc_cup.numpy().astype(np.uint8), cup_confidence, disc_confidence
|
56 |
|
57 |
def process(self, image):
|
58 |
-
image_shape = image.shape[:2]
|
59 |
disease_idx, cls_confidence = self.glaucoma_pred(image)
|
60 |
disc_cup, cup_confidence, disc_confidence = self.optic_disc_cup_pred(image)
|
61 |
|
62 |
try:
|
63 |
-
vcdr = simple_vcdr(disc_cup)
|
64 |
except:
|
65 |
vcdr = np.nan
|
66 |
|
67 |
-
# Mask for optic disc and cup
|
68 |
mask = (disc_cup > 0).astype(np.uint8)
|
69 |
-
|
70 |
-
# Get bounding box of the optic cup + disc and add dynamic padding
|
71 |
x, y, w, h = cv2.boundingRect(mask)
|
72 |
-
padding = max(50, int(0.2 * max(w, h)))
|
73 |
x = max(x - padding, 0)
|
74 |
y = max(y - padding, 0)
|
75 |
w = min(w + 2 * padding, image.shape[1] - x)
|
76 |
h = min(h + 2 * padding, image.shape[0] - y)
|
77 |
|
78 |
-
# Ensure that the bounding box is large enough to avoid cropping errors
|
79 |
cropped_image = image[y:y+h, x:x+w] if w >= 50 and h >= 50 else image.copy()
|
80 |
-
|
81 |
-
# Generate disc and cup visualization
|
82 |
_, disc_cup_image = add_mask(image, disc_cup, [1, 2], [[0, 255, 0], [255, 0, 0]], 0.2)
|
83 |
|
84 |
return disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence, cropped_image
|
85 |
|
86 |
# --- Utility Functions ---
|
87 |
def simple_vcdr(mask):
|
88 |
-
"""
|
89 |
-
Simple function to calculate the vertical cup-to-disc ratio (VCDR).
|
90 |
-
Assumes:
|
91 |
-
- mask contains class 1 for optic disc and class 2 for optic cup.
|
92 |
-
"""
|
93 |
disc_area = np.sum(mask == 1)
|
94 |
cup_area = np.sum(mask == 2)
|
95 |
-
if disc_area == 0:
|
96 |
return np.nan
|
97 |
vcdr = cup_area / disc_area
|
98 |
return vcdr
|
99 |
|
100 |
def add_mask(image, mask, classes, colors, alpha=0.5):
|
101 |
-
"""
|
102 |
-
Adds a transparent mask to the original image.
|
103 |
-
Args:
|
104 |
-
- image: the original RGB image
|
105 |
-
- mask: the predicted segmentation mask
|
106 |
-
- classes: a list of class indices to apply masks for (e.g., [1, 2])
|
107 |
-
- colors: a list of colors for each class (e.g., [[0, 255, 0], [255, 0, 0]] for green and red)
|
108 |
-
- alpha: transparency level (default = 0.5)
|
109 |
-
"""
|
110 |
overlay = image.copy()
|
111 |
for class_id, color in zip(classes, colors):
|
112 |
overlay[mask == class_id] = color
|
@@ -116,79 +93,59 @@ def add_mask(image, mask, classes, colors, alpha=0.5):
|
|
116 |
# --- Streamlit Interface ---
|
117 |
def main():
|
118 |
st.set_page_config(layout="wide")
|
119 |
-
st.title("Glaucoma Screening from Retinal Fundus Images")
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
Image.fromarray(cropped_image).save(buf, format="PNG")
|
173 |
-
byte_img = buf.getvalue()
|
174 |
-
st.sidebar.download_button(
|
175 |
-
label="Download Cropped Image",
|
176 |
-
data=byte_img,
|
177 |
-
file_name="cropped_image.png",
|
178 |
-
mime="image/png"
|
179 |
-
)
|
180 |
-
|
181 |
-
# Display results with confidence
|
182 |
-
st.subheader("Screening results:")
|
183 |
-
final_results_as_table = f"""
|
184 |
-
|Parameters|Outcomes|
|
185 |
-
|---|---|
|
186 |
-
|Vertical cup-to-disc ratio|{vcdr:.04f}|
|
187 |
-
|Category|{model.cls_id2label[disease_idx]} ({cls_confidence:.02f}% confidence)|
|
188 |
-
|Optic Cup Segmentation Confidence|{cup_confidence:.02f}%|
|
189 |
-
|Optic Disc Segmentation Confidence|{disc_confidence:.02f}%|
|
190 |
-
"""
|
191 |
-
st.markdown(final_results_as_table)
|
192 |
|
193 |
if __name__ == '__main__':
|
194 |
main()
|
|
|
8 |
import streamlit as st
|
9 |
from PIL import Image
|
10 |
import io
|
11 |
+
import zipfile
|
12 |
+
import os
|
13 |
|
14 |
# --- GlaucomaModel Class ---
|
15 |
class GlaucomaModel(object):
|
|
|
24 |
# Segmentation model for optic disc and cup
|
25 |
self.seg_extractor = AutoImageProcessor.from_pretrained(seg_model_path)
|
26 |
self.seg_model = SegformerForSemanticSegmentation.from_pretrained(seg_model_path).to(device).eval()
|
27 |
+
# Mapping for class labels
|
|
|
28 |
self.cls_id2label = self.cls_model.config.id2label
|
|
|
29 |
|
30 |
def glaucoma_pred(self, image):
|
31 |
inputs = self.cls_extractor(images=image.copy(), return_tensors="pt")
|
32 |
with torch.no_grad():
|
33 |
inputs.to(self.device)
|
34 |
outputs = self.cls_model(**inputs).logits
|
|
|
35 |
probs = F.softmax(outputs, dim=-1)
|
36 |
disease_idx = probs.cpu()[0, :].numpy().argmax()
|
37 |
+
confidence = probs.cpu()[0, disease_idx].item() * 100
|
38 |
return disease_idx, confidence
|
39 |
|
40 |
def optic_disc_cup_pred(self, image):
|
|
|
46 |
upsampled_logits = nn.functional.interpolate(
|
47 |
logits, size=image.shape[:2], mode="bilinear", align_corners=False
|
48 |
)
|
|
|
49 |
seg_probs = F.softmax(upsampled_logits, dim=1)
|
50 |
pred_disc_cup = upsampled_logits.argmax(dim=1)[0]
|
51 |
+
cup_confidence = seg_probs[0, 2, :, :].mean().item() * 100
|
52 |
+
disc_confidence = seg_probs[0, 1, :, :].mean().item() * 100
|
53 |
return pred_disc_cup.numpy().astype(np.uint8), cup_confidence, disc_confidence
|
54 |
|
55 |
def process(self, image):
|
|
|
56 |
disease_idx, cls_confidence = self.glaucoma_pred(image)
|
57 |
disc_cup, cup_confidence, disc_confidence = self.optic_disc_cup_pred(image)
|
58 |
|
59 |
try:
|
60 |
+
vcdr = simple_vcdr(disc_cup)
|
61 |
except:
|
62 |
vcdr = np.nan
|
63 |
|
|
|
64 |
mask = (disc_cup > 0).astype(np.uint8)
|
|
|
|
|
65 |
x, y, w, h = cv2.boundingRect(mask)
|
66 |
+
padding = max(50, int(0.2 * max(w, h)))
|
67 |
x = max(x - padding, 0)
|
68 |
y = max(y - padding, 0)
|
69 |
w = min(w + 2 * padding, image.shape[1] - x)
|
70 |
h = min(h + 2 * padding, image.shape[0] - y)
|
71 |
|
|
|
72 |
cropped_image = image[y:y+h, x:x+w] if w >= 50 and h >= 50 else image.copy()
|
|
|
|
|
73 |
_, disc_cup_image = add_mask(image, disc_cup, [1, 2], [[0, 255, 0], [255, 0, 0]], 0.2)
|
74 |
|
75 |
return disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence, cropped_image
|
76 |
|
77 |
# --- Utility Functions ---
|
78 |
def simple_vcdr(mask):
|
|
|
|
|
|
|
|
|
|
|
79 |
disc_area = np.sum(mask == 1)
|
80 |
cup_area = np.sum(mask == 2)
|
81 |
+
if disc_area == 0:
|
82 |
return np.nan
|
83 |
vcdr = cup_area / disc_area
|
84 |
return vcdr
|
85 |
|
86 |
def add_mask(image, mask, classes, colors, alpha=0.5):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
overlay = image.copy()
|
88 |
for class_id, color in zip(classes, colors):
|
89 |
overlay[mask == class_id] = color
|
|
|
93 |
# --- Streamlit Interface ---
|
94 |
def main():
|
95 |
st.set_page_config(layout="wide")
|
96 |
+
st.title("Batch Glaucoma Screening from Retinal Fundus Images")
|
97 |
+
|
98 |
+
st.sidebar.title("Settings")
|
99 |
+
confidence_threshold = st.sidebar.slider("Confidence Threshold (%)", 0, 100, 70)
|
100 |
+
uploaded_files = st.sidebar.file_uploader("Upload Images", type=['png', 'jpeg', 'jpg'], accept_multiple_files=True)
|
101 |
+
|
102 |
+
confident_images = []
|
103 |
+
download_confident_images = []
|
104 |
+
|
105 |
+
if uploaded_files:
|
106 |
+
for uploaded_file in uploaded_files:
|
107 |
+
image = Image.open(uploaded_file).convert('RGB')
|
108 |
+
image_np = np.array(image).astype(np.uint8)
|
109 |
+
|
110 |
+
with st.spinner(f'Processing {uploaded_file.name}...'):
|
111 |
+
model = GlaucomaModel(device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
|
112 |
+
disease_idx, disc_cup_image, vcdr, cls_conf, cup_conf, disc_conf, cropped_image = model.process(image_np)
|
113 |
+
|
114 |
+
# Confidence-based grouping
|
115 |
+
is_confident = cls_conf >= confidence_threshold
|
116 |
+
if is_confident:
|
117 |
+
confident_images.append(uploaded_file.name)
|
118 |
+
download_confident_images.append((cropped_image, uploaded_file.name))
|
119 |
+
|
120 |
+
# Display Results
|
121 |
+
with st.expander(f"Results for {uploaded_file.name}", expanded=False):
|
122 |
+
cols = st.columns(4)
|
123 |
+
cols[0].image(image_np, caption="Input Image", use_column_width=True)
|
124 |
+
cols[1].image(disc_cup_image, caption="Disc/Cup Segmentation", use_column_width=True)
|
125 |
+
cols[2].image(image_np, caption="Class Activation Map", use_column_width=True)
|
126 |
+
cols[3].image(cropped_image, caption="Cropped Image", use_column_width=True)
|
127 |
+
|
128 |
+
st.write(f"**Vertical cup-to-disc ratio:** {vcdr:.04f}")
|
129 |
+
st.write(f"**Category:** {model.cls_id2label[disease_idx]} ({cls_conf:.02f}% confidence)")
|
130 |
+
st.write(f"**Optic Cup Segmentation Confidence:** {cup_conf:.02f}%")
|
131 |
+
st.write(f"**Optic Disc Segmentation Confidence:** {disc_conf:.02f}%")
|
132 |
+
st.write(f"**Confidence Group:** {'Confident' if is_confident else 'Not Confident'}")
|
133 |
+
|
134 |
+
# Download Button for Confident Images
|
135 |
+
if download_confident_images:
|
136 |
+
with zipfile.ZipFile("confident_cropped_images.zip", "w") as zf:
|
137 |
+
for cropped_image, name in download_confident_images:
|
138 |
+
img_buffer = io.BytesIO()
|
139 |
+
Image.fromarray(cropped_image).save(img_buffer, format="PNG")
|
140 |
+
zf.writestr(f"{name}_cropped.png", img_buffer.getvalue())
|
141 |
+
|
142 |
+
# Provide a markdown link to the ZIP file
|
143 |
+
st.sidebar.markdown(
|
144 |
+
f"[Download Confident Cropped Images](./confident_cropped_images.zip)",
|
145 |
+
unsafe_allow_html=True
|
146 |
+
)
|
147 |
+
else:
|
148 |
+
st.sidebar.info("Upload images to begin analysis.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
if __name__ == '__main__':
|
151 |
main()
|