import cv2 import torch import numpy as np import torch.nn.functional as F from torch import nn from transformers import AutoImageProcessor, Swinv2ForImageClassification, SegformerForSemanticSegmentation import streamlit as st from PIL import Image import io import zipfile import pandas as pd from datetime import datetime import os import tempfile import base64 # Add at the top with other constants MODEL_OPTIONS = { "Default (ferferefer/segformer)": "ferferefer/segformer", "Pamixsun": "pamixsun/segformer_for_optic_disc_cup_segmentation" } # --- GlaucomaModel Class --- class GlaucomaModel(object): def __init__(self, cls_model_path="pamixsun/swinv2_tiny_for_glaucoma_classification", seg_model_path=None, device=torch.device('cpu')): self.device = device self.seg_model_path = seg_model_path or MODEL_OPTIONS["Pamixsun"] # Classification model setup remains the same self.cls_extractor = AutoImageProcessor.from_pretrained(cls_model_path) self.cls_model = Swinv2ForImageClassification.from_pretrained(cls_model_path).to(device).eval() # Segmentation model setup with model type detection self.seg_extractor = AutoImageProcessor.from_pretrained(self.seg_model_path) self.seg_model = SegformerForSemanticSegmentation.from_pretrained(self.seg_model_path).to(device).eval() # Detect model type self.is_ferferefer = "ferferefer" in self.seg_model_path.lower() self.cls_id2label = self.cls_model.config.id2label def glaucoma_pred(self, image): inputs = self.cls_extractor(images=image.copy(), return_tensors="pt") with torch.no_grad(): inputs.to(self.device) outputs = self.cls_model(**inputs).logits probs = F.softmax(outputs, dim=-1) disease_idx = probs.cpu()[0, :].numpy().argmax() confidence = probs.cpu()[0, disease_idx].item() * 100 return disease_idx, confidence def optic_disc_cup_pred(self, image): inputs = self.seg_extractor(images=image.copy(), return_tensors="pt") with torch.no_grad(): inputs.to(self.device) outputs = self.seg_model(**inputs) logits = outputs.logits.cpu() upsampled_logits = nn.functional.interpolate( logits, size=image.shape[:2], mode="bilinear", align_corners=False ) if self.is_ferferefer: # ferferefer model specific processing seg_probs = F.softmax(upsampled_logits, dim=1) pred_disc_cup = upsampled_logits.argmax(dim=1)[0] # Map ferferefer model classes to match Pamixsun format # Assuming ferferefer uses different class indices class_mapping = { 0: 0, # background 1: 1, # disc 2: 2 # cup } pred_disc_cup_mapped = torch.zeros_like(pred_disc_cup) for old_class, new_class in class_mapping.items(): pred_disc_cup_mapped[pred_disc_cup == old_class] = new_class pred_disc_cup = pred_disc_cup_mapped else: # Pamixsun model processing (original logic) seg_probs = F.softmax(upsampled_logits, dim=1) pred_disc_cup = upsampled_logits.argmax(dim=1)[0] # Calculate confidences cup_mask = pred_disc_cup == 2 disc_mask = pred_disc_cup == 1 cup_confidence = seg_probs[0, 2, cup_mask].mean().item() * 100 if cup_mask.any() else 0 disc_confidence = seg_probs[0, 1, disc_mask].mean().item() * 100 if disc_mask.any() else 0 return pred_disc_cup.numpy().astype(np.uint8), cup_confidence, disc_confidence def process(self, image): disease_idx, cls_confidence = self.glaucoma_pred(image) disc_cup, cup_confidence, disc_confidence = self.optic_disc_cup_pred(image) try: vcdr = simple_vcdr(disc_cup) except: vcdr = np.nan mask = (disc_cup > 0).astype(np.uint8) x, y, w, h = cv2.boundingRect(mask) padding = max(50, int(0.2 * max(w, h))) x = max(x - padding, 0) y = max(y - padding, 0) w = min(w + 2 * padding, image.shape[1] - x) h = min(h + 2 * padding, image.shape[0] - y) cropped_image = image[y:y+h, x:x+w] if w >= 50 and h >= 50 else image.copy() _, disc_cup_image = add_mask(image, disc_cup, [1, 2], [[0, 255, 0], [255, 0, 0]], 0.2) return disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence, cropped_image # --- Utility Functions --- def simple_vcdr(mask): disc_area = np.sum(mask == 1) cup_area = np.sum(mask == 2) if disc_area == 0: return np.nan vcdr = cup_area / disc_area return vcdr def add_mask(image, mask, classes, colors, alpha=0.5): overlay = image.copy() for class_id, color in zip(classes, colors): overlay[mask == class_id] = color output = cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0) return output, overlay def get_confidence_level(confidence): """Enhanced confidence descriptions for segmentation""" if confidence >= 90: return "Excellent (Model is very certain about the detected boundaries)" elif confidence >= 75: return "Good (Model is confident about most of the detected area)" elif confidence >= 60: return "Fair (Model has some uncertainty in parts of the detection)" elif confidence >= 45: return "Poor (Model is uncertain about many detected areas)" else: return "Very Poor (Model's detection is highly uncertain)" def process_batch(model, images_data, progress_bar=None): results = [] for idx, (file_name, image) in enumerate(images_data): try: disease_idx, disc_cup_image, vcdr, cls_conf, cup_conf, disc_conf, cropped_image = model.process(image) results.append({ 'file_name': file_name, 'diagnosis': model.cls_id2label[disease_idx], 'confidence': cls_conf, 'vcdr': vcdr, 'cup_conf': cup_conf, 'disc_conf': disc_conf, 'processed_image': disc_cup_image, 'cropped_image': cropped_image }) if progress_bar: progress_bar.progress((idx + 1) / len(images_data)) except Exception as e: st.error(f"Error processing {file_name}: {str(e)}") return results def save_results(results, original_images): # Create temporary directory for results with tempfile.TemporaryDirectory() as temp_dir: # Save report as CSV df = pd.DataFrame([{ 'File': r['file_name'], 'Diagnosis': r['diagnosis'], 'Confidence (%)': f"{r['confidence']:.1f}", 'VCDR': f"{r['vcdr']:.3f}", 'Cup Confidence (%)': f"{r['cup_conf']:.1f}", 'Disc Confidence (%)': f"{r['disc_conf']:.1f}" } for r in results]) report_path = os.path.join(temp_dir, 'report.csv') df.to_csv(report_path, index=False) # Save processed images for result, orig_img in zip(results, original_images): img_name = result['file_name'] base_name = os.path.splitext(img_name)[0] # Save original orig_path = os.path.join(temp_dir, f"{base_name}_original.jpg") Image.fromarray(orig_img).save(orig_path) # Save segmentation seg_path = os.path.join(temp_dir, f"{base_name}_segmentation.jpg") Image.fromarray(result['processed_image']).save(seg_path) # Save ROI roi_path = os.path.join(temp_dir, f"{base_name}_roi.jpg") Image.fromarray(result['cropped_image']).save(roi_path) # Create ZIP file zip_path = os.path.join(temp_dir, 'results.zip') with zipfile.ZipFile(zip_path, 'w') as zipf: for root, _, files in os.walk(temp_dir): for file in files: if file != 'results.zip': file_path = os.path.join(root, file) arcname = os.path.basename(file_path) zipf.write(file_path, arcname) with open(zip_path, 'rb') as f: return f.read() # --- Streamlit Interface --- def main(): # Use the old layout setting method st.set_page_config(layout="wide") # Use simple title instead of markdown st.title("Glaucoma Screening from Retinal Fundus Images") st.write("Upload retinal images for automated glaucoma detection and optic disc/cup segmentation") # Add model selection in sidebar before file upload st.sidebar.title("Model Settings") selected_model = st.sidebar.selectbox( "Select Segmentation Model", list(MODEL_OPTIONS.keys()), index=1 # Default to Pamixsun ) st.sidebar.title("Upload Images") st.set_option('deprecation.showfileUploaderEncoding', False) # Important for old versions uploaded_files = st.sidebar.file_uploader( "Upload retinal images", type=['png', 'jpeg', 'jpg'], accept_multiple_files=True ) # Simple explanation in sidebar st.sidebar.markdown(""" ### Understanding Results: - Diagnosis Confidence: AI certainty level - VCDR: Cup to disc ratio (>0.7 high risk) - Segmentation: Accuracy of detection """) if uploaded_files: try: # Model loading feedback with better visuals st.info("🤖 Loading AI Models") st.write("Classification: pamixsun/swinv2_tiny_for_glaucoma_classification") st.write(f"Segmentation: {selected_model}") model = GlaucomaModel( device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), seg_model_path=MODEL_OPTIONS[selected_model] ) st.success(f"✅ Models loaded successfully - Using {'GPU' if torch.cuda.is_available() else 'CPU'}") st.write("---") for file in uploaded_files: try: st.info(f"📸 Processing: {file.name}") image = Image.open(file).convert('RGB') image_np = np.array(image) disease_idx, disc_cup_image, vcdr, cls_conf, cup_conf, disc_conf, cropped_image = model.process(image_np) st.write("---") st.success(f"Results for: {file.name}") # Key findings with better visuals st.info("📊 Key Findings") # Diagnosis with color-coded warning levels diagnosis = model.cls_id2label[disease_idx] if diagnosis == "Glaucoma": st.warning(f"Diagnosis: {diagnosis} ({cls_conf:.1f}% confidence)") else: st.success(f"Diagnosis: {diagnosis} ({cls_conf:.1f}% confidence)") # VCDR with risk levels if vcdr > 0.7: st.warning(f"VCDR: {vcdr:.3f} - ⚠️ High Risk") elif vcdr > 0.5: st.warning(f"VCDR: {vcdr:.3f} - ⚠️ Borderline") else: st.success(f"VCDR: {vcdr:.3f} - ✅ Normal") # Segmentation confidence st.info("🔍 Segmentation Confidence") st.write(""" • Optic Cup (red area): Central depression • Optic Disc (green outline): Entire nerve area """) # Cup and Disc confidence with warnings if cup_conf < 60: st.warning(f"Cup Detection: {cup_conf:.1f}% - Low Confidence") else: st.write(f"Cup Detection: {cup_conf:.1f}%") if disc_conf < 60: st.warning(f"Disc Detection: {disc_conf:.1f}% - Low Confidence") else: st.write(f"Disc Detection: {disc_conf:.1f}%") # Images with clear sections st.info("🖼️ Analysis Images") st.image(disc_cup_image, caption="Green: Optic Disc | Red: Optic Cup") st.image(cropped_image, caption="Region of Interest") except Exception as e: st.error(f"Error processing {file.name}: {str(e)}") continue # Download section try: st.info("📥 Preparing Download") results = [] original_images = [] for file in uploaded_files: image = Image.open(file).convert('RGB') image_np = np.array(image) disease_idx, disc_cup_image, vcdr, cls_conf, cup_conf, disc_conf, cropped_image = model.process(image_np) results.append({ 'file_name': file.name, 'diagnosis': model.cls_id2label[disease_idx], 'confidence': cls_conf, 'vcdr': vcdr, 'cup_conf': cup_conf, 'disc_conf': disc_conf, 'processed_image': disc_cup_image, 'cropped_image': cropped_image }) original_images.append(image_np) zip_data = save_results(results, original_images) b64_zip = base64.b64encode(zip_data).decode() st.success("✅ Download Ready") href = f'📥 Download All Results (ZIP)' st.markdown(href, unsafe_allow_html=True) except Exception as e: st.error(f"Error preparing download: {str(e)}") st.success("✅ All Processing Complete!") except Exception as e: st.error(f"An error occurred: {str(e)}") if __name__ == "__main__": main()