File size: 8,012 Bytes
7a986e7 f2c28c8 7a986e7 f2c28c8 b75380d f2c28c8 7a986e7 b75380d 7a986e7 b75380d 7a986e7 0eb72c7 7a986e7 0eb72c7 7a986e7 0eb72c7 b75380d 0eb72c7 b75380d 0eb72c7 b75380d 0eb72c7 7a986e7 0eb72c7 b75380d 0eb72c7 7a986e7 f2c28c8 7a986e7 0eb72c7 f2c28c8 dfbd972 0eb72c7 f2c28c8 7a986e7 f2c28c8 7a986e7 f2c28c8 7a986e7 f2c28c8 7a986e7 f2c28c8 7a986e7 f2c28c8 a4f839e 0eb72c7 f2c28c8 7a986e7 dfbd972 f2c28c8 7a986e7 dfbd972 0eb72c7 54ff357 b75380d 54ff357 b75380d 54ff357 b75380d 7a986e7 dfbd972 a7e5224 dfbd972 a7e5224 b75380d dfbd972 f2c28c8 a4f839e |
|
import cv2
import torch
import numpy as np
import torch.nn.functional as F
from torch import nn
from transformers import AutoImageProcessor, Swinv2ForImageClassification, SegformerForSemanticSegmentation
import matplotlib.pyplot as plt
import streamlit as st
from PIL import Image
import io
# --- GlaucomaModel Class ---
class GlaucomaModel(object):
def __init__(self,
cls_model_path="pamixsun/swinv2_tiny_for_glaucoma_classification",
seg_model_path='pamixsun/segformer_for_optic_disc_cup_segmentation',
device=torch.device('cpu')):
self.device = device
# Classification model for glaucoma
self.cls_extractor = AutoImageProcessor.from_pretrained(cls_model_path)
self.cls_model = Swinv2ForImageClassification.from_pretrained(cls_model_path).to(device).eval()
# Segmentation model for optic disc and cup
self.seg_extractor = AutoImageProcessor.from_pretrained(seg_model_path)
self.seg_model = SegformerForSemanticSegmentation.from_pretrained(seg_model_path).to(device).eval()
# Class activation map
self.cls_id2label = self.cls_model.config.id2label
self.seg_id2label = self.seg_model.config.id2label
def glaucoma_pred(self, image):
inputs = self.cls_extractor(images=image.copy(), return_tensors="pt")
with torch.no_grad():
inputs.to(self.device)
outputs = self.cls_model(**inputs).logits
# Softmax for probabilities
probs = F.softmax(outputs, dim=-1)
disease_idx = probs.cpu()[0, :].numpy().argmax()
confidence = probs.cpu()[0, disease_idx].item() * 100 # Scale to percentage
return disease_idx, confidence
def optic_disc_cup_pred(self, image):
inputs = self.seg_extractor(images=image.copy(), return_tensors="pt")
with torch.no_grad():
inputs.to(self.device)
outputs = self.seg_model(**inputs)
logits = outputs.logits.cpu()
upsampled_logits = nn.functional.interpolate(
logits, size=image.shape[:2], mode="bilinear", align_corners=False
)
# Softmax for segmentation confidence
seg_probs = F.softmax(upsampled_logits, dim=1)
pred_disc_cup = upsampled_logits.argmax(dim=1)[0]
cup_confidence = seg_probs[0, 2, :, :].mean().item() * 100 # Scale to percentage
disc_confidence = seg_probs[0, 1, :, :].mean().item() * 100 # Scale to percentage
return pred_disc_cup.numpy().astype(np.uint8), cup_confidence, disc_confidence
def process(self, image):
image_shape = image.shape[:2]
disease_idx, cls_confidence = self.glaucoma_pred(image)
disc_cup, cup_confidence, disc_confidence = self.optic_disc_cup_pred(image)
try:
vcdr = simple_vcdr(disc_cup) # Calculate vertical cup-to-disc ratio
except:
vcdr = np.nan
# Mask for optic disc and cup
mask = (disc_cup > 0).astype(np.uint8)
# Get bounding box of the optic cup + disc and add dynamic padding
x, y, w, h = cv2.boundingRect(mask)
padding = max(50, int(0.2 * max(w, h))) # Dynamic padding (20% of width or height)
x = max(x - padding, 0)
y = max(y - padding, 0)
w = min(w + 2 * padding, image.shape[1] - x)
h = min(h + 2 * padding, image.shape[0] - y)
# Ensure that the bounding box is large enough to avoid cropping errors
cropped_image = image[y:y+h, x:x+w] if w >= 50 and h >= 50 else image.copy()
# Generate disc and cup visualization
_, disc_cup_image = add_mask(image, disc_cup, [1, 2], [[0, 255, 0], [255, 0, 0]], 0.2)
return disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence, cropped_image
# --- Utility Functions ---
def simple_vcdr(mask):
"""
Simple function to calculate the vertical cup-to-disc ratio (VCDR).
Assumes:
- mask contains class 1 for optic disc and class 2 for optic cup.
"""
disc_area = np.sum(mask == 1)
cup_area = np.sum(mask == 2)
if disc_area == 0: # Avoid division by zero
return np.nan
vcdr = cup_area / disc_area
return vcdr
def add_mask(image, mask, classes, colors, alpha=0.5):
"""
Adds a transparent mask to the original image.
Args:
- image: the original RGB image
- mask: the predicted segmentation mask
- classes: a list of class indices to apply masks for (e.g., [1, 2])
- colors: a list of colors for each class (e.g., [[0, 255, 0], [255, 0, 0]] for green and red)
- alpha: transparency level (default = 0.5)
"""
overlay = image.copy()
for class_id, color in zip(classes, colors):
overlay[mask == class_id] = color
output = cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0)
return output, overlay
# --- Streamlit Interface ---
def main():
st.set_page_config(layout="wide")
st.title("Glaucoma Screening from Retinal Fundus Images")
st.write('Developed by X. Sun. Find more info about me: https://pamixsun.github.io')
# Set columns for the interface
cols = st.beta_columns((1, 1, 1, 1))
cols[0].subheader("Input image")
cols[1].subheader("Optic disc and optic cup")
cols[2].subheader("Class activation map")
cols[3].subheader("Cropped Image")
# File uploader
st.sidebar.title("Image selection")
st.set_option('deprecation.showfileUploaderEncoding', False)
uploaded_file = st.sidebar.file_uploader("Upload image", type=['png', 'jpeg', 'jpg'])
if uploaded_file is not None:
# Read and display uploaded image
image = Image.open(uploaded_file).convert('RGB')
image = np.array(image).astype(np.uint8)
fig, ax = plt.subplots()
ax.imshow(image)
ax.axis('off')
cols[0].pyplot(fig)
if st.sidebar.button("Analyze image"):
if uploaded_file is None:
st.sidebar.write("Please upload an image")
else:
with st.spinner('Loading model...'):
run_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = GlaucomaModel(device=run_device)
with st.spinner('Analyzing...'):
# Get predictions from the model
disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence, cropped_image = model.process(image)
# Display optic disc and cup image
ax.imshow(disc_cup_image)
ax.axis('off')
cols[1].pyplot(fig)
# Display classification map
ax.imshow(image)
ax.axis('off')
cols[2].pyplot(fig)
# Display the cropped image
ax.imshow(cropped_image)
ax.axis('off')
cols[3].pyplot(fig)
# Make cropped image downloadable by converting it to bytes
buf = io.BytesIO()
Image.fromarray(cropped_image).save(buf, format="PNG")
byte_img = buf.getvalue()
st.sidebar.download_button(
label="Download Cropped Image",
data=byte_img,
file_name="cropped_image.png",
mime="image/png"
)
# Display results with confidence
st.subheader("Screening results:")
final_results_as_table = f"""
|Parameters|Outcomes|
|---|---|
|Vertical cup-to-disc ratio|{vcdr:.04f}|
|Category|{model.cls_id2label[disease_idx]} ({cls_confidence:.02f}% confidence)|
|Optic Cup Segmentation Confidence|{cup_confidence:.02f}%|
|Optic Disc Segmentation Confidence|{disc_confidence:.02f}%|
"""
st.markdown(final_results_as_table)
if __name__ == '__main__':
main() |