Spaces:
Running
Running
Commit
Β·
86e6672
1
Parent(s):
2958b13
upgrade to gradio4
Browse files- README.md +6 -6
- app.py +98 -86
- predictor.py +1 -1
- {val_od_examples β test_examples}/ACDC.jpg +0 -0
- {val_od_examples β test_examples}/BTCV.jpg +0 -0
- {val_od_examples β test_examples}/BUID.jpg +0 -0
- {val_od_examples β test_examples}/DRIVE.jpg +0 -0
- {val_od_examples β test_examples}/HipXRay.jpg +0 -0
- {val_od_examples β test_examples}/PanDental.jpg +0 -0
- {val_od_examples β test_examples}/SCD.jpg +0 -0
- test_examples/SCR.jpg +0 -0
- {val_od_examples β test_examples}/SpineWeb.jpg +0 -0
- {val_od_examples β test_examples}/WBC.jpg +0 -0
README.md
CHANGED
@@ -4,19 +4,19 @@ emoji: π©»
|
|
4 |
colorFrom: blue
|
5 |
colorTo: pink
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: true
|
10 |
license: apache-2.0
|
11 |
---
|
12 |
|
13 |
-
This demo uses the ScribblePrompt-UNet model described in ["ScribblePrompt: Fast and Flexible Interactive Segmentation for Any
|
14 |
|
15 |
```
|
16 |
-
@article{
|
17 |
-
title={ScribblePrompt: Fast and Flexible Interactive Segmentation for Any
|
18 |
author={Hallee E. Wong and Marianne Rakic and John Guttag and Adrian V. Dalca},
|
19 |
-
journal={
|
20 |
-
year={
|
21 |
}
|
22 |
```
|
|
|
4 |
colorFrom: blue
|
5 |
colorTo: pink
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.41.0
|
8 |
app_file: app.py
|
9 |
pinned: true
|
10 |
license: apache-2.0
|
11 |
---
|
12 |
|
13 |
+
This demo uses the ScribblePrompt-UNet model described in ["ScribblePrompt: Fast and Flexible Interactive Segmentation for Any Biomedical Image"](https://arxiv.org/abs/2312.07381)
|
14 |
|
15 |
```
|
16 |
+
@article{wong2024scribbleprompt,
|
17 |
+
title={ScribblePrompt: Fast and Flexible Interactive Segmentation for Any Biomedical Image},
|
18 |
author={Hallee E. Wong and Marianne Rakic and John Guttag and Adrian V. Dalca},
|
19 |
+
journal={European Conference on Computer Vision (ECCV)},
|
20 |
+
year={2024},
|
21 |
}
|
22 |
```
|
app.py
CHANGED
@@ -5,20 +5,19 @@ import torch.nn.functional as F
|
|
5 |
import os
|
6 |
import cv2
|
7 |
import pathlib
|
|
|
8 |
|
9 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
10 |
|
11 |
from predictor import Predictor
|
12 |
|
13 |
-
|
|
|
14 |
|
15 |
test_example_dir = pathlib.Path("./test_examples")
|
16 |
test_examples = [str(test_example_dir / x) for x in sorted(os.listdir(test_example_dir))]
|
17 |
|
18 |
-
|
19 |
-
val_examples = [str(val_example_dir / x) for x in sorted(os.listdir(val_example_dir))]
|
20 |
-
|
21 |
-
default_example = test_example_dir / "TotalSegmentator_2.jpg"
|
22 |
exp_dir = pathlib.Path('./checkpoints')
|
23 |
default_model = 'ScribblePrompt-Unet'
|
24 |
|
@@ -82,7 +81,7 @@ def image_overlay(img, mask=None, scribbles=None, contour=False, alpha=0.5):
|
|
82 |
|
83 |
if contour:
|
84 |
contours = cv2.findContours((mask[...,None]>0.5).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
85 |
-
cv2.drawContours(output, contours[0], -1, (0, 255, 0),
|
86 |
else:
|
87 |
mask_overlay = _get_overlay(img, mask)
|
88 |
mask2 = 0.5*np.repeat(mask[...,None], 3, axis=-1)
|
@@ -111,26 +110,29 @@ def viz_pred_mask(img, mask=None, point_coords=None, point_labels=None, bbox_coo
|
|
111 |
|
112 |
out = image_overlay(img, mask=mask, scribbles=seperate_scribble_masks)
|
113 |
|
|
|
|
|
|
|
114 |
if point_coords is not None:
|
115 |
for i,(col,row) in enumerate(point_coords):
|
116 |
if point_labels[i] == 1:
|
117 |
-
cv2.circle(out,(col, row),
|
118 |
else:
|
119 |
-
cv2.circle(out,(col, row),
|
120 |
|
121 |
if bbox_coords is not None:
|
122 |
for i in range(len(bbox_coords)//2):
|
123 |
-
cv2.rectangle(out, bbox_coords[2*i], bbox_coords[2*i+1], (255,165,0),
|
124 |
if len(bbox_coords) % 2 == 1:
|
125 |
cv2.circle(out, tuple(bbox_coords[-1]), 2, (255,165,0), -1)
|
126 |
|
127 |
-
return out
|
128 |
|
129 |
# -----------------------------------------------------------------------------
|
130 |
# Collect scribbles
|
131 |
# -----------------------------------------------------------------------------
|
132 |
|
133 |
-
def get_scribbles(seperate_scribble_masks, last_scribble_mask, scribble_img
|
134 |
"""
|
135 |
Record scribbles
|
136 |
"""
|
@@ -138,28 +140,19 @@ def get_scribbles(seperate_scribble_masks, last_scribble_mask, scribble_img, lab
|
|
138 |
|
139 |
if scribble_img is not None:
|
140 |
|
141 |
-
|
142 |
-
|
143 |
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
# In case any scribbles were removed
|
150 |
-
corrected_scribble_masks = np.stack(2*[(scribble_mask > 0)], axis=0)*seperate_scribble_masks
|
151 |
-
corrected_last_scribble_mask = last_scribble_mask*(scribble_mask > 0)
|
152 |
-
|
153 |
-
delta = (scribble_mask - corrected_last_scribble_mask) > 0
|
154 |
-
new_scribbles = scribble_mask * delta
|
155 |
-
corrected_scribble_masks[label,...] = np.clip(corrected_scribble_masks[label,...] + new_scribbles, a_min=0, a_max=1)
|
156 |
-
|
157 |
-
last_scribble_mask = scribble_mask
|
158 |
-
seperate_scribble_masks = corrected_scribble_masks
|
159 |
|
160 |
return seperate_scribble_masks, last_scribble_mask
|
161 |
|
162 |
-
def get_predictions(predictor, input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks,
|
|
|
163 |
"""
|
164 |
Make predictions
|
165 |
"""
|
@@ -194,8 +187,7 @@ def refresh_predictions(predictor, input_img, output_img, click_coords, click_la
|
|
194 |
|
195 |
# Record any new scribbles
|
196 |
seperate_scribble_masks, last_scribble_mask = get_scribbles(
|
197 |
-
seperate_scribble_masks, last_scribble_mask, scribble_img
|
198 |
-
label=(0 if brush_label == "Positive (green)" else 1) # current color of the brush
|
199 |
)
|
200 |
|
201 |
# Make prediction
|
@@ -206,12 +198,33 @@ def refresh_predictions(predictor, input_img, output_img, click_coords, click_la
|
|
206 |
# Update input visualizations
|
207 |
mask_to_viz = best_mask.numpy()
|
208 |
click_input_viz = viz_pred_mask(input_img, mask_to_viz, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox)
|
209 |
-
scribble_input_viz = viz_pred_mask(input_img, mask_to_viz, click_coords, click_labels, bbox_coords, None, binary_checkbox)
|
210 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
out_viz = [
|
212 |
viz_pred_mask(input_img, mask_to_viz, point_coords=None, point_labels=None, bbox_coords=None, seperate_scribble_masks=None, binary=binary_checkbox),
|
213 |
input_img,
|
214 |
-
|
215 |
]
|
216 |
|
217 |
return click_input_viz, scribble_input_viz, out_viz, best_mask, low_res_mask, img_features, seperate_scribble_masks, last_scribble_mask
|
@@ -298,8 +311,8 @@ def undo_click(predictor, input_img, brush_label, bbox_label, best_mask, low_res
|
|
298 |
with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as demo:
|
299 |
|
300 |
# State variables
|
301 |
-
seperate_scribble_masks = gr.State(np.zeros((2,
|
302 |
-
last_scribble_mask = gr.State(np.zeros((
|
303 |
|
304 |
click_coords = gr.State([])
|
305 |
click_labels = gr.State([])
|
@@ -312,10 +325,11 @@ with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as de
|
|
312 |
low_res_mask = gr.State(None)
|
313 |
|
314 |
gr.HTML("""\
|
315 |
-
<h1 style="text-align: center; font-size: 28pt;">ScribblePrompt: Fast and Flexible Interactive Segmention for Any
|
316 |
-
<p style="text-align: center; font-size: large;"
|
317 |
-
|
318 |
-
|
|
|
319 |
""")
|
320 |
|
321 |
with gr.Accordion("Open for instructions!", open=False):
|
@@ -351,34 +365,42 @@ with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as de
|
|
351 |
value="Positive (green)", label="Scribble/Click Label")
|
352 |
bbox_label = gr.Checkbox(value=False, label="Bounding Box (2 clicks)")
|
353 |
with gr.Column(scale=1):
|
|
|
354 |
binary_checkbox = gr.Checkbox(value=True, label="Show binary masks", visible=False)
|
355 |
autopredict_checkbox = gr.Checkbox(value=True, label="Auto-update prediction on clicks")
|
356 |
-
gr.
|
|
|
357 |
multimask_mode = gr.Checkbox(value=True, label="Multi-mask mode", visible=False)
|
358 |
|
359 |
with gr.Row():
|
360 |
display_height = 500
|
361 |
|
|
|
|
|
|
|
362 |
with gr.Column(scale=1):
|
363 |
with gr.Tab("Scribbles"):
|
364 |
-
scribble_img = gr.
|
365 |
label="Input",
|
366 |
-
|
367 |
-
|
368 |
-
brush_color="#00FF00",
|
369 |
-
tool="sketch",
|
370 |
-
height=display_height,
|
371 |
type='numpy',
|
372 |
-
value=default_example,
|
|
|
|
|
|
|
|
|
373 |
)
|
374 |
-
clear_scribble_button = gr.ClearButton([scribble_img], value="Clear Scribbles", variant="stop")
|
375 |
|
376 |
with gr.Tab("Clicks/Boxes") as click_tab:
|
377 |
click_img = gr.Image(
|
378 |
label="Input",
|
379 |
type='numpy',
|
380 |
value=default_example,
|
381 |
-
|
|
|
|
|
|
|
382 |
)
|
383 |
with gr.Row():
|
384 |
undo_click_button = gr.Button("Undo Last Click")
|
@@ -388,21 +410,20 @@ with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as de
|
|
388 |
input_img = gr.Image(
|
389 |
label="Input",
|
390 |
image_mode="L",
|
391 |
-
visible=True,
|
392 |
value=default_example,
|
393 |
-
height=display_height
|
394 |
)
|
395 |
gr.Markdown("To upload your own image: click the `x` in the top right corner to clear the current image, then drag & drop")
|
396 |
|
397 |
with gr.Column(scale=1):
|
398 |
with gr.Tab("Output"):
|
399 |
output_img = gr.Gallery(
|
400 |
-
label='
|
401 |
columns=1,
|
402 |
elem_id="gallery",
|
403 |
preview=True,
|
404 |
object_fit="scale-down",
|
405 |
-
height=display_height
|
406 |
)
|
407 |
|
408 |
submit_button = gr.Button("Refresh Prediction", variant='primary')
|
@@ -424,28 +445,9 @@ with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as de
|
|
424 |
|
425 |
gr.Examples(examples=test_examples,
|
426 |
inputs=[input_img],
|
427 |
-
examples_per_page=
|
428 |
-
label='
|
429 |
-
)
|
430 |
-
|
431 |
-
gr.Examples(examples=val_examples,
|
432 |
-
inputs=[input_img],
|
433 |
-
examples_per_page=10,
|
434 |
-
label='Unseen Examples from Validation Datasets'
|
435 |
)
|
436 |
-
|
437 |
-
# When clear scribble button is clicked
|
438 |
-
def clear_scribble_history(input_img):
|
439 |
-
if input_img is not None:
|
440 |
-
input_shape = input_img.shape[:2]
|
441 |
-
else:
|
442 |
-
input_shape = (RES, RES)
|
443 |
-
return input_img, input_img, np.zeros((2,)+input_shape, dtype=np.float32), np.zeros(input_shape, dtype=np.float32), None, None
|
444 |
-
|
445 |
-
clear_scribble_button.click(clear_scribble_history,
|
446 |
-
inputs=[input_img],
|
447 |
-
outputs=[click_img, scribble_img, seperate_scribble_masks, last_scribble_mask, best_mask, low_res_mask]
|
448 |
-
)
|
449 |
|
450 |
# When clear clicks button is clicked
|
451 |
def clear_click_history(input_img):
|
@@ -460,9 +462,25 @@ with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as de
|
|
460 |
if input_img is not None:
|
461 |
input_shape = input_img.shape[:2]
|
462 |
else:
|
463 |
-
input_shape = (
|
464 |
return input_img, input_img, [], [], [], [], np.zeros((2,)+input_shape, dtype=np.float32), np.zeros(input_shape, dtype=np.float32), None, None, None
|
465 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
466 |
input_img.change(clear_all_history,
|
467 |
inputs=[input_img],
|
468 |
outputs=[click_img, scribble_img,
|
@@ -527,7 +545,8 @@ with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as de
|
|
527 |
undo_click_button.click(fn=undo_click,
|
528 |
inputs=[
|
529 |
predictor,
|
530 |
-
input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords,
|
|
|
531 |
seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
|
532 |
output_img, binary_checkbox, multimask_mode, autopredict_checkbox
|
533 |
],
|
@@ -542,8 +561,7 @@ with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as de
|
|
542 |
Draw scribbles in the click canvas
|
543 |
"""
|
544 |
seperate_scribble_masks, last_scribble_mask = get_scribbles(
|
545 |
-
seperate_scribble_masks, last_scribble_mask, scribble_img
|
546 |
-
label=(0 if brush_label == "Positive (green)" else 1) # previous color of the brush
|
547 |
)
|
548 |
click_input_viz = viz_pred_mask(
|
549 |
input_img, best_mask, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox
|
@@ -566,17 +584,11 @@ with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as de
|
|
566 |
Recorn new scribbles when changing brush color
|
567 |
"""
|
568 |
if label == "Negative (red)":
|
569 |
-
brush_update = gr.
|
570 |
elif label == "Positive (green)":
|
571 |
-
brush_update = gr.
|
572 |
else:
|
573 |
raise TypeError("Invalid brush color")
|
574 |
-
|
575 |
-
# Record latest scribbles
|
576 |
-
seperate_scribble_masks, last_scribble_mask = get_scribbles(
|
577 |
-
seperate_scribble_masks, last_scribble_mask, scribble_img,
|
578 |
-
label=(1 if label == "Positive (green)" else 0) # previous color of the brush
|
579 |
-
)
|
580 |
|
581 |
return seperate_scribble_masks, last_scribble_mask, brush_update
|
582 |
|
|
|
5 |
import os
|
6 |
import cv2
|
7 |
import pathlib
|
8 |
+
import math
|
9 |
|
10 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
11 |
|
12 |
from predictor import Predictor
|
13 |
|
14 |
+
H = 256
|
15 |
+
W = 256
|
16 |
|
17 |
test_example_dir = pathlib.Path("./test_examples")
|
18 |
test_examples = [str(test_example_dir / x) for x in sorted(os.listdir(test_example_dir))]
|
19 |
|
20 |
+
default_example = test_examples[0]
|
|
|
|
|
|
|
21 |
exp_dir = pathlib.Path('./checkpoints')
|
22 |
default_model = 'ScribblePrompt-Unet'
|
23 |
|
|
|
81 |
|
82 |
if contour:
|
83 |
contours = cv2.findContours((mask[...,None]>0.5).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
84 |
+
cv2.drawContours(output, contours[0], -1, (0, 255, 0), 2)
|
85 |
else:
|
86 |
mask_overlay = _get_overlay(img, mask)
|
87 |
mask2 = 0.5*np.repeat(mask[...,None], 3, axis=-1)
|
|
|
110 |
|
111 |
out = image_overlay(img, mask=mask, scribbles=seperate_scribble_masks)
|
112 |
|
113 |
+
H,W = img.shape[:2]
|
114 |
+
marker_size = min(H,W)//100
|
115 |
+
|
116 |
if point_coords is not None:
|
117 |
for i,(col,row) in enumerate(point_coords):
|
118 |
if point_labels[i] == 1:
|
119 |
+
cv2.circle(out,(col, row), marker_size, (0,255,0), -1)
|
120 |
else:
|
121 |
+
cv2.circle(out,(col, row), marker_size, (255,0,0), -1)
|
122 |
|
123 |
if bbox_coords is not None:
|
124 |
for i in range(len(bbox_coords)//2):
|
125 |
+
cv2.rectangle(out, bbox_coords[2*i], bbox_coords[2*i+1], (255,165,0), marker_size)
|
126 |
if len(bbox_coords) % 2 == 1:
|
127 |
cv2.circle(out, tuple(bbox_coords[-1]), 2, (255,165,0), -1)
|
128 |
|
129 |
+
return out.astype(np.uint8)
|
130 |
|
131 |
# -----------------------------------------------------------------------------
|
132 |
# Collect scribbles
|
133 |
# -----------------------------------------------------------------------------
|
134 |
|
135 |
+
def get_scribbles(seperate_scribble_masks, last_scribble_mask, scribble_img):
|
136 |
"""
|
137 |
Record scribbles
|
138 |
"""
|
|
|
140 |
|
141 |
if scribble_img is not None:
|
142 |
|
143 |
+
# Only use first layer
|
144 |
+
color_mask = scribble_img.get('layers')[0]
|
145 |
|
146 |
+
positive_scribbles = 1.0*(color_mask[...,1] > 128)
|
147 |
+
negative_scribbles = 1.0*(color_mask[...,0] > 128)
|
148 |
+
|
149 |
+
seperate_scribble_masks = np.stack([positive_scribbles, negative_scribbles], axis=0)
|
150 |
+
last_scribble_mask = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
|
152 |
return seperate_scribble_masks, last_scribble_mask
|
153 |
|
154 |
+
def get_predictions(predictor, input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks,
|
155 |
+
low_res_mask, img_features, multimask_mode):
|
156 |
"""
|
157 |
Make predictions
|
158 |
"""
|
|
|
187 |
|
188 |
# Record any new scribbles
|
189 |
seperate_scribble_masks, last_scribble_mask = get_scribbles(
|
190 |
+
seperate_scribble_masks, last_scribble_mask, scribble_img
|
|
|
191 |
)
|
192 |
|
193 |
# Make prediction
|
|
|
198 |
# Update input visualizations
|
199 |
mask_to_viz = best_mask.numpy()
|
200 |
click_input_viz = viz_pred_mask(input_img, mask_to_viz, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox)
|
|
|
201 |
|
202 |
+
empty_channel = np.zeros(input_img.shape[:2]).astype(np.uint8)
|
203 |
+
full_channel = 255*np.ones(input_img.shape[:2]).astype(np.uint8)
|
204 |
+
gray_mask = (255*mask_to_viz).astype(np.uint8)
|
205 |
+
|
206 |
+
bg = viz_pred_mask(input_img, mask_to_viz, click_coords, click_labels, bbox_coords, None, binary_checkbox)
|
207 |
+
old_scribbles = scribble_img.get('layers')[0]
|
208 |
+
|
209 |
+
scribble_mask = 255*(old_scribbles > 0).any(-1)
|
210 |
+
|
211 |
+
scribble_input_viz = {
|
212 |
+
"background": np.stack([bg[...,i] for i in range(3)]+[full_channel], axis=-1),
|
213 |
+
["layers"][0]: [np.stack([
|
214 |
+
(255*seperate_scribble_masks[1]).astype(np.uint8),
|
215 |
+
(255*seperate_scribble_masks[0]).astype(np.uint8),
|
216 |
+
empty_channel,
|
217 |
+
scribble_mask
|
218 |
+
], axis=-1)],
|
219 |
+
"composite": np.stack([click_input_viz[...,i] for i in range(3)]+[empty_channel], axis=-1),
|
220 |
+
}
|
221 |
+
|
222 |
+
mask_img = 255*(mask_to_viz[...,None].repeat(axis=2, repeats=3)>0.5) if binary_checkbox else mask_to_viz[...,None].repeat(axis=2, repeats=3)
|
223 |
+
|
224 |
out_viz = [
|
225 |
viz_pred_mask(input_img, mask_to_viz, point_coords=None, point_labels=None, bbox_coords=None, seperate_scribble_masks=None, binary=binary_checkbox),
|
226 |
input_img,
|
227 |
+
mask_img,
|
228 |
]
|
229 |
|
230 |
return click_input_viz, scribble_input_viz, out_viz, best_mask, low_res_mask, img_features, seperate_scribble_masks, last_scribble_mask
|
|
|
311 |
with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as demo:
|
312 |
|
313 |
# State variables
|
314 |
+
seperate_scribble_masks = gr.State(np.zeros((2, H, W), dtype=np.float32))
|
315 |
+
last_scribble_mask = gr.State(np.zeros((H, W), dtype=np.float32))
|
316 |
|
317 |
click_coords = gr.State([])
|
318 |
click_labels = gr.State([])
|
|
|
325 |
low_res_mask = gr.State(None)
|
326 |
|
327 |
gr.HTML("""\
|
328 |
+
<h1 style="text-align: center; font-size: 28pt;">ScribblePrompt: Fast and Flexible Interactive Segmention for Any Biomedical Image</h1>
|
329 |
+
<p style="text-align: center; font-size: large;">
|
330 |
+
<b>ScribblePrompt</b> is an interactive segmentation tool designed to help users segment <b>new</b> structures in medical images using scribbles, clicks <b>and</b> bounding boxes.
|
331 |
+
[<a href="https://arxiv.org/abs/2312.07381">paper</a> | <a href="https://scribbleprompt.csail.mit.edu">website</a> | <a href="https://github.com/halleewong/ScribblePrompt">code</a>]
|
332 |
+
</p>
|
333 |
""")
|
334 |
|
335 |
with gr.Accordion("Open for instructions!", open=False):
|
|
|
365 |
value="Positive (green)", label="Scribble/Click Label")
|
366 |
bbox_label = gr.Checkbox(value=False, label="Bounding Box (2 clicks)")
|
367 |
with gr.Column(scale=1):
|
368 |
+
|
369 |
binary_checkbox = gr.Checkbox(value=True, label="Show binary masks", visible=False)
|
370 |
autopredict_checkbox = gr.Checkbox(value=True, label="Auto-update prediction on clicks")
|
371 |
+
with gr.Accordion("Troubleshooting tips", open=False):
|
372 |
+
gr.Markdown("<span style='color:orange'>If you encounter an <span style='color:orange'>error</span> try clicking 'Clear All Inputs'.")
|
373 |
multimask_mode = gr.Checkbox(value=True, label="Multi-mask mode", visible=False)
|
374 |
|
375 |
with gr.Row():
|
376 |
display_height = 500
|
377 |
|
378 |
+
green_brush = gr.Brush(colors=["#00FF00"], color_mode="fixed", default_size=2)
|
379 |
+
red_brush = gr.Brush(colors=["#FF0000"], color_mode="fixed", default_size=2)
|
380 |
+
|
381 |
with gr.Column(scale=1):
|
382 |
with gr.Tab("Scribbles"):
|
383 |
+
scribble_img = gr.ImageEditor(
|
384 |
label="Input",
|
385 |
+
image_mode="RGB",
|
386 |
+
brush=green_brush,
|
|
|
|
|
|
|
387 |
type='numpy',
|
388 |
+
value=default_example,
|
389 |
+
transforms=(),
|
390 |
+
sources=(),
|
391 |
+
show_download_button=True,
|
392 |
+
# height=display_height
|
393 |
)
|
|
|
394 |
|
395 |
with gr.Tab("Clicks/Boxes") as click_tab:
|
396 |
click_img = gr.Image(
|
397 |
label="Input",
|
398 |
type='numpy',
|
399 |
value=default_example,
|
400 |
+
show_download_button=True,
|
401 |
+
sources=(),
|
402 |
+
container=True,
|
403 |
+
# height=display_height-50
|
404 |
)
|
405 |
with gr.Row():
|
406 |
undo_click_button = gr.Button("Undo Last Click")
|
|
|
410 |
input_img = gr.Image(
|
411 |
label="Input",
|
412 |
image_mode="L",
|
|
|
413 |
value=default_example,
|
414 |
+
# height=display_height
|
415 |
)
|
416 |
gr.Markdown("To upload your own image: click the `x` in the top right corner to clear the current image, then drag & drop")
|
417 |
|
418 |
with gr.Column(scale=1):
|
419 |
with gr.Tab("Output"):
|
420 |
output_img = gr.Gallery(
|
421 |
+
label='Output',
|
422 |
columns=1,
|
423 |
elem_id="gallery",
|
424 |
preview=True,
|
425 |
object_fit="scale-down",
|
426 |
+
# height=display_height
|
427 |
)
|
428 |
|
429 |
submit_button = gr.Button("Refresh Prediction", variant='primary')
|
|
|
445 |
|
446 |
gr.Examples(examples=test_examples,
|
447 |
inputs=[input_img],
|
448 |
+
examples_per_page=12,
|
449 |
+
label='Examples from datasets unseen during training'
|
|
|
|
|
|
|
|
|
|
|
|
|
450 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
451 |
|
452 |
# When clear clicks button is clicked
|
453 |
def clear_click_history(input_img):
|
|
|
462 |
if input_img is not None:
|
463 |
input_shape = input_img.shape[:2]
|
464 |
else:
|
465 |
+
input_shape = (H, W)
|
466 |
return input_img, input_img, [], [], [], [], np.zeros((2,)+input_shape, dtype=np.float32), np.zeros(input_shape, dtype=np.float32), None, None, None
|
467 |
|
468 |
+
# def clear_history_and_pad_input(input_img):
|
469 |
+
# if input_img is not None:
|
470 |
+
# h,w = input_img.shape[:2]
|
471 |
+
# if h != w:
|
472 |
+
# # Pad to square
|
473 |
+
# pad = abs(h-w)
|
474 |
+
# if h > w:
|
475 |
+
# padding = [(0,0), (math.ceil(pad/2),math.floor(pad/2))]
|
476 |
+
# else:
|
477 |
+
# padding = [(math.ceil(pad/2),math.floor(pad/2)), (0,0)]
|
478 |
+
|
479 |
+
# input_img = np.pad(input_img, padding, mode='constant', constant_values=0)
|
480 |
+
|
481 |
+
# return clear_all_history(input_img)
|
482 |
+
|
483 |
+
|
484 |
input_img.change(clear_all_history,
|
485 |
inputs=[input_img],
|
486 |
outputs=[click_img, scribble_img,
|
|
|
545 |
undo_click_button.click(fn=undo_click,
|
546 |
inputs=[
|
547 |
predictor,
|
548 |
+
input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords,
|
549 |
+
click_labels, bbox_coords,
|
550 |
seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
|
551 |
output_img, binary_checkbox, multimask_mode, autopredict_checkbox
|
552 |
],
|
|
|
561 |
Draw scribbles in the click canvas
|
562 |
"""
|
563 |
seperate_scribble_masks, last_scribble_mask = get_scribbles(
|
564 |
+
seperate_scribble_masks, last_scribble_mask, scribble_img
|
|
|
565 |
)
|
566 |
click_input_viz = viz_pred_mask(
|
567 |
input_img, best_mask, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox
|
|
|
584 |
Recorn new scribbles when changing brush color
|
585 |
"""
|
586 |
if label == "Negative (red)":
|
587 |
+
brush_update = gr.update(brush=red_brush)
|
588 |
elif label == "Positive (green)":
|
589 |
+
brush_update = gr.update(brush=green_brush)
|
590 |
else:
|
591 |
raise TypeError("Invalid brush color")
|
|
|
|
|
|
|
|
|
|
|
|
|
592 |
|
593 |
return seperate_scribble_masks, last_scribble_mask, brush_update
|
594 |
|
predictor.py
CHANGED
@@ -3,7 +3,6 @@ import torch.nn.functional as F
|
|
3 |
from typing import Dict, Tuple, Optional
|
4 |
import network
|
5 |
|
6 |
-
|
7 |
class Predictor:
|
8 |
"""
|
9 |
Wrapper for ScribblePrompt Unet model
|
@@ -96,6 +95,7 @@ def rescale_inputs(inputs: Dict[str,any], res=128):
|
|
96 |
Rescale the inputs
|
97 |
"""
|
98 |
h,w = inputs['img'].shape[-2:]
|
|
|
99 |
if h != res or w != res:
|
100 |
|
101 |
inputs.update(dict(
|
|
|
3 |
from typing import Dict, Tuple, Optional
|
4 |
import network
|
5 |
|
|
|
6 |
class Predictor:
|
7 |
"""
|
8 |
Wrapper for ScribblePrompt Unet model
|
|
|
95 |
Rescale the inputs
|
96 |
"""
|
97 |
h,w = inputs['img'].shape[-2:]
|
98 |
+
|
99 |
if h != res or w != res:
|
100 |
|
101 |
inputs.update(dict(
|
{val_od_examples β test_examples}/ACDC.jpg
RENAMED
File without changes
|
{val_od_examples β test_examples}/BTCV.jpg
RENAMED
File without changes
|
{val_od_examples β test_examples}/BUID.jpg
RENAMED
File without changes
|
{val_od_examples β test_examples}/DRIVE.jpg
RENAMED
File without changes
|
{val_od_examples β test_examples}/HipXRay.jpg
RENAMED
File without changes
|
{val_od_examples β test_examples}/PanDental.jpg
RENAMED
File without changes
|
{val_od_examples β test_examples}/SCD.jpg
RENAMED
File without changes
|
test_examples/SCR.jpg
DELETED
Binary file (13.6 kB)
|
|
{val_od_examples β test_examples}/SpineWeb.jpg
RENAMED
File without changes
|
{val_od_examples β test_examples}/WBC.jpg
RENAMED
File without changes
|