File size: 7,755 Bytes
7a986e7 f2c28c8 7a986e7 f2c28c8 7a986e7 0eb72c7 7a986e7 0eb72c7 7a986e7 0eb72c7 7a986e7 0eb72c7 7a986e7 f2c28c8 7a986e7 f2c28c8 7a986e7 0eb72c7 f2c28c8 dfbd972 0eb72c7 f2c28c8 7a986e7 f2c28c8 7a986e7 f2c28c8 7a986e7 f2c28c8 7a986e7 f2c28c8 7a986e7 f2c28c8 7a986e7 0eb72c7 f2c28c8 7a986e7 dfbd972 f2c28c8 7a986e7 dfbd972 0eb72c7 7a986e7 dfbd972 a7e5224 dfbd972 a7e5224 7a986e7 dfbd972 f2c28c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import cv2
import torch
import numpy as np
import torch.nn.functional as F
from torch import nn
from transformers import AutoImageProcessor, Swinv2ForImageClassification, SegformerForSemanticSegmentation
import matplotlib.pyplot as plt
import streamlit as st
from PIL import Image
# --- GlaucomaModel Class ---
class GlaucomaModel(object):
def __init__(self,
cls_model_path="pamixsun/swinv2_tiny_for_glaucoma_classification",
seg_model_path='pamixsun/segformer_for_optic_disc_cup_segmentation',
device=torch.device('cpu')):
self.device = device
# Classification model for glaucoma
self.cls_extractor = AutoImageProcessor.from_pretrained(cls_model_path)
self.cls_model = Swinv2ForImageClassification.from_pretrained(cls_model_path).to(device).eval()
# Segmentation model for optic disc and cup
self.seg_extractor = AutoImageProcessor.from_pretrained(seg_model_path)
self.seg_model = SegformerForSemanticSegmentation.from_pretrained(seg_model_path).to(device).eval()
# Class activation map
self.cls_id2label = self.cls_model.config.id2label
self.seg_id2label = self.seg_model.config.id2label
def glaucoma_pred(self, image):
inputs = self.cls_extractor(images=image.copy(), return_tensors="pt")
with torch.no_grad():
inputs.to(self.device)
outputs = self.cls_model(**inputs).logits
# Softmax for probabilities
probs = F.softmax(outputs, dim=-1)
disease_idx = probs.cpu()[0, :].numpy().argmax()
confidence = probs.cpu()[0, disease_idx].item()
return disease_idx, confidence
def optic_disc_cup_pred(self, image):
inputs = self.seg_extractor(images=image.copy(), return_tensors="pt")
with torch.no_grad():
inputs.to(self.device)
outputs = self.seg_model(**inputs)
logits = outputs.logits.cpu()
upsampled_logits = nn.functional.interpolate(
logits, size=image.shape[:2], mode="bilinear", align_corners=False
)
# Softmax for segmentation confidence
seg_probs = F.softmax(upsampled_logits, dim=1)
pred_disc_cup = upsampled_logits.argmax(dim=1)[0]
cup_confidence = seg_probs[0, 2, :, :].mean().item()
disc_confidence = seg_probs[0, 1, :, :].mean().item()
return pred_disc_cup.numpy().astype(np.uint8), cup_confidence, disc_confidence
def process(self, image):
image_shape = image.shape[:2]
disease_idx, cls_confidence = self.glaucoma_pred(image)
disc_cup, cup_confidence, disc_confidence = self.optic_disc_cup_pred(image)
try:
vcdr = simple_vcdr(disc_cup) # Calculate vertical cup-to-disc ratio
except:
vcdr = np.nan
# Mask for optic disc and cup
mask = (disc_cup > 0).astype(np.uint8)
# Get bounding box of the optic cup + disc and add padding
x, y, w, h = cv2.boundingRect(mask)
padding = 20 # Add padding to avoid edge distortion
x = max(x - padding, 0)
y = max(y - padding, 0)
w = min(w + 2 * padding, image.shape[1] - x)
h = min(h + 2 * padding, image.shape[0] - y)
# Ensure that the bounding box is large enough to avoid cropping errors
min_size = 50 # Define a minimum size for the cropped image
if w < min_size or h < min_size:
cropped_image = image.copy() # Fallback: if bounding box too small, return original image
else:
cropped_image = image[y:y+h, x:x+w]
# Generate disc and cup visualization
_, disc_cup_image = add_mask(image, disc_cup, [1, 2], [[0, 255, 0], [255, 0, 0]], 0.2)
return disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence, cropped_image
# --- Utility Functions ---
def simple_vcdr(mask):
"""
Simple function to calculate the vertical cup-to-disc ratio (VCDR).
Assumes:
- mask contains class 1 for optic disc and class 2 for optic cup.
"""
disc_area = np.sum(mask == 1)
cup_area = np.sum(mask == 2)
# Avoid division by zero
if disc_area == 0:
return np.nan
vcdr = cup_area / disc_area
return vcdr
def add_mask(image, mask, classes, colors, alpha=0.5):
"""
Adds a transparent mask to the original image.
Args:
- image: the original RGB image
- mask: the predicted segmentation mask
- classes: a list of class indices to apply masks for (e.g., [1, 2])
- colors: a list of colors for each class (e.g., [[0, 255, 0], [255, 0, 0]] for green and red)
- alpha: transparency level (default = 0.5)
"""
overlay = image.copy()
for class_id, color in zip(classes, colors):
overlay[mask == class_id] = color
# Blend the overlay with the original image
output = cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0)
return output, overlay
# --- Streamlit Interface ---
def main():
# Wide mode in Streamlit
st.set_page_config(layout="wide")
st.title("Glaucoma Screening from Retinal Fundus Images")
st.write('Developed by X. Sun. Find more info about me: https://pamixsun.github.io')
# Set columns for the interface
cols = st.beta_columns((1, 1, 1, 1))
cols[0].subheader("Input image")
cols[1].subheader("Optic disc and optic cup")
cols[2].subheader("Class activation map")
cols[3].subheader("Cropped Image")
# File uploader
st.sidebar.title("Image selection")
st.set_option('deprecation.showfileUploaderEncoding', False)
uploaded_file = st.sidebar.file_uploader("Upload image", type=['png', 'jpeg', 'jpg'])
if uploaded_file is not None:
# Read and display uploaded image
image = Image.open(uploaded_file).convert('RGB')
image = np.array(image).astype(np.uint8)
fig, ax = plt.subplots()
ax.imshow(image)
ax.axis('off')
cols[0].pyplot(fig)
if st.sidebar.button("Analyze image"):
if uploaded_file is None:
st.sidebar.write("Please upload an image")
else:
with st.spinner('Loading model...'):
# Load the model on available device
run_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = GlaucomaModel(device=run_device)
with st.spinner('Analyzing...'):
# Get predictions from the model
disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence, cropped_image = model.process(image)
# Display optic disc and cup image
ax.imshow(disc_cup_image)
ax.axis('off')
cols[1].pyplot(fig)
# Display classification map
ax.imshow(image)
ax.axis('off')
cols[2].pyplot(fig)
# Display the cropped image
ax.imshow(cropped_image)
ax.axis('off')
cols[3].pyplot(fig)
# Display results with confidence
st.subheader("Screening results:")
final_results_as_table = f"""
|Parameters|Outcomes|
|---|---|
|Vertical cup-to-disc ratio|{vcdr:.04f}|
|Category|{model.cls_id2label[disease_idx]} ({cls_confidence*100:.02f}% confidence)|
|Optic Cup Segmentation Confidence|{cup_confidence*100:.02f}%|
|Optic Disc Segmentation Confidence|{disc_confidence*100:.02f}%|
"""
st.markdown(final_results_as_table)
if __name__ == '__main__':
main() |