Realcat commited on
Commit
3e3d5ea
·
1 Parent(s): e948c12

update: api

Browse files
Files changed (2) hide show
  1. common/api.py +20 -6
  2. test_app_cli.py +45 -11
common/api.py CHANGED
@@ -39,7 +39,15 @@ class ImageMatchingAPI(torch.nn.Module):
39
  "name": "xfeat",
40
  "max_keypoints": 1024,
41
  "keypoint_threshold": 0.015,
42
- }
 
 
 
 
 
 
 
 
43
  },
44
  "ransac": {
45
  "enable": True,
@@ -75,12 +83,18 @@ class ImageMatchingAPI(torch.nn.Module):
75
  """
76
  super().__init__()
77
  self.device = device
78
- self.conf = conf = {
79
- **self.parse_match_config(self.default_conf),
80
- **conf,
81
- }
82
  self._updata_config(detect_threshold, max_keypoints, match_threshold)
83
  self._init_models()
 
 
 
 
 
 
 
 
 
84
  self.pred = None
85
 
86
  def parse_match_config(self, conf):
@@ -123,7 +137,7 @@ class ImageMatchingAPI(torch.nn.Module):
123
 
124
  def _init_models(self):
125
  # initialize matcher
126
- self.matcher = get_model(self.conf["matcher"])
127
  # initialize extractor
128
  if self.dense:
129
  self.extractor = None
 
39
  "name": "xfeat",
40
  "max_keypoints": 1024,
41
  "keypoint_threshold": 0.015,
42
+ },
43
+ "preprocessing": {
44
+ "grayscale": False,
45
+ "resize_max": 1600,
46
+ "force_resize": True,
47
+ "width": 640,
48
+ "height": 480,
49
+ "dfactor": 8,
50
+ },
51
  },
52
  "ransac": {
53
  "enable": True,
 
83
  """
84
  super().__init__()
85
  self.device = device
86
+ self.conf = self.parse_match_config(conf)
 
 
 
87
  self._updata_config(detect_threshold, max_keypoints, match_threshold)
88
  self._init_models()
89
+ if device == "cuda":
90
+ memory_allocated = torch.cuda.memory_allocated(device)
91
+ memory_reserved = torch.cuda.memory_reserved(device)
92
+ logger.info(
93
+ f"GPU memory allocated: {memory_allocated / 1024**2:.3f} MB"
94
+ )
95
+ logger.info(
96
+ f"GPU memory reserved: {memory_reserved / 1024**2:.3f} MB"
97
+ )
98
  self.pred = None
99
 
100
  def parse_match_config(self, conf):
 
137
 
138
  def _init_models(self):
139
  # initialize matcher
140
+ self.matcher = get_model(self.match_conf)
141
  # initialize extractor
142
  if self.dense:
143
  self.extractor = None
test_app_cli.py CHANGED
@@ -39,23 +39,57 @@ def test_one():
39
  img_path2 = ROOT / "datasets/sacre_coeur/mapping/17295357_9106075285.jpg"
40
  image0 = cv2.imread(str(img_path1))[:, :, ::-1] # RGB
41
  image1 = cv2.imread(str(img_path2))[:, :, ::-1] # RGB
42
-
43
  conf = {
 
44
  "matcher": {
45
- "output": "matches-omniglue",
46
  "model": {
47
- "name": "omniglue",
48
  "match_threshold": 0.2,
49
- "features": "null",
50
- },
51
- "preprocessing": {
52
- "grayscale": False,
53
- "resize_max": 1024,
54
- "dfactor": 8,
55
- "force_resize": False,
56
- },
57
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  "dense": True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  }
60
  api = ImageMatchingAPI(conf=conf, device=device)
61
  api(image0, image1)
 
39
  img_path2 = ROOT / "datasets/sacre_coeur/mapping/17295357_9106075285.jpg"
40
  image0 = cv2.imread(str(img_path1))[:, :, ::-1] # RGB
41
  image1 = cv2.imread(str(img_path2))[:, :, ::-1] # RGB
42
+ # sparse
43
  conf = {
44
+ "dense": False,
45
  "matcher": {
 
46
  "model": {
47
+ "name": "NN-mutual",
48
  "match_threshold": 0.2,
49
+ }
50
+ },
51
+ "feature": {
52
+ "model": {
53
+ "name": "xfeat",
54
+ "max_keypoints": 1024,
55
+ "keypoint_threshold": 0.015,
56
+ }
57
  },
58
+ "ransac": {
59
+ "enable": True,
60
+ "estimator": "poselib",
61
+ "geometry": "homography",
62
+ "method": "RANSAC",
63
+ "reproj_threshold": 3,
64
+ "confidence": 0.9999,
65
+ "max_iter": 10000,
66
+ },
67
+ }
68
+ api = ImageMatchingAPI(conf=conf, device=device)
69
+ api(image0, image1)
70
+ log_path = ROOT / "experiments" / "one"
71
+ log_path.mkdir(exist_ok=True, parents=True)
72
+ api.visualize(log_path=log_path)
73
+
74
+ # dense
75
+ conf = {
76
  "dense": True,
77
+ "matcher": {
78
+ "model": {
79
+ "name": "loftr",
80
+ "match_threshold": 0.2,
81
+ }
82
+ },
83
+ "feature": {},
84
+ "ransac": {
85
+ "enable": True,
86
+ "estimator": "poselib",
87
+ "geometry": "homography",
88
+ "method": "RANSAC",
89
+ "reproj_threshold": 3,
90
+ "confidence": 0.9999,
91
+ "max_iter": 10000,
92
+ },
93
  }
94
  api = ImageMatchingAPI(conf=conf, device=device)
95
  api(image0, image1)