Vincentqyw commited on
Commit
0520a5a
1 Parent(s): b13cad2

update: lanet

Browse files
Files changed (1) hide show
  1. hloc/extractors/lanet.py +7 -2
hloc/extractors/lanet.py CHANGED
@@ -34,10 +34,10 @@ class LANet(BaseModel):
34
 
35
  # Scores & Descriptors
36
  kpts_score = (
37
- torch.cat([keypoints, scores], dim=1).view(3, -1).t().cpu().detach().numpy()
38
  )
39
  descriptors = (
40
- descriptors.view(256, Hc, Wc).view(256, -1).t().cpu().detach().numpy()
41
  )
42
 
43
  # Filter based on confidence threshold
@@ -46,6 +46,11 @@ class LANet(BaseModel):
46
  keypoints = kpts_score[:, 1:]
47
  scores = kpts_score[:, 0]
48
 
 
 
 
 
 
49
  return {
50
  "keypoints": torch.from_numpy(keypoints)[None],
51
  "scores": torch.from_numpy(scores)[None],
 
34
 
35
  # Scores & Descriptors
36
  kpts_score = (
37
+ torch.cat([keypoints, scores], dim=1).view(3, -1).t()
38
  )
39
  descriptors = (
40
+ descriptors.view(256, Hc, Wc).view(256, -1).t()
41
  )
42
 
43
  # Filter based on confidence threshold
 
46
  keypoints = kpts_score[:, 1:]
47
  scores = kpts_score[:, 0]
48
 
49
+ idxs = scores.argsort()[-self.conf["max_keypoints"] or None :]
50
+ keypoints = keypoints[idxs, :2]
51
+ descriptors = descriptors[idxs]
52
+ scores = scores[idxs]
53
+
54
  return {
55
  "keypoints": torch.from_numpy(keypoints)[None],
56
  "scores": torch.from_numpy(scores)[None],