luigi12345 commited on
Commit
c30c8d7
·
verified ·
1 Parent(s): 54ff357

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -105
app.py CHANGED
@@ -8,6 +8,8 @@ import matplotlib.pyplot as plt
8
  import streamlit as st
9
  from PIL import Image
10
  import io
 
 
11
 
12
  # --- GlaucomaModel Class ---
13
  class GlaucomaModel(object):
@@ -22,20 +24,17 @@ class GlaucomaModel(object):
22
  # Segmentation model for optic disc and cup
23
  self.seg_extractor = AutoImageProcessor.from_pretrained(seg_model_path)
24
  self.seg_model = SegformerForSemanticSegmentation.from_pretrained(seg_model_path).to(device).eval()
25
-
26
- # Class activation map
27
  self.cls_id2label = self.cls_model.config.id2label
28
- self.seg_id2label = self.seg_model.config.id2label
29
 
30
  def glaucoma_pred(self, image):
31
  inputs = self.cls_extractor(images=image.copy(), return_tensors="pt")
32
  with torch.no_grad():
33
  inputs.to(self.device)
34
  outputs = self.cls_model(**inputs).logits
35
- # Softmax for probabilities
36
  probs = F.softmax(outputs, dim=-1)
37
  disease_idx = probs.cpu()[0, :].numpy().argmax()
38
- confidence = probs.cpu()[0, disease_idx].item() * 100 # Scale to percentage
39
  return disease_idx, confidence
40
 
41
  def optic_disc_cup_pred(self, image):
@@ -47,66 +46,44 @@ class GlaucomaModel(object):
47
  upsampled_logits = nn.functional.interpolate(
48
  logits, size=image.shape[:2], mode="bilinear", align_corners=False
49
  )
50
- # Softmax for segmentation confidence
51
  seg_probs = F.softmax(upsampled_logits, dim=1)
52
  pred_disc_cup = upsampled_logits.argmax(dim=1)[0]
53
- cup_confidence = seg_probs[0, 2, :, :].mean().item() * 100 # Scale to percentage
54
- disc_confidence = seg_probs[0, 1, :, :].mean().item() * 100 # Scale to percentage
55
  return pred_disc_cup.numpy().astype(np.uint8), cup_confidence, disc_confidence
56
 
57
  def process(self, image):
58
- image_shape = image.shape[:2]
59
  disease_idx, cls_confidence = self.glaucoma_pred(image)
60
  disc_cup, cup_confidence, disc_confidence = self.optic_disc_cup_pred(image)
61
 
62
  try:
63
- vcdr = simple_vcdr(disc_cup) # Calculate vertical cup-to-disc ratio
64
  except:
65
  vcdr = np.nan
66
 
67
- # Mask for optic disc and cup
68
  mask = (disc_cup > 0).astype(np.uint8)
69
-
70
- # Get bounding box of the optic cup + disc and add dynamic padding
71
  x, y, w, h = cv2.boundingRect(mask)
72
- padding = max(50, int(0.2 * max(w, h))) # Dynamic padding (20% of width or height)
73
  x = max(x - padding, 0)
74
  y = max(y - padding, 0)
75
  w = min(w + 2 * padding, image.shape[1] - x)
76
  h = min(h + 2 * padding, image.shape[0] - y)
77
 
78
- # Ensure that the bounding box is large enough to avoid cropping errors
79
  cropped_image = image[y:y+h, x:x+w] if w >= 50 and h >= 50 else image.copy()
80
-
81
- # Generate disc and cup visualization
82
  _, disc_cup_image = add_mask(image, disc_cup, [1, 2], [[0, 255, 0], [255, 0, 0]], 0.2)
83
 
84
  return disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence, cropped_image
85
 
86
  # --- Utility Functions ---
87
  def simple_vcdr(mask):
88
- """
89
- Simple function to calculate the vertical cup-to-disc ratio (VCDR).
90
- Assumes:
91
- - mask contains class 1 for optic disc and class 2 for optic cup.
92
- """
93
  disc_area = np.sum(mask == 1)
94
  cup_area = np.sum(mask == 2)
95
- if disc_area == 0: # Avoid division by zero
96
  return np.nan
97
  vcdr = cup_area / disc_area
98
  return vcdr
99
 
100
  def add_mask(image, mask, classes, colors, alpha=0.5):
101
- """
102
- Adds a transparent mask to the original image.
103
- Args:
104
- - image: the original RGB image
105
- - mask: the predicted segmentation mask
106
- - classes: a list of class indices to apply masks for (e.g., [1, 2])
107
- - colors: a list of colors for each class (e.g., [[0, 255, 0], [255, 0, 0]] for green and red)
108
- - alpha: transparency level (default = 0.5)
109
- """
110
  overlay = image.copy()
111
  for class_id, color in zip(classes, colors):
112
  overlay[mask == class_id] = color
@@ -116,79 +93,59 @@ def add_mask(image, mask, classes, colors, alpha=0.5):
116
  # --- Streamlit Interface ---
