File size: 5,584 Bytes
7a986e7 f2c28c8 7a986e7 f2c28c8 7a986e7 f2c28c8 7a986e7 f2c28c8 7a986e7 dfbd972 f2c28c8 dfbd972 7a986e7 f2c28c8 7a986e7 f2c28c8 7a986e7 f2c28c8 7a986e7 f2c28c8 7a986e7 f2c28c8 7a986e7 f2c28c8 7a986e7 f2c28c8 7a986e7 dfbd972 f2c28c8 7a986e7 dfbd972 7a986e7 dfbd972 a7e5224 dfbd972 a7e5224 7a986e7 dfbd972 f2c28c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import cv2
import torch
import numpy as np
import torch.nn.functional as F
from torch import nn
from transformers import AutoImageProcessor, Swinv2ForImageClassification, SegformerForSemanticSegmentation
import matplotlib.pyplot as plt
import streamlit as st
from PIL import Image
# --- GlaucomaModel Class ---
class GlaucomaModel(object):
def __init__(self,
cls_model_path="pamixsun/swinv2_tiny_for_glaucoma_classification",
seg_model_path='pamixsun/segformer_for_optic_disc_cup_segmentation',
device=torch.device('cpu')):
self.device = device
# Classification model for glaucoma
self.cls_extractor = AutoImageProcessor.from_pretrained(cls_model_path)
self.cls_model = Swinv2ForImageClassification.from_pretrained(cls_model_path).to(device).eval()
# Segmentation model for optic disc and cup
self.seg_extractor = AutoImageProcessor.from_pretrained(seg_model_path)
self.seg_model = SegformerForSemanticSegmentation.from_pretrained(seg_model_path).to(device).eval()
# Class activation map
self.cls_id2label = self.cls_model.config.id2label
self.seg_id2label = self.seg_model.config.id2label
def glaucoma_pred(self, image):
inputs = self.cls_extractor(images=image.copy(), return_tensors="pt")
with torch.no_grad():
inputs.to(self.device)
outputs = self.cls_model(**inputs).logits
# Softmax for probabilities
probs = F.softmax(outputs, dim=-1)
disease_idx = probs.cpu()[0, :].numpy().argmax()
confidence = probs.cpu()[0, disease_idx].item()
return disease_idx, confidence
def optic_disc_cup_pred(self, image):
inputs = self.seg_extractor(images=image.copy(), return_tensors="pt")
with torch.no_grad():
inputs.to(self.device)
outputs = self.seg_model(**inputs)
logits = outputs.logits.cpu()
upsampled_logits = nn.functional.interpolate(
logits, size=image.shape[:2], mode="bilinear", align_corners=False
)
# Softmax for segmentation confidence
seg_probs = F.softmax(upsampled_logits, dim=1)
pred_disc_cup = upsampled_logits.argmax(dim=1)[0]
cup_confidence = seg_probs[0, 2, :, :].mean().item()
disc_confidence = seg_probs[0, 1, :, :].mean().item()
return pred_disc_cup.numpy().astype(np.uint8), cup_confidence, disc_confidence
def process(self, image):
image_shape = image.shape[:2]
disease_idx, cls_confidence = self.glaucoma_pred(image)
disc_cup, cup_confidence, disc_confidence = self.optic_disc_cup_pred(image)
try:
vcdr = simple_vcdr(disc_cup) # Assuming simple_vcdr() is defined elsewhere
except:
vcdr = np.nan
_, disc_cup_image = add_mask(image, disc_cup, [1, 2], [[0, 255, 0], [255, 0, 0]], 0.2)
return disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence
# --- Streamlit Interface ---
def main():
# Wide mode in Streamlit
st.set_page_config(layout="wide")
st.title("Glaucoma Screening from Retinal Fundus Images")
st.write('Developed by X. Sun. Find more info about me: https://pamixsun.github.io')
# Set columns for the interface
cols = st.beta_columns((1, 1, 1))
cols[0].subheader("Input image")
cols[1].subheader("Optic disc and optic cup")
cols[2].subheader("Classification Map")
# File uploader
st.sidebar.title("Image selection")
st.set_option('deprecation.showfileUploaderEncoding', False)
uploaded_file = st.sidebar.file_uploader("Upload image", type=['png', 'jpeg', 'jpg'])
if uploaded_file is not None:
# Read and display uploaded image
image = Image.open(uploaded_file).convert('RGB')
image = np.array(image).astype(np.uint8)
fig, ax = plt.subplots()
ax.imshow(image)
ax.axis('off')
cols[0].pyplot(fig)
if st.sidebar.button("Analyze image"):
if uploaded_file is None:
st.sidebar.write("Please upload an image")
else:
with st.spinner('Loading model...'):
# Load the model on available device
run_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = GlaucomaModel(device=run_device)
with st.spinner('Analyzing...'):
# Get predictions from the model
disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence = model.process(image)
# Display optic disc and cup image
ax.imshow(disc_cup_image)
ax.axis('off')
cols[1].pyplot(fig)
# Display classification map
ax.imshow(image)
ax.axis('off')
cols[2].pyplot(fig)
# Display results with confidence
st.subheader("Screening results:")
final_results_as_table = f"""
|Parameters|Outcomes|
|---|---|
|Vertical cup-to-disc ratio|{vcdr:.04f}|
|Category|{model.cls_id2label[disease_idx]} ({cls_confidence*100:.02f}% confidence)|
|Optic Cup Segmentation Confidence|{cup_confidence*100:.02f}%|
|Optic Disc Segmentation Confidence|{disc_confidence*100:.02f}%|
"""
st.markdown(final_results_as_table)
if __name__ == '__main__':
main() |