Realcat commited on
Commit
91bcc2a
·
1 Parent(s): bedfd11

fix: cache inference

Browse files
.gitignore CHANGED
@@ -18,3 +18,5 @@ gradio_cached_examples
18
  hloc/matchers/quadtree.py
19
  third_party/QuadTreeAttention
20
  desktop.ini
 
 
 
18
  hloc/matchers/quadtree.py
19
  third_party/QuadTreeAttention
20
  desktop.ini
21
+ experiments*
22
+ datasets/wxbs_benchmark
common/utils.py CHANGED
@@ -518,7 +518,7 @@ def run_matching(
518
  gr.Info(f"Matching images done using: {time.time()-t1:.3f}s")
519
  logger.info(f"Matching images done using: {time.time()-t1:.3f}s")
520
  t1 = time.time()
521
- # plot images with keypoints
522
  titles = [
523
  "Image 0 - Keypoints",
524
  "Image 1 - Keypoints",
 
518
  gr.Info(f"Matching images done using: {time.time()-t1:.3f}s")
519
  logger.info(f"Matching images done using: {time.time()-t1:.3f}s")
520
  t1 = time.time()
521
+ # plot images with keypoints\
522
  titles = [
523
  "Image 0 - Keypoints",
524
  "Image 1 - Keypoints",
common/viz.py CHANGED
@@ -293,7 +293,7 @@ def draw_matches_core(
293
  mkpts1,
294
  color,
295
  titles=titles,
296
- # text=texts,
297
  path=path,
298
  dpi=dpi,
299
  pad=pad,
@@ -308,7 +308,7 @@ def draw_matches_core(
308
  mkpts1,
309
  color,
310
  titles=titles,
311
- # text=texts,
312
  pad=pad,
313
  dpi=dpi,
314
  )
 
293
  mkpts1,
294
  color,
295
  titles=titles,
296
+ text=text,
297
  path=path,
298
  dpi=dpi,
299
  pad=pad,
 
308
  mkpts1,
309
  color,
310
  titles=titles,
311
+ text=text,
312
  pad=pad,
313
  dpi=dpi,
314
  )
hloc/extractors/superpoint.py CHANGED
@@ -44,4 +44,4 @@ class SuperPoint(BaseModel):
44
  self.net = superpoint.SuperPoint(conf)
45
 
46
  def _forward(self, data):
47
- return self.net(data)
 
44
  self.net = superpoint.SuperPoint(conf)
45
 
46
  def _forward(self, data):
47
+ return self.net(data, self.conf)
third_party/SuperGluePretrainedNetwork/models/superpoint.py CHANGED
@@ -83,9 +83,9 @@ def sample_descriptors(keypoints, descriptors, s: int = 8):
83
  """Interpolate descriptors at keypoint locations"""
84
  b, c, h, w = descriptors.shape
85
  keypoints = keypoints - s / 2 + 0.5
86
- keypoints /= torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],).to(
87
- keypoints
88
- )[None]
89
  keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
90
  args = {"align_corners": True} if torch.__version__ >= "1.3" else {}
91
  descriptors = torch.nn.functional.grid_sample(
@@ -136,7 +136,11 @@ class SuperPoint(nn.Module):
136
 
137
  self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
138
  self.convDb = nn.Conv2d(
139
- c5, self.config["descriptor_dim"], kernel_size=1, stride=1, padding=0
 
 
 
 
140
  )
141
 
142
  path = Path(__file__).parent / "weights/superpoint_v1.pth"
@@ -148,8 +152,12 @@ class SuperPoint(nn.Module):
148
 
149
  print("Loaded SuperPoint model")
150
 
151
- def forward(self, data):
152
  """Compute keypoints, scores, descriptors for image"""
 
 
 
 
153
  # Shared Encoder
154
  x = self.relu(self.conv1a(data["image"]))
155
  x = self.relu(self.conv1b(x))
@@ -182,7 +190,9 @@ class SuperPoint(nn.Module):
182
  keypoints, scores = list(
183
  zip(
184
  *[
185
- remove_borders(k, s, self.config["remove_borders"], h * 8, w * 8)
 
 
186
  for k, s in zip(keypoints, scores)
187
  ]
188
  )
 
83
  """Interpolate descriptors at keypoint locations"""
84
  b, c, h, w = descriptors.shape
85
  keypoints = keypoints - s / 2 + 0.5
86
+ keypoints /= torch.tensor(
87
+ [(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],
88
+ ).to(keypoints)[None]
89
  keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
90
  args = {"align_corners": True} if torch.__version__ >= "1.3" else {}
91
  descriptors = torch.nn.functional.grid_sample(
 
136
 
137
  self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
138
  self.convDb = nn.Conv2d(
139
+ c5,
140
+ self.config["descriptor_dim"],
141
+ kernel_size=1,
142
+ stride=1,
143
+ padding=0,
144
  )
145
 
146
  path = Path(__file__).parent / "weights/superpoint_v1.pth"
 
152
 
153
  print("Loaded SuperPoint model")
154
 
155
+ def forward(self, data, cfg={}):
156
  """Compute keypoints, scores, descriptors for image"""
157
+ self.config = {
158
+ **self.config,
159
+ **cfg,
160
+ }
161
  # Shared Encoder
162
  x = self.relu(self.conv1a(data["image"]))
163
  x = self.relu(self.conv1b(x))
 
190
  keypoints, scores = list(
191
  zip(
192
  *[
193
+ remove_borders(
194
+ k, s, self.config["remove_borders"], h * 8, w * 8
195
+ )
196
  for k, s in zip(keypoints, scores)
197
  ]
198
  )