halleewong commited on
Commit
b6bb35e
1 Parent(s): 26e8a2f

initial commit

Browse files
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
- title: ScribblePrompt
3
- emoji: 👁
4
  colorFrom: blue
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 4.8.0
8
  app_file: app.py
9
- pinned: false
10
  license: apache-2.0
11
  ---
12
 
 
1
  ---
2
+ title: Scribbleprompt
3
+ emoji: 🩻
4
  colorFrom: blue
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 3.41.0
8
  app_file: app.py
9
+ pinned: true
10
  license: apache-2.0
11
  ---
12
 
app.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import os
6
+ import cv2
7
+ import pathlib
8
+
9
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
+
11
+ from predictor import Predictor
12
+
13
+ RES = 256
14
+
15
+ test_example_dir = pathlib.Path("./test_examples")
16
+ test_examples = [str(test_example_dir / x) for x in sorted(os.listdir(test_example_dir))]
17
+
18
+ val_example_dir = pathlib.Path("./val_od_examples")
19
+ val_examples = [str(val_example_dir / x) for x in sorted(os.listdir(val_example_dir))]
20
+
21
+ default_example = test_example_dir / "TotalSegmentator_2.jpg"
22
+ exp_dir = pathlib.Path('./checkpoints')
23
+ default_model = 'ScribblePrompt-Unet'
24
+
25
+ model_dict = {
26
+ 'ScribblePrompt-Unet': 'ScribblePrompt_unet_v1_nf192_res128.pt'
27
+ }
28
+
29
+ # -----------------------------------------------------------------------------
30
+ # Model initialization functions
31
+ # -----------------------------------------------------------------------------
32
+
33
+ def load_model(exp_key: str = default_model):
34
+ fpath = exp_dir / model_dict.get(exp_key)
35
+ exp = Predictor(fpath)
36
+ return exp, None
37
+
38
+ # -----------------------------------------------------------------------------
39
+ # Vizualization functions
40
+ # -----------------------------------------------------------------------------
41
+
42
+ def _get_overlay(img, lay, const_color="l_blue"):
43
+ """
44
+ Helper function for preparing overlay
45
+ """
46
+ assert lay.ndim==2, "Overlay must be 2D, got shape: " + str(lay.shape)
47
+
48
+ if img.ndim == 2:
49
+ img = np.repeat(img[...,None], 3, axis=-1)
50
+
51
+ assert img.ndim==3, "Image must be 3D, got shape: " + str(img.shape)
52
+
53
+ if const_color == "blue":
54
+ const_color = 255*np.array([0, 0, 1])
55
+ elif const_color == "green":
56
+ const_color = 255*np.array([0, 1, 0])
57
+ elif const_color == "red":
58
+ const_color = 255*np.array([1, 0, 0])
59
+ elif const_color == "l_blue":
60
+ const_color = np.array([31, 119, 180])
61
+ elif const_color == "orange":
62
+ const_color = np.array([255, 127, 14])
63
+ else:
64
+ raise NotImplementedError
65
+
66
+ x,y = np.nonzero(lay)
67
+ for i in range(img.shape[-1]):
68
+ img[x,y,i] = const_color[i]
69
+
70
+ return img
71
+
72
+ def image_overlay(img, mask=None, scribbles=None, contour=False, alpha=0.5):
73
+ """
74
+ Overlay the ground truth mask and scribbles on the image if provided
75
+ """
76
+ assert img.ndim == 2, "Image must be 2D, got shape: " + str(img.shape)
77
+ output = np.repeat(img[...,None], 3, axis=-1)
78
+
79
+ if mask is not None:
80
+
81
+ assert mask.ndim == 2, "Mask must be 2D, got shape: " + str(mask.shape)
82
+
83
+ if contour:
84
+ contours = cv2.findContours((mask[...,None]>0.5).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
85
+ cv2.drawContours(output, contours[0], -1, (0, 255, 0), 1)
86
+ else:
87
+ mask_overlay = _get_overlay(img, mask)
88
+ mask2 = 0.5*np.repeat(mask[...,None], 3, axis=-1)
89
+ output = cv2.convertScaleAbs(mask_overlay * mask2 + output * (1 - mask2))
90
+
91
+ if scribbles is not None:
92
+ pos_scribble_overlay = _get_overlay(output, scribbles[0,...], const_color="green")
93
+ cv2.addWeighted(pos_scribble_overlay, alpha, output, 1 - alpha, 0, output)
94
+ neg_scribble_overlay = _get_overlay(output, scribbles[1,...], const_color="red")
95
+ cv2.addWeighted(neg_scribble_overlay, alpha, output, 1 - alpha, 0, output)
96
+
97
+ return output
98
+
99
+
100
+ def viz_pred_mask(img, mask=None, point_coords=None, point_labels=None, bbox_coords=None, seperate_scribble_masks=None, binary=True):
101
+ """
102
+ Visualize image with clicks, scribbles, predicted mask overlaid
103
+ """
104
+ assert isinstance(img, np.ndarray), "Image must be numpy array, got type: " + str(type(img))
105
+ if mask is not None:
106
+ if isinstance(mask, torch.Tensor):
107
+ mask = mask.cpu().numpy()
108
+
109
+ if binary and mask is not None:
110
+ mask = 1*(mask > 0.5)
111
+
112
+ out = image_overlay(img, mask=mask, scribbles=seperate_scribble_masks)
113
+
114
+ if point_coords is not None:
115
+ for i,(col,row) in enumerate(point_coords):
116
+ if point_labels[i] == 1:
117
+ cv2.circle(out,(col, row), 2, (0,255,0), -1)
118
+ else:
119
+ cv2.circle(out,(col, row), 2, (255,0,0), -1)
120
+
121
+ if bbox_coords is not None:
122
+ for i in range(len(bbox_coords)//2):
123
+ cv2.rectangle(out, bbox_coords[2*i], bbox_coords[2*i+1], (255,165,0), 1)
124
+ if len(bbox_coords) % 2 == 1:
125
+ cv2.circle(out, tuple(bbox_coords[-1]), 2, (255,165,0), -1)
126
+
127
+ return out
128
+
129
+ # -----------------------------------------------------------------------------
130
+ # Collect scribbles
131
+ # -----------------------------------------------------------------------------
132
+
133
+ def get_scribbles(seperate_scribble_masks, last_scribble_mask, scribble_img, label: int):
134
+ """
135
+ Record scribbles
136
+ """
137
+ assert isinstance(seperate_scribble_masks, np.ndarray), "seperate_scribble_masks must be numpy array, got type: " + str(type(seperate_scribble_masks))
138
+
139
+ if scribble_img is not None:
140
+
141
+ color_mask = scribble_img.get('mask')
142
+ scribble_mask = color_mask[...,0]/255
143
+
144
+ not_same = (scribble_mask != last_scribble_mask)
145
+ if not isinstance(not_same, bool):
146
+ not_same = not_same.any()
147
+
148
+ if not_same:
149
+ # In case any scribbles were removed
150
+ corrected_scribble_masks = np.stack(2*[(scribble_mask > 0)], axis=0)*seperate_scribble_masks
151
+ corrected_last_scribble_mask = last_scribble_mask*(scribble_mask > 0)
152
+
153
+ delta = (scribble_mask - corrected_last_scribble_mask) > 0
154
+ new_scribbles = scribble_mask * delta
155
+ corrected_scribble_masks[label,...] = np.clip(corrected_scribble_masks[label,...] + new_scribbles, a_min=0, a_max=1)
156
+
157
+ last_scribble_mask = scribble_mask
158
+ seperate_scribble_masks = corrected_scribble_masks
159
+
160
+ return seperate_scribble_masks, last_scribble_mask
161
+
162
+ def get_predictions(predictor, input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks, low_res_mask, img_features, multimask_mode):
163
+ """
164
+ Make predictions
165
+ """
166
+ box = None
167
+ if len(bbox_coords) == 1:
168
+ gr.Error("Please click a second time to define the bounding box")
169
+ box = None
170
+ elif len(bbox_coords) == 2:
171
+ box = torch.Tensor(bbox_coords).flatten()[None,None,...].int().to(device) # B x n x 4
172
+
173
+ if seperate_scribble_masks is not None:
174
+ scribble = torch.from_numpy(seperate_scribble_masks)[None,...].to(device)
175
+ else:
176
+ scribble = None
177
+
178
+ prompts = dict(
179
+ img=torch.from_numpy(input_img)[None,None,...].to(device)/255,
180
+ point_coords=torch.Tensor([click_coords]).int().to(device) if len(click_coords)>0 else None,
181
+ point_labels=torch.Tensor([click_labels]).int().to(device) if len(click_labels)>0 else None,
182
+ scribble=scribble,
183
+ mask_input=low_res_mask.to(device) if low_res_mask is not None else None,
184
+ box=box,
185
+ )
186
+
187
+ mask, img_features, low_res_mask = predictor.predict(prompts, img_features, multimask_mode=multimask_mode)
188
+
189
+ return mask, img_features, low_res_mask
190
+
191
+ def refresh_predictions(predictor, input_img, output_img, click_coords, click_labels, bbox_coords, brush_label,
192
+ scribble_img, seperate_scribble_masks, last_scribble_mask,
193
+ best_mask, low_res_mask, img_features, binary_checkbox, multimask_mode):
194
+
195
+ # Record any new scribbles
196
+ seperate_scribble_masks, last_scribble_mask = get_scribbles(
197
+ seperate_scribble_masks, last_scribble_mask, scribble_img,
198
+ label=(0 if brush_label == "Positive (green)" else 1) # current color of the brush
199
+ )
200
+
201
+ # Make prediction
202
+ best_mask, img_features, low_res_mask = get_predictions(
203
+ predictor, input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks, low_res_mask, img_features, multimask_mode
204
+ )
205
+
206
+ # Update input visualizations
207
+ mask_to_viz = best_mask.numpy()
208
+ click_input_viz = viz_pred_mask(input_img, mask_to_viz, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox)
209
+ scribble_input_viz = viz_pred_mask(input_img, mask_to_viz, click_coords, click_labels, bbox_coords, None, binary_checkbox)
210
+
211
+ out_viz = [
212
+ viz_pred_mask(input_img, mask_to_viz, point_coords=None, point_labels=None, bbox_coords=None, seperate_scribble_masks=None, binary=binary_checkbox),
213
+ 255*(mask_to_viz[...,None].repeat(axis=2, repeats=3)>0.5) if binary_checkbox else mask_to_viz[...,None].repeat(axis=2, repeats=3),
214
+ ]
215
+
216
+ return click_input_viz, scribble_input_viz, out_viz, best_mask, low_res_mask, img_features, seperate_scribble_masks, last_scribble_mask
217
+
218
+
219
+ def get_select_coords(predictor, input_img, brush_label, bbox_label, best_mask, low_res_mask,
220
+ click_coords, click_labels, bbox_coords,
221
+ seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
222
+ output_img, binary_checkbox, multimask_mode, autopredict_checkbox, evt: gr.SelectData):
223
+ """
224
+ Record user click and update the prediction
225
+ """
226
+ # Record click coordinates
227
+ if bbox_label:
228
+ bbox_coords.append(evt.index)
229
+ elif brush_label in ['Positive (green)', 'Negative (red)']:
230
+ click_coords.append(evt.index)
231
+ click_labels.append(1 if brush_label=='Positive (green)' else 0)
232
+ else:
233
+ raise TypeError("Invalid brush label: {brush_label}")
234
+
235
+ # Only make new prediction if not waiting for additional bounding box click
236
+ if (len(bbox_coords) % 2 == 0) and autopredict_checkbox:
237
+
238
+ click_input_viz, scribble_input_viz, output_viz, best_mask, low_res_mask, img_features, seperate_scribble_masks, last_scribble_mask = refresh_predictions(
239
+ predictor, input_img, output_img, click_coords, click_labels, bbox_coords, brush_label,
240
+ scribble_img, seperate_scribble_masks, last_scribble_mask,
241
+ best_mask, low_res_mask, img_features, binary_checkbox, multimask_mode
242
+ )
243
+ return click_input_viz, scribble_input_viz, output_viz, best_mask, low_res_mask, img_features, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask
244
+
245
+ else:
246
+ click_input_viz = viz_pred_mask(
247
+ input_img, best_mask, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox
248
+ )
249
+ scribble_input_viz = viz_pred_mask(
250
+ input_img, best_mask, click_coords, click_labels, bbox_coords, None, binary_checkbox
251
+ )
252
+ # Don't update output image if waiting for additional bounding box click
253
+ return click_input_viz, scribble_input_viz, output_img, best_mask, low_res_mask, img_features, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask
254
+
255
+
256
+ def undo_click(predictor, input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords, click_labels, bbox_coords,
257
+ seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
258
+ output_img, binary_checkbox, multimask_mode, autopredict_checkbox):
259
+ """
260
+ Remove last click and then update the prediction
261
+ """
262
+ if bbox_label:
263
+ if len(bbox_coords) > 0:
264
+ bbox_coords.pop()
265
+ elif brush_label in ['Positive (green)', 'Negative (red)']:
266
+ if len(click_coords) > 0:
267
+ click_coords.pop()
268
+ click_labels.pop()
269
+ else:
270
+ raise TypeError("Invalid brush label: {brush_label}")
271
+
272
+ # Only make new prediction if not waiting for additional bounding box click
273
+ if (len(bbox_coords)==0 or len(bbox_coords)==2) and autopredict_checkbox:
274
+
275
+ click_input_viz, scribble_input_viz, output_viz, best_mask, low_res_mask, img_features, seperate_scribble_masks, last_scribble_mask = refresh_predictions(
276
+ predictor, input_img, output_img, click_coords, click_labels, bbox_coords, brush_label,
277
+ scribble_img, seperate_scribble_masks, last_scribble_mask,
278
+ best_mask, low_res_mask, img_features, binary_checkbox, multimask_mode
279
+ )
280
+ return click_input_viz, scribble_input_viz, output_viz, best_mask, low_res_mask, img_features, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask
281
+
282
+ else:
283
+ click_input_viz = viz_pred_mask(
284
+ input_img, best_mask, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox
285
+ )
286
+ scribble_input_viz = viz_pred_mask(
287
+ input_img, best_mask, click_coords, click_labels, bbox_coords, None, binary_checkbox
288
+ )
289
+
290
+ # Don't update output image if waiting for additional bounding box click
291
+ return click_input_viz, scribble_input_viz, output_img, best_mask, low_res_mask, img_features, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask
292
+
293
+
294
+
295
+ # --------------------------------------------------
296
+
297
+ with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as demo:
298
+
299
+ # State variables
300
+ seperate_scribble_masks = gr.State(np.zeros((2,RES,RES), dtype=np.float32))
301
+ last_scribble_mask = gr.State(np.zeros((RES,RES), dtype=np.float32))
302
+
303
+ click_coords = gr.State([])
304
+ click_labels = gr.State([])
305
+ bbox_coords = gr.State([])
306
+
307
+ # Load default model
308
+ predictor = gr.State(load_model()[0])
309
+ img_features = gr.State(None) # For SAM models
310
+ best_mask = gr.State(None)
311
+ low_res_mask = gr.State(None)
312
+
313
+ gr.HTML("""\
314
+ <h1 style="text-align: center; font-size: 28pt;">ScribblePrompt: Fast and Flexible Interactive Segmention for Any Medical Image</h1>
315
+ <p style="text-align: center; font-size: large;"><a href="https://scribbleprompt.csail.mit.edu">ScribblePrompt</a> is an interactive segmentation tool designed to help users segment <b>new</b> structures in medical images using scribbles, clicks <b>and</b> bounding boxes.
316
+ </p>
317
+
318
+ """)
319
+
320
+ with gr.Accordion("Open for instructions!", open=False):
321
+ gr.Markdown(
322
+ """
323
+ * Select an input image from the examples below or upload your own image through the <b>'Input Image'</b> tab.
324
+ * Use the <b>'Scribbles'</b> tab to draw <span style='color:green'>positive</span> or <span style='color:red'>negative</span> scribbles.
325
+ - Use the buttons in the top right hand corner of the canvas to undo or adjust the brush size
326
+ - Note: the app cannot detect new scribbles drawn on top of previous scribbles in a different color. Please undo/erase the scribble before drawing on the same pixel in a different color.
327
+ * Use the <b>'Clicks/Boxes'</b> tab to draw <span style='color:green'>positive</span> or <span style='color:red'>negative</span> clicks and <span style='color:orange'>bounding boxes</span> by placing two clicks.
328
+ * The <b>'Output'</b> tab will show the model's prediction based on your current inputs and the previous prediction.
329
+ * The <b>'Clear Input Mask'</b> button will clear the latest prediction (which is used as an input to the model).
330
+ * The <b>'Clear All Inputs'</b> button will clear all inputs (including scribbles, clicks, bounding boxes, and the last prediction).
331
+ """
332
+ )
333
+
334
+
335
+ # Interface ------------------------------------
336
+
337
+ with gr.Row():
338
+ model_dropdown = gr.Dropdown(
339
+ label="Model",
340
+ choices = list(model_dict.keys()),
341
+ value=default_model,
342
+ multiselect=False,
343
+ interactive=False,
344
+ visible=False
345
+ )
346
+
347
+ with gr.Row():
348
+ with gr.Column(scale=1):
349
+ brush_label = gr.Radio(["Positive (green)", "Negative (red)"],
350
+ value="Positive (green)", label="Scribble/Click Label")
351
+ bbox_label = gr.Checkbox(value=False, label="Bounding Box (2 clicks)")
352
+ with gr.Column(scale=1):
353
+ binary_checkbox = gr.Checkbox(value=True, label="Show binary masks", visible=False)
354
+ autopredict_checkbox = gr.Checkbox(value=True, label="Auto-update prediction on clicks")
355
+ gr.Markdown("<span style='color:orange'>Troubleshooting:</span> If the image does not fully load in the Scribbles tab, click 'Clear Scribbles' or 'Clear All Inputs' to reload (it make take multiple tries). If you encounter an <span style='color:orange'>error</span> try clicking 'Clear All Inputs'.")
356
+ multimask_mode = gr.Checkbox(value=True, label="Multi-mask mode", visible=False)
357
+
358
+ with gr.Row():
359
+ display_height = 500
360
+
361
+ with gr.Column(scale=1):
362
+ with gr.Tab("Scribbles"):
363
+ scribble_img = gr.Image(
364
+ label="Input",
365
+ brush_radius=3,
366
+ interactive=True,
367
+ brush_color="#00FF00",
368
+ tool="sketch",
369
+ height=display_height,
370
+ type='numpy',
371
+ value=default_example,
372
+ )
373
+ clear_scribble_button = gr.ClearButton([scribble_img], value="Clear Scribbles", variant="stop")
374
+
375
+ with gr.Tab("Clicks/Boxes") as click_tab:
376
+ click_img = gr.Image(
377
+ label="Input",
378
+ type='numpy',
379
+ value=default_example,
380
+ height=display_height
381
+ )
382
+ with gr.Row():
383
+ undo_click_button = gr.Button("Undo Last Click")
384
+ clear_click_button = gr.Button("Clear Clicks/Boxes", variant="stop")
385
+
386
+ with gr.Tab("Input Image"):
387
+ input_img = gr.Image(
388
+ label="Input",
389
+ image_mode="L",
390
+ visible=True,
391
+ value=default_example,
392
+ height=display_height
393
+ )
394
+ gr.Markdown("To upload your own image: click the `x` in the top right corner to clear the current image, then drag & drop")
395
+
396
+ with gr.Column(scale=1):
397
+ with gr.Tab("Output"):
398
+ output_img = gr.Gallery(
399
+ label='Outputs',
400
+ columns=1,
401
+ elem_id="gallery",
402
+ preview=True,
403
+ object_fit="scale-down",
404
+ height=display_height+50
405
+ )
406
+
407
+ submit_button = gr.Button("Refresh Prediction", variant='primary')
408
+ clear_all_button = gr.ClearButton([scribble_img], value="Clear All Inputs", variant="stop")
409
+ clear_mask_button = gr.Button("Clear Input Mask")
410
+
411
+ # ----------------------------------------------
412
+ # Loading Models
413
+ # ----------------------------------------------
414
+
415
+ model_dropdown.change(fn=load_model,
416
+ inputs=[model_dropdown],
417
+ outputs=[predictor, img_features]
418
+ )
419
+
420
+ # ----------------------------------------------
421
+ # Loading Examples
422
+ # ----------------------------------------------
423
+
424
+ gr.Examples(examples=test_examples,
425
+ inputs=[input_img],
426
+ examples_per_page=10,
427
+ label='Unseen Examples from Test Datasets'
428
+ )
429
+
430
+ gr.Examples(examples=val_examples,
431
+ inputs=[input_img],
432
+ examples_per_page=10,
433
+ label='Unseen Examples from Validation Datasets'
434
+ )
435
+
436
+ # When clear scribble button is clicked
437
+ def clear_scribble_history(input_img):
438
+ if input_img is not None:
439
+ input_shape = input_img.shape[:2]
440
+ else:
441
+ input_shape = (RES, RES)
442
+ return input_img, input_img, np.zeros((2,)+input_shape, dtype=np.float32), np.zeros(input_shape, dtype=np.float32), None, None
443
+
444
+ clear_scribble_button.click(clear_scribble_history,
445
+ inputs=[input_img],
446
+ outputs=[click_img, scribble_img, seperate_scribble_masks, last_scribble_mask, best_mask, low_res_mask]
447
+ )
448
+
449
+ # When clear clicks button is clicked
450
+ def clear_click_history(input_img):
451
+ return input_img, input_img, [], [], [], None, None
452
+
453
+ clear_click_button.click(clear_click_history,
454
+ inputs=[input_img],
455
+ outputs=[click_img, scribble_img, click_coords, click_labels, bbox_coords, best_mask, low_res_mask])
456
+
457
+ # When clear all button is clicked
458
+ def clear_all_history(input_img):
459
+ if input_img is not None:
460
+ input_shape = input_img.shape[:2]
461
+ else:
462
+ input_shape = (RES, RES)
463
+ return input_img, input_img, [], [], [], [], np.zeros((2,)+input_shape, dtype=np.float32), np.zeros(input_shape, dtype=np.float32), None, None, None
464
+
465
+ input_img.change(clear_all_history,
466
+ inputs=[input_img],
467
+ outputs=[click_img, scribble_img,
468
+ output_img, click_coords, click_labels, bbox_coords,
469
+ seperate_scribble_masks, last_scribble_mask,
470
+ best_mask, low_res_mask, img_features
471
+ ])
472
+
473
+ clear_all_button.click(clear_all_history,
474
+ inputs=[input_img],
475
+ outputs=[click_img, scribble_img,
476
+ output_img, click_coords, click_labels, bbox_coords,
477
+ seperate_scribble_masks, last_scribble_mask,
478
+ best_mask, low_res_mask, img_features
479
+ ])
480
+
481
+ # clear previous prediction mask
482
+ def clear_best_mask(input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks):
483
+
484
+ click_input_viz = viz_pred_mask(
485
+ input_img, None, click_coords, click_labels, bbox_coords, seperate_scribble_masks
486
+ )
487
+ scribble_input_viz = viz_pred_mask(
488
+ input_img, None, click_coords, click_labels, bbox_coords, None
489
+ )
490
+
491
+ return None, None, click_input_viz, scribble_input_viz
492
+
493
+ clear_mask_button.click(
494
+ clear_best_mask,
495
+ inputs=[input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks],
496
+ outputs=[best_mask, low_res_mask, click_img, scribble_img],
497
+ )
498
+
499
+ # ----------------------------------------------
500
+ # Clicks
501
+ # ----------------------------------------------
502
+
503
+ click_img.select(get_select_coords,
504
+ inputs=[
505
+ predictor,
506
+ input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords, click_labels, bbox_coords,
507
+ seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
508
+ output_img, binary_checkbox, multimask_mode, autopredict_checkbox
509
+ ],
510
+ outputs=[click_img, scribble_img, output_img, best_mask, low_res_mask, img_features,
511
+ click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask],
512
+ api_name = "get_select_coords"
513
+ )
514
+
515
+ submit_button.click(fn=refresh_predictions,
516
+ inputs=[
517
+ predictor, input_img, output_img, click_coords, click_labels, bbox_coords, brush_label,
518
+ scribble_img, seperate_scribble_masks, last_scribble_mask,
519
+ best_mask, low_res_mask, img_features, binary_checkbox, multimask_mode
520
+ ],
521
+ outputs=[click_img, scribble_img, output_img, best_mask, low_res_mask, img_features,
522
+ seperate_scribble_masks, last_scribble_mask],
523
+ api_name="refresh_predictions"
524
+ )
525
+
526
+ undo_click_button.click(fn=undo_click,
527
+ inputs=[
528
+ predictor,
529
+ input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords, click_labels, bbox_coords,
530
+ seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
531
+ output_img, binary_checkbox, multimask_mode, autopredict_checkbox
532
+ ],
533
+ outputs=[click_img, scribble_img, output_img, best_mask, low_res_mask, img_features,
534
+ click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask],
535
+ api_name="undo_click"
536
+ )
537
+
538
+ def update_click_img(input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox,
539
+ last_scribble_mask, scribble_img, brush_label, best_mask):
540
+ """
541
+ Draw scribbles in the click canvas
542
+ """
543
+ seperate_scribble_masks, last_scribble_mask = get_scribbles(
544
+ seperate_scribble_masks, last_scribble_mask, scribble_img,
545
+ label=(0 if brush_label == "Positive (green)" else 1) # previous color of the brush
546
+ )
547
+ click_input_viz = viz_pred_mask(
548
+ input_img, best_mask, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox
549
+ )
550
+ return click_input_viz, seperate_scribble_masks, last_scribble_mask
551
+
552
+ click_tab.select(fn=update_click_img,
553
+ inputs=[input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks,
554
+ binary_checkbox, last_scribble_mask, scribble_img, brush_label, best_mask],
555
+ outputs=[click_img, seperate_scribble_masks, last_scribble_mask],
556
+ api_name="update_click_img"
557
+ )
558
+
559
+ # ----------------------------------------------
560
+ # Scribbles
561
+ # ----------------------------------------------
562
+
563
+ def change_brush_color(seperate_scribble_masks, last_scribble_mask, scribble_img, label):
564
+ """
565
+ Recorn new scribbles when changing brush color
566
+ """
567
+ if label == "Negative (red)":
568
+ brush_update = gr.Image.update(brush_color = "#FF0000") # red
569
+ elif label == "Positive (green)":
570
+ brush_update = gr.Image.update(brush_color = "#00FF00") # green
571
+ else:
572
+ raise TypeError("Invalid brush color")
573
+
574
+ # Record latest scribbles
575
+ seperate_scribble_masks, last_scribble_mask = get_scribbles(
576
+ seperate_scribble_masks, last_scribble_mask, scribble_img,
577
+ label=(1 if label == "Positive (green)" else 0) # previous color of the brush
578
+ )
579
+
580
+ return seperate_scribble_masks, last_scribble_mask, brush_update
581
+
582
+ brush_label.change(fn=change_brush_color,
583
+ inputs=[seperate_scribble_masks, last_scribble_mask, scribble_img, brush_label],
584
+ outputs=[seperate_scribble_masks, last_scribble_mask, scribble_img],
585
+ api_name="change_brush_color"
586
+ )
587
+
588
+
589
+ if __name__ == "__main__":
590
+
591
+ demo.queue(api_open=False).launch(show_api=False)
checkpoints/ScribblePrompt_unet_v1_nf192_res128.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:43f57ee8fa8ec529c31be281e06749f9e629b30157bbbcc9baf200cddec1acbe
3
+ size 15977486
network.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Dict, Any, List
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ # -----------------------------------------------------------------------------
6
+ # Blocks
7
+ # -----------------------------------------------------------------------------
8
+
9
+ class Conv2d(nn.Module):
10
+ """ Perform a 2D convolution
11
+
12
+ inputs are [b, c, h, w] where
13
+ b is the batch size
14
+ c is the number of channels
15
+ h is the height
16
+ w is the width
17
+ """
18
+ def __init__(self,
19
+ in_channels: int,
20
+ out_channels: int,
21
+ kernel_size: int,
22
+ padding: int,
23
+ do_activation: bool = True,
24
+ ):
25
+ super(Conv2d, self).__init__()
26
+
27
+ conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
28
+ lst = [conv]
29
+
30
+ if do_activation:
31
+ lst.append(nn.PReLU())
32
+
33
+ self.conv = nn.Sequential(*lst)
34
+
35
+ def forward(self, x):
36
+ # x is [B, C, H, W]
37
+ return self.conv(x)
38
+
39
+ # -----------------------------------------------------------------------------
40
+ # Network
41
+ # -----------------------------------------------------------------------------
42
+
43
+ class _UNet(nn.Module):
44
+ def __init__(self,
45
+ in_channels: int = 1,
46
+ out_channels: int = 1,
47
+ features: List[int] = [64, 64, 64, 64, 64],
48
+ conv_kernel_size: int = 3,
49
+ conv: Optional[nn.Module] = None,
50
+ conv_kwargs: Dict[str,Any] = {}
51
+ ):
52
+ """
53
+ UNet (but can switch out the Conv)
54
+ """
55
+ super(_UNet, self).__init__()
56
+
57
+ self.in_channels = in_channels
58
+
59
+ padding = (conv_kernel_size - 1) // 2
60
+
61
+ self.ups = nn.ModuleList()
62
+ self.downs = nn.ModuleList()
63
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
64
+
65
+ # Down part of U-Net
66
+ for feat in features:
67
+ self.downs.append(
68
+ conv(
69
+ in_channels, feat, kernel_size=conv_kernel_size, padding=padding, **conv_kwargs
70
+ )
71
+ )
72
+ in_channels = feat
73
+
74
+ # Up part of U-Net
75
+ for feat in reversed(features):
76
+ self.ups.append(nn.UpsamplingBilinear2d(scale_factor=2))
77
+ self.ups.append(
78
+ conv(
79
+ # Factor of 2 is for the skip connections
80
+ feat * 2, feat, kernel_size=conv_kernel_size, padding=padding, **conv_kwargs
81
+ )
82
+ )
83
+
84
+ self.bottleneck = conv(
85
+ features[-1], features[-1], kernel_size=conv_kernel_size, padding=padding, **conv_kwargs
86
+ )
87
+ self.final_conv = conv(
88
+ features[0], out_channels, kernel_size=1, padding=0, do_activation=False, **conv_kwargs
89
+ )
90
+
91
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
92
+ skip_connections = []
93
+ for down in self.downs:
94
+ x = down(x)
95
+ skip_connections.append(x)
96
+ x = self.pool(x)
97
+
98
+ x = self.bottleneck(x)
99
+ skip_connections = skip_connections[::-1]
100
+
101
+ for idx in range(0, len(self.ups), 2):
102
+ x = self.ups[idx](x)
103
+ skip_connection = skip_connections[idx // 2]
104
+
105
+ concat_skip = torch.cat((skip_connection, x), dim=1)
106
+ x = self.ups[idx + 1](concat_skip)
107
+
108
+ return self.final_conv(x)
109
+
110
+
111
+ class UNet(_UNet):
112
+ """
113
+ Unet with normal conv blocks
114
+
115
+ input shape: B x C x H x W
116
+ output shape: B x C x H x W
117
+ """
118
+ def __init__(self, **kwargs) -> None:
119
+ super().__init__(conv=Conv2d, **kwargs)
120
+
121
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
122
+ return super().forward(x)
123
+
predictor.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from typing import Dict, Tuple, Optional
4
+ import network
5
+
6
+
7
+ class Predictor:
8
+ """
9
+ Wrapper for ScribblePrompt Unet model
10
+ """
11
+ def __init__(self, path: str, verbose: bool = False):
12
+
13
+ self.verbose = verbose
14
+
15
+ assert path.exists(), f"Checkpoint {path} does not exist"
16
+ self.path = path
17
+
18
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ self.build_model()
20
+ self.load()
21
+ self.model.eval()
22
+ self.to_device()
23
+
24
+ def build_model(self):
25
+ """
26
+ Build the model
27
+ """
28
+ self.model = network.UNet(
29
+ in_channels = 5,
30
+ out_channels = 1,
31
+ features = [192, 192, 192, 192],
32
+ )
33
+
34
+ def load(self):
35
+ """
36
+ Load the state of the model from a checkpoint file.
37
+ """
38
+ with (self.path).open("rb") as f:
39
+ state = torch.load(f, map_location=self.device)
40
+ self.model.load_state_dict(state, strict=True)
41
+ if self.verbose:
42
+ print(
43
+ f"Loaded checkpoint from {self.path} to {self.device}"
44
+ )
45
+
46
+ def to_device(self):
47
+ """
48
+ Move the model to cpu or gpu
49
+ """
50
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
51
+ self.model = self.model.to(self.device)
52
+
53
+ def predict(self, prompts: Dict[str,any], img_features: Optional[torch.Tensor] = None, multimask_mode: bool = False):
54
+ """
55
+ Make predictions!
56
+
57
+ Returns:
58
+ mask (torch.Tensor): H x W
59
+ img_features (torch.Tensor): B x 1 x H x W (for SAM models)
60
+ low_res_mask (torch.Tensor): B x 1 x H x W logits
61
+ """
62
+ if self.verbose:
63
+ print("point_coords", prompts.get("point_coords", None))
64
+ print("point_labels", prompts.get("point_labels", None))
65
+ print("box", prompts.get("box", None))
66
+ print("img", prompts.get("img").shape, prompts.get("img").min(), prompts.get("img").max())
67
+ if prompts.get("scribble") is not None:
68
+ print("scribble", prompts.get("scribble", None).shape, prompts.get("scribble").min(), prompts.get("scribble").max())
69
+
70
+ original_shape = prompts.get('img').shape[-2:]
71
+
72
+ # Rescale to 128 x 128
73
+ prompts = rescale_inputs(prompts)
74
+
75
+ # Prepare inputs for ScribblePrompt unet (1 x 5 x 128 x 128)
76
+ x = prepare_inputs(prompts).float()
77
+
78
+ with torch.no_grad():
79
+ yhat = self.model(x.to(self.device)).cpu()
80
+
81
+ mask = torch.sigmoid(yhat)
82
+
83
+ # Resize for app resolution
84
+ mask = F.interpolate(mask, size=original_shape, mode='bilinear').squeeze()
85
+
86
+ # mask: H x W, yhat: 1 x 1 x H x W
87
+ return mask, None, yhat
88
+
89
+
90
+ # -----------------------------------------------------------------------------
91
+ # Prepare inputs
92
+ # -----------------------------------------------------------------------------
93
+
94
+ def rescale_inputs(inputs: Dict[str,any], res=128):
95
+ """
96
+ Rescale the inputs
97
+ """
98
+ h,w = inputs['img'].shape[-2:]
99
+ if h != res or w != res:
100
+
101
+ inputs.update(dict(
102
+ img = F.interpolate(inputs['img'], size=(res,res), mode='bilinear')
103
+ ))
104
+
105
+ if inputs.get('scribble') is not None:
106
+ inputs.update({
107
+ 'scribble': F.interpolate(inputs['scribble'], size=(res,res), mode='bilinear')
108
+ })
109
+
110
+ if inputs.get("box") is not None:
111
+ boxes = inputs.get("box").clone()
112
+ coords = boxes.reshape(-1, 2, 2)
113
+ coords[..., 0] = coords[..., 0] * (res / w)
114
+ coords[..., 1] = coords[..., 1] * (res / h)
115
+ inputs.update({'box': coords.reshape(1, -1, 4).int()})
116
+
117
+ if inputs.get("point_coords") is not None:
118
+ coords = inputs.get("point_coords").clone()
119
+ coords[..., 0] = coords[..., 0] * (res / w)
120
+ coords[..., 1] = coords[..., 1] * (res / h)
121
+ inputs.update({'point_coords': coords.int()})
122
+
123
+ return inputs
124
+
125
+ def prepare_inputs(inputs: Dict[str,torch.Tensor], device = None) -> torch.Tensor:
126
+ """
127
+ Prepare inputs for ScribblePrompt Unet
128
+
129
+ Returns:
130
+ x (torch.Tensor): B x 5 x H x W
131
+ """
132
+ img = inputs['img']
133
+ if device is None:
134
+ device = img.device
135
+
136
+ img = img.to(device)
137
+ shape = tuple(img.shape[-2:])
138
+
139
+ if inputs.get("box") is not None:
140
+ # Embed bounding box
141
+ # Input: B x 1 x 4
142
+ # Output: B x 1 x H x W
143
+ box_embed = bbox_shaded(inputs['box'], shape=shape, device=device)
144
+ else:
145
+ box_embed = torch.zeros(img.shape, device=device)
146
+
147
+ if inputs.get("point_coords") is not None:
148
+ # Encode points
149
+ # B x 2 x H x W
150
+ scribble_click_embed = click_onehot(inputs['point_coords'], inputs['point_labels'], shape=shape)
151
+ else:
152
+ scribble_click_embed = torch.zeros((img.shape[0], 2) + shape, device=device)
153
+
154
+ if inputs.get("scribble") is not None:
155
+ # Combine scribbles with click encoding
156
+ # B x 2 x H x W
157
+ scribble_click_embed = torch.clamp(scribble_click_embed + inputs.get('scribble'), min=0.0, max=1.0)
158
+
159
+ if inputs.get('mask_input') is not None:
160
+ # Previous prediction
161
+ mask_input = inputs['mask_input']
162
+ else:
163
+ # Initialize empty channel for mask input
164
+ mask_input = torch.zeros(img.shape, device=img.device)
165
+
166
+ x = torch.cat((img, box_embed, scribble_click_embed, mask_input), dim=-3)
167
+ # B x 5 x H x W
168
+
169
+ return x
170
+
171
+ # -----------------------------------------------------------------------------
172
+ # Encode clicks and bounding boxes
173
+ # -----------------------------------------------------------------------------
174
+
175
+ def click_onehot(point_coords, point_labels, shape: Tuple[int,int] = (128,128), indexing='xy'):
176
+ """
177
+ Represent clicks as two HxW binary masks (one for positive clicks and one for negative)
178
+ with 1 at the click locations and 0 otherwise
179
+
180
+ Args:
181
+ point_coords (torch.Tensor): BxNx2 tensor of xy coordinates
182
+ point_labels (torch.Tensor): BxN tensor of labels (0 or 1)
183
+ shape (tuple): output shape
184
+ Returns:
185
+ embed (torch.Tensor): Bx2xHxW tensor
186
+ """
187
+ assert indexing in ['xy','uv'], f"Invalid indexing: {indexing}"
188
+ assert len(point_coords.shape) == 3, "point_coords must be BxNx2"
189
+ assert point_coords.shape[-1] == 2, "point_coords must be BxNx2"
190
+ assert point_labels.shape[-1] == point_coords.shape[1], "point_labels must be BxN"
191
+ assert len(shape)==2, f"shape must be 2D: {shape}"
192
+
193
+ device = point_coords.device
194
+ batch_size = point_coords.shape[0]
195
+ n_points = point_coords.shape[1]
196
+
197
+ embed = torch.zeros((batch_size,2)+shape, device=device)
198
+ labels = point_labels.flatten().float()
199
+
200
+ idx_coords = torch.cat((
201
+ torch.arange(batch_size, device=device).reshape(-1,1).repeat(1,n_points)[...,None],
202
+ point_coords
203
+ ), axis=2).reshape(-1,3)
204
+
205
+ if indexing=='xy':
206
+ embed[ idx_coords[:,0], 0, idx_coords[:,2], idx_coords[:,1] ] = labels
207
+ embed[ idx_coords[:,0], 1, idx_coords[:,2], idx_coords[:,1] ] = 1.0-labels
208
+ else:
209
+ embed[ idx_coords[:,0], 0, idx_coords[:,1], idx_coords[:,2] ] = labels
210
+ embed[ idx_coords[:,0], 1, idx_coords[:,1], idx_coords[:,2] ] = 1.0-labels
211
+
212
+ return embed
213
+
214
+
215
+ def bbox_shaded(boxes, shape: Tuple[int,int] = (128,128), device='cpu'):
216
+ """
217
+ Represent bounding boxes as a binary mask with 1 inside boxes and 0 otherwise
218
+
219
+ Args:
220
+ boxes (torch.Tensor): Bx1x4 [x1, y1, x2, y2]
221
+ Returns:
222
+ bbox_embed (torch.Tesor): Bx1xHxW according to shape
223
+ """
224
+ assert len(shape)==2, "shape must be 2D"
225
+ if isinstance(boxes, torch.Tensor):
226
+ boxes = boxes.int().cpu().numpy()
227
+
228
+ batch_size = boxes.shape[0]
229
+ n_boxes = boxes.shape[1]
230
+ bbox_embed = torch.zeros((batch_size,1)+tuple(shape), device=device, dtype=torch.float32)
231
+
232
+ if boxes is not None:
233
+ for i in range(batch_size):
234
+ for j in range(n_boxes):
235
+ x1, y1, x2, y2 = boxes[i,j,:]
236
+ x_min = min(x1,x2)
237
+ x_max = max(x1,x2)
238
+ y_min = min(y1,y2)
239
+ y_max = max(y1,y2)
240
+ bbox_embed[ i, 0, y_min:y_max, x_min:x_max ] = 1.0
241
+
242
+ return bbox_embed
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ numpy
2
+ torch
3
+ opencv-python
4
+ pathlib
test_examples/COBRE.jpg ADDED
test_examples/SCR.jpg ADDED
test_examples/TotalSegmentator.jpg ADDED
test_examples/TotalSegmentator_2.jpg ADDED
val_od_examples/ACDC.jpg ADDED
val_od_examples/BTCV.jpg ADDED
val_od_examples/BUID.jpg ADDED
val_od_examples/DRIVE.jpg ADDED
val_od_examples/HipXRay.jpg ADDED
val_od_examples/PanDental.jpg ADDED
val_od_examples/SCD.jpg ADDED
val_od_examples/SpineWeb.jpg ADDED
val_od_examples/WBC.jpg ADDED