|
import cv2 |
|
import torch |
|
import numpy as np |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from transformers import AutoImageProcessor, Swinv2ForImageClassification, SegformerForSemanticSegmentation |
|
import streamlit as st |
|
from PIL import Image |
|
import io |
|
import zipfile |
|
import pandas as pd |
|
from datetime import datetime |
|
import os |
|
import tempfile |
|
import base64 |
|
|
|
|
|
MODEL_OPTIONS = { |
|
"Default (ferferefer/segformer)": "ferferefer/segformer", |
|
"Pamixsun": "pamixsun/segformer_for_optic_disc_cup_segmentation" |
|
} |
|
|
|
|
|
class GlaucomaModel(object): |
|
def __init__(self, |
|
cls_model_path="pamixsun/swinv2_tiny_for_glaucoma_classification", |
|
seg_model_path=None, |
|
device=torch.device('cpu')): |
|
self.device = device |
|
self.seg_model_path = seg_model_path or MODEL_OPTIONS["Pamixsun"] |
|
|
|
|
|
self.cls_extractor = AutoImageProcessor.from_pretrained(cls_model_path) |
|
self.cls_model = Swinv2ForImageClassification.from_pretrained(cls_model_path).to(device).eval() |
|
|
|
|
|
self.seg_extractor = AutoImageProcessor.from_pretrained(self.seg_model_path) |
|
self.seg_model = SegformerForSemanticSegmentation.from_pretrained(self.seg_model_path).to(device).eval() |
|
|
|
|
|
self.is_ferferefer = "ferferefer" in self.seg_model_path.lower() |
|
|
|
self.cls_id2label = self.cls_model.config.id2label |
|
|
|
def glaucoma_pred(self, image): |
|
inputs = self.cls_extractor(images=image.copy(), return_tensors="pt") |
|
with torch.no_grad(): |
|
inputs.to(self.device) |
|
outputs = self.cls_model(**inputs).logits |
|
probs = F.softmax(outputs, dim=-1) |
|
disease_idx = probs.cpu()[0, :].numpy().argmax() |
|
confidence = probs.cpu()[0, disease_idx].item() * 100 |
|
return disease_idx, confidence |
|
|
|
def optic_disc_cup_pred(self, image): |
|
inputs = self.seg_extractor(images=image.copy(), return_tensors="pt") |
|
with torch.no_grad(): |
|
inputs.to(self.device) |
|
outputs = self.seg_model(**inputs) |
|
|
|
logits = outputs.logits.cpu() |
|
upsampled_logits = nn.functional.interpolate( |
|
logits, size=image.shape[:2], mode="bilinear", align_corners=False |
|
) |
|
|
|
if self.is_ferferefer: |
|
|
|
seg_probs = F.softmax(upsampled_logits, dim=1) |
|
pred_disc_cup = upsampled_logits.argmax(dim=1)[0] |
|
|
|
|
|
|
|
class_mapping = { |
|
0: 0, |
|
1: 1, |
|
2: 2 |
|
} |
|
|
|
pred_disc_cup_mapped = torch.zeros_like(pred_disc_cup) |
|
for old_class, new_class in class_mapping.items(): |
|
pred_disc_cup_mapped[pred_disc_cup == old_class] = new_class |
|
pred_disc_cup = pred_disc_cup_mapped |
|
else: |
|
|
|
seg_probs = F.softmax(upsampled_logits, dim=1) |
|
pred_disc_cup = upsampled_logits.argmax(dim=1)[0] |
|
|
|
|
|
cup_mask = pred_disc_cup == 2 |
|
disc_mask = pred_disc_cup == 1 |
|
|
|
cup_confidence = seg_probs[0, 2, cup_mask].mean().item() * 100 if cup_mask.any() else 0 |
|
disc_confidence = seg_probs[0, 1, disc_mask].mean().item() * 100 if disc_mask.any() else 0 |
|
|
|
return pred_disc_cup.numpy().astype(np.uint8), cup_confidence, disc_confidence |
|
|
|
def process(self, image): |
|
disease_idx, cls_confidence = self.glaucoma_pred(image) |
|
disc_cup, cup_confidence, disc_confidence = self.optic_disc_cup_pred(image) |
|
|
|
try: |
|
vcdr = simple_vcdr(disc_cup) |
|
except: |
|
vcdr = np.nan |
|
|
|
mask = (disc_cup > 0).astype(np.uint8) |
|
x, y, w, h = cv2.boundingRect(mask) |
|
padding = max(50, int(0.2 * max(w, h))) |
|
x = max(x - padding, 0) |
|
y = max(y - padding, 0) |
|
w = min(w + 2 * padding, image.shape[1] - x) |
|
h = min(h + 2 * padding, image.shape[0] - y) |
|
|
|
cropped_image = image[y:y+h, x:x+w] if w >= 50 and h >= 50 else image.copy() |
|
_, disc_cup_image = add_mask(image, disc_cup, [1, 2], [[0, 255, 0], [255, 0, 0]], 0.2) |
|
|
|
return disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence, cropped_image |
|
|
|
|
|
def simple_vcdr(mask): |
|
disc_area = np.sum(mask == 1) |
|
cup_area = np.sum(mask == 2) |
|
if disc_area == 0: |
|
return np.nan |
|
vcdr = cup_area / disc_area |
|
return vcdr |
|
|
|
def add_mask(image, mask, classes, colors, alpha=0.5): |
|
overlay = image.copy() |
|
for class_id, color in zip(classes, colors): |
|
overlay[mask == class_id] = color |
|
output = cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0) |
|
return output, overlay |
|
|
|
def get_confidence_level(confidence): |
|
"""Enhanced confidence descriptions for segmentation""" |
|
if confidence >= 90: |
|
return "Excellent (Model is very certain about the detected boundaries)" |
|
elif confidence >= 75: |
|
return "Good (Model is confident about most of the detected area)" |
|
elif confidence >= 60: |
|
return "Fair (Model has some uncertainty in parts of the detection)" |
|
elif confidence >= 45: |
|
return "Poor (Model is uncertain about many detected areas)" |
|
else: |
|
return "Very Poor (Model's detection is highly uncertain)" |
|
|
|
def process_batch(model, images_data, progress_bar=None): |
|
results = [] |
|
for idx, (file_name, image) in enumerate(images_data): |
|
try: |
|
disease_idx, disc_cup_image, vcdr, cls_conf, cup_conf, disc_conf, cropped_image = model.process(image) |
|
results.append({ |
|
'file_name': file_name, |
|
'diagnosis': model.cls_id2label[disease_idx], |
|
'confidence': cls_conf, |
|
'vcdr': vcdr, |
|
'cup_conf': cup_conf, |
|
'disc_conf': disc_conf, |
|
'processed_image': disc_cup_image, |
|
'cropped_image': cropped_image |
|
}) |
|
if progress_bar: |
|
progress_bar.progress((idx + 1) / len(images_data)) |
|
except Exception as e: |
|
st.error(f"Error processing {file_name}: {str(e)}") |
|
return results |
|
|
|
def save_results(results, original_images): |
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
|
df = pd.DataFrame([{ |
|
'File': r['file_name'], |
|
'Diagnosis': r['diagnosis'], |
|
'Confidence (%)': f"{r['confidence']:.1f}", |
|
'VCDR': f"{r['vcdr']:.3f}", |
|
'Cup Confidence (%)': f"{r['cup_conf']:.1f}", |
|
'Disc Confidence (%)': f"{r['disc_conf']:.1f}" |
|
} for r in results]) |
|
|
|
report_path = os.path.join(temp_dir, 'report.csv') |
|
df.to_csv(report_path, index=False) |
|
|
|
|
|
for result, orig_img in zip(results, original_images): |
|
img_name = result['file_name'] |
|
base_name = os.path.splitext(img_name)[0] |
|
|
|
|
|
orig_path = os.path.join(temp_dir, f"{base_name}_original.jpg") |
|
Image.fromarray(orig_img).save(orig_path) |
|
|
|
|
|
seg_path = os.path.join(temp_dir, f"{base_name}_segmentation.jpg") |
|
Image.fromarray(result['processed_image']).save(seg_path) |
|
|
|
|
|
roi_path = os.path.join(temp_dir, f"{base_name}_roi.jpg") |
|
Image.fromarray(result['cropped_image']).save(roi_path) |
|
|
|
|
|
zip_path = os.path.join(temp_dir, 'results.zip') |
|
with zipfile.ZipFile(zip_path, 'w') as zipf: |
|
for root, _, files in os.walk(temp_dir): |
|
for file in files: |
|
if file != 'results.zip': |
|
file_path = os.path.join(root, file) |
|
arcname = os.path.basename(file_path) |
|
zipf.write(file_path, arcname) |
|
|
|
with open(zip_path, 'rb') as f: |
|
return f.read() |
|
|
|
|
|
def main(): |
|
|
|
st.set_page_config(layout="wide") |
|
|
|
|
|
st.title("Glaucoma Screening from Retinal Fundus Images") |
|
st.write("Upload retinal images for automated glaucoma detection and optic disc/cup segmentation") |
|
|
|
|
|
st.sidebar.title("Model Settings") |
|
selected_model = st.sidebar.selectbox( |
|
"Select Segmentation Model", |
|
list(MODEL_OPTIONS.keys()), |
|
index=1 |
|
) |
|
|
|
st.sidebar.title("Upload Images") |
|
st.set_option('deprecation.showfileUploaderEncoding', False) |
|
uploaded_files = st.sidebar.file_uploader( |
|
"Upload retinal images", |
|
type=['png', 'jpeg', 'jpg'], |
|
accept_multiple_files=True |
|
) |
|
|
|
|
|
st.sidebar.markdown(""" |
|
### Understanding Results: |
|
- Diagnosis Confidence: AI certainty level |
|
- VCDR: Cup to disc ratio (>0.7 high risk) |
|
- Segmentation: Accuracy of detection |
|
""") |
|
|
|
if uploaded_files: |
|
try: |
|
|
|
st.info("π€ Loading AI Models") |
|
st.write("Classification: pamixsun/swinv2_tiny_for_glaucoma_classification") |
|
st.write(f"Segmentation: {selected_model}") |
|
|
|
model = GlaucomaModel( |
|
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), |
|
seg_model_path=MODEL_OPTIONS[selected_model] |
|
) |
|
|
|
st.success(f"β
Models loaded successfully - Using {'GPU' if torch.cuda.is_available() else 'CPU'}") |
|
st.write("---") |
|
|
|
for file in uploaded_files: |
|
try: |
|
st.info(f"πΈ Processing: {file.name}") |
|
image = Image.open(file).convert('RGB') |
|
image_np = np.array(image) |
|
|
|
disease_idx, disc_cup_image, vcdr, cls_conf, cup_conf, disc_conf, cropped_image = model.process(image_np) |
|
|
|
st.write("---") |
|
st.success(f"Results for: {file.name}") |
|
|
|
|
|
st.info("π Key Findings") |
|
|
|
|
|
diagnosis = model.cls_id2label[disease_idx] |
|
if diagnosis == "Glaucoma": |
|
st.warning(f"Diagnosis: {diagnosis} ({cls_conf:.1f}% confidence)") |
|
else: |
|
st.success(f"Diagnosis: {diagnosis} ({cls_conf:.1f}% confidence)") |
|
|
|
|
|
if vcdr > 0.7: |
|
st.warning(f"VCDR: {vcdr:.3f} - β οΈ High Risk") |
|
elif vcdr > 0.5: |
|
st.warning(f"VCDR: {vcdr:.3f} - β οΈ Borderline") |
|
else: |
|
st.success(f"VCDR: {vcdr:.3f} - β
Normal") |
|
|
|
|
|
st.info("π Segmentation Confidence") |
|
st.write(""" |
|
β’ Optic Cup (red area): Central depression |
|
β’ Optic Disc (green outline): Entire nerve area |
|
""") |
|
|
|
|
|
if cup_conf < 60: |
|
st.warning(f"Cup Detection: {cup_conf:.1f}% - Low Confidence") |
|
else: |
|
st.write(f"Cup Detection: {cup_conf:.1f}%") |
|
|
|
if disc_conf < 60: |
|
st.warning(f"Disc Detection: {disc_conf:.1f}% - Low Confidence") |
|
else: |
|
st.write(f"Disc Detection: {disc_conf:.1f}%") |
|
|
|
|
|
st.info("πΌοΈ Analysis Images") |
|
st.image(disc_cup_image, caption="Green: Optic Disc | Red: Optic Cup") |
|
st.image(cropped_image, caption="Region of Interest") |
|
|
|
except Exception as e: |
|
st.error(f"Error processing {file.name}: {str(e)}") |
|
continue |
|
|
|
|
|
try: |
|
st.info("π₯ Preparing Download") |
|
results = [] |
|
original_images = [] |
|
for file in uploaded_files: |
|
image = Image.open(file).convert('RGB') |
|
image_np = np.array(image) |
|
disease_idx, disc_cup_image, vcdr, cls_conf, cup_conf, disc_conf, cropped_image = model.process(image_np) |
|
results.append({ |
|
'file_name': file.name, |
|
'diagnosis': model.cls_id2label[disease_idx], |
|
'confidence': cls_conf, |
|
'vcdr': vcdr, |
|
'cup_conf': cup_conf, |
|
'disc_conf': disc_conf, |
|
'processed_image': disc_cup_image, |
|
'cropped_image': cropped_image |
|
}) |
|
original_images.append(image_np) |
|
|
|
zip_data = save_results(results, original_images) |
|
b64_zip = base64.b64encode(zip_data).decode() |
|
|
|
st.success("β
Download Ready") |
|
href = f'<a href="data:application/zip;base64,{b64_zip}" download="glaucoma_results.zip">π₯ Download All Results (ZIP)</a>' |
|
st.markdown(href, unsafe_allow_html=True) |
|
|
|
except Exception as e: |
|
st.error(f"Error preparing download: {str(e)}") |
|
|
|
st.success("β
All Processing Complete!") |
|
|
|
except Exception as e: |
|
st.error(f"An error occurred: {str(e)}") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|