Spaces:
Running
Running
halleewong
commited on
Commit
•
b6bb35e
1
Parent(s):
26e8a2f
initial commit
Browse files- LICENSE +201 -0
- README.md +5 -5
- app.py +591 -0
- checkpoints/ScribblePrompt_unet_v1_nf192_res128.pt +3 -0
- network.py +123 -0
- predictor.py +242 -0
- requirements.txt +4 -0
- test_examples/COBRE.jpg +0 -0
- test_examples/SCR.jpg +0 -0
- test_examples/TotalSegmentator.jpg +0 -0
- test_examples/TotalSegmentator_2.jpg +0 -0
- val_od_examples/ACDC.jpg +0 -0
- val_od_examples/BTCV.jpg +0 -0
- val_od_examples/BUID.jpg +0 -0
- val_od_examples/DRIVE.jpg +0 -0
- val_od_examples/HipXRay.jpg +0 -0
- val_od_examples/PanDental.jpg +0 -0
- val_od_examples/SCD.jpg +0 -0
- val_od_examples/SpineWeb.jpg +0 -0
- val_od_examples/WBC.jpg +0 -0
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
colorFrom: blue
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
10 |
license: apache-2.0
|
11 |
---
|
12 |
|
|
|
1 |
---
|
2 |
+
title: Scribbleprompt
|
3 |
+
emoji: 🩻
|
4 |
colorFrom: blue
|
5 |
+
colorTo: pink
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.41.0
|
8 |
app_file: app.py
|
9 |
+
pinned: true
|
10 |
license: apache-2.0
|
11 |
---
|
12 |
|
app.py
ADDED
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import os
|
6 |
+
import cv2
|
7 |
+
import pathlib
|
8 |
+
|
9 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
10 |
+
|
11 |
+
from predictor import Predictor
|
12 |
+
|
13 |
+
RES = 256
|
14 |
+
|
15 |
+
test_example_dir = pathlib.Path("./test_examples")
|
16 |
+
test_examples = [str(test_example_dir / x) for x in sorted(os.listdir(test_example_dir))]
|
17 |
+
|
18 |
+
val_example_dir = pathlib.Path("./val_od_examples")
|
19 |
+
val_examples = [str(val_example_dir / x) for x in sorted(os.listdir(val_example_dir))]
|
20 |
+
|
21 |
+
default_example = test_example_dir / "TotalSegmentator_2.jpg"
|
22 |
+
exp_dir = pathlib.Path('./checkpoints')
|
23 |
+
default_model = 'ScribblePrompt-Unet'
|
24 |
+
|
25 |
+
model_dict = {
|
26 |
+
'ScribblePrompt-Unet': 'ScribblePrompt_unet_v1_nf192_res128.pt'
|
27 |
+
}
|
28 |
+
|
29 |
+
# -----------------------------------------------------------------------------
|
30 |
+
# Model initialization functions
|
31 |
+
# -----------------------------------------------------------------------------
|
32 |
+
|
33 |
+
def load_model(exp_key: str = default_model):
|
34 |
+
fpath = exp_dir / model_dict.get(exp_key)
|
35 |
+
exp = Predictor(fpath)
|
36 |
+
return exp, None
|
37 |
+
|
38 |
+
# -----------------------------------------------------------------------------
|
39 |
+
# Vizualization functions
|
40 |
+
# -----------------------------------------------------------------------------
|
41 |
+
|
42 |
+
def _get_overlay(img, lay, const_color="l_blue"):
|
43 |
+
"""
|
44 |
+
Helper function for preparing overlay
|
45 |
+
"""
|
46 |
+
assert lay.ndim==2, "Overlay must be 2D, got shape: " + str(lay.shape)
|
47 |
+
|
48 |
+
if img.ndim == 2:
|
49 |
+
img = np.repeat(img[...,None], 3, axis=-1)
|
50 |
+
|
51 |
+
assert img.ndim==3, "Image must be 3D, got shape: " + str(img.shape)
|
52 |
+
|
53 |
+
if const_color == "blue":
|
54 |
+
const_color = 255*np.array([0, 0, 1])
|
55 |
+
elif const_color == "green":
|
56 |
+
const_color = 255*np.array([0, 1, 0])
|
57 |
+
elif const_color == "red":
|
58 |
+
const_color = 255*np.array([1, 0, 0])
|
59 |
+
elif const_color == "l_blue":
|
60 |
+
const_color = np.array([31, 119, 180])
|
61 |
+
elif const_color == "orange":
|
62 |
+
const_color = np.array([255, 127, 14])
|
63 |
+
else:
|
64 |
+
raise NotImplementedError
|
65 |
+
|
66 |
+
x,y = np.nonzero(lay)
|
67 |
+
for i in range(img.shape[-1]):
|
68 |
+
img[x,y,i] = const_color[i]
|
69 |
+
|
70 |
+
return img
|
71 |
+
|
72 |
+
def image_overlay(img, mask=None, scribbles=None, contour=False, alpha=0.5):
|
73 |
+
"""
|
74 |
+
Overlay the ground truth mask and scribbles on the image if provided
|
75 |
+
"""
|
76 |
+
assert img.ndim == 2, "Image must be 2D, got shape: " + str(img.shape)
|
77 |
+
output = np.repeat(img[...,None], 3, axis=-1)
|
78 |
+
|
79 |
+
if mask is not None:
|
80 |
+
|
81 |
+
assert mask.ndim == 2, "Mask must be 2D, got shape: " + str(mask.shape)
|
82 |
+
|
83 |
+
if contour:
|
84 |
+
contours = cv2.findContours((mask[...,None]>0.5).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
85 |
+
cv2.drawContours(output, contours[0], -1, (0, 255, 0), 1)
|
86 |
+
else:
|
87 |
+
mask_overlay = _get_overlay(img, mask)
|
88 |
+
mask2 = 0.5*np.repeat(mask[...,None], 3, axis=-1)
|
89 |
+
output = cv2.convertScaleAbs(mask_overlay * mask2 + output * (1 - mask2))
|
90 |
+
|
91 |
+
if scribbles is not None:
|
92 |
+
pos_scribble_overlay = _get_overlay(output, scribbles[0,...], const_color="green")
|
93 |
+
cv2.addWeighted(pos_scribble_overlay, alpha, output, 1 - alpha, 0, output)
|
94 |
+
neg_scribble_overlay = _get_overlay(output, scribbles[1,...], const_color="red")
|
95 |
+
cv2.addWeighted(neg_scribble_overlay, alpha, output, 1 - alpha, 0, output)
|
96 |
+
|
97 |
+
return output
|
98 |
+
|
99 |
+
|
100 |
+
def viz_pred_mask(img, mask=None, point_coords=None, point_labels=None, bbox_coords=None, seperate_scribble_masks=None, binary=True):
|
101 |
+
"""
|
102 |
+
Visualize image with clicks, scribbles, predicted mask overlaid
|
103 |
+
"""
|
104 |
+
assert isinstance(img, np.ndarray), "Image must be numpy array, got type: " + str(type(img))
|
105 |
+
if mask is not None:
|
106 |
+
if isinstance(mask, torch.Tensor):
|
107 |
+
mask = mask.cpu().numpy()
|
108 |
+
|
109 |
+
if binary and mask is not None:
|
110 |
+
mask = 1*(mask > 0.5)
|
111 |
+
|
112 |
+
out = image_overlay(img, mask=mask, scribbles=seperate_scribble_masks)
|
113 |
+
|
114 |
+
if point_coords is not None:
|
115 |
+
for i,(col,row) in enumerate(point_coords):
|
116 |
+
if point_labels[i] == 1:
|
117 |
+
cv2.circle(out,(col, row), 2, (0,255,0), -1)
|
118 |
+
else:
|
119 |
+
cv2.circle(out,(col, row), 2, (255,0,0), -1)
|
120 |
+
|
121 |
+
if bbox_coords is not None:
|
122 |
+
for i in range(len(bbox_coords)//2):
|
123 |
+
cv2.rectangle(out, bbox_coords[2*i], bbox_coords[2*i+1], (255,165,0), 1)
|
124 |
+
if len(bbox_coords) % 2 == 1:
|
125 |
+
cv2.circle(out, tuple(bbox_coords[-1]), 2, (255,165,0), -1)
|
126 |
+
|
127 |
+
return out
|
128 |
+
|
129 |
+
# -----------------------------------------------------------------------------
|
130 |
+
# Collect scribbles
|
131 |
+
# -----------------------------------------------------------------------------
|
132 |
+
|
133 |
+
def get_scribbles(seperate_scribble_masks, last_scribble_mask, scribble_img, label: int):
|
134 |
+
"""
|
135 |
+
Record scribbles
|
136 |
+
"""
|
137 |
+
assert isinstance(seperate_scribble_masks, np.ndarray), "seperate_scribble_masks must be numpy array, got type: " + str(type(seperate_scribble_masks))
|
138 |
+
|
139 |
+
if scribble_img is not None:
|
140 |
+
|
141 |
+
color_mask = scribble_img.get('mask')
|
142 |
+
scribble_mask = color_mask[...,0]/255
|
143 |
+
|
144 |
+
not_same = (scribble_mask != last_scribble_mask)
|
145 |
+
if not isinstance(not_same, bool):
|
146 |
+
not_same = not_same.any()
|
147 |
+
|
148 |
+
if not_same:
|
149 |
+
# In case any scribbles were removed
|
150 |
+
corrected_scribble_masks = np.stack(2*[(scribble_mask > 0)], axis=0)*seperate_scribble_masks
|
151 |
+
corrected_last_scribble_mask = last_scribble_mask*(scribble_mask > 0)
|
152 |
+
|
153 |
+
delta = (scribble_mask - corrected_last_scribble_mask) > 0
|
154 |
+
new_scribbles = scribble_mask * delta
|
155 |
+
corrected_scribble_masks[label,...] = np.clip(corrected_scribble_masks[label,...] + new_scribbles, a_min=0, a_max=1)
|
156 |
+
|
157 |
+
last_scribble_mask = scribble_mask
|
158 |
+
seperate_scribble_masks = corrected_scribble_masks
|
159 |
+
|
160 |
+
return seperate_scribble_masks, last_scribble_mask
|
161 |
+
|
162 |
+
def get_predictions(predictor, input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks, low_res_mask, img_features, multimask_mode):
|
163 |
+
"""
|
164 |
+
Make predictions
|
165 |
+
"""
|
166 |
+
box = None
|
167 |
+
if len(bbox_coords) == 1:
|
168 |
+
gr.Error("Please click a second time to define the bounding box")
|
169 |
+
box = None
|
170 |
+
elif len(bbox_coords) == 2:
|
171 |
+
box = torch.Tensor(bbox_coords).flatten()[None,None,...].int().to(device) # B x n x 4
|
172 |
+
|
173 |
+
if seperate_scribble_masks is not None:
|
174 |
+
scribble = torch.from_numpy(seperate_scribble_masks)[None,...].to(device)
|
175 |
+
else:
|
176 |
+
scribble = None
|
177 |
+
|
178 |
+
prompts = dict(
|
179 |
+
img=torch.from_numpy(input_img)[None,None,...].to(device)/255,
|
180 |
+
point_coords=torch.Tensor([click_coords]).int().to(device) if len(click_coords)>0 else None,
|
181 |
+
point_labels=torch.Tensor([click_labels]).int().to(device) if len(click_labels)>0 else None,
|
182 |
+
scribble=scribble,
|
183 |
+
mask_input=low_res_mask.to(device) if low_res_mask is not None else None,
|
184 |
+
box=box,
|
185 |
+
)
|
186 |
+
|
187 |
+
mask, img_features, low_res_mask = predictor.predict(prompts, img_features, multimask_mode=multimask_mode)
|
188 |
+
|
189 |
+
return mask, img_features, low_res_mask
|
190 |
+
|
191 |
+
def refresh_predictions(predictor, input_img, output_img, click_coords, click_labels, bbox_coords, brush_label,
|
192 |
+
scribble_img, seperate_scribble_masks, last_scribble_mask,
|
193 |
+
best_mask, low_res_mask, img_features, binary_checkbox, multimask_mode):
|
194 |
+
|
195 |
+
# Record any new scribbles
|
196 |
+
seperate_scribble_masks, last_scribble_mask = get_scribbles(
|
197 |
+
seperate_scribble_masks, last_scribble_mask, scribble_img,
|
198 |
+
label=(0 if brush_label == "Positive (green)" else 1) # current color of the brush
|
199 |
+
)
|
200 |
+
|
201 |
+
# Make prediction
|
202 |
+
best_mask, img_features, low_res_mask = get_predictions(
|
203 |
+
predictor, input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks, low_res_mask, img_features, multimask_mode
|
204 |
+
)
|
205 |
+
|
206 |
+
# Update input visualizations
|
207 |
+
mask_to_viz = best_mask.numpy()
|
208 |
+
click_input_viz = viz_pred_mask(input_img, mask_to_viz, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox)
|
209 |
+
scribble_input_viz = viz_pred_mask(input_img, mask_to_viz, click_coords, click_labels, bbox_coords, None, binary_checkbox)
|
210 |
+
|
211 |
+
out_viz = [
|
212 |
+
viz_pred_mask(input_img, mask_to_viz, point_coords=None, point_labels=None, bbox_coords=None, seperate_scribble_masks=None, binary=binary_checkbox),
|
213 |
+
255*(mask_to_viz[...,None].repeat(axis=2, repeats=3)>0.5) if binary_checkbox else mask_to_viz[...,None].repeat(axis=2, repeats=3),
|
214 |
+
]
|
215 |
+
|
216 |
+
return click_input_viz, scribble_input_viz, out_viz, best_mask, low_res_mask, img_features, seperate_scribble_masks, last_scribble_mask
|
217 |
+
|
218 |
+
|
219 |
+
def get_select_coords(predictor, input_img, brush_label, bbox_label, best_mask, low_res_mask,
|
220 |
+
click_coords, click_labels, bbox_coords,
|
221 |
+
seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
|
222 |
+
output_img, binary_checkbox, multimask_mode, autopredict_checkbox, evt: gr.SelectData):
|
223 |
+
"""
|
224 |
+
Record user click and update the prediction
|
225 |
+
"""
|
226 |
+
# Record click coordinates
|
227 |
+
if bbox_label:
|
228 |
+
bbox_coords.append(evt.index)
|
229 |
+
elif brush_label in ['Positive (green)', 'Negative (red)']:
|
230 |
+
click_coords.append(evt.index)
|
231 |
+
click_labels.append(1 if brush_label=='Positive (green)' else 0)
|
232 |
+
else:
|
233 |
+
raise TypeError("Invalid brush label: {brush_label}")
|
234 |
+
|
235 |
+
# Only make new prediction if not waiting for additional bounding box click
|
236 |
+
if (len(bbox_coords) % 2 == 0) and autopredict_checkbox:
|
237 |
+
|
238 |
+
click_input_viz, scribble_input_viz, output_viz, best_mask, low_res_mask, img_features, seperate_scribble_masks, last_scribble_mask = refresh_predictions(
|
239 |
+
predictor, input_img, output_img, click_coords, click_labels, bbox_coords, brush_label,
|
240 |
+
scribble_img, seperate_scribble_masks, last_scribble_mask,
|
241 |
+
best_mask, low_res_mask, img_features, binary_checkbox, multimask_mode
|
242 |
+
)
|
243 |
+
return click_input_viz, scribble_input_viz, output_viz, best_mask, low_res_mask, img_features, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask
|
244 |
+
|
245 |
+
else:
|
246 |
+
click_input_viz = viz_pred_mask(
|
247 |
+
input_img, best_mask, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox
|
248 |
+
)
|
249 |
+
scribble_input_viz = viz_pred_mask(
|
250 |
+
input_img, best_mask, click_coords, click_labels, bbox_coords, None, binary_checkbox
|
251 |
+
)
|
252 |
+
# Don't update output image if waiting for additional bounding box click
|
253 |
+
return click_input_viz, scribble_input_viz, output_img, best_mask, low_res_mask, img_features, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask
|
254 |
+
|
255 |
+
|
256 |
+
def undo_click(predictor, input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords, click_labels, bbox_coords,
|
257 |
+
seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
|
258 |
+
output_img, binary_checkbox, multimask_mode, autopredict_checkbox):
|
259 |
+
"""
|
260 |
+
Remove last click and then update the prediction
|
261 |
+
"""
|
262 |
+
if bbox_label:
|
263 |
+
if len(bbox_coords) > 0:
|
264 |
+
bbox_coords.pop()
|
265 |
+
elif brush_label in ['Positive (green)', 'Negative (red)']:
|
266 |
+
if len(click_coords) > 0:
|
267 |
+
click_coords.pop()
|
268 |
+
click_labels.pop()
|
269 |
+
else:
|
270 |
+
raise TypeError("Invalid brush label: {brush_label}")
|
271 |
+
|
272 |
+
# Only make new prediction if not waiting for additional bounding box click
|
273 |
+
if (len(bbox_coords)==0 or len(bbox_coords)==2) and autopredict_checkbox:
|
274 |
+
|
275 |
+
click_input_viz, scribble_input_viz, output_viz, best_mask, low_res_mask, img_features, seperate_scribble_masks, last_scribble_mask = refresh_predictions(
|
276 |
+
predictor, input_img, output_img, click_coords, click_labels, bbox_coords, brush_label,
|
277 |
+
scribble_img, seperate_scribble_masks, last_scribble_mask,
|
278 |
+
best_mask, low_res_mask, img_features, binary_checkbox, multimask_mode
|
279 |
+
)
|
280 |
+
return click_input_viz, scribble_input_viz, output_viz, best_mask, low_res_mask, img_features, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask
|
281 |
+
|
282 |
+
else:
|
283 |
+
click_input_viz = viz_pred_mask(
|
284 |
+
input_img, best_mask, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox
|
285 |
+
)
|
286 |
+
scribble_input_viz = viz_pred_mask(
|
287 |
+
input_img, best_mask, click_coords, click_labels, bbox_coords, None, binary_checkbox
|
288 |
+
)
|
289 |
+
|
290 |
+
# Don't update output image if waiting for additional bounding box click
|
291 |
+
return click_input_viz, scribble_input_viz, output_img, best_mask, low_res_mask, img_features, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask
|
292 |
+
|
293 |
+
|
294 |
+
|
295 |
+
# --------------------------------------------------
|
296 |
+
|
297 |
+
with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as demo:
|
298 |
+
|
299 |
+
# State variables
|
300 |
+
seperate_scribble_masks = gr.State(np.zeros((2,RES,RES), dtype=np.float32))
|
301 |
+
last_scribble_mask = gr.State(np.zeros((RES,RES), dtype=np.float32))
|
302 |
+
|
303 |
+
click_coords = gr.State([])
|
304 |
+
click_labels = gr.State([])
|
305 |
+
bbox_coords = gr.State([])
|
306 |
+
|
307 |
+
# Load default model
|
308 |
+
predictor = gr.State(load_model()[0])
|
309 |
+
img_features = gr.State(None) # For SAM models
|
310 |
+
best_mask = gr.State(None)
|
311 |
+
low_res_mask = gr.State(None)
|
312 |
+
|
313 |
+
gr.HTML("""\
|
314 |
+
<h1 style="text-align: center; font-size: 28pt;">ScribblePrompt: Fast and Flexible Interactive Segmention for Any Medical Image</h1>
|
315 |
+
<p style="text-align: center; font-size: large;"><a href="https://scribbleprompt.csail.mit.edu">ScribblePrompt</a> is an interactive segmentation tool designed to help users segment <b>new</b> structures in medical images using scribbles, clicks <b>and</b> bounding boxes.
|
316 |
+
</p>
|
317 |
+
|
318 |
+
""")
|
319 |
+
|
320 |
+
with gr.Accordion("Open for instructions!", open=False):
|
321 |
+
gr.Markdown(
|
322 |
+
"""
|
323 |
+
* Select an input image from the examples below or upload your own image through the <b>'Input Image'</b> tab.
|
324 |
+
* Use the <b>'Scribbles'</b> tab to draw <span style='color:green'>positive</span> or <span style='color:red'>negative</span> scribbles.
|
325 |
+
- Use the buttons in the top right hand corner of the canvas to undo or adjust the brush size
|
326 |
+
- Note: the app cannot detect new scribbles drawn on top of previous scribbles in a different color. Please undo/erase the scribble before drawing on the same pixel in a different color.
|
327 |
+
* Use the <b>'Clicks/Boxes'</b> tab to draw <span style='color:green'>positive</span> or <span style='color:red'>negative</span> clicks and <span style='color:orange'>bounding boxes</span> by placing two clicks.
|
328 |
+
* The <b>'Output'</b> tab will show the model's prediction based on your current inputs and the previous prediction.
|
329 |
+
* The <b>'Clear Input Mask'</b> button will clear the latest prediction (which is used as an input to the model).
|
330 |
+
* The <b>'Clear All Inputs'</b> button will clear all inputs (including scribbles, clicks, bounding boxes, and the last prediction).
|
331 |
+
"""
|
332 |
+
)
|
333 |
+
|
334 |
+
|
335 |
+
# Interface ------------------------------------
|
336 |
+
|
337 |
+
with gr.Row():
|
338 |
+
model_dropdown = gr.Dropdown(
|
339 |
+
label="Model",
|
340 |
+
choices = list(model_dict.keys()),
|
341 |
+
value=default_model,
|
342 |
+
multiselect=False,
|
343 |
+
interactive=False,
|
344 |
+
visible=False
|
345 |
+
)
|
346 |
+
|
347 |
+
with gr.Row():
|
348 |
+
with gr.Column(scale=1):
|
349 |
+
brush_label = gr.Radio(["Positive (green)", "Negative (red)"],
|
350 |
+
value="Positive (green)", label="Scribble/Click Label")
|
351 |
+
bbox_label = gr.Checkbox(value=False, label="Bounding Box (2 clicks)")
|
352 |
+
with gr.Column(scale=1):
|
353 |
+
binary_checkbox = gr.Checkbox(value=True, label="Show binary masks", visible=False)
|
354 |
+
autopredict_checkbox = gr.Checkbox(value=True, label="Auto-update prediction on clicks")
|
355 |
+
gr.Markdown("<span style='color:orange'>Troubleshooting:</span> If the image does not fully load in the Scribbles tab, click 'Clear Scribbles' or 'Clear All Inputs' to reload (it make take multiple tries). If you encounter an <span style='color:orange'>error</span> try clicking 'Clear All Inputs'.")
|
356 |
+
multimask_mode = gr.Checkbox(value=True, label="Multi-mask mode", visible=False)
|
357 |
+
|
358 |
+
with gr.Row():
|
359 |
+
display_height = 500
|
360 |
+
|
361 |
+
with gr.Column(scale=1):
|
362 |
+
with gr.Tab("Scribbles"):
|
363 |
+
scribble_img = gr.Image(
|
364 |
+
label="Input",
|
365 |
+
brush_radius=3,
|
366 |
+
interactive=True,
|
367 |
+
brush_color="#00FF00",
|
368 |
+
tool="sketch",
|
369 |
+
height=display_height,
|
370 |
+
type='numpy',
|
371 |
+
value=default_example,
|
372 |
+
)
|
373 |
+
clear_scribble_button = gr.ClearButton([scribble_img], value="Clear Scribbles", variant="stop")
|
374 |
+
|
375 |
+
with gr.Tab("Clicks/Boxes") as click_tab:
|
376 |
+
click_img = gr.Image(
|
377 |
+
label="Input",
|
378 |
+
type='numpy',
|
379 |
+
value=default_example,
|
380 |
+
height=display_height
|
381 |
+
)
|
382 |
+
with gr.Row():
|
383 |
+
undo_click_button = gr.Button("Undo Last Click")
|
384 |
+
clear_click_button = gr.Button("Clear Clicks/Boxes", variant="stop")
|
385 |
+
|
386 |
+
with gr.Tab("Input Image"):
|
387 |
+
input_img = gr.Image(
|
388 |
+
label="Input",
|
389 |
+
image_mode="L",
|
390 |
+
visible=True,
|
391 |
+
value=default_example,
|
392 |
+
height=display_height
|
393 |
+
)
|
394 |
+
gr.Markdown("To upload your own image: click the `x` in the top right corner to clear the current image, then drag & drop")
|
395 |
+
|
396 |
+
with gr.Column(scale=1):
|
397 |
+
with gr.Tab("Output"):
|
398 |
+
output_img = gr.Gallery(
|
399 |
+
label='Outputs',
|
400 |
+
columns=1,
|
401 |
+
elem_id="gallery",
|
402 |
+
preview=True,
|
403 |
+
object_fit="scale-down",
|
404 |
+
height=display_height+50
|
405 |
+
)
|
406 |
+
|
407 |
+
submit_button = gr.Button("Refresh Prediction", variant='primary')
|
408 |
+
clear_all_button = gr.ClearButton([scribble_img], value="Clear All Inputs", variant="stop")
|
409 |
+
clear_mask_button = gr.Button("Clear Input Mask")
|
410 |
+
|
411 |
+
# ----------------------------------------------
|
412 |
+
# Loading Models
|
413 |
+
# ----------------------------------------------
|
414 |
+
|
415 |
+
model_dropdown.change(fn=load_model,
|
416 |
+
inputs=[model_dropdown],
|
417 |
+
outputs=[predictor, img_features]
|
418 |
+
)
|
419 |
+
|
420 |
+
# ----------------------------------------------
|
421 |
+
# Loading Examples
|
422 |
+
# ----------------------------------------------
|
423 |
+
|
424 |
+
gr.Examples(examples=test_examples,
|
425 |
+
inputs=[input_img],
|
426 |
+
examples_per_page=10,
|
427 |
+
label='Unseen Examples from Test Datasets'
|
428 |
+
)
|
429 |
+
|
430 |
+
gr.Examples(examples=val_examples,
|
431 |
+
inputs=[input_img],
|
432 |
+
examples_per_page=10,
|
433 |
+
label='Unseen Examples from Validation Datasets'
|
434 |
+
)
|
435 |
+
|
436 |
+
# When clear scribble button is clicked
|
437 |
+
def clear_scribble_history(input_img):
|
438 |
+
if input_img is not None:
|
439 |
+
input_shape = input_img.shape[:2]
|
440 |
+
else:
|
441 |
+
input_shape = (RES, RES)
|
442 |
+
return input_img, input_img, np.zeros((2,)+input_shape, dtype=np.float32), np.zeros(input_shape, dtype=np.float32), None, None
|
443 |
+
|
444 |
+
clear_scribble_button.click(clear_scribble_history,
|
445 |
+
inputs=[input_img],
|
446 |
+
outputs=[click_img, scribble_img, seperate_scribble_masks, last_scribble_mask, best_mask, low_res_mask]
|
447 |
+
)
|
448 |
+
|
449 |
+
# When clear clicks button is clicked
|
450 |
+
def clear_click_history(input_img):
|
451 |
+
return input_img, input_img, [], [], [], None, None
|
452 |
+
|
453 |
+
clear_click_button.click(clear_click_history,
|
454 |
+
inputs=[input_img],
|
455 |
+
outputs=[click_img, scribble_img, click_coords, click_labels, bbox_coords, best_mask, low_res_mask])
|
456 |
+
|
457 |
+
# When clear all button is clicked
|
458 |
+
def clear_all_history(input_img):
|
459 |
+
if input_img is not None:
|
460 |
+
input_shape = input_img.shape[:2]
|
461 |
+
else:
|
462 |
+
input_shape = (RES, RES)
|
463 |
+
return input_img, input_img, [], [], [], [], np.zeros((2,)+input_shape, dtype=np.float32), np.zeros(input_shape, dtype=np.float32), None, None, None
|
464 |
+
|
465 |
+
input_img.change(clear_all_history,
|
466 |
+
inputs=[input_img],
|
467 |
+
outputs=[click_img, scribble_img,
|
468 |
+
output_img, click_coords, click_labels, bbox_coords,
|
469 |
+
seperate_scribble_masks, last_scribble_mask,
|
470 |
+
best_mask, low_res_mask, img_features
|
471 |
+
])
|
472 |
+
|
473 |
+
clear_all_button.click(clear_all_history,
|
474 |
+
inputs=[input_img],
|
475 |
+
outputs=[click_img, scribble_img,
|
476 |
+
output_img, click_coords, click_labels, bbox_coords,
|
477 |
+
seperate_scribble_masks, last_scribble_mask,
|
478 |
+
best_mask, low_res_mask, img_features
|
479 |
+
])
|
480 |
+
|
481 |
+
# clear previous prediction mask
|
482 |
+
def clear_best_mask(input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks):
|
483 |
+
|
484 |
+
click_input_viz = viz_pred_mask(
|
485 |
+
input_img, None, click_coords, click_labels, bbox_coords, seperate_scribble_masks
|
486 |
+
)
|
487 |
+
scribble_input_viz = viz_pred_mask(
|
488 |
+
input_img, None, click_coords, click_labels, bbox_coords, None
|
489 |
+
)
|
490 |
+
|
491 |
+
return None, None, click_input_viz, scribble_input_viz
|
492 |
+
|
493 |
+
clear_mask_button.click(
|
494 |
+
clear_best_mask,
|
495 |
+
inputs=[input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks],
|
496 |
+
outputs=[best_mask, low_res_mask, click_img, scribble_img],
|
497 |
+
)
|
498 |
+
|
499 |
+
# ----------------------------------------------
|
500 |
+
# Clicks
|
501 |
+
# ----------------------------------------------
|
502 |
+
|
503 |
+
click_img.select(get_select_coords,
|
504 |
+
inputs=[
|
505 |
+
predictor,
|
506 |
+
input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords, click_labels, bbox_coords,
|
507 |
+
seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
|
508 |
+
output_img, binary_checkbox, multimask_mode, autopredict_checkbox
|
509 |
+
],
|
510 |
+
outputs=[click_img, scribble_img, output_img, best_mask, low_res_mask, img_features,
|
511 |
+
click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask],
|
512 |
+
api_name = "get_select_coords"
|
513 |
+
)
|
514 |
+
|
515 |
+
submit_button.click(fn=refresh_predictions,
|
516 |
+
inputs=[
|
517 |
+
predictor, input_img, output_img, click_coords, click_labels, bbox_coords, brush_label,
|
518 |
+
scribble_img, seperate_scribble_masks, last_scribble_mask,
|
519 |
+
best_mask, low_res_mask, img_features, binary_checkbox, multimask_mode
|
520 |
+
],
|
521 |
+
outputs=[click_img, scribble_img, output_img, best_mask, low_res_mask, img_features,
|
522 |
+
seperate_scribble_masks, last_scribble_mask],
|
523 |
+
api_name="refresh_predictions"
|
524 |
+
)
|
525 |
+
|
526 |
+
undo_click_button.click(fn=undo_click,
|
527 |
+
inputs=[
|
528 |
+
predictor,
|
529 |
+
input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords, click_labels, bbox_coords,
|
530 |
+
seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
|
531 |
+
output_img, binary_checkbox, multimask_mode, autopredict_checkbox
|
532 |
+
],
|
533 |
+
outputs=[click_img, scribble_img, output_img, best_mask, low_res_mask, img_features,
|
534 |
+
click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask],
|
535 |
+
api_name="undo_click"
|
536 |
+
)
|
537 |
+
|
538 |
+
def update_click_img(input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox,
|
539 |
+
last_scribble_mask, scribble_img, brush_label, best_mask):
|
540 |
+
"""
|
541 |
+
Draw scribbles in the click canvas
|
542 |
+
"""
|
543 |
+
seperate_scribble_masks, last_scribble_mask = get_scribbles(
|
544 |
+
seperate_scribble_masks, last_scribble_mask, scribble_img,
|
545 |
+
label=(0 if brush_label == "Positive (green)" else 1) # previous color of the brush
|
546 |
+
)
|
547 |
+
click_input_viz = viz_pred_mask(
|
548 |
+
input_img, best_mask, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox
|
549 |
+
)
|
550 |
+
return click_input_viz, seperate_scribble_masks, last_scribble_mask
|
551 |
+
|
552 |
+
click_tab.select(fn=update_click_img,
|
553 |
+
inputs=[input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks,
|
554 |
+
binary_checkbox, last_scribble_mask, scribble_img, brush_label, best_mask],
|
555 |
+
outputs=[click_img, seperate_scribble_masks, last_scribble_mask],
|
556 |
+
api_name="update_click_img"
|
557 |
+
)
|
558 |
+
|
559 |
+
# ----------------------------------------------
|
560 |
+
# Scribbles
|
561 |
+
# ----------------------------------------------
|
562 |
+
|
563 |
+
def change_brush_color(seperate_scribble_masks, last_scribble_mask, scribble_img, label):
|
564 |
+
"""
|
565 |
+
Recorn new scribbles when changing brush color
|
566 |
+
"""
|
567 |
+
if label == "Negative (red)":
|
568 |
+
brush_update = gr.Image.update(brush_color = "#FF0000") # red
|
569 |
+
elif label == "Positive (green)":
|
570 |
+
brush_update = gr.Image.update(brush_color = "#00FF00") # green
|
571 |
+
else:
|
572 |
+
raise TypeError("Invalid brush color")
|
573 |
+
|
574 |
+
# Record latest scribbles
|
575 |
+
seperate_scribble_masks, last_scribble_mask = get_scribbles(
|
576 |
+
seperate_scribble_masks, last_scribble_mask, scribble_img,
|
577 |
+
label=(1 if label == "Positive (green)" else 0) # previous color of the brush
|
578 |
+
)
|
579 |
+
|
580 |
+
return seperate_scribble_masks, last_scribble_mask, brush_update
|
581 |
+
|
582 |
+
brush_label.change(fn=change_brush_color,
|
583 |
+
inputs=[seperate_scribble_masks, last_scribble_mask, scribble_img, brush_label],
|
584 |
+
outputs=[seperate_scribble_masks, last_scribble_mask, scribble_img],
|
585 |
+
api_name="change_brush_color"
|
586 |
+
)
|
587 |
+
|
588 |
+
|
589 |
+
if __name__ == "__main__":
|
590 |
+
|
591 |
+
demo.queue(api_open=False).launch(show_api=False)
|
checkpoints/ScribblePrompt_unet_v1_nf192_res128.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:43f57ee8fa8ec529c31be281e06749f9e629b30157bbbcc9baf200cddec1acbe
|
3 |
+
size 15977486
|
network.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Dict, Any, List
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
# -----------------------------------------------------------------------------
|
6 |
+
# Blocks
|
7 |
+
# -----------------------------------------------------------------------------
|
8 |
+
|
9 |
+
class Conv2d(nn.Module):
|
10 |
+
""" Perform a 2D convolution
|
11 |
+
|
12 |
+
inputs are [b, c, h, w] where
|
13 |
+
b is the batch size
|
14 |
+
c is the number of channels
|
15 |
+
h is the height
|
16 |
+
w is the width
|
17 |
+
"""
|
18 |
+
def __init__(self,
|
19 |
+
in_channels: int,
|
20 |
+
out_channels: int,
|
21 |
+
kernel_size: int,
|
22 |
+
padding: int,
|
23 |
+
do_activation: bool = True,
|
24 |
+
):
|
25 |
+
super(Conv2d, self).__init__()
|
26 |
+
|
27 |
+
conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
|
28 |
+
lst = [conv]
|
29 |
+
|
30 |
+
if do_activation:
|
31 |
+
lst.append(nn.PReLU())
|
32 |
+
|
33 |
+
self.conv = nn.Sequential(*lst)
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
# x is [B, C, H, W]
|
37 |
+
return self.conv(x)
|
38 |
+
|
39 |
+
# -----------------------------------------------------------------------------
|
40 |
+
# Network
|
41 |
+
# -----------------------------------------------------------------------------
|
42 |
+
|
43 |
+
class _UNet(nn.Module):
|
44 |
+
def __init__(self,
|
45 |
+
in_channels: int = 1,
|
46 |
+
out_channels: int = 1,
|
47 |
+
features: List[int] = [64, 64, 64, 64, 64],
|
48 |
+
conv_kernel_size: int = 3,
|
49 |
+
conv: Optional[nn.Module] = None,
|
50 |
+
conv_kwargs: Dict[str,Any] = {}
|
51 |
+
):
|
52 |
+
"""
|
53 |
+
UNet (but can switch out the Conv)
|
54 |
+
"""
|
55 |
+
super(_UNet, self).__init__()
|
56 |
+
|
57 |
+
self.in_channels = in_channels
|
58 |
+
|
59 |
+
padding = (conv_kernel_size - 1) // 2
|
60 |
+
|
61 |
+
self.ups = nn.ModuleList()
|
62 |
+
self.downs = nn.ModuleList()
|
63 |
+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
64 |
+
|
65 |
+
# Down part of U-Net
|
66 |
+
for feat in features:
|
67 |
+
self.downs.append(
|
68 |
+
conv(
|
69 |
+
in_channels, feat, kernel_size=conv_kernel_size, padding=padding, **conv_kwargs
|
70 |
+
)
|
71 |
+
)
|
72 |
+
in_channels = feat
|
73 |
+
|
74 |
+
# Up part of U-Net
|
75 |
+
for feat in reversed(features):
|
76 |
+
self.ups.append(nn.UpsamplingBilinear2d(scale_factor=2))
|
77 |
+
self.ups.append(
|
78 |
+
conv(
|
79 |
+
# Factor of 2 is for the skip connections
|
80 |
+
feat * 2, feat, kernel_size=conv_kernel_size, padding=padding, **conv_kwargs
|
81 |
+
)
|
82 |
+
)
|
83 |
+
|
84 |
+
self.bottleneck = conv(
|
85 |
+
features[-1], features[-1], kernel_size=conv_kernel_size, padding=padding, **conv_kwargs
|
86 |
+
)
|
87 |
+
self.final_conv = conv(
|
88 |
+
features[0], out_channels, kernel_size=1, padding=0, do_activation=False, **conv_kwargs
|
89 |
+
)
|
90 |
+
|
91 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
92 |
+
skip_connections = []
|
93 |
+
for down in self.downs:
|
94 |
+
x = down(x)
|
95 |
+
skip_connections.append(x)
|
96 |
+
x = self.pool(x)
|
97 |
+
|
98 |
+
x = self.bottleneck(x)
|
99 |
+
skip_connections = skip_connections[::-1]
|
100 |
+
|
101 |
+
for idx in range(0, len(self.ups), 2):
|
102 |
+
x = self.ups[idx](x)
|
103 |
+
skip_connection = skip_connections[idx // 2]
|
104 |
+
|
105 |
+
concat_skip = torch.cat((skip_connection, x), dim=1)
|
106 |
+
x = self.ups[idx + 1](concat_skip)
|
107 |
+
|
108 |
+
return self.final_conv(x)
|
109 |
+
|
110 |
+
|
111 |
+
class UNet(_UNet):
|
112 |
+
"""
|
113 |
+
Unet with normal conv blocks
|
114 |
+
|
115 |
+
input shape: B x C x H x W
|
116 |
+
output shape: B x C x H x W
|
117 |
+
"""
|
118 |
+
def __init__(self, **kwargs) -> None:
|
119 |
+
super().__init__(conv=Conv2d, **kwargs)
|
120 |
+
|
121 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
122 |
+
return super().forward(x)
|
123 |
+
|
predictor.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from typing import Dict, Tuple, Optional
|
4 |
+
import network
|
5 |
+
|
6 |
+
|
7 |
+
class Predictor:
|
8 |
+
"""
|
9 |
+
Wrapper for ScribblePrompt Unet model
|
10 |
+
"""
|
11 |
+
def __init__(self, path: str, verbose: bool = False):
|
12 |
+
|
13 |
+
self.verbose = verbose
|
14 |
+
|
15 |
+
assert path.exists(), f"Checkpoint {path} does not exist"
|
16 |
+
self.path = path
|
17 |
+
|
18 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
19 |
+
self.build_model()
|
20 |
+
self.load()
|
21 |
+
self.model.eval()
|
22 |
+
self.to_device()
|
23 |
+
|
24 |
+
def build_model(self):
|
25 |
+
"""
|
26 |
+
Build the model
|
27 |
+
"""
|
28 |
+
self.model = network.UNet(
|
29 |
+
in_channels = 5,
|
30 |
+
out_channels = 1,
|
31 |
+
features = [192, 192, 192, 192],
|
32 |
+
)
|
33 |
+
|
34 |
+
def load(self):
|
35 |
+
"""
|
36 |
+
Load the state of the model from a checkpoint file.
|
37 |
+
"""
|
38 |
+
with (self.path).open("rb") as f:
|
39 |
+
state = torch.load(f, map_location=self.device)
|
40 |
+
self.model.load_state_dict(state, strict=True)
|
41 |
+
if self.verbose:
|
42 |
+
print(
|
43 |
+
f"Loaded checkpoint from {self.path} to {self.device}"
|
44 |
+
)
|
45 |
+
|
46 |
+
def to_device(self):
|
47 |
+
"""
|
48 |
+
Move the model to cpu or gpu
|
49 |
+
"""
|
50 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
51 |
+
self.model = self.model.to(self.device)
|
52 |
+
|
53 |
+
def predict(self, prompts: Dict[str,any], img_features: Optional[torch.Tensor] = None, multimask_mode: bool = False):
|
54 |
+
"""
|
55 |
+
Make predictions!
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
mask (torch.Tensor): H x W
|
59 |
+
img_features (torch.Tensor): B x 1 x H x W (for SAM models)
|
60 |
+
low_res_mask (torch.Tensor): B x 1 x H x W logits
|
61 |
+
"""
|
62 |
+
if self.verbose:
|
63 |
+
print("point_coords", prompts.get("point_coords", None))
|
64 |
+
print("point_labels", prompts.get("point_labels", None))
|
65 |
+
print("box", prompts.get("box", None))
|
66 |
+
print("img", prompts.get("img").shape, prompts.get("img").min(), prompts.get("img").max())
|
67 |
+
if prompts.get("scribble") is not None:
|
68 |
+
print("scribble", prompts.get("scribble", None).shape, prompts.get("scribble").min(), prompts.get("scribble").max())
|
69 |
+
|
70 |
+
original_shape = prompts.get('img').shape[-2:]
|
71 |
+
|
72 |
+
# Rescale to 128 x 128
|
73 |
+
prompts = rescale_inputs(prompts)
|
74 |
+
|
75 |
+
# Prepare inputs for ScribblePrompt unet (1 x 5 x 128 x 128)
|
76 |
+
x = prepare_inputs(prompts).float()
|
77 |
+
|
78 |
+
with torch.no_grad():
|
79 |
+
yhat = self.model(x.to(self.device)).cpu()
|
80 |
+
|
81 |
+
mask = torch.sigmoid(yhat)
|
82 |
+
|
83 |
+
# Resize for app resolution
|
84 |
+
mask = F.interpolate(mask, size=original_shape, mode='bilinear').squeeze()
|
85 |
+
|
86 |
+
# mask: H x W, yhat: 1 x 1 x H x W
|
87 |
+
return mask, None, yhat
|
88 |
+
|
89 |
+
|
90 |
+
# -----------------------------------------------------------------------------
|
91 |
+
# Prepare inputs
|
92 |
+
# -----------------------------------------------------------------------------
|
93 |
+
|
94 |
+
def rescale_inputs(inputs: Dict[str,any], res=128):
|
95 |
+
"""
|
96 |
+
Rescale the inputs
|
97 |
+
"""
|
98 |
+
h,w = inputs['img'].shape[-2:]
|
99 |
+
if h != res or w != res:
|
100 |
+
|
101 |
+
inputs.update(dict(
|
102 |
+
img = F.interpolate(inputs['img'], size=(res,res), mode='bilinear')
|
103 |
+
))
|
104 |
+
|
105 |
+
if inputs.get('scribble') is not None:
|
106 |
+
inputs.update({
|
107 |
+
'scribble': F.interpolate(inputs['scribble'], size=(res,res), mode='bilinear')
|
108 |
+
})
|
109 |
+
|
110 |
+
if inputs.get("box") is not None:
|
111 |
+
boxes = inputs.get("box").clone()
|
112 |
+
coords = boxes.reshape(-1, 2, 2)
|
113 |
+
coords[..., 0] = coords[..., 0] * (res / w)
|
114 |
+
coords[..., 1] = coords[..., 1] * (res / h)
|
115 |
+
inputs.update({'box': coords.reshape(1, -1, 4).int()})
|
116 |
+
|
117 |
+
if inputs.get("point_coords") is not None:
|
118 |
+
coords = inputs.get("point_coords").clone()
|
119 |
+
coords[..., 0] = coords[..., 0] * (res / w)
|
120 |
+
coords[..., 1] = coords[..., 1] * (res / h)
|
121 |
+
inputs.update({'point_coords': coords.int()})
|
122 |
+
|
123 |
+
return inputs
|
124 |
+
|
125 |
+
def prepare_inputs(inputs: Dict[str,torch.Tensor], device = None) -> torch.Tensor:
|
126 |
+
"""
|
127 |
+
Prepare inputs for ScribblePrompt Unet
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
x (torch.Tensor): B x 5 x H x W
|
131 |
+
"""
|
132 |
+
img = inputs['img']
|
133 |
+
if device is None:
|
134 |
+
device = img.device
|
135 |
+
|
136 |
+
img = img.to(device)
|
137 |
+
shape = tuple(img.shape[-2:])
|
138 |
+
|
139 |
+
if inputs.get("box") is not None:
|
140 |
+
# Embed bounding box
|
141 |
+
# Input: B x 1 x 4
|
142 |
+
# Output: B x 1 x H x W
|
143 |
+
box_embed = bbox_shaded(inputs['box'], shape=shape, device=device)
|
144 |
+
else:
|
145 |
+
box_embed = torch.zeros(img.shape, device=device)
|
146 |
+
|
147 |
+
if inputs.get("point_coords") is not None:
|
148 |
+
# Encode points
|
149 |
+
# B x 2 x H x W
|
150 |
+
scribble_click_embed = click_onehot(inputs['point_coords'], inputs['point_labels'], shape=shape)
|
151 |
+
else:
|
152 |
+
scribble_click_embed = torch.zeros((img.shape[0], 2) + shape, device=device)
|
153 |
+
|
154 |
+
if inputs.get("scribble") is not None:
|
155 |
+
# Combine scribbles with click encoding
|
156 |
+
# B x 2 x H x W
|
157 |
+
scribble_click_embed = torch.clamp(scribble_click_embed + inputs.get('scribble'), min=0.0, max=1.0)
|
158 |
+
|
159 |
+
if inputs.get('mask_input') is not None:
|
160 |
+
# Previous prediction
|
161 |
+
mask_input = inputs['mask_input']
|
162 |
+
else:
|
163 |
+
# Initialize empty channel for mask input
|
164 |
+
mask_input = torch.zeros(img.shape, device=img.device)
|
165 |
+
|
166 |
+
x = torch.cat((img, box_embed, scribble_click_embed, mask_input), dim=-3)
|
167 |
+
# B x 5 x H x W
|
168 |
+
|
169 |
+
return x
|
170 |
+
|
171 |
+
# -----------------------------------------------------------------------------
|
172 |
+
# Encode clicks and bounding boxes
|
173 |
+
# -----------------------------------------------------------------------------
|
174 |
+
|
175 |
+
def click_onehot(point_coords, point_labels, shape: Tuple[int,int] = (128,128), indexing='xy'):
|
176 |
+
"""
|
177 |
+
Represent clicks as two HxW binary masks (one for positive clicks and one for negative)
|
178 |
+
with 1 at the click locations and 0 otherwise
|
179 |
+
|
180 |
+
Args:
|
181 |
+
point_coords (torch.Tensor): BxNx2 tensor of xy coordinates
|
182 |
+
point_labels (torch.Tensor): BxN tensor of labels (0 or 1)
|
183 |
+
shape (tuple): output shape
|
184 |
+
Returns:
|
185 |
+
embed (torch.Tensor): Bx2xHxW tensor
|
186 |
+
"""
|
187 |
+
assert indexing in ['xy','uv'], f"Invalid indexing: {indexing}"
|
188 |
+
assert len(point_coords.shape) == 3, "point_coords must be BxNx2"
|
189 |
+
assert point_coords.shape[-1] == 2, "point_coords must be BxNx2"
|
190 |
+
assert point_labels.shape[-1] == point_coords.shape[1], "point_labels must be BxN"
|
191 |
+
assert len(shape)==2, f"shape must be 2D: {shape}"
|
192 |
+
|
193 |
+
device = point_coords.device
|
194 |
+
batch_size = point_coords.shape[0]
|
195 |
+
n_points = point_coords.shape[1]
|
196 |
+
|
197 |
+
embed = torch.zeros((batch_size,2)+shape, device=device)
|
198 |
+
labels = point_labels.flatten().float()
|
199 |
+
|
200 |
+
idx_coords = torch.cat((
|
201 |
+
torch.arange(batch_size, device=device).reshape(-1,1).repeat(1,n_points)[...,None],
|
202 |
+
point_coords
|
203 |
+
), axis=2).reshape(-1,3)
|
204 |
+
|
205 |
+
if indexing=='xy':
|
206 |
+
embed[ idx_coords[:,0], 0, idx_coords[:,2], idx_coords[:,1] ] = labels
|
207 |
+
embed[ idx_coords[:,0], 1, idx_coords[:,2], idx_coords[:,1] ] = 1.0-labels
|
208 |
+
else:
|
209 |
+
embed[ idx_coords[:,0], 0, idx_coords[:,1], idx_coords[:,2] ] = labels
|
210 |
+
embed[ idx_coords[:,0], 1, idx_coords[:,1], idx_coords[:,2] ] = 1.0-labels
|
211 |
+
|
212 |
+
return embed
|
213 |
+
|
214 |
+
|
215 |
+
def bbox_shaded(boxes, shape: Tuple[int,int] = (128,128), device='cpu'):
|
216 |
+
"""
|
217 |
+
Represent bounding boxes as a binary mask with 1 inside boxes and 0 otherwise
|
218 |
+
|
219 |
+
Args:
|
220 |
+
boxes (torch.Tensor): Bx1x4 [x1, y1, x2, y2]
|
221 |
+
Returns:
|
222 |
+
bbox_embed (torch.Tesor): Bx1xHxW according to shape
|
223 |
+
"""
|
224 |
+
assert len(shape)==2, "shape must be 2D"
|
225 |
+
if isinstance(boxes, torch.Tensor):
|
226 |
+
boxes = boxes.int().cpu().numpy()
|
227 |
+
|
228 |
+
batch_size = boxes.shape[0]
|
229 |
+
n_boxes = boxes.shape[1]
|
230 |
+
bbox_embed = torch.zeros((batch_size,1)+tuple(shape), device=device, dtype=torch.float32)
|
231 |
+
|
232 |
+
if boxes is not None:
|
233 |
+
for i in range(batch_size):
|
234 |
+
for j in range(n_boxes):
|
235 |
+
x1, y1, x2, y2 = boxes[i,j,:]
|
236 |
+
x_min = min(x1,x2)
|
237 |
+
x_max = max(x1,x2)
|
238 |
+
y_min = min(y1,y2)
|
239 |
+
y_max = max(y1,y2)
|
240 |
+
bbox_embed[ i, 0, y_min:y_max, x_min:x_max ] = 1.0
|
241 |
+
|
242 |
+
return bbox_embed
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy
|
2 |
+
torch
|
3 |
+
opencv-python
|
4 |
+
pathlib
|
test_examples/COBRE.jpg
ADDED
test_examples/SCR.jpg
ADDED
test_examples/TotalSegmentator.jpg
ADDED
test_examples/TotalSegmentator_2.jpg
ADDED
val_od_examples/ACDC.jpg
ADDED
val_od_examples/BTCV.jpg
ADDED
val_od_examples/BUID.jpg
ADDED
val_od_examples/DRIVE.jpg
ADDED
val_od_examples/HipXRay.jpg
ADDED
val_od_examples/PanDental.jpg
ADDED
val_od_examples/SCD.jpg
ADDED
val_od_examples/SpineWeb.jpg
ADDED
val_od_examples/WBC.jpg
ADDED