117
  def main():
118
  st.set_page_config(layout="wide")
119
- st.title("Glaucoma Screening from Retinal Fundus Images")
120
- st.write('Developed by X. Sun. Find more info about me: https://pamixsun.github.io')
121
-
122
- # Set columns for the interface
123
- cols = st.beta_columns((1, 1, 1, 1))
124
- cols[0].subheader("Input image")
125
- cols[1].subheader("Optic disc and optic cup")
126
- cols[2].subheader("Class activation map")
127
- cols[3].subheader("Cropped Image")
128
-
129
- # File uploader
130
- st.sidebar.title("Image selection")
131
- st.set_option('deprecation.showfileUploaderEncoding', False)
132
- uploaded_file = st.sidebar.file_uploader("Upload image", type=['png', 'jpeg', 'jpg'])
133
-
134
- if uploaded_file is not None:
135
- # Read and display uploaded image
136
- image = Image.open(uploaded_file).convert('RGB')
137
- image = np.array(image).astype(np.uint8)
138
- fig, ax = plt.subplots()
139
- ax.imshow(image)
140
- ax.axis('off')
141
- cols[0].pyplot(fig)
142
-
143
- if st.sidebar.button("Analyze image"):
144
- if uploaded_file is None:
145
- st.sidebar.write("Please upload an image")
146
- else:
147
- with st.spinner('Loading model...'):
148
- run_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
149
- model = GlaucomaModel(device=run_device)
150
-
151
- with st.spinner('Analyzing...'):
152
- # Get predictions from the model
153
- disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence, cropped_image = model.process(image)
154
-
155
- # Display optic disc and cup image
156
- ax.imshow(disc_cup_image)
157
- ax.axis('off')
158
- cols[1].pyplot(fig)
159
-
160
- # Display classification map
161
- ax.imshow(image)
162
- ax.axis('off')
163
- cols[2].pyplot(fig)
164
-
165
- # Display the cropped image
166
- ax.imshow(cropped_image)
167
- ax.axis('off')
168
- cols[3].pyplot(fig)
169
-
170
- # Make cropped image downloadable by converting it to bytes
171
- buf = io.BytesIO()
172
- Image.fromarray(cropped_image).save(buf, format="PNG")
173
- byte_img = buf.getvalue()
174
- st.sidebar.download_button(
175
- label="Download Cropped Image",
176
- data=byte_img,
177
- file_name="cropped_image.png",
178
- mime="image/png"
179
- )
180
-
181
- # Display results with confidence
182
- st.subheader("Screening results:")
183
- final_results_as_table = f"""
184
- |Parameters|Outcomes|
185
- |---|---|
186
- |Vertical cup-to-disc ratio|{vcdr:.04f}|
187
- |Category|{model.cls_id2label[disease_idx]} ({cls_confidence:.02f}% confidence)|
188
- |Optic Cup Segmentation Confidence|{cup_confidence:.02f}%|
189
- |Optic Disc Segmentation Confidence|{disc_confidence:.02f}%|
190
- """
191
- st.markdown(final_results_as_table)
192
 
193
  if __name__ == '__main__':
194
  main()
 
8
  import streamlit as st
9
  from PIL import Image
10
  import io
11
+ import zipfile
12
+ import os
13
 
14
  # --- GlaucomaModel Class ---
15
  class GlaucomaModel(object):
 
24
  # Segmentation model for optic disc and cup
25
  self.seg_extractor = AutoImageProcessor.from_pretrained(seg_model_path)
26
  self.seg_model = SegformerForSemanticSegmentation.from_pretrained(seg_model_path).to(device).eval()
27
+ # Mapping for class labels
 
28
  self.cls_id2label = self.cls_model.config.id2label
 
29
 
30
  def glaucoma_pred(self, image):
31
  inputs = self.cls_extractor(images=image.copy(), return_tensors="pt")
32
  with torch.no_grad():
33
  inputs.to(self.device)
34
  outputs = self.cls_model(**inputs).logits
 
35
  probs = F.softmax(outputs, dim=-1)
36
  disease_idx = probs.cpu()[0, :].numpy().argmax()
37
+ confidence = probs.cpu()[0, disease_idx].item() * 100
38
  return disease_idx, confidence
39
 
40
  def optic_disc_cup_pred(self, image):
 
46
  upsampled_logits = nn.functional.interpolate(
47
  logits, size=image.shape[:2], mode="bilinear", align_corners=False
48
  )
 
49
  seg_probs = F.softmax(upsampled_logits, dim=1)
50
  pred_disc_cup = upsampled_logits.argmax(dim=1)[0]
51
+ cup_confidence = seg_probs[0, 2, :, :].mean().item() * 100
52
+ disc_confidence = seg_probs[0, 1, :, :].mean().item() * 100
53
  return pred_disc_cup.numpy().astype(np.uint8), cup_confidence, disc_confidence
54
 
55
  def process(self, image):
 
56
  disease_idx, cls_confidence = self.glaucoma_pred(image)
57
  disc_cup, cup_confidence, disc_confidence = self.optic_disc_cup_pred(image)
58
 
59
  try:
60
+ vcdr = simple_vcdr(disc_cup)
61
  except:
62
  vcdr = np.nan
63
 
 
64
  mask = (disc_cup > 0).astype(np.uint8)
 
 
65
  x, y, w, h = cv2.boundingRect(mask)
66
+ padding = max(50, int(0.2 * max(w, h)))
67
  x = max(x - padding, 0)
68
  y = max(y - padding, 0)
69
  w = min(w + 2 * padding, image.shape[1] - x)
70
  h = min(h + 2 * padding, image.shape[0] - y)
71
 
 
72
  cropped_image = image[y:y+h, x:x+w] if w >= 50 and h >= 50 else image.copy()
 
 
73
  _, disc_cup_image = add_mask(image, disc_cup, [1, 2], [[0, 255, 0], [255, 0, 0]], 0.2)
74
 
75
  return disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence, cropped_image
76
 
77
  # --- Utility Functions ---
78
  def simple_vcdr(mask):
 
 
 
 
 
79
  disc_area = np.sum(mask == 1)
80
  cup_area = np.sum(mask == 2)
81
+ if disc_area == 0:
82
  return np.nan
83
  vcdr = cup_area / disc_area
84
  return vcdr
85
 
86
  def add_mask(image, mask, classes, colors, alpha=0.5):
 
 
 
 
 
 
 
 
 
87
  overlay = image.copy()
88
  for class_id, color in zip(classes, colors):
89
  overlay[mask == class_id] = color
 
93
  # --- Streamlit Interface ---
94
  def main():
95
  st.set_page_config(layout="wide")
96
+ st.title("Batch Glaucoma Screening from Retinal Fundus Images")
97
+
98
+ st.sidebar.title("Settings")
99
+ confidence_threshold = st.sidebar.slider("Confidence Threshold (%)", 0, 100, 70)
100
+ uploaded_files = st.sidebar.file_uploader("Upload Images", type=['png', 'jpeg', 'jpg'], accept_multiple_files=True)
101
+
102
+ confident_images = []
103
+ download_confident_images = []
104
+
105
+ if uploaded_files:
106
+ for uploaded_file in uploaded_files:
107
+ image = Image.open(uploaded_file).convert('RGB')
108
+ image_np = np.array(image).astype(np.uint8)
109
+
110
+ with st.spinner(f'Processing {uploaded_file.name}...'):
111
+ model = GlaucomaModel(device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
112
+ disease_idx, disc_cup_image, vcdr, cls_conf, cup_conf, disc_conf, cropped_image = model.process(image_np)
113
+
114
+ # Confidence-based grouping
115
+ is_confident = cls_conf >= confidence_threshold
116
+ if is_confident:
117
+ confident_images.append(uploaded_file.name)
118
+ download_confident_images.append((cropped_image, uploaded_file.name))
119
+
120
+ # Display Results
121
+ with st.expander(f"Results for {uploaded_file.name}", expanded=False):
122
+ cols = st.columns(4)
123
+ cols[0].image(image_np, caption="Input Image", use_column_width=True)
124
+ cols[1].image(disc_cup_image, caption="Disc/Cup Segmentation", use_column_width=True)
125
+ cols[2].image(image_np, caption="Class Activation Map", use_column_width=True)
126
+ cols[3].image(cropped_image, caption="Cropped Image", use_column_width=True)
127
+
128
+ st.write(f"**Vertical cup-to-disc ratio:** {vcdr:.04f}")
129
+ st.write(f"**Category:** {model.cls_id2label[disease_idx]} ({cls_conf:.02f}% confidence)")
130
+ st.write(f"**Optic Cup Segmentation Confidence:** {cup_conf:.02f}%")
131
+ st.write(f"**Optic Disc Segmentation Confidence:** {disc_conf:.02f}%")
132
+ st.write(f"**Confidence Group:** {'Confident' if is_confident else 'Not Confident'}")
133
+
134
+ # Download Button for Confident Images
135
+ if download_confident_images:
136
+ with zipfile.ZipFile("confident_cropped_images.zip", "w") as zf:
137
+ for cropped_image, name in download_confident_images:
138
+ img_buffer = io.BytesIO()
139
+ Image.fromarray(cropped_image).save(img_buffer, format="PNG")
140
+ zf.writestr(f"{name}_cropped.png", img_buffer.getvalue())
141
+
142
+ # Provide a markdown link to the ZIP file
143
+ st.sidebar.markdown(
144
+ f"[Download Confident Cropped Images](./confident_cropped_images.zip)",
145
+ unsafe_allow_html=True
146
+ )
147
+ else:
148
+ st.sidebar.info("Upload images to begin analysis.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  if __name__ == '__main__':
151
  main()