Realcat commited on
Commit
f215948
·
1 Parent(s): a9516a9

add: doc strings and type

Browse files
Files changed (3) hide show
  1. app.py +48 -23
  2. common/utils.py +188 -49
  3. common/viz.py +151 -37
app.py CHANGED
@@ -1,5 +1,7 @@
1
  import argparse
2
  from pathlib import Path
 
 
3
  import gradio as gr
4
  from common.utils import (
5
  matcher_zoo,
@@ -56,36 +58,59 @@ def ui_change_imagebox(choice):
56
  }
57
 
58
 
59
- def ui_reset_state(*args):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  """
61
  Reset the state of the UI.
62
 
63
  Returns:
64
  tuple: A tuple containing the initial values for the UI state.
65
  """
66
- key = list(matcher_zoo.keys())[0] # Get the first key from matcher_zoo
67
  return (
68
- None, # image0
69
- None, # image1
70
- DEFAULT_MATCHING_THRESHOLD, # matching_threshold
71
- DEFAULT_SETTING_MAX_FEATURES, # max_features
72
- DEFAULT_DEFAULT_KEYPOINT_THRESHOLD, # keypoint_threshold
73
- key, # matcher
74
- ui_change_imagebox("upload"), # input image0
75
- ui_change_imagebox("upload"), # input image1
76
- "upload", # match_image_src
77
- None, # keypoints
78
- None, # raw matches
79
- None, # ransac matches
80
- {}, # matches result info
81
- {}, # matcher config
82
- None, # warped imageInstance of 'Radio' has no 'change' member
83
- {}, # geometry result
84
- DEFAULT_RANSAC_METHOD, # ransac_method
85
- DEFAULT_RANSAC_REPROJ_THRESHOLD, # ransac_reproj_threshold
86
- DEFAULT_RANSAC_CONFIDENCE, # ransac_confidence
87
- DEFAULT_RANSAC_MAX_ITER, # ransac_max_iter
88
- DEFAULT_SETTING_GEOMETRY, # geometry
89
  )
90
 
91
 
 
1
  import argparse
2
  from pathlib import Path
3
+ import numpy as np
4
+ from typing import Dict, Any, Optional, Tuple, List, Union
5
  import gradio as gr
6
  from common.utils import (
7
  matcher_zoo,
 
58
  }
59
 
60
 
61
+ def ui_reset_state(
62
+ *args: Any,
63
+ ) -> Tuple[
64
+ Optional[np.ndarray],
65
+ Optional[np.ndarray],
66
+ float,
67
+ int,
68
+ float,
69
+ str,
70
+ Dict[str, Any],
71
+ Dict[str, Any],
72
+ str,
73
+ Optional[np.ndarray],
74
+ Optional[np.ndarray],
75
+ Optional[np.ndarray],
76
+ Dict[str, Any],
77
+ Dict[str, Any],
78
+ Optional[np.ndarray],
79
+ Dict[str, Any],
80
+ str,
81
+ int,
82
+ float,
83
+ int,
84
+ ]:
85
  """
86
  Reset the state of the UI.
87
 
88
  Returns:
89
  tuple: A tuple containing the initial values for the UI state.
90
  """
91
+ key: str = list(matcher_zoo.keys())[0] # Get the first key from matcher_zoo
92
  return (
93
+ None, # image0: Optional[np.ndarray]
94
+ None, # image1: Optional[np.ndarray]
95
+ DEFAULT_MATCHING_THRESHOLD, # matching_threshold: float
96
+ DEFAULT_SETTING_MAX_FEATURES, # max_features: int
97
+ DEFAULT_DEFAULT_KEYPOINT_THRESHOLD, # keypoint_threshold: float
98
+ key, # matcher: str
99
+ ui_change_imagebox("upload"), # input image0: Dict[str, Any]
100
+ ui_change_imagebox("upload"), # input image1: Dict[str, Any]
101
+ "upload", # match_image_src: str
102
+ None, # keypoints: Optional[np.ndarray]
103
+ None, # raw matches: Optional[np.ndarray]
104
+ None, # ransac matches: Optional[np.ndarray]
105
+ {}, # matches result info: Dict[str, Any]
106
+ {}, # matcher config: Dict[str, Any]
107
+ None, # warped image: Optional[np.ndarray]
108
+ {}, # geometry result: Dict[str, Any]
109
+ DEFAULT_RANSAC_METHOD, # ransac_method: str
110
+ DEFAULT_RANSAC_REPROJ_THRESHOLD, # ransac_reproj_threshold: float
111
+ DEFAULT_RANSAC_CONFIDENCE, # ransac_confidence: float
112
+ DEFAULT_RANSAC_MAX_ITER, # ransac_max_iter: int
113
+ DEFAULT_SETTING_GEOMETRY, # geometry: str
114
  )
115
 
116
 
common/utils.py CHANGED
@@ -5,6 +5,7 @@ import torch
5
  import cv2
6
  import gradio as gr
7
  from pathlib import Path
 
8
  from itertools import combinations
9
  from hloc import matchers, extractors, logger
10
  from hloc.utils.base_model import dynamic_load
@@ -25,19 +26,39 @@ DEFAULT_RANSAC_MAX_ITER = 10000
25
  DEFAULT_MIN_NUM_MATCHES = 4
26
  DEFAULT_MATCHING_THRESHOLD = 0.2
27
  DEFAULT_SETTING_GEOMETRY = "Homography"
28
- GRADIO_VERSION = gr.__version__.split('.')[0]
29
 
30
- def get_model(match_conf):
 
 
 
 
 
 
 
 
 
 
31
  Model = dynamic_load(matchers, match_conf["model"]["name"])
32
  model = Model(match_conf["model"]).eval().to(device)
33
  return model
34
 
35
 
36
- def get_feature_model(conf):
 
 
 
 
 
 
 
 
 
37
  Model = dynamic_load(extractors, conf["model"]["name"])
38
  model = Model(conf["model"]).eval().to(device)
39
  return model
40
 
 
41
  def gen_examples():
42
  random.seed(1)
43
  example_matchers = [
@@ -92,15 +113,30 @@ def gen_examples():
92
 
93
 
94
  def filter_matches(
95
- pred,
96
- ransac_method=DEFAULT_RANSAC_METHOD,
97
- ransac_reproj_threshold=DEFAULT_RANSAC_REPROJ_THRESHOLD,
98
- ransac_confidence=DEFAULT_RANSAC_CONFIDENCE,
99
- ransac_max_iter=DEFAULT_RANSAC_MAX_ITER,
100
- ):
101
- mkpts0 = None
102
- mkpts1 = None
103
- feature_type = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  if "keypoints0_orig" in pred.keys() and "keypoints1_orig" in pred.keys():
105
  mkpts0 = pred["keypoints0_orig"]
106
  mkpts1 = pred["keypoints1_orig"]
@@ -142,20 +178,33 @@ def filter_matches(
142
 
143
 
144
  def compute_geom(
145
- pred,
146
- ransac_method=DEFAULT_RANSAC_METHOD,
147
- ransac_reproj_threshold=DEFAULT_RANSAC_REPROJ_THRESHOLD,
148
- ransac_confidence=DEFAULT_RANSAC_CONFIDENCE,
149
- ransac_max_iter=DEFAULT_RANSAC_MAX_ITER,
150
- ) -> dict:
151
- mkpts0 = None
152
- mkpts1 = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  if "keypoints0_orig" in pred.keys() and "keypoints1_orig" in pred.keys():
155
  mkpts0 = pred["keypoints0_orig"]
156
  mkpts1 = pred["keypoints1_orig"]
157
-
158
- if (
159
  "line_keypoints0_orig" in pred.keys()
160
  and "line_keypoints1_orig" in pred.keys()
161
  ):
@@ -166,7 +215,7 @@ def compute_geom(
166
  if len(mkpts0) < 2 * DEFAULT_MIN_NUM_MATCHES:
167
  return {}
168
  h1, w1, _ = pred["image0_orig"].shape
169
- geo_info = {}
170
  F, inliers = cv2.findFundamentalMat(
171
  mkpts0,
172
  mkpts1,
@@ -197,22 +246,39 @@ def compute_geom(
197
  geo_info["H1"] = H1.tolist()
198
  geo_info["H2"] = H2.tolist()
199
  except cv2.error as e:
200
- logger.error(f"e, skip")
201
  return geo_info
202
  else:
203
  return {}
204
 
205
 
206
- def wrap_images(img0, img1, geo_info, geom_type):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  h1, w1, _ = img0.shape
208
  h2, w2, _ = img1.shape
209
- result_matrix = None
210
  if geo_info is not None and len(geo_info) != 0:
211
  rectified_image0 = img0
212
  rectified_image1 = None
213
  H = np.array(geo_info["Homography"])
214
  F = np.array(geo_info["Fundamental"])
215
- title = []
216
  if geom_type == "Homography":
217
  rectified_image1 = cv2.warpPerspective(
218
  img1, H, (img0.shape[1], img0.shape[0])
@@ -242,15 +308,32 @@ def wrap_images(img0, img1, geo_info, geom_type):
242
  return None, None
243
 
244
 
245
- def change_estimate_geom(input_image0, input_image1, matches_info, choice):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  if (
247
  matches_info is None
248
  or len(matches_info) < 1
249
  or "geom_info" not in matches_info.keys()
250
  ):
251
  return None, None
252
- geom_info = matches_info["geom_info"]
253
- wrapped_images = None
254
  if choice != "No":
255
  wrapped_images, _ = wrap_images(
256
  input_image0, input_image1, geom_info, choice
@@ -260,16 +343,34 @@ def change_estimate_geom(input_image0, input_image1, matches_info, choice):
260
  return None, None
261
 
262
 
263
- def display_matches(pred: dict, titles=[], dpi=300):
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  img0 = pred["image0_orig"]
265
  img1 = pred["image1_orig"]
266
 
267
  num_inliers = 0
268
- if "keypoints0_orig" in pred.keys() and "keypoints1_orig" in pred.keys():
 
 
 
 
 
269
  mkpts0 = pred["keypoints0_orig"]
270
  mkpts1 = pred["keypoints1_orig"]
271
  num_inliers = len(mkpts0)
272
- if "mconf" in pred.keys():
273
  mconf = pred["mconf"]
274
  else:
275
  mconf = np.ones(len(mkpts0))
@@ -283,7 +384,12 @@ def display_matches(pred: dict, titles=[], dpi=300):
283
  titles=titles,
284
  )
285
  fig = fig_mkpts
286
- if "line0_orig" in pred.keys() and "line1_orig" in pred.keys():
 
 
 
 
 
287
  # lines
288
  mtlines0 = pred["line0_orig"]
289
  mtlines1 = pred["line1_orig"]
@@ -297,12 +403,12 @@ def display_matches(pred: dict, titles=[], dpi=300):
297
  fig_lines = fig2im(fig_lines)
298
 
299
  # keypoints
300
- mkpts0 = pred["line_keypoints0_orig"]
301
- mkpts1 = pred["line_keypoints1_orig"]
302
 
303
  if mkpts0 is not None and mkpts1 is not None:
304
  num_inliers = len(mkpts0)
305
- if "mconf" in pred.keys():
306
  mconf = pred["mconf"]
307
  else:
308
  mconf = np.ones(len(mkpts0))
@@ -317,18 +423,51 @@ def display_matches(pred: dict, titles=[], dpi=300):
317
 
318
 
319
  def run_matching(
320
- image0,
321
- image1,
322
- match_threshold,
323
- extract_max_keypoints,
324
- keypoint_threshold,
325
- key,
326
- ransac_method=DEFAULT_RANSAC_METHOD,
327
- ransac_reproj_threshold=DEFAULT_RANSAC_REPROJ_THRESHOLD,
328
- ransac_confidence=DEFAULT_RANSAC_CONFIDENCE,
329
- ransac_max_iter=DEFAULT_RANSAC_MAX_ITER,
330
- choice_estimate_geom=DEFAULT_SETTING_GEOMETRY,
331
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  # image0 and image1 is RGB mode
333
  if image0 is None or image1 is None:
334
  raise gr.Error("Error: No images found! Please upload two images.")
 
5
  import cv2
6
  import gradio as gr
7
  from pathlib import Path
8
+ from typing import Dict, Any, Optional, Tuple, List, Union
9
  from itertools import combinations
10
  from hloc import matchers, extractors, logger
11
  from hloc.utils.base_model import dynamic_load
 
26
  DEFAULT_MIN_NUM_MATCHES = 4
27
  DEFAULT_MATCHING_THRESHOLD = 0.2
28
  DEFAULT_SETTING_GEOMETRY = "Homography"
29
+ GRADIO_VERSION = gr.__version__.split(".")[0]
30
 
31
+
32
+ def get_model(match_conf: Dict[str, Any]):
33
+ """
34
+ Load a matcher model from the provided configuration.
35
+
36
+ Args:
37
+ match_conf: A dictionary containing the model configuration.
38
+
39
+ Returns:
40
+ A matcher model instance.
41
+ """
42
  Model = dynamic_load(matchers, match_conf["model"]["name"])
43
  model = Model(match_conf["model"]).eval().to(device)
44
  return model
45
 
46
 
47
+ def get_feature_model(conf: Dict[str, Dict[str, Any]]):
48
+ """
49
+ Load a feature extraction model from the provided configuration.
50
+
51
+ Args:
52
+ conf: A dictionary containing the model configuration.
53
+
54
+ Returns:
55
+ A feature extraction model instance.
56
+ """
57
  Model = dynamic_load(extractors, conf["model"]["name"])
58
  model = Model(conf["model"]).eval().to(device)
59
  return model
60
 
61
+
62
  def gen_examples():
63
  random.seed(1)
64
  example_matchers = [
 
113
 
114
 
115
  def filter_matches(
116
+ pred: Dict[str, Any],
117
+ ransac_method: str = DEFAULT_RANSAC_METHOD,
118
+ ransac_reproj_threshold: float = DEFAULT_RANSAC_REPROJ_THRESHOLD,
119
+ ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE,
120
+ ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
121
+ ) -> Dict[str, Any]:
122
+ """
123
+ Filter matches using RANSAC. If keypoints are available, filter by keypoints.
124
+ If lines are available, filter by lines. If both keypoints and lines are
125
+ available, filter by keypoints.
126
+
127
+ Args:
128
+ pred (Dict[str, Any]): dict of matches, including original keypoints.
129
+ ransac_method (str, optional): RANSAC method. Defaults to DEFAULT_RANSAC_METHOD.
130
+ ransac_reproj_threshold (float, optional): RANSAC reprojection threshold. Defaults to DEFAULT_RANSAC_REPROJ_THRESHOLD.
131
+ ransac_confidence (float, optional): RANSAC confidence. Defaults to DEFAULT_RANSAC_CONFIDENCE.
132
+ ransac_max_iter (int, optional): RANSAC maximum iterations. Defaults to DEFAULT_RANSAC_MAX_ITER.
133
+
134
+ Returns:
135
+ Dict[str, Any]: filtered matches.
136
+ """
137
+ mkpts0: Optional[np.ndarray] = None
138
+ mkpts1: Optional[np.ndarray] = None
139
+ feature_type: Optional[str] = None
140
  if "keypoints0_orig" in pred.keys() and "keypoints1_orig" in pred.keys():
141
  mkpts0 = pred["keypoints0_orig"]
142
  mkpts1 = pred["keypoints1_orig"]
 
178
 
179
 
180
  def compute_geom(
181
+ pred: Dict[str, Any],
182
+ ransac_method: str = DEFAULT_RANSAC_METHOD,
183
+ ransac_reproj_threshold: float = DEFAULT_RANSAC_REPROJ_THRESHOLD,
184
+ ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE,
185
+ ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
186
+ ) -> Dict[str, List[float]]:
187
+ """
188
+ Compute geometric information of matches, including Fundamental matrix,
189
+ Homography matrix, and rectification matrices (if available).
190
+
191
+ Args:
192
+ pred (Dict[str, Any]): dict of matches, including original keypoints.
193
+ ransac_method (str, optional): RANSAC method. Defaults to DEFAULT_RANSAC_METHOD.
194
+ ransac_reproj_threshold (float, optional): RANSAC reprojection threshold. Defaults to DEFAULT_RANSAC_REPROJ_THRESHOLD.
195
+ ransac_confidence (float, optional): RANSAC confidence. Defaults to DEFAULT_RANSAC_CONFIDENCE.
196
+ ransac_max_iter (int, optional): RANSAC maximum iterations. Defaults to DEFAULT_RANSAC_MAX_ITER.
197
+
198
+ Returns:
199
+ Dict[str, List[float]]: geometric information in form of a dict.
200
+ """
201
+ mkpts0: Optional[np.ndarray] = None
202
+ mkpts1: Optional[np.ndarray] = None
203
 
204
  if "keypoints0_orig" in pred.keys() and "keypoints1_orig" in pred.keys():
205
  mkpts0 = pred["keypoints0_orig"]
206
  mkpts1 = pred["keypoints1_orig"]
207
+ elif (
 
208
  "line_keypoints0_orig" in pred.keys()
209
  and "line_keypoints1_orig" in pred.keys()
210
  ):
 
215
  if len(mkpts0) < 2 * DEFAULT_MIN_NUM_MATCHES:
216
  return {}
217
  h1, w1, _ = pred["image0_orig"].shape
218
+ geo_info: Dict[str, List[float]] = {}
219
  F, inliers = cv2.findFundamentalMat(
220
  mkpts0,
221
  mkpts1,
 
246
  geo_info["H1"] = H1.tolist()
247
  geo_info["H2"] = H2.tolist()
248
  except cv2.error as e:
249
+ logger.error(f"{e}, skip")
250
  return geo_info
251
  else:
252
  return {}
253
 
254
 
255
+ def wrap_images(
256
+ img0: np.ndarray,
257
+ img1: np.ndarray,
258
+ geo_info: Optional[Dict[str, List[float]]],
259
+ geom_type: str,
260
+ ) -> Tuple[Optional[str], Optional[Dict[str, List[float]]]]:
261
+ """
262
+ Wraps the images based on the geometric transformation used to align them.
263
+
264
+ Args:
265
+ img0: numpy array representing the first image.
266
+ img1: numpy array representing the second image.
267
+ geo_info: dictionary containing the geometric transformation information.
268
+ geom_type: type of geometric transformation used to align the images.
269
+
270
+ Returns:
271
+ A tuple containing a base64 encoded image string and a dictionary with the transformation matrix.
272
+ """
273
  h1, w1, _ = img0.shape
274
  h2, w2, _ = img1.shape
275
+ result_matrix: Optional[np.ndarray] = None
276
  if geo_info is not None and len(geo_info) != 0:
277
  rectified_image0 = img0
278
  rectified_image1 = None
279
  H = np.array(geo_info["Homography"])
280
  F = np.array(geo_info["Fundamental"])
281
+ title: List[str] = []
282
  if geom_type == "Homography":
283
  rectified_image1 = cv2.warpPerspective(
284
  img1, H, (img0.shape[1], img0.shape[0])
 
308
  return None, None
309
 
310
 
311
+ def change_estimate_geom(
312
+ input_image0: np.ndarray,
313
+ input_image1: np.ndarray,
314
+ matches_info: Dict[str, Any],
315
+ choice: str,
316
+ ) -> Tuple[Optional[np.ndarray], Optional[Dict[str, Any]]]:
317
+ """
318
+ Changes the estimate of the geometric transformation used to align the images.
319
+
320
+ Args:
321
+ input_image0: First input image.
322
+ input_image1: Second input image.
323
+ matches_info: Dictionary containing information about the matches.
324
+ choice: Type of geometric transformation to use ('Homography' or 'Fundamental') or 'No' to disable.
325
+
326
+ Returns:
327
+ A tuple containing the updated images and the updated matches info.
328
+ """
329
  if (
330
  matches_info is None
331
  or len(matches_info) < 1
332
  or "geom_info" not in matches_info.keys()
333
  ):
334
  return None, None
335
+ geom_info: Dict[str, Any] = matches_info["geom_info"]
336
+ wrapped_images: Optional[np.ndarray] = None
337
  if choice != "No":
338
  wrapped_images, _ = wrap_images(
339
  input_image0, input_image1, geom_info, choice
 
343
  return None, None
344
 
345
 
346
+ def display_matches(
347
+ pred: Dict[str, np.ndarray], titles: List[str] = [], dpi: int = 300
348
+ ) -> Tuple[np.ndarray, int]:
349
+ """
350
+ Displays the matches between two images.
351
+
352
+ Args:
353
+ pred: Dictionary containing the original images and the matches.
354
+ titles: Optional titles for the plot.
355
+ dpi: Resolution of the plot.
356
+
357
+ Returns:
358
+ The resulting concatenated plot and the number of inliers.
359
+ """
360
  img0 = pred["image0_orig"]
361
  img1 = pred["image1_orig"]
362
 
363
  num_inliers = 0
364
+ if (
365
+ "keypoints0_orig" in pred
366
+ and "keypoints1_orig" in pred
367
+ and pred["keypoints0_orig"] is not None
368
+ and pred["keypoints1_orig"] is not None
369
+ ):
370
  mkpts0 = pred["keypoints0_orig"]
371
  mkpts1 = pred["keypoints1_orig"]
372
  num_inliers = len(mkpts0)
373
+ if "mconf" in pred:
374
  mconf = pred["mconf"]
375
  else:
376
  mconf = np.ones(len(mkpts0))
 
384
  titles=titles,
385
  )
386
  fig = fig_mkpts
387
+ if (
388
+ "line0_orig" in pred
389
+ and "line1_orig" in pred
390
+ and pred["line0_orig"] is not None
391
+ and pred["line1_orig"] is not None
392
+ ):
393
  # lines
394
  mtlines0 = pred["line0_orig"]
395
  mtlines1 = pred["line1_orig"]
 
403
  fig_lines = fig2im(fig_lines)
404
 
405
  # keypoints
406
+ mkpts0 = pred.get("line_keypoints0_orig")
407
+ mkpts1 = pred.get("line_keypoints1_orig")
408
 
409
  if mkpts0 is not None and mkpts1 is not None:
410
  num_inliers = len(mkpts0)
411
+ if "mconf" in pred:
412
  mconf = pred["mconf"]
413
  else:
414
  mconf = np.ones(len(mkpts0))
 
423
 
424
 
425
  def run_matching(
426
+ image0: np.ndarray,
427
+ image1: np.ndarray,
428
+ match_threshold: float,
429
+ extract_max_keypoints: int,
430
+ keypoint_threshold: float,
431
+ key: str,
432
+ ransac_method: str = DEFAULT_RANSAC_METHOD,
433
+ ransac_reproj_threshold: int = DEFAULT_RANSAC_REPROJ_THRESHOLD,
434
+ ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE,
435
+ ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
436
+ choice_estimate_geom: str = DEFAULT_SETTING_GEOMETRY,
437
+ ) -> Tuple[
438
+ np.ndarray,
439
+ np.ndarray,
440
+ np.ndarray,
441
+ Dict[str, int],
442
+ Dict[str, Dict[str, Any]],
443
+ Dict[str, Dict[str, float]],
444
+ np.ndarray,
445
+ ]:
446
+ """Match two images using the given parameters.
447
+
448
+ Args:
449
+ image0 (np.ndarray): RGB image 0.
450
+ image1 (np.ndarray): RGB image 1.
451
+ match_threshold (float): match threshold.
452
+ extract_max_keypoints (int): number of keypoints to extract.
453
+ keypoint_threshold (float): keypoint threshold.
454
+ key (str): key of the model to use.
455
+ ransac_method (str, optional): RANSAC method to use.
456
+ ransac_reproj_threshold (int, optional): RANSAC reprojection threshold.
457
+ ransac_confidence (float, optional): RANSAC confidence level.
458
+ ransac_max_iter (int, optional): RANSAC maximum number of iterations.
459
+ choice_estimate_geom (str, optional): setting of geometry estimation.
460
+
461
+ Returns:
462
+ tuple:
463
+ - output_keypoints (np.ndarray): image with keypoints.
464
+ - output_matches_raw (np.ndarray): image with raw matches.
465
+ - output_matches_ransac (np.ndarray): image with RANSAC matches.
466
+ - num_matches (Dict[str, int]): number of raw and RANSAC matches.
467
+ - configs (Dict[str, Dict[str, Any]]): match and feature extraction configs.
468
+ - geom_info (Dict[str, Dict[str, float]]): geometry information.
469
+ - output_wrapped (np.ndarray): wrapped images.
470
+ """
471
  # image0 and image1 is RGB mode
472
  if image0 is None or image1 is None:
473
  raise gr.Error("Error: No images found! Please upload two images.")
common/viz.py CHANGED
@@ -1,20 +1,35 @@
1
  import numpy as np
2
- import matplotlib.pyplot as plt
3
- import matplotlib
4
  import seaborn as sns
 
 
 
 
5
 
6
 
7
- def plot_images(imgs, titles=None, cmaps="gray", dpi=100, size=5, pad=0.5):
 
 
 
 
 
 
 
8
  """Plot a set of images horizontally.
9
  Args:
10
  imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
11
  titles: a list of strings, as titles for each image.
12
- cmaps: colormaps for monochrome images.
 
 
 
 
 
 
 
13
  """
14
  n = len(imgs)
15
- if not isinstance(cmaps, (list, tuple)):
16
  cmaps = [cmaps] * n
17
- # figsize = (size*n, size*3/4) if size is not None else None
18
  figsize = (size * n, size * 6 / 5) if size is not None else None
19
  fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi)
20
 
@@ -33,24 +48,33 @@ def plot_images(imgs, titles=None, cmaps="gray", dpi=100, size=5, pad=0.5):
33
  return fig
34
 
35
 
36
- def plot_color_line_matches(lines, correct_matches=None, lw=2, indices=(0, 1)):
 
 
 
 
 
37
  """Plot line matches for existing images with multiple colors.
 
38
  Args:
39
- lines: list of ndarrays of size (N, 2, 2).
40
- correct_matches: bool array of size (N,) indicating correct matches.
41
- lw: line width as float pixels.
42
- indices: indices of the images to draw the matches on.
 
 
 
 
43
  """
44
- n_lines = len(lines[0])
45
  colors = sns.color_palette("husl", n_colors=n_lines)
46
  np.random.shuffle(colors)
47
  alphas = np.ones(n_lines)
48
- # If correct_matches is not None, display wrong matches with a low alpha
49
  if correct_matches is not None:
50
  alphas[~np.array(correct_matches)] = 0.2
51
 
52
  fig = plt.gcf()
53
- ax = fig.axes
54
  assert len(ax) > max(indices)
55
  axes = [ax[i] for i in indices]
56
  fig.canvas.draw()
@@ -78,21 +102,39 @@ def plot_color_line_matches(lines, correct_matches=None, lw=2, indices=(0, 1)):
78
 
79
 
80
  def make_matching_figure(
81
- img0,
82
- img1,
83
- mkpts0,
84
- mkpts1,
85
- color,
86
- titles=None,
87
- kpts0=None,
88
- kpts1=None,
89
- text=[],
90
- dpi=75,
91
- path=None,
92
- pad=0,
93
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  # draw image pair
95
- # assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
96
  fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
97
  axes[0].imshow(img0) # , cmap='gray')
98
  axes[1].imshow(img1) # , cmap='gray')
@@ -156,7 +198,20 @@ def make_matching_figure(
156
  return fig
157
 
158
 
159
- def error_colormap(err, thr, alpha=1.0):
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
161
  x = 1 - np.clip(err / (thr * 2), 0, 1)
162
  return np.clip(
@@ -173,22 +228,57 @@ color_map = np.arange(100)
173
  np.random.shuffle(color_map)
174
 
175
 
176
- def fig2im(fig):
 
 
 
 
 
 
 
 
 
 
177
  fig.canvas.draw()
178
- w, h = fig.canvas.get_width_height()
179
  buf_ndarray = np.frombuffer(fig.canvas.tostring_rgb(), dtype="u1")
180
- im = buf_ndarray.reshape(h, w, 3)
181
- return im
182
 
183
 
184
  def draw_matches(
185
- mkpts0, mkpts1, img0, img1, conf, titles=None, dpi=150, path=None, pad=0.5
186
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  thr = 5e-4
188
  thr = 0.5
189
  color = error_colormap(conf, thr, alpha=0.1)
190
  text = [
191
- f"image name",
192
  f"#Matches: {len(mkpts0)}",
193
  ]
194
  if path:
@@ -222,7 +312,31 @@ def draw_matches(
222
  )
223
 
224
 
225
- def draw_image_pairs(img0, img1, text=[], dpi=75, path=None, pad=0.5):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  # draw image pair
227
  fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
228
  axes[0].imshow(img0) # , cmap='gray')
 
1
  import numpy as np
 
 
2
  import seaborn as sns
3
+ import matplotlib
4
+ import matplotlib.pyplot as plt
5
+ from pathlib import Path
6
+ from typing import Dict, Any, Optional, Tuple, List, Union
7
 
8
 
9
+ def plot_images(
10
+ imgs: List[np.ndarray],
11
+ titles: Optional[List[str]] = None,
12
+ cmaps: Union[str, List[str]] = "gray",
13
+ dpi: int = 100,
14
+ size: Optional[int] = 5,
15
+ pad: float = 0.5,
16
+ ) -> plt.Figure:
17
  """Plot a set of images horizontally.
18
  Args:
19
  imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
20
  titles: a list of strings, as titles for each image.
21
+ cmaps: colormaps for monochrome images. If a single string is given,
22
+ it is used for all images.
23
+ dpi: DPI of the figure.
24
+ size: figure size in inches (width). If not provided, the figure
25
+ size is determined automatically.
26
+ pad: padding between subplots, in inches.
27
+ Returns:
28
+ The created figure.
29
  """
30
  n = len(imgs)
31
+ if not isinstance(cmaps, list):
32
  cmaps = [cmaps] * n
 
33
  figsize = (size * n, size * 6 / 5) if size is not None else None
34
  fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi)
35
 
 
48
  return fig
49
 
50
 
51
+ def plot_color_line_matches(
52
+ lines: List[np.ndarray],
53
+ correct_matches: Optional[np.ndarray] = None,
54
+ lw: float = 2.0,
55
+ indices: Tuple[int, int] = (0, 1),
56
+ ) -> matplotlib.figure.Figure:
57
  """Plot line matches for existing images with multiple colors.
58
+
59
  Args:
60
+ lines: List of ndarrays of size (N, 2, 2) representing line segments.
61
+ correct_matches: Optional bool array of size (N,) indicating correct
62
+ matches. If not None, display wrong matches with a low alpha.
63
+ lw: Line width as float pixels.
64
+ indices: Indices of the images to draw the matches on.
65
+
66
+ Returns:
67
+ The modified matplotlib figure.
68
  """
69
+ n_lines = lines[0].shape[0]
70
  colors = sns.color_palette("husl", n_colors=n_lines)
71
  np.random.shuffle(colors)
72
  alphas = np.ones(n_lines)
 
73
  if correct_matches is not None:
74
  alphas[~np.array(correct_matches)] = 0.2
75
 
76
  fig = plt.gcf()
77
+ ax = typing.cast(List[matplotlib.axes.Axes], fig.axes)
78
  assert len(ax) > max(indices)
79
  axes = [ax[i] for i in indices]
80
  fig.canvas.draw()
 
102
 
103
 
104
  def make_matching_figure(
105
+ img0: np.ndarray,
106
+ img1: np.ndarray,
107
+ mkpts0: np.ndarray,
108
+ mkpts1: np.ndarray,
109
+ color: np.ndarray,
110
+ titles: Optional[List[str]] = None,
111
+ kpts0: Optional[np.ndarray] = None,
112
+ kpts1: Optional[np.ndarray] = None,
113
+ text: List[str] = [],
114
+ dpi: int = 75,
115
+ path: Optional[Path] = None,
116
+ pad: float = 0.0,
117
+ ) -> Optional[plt.Figure]:
118
+ """Draw image pair with matches.
119
+
120
+ Args:
121
+ img0: image0 as HxWx3 numpy array.
122
+ img1: image1 as HxWx3 numpy array.
123
+ mkpts0: matched points in image0 as Nx2 numpy array.
124
+ mkpts1: matched points in image1 as Nx2 numpy array.
125
+ color: colors for the matches as Nx4 numpy array.
126
+ titles: titles for the two subplots.
127
+ kpts0: keypoints in image0 as Kx2 numpy array.
128
+ kpts1: keypoints in image1 as Kx2 numpy array.
129
+ text: list of strings to display in the top-left corner of the image.
130
+ dpi: dots per inch of the saved figure.
131
+ path: if not None, save the figure to this path.
132
+ pad: padding around the image as a fraction of the image size.
133
+
134
+ Returns:
135
+ The matplotlib Figure object if path is None.
136
+ """
137
  # draw image pair
 
138
  fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
139
  axes[0].imshow(img0) # , cmap='gray')
140
  axes[1].imshow(img1) # , cmap='gray')
 
198
  return fig
199
 
200
 
201
+ def error_colormap(
202
+ err: np.ndarray, thr: float, alpha: float = 1.0
203
+ ) -> np.ndarray:
204
+ """
205
+ Create a colormap based on the error values.
206
+
207
+ Args:
208
+ err: Error values as a numpy array of shape (N,).
209
+ thr: Threshold value for the error.
210
+ alpha: Alpha value for the colormap, between 0 and 1.
211
+
212
+ Returns:
213
+ Colormap as a numpy array of shape (N, 4) with values in [0, 1].
214
+ """
215
  assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
216
  x = 1 - np.clip(err / (thr * 2), 0, 1)
217
  return np.clip(
 
228
  np.random.shuffle(color_map)
229
 
230
 
231
+ def fig2im(fig: matplotlib.figure.Figure) -> np.ndarray:
232
+ """
233
+ Convert a matplotlib figure to a numpy array with RGB values.
234
+
235
+ Args:
236
+ fig: A matplotlib figure.
237
+
238
+ Returns:
239
+ A numpy array with shape (height, width, 3) and dtype uint8 containing
240
+ the RGB values of the figure.
241
+ """
242
  fig.canvas.draw()
243
+ (width, height) = fig.canvas.get_width_height()
244
  buf_ndarray = np.frombuffer(fig.canvas.tostring_rgb(), dtype="u1")
245
+ return buf_ndarray.reshape(height, width, 3)
 
246
 
247
 
248
  def draw_matches(
249
+ mkpts0: List[np.ndarray],
250
+ mkpts1: List[np.ndarray],
251
+ img0: np.ndarray,
252
+ img1: np.ndarray,
253
+ conf: np.ndarray,
254
+ titles: Optional[List[str]] = None,
255
+ dpi: int = 150,
256
+ path: Optional[str] = None,
257
+ pad: float = 0.5,
258
+ ) -> np.ndarray:
259
+ """
260
+ Draw matches between two images.
261
+
262
+ Args:
263
+ mkpts0: List of matches from the first image, with shape (N, 2)
264
+ mkpts1: List of matches from the second image, with shape (N, 2)
265
+ img0: First image, with shape (H, W, 3)
266
+ img1: Second image, with shape (H, W, 3)
267
+ conf: Confidence values for the matches, with shape (N,)
268
+ titles: Optional list of title strings for the plot
269
+ dpi: DPI for the saved image
270
+ path: Optional path to save the image to. If None, the image is not saved.
271
+ pad: Padding between subplots
272
+
273
+ Returns:
274
+ The figure as a numpy array with shape (height, width, 3) and dtype uint8
275
+ containing the RGB values of the figure.
276
+ """
277
  thr = 5e-4
278
  thr = 0.5
279
  color = error_colormap(conf, thr, alpha=0.1)
280
  text = [
281
+ "image name",
282
  f"#Matches: {len(mkpts0)}",
283
  ]
284
  if path:
 
312
  )
313
 
314
 
315
+ def draw_image_pairs(
316
+ img0: np.ndarray,
317
+ img1: np.ndarray,
318
+ text: List[str] = [],
319
+ dpi: int = 75,
320
+ path: Optional[str] = None,
321
+ pad: float = 0.5,
322
+ ) -> np.ndarray:
323
+ """Draw image pair horizontally.
324
+
325
+ Args:
326
+ img0: First image, with shape (H, W, 3)
327
+ img1: Second image, with shape (H, W, 3)
328
+ text: List of strings to print. Each string is a new line.
329
+ dpi: DPI of the figure.
330
+ path: Path to save the image to. If None, the image is not saved and
331
+ the function returns the figure as a numpy array with shape
332
+ (height, width, 3) and dtype uint8 containing the RGB values of the
333
+ figure.
334
+ pad: Padding between subplots
335
+
336
+ Returns:
337
+ The figure as a numpy array with shape (height, width, 3) and dtype uint8
338
+ containing the RGB values of the figure, or None if path is not None.
339
+ """
340
  # draw image pair
341
  fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
342
  axes[0].imshow(img0) # , cmap='gray')