Vincentqyw
commited on
Commit
•
0520a5a
1
Parent(s):
b13cad2
update: lanet
Browse files- hloc/extractors/lanet.py +7 -2
hloc/extractors/lanet.py
CHANGED
@@ -34,10 +34,10 @@ class LANet(BaseModel):
|
|
34 |
|
35 |
# Scores & Descriptors
|
36 |
kpts_score = (
|
37 |
-
torch.cat([keypoints, scores], dim=1).view(3, -1).t()
|
38 |
)
|
39 |
descriptors = (
|
40 |
-
descriptors.view(256, Hc, Wc).view(256, -1).t()
|
41 |
)
|
42 |
|
43 |
# Filter based on confidence threshold
|
@@ -46,6 +46,11 @@ class LANet(BaseModel):
|
|
46 |
keypoints = kpts_score[:, 1:]
|
47 |
scores = kpts_score[:, 0]
|
48 |
|
|
|
|
|
|
|
|
|
|
|
49 |
return {
|
50 |
"keypoints": torch.from_numpy(keypoints)[None],
|
51 |
"scores": torch.from_numpy(scores)[None],
|
|
|
34 |
|
35 |
# Scores & Descriptors
|
36 |
kpts_score = (
|
37 |
+
torch.cat([keypoints, scores], dim=1).view(3, -1).t()
|
38 |
)
|
39 |
descriptors = (
|
40 |
+
descriptors.view(256, Hc, Wc).view(256, -1).t()
|
41 |
)
|
42 |
|
43 |
# Filter based on confidence threshold
|
|
|
46 |
keypoints = kpts_score[:, 1:]
|
47 |
scores = kpts_score[:, 0]
|
48 |
|
49 |
+
idxs = scores.argsort()[-self.conf["max_keypoints"] or None :]
|
50 |
+
keypoints = keypoints[idxs, :2]
|
51 |
+
descriptors = descriptors[idxs]
|
52 |
+
scores = scores[idxs]
|
53 |
+
|
54 |
return {
|
55 |
"keypoints": torch.from_numpy(keypoints)[None],
|
56 |
"scores": torch.from_numpy(scores)[None],
|