|
import cv2 |
|
import torch |
|
import numpy as np |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from transformers import AutoImageProcessor, Swinv2ForImageClassification, SegformerForSemanticSegmentation |
|
import streamlit as st |
|
from PIL import Image |
|
import io |
|
import zipfile |
|
import pandas as pd |
|
from datetime import datetime |
|
import os |
|
import tempfile |
|
import base64 |
|
|
|
|
|
class GlaucomaModel(object): |
|
def __init__(self, |
|
cls_model_path="pamixsun/swinv2_tiny_for_glaucoma_classification", |
|
seg_model_path='pamixsun/segformer_for_optic_disc_cup_segmentation', |
|
device=torch.device('cpu')): |
|
self.device = device |
|
|
|
self.cls_extractor = AutoImageProcessor.from_pretrained(cls_model_path) |
|
self.cls_model = Swinv2ForImageClassification.from_pretrained(cls_model_path).to(device).eval() |
|
|
|
self.seg_extractor = AutoImageProcessor.from_pretrained(seg_model_path) |
|
self.seg_model = SegformerForSemanticSegmentation.from_pretrained(seg_model_path).to(device).eval() |
|
|
|
self.cls_id2label = self.cls_model.config.id2label |
|
|
|
def glaucoma_pred(self, image): |
|
inputs = self.cls_extractor(images=image.copy(), return_tensors="pt") |
|
with torch.no_grad(): |
|
inputs.to(self.device) |
|
outputs = self.cls_model(**inputs).logits |
|
probs = F.softmax(outputs, dim=-1) |
|
disease_idx = probs.cpu()[0, :].numpy().argmax() |
|
confidence = probs.cpu()[0, disease_idx].item() * 100 |
|
return disease_idx, confidence |
|
|
|
def optic_disc_cup_pred(self, image): |
|
inputs = self.seg_extractor(images=image.copy(), return_tensors="pt") |
|
with torch.no_grad(): |
|
inputs.to(self.device) |
|
outputs = self.seg_model(**inputs) |
|
logits = outputs.logits.cpu() |
|
upsampled_logits = nn.functional.interpolate( |
|
logits, size=image.shape[:2], mode="bilinear", align_corners=False |
|
) |
|
seg_probs = F.softmax(upsampled_logits, dim=1) |
|
pred_disc_cup = upsampled_logits.argmax(dim=1)[0] |
|
|
|
|
|
|
|
cup_mask = pred_disc_cup == 2 |
|
disc_mask = pred_disc_cup == 1 |
|
|
|
|
|
cup_confidence = seg_probs[0, 2, cup_mask].mean().item() * 100 if cup_mask.any() else 0 |
|
disc_confidence = seg_probs[0, 1, disc_mask].mean().item() * 100 if disc_mask.any() else 0 |
|
|
|
return pred_disc_cup.numpy().astype(np.uint8), cup_confidence, disc_confidence |
|
|
|
def process(self, image): |
|
disease_idx, cls_confidence = self.glaucoma_pred(image) |
|
disc_cup, cup_confidence, disc_confidence = self.optic_disc_cup_pred(image) |
|
|
|
try: |
|
vcdr = simple_vcdr(disc_cup) |
|
except: |
|
vcdr = np.nan |
|
|
|
mask = (disc_cup > 0).astype(np.uint8) |
|
x, y, w, h = cv2.boundingRect(mask) |
|
padding = max(50, int(0.2 * max(w, h))) |
|
x = max(x - padding, 0) |
|
y = max(y - padding, 0) |
|
w = min(w + 2 * padding, image.shape[1] - x) |
|
h = min(h + 2 * padding, image.shape[0] - y) |
|
|
|
cropped_image = image[y:y+h, x:x+w] if w >= 50 and h >= 50 else image.copy() |
|
_, disc_cup_image = add_mask(image, disc_cup, [1, 2], [[0, 255, 0], [255, 0, 0]], 0.2) |
|
|
|
return disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence, cropped_image |
|
|
|
|
|
def simple_vcdr(mask): |
|
disc_area = np.sum(mask == 1) |
|
cup_area = np.sum(mask == 2) |
|
if disc_area == 0: |
|
return np.nan |
|
vcdr = cup_area / disc_area |
|
return vcdr |
|
|
|
def add_mask(image, mask, classes, colors, alpha=0.5): |
|
overlay = image.copy() |
|
for class_id, color in zip(classes, colors): |
|
overlay[mask == class_id] = color |
|
output = cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0) |
|
return output, overlay |
|
|
|
def get_confidence_level(confidence): |
|
"""Enhanced confidence descriptions for segmentation""" |
|
if confidence >= 90: |
|
return "Excellent (Very clear boundaries)" |
|
elif confidence >= 75: |
|
return "Good (Clear boundaries)" |
|
elif confidence >= 60: |
|
return "Fair (Visible but some unclear areas)" |
|
elif confidence >= 45: |
|
return "Poor (Difficult to determine)" |
|
else: |
|
return "Very Poor (Not reliable)" |
|
|
|
def process_batch(model, images_data, progress_bar=None): |
|
results = [] |
|
for idx, (file_name, image) in enumerate(images_data): |
|
try: |
|
disease_idx, disc_cup_image, vcdr, cls_conf, cup_conf, disc_conf, cropped_image = model.process(image) |
|
results.append({ |
|
'file_name': file_name, |
|
'diagnosis': model.cls_id2label[disease_idx], |
|
'confidence': cls_conf, |
|
'vcdr': vcdr, |
|
'cup_conf': cup_conf, |
|
'disc_conf': disc_conf, |
|
'processed_image': disc_cup_image, |
|
'cropped_image': cropped_image |
|
}) |
|
if progress_bar: |
|
progress_bar.progress((idx + 1) / len(images_data)) |
|
except Exception as e: |
|
st.error(f"Error processing {file_name}: {str(e)}") |
|
return results |
|
|
|
def save_results(results, original_images): |
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
|
df = pd.DataFrame([{ |
|
'File': r['file_name'], |
|
'Diagnosis': r['diagnosis'], |
|
'Confidence (%)': f"{r['confidence']:.1f}", |
|
'VCDR': f"{r['vcdr']:.3f}", |
|
'Cup Confidence (%)': f"{r['cup_conf']:.1f}", |
|
'Disc Confidence (%)': f"{r['disc_conf']:.1f}" |
|
} for r in results]) |
|
|
|
report_path = os.path.join(temp_dir, 'report.csv') |
|
df.to_csv(report_path, index=False) |
|
|
|
|
|
for result, orig_img in zip(results, original_images): |
|
img_name = result['file_name'] |
|
base_name = os.path.splitext(img_name)[0] |
|
|
|
|
|
orig_path = os.path.join(temp_dir, f"{base_name}_original.jpg") |
|
Image.fromarray(orig_img).save(orig_path) |
|
|
|
|
|
seg_path = os.path.join(temp_dir, f"{base_name}_segmentation.jpg") |
|
Image.fromarray(result['processed_image']).save(seg_path) |
|
|
|
|
|
roi_path = os.path.join(temp_dir, f"{base_name}_roi.jpg") |
|
Image.fromarray(result['cropped_image']).save(roi_path) |
|
|
|
|
|
zip_path = os.path.join(temp_dir, 'results.zip') |
|
with zipfile.ZipFile(zip_path, 'w') as zipf: |
|
for root, _, files in os.walk(temp_dir): |
|
for file in files: |
|
if file != 'results.zip': |
|
file_path = os.path.join(root, file) |
|
arcname = os.path.basename(file_path) |
|
zipf.write(file_path, arcname) |
|
|
|
with open(zip_path, 'rb') as f: |
|
return f.read() |
|
|
|
|
|
def main(): |
|
|
|
st.set_page_config(layout="wide") |
|
|
|
|
|
st.title("Glaucoma Screening from Retinal Fundus Images") |
|
st.write("Upload retinal images for automated glaucoma detection and optic disc/cup segmentation") |
|
|
|
|
|
st.sidebar.title("Upload Images") |
|
st.set_option('deprecation.showfileUploaderEncoding', False) |
|
uploaded_files = st.sidebar.file_uploader( |
|
"Upload retinal images", |
|
type=['png', 'jpeg', 'jpg'], |
|
accept_multiple_files=True |
|
) |
|
|
|
|
|
st.sidebar.markdown(""" |
|
### Understanding Results: |
|
- Diagnosis Confidence: AI certainty level |
|
- VCDR: Cup to disc ratio (>0.7 high risk) |
|
- Segmentation: Accuracy of detection |
|
""") |
|
|
|
if uploaded_files: |
|
st.write("Loading AI models...") |
|
|
|
try: |
|
model = GlaucomaModel(device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")) |
|
|
|
for file in uploaded_files: |
|
try: |
|
|
|
st.write(f"Processing: {file.name}") |
|
image = Image.open(file).convert('RGB') |
|
image_np = np.array(image) |
|
|
|
|
|
disease_idx, disc_cup_image, vcdr, cls_conf, cup_conf, disc_conf, cropped_image = model.process(image_np) |
|
|
|
|
|
st.write("---") |
|
st.write(f"Results for {file.name}") |
|
|
|
|
|
st.write("📊 **Diagnosis Results:**") |
|
st.write(f"• Finding: {model.cls_id2label[disease_idx]}") |
|
st.write(f"• AI Confidence: {cls_conf:.1f}% ({get_confidence_level(cls_conf)})") |
|
|
|
|
|
st.write("\n🔍 **Segmentation Quality:**") |
|
st.write(f"• Optic Cup Detection: {cup_conf:.1f}% - {get_confidence_level(cup_conf)}") |
|
st.write(f"• Optic Disc Detection: {disc_conf:.1f}% - {get_confidence_level(disc_conf)}") |
|
|
|
|
|
st.write("\n📏 **Clinical Measurements:**") |
|
st.write(f"• Cup-to-Disc Ratio (VCDR): {vcdr:.3f}") |
|
if vcdr > 0.7: |
|
st.write(" ⚠️ High VCDR - Potential risk indicator") |
|
elif vcdr > 0.5: |
|
st.write(" ℹ️ Borderline VCDR - Follow-up recommended") |
|
else: |
|
st.write(" ✅ Normal VCDR range") |
|
|
|
|
|
st.write("\n🖼️ **Visual Analysis:**") |
|
st.image(disc_cup_image, caption=""" |
|
Segmentation Overlay |
|
• Green outline: Optic Disc boundary |
|
• Red area: Optic Cup region |
|
• Transparency shows underlying retina |
|
""") |
|
st.image(cropped_image, caption="Zoomed Region of Interest") |
|
|
|
|
|
if cup_conf < 60 or disc_conf < 60: |
|
st.write("\n⚠️ Note: Low segmentation confidence. Image quality might affect measurements.") |
|
|
|
except Exception as e: |
|
st.error(f"Error processing {file.name}: {str(e)}") |
|
continue |
|
|
|
|
|
st.write("---") |
|
st.write("Processing complete!") |
|
|
|
except Exception as e: |
|
st.error(f"An error occurred: {str(e)}") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|