Realcat commited on
Commit
7bec878
·
1 Parent(s): 7f38955

add: output file

Browse files
Files changed (3) hide show
  1. common/app_class.py +7 -0
  2. common/utils.py +18 -0
  3. hloc/match_features.py +1 -1
common/app_class.py CHANGED
@@ -254,6 +254,9 @@ class ImageMatchingApp:
254
  with gr.Accordion(
255
  "Open for More: Matches Statistics", open=False
256
  ):
 
 
 
257
  matches_result_info = gr.JSON(
258
  label="Matches Statistics"
259
  )
@@ -299,6 +302,7 @@ class ImageMatchingApp:
299
  geometry_result,
300
  output_wrapped,
301
  state_cache,
 
302
  ]
303
  # button callbacks
304
  button_run.click(
@@ -327,6 +331,7 @@ class ImageMatchingApp:
327
  ransac_confidence,
328
  ransac_max_iter,
329
  choice_geometry_type,
 
330
  ]
331
  button_reset.click(
332
  fn=self.ui_reset_state,
@@ -349,6 +354,7 @@ class ImageMatchingApp:
349
  output_matches_ransac,
350
  matches_result_info,
351
  output_wrapped,
 
352
  ],
353
  )
354
 
@@ -492,6 +498,7 @@ class ImageMatchingApp:
492
  ], # ransac_confidence: float
493
  self.cfg["defaults"]["ransac_max_iter"], # ransac_max_iter: int
494
  self.cfg["defaults"]["setting_geometry"], # geometry: str
 
495
  )
496
 
497
  def display_supported_algorithms(self, style="tab"):
 
254
  with gr.Accordion(
255
  "Open for More: Matches Statistics", open=False
256
  ):
257
+ output_pred = gr.File(
258
+ label="Outputs", elem_id="download"
259
+ )
260
  matches_result_info = gr.JSON(
261
  label="Matches Statistics"
262
  )
 
302
  geometry_result,
303
  output_wrapped,
304
  state_cache,
305
+ output_pred,
306
  ]
307
  # button callbacks
308
  button_run.click(
 
331
  ransac_confidence,
332
  ransac_max_iter,
333
  choice_geometry_type,
334
+ output_pred,
335
  ]
336
  button_reset.click(
337
  fn=self.ui_reset_state,
 
354
  output_matches_ransac,
355
  matches_result_info,
356
  output_wrapped,
357
+ output_pred,
358
  ],
359
  )
360
 
 
498
  ], # ransac_confidence: float
499
  self.cfg["defaults"]["ransac_max_iter"], # ransac_max_iter: int
500
  self.cfg["defaults"]["setting_geometry"], # geometry: str
501
+ None, # predictions
502
  )
503
 
504
  def display_supported_algorithms(self, style="tab"):
common/utils.py CHANGED
@@ -24,6 +24,8 @@ from .viz import (
24
  import time
25
  import matplotlib.pyplot as plt
26
  import warnings
 
 
27
 
28
  warnings.simplefilter("ignore")
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -774,6 +776,14 @@ def run_ransac(
774
 
775
  num_matches_raw = state_cache["num_matches_raw"]
776
  state_cache["wrapped_image"] = warped_image
 
 
 
 
 
 
 
 
777
  return (
778
  output_matches_ransac,
779
  {
@@ -781,6 +791,7 @@ def run_ransac(
781
  "num_matches_ransac": num_matches_ransac,
782
  },
783
  output_wrapped,
 
784
  )
785
 
786
 
@@ -961,6 +972,12 @@ def run_matching(
961
  state_cache["num_matches_raw"] = num_matches_raw
962
  state_cache["num_matches_ransac"] = num_matches_ransac
963
  state_cache["wrapped_image"] = warped_image
 
 
 
 
 
 
964
  return (
965
  output_keypoints,
966
  output_matches_raw,
@@ -978,6 +995,7 @@ def run_matching(
978
  },
979
  output_wrapped,
980
  state_cache,
 
981
  )
982
 
983
 
 
24
  import time
25
  import matplotlib.pyplot as plt
26
  import warnings
27
+ import tempfile
28
+ import pickle
29
 
30
  warnings.simplefilter("ignore")
31
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
776
 
777
  num_matches_raw = state_cache["num_matches_raw"]
778
  state_cache["wrapped_image"] = warped_image
779
+
780
+ # tmp_state_cache = tempfile.NamedTemporaryFile(suffix='.pkl', delete=False)
781
+ tmp_state_cache = "output.pkl"
782
+ with open(tmp_state_cache, "wb") as f:
783
+ pickle.dump(state_cache, f)
784
+
785
+ logger.info(f"Dump results done!")
786
+
787
  return (
788
  output_matches_ransac,
789
  {
 
791
  "num_matches_ransac": num_matches_ransac,
792
  },
793
  output_wrapped,
794
+ tmp_state_cache,
795
  )
796
 
797
 
 
972
  state_cache["num_matches_raw"] = num_matches_raw
973
  state_cache["num_matches_ransac"] = num_matches_ransac
974
  state_cache["wrapped_image"] = warped_image
975
+
976
+ # tmp_state_cache = tempfile.NamedTemporaryFile(suffix='.pkl', delete=False)
977
+ tmp_state_cache = "output.pkl"
978
+ with open(tmp_state_cache, "wb") as f:
979
+ pickle.dump(state_cache, f)
980
+ logger.info(f"Dump results done!")
981
  return (
982
  output_keypoints,
983
  output_matches_raw,
 
995
  },
996
  output_wrapped,
997
  state_cache,
998
+ tmp_state_cache,
999
  )
1000
 
1001
 
hloc/match_features.py CHANGED
@@ -386,7 +386,7 @@ def match_images(model, feat0, feat1):
386
  "mkeypoints1": mkpts1,
387
  "mkeypoints0_orig": mkpts0_origin.numpy(),
388
  "mkeypoints1_orig": mkpts1_origin.numpy(),
389
- "mconf": mconfid,
390
  }
391
  del feat0, feat1, desc0, desc1, kpts0, kpts1, kpts0_origin, kpts1_origin
392
  torch.cuda.empty_cache()
 
386
  "mkeypoints1": mkpts1,
387
  "mkeypoints0_orig": mkpts0_origin.numpy(),
388
  "mkeypoints1_orig": mkpts1_origin.numpy(),
389
+ "mconf": mconfid.numpy(),
390
  }
391
  del feat0, feat1, desc0, desc1, kpts0, kpts1, kpts0_origin, kpts1_origin
392
  torch.cuda.empty_cache()