luigi12345 commited on
Commit
2c57916
β€’
1 Parent(s): c4921c6
Files changed (1) hide show
  1. app.py +106 -84
app.py CHANGED
@@ -24,19 +24,22 @@ MODEL_OPTIONS = {
24
  class GlaucomaModel(object):
25
  def __init__(self,
26
  cls_model_path="pamixsun/swinv2_tiny_for_glaucoma_classification",
27
- seg_model_path=None, # Make this optional
28
  device=torch.device('cpu')):
29
  self.device = device
30
- # Classification model for glaucoma (always the same)
 
 
31
  self.cls_extractor = AutoImageProcessor.from_pretrained(cls_model_path)
32
  self.cls_model = Swinv2ForImageClassification.from_pretrained(cls_model_path).to(device).eval()
33
 
34
- # Segmentation model - use provided path or default
35
- seg_path = seg_model_path or MODEL_OPTIONS["Pamixsun"] # Default to Pamixsun if none provided
36
- self.seg_extractor = AutoImageProcessor.from_pretrained(seg_path)
37
- self.seg_model = SegformerForSemanticSegmentation.from_pretrained(seg_path).to(device).eval()
 
 
38
 
39
- # Mapping for class labels
40
  self.cls_id2label = self.cls_model.config.id2label
41
 
42
  def glaucoma_pred(self, image):
@@ -54,19 +57,38 @@ class GlaucomaModel(object):
54
  with torch.no_grad():
55
  inputs.to(self.device)
56
  outputs = self.seg_model(**inputs)
 
57
  logits = outputs.logits.cpu()
58
  upsampled_logits = nn.functional.interpolate(
59
  logits, size=image.shape[:2], mode="bilinear", align_corners=False
60
  )
61
- seg_probs = F.softmax(upsampled_logits, dim=1)
62
- pred_disc_cup = upsampled_logits.argmax(dim=1)[0]
63
 
64
- # Calculate segmentation confidence based on probability distribution
65
- # For each pixel classified as cup/disc, check how confident the model is
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  cup_mask = pred_disc_cup == 2
67
  disc_mask = pred_disc_cup == 1
68
 
69
- # Get confidence only for pixels predicted as cup/disc
70
  cup_confidence = seg_probs[0, 2, cup_mask].mean().item() * 100 if cup_mask.any() else 0
71
  disc_confidence = seg_probs[0, 1, disc_mask].mean().item() * 100 if disc_mask.any() else 0
72
 
@@ -225,107 +247,107 @@ def main():
225
 
226
  if uploaded_files:
227
  try:
228
- # Enhanced model loading feedback
229
- st.write("πŸ€– Initializing AI models...")
230
- st.write(f"β€’ Loading classification model: pamixsun/swinv2_tiny_for_glaucoma_classification")
231
- st.write(f"β€’ Loading segmentation model: {selected_model}")
232
 
233
- # Initialize model with selected segmentation model
234
  model = GlaucomaModel(
235
  device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
236
  seg_model_path=MODEL_OPTIONS[selected_model]
237
  )
238
 
239
- # Show model loading completion
240
- st.write("βœ… Models loaded successfully")
241
- st.write(f"πŸ–₯️ Using: {'GPU' if torch.cuda.is_available() else 'CPU'} for processing")
242
  st.write("---")
243
 
244
  for file in uploaded_files:
245
  try:
246
- # Process each image with enhanced feedback
247
- st.write(f"πŸ“Έ Processing image: {file.name}")
248
  image = Image.open(file).convert('RGB')
249
  image_np = np.array(image)
250
 
251
- # Get predictions
252
  disease_idx, disc_cup_image, vcdr, cls_conf, cup_conf, disc_conf, cropped_image = model.process(image_np)
253
 
254
- # Enhanced results display
255
  st.write("---")
256
- st.write(f"Results for {file.name}")
257
 
258
- # Diagnosis section
259
- st.write("πŸ“Š **Diagnosis Results:**")
260
- st.write(f"β€’ Finding: {model.cls_id2label[disease_idx]}")
261
- st.write(f"β€’ AI Confidence: {cls_conf:.1f}% ({get_confidence_level(cls_conf)})")
262
 
263
- # Enhanced Segmentation confidence section with detailed explanations
264
- st.write("\nπŸ” **Understanding Segmentation Confidence:**")
265
- st.write("""
266
- Segmentation confidence shows how certain the AI is about each pixel it classified:
267
- β€’ For the Optic Cup (central depression):
268
- - Measures the AI's certainty that the red-colored pixels are truly part of the cup
269
- - Higher confidence means clearer cup boundaries and more reliable VCDR
270
-
271
- β€’ For the Optic Disc (entire circular area):
272
- - Indicates how sure the AI is about the green-outlined disc boundary
273
- - Higher confidence suggests better disc margin visibility
274
-
275
- Confidence scores are calculated by averaging the model's certainty
276
- for each pixel it identified as cup or disc. A score of 100% would mean
277
- the model is absolutely certain about every pixel's classification.
278
- """)
279
-
280
- st.write("\nπŸ“Š **Current Segmentation Confidence Scores:**")
281
- st.write(f"β€’ Optic Cup Detection: {cup_conf:.1f}% - {get_confidence_level(cup_conf)}")
282
- st.write(f"β€’ Optic Disc Detection: {disc_conf:.1f}% - {get_confidence_level(disc_conf)}")
283
-
284
- # Add interpretation guidance
285
- if cup_conf >= 75 and disc_conf >= 75:
286
- st.write("βœ… High confidence scores indicate reliable measurements")
287
- elif cup_conf < 60 or disc_conf < 60:
288
- st.write("""
289
- ⚠️ Lower confidence scores might be due to:
290
- β€’ Image quality issues (blur, poor contrast)
291
- β€’ Unusual anatomical variations
292
- β€’ Pathological changes affecting visibility
293
- β€’ Poor image centering or focus
294
-
295
- Consider retaking the image if possible.
296
- """)
297
 
298
- # Clinical metrics
299
- st.write("\nπŸ“ **Clinical Measurements:**")
300
- st.write(f"β€’ Cup-to-Disc Ratio (VCDR): {vcdr:.3f}")
301
  if vcdr > 0.7:
302
- st.write(" ⚠️ High VCDR - Potential risk indicator")
303
  elif vcdr > 0.5:
304
- st.write(" ℹ️ Borderline VCDR - Follow-up recommended")
305
  else:
306
- st.write(" βœ… Normal VCDR range")
 
 
 
 
 
 
 
307
 
308
- # Image display with enhanced captions
309
- st.write("\nπŸ–ΌοΈ **Visual Analysis:**")
310
- st.image(disc_cup_image, caption="""
311
- Segmentation Overlay
312
- β€’ Green outline: Optic Disc boundary
313
- β€’ Red area: Optic Cup region
314
- β€’ Transparency shows underlying retina
315
- """)
316
- st.image(cropped_image, caption="Zoomed Region of Interest")
 
317
 
318
- # Add quality note if needed
319
- if cup_conf < 60 or disc_conf < 60:
320
- st.write("\n⚠️ Note: Low segmentation confidence. Image quality might affect measurements.")
 
321
 
322
  except Exception as e:
323
  st.error(f"Error processing {file.name}: {str(e)}")
324
  continue
325
 
326
- # Simple summary at the end
327
- st.write("---")
328
- st.write("Processing complete!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
  except Exception as e:
331
  st.error(f"An error occurred: {str(e)}")
 
24
  class GlaucomaModel(object):
25
  def __init__(self,
26
  cls_model_path="pamixsun/swinv2_tiny_for_glaucoma_classification",
27
+ seg_model_path=None,
28
  device=torch.device('cpu')):
29
  self.device = device
30
+ self.seg_model_path = seg_model_path or MODEL_OPTIONS["Pamixsun"]
31
+
32
+ # Classification model setup remains the same
33
  self.cls_extractor = AutoImageProcessor.from_pretrained(cls_model_path)
34
  self.cls_model = Swinv2ForImageClassification.from_pretrained(cls_model_path).to(device).eval()
35
 
36
+ # Segmentation model setup with model type detection
37
+ self.seg_extractor = AutoImageProcessor.from_pretrained(self.seg_model_path)
38
+ self.seg_model = SegformerForSemanticSegmentation.from_pretrained(self.seg_model_path).to(device).eval()
39
+
40
+ # Detect model type
41
+ self.is_ferferefer = "ferferefer" in self.seg_model_path.lower()
42
 
 
43
  self.cls_id2label = self.cls_model.config.id2label
44
 
45
  def glaucoma_pred(self, image):
 
57
  with torch.no_grad():
58
  inputs.to(self.device)
59
  outputs = self.seg_model(**inputs)
60
+
61
  logits = outputs.logits.cpu()
62
  upsampled_logits = nn.functional.interpolate(
63
  logits, size=image.shape[:2], mode="bilinear", align_corners=False
64
  )
 
 
65
 
66
+ if self.is_ferferefer:
67
+ # ferferefer model specific processing
68
+ seg_probs = F.softmax(upsampled_logits, dim=1)
69
+ pred_disc_cup = upsampled_logits.argmax(dim=1)[0]
70
+
71
+ # Map ferferefer model classes to match Pamixsun format
72
+ # Assuming ferferefer uses different class indices
73
+ class_mapping = {
74
+ 0: 0, # background
75
+ 1: 1, # disc
76
+ 2: 2 # cup
77
+ }
78
+
79
+ pred_disc_cup_mapped = torch.zeros_like(pred_disc_cup)
80
+ for old_class, new_class in class_mapping.items():
81
+ pred_disc_cup_mapped[pred_disc_cup == old_class] = new_class
82
+ pred_disc_cup = pred_disc_cup_mapped
83
+ else:
84
+ # Pamixsun model processing (original logic)
85
+ seg_probs = F.softmax(upsampled_logits, dim=1)
86
+ pred_disc_cup = upsampled_logits.argmax(dim=1)[0]
87
+
88
+ # Calculate confidences
89
  cup_mask = pred_disc_cup == 2
90
  disc_mask = pred_disc_cup == 1
91
 
 
92
  cup_confidence = seg_probs[0, 2, cup_mask].mean().item() * 100 if cup_mask.any() else 0
93
  disc_confidence = seg_probs[0, 1, disc_mask].mean().item() * 100 if disc_mask.any() else 0
94
 
 
247
 
248
  if uploaded_files:
249
  try:
250
+ # Model loading feedback with better visuals
251
+ st.info("πŸ€– Loading AI Models")
252
+ st.write("Classification: pamixsun/swinv2_tiny_for_glaucoma_classification")
253
+ st.write(f"Segmentation: {selected_model}")
254
 
 
255
  model = GlaucomaModel(
256
  device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
257
  seg_model_path=MODEL_OPTIONS[selected_model]
258
  )
259
 
260
+ st.success(f"βœ… Models loaded successfully - Using {'GPU' if torch.cuda.is_available() else 'CPU'}")
 
 
261
  st.write("---")
262
 
263
  for file in uploaded_files:
264
  try:
265
+ st.info(f"πŸ“Έ Processing: {file.name}")
 
266
  image = Image.open(file).convert('RGB')
267
  image_np = np.array(image)
268
 
 
269
  disease_idx, disc_cup_image, vcdr, cls_conf, cup_conf, disc_conf, cropped_image = model.process(image_np)
270
 
 
271
  st.write("---")
272
+ st.success(f"Results for: {file.name}")
273
 
274
+ # Key findings with better visuals
275
+ st.info("πŸ“Š Key Findings")
 
 
276
 
277
+ # Diagnosis with color-coded warning levels
278
+ diagnosis = model.cls_id2label[disease_idx]
279
+ if diagnosis == "Glaucoma":
280
+ st.warning(f"Diagnosis: {diagnosis} ({cls_conf:.1f}% confidence)")
281
+ else:
282
+ st.success(f"Diagnosis: {diagnosis} ({cls_conf:.1f}% confidence)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
+ # VCDR with risk levels
 
 
285
  if vcdr > 0.7:
286
+ st.warning(f"VCDR: {vcdr:.3f} - ⚠️ High Risk")
287
  elif vcdr > 0.5:
288
+ st.warning(f"VCDR: {vcdr:.3f} - ⚠️ Borderline")
289
  else:
290
+ st.success(f"VCDR: {vcdr:.3f} - βœ… Normal")
291
+
292
+ # Segmentation confidence
293
+ st.info("πŸ” Segmentation Confidence")
294
+ st.write("""
295
+ β€’ Optic Cup (red area): Central depression
296
+ β€’ Optic Disc (green outline): Entire nerve area
297
+ """)
298
 
299
+ # Cup and Disc confidence with warnings
300
+ if cup_conf < 60:
301
+ st.warning(f"Cup Detection: {cup_conf:.1f}% - Low Confidence")
302
+ else:
303
+ st.write(f"Cup Detection: {cup_conf:.1f}%")
304
+
305
+ if disc_conf < 60:
306
+ st.warning(f"Disc Detection: {disc_conf:.1f}% - Low Confidence")
307
+ else:
308
+ st.write(f"Disc Detection: {disc_conf:.1f}%")
309
 
310
+ # Images with clear sections
311
+ st.info("πŸ–ΌοΈ Analysis Images")
312
+ st.image(disc_cup_image, caption="Green: Optic Disc | Red: Optic Cup")
313
+ st.image(cropped_image, caption="Region of Interest")
314
 
315
  except Exception as e:
316
  st.error(f"Error processing {file.name}: {str(e)}")
317
  continue
318
 
319
+ # Download section
320
+ try:
321
+ st.info("πŸ“₯ Preparing Download")
322
+ results = []
323
+ original_images = []
324
+ for file in uploaded_files:
325
+ image = Image.open(file).convert('RGB')
326
+ image_np = np.array(image)
327
+ disease_idx, disc_cup_image, vcdr, cls_conf, cup_conf, disc_conf, cropped_image = model.process(image_np)
328
+ results.append({
329
+ 'file_name': file.name,
330
+ 'diagnosis': model.cls_id2label[disease_idx],
331
+ 'confidence': cls_conf,
332
+ 'vcdr': vcdr,
333
+ 'cup_conf': cup_conf,
334
+ 'disc_conf': disc_conf,
335
+ 'processed_image': disc_cup_image,
336
+ 'cropped_image': cropped_image
337
+ })
338
+ original_images.append(image_np)
339
+
340
+ zip_data = save_results(results, original_images)
341
+ b64_zip = base64.b64encode(zip_data).decode()
342
+
343
+ st.success("βœ… Download Ready")
344
+ href = f'<a href="data:application/zip;base64,{b64_zip}" download="glaucoma_results.zip">πŸ“₯ Download All Results (ZIP)</a>'
345
+ st.markdown(href, unsafe_allow_html=True)
346
+
347
+ except Exception as e:
348
+ st.error(f"Error preparing download: {str(e)}")
349
+
350
+ st.success("βœ… All Processing Complete!")
351
 
352
  except Exception as e:
353
  st.error(f"An error occurred: {str(e)}")