luigi12345
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -7,6 +7,7 @@ from transformers import AutoImageProcessor, Swinv2ForImageClassification, Segfo
|
|
7 |
import matplotlib.pyplot as plt
|
8 |
import streamlit as st
|
9 |
from PIL import Image
|
|
|
10 |
|
11 |
# --- GlaucomaModel Class ---
|
12 |
class GlaucomaModel(object):
|
@@ -34,7 +35,7 @@ class GlaucomaModel(object):
|
|
34 |
# Softmax for probabilities
|
35 |
probs = F.softmax(outputs, dim=-1)
|
36 |
disease_idx = probs.cpu()[0, :].numpy().argmax()
|
37 |
-
confidence = probs.cpu()[0, disease_idx].item()
|
38 |
return disease_idx, confidence
|
39 |
|
40 |
def optic_disc_cup_pred(self, image):
|
@@ -49,8 +50,8 @@ class GlaucomaModel(object):
|
|
49 |
# Softmax for segmentation confidence
|
50 |
seg_probs = F.softmax(upsampled_logits, dim=1)
|
51 |
pred_disc_cup = upsampled_logits.argmax(dim=1)[0]
|
52 |
-
cup_confidence = seg_probs[0, 2, :, :].mean().item()
|
53 |
-
disc_confidence = seg_probs[0, 1, :, :].mean().item()
|
54 |
return pred_disc_cup.numpy().astype(np.uint8), cup_confidence, disc_confidence
|
55 |
|
56 |
def process(self, image):
|
@@ -66,20 +67,16 @@ class GlaucomaModel(object):
|
|
66 |
# Mask for optic disc and cup
|
67 |
mask = (disc_cup > 0).astype(np.uint8)
|
68 |
|
69 |
-
# Get bounding box of the optic cup + disc and add padding
|
70 |
x, y, w, h = cv2.boundingRect(mask)
|
71 |
-
padding =
|
72 |
x = max(x - padding, 0)
|
73 |
y = max(y - padding, 0)
|
74 |
w = min(w + 2 * padding, image.shape[1] - x)
|
75 |
h = min(h + 2 * padding, image.shape[0] - y)
|
76 |
|
77 |
# Ensure that the bounding box is large enough to avoid cropping errors
|
78 |
-
|
79 |
-
if w < min_size or h < min_size:
|
80 |
-
cropped_image = image.copy() # Fallback: if bounding box too small, return original image
|
81 |
-
else:
|
82 |
-
cropped_image = image[y:y+h, x:x+w]
|
83 |
|
84 |
# Generate disc and cup visualization
|
85 |
_, disc_cup_image = add_mask(image, disc_cup, [1, 2], [[0, 255, 0], [255, 0, 0]], 0.2)
|
@@ -95,9 +92,7 @@ def simple_vcdr(mask):
|
|
95 |
"""
|
96 |
disc_area = np.sum(mask == 1)
|
97 |
cup_area = np.sum(mask == 2)
|
98 |
-
|
99 |
-
# Avoid division by zero
|
100 |
-
if disc_area == 0:
|
101 |
return np.nan
|
102 |
vcdr = cup_area / disc_area
|
103 |
return vcdr
|
@@ -115,15 +110,12 @@ def add_mask(image, mask, classes, colors, alpha=0.5):
|
|
115 |
overlay = image.copy()
|
116 |
for class_id, color in zip(classes, colors):
|
117 |
overlay[mask == class_id] = color
|
118 |
-
# Blend the overlay with the original image
|
119 |
output = cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0)
|
120 |
return output, overlay
|
121 |
|
122 |
# --- Streamlit Interface ---
|
123 |
def main():
|
124 |
-
# Wide mode in Streamlit
|
125 |
st.set_page_config(layout="wide")
|
126 |
-
|
127 |
st.title("Glaucoma Screening from Retinal Fundus Images")
|
128 |
st.write('Developed by X. Sun. Find more info about me: https://pamixsun.github.io')
|
129 |
|
@@ -153,12 +145,10 @@ def main():
|
|
153 |
st.sidebar.write("Please upload an image")
|
154 |
else:
|
155 |
with st.spinner('Loading model...'):
|
156 |
-
# Load the model on available device
|
157 |
run_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
158 |
model = GlaucomaModel(device=run_device)
|
159 |
|
160 |
with st.spinner('Analyzing...'):
|
161 |
-
# Get predictions from the model
|
162 |
disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence, cropped_image = model.process(image)
|
163 |
|
164 |
# Display optic disc and cup image
|
@@ -176,17 +166,27 @@ def main():
|
|
176 |
ax.axis('off')
|
177 |
cols[3].pyplot(fig)
|
178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
# Display results with confidence
|
180 |
st.subheader("Screening results:")
|
181 |
final_results_as_table = f"""
|
182 |
|Parameters|Outcomes|
|
183 |
|---|---|
|
184 |
|Vertical cup-to-disc ratio|{vcdr:.04f}|
|
185 |
-
|Category|{model.cls_id2label[disease_idx]} ({cls_confidence
|
186 |
-
|Optic Cup Segmentation Confidence|{cup_confidence
|
187 |
-
|Optic Disc Segmentation Confidence|{disc_confidence
|
188 |
"""
|
189 |
st.markdown(final_results_as_table)
|
190 |
|
191 |
if __name__ == '__main__':
|
192 |
-
|
|
|
7 |
import matplotlib.pyplot as plt
|
8 |
import streamlit as st
|
9 |
from PIL import Image
|
10 |
+
import io
|
11 |
|
12 |
# --- GlaucomaModel Class ---
|
13 |
class GlaucomaModel(object):
|
|
|
35 |
# Softmax for probabilities
|
36 |
probs = F.softmax(outputs, dim=-1)
|
37 |
disease_idx = probs.cpu()[0, :].numpy().argmax()
|
38 |
+
confidence = probs.cpu()[0, disease_idx].item() * 100 # Scale to percentage
|
39 |
return disease_idx, confidence
|
40 |
|
41 |
def optic_disc_cup_pred(self, image):
|
|
|
50 |
# Softmax for segmentation confidence
|
51 |
seg_probs = F.softmax(upsampled_logits, dim=1)
|
52 |
pred_disc_cup = upsampled_logits.argmax(dim=1)[0]
|
53 |
+
cup_confidence = seg_probs[0, 2, :, :].mean().item() * 100 # Scale to percentage
|
54 |
+
disc_confidence = seg_probs[0, 1, :, :].mean().item() * 100 # Scale to percentage
|
55 |
return pred_disc_cup.numpy().astype(np.uint8), cup_confidence, disc_confidence
|
56 |
|
57 |
def process(self, image):
|
|
|
67 |
# Mask for optic disc and cup
|
68 |
mask = (disc_cup > 0).astype(np.uint8)
|
69 |
|
70 |
+
# Get bounding box of the optic cup + disc and add dynamic padding
|
71 |
x, y, w, h = cv2.boundingRect(mask)
|
72 |
+
padding = max(50, int(0.2 * max(w, h))) # Dynamic padding (20% of width or height)
|
73 |
x = max(x - padding, 0)
|
74 |
y = max(y - padding, 0)
|
75 |
w = min(w + 2 * padding, image.shape[1] - x)
|
76 |
h = min(h + 2 * padding, image.shape[0] - y)
|
77 |
|
78 |
# Ensure that the bounding box is large enough to avoid cropping errors
|
79 |
+
cropped_image = image[y:y+h, x:x+w] if w >= 50 and h >= 50 else image.copy()
|
|
|
|
|
|
|
|
|
80 |
|
81 |
# Generate disc and cup visualization
|
82 |
_, disc_cup_image = add_mask(image, disc_cup, [1, 2], [[0, 255, 0], [255, 0, 0]], 0.2)
|
|
|
92 |
"""
|
93 |
disc_area = np.sum(mask == 1)
|
94 |
cup_area = np.sum(mask == 2)
|
95 |
+
if disc_area == 0: # Avoid division by zero
|
|
|
|
|
96 |
return np.nan
|
97 |
vcdr = cup_area / disc_area
|
98 |
return vcdr
|
|
|
110 |
overlay = image.copy()
|
111 |
for class_id, color in zip(classes, colors):
|
112 |
overlay[mask == class_id] = color
|
|
|
113 |
output = cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0)
|
114 |
return output, overlay
|
115 |
|
116 |
# --- Streamlit Interface ---
|
117 |
def main():
|
|
|
118 |
st.set_page_config(layout="wide")
|
|
|
119 |
st.title("Glaucoma Screening from Retinal Fundus Images")
|
120 |
st.write('Developed by X. Sun. Find more info about me: https://pamixsun.github.io')
|
121 |
|
|
|
145 |
st.sidebar.write("Please upload an image")
|
146 |
else:
|
147 |
with st.spinner('Loading model...'):
|
|
|
148 |
run_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
149 |
model = GlaucomaModel(device=run_device)
|
150 |
|
151 |
with st.spinner('Analyzing...'):
|
|
|
152 |
disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence, cropped_image = model.process(image)
|
153 |
|
154 |
# Display optic disc and cup image
|
|
|
166 |
ax.axis('off')
|
167 |
cols[3].pyplot(fig)
|
168 |
|
169 |
+
# Make cropped image downloadable
|
170 |
+
buf = io.BytesIO()
|
171 |
+
Image.fromarray(cropped_image).save(buf, format="PNG")
|
172 |
+
st.sidebar.download_button(
|
173 |
+
label="Download Cropped Image",
|
174 |
+
data=buf.getvalue(),
|
175 |
+
file_name="cropped_image.png",
|
176 |
+
mime="image/png"
|
177 |
+
)
|
178 |
+
|
179 |
# Display results with confidence
|
180 |
st.subheader("Screening results:")
|
181 |
final_results_as_table = f"""
|
182 |
|Parameters|Outcomes|
|
183 |
|---|---|
|
184 |
|Vertical cup-to-disc ratio|{vcdr:.04f}|
|
185 |
+
|Category|{model.cls_id2label[disease_idx]} ({cls_confidence:.02f}% confidence)|
|
186 |
+
|Optic Cup Segmentation Confidence|{cup_confidence:.02f}%|
|
187 |
+
|Optic Disc Segmentation Confidence|{disc_confidence:.02f}%|
|
188 |
"""
|
189 |
st.markdown(final_results_as_table)
|
190 |
|
191 |
if __name__ == '__main__':
|
192 |
+
|