luigi12345
commited on
Commit
β’
2c57916
1
Parent(s):
c4921c6
app.py
CHANGED
@@ -24,19 +24,22 @@ MODEL_OPTIONS = {
|
|
24 |
class GlaucomaModel(object):
|
25 |
def __init__(self,
|
26 |
cls_model_path="pamixsun/swinv2_tiny_for_glaucoma_classification",
|
27 |
-
seg_model_path=None,
|
28 |
device=torch.device('cpu')):
|
29 |
self.device = device
|
30 |
-
|
|
|
|
|
31 |
self.cls_extractor = AutoImageProcessor.from_pretrained(cls_model_path)
|
32 |
self.cls_model = Swinv2ForImageClassification.from_pretrained(cls_model_path).to(device).eval()
|
33 |
|
34 |
-
# Segmentation model
|
35 |
-
|
36 |
-
self.
|
37 |
-
|
|
|
|
|
38 |
|
39 |
-
# Mapping for class labels
|
40 |
self.cls_id2label = self.cls_model.config.id2label
|
41 |
|
42 |
def glaucoma_pred(self, image):
|
@@ -54,19 +57,38 @@ class GlaucomaModel(object):
|
|
54 |
with torch.no_grad():
|
55 |
inputs.to(self.device)
|
56 |
outputs = self.seg_model(**inputs)
|
|
|
57 |
logits = outputs.logits.cpu()
|
58 |
upsampled_logits = nn.functional.interpolate(
|
59 |
logits, size=image.shape[:2], mode="bilinear", align_corners=False
|
60 |
)
|
61 |
-
seg_probs = F.softmax(upsampled_logits, dim=1)
|
62 |
-
pred_disc_cup = upsampled_logits.argmax(dim=1)[0]
|
63 |
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
cup_mask = pred_disc_cup == 2
|
67 |
disc_mask = pred_disc_cup == 1
|
68 |
|
69 |
-
# Get confidence only for pixels predicted as cup/disc
|
70 |
cup_confidence = seg_probs[0, 2, cup_mask].mean().item() * 100 if cup_mask.any() else 0
|
71 |
disc_confidence = seg_probs[0, 1, disc_mask].mean().item() * 100 if disc_mask.any() else 0
|
72 |
|
@@ -225,107 +247,107 @@ def main():
|
|
225 |
|
226 |
if uploaded_files:
|
227 |
try:
|
228 |
-
#
|
229 |
-
st.
|
230 |
-
st.write(
|
231 |
-
st.write(f"
|
232 |
|
233 |
-
# Initialize model with selected segmentation model
|
234 |
model = GlaucomaModel(
|
235 |
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
|
236 |
seg_model_path=MODEL_OPTIONS[selected_model]
|
237 |
)
|
238 |
|
239 |
-
|
240 |
-
st.write("β
Models loaded successfully")
|
241 |
-
st.write(f"π₯οΈ Using: {'GPU' if torch.cuda.is_available() else 'CPU'} for processing")
|
242 |
st.write("---")
|
243 |
|
244 |
for file in uploaded_files:
|
245 |
try:
|
246 |
-
|
247 |
-
st.write(f"πΈ Processing image: {file.name}")
|
248 |
image = Image.open(file).convert('RGB')
|
249 |
image_np = np.array(image)
|
250 |
|
251 |
-
# Get predictions
|
252 |
disease_idx, disc_cup_image, vcdr, cls_conf, cup_conf, disc_conf, cropped_image = model.process(image_np)
|
253 |
|
254 |
-
# Enhanced results display
|
255 |
st.write("---")
|
256 |
-
st.
|
257 |
|
258 |
-
#
|
259 |
-
st.
|
260 |
-
st.write(f"β’ Finding: {model.cls_id2label[disease_idx]}")
|
261 |
-
st.write(f"β’ AI Confidence: {cls_conf:.1f}% ({get_confidence_level(cls_conf)})")
|
262 |
|
263 |
-
#
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
- Higher confidence means clearer cup boundaries and more reliable VCDR
|
270 |
-
|
271 |
-
β’ For the Optic Disc (entire circular area):
|
272 |
-
- Indicates how sure the AI is about the green-outlined disc boundary
|
273 |
-
- Higher confidence suggests better disc margin visibility
|
274 |
-
|
275 |
-
Confidence scores are calculated by averaging the model's certainty
|
276 |
-
for each pixel it identified as cup or disc. A score of 100% would mean
|
277 |
-
the model is absolutely certain about every pixel's classification.
|
278 |
-
""")
|
279 |
-
|
280 |
-
st.write("\nπ **Current Segmentation Confidence Scores:**")
|
281 |
-
st.write(f"β’ Optic Cup Detection: {cup_conf:.1f}% - {get_confidence_level(cup_conf)}")
|
282 |
-
st.write(f"β’ Optic Disc Detection: {disc_conf:.1f}% - {get_confidence_level(disc_conf)}")
|
283 |
-
|
284 |
-
# Add interpretation guidance
|
285 |
-
if cup_conf >= 75 and disc_conf >= 75:
|
286 |
-
st.write("β
High confidence scores indicate reliable measurements")
|
287 |
-
elif cup_conf < 60 or disc_conf < 60:
|
288 |
-
st.write("""
|
289 |
-
β οΈ Lower confidence scores might be due to:
|
290 |
-
β’ Image quality issues (blur, poor contrast)
|
291 |
-
β’ Unusual anatomical variations
|
292 |
-
β’ Pathological changes affecting visibility
|
293 |
-
β’ Poor image centering or focus
|
294 |
-
|
295 |
-
Consider retaking the image if possible.
|
296 |
-
""")
|
297 |
|
298 |
-
#
|
299 |
-
st.write("\nπ **Clinical Measurements:**")
|
300 |
-
st.write(f"β’ Cup-to-Disc Ratio (VCDR): {vcdr:.3f}")
|
301 |
if vcdr > 0.7:
|
302 |
-
st.
|
303 |
elif vcdr > 0.5:
|
304 |
-
st.
|
305 |
else:
|
306 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
307 |
|
308 |
-
#
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
""
|
316 |
-
|
|
|
317 |
|
318 |
-
#
|
319 |
-
|
320 |
-
|
|
|
321 |
|
322 |
except Exception as e:
|
323 |
st.error(f"Error processing {file.name}: {str(e)}")
|
324 |
continue
|
325 |
|
326 |
-
#
|
327 |
-
|
328 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
329 |
|
330 |
except Exception as e:
|
331 |
st.error(f"An error occurred: {str(e)}")
|
|
|
24 |
class GlaucomaModel(object):
|
25 |
def __init__(self,
|
26 |
cls_model_path="pamixsun/swinv2_tiny_for_glaucoma_classification",
|
27 |
+
seg_model_path=None,
|
28 |
device=torch.device('cpu')):
|
29 |
self.device = device
|
30 |
+
self.seg_model_path = seg_model_path or MODEL_OPTIONS["Pamixsun"]
|
31 |
+
|
32 |
+
# Classification model setup remains the same
|
33 |
self.cls_extractor = AutoImageProcessor.from_pretrained(cls_model_path)
|
34 |
self.cls_model = Swinv2ForImageClassification.from_pretrained(cls_model_path).to(device).eval()
|
35 |
|
36 |
+
# Segmentation model setup with model type detection
|
37 |
+
self.seg_extractor = AutoImageProcessor.from_pretrained(self.seg_model_path)
|
38 |
+
self.seg_model = SegformerForSemanticSegmentation.from_pretrained(self.seg_model_path).to(device).eval()
|
39 |
+
|
40 |
+
# Detect model type
|
41 |
+
self.is_ferferefer = "ferferefer" in self.seg_model_path.lower()
|
42 |
|
|
|
43 |
self.cls_id2label = self.cls_model.config.id2label
|
44 |
|
45 |
def glaucoma_pred(self, image):
|
|
|
57 |
with torch.no_grad():
|
58 |
inputs.to(self.device)
|
59 |
outputs = self.seg_model(**inputs)
|
60 |
+
|
61 |
logits = outputs.logits.cpu()
|
62 |
upsampled_logits = nn.functional.interpolate(
|
63 |
logits, size=image.shape[:2], mode="bilinear", align_corners=False
|
64 |
)
|
|
|
|
|
65 |
|
66 |
+
if self.is_ferferefer:
|
67 |
+
# ferferefer model specific processing
|
68 |
+
seg_probs = F.softmax(upsampled_logits, dim=1)
|
69 |
+
pred_disc_cup = upsampled_logits.argmax(dim=1)[0]
|
70 |
+
|
71 |
+
# Map ferferefer model classes to match Pamixsun format
|
72 |
+
# Assuming ferferefer uses different class indices
|
73 |
+
class_mapping = {
|
74 |
+
0: 0, # background
|
75 |
+
1: 1, # disc
|
76 |
+
2: 2 # cup
|
77 |
+
}
|
78 |
+
|
79 |
+
pred_disc_cup_mapped = torch.zeros_like(pred_disc_cup)
|
80 |
+
for old_class, new_class in class_mapping.items():
|
81 |
+
pred_disc_cup_mapped[pred_disc_cup == old_class] = new_class
|
82 |
+
pred_disc_cup = pred_disc_cup_mapped
|
83 |
+
else:
|
84 |
+
# Pamixsun model processing (original logic)
|
85 |
+
seg_probs = F.softmax(upsampled_logits, dim=1)
|
86 |
+
pred_disc_cup = upsampled_logits.argmax(dim=1)[0]
|
87 |
+
|
88 |
+
# Calculate confidences
|
89 |
cup_mask = pred_disc_cup == 2
|
90 |
disc_mask = pred_disc_cup == 1
|
91 |
|
|
|
92 |
cup_confidence = seg_probs[0, 2, cup_mask].mean().item() * 100 if cup_mask.any() else 0
|
93 |
disc_confidence = seg_probs[0, 1, disc_mask].mean().item() * 100 if disc_mask.any() else 0
|
94 |
|
|
|
247 |
|
248 |
if uploaded_files:
|
249 |
try:
|
250 |
+
# Model loading feedback with better visuals
|
251 |
+
st.info("π€ Loading AI Models")
|
252 |
+
st.write("Classification: pamixsun/swinv2_tiny_for_glaucoma_classification")
|
253 |
+
st.write(f"Segmentation: {selected_model}")
|
254 |
|
|
|
255 |
model = GlaucomaModel(
|
256 |
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
|
257 |
seg_model_path=MODEL_OPTIONS[selected_model]
|
258 |
)
|
259 |
|
260 |
+
st.success(f"β
Models loaded successfully - Using {'GPU' if torch.cuda.is_available() else 'CPU'}")
|
|
|
|
|
261 |
st.write("---")
|
262 |
|
263 |
for file in uploaded_files:
|
264 |
try:
|
265 |
+
st.info(f"πΈ Processing: {file.name}")
|
|
|
266 |
image = Image.open(file).convert('RGB')
|
267 |
image_np = np.array(image)
|
268 |
|
|
|
269 |
disease_idx, disc_cup_image, vcdr, cls_conf, cup_conf, disc_conf, cropped_image = model.process(image_np)
|
270 |
|
|
|
271 |
st.write("---")
|
272 |
+
st.success(f"Results for: {file.name}")
|
273 |
|
274 |
+
# Key findings with better visuals
|
275 |
+
st.info("π Key Findings")
|
|
|
|
|
276 |
|
277 |
+
# Diagnosis with color-coded warning levels
|
278 |
+
diagnosis = model.cls_id2label[disease_idx]
|
279 |
+
if diagnosis == "Glaucoma":
|
280 |
+
st.warning(f"Diagnosis: {diagnosis} ({cls_conf:.1f}% confidence)")
|
281 |
+
else:
|
282 |
+
st.success(f"Diagnosis: {diagnosis} ({cls_conf:.1f}% confidence)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
|
284 |
+
# VCDR with risk levels
|
|
|
|
|
285 |
if vcdr > 0.7:
|
286 |
+
st.warning(f"VCDR: {vcdr:.3f} - β οΈ High Risk")
|
287 |
elif vcdr > 0.5:
|
288 |
+
st.warning(f"VCDR: {vcdr:.3f} - β οΈ Borderline")
|
289 |
else:
|
290 |
+
st.success(f"VCDR: {vcdr:.3f} - β
Normal")
|
291 |
+
|
292 |
+
# Segmentation confidence
|
293 |
+
st.info("π Segmentation Confidence")
|
294 |
+
st.write("""
|
295 |
+
β’ Optic Cup (red area): Central depression
|
296 |
+
β’ Optic Disc (green outline): Entire nerve area
|
297 |
+
""")
|
298 |
|
299 |
+
# Cup and Disc confidence with warnings
|
300 |
+
if cup_conf < 60:
|
301 |
+
st.warning(f"Cup Detection: {cup_conf:.1f}% - Low Confidence")
|
302 |
+
else:
|
303 |
+
st.write(f"Cup Detection: {cup_conf:.1f}%")
|
304 |
+
|
305 |
+
if disc_conf < 60:
|
306 |
+
st.warning(f"Disc Detection: {disc_conf:.1f}% - Low Confidence")
|
307 |
+
else:
|
308 |
+
st.write(f"Disc Detection: {disc_conf:.1f}%")
|
309 |
|
310 |
+
# Images with clear sections
|
311 |
+
st.info("πΌοΈ Analysis Images")
|
312 |
+
st.image(disc_cup_image, caption="Green: Optic Disc | Red: Optic Cup")
|
313 |
+
st.image(cropped_image, caption="Region of Interest")
|
314 |
|
315 |
except Exception as e:
|
316 |
st.error(f"Error processing {file.name}: {str(e)}")
|
317 |
continue
|
318 |
|
319 |
+
# Download section
|
320 |
+
try:
|
321 |
+
st.info("π₯ Preparing Download")
|
322 |
+
results = []
|
323 |
+
original_images = []
|
324 |
+
for file in uploaded_files:
|
325 |
+
image = Image.open(file).convert('RGB')
|
326 |
+
image_np = np.array(image)
|
327 |
+
disease_idx, disc_cup_image, vcdr, cls_conf, cup_conf, disc_conf, cropped_image = model.process(image_np)
|
328 |
+
results.append({
|
329 |
+
'file_name': file.name,
|
330 |
+
'diagnosis': model.cls_id2label[disease_idx],
|
331 |
+
'confidence': cls_conf,
|
332 |
+
'vcdr': vcdr,
|
333 |
+
'cup_conf': cup_conf,
|
334 |
+
'disc_conf': disc_conf,
|
335 |
+
'processed_image': disc_cup_image,
|
336 |
+
'cropped_image': cropped_image
|
337 |
+
})
|
338 |
+
original_images.append(image_np)
|
339 |
+
|
340 |
+
zip_data = save_results(results, original_images)
|
341 |
+
b64_zip = base64.b64encode(zip_data).decode()
|
342 |
+
|
343 |
+
st.success("β
Download Ready")
|
344 |
+
href = f'<a href="data:application/zip;base64,{b64_zip}" download="glaucoma_results.zip">π₯ Download All Results (ZIP)</a>'
|
345 |
+
st.markdown(href, unsafe_allow_html=True)
|
346 |
+
|
347 |
+
except Exception as e:
|
348 |
+
st.error(f"Error preparing download: {str(e)}")
|
349 |
+
|
350 |
+
st.success("β
All Processing Complete!")
|
351 |
|
352 |
except Exception as e:
|
353 |
st.error(f"An error occurred: {str(e)}")
|