Vincentqyw commited on
Commit
b178479
·
1 Parent(s): 1906d47

update: ui

Browse files
Files changed (5) hide show
  1. app.py +186 -162
  2. common/utils.py +323 -12
  3. common/visualize_util.py +0 -642
  4. common/{plotting.py → viz.py} +116 -21
  5. style.css +18 -0
app.py CHANGED
@@ -1,59 +1,20 @@
1
  import argparse
2
  import gradio as gr
3
-
4
- from hloc import extract_features
5
  from common.utils import (
6
  matcher_zoo,
7
- device,
8
- match_dense,
9
- match_features,
10
- get_model,
11
- get_feature_model,
12
- display_matches,
13
  )
14
 
 
 
 
15
 
16
- def run_matching(
17
- match_threshold, extract_max_keypoints, keypoint_threshold, key, image0, image1
18
- ):
19
- # image0 and image1 is RGB mode
20
- if image0 is None or image1 is None:
21
- raise gr.Error("Error: No images found! Please upload two images.")
22
-
23
- model = matcher_zoo[key]
24
- match_conf = model["config"]
25
- # update match config
26
- match_conf["model"]["match_threshold"] = match_threshold
27
- match_conf["model"]["max_keypoints"] = extract_max_keypoints
28
 
29
- matcher = get_model(match_conf)
30
- if model["dense"]:
31
- pred = match_dense.match_images(
32
- matcher, image0, image1, match_conf["preprocessing"], device=device
33
- )
34
- del matcher
35
- extract_conf = None
36
- else:
37
- extract_conf = model["config_feature"]
38
- # update extract config
39
- extract_conf["model"]["max_keypoints"] = extract_max_keypoints
40
- extract_conf["model"]["keypoint_threshold"] = keypoint_threshold
41
- extractor = get_feature_model(extract_conf)
42
- pred0 = extract_features.extract(
43
- extractor, image0, extract_conf["preprocessing"]
44
- )
45
- pred1 = extract_features.extract(
46
- extractor, image1, extract_conf["preprocessing"]
47
- )
48
- pred = match_features.match_images(matcher, pred0, pred1)
49
- del extractor
50
- fig, num_inliers = display_matches(pred)
51
- del pred
52
- return (
53
- fig,
54
- {"matches number": num_inliers},
55
- {"match_conf": match_conf, "extractor_conf": extract_conf},
56
- )
57
 
58
 
59
  def ui_change_imagebox(choice):
@@ -61,7 +22,18 @@ def ui_change_imagebox(choice):
61
 
62
 
63
  def ui_reset_state(
64
- match_threshold, extract_max_keypoints, keypoint_threshold, key, image0, image1
 
 
 
 
 
 
 
 
 
 
 
65
  ):
66
  match_threshold = 0.2
67
  extract_max_keypoints = 1000
@@ -69,31 +41,35 @@ def ui_reset_state(
69
  key = list(matcher_zoo.keys())[0]
70
  image0 = None
71
  image1 = None
 
72
  return (
 
 
73
  match_threshold,
74
  extract_max_keypoints,
75
  keypoint_threshold,
76
  key,
77
- image0,
78
- image1,
79
- {"value": None, "source": "upload", "__type__": "update"},
80
- {"value": None, "source": "upload", "__type__": "update"},
81
  "upload",
82
  None,
83
  {},
84
  {},
 
 
 
 
 
 
 
 
85
  )
86
 
87
 
 
88
  def run(config):
89
- with gr.Blocks(css="footer {visibility: hidden}") as app:
90
- gr.Markdown(
91
- """
92
- <p align="center">
93
- <h1 align="center">Image Matching WebUI</h1>
94
- </p>
95
- """
96
- )
97
 
98
  with gr.Row(equal_height=False):
99
  with gr.Column():
@@ -109,43 +85,6 @@ def run(config):
109
  label="Image Source",
110
  value="upload",
111
  )
112
-
113
- with gr.Row():
114
- match_setting_threshold = gr.Slider(
115
- minimum=0.0,
116
- maximum=1,
117
- step=0.001,
118
- label="Match threshold",
119
- value=0.1,
120
- )
121
- match_setting_max_features = gr.Slider(
122
- minimum=10,
123
- maximum=10000,
124
- step=10,
125
- label="Max number of features",
126
- value=1000,
127
- )
128
- # TODO: add line settings
129
- with gr.Row():
130
- detect_keypoints_threshold = gr.Slider(
131
- minimum=0,
132
- maximum=1,
133
- step=0.001,
134
- label="Keypoint threshold",
135
- value=0.015,
136
- )
137
- detect_line_threshold = gr.Slider(
138
- minimum=0.1,
139
- maximum=1,
140
- step=0.01,
141
- label="Line threshold",
142
- value=0.2,
143
- )
144
- # matcher_lists = gr.Radio(
145
- # ["NN-mutual", "Dual-Softmax"],
146
- # label="Matcher mode",
147
- # value="NN-mutual",
148
- # )
149
  with gr.Row():
150
  input_image0 = gr.Image(
151
  label="Image 0",
@@ -166,89 +105,147 @@ def run(config):
166
  label="Run Match", value="Run Match", variant="primary"
167
  )
168
 
169
- with gr.Accordion("Open for More!", open=False):
170
- gr.Markdown(
171
- f"""
172
- <h3>Supported Algorithms</h3>
173
- {", ".join(matcher_zoo.keys())}
174
- """
175
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
 
177
  # collect inputs
178
  inputs = [
 
 
179
  match_setting_threshold,
180
  match_setting_max_features,
181
  detect_keypoints_threshold,
182
  matcher_list,
183
- input_image0,
184
- input_image1,
 
 
 
 
185
  ]
186
 
187
  # Add some examples
188
  with gr.Row():
189
- examples = [
190
- [
191
- 0.1,
192
- 2000,
193
- 0.015,
194
- "disk+lightglue",
195
- "datasets/sacre_coeur/mapping/71295362_4051449754.jpg",
196
- "datasets/sacre_coeur/mapping/93341989_396310999.jpg",
197
- ],
198
- [
199
- 0.1,
200
- 2000,
201
- 0.015,
202
- "loftr",
203
- "datasets/sacre_coeur/mapping/03903474_1471484089.jpg",
204
- "datasets/sacre_coeur/mapping/02928139_3448003521.jpg",
205
- ],
206
- [
207
- 0.1,
208
- 2000,
209
- 0.015,
210
- "disk",
211
- "datasets/sacre_coeur/mapping/10265353_3838484249.jpg",
212
- "datasets/sacre_coeur/mapping/51091044_3486849416.jpg",
213
- ],
214
- [
215
- 0.1,
216
- 2000,
217
- 0.015,
218
- "topicfm",
219
- "datasets/sacre_coeur/mapping/44120379_8371960244.jpg",
220
- "datasets/sacre_coeur/mapping/93341989_396310999.jpg",
221
- ],
222
- [
223
- 0.1,
224
- 2000,
225
- 0.015,
226
- "superpoint+superglue",
227
- "datasets/sacre_coeur/mapping/17295357_9106075285.jpg",
228
- "datasets/sacre_coeur/mapping/44120379_8371960244.jpg",
229
- ],
230
- ]
231
  # Example inputs
232
  gr.Examples(
233
- examples=examples,
234
  inputs=inputs,
235
  outputs=[],
236
  fn=run_matching,
237
- cache_examples=True,
238
- label="Examples (click one of the images below to Run Match)",
 
 
 
 
 
 
 
 
 
 
239
  )
240
 
241
  with gr.Column():
242
- output_mkpts = gr.Image(label="Keypoints Matching", type="numpy")
243
- matches_result_info = gr.JSON(label="Matches Statistics")
244
- matcher_info = gr.JSON(label="Match info")
 
 
 
 
 
 
 
 
 
245
 
246
  # callbacks
247
  match_image_src.change(
248
- fn=ui_change_imagebox, inputs=match_image_src, outputs=input_image0
 
 
249
  )
250
  match_image_src.change(
251
- fn=ui_change_imagebox, inputs=match_image_src, outputs=input_image1
 
 
252
  )
253
 
254
  # collect outputs
@@ -256,34 +253,61 @@ def run(config):
256
  output_mkpts,
257
  matches_result_info,
258
  matcher_info,
 
 
259
  ]
260
  # button callbacks
261
  button_run.click(fn=run_matching, inputs=inputs, outputs=outputs)
262
 
263
  # Reset images
264
  reset_outputs = [
 
 
265
  match_setting_threshold,
266
  match_setting_max_features,
267
  detect_keypoints_threshold,
268
  matcher_list,
269
  input_image0,
270
  input_image1,
271
- input_image0,
272
- input_image1,
273
  match_image_src,
274
  output_mkpts,
275
  matches_result_info,
276
  matcher_info,
 
 
 
 
 
 
 
 
277
  ]
278
- button_reset.click(fn=ui_reset_state, inputs=inputs, outputs=reset_outputs)
279
- app.queue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  app.launch(share=False)
281
 
282
 
283
  if __name__ == "__main__":
284
  parser = argparse.ArgumentParser()
285
  parser.add_argument(
286
- "--config_path", type=str, default="config.yaml", help="configuration file path"
 
 
 
287
  )
288
  args = parser.parse_args()
289
  config = None
 
1
  import argparse
2
  import gradio as gr
 
 
3
  from common.utils import (
4
  matcher_zoo,
5
+ change_estimate_geom,
6
+ run_matching,
7
+ ransac_zoo,
8
+ gen_examples,
 
 
9
  )
10
 
11
+ DESCRIPTION = """
12
+ # Image Matching WebUI
13
+ This Space demonstrates [Image Matching WebUI](https://github.com/Vincentqyw/image-matching-webui) by vincent qin. Feel free to play with it, or duplicate to run image matching without a queue!
14
 
15
+ 🔎 For more details about supported local features and matchers, please refer to https://github.com/Vincentqyw/image-matching-webui
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  def ui_change_imagebox(choice):
 
22
 
23
 
24
  def ui_reset_state(
25
+ image0,
26
+ image1,
27
+ match_threshold,
28
+ extract_max_keypoints,
29
+ keypoint_threshold,
30
+ key,
31
+ enable_ransac=False,
32
+ ransac_method="RANSAC",
33
+ ransac_reproj_threshold=8,
34
+ ransac_confidence=0.999,
35
+ ransac_max_iter=10000,
36
+ choice_estimate_geom="Homography",
37
  ):
38
  match_threshold = 0.2
39
  extract_max_keypoints = 1000
 
41
  key = list(matcher_zoo.keys())[0]
42
  image0 = None
43
  image1 = None
44
+ enable_ransac = False
45
  return (
46
+ image0,
47
+ image1,
48
  match_threshold,
49
  extract_max_keypoints,
50
  keypoint_threshold,
51
  key,
52
+ ui_change_imagebox("upload"),
53
+ ui_change_imagebox("upload"),
 
 
54
  "upload",
55
  None,
56
  {},
57
  {},
58
+ None,
59
+ {},
60
+ False,
61
+ "RANSAC",
62
+ 8,
63
+ 0.999,
64
+ 10000,
65
+ "Homography",
66
  )
67
 
68
 
69
+ # "footer {visibility: hidden}"
70
  def run(config):
71
+ with gr.Blocks(css="style.css") as app:
72
+ gr.Markdown(DESCRIPTION)
 
 
 
 
 
 
73
 
74
  with gr.Row(equal_height=False):
75
  with gr.Column():
 
85
  label="Image Source",
86
  value="upload",
87
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  with gr.Row():
89
  input_image0 = gr.Image(
90
  label="Image 0",
 
105
  label="Run Match", value="Run Match", variant="primary"
106
  )
107
 
108
+ with gr.Accordion("Advanced Setting", open=False):
109
+ with gr.Accordion("Matching Setting", open=True):
110
+ with gr.Row():
111
+ match_setting_threshold = gr.Slider(
112
+ minimum=0.0,
113
+ maximum=1,
114
+ step=0.001,
115
+ label="Match thres.",
116
+ value=0.1,
117
+ )
118
+ match_setting_max_features = gr.Slider(
119
+ minimum=10,
120
+ maximum=10000,
121
+ step=10,
122
+ label="Max features",
123
+ value=1000,
124
+ )
125
+ # TODO: add line settings
126
+ with gr.Row():
127
+ detect_keypoints_threshold = gr.Slider(
128
+ minimum=0,
129
+ maximum=1,
130
+ step=0.001,
131
+ label="Keypoint thres.",
132
+ value=0.015,
133
+ )
134
+ detect_line_threshold = gr.Slider(
135
+ minimum=0.1,
136
+ maximum=1,
137
+ step=0.01,
138
+ label="Line thres.",
139
+ value=0.2,
140
+ )
141
+ # matcher_lists = gr.Radio(
142
+ # ["NN-mutual", "Dual-Softmax"],
143
+ # label="Matcher mode",
144
+ # value="NN-mutual",
145
+ # )
146
+ with gr.Accordion("RANSAC Setting", open=False):
147
+ with gr.Row(equal_height=False):
148
+ enable_ransac = gr.Checkbox(label="Enable RANSAC")
149
+ ransac_method = gr.Dropdown(
150
+ choices=ransac_zoo.keys(),
151
+ value="RANSAC",
152
+ label="RANSAC Method",
153
+ interactive=True,
154
+ )
155
+ ransac_reproj_threshold = gr.Slider(
156
+ minimum=0.0,
157
+ maximum=12,
158
+ step=0.01,
159
+ label="Ransac Reproj threshold",
160
+ value=8.0,
161
+ )
162
+ ransac_confidence = gr.Slider(
163
+ minimum=0.0,
164
+ maximum=1,
165
+ step=0.00001,
166
+ label="Ransac Confidence",
167
+ value=0.99999,
168
+ )
169
+ ransac_max_iter = gr.Slider(
170
+ minimum=0.0,
171
+ maximum=100000,
172
+ step=100,
173
+ label="Ransac Iterations",
174
+ value=10000,
175
+ )
176
+
177
+ with gr.Accordion("Geometry Setting", open=True):
178
+ with gr.Row(equal_height=False):
179
+ # show_geom = gr.Checkbox(label="Show Geometry")
180
+ choice_estimate_geom = gr.Radio(
181
+ ["Fundamental", "Homography"],
182
+ label="Reconstruct Geometry",
183
+ value="Homography",
184
+ )
185
 
186
+ # with gr.Column():
187
  # collect inputs
188
  inputs = [
189
+ input_image0,
190
+ input_image1,
191
  match_setting_threshold,
192
  match_setting_max_features,
193
  detect_keypoints_threshold,
194
  matcher_list,
195
+ enable_ransac,
196
+ ransac_method,
197
+ ransac_reproj_threshold,
198
+ ransac_confidence,
199
+ ransac_max_iter,
200
+ choice_estimate_geom,
201
  ]
202
 
203
  # Add some examples
204
  with gr.Row():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  # Example inputs
206
  gr.Examples(
207
+ examples=gen_examples(),
208
  inputs=inputs,
209
  outputs=[],
210
  fn=run_matching,
211
+ cache_examples=False,
212
+ label=(
213
+ "Examples (click one of the images below to Run"
214
+ " Match)"
215
+ ),
216
+ )
217
+ with gr.Accordion("Open for More!", open=False):
218
+ gr.Markdown(
219
+ f"""
220
+ <h3>Supported Algorithms</h3>
221
+ {", ".join(matcher_zoo.keys())}
222
+ """
223
  )
224
 
225
  with gr.Column():
226
+ output_mkpts = gr.Image(
227
+ label="Keypoints Matching", type="numpy"
228
+ )
229
+ with gr.Accordion(
230
+ "Open for More: Matches Statistics", open=False
231
+ ):
232
+ matches_result_info = gr.JSON(label="Matches Statistics")
233
+ matcher_info = gr.JSON(label="Match info")
234
+
235
+ output_wrapped = gr.Image(label="Wrapped Pair", type="numpy")
236
+ with gr.Accordion("Open for More: Geometry info", open=False):
237
+ geometry_result = gr.JSON(label="Reconstructed Geometry")
238
 
239
  # callbacks
240
  match_image_src.change(
241
+ fn=ui_change_imagebox,
242
+ inputs=match_image_src,
243
+ outputs=input_image0,
244
  )
245
  match_image_src.change(
246
+ fn=ui_change_imagebox,
247
+ inputs=match_image_src,
248
+ outputs=input_image1,
249
  )
250
 
251
  # collect outputs
 
253
  output_mkpts,
254
  matches_result_info,
255
  matcher_info,
256
+ geometry_result,
257
+ output_wrapped,
258
  ]
259
  # button callbacks
260
  button_run.click(fn=run_matching, inputs=inputs, outputs=outputs)
261
 
262
  # Reset images
263
  reset_outputs = [
264
+ input_image0,
265
+ input_image1,
266
  match_setting_threshold,
267
  match_setting_max_features,
268
  detect_keypoints_threshold,
269
  matcher_list,
270
  input_image0,
271
  input_image1,
 
 
272
  match_image_src,
273
  output_mkpts,
274
  matches_result_info,
275
  matcher_info,
276
+ output_wrapped,
277
+ geometry_result,
278
+ enable_ransac,
279
+ ransac_method,
280
+ ransac_reproj_threshold,
281
+ ransac_confidence,
282
+ ransac_max_iter,
283
+ choice_estimate_geom,
284
  ]
285
+ button_reset.click(
286
+ fn=ui_reset_state, inputs=inputs, outputs=reset_outputs
287
+ )
288
+
289
+ # estimate geo
290
+ choice_estimate_geom.change(
291
+ fn=change_estimate_geom,
292
+ inputs=[
293
+ input_image0,
294
+ input_image1,
295
+ geometry_result,
296
+ choice_estimate_geom,
297
+ ],
298
+ outputs=[output_wrapped, geometry_result],
299
+ )
300
+
301
  app.launch(share=False)
302
 
303
 
304
  if __name__ == "__main__":
305
  parser = argparse.ArgumentParser()
306
  parser.add_argument(
307
+ "--config_path",
308
+ type=str,
309
+ default="config.yaml",
310
+ help="configuration file path",
311
  )
312
  args = parser.parse_args()
313
  config = None
common/utils.py CHANGED
@@ -1,11 +1,14 @@
1
- import torch
 
2
  import numpy as np
 
 
3
  import cv2
 
4
  from hloc import matchers, extractors
5
  from hloc.utils.base_model import dynamic_load
6
  from hloc import match_dense, match_features, extract_features
7
- from .plotting import draw_matches, fig2im
8
- from .visualize_util import plot_images, plot_color_line_matches
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
@@ -22,6 +25,217 @@ def get_feature_model(conf):
22
  return model
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def display_matches(pred: dict):
26
  img0 = pred["image0_orig"]
27
  img1 = pred["image1_orig"]
@@ -42,7 +256,10 @@ def display_matches(pred: dict):
42
  img1,
43
  mconf,
44
  dpi=300,
45
- titles=["Image 0 - matched keypoints", "Image 1 - matched keypoints"],
 
 
 
46
  )
47
  fig = fig_mkpts
48
  if "line0_orig" in pred.keys() and "line1_orig" in pred.keys():
@@ -69,13 +286,107 @@ def display_matches(pred: dict):
69
  else:
70
  mconf = np.ones(len(mkpts0))
71
  fig_mkpts = draw_matches(mkpts0, mkpts1, img0, img1, mconf, dpi=300)
72
- fig_lines = cv2.resize(fig_lines, (fig_mkpts.shape[1], fig_mkpts.shape[0]))
 
 
73
  fig = np.concatenate([fig_mkpts, fig_lines], axis=0)
74
  else:
75
  fig = fig_lines
76
  return fig, num_inliers
77
 
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  # Matchers collections
80
  matcher_zoo = {
81
  "gluestick": {"config": match_dense.confs["gluestick"], "dense": True},
@@ -147,11 +458,11 @@ matcher_zoo = {
147
  "config_feature": extract_features.confs["d2net-ss"],
148
  "dense": False,
149
  },
150
- # "d2net-ms": {
151
- # "config": match_features.confs["NN-mutual"],
152
- # "config_feature": extract_features.confs["d2net-ms"],
153
- # "dense": False,
154
- # },
155
  "alike": {
156
  "config": match_features.confs["NN-mutual"],
157
  "config_feature": extract_features.confs["alike"],
@@ -177,6 +488,6 @@ matcher_zoo = {
177
  "config_feature": extract_features.confs["sift"],
178
  "dense": False,
179
  },
180
- # "roma": {"config": match_dense.confs["roma"], "dense": True},
181
- # "DKMv3": {"config": match_dense.confs["dkm"], "dense": True},
182
  }
 
1
+ import os
2
+ import random
3
  import numpy as np
4
+ import torch
5
+ from itertools import combinations
6
  import cv2
7
+ import gradio as gr
8
  from hloc import matchers, extractors
9
  from hloc.utils.base_model import dynamic_load
10
  from hloc import match_dense, match_features, extract_features
11
+ from .viz import draw_matches, fig2im, plot_images, plot_color_line_matches
 
12
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
 
25
  return model
26
 
27
 
28
+ def gen_examples():
29
+ random.seed(1)
30
+ example_matchers = [
31
+ "disk+lightglue",
32
+ "loftr",
33
+ "disk",
34
+ "d2net",
35
+ "topicfm",
36
+ "superpoint+superglue",
37
+ "disk+dualsoftmax",
38
+ "lanet",
39
+ ]
40
+
41
+ def gen_images_pairs(path: str, count: int = 5):
42
+ imgs_list = [
43
+ os.path.join(path, file)
44
+ for file in os.listdir(path)
45
+ if file.lower().endswith((".jpg", ".jpeg", ".png"))
46
+ ]
47
+ pairs = list(combinations(imgs_list, 2))
48
+ selected = random.sample(range(len(pairs)), count)
49
+ return [pairs[i] for i in selected]
50
+ # image pair path
51
+ path = "datasets/sacre_coeur/mapping"
52
+ pairs = gen_images_pairs(path, len(example_matchers))
53
+ match_setting_threshold = 0.1
54
+ match_setting_max_features = 2000
55
+ detect_keypoints_threshold = 0.01
56
+ enable_ransac = False
57
+ ransac_method = "RANSAC"
58
+ ransac_reproj_threshold = 8
59
+ ransac_confidence = 0.999
60
+ ransac_max_iter = 10000
61
+ input_lists = []
62
+ for pair, mt in zip(pairs, example_matchers):
63
+ input_lists.append(
64
+ [
65
+ pair[0],
66
+ pair[1],
67
+ match_setting_threshold,
68
+ match_setting_max_features,
69
+ detect_keypoints_threshold,
70
+ mt,
71
+ enable_ransac,
72
+ ransac_method,
73
+ ransac_reproj_threshold,
74
+ ransac_confidence,
75
+ ransac_max_iter,
76
+ ]
77
+ )
78
+ return input_lists
79
+
80
+
81
+ def filter_matches(
82
+ pred,
83
+ ransac_method="RANSAC",
84
+ ransac_reproj_threshold=8,
85
+ ransac_confidence=0.999,
86
+ ransac_max_iter=10000,
87
+ ):
88
+ mkpts0 = None
89
+ mkpts1 = None
90
+ feature_type = None
91
+ if "keypoints0_orig" in pred.keys() and "keypoints1_orig" in pred.keys():
92
+ mkpts0 = pred["keypoints0_orig"]
93
+ mkpts1 = pred["keypoints1_orig"]
94
+ feature_type = "KEYPOINT"
95
+ elif (
96
+ "line_keypoints0_orig" in pred.keys()
97
+ and "line_keypoints1_orig" in pred.keys()
98
+ ):
99
+ mkpts0 = pred["line_keypoints0_orig"]
100
+ mkpts1 = pred["line_keypoints1_orig"]
101
+ feature_type = "LINE"
102
+ else:
103
+ return pred
104
+ if mkpts0 is None or mkpts0 is None:
105
+ return pred
106
+ if ransac_method not in ransac_zoo.keys():
107
+ ransac_method = "RANSAC"
108
+ H, mask = cv2.findHomography(
109
+ mkpts0,
110
+ mkpts1,
111
+ method=ransac_zoo[ransac_method],
112
+ ransacReprojThreshold=ransac_reproj_threshold,
113
+ confidence=ransac_confidence,
114
+ maxIters=ransac_max_iter,
115
+ )
116
+ mask = np.array(mask.ravel().astype("bool"), dtype="bool")
117
+ if H is not None:
118
+ if feature_type == "KEYPOINT":
119
+ pred["keypoints0_orig"] = mkpts0[mask]
120
+ pred["keypoints1_orig"] = mkpts1[mask]
121
+ pred["mconf"] = pred["mconf"][mask]
122
+ elif feature_type == "LINE":
123
+ pred["line_keypoints0_orig"] = mkpts0[mask]
124
+ pred["line_keypoints1_orig"] = mkpts1[mask]
125
+ return pred
126
+
127
+
128
+ def compute_geom(
129
+ pred,
130
+ ransac_method="RANSAC",
131
+ ransac_reproj_threshold=8,
132
+ ransac_confidence=0.999,
133
+ ransac_max_iter=10000,
134
+ ) -> dict:
135
+ mkpts0 = None
136
+ mkpts1 = None
137
+
138
+ if "keypoints0_orig" in pred.keys() and "keypoints1_orig" in pred.keys():
139
+ mkpts0 = pred["keypoints0_orig"]
140
+ mkpts1 = pred["keypoints1_orig"]
141
+
142
+ if (
143
+ "line_keypoints0_orig" in pred.keys()
144
+ and "line_keypoints1_orig" in pred.keys()
145
+ ):
146
+ mkpts0 = pred["line_keypoints0_orig"]
147
+ mkpts1 = pred["line_keypoints1_orig"]
148
+
149
+ if mkpts0 is not None and mkpts1 is not None:
150
+ if len(mkpts0) < 8:
151
+ return {}
152
+ h1, w1, _ = pred["image0_orig"].shape
153
+ geo_info = {}
154
+ F, inliers = cv2.findFundamentalMat(
155
+ mkpts0,
156
+ mkpts1,
157
+ method=ransac_zoo[ransac_method],
158
+ ransacReprojThreshold=ransac_reproj_threshold,
159
+ confidence=ransac_confidence,
160
+ maxIters=ransac_max_iter,
161
+ )
162
+ geo_info["Fundamental"] = F.tolist()
163
+ H, _ = cv2.findHomography(
164
+ mkpts1,
165
+ mkpts0,
166
+ method=ransac_zoo[ransac_method],
167
+ ransacReprojThreshold=ransac_reproj_threshold,
168
+ confidence=ransac_confidence,
169
+ maxIters=ransac_max_iter,
170
+ )
171
+ geo_info["Homography"] = H.tolist()
172
+ _, H1, H2 = cv2.stereoRectifyUncalibrated(
173
+ mkpts0.reshape(-1, 2), mkpts1.reshape(-1, 2), F, imgSize=(w1, h1)
174
+ )
175
+ geo_info["H1"] = H1.tolist()
176
+ geo_info["H2"] = H2.tolist()
177
+ return geo_info
178
+ else:
179
+ return {}
180
+
181
+
182
+ def wrap_images(img0, img1, geo_info, geom_type):
183
+ h1, w1, _ = img0.shape
184
+ h2, w2, _ = img1.shape
185
+ result_matrix = None
186
+ if geo_info is not None and len(geo_info) != 0:
187
+ rectified_image0 = img0
188
+ rectified_image1 = None
189
+ H = np.array(geo_info["Homography"])
190
+ F = np.array(geo_info["Fundamental"])
191
+ title = []
192
+ if geom_type == "Homography":
193
+ rectified_image1 = cv2.warpPerspective(
194
+ img1, H, (img0.shape[1] + img1.shape[1], img0.shape[0])
195
+ )
196
+ result_matrix = H
197
+ title = ["Image 0", "Image 1 - warped"]
198
+ elif geom_type == "Fundamental":
199
+ H1, H2 = np.array(geo_info["H1"]), np.array(geo_info["H2"])
200
+ rectified_image0 = cv2.warpPerspective(img0, H1, (w1, h1))
201
+ rectified_image1 = cv2.warpPerspective(img1, H2, (w2, h2))
202
+ result_matrix = F
203
+ title = ["Image 0 - warped", "Image 1 - warped"]
204
+ else:
205
+ print("Error: Unknown geometry type")
206
+ fig = plot_images(
207
+ [rectified_image0.squeeze(), rectified_image1.squeeze()],
208
+ title,
209
+ dpi=300,
210
+ )
211
+ dictionary = {
212
+ "row1": result_matrix[0].tolist(),
213
+ "row2": result_matrix[1].tolist(),
214
+ "row3": result_matrix[2].tolist(),
215
+ }
216
+ return fig2im(fig), dictionary
217
+ else:
218
+ return None, None
219
+
220
+
221
+ def change_estimate_geom(input_image0, input_image1, matches_info, choice):
222
+ if (
223
+ matches_info is None
224
+ or len(matches_info) < 1
225
+ or "geom_info" not in matches_info.keys()
226
+ ):
227
+ return None, None
228
+ geom_info = matches_info["geom_info"]
229
+ wrapped_images = None
230
+ if choice != "No":
231
+ wrapped_images, _ = wrap_images(
232
+ input_image0, input_image1, geom_info, choice
233
+ )
234
+ return wrapped_images, matches_info
235
+ else:
236
+ return None, None
237
+
238
+
239
  def display_matches(pred: dict):
240
  img0 = pred["image0_orig"]
241
  img1 = pred["image1_orig"]
 
256
  img1,
257
  mconf,
258
  dpi=300,
259
+ titles=[
260
+ "Image 0 - matched keypoints",
261
+ "Image 1 - matched keypoints",
262
+ ],
263
  )
264
  fig = fig_mkpts
265
  if "line0_orig" in pred.keys() and "line1_orig" in pred.keys():
 
286
  else:
287
  mconf = np.ones(len(mkpts0))
288
  fig_mkpts = draw_matches(mkpts0, mkpts1, img0, img1, mconf, dpi=300)
289
+ fig_lines = cv2.resize(
290
+ fig_lines, (fig_mkpts.shape[1], fig_mkpts.shape[0])
291
+ )
292
  fig = np.concatenate([fig_mkpts, fig_lines], axis=0)
293
  else:
294
  fig = fig_lines
295
  return fig, num_inliers
296
 
297
 
298
+ def run_matching(
299
+ image0,
300
+ image1,
301
+ match_threshold,
302
+ extract_max_keypoints,
303
+ keypoint_threshold,
304
+ key,
305
+ enable_ransac=False,
306
+ ransac_method="RANSAC",
307
+ ransac_reproj_threshold=8,
308
+ ransac_confidence=0.999,
309
+ ransac_max_iter=10000,
310
+ choice_estimate_geom="Homography",
311
+ ):
312
+ # image0 and image1 is RGB mode
313
+ if image0 is None or image1 is None:
314
+ raise gr.Error("Error: No images found! Please upload two images.")
315
+
316
+ model = matcher_zoo[key]
317
+ match_conf = model["config"]
318
+ # update match config
319
+ match_conf["model"]["match_threshold"] = match_threshold
320
+ match_conf["model"]["max_keypoints"] = extract_max_keypoints
321
+
322
+ matcher = get_model(match_conf)
323
+ if model["dense"]:
324
+ pred = match_dense.match_images(
325
+ matcher, image0, image1, match_conf["preprocessing"], device=device
326
+ )
327
+ del matcher
328
+ extract_conf = None
329
+ else:
330
+ extract_conf = model["config_feature"]
331
+ # update extract config
332
+ extract_conf["model"]["max_keypoints"] = extract_max_keypoints
333
+ extract_conf["model"]["keypoint_threshold"] = keypoint_threshold
334
+ extractor = get_feature_model(extract_conf)
335
+ pred0 = extract_features.extract(
336
+ extractor, image0, extract_conf["preprocessing"]
337
+ )
338
+ pred1 = extract_features.extract(
339
+ extractor, image1, extract_conf["preprocessing"]
340
+ )
341
+ pred = match_features.match_images(matcher, pred0, pred1)
342
+ del extractor
343
+
344
+ if enable_ransac:
345
+ filter_matches(
346
+ pred,
347
+ ransac_method=ransac_method,
348
+ ransac_reproj_threshold=ransac_reproj_threshold,
349
+ ransac_confidence=ransac_confidence,
350
+ ransac_max_iter=ransac_max_iter,
351
+ )
352
+
353
+ fig, num_inliers = display_matches(pred)
354
+ geom_info = compute_geom(pred)
355
+ output_wrapped, _ = change_estimate_geom(
356
+ pred["image0_orig"],
357
+ pred["image1_orig"],
358
+ {"geom_info": geom_info},
359
+ choice_estimate_geom,
360
+ )
361
+ del pred
362
+ return (
363
+ fig,
364
+ {"matches number": num_inliers},
365
+ {
366
+ "match_conf": match_conf,
367
+ "extractor_conf": extract_conf,
368
+ },
369
+ {
370
+ "geom_info": geom_info,
371
+ },
372
+ output_wrapped,
373
+ # geometry_result,
374
+ )
375
+
376
+
377
+ # @ref: https://docs.opencv.org/4.x/d0/d74/md__build_4_x-contrib_docs-lin64_opencv_doc_tutorials_calib3d_usac.html
378
+ # AND: https://opencv.org/blog/2021/06/09/evaluating-opencvs-new-ransacs
379
+ ransac_zoo = {
380
+ "RANSAC": cv2.RANSAC,
381
+ "USAC_MAGSAC": cv2.USAC_MAGSAC,
382
+ "USAC_DEFAULT": cv2.USAC_DEFAULT,
383
+ "USAC_FM_8PTS": cv2.USAC_FM_8PTS,
384
+ "USAC_PROSAC": cv2.USAC_PROSAC,
385
+ "USAC_FAST": cv2.USAC_FAST,
386
+ "USAC_ACCURATE": cv2.USAC_ACCURATE,
387
+ "USAC_PARALLEL": cv2.USAC_PARALLEL,
388
+ }
389
+
390
  # Matchers collections
391
  matcher_zoo = {
392
  "gluestick": {"config": match_dense.confs["gluestick"], "dense": True},
 
458
  "config_feature": extract_features.confs["d2net-ss"],
459
  "dense": False,
460
  },
461
+ "d2net-ms": {
462
+ "config": match_features.confs["NN-mutual"],
463
+ "config_feature": extract_features.confs["d2net-ms"],
464
+ "dense": False,
465
+ },
466
  "alike": {
467
  "config": match_features.confs["NN-mutual"],
468
  "config_feature": extract_features.confs["alike"],
 
488
  "config_feature": extract_features.confs["sift"],
489
  "dense": False,
490
  },
491
+ "roma": {"config": match_dense.confs["roma"], "dense": True},
492
+ "DKMv3": {"config": match_dense.confs["dkm"], "dense": True},
493
  }
common/visualize_util.py DELETED
@@ -1,642 +0,0 @@
1
- """ Organize some frequently used visualization functions. """
2
- import cv2
3
- import numpy as np
4
- import matplotlib
5
- import matplotlib.pyplot as plt
6
- import copy
7
- import seaborn as sns
8
-
9
-
10
- # Plot junctions onto the image (return a separate copy)
11
- def plot_junctions(input_image, junctions, junc_size=3, color=None):
12
- """
13
- input_image: can be 0~1 float or 0~255 uint8.
14
- junctions: Nx2 or 2xN np array.
15
- junc_size: the size of the plotted circles.
16
- """
17
- # Create image copy
18
- image = copy.copy(input_image)
19
- # Make sure the image is converted to 255 uint8
20
- if image.dtype == np.uint8:
21
- pass
22
- # A float type image ranging from 0~1
23
- elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.0:
24
- image = (image * 255.0).astype(np.uint8)
25
- # A float type image ranging from 0.~255.
26
- elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.0:
27
- image = image.astype(np.uint8)
28
- else:
29
- raise ValueError(
30
- "[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8."
31
- )
32
-
33
- # Check whether the image is single channel
34
- if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)):
35
- # Squeeze to H*W first
36
- image = image.squeeze()
37
-
38
- # Stack to channle 3
39
- image = np.concatenate([image[..., None] for _ in range(3)], axis=-1)
40
-
41
- # Junction dimensions should be N*2
42
- if not len(junctions.shape) == 2:
43
- raise ValueError("[Error] junctions should be 2-dim array.")
44
-
45
- # Always convert to N*2
46
- if junctions.shape[-1] != 2:
47
- if junctions.shape[0] == 2:
48
- junctions = junctions.T
49
- else:
50
- raise ValueError("[Error] At least one of the two dims should be 2.")
51
-
52
- # Round and convert junctions to int (and check the boundary)
53
- H, W = image.shape[:2]
54
- junctions = (np.round(junctions)).astype(np.int)
55
- junctions[junctions < 0] = 0
56
- junctions[junctions[:, 0] >= H, 0] = H - 1 # (first dim) max bounded by H-1
57
- junctions[junctions[:, 1] >= W, 1] = W - 1 # (second dim) max bounded by W-1
58
-
59
- # Iterate through all the junctions
60
- num_junc = junctions.shape[0]
61
- if color is None:
62
- color = (0, 255.0, 0)
63
- for idx in range(num_junc):
64
- # Fetch one junction
65
- junc = junctions[idx, :]
66
- cv2.circle(
67
- image, tuple(np.flip(junc)), radius=junc_size, color=color, thickness=3
68
- )
69
-
70
- return image
71
-
72
-
73
- # Plot line segements given junctions and line adjecent map
74
- def plot_line_segments(
75
- input_image,
76
- junctions,
77
- line_map,
78
- junc_size=3,
79
- color=(0, 255.0, 0),
80
- line_width=1,
81
- plot_survived_junc=True,
82
- ):
83
- """
84
- input_image: can be 0~1 float or 0~255 uint8.
85
- junctions: Nx2 or 2xN np array.
86
- line_map: NxN np array
87
- junc_size: the size of the plotted circles.
88
- color: color of the line segments (can be string "random")
89
- line_width: width of the drawn segments.
90
- plot_survived_junc: whether we only plot the survived junctions.
91
- """
92
- # Create image copy
93
- image = copy.copy(input_image)
94
- # Make sure the image is converted to 255 uint8
95
- if image.dtype == np.uint8:
96
- pass
97
- # A float type image ranging from 0~1
98
- elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.0:
99
- image = (image * 255.0).astype(np.uint8)
100
- # A float type image ranging from 0.~255.
101
- elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.0:
102
- image = image.astype(np.uint8)
103
- else:
104
- raise ValueError(
105
- "[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8."
106
- )
107
-
108
- # Check whether the image is single channel
109
- if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)):
110
- # Squeeze to H*W first
111
- image = image.squeeze()
112
-
113
- # Stack to channle 3
114
- image = np.concatenate([image[..., None] for _ in range(3)], axis=-1)
115
-
116
- # Junction dimensions should be 2
117
- if not len(junctions.shape) == 2:
118
- raise ValueError("[Error] junctions should be 2-dim array.")
119
-
120
- # Always convert to N*2
121
- if junctions.shape[-1] != 2:
122
- if junctions.shape[0] == 2:
123
- junctions = junctions.T
124
- else:
125
- raise ValueError("[Error] At least one of the two dims should be 2.")
126
-
127
- # line_map dimension should be 2
128
- if not len(line_map.shape) == 2:
129
- raise ValueError("[Error] line_map should be 2-dim array.")
130
-
131
- # Color should be "random" or a list or tuple with length 3
132
- if color != "random":
133
- if not (isinstance(color, tuple) or isinstance(color, list)):
134
- raise ValueError("[Error] color should have type list or tuple.")
135
- else:
136
- if len(color) != 3:
137
- raise ValueError(
138
- "[Error] color should be a list or tuple with length 3."
139
- )
140
-
141
- # Make a copy of the line_map
142
- line_map_tmp = copy.copy(line_map)
143
-
144
- # Parse line_map back to segment pairs
145
- segments = np.zeros([0, 4])
146
- for idx in range(junctions.shape[0]):
147
- # if no connectivity, just skip it
148
- if line_map_tmp[idx, :].sum() == 0:
149
- continue
150
- # record the line segment
151
- else:
152
- for idx2 in np.where(line_map_tmp[idx, :] == 1)[0]:
153
- p1 = np.flip(junctions[idx, :]) # Convert to xy format
154
- p2 = np.flip(junctions[idx2, :]) # Convert to xy format
155
- segments = np.concatenate(
156
- (segments, np.array([p1[0], p1[1], p2[0], p2[1]])[None, ...]),
157
- axis=0,
158
- )
159
-
160
- # Update line_map
161
- line_map_tmp[idx, idx2] = 0
162
- line_map_tmp[idx2, idx] = 0
163
-
164
- # Draw segment pairs
165
- for idx in range(segments.shape[0]):
166
- seg = np.round(segments[idx, :]).astype(np.int)
167
- # Decide the color
168
- if color != "random":
169
- color = tuple(color)
170
- else:
171
- color = tuple(
172
- np.random.rand(
173
- 3,
174
- )
175
- )
176
- cv2.line(
177
- image, tuple(seg[:2]), tuple(seg[2:]), color=color, thickness=line_width
178
- )
179
-
180
- # Also draw the junctions
181
- if not plot_survived_junc:
182
- num_junc = junctions.shape[0]
183
- for idx in range(num_junc):
184
- # Fetch one junction
185
- junc = junctions[idx, :]
186
- cv2.circle(
187
- image,
188
- tuple(np.flip(junc)),
189
- radius=junc_size,
190
- color=(0, 255.0, 0),
191
- thickness=3,
192
- )
193
- # Only plot the junctions which are part of a line segment
194
- else:
195
- for idx in range(segments.shape[0]):
196
- seg = np.round(segments[idx, :]).astype(np.int) # Already in HW format.
197
- cv2.circle(
198
- image,
199
- tuple(seg[:2]),
200
- radius=junc_size,
201
- color=(0, 255.0, 0),
202
- thickness=3,
203
- )
204
- cv2.circle(
205
- image,
206
- tuple(seg[2:]),
207
- radius=junc_size,
208
- color=(0, 255.0, 0),
209
- thickness=3,
210
- )
211
-
212
- return image
213
-
214
-
215
- # Plot line segments given Nx4 or Nx2x2 line segments
216
- def plot_line_segments_from_segments(
217
- input_image, line_segments, junc_size=3, color=(0, 255.0, 0), line_width=1
218
- ):
219
- # Create image copy
220
- image = copy.copy(input_image)
221
- # Make sure the image is converted to 255 uint8
222
- if image.dtype == np.uint8:
223
- pass
224
- # A float type image ranging from 0~1
225
- elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.0:
226
- image = (image * 255.0).astype(np.uint8)
227
- # A float type image ranging from 0.~255.
228
- elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.0:
229
- image = image.astype(np.uint8)
230
- else:
231
- raise ValueError(
232
- "[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8."
233
- )
234
-
235
- # Check whether the image is single channel
236
- if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)):
237
- # Squeeze to H*W first
238
- image = image.squeeze()
239
-
240
- # Stack to channle 3
241
- image = np.concatenate([image[..., None] for _ in range(3)], axis=-1)
242
-
243
- # Check the if line_segments are in (1) Nx4, or (2) Nx2x2.
244
- H, W, _ = image.shape
245
- # (1) Nx4 format
246
- if len(line_segments.shape) == 2 and line_segments.shape[-1] == 4:
247
- # Round to int32
248
- line_segments = line_segments.astype(np.int32)
249
-
250
- # Clip H dimension
251
- line_segments[:, 0] = np.clip(line_segments[:, 0], a_min=0, a_max=H - 1)
252
- line_segments[:, 2] = np.clip(line_segments[:, 2], a_min=0, a_max=H - 1)
253
-
254
- # Clip W dimension
255
- line_segments[:, 1] = np.clip(line_segments[:, 1], a_min=0, a_max=W - 1)
256
- line_segments[:, 3] = np.clip(line_segments[:, 3], a_min=0, a_max=W - 1)
257
-
258
- # Convert to Nx2x2 format
259
- line_segments = np.concatenate(
260
- [
261
- np.expand_dims(line_segments[:, :2], axis=1),
262
- np.expand_dims(line_segments[:, 2:], axis=1),
263
- ],
264
- axis=1,
265
- )
266
-
267
- # (2) Nx2x2 format
268
- elif len(line_segments.shape) == 3 and line_segments.shape[-1] == 2:
269
- # Round to int32
270
- line_segments = line_segments.astype(np.int32)
271
-
272
- # Clip H dimension
273
- line_segments[:, :, 0] = np.clip(line_segments[:, :, 0], a_min=0, a_max=H - 1)
274
- line_segments[:, :, 1] = np.clip(line_segments[:, :, 1], a_min=0, a_max=W - 1)
275
-
276
- else:
277
- raise ValueError(
278
- "[Error] line_segments should be either Nx4 or Nx2x2 in HW format."
279
- )
280
-
281
- # Draw segment pairs (all segments should be in HW format)
282
- image = image.copy()
283
- for idx in range(line_segments.shape[0]):
284
- seg = np.round(line_segments[idx, :, :]).astype(np.int32)
285
- # Decide the color
286
- if color != "random":
287
- color = tuple(color)
288
- else:
289
- color = tuple(
290
- np.random.rand(
291
- 3,
292
- )
293
- )
294
- cv2.line(
295
- image,
296
- tuple(np.flip(seg[0, :])),
297
- tuple(np.flip(seg[1, :])),
298
- color=color,
299
- thickness=line_width,
300
- )
301
-
302
- # Also draw the junctions
303
- cv2.circle(
304
- image,
305
- tuple(np.flip(seg[0, :])),
306
- radius=junc_size,
307
- color=(0, 255.0, 0),
308
- thickness=3,
309
- )
310
- cv2.circle(
311
- image,
312
- tuple(np.flip(seg[1, :])),
313
- radius=junc_size,
314
- color=(0, 255.0, 0),
315
- thickness=3,
316
- )
317
-
318
- return image
319
-
320
-
321
- # Additional functions to visualize multiple images at the same time,
322
- # e.g. for line matching
323
- def plot_images(imgs, titles=None, cmaps="gray", dpi=100, size=5, pad=0.5):
324
- """Plot a set of images horizontally.
325
- Args:
326
- imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
327
- titles: a list of strings, as titles for each image.
328
- cmaps: colormaps for monochrome images.
329
- """
330
- n = len(imgs)
331
- if not isinstance(cmaps, (list, tuple)):
332
- cmaps = [cmaps] * n
333
- # figsize = (size*n, size*3/4) if size is not None else None
334
- figsize = (size * n, size * 6 / 5) if size is not None else None
335
- fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi)
336
-
337
- if n == 1:
338
- ax = [ax]
339
- for i in range(n):
340
- ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
341
- ax[i].get_yaxis().set_ticks([])
342
- ax[i].get_xaxis().set_ticks([])
343
- ax[i].set_axis_off()
344
- for spine in ax[i].spines.values(): # remove frame
345
- spine.set_visible(False)
346
- if titles:
347
- ax[i].set_title(titles[i])
348
- fig.tight_layout(pad=pad)
349
- return fig
350
-
351
-
352
- def plot_keypoints(kpts, colors="lime", ps=4):
353
- """Plot keypoints for existing images.
354
- Args:
355
- kpts: list of ndarrays of size (N, 2).
356
- colors: string, or list of list of tuples (one for each keypoints).
357
- ps: size of the keypoints as float.
358
- """
359
- if not isinstance(colors, list):
360
- colors = [colors] * len(kpts)
361
- axes = plt.gcf().axes
362
- for a, k, c in zip(axes, kpts, colors):
363
- a.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0)
364
-
365
-
366
- def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.0):
367
- """Plot matches for a pair of existing images.
368
- Args:
369
- kpts0, kpts1: corresponding keypoints of size (N, 2).
370
- color: color of each match, string or RGB tuple. Random if not given.
371
- lw: width of the lines.
372
- ps: size of the end points (no endpoint if ps=0)
373
- indices: indices of the images to draw the matches on.
374
- a: alpha opacity of the match lines.
375
- """
376
- fig = plt.gcf()
377
- ax = fig.axes
378
- assert len(ax) > max(indices)
379
- ax0, ax1 = ax[indices[0]], ax[indices[1]]
380
- fig.canvas.draw()
381
-
382
- assert len(kpts0) == len(kpts1)
383
- if color is None:
384
- color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist()
385
- elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
386
- color = [color] * len(kpts0)
387
-
388
- if lw > 0:
389
- # transform the points into the figure coordinate system
390
- transFigure = fig.transFigure.inverted()
391
- fkpts0 = transFigure.transform(ax0.transData.transform(kpts0))
392
- fkpts1 = transFigure.transform(ax1.transData.transform(kpts1))
393
- fig.lines += [
394
- matplotlib.lines.Line2D(
395
- (fkpts0[i, 0], fkpts1[i, 0]),
396
- (fkpts0[i, 1], fkpts1[i, 1]),
397
- zorder=1,
398
- transform=fig.transFigure,
399
- c=color[i],
400
- linewidth=lw,
401
- alpha=a,
402
- )
403
- for i in range(len(kpts0))
404
- ]
405
-
406
- # freeze the axes to prevent the transform to change
407
- ax0.autoscale(enable=False)
408
- ax1.autoscale(enable=False)
409
-
410
- if ps > 0:
411
- ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps, zorder=2)
412
- ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps, zorder=2)
413
-
414
-
415
- def plot_lines(
416
- lines, line_colors="orange", point_colors="cyan", ps=4, lw=2, indices=(0, 1)
417
- ):
418
- """Plot lines and endpoints for existing images.
419
- Args:
420
- lines: list of ndarrays of size (N, 2, 2).
421
- colors: string, or list of list of tuples (one for each keypoints).
422
- ps: size of the keypoints as float pixels.
423
- lw: line width as float pixels.
424
- indices: indices of the images to draw the matches on.
425
- """
426
- if not isinstance(line_colors, list):
427
- line_colors = [line_colors] * len(lines)
428
- if not isinstance(point_colors, list):
429
- point_colors = [point_colors] * len(lines)
430
-
431
- fig = plt.gcf()
432
- ax = fig.axes
433
- assert len(ax) > max(indices)
434
- axes = [ax[i] for i in indices]
435
- fig.canvas.draw()
436
-
437
- # Plot the lines and junctions
438
- for a, l, lc, pc in zip(axes, lines, line_colors, point_colors):
439
- for i in range(len(l)):
440
- line = matplotlib.lines.Line2D(
441
- (l[i, 0, 0], l[i, 1, 0]),
442
- (l[i, 0, 1], l[i, 1, 1]),
443
- zorder=1,
444
- c=lc,
445
- linewidth=lw,
446
- )
447
- a.add_line(line)
448
- pts = l.reshape(-1, 2)
449
- a.scatter(pts[:, 0], pts[:, 1], c=pc, s=ps, linewidths=0, zorder=2)
450
-
451
- return fig
452
-
453
-
454
- def plot_line_matches(kpts0, kpts1, color=None, lw=1.5, indices=(0, 1), a=1.0):
455
- """Plot matches for a pair of existing images, parametrized by their middle point.
456
- Args:
457
- kpts0, kpts1: corresponding middle points of the lines of size (N, 2).
458
- color: color of each match, string or RGB tuple. Random if not given.
459
- lw: width of the lines.
460
- indices: indices of the images to draw the matches on.
461
- a: alpha opacity of the match lines.
462
- """
463
- fig = plt.gcf()
464
- ax = fig.axes
465
- assert len(ax) > max(indices)
466
- ax0, ax1 = ax[indices[0]], ax[indices[1]]
467
- fig.canvas.draw()
468
-
469
- assert len(kpts0) == len(kpts1)
470
- if color is None:
471
- color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist()
472
- elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
473
- color = [color] * len(kpts0)
474
-
475
- if lw > 0:
476
- # transform the points into the figure coordinate system
477
- transFigure = fig.transFigure.inverted()
478
- fkpts0 = transFigure.transform(ax0.transData.transform(kpts0))
479
- fkpts1 = transFigure.transform(ax1.transData.transform(kpts1))
480
- fig.lines += [
481
- matplotlib.lines.Line2D(
482
- (fkpts0[i, 0], fkpts1[i, 0]),
483
- (fkpts0[i, 1], fkpts1[i, 1]),
484
- zorder=1,
485
- transform=fig.transFigure,
486
- c=color[i],
487
- linewidth=lw,
488
- alpha=a,
489
- )
490
- for i in range(len(kpts0))
491
- ]
492
-
493
- # freeze the axes to prevent the transform to change
494
- ax0.autoscale(enable=False)
495
- ax1.autoscale(enable=False)
496
-
497
-
498
- def plot_color_line_matches(lines, correct_matches=None, lw=2, indices=(0, 1)):
499
- """Plot line matches for existing images with multiple colors.
500
- Args:
501
- lines: list of ndarrays of size (N, 2, 2).
502
- correct_matches: bool array of size (N,) indicating correct matches.
503
- lw: line width as float pixels.
504
- indices: indices of the images to draw the matches on.
505
- """
506
- n_lines = len(lines[0])
507
- colors = sns.color_palette("husl", n_colors=n_lines)
508
- np.random.shuffle(colors)
509
- alphas = np.ones(n_lines)
510
- # If correct_matches is not None, display wrong matches with a low alpha
511
- if correct_matches is not None:
512
- alphas[~np.array(correct_matches)] = 0.2
513
-
514
- fig = plt.gcf()
515
- ax = fig.axes
516
- assert len(ax) > max(indices)
517
- axes = [ax[i] for i in indices]
518
- fig.canvas.draw()
519
-
520
- # Plot the lines
521
- for a, l in zip(axes, lines):
522
- # Transform the points into the figure coordinate system
523
- transFigure = fig.transFigure.inverted()
524
- endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
525
- endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
526
- fig.lines += [
527
- matplotlib.lines.Line2D(
528
- (endpoint0[i, 0], endpoint1[i, 0]),
529
- (endpoint0[i, 1], endpoint1[i, 1]),
530
- zorder=1,
531
- transform=fig.transFigure,
532
- c=colors[i],
533
- alpha=alphas[i],
534
- linewidth=lw,
535
- )
536
- for i in range(n_lines)
537
- ]
538
-
539
- return fig
540
-
541
-
542
- def plot_color_lines(lines, correct_matches, wrong_matches, lw=2, indices=(0, 1)):
543
- """Plot line matches for existing images with multiple colors:
544
- green for correct matches, red for wrong ones, and blue for the rest.
545
- Args:
546
- lines: list of ndarrays of size (N, 2, 2).
547
- correct_matches: list of bool arrays of size N with correct matches.
548
- wrong_matches: list of bool arrays of size (N,) with correct matches.
549
- lw: line width as float pixels.
550
- indices: indices of the images to draw the matches on.
551
- """
552
- # palette = sns.color_palette()
553
- palette = sns.color_palette("hls", 8)
554
- blue = palette[5] # palette[0]
555
- red = palette[0] # palette[3]
556
- green = palette[2] # palette[2]
557
- colors = [np.array([blue] * len(l)) for l in lines]
558
- for i, c in enumerate(colors):
559
- c[np.array(correct_matches[i])] = green
560
- c[np.array(wrong_matches[i])] = red
561
-
562
- fig = plt.gcf()
563
- ax = fig.axes
564
- assert len(ax) > max(indices)
565
- axes = [ax[i] for i in indices]
566
- fig.canvas.draw()
567
-
568
- # Plot the lines
569
- for a, l, c in zip(axes, lines, colors):
570
- # Transform the points into the figure coordinate system
571
- transFigure = fig.transFigure.inverted()
572
- endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
573
- endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
574
- fig.lines += [
575
- matplotlib.lines.Line2D(
576
- (endpoint0[i, 0], endpoint1[i, 0]),
577
- (endpoint0[i, 1], endpoint1[i, 1]),
578
- zorder=1,
579
- transform=fig.transFigure,
580
- c=c[i],
581
- linewidth=lw,
582
- )
583
- for i in range(len(l))
584
- ]
585
-
586
-
587
- def plot_subsegment_matches(lines, subsegments, lw=2, indices=(0, 1)):
588
- """Plot line matches for existing images with multiple colors and
589
- highlight the actually matched subsegments.
590
- Args:
591
- lines: list of ndarrays of size (N, 2, 2).
592
- subsegments: list of ndarrays of size (N, 2, 2).
593
- lw: line width as float pixels.
594
- indices: indices of the images to draw the matches on.
595
- """
596
- n_lines = len(lines[0])
597
- colors = sns.cubehelix_palette(
598
- start=2, rot=-0.2, dark=0.3, light=0.7, gamma=1.3, hue=1, n_colors=n_lines
599
- )
600
-
601
- fig = plt.gcf()
602
- ax = fig.axes
603
- assert len(ax) > max(indices)
604
- axes = [ax[i] for i in indices]
605
- fig.canvas.draw()
606
-
607
- # Plot the lines
608
- for a, l, ss in zip(axes, lines, subsegments):
609
- # Transform the points into the figure coordinate system
610
- transFigure = fig.transFigure.inverted()
611
-
612
- # Draw full line
613
- endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
614
- endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
615
- fig.lines += [
616
- matplotlib.lines.Line2D(
617
- (endpoint0[i, 0], endpoint1[i, 0]),
618
- (endpoint0[i, 1], endpoint1[i, 1]),
619
- zorder=1,
620
- transform=fig.transFigure,
621
- c="red",
622
- alpha=0.7,
623
- linewidth=lw,
624
- )
625
- for i in range(n_lines)
626
- ]
627
-
628
- # Draw matched subsegment
629
- endpoint0 = transFigure.transform(a.transData.transform(ss[:, 0]))
630
- endpoint1 = transFigure.transform(a.transData.transform(ss[:, 1]))
631
- fig.lines += [
632
- matplotlib.lines.Line2D(
633
- (endpoint0[i, 0], endpoint1[i, 0]),
634
- (endpoint0[i, 1], endpoint1[i, 1]),
635
- zorder=1,
636
- transform=fig.transFigure,
637
- c=colors[i],
638
- alpha=1,
639
- linewidth=lw,
640
- )
641
- for i in range(n_lines)
642
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
common/{plotting.py → viz.py} RENAMED
@@ -6,6 +6,7 @@ import matplotlib.cm as cm
6
  from PIL import Image
7
  import torch.nn.functional as F
8
  import torch
 
9
 
10
 
11
  def _compute_conf_thresh(data):
@@ -19,7 +20,77 @@ def _compute_conf_thresh(data):
19
  return thr
20
 
21
 
22
- # --- VISUALIZATION --- #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  def make_matching_figure(
@@ -57,7 +128,7 @@ def make_matching_figure(
57
  axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c="w", s=5)
58
 
59
  # draw matches
60
- if mkpts0.shape[0] > 1 and mkpts1.shape[0] > 1:
61
  fig.canvas.draw()
62
  transFigure = fig.transFigure.inverted()
63
  fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
@@ -105,8 +176,12 @@ def _make_evaluation_figure(data, b_id, alpha="dynamic"):
105
  b_mask = data["m_bids"] == b_id
106
  conf_thr = _compute_conf_thresh(data)
107
 
108
- img0 = (data["image0"][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
109
- img1 = (data["image1"][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
 
 
 
 
110
  kpts0 = data["mkpts0_f"][b_mask].cpu().numpy()
111
  kpts1 = data["mkpts1_f"][b_mask].cpu().numpy()
112
 
@@ -131,8 +206,10 @@ def _make_evaluation_figure(data, b_id, alpha="dynamic"):
131
 
132
  text = [
133
  f"#Matches {len(kpts0)}",
134
- f"Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}",
135
- f"Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}",
 
 
136
  ]
137
 
138
  # make the figure
@@ -188,7 +265,9 @@ def error_colormap(err, thr, alpha=1.0):
188
  assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
189
  x = 1 - np.clip(err / (thr * 2), 0, 1)
190
  return np.clip(
191
- np.stack([2 - x * 2, x * 2, np.zeros_like(x), np.ones_like(x) * alpha], -1),
 
 
192
  0,
193
  1,
194
  )
@@ -200,9 +279,13 @@ np.random.shuffle(color_map)
200
 
201
 
202
  def draw_topics(
203
- data, img0, img1, saved_folder="viz_topics", show_n_topics=8, saved_name=None
 
 
 
 
 
204
  ):
205
-
206
  topic0, topic1 = data["topic_matrix"]["img0"], data["topic_matrix"]["img1"]
207
  hw0_c, hw1_c = data["hw0_c"], data["hw1_c"]
208
  hw0_i, hw1_i = data["hw0_i"], data["hw1_i"]
@@ -237,7 +320,10 @@ def draw_topics(
237
  dim=-1, keepdim=True
238
  ) # .float() / (n_topics - 1) #* 255 + 1
239
  # topic1[~mask1_nonzero] = -1
240
- label_img0, label_img1 = torch.zeros_like(topic0) - 1, torch.zeros_like(topic1) - 1
 
 
 
241
  for i, k in enumerate(top_topics):
242
  label_img0[topic0 == k] = color_map[k]
243
  label_img1[topic1 == k] = color_map[k]
@@ -312,24 +398,30 @@ def draw_topicfm_demo(
312
  opencv_display=False,
313
  opencv_title="",
314
  ):
315
- topic_map0, topic_map1 = draw_topics(data, img0, img1, show_n_topics=show_n_topics)
316
-
317
- mask_tm0, mask_tm1 = np.expand_dims(topic_map0 >= 0, axis=-1), np.expand_dims(
318
- topic_map1 >= 0, axis=-1
319
  )
320
 
 
 
 
 
321
  topic_cm0, topic_cm1 = cm.jet(topic_map0 / 99.0), cm.jet(topic_map1 / 99.0)
322
- topic_cm0 = cv2.cvtColor(topic_cm0[..., :3].astype(np.float32), cv2.COLOR_RGB2BGR)
323
- topic_cm1 = cv2.cvtColor(topic_cm1[..., :3].astype(np.float32), cv2.COLOR_RGB2BGR)
 
 
 
 
324
  overlay0 = (mask_tm0 * topic_cm0 + (1 - mask_tm0) * img0).astype(np.float32)
325
  overlay1 = (mask_tm1 * topic_cm1 + (1 - mask_tm1) * img1).astype(np.float32)
326
 
327
  cv2.addWeighted(overlay0, topic_alpha, img0, 1 - topic_alpha, 0, overlay0)
328
  cv2.addWeighted(overlay1, topic_alpha, img1, 1 - topic_alpha, 0, overlay1)
329
 
330
- overlay0, overlay1 = (overlay0 * 255).astype(np.uint8), (overlay1 * 255).astype(
331
- np.uint8
332
- )
333
 
334
  h0, w0 = img0.shape[:2]
335
  h1, w1 = img1.shape[:2]
@@ -338,7 +430,9 @@ def draw_topicfm_demo(
338
  out_fig[:h0, :w0] = overlay0
339
  if h0 >= h1:
340
  start = (h0 - h1) // 2
341
- out_fig[start : (start + h1), (w0 + margin) : (w0 + margin + w1)] = overlay1
 
 
342
  else:
343
  start = (h1 - h0) // 2
344
  out_fig[:h0, (w0 + margin) : (w0 + margin + w1)] = overlay1[
@@ -358,7 +452,8 @@ def draw_topicfm_demo(
358
  img1[start : start + h0] * 255
359
  ).astype(np.uint8)
360
 
361
- # draw matching lines, this is inspried from https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/master/models/utils.py
 
362
  mkpts0, mkpts1 = np.round(mkpts0).astype(int), np.round(mkpts1).astype(int)
363
  mcolor = (np.array(mcolor[:, [2, 1, 0]]) * 255).astype(int)
364
 
 
6
  from PIL import Image
7
  import torch.nn.functional as F
8
  import torch
9
+ import seaborn as sns
10
 
11
 
12
  def _compute_conf_thresh(data):
 
20
  return thr
21
 
22
 
23
+ def plot_images(imgs, titles=None, cmaps="gray", dpi=100, size=5, pad=0.5):
24
+ """Plot a set of images horizontally.
25
+ Args:
26
+ imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
27
+ titles: a list of strings, as titles for each image.
28
+ cmaps: colormaps for monochrome images.
29
+ """
30
+ n = len(imgs)
31
+ if not isinstance(cmaps, (list, tuple)):
32
+ cmaps = [cmaps] * n
33
+ # figsize = (size*n, size*3/4) if size is not None else None
34
+ figsize = (size * n, size * 6 / 5) if size is not None else None
35
+ fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi)
36
+
37
+ if n == 1:
38
+ ax = [ax]
39
+ for i in range(n):
40
+ ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
41
+ ax[i].get_yaxis().set_ticks([])
42
+ ax[i].get_xaxis().set_ticks([])
43
+ ax[i].set_axis_off()
44
+ for spine in ax[i].spines.values(): # remove frame
45
+ spine.set_visible(False)
46
+ if titles:
47
+ ax[i].set_title(titles[i])
48
+ fig.tight_layout(pad=pad)
49
+ return fig
50
+
51
+
52
+ def plot_color_line_matches(lines, correct_matches=None, lw=2, indices=(0, 1)):
53
+ """Plot line matches for existing images with multiple colors.
54
+ Args:
55
+ lines: list of ndarrays of size (N, 2, 2).
56
+ correct_matches: bool array of size (N,) indicating correct matches.
57
+ lw: line width as float pixels.
58
+ indices: indices of the images to draw the matches on.
59
+ """
60
+ n_lines = len(lines[0])
61
+ colors = sns.color_palette("husl", n_colors=n_lines)
62
+ np.random.shuffle(colors)
63
+ alphas = np.ones(n_lines)
64
+ # If correct_matches is not None, display wrong matches with a low alpha
65
+ if correct_matches is not None:
66
+ alphas[~np.array(correct_matches)] = 0.2
67
+
68
+ fig = plt.gcf()
69
+ ax = fig.axes
70
+ assert len(ax) > max(indices)
71
+ axes = [ax[i] for i in indices]
72
+ fig.canvas.draw()
73
+
74
+ # Plot the lines
75
+ for a, l in zip(axes, lines):
76
+ # Transform the points into the figure coordinate system
77
+ transFigure = fig.transFigure.inverted()
78
+ endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
79
+ endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
80
+ fig.lines += [
81
+ matplotlib.lines.Line2D(
82
+ (endpoint0[i, 0], endpoint1[i, 0]),
83
+ (endpoint0[i, 1], endpoint1[i, 1]),
84
+ zorder=1,
85
+ transform=fig.transFigure,
86
+ c=colors[i],
87
+ alpha=alphas[i],
88
+ linewidth=lw,
89
+ )
90
+ for i in range(n_lines)
91
+ ]
92
+
93
+ return fig
94
 
95
 
96
  def make_matching_figure(
 
128
  axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c="w", s=5)
129
 
130
  # draw matches
131
+ if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
132
  fig.canvas.draw()
133
  transFigure = fig.transFigure.inverted()
134
  fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
 
176
  b_mask = data["m_bids"] == b_id
177
  conf_thr = _compute_conf_thresh(data)
178
 
179
+ img0 = (
180
+ (data["image0"][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
181
+ )
182
+ img1 = (
183
+ (data["image1"][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
184
+ )
185
  kpts0 = data["mkpts0_f"][b_mask].cpu().numpy()
186
  kpts1 = data["mkpts1_f"][b_mask].cpu().numpy()
187
 
 
206
 
207
  text = [
208
  f"#Matches {len(kpts0)}",
209
+ f"Precision({conf_thr:.2e}) ({100 * precision:.1f}%):"
210
+ f" {n_correct}/{len(kpts0)}",
211
+ f"Recall({conf_thr:.2e}) ({100 * recall:.1f}%):"
212
+ f" {n_correct}/{n_gt_matches}",
213
  ]
214
 
215
  # make the figure
 
265
  assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
266
  x = 1 - np.clip(err / (thr * 2), 0, 1)
267
  return np.clip(
268
+ np.stack(
269
+ [2 - x * 2, x * 2, np.zeros_like(x), np.ones_like(x) * alpha], -1
270
+ ),
271
  0,
272
  1,
273
  )
 
279
 
280
 
281
  def draw_topics(
282
+ data,
283
+ img0,
284
+ img1,
285
+ saved_folder="viz_topics",
286
+ show_n_topics=8,
287
+ saved_name=None,
288
  ):
 
289
  topic0, topic1 = data["topic_matrix"]["img0"], data["topic_matrix"]["img1"]
290
  hw0_c, hw1_c = data["hw0_c"], data["hw1_c"]
291
  hw0_i, hw1_i = data["hw0_i"], data["hw1_i"]
 
320
  dim=-1, keepdim=True
321
  ) # .float() / (n_topics - 1) #* 255 + 1
322
  # topic1[~mask1_nonzero] = -1
323
+ label_img0, label_img1 = (
324
+ torch.zeros_like(topic0) - 1,
325
+ torch.zeros_like(topic1) - 1,
326
+ )
327
  for i, k in enumerate(top_topics):
328
  label_img0[topic0 == k] = color_map[k]
329
  label_img1[topic1 == k] = color_map[k]
 
398
  opencv_display=False,
399
  opencv_title="",
400
  ):
401
+ topic_map0, topic_map1 = draw_topics(
402
+ data, img0, img1, show_n_topics=show_n_topics
 
 
403
  )
404
 
405
+ mask_tm0, mask_tm1 = np.expand_dims(
406
+ topic_map0 >= 0, axis=-1
407
+ ), np.expand_dims(topic_map1 >= 0, axis=-1)
408
+
409
  topic_cm0, topic_cm1 = cm.jet(topic_map0 / 99.0), cm.jet(topic_map1 / 99.0)
410
+ topic_cm0 = cv2.cvtColor(
411
+ topic_cm0[..., :3].astype(np.float32), cv2.COLOR_RGB2BGR
412
+ )
413
+ topic_cm1 = cv2.cvtColor(
414
+ topic_cm1[..., :3].astype(np.float32), cv2.COLOR_RGB2BGR
415
+ )
416
  overlay0 = (mask_tm0 * topic_cm0 + (1 - mask_tm0) * img0).astype(np.float32)
417
  overlay1 = (mask_tm1 * topic_cm1 + (1 - mask_tm1) * img1).astype(np.float32)
418
 
419
  cv2.addWeighted(overlay0, topic_alpha, img0, 1 - topic_alpha, 0, overlay0)
420
  cv2.addWeighted(overlay1, topic_alpha, img1, 1 - topic_alpha, 0, overlay1)
421
 
422
+ overlay0, overlay1 = (overlay0 * 255).astype(np.uint8), (
423
+ overlay1 * 255
424
+ ).astype(np.uint8)
425
 
426
  h0, w0 = img0.shape[:2]
427
  h1, w1 = img1.shape[:2]
 
430
  out_fig[:h0, :w0] = overlay0
431
  if h0 >= h1:
432
  start = (h0 - h1) // 2
433
+ out_fig[
434
+ start : (start + h1), (w0 + margin) : (w0 + margin + w1)
435
+ ] = overlay1
436
  else:
437
  start = (h1 - h0) // 2
438
  out_fig[:h0, (w0 + margin) : (w0 + margin + w1)] = overlay1[
 
452
  img1[start : start + h0] * 255
453
  ).astype(np.uint8)
454
 
455
+ # draw matching lines, this is inspried from
456
+ # https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/master/models/utils.py
457
  mkpts0, mkpts1 = np.round(mkpts0).astype(int), np.round(mkpts1).astype(int)
458
  mcolor = (np.array(mcolor[:, [2, 1, 0]]) * 255).astype(int)
459
 
style.css ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+
5
+ #duplicate-button {
6
+ margin: auto;
7
+ color: white;
8
+ background: #1565c0;
9
+ border-radius: 100vh;
10
+ }
11
+
12
+ #component-0 {
13
+ /* max-width: 900px; */
14
+ margin: auto;
15
+ padding-top: 1.5rem;
16
+ }
17
+
18
+ footer {visibility: hidden}