Realcat commited on
Commit
a9c6c59
·
1 Parent(s): 391d6b1

fix: cached model

Browse files
Files changed (4) hide show
  1. common/app_class.py +2 -2
  2. common/config.yaml +20 -20
  3. common/utils.py +6 -7
  4. common/viz.py +7 -5
common/app_class.py CHANGED
@@ -5,7 +5,7 @@ from pathlib import Path
5
  from typing import Dict, Any, Optional, Tuple, List, Union
6
  from common.utils import (
7
  ransac_zoo,
8
- change_estimate_geom,
9
  load_config,
10
  get_matcher_zoo,
11
  run_matching,
@@ -290,7 +290,7 @@ class ImageMatchingApp:
290
 
291
  # estimate geo
292
  choice_geometry_type.change(
293
- fn=change_estimate_geom,
294
  inputs=[
295
  input_image0,
296
  input_image1,
 
5
  from typing import Dict, Any, Optional, Tuple, List, Union
6
  from common.utils import (
7
  ransac_zoo,
8
+ generate_warp_images,
9
  load_config,
10
  get_matcher_zoo,
11
  run_matching,
 
290
 
291
  # estimate geo
292
  choice_geometry_type.change(
293
+ fn=generate_warp_images,
294
  inputs=[
295
  input_image0,
296
  input_image1,
common/config.yaml CHANGED
@@ -16,26 +16,26 @@ defaults:
16
  setting_geometry: Homography
17
 
18
  matcher_zoo:
19
- roma:
20
- matcher: roma
21
- dense: true
22
- info:
23
- name: RoMa #dispaly name
24
- source: "CVPR 2024"
25
- github: https://github.com/Parskatt/RoMa
26
- paper: https://arxiv.org/abs/2305.15404
27
- project: https://parskatt.github.io/RoMa
28
- display: true
29
- dkm:
30
- matcher: dkm
31
- dense: true
32
- info:
33
- name: DKM #dispaly name
34
- source: "CVPR 2023"
35
- github: https://github.com/Parskatt/DKM
36
- paper: https://arxiv.org/abs/2202.00667
37
- project: https://parskatt.github.io/DKM
38
- display: true
39
  loftr:
40
  matcher: loftr
41
  dense: true
 
16
  setting_geometry: Homography
17
 
18
  matcher_zoo:
19
+ # roma:
20
+ # matcher: roma
21
+ # dense: true
22
+ # info:
23
+ # name: RoMa #dispaly name
24
+ # source: "CVPR 2024"
25
+ # github: https://github.com/Parskatt/RoMa
26
+ # paper: https://arxiv.org/abs/2305.15404
27
+ # project: https://parskatt.github.io/RoMa
28
+ # display: true
29
+ # dkm:
30
+ # matcher: dkm
31
+ # dense: true
32
+ # info:
33
+ # name: DKM #dispaly name
34
+ # source: "CVPR 2023"
35
+ # github: https://github.com/Parskatt/DKM
36
+ # paper: https://arxiv.org/abs/2202.00667
37
+ # project: https://parskatt.github.io/DKM
38
+ # display: true
39
  loftr:
40
  matcher: loftr
41
  dense: true
common/utils.py CHANGED
@@ -12,7 +12,6 @@ from hloc.utils.base_model import dynamic_load
12
  from hloc import match_dense, match_features, extract_features
13
  from hloc.utils.viz import add_text, plot_keypoints
14
  from .viz import (
15
- draw_matches,
16
  fig2im,
17
  plot_images,
18
  display_matches,
@@ -242,7 +241,7 @@ def filter_matches(
242
  return pred
243
 
244
 
245
- def compute_geom(
246
  pred: Dict[str, Any],
247
  ransac_method: str = DEFAULT_RANSAC_METHOD,
248
  ransac_reproj_threshold: float = DEFAULT_RANSAC_REPROJ_THRESHOLD,
@@ -373,7 +372,7 @@ def wrap_images(
373
  return None, None
374
 
375
 
376
- def change_estimate_geom(
377
  input_image0: np.ndarray,
378
  input_image1: np.ndarray,
379
  matches_info: Dict[str, Any],
@@ -475,7 +474,7 @@ def run_matching(
475
  match_conf["model"]["match_threshold"] = match_threshold
476
  match_conf["model"]["max_keypoints"] = extract_max_keypoints
477
  t0 = time.time()
478
- cache_key = match_conf["model"]["name"]
479
  if cache_key in models_already_loaded:
480
  matcher = models_already_loaded[cache_key]
481
  matcher.conf["max_keypoints"] = extract_max_keypoints
@@ -499,7 +498,7 @@ def run_matching(
499
  # update extract config
500
  extract_conf["model"]["max_keypoints"] = extract_max_keypoints
501
  extract_conf["model"]["keypoint_threshold"] = keypoint_threshold
502
- cache_key = extract_conf["model"]["name"]
503
  if cache_key in models_already_loaded:
504
  extractor = models_already_loaded[cache_key]
505
  extractor.conf["max_keypoints"] = extract_max_keypoints
@@ -567,8 +566,8 @@ def run_matching(
567
 
568
  t1 = time.time()
569
  # plot wrapped images
570
- geom_info = compute_geom(pred)
571
- output_wrapped, _ = change_estimate_geom(
572
  pred["image0_orig"],
573
  pred["image1_orig"],
574
  {"geom_info": geom_info},
 
12
  from hloc import match_dense, match_features, extract_features
13
  from hloc.utils.viz import add_text, plot_keypoints
14
  from .viz import (
 
15
  fig2im,
16
  plot_images,
17
  display_matches,
 
241
  return pred
242
 
243
 
244
+ def compute_geometry(
245
  pred: Dict[str, Any],
246
  ransac_method: str = DEFAULT_RANSAC_METHOD,
247
  ransac_reproj_threshold: float = DEFAULT_RANSAC_REPROJ_THRESHOLD,
 
372
  return None, None
373
 
374
 
375
+ def generate_warp_images(
376
  input_image0: np.ndarray,
377
  input_image1: np.ndarray,
378
  matches_info: Dict[str, Any],
 
474
  match_conf["model"]["match_threshold"] = match_threshold
475
  match_conf["model"]["max_keypoints"] = extract_max_keypoints
476
  t0 = time.time()
477
+ cache_key = "{}_{}".format(key, match_conf["model"]["name"])
478
  if cache_key in models_already_loaded:
479
  matcher = models_already_loaded[cache_key]
480
  matcher.conf["max_keypoints"] = extract_max_keypoints
 
498
  # update extract config
499
  extract_conf["model"]["max_keypoints"] = extract_max_keypoints
500
  extract_conf["model"]["keypoint_threshold"] = keypoint_threshold
501
+ cache_key = "{}_{}".format(key, extract_conf["model"]["name"])
502
  if cache_key in models_already_loaded:
503
  extractor = models_already_loaded[cache_key]
504
  extractor.conf["max_keypoints"] = extract_max_keypoints
 
566
 
567
  t1 = time.time()
568
  # plot wrapped images
569
+ geom_info = compute_geometry(pred)
570
+ output_wrapped, _ = generate_warp_images(
571
  pred["image0_orig"],
572
  pred["image1_orig"],
573
  {"geom_info": geom_info},
common/viz.py CHANGED
@@ -247,7 +247,7 @@ def fig2im(fig: matplotlib.figure.Figure) -> np.ndarray:
247
  return buf_ndarray.reshape(height, width, 3)
248
 
249
 
250
- def draw_matches(
251
  mkpts0: List[np.ndarray],
252
  mkpts1: List[np.ndarray],
253
  img0: np.ndarray,
@@ -293,7 +293,7 @@ def draw_matches(
293
  mkpts1,
294
  color,
295
  titles=titles,
296
- text=text,
297
  path=path,
298
  dpi=dpi,
299
  pad=pad,
@@ -308,7 +308,7 @@ def draw_matches(
308
  mkpts1,
309
  color,
310
  titles=titles,
311
- text=text,
312
  pad=pad,
313
  dpi=dpi,
314
  )
@@ -406,7 +406,7 @@ def display_matches(
406
  mconf = pred["mconf"]
407
  else:
408
  mconf = np.ones(len(mkpts0))
409
- fig_mkpts = draw_matches(
410
  mkpts0,
411
  mkpts1,
412
  img0,
@@ -445,7 +445,9 @@ def display_matches(
445
  mconf = pred["mconf"]
446
  else:
447
  mconf = np.ones(len(mkpts0))
448
- fig_mkpts = draw_matches(mkpts0, mkpts1, img0, img1, mconf, dpi=300)
 
 
449
  fig_lines = cv2.resize(
450
  fig_lines, (fig_mkpts.shape[1], fig_mkpts.shape[0])
451
  )
 
247
  return buf_ndarray.reshape(height, width, 3)
248
 
249
 
250
+ def draw_matches_core(
251
  mkpts0: List[np.ndarray],
252
  mkpts1: List[np.ndarray],
253
  img0: np.ndarray,
 
293
  mkpts1,
294
  color,
295
  titles=titles,
296
+ # text=texts,
297
  path=path,
298
  dpi=dpi,
299
  pad=pad,
 
308
  mkpts1,
309
  color,
310
  titles=titles,
311
+ # text=texts,
312
  pad=pad,
313
  dpi=dpi,
314
  )
 
406
  mconf = pred["mconf"]
407
  else:
408
  mconf = np.ones(len(mkpts0))
409
+ fig_mkpts = draw_matches_core(
410
  mkpts0,
411
  mkpts1,
412
  img0,
 
445
  mconf = pred["mconf"]
446
  else:
447
  mconf = np.ones(len(mkpts0))
448
+ fig_mkpts = draw_matches_core(
449
+ mkpts0, mkpts1, img0, img1, mconf, dpi=300
450
+ )
451
  fig_lines = cv2.resize(
452
  fig_lines, (fig_mkpts.shape[1], fig_mkpts.shape[0])
453
  )