import cv2 import torch import numpy as np import torch.nn.functional as F from torch import nn from transformers import AutoImageProcessor, Swinv2ForImageClassification, SegformerForSemanticSegmentation import streamlit as st from PIL import Image import io import zipfile import pandas as pd from datetime import datetime import os import tempfile # --- GlaucomaModel Class --- class GlaucomaModel(object): def __init__(self, cls_model_path="pamixsun/swinv2_tiny_for_glaucoma_classification", seg_model_path='pamixsun/segformer_for_optic_disc_cup_segmentation', device=torch.device('cpu')): self.device = device # Classification model for glaucoma self.cls_extractor = AutoImageProcessor.from_pretrained(cls_model_path) self.cls_model = Swinv2ForImageClassification.from_pretrained(cls_model_path).to(device).eval() # Segmentation model for optic disc and cup self.seg_extractor = AutoImageProcessor.from_pretrained(seg_model_path) self.seg_model = SegformerForSemanticSegmentation.from_pretrained(seg_model_path).to(device).eval() # Mapping for class labels self.cls_id2label = self.cls_model.config.id2label def glaucoma_pred(self, image): inputs = self.cls_extractor(images=image.copy(), return_tensors="pt") with torch.no_grad(): inputs.to(self.device) outputs = self.cls_model(**inputs).logits probs = F.softmax(outputs, dim=-1) disease_idx = probs.cpu()[0, :].numpy().argmax() confidence = probs.cpu()[0, disease_idx].item() * 100 return disease_idx, confidence def optic_disc_cup_pred(self, image): inputs = self.seg_extractor(images=image.copy(), return_tensors="pt") with torch.no_grad(): inputs.to(self.device) outputs = self.seg_model(**inputs) logits = outputs.logits.cpu() upsampled_logits = nn.functional.interpolate( logits, size=image.shape[:2], mode="bilinear", align_corners=False ) seg_probs = F.softmax(upsampled_logits, dim=1) pred_disc_cup = upsampled_logits.argmax(dim=1)[0] # Calculate segmentation confidence based on probability distribution # For each pixel classified as cup/disc, check how confident the model is cup_mask = pred_disc_cup == 2 disc_mask = pred_disc_cup == 1 # Get confidence only for pixels predicted as cup/disc cup_confidence = seg_probs[0, 2, cup_mask].mean().item() * 100 if cup_mask.any() else 0 disc_confidence = seg_probs[0, 1, disc_mask].mean().item() * 100 if disc_mask.any() else 0 return pred_disc_cup.numpy().astype(np.uint8), cup_confidence, disc_confidence def process(self, image): disease_idx, cls_confidence = self.glaucoma_pred(image) disc_cup, cup_confidence, disc_confidence = self.optic_disc_cup_pred(image) try: vcdr = simple_vcdr(disc_cup) except: vcdr = np.nan mask = (disc_cup > 0).astype(np.uint8) x, y, w, h = cv2.boundingRect(mask) padding = max(50, int(0.2 * max(w, h))) x = max(x - padding, 0) y = max(y - padding, 0) w = min(w + 2 * padding, image.shape[1] - x) h = min(h + 2 * padding, image.shape[0] - y) cropped_image = image[y:y+h, x:x+w] if w >= 50 and h >= 50 else image.copy() _, disc_cup_image = add_mask(image, disc_cup, [1, 2], [[0, 255, 0], [255, 0, 0]], 0.2) return disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence, cropped_image # --- Utility Functions --- def simple_vcdr(mask): disc_area = np.sum(mask == 1) cup_area = np.sum(mask == 2) if disc_area == 0: return np.nan vcdr = cup_area / disc_area return vcdr def add_mask(image, mask, classes, colors, alpha=0.5): overlay = image.copy() for class_id, color in zip(classes, colors): overlay[mask == class_id] = color output = cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0) return output, overlay def get_confidence_level(confidence): if confidence >= 90: return "Very High" elif confidence >= 75: return "High" elif confidence >= 60: return "Moderate" elif confidence >= 45: return "Low" else: return "Very Low" def process_batch(model, images_data, progress_bar=None): results = [] for idx, (file_name, image) in enumerate(images_data): try: disease_idx, disc_cup_image, vcdr, cls_conf, cup_conf, disc_conf, cropped_image = model.process(image) results.append({ 'file_name': file_name, 'diagnosis': model.cls_id2label[disease_idx], 'confidence': cls_conf, 'vcdr': vcdr, 'cup_conf': cup_conf, 'disc_conf': disc_conf, 'processed_image': disc_cup_image, 'cropped_image': cropped_image }) if progress_bar: progress_bar.progress((idx + 1) / len(images_data)) except Exception as e: st.error(f"Error processing {file_name}: {str(e)}") return results def save_results(results, original_images): # Create temporary directory for results with tempfile.TemporaryDirectory() as temp_dir: # Save report as CSV df = pd.DataFrame([{ 'File': r['file_name'], 'Diagnosis': r['diagnosis'], 'Confidence (%)': f"{r['confidence']:.1f}", 'VCDR': f"{r['vcdr']:.3f}", 'Cup Confidence (%)': f"{r['cup_conf']:.1f}", 'Disc Confidence (%)': f"{r['disc_conf']:.1f}" } for r in results]) report_path = os.path.join(temp_dir, 'report.csv') df.to_csv(report_path, index=False) # Save processed images for result, orig_img in zip(results, original_images): img_name = result['file_name'] base_name = os.path.splitext(img_name)[0] # Save original orig_path = os.path.join(temp_dir, f"{base_name}_original.jpg") Image.fromarray(orig_img).save(orig_path) # Save segmentation seg_path = os.path.join(temp_dir, f"{base_name}_segmentation.jpg") Image.fromarray(result['processed_image']).save(seg_path) # Save ROI roi_path = os.path.join(temp_dir, f"{base_name}_roi.jpg") Image.fromarray(result['cropped_image']).save(roi_path) # Create ZIP file zip_path = os.path.join(temp_dir, 'results.zip') with zipfile.ZipFile(zip_path, 'w') as zipf: for root, _, files in os.walk(temp_dir): for file in files: if file != 'results.zip': file_path = os.path.join(root, file) arcname = os.path.basename(file_path) zipf.write(file_path, arcname) with open(zip_path, 'rb') as f: return f.read() # --- Streamlit Interface --- def main(): st.set_page_config(layout="wide", page_title="Glaucoma Screening Tool") print("Starting app...") # Debug print st.markdown("""
Upload retinal images for automated glaucoma detection and optic disc/cup segmentation
""", unsafe_allow_html=True) print("Header rendered...") # Debug print # Add session state for better state management if 'processed_count' not in st.session_state: st.session_state.processed_count = 0 # Add a more informative sidebar with st.sidebar: st.markdown("### 📤 Upload Images") uploaded_files = st.file_uploader( "Upload Retinal Images", type=['png', 'jpeg', 'jpg'], accept_multiple_files=True, help="Support multiple images in PNG, JPEG formats" ) st.markdown("### 📊 Processing Stats") if 'processed_count' in st.session_state: st.metric("Images Processed", st.session_state.processed_count) st.markdown("---") # Add batch size limit max_batch = st.number_input("Max Batch Size", min_value=1, max_value=100, value=20, help="Maximum number of images to process in one batch") if uploaded_files: # Validate batch size if len(uploaded_files) > max_batch: st.warning(f"⚠️ Please upload maximum {max_batch} images at once. Current: {len(uploaded_files)}") return # Add loading animation with st.spinner('🔄 Initializing model...'): model = GlaucomaModel(device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")) # Add summary metrics at the top col1, col2 = st.columns(2) with col1: st.info(f"📁 Total images: {len(uploaded_files)}") with col2: st.info(f"⚙️ Using: {'GPU' if torch.cuda.is_available() else 'CPU'}") # Prepare images data images_data = [] original_images = [] for file in uploaded_files: try: image = Image.open(file).convert('RGB') image_np = np.array(image) images_data.append((file.name, image_np)) original_images.append(image_np) except Exception as e: st.error(f"Error loading {file.name}: {str(e)}") continue progress_bar = st.progress(0) st.write(f"Processing {len(images_data)} images...") # Process all images results = process_batch(model, images_data, progress_bar) if results: # Add summary statistics st.markdown("### 📊 Summary Statistics") glaucoma_count = sum(1 for r in results if r['diagnosis'] == 'Glaucoma') normal_count = len(results) - glaucoma_count cols = st.columns(4) with cols[0]: st.metric("Total Processed", len(results)) with cols[1]: st.metric("Glaucoma Detected", glaucoma_count) with cols[2]: st.metric("Normal", normal_count) with cols[3]: avg_conf = sum(r['confidence'] for r in results) / len(results) st.metric("Avg Confidence", f"{avg_conf:.1f}%") # Add filter options st.markdown("### 🔍 Filter Results") show_only = st.multiselect( "Show cases:", ["All", "Glaucoma", "Normal"], default=["All"] ) # Filter results based on selection filtered_results = results if "All" not in show_only: filtered_results = [ r for r in results if (r['diagnosis'] == 'Glaucoma' and 'Glaucoma' in show_only) or (r['diagnosis'] == 'Normal' and 'Normal' in show_only) ] # Display filtered results for result in filtered_results: with st.expander( f"📋 {result['file_name']} - {result['diagnosis']} ({result['confidence']:.1f}% confidence)" ): cols = st.columns(3) with cols[0]: st.image(result['processed_image'], caption="Segmentation", use_column_width=True) with cols[1]: st.image(result['cropped_image'], caption="ROI", use_column_width=True) with cols[2]: st.write("### Metrics") st.write(f"Diagnosis: {result['diagnosis']}") st.write(f"Confidence: {result['confidence']:.1f}%") st.write(f"VCDR: {result['vcdr']:.3f}") st.write(f"Cup Confidence: {result['cup_conf']:.1f}%") st.write(f"Disc Confidence: {result['disc_conf']:.1f}%") # Update session state st.session_state.processed_count += len(results) # Add export options st.markdown("### 📥 Export Options") col1, col2 = st.columns(2) with col1: st.download_button( label="📥 Download All Results (ZIP)", data=zip_data, file_name=f"glaucoma_screening_{datetime.now().strftime('%Y%m%d_%H%M%S')}.zip", mime="application/zip" ) with col2: # Add CSV-only download csv_data = pd.DataFrame([{ 'File': r['file_name'], 'Diagnosis': r['diagnosis'], 'Confidence': r['confidence'], 'VCDR': r['vcdr'] } for r in results]).to_csv(index=False) st.download_button( label="📊 Download Report (CSV)", data=csv_data, file_name=f"glaucoma_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv", mime="text/csv" ) # Add this at the end of the file if __name__ == "__main__": print("Running main...") # Debug print main()