Realcat commited on
Commit
34aafe5
·
1 Parent(s): c0283b3

add: matching API

Browse files
Files changed (6) hide show
  1. README.md +2 -1
  2. common/api.py +298 -0
  3. common/config.yaml +2 -0
  4. common/utils.py +15 -12
  5. common/viz.py +10 -6
  6. test_app_cli.py +16 -130
README.md CHANGED
@@ -38,7 +38,8 @@ The tool currently supports various popular image matching algorithms, namely:
38
  - [x] [RoMa](https://github.com/Vincentqyw/RoMa), CVPR 2024
39
  - [x] [DeDoDe](https://github.com/Parskatt/DeDoDe), 3DV 2024
40
  - [ ] [Mickey](https://github.com/nianticlabs/mickey), CVPR 2024
41
- - [ ] [GIM](https://github.com/xuelunshen/gim), ICLR 2024
 
42
  - [x] [LightGlue](https://github.com/cvg/LightGlue), ICCV 2023
43
  - [x] [DarkFeat](https://github.com/THU-LYJ-Lab/DarkFeat), AAAI 2023
44
  - [ ] [ASTR](https://github.com/ASTR2023/ASTR), CVPR 2023
 
38
  - [x] [RoMa](https://github.com/Vincentqyw/RoMa), CVPR 2024
39
  - [x] [DeDoDe](https://github.com/Parskatt/DeDoDe), 3DV 2024
40
  - [ ] [Mickey](https://github.com/nianticlabs/mickey), CVPR 2024
41
+ - [x] [GIM](https://github.com/xuelunshen/gim), ICLR 2024
42
+ - [ ] [DUSt3R](https://github.com/naver/dust3r), arXiv 2023
43
  - [x] [LightGlue](https://github.com/cvg/LightGlue), ICCV 2023
44
  - [x] [DarkFeat](https://github.com/THU-LYJ-Lab/DarkFeat), AAAI 2023
45
  - [ ] [ASTR](https://github.com/ASTR2023/ASTR), CVPR 2023
common/api.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import warnings
4
+ import numpy as np
5
+ from pathlib import Path
6
+ from typing import Dict, Any, Optional, Tuple, List, Union
7
+ from hloc import logger
8
+ from hloc import match_dense, match_features, extract_features
9
+ from hloc.utils.viz import add_text, plot_keypoints
10
+ from .utils import (
11
+ load_config,
12
+ get_model,
13
+ get_feature_model,
14
+ filter_matches,
15
+ device,
16
+ ROOT,
17
+ )
18
+ from .viz import (
19
+ fig2im,
20
+ plot_images,
21
+ display_matches,
22
+ )
23
+ import matplotlib.pyplot as plt
24
+
25
+ warnings.simplefilter("ignore")
26
+
27
+
28
+ class ImageMatchingAPI(torch.nn.Module):
29
+ default_conf = {
30
+ "dense": True,
31
+ "matcher": {
32
+ "model": {
33
+ "name": "topicfm",
34
+ "match_threshold": 0.2,
35
+ }
36
+ },
37
+ "feature": {
38
+ "model": {
39
+ "name": "xfeat",
40
+ "max_keypoints": 1024,
41
+ "keypoint_threshold": 0.015,
42
+ }
43
+ },
44
+ "ransac": {
45
+ "enable": True,
46
+ "estimator": "poselib",
47
+ "geometry": "homography",
48
+ "method": "RANSAC",
49
+ "reproj_threshold": 3,
50
+ "confidence": 0.9999,
51
+ "max_iter": 10000,
52
+ },
53
+ }
54
+
55
+ def __init__(
56
+ self,
57
+ conf: dict = {},
58
+ device: str = "cpu",
59
+ detect_threshold: float = 0.015,
60
+ max_keypoints: int = 1024,
61
+ match_threshold: float = 0.2,
62
+ ) -> None:
63
+ """
64
+ Initializes an instance of the ImageMatchingAPI class.
65
+
66
+ Args:
67
+ conf (dict): A dictionary containing the configuration parameters.
68
+ device (str, optional): The device to use for computation. Defaults to "cpu".
69
+ detect_threshold (float, optional): The threshold for detecting keypoints. Defaults to 0.015.
70
+ max_keypoints (int, optional): The maximum number of keypoints to extract. Defaults to 1024.
71
+ match_threshold (float, optional): The threshold for matching keypoints. Defaults to 0.2.
72
+
73
+ Returns:
74
+ None
75
+ """
76
+ super().__init__()
77
+ self.device = device
78
+ self.conf = conf = {
79
+ **self.parse_match_config(self.default_conf),
80
+ **conf,
81
+ }
82
+ self._updata_config(detect_threshold, max_keypoints, match_threshold)
83
+ self._init_models()
84
+ self.pred = None
85
+
86
+ def parse_match_config(self, conf):
87
+ if conf["dense"]:
88
+ return {
89
+ **conf,
90
+ "matcher": match_dense.confs.get(
91
+ conf["matcher"]["model"]["name"]
92
+ ),
93
+ "dense": True,
94
+ }
95
+ else:
96
+ return {
97
+ **conf,
98
+ "feature": extract_features.confs.get(
99
+ conf["feature"]["model"]["name"]
100
+ ),
101
+ "matcher": match_features.confs.get(
102
+ conf["matcher"]["model"]["name"]
103
+ ),
104
+ "dense": False,
105
+ }
106
+
107
+ def _updata_config(
108
+ self,
109
+ detect_threshold: float = 0.015,
110
+ max_keypoints: int = 1024,
111
+ match_threshold: float = 0.2,
112
+ ):
113
+ self.dense = self.conf["dense"]
114
+ if self.conf["dense"]:
115
+ self.conf["matcher"]["model"]["match_threshold"] = match_threshold
116
+ else:
117
+ self.conf["feature"]["model"]["max_keypoints"] = max_keypoints
118
+ self.conf["feature"]["model"][
119
+ "keypoint_threshold"
120
+ ] = detect_threshold
121
+ self.match_conf = self.conf["matcher"]
122
+ self.extract_conf = self.conf["feature"]
123
+
124
+ def _init_models(self):
125
+ # initialize matcher
126
+ self.matcher = get_model(self.conf["matcher"])
127
+ # initialize extractor
128
+ if self.dense:
129
+ self.extractor = None
130
+ else:
131
+ self.extractor = get_feature_model(self.conf["feature"])
132
+
133
+ def _forward(self, img0, img1):
134
+ if self.dense:
135
+ pred = match_dense.match_images(
136
+ self.matcher,
137
+ img0,
138
+ img1,
139
+ self.match_conf["preprocessing"],
140
+ device=self.device,
141
+ )
142
+ last_fixed = "{}".format(self.match_conf["model"]["name"])
143
+ else:
144
+ pred0 = extract_features.extract(
145
+ self.extractor, img0, self.extract_conf["preprocessing"]
146
+ )
147
+ pred1 = extract_features.extract(
148
+ self.extractor, img1, self.extract_conf["preprocessing"]
149
+ )
150
+ pred = match_features.match_images(self.matcher, pred0, pred1)
151
+ return pred
152
+
153
+ @torch.inference_mode()
154
+ def forward(
155
+ self,
156
+ img0: np.ndarray,
157
+ img1: np.ndarray,
158
+ ) -> Dict[str, np.ndarray]:
159
+ """
160
+ Forward pass of the image matching API.
161
+
162
+ Args:
163
+ img0: A 3D NumPy array of shape (H, W, C) representing the first image.
164
+ Values are in the range [0, 1] and are in RGB mode.
165
+ img1: A 3D NumPy array of shape (H, W, C) representing the second image.
166
+ Values are in the range [0, 1] and are in RGB mode.
167
+
168
+ Returns:
169
+ A dictionary containing the following keys:
170
+ - image0_orig: The original image 0.
171
+ - image1_orig: The original image 1.
172
+ - keypoints0_orig: The keypoints detected in image 0.
173
+ - keypoints1_orig: The keypoints detected in image 1.
174
+ - mkeypoints0_orig: The raw matches between image 0 and image 1.
175
+ - mkeypoints1_orig: The raw matches between image 1 and image 0.
176
+ - mmkeypoints0_orig: The RANSAC inliers in image 0.
177
+ - mmkeypoints1_orig: The RANSAC inliers in image 1.
178
+ - mconf: The confidence scores for the raw matches.
179
+ - mmconf: The confidence scores for the RANSAC inliers.
180
+ """
181
+ # Take as input a pair of images (not a batch)
182
+ assert isinstance(img0, np.ndarray)
183
+ assert isinstance(img1, np.ndarray)
184
+ self.pred = self._forward(img0, img1)
185
+ if self.conf["ransac"]["enable"]:
186
+ self.pred = self._geometry_check(self.pred)
187
+ return self.pred
188
+
189
+ def _geometry_check(
190
+ self,
191
+ pred: Dict[str, Any],
192
+ ) -> Dict[str, Any]:
193
+ """
194
+ Filter matches using RANSAC. If keypoints are available, filter by keypoints.
195
+ If lines are available, filter by lines. If both keypoints and lines are
196
+ available, filter by keypoints.
197
+
198
+ Args:
199
+ pred (Dict[str, Any]): dict of matches, including original keypoints.
200
+ See :func:`filter_matches` for the expected keys.
201
+
202
+ Returns:
203
+ Dict[str, Any]: filtered matches
204
+ """
205
+ pred = filter_matches(
206
+ pred,
207
+ ransac_method=self.conf["ransac"]["method"],
208
+ ransac_reproj_threshold=self.conf["ransac"]["reproj_threshold"],
209
+ ransac_confidence=self.conf["ransac"]["confidence"],
210
+ ransac_max_iter=self.conf["ransac"]["max_iter"],
211
+ )
212
+ return pred
213
+
214
+ def visualize(
215
+ self,
216
+ log_path: Optional[Path] = None,
217
+ ) -> None:
218
+ """
219
+ Visualize the matches.
220
+
221
+ Args:
222
+ log_path (Path, optional): The directory to save the images. Defaults to None.
223
+
224
+ Returns:
225
+ None
226
+ """
227
+ if self.conf["dense"]:
228
+ postfix = str(self.conf["matcher"]["model"]["name"])
229
+ else:
230
+ postfix = "{}_{}".format(
231
+ str(self.conf["feature"]["model"]["name"]),
232
+ str(self.conf["matcher"]["model"]["name"]),
233
+ )
234
+ titles = [
235
+ "Image 0 - Keypoints",
236
+ "Image 1 - Keypoints",
237
+ ]
238
+ pred: Dict[str, Any] = self.pred
239
+ image0: np.ndarray = pred["image0_orig"]
240
+ image1: np.ndarray = pred["image1_orig"]
241
+ output_keypoints: np.ndarray = plot_images(
242
+ [image0, image1], titles=titles, dpi=300
243
+ )
244
+ if (
245
+ "keypoints0_orig" in pred.keys()
246
+ and "keypoints1_orig" in pred.keys()
247
+ ):
248
+ plot_keypoints([pred["keypoints0_orig"], pred["keypoints1_orig"]])
249
+ text: str = (
250
+ f"# keypoints0: {len(pred['keypoints0_orig'])} \n"
251
+ + f"# keypoints1: {len(pred['keypoints1_orig'])}"
252
+ )
253
+ add_text(0, text, fs=15)
254
+ output_keypoints = fig2im(output_keypoints)
255
+ # plot images with raw matches
256
+ titles = [
257
+ "Image 0 - Raw matched keypoints",
258
+ "Image 1 - Raw matched keypoints",
259
+ ]
260
+ output_matches_raw, num_matches_raw = display_matches(
261
+ pred, titles=titles, tag="KPTS_RAW"
262
+ )
263
+ # plot images with ransac matches
264
+ titles = [
265
+ "Image 0 - Ransac matched keypoints",
266
+ "Image 1 - Ransac matched keypoints",
267
+ ]
268
+ output_matches_ransac, num_matches_ransac = display_matches(
269
+ pred, titles=titles, tag="KPTS_RANSAC"
270
+ )
271
+ if log_path is not None:
272
+ img_keypoints_path: Path = log_path / f"img_keypoints_{postfix}.png"
273
+ img_matches_raw_path: Path = (
274
+ log_path / f"img_matches_raw_{postfix}.png"
275
+ )
276
+ img_matches_ransac_path: Path = (
277
+ log_path / f"img_matches_ransac_{postfix}.png"
278
+ )
279
+ cv2.imwrite(
280
+ str(img_keypoints_path),
281
+ output_keypoints[:, :, ::-1].copy(), # RGB -> BGR
282
+ )
283
+ cv2.imwrite(
284
+ str(img_matches_raw_path),
285
+ output_matches_raw[:, :, ::-1].copy(), # RGB -> BGR
286
+ )
287
+ cv2.imwrite(
288
+ str(img_matches_ransac_path),
289
+ output_matches_ransac[:, :, ::-1].copy(), # RGB -> BGR
290
+ )
291
+ plt.close("all")
292
+
293
+
294
+ if __name__ == "__main__":
295
+ import argparse
296
+
297
+ config = load_config(ROOT / "common/config.yaml")
298
+ test_api(config)
common/config.yaml CHANGED
@@ -319,6 +319,7 @@ matcher_zoo:
319
  project: null
320
  display: true
321
  gluestick:
 
322
  matcher: gluestick
323
  dense: true
324
  info:
@@ -329,6 +330,7 @@ matcher_zoo:
329
  project: https://iago-suarez.com/gluestick
330
  display: true
331
  sold2:
 
332
  matcher: sold2
333
  dense: true
334
  info:
 
319
  project: null
320
  display: true
321
  gluestick:
322
+ enable: false
323
  matcher: gluestick
324
  dense: true
325
  info:
 
330
  project: https://iago-suarez.com/gluestick
331
  display: true
332
  sold2:
333
+ enable: false
334
  matcher: sold2
335
  dense: true
336
  info:
common/utils.py CHANGED
@@ -78,21 +78,24 @@ def get_matcher_zoo(
78
  """
79
  matcher_zoo_restored = {}
80
  for k, v in matcher_zoo.items():
81
- dense = v["dense"]
82
- if dense:
83
- matcher_zoo_restored[k] = {
84
- "matcher": match_dense.confs.get(v["matcher"]),
85
- "dense": dense,
86
- }
87
- else:
88
- matcher_zoo_restored[k] = {
89
- "feature": extract_features.confs.get(v["feature"]),
90
- "matcher": match_features.confs.get(v["matcher"]),
91
- "dense": dense,
92
- }
93
  return matcher_zoo_restored
94
 
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  def get_model(match_conf: Dict[str, Any]):
97
  """
98
  Load a matcher model from the provided configuration.
 
78
  """
79
  matcher_zoo_restored = {}
80
  for k, v in matcher_zoo.items():
81
+ matcher_zoo_restored[k] = parse_match_config(v)
 
 
 
 
 
 
 
 
 
 
 
82
  return matcher_zoo_restored
83
 
84
 
85
+ def parse_match_config(conf):
86
+ if conf["dense"]:
87
+ return {
88
+ "matcher": match_dense.confs.get(conf["matcher"]),
89
+ "dense": True,
90
+ }
91
+ else:
92
+ return {
93
+ "feature": extract_features.confs.get(conf["feature"]),
94
+ "matcher": match_features.confs.get(conf["matcher"]),
95
+ "dense": False,
96
+ }
97
+
98
+
99
  def get_model(match_conf: Dict[str, Any]):
100
  """
101
  Load a matcher model from the provided configuration.
common/viz.py CHANGED
@@ -415,12 +415,17 @@ def display_matches(
415
  num_inliers = 0
416
  KPTS0_KEY = None
417
  KPTS1_KEY = None
 
418
  if tag == "KPTS_RAW":
419
  KPTS0_KEY = "mkeypoints0_orig"
420
  KPTS1_KEY = "mkeypoints1_orig"
 
 
421
  elif tag == "KPTS_RANSAC":
422
  KPTS0_KEY = "mmkeypoints0_orig"
423
  KPTS1_KEY = "mmkeypoints1_orig"
 
 
424
  else:
425
  # TODO: LINES_RAW, LINES_RANSAC
426
  raise ValueError(f"Unknown tag: {tag}")
@@ -434,16 +439,14 @@ def display_matches(
434
  mkpts0 = pred[KPTS0_KEY]
435
  mkpts1 = pred[KPTS1_KEY]
436
  num_inliers = len(mkpts0)
437
- if "mmconf" in pred:
438
- mmconf = pred["mmconf"]
439
- else:
440
- mmconf = np.ones(len(mkpts0))
441
  fig_mkpts = draw_matches_core(
442
  mkpts0,
443
  mkpts1,
444
  img0,
445
  img1,
446
- mmconf,
447
  dpi=dpi,
448
  titles=titles,
449
  texts=texts,
@@ -472,7 +475,8 @@ def display_matches(
472
  # keypoints
473
  mkpts0 = pred.get("line_keypoints0_orig")
474
  mkpts1 = pred.get("line_keypoints1_orig")
475
-
 
476
  if mkpts0 is not None and mkpts1 is not None:
477
  num_inliers = len(mkpts0)
478
  if "mconf" in pred:
 
415
  num_inliers = 0
416
  KPTS0_KEY = None
417
  KPTS1_KEY = None
418
+ confid = None
419
  if tag == "KPTS_RAW":
420
  KPTS0_KEY = "mkeypoints0_orig"
421
  KPTS1_KEY = "mkeypoints1_orig"
422
+ if "mconf" in pred:
423
+ confid = pred["mconf"]
424
  elif tag == "KPTS_RANSAC":
425
  KPTS0_KEY = "mmkeypoints0_orig"
426
  KPTS1_KEY = "mmkeypoints1_orig"
427
+ if "mmconf" in pred:
428
+ confid = pred["mmconf"]
429
  else:
430
  # TODO: LINES_RAW, LINES_RANSAC
431
  raise ValueError(f"Unknown tag: {tag}")
 
439
  mkpts0 = pred[KPTS0_KEY]
440
  mkpts1 = pred[KPTS1_KEY]
441
  num_inliers = len(mkpts0)
442
+ if confid is None:
443
+ confid = np.ones(len(mkpts0))
 
 
444
  fig_mkpts = draw_matches_core(
445
  mkpts0,
446
  mkpts1,
447
  img0,
448
  img1,
449
+ confid,
450
  dpi=dpi,
451
  titles=titles,
452
  texts=texts,
 
475
  # keypoints
476
  mkpts0 = pred.get("line_keypoints0_orig")
477
  mkpts1 = pred.get("line_keypoints1_orig")
478
+ fig = None
479
+ breakpoint()
480
  if mkpts0 is not None and mkpts1 is not None:
481
  num_inliers = len(mkpts0)
482
  if "mconf" in pred:
test_app_cli.py CHANGED
@@ -1,153 +1,39 @@
1
  import cv2
2
  import warnings
 
3
  from pathlib import Path
4
  from hloc import logger
5
- from hloc import matchers, extractors, logger
6
- from hloc import match_dense, match_features, extract_features
7
- from hloc.utils.viz import add_text, plot_keypoints
8
  from common.utils import (
9
- load_config,
10
- get_model,
11
- get_feature_model,
12
- ransac_zoo,
13
  get_matcher_zoo,
14
- filter_matches,
15
  device,
16
  ROOT,
17
  )
18
- from common.viz import (
19
- fig2im,
20
- plot_images,
21
- display_matches,
22
- plot_color_line_matches,
23
- )
24
- import time
25
- import matplotlib.pyplot as plt
26
-
27
- warnings.simplefilter("ignore")
28
 
29
-
30
- def test_modules(config: dict):
31
  img_path1 = ROOT / "datasets/sacre_coeur/mapping/02928139_3448003521.jpg"
32
  img_path2 = ROOT / "datasets/sacre_coeur/mapping/17295357_9106075285.jpg"
33
- image0 = cv2.imread(str(img_path1))
34
- image1 = cv2.imread(str(img_path2))
35
- keypoint_threshold = 0.0
36
- extract_max_keypoints = 2000
37
- match_threshold = 0.2
38
- log_path = ROOT / "experiments"
39
- log_path.mkdir(exist_ok=True, parents=True)
40
 
41
  matcher_zoo_restored = get_matcher_zoo(config["matcher_zoo"])
42
  for k, v in matcher_zoo_restored.items():
43
  if image0 is None or image1 is None:
44
  logger.error("Error: No images found! Please upload two images.")
45
- # init output
46
- output_keypoints = None
47
- output_matches_raw = None
48
- output_matches_ransac = None
49
- match_conf = v["matcher"]
50
-
51
- # update match config
52
- match_conf["model"]["match_threshold"] = match_threshold
53
- match_conf["model"]["max_keypoints"] = extract_max_keypoints
54
- matcher = get_model(match_conf)
55
- t1 = time.time()
56
- if v["dense"]:
57
- pred = match_dense.match_images(
58
- matcher,
59
- image0,
60
- image1,
61
- match_conf["preprocessing"],
62
- device=device,
63
- )
64
- del matcher
65
- extract_conf = None
66
- last_fixed = "{}".format(match_conf["model"]["name"])
67
  else:
68
- extract_conf = v["feature"]
69
-
70
- # update extract config
71
- extract_conf["model"]["max_keypoints"] = extract_max_keypoints
72
- extract_conf["model"]["keypoint_threshold"] = keypoint_threshold
73
- extractor = get_feature_model(extract_conf)
74
- pred0 = extract_features.extract(
75
- extractor, image0, extract_conf["preprocessing"]
76
- )
77
- pred1 = extract_features.extract(
78
- extractor, image1, extract_conf["preprocessing"]
79
- )
80
- pred = match_features.match_images(matcher, pred0, pred1)
81
- del extractor
82
- last_fixed = "{}_{}".format(
83
- extract_conf["model"]["name"], match_conf["model"]["name"]
84
- )
85
-
86
- # keypoints on images
87
- logger.info(f"Match features done using: {time.time()-t1:.3f}s")
88
- t1 = time.time()
89
- texts = [
90
- f"image pairs: {img_path1.name} & {img_path2.name}",
91
- "",
92
- ]
93
- titles = [
94
- "Image 0 - Keypoints",
95
- "Image 1 - Keypoints",
96
- ]
97
- output_keypoints = plot_images([image0, image1], titles=titles, dpi=300)
98
- if "keypoints0" in pred.keys() and "keypoints1" in pred.keys():
99
- plot_keypoints([pred["keypoints0"], pred["keypoints1"]])
100
- text = (
101
- f"# keypoints0: {len(pred['keypoints0'])} \n"
102
- + f"# keypoints1: {len(pred['keypoints1'])}"
103
- )
104
- add_text(0, text, fs=15)
105
- output_keypoints = fig2im(output_keypoints)
106
-
107
- # plot images with raw matches
108
- titles = [
109
- "Image 0 - Raw matched keypoints",
110
- "Image 1 - Raw matched keypoints",
111
- ]
112
- output_matches_raw, num_matches_raw = display_matches(
113
- pred, titles=titles
114
- )
115
- logger.info(f"Plot keypoints done using: {time.time()-t1:.3f}s")
116
- t1 = time.time()
117
-
118
- filter_matches(
119
- pred,
120
- ransac_method=config["defaults"]["ransac_method"],
121
- ransac_reproj_threshold=config["defaults"][
122
- "ransac_reproj_threshold"
123
- ],
124
- ransac_confidence=config["defaults"]["ransac_confidence"],
125
- ransac_max_iter=config["defaults"]["ransac_max_iter"],
126
- )
127
- # plot images with ransac matches
128
- titles = [
129
- "Image 0 - Ransac matched keypoints",
130
- "Image 1 - Ransac matched keypoints",
131
- ]
132
- output_matches_ransac, num_matches_ransac = display_matches(
133
- pred, titles=titles
134
- )
135
- logger.info(f"RANSAC matches done using: {time.time()-t1:.3f}s")
136
-
137
- img_keypoints_path = log_path / f"img_keypoints_{last_fixed}.png"
138
- img_matches_raw_path = log_path / f"img_matches_raw_{last_fixed}.png"
139
- img_matches_ransac_path = (
140
- log_path / f"img_matches_ransac_{last_fixed}.png"
141
- )
142
- cv2.imwrite(str(img_keypoints_path), output_keypoints)
143
- cv2.imwrite(str(img_matches_raw_path), output_matches_raw)
144
- cv2.imwrite(str(img_matches_ransac_path), output_matches_ransac)
145
-
146
- plt.close("all")
147
-
148
 
149
  if __name__ == "__main__":
150
  import argparse
151
 
152
  config = load_config(ROOT / "common/config.yaml")
153
- test_modules(config)
 
1
  import cv2
2
  import warnings
3
+ import numpy as np
4
  from pathlib import Path
5
  from hloc import logger
 
 
 
6
  from common.utils import (
 
 
 
 
7
  get_matcher_zoo,
8
+ load_config,
9
  device,
10
  ROOT,
11
  )
12
+ from common.api import ImageMatchingAPI
 
 
 
 
 
 
 
 
 
13
 
14
+ def test_api(config: dict = None):
 
15
  img_path1 = ROOT / "datasets/sacre_coeur/mapping/02928139_3448003521.jpg"
16
  img_path2 = ROOT / "datasets/sacre_coeur/mapping/17295357_9106075285.jpg"
17
+ image0 = cv2.imread(str(img_path1))[:, :, ::-1]
18
+ image1 = cv2.imread(str(img_path2))[:, :, ::-1]
 
 
 
 
 
19
 
20
  matcher_zoo_restored = get_matcher_zoo(config["matcher_zoo"])
21
  for k, v in matcher_zoo_restored.items():
22
  if image0 is None or image1 is None:
23
  logger.error("Error: No images found! Please upload two images.")
24
+ enable = config["matcher_zoo"][k].get("enable", True)
25
+ if enable:
26
+ logger.info(f"Testing {k} ...")
27
+ api = ImageMatchingAPI(conf=v, device=device)
28
+ api(image0, image1)
29
+ log_path = ROOT / "experiments1"
30
+ log_path.mkdir(exist_ok=True, parents=True)
31
+ api.visualize(log_path=log_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  else:
33
+ logger.info(f"Skipping {k} ...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  if __name__ == "__main__":
36
  import argparse
37
 
38
  config = load_config(ROOT / "common/config.yaml")
39
+ test_api(config)