Spaces:
Running
Running
Realcat
commited on
Commit
·
d41cdad
1
Parent(s):
34aafe5
add: omniglue
Browse files- README.md +1 -0
- common/app_class.py +290 -250
- common/config.yaml +11 -0
- env-docker.txt +2 -1
- hloc/match_dense.py +14 -0
- hloc/matchers/omniglue.py +81 -0
- requirements.txt +2 -1
- test_app_cli.py +2 -0
- third_party/omniglue/.gitignore +28 -0
- third_party/omniglue/CHANGELOG.md +31 -0
- third_party/omniglue/CONTRIBUTING.md +29 -0
- third_party/omniglue/LICENSE +202 -0
- third_party/omniglue/README.md +152 -0
- third_party/omniglue/__init__.py +19 -0
- third_party/omniglue/demo.py +91 -0
- third_party/omniglue/init_repo.sh +27 -0
- third_party/omniglue/pyproject.toml +62 -0
- third_party/omniglue/requirements.txt +8 -0
- third_party/omniglue/res/demo1.jpg +3 -0
- third_party/omniglue/res/demo2.jpg +3 -0
- third_party/omniglue/res/demo_output.png +3 -0
- third_party/omniglue/res/og_diagram.png +3 -0
- third_party/omniglue/res/result_tf_and_onnx.png +3 -0
- third_party/omniglue/src/omniglue/__init__.py +17 -0
- third_party/omniglue/src/omniglue/dino_extract.py +207 -0
- third_party/omniglue/src/omniglue/omniglue_extract.py +178 -0
- third_party/omniglue/src/omniglue/superpoint_extract.py +212 -0
- third_party/omniglue/src/omniglue/utils.py +282 -0
- third_party/omniglue/third_party/dinov2/__init__.py +0 -0
- third_party/omniglue/third_party/dinov2/dino.py +411 -0
- third_party/omniglue/third_party/dinov2/dino_utils.py +341 -0
README.md
CHANGED
@@ -34,6 +34,7 @@ Here is a demo of the tool:
|
|
34 |
![demo](assets/demo.gif)
|
35 |
|
36 |
The tool currently supports various popular image matching algorithms, namely:
|
|
|
37 |
- [x] [XFeat](https://github.com/verlab/accelerated_features), CVPR 2024
|
38 |
- [x] [RoMa](https://github.com/Vincentqyw/RoMa), CVPR 2024
|
39 |
- [x] [DeDoDe](https://github.com/Parskatt/DeDoDe), 3DV 2024
|
|
|
34 |
![demo](assets/demo.gif)
|
35 |
|
36 |
The tool currently supports various popular image matching algorithms, namely:
|
37 |
+
- [x] [OmniGlue](https://github.com/Vincentqyw/omniglue-onnx), CVPR 2024
|
38 |
- [x] [XFeat](https://github.com/verlab/accelerated_features), CVPR 2024
|
39 |
- [x] [RoMa](https://github.com/Vincentqyw/RoMa), CVPR 2024
|
40 |
- [x] [DeDoDe](https://github.com/Parskatt/DeDoDe), 3DV 2024
|
common/app_class.py
CHANGED
@@ -12,6 +12,7 @@ from common.utils import (
|
|
12 |
run_ransac,
|
13 |
gen_examples,
|
14 |
GRADIO_VERSION,
|
|
|
15 |
)
|
16 |
|
17 |
|
@@ -49,288 +50,327 @@ class ImageMatchingApp:
|
|
49 |
|
50 |
def init_interface(self):
|
51 |
with gr.Blocks() as self.app:
|
52 |
-
with gr.
|
53 |
-
with gr.
|
54 |
-
gr.
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
show_download_button=False,
|
60 |
-
)
|
61 |
-
with gr.Column(scale=3):
|
62 |
-
gr.Markdown(DESCRIPTION)
|
63 |
-
with gr.Row(equal_height=False):
|
64 |
-
with gr.Column():
|
65 |
-
with gr.Row():
|
66 |
-
matcher_list = gr.Dropdown(
|
67 |
-
choices=self.init_matcher_dropdown(),
|
68 |
-
value="disk+lightglue",
|
69 |
-
label="Matching Model",
|
70 |
-
interactive=True,
|
71 |
-
)
|
72 |
-
match_image_src = gr.Radio(
|
73 |
-
(
|
74 |
-
["upload", "webcam", "clipboard"]
|
75 |
-
if GRADIO_VERSION > "3"
|
76 |
-
else ["upload", "webcam", "canvas"]
|
77 |
),
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
input_image0 = gr.Image(
|
83 |
-
label="Image 0",
|
84 |
-
type="numpy",
|
85 |
-
image_mode="RGB",
|
86 |
-
height=300 if GRADIO_VERSION > "3" else None,
|
87 |
-
interactive=True,
|
88 |
-
)
|
89 |
-
input_image1 = gr.Image(
|
90 |
-
label="Image 1",
|
91 |
-
type="numpy",
|
92 |
-
image_mode="RGB",
|
93 |
-
height=300 if GRADIO_VERSION > "3" else None,
|
94 |
-
interactive=True,
|
95 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
minimum=0.0,
|
108 |
-
maximum=
|
109 |
-
step=0.
|
110 |
-
label="
|
111 |
-
value=0
|
112 |
-
)
|
113 |
-
match_setting_max_features = gr.Slider(
|
114 |
-
minimum=10,
|
115 |
-
maximum=10000,
|
116 |
-
step=10,
|
117 |
-
label="Max features",
|
118 |
-
value=1000,
|
119 |
)
|
120 |
-
|
121 |
-
|
122 |
-
detect_keypoints_threshold = gr.Slider(
|
123 |
-
minimum=0,
|
124 |
maximum=1,
|
125 |
-
step=0.
|
126 |
-
label="
|
127 |
-
value=
|
|
|
|
|
128 |
)
|
129 |
-
|
130 |
-
minimum=0.
|
131 |
-
maximum=
|
132 |
-
step=
|
133 |
-
label="
|
134 |
-
value=
|
|
|
|
|
135 |
)
|
136 |
-
|
137 |
-
|
138 |
-
# label="Matcher mode",
|
139 |
-
# value="NN-mutual",
|
140 |
-
# )
|
141 |
-
with gr.Accordion("RANSAC Setting", open=True):
|
142 |
-
with gr.Row(equal_height=False):
|
143 |
-
ransac_method = gr.Dropdown(
|
144 |
-
choices=ransac_zoo.keys(),
|
145 |
-
value=self.cfg["defaults"]["ransac_method"],
|
146 |
-
label="RANSAC Method",
|
147 |
-
interactive=True,
|
148 |
)
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
)
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
)
|
170 |
-
|
171 |
-
|
|
|
|
|
|
|
|
|
|
|
172 |
)
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
label="
|
178 |
-
value=self.cfg["defaults"][
|
179 |
-
"setting_geometry"
|
180 |
-
],
|
181 |
)
|
182 |
|
183 |
-
#
|
184 |
-
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
input_image0,
|
187 |
input_image1,
|
188 |
match_setting_threshold,
|
189 |
match_setting_max_features,
|
190 |
detect_keypoints_threshold,
|
191 |
matcher_list,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
ransac_method,
|
193 |
ransac_reproj_threshold,
|
194 |
ransac_confidence,
|
195 |
ransac_max_iter,
|
196 |
choice_geometry_type,
|
197 |
-
gr.State(self.matcher_zoo),
|
198 |
-
# state_cache,
|
199 |
]
|
|
|
|
|
|
|
|
|
|
|
200 |
|
201 |
-
#
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
self.display_supported_algorithms()
|
218 |
-
|
219 |
-
with gr.Column():
|
220 |
-
output_keypoints = gr.Image(label="Keypoints", type="numpy")
|
221 |
-
output_matches_raw = gr.Image(
|
222 |
-
label="Raw Matches",
|
223 |
-
type="numpy",
|
224 |
)
|
225 |
-
|
226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
)
|
228 |
-
|
229 |
-
|
230 |
-
):
|
231 |
-
matches_result_info = gr.JSON(
|
232 |
-
label="Matches Statistics"
|
233 |
-
)
|
234 |
-
matcher_info = gr.JSON(label="Match info")
|
235 |
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
)
|
242 |
-
|
243 |
-
"
|
244 |
-
)
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
inputs=match_image_src,
|
253 |
-
outputs=input_image0,
|
254 |
-
)
|
255 |
-
match_image_src.change(
|
256 |
-
fn=self.ui_change_imagebox,
|
257 |
-
inputs=match_image_src,
|
258 |
-
outputs=input_image1,
|
259 |
-
)
|
260 |
-
|
261 |
-
# collect outputs
|
262 |
-
outputs = [
|
263 |
-
output_keypoints,
|
264 |
-
output_matches_raw,
|
265 |
-
output_matches_ransac,
|
266 |
-
matches_result_info,
|
267 |
-
matcher_info,
|
268 |
-
geometry_result,
|
269 |
-
output_wrapped,
|
270 |
-
state_cache,
|
271 |
-
]
|
272 |
-
# button callbacks
|
273 |
-
button_run.click(
|
274 |
-
fn=run_matching, inputs=inputs, outputs=outputs
|
275 |
-
)
|
276 |
-
|
277 |
-
# Reset images
|
278 |
-
reset_outputs = [
|
279 |
-
input_image0,
|
280 |
-
input_image1,
|
281 |
-
match_setting_threshold,
|
282 |
-
match_setting_max_features,
|
283 |
-
detect_keypoints_threshold,
|
284 |
-
matcher_list,
|
285 |
-
input_image0,
|
286 |
-
input_image1,
|
287 |
-
match_image_src,
|
288 |
-
output_keypoints,
|
289 |
-
output_matches_raw,
|
290 |
-
output_matches_ransac,
|
291 |
-
matches_result_info,
|
292 |
-
matcher_info,
|
293 |
-
output_wrapped,
|
294 |
-
geometry_result,
|
295 |
-
ransac_method,
|
296 |
-
ransac_reproj_threshold,
|
297 |
-
ransac_confidence,
|
298 |
-
ransac_max_iter,
|
299 |
-
choice_geometry_type,
|
300 |
-
]
|
301 |
-
button_reset.click(
|
302 |
-
fn=self.ui_reset_state, inputs=None, outputs=reset_outputs
|
303 |
-
)
|
304 |
-
|
305 |
-
# run ransac button action
|
306 |
-
button_ransac.click(
|
307 |
-
fn=run_ransac,
|
308 |
-
inputs=[
|
309 |
-
state_cache,
|
310 |
-
choice_geometry_type,
|
311 |
-
ransac_method,
|
312 |
-
ransac_reproj_threshold,
|
313 |
-
ransac_confidence,
|
314 |
-
ransac_max_iter,
|
315 |
-
],
|
316 |
-
outputs=[
|
317 |
-
output_matches_ransac,
|
318 |
-
matches_result_info,
|
319 |
-
output_wrapped,
|
320 |
-
],
|
321 |
-
)
|
322 |
-
|
323 |
-
# estimate geo
|
324 |
-
choice_geometry_type.change(
|
325 |
-
fn=generate_warp_images,
|
326 |
-
inputs=[
|
327 |
-
input_image0,
|
328 |
-
input_image1,
|
329 |
-
geometry_result,
|
330 |
-
choice_geometry_type,
|
331 |
-
],
|
332 |
-
outputs=[output_wrapped, geometry_result],
|
333 |
-
)
|
334 |
|
335 |
def run(self):
|
336 |
self.app.queue().launch(
|
|
|
12 |
run_ransac,
|
13 |
gen_examples,
|
14 |
GRADIO_VERSION,
|
15 |
+
ROOT,
|
16 |
)
|
17 |
|
18 |
|
|
|
50 |
|
51 |
def init_interface(self):
|
52 |
with gr.Blocks() as self.app:
|
53 |
+
with gr.Tab("Image Matching"):
|
54 |
+
with gr.Row():
|
55 |
+
with gr.Column(scale=1):
|
56 |
+
gr.Image(
|
57 |
+
str(
|
58 |
+
Path(__file__).parent.parent
|
59 |
+
/ "assets/logo.webp"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
),
|
61 |
+
elem_id="logo-img",
|
62 |
+
show_label=False,
|
63 |
+
show_share_button=False,
|
64 |
+
show_download_button=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
)
|
66 |
+
with gr.Column(scale=3):
|
67 |
+
gr.Markdown(DESCRIPTION)
|
68 |
+
with gr.Row(equal_height=False):
|
69 |
+
with gr.Column():
|
70 |
+
with gr.Row():
|
71 |
+
matcher_list = gr.Dropdown(
|
72 |
+
choices=self.init_matcher_dropdown(),
|
73 |
+
value="disk+lightglue",
|
74 |
+
label="Matching Model",
|
75 |
+
interactive=True,
|
76 |
+
)
|
77 |
+
match_image_src = gr.Radio(
|
78 |
+
(
|
79 |
+
["upload", "webcam", "clipboard"]
|
80 |
+
if GRADIO_VERSION > "3"
|
81 |
+
else ["upload", "webcam", "canvas"]
|
82 |
+
),
|
83 |
+
label="Image Source",
|
84 |
+
value="upload",
|
85 |
+
)
|
86 |
+
with gr.Row():
|
87 |
+
input_image0 = gr.Image(
|
88 |
+
label="Image 0",
|
89 |
+
type="numpy",
|
90 |
+
image_mode="RGB",
|
91 |
+
height=300 if GRADIO_VERSION > "3" else None,
|
92 |
+
interactive=True,
|
93 |
+
)
|
94 |
+
input_image1 = gr.Image(
|
95 |
+
label="Image 1",
|
96 |
+
type="numpy",
|
97 |
+
image_mode="RGB",
|
98 |
+
height=300 if GRADIO_VERSION > "3" else None,
|
99 |
+
interactive=True,
|
100 |
+
)
|
101 |
|
102 |
+
with gr.Row():
|
103 |
+
button_reset = gr.Button(value="Reset")
|
104 |
+
button_run = gr.Button(
|
105 |
+
value="Run Match", variant="primary"
|
106 |
+
)
|
107 |
|
108 |
+
with gr.Accordion("Advanced Setting", open=False):
|
109 |
+
with gr.Accordion("Matching Setting", open=True):
|
110 |
+
with gr.Row():
|
111 |
+
match_setting_threshold = gr.Slider(
|
112 |
+
minimum=0.0,
|
113 |
+
maximum=1,
|
114 |
+
step=0.001,
|
115 |
+
label="Match thres.",
|
116 |
+
value=0.1,
|
117 |
+
)
|
118 |
+
match_setting_max_features = gr.Slider(
|
119 |
+
minimum=10,
|
120 |
+
maximum=10000,
|
121 |
+
step=10,
|
122 |
+
label="Max features",
|
123 |
+
value=1000,
|
124 |
+
)
|
125 |
+
# TODO: add line settings
|
126 |
+
with gr.Row():
|
127 |
+
detect_keypoints_threshold = gr.Slider(
|
128 |
+
minimum=0,
|
129 |
+
maximum=1,
|
130 |
+
step=0.001,
|
131 |
+
label="Keypoint thres.",
|
132 |
+
value=0.015,
|
133 |
+
)
|
134 |
+
detect_line_threshold = gr.Slider(
|
135 |
+
minimum=0.1,
|
136 |
+
maximum=1,
|
137 |
+
step=0.01,
|
138 |
+
label="Line thres.",
|
139 |
+
value=0.2,
|
140 |
+
)
|
141 |
+
# matcher_lists = gr.Radio(
|
142 |
+
# ["NN-mutual", "Dual-Softmax"],
|
143 |
+
# label="Matcher mode",
|
144 |
+
# value="NN-mutual",
|
145 |
+
# )
|
146 |
+
with gr.Accordion("RANSAC Setting", open=True):
|
147 |
+
with gr.Row(equal_height=False):
|
148 |
+
ransac_method = gr.Dropdown(
|
149 |
+
choices=ransac_zoo.keys(),
|
150 |
+
value=self.cfg["defaults"][
|
151 |
+
"ransac_method"
|
152 |
+
],
|
153 |
+
label="RANSAC Method",
|
154 |
+
interactive=True,
|
155 |
+
)
|
156 |
+
ransac_reproj_threshold = gr.Slider(
|
157 |
minimum=0.0,
|
158 |
+
maximum=12,
|
159 |
+
step=0.01,
|
160 |
+
label="Ransac Reproj threshold",
|
161 |
+
value=8.0,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
)
|
163 |
+
ransac_confidence = gr.Slider(
|
164 |
+
minimum=0.0,
|
|
|
|
|
165 |
maximum=1,
|
166 |
+
step=0.00001,
|
167 |
+
label="Ransac Confidence",
|
168 |
+
value=self.cfg["defaults"][
|
169 |
+
"ransac_confidence"
|
170 |
+
],
|
171 |
)
|
172 |
+
ransac_max_iter = gr.Slider(
|
173 |
+
minimum=0.0,
|
174 |
+
maximum=100000,
|
175 |
+
step=100,
|
176 |
+
label="Ransac Iterations",
|
177 |
+
value=self.cfg["defaults"][
|
178 |
+
"ransac_max_iter"
|
179 |
+
],
|
180 |
)
|
181 |
+
button_ransac = gr.Button(
|
182 |
+
value="Rerun RANSAC", variant="primary"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
)
|
184 |
+
with gr.Accordion("Geometry Setting", open=False):
|
185 |
+
with gr.Row(equal_height=False):
|
186 |
+
choice_geometry_type = gr.Radio(
|
187 |
+
["Fundamental", "Homography"],
|
188 |
+
label="Reconstruct Geometry",
|
189 |
+
value=self.cfg["defaults"][
|
190 |
+
"setting_geometry"
|
191 |
+
],
|
192 |
+
)
|
193 |
+
|
194 |
+
# collect inputs
|
195 |
+
state_cache = gr.State({})
|
196 |
+
inputs = [
|
197 |
+
input_image0,
|
198 |
+
input_image1,
|
199 |
+
match_setting_threshold,
|
200 |
+
match_setting_max_features,
|
201 |
+
detect_keypoints_threshold,
|
202 |
+
matcher_list,
|
203 |
+
ransac_method,
|
204 |
+
ransac_reproj_threshold,
|
205 |
+
ransac_confidence,
|
206 |
+
ransac_max_iter,
|
207 |
+
choice_geometry_type,
|
208 |
+
gr.State(self.matcher_zoo),
|
209 |
+
# state_cache,
|
210 |
+
]
|
211 |
+
|
212 |
+
# Add some examples
|
213 |
+
with gr.Row():
|
214 |
+
# Example inputs
|
215 |
+
gr.Examples(
|
216 |
+
examples=gen_examples(),
|
217 |
+
inputs=inputs,
|
218 |
+
outputs=[],
|
219 |
+
fn=run_matching,
|
220 |
+
cache_examples=False,
|
221 |
+
label=(
|
222 |
+
"Examples (click one of the images below to Run"
|
223 |
+
" Match). Thx: WxBS"
|
224 |
+
),
|
225 |
)
|
226 |
+
with gr.Accordion("Supported Algorithms", open=False):
|
227 |
+
# add a table of supported algorithms
|
228 |
+
self.display_supported_algorithms()
|
229 |
+
|
230 |
+
with gr.Column():
|
231 |
+
output_keypoints = gr.Image(
|
232 |
+
label="Keypoints", type="numpy"
|
233 |
+
)
|
234 |
+
output_matches_raw = gr.Image(
|
235 |
+
label="Raw Matches",
|
236 |
+
type="numpy",
|
237 |
+
)
|
238 |
+
output_matches_ransac = gr.Image(
|
239 |
+
label="Ransac Matches", type="numpy"
|
240 |
+
)
|
241 |
+
with gr.Accordion(
|
242 |
+
"Open for More: Matches Statistics", open=False
|
243 |
+
):
|
244 |
+
matches_result_info = gr.JSON(
|
245 |
+
label="Matches Statistics"
|
246 |
)
|
247 |
+
matcher_info = gr.JSON(label="Match info")
|
248 |
+
|
249 |
+
with gr.Accordion(
|
250 |
+
"Open for More: Warped Image", open=False
|
251 |
+
):
|
252 |
+
output_wrapped = gr.Image(
|
253 |
+
label="Wrapped Pair", type="numpy"
|
254 |
)
|
255 |
+
with gr.Accordion(
|
256 |
+
"Open for More: Geometry info", open=False
|
257 |
+
):
|
258 |
+
geometry_result = gr.JSON(
|
259 |
+
label="Reconstructed Geometry"
|
|
|
|
|
|
|
260 |
)
|
261 |
|
262 |
+
# callbacks
|
263 |
+
match_image_src.change(
|
264 |
+
fn=self.ui_change_imagebox,
|
265 |
+
inputs=match_image_src,
|
266 |
+
outputs=input_image0,
|
267 |
+
)
|
268 |
+
match_image_src.change(
|
269 |
+
fn=self.ui_change_imagebox,
|
270 |
+
inputs=match_image_src,
|
271 |
+
outputs=input_image1,
|
272 |
+
)
|
273 |
+
|
274 |
+
# collect outputs
|
275 |
+
outputs = [
|
276 |
+
output_keypoints,
|
277 |
+
output_matches_raw,
|
278 |
+
output_matches_ransac,
|
279 |
+
matches_result_info,
|
280 |
+
matcher_info,
|
281 |
+
geometry_result,
|
282 |
+
output_wrapped,
|
283 |
+
state_cache,
|
284 |
+
]
|
285 |
+
# button callbacks
|
286 |
+
button_run.click(
|
287 |
+
fn=run_matching, inputs=inputs, outputs=outputs
|
288 |
+
)
|
289 |
+
|
290 |
+
# Reset images
|
291 |
+
reset_outputs = [
|
292 |
input_image0,
|
293 |
input_image1,
|
294 |
match_setting_threshold,
|
295 |
match_setting_max_features,
|
296 |
detect_keypoints_threshold,
|
297 |
matcher_list,
|
298 |
+
input_image0,
|
299 |
+
input_image1,
|
300 |
+
match_image_src,
|
301 |
+
output_keypoints,
|
302 |
+
output_matches_raw,
|
303 |
+
output_matches_ransac,
|
304 |
+
matches_result_info,
|
305 |
+
matcher_info,
|
306 |
+
output_wrapped,
|
307 |
+
geometry_result,
|
308 |
ransac_method,
|
309 |
ransac_reproj_threshold,
|
310 |
ransac_confidence,
|
311 |
ransac_max_iter,
|
312 |
choice_geometry_type,
|
|
|
|
|
313 |
]
|
314 |
+
button_reset.click(
|
315 |
+
fn=self.ui_reset_state,
|
316 |
+
inputs=None,
|
317 |
+
outputs=reset_outputs,
|
318 |
+
)
|
319 |
|
320 |
+
# run ransac button action
|
321 |
+
button_ransac.click(
|
322 |
+
fn=run_ransac,
|
323 |
+
inputs=[
|
324 |
+
state_cache,
|
325 |
+
choice_geometry_type,
|
326 |
+
ransac_method,
|
327 |
+
ransac_reproj_threshold,
|
328 |
+
ransac_confidence,
|
329 |
+
ransac_max_iter,
|
330 |
+
],
|
331 |
+
outputs=[
|
332 |
+
output_matches_ransac,
|
333 |
+
matches_result_info,
|
334 |
+
output_wrapped,
|
335 |
+
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
336 |
)
|
337 |
+
|
338 |
+
# estimate geo
|
339 |
+
choice_geometry_type.change(
|
340 |
+
fn=generate_warp_images,
|
341 |
+
inputs=[
|
342 |
+
input_image0,
|
343 |
+
input_image1,
|
344 |
+
geometry_result,
|
345 |
+
choice_geometry_type,
|
346 |
+
],
|
347 |
+
outputs=[output_wrapped, geometry_result],
|
348 |
)
|
349 |
+
with gr.Tab("Under construction"):
|
350 |
+
self.init_tab_sfm()
|
|
|
|
|
|
|
|
|
|
|
351 |
|
352 |
+
def init_tab_sfm(self):
|
353 |
+
with gr.Row():
|
354 |
+
with gr.Column():
|
355 |
+
with gr.Row():
|
356 |
+
gr.Textbox("Under construction", label="A", visible=True)
|
357 |
+
gr.Textbox("Under construction", label="B", visible=True)
|
358 |
+
gr.Textbox("Under construction", label="C", visible=True)
|
359 |
+
with gr.Row():
|
360 |
+
with gr.Accordion("Open for More", open=False):
|
361 |
+
gr.Textbox(
|
362 |
+
"Under construction", label="A1", visible=True
|
363 |
)
|
364 |
+
gr.Textbox(
|
365 |
+
"Under construction", label="B1", visible=True
|
366 |
+
)
|
367 |
+
gr.Textbox(
|
368 |
+
"Under construction", label="C1", visible=True
|
369 |
+
)
|
370 |
+
with gr.Column():
|
371 |
+
gr.Textbox("Under construction", label="D", visible=True)
|
372 |
+
gr.Textbox("Under construction", label="E", visible=True)
|
373 |
+
gr.Textbox("Under construction", label="F", visible=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
374 |
|
375 |
def run(self):
|
376 |
self.app.queue().launch(
|
common/config.yaml
CHANGED
@@ -16,6 +16,17 @@ defaults:
|
|
16 |
setting_geometry: Homography
|
17 |
|
18 |
matcher_zoo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
DUSt3R:
|
20 |
# TODO: duster is under development
|
21 |
enable: false
|
|
|
16 |
setting_geometry: Homography
|
17 |
|
18 |
matcher_zoo:
|
19 |
+
omniglue:
|
20 |
+
enable: true
|
21 |
+
matcher: omniglue
|
22 |
+
dense: true
|
23 |
+
info:
|
24 |
+
name: OmniGlue
|
25 |
+
source: "CVPR 2024"
|
26 |
+
github: https://github.com/Vincentqyw/omniglue-onnx
|
27 |
+
paper: https://arxiv.org/abs/2405.12979
|
28 |
+
project: https://hwjiang1510.github.io/OmniGlue/
|
29 |
+
display: true
|
30 |
DUSt3R:
|
31 |
# TODO: duster is under development
|
32 |
enable: false
|
env-docker.txt
CHANGED
@@ -29,4 +29,5 @@ tensorboardX==2.6.1
|
|
29 |
torchmetrics==0.6.0
|
30 |
torchvision==0.17.1
|
31 |
tqdm==4.65.0
|
32 |
-
yacs==0.1.8
|
|
|
|
29 |
torchmetrics==0.6.0
|
30 |
torchvision==0.17.1
|
31 |
tqdm==4.65.0
|
32 |
+
yacs==0.1.8
|
33 |
+
onnxruntime
|
hloc/match_dense.py
CHANGED
@@ -211,6 +211,20 @@ confs = {
|
|
211 |
"dfactor": 8,
|
212 |
},
|
213 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
"sold2": {
|
215 |
"output": "matches-sold2",
|
216 |
"model": {
|
|
|
211 |
"dfactor": 8,
|
212 |
},
|
213 |
},
|
214 |
+
"omniglue": {
|
215 |
+
"output": "matches-omniglue",
|
216 |
+
"model": {
|
217 |
+
"name": "omniglue",
|
218 |
+
"match_threshold": 0.2,
|
219 |
+
"features": "null",
|
220 |
+
},
|
221 |
+
"preprocessing": {
|
222 |
+
"grayscale": False,
|
223 |
+
"resize_max": 1024,
|
224 |
+
"dfactor": 8,
|
225 |
+
"force_resize": False,
|
226 |
+
},
|
227 |
+
},
|
228 |
"sold2": {
|
229 |
"output": "matches-sold2",
|
230 |
"model": {
|
hloc/matchers/omniglue.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import torch
|
3 |
+
import subprocess
|
4 |
+
import numpy as np
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
from .. import logger
|
8 |
+
from ..utils.base_model import BaseModel
|
9 |
+
|
10 |
+
omniglue_path = Path(__file__).parent / "../../third_party/omniglue"
|
11 |
+
sys.path.append(str(omniglue_path))
|
12 |
+
from src import omniglue
|
13 |
+
|
14 |
+
|
15 |
+
class OmniGlue(BaseModel):
|
16 |
+
default_conf = {
|
17 |
+
"match_threshold": 0.02,
|
18 |
+
"max_keypoints": 2048,
|
19 |
+
}
|
20 |
+
required_inputs = ["image0", "image1"]
|
21 |
+
dino_v2_link_dict = {
|
22 |
+
"dinov2_vitb14_pretrain.pth": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth"
|
23 |
+
}
|
24 |
+
|
25 |
+
def _init(self, conf):
|
26 |
+
logger.info(f"Loadeding OmniGlue model")
|
27 |
+
og_model_path = omniglue_path / "models" / "omniglue.onnx"
|
28 |
+
sp_model_path = omniglue_path / "models" / "sp_v6.onnx"
|
29 |
+
dino_model_path = (
|
30 |
+
omniglue_path / "models" / "dinov2_vitb14_pretrain.pth" # ~330MB
|
31 |
+
)
|
32 |
+
if not dino_model_path.exists():
|
33 |
+
link = self.dino_v2_link_dict.get(dino_model_path.name, None)
|
34 |
+
if link is not None:
|
35 |
+
cmd = ["wget", link, "-O", str(dino_model_path)]
|
36 |
+
logger.info(f"Downloading the dinov2 model with `{cmd}`.")
|
37 |
+
subprocess.run(cmd, check=True)
|
38 |
+
else:
|
39 |
+
logger.error(f"Invalid dinov2 model: {dino_model_path.name}")
|
40 |
+
|
41 |
+
self.net = omniglue.OmniGlue(
|
42 |
+
og_export=str(og_model_path),
|
43 |
+
sp_export=str(sp_model_path),
|
44 |
+
dino_export=str(dino_model_path),
|
45 |
+
max_keypoints=self.conf["max_keypoints"] * 4,
|
46 |
+
)
|
47 |
+
logger.info(f"Loaded OmniGlue model done!")
|
48 |
+
|
49 |
+
def _forward(self, data):
|
50 |
+
image0_rgb_np = data["image0"][0].permute(1, 2, 0).cpu().numpy() * 255
|
51 |
+
image1_rgb_np = data["image1"][0].permute(1, 2, 0).cpu().numpy() * 255
|
52 |
+
image0_rgb_np = image0_rgb_np.astype(np.uint8) # RGB, 0-255
|
53 |
+
image1_rgb_np = image1_rgb_np.astype(np.uint8) # RGB, 0-255
|
54 |
+
match_kp0, match_kp1, match_confidences = self.net.FindMatches(
|
55 |
+
image0_rgb_np, image1_rgb_np
|
56 |
+
)
|
57 |
+
|
58 |
+
# filter matches
|
59 |
+
match_threshold = self.conf["match_threshold"]
|
60 |
+
keep_idx = []
|
61 |
+
for i in range(match_kp0.shape[0]):
|
62 |
+
if match_confidences[i] > match_threshold:
|
63 |
+
keep_idx.append(i)
|
64 |
+
num_filtered_matches = len(keep_idx)
|
65 |
+
scores = torch.from_numpy(match_confidences[keep_idx]).reshape(-1, 1)
|
66 |
+
pred = {
|
67 |
+
"keypoints0": torch.from_numpy(match_kp0[keep_idx]),
|
68 |
+
"keypoints1": torch.from_numpy(match_kp1[keep_idx]),
|
69 |
+
"mconf": scores,
|
70 |
+
}
|
71 |
+
|
72 |
+
top_k = self.conf["max_keypoints"]
|
73 |
+
if top_k is not None and len(scores) > top_k:
|
74 |
+
keep = torch.argsort(scores, descending=True)[:top_k]
|
75 |
+
scores = scores[keep]
|
76 |
+
pred["keypoints0"], pred["keypoints1"], pred["mconf"] = (
|
77 |
+
pred["keypoints0"][keep],
|
78 |
+
pred["keypoints1"][keep],
|
79 |
+
scores,
|
80 |
+
)
|
81 |
+
return pred
|
requirements.txt
CHANGED
@@ -30,4 +30,5 @@ tensorboardX==2.6.1
|
|
30 |
torchmetrics==0.6.0
|
31 |
torchvision==0.17.1
|
32 |
tqdm==4.65.0
|
33 |
-
yacs==0.1.8
|
|
|
|
30 |
torchmetrics==0.6.0
|
31 |
torchvision==0.17.1
|
32 |
tqdm==4.65.0
|
33 |
+
yacs==0.1.8
|
34 |
+
onnxruntime
|
test_app_cli.py
CHANGED
@@ -11,6 +11,7 @@ from common.utils import (
|
|
11 |
)
|
12 |
from common.api import ImageMatchingAPI
|
13 |
|
|
|
14 |
def test_api(config: dict = None):
|
15 |
img_path1 = ROOT / "datasets/sacre_coeur/mapping/02928139_3448003521.jpg"
|
16 |
img_path2 = ROOT / "datasets/sacre_coeur/mapping/17295357_9106075285.jpg"
|
@@ -32,6 +33,7 @@ def test_api(config: dict = None):
|
|
32 |
else:
|
33 |
logger.info(f"Skipping {k} ...")
|
34 |
|
|
|
35 |
if __name__ == "__main__":
|
36 |
import argparse
|
37 |
|
|
|
11 |
)
|
12 |
from common.api import ImageMatchingAPI
|
13 |
|
14 |
+
|
15 |
def test_api(config: dict = None):
|
16 |
img_path1 = ROOT / "datasets/sacre_coeur/mapping/02928139_3448003521.jpg"
|
17 |
img_path2 = ROOT / "datasets/sacre_coeur/mapping/17295357_9106075285.jpg"
|
|
|
33 |
else:
|
34 |
logger.info(f"Skipping {k} ...")
|
35 |
|
36 |
+
|
37 |
if __name__ == "__main__":
|
38 |
import argparse
|
39 |
|
third_party/omniglue/.gitignore
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Compiled python modules.
|
2 |
+
*.pyc
|
3 |
+
|
4 |
+
# Byte-compiled
|
5 |
+
_pycache__/
|
6 |
+
.cache/
|
7 |
+
|
8 |
+
# Poetry, setuptools, PyPI distribution artifacts.
|
9 |
+
/*.egg-info
|
10 |
+
.eggs/
|
11 |
+
build/
|
12 |
+
dist/
|
13 |
+
poetry.lock
|
14 |
+
|
15 |
+
# Tests
|
16 |
+
.pytest_cache/
|
17 |
+
|
18 |
+
# Type checking
|
19 |
+
.pytype/
|
20 |
+
|
21 |
+
# Other
|
22 |
+
*.DS_Store
|
23 |
+
|
24 |
+
# PyCharm
|
25 |
+
.idea
|
26 |
+
models/sp_v6*
|
27 |
+
models/og_export*
|
28 |
+
models/dinov2_vitb14_pretrain.pth
|
third_party/omniglue/CHANGELOG.md
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Changelog
|
2 |
+
|
3 |
+
<!--
|
4 |
+
|
5 |
+
Changelog follow the https://keepachangelog.com/ standard (at least the headers)
|
6 |
+
|
7 |
+
This allow to:
|
8 |
+
|
9 |
+
* auto-parsing release notes during the automated releases from github-action:
|
10 |
+
https://github.com/marketplace/actions/pypi-github-auto-release
|
11 |
+
* Have clickable headers in the rendered markdown
|
12 |
+
|
13 |
+
To release a new version (e.g. from `1.0.0` -> `2.0.0`):
|
14 |
+
|
15 |
+
* Create a new `# [2.0.0] - YYYY-MM-DD` header and add the current
|
16 |
+
`[Unreleased]` notes.
|
17 |
+
* At the end of the file:
|
18 |
+
* Define the new link url:
|
19 |
+
`[2.0.0]: https://github.com/google-research/omniglue/compare/v1.0.0...v2.0.0`
|
20 |
+
* Update the `[Unreleased]` url: `v1.0.0...HEAD` -> `v2.0.0...HEAD`
|
21 |
+
|
22 |
+
-->
|
23 |
+
|
24 |
+
## [Unreleased]
|
25 |
+
|
26 |
+
## [0.1.0] - 2022-01-01
|
27 |
+
|
28 |
+
* Initial release
|
29 |
+
|
30 |
+
[Unreleased]: https://github.com/google-research/omniglue/compare/v0.1.0...HEAD
|
31 |
+
[0.1.0]: https://github.com/google-research/omniglue/releases/tag/v0.1.0
|
third_party/omniglue/CONTRIBUTING.md
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# How to Contribute
|
2 |
+
|
3 |
+
We'd love to accept your patches and contributions to this project. There are
|
4 |
+
just a few small guidelines you need to follow.
|
5 |
+
|
6 |
+
## Contributor License Agreement
|
7 |
+
|
8 |
+
Contributions to this project must be accompanied by a Contributor License
|
9 |
+
Agreement (CLA). You (or your employer) retain the copyright to your
|
10 |
+
contribution; this simply gives us permission to use and redistribute your
|
11 |
+
contributions as part of the project. Head over to
|
12 |
+
<https://cla.developers.google.com/> to see your current agreements on file or
|
13 |
+
to sign a new one.
|
14 |
+
|
15 |
+
You generally only need to submit a CLA once, so if you've already submitted one
|
16 |
+
(even if it was for a different project), you probably don't need to do it
|
17 |
+
again.
|
18 |
+
|
19 |
+
## Code Reviews
|
20 |
+
|
21 |
+
All submissions, including submissions by project members, require review. We
|
22 |
+
use GitHub pull requests for this purpose. Consult
|
23 |
+
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
|
24 |
+
information on using pull requests.
|
25 |
+
|
26 |
+
## Community Guidelines
|
27 |
+
|
28 |
+
This project follows
|
29 |
+
[Google's Open Source Community Guidelines](https://opensource.google/conduct/).
|
third_party/omniglue/LICENSE
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Apache License
|
3 |
+
Version 2.0, January 2004
|
4 |
+
http://www.apache.org/licenses/
|
5 |
+
|
6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
7 |
+
|
8 |
+
1. Definitions.
|
9 |
+
|
10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
12 |
+
|
13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
14 |
+
the copyright owner that is granting the License.
|
15 |
+
|
16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
17 |
+
other entities that control, are controlled by, or are under common
|
18 |
+
control with that entity. For the purposes of this definition,
|
19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
20 |
+
direction or management of such entity, whether by contract or
|
21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
23 |
+
|
24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
25 |
+
exercising permissions granted by this License.
|
26 |
+
|
27 |
+
"Source" form shall mean the preferred form for making modifications,
|
28 |
+
including but not limited to software source code, documentation
|
29 |
+
source, and configuration files.
|
30 |
+
|
31 |
+
"Object" form shall mean any form resulting from mechanical
|
32 |
+
transformation or translation of a Source form, including but
|
33 |
+
not limited to compiled object code, generated documentation,
|
34 |
+
and conversions to other media types.
|
35 |
+
|
36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
37 |
+
Object form, made available under the License, as indicated by a
|
38 |
+
copyright notice that is included in or attached to the work
|
39 |
+
(an example is provided in the Appendix below).
|
40 |
+
|
41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
42 |
+
form, that is based on (or derived from) the Work and for which the
|
43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
45 |
+
of this License, Derivative Works shall not include works that remain
|
46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
47 |
+
the Work and Derivative Works thereof.
|
48 |
+
|
49 |
+
"Contribution" shall mean any work of authorship, including
|
50 |
+
the original version of the Work and any modifications or additions
|
51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
55 |
+
means any form of electronic, verbal, or written communication sent
|
56 |
+
to the Licensor or its representatives, including but not limited to
|
57 |
+
communication on electronic mailing lists, source code control systems,
|
58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
60 |
+
excluding communication that is conspicuously marked or otherwise
|
61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
62 |
+
|
63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
65 |
+
subsequently incorporated within the Work.
|
66 |
+
|
67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
72 |
+
Work and such Derivative Works in Source or Object form.
|
73 |
+
|
74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
77 |
+
(except as stated in this section) patent license to make, have made,
|
78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
79 |
+
where such license applies only to those patent claims licensable
|
80 |
+
by such Contributor that are necessarily infringed by their
|
81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
83 |
+
institute patent litigation against any entity (including a
|
84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
85 |
+
or a Contribution incorporated within the Work constitutes direct
|
86 |
+
or contributory patent infringement, then any patent licenses
|
87 |
+
granted to You under this License for that Work shall terminate
|
88 |
+
as of the date such litigation is filed.
|
89 |
+
|
90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
91 |
+
Work or Derivative Works thereof in any medium, with or without
|
92 |
+
modifications, and in Source or Object form, provided that You
|
93 |
+
meet the following conditions:
|
94 |
+
|
95 |
+
(a) You must give any other recipients of the Work or
|
96 |
+
Derivative Works a copy of this License; and
|
97 |
+
|
98 |
+
(b) You must cause any modified files to carry prominent notices
|
99 |
+
stating that You changed the files; and
|
100 |
+
|
101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
102 |
+
that You distribute, all copyright, patent, trademark, and
|
103 |
+
attribution notices from the Source form of the Work,
|
104 |
+
excluding those notices that do not pertain to any part of
|
105 |
+
the Derivative Works; and
|
106 |
+
|
107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
108 |
+
distribution, then any Derivative Works that You distribute must
|
109 |
+
include a readable copy of the attribution notices contained
|
110 |
+
within such NOTICE file, excluding those notices that do not
|
111 |
+
pertain to any part of the Derivative Works, in at least one
|
112 |
+
of the following places: within a NOTICE text file distributed
|
113 |
+
as part of the Derivative Works; within the Source form or
|
114 |
+
documentation, if provided along with the Derivative Works; or,
|
115 |
+
within a display generated by the Derivative Works, if and
|
116 |
+
wherever such third-party notices normally appear. The contents
|
117 |
+
of the NOTICE file are for informational purposes only and
|
118 |
+
do not modify the License. You may add Your own attribution
|
119 |
+
notices within Derivative Works that You distribute, alongside
|
120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
121 |
+
that such additional attribution notices cannot be construed
|
122 |
+
as modifying the License.
|
123 |
+
|
124 |
+
You may add Your own copyright statement to Your modifications and
|
125 |
+
may provide additional or different license terms and conditions
|
126 |
+
for use, reproduction, or distribution of Your modifications, or
|
127 |
+
for any such Derivative Works as a whole, provided Your use,
|
128 |
+
reproduction, and distribution of the Work otherwise complies with
|
129 |
+
the conditions stated in this License.
|
130 |
+
|
131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
133 |
+
by You to the Licensor shall be under the terms and conditions of
|
134 |
+
this License, without any additional terms or conditions.
|
135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
136 |
+
the terms of any separate license agreement you may have executed
|
137 |
+
with Licensor regarding such Contributions.
|
138 |
+
|
139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
141 |
+
except as required for reasonable and customary use in describing the
|
142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
143 |
+
|
144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
145 |
+
agreed to in writing, Licensor provides the Work (and each
|
146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
148 |
+
implied, including, without limitation, any warranties or conditions
|
149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
151 |
+
appropriateness of using or redistributing the Work and assume any
|
152 |
+
risks associated with Your exercise of permissions under this License.
|
153 |
+
|
154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
155 |
+
whether in tort (including negligence), contract, or otherwise,
|
156 |
+
unless required by applicable law (such as deliberate and grossly
|
157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
158 |
+
liable to You for damages, including any direct, indirect, special,
|
159 |
+
incidental, or consequential damages of any character arising as a
|
160 |
+
result of this License or out of the use or inability to use the
|
161 |
+
Work (including but not limited to damages for loss of goodwill,
|
162 |
+
work stoppage, computer failure or malfunction, or any and all
|
163 |
+
other commercial damages or losses), even if such Contributor
|
164 |
+
has been advised of the possibility of such damages.
|
165 |
+
|
166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
169 |
+
or other liability obligations and/or rights consistent with this
|
170 |
+
License. However, in accepting such obligations, You may act only
|
171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
172 |
+
of any other Contributor, and only if You agree to indemnify,
|
173 |
+
defend, and hold each Contributor harmless for any liability
|
174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
175 |
+
of your accepting any such warranty or additional liability.
|
176 |
+
|
177 |
+
END OF TERMS AND CONDITIONS
|
178 |
+
|
179 |
+
APPENDIX: How to apply the Apache License to your work.
|
180 |
+
|
181 |
+
To apply the Apache License to your work, attach the following
|
182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
183 |
+
replaced with your own identifying information. (Don't include
|
184 |
+
the brackets!) The text should be enclosed in the appropriate
|
185 |
+
comment syntax for the file format. We also recommend that a
|
186 |
+
file or class name and description of purpose be included on the
|
187 |
+
same "printed page" as the copyright notice for easier
|
188 |
+
identification within third-party archives.
|
189 |
+
|
190 |
+
Copyright [yyyy] [name of copyright owner]
|
191 |
+
|
192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
193 |
+
you may not use this file except in compliance with the License.
|
194 |
+
You may obtain a copy of the License at
|
195 |
+
|
196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
197 |
+
|
198 |
+
Unless required by applicable law or agreed to in writing, software
|
199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
201 |
+
See the License for the specific language governing permissions and
|
202 |
+
limitations under the License.
|
third_party/omniglue/README.md
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
|
3 |
+
# \[CVPR'24\] Code release for OmniGlue(ONNX)
|
4 |
+
|
5 |
+
[![Open in Spaces](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/Realcat/image-matching-webui)
|
6 |
+
|
7 |
+
<p align="center">
|
8 |
+
<a href="https://hwjiang1510.github.io/">Hanwen Jiang</a>,
|
9 |
+
<a href="https://scholar.google.com/citations?user=jgSItF4AAAAJ">Arjun Karpur</a>,
|
10 |
+
<a href="https://scholar.google.com/citations?user=7EeSOcgAAAAJ">Bingyi Cao</a>,
|
11 |
+
<a href="https://www.cs.utexas.edu/~huangqx/">Qixing Huang</a>,
|
12 |
+
<a href="https://andrefaraujo.github.io/">Andre Araujo</a>
|
13 |
+
</p>
|
14 |
+
|
15 |
+
</div>
|
16 |
+
|
17 |
+
--------------------------------------------------------------------------------
|
18 |
+
|
19 |
+
<div align="center">
|
20 |
+
<a href="https://hwjiang1510.github.io/OmniGlue/"><strong>Project Page</strong></a> |
|
21 |
+
<a href="https://arxiv.org/abs/2405.12979"><strong>Paper</strong></a> |
|
22 |
+
<a href="#installation"><strong>Usage</strong></a> |
|
23 |
+
<a href="https://huggingface.co/spaces/qubvel-hf/omniglue"><strong>Demo</strong></a>
|
24 |
+
</div>
|
25 |
+
|
26 |
+
<br>
|
27 |
+
|
28 |
+
ONNX-compatible release for the CVPR 2024 paper: **OmniGlue: Generalizable Feature
|
29 |
+
Matching with Foundation Model Guidance**.
|
30 |
+
|
31 |
+
![og_diagram.png](res/og_diagram.png "og_diagram.png")
|
32 |
+
|
33 |
+
**Abstract:** The image matching field has been witnessing a continuous
|
34 |
+
emergence of novel learnable feature matching techniques, with ever-improving
|
35 |
+
performance on conventional benchmarks. However, our investigation shows that
|
36 |
+
despite these gains, their potential for real-world applications is restricted
|
37 |
+
by their limited generalization capabilities to novel image domains. In this
|
38 |
+
paper, we introduce OmniGlue, the first learnable image matcher that is designed
|
39 |
+
with generalization as a core principle. OmniGlue leverages broad knowledge from
|
40 |
+
a vision foundation model to guide the feature matching process, boosting
|
41 |
+
generalization to domains not seen at training time. Additionally, we propose a
|
42 |
+
novel keypoint position-guided attention mechanism which disentangles spatial
|
43 |
+
and appearance information, leading to enhanced matching descriptors. We perform
|
44 |
+
comprehensive experiments on a suite of 6 datasets with varied image domains,
|
45 |
+
including scene-level, object-centric and aerial images. OmniGlue’s novel
|
46 |
+
components lead to relative gains on unseen domains of 18.8% with respect to a
|
47 |
+
directly comparable reference model, while also outperforming the recent
|
48 |
+
LightGlue method by 10.1% relatively.
|
49 |
+
|
50 |
+
|
51 |
+
## Installation
|
52 |
+
|
53 |
+
First, use pip to install `omniglue`:
|
54 |
+
|
55 |
+
```sh
|
56 |
+
conda create -n omniglue pip
|
57 |
+
conda activate omniglue
|
58 |
+
|
59 |
+
git clone https://github.com/google-research/omniglue.git
|
60 |
+
cd omniglue
|
61 |
+
pip install -e .
|
62 |
+
```
|
63 |
+
|
64 |
+
Then, download the following models to `./models/`
|
65 |
+
|
66 |
+
```sh
|
67 |
+
# Download to ./models/ dir.
|
68 |
+
mkdir models
|
69 |
+
cd models
|
70 |
+
|
71 |
+
# SuperPoint.
|
72 |
+
git clone https://github.com/rpautrat/SuperPoint.git
|
73 |
+
mv SuperPoint/pretrained_models/sp_v6.tgz . && rm -rf SuperPoint
|
74 |
+
tar zxvf sp_v6.tgz && rm sp_v6.tgz
|
75 |
+
|
76 |
+
# DINOv2 - vit-b14.
|
77 |
+
wget https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth
|
78 |
+
|
79 |
+
# OmniGlue.
|
80 |
+
wget https://storage.googleapis.com/omniglue/og_export.zip
|
81 |
+
unzip og_export.zip && rm og_export.zip
|
82 |
+
```
|
83 |
+
|
84 |
+
Direct download links:
|
85 |
+
|
86 |
+
- [[SuperPoint weights]](https://github.com/rpautrat/SuperPoint/tree/master/pretrained_models): from [github.com/rpautrat/SuperPoint](https://github.com/rpautrat/SuperPoint)
|
87 |
+
- [[DINOv2 weights]](https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth): from [github.com/facebookresearch/dinov2](https://github.com/facebookresearch/dinov2) (ViT-B/14 distilled backbone without register).
|
88 |
+
- [[OmniGlue weights]](https://storage.googleapis.com/omniglue/og_export.zip)
|
89 |
+
|
90 |
+
## Usage
|
91 |
+
The code snippet below outlines how you can perform OmniGlue inference in your
|
92 |
+
own python codebase.
|
93 |
+
|
94 |
+
```py
|
95 |
+
|
96 |
+
from src import omniglue
|
97 |
+
|
98 |
+
image0 = ... # load images from file into np.array
|
99 |
+
image1 = ...
|
100 |
+
|
101 |
+
og = omniglue.OmniGlue(
|
102 |
+
og_export="./models/omniglue.onnx",
|
103 |
+
sp_export="./models/sp_v6.onnx",
|
104 |
+
dino_export="./models/dinov2_vitb14_pretrain.pth",
|
105 |
+
)
|
106 |
+
|
107 |
+
match_kp0s, match_kp1s, match_confidences = og.FindMatches(image0, image1)
|
108 |
+
# Output:
|
109 |
+
# match_kp0: (N, 2) array of (x,y) coordinates in image0.
|
110 |
+
# match_kp1: (N, 2) array of (x,y) coordinates in image1.
|
111 |
+
# match_confidences: N-dim array of each of the N match confidence scores.
|
112 |
+
```
|
113 |
+
|
114 |
+
## Demo
|
115 |
+
|
116 |
+
`demo.py` contains example usage of the `omniglue` module. To try with your own
|
117 |
+
images, replace `./res/demo1.jpg` and `./res/demo2.jpg` with your own
|
118 |
+
filepaths.
|
119 |
+
|
120 |
+
```sh
|
121 |
+
conda activate omniglue
|
122 |
+
python demo.py ./res/demo1.jpg ./res/demo2.jpg
|
123 |
+
# <see output in './demo_output.png'>
|
124 |
+
```
|
125 |
+
|
126 |
+
Expected output:
|
127 |
+
![demo_output.png](res/demo_output.png "demo_output.png")
|
128 |
+
|
129 |
+
Comparison of Results Between TensorFlow and ONNX:
|
130 |
+
![result_tf_and_onnx.png](res/result_tf_and_onnx.png "result_tf_and_onnx.png")
|
131 |
+
|
132 |
+
|
133 |
+
## Repo TODOs
|
134 |
+
|
135 |
+
- ~~Provide `demo.py` example usage script.~~
|
136 |
+
- Support matching for pre-extracted features.
|
137 |
+
- Release eval pipelines for in-domain (MegaDepth).
|
138 |
+
- Release eval pipelines for all out-of-domain datasets.
|
139 |
+
|
140 |
+
## BibTex
|
141 |
+
```
|
142 |
+
@inproceedings{jiang2024Omniglue,
|
143 |
+
title={OmniGlue: Generalizable Feature Matching with Foundation Model Guidance},
|
144 |
+
author={Jiang, Hanwen and Karpur, Arjun and Cao, Bingyi and Huang, Qixing and Araujo, Andre},
|
145 |
+
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
146 |
+
year={2024},
|
147 |
+
}
|
148 |
+
```
|
149 |
+
|
150 |
+
--------------------------------------------------------------------------------
|
151 |
+
|
152 |
+
This is not an officially supported Google product.
|
third_party/omniglue/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Google LLC
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""omniglue API."""
|
16 |
+
|
17 |
+
# A new PyPI release will be pushed every time `__version__` is increased.
|
18 |
+
# When changing this, also update the CHANGELOG.md.
|
19 |
+
__version__ = "0.1.0"
|
third_party/omniglue/demo.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2024 Google LLC
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Demo script for performing OmniGlue inference."""
|
17 |
+
|
18 |
+
import sys
|
19 |
+
import time
|
20 |
+
import matplotlib.pyplot as plt
|
21 |
+
import numpy as np
|
22 |
+
from src import omniglue
|
23 |
+
from src.omniglue import utils
|
24 |
+
from PIL import Image
|
25 |
+
|
26 |
+
|
27 |
+
def main(argv) -> None:
|
28 |
+
if len(argv) != 3:
|
29 |
+
print("error - usage: python demo.py <img1_fp> <img2_fp>")
|
30 |
+
return
|
31 |
+
|
32 |
+
# Load images.
|
33 |
+
print("> Loading images...")
|
34 |
+
image0 = np.array(Image.open(argv[1]))
|
35 |
+
image1 = np.array(Image.open(argv[2]))
|
36 |
+
|
37 |
+
# Load models.
|
38 |
+
print("> Loading OmniGlue (and its submodules: SuperPoint & DINOv2)...")
|
39 |
+
start = time.time()
|
40 |
+
og = omniglue.OmniGlue(
|
41 |
+
og_export="./models/omniglue.onnx",
|
42 |
+
sp_export="./models/sp_v6.onnx",
|
43 |
+
dino_export="./models/dinov2_vitb14_pretrain.pth",
|
44 |
+
)
|
45 |
+
print(f"> \tTook {time.time() - start} seconds.")
|
46 |
+
|
47 |
+
# Perform inference.
|
48 |
+
print("> Finding matches...")
|
49 |
+
start = time.time()
|
50 |
+
match_kp0, match_kp1, match_confidences = og.FindMatches(image0, image1)
|
51 |
+
num_matches = match_kp0.shape[0]
|
52 |
+
print(f"> \tFound {num_matches} matches.")
|
53 |
+
print(f"> \tTook {time.time() - start} seconds.")
|
54 |
+
|
55 |
+
# Filter by confidence (0.02).
|
56 |
+
print("> Filtering matches...")
|
57 |
+
match_threshold = 0.02 # Choose any value [0.0, 1.0).
|
58 |
+
keep_idx = []
|
59 |
+
for i in range(match_kp0.shape[0]):
|
60 |
+
if match_confidences[i] > match_threshold:
|
61 |
+
keep_idx.append(i)
|
62 |
+
num_filtered_matches = len(keep_idx)
|
63 |
+
match_kp0 = match_kp0[keep_idx]
|
64 |
+
match_kp1 = match_kp1[keep_idx]
|
65 |
+
match_confidences = match_confidences[keep_idx]
|
66 |
+
print(
|
67 |
+
f"> \tFound {num_filtered_matches}/{num_matches} above threshold {match_threshold}"
|
68 |
+
)
|
69 |
+
|
70 |
+
# Visualize.
|
71 |
+
print("> Visualizing matches...")
|
72 |
+
viz = utils.visualize_matches(
|
73 |
+
image0,
|
74 |
+
image1,
|
75 |
+
match_kp0,
|
76 |
+
match_kp1,
|
77 |
+
np.eye(num_filtered_matches),
|
78 |
+
show_keypoints=True,
|
79 |
+
highlight_unmatched=True,
|
80 |
+
title=f"{num_filtered_matches} matches",
|
81 |
+
line_width=2,
|
82 |
+
)
|
83 |
+
plt.figure(figsize=(20, 10), dpi=100, facecolor="w", edgecolor="k")
|
84 |
+
plt.axis("off")
|
85 |
+
plt.imshow(viz)
|
86 |
+
plt.imsave("./demo_output.png", viz)
|
87 |
+
print("> \tSaved visualization to ./demo_output.png")
|
88 |
+
|
89 |
+
|
90 |
+
if __name__ == "__main__":
|
91 |
+
main(sys.argv)
|
third_party/omniglue/init_repo.sh
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
mkdir models
|
3 |
+
cd models
|
4 |
+
|
5 |
+
# SuperPoint.
|
6 |
+
git clone https://github.com/rpautrat/SuperPoint.git
|
7 |
+
mv SuperPoint/pretrained_models/sp_v6.tgz . && rm -rf SuperPoint
|
8 |
+
tar zxvf sp_v6.tgz && rm sp_v6.tgz
|
9 |
+
|
10 |
+
# DINOv2 - vit-b14.
|
11 |
+
wget https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth
|
12 |
+
|
13 |
+
# OmniGlue.
|
14 |
+
wget https://storage.googleapis.com/omniglue/og_export.zip
|
15 |
+
unzip og_export.zip && rm og_export.zip
|
16 |
+
|
17 |
+
cd ..
|
18 |
+
|
19 |
+
saved_model=./models/og_export
|
20 |
+
output_onnx=./models/omniglue.onnx
|
21 |
+
python -m tf2onnx.convert --saved-model ${saved_model} --output ${output_onnx} --tag serve
|
22 |
+
|
23 |
+
|
24 |
+
saved_model=./models/sp_v6
|
25 |
+
output_onnx=./models/sp_v6.onnx
|
26 |
+
python -m tf2onnx.convert --saved-model ${saved_model} --output ${output_onnx} --tag serve
|
27 |
+
|
third_party/omniglue/pyproject.toml
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
# Project metadata. Available keys are documented at:
|
3 |
+
# https://packaging.python.org/en/latest/specifications/declaring-project-metadata
|
4 |
+
name = "omniglue"
|
5 |
+
description = "Official code release for CVPR'24 paper 'OmniGlue: Generalizable Feature Matching with Foundation Model Guidance"
|
6 |
+
readme = "README.md"
|
7 |
+
requires-python = ">=3.8"
|
8 |
+
license = {file = "LICENSE"}
|
9 |
+
authors = [{name = "OmniGlue authors"}]
|
10 |
+
classifiers = [ # List of https://pypi.org/classifiers/
|
11 |
+
"License :: OSI Approved :: Apache Software License",
|
12 |
+
"Intended Audience :: Science/Research",
|
13 |
+
]
|
14 |
+
keywords = ["feature matching"]
|
15 |
+
dynamic = ["version", "dependencies"]
|
16 |
+
|
17 |
+
# pip dependencies of the project
|
18 |
+
# Installed locally with `pip install -e .`
|
19 |
+
[tool.setuptools.dynamic]
|
20 |
+
dependencies = {file = ["requirements.txt"]}
|
21 |
+
|
22 |
+
[project.urls]
|
23 |
+
homepage = "https://github.com/google-research/omniglue"
|
24 |
+
repository = "https://github.com/google-research/omniglue"
|
25 |
+
changelog = "https://github.com/google-research/omniglue/blob/main/CHANGELOG.md"
|
26 |
+
# documentation = ""
|
27 |
+
|
28 |
+
[tool.setuptools.packages.find]
|
29 |
+
where = ["src", "third_party"]
|
30 |
+
include = ["omniglue*", "dinov2*"]
|
31 |
+
|
32 |
+
[project.optional-dependencies]
|
33 |
+
# Development deps (unittest, linting, formating,...)
|
34 |
+
# Installed through `pip install -e .[dev]`
|
35 |
+
dev = [
|
36 |
+
"pytest",
|
37 |
+
"pytest-xdist",
|
38 |
+
"pylint>=2.6.0",
|
39 |
+
"pyink",
|
40 |
+
]
|
41 |
+
|
42 |
+
[tool.pyink]
|
43 |
+
# Formatting configuration to follow Google style-guide
|
44 |
+
line-length = 80
|
45 |
+
unstable = true
|
46 |
+
pyink-indentation = 2
|
47 |
+
pyink-use-majority-quotes = true
|
48 |
+
|
49 |
+
[build-system]
|
50 |
+
# Build system specify which backend is used to build/install the project (flit,
|
51 |
+
# poetry, setuptools,...). All backends are supported by `pip install`
|
52 |
+
requires = ["setuptools", "wheel"]
|
53 |
+
build-backend = "setuptools.build_meta"
|
54 |
+
|
55 |
+
[tool.flit.sdist]
|
56 |
+
# Flit specific options (files to exclude from the PyPI package).
|
57 |
+
# If using another build backend (setuptools, poetry), you can remove this
|
58 |
+
# section.
|
59 |
+
exclude = [
|
60 |
+
# Do not release tests files on PyPI
|
61 |
+
"**/*_test.py",
|
62 |
+
]
|
third_party/omniglue/requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
matplotlib
|
2 |
+
numpy
|
3 |
+
opencv-python
|
4 |
+
Pillow
|
5 |
+
torch
|
6 |
+
gdown
|
7 |
+
tf2onnx
|
8 |
+
onnxruntime
|
third_party/omniglue/res/demo1.jpg
ADDED
Git LFS Details
|
third_party/omniglue/res/demo2.jpg
ADDED
Git LFS Details
|
third_party/omniglue/res/demo_output.png
ADDED
Git LFS Details
|
third_party/omniglue/res/og_diagram.png
ADDED
Git LFS Details
|
third_party/omniglue/res/result_tf_and_onnx.png
ADDED
Git LFS Details
|
third_party/omniglue/src/omniglue/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Google LLC
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from . import omniglue_extract
|
16 |
+
|
17 |
+
OmniGlue = omniglue_extract.OmniGlue
|
third_party/omniglue/src/omniglue/dino_extract.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Google LLC
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Wrapper for performing DINOv2 inference."""
|
16 |
+
|
17 |
+
import cv2
|
18 |
+
import numpy as np
|
19 |
+
from third_party.dinov2 import dino
|
20 |
+
|
21 |
+
from . import utils
|
22 |
+
import torch
|
23 |
+
|
24 |
+
|
25 |
+
class DINOExtract:
|
26 |
+
"""Class to initialize DINO model and extract features from an image."""
|
27 |
+
|
28 |
+
def __init__(self, cpt_path: str, feature_layer: int = 1):
|
29 |
+
self.feature_layer = feature_layer
|
30 |
+
self.model = dino.vit_base()
|
31 |
+
state_dict_raw = torch.load(cpt_path, map_location="cpu")
|
32 |
+
|
33 |
+
# state_dict = {}
|
34 |
+
# for k, v in state_dict_raw.items():
|
35 |
+
# state_dict[k.replace('blocks', 'blocks.0')] = v
|
36 |
+
|
37 |
+
self.model.load_state_dict(state_dict_raw)
|
38 |
+
self.model.eval()
|
39 |
+
|
40 |
+
self.image_size_max = 630
|
41 |
+
|
42 |
+
self.h_down_rate = self.model.patch_embed.patch_size[0]
|
43 |
+
self.w_down_rate = self.model.patch_embed.patch_size[1]
|
44 |
+
|
45 |
+
def __call__(self, image) -> np.ndarray:
|
46 |
+
return self.forward(image)
|
47 |
+
|
48 |
+
def forward(self, image: np.ndarray) -> np.ndarray:
|
49 |
+
"""Feeds image through DINO ViT model to extract features.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
image: (H, W, 3) numpy array, decoded image bytes, value range [0, 255].
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
features: (H // 14, W // 14, C) numpy array image features.
|
56 |
+
"""
|
57 |
+
image = self._resize_input_image(image)
|
58 |
+
image_processed = self._process_image(image)
|
59 |
+
image_processed = image_processed.unsqueeze(0).float()
|
60 |
+
features = self.extract_feature(image_processed)
|
61 |
+
features = features.squeeze(0).permute(1, 2, 0).cpu().numpy()
|
62 |
+
return features
|
63 |
+
|
64 |
+
def _resize_input_image(
|
65 |
+
self, image: np.ndarray, interpolation=cv2.INTER_LINEAR
|
66 |
+
):
|
67 |
+
"""Resizes image such that both dimensions are divisble by down_rate."""
|
68 |
+
h_image, w_image = image.shape[:2]
|
69 |
+
h_larger_flag = h_image > w_image
|
70 |
+
large_side_image = max(h_image, w_image)
|
71 |
+
|
72 |
+
# resize the image with the largest side length smaller than a threshold
|
73 |
+
# to accelerate ViT backbone inference (which has quadratic complexity).
|
74 |
+
if large_side_image > self.image_size_max:
|
75 |
+
if h_larger_flag:
|
76 |
+
h_image_target = self.image_size_max
|
77 |
+
w_image_target = int(self.image_size_max * w_image / h_image)
|
78 |
+
else:
|
79 |
+
w_image_target = self.image_size_max
|
80 |
+
h_image_target = int(self.image_size_max * h_image / w_image)
|
81 |
+
else:
|
82 |
+
h_image_target = h_image
|
83 |
+
w_image_target = w_image
|
84 |
+
|
85 |
+
h, w = (
|
86 |
+
h_image_target // self.h_down_rate,
|
87 |
+
w_image_target // self.w_down_rate,
|
88 |
+
)
|
89 |
+
h_resize, w_resize = h * self.h_down_rate, w * self.w_down_rate
|
90 |
+
image = cv2.resize(
|
91 |
+
image, (w_resize, h_resize), interpolation=interpolation
|
92 |
+
)
|
93 |
+
return image
|
94 |
+
|
95 |
+
def _process_image(self, image: np.ndarray) -> torch.Tensor:
|
96 |
+
"""Turn image into pytorch tensor and normalize it."""
|
97 |
+
mean = np.array([0.485, 0.456, 0.406])
|
98 |
+
std = np.array([0.229, 0.224, 0.225])
|
99 |
+
|
100 |
+
image_processed = image / 255.0
|
101 |
+
image_processed = (image_processed - mean) / std
|
102 |
+
image_processed = torch.from_numpy(image_processed).permute(2, 0, 1)
|
103 |
+
return image_processed
|
104 |
+
|
105 |
+
def extract_feature(self, image):
|
106 |
+
"""Extracts features from image.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
image: (B, 3, H, W) torch tensor, normalized with ImageNet mean/std.
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
features: (B, C, H//14, W//14) torch tensor image features.
|
113 |
+
"""
|
114 |
+
b, _, h_origin, w_origin = image.shape
|
115 |
+
out = self.model.get_intermediate_layers(image, n=self.feature_layer)[0]
|
116 |
+
h = int(h_origin / self.h_down_rate)
|
117 |
+
w = int(w_origin / self.w_down_rate)
|
118 |
+
dim = out.shape[-1]
|
119 |
+
out = out.reshape(b, h, w, dim).permute(0, 3, 1, 2).detach()
|
120 |
+
return out
|
121 |
+
|
122 |
+
|
123 |
+
def _preprocess_shape(
|
124 |
+
h_image, w_image, image_size_max=630, h_down_rate=14, w_down_rate=14
|
125 |
+
):
|
126 |
+
h_image = h_image.squeeze()
|
127 |
+
w_image = w_image.squeeze()
|
128 |
+
|
129 |
+
h_larger_flag = h_image > w_image
|
130 |
+
large_side_image = max(h_image, w_image)
|
131 |
+
|
132 |
+
def resize_h_larger():
|
133 |
+
h_image_target = image_size_max
|
134 |
+
w_image_target = int(image_size_max * w_image / h_image)
|
135 |
+
return h_image_target, w_image_target
|
136 |
+
|
137 |
+
def resize_w_larger_or_equal():
|
138 |
+
w_image_target = image_size_max
|
139 |
+
h_image_target = int(image_size_max * h_image / w_image)
|
140 |
+
return h_image_target, w_image_target
|
141 |
+
|
142 |
+
def keep_original():
|
143 |
+
return h_image, w_image
|
144 |
+
|
145 |
+
if large_side_image > image_size_max:
|
146 |
+
if h_larger_flag:
|
147 |
+
h_image_target, w_image_target = resize_h_larger()
|
148 |
+
else:
|
149 |
+
h_image_target, w_image_target = resize_w_larger_or_equal()
|
150 |
+
else:
|
151 |
+
h_image_target, w_image_target = keep_original()
|
152 |
+
|
153 |
+
h = h_image_target // h_down_rate
|
154 |
+
w = w_image_target // w_down_rate
|
155 |
+
h_resize = torch.tensor(h * h_down_rate)
|
156 |
+
w_resize = torch.tensor(w * w_down_rate)
|
157 |
+
|
158 |
+
h_resize = h_resize.unsqueeze(0)
|
159 |
+
w_resize = w_resize.unsqueeze(0)
|
160 |
+
|
161 |
+
return h_resize, w_resize
|
162 |
+
|
163 |
+
|
164 |
+
def get_dino_descriptors(dino_features, keypoints, height, width, feature_dim):
|
165 |
+
"""Get DINO descriptors using Superpoint keypoints.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
dino_features: DINO features in 1-D.
|
169 |
+
keypoints: Superpoint keypoint locations, in format (x, y), in pixels, shape
|
170 |
+
(N, 2).
|
171 |
+
height: image height, type torch int32.
|
172 |
+
width: image width, type torch int32.
|
173 |
+
feature_dim: DINO feature channel size, type torch int32.
|
174 |
+
|
175 |
+
Returns:
|
176 |
+
Interpolated DINO descriptors.
|
177 |
+
"""
|
178 |
+
height_1d = height.reshape([1])
|
179 |
+
width_1d = width.reshape([1])
|
180 |
+
|
181 |
+
height_1d_resized, width_1d_resized = _preprocess_shape(
|
182 |
+
height_1d, width_1d, image_size_max=630, h_down_rate=14, w_down_rate=14
|
183 |
+
)
|
184 |
+
|
185 |
+
height_feat = height_1d_resized // 14
|
186 |
+
width_feat = width_1d_resized // 14
|
187 |
+
feature_dim_1d = torch.tensor(feature_dim).reshape([1])
|
188 |
+
|
189 |
+
dino_features = dino_features.reshape(
|
190 |
+
height_feat, width_feat, feature_dim_1d
|
191 |
+
)
|
192 |
+
|
193 |
+
img_size = torch.cat([width_1d, height_1d], dim=0).float()
|
194 |
+
feature_size = torch.cat([width_feat, height_feat], dim=0).float()
|
195 |
+
keypoints_feature = (
|
196 |
+
keypoints[0] / img_size.unsqueeze(0) * feature_size.unsqueeze(0)
|
197 |
+
)
|
198 |
+
|
199 |
+
dino_descriptors = []
|
200 |
+
for kp in keypoints_feature:
|
201 |
+
dino_descriptors.append(
|
202 |
+
utils.lookup_descriptor_bilinear(kp.numpy(), dino_features)
|
203 |
+
)
|
204 |
+
dino_descriptors = torch.tensor(
|
205 |
+
np.array(dino_descriptors), dtype=torch.float32
|
206 |
+
)
|
207 |
+
return dino_descriptors
|
third_party/omniglue/src/omniglue/omniglue_extract.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Google LLC
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Wrapper for performing OmniGlue inference, plus (optionally) SP/DINO."""
|
16 |
+
import cv2
|
17 |
+
import torch
|
18 |
+
import numpy as np
|
19 |
+
import onnxruntime
|
20 |
+
|
21 |
+
from . import dino_extract
|
22 |
+
from . import superpoint_extract
|
23 |
+
from . import utils
|
24 |
+
|
25 |
+
|
26 |
+
DINO_FEATURE_DIM = 768
|
27 |
+
MATCH_THRESHOLD = 1e-3
|
28 |
+
|
29 |
+
|
30 |
+
class OmniGlue:
|
31 |
+
# TODO(omniglue): class docstring
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
og_export: str,
|
36 |
+
sp_export: str | None = None,
|
37 |
+
dino_export: str | None = None,
|
38 |
+
max_keypoints: int = 2048,
|
39 |
+
) -> None:
|
40 |
+
self.max_keypoints = max_keypoints
|
41 |
+
self.matcher = onnxruntime.InferenceSession(og_export)
|
42 |
+
if sp_export is not None:
|
43 |
+
self.sp_extract = superpoint_extract.SuperPointExtract(sp_export)
|
44 |
+
if dino_export is not None:
|
45 |
+
self.dino_extract = dino_extract.DINOExtract(
|
46 |
+
dino_export, feature_layer=1
|
47 |
+
)
|
48 |
+
|
49 |
+
def FindMatches(self, image0: np.ndarray, image1: np.ndarray):
|
50 |
+
"""TODO(omniglue): docstring."""
|
51 |
+
height0, width0 = image0.shape[:2]
|
52 |
+
height1, width1 = image1.shape[:2]
|
53 |
+
# TODO: numpy to torch inputs
|
54 |
+
sp_features0 = self.sp_extract(image0, num_features=self.max_keypoints)
|
55 |
+
sp_features1 = self.sp_extract(image1, num_features=self.max_keypoints)
|
56 |
+
dino_features0 = self.dino_extract(image0)
|
57 |
+
dino_features1 = self.dino_extract(image1)
|
58 |
+
dino_descriptors0 = dino_extract.get_dino_descriptors(
|
59 |
+
dino_features0,
|
60 |
+
sp_features0,
|
61 |
+
torch.tensor(height0),
|
62 |
+
torch.tensor(width0),
|
63 |
+
DINO_FEATURE_DIM,
|
64 |
+
)
|
65 |
+
dino_descriptors1 = dino_extract.get_dino_descriptors(
|
66 |
+
dino_features1,
|
67 |
+
sp_features1,
|
68 |
+
torch.tensor(height1),
|
69 |
+
torch.tensor(width1),
|
70 |
+
DINO_FEATURE_DIM,
|
71 |
+
)
|
72 |
+
|
73 |
+
inputs = self._construct_inputs(
|
74 |
+
width0,
|
75 |
+
height0,
|
76 |
+
width1,
|
77 |
+
height1,
|
78 |
+
sp_features0,
|
79 |
+
sp_features1,
|
80 |
+
dino_descriptors0,
|
81 |
+
dino_descriptors1,
|
82 |
+
)
|
83 |
+
|
84 |
+
og_outputs = self.matcher.run(None, inputs)
|
85 |
+
soft_assignment = torch.from_numpy(og_outputs[0][:, :-1, :-1])
|
86 |
+
|
87 |
+
match_matrix = (
|
88 |
+
utils.soft_assignment_to_match_matrix(
|
89 |
+
soft_assignment, MATCH_THRESHOLD
|
90 |
+
)
|
91 |
+
.numpy()
|
92 |
+
.squeeze()
|
93 |
+
)
|
94 |
+
|
95 |
+
# Filter out any matches with 0.0 confidence keypoints.
|
96 |
+
match_indices = np.argwhere(match_matrix)
|
97 |
+
keep = []
|
98 |
+
for i in range(match_indices.shape[0]):
|
99 |
+
match = match_indices[i, :]
|
100 |
+
if (sp_features0[2][match[0]] > 0.0) and (
|
101 |
+
sp_features1[2][match[1]] > 0.0
|
102 |
+
):
|
103 |
+
keep.append(i)
|
104 |
+
match_indices = match_indices[keep]
|
105 |
+
|
106 |
+
# Format matches in terms of keypoint locations.
|
107 |
+
match_kp0s = []
|
108 |
+
match_kp1s = []
|
109 |
+
match_confidences = []
|
110 |
+
for match in match_indices:
|
111 |
+
match_kp0s.append(sp_features0[0][match[0], :])
|
112 |
+
match_kp1s.append(sp_features1[0][match[1], :])
|
113 |
+
match_confidences.append(soft_assignment[0, match[0], match[1]])
|
114 |
+
match_kp0s = np.array(match_kp0s)
|
115 |
+
match_kp1s = np.array(match_kp1s)
|
116 |
+
match_confidences = np.array(match_confidences)
|
117 |
+
return match_kp0s, match_kp1s, match_confidences
|
118 |
+
|
119 |
+
### Private methods ###
|
120 |
+
|
121 |
+
def _construct_inputs(
|
122 |
+
self,
|
123 |
+
width0,
|
124 |
+
height0,
|
125 |
+
width1,
|
126 |
+
height1,
|
127 |
+
sp_features0,
|
128 |
+
sp_features1,
|
129 |
+
dino_descriptors0,
|
130 |
+
dino_descriptors1,
|
131 |
+
):
|
132 |
+
keypoints0 = sp_features0[0]
|
133 |
+
keypoints1 = sp_features1[0]
|
134 |
+
descriptors0 = sp_features0[1]
|
135 |
+
descriptors1 = sp_features1[1]
|
136 |
+
scores0 = sp_features0[2]
|
137 |
+
scores1 = sp_features1[2]
|
138 |
+
descriptors0_dino = dino_descriptors0
|
139 |
+
descriptors1_dino = dino_descriptors1
|
140 |
+
if isinstance(keypoints0, torch.Tensor):
|
141 |
+
keypoints0 = keypoints0.detach().numpy()
|
142 |
+
if isinstance(keypoints1, torch.Tensor):
|
143 |
+
keypoints1 = keypoints1.detach().numpy()
|
144 |
+
if isinstance(descriptors0, torch.Tensor):
|
145 |
+
descriptors0 = descriptors0.detach().numpy()
|
146 |
+
if isinstance(descriptors1, torch.Tensor):
|
147 |
+
descriptors1 = descriptors1.detach().numpy()
|
148 |
+
if isinstance(scores0, torch.Tensor):
|
149 |
+
scores0 = scores0.detach().numpy()
|
150 |
+
if isinstance(scores1, torch.Tensor):
|
151 |
+
scores1 = scores1.detach().numpy()
|
152 |
+
if isinstance(descriptors0_dino, torch.Tensor):
|
153 |
+
descriptors0_dino = descriptors0_dino.detach().numpy()
|
154 |
+
if isinstance(descriptors1_dino, torch.Tensor):
|
155 |
+
descriptors1_dino = descriptors1_dino.detach().numpy()
|
156 |
+
inputs = {
|
157 |
+
"keypoints0": np.expand_dims(keypoints0, axis=0).astype(np.float32),
|
158 |
+
"keypoints1": np.expand_dims(keypoints1, axis=0).astype(np.float32),
|
159 |
+
"descriptors0": np.expand_dims(descriptors0, axis=0).astype(
|
160 |
+
np.float32
|
161 |
+
),
|
162 |
+
"descriptors1": np.expand_dims(descriptors1, axis=0).astype(
|
163 |
+
np.float32
|
164 |
+
),
|
165 |
+
"scores0": np.expand_dims(
|
166 |
+
np.expand_dims(scores0, axis=0), axis=-1
|
167 |
+
).astype(np.float32),
|
168 |
+
"scores1": np.expand_dims(
|
169 |
+
np.expand_dims(scores1, axis=0), axis=-1
|
170 |
+
).astype(np.float32),
|
171 |
+
"descriptors0_dino": np.expand_dims(descriptors0_dino, axis=0),
|
172 |
+
"descriptors1_dino": np.expand_dims(descriptors1_dino, axis=0),
|
173 |
+
"width0": np.expand_dims(width0, axis=0).astype(np.int32),
|
174 |
+
"width1": np.expand_dims(width1, axis=0).astype(np.int32),
|
175 |
+
"height0": np.expand_dims(height0, axis=0).astype(np.int32),
|
176 |
+
"height1": np.expand_dims(height1, axis=0).astype(np.int32),
|
177 |
+
}
|
178 |
+
return inputs
|
third_party/omniglue/src/omniglue/superpoint_extract.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Google LLC
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Wrapper for performing SuperPoint inference."""
|
16 |
+
|
17 |
+
import math
|
18 |
+
from typing import Optional, Tuple
|
19 |
+
|
20 |
+
import cv2
|
21 |
+
import numpy as np
|
22 |
+
from . import utils
|
23 |
+
import onnxruntime
|
24 |
+
|
25 |
+
|
26 |
+
class SuperPointExtract:
|
27 |
+
"""Class to initialize SuperPoint model and extract features from an image.
|
28 |
+
|
29 |
+
To stay consistent with SuperPoint training and eval configurations, resize
|
30 |
+
images to (320x240) or (640x480).
|
31 |
+
|
32 |
+
Attributes
|
33 |
+
model_path: string, filepath to saved SuperPoint ONNX model weights.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self, model_path: str):
|
37 |
+
self.model_path = model_path
|
38 |
+
self.net = onnxruntime.InferenceSession(self.model_path)
|
39 |
+
|
40 |
+
def __call__(
|
41 |
+
self,
|
42 |
+
image,
|
43 |
+
segmentation_mask=None,
|
44 |
+
num_features=1024,
|
45 |
+
pad_random_features=False,
|
46 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
47 |
+
return self.compute(
|
48 |
+
image,
|
49 |
+
segmentation_mask=segmentation_mask,
|
50 |
+
num_features=num_features,
|
51 |
+
pad_random_features=pad_random_features,
|
52 |
+
)
|
53 |
+
|
54 |
+
def compute(
|
55 |
+
self,
|
56 |
+
image: np.ndarray,
|
57 |
+
segmentation_mask: Optional[np.ndarray] = None,
|
58 |
+
num_features: int = 1024,
|
59 |
+
pad_random_features: bool = False,
|
60 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
61 |
+
"""Feeds image through SuperPoint model to extract keypoints and features.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
image: (H, W, 3) numpy array, decoded image bytes.
|
65 |
+
segmentation_mask: (H, W) binary numpy array or None. If not None,
|
66 |
+
extracted keypoints are restricted to being within the mask.
|
67 |
+
num_features: max number of features to extract (or 0 to indicate keeping
|
68 |
+
all extracted features).
|
69 |
+
pad_random_features: if True, adds randomly sampled keypoints to the
|
70 |
+
output such that there are exactly 'num_features' keypoints. Descriptors
|
71 |
+
for these sampled keypoints are taken from the network's descriptor map
|
72 |
+
output, and scores are set to 0. No action taken if num_features = 0.
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
keypoints: (N, 2) numpy array, coordinates of keypoints as floats.
|
76 |
+
descriptors: (N, 256) numpy array, descriptors for keypoints as floats.
|
77 |
+
scores: (N, 1) numpy array, confidence values for keypoints as floats.
|
78 |
+
"""
|
79 |
+
|
80 |
+
# Resize image so both dimensions are divisible by 8.
|
81 |
+
image, keypoint_scale_factors = self._resize_input_image(image)
|
82 |
+
if segmentation_mask is not None:
|
83 |
+
segmentation_mask, _ = self._resize_input_image(
|
84 |
+
segmentation_mask, interpolation=cv2.INTER_NEAREST
|
85 |
+
)
|
86 |
+
assert (
|
87 |
+
segmentation_mask is None
|
88 |
+
or image.shape[:2] == segmentation_mask.shape[:2]
|
89 |
+
)
|
90 |
+
|
91 |
+
# Preprocess and feed-forward image.
|
92 |
+
image_preprocessed = self._preprocess_image(image)
|
93 |
+
out = self.net.run(
|
94 |
+
None,
|
95 |
+
{
|
96 |
+
self.net.get_inputs()[0].name: np.expand_dims(
|
97 |
+
image_preprocessed, 0
|
98 |
+
)
|
99 |
+
},
|
100 |
+
)
|
101 |
+
# Format output from network.
|
102 |
+
keypoint_map = np.squeeze(out[5])
|
103 |
+
descriptor_map = np.squeeze(out[0])
|
104 |
+
if segmentation_mask is not None:
|
105 |
+
keypoint_map = np.where(segmentation_mask, keypoint_map, 0.0)
|
106 |
+
keypoints, descriptors, scores = self._extract_superpoint_output(
|
107 |
+
keypoint_map, descriptor_map, num_features, pad_random_features
|
108 |
+
)
|
109 |
+
|
110 |
+
# Rescale keypoint locations to match original input image size, and return.
|
111 |
+
keypoints = keypoints / keypoint_scale_factors
|
112 |
+
return (keypoints, descriptors, scores)
|
113 |
+
|
114 |
+
def _resize_input_image(self, image, interpolation=cv2.INTER_LINEAR):
|
115 |
+
"""Resizes image such that both dimensions are divisble by 8."""
|
116 |
+
|
117 |
+
# Calculate new image dimensions and per-dimension resizing scale factor.
|
118 |
+
new_dim = [-1, -1]
|
119 |
+
keypoint_scale_factors = [1.0, 1.0]
|
120 |
+
for i in range(2):
|
121 |
+
dim_size = image.shape[i]
|
122 |
+
mod_eight = dim_size % 8
|
123 |
+
if mod_eight < 4:
|
124 |
+
# Round down to nearest multiple of 8.
|
125 |
+
new_dim[i] = dim_size - mod_eight
|
126 |
+
elif mod_eight >= 4:
|
127 |
+
# Round up to nearest multiple of 8.
|
128 |
+
new_dim[i] = dim_size + (8 - mod_eight)
|
129 |
+
keypoint_scale_factors[i] = (new_dim[i] - 1) / (dim_size - 1)
|
130 |
+
|
131 |
+
# Resize and return image + scale factors.
|
132 |
+
new_dim = new_dim[::-1] # Convert from (row, col) to (x,y).
|
133 |
+
keypoint_scale_factors = keypoint_scale_factors[::-1]
|
134 |
+
image = cv2.resize(image, tuple(new_dim), interpolation=interpolation)
|
135 |
+
return image, keypoint_scale_factors
|
136 |
+
|
137 |
+
def _preprocess_image(self, image):
|
138 |
+
"""Converts image to grayscale and normalizes values for model input."""
|
139 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
140 |
+
image = np.expand_dims(image, 2)
|
141 |
+
image = image.astype(np.float32)
|
142 |
+
image = image / 255.0
|
143 |
+
return image
|
144 |
+
|
145 |
+
def _extract_superpoint_output(
|
146 |
+
self,
|
147 |
+
keypoint_map,
|
148 |
+
descriptor_map,
|
149 |
+
keep_k_points=512,
|
150 |
+
pad_random_features=False,
|
151 |
+
):
|
152 |
+
"""Converts from raw SuperPoint output (feature maps) into numpy arrays.
|
153 |
+
|
154 |
+
If keep_k_points is 0, then keep all detected keypoints. Otherwise, sort by
|
155 |
+
confidence and keep only the top k confidence keypoints.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
keypoint_map: (H, W, 1) numpy array, raw output confidence values from
|
159 |
+
SuperPoint model.
|
160 |
+
descriptor_map: (H, W, 256) numpy array, raw output descriptors from
|
161 |
+
SuperPoint model.
|
162 |
+
keep_k_points: int, number of keypoints to keep (or 0 to indicate keeping
|
163 |
+
all detected keypoints).
|
164 |
+
pad_random_features: if True, adds randomly sampled keypoints to the
|
165 |
+
output such that there are exactly 'num_features' keypoints. Descriptors
|
166 |
+
for these sampled keypoints are taken from the network's descriptor map
|
167 |
+
output, and scores are set to 0. No action taken if keep_k_points = 0.
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
keypoints: (N, 2) numpy array, image coordinates (x, y) of keypoints as
|
171 |
+
floats.
|
172 |
+
descriptors: (N, 256) numpy array, descriptors for keypoints as floats.
|
173 |
+
scores: (N, 1) numpy array, confidence values for keypoints as floats.
|
174 |
+
"""
|
175 |
+
|
176 |
+
def _select_k_best(points, k):
|
177 |
+
sorted_prob = points[points[:, 2].argsort(), :]
|
178 |
+
start = min(k, points.shape[0])
|
179 |
+
return sorted_prob[-start:, :2], sorted_prob[-start:, 2]
|
180 |
+
|
181 |
+
keypoints = np.where(keypoint_map > 0)
|
182 |
+
prob = keypoint_map[keypoints[0], keypoints[1]]
|
183 |
+
keypoints = np.stack([keypoints[0], keypoints[1], prob], axis=-1)
|
184 |
+
|
185 |
+
# Keep only top k points, or all points if keep_k_points param is 0.
|
186 |
+
if keep_k_points == 0:
|
187 |
+
keep_k_points = keypoints.shape[0]
|
188 |
+
keypoints, scores = _select_k_best(keypoints, keep_k_points)
|
189 |
+
|
190 |
+
# Optionally, pad with random features (and confidence scores of 0).
|
191 |
+
image_shape = np.array(keypoint_map.shape[:2])
|
192 |
+
if pad_random_features and (keep_k_points > keypoints.shape[0]):
|
193 |
+
num_pad = keep_k_points - keypoints.shape[0]
|
194 |
+
keypoints_pad = (image_shape - 1) * np.random.uniform(
|
195 |
+
size=(num_pad, 2)
|
196 |
+
)
|
197 |
+
keypoints = np.concatenate((keypoints, keypoints_pad))
|
198 |
+
scores_pad = np.zeros((num_pad))
|
199 |
+
scores = np.concatenate((scores, scores_pad))
|
200 |
+
|
201 |
+
# Lookup descriptors via bilinear interpolation.
|
202 |
+
# TODO: batch descriptor lookup with bilinear interpolation.
|
203 |
+
keypoints[:, [0, 1]] = keypoints[
|
204 |
+
:, [1, 0]
|
205 |
+
] # Swap from (row,col) to (x,y).
|
206 |
+
descriptors = []
|
207 |
+
for kp in keypoints:
|
208 |
+
descriptors.append(
|
209 |
+
utils.lookup_descriptor_bilinear(kp, descriptor_map)
|
210 |
+
)
|
211 |
+
descriptors = np.array(descriptors)
|
212 |
+
return keypoints, descriptors, scores
|
third_party/omniglue/src/omniglue/utils.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Google LLC
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Shared utility functions for OmniGlue."""
|
16 |
+
import cv2
|
17 |
+
import torch
|
18 |
+
import math
|
19 |
+
import numpy as np
|
20 |
+
from typing import Optional
|
21 |
+
|
22 |
+
|
23 |
+
def lookup_descriptor_bilinear(
|
24 |
+
keypoint: np.ndarray, descriptor_map: np.ndarray
|
25 |
+
) -> np.ndarray:
|
26 |
+
"""Looks up descriptor value for keypoint from a dense descriptor map.
|
27 |
+
|
28 |
+
Uses bilinear interpolation to find descriptor value at non-integer
|
29 |
+
positions.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
keypoint: 2-dim numpy array containing (x, y) keypoint image coordinates.
|
33 |
+
descriptor_map: (H, W, D) numpy array representing a dense descriptor map.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
D-dim descriptor value at the input 'keypoint' location.
|
37 |
+
|
38 |
+
Raises:
|
39 |
+
ValueError, if kepoint position is out of bounds.
|
40 |
+
"""
|
41 |
+
height, width = descriptor_map.shape[:2]
|
42 |
+
if (
|
43 |
+
keypoint[0] < 0
|
44 |
+
or keypoint[0] > width
|
45 |
+
or keypoint[1] < 0
|
46 |
+
or keypoint[1] > height
|
47 |
+
):
|
48 |
+
raise ValueError(
|
49 |
+
"Keypoint position (%f, %f) is out of descriptor map bounds (%i w x"
|
50 |
+
" %i h)." % (keypoint[0], keypoint[1], width, height)
|
51 |
+
)
|
52 |
+
|
53 |
+
x_range = [math.floor(keypoint[0])]
|
54 |
+
if not keypoint[0].is_integer() and keypoint[0] < width - 1:
|
55 |
+
x_range.append(x_range[0] + 1)
|
56 |
+
y_range = [math.floor(keypoint[1])]
|
57 |
+
if not keypoint[1].is_integer() and keypoint[1] < height - 1:
|
58 |
+
y_range.append(y_range[0] + 1)
|
59 |
+
|
60 |
+
bilinear_descriptor = np.zeros(descriptor_map.shape[2])
|
61 |
+
for curr_x in x_range:
|
62 |
+
for curr_y in y_range:
|
63 |
+
curr_descriptor = descriptor_map[curr_y, curr_x, :]
|
64 |
+
bilinear_scalar = (1.0 - abs(keypoint[0] - curr_x)) * (
|
65 |
+
1.0 - abs(keypoint[1] - curr_y)
|
66 |
+
)
|
67 |
+
bilinear_descriptor += bilinear_scalar * curr_descriptor
|
68 |
+
return bilinear_descriptor
|
69 |
+
|
70 |
+
|
71 |
+
def soft_assignment_to_match_matrix(
|
72 |
+
soft_assignment: torch.Tensor, match_threshold: float
|
73 |
+
) -> torch.Tensor:
|
74 |
+
"""Converts a matrix of soft assignment values to binary yes/no match matrix.
|
75 |
+
|
76 |
+
Searches soft_assignment for row- and column-maximum values, which indicate
|
77 |
+
mutual nearest neighbor matches between two unique sets of keypoints. Also,
|
78 |
+
ensures that score values for matches are above the specified threshold.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
soft_assignment: (B, N, M) tensor, contains matching likelihood value
|
82 |
+
between features of different sets. N is number of features in image0, and
|
83 |
+
M is number of features in image1. Higher value indicates more likely to
|
84 |
+
match.
|
85 |
+
match_threshold: float, thresholding value to consider a match valid.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
(B, N, M) tensor of binary values. A value of 1 at index (x, y) indicates
|
89 |
+
a match between index 'x' (out of N) in image0 and index 'y' (out of M) in
|
90 |
+
image 1.
|
91 |
+
"""
|
92 |
+
|
93 |
+
def _range_like(x, dim):
|
94 |
+
return torch.arange(x.shape[dim], dtype=x.dtype)
|
95 |
+
|
96 |
+
matches = []
|
97 |
+
for i in range(soft_assignment.shape[0]):
|
98 |
+
scores = soft_assignment[i, :].unsqueeze(0)
|
99 |
+
|
100 |
+
max0 = torch.max(scores, dim=2)[0]
|
101 |
+
indices0 = torch.argmax(scores, dim=2)
|
102 |
+
indices1 = torch.argmax(scores, dim=1)
|
103 |
+
|
104 |
+
mutual = _range_like(indices0, 1).unsqueeze(0) == indices1.gather(
|
105 |
+
1, indices0
|
106 |
+
)
|
107 |
+
|
108 |
+
kp_ind_pairs = torch.stack(
|
109 |
+
[_range_like(indices0, 1), indices0.squeeze()], dim=1
|
110 |
+
)
|
111 |
+
mutual_max0 = torch.where(
|
112 |
+
mutual, max0, torch.zeros_like(max0)
|
113 |
+
).squeeze()
|
114 |
+
sparse = torch.sparse_coo_tensor(
|
115 |
+
kp_ind_pairs.t(), mutual_max0, scores.shape[1:]
|
116 |
+
)
|
117 |
+
match_matrix = sparse.to_dense()
|
118 |
+
matches.append(match_matrix)
|
119 |
+
|
120 |
+
match_matrix = torch.stack(matches)
|
121 |
+
match_matrix = match_matrix > match_threshold
|
122 |
+
return match_matrix
|
123 |
+
|
124 |
+
|
125 |
+
def visualize_matches(
|
126 |
+
image0: np.ndarray,
|
127 |
+
image1: np.ndarray,
|
128 |
+
kp0: np.ndarray,
|
129 |
+
kp1: np.ndarray,
|
130 |
+
match_matrix: np.ndarray,
|
131 |
+
match_labels: Optional[np.ndarray] = None,
|
132 |
+
show_keypoints: bool = False,
|
133 |
+
highlight_unmatched: bool = False,
|
134 |
+
title: Optional[str] = None,
|
135 |
+
line_width: int = 1,
|
136 |
+
circle_radius: int = 4,
|
137 |
+
circle_thickness: int = 2,
|
138 |
+
rng: Optional["np.random.Generator"] = None,
|
139 |
+
):
|
140 |
+
"""Generates visualization of keypoints and matches for two images.
|
141 |
+
|
142 |
+
Stacks image0 and image1 horizontally. In case the two images have different
|
143 |
+
heights, scales image1 (and its keypoints) to match image0's height. Note
|
144 |
+
that keypoints must be in (x, y) format, NOT (row, col). If match_matrix
|
145 |
+
includes unmatched dustbins, the dustbins will be removed before visualizing
|
146 |
+
matches.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
image0: (H, W, 3) array containing image0 contents.
|
150 |
+
image1: (H, W, 3) array containing image1 contents.
|
151 |
+
kp0: (N, 2) array where each row represents (x, y) coordinates of keypoints
|
152 |
+
in image0.
|
153 |
+
kp1: (M, 2) array, where each row represents (x, y) coordinates of keypoints
|
154 |
+
in image1.
|
155 |
+
match_matrix: (N, M) binary array, where values are non-zero for keypoint
|
156 |
+
indices making up a match.
|
157 |
+
match_labels: (N, M) binary array, where values are non-zero for keypoint
|
158 |
+
indices making up a ground-truth match. When None, matches from
|
159 |
+
'match_matrix' are colored randomly. Otherwise, matches from
|
160 |
+
'match_matrix' are colored according to accuracy (compared to labels).
|
161 |
+
show_keypoints: if True, all image0 and image1 keypoints (including
|
162 |
+
unmatched ones) are visualized.
|
163 |
+
highlight_unmatched: if True, highlights unmatched keypoints in blue.
|
164 |
+
title: if not None, adds title text to top left of visualization.
|
165 |
+
line_width: width of correspondence line, in pixels.
|
166 |
+
circle_radius: radius of keypoint circles, if visualized.
|
167 |
+
circle_thickness: thickness of keypoint circles, if visualized.
|
168 |
+
rng: np random number generator to generate the line colors.
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
Numpy array of image0 and image1 side-by-side, with lines between matches
|
172 |
+
according to match_matrix. If show_keypoints is True, keypoints from both
|
173 |
+
images are also visualized.
|
174 |
+
"""
|
175 |
+
# initialize RNG
|
176 |
+
if rng is None:
|
177 |
+
rng = np.random.default_rng()
|
178 |
+
|
179 |
+
# Make copy of input param that may be modified in this function.
|
180 |
+
kp1 = np.copy(kp1)
|
181 |
+
|
182 |
+
# Detect unmatched dustbins.
|
183 |
+
has_unmatched_dustbins = (match_matrix.shape[0] == kp0.shape[0] + 1) and (
|
184 |
+
match_matrix.shape[1] == kp1.shape[0] + 1
|
185 |
+
)
|
186 |
+
|
187 |
+
# If necessary, resize image1 so that the pair can be stacked horizontally.
|
188 |
+
height0 = image0.shape[0]
|
189 |
+
height1 = image1.shape[0]
|
190 |
+
if height0 != height1:
|
191 |
+
scale_factor = height0 / height1
|
192 |
+
if scale_factor <= 1.0:
|
193 |
+
interp_method = cv2.INTER_AREA
|
194 |
+
else:
|
195 |
+
interp_method = cv2.INTER_LINEAR
|
196 |
+
new_dim1 = (int(image1.shape[1] * scale_factor), height0)
|
197 |
+
image1 = cv2.resize(image1, new_dim1, interpolation=interp_method)
|
198 |
+
kp1 *= scale_factor
|
199 |
+
|
200 |
+
# Create side-by-side image and add lines for all matches.
|
201 |
+
viz = cv2.hconcat([image0, image1])
|
202 |
+
w0 = image0.shape[1]
|
203 |
+
matches = np.argwhere(
|
204 |
+
match_matrix[:-1, :-1] if has_unmatched_dustbins else match_matrix
|
205 |
+
)
|
206 |
+
for match in matches:
|
207 |
+
mpt0 = kp0[match[0]]
|
208 |
+
mpt1 = kp1[match[1]]
|
209 |
+
if isinstance(mpt0, torch.Tensor):
|
210 |
+
mpt0 = mpt0.numpy()
|
211 |
+
if isinstance(mpt1, torch.Tensor):
|
212 |
+
mpt1 = mpt1.numpy()
|
213 |
+
pt0 = (int(mpt0[0]), int(mpt0[1]))
|
214 |
+
pt1 = (int(mpt1[0] + w0), int(mpt1[1]))
|
215 |
+
if match_labels is None:
|
216 |
+
color = tuple(rng.integers(0, 255, size=3).tolist())
|
217 |
+
else:
|
218 |
+
if match_labels[match[0], match[1]]:
|
219 |
+
color = (0, 255, 0)
|
220 |
+
else:
|
221 |
+
color = (255, 0, 0)
|
222 |
+
cv2.line(viz, pt0, pt1, color, line_width)
|
223 |
+
|
224 |
+
# Optionally, add circles to output image to represent each keypoint.
|
225 |
+
if show_keypoints:
|
226 |
+
for i in range(np.shape(kp0)[0]):
|
227 |
+
kp = kp0[i].numpy() if isinstance(kp0[i], torch.Tensor) else kp0[i]
|
228 |
+
if (
|
229 |
+
highlight_unmatched
|
230 |
+
and has_unmatched_dustbins
|
231 |
+
and match_matrix[i, -1]
|
232 |
+
):
|
233 |
+
cv2.circle(
|
234 |
+
viz,
|
235 |
+
tuple(kp.astype(np.int32).tolist()),
|
236 |
+
circle_radius,
|
237 |
+
(255, 0, 0),
|
238 |
+
circle_thickness,
|
239 |
+
)
|
240 |
+
else:
|
241 |
+
cv2.circle(
|
242 |
+
viz,
|
243 |
+
tuple(kp.astype(np.int32).tolist()),
|
244 |
+
circle_radius,
|
245 |
+
(0, 0, 255),
|
246 |
+
circle_thickness,
|
247 |
+
)
|
248 |
+
for j in range(np.shape(kp1)[0]):
|
249 |
+
kp = kp1[j].numpy() if isinstance(kp1[j], torch.Tensor) else kp1[j]
|
250 |
+
kp[0] += w0
|
251 |
+
if (
|
252 |
+
highlight_unmatched
|
253 |
+
and has_unmatched_dustbins
|
254 |
+
and match_matrix[-1, j]
|
255 |
+
):
|
256 |
+
cv2.circle(
|
257 |
+
viz,
|
258 |
+
tuple(kp.astype(np.int32).tolist()),
|
259 |
+
circle_radius,
|
260 |
+
(255, 0, 0),
|
261 |
+
circle_thickness,
|
262 |
+
)
|
263 |
+
else:
|
264 |
+
cv2.circle(
|
265 |
+
viz,
|
266 |
+
tuple(kp.astype(np.int32).tolist()),
|
267 |
+
circle_radius,
|
268 |
+
(0, 0, 255),
|
269 |
+
circle_thickness,
|
270 |
+
)
|
271 |
+
if title is not None:
|
272 |
+
viz = cv2.putText(
|
273 |
+
viz,
|
274 |
+
title,
|
275 |
+
(5, 30),
|
276 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
277 |
+
1,
|
278 |
+
(0, 0, 255),
|
279 |
+
2,
|
280 |
+
cv2.LINE_AA,
|
281 |
+
)
|
282 |
+
return viz
|
third_party/omniglue/third_party/dinov2/__init__.py
ADDED
File without changes
|
third_party/omniglue/third_party/dinov2/dino.py
ADDED
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
10 |
+
|
11 |
+
from functools import partial
|
12 |
+
import math
|
13 |
+
from typing import Callable, Sequence, Tuple, Union
|
14 |
+
|
15 |
+
from third_party.dinov2 import dino_utils
|
16 |
+
import torch
|
17 |
+
from torch import nn
|
18 |
+
from torch.nn.init import trunc_normal_
|
19 |
+
import torch.utils.checkpoint
|
20 |
+
|
21 |
+
|
22 |
+
def named_apply(
|
23 |
+
fn: Callable,
|
24 |
+
module: nn.Module,
|
25 |
+
name="",
|
26 |
+
depth_first=True,
|
27 |
+
include_root=False,
|
28 |
+
) -> nn.Module:
|
29 |
+
if not depth_first and include_root:
|
30 |
+
fn(module=module, name=name)
|
31 |
+
for child_name, child_module in module.named_children():
|
32 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
33 |
+
named_apply(
|
34 |
+
fn=fn,
|
35 |
+
module=child_module,
|
36 |
+
name=child_name,
|
37 |
+
depth_first=depth_first,
|
38 |
+
include_root=True,
|
39 |
+
)
|
40 |
+
if depth_first and include_root:
|
41 |
+
fn(module=module, name=name)
|
42 |
+
return module
|
43 |
+
|
44 |
+
|
45 |
+
class BlockChunk(nn.ModuleList):
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
for b in self:
|
49 |
+
x = b(x)
|
50 |
+
return x
|
51 |
+
|
52 |
+
|
53 |
+
class DinoVisionTransformer(nn.Module):
|
54 |
+
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
img_size=518,
|
58 |
+
patch_size=16,
|
59 |
+
in_chans=3,
|
60 |
+
embed_dim=768,
|
61 |
+
depth=12,
|
62 |
+
num_heads=12,
|
63 |
+
mlp_ratio=4.0,
|
64 |
+
qkv_bias=True,
|
65 |
+
ffn_bias=True,
|
66 |
+
proj_bias=True,
|
67 |
+
drop_path_rate=0.0,
|
68 |
+
drop_path_uniform=False,
|
69 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
70 |
+
embed_layer=dino_utils.PatchEmbed,
|
71 |
+
act_layer=nn.GELU,
|
72 |
+
block_fn=dino_utils.Block,
|
73 |
+
ffn_layer="mlp",
|
74 |
+
block_chunks=0,
|
75 |
+
):
|
76 |
+
"""Args:
|
77 |
+
|
78 |
+
img_size (int, tuple): input image size
|
79 |
+
patch_size (int, tuple): patch size
|
80 |
+
in_chans (int): number of input channels
|
81 |
+
embed_dim (int): embedding dimension
|
82 |
+
depth (int): depth of transformer
|
83 |
+
num_heads (int): number of attention heads
|
84 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
85 |
+
qkv_bias (bool): enable bias for qkv if True
|
86 |
+
proj_bias (bool): enable bias for proj in attn if True
|
87 |
+
ffn_bias (bool): enable bias for ffn if True
|
88 |
+
drop_path_rate (float): stochastic depth rate
|
89 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
90 |
+
weight_init (str): weight init scheme
|
91 |
+
init_values (float): layer-scale init values
|
92 |
+
embed_layer (nn.Module): patch embedding layer
|
93 |
+
act_layer (nn.Module): MLP activation layer
|
94 |
+
block_fn (nn.Module): transformer block class
|
95 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
96 |
+
block_chunks: (int) split block sequence into block_chunks units for
|
97 |
+
FSDP wrap
|
98 |
+
"""
|
99 |
+
super().__init__()
|
100 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
101 |
+
|
102 |
+
self.num_features = self.embed_dim = (
|
103 |
+
embed_dim # num_features for consistency with other models
|
104 |
+
)
|
105 |
+
self.num_tokens = 1
|
106 |
+
self.n_blocks = depth
|
107 |
+
self.num_heads = num_heads
|
108 |
+
self.patch_size = patch_size
|
109 |
+
|
110 |
+
self.patch_embed = embed_layer(
|
111 |
+
img_size=img_size,
|
112 |
+
patch_size=patch_size,
|
113 |
+
in_chans=in_chans,
|
114 |
+
embed_dim=embed_dim,
|
115 |
+
)
|
116 |
+
num_patches = self.patch_embed.num_patches
|
117 |
+
|
118 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
119 |
+
self.pos_embed = nn.Parameter(
|
120 |
+
torch.zeros(1, num_patches + self.num_tokens, embed_dim)
|
121 |
+
)
|
122 |
+
|
123 |
+
if drop_path_uniform is True:
|
124 |
+
dpr = [drop_path_rate] * depth
|
125 |
+
else:
|
126 |
+
dpr = [
|
127 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
128 |
+
] # stochastic depth decay rule
|
129 |
+
|
130 |
+
if ffn_layer == "mlp":
|
131 |
+
ffn_layer = dino_utils.Mlp
|
132 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
133 |
+
# ffn_layer = SwiGLUFFNFused
|
134 |
+
raise NotImplementedError("FFN only support mlp but using swiglu")
|
135 |
+
elif ffn_layer == "identity":
|
136 |
+
|
137 |
+
def f(*args, **kwargs):
|
138 |
+
return nn.Identity()
|
139 |
+
|
140 |
+
ffn_layer = f
|
141 |
+
else:
|
142 |
+
raise NotImplementedError
|
143 |
+
|
144 |
+
blocks_list = [
|
145 |
+
block_fn(
|
146 |
+
dim=embed_dim,
|
147 |
+
num_heads=num_heads,
|
148 |
+
mlp_ratio=mlp_ratio,
|
149 |
+
qkv_bias=qkv_bias,
|
150 |
+
proj_bias=proj_bias,
|
151 |
+
ffn_bias=ffn_bias,
|
152 |
+
drop_path=dpr[i],
|
153 |
+
norm_layer=norm_layer,
|
154 |
+
act_layer=act_layer,
|
155 |
+
ffn_layer=ffn_layer,
|
156 |
+
init_values=init_values,
|
157 |
+
)
|
158 |
+
for i in range(depth)
|
159 |
+
]
|
160 |
+
if block_chunks > 0:
|
161 |
+
self.chunked_blocks = True
|
162 |
+
chunked_blocks = []
|
163 |
+
chunksize = depth // block_chunks
|
164 |
+
for i in range(0, depth, chunksize):
|
165 |
+
# this is to keep the block index consistent if we chunk the block list
|
166 |
+
chunked_blocks.append(
|
167 |
+
[nn.Identity()] * i + blocks_list[i : i + chunksize]
|
168 |
+
)
|
169 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
170 |
+
else:
|
171 |
+
self.chunked_blocks = False
|
172 |
+
self.blocks = nn.ModuleList(blocks_list)
|
173 |
+
|
174 |
+
self.norm = norm_layer(embed_dim)
|
175 |
+
self.head = nn.Identity()
|
176 |
+
|
177 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
178 |
+
|
179 |
+
self.init_weights()
|
180 |
+
|
181 |
+
def init_weights(self):
|
182 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
183 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
184 |
+
named_apply(init_weights_vit_timm, self)
|
185 |
+
|
186 |
+
def interpolate_pos_encoding(self, x, w, h):
|
187 |
+
previous_dtype = x.dtype
|
188 |
+
npatch = x.shape[1] - 1
|
189 |
+
N = self.pos_embed.shape[1] - 1
|
190 |
+
if npatch == N and w == h:
|
191 |
+
return self.pos_embed
|
192 |
+
pos_embed = self.pos_embed.float()
|
193 |
+
class_pos_embed = pos_embed[:, 0]
|
194 |
+
patch_pos_embed = pos_embed[:, 1:]
|
195 |
+
dim = x.shape[-1]
|
196 |
+
w0 = w // self.patch_size
|
197 |
+
h0 = h // self.patch_size
|
198 |
+
# we add a small number to avoid floating point error in the interpolation
|
199 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
200 |
+
w0, h0 = w0 + 0.1, h0 + 0.1
|
201 |
+
|
202 |
+
patch_pos_embed = nn.functional.interpolate(
|
203 |
+
patch_pos_embed.reshape(
|
204 |
+
1, int(math.sqrt(N)), int(math.sqrt(N)), dim
|
205 |
+
).permute(0, 3, 1, 2),
|
206 |
+
size=None,
|
207 |
+
scale_factor=[w0 / math.sqrt(N), h0 / math.sqrt(N)],
|
208 |
+
mode="bicubic",
|
209 |
+
)
|
210 |
+
|
211 |
+
assert (
|
212 |
+
int(w0) == patch_pos_embed.shape[-2]
|
213 |
+
and int(h0) == patch_pos_embed.shape[-1]
|
214 |
+
)
|
215 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
216 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
|
217 |
+
previous_dtype
|
218 |
+
)
|
219 |
+
|
220 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
221 |
+
B, nc, w, h = x.shape
|
222 |
+
x = self.patch_embed(x)
|
223 |
+
if masks is not None:
|
224 |
+
x = torch.where(
|
225 |
+
masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x
|
226 |
+
)
|
227 |
+
|
228 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
229 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
230 |
+
|
231 |
+
return x
|
232 |
+
|
233 |
+
def forward_features_list(self, x_list, masks_list):
|
234 |
+
x = [
|
235 |
+
self.prepare_tokens_with_masks(x, masks)
|
236 |
+
for x, masks in zip(x_list, masks_list)
|
237 |
+
]
|
238 |
+
for blk in self.blocks:
|
239 |
+
x = blk(x)
|
240 |
+
|
241 |
+
all_x = x
|
242 |
+
output = []
|
243 |
+
for x, masks in zip(all_x, masks_list):
|
244 |
+
x_norm = self.norm(x)
|
245 |
+
output.append({
|
246 |
+
"x_norm_clstoken": x_norm[:, 0],
|
247 |
+
"x_norm_patchtokens": x_norm[:, 1:],
|
248 |
+
"x_prenorm": x,
|
249 |
+
"masks": masks,
|
250 |
+
})
|
251 |
+
return output
|
252 |
+
|
253 |
+
def forward_features(self, x, masks=None):
|
254 |
+
if isinstance(x, list):
|
255 |
+
return self.forward_features_list(x, masks)
|
256 |
+
|
257 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
258 |
+
|
259 |
+
for blk in self.blocks:
|
260 |
+
x = blk(x)
|
261 |
+
|
262 |
+
x_norm = self.norm(x)
|
263 |
+
return {
|
264 |
+
"x_norm_clstoken": x_norm[:, 0],
|
265 |
+
"x_norm_patchtokens": x_norm[:, 1:],
|
266 |
+
"x_prenorm": x,
|
267 |
+
"masks": masks,
|
268 |
+
}
|
269 |
+
|
270 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
271 |
+
x = self.prepare_tokens_with_masks(x)
|
272 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
273 |
+
output, total_block_len = [], len(self.blocks)
|
274 |
+
blocks_to_take = (
|
275 |
+
range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
276 |
+
)
|
277 |
+
for i, blk in enumerate(self.blocks):
|
278 |
+
x = blk(x)
|
279 |
+
if i in blocks_to_take:
|
280 |
+
output.append(x)
|
281 |
+
assert len(output) == len(
|
282 |
+
blocks_to_take
|
283 |
+
), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
284 |
+
return output
|
285 |
+
|
286 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
287 |
+
x = self.prepare_tokens_with_masks(x)
|
288 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
289 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
290 |
+
blocks_to_take = (
|
291 |
+
range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
292 |
+
)
|
293 |
+
for block_chunk in self.blocks:
|
294 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
295 |
+
x = blk(x)
|
296 |
+
if i in blocks_to_take:
|
297 |
+
output.append(x)
|
298 |
+
i += 1
|
299 |
+
assert len(output) == len(
|
300 |
+
blocks_to_take
|
301 |
+
), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
302 |
+
return output
|
303 |
+
|
304 |
+
def get_intermediate_layers(
|
305 |
+
self,
|
306 |
+
x: torch.Tensor,
|
307 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
308 |
+
reshape: bool = False,
|
309 |
+
return_class_token: bool = False,
|
310 |
+
norm=True,
|
311 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
312 |
+
if self.chunked_blocks:
|
313 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
314 |
+
else:
|
315 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
316 |
+
if norm:
|
317 |
+
outputs = [self.norm(out) for out in outputs]
|
318 |
+
class_tokens = [out[:, 0] for out in outputs]
|
319 |
+
outputs = [out[:, 1:] for out in outputs]
|
320 |
+
if reshape:
|
321 |
+
B, _, w, h = x.shape
|
322 |
+
outputs = [
|
323 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1)
|
324 |
+
.permute(0, 3, 1, 2)
|
325 |
+
.contiguous()
|
326 |
+
for out in outputs
|
327 |
+
]
|
328 |
+
if return_class_token:
|
329 |
+
return tuple(zip(outputs, class_tokens))
|
330 |
+
return tuple(outputs)
|
331 |
+
|
332 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
333 |
+
return self.get_intermediate_layers(
|
334 |
+
x, n=1, reshape=True, return_class_token=False, norm=True
|
335 |
+
)[0]
|
336 |
+
|
337 |
+
# def forward(self, *args, is_training=False, **kwargs):
|
338 |
+
# ret = self.forward_features(*args, **kwargs)
|
339 |
+
# if is_training:
|
340 |
+
# return ret
|
341 |
+
# else:
|
342 |
+
# return self.head(ret["x_norm_clstoken"])
|
343 |
+
|
344 |
+
|
345 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
346 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
347 |
+
if isinstance(module, nn.Linear):
|
348 |
+
trunc_normal_(module.weight, std=0.02)
|
349 |
+
if module.bias is not None:
|
350 |
+
nn.init.zeros_(module.bias)
|
351 |
+
|
352 |
+
|
353 |
+
def vit_small(patch_size=14, **kwargs):
|
354 |
+
model = DinoVisionTransformer(
|
355 |
+
img_size=518,
|
356 |
+
patch_size=patch_size,
|
357 |
+
embed_dim=384,
|
358 |
+
depth=12,
|
359 |
+
num_heads=6,
|
360 |
+
mlp_ratio=4,
|
361 |
+
init_values=1e-5,
|
362 |
+
block_fn=partial(dino_utils.Block, attn_class=dino_utils.MemEffAttention),
|
363 |
+
**kwargs,
|
364 |
+
)
|
365 |
+
return model
|
366 |
+
|
367 |
+
|
368 |
+
def vit_base(patch_size=14, **kwargs):
|
369 |
+
model = DinoVisionTransformer(
|
370 |
+
img_size=518,
|
371 |
+
patch_size=patch_size,
|
372 |
+
embed_dim=768,
|
373 |
+
depth=12,
|
374 |
+
num_heads=12,
|
375 |
+
mlp_ratio=4,
|
376 |
+
init_values=1e-5,
|
377 |
+
block_fn=partial(dino_utils.Block, attn_class=dino_utils.MemEffAttention),
|
378 |
+
**kwargs,
|
379 |
+
)
|
380 |
+
return model
|
381 |
+
|
382 |
+
|
383 |
+
def vit_large(patch_size=14, **kwargs):
|
384 |
+
model = DinoVisionTransformer(
|
385 |
+
img_size=518,
|
386 |
+
patch_size=patch_size,
|
387 |
+
embed_dim=1024,
|
388 |
+
depth=24,
|
389 |
+
num_heads=16,
|
390 |
+
mlp_ratio=4,
|
391 |
+
init_values=1e-5,
|
392 |
+
block_fn=partial(dino_utils.Block, attn_class=dino_utils.MemEffAttention),
|
393 |
+
**kwargs,
|
394 |
+
)
|
395 |
+
return model
|
396 |
+
|
397 |
+
|
398 |
+
def vit_giant2(patch_size=14, **kwargs):
|
399 |
+
"""Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64"""
|
400 |
+
model = DinoVisionTransformer(
|
401 |
+
img_size=518,
|
402 |
+
patch_size=patch_size,
|
403 |
+
embed_dim=1536,
|
404 |
+
depth=40,
|
405 |
+
num_heads=24,
|
406 |
+
mlp_ratio=4,
|
407 |
+
init_values=1e-5,
|
408 |
+
block_fn=partial(dino_utils.Block, attn_class=dino_utils.MemEffAttention),
|
409 |
+
**kwargs,
|
410 |
+
)
|
411 |
+
return model
|
third_party/omniglue/third_party/dinov2/dino_utils.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
#
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/eval/segmentation_m2f/models/backbones/vit.py
|
8 |
+
|
9 |
+
from typing import Callable, Optional, Tuple, Union
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
|
15 |
+
class Mlp(nn.Module):
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
in_features: int,
|
20 |
+
hidden_features: Optional[int] = None,
|
21 |
+
out_features: Optional[int] = None,
|
22 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
23 |
+
drop: float = 0.0,
|
24 |
+
bias: bool = True,
|
25 |
+
) -> None:
|
26 |
+
super().__init__()
|
27 |
+
out_features = out_features or in_features
|
28 |
+
hidden_features = hidden_features or in_features
|
29 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
30 |
+
self.act = act_layer()
|
31 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
32 |
+
self.drop = nn.Dropout(drop)
|
33 |
+
|
34 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
35 |
+
x = self.fc1(x)
|
36 |
+
x = self.act(x)
|
37 |
+
x = self.drop(x)
|
38 |
+
x = self.fc2(x)
|
39 |
+
x = self.drop(x)
|
40 |
+
return x
|
41 |
+
|
42 |
+
|
43 |
+
def make_2tuple(x):
|
44 |
+
if isinstance(x, tuple):
|
45 |
+
assert len(x) == 2
|
46 |
+
return x
|
47 |
+
|
48 |
+
assert isinstance(x, int)
|
49 |
+
return (x, x)
|
50 |
+
|
51 |
+
|
52 |
+
class PatchEmbed(nn.Module):
|
53 |
+
"""2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
54 |
+
|
55 |
+
Args:
|
56 |
+
img_size: Image size.
|
57 |
+
patch_size: Patch token size.
|
58 |
+
in_chans: Number of input image channels.
|
59 |
+
embed_dim: Number of linear projection output channels.
|
60 |
+
norm_layer: Normalization layer.
|
61 |
+
"""
|
62 |
+
|
63 |
+
def __init__(
|
64 |
+
self,
|
65 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
66 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
67 |
+
in_chans: int = 3,
|
68 |
+
embed_dim: int = 768,
|
69 |
+
norm_layer: Optional[Callable] = None,
|
70 |
+
flatten_embedding: bool = True,
|
71 |
+
) -> None:
|
72 |
+
super().__init__()
|
73 |
+
|
74 |
+
image_HW = make_2tuple(img_size)
|
75 |
+
patch_HW = make_2tuple(patch_size)
|
76 |
+
patch_grid_size = (
|
77 |
+
image_HW[0] // patch_HW[0],
|
78 |
+
image_HW[1] // patch_HW[1],
|
79 |
+
)
|
80 |
+
|
81 |
+
self.img_size = image_HW
|
82 |
+
self.patch_size = patch_HW
|
83 |
+
self.patches_resolution = patch_grid_size
|
84 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
85 |
+
|
86 |
+
self.in_chans = in_chans
|
87 |
+
self.embed_dim = embed_dim
|
88 |
+
|
89 |
+
self.flatten_embedding = flatten_embedding
|
90 |
+
|
91 |
+
self.proj = nn.Conv2d(
|
92 |
+
in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW
|
93 |
+
)
|
94 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
95 |
+
|
96 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
97 |
+
_, _, H, W = x.shape
|
98 |
+
patch_H, patch_W = self.patch_size
|
99 |
+
|
100 |
+
assert (
|
101 |
+
H % patch_H == 0
|
102 |
+
), f"Input image height {H} is not a multiple of patch height {patch_H}"
|
103 |
+
assert (
|
104 |
+
W % patch_W == 0
|
105 |
+
), f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
106 |
+
|
107 |
+
x = self.proj(x) # B C H W
|
108 |
+
H, W = x.size(2), x.size(3)
|
109 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
110 |
+
x = self.norm(x)
|
111 |
+
if not self.flatten_embedding:
|
112 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
113 |
+
return x
|
114 |
+
|
115 |
+
def flops(self) -> float:
|
116 |
+
Ho, Wo = self.patches_resolution
|
117 |
+
flops = (
|
118 |
+
Ho
|
119 |
+
* Wo
|
120 |
+
* self.embed_dim
|
121 |
+
* self.in_chans
|
122 |
+
* (self.patch_size[0] * self.patch_size[1])
|
123 |
+
)
|
124 |
+
if self.norm is not None:
|
125 |
+
flops += Ho * Wo * self.embed_dim
|
126 |
+
return flops
|
127 |
+
|
128 |
+
|
129 |
+
XFORMERS_AVAILABLE = False
|
130 |
+
|
131 |
+
|
132 |
+
class Attention(nn.Module):
|
133 |
+
|
134 |
+
def __init__(
|
135 |
+
self,
|
136 |
+
dim: int,
|
137 |
+
num_heads: int = 8,
|
138 |
+
qkv_bias: bool = False,
|
139 |
+
proj_bias: bool = True,
|
140 |
+
attn_drop: float = 0.0,
|
141 |
+
proj_drop: float = 0.0,
|
142 |
+
) -> None:
|
143 |
+
super().__init__()
|
144 |
+
self.num_heads = num_heads
|
145 |
+
head_dim = dim // num_heads
|
146 |
+
self.scale = head_dim**-0.5
|
147 |
+
|
148 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
149 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
150 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
151 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
152 |
+
|
153 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
154 |
+
B, N, C = x.shape
|
155 |
+
qkv = (
|
156 |
+
self.qkv(x)
|
157 |
+
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
158 |
+
.permute(2, 0, 3, 1, 4)
|
159 |
+
)
|
160 |
+
|
161 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
162 |
+
attn = q @ k.transpose(-2, -1)
|
163 |
+
|
164 |
+
attn = attn.softmax(dim=-1)
|
165 |
+
attn = self.attn_drop(attn)
|
166 |
+
|
167 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
168 |
+
x = self.proj(x)
|
169 |
+
x = self.proj_drop(x)
|
170 |
+
return x
|
171 |
+
|
172 |
+
|
173 |
+
class MemEffAttention(Attention):
|
174 |
+
|
175 |
+
def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
|
176 |
+
if not XFORMERS_AVAILABLE:
|
177 |
+
assert attn_bias is None, "xFormers is required for nested tensors usage"
|
178 |
+
return super().forward(x)
|
179 |
+
else:
|
180 |
+
raise NotImplementedError("MemEffAttention do not support xFormer")
|
181 |
+
# B, N, C = x.shape
|
182 |
+
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
183 |
+
|
184 |
+
# q, k, v = unbind(qkv, 2)
|
185 |
+
|
186 |
+
# x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
187 |
+
# x = x.reshape([B, N, C])
|
188 |
+
|
189 |
+
# x = self.proj(x)
|
190 |
+
# x = self.proj_drop(x)
|
191 |
+
# return x
|
192 |
+
|
193 |
+
|
194 |
+
class LayerScale(nn.Module):
|
195 |
+
|
196 |
+
def __init__(
|
197 |
+
self,
|
198 |
+
dim: int,
|
199 |
+
init_values: Union[float, torch.Tensor] = 1e-5,
|
200 |
+
inplace: bool = False,
|
201 |
+
) -> None:
|
202 |
+
super().__init__()
|
203 |
+
self.inplace = inplace
|
204 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
205 |
+
|
206 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
207 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
208 |
+
|
209 |
+
|
210 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
211 |
+
if drop_prob == 0.0 or not training:
|
212 |
+
return x
|
213 |
+
keep_prob = 1 - drop_prob
|
214 |
+
shape = (x.shape[0],) + (1,) * (
|
215 |
+
x.ndim - 1
|
216 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
217 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
218 |
+
if keep_prob > 0.0:
|
219 |
+
random_tensor.div_(keep_prob)
|
220 |
+
output = x * random_tensor
|
221 |
+
return output
|
222 |
+
|
223 |
+
|
224 |
+
class DropPath(nn.Module):
|
225 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
226 |
+
|
227 |
+
def __init__(self, drop_prob=None):
|
228 |
+
super(DropPath, self).__init__()
|
229 |
+
self.drop_prob = drop_prob
|
230 |
+
|
231 |
+
def forward(self, x):
|
232 |
+
return drop_path(x, self.drop_prob, self.training)
|
233 |
+
|
234 |
+
|
235 |
+
class Block(nn.Module):
|
236 |
+
|
237 |
+
def __init__(
|
238 |
+
self,
|
239 |
+
dim: int,
|
240 |
+
num_heads: int,
|
241 |
+
mlp_ratio: float = 4.0,
|
242 |
+
qkv_bias: bool = False,
|
243 |
+
proj_bias: bool = True,
|
244 |
+
ffn_bias: bool = True,
|
245 |
+
drop: float = 0.0,
|
246 |
+
attn_drop: float = 0.0,
|
247 |
+
init_values=None,
|
248 |
+
drop_path: float = 0.0,
|
249 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
250 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
251 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
252 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
253 |
+
) -> None:
|
254 |
+
super().__init__()
|
255 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
256 |
+
self.norm1 = norm_layer(dim)
|
257 |
+
self.attn = attn_class(
|
258 |
+
dim,
|
259 |
+
num_heads=num_heads,
|
260 |
+
qkv_bias=qkv_bias,
|
261 |
+
proj_bias=proj_bias,
|
262 |
+
attn_drop=attn_drop,
|
263 |
+
proj_drop=drop,
|
264 |
+
)
|
265 |
+
self.ls1 = (
|
266 |
+
LayerScale(dim, init_values=init_values)
|
267 |
+
if init_values
|
268 |
+
else nn.Identity()
|
269 |
+
)
|
270 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
271 |
+
|
272 |
+
self.norm2 = norm_layer(dim)
|
273 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
274 |
+
self.mlp = ffn_layer(
|
275 |
+
in_features=dim,
|
276 |
+
hidden_features=mlp_hidden_dim,
|
277 |
+
act_layer=act_layer,
|
278 |
+
drop=drop,
|
279 |
+
bias=ffn_bias,
|
280 |
+
)
|
281 |
+
self.ls2 = (
|
282 |
+
LayerScale(dim, init_values=init_values)
|
283 |
+
if init_values
|
284 |
+
else nn.Identity()
|
285 |
+
)
|
286 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
287 |
+
|
288 |
+
self.sample_drop_ratio = drop_path
|
289 |
+
|
290 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
291 |
+
def attn_residual_func(x: torch.Tensor) -> torch.Tensor:
|
292 |
+
return self.ls1(self.attn(self.norm1(x)))
|
293 |
+
|
294 |
+
def ffn_residual_func(x: torch.Tensor) -> torch.Tensor:
|
295 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
296 |
+
|
297 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
298 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
299 |
+
x = drop_add_residual_stochastic_depth(
|
300 |
+
x,
|
301 |
+
residual_func=attn_residual_func,
|
302 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
303 |
+
)
|
304 |
+
x = drop_add_residual_stochastic_depth(
|
305 |
+
x,
|
306 |
+
residual_func=ffn_residual_func,
|
307 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
308 |
+
)
|
309 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
310 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
311 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
312 |
+
else:
|
313 |
+
x = x + attn_residual_func(x)
|
314 |
+
x = x + ffn_residual_func(x)
|
315 |
+
return x
|
316 |
+
|
317 |
+
|
318 |
+
def drop_add_residual_stochastic_depth(
|
319 |
+
x: torch.Tensor,
|
320 |
+
residual_func: Callable[[torch.Tensor], torch.Tensor],
|
321 |
+
sample_drop_ratio: float = 0.0,
|
322 |
+
) -> torch.Tensor:
|
323 |
+
# 1) extract subset using permutation
|
324 |
+
b, n, d = x.shape
|
325 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
326 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
327 |
+
x_subset = x[brange]
|
328 |
+
|
329 |
+
# 2) apply residual_func to get residual
|
330 |
+
residual = residual_func(x_subset)
|
331 |
+
|
332 |
+
x_flat = x.flatten(1)
|
333 |
+
residual = residual.flatten(1)
|
334 |
+
|
335 |
+
residual_scale_factor = b / sample_subset_size
|
336 |
+
|
337 |
+
# 3) add the residual
|
338 |
+
x_plus_residual = torch.index_add(
|
339 |
+
x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
|
340 |
+
)
|
341 |
+
return x_plus_residual.view_as(x)
|