luigi12345's picture
Update app.py
7a986e7 verified
raw
history blame
5.58 kB
import cv2
import torch
import numpy as np
import torch.nn.functional as F
from torch import nn
from transformers import AutoImageProcessor, Swinv2ForImageClassification, SegformerForSemanticSegmentation
import matplotlib.pyplot as plt
import streamlit as st
from PIL import Image
# --- GlaucomaModel Class ---
class GlaucomaModel(object):
def __init__(self,
cls_model_path="pamixsun/swinv2_tiny_for_glaucoma_classification",
seg_model_path='pamixsun/segformer_for_optic_disc_cup_segmentation',
device=torch.device('cpu')):
self.device = device
# Classification model for glaucoma
self.cls_extractor = AutoImageProcessor.from_pretrained(cls_model_path)
self.cls_model = Swinv2ForImageClassification.from_pretrained(cls_model_path).to(device).eval()
# Segmentation model for optic disc and cup
self.seg_extractor = AutoImageProcessor.from_pretrained(seg_model_path)
self.seg_model = SegformerForSemanticSegmentation.from_pretrained(seg_model_path).to(device).eval()
# Class activation map
self.cls_id2label = self.cls_model.config.id2label
self.seg_id2label = self.seg_model.config.id2label
def glaucoma_pred(self, image):
inputs = self.cls_extractor(images=image.copy(), return_tensors="pt")
with torch.no_grad():
inputs.to(self.device)
outputs = self.cls_model(**inputs).logits
# Softmax for probabilities
probs = F.softmax(outputs, dim=-1)
disease_idx = probs.cpu()[0, :].numpy().argmax()
confidence = probs.cpu()[0, disease_idx].item()
return disease_idx, confidence
def optic_disc_cup_pred(self, image):
inputs = self.seg_extractor(images=image.copy(), return_tensors="pt")
with torch.no_grad():
inputs.to(self.device)
outputs = self.seg_model(**inputs)
logits = outputs.logits.cpu()
upsampled_logits = nn.functional.interpolate(
logits, size=image.shape[:2], mode="bilinear", align_corners=False
)
# Softmax for segmentation confidence
seg_probs = F.softmax(upsampled_logits, dim=1)
pred_disc_cup = upsampled_logits.argmax(dim=1)[0]
cup_confidence = seg_probs[0, 2, :, :].mean().item()
disc_confidence = seg_probs[0, 1, :, :].mean().item()
return pred_disc_cup.numpy().astype(np.uint8), cup_confidence, disc_confidence
def process(self, image):
image_shape = image.shape[:2]
disease_idx, cls_confidence = self.glaucoma_pred(image)
disc_cup, cup_confidence, disc_confidence = self.optic_disc_cup_pred(image)
try:
vcdr = simple_vcdr(disc_cup) # Assuming simple_vcdr() is defined elsewhere
except:
vcdr = np.nan
_, disc_cup_image = add_mask(image, disc_cup, [1, 2], [[0, 255, 0], [255, 0, 0]], 0.2)
return disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence
# --- Streamlit Interface ---
def main():
# Wide mode in Streamlit
st.set_page_config(layout="wide")
st.title("Glaucoma Screening from Retinal Fundus Images")
st.write('Developed by X. Sun. Find more info about me: https://pamixsun.github.io')
# Set columns for the interface
cols = st.beta_columns((1, 1, 1))
cols[0].subheader("Input image")
cols[1].subheader("Optic disc and optic cup")
cols[2].subheader("Classification Map")
# File uploader
st.sidebar.title("Image selection")
st.set_option('deprecation.showfileUploaderEncoding', False)
uploaded_file = st.sidebar.file_uploader("Upload image", type=['png', 'jpeg', 'jpg'])
if uploaded_file is not None:
# Read and display uploaded image
image = Image.open(uploaded_file).convert('RGB')
image = np.array(image).astype(np.uint8)
fig, ax = plt.subplots()
ax.imshow(image)
ax.axis('off')
cols[0].pyplot(fig)
if st.sidebar.button("Analyze image"):
if uploaded_file is None:
st.sidebar.write("Please upload an image")
else:
with st.spinner('Loading model...'):
# Load the model on available device
run_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = GlaucomaModel(device=run_device)
with st.spinner('Analyzing...'):
# Get predictions from the model
disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence = model.process(image)
# Display optic disc and cup image
ax.imshow(disc_cup_image)
ax.axis('off')
cols[1].pyplot(fig)
# Display classification map
ax.imshow(image)
ax.axis('off')
cols[2].pyplot(fig)
# Display results with confidence
st.subheader("Screening results:")
final_results_as_table = f"""
|Parameters|Outcomes|
|---|---|
|Vertical cup-to-disc ratio|{vcdr:.04f}|
|Category|{model.cls_id2label[disease_idx]} ({cls_confidence*100:.02f}% confidence)|
|Optic Cup Segmentation Confidence|{cup_confidence*100:.02f}%|
|Optic Disc Segmentation Confidence|{disc_confidence*100:.02f}%|
"""
st.markdown(final_results_as_table)
if __name__ == '__main__':
main()