Realcat commited on
Commit
d46c0a9
·
1 Parent(s): d64f8f2

update:sift and update lightglue

Browse files
Files changed (43) hide show
  1. common/config.yaml +13 -1
  2. common/utils.py +13 -4
  3. hloc/extract_features.py +3 -2
  4. hloc/extractors/alike.py +2 -0
  5. hloc/extractors/d2net.py +4 -0
  6. hloc/extractors/darkfeat.py +2 -1
  7. hloc/extractors/dedode.py +2 -3
  8. hloc/extractors/example.py +1 -0
  9. hloc/extractors/lanet.py +2 -0
  10. hloc/extractors/r2d2.py +2 -0
  11. hloc/extractors/rekd.py +2 -0
  12. hloc/extractors/rord.py +4 -5
  13. hloc/extractors/sift.py +224 -0
  14. hloc/extractors/superpoint.py +2 -0
  15. hloc/match_dense.py +2 -5
  16. hloc/match_features.py +38 -14
  17. hloc/matchers/duster.py +36 -30
  18. hloc/matchers/lightglue.py +10 -0
  19. hloc/matchers/sgmnet.py +6 -2
  20. hloc/matchers/sold2.py +1 -0
  21. hloc/utils/viz.py +1 -0
  22. third_party/LightGlue/.flake8 +4 -0
  23. third_party/LightGlue/.github/workflows/code-quality.yml +24 -0
  24. third_party/LightGlue/.gitignore +162 -6
  25. third_party/LightGlue/LICENSE +1 -1
  26. third_party/LightGlue/README.md +71 -25
  27. third_party/LightGlue/assets/DSC_0410.JPG +0 -0
  28. third_party/LightGlue/assets/DSC_0411.JPG +0 -0
  29. third_party/LightGlue/assets/benchmark.png +3 -0
  30. third_party/LightGlue/assets/benchmark_cpu.png +3 -0
  31. third_party/LightGlue/benchmark.py +255 -0
  32. third_party/LightGlue/demo.ipynb +29 -22
  33. third_party/LightGlue/lightglue/__init__.py +7 -4
  34. third_party/LightGlue/lightglue/aliked.py +758 -0
  35. third_party/LightGlue/lightglue/disk.py +10 -24
  36. third_party/LightGlue/lightglue/dog_hardnet.py +41 -0
  37. third_party/LightGlue/lightglue/lightglue.py +331 -146
  38. third_party/LightGlue/lightglue/sift.py +216 -0
  39. third_party/LightGlue/lightglue/superpoint.py +21 -36
  40. third_party/LightGlue/lightglue/utils.py +25 -10
  41. third_party/LightGlue/lightglue/viz2d.py +1 -1
  42. third_party/LightGlue/pyproject.toml +30 -0
  43. third_party/LightGlue/setup.py +0 -27
common/config.yaml CHANGED
@@ -25,7 +25,7 @@ matcher_zoo:
25
  source: "CVPR 2024"
26
  github: https://github.com/Vincentqyw/omniglue-onnx
27
  paper: https://arxiv.org/abs/2405.12979
28
- project: https://hwjiang1510.github.io/OmniGlue/
29
  display: true
30
  DUSt3R:
31
  # TODO: duster is under development
@@ -40,6 +40,7 @@ matcher_zoo:
40
  project: https://dust3r.europe.naverlabs.com
41
  display: true
42
  GIM(dkm):
 
43
  matcher: gim(dkm)
44
  dense: true
45
  info:
@@ -197,6 +198,17 @@ matcher_zoo:
197
  paper: https://arxiv.org/abs/1712.07629
198
  project: null
199
  display: false
 
 
 
 
 
 
 
 
 
 
 
200
  disk+lightglue:
201
  matcher: disk-lightglue
202
  feature: disk
 
25
  source: "CVPR 2024"
26
  github: https://github.com/Vincentqyw/omniglue-onnx
27
  paper: https://arxiv.org/abs/2405.12979
28
+ project: https://hwjiang1510.github.io/OmniGlue
29
  display: true
30
  DUSt3R:
31
  # TODO: duster is under development
 
40
  project: https://dust3r.europe.naverlabs.com
41
  display: true
42
  GIM(dkm):
43
+ enable: false
44
  matcher: gim(dkm)
45
  dense: true
46
  info:
 
198
  paper: https://arxiv.org/abs/1712.07629
199
  project: null
200
  display: false
201
+ sift+lightglue:
202
+ matcher: sift-lightglue
203
+ feature: sift
204
+ dense: false
205
+ info:
206
+ name: LightGlue #dispaly name
207
+ source: "ICCV 2023"
208
+ github: https://github.com/cvg/LightGlue
209
+ paper: https://arxiv.org/pdf/2306.13643
210
+ project: null
211
+ display: true
212
  disk+lightglue:
213
  matcher: disk-lightglue
214
  feature: disk
common/utils.py CHANGED
@@ -7,6 +7,7 @@ import psutil
7
  import shutil
8
  import numpy as np
9
  import gradio as gr
 
10
  from pathlib import Path
11
  import poselib
12
  from itertools import combinations
@@ -231,10 +232,10 @@ def gen_examples():
231
  return [pairs[i] for i in selected]
232
 
233
  # rotated examples
234
- def gen_rot_image_pairs(count: int = 5):
235
  path = ROOT / "datasets/sacre_coeur/mapping"
236
  path_rot = ROOT / "datasets/sacre_coeur/mapping_rot"
237
- rot_list = [45, 90, 135, 180, 225, 270]
238
  pairs = []
239
  for file in os.listdir(path):
240
  if file.lower().endswith((".jpg", ".jpeg", ".png")):
@@ -274,6 +275,7 @@ def gen_examples():
274
  # image pair path
275
  pairs = gen_images_pairs()
276
  pairs += gen_rot_image_pairs()
 
277
  pairs += gen_image_pairs_wxbs()
278
 
279
  match_setting_threshold = DEFAULT_SETTING_THRESHOLD
@@ -1015,8 +1017,15 @@ ransac_zoo = {
1015
 
1016
 
1017
  def rotate_image(input_path, degrees, output_path):
1018
- from PIL import Image
1019
-
1020
  img = Image.open(input_path)
1021
  img_rotated = img.rotate(-degrees)
1022
  img_rotated.save(output_path)
 
 
 
 
 
 
 
 
 
 
7
  import shutil
8
  import numpy as np
9
  import gradio as gr
10
+ from PIL import Image
11
  from pathlib import Path
12
  import poselib
13
  from itertools import combinations
 
232
  return [pairs[i] for i in selected]
233
 
234
  # rotated examples
235
+ def gen_rot_image_pairs(count: int = 10):
236
  path = ROOT / "datasets/sacre_coeur/mapping"
237
  path_rot = ROOT / "datasets/sacre_coeur/mapping_rot"
238
+ rot_list = [45, 180, 90, 225, 270]
239
  pairs = []
240
  for file in os.listdir(path):
241
  if file.lower().endswith((".jpg", ".jpeg", ".png")):
 
275
  # image pair path
276
  pairs = gen_images_pairs()
277
  pairs += gen_rot_image_pairs()
278
+ pairs += gen_scale_image_pairs()
279
  pairs += gen_image_pairs_wxbs()
280
 
281
  match_setting_threshold = DEFAULT_SETTING_THRESHOLD
 
1017
 
1018
 
1019
  def rotate_image(input_path, degrees, output_path):
 
 
1020
  img = Image.open(input_path)
1021
  img_rotated = img.rotate(-degrees)
1022
  img_rotated.save(output_path)
1023
+
1024
+
1025
+ def scale_image(input_path, scale_factor, output_path):
1026
+ img = Image.open(input_path)
1027
+ width, height = img.size
1028
+ new_width = int(width * scale_factor)
1029
+ new_height = int(height * scale_factor)
1030
+ img_resized = img.resize((new_width, new_height))
1031
+ img_resized.save(output_path)
hloc/extract_features.py CHANGED
@@ -131,6 +131,7 @@ confs = {
131
  "output": "feats-rootsift-n5000-r1600",
132
  "model": {
133
  "name": "dog",
 
134
  "max_keypoints": 5000,
135
  },
136
  "preprocessing": {
@@ -145,8 +146,8 @@ confs = {
145
  "sift": {
146
  "output": "feats-sift-n5000-r1600",
147
  "model": {
148
- "name": "dog",
149
- "descriptor": "sift",
150
  "max_keypoints": 5000,
151
  },
152
  "preprocessing": {
 
131
  "output": "feats-rootsift-n5000-r1600",
132
  "model": {
133
  "name": "dog",
134
+ "descriptor": "rootsift",
135
  "max_keypoints": 5000,
136
  },
137
  "preprocessing": {
 
146
  "sift": {
147
  "output": "feats-sift-n5000-r1600",
148
  "model": {
149
+ "name": "sift",
150
+ "rootsift": True,
151
  "max_keypoints": 5000,
152
  },
153
  "preprocessing": {
hloc/extractors/alike.py CHANGED
@@ -3,6 +3,7 @@ from pathlib import Path
3
  import torch
4
 
5
  from ..utils.base_model import BaseModel
 
6
 
7
  alike_path = Path(__file__).parent / "../../third_party/ALIKE"
8
  sys.path.append(str(alike_path))
@@ -33,6 +34,7 @@ class Alike(BaseModel):
33
  scores_th=conf["detection_threshold"],
34
  n_limit=conf["max_keypoints"],
35
  )
 
36
 
37
  def _forward(self, data):
38
  image = data["image"]
 
3
  import torch
4
 
5
  from ..utils.base_model import BaseModel
6
+ from hloc import logger
7
 
8
  alike_path = Path(__file__).parent / "../../third_party/ALIKE"
9
  sys.path.append(str(alike_path))
 
34
  scores_th=conf["detection_threshold"],
35
  n_limit=conf["max_keypoints"],
36
  )
37
+ logger.info(f"Load Alike model done.")
38
 
39
  def _forward(self, data):
40
  image = data["image"]
hloc/extractors/d2net.py CHANGED
@@ -4,13 +4,16 @@ import subprocess
4
  import torch
5
 
6
  from ..utils.base_model import BaseModel
 
7
 
8
  d2net_path = Path(__file__).parent / "../../third_party"
9
  sys.path.append(str(d2net_path))
10
  from d2net.lib.model_test import D2Net as _D2Net
11
  from d2net.lib.pyramid import process_multiscale
 
12
  d2net_path = Path(__file__).parent / "../../third_party/d2net"
13
 
 
14
  class D2Net(BaseModel):
15
  default_conf = {
16
  "model_name": "d2_tf.pth",
@@ -36,6 +39,7 @@ class D2Net(BaseModel):
36
  self.net = _D2Net(
37
  model_file=model_file, use_relu=conf["use_relu"], use_cuda=False
38
  )
 
39
 
40
  def _forward(self, data):
41
  image = data["image"]
 
4
  import torch
5
 
6
  from ..utils.base_model import BaseModel
7
+ from hloc import logger
8
 
9
  d2net_path = Path(__file__).parent / "../../third_party"
10
  sys.path.append(str(d2net_path))
11
  from d2net.lib.model_test import D2Net as _D2Net
12
  from d2net.lib.pyramid import process_multiscale
13
+
14
  d2net_path = Path(__file__).parent / "../../third_party/d2net"
15
 
16
+
17
  class D2Net(BaseModel):
18
  default_conf = {
19
  "model_name": "d2_tf.pth",
 
39
  self.net = _D2Net(
40
  model_file=model_file, use_relu=conf["use_relu"], use_cuda=False
41
  )
42
+ logger.info(f"Load D2Net model done.")
43
 
44
  def _forward(self, data):
45
  image = data["image"]
hloc/extractors/darkfeat.py CHANGED
@@ -2,7 +2,7 @@ import sys
2
  from pathlib import Path
3
  import subprocess
4
  from ..utils.base_model import BaseModel
5
- from .. import logger
6
 
7
  darkfeat_path = Path(__file__).parent / "../../third_party/DarkFeat"
8
  sys.path.append(str(darkfeat_path))
@@ -43,6 +43,7 @@ class DarkFeat(BaseModel):
43
  raise e
44
 
45
  self.net = DarkFeat_(model_path)
 
46
 
47
  def _forward(self, data):
48
  pred = self.net({"image": data["image"]})
 
2
  from pathlib import Path
3
  import subprocess
4
  from ..utils.base_model import BaseModel
5
+ from hloc import logger
6
 
7
  darkfeat_path = Path(__file__).parent / "../../third_party/DarkFeat"
8
  sys.path.append(str(darkfeat_path))
 
43
  raise e
44
 
45
  self.net = DarkFeat_(model_path)
46
+ logger.info(f"Load DarkFeat model done.")
47
 
48
  def _forward(self, data):
49
  pred = self.net({"image": data["image"]})
hloc/extractors/dedode.py CHANGED
@@ -4,7 +4,7 @@ import subprocess
4
  import torch
5
  from PIL import Image
6
  from ..utils.base_model import BaseModel
7
- from .. import logger
8
  import torchvision.transforms as transforms
9
 
10
  dedode_path = Path(__file__).parent / "../../third_party/DeDoDe"
@@ -15,6 +15,7 @@ from DeDoDe.utils import to_pixel_coords
15
 
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
 
18
  class DeDoDe(BaseModel):
19
  default_conf = {
20
  "name": "dedode",
@@ -61,8 +62,6 @@ class DeDoDe(BaseModel):
61
  )
62
  subprocess.run(cmd, check=True)
63
 
64
- logger.info(f"Loading DeDoDe model...")
65
-
66
  # load the model
67
  weights_detector = torch.load(model_detector_path, map_location="cpu")
68
  weights_descriptor = torch.load(
 
4
  import torch
5
  from PIL import Image
6
  from ..utils.base_model import BaseModel
7
+ from hloc import logger
8
  import torchvision.transforms as transforms
9
 
10
  dedode_path = Path(__file__).parent / "../../third_party/DeDoDe"
 
15
 
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
+
19
  class DeDoDe(BaseModel):
20
  default_conf = {
21
  "name": "dedode",
 
62
  )
63
  subprocess.run(cmd, check=True)
64
 
 
 
65
  # load the model
66
  weights_detector = torch.load(model_detector_path, map_location="cpu")
67
  weights_descriptor = torch.load(
hloc/extractors/example.py CHANGED
@@ -13,6 +13,7 @@ sys.path.append(str(example_path))
13
 
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
 
16
  class Example(BaseModel):
17
  # change to your default configs
18
  default_conf = {
 
13
 
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
+
17
  class Example(BaseModel):
18
  # change to your default configs
19
  default_conf = {
hloc/extractors/lanet.py CHANGED
@@ -4,6 +4,7 @@ import subprocess
4
  import torch
5
 
6
  from ..utils.base_model import BaseModel
 
7
 
8
  lanet_path = Path(__file__).parent / "../../third_party/lanet"
9
  sys.path.append(str(lanet_path))
@@ -29,6 +30,7 @@ class LANet(BaseModel):
29
  self.net = PointModel(is_test=True)
30
  state_dict = torch.load(model_path, map_location="cpu")
31
  self.net.load_state_dict(state_dict["model_state"])
 
32
 
33
  def _forward(self, data):
34
  image = data["image"]
 
4
  import torch
5
 
6
  from ..utils.base_model import BaseModel
7
+ from hloc import logger
8
 
9
  lanet_path = Path(__file__).parent / "../../third_party/lanet"
10
  sys.path.append(str(lanet_path))
 
30
  self.net = PointModel(is_test=True)
31
  state_dict = torch.load(model_path, map_location="cpu")
32
  self.net.load_state_dict(state_dict["model_state"])
33
+ logger.info(f"Load LANet model done.")
34
 
35
  def _forward(self, data):
36
  image = data["image"]
hloc/extractors/r2d2.py CHANGED
@@ -3,6 +3,7 @@ from pathlib import Path
3
  import torchvision.transforms as tvf
4
 
5
  from ..utils.base_model import BaseModel
 
6
 
7
  base_path = Path(__file__).parent / "../../third_party"
8
  sys.path.append(str(base_path))
@@ -34,6 +35,7 @@ class R2D2(BaseModel):
34
  rel_thr=conf["reliability_threshold"],
35
  rep_thr=conf["repetability_threshold"],
36
  )
 
37
 
38
  def _forward(self, data):
39
  img = data["image"]
 
3
  import torchvision.transforms as tvf
4
 
5
  from ..utils.base_model import BaseModel
6
+ from hloc import logger
7
 
8
  base_path = Path(__file__).parent / "../../third_party"
9
  sys.path.append(str(base_path))
 
35
  rel_thr=conf["reliability_threshold"],
36
  rep_thr=conf["repetability_threshold"],
37
  )
38
+ logger.info(f"Load R2D2 model done.")
39
 
40
  def _forward(self, data):
41
  img = data["image"]
hloc/extractors/rekd.py CHANGED
@@ -4,6 +4,7 @@ import subprocess
4
  import torch
5
 
6
  from ..utils.base_model import BaseModel
 
7
 
8
  rekd_path = Path(__file__).parent / "../../third_party"
9
  sys.path.append(str(rekd_path))
@@ -28,6 +29,7 @@ class REKD(BaseModel):
28
  self.net = REKD_(is_test=True)
29
  state_dict = torch.load(model_path, map_location="cpu")
30
  self.net.load_state_dict(state_dict["model_state"])
 
31
 
32
  def _forward(self, data):
33
  image = data["image"]
 
4
  import torch
5
 
6
  from ..utils.base_model import BaseModel
7
+ from hloc import logger
8
 
9
  rekd_path = Path(__file__).parent / "../../third_party"
10
  sys.path.append(str(rekd_path))
 
29
  self.net = REKD_(is_test=True)
30
  state_dict = torch.load(model_path, map_location="cpu")
31
  self.net.load_state_dict(state_dict["model_state"])
32
+ logger.info(f"Load REKD model done.")
33
 
34
  def _forward(self, data):
35
  image = data["image"]
hloc/extractors/rord.py CHANGED
@@ -4,13 +4,14 @@ import subprocess
4
  import torch
5
 
6
  from ..utils.base_model import BaseModel
7
- from .. import logger
8
 
9
  rord_path = Path(__file__).parent / "../../third_party"
10
  sys.path.append(str(rord_path))
11
  from RoRD.lib.model_test import D2Net as _RoRD
12
  from RoRD.lib.pyramid import process_multiscale
13
 
 
14
  class RoRD(BaseModel):
15
  default_conf = {
16
  "model_name": "rord.pth",
@@ -32,9 +33,7 @@ class RoRD(BaseModel):
32
  model_path.parent.mkdir(exist_ok=True)
33
  cmd_wo_proxy = ["gdown", link, "-O", str(model_path)]
34
  cmd = ["gdown", link, "-O", str(model_path), "--proxy", self.proxy]
35
- logger.info(
36
- f"Downloading the RoRD model with `{cmd_wo_proxy}`."
37
- )
38
  try:
39
  subprocess.run(cmd_wo_proxy, check=True)
40
  except subprocess.CalledProcessError as e:
@@ -44,10 +43,10 @@ class RoRD(BaseModel):
44
  except subprocess.CalledProcessError as e:
45
  logger.error(f"Failed to download the RoRD model.")
46
  raise e
47
- logger.info("RoRD model loaded.")
48
  self.net = _RoRD(
49
  model_file=model_path, use_relu=conf["use_relu"], use_cuda=False
50
  )
 
51
 
52
  def _forward(self, data):
53
  image = data["image"]
 
4
  import torch
5
 
6
  from ..utils.base_model import BaseModel
7
+ from hloc import logger
8
 
9
  rord_path = Path(__file__).parent / "../../third_party"
10
  sys.path.append(str(rord_path))
11
  from RoRD.lib.model_test import D2Net as _RoRD
12
  from RoRD.lib.pyramid import process_multiscale
13
 
14
+
15
  class RoRD(BaseModel):
16
  default_conf = {
17
  "model_name": "rord.pth",
 
33
  model_path.parent.mkdir(exist_ok=True)
34
  cmd_wo_proxy = ["gdown", link, "-O", str(model_path)]
35
  cmd = ["gdown", link, "-O", str(model_path), "--proxy", self.proxy]
36
+ logger.info(f"Downloading the RoRD model with `{cmd_wo_proxy}`.")
 
 
37
  try:
38
  subprocess.run(cmd_wo_proxy, check=True)
39
  except subprocess.CalledProcessError as e:
 
43
  except subprocess.CalledProcessError as e:
44
  logger.error(f"Failed to download the RoRD model.")
45
  raise e
 
46
  self.net = _RoRD(
47
  model_file=model_path, use_relu=conf["use_relu"], use_cuda=False
48
  )
49
+ logger.info(f"Load RoRD model done.")
50
 
51
  def _forward(self, data):
52
  image = data["image"]
hloc/extractors/sift.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from kornia.color import rgb_to_grayscale
7
+ from packaging import version
8
+ from omegaconf import OmegaConf
9
+
10
+ try:
11
+ import pycolmap
12
+ except ImportError:
13
+ pycolmap = None
14
+ from hloc import logger
15
+ from ..utils.base_model import BaseModel
16
+
17
+
18
+ def filter_dog_point(
19
+ points, scales, angles, image_shape, nms_radius, scores=None
20
+ ):
21
+ h, w = image_shape
22
+ ij = np.round(points - 0.5).astype(int).T[::-1]
23
+
24
+ # Remove duplicate points (identical coordinates).
25
+ # Pick highest scale or score
26
+ s = scales if scores is None else scores
27
+ buffer = np.zeros((h, w))
28
+ np.maximum.at(buffer, tuple(ij), s)
29
+ keep = np.where(buffer[tuple(ij)] == s)[0]
30
+
31
+ # Pick lowest angle (arbitrary).
32
+ ij = ij[:, keep]
33
+ buffer[:] = np.inf
34
+ o_abs = np.abs(angles[keep])
35
+ np.minimum.at(buffer, tuple(ij), o_abs)
36
+ mask = buffer[tuple(ij)] == o_abs
37
+ ij = ij[:, mask]
38
+ keep = keep[mask]
39
+
40
+ if nms_radius > 0:
41
+ # Apply NMS on the remaining points
42
+ buffer[:] = 0
43
+ buffer[tuple(ij)] = s[keep] # scores or scale
44
+
45
+ local_max = torch.nn.functional.max_pool2d(
46
+ torch.from_numpy(buffer).unsqueeze(0),
47
+ kernel_size=nms_radius * 2 + 1,
48
+ stride=1,
49
+ padding=nms_radius,
50
+ ).squeeze(0)
51
+ is_local_max = buffer == local_max.numpy()
52
+ keep = keep[is_local_max[tuple(ij)]]
53
+ return keep
54
+
55
+
56
+ def sift_to_rootsift(x: torch.Tensor, eps=1e-6) -> torch.Tensor:
57
+ x = torch.nn.functional.normalize(x, p=1, dim=-1, eps=eps)
58
+ x.clip_(min=eps).sqrt_()
59
+ return torch.nn.functional.normalize(x, p=2, dim=-1, eps=eps)
60
+
61
+
62
+ def run_opencv_sift(features: cv2.Feature2D, image: np.ndarray) -> np.ndarray:
63
+ """
64
+ Detect keypoints using OpenCV Detector.
65
+ Optionally, perform description.
66
+ Args:
67
+ features: OpenCV based keypoints detector and descriptor
68
+ image: Grayscale image of uint8 data type
69
+ Returns:
70
+ keypoints: 1D array of detected cv2.KeyPoint
71
+ scores: 1D array of responses
72
+ descriptors: 1D array of descriptors
73
+ """
74
+ detections, descriptors = features.detectAndCompute(image, None)
75
+ points = np.array([k.pt for k in detections], dtype=np.float32)
76
+ scores = np.array([k.response for k in detections], dtype=np.float32)
77
+ scales = np.array([k.size for k in detections], dtype=np.float32)
78
+ angles = np.deg2rad(
79
+ np.array([k.angle for k in detections], dtype=np.float32)
80
+ )
81
+ return points, scores, scales, angles, descriptors
82
+
83
+
84
+ class SIFT(BaseModel):
85
+ default_conf = {
86
+ "rootsift": True,
87
+ "nms_radius": 0, # None to disable filtering entirely.
88
+ "max_keypoints": 4096,
89
+ "backend": "opencv", # in {opencv, pycolmap, pycolmap_cpu, pycolmap_cuda}
90
+ "detection_threshold": 0.0066667, # from COLMAP
91
+ "edge_threshold": 10,
92
+ "first_octave": -1, # only used by pycolmap, the default of COLMAP
93
+ "num_octaves": 4,
94
+ }
95
+
96
+ required_data_keys = ["image"]
97
+
98
+ def _init(self, conf):
99
+ self.conf = OmegaConf.create(self.conf)
100
+ backend = self.conf.backend
101
+ if backend.startswith("pycolmap"):
102
+ if pycolmap is None:
103
+ raise ImportError(
104
+ "Cannot find module pycolmap: install it with pip"
105
+ "or use backend=opencv."
106
+ )
107
+ options = {
108
+ "peak_threshold": self.conf.detection_threshold,
109
+ "edge_threshold": self.conf.edge_threshold,
110
+ "first_octave": self.conf.first_octave,
111
+ "num_octaves": self.conf.num_octaves,
112
+ "normalization": pycolmap.Normalization.L2, # L1_ROOT is buggy.
113
+ }
114
+ device = (
115
+ "auto"
116
+ if backend == "pycolmap"
117
+ else backend.replace("pycolmap_", "")
118
+ )
119
+ if (
120
+ backend == "pycolmap_cpu" or not pycolmap.has_cuda
121
+ ) and pycolmap.__version__ < "0.5.0":
122
+ warnings.warn(
123
+ "The pycolmap CPU SIFT is buggy in version < 0.5.0, "
124
+ "consider upgrading pycolmap or use the CUDA version.",
125
+ stacklevel=1,
126
+ )
127
+ else:
128
+ options["max_num_features"] = self.conf.max_keypoints
129
+ self.sift = pycolmap.Sift(options=options, device=device)
130
+ elif backend == "opencv":
131
+ self.sift = cv2.SIFT_create(
132
+ contrastThreshold=self.conf.detection_threshold,
133
+ nfeatures=self.conf.max_keypoints,
134
+ edgeThreshold=self.conf.edge_threshold,
135
+ nOctaveLayers=self.conf.num_octaves,
136
+ )
137
+ else:
138
+ backends = {"opencv", "pycolmap", "pycolmap_cpu", "pycolmap_cuda"}
139
+ raise ValueError(
140
+ f"Unknown backend: {backend} not in "
141
+ f"{{{','.join(backends)}}}."
142
+ )
143
+ logger.info(f"Load SIFT model done.")
144
+
145
+ def extract_single_image(self, image: torch.Tensor):
146
+ image_np = image.cpu().numpy().squeeze(0)
147
+
148
+ if self.conf.backend.startswith("pycolmap"):
149
+ if version.parse(pycolmap.__version__) >= version.parse("0.5.0"):
150
+ detections, descriptors = self.sift.extract(image_np)
151
+ scores = None # Scores are not exposed by COLMAP anymore.
152
+ else:
153
+ detections, scores, descriptors = self.sift.extract(image_np)
154
+ keypoints = detections[:, :2] # Keep only (x, y).
155
+ scales, angles = detections[:, -2:].T
156
+ if scores is not None and (
157
+ self.conf.backend == "pycolmap_cpu" or not pycolmap.has_cuda
158
+ ):
159
+ # Set the scores as a combination of abs. response and scale.
160
+ scores = np.abs(scores) * scales
161
+ elif self.conf.backend == "opencv":
162
+ # TODO: Check if opencv keypoints are already in corner convention
163
+ keypoints, scores, scales, angles, descriptors = run_opencv_sift(
164
+ self.sift, (image_np * 255.0).astype(np.uint8)
165
+ )
166
+ pred = {
167
+ "keypoints": keypoints,
168
+ "scales": scales,
169
+ "oris": angles,
170
+ "descriptors": descriptors,
171
+ }
172
+ if scores is not None:
173
+ pred["scores"] = scores
174
+
175
+ # sometimes pycolmap returns points outside the image. We remove them
176
+ if self.conf.backend.startswith("pycolmap"):
177
+ is_inside = (
178
+ pred["keypoints"] + 0.5 < np.array([image_np.shape[-2:][::-1]])
179
+ ).all(-1)
180
+ pred = {k: v[is_inside] for k, v in pred.items()}
181
+
182
+ if self.conf.nms_radius is not None:
183
+ keep = filter_dog_point(
184
+ pred["keypoints"],
185
+ pred["scales"],
186
+ pred["oris"],
187
+ image_np.shape,
188
+ self.conf.nms_radius,
189
+ scores=pred.get("scores"),
190
+ )
191
+ pred = {k: v[keep] for k, v in pred.items()}
192
+
193
+ pred = {k: torch.from_numpy(v) for k, v in pred.items()}
194
+ if scores is not None:
195
+ # Keep the k keypoints with highest score
196
+ num_points = self.conf.max_keypoints
197
+ if num_points is not None and len(pred["keypoints"]) > num_points:
198
+ indices = torch.topk(pred["scores"], num_points).indices
199
+ pred = {k: v[indices] for k, v in pred.items()}
200
+ return pred
201
+
202
+ def _forward(self, data: dict) -> dict:
203
+ image = data["image"]
204
+ if image.shape[1] == 3:
205
+ image = rgb_to_grayscale(image)
206
+ device = image.device
207
+ image = image.cpu()
208
+ pred = []
209
+ for k in range(len(image)):
210
+ img = image[k]
211
+ if "image_size" in data.keys():
212
+ # avoid extracting points in padded areas
213
+ w, h = data["image_size"][k]
214
+ img = img[:, :h, :w]
215
+ p = self.extract_single_image(img)
216
+ pred.append(p)
217
+ pred = {
218
+ k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]
219
+ }
220
+ if self.conf.rootsift:
221
+ pred["descriptors"] = sift_to_rootsift(pred["descriptors"])
222
+ pred["descriptors"] = pred["descriptors"].permute(0, 2, 1)
223
+ pred["keypoint_scores"] = pred["scores"].clone()
224
+ return pred
hloc/extractors/superpoint.py CHANGED
@@ -3,6 +3,7 @@ from pathlib import Path
3
  import torch
4
 
5
  from ..utils.base_model import BaseModel
 
6
 
7
  sys.path.append(str(Path(__file__).parent / "../../third_party"))
8
  from SuperGluePretrainedNetwork.models import superpoint # noqa E402
@@ -42,6 +43,7 @@ class SuperPoint(BaseModel):
42
  if conf["fix_sampling"]:
43
  superpoint.sample_descriptors = sample_descriptors_fix_sampling
44
  self.net = superpoint.SuperPoint(conf)
 
45
 
46
  def _forward(self, data):
47
  return self.net(data, self.conf)
 
3
  import torch
4
 
5
  from ..utils.base_model import BaseModel
6
+ from hloc import logger
7
 
8
  sys.path.append(str(Path(__file__).parent / "../../third_party"))
9
  from SuperGluePretrainedNetwork.models import superpoint # noqa E402
 
43
  if conf["fix_sampling"]:
44
  superpoint.sample_descriptors = sample_descriptors_fix_sampling
45
  self.net = superpoint.SuperPoint(conf)
46
+ logger.info(f"Load SuperPoint model done.")
47
 
48
  def _forward(self, data):
49
  return self.net(data, self.conf)
hloc/match_dense.py CHANGED
@@ -138,11 +138,8 @@ confs = {
138
  },
139
  "preprocessing": {
140
  "grayscale": False,
141
- "force_resize": True,
142
- "resize_max": 1024,
143
- "width": 512,
144
- "height": 512,
145
- "dfactor": 8,
146
  },
147
  },
148
  "xfeat_dense": {
 
138
  },
139
  "preprocessing": {
140
  "grayscale": False,
141
+ "resize_max": 512,
142
+ "dfactor": 16,
 
 
 
143
  },
144
  },
145
  "xfeat_dense": {
hloc/match_features.py CHANGED
@@ -63,7 +63,7 @@ confs = {
63
  },
64
  },
65
  "disk-lightglue": {
66
- "output": "matches-lightglue",
67
  "model": {
68
  "name": "lightglue",
69
  "match_threshold": 0.2,
@@ -79,6 +79,24 @@ confs = {
79
  "force_resize": False,
80
  },
81
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  "sgmnet": {
83
  "output": "matches-sgmnet",
84
  "model": {
@@ -339,19 +357,25 @@ def match_images(model, feat0, feat1):
339
  feat0["keypoints"] = feat0["keypoints"][0][None]
340
  if isinstance(feat1["keypoints"], list):
341
  feat1["keypoints"] = feat1["keypoints"][0][None]
342
-
343
- pred = model(
344
- {
345
- "image0": feat0["image"],
346
- "keypoints0": feat0["keypoints"],
347
- "scores0": feat0["scores"][0].unsqueeze(0),
348
- "descriptors0": desc0,
349
- "image1": feat1["image"],
350
- "keypoints1": feat1["keypoints"],
351
- "scores1": feat1["scores"][0].unsqueeze(0),
352
- "descriptors1": desc1,
353
- }
354
- )
 
 
 
 
 
 
355
  pred = {
356
  k: v.cpu().detach()[0] if isinstance(v, torch.Tensor) else v
357
  for k, v in pred.items()
 
63
  },
64
  },
65
  "disk-lightglue": {
66
+ "output": "matches-disk-lightglue",
67
  "model": {
68
  "name": "lightglue",
69
  "match_threshold": 0.2,
 
79
  "force_resize": False,
80
  },
81
  },
82
+ "sift-lightglue": {
83
+ "output": "matches-sift-lightglue",
84
+ "model": {
85
+ "name": "lightglue",
86
+ "match_threshold": 0.2,
87
+ "width_confidence": 0.99, # for point pruning
88
+ "depth_confidence": 0.95, # for early stopping,
89
+ "features": "sift",
90
+ "add_scale_ori": True,
91
+ "model_name": "sift_lightglue.pth",
92
+ },
93
+ "preprocessing": {
94
+ "grayscale": True,
95
+ "resize_max": 1024,
96
+ "dfactor": 8,
97
+ "force_resize": False,
98
+ },
99
+ },
100
  "sgmnet": {
101
  "output": "matches-sgmnet",
102
  "model": {
 
357
  feat0["keypoints"] = feat0["keypoints"][0][None]
358
  if isinstance(feat1["keypoints"], list):
359
  feat1["keypoints"] = feat1["keypoints"][0][None]
360
+ input_dict = {
361
+ "image0": feat0["image"],
362
+ "keypoints0": feat0["keypoints"],
363
+ "scores0": feat0["scores"][0].unsqueeze(0),
364
+ "descriptors0": desc0,
365
+ "image1": feat1["image"],
366
+ "keypoints1": feat1["keypoints"],
367
+ "scores1": feat1["scores"][0].unsqueeze(0),
368
+ "descriptors1": desc1,
369
+ }
370
+ if "scales" in feat0:
371
+ input_dict = {**input_dict, "scales0": feat0["scales"]}
372
+ if "scales" in feat1:
373
+ input_dict = {**input_dict, "scales1": feat1["scales"]}
374
+ if "oris" in feat0:
375
+ input_dict = {**input_dict, "oris0": feat0["oris"]}
376
+ if "oris" in feat1:
377
+ input_dict = {**input_dict, "oris1": feat1["oris"]}
378
+ pred = model(input_dict)
379
  pred = {
380
  k: v.cpu().detach()[0] if isinstance(v, torch.Tensor) else v
381
  for k, v in pred.items()
hloc/matchers/duster.py CHANGED
@@ -13,7 +13,7 @@ duster_path = Path(__file__).parent / "../../third_party/dust3r"
13
  sys.path.append(str(duster_path))
14
 
15
  from dust3r.inference import inference
16
- from dust3r.model import load_model
17
  from dust3r.image_pairs import make_pairs
18
  from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
19
  from dust3r.utils.geometry import find_reciprocal_matches, xy_grid
@@ -33,7 +33,11 @@ class Duster(BaseModel):
33
  self.normalize = tfm.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
34
  self.model_path = self.conf["model_path"]
35
  self.download_weights()
36
- self.net = load_model(self.model_path, device)
 
 
 
 
37
  logger.info(f"Loaded Dust3r model")
38
 
39
  def download_weights(self):
@@ -68,8 +72,11 @@ class Duster(BaseModel):
68
 
69
  def _forward(self, data):
70
  img0, img1 = data["image0"], data["image1"]
71
- # img0 = self.preprocess(img0)
72
- # img1 = self.preprocess(img1)
 
 
 
73
 
74
  images = [
75
  {"img": img0, "idx": 0, "instance": 0},
@@ -79,22 +86,13 @@ class Duster(BaseModel):
79
  images, scene_graph="complete", prefilter=None, symmetrize=True
80
  )
81
  output = inference(pairs, self.net, device, batch_size=1)
82
-
83
  scene = global_aligner(
84
  output, device=device, mode=GlobalAlignerMode.PairViewer
85
  )
86
- batch_size = 1
87
- schedule = "cosine"
88
- lr = 0.01
89
- niter = 300
90
- loss = scene.compute_global_alignment(
91
- init="mst", niter=niter, schedule=schedule, lr=lr
92
- )
93
-
94
  # retrieve useful values from scene:
 
95
  confidence_masks = scene.get_masks()
96
  pts3d = scene.get_pts3d()
97
- imgs = scene.imgs
98
  pts2d_list, pts3d_list = [], []
99
  for i in range(2):
100
  conf_i = confidence_masks[i].cpu().numpy()
@@ -102,21 +100,29 @@ class Duster(BaseModel):
102
  xy_grid(*imgs[i].shape[:2][::-1])[conf_i]
103
  ) # imgs[i].shape[:2] = (H, W)
104
  pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i])
105
- reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(
106
- *pts3d_list
107
- )
108
- logger.info(f"Found {num_matches} matches")
109
- mkpts1 = pts2d_list[1][reciprocal_in_P2]
110
- mkpts0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2]
111
-
112
- top_k = self.conf["max_keypoints"]
113
- if top_k is not None and len(mkpts0) > top_k:
114
- keep = np.round(np.linspace(0, len(mkpts0) - 1, top_k)).astype(int)
115
- mkpts0 = mkpts0[keep]
116
- mkpts1 = mkpts1[keep]
117
- pred = {
118
- "keypoints0": torch.from_numpy(mkpts0),
119
- "keypoints1": torch.from_numpy(mkpts1),
120
- }
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  return pred
 
13
  sys.path.append(str(duster_path))
14
 
15
  from dust3r.inference import inference
16
+ from dust3r.model import load_model, AsymmetricCroCo3DStereo
17
  from dust3r.image_pairs import make_pairs
18
  from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
19
  from dust3r.utils.geometry import find_reciprocal_matches, xy_grid
 
33
  self.normalize = tfm.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
34
  self.model_path = self.conf["model_path"]
35
  self.download_weights()
36
+ # self.net = load_model(self.model_path, device)
37
+ self.net = AsymmetricCroCo3DStereo.from_pretrained(
38
+ self.model_path
39
+ # "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt"
40
+ ).to(device)
41
  logger.info(f"Loaded Dust3r model")
42
 
43
  def download_weights(self):
 
72
 
73
  def _forward(self, data):
74
  img0, img1 = data["image0"], data["image1"]
75
+ mean = torch.tensor([0.5, 0.5, 0.5]).to(device)
76
+ std = torch.tensor([0.5, 0.5, 0.5]).to(device)
77
+
78
+ img0 = (img0 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1)
79
+ img1 = (img1 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1)
80
 
81
  images = [
82
  {"img": img0, "idx": 0, "instance": 0},
 
86
  images, scene_graph="complete", prefilter=None, symmetrize=True
87
  )
88
  output = inference(pairs, self.net, device, batch_size=1)
 
89
  scene = global_aligner(
90
  output, device=device, mode=GlobalAlignerMode.PairViewer
91
  )
 
 
 
 
 
 
 
 
92
  # retrieve useful values from scene:
93
+ imgs = scene.imgs
94
  confidence_masks = scene.get_masks()
95
  pts3d = scene.get_pts3d()
 
96
  pts2d_list, pts3d_list = [], []
97
  for i in range(2):
98
  conf_i = confidence_masks[i].cpu().numpy()
 
100
  xy_grid(*imgs[i].shape[:2][::-1])[conf_i]
101
  ) # imgs[i].shape[:2] = (H, W)
102
  pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
+ if len(pts3d_list[1]) == 0:
105
+ pred = {
106
+ "keypoints0": torch.zeros([0, 2]),
107
+ "keypoints1": torch.zeros([0, 2]),
108
+ }
109
+ logger.warning(f"Matched {0} points")
110
+ else:
111
+ reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(
112
+ *pts3d_list
113
+ )
114
+ logger.info(f"Found {num_matches} matches")
115
+ mkpts1 = pts2d_list[1][reciprocal_in_P2]
116
+ mkpts0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2]
117
+ top_k = self.conf["max_keypoints"]
118
+ if top_k is not None and len(mkpts0) > top_k:
119
+ keep = np.round(np.linspace(0, len(mkpts0) - 1, top_k)).astype(
120
+ int
121
+ )
122
+ mkpts0 = mkpts0[keep]
123
+ mkpts1 = mkpts1[keep]
124
+ pred = {
125
+ "keypoints0": torch.from_numpy(mkpts0),
126
+ "keypoints1": torch.from_numpy(mkpts1),
127
+ }
128
  return pred
hloc/matchers/lightglue.py CHANGED
@@ -18,6 +18,7 @@ class LightGlue(BaseModel):
18
  "model_name": "superpoint_lightglue.pth",
19
  "flash": True, # enable FlashAttention if available.
20
  "mp": False, # enable mixed precision
 
21
  }
22
  required_inputs = [
23
  "image0",
@@ -44,9 +45,18 @@ class LightGlue(BaseModel):
44
  "keypoints": data["keypoints0"],
45
  "descriptors": data["descriptors0"].permute(0, 2, 1),
46
  }
 
 
 
 
 
47
  input["image1"] = {
48
  "image": data["image1"],
49
  "keypoints": data["keypoints1"],
50
  "descriptors": data["descriptors1"].permute(0, 2, 1),
51
  }
 
 
 
 
52
  return self.net(input)
 
18
  "model_name": "superpoint_lightglue.pth",
19
  "flash": True, # enable FlashAttention if available.
20
  "mp": False, # enable mixed precision
21
+ "add_scale_ori": False,
22
  }
23
  required_inputs = [
24
  "image0",
 
45
  "keypoints": data["keypoints0"],
46
  "descriptors": data["descriptors0"].permute(0, 2, 1),
47
  }
48
+ if "scales0" in data:
49
+ input["image0"] = {**input["image0"], "scales": data["scales0"]}
50
+ if "oris0" in data:
51
+ input["image0"] = {**input["image0"], "oris": data["oris0"]}
52
+
53
  input["image1"] = {
54
  "image": data["image1"],
55
  "keypoints": data["keypoints1"],
56
  "descriptors": data["descriptors1"].permute(0, 2, 1),
57
  }
58
+ if "scales1" in data:
59
+ input["image1"] = {**input["image1"], "scales": data["scales1"]}
60
+ if "oris1" in data:
61
+ input["image1"] = {**input["image1"], "oris": data["oris1"]}
62
  return self.net(input)
hloc/matchers/sgmnet.py CHANGED
@@ -99,8 +99,12 @@ class SGMNet(BaseModel):
99
  score2 = data["scores1"].reshape(-1, 1)
100
  desc1 = data["descriptors0"].permute(0, 2, 1) # 1 x N x 128
101
  desc2 = data["descriptors1"].permute(0, 2, 1)
102
- size1 = torch.tensor(data["image0"].shape[2:]).flip(0) # W x H -> x & y
103
- size2 = torch.tensor(data["image1"].shape[2:]).flip(0) # W x H
 
 
 
 
104
  norm_x1 = self.normalize_size(x1, size1)
105
  norm_x2 = self.normalize_size(x2, size2)
106
 
 
99
  score2 = data["scores1"].reshape(-1, 1)
100
  desc1 = data["descriptors0"].permute(0, 2, 1) # 1 x N x 128
101
  desc2 = data["descriptors1"].permute(0, 2, 1)
102
+ size1 = (
103
+ torch.tensor(data["image0"].shape[2:]).flip(0).to(x1.device)
104
+ ) # W x H -> x & y
105
+ size2 = (
106
+ torch.tensor(data["image1"].shape[2:]).flip(0).to(x2.device)
107
+ ) # W x H
108
  norm_x1 = self.normalize_size(x1, size1)
109
  norm_x2 = self.normalize_size(x2, size2)
110
 
hloc/matchers/sold2.py CHANGED
@@ -34,6 +34,7 @@ class SOLD2(BaseModel):
34
  weight_urls = {
35
  "sold2_wireframe.tar": "https://www.polybox.ethz.ch/index.php/s/blOrW89gqSLoHOk/download",
36
  }
 
37
  # Initialize the line matcher
38
  def _init(self, conf):
39
  checkpoint_path = conf["checkpoint_dir"] / conf["weights"]
 
34
  weight_urls = {
35
  "sold2_wireframe.tar": "https://www.polybox.ethz.ch/index.php/s/blOrW89gqSLoHOk/download",
36
  }
37
+
38
  # Initialize the line matcher
39
  def _init(self, conf):
40
  checkpoint_path = conf["checkpoint_dir"] / conf["weights"]
hloc/utils/viz.py CHANGED
@@ -71,6 +71,7 @@ def plot_keypoints(kpts, colors="lime", ps=4):
71
  except IndexError as e:
72
  pass
73
 
 
74
  def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.0):
75
  """Plot matches for a pair of existing images.
76
  Args:
 
71
  except IndexError as e:
72
  pass
73
 
74
+
75
  def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.0):
76
  """Plot matches for a pair of existing images.
77
  Args:
third_party/LightGlue/.flake8 ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ [flake8]
2
+ max-line-length = 88
3
+ extend-ignore = E203
4
+ exclude = .git,__pycache__,build,.venv/
third_party/LightGlue/.github/workflows/code-quality.yml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Format and Lint Checks
2
+ on:
3
+ push:
4
+ branches:
5
+ - main
6
+ paths:
7
+ - '*.py'
8
+ pull_request:
9
+ types: [ assigned, opened, synchronize, reopened ]
10
+ jobs:
11
+ check:
12
+ name: Format and Lint Checks
13
+ runs-on: ubuntu-latest
14
+ steps:
15
+ - uses: actions/checkout@v3
16
+ - uses: actions/setup-python@v4
17
+ with:
18
+ python-version: '3.10'
19
+ cache: 'pip'
20
+ - run: python -m pip install --upgrade pip
21
+ - run: python -m pip install .[dev]
22
+ - run: python -m flake8 .
23
+ - run: python -m isort . --check-only --diff
24
+ - run: python -m black . --check --diff
third_party/LightGlue/.gitignore CHANGED
@@ -1,10 +1,166 @@
1
- *.egg-info
2
- *.pyc
3
- /.idea/
4
  /data/
5
  /outputs/
6
- __pycache__
7
  /lightglue/weights/
8
- lightglue/_flash/
9
  *-checkpoint.ipynb
10
- *.pth
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  /data/
2
  /outputs/
 
3
  /lightglue/weights/
 
4
  *-checkpoint.ipynb
5
+ *.pth
6
+
7
+ # Byte-compiled / optimized / DLL files
8
+ __pycache__/
9
+ *.py[cod]
10
+ *$py.class
11
+
12
+ # C extensions
13
+ *.so
14
+
15
+ # Distribution / packaging
16
+ .Python
17
+ build/
18
+ develop-eggs/
19
+ dist/
20
+ downloads/
21
+ eggs/
22
+ .eggs/
23
+ lib/
24
+ lib64/
25
+ parts/
26
+ sdist/
27
+ var/
28
+ wheels/
29
+ share/python-wheels/
30
+ *.egg-info/
31
+ .installed.cfg
32
+ *.egg
33
+ MANIFEST
34
+
35
+ # PyInstaller
36
+ # Usually these files are written by a python script from a template
37
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
38
+ *.manifest
39
+ *.spec
40
+
41
+ # Installer logs
42
+ pip-log.txt
43
+ pip-delete-this-directory.txt
44
+
45
+ # Unit test / coverage reports
46
+ htmlcov/
47
+ .tox/
48
+ .nox/
49
+ .coverage
50
+ .coverage.*
51
+ .cache
52
+ nosetests.xml
53
+ coverage.xml
54
+ *.cover
55
+ *.py,cover
56
+ .hypothesis/
57
+ .pytest_cache/
58
+ cover/
59
+
60
+ # Translations
61
+ *.mo
62
+ *.pot
63
+
64
+ # Django stuff:
65
+ *.log
66
+ local_settings.py
67
+ db.sqlite3
68
+ db.sqlite3-journal
69
+
70
+ # Flask stuff:
71
+ instance/
72
+ .webassets-cache
73
+
74
+ # Scrapy stuff:
75
+ .scrapy
76
+
77
+ # Sphinx documentation
78
+ docs/_build/
79
+
80
+ # PyBuilder
81
+ .pybuilder/
82
+ target/
83
+
84
+ # Jupyter Notebook
85
+ .ipynb_checkpoints
86
+
87
+ # IPython
88
+ profile_default/
89
+ ipython_config.py
90
+
91
+ # pyenv
92
+ # For a library or package, you might want to ignore these files since the code is
93
+ # intended to run in multiple environments; otherwise, check them in:
94
+ # .python-version
95
+
96
+ # pipenv
97
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
99
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
100
+ # install all needed dependencies.
101
+ #Pipfile.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/#use-with-ide
116
+ .pdm.toml
117
+
118
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
119
+ __pypackages__/
120
+
121
+ # Celery stuff
122
+ celerybeat-schedule
123
+ celerybeat.pid
124
+
125
+ # SageMath parsed files
126
+ *.sage.py
127
+
128
+ # Environments
129
+ .env
130
+ .venv
131
+ env/
132
+ venv/
133
+ ENV/
134
+ env.bak/
135
+ venv.bak/
136
+
137
+ # Spyder project settings
138
+ .spyderproject
139
+ .spyproject
140
+
141
+ # Rope project settings
142
+ .ropeproject
143
+
144
+ # mkdocs documentation
145
+ /site
146
+
147
+ # mypy
148
+ .mypy_cache/
149
+ .dmypy.json
150
+ dmypy.json
151
+
152
+ # Pyre type checker
153
+ .pyre/
154
+
155
+ # pytype static type analyzer
156
+ .pytype/
157
+
158
+ # Cython debug symbols
159
+ cython_debug/
160
+
161
+ # PyCharm
162
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
163
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
164
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
165
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
166
+ .idea/
third_party/LightGlue/LICENSE CHANGED
@@ -186,7 +186,7 @@
186
  same "printed page" as the copyright notice for easier
187
  identification within third-party archives.
188
 
189
- Copyright [yyyy] [name of copyright owner]
190
 
191
  Licensed under the Apache License, Version 2.0 (the "License");
192
  you may not use this file except in compliance with the License.
 
186
  same "printed page" as the copyright notice for easier
187
  identification within third-party archives.
188
 
189
+ Copyright 2023 ETH Zurich
190
 
191
  Licensed under the Apache License, Version 2.0 (the "License");
192
  you may not use this file except in compliance with the License.
third_party/LightGlue/README.md CHANGED
@@ -1,5 +1,5 @@
1
  <p align="center">
2
- <h1 align="center"><ins>LightGlue ⚡️</ins><br>Local Feature Matching at Light Speed</h1>
3
  <p align="center">
4
  <a href="https://www.linkedin.com/in/philipplindenberger/">Philipp Lindenberger</a>
5
  ·
@@ -7,15 +7,14 @@
7
  ·
8
  <a href="https://www.microsoft.com/en-us/research/people/mapoll/">Marc&nbsp;Pollefeys</a>
9
  </p>
10
- <!-- <p align="center">
11
- <img src="assets/larchitecture.svg" alt="Logo" height="40">
12
- </p> -->
13
- <!-- <h2 align="center">PrePrint 2023</h2> -->
14
- <h2 align="center"><p>
15
  <a href="https://arxiv.org/pdf/2306.13643.pdf" align="center">Paper</a> |
16
- <a href="https://colab.research.google.com/github/cvg/LightGlue/blob/main/demo.ipynb" align="center">Colab</a>
17
- </p></h2>
18
- <div align="center"></div>
 
 
19
  </p>
20
  <p align="center">
21
  <a href="https://arxiv.org/abs/2306.13643"><img src="assets/easy_hard.jpg" alt="example" width=80%></a>
@@ -27,8 +26,8 @@
27
 
28
  This repository hosts the inference code of LightGlue, a lightweight feature matcher with high accuracy and blazing fast inference. It takes as input a set of keypoints and descriptors for each image and returns the indices of corresponding points. The architecture is based on adaptive pruning techniques, in both network width and depth - [check out the paper for more details](https://arxiv.org/pdf/2306.13643.pdf).
29
 
30
- We release pretrained weights of LightGlue with [SuperPoint](https://arxiv.org/abs/1712.07629) and [DISK](https://arxiv.org/abs/2006.13566) local features.
31
- The training end evaluation code will be released in July in a separate repo. To be notified, subscribe to [issue #6](https://github.com/cvg/LightGlue/issues/6).
32
 
33
  ## Installation and demo [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/cvg/LightGlue/blob/main/demo.ipynb)
34
 
@@ -44,14 +43,14 @@ We provide a [demo notebook](demo.ipynb) which shows how to perform feature extr
44
  Here is a minimal script to match two images:
45
 
46
  ```python
47
- from lightglue import LightGlue, SuperPoint, DISK
48
  from lightglue.utils import load_image, rbd
49
 
50
  # SuperPoint+LightGlue
51
  extractor = SuperPoint(max_num_keypoints=2048).eval().cuda() # load the extractor
52
  matcher = LightGlue(features='superpoint').eval().cuda() # load the matcher
53
 
54
- # or DISK+LightGlue
55
  extractor = DISK(max_num_keypoints=2048).eval().cuda() # load the extractor
56
  matcher = LightGlue(features='disk').eval().cuda() # load the matcher
57
 
@@ -88,6 +87,18 @@ feats0, feats1, matches01 = match_pair(extractor, matcher, image0, image1)
88
 
89
  ## Advanced configuration
90
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  The default values give a good trade-off between speed and accuracy. To maximize the accuracy, use all keypoints and disable the adaptive mechanisms:
92
  ```python
93
  extractor = SuperPoint(max_num_keypoints=None)
@@ -99,31 +110,62 @@ To increase the speed with a small drop of accuracy, decrease the number of keyp
99
  extractor = SuperPoint(max_num_keypoints=1024)
100
  matcher = LightGlue(features='superpoint', depth_confidence=0.9, width_confidence=0.95)
101
  ```
102
- The maximum speed is obtained with [FlashAttention](https://arxiv.org/abs/2205.14135), which is automatically used when ```torch >= 2.0``` or if it is [installed from source](https://github.com/HazyResearch/flash-attention#installation-and-features).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  <details>
105
- <summary>[Detail of all parameters - click to expand]</summary>
106
 
107
- - [```n_layers```](https://github.com/cvg/LightGlue/blob/main/lightglue/lightglue.py#L261): Number of stacked self+cross attention layers. Reduce this value for faster inference at the cost of accuracy (continuous red line in the plot above). Default: 9 (all layers).
108
- - [```flash```](https://github.com/cvg/LightGlue/blob/main/lightglue/lightglue.py#L263): Enable FlashAttention. Significantly increases the speed and reduces the memory consumption without any impact on accuracy. Default: True (LightGlue automatically detects if FlashAttention is available).
109
- - [```mp```](https://github.com/cvg/LightGlue/blob/main/lightglue/lightglue.py#L264): Enable mixed precision inference. Default: False (off)
110
- - [```depth_confidence```](https://github.com/cvg/LightGlue/blob/main/lightglue/lightglue.py#L265): Controls the early stopping. A lower values stops more often at earlier layers. Default: 0.95, disable with -1.
111
- - [```width_confidence```](https://github.com/cvg/LightGlue/blob/main/lightglue/lightglue.py#L266): Controls the iterative point pruning. A lower value prunes more points earlier. Default: 0.99, disable with -1.
112
- - [```filter_threshold```](https://github.com/cvg/LightGlue/blob/main/lightglue/lightglue.py#L267): Match confidence. Increase this value to obtain less, but stronger matches. Default: 0.1
113
 
114
  </details>
115
 
 
 
 
 
 
116
  ## Other links
117
  - [hloc - the visual localization toolbox](https://github.com/cvg/Hierarchical-Localization/): run LightGlue for Structure-from-Motion and visual localization.
118
- - [LightGlue-ONNX](https://github.com/fabio-sim/LightGlue-ONNX): export LightGlue to the Open Neural Network Exchange format.
119
  - [Image Matching WebUI](https://github.com/Vincentqyw/image-matching-webui): a web GUI to easily compare different matchers, including LightGlue.
120
- - [kornia](kornia.readthedocs.io/) now exposes LightGlue via the interfaces [`LightGlue`](https://kornia.readthedocs.io/en/latest/feature.html#kornia.feature.LightGlue) and [`LightGlueMatcher`](https://kornia.readthedocs.io/en/latest/feature.html#kornia.feature.LightGlueMatcher).
121
 
122
- ## BibTeX Citation
123
  If you use any ideas from the paper or code from this repo, please consider citing:
124
 
125
  ```txt
126
- @inproceedings{lindenberger23lightglue,
127
  author = {Philipp Lindenberger and
128
  Paul-Edouard Sarlin and
129
  Marc Pollefeys},
@@ -132,3 +174,7 @@ If you use any ideas from the paper or code from this repo, please consider citi
132
  year = {2023}
133
  }
134
  ```
 
 
 
 
 
1
  <p align="center">
2
+ <h1 align="center"><ins>LightGlue</ins> ⚡️<br>Local Feature Matching at Light Speed</h1>
3
  <p align="center">
4
  <a href="https://www.linkedin.com/in/philipplindenberger/">Philipp Lindenberger</a>
5
  ·
 
7
  ·
8
  <a href="https://www.microsoft.com/en-us/research/people/mapoll/">Marc&nbsp;Pollefeys</a>
9
  </p>
10
+ <h2 align="center">
11
+ <p>ICCV 2023</p>
 
 
 
12
  <a href="https://arxiv.org/pdf/2306.13643.pdf" align="center">Paper</a> |
13
+ <a href="https://colab.research.google.com/github/cvg/LightGlue/blob/main/demo.ipynb" align="center">Colab</a> |
14
+ <a href="https://psarlin.com/assets/LightGlue_ICCV2023_poster_compressed.pdf" align="center">Poster</a> |
15
+ <a href="https://github.com/cvg/glue-factory" align="center">Train your own!</a>
16
+ </h2>
17
+
18
  </p>
19
  <p align="center">
20
  <a href="https://arxiv.org/abs/2306.13643"><img src="assets/easy_hard.jpg" alt="example" width=80%></a>
 
26
 
27
  This repository hosts the inference code of LightGlue, a lightweight feature matcher with high accuracy and blazing fast inference. It takes as input a set of keypoints and descriptors for each image and returns the indices of corresponding points. The architecture is based on adaptive pruning techniques, in both network width and depth - [check out the paper for more details](https://arxiv.org/pdf/2306.13643.pdf).
28
 
29
+ We release pretrained weights of LightGlue with [SuperPoint](https://arxiv.org/abs/1712.07629), [DISK](https://arxiv.org/abs/2006.13566), [ALIKED](https://arxiv.org/abs/2304.03608) and [SIFT](https://www.cs.ubc.ca/~lowe/papers/ijcv04.pdf) local features.
30
+ The training and evaluation code can be found in our library [glue-factory](https://github.com/cvg/glue-factory/).
31
 
32
  ## Installation and demo [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/cvg/LightGlue/blob/main/demo.ipynb)
33
 
 
43
  Here is a minimal script to match two images:
44
 
45
  ```python
46
+ from lightglue import LightGlue, SuperPoint, DISK, SIFT, ALIKED, DoGHardNet
47
  from lightglue.utils import load_image, rbd
48
 
49
  # SuperPoint+LightGlue
50
  extractor = SuperPoint(max_num_keypoints=2048).eval().cuda() # load the extractor
51
  matcher = LightGlue(features='superpoint').eval().cuda() # load the matcher
52
 
53
+ # or DISK+LightGlue, ALIKED+LightGlue or SIFT+LightGlue
54
  extractor = DISK(max_num_keypoints=2048).eval().cuda() # load the extractor
55
  matcher = LightGlue(features='disk').eval().cuda() # load the matcher
56
 
 
87
 
88
  ## Advanced configuration
89
 
90
+ <details>
91
+ <summary>[Detail of all parameters - click to expand]</summary>
92
+
93
+ - ```n_layers```: Number of stacked self+cross attention layers. Reduce this value for faster inference at the cost of accuracy (continuous red line in the plot above). Default: 9 (all layers).
94
+ - ```flash```: Enable FlashAttention. Significantly increases the speed and reduces the memory consumption without any impact on accuracy. Default: True (LightGlue automatically detects if FlashAttention is available).
95
+ - ```mp```: Enable mixed precision inference. Default: False (off)
96
+ - ```depth_confidence```: Controls the early stopping. A lower values stops more often at earlier layers. Default: 0.95, disable with -1.
97
+ - ```width_confidence```: Controls the iterative point pruning. A lower value prunes more points earlier. Default: 0.99, disable with -1.
98
+ - ```filter_threshold```: Match confidence. Increase this value to obtain less, but stronger matches. Default: 0.1
99
+
100
+ </details>
101
+
102
  The default values give a good trade-off between speed and accuracy. To maximize the accuracy, use all keypoints and disable the adaptive mechanisms:
103
  ```python
104
  extractor = SuperPoint(max_num_keypoints=None)
 
110
  extractor = SuperPoint(max_num_keypoints=1024)
111
  matcher = LightGlue(features='superpoint', depth_confidence=0.9, width_confidence=0.95)
112
  ```
113
+
114
+ The maximum speed is obtained with a combination of:
115
+ - [FlashAttention](https://arxiv.org/abs/2205.14135): automatically used when ```torch >= 2.0``` or if [installed from source](https://github.com/HazyResearch/flash-attention#installation-and-features).
116
+ - PyTorch compilation, available when ```torch >= 2.0```:
117
+ ```python
118
+ matcher = matcher.eval().cuda()
119
+ matcher.compile(mode='reduce-overhead')
120
+ ```
121
+ For inputs with fewer than 1536 keypoints (determined experimentally), this compiles LightGlue but disables point pruning (large overhead). For larger input sizes, it automatically falls backs to eager mode with point pruning. Adaptive depths is supported for any input size.
122
+
123
+ ## Benchmark
124
+
125
+
126
+ <p align="center">
127
+ <a><img src="assets/benchmark.png" alt="Logo" width=80%></a>
128
+ <br>
129
+ <em>Benchmark results on GPU (RTX 3080). With compilation and adaptivity, LightGlue runs at 150 FPS @ 1024 keypoints and 50 FPS @ 4096 keypoints per image. This is a 4-10x speedup over SuperGlue. </em>
130
+ </p>
131
+
132
+ <p align="center">
133
+ <a><img src="assets/benchmark_cpu.png" alt="Logo" width=80%></a>
134
+ <br>
135
+ <em>Benchmark results on CPU (Intel i7 10700K). LightGlue runs at 20 FPS @ 512 keypoints. </em>
136
+ </p>
137
+
138
+ Obtain the same plots for your setup using our [benchmark script](benchmark.py):
139
+ ```
140
+ python benchmark.py [--device cuda] [--add_superglue] [--num_keypoints 512 1024 2048 4096] [--compile]
141
+ ```
142
 
143
  <details>
144
+ <summary>[Performance tip - click to expand]</summary>
145
 
146
+ Note: **Point pruning** introduces an overhead that sometimes outweighs its benefits.
147
+ Point pruning is thus enabled only when the there are more than N keypoints in an image, where N is hardware-dependent.
148
+ We provide defaults optimized for current hardware (RTX 30xx GPUs).
149
+ We suggest running the benchmark script and adjusting the thresholds for your hardware by updating `LightGlue.pruning_keypoint_thresholds['cuda']`.
 
 
150
 
151
  </details>
152
 
153
+ ## Training and evaluation
154
+
155
+ With [Glue Factory](https://github.com/cvg/glue-factory), you can train LightGlue with your own local features, on your own dataset!
156
+ You can also evaluate it and other baselines on standard benchmarks like HPatches and MegaDepth.
157
+
158
  ## Other links
159
  - [hloc - the visual localization toolbox](https://github.com/cvg/Hierarchical-Localization/): run LightGlue for Structure-from-Motion and visual localization.
160
+ - [LightGlue-ONNX](https://github.com/fabio-sim/LightGlue-ONNX): export LightGlue to the Open Neural Network Exchange (ONNX) format with support for TensorRT and OpenVINO.
161
  - [Image Matching WebUI](https://github.com/Vincentqyw/image-matching-webui): a web GUI to easily compare different matchers, including LightGlue.
162
+ - [kornia](https://kornia.readthedocs.io) now exposes LightGlue via the interfaces [`LightGlue`](https://kornia.readthedocs.io/en/latest/feature.html#kornia.feature.LightGlue) and [`LightGlueMatcher`](https://kornia.readthedocs.io/en/latest/feature.html#kornia.feature.LightGlueMatcher).
163
 
164
+ ## BibTeX citation
165
  If you use any ideas from the paper or code from this repo, please consider citing:
166
 
167
  ```txt
168
+ @inproceedings{lindenberger2023lightglue,
169
  author = {Philipp Lindenberger and
170
  Paul-Edouard Sarlin and
171
  Marc Pollefeys},
 
174
  year = {2023}
175
  }
176
  ```
177
+
178
+
179
+ ## License
180
+ The pre-trained weights of LightGlue and the code provided in this repository are released under the [Apache-2.0 license](./LICENSE). [DISK](https://github.com/cvlab-epfl/disk) follows this license as well but SuperPoint follows [a different, restrictive license](https://github.com/magicleap/SuperPointPretrainedNetwork/blob/master/LICENSE) (this includes its pre-trained weights and its [inference file](./lightglue/superpoint.py)). [ALIKED](https://github.com/Shiaoming/ALIKED) was published under a BSD-3-Clause license.
third_party/LightGlue/assets/DSC_0410.JPG CHANGED

Git LFS Details

  • SHA256: 1d6a86be44519faf4c86e9a869c5b298a5a7e1478f7479400c28aa2d018bd1b0
  • Pointer size: 131 Bytes
  • Size of remote file: 391 kB
third_party/LightGlue/assets/DSC_0411.JPG CHANGED

Git LFS Details

  • SHA256: 7211ee48ec2fbc082d2dabf8dbf503c853a473712375c2ad32a29d538a168a47
  • Pointer size: 131 Bytes
  • Size of remote file: 421 kB
third_party/LightGlue/assets/benchmark.png ADDED

Git LFS Details

  • SHA256: cae077138caeca75aa99cb8047b198a94bc995e488f00db245d94ad77498142f
  • Pointer size: 130 Bytes
  • Size of remote file: 70.3 kB
third_party/LightGlue/assets/benchmark_cpu.png ADDED

Git LFS Details

  • SHA256: 1801a05606f1d316173365a9692f85c95f2d8aa53570c35415dd4805ac1d075d
  • Pointer size: 130 Bytes
  • Size of remote file: 56.4 kB
third_party/LightGlue/benchmark.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Benchmark script for LightGlue on real images
2
+ import argparse
3
+ import time
4
+ from collections import defaultdict
5
+ from pathlib import Path
6
+
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import torch
10
+ import torch._dynamo
11
+
12
+ from lightglue import LightGlue, SuperPoint
13
+ from lightglue.utils import load_image
14
+
15
+ torch.set_grad_enabled(False)
16
+
17
+
18
+ def measure(matcher, data, device="cuda", r=100):
19
+ timings = np.zeros((r, 1))
20
+ if device.type == "cuda":
21
+ starter = torch.cuda.Event(enable_timing=True)
22
+ ender = torch.cuda.Event(enable_timing=True)
23
+ # warmup
24
+ for _ in range(10):
25
+ _ = matcher(data)
26
+ # measurements
27
+ with torch.no_grad():
28
+ for rep in range(r):
29
+ if device.type == "cuda":
30
+ starter.record()
31
+ _ = matcher(data)
32
+ ender.record()
33
+ # sync gpu
34
+ torch.cuda.synchronize()
35
+ curr_time = starter.elapsed_time(ender)
36
+ else:
37
+ start = time.perf_counter()
38
+ _ = matcher(data)
39
+ curr_time = (time.perf_counter() - start) * 1e3
40
+ timings[rep] = curr_time
41
+ mean_syn = np.sum(timings) / r
42
+ std_syn = np.std(timings)
43
+ return {"mean": mean_syn, "std": std_syn}
44
+
45
+
46
+ def print_as_table(d, title, cnames):
47
+ print()
48
+ header = f"{title:30} " + " ".join([f"{x:>7}" for x in cnames])
49
+ print(header)
50
+ print("-" * len(header))
51
+ for k, l in d.items():
52
+ print(f"{k:30}", " ".join([f"{x:>7.1f}" for x in l]))
53
+
54
+
55
+ if __name__ == "__main__":
56
+ parser = argparse.ArgumentParser(description="Benchmark script for LightGlue")
57
+ parser.add_argument(
58
+ "--device",
59
+ choices=["auto", "cuda", "cpu", "mps"],
60
+ default="auto",
61
+ help="device to benchmark on",
62
+ )
63
+ parser.add_argument("--compile", action="store_true", help="Compile LightGlue runs")
64
+ parser.add_argument(
65
+ "--no_flash", action="store_true", help="disable FlashAttention"
66
+ )
67
+ parser.add_argument(
68
+ "--no_prune_thresholds",
69
+ action="store_true",
70
+ help="disable pruning thresholds (i.e. always do pruning)",
71
+ )
72
+ parser.add_argument(
73
+ "--add_superglue",
74
+ action="store_true",
75
+ help="add SuperGlue to the benchmark (requires hloc)",
76
+ )
77
+ parser.add_argument(
78
+ "--measure", default="time", choices=["time", "log-time", "throughput"]
79
+ )
80
+ parser.add_argument(
81
+ "--repeat", "--r", type=int, default=100, help="repetitions of measurements"
82
+ )
83
+ parser.add_argument(
84
+ "--num_keypoints",
85
+ nargs="+",
86
+ type=int,
87
+ default=[256, 512, 1024, 2048, 4096],
88
+ help="number of keypoints (list separated by spaces)",
89
+ )
90
+ parser.add_argument(
91
+ "--matmul_precision", default="highest", choices=["highest", "high", "medium"]
92
+ )
93
+ parser.add_argument(
94
+ "--save", default=None, type=str, help="path where figure should be saved"
95
+ )
96
+ args = parser.parse_intermixed_args()
97
+
98
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
99
+ if args.device != "auto":
100
+ device = torch.device(args.device)
101
+
102
+ print("Running benchmark on device:", device)
103
+
104
+ images = Path("assets")
105
+ inputs = {
106
+ "easy": (
107
+ load_image(images / "DSC_0411.JPG"),
108
+ load_image(images / "DSC_0410.JPG"),
109
+ ),
110
+ "difficult": (
111
+ load_image(images / "sacre_coeur1.jpg"),
112
+ load_image(images / "sacre_coeur2.jpg"),
113
+ ),
114
+ }
115
+
116
+ configs = {
117
+ "LightGlue-full": {
118
+ "depth_confidence": -1,
119
+ "width_confidence": -1,
120
+ },
121
+ # 'LG-prune': {
122
+ # 'width_confidence': -1,
123
+ # },
124
+ # 'LG-depth': {
125
+ # 'depth_confidence': -1,
126
+ # },
127
+ "LightGlue-adaptive": {},
128
+ }
129
+
130
+ if args.compile:
131
+ configs = {**configs, **{k + "-compile": v for k, v in configs.items()}}
132
+
133
+ sg_configs = {
134
+ # 'SuperGlue': {},
135
+ "SuperGlue-fast": {"sinkhorn_iterations": 5}
136
+ }
137
+
138
+ torch.set_float32_matmul_precision(args.matmul_precision)
139
+
140
+ results = {k: defaultdict(list) for k, v in inputs.items()}
141
+
142
+ extractor = SuperPoint(max_num_keypoints=None, detection_threshold=-1)
143
+ extractor = extractor.eval().to(device)
144
+ figsize = (len(inputs) * 4.5, 4.5)
145
+ fig, axes = plt.subplots(1, len(inputs), sharey=True, figsize=figsize)
146
+ axes = axes if len(inputs) > 1 else [axes]
147
+ fig.canvas.manager.set_window_title(f"LightGlue benchmark ({device.type})")
148
+
149
+ for title, ax in zip(inputs.keys(), axes):
150
+ ax.set_xscale("log", base=2)
151
+ bases = [2**x for x in range(7, 16)]
152
+ ax.set_xticks(bases, bases)
153
+ ax.grid(which="major")
154
+ if args.measure == "log-time":
155
+ ax.set_yscale("log")
156
+ yticks = [10**x for x in range(6)]
157
+ ax.set_yticks(yticks, yticks)
158
+ mpos = [10**x * i for x in range(6) for i in range(2, 10)]
159
+ mlabel = [
160
+ 10**x * i if i in [2, 5] else None
161
+ for x in range(6)
162
+ for i in range(2, 10)
163
+ ]
164
+ ax.set_yticks(mpos, mlabel, minor=True)
165
+ ax.grid(which="minor", linewidth=0.2)
166
+ ax.set_title(title)
167
+
168
+ ax.set_xlabel("# keypoints")
169
+ if args.measure == "throughput":
170
+ ax.set_ylabel("Throughput [pairs/s]")
171
+ else:
172
+ ax.set_ylabel("Latency [ms]")
173
+
174
+ for name, conf in configs.items():
175
+ print("Run benchmark for:", name)
176
+ torch.cuda.empty_cache()
177
+ matcher = LightGlue(features="superpoint", flash=not args.no_flash, **conf)
178
+ if args.no_prune_thresholds:
179
+ matcher.pruning_keypoint_thresholds = {
180
+ k: -1 for k in matcher.pruning_keypoint_thresholds
181
+ }
182
+ matcher = matcher.eval().to(device)
183
+ if name.endswith("compile"):
184
+ import torch._dynamo
185
+
186
+ torch._dynamo.reset() # avoid buffer overflow
187
+ matcher.compile()
188
+ for pair_name, ax in zip(inputs.keys(), axes):
189
+ image0, image1 = [x.to(device) for x in inputs[pair_name]]
190
+ runtimes = []
191
+ for num_kpts in args.num_keypoints:
192
+ extractor.conf.max_num_keypoints = num_kpts
193
+ feats0 = extractor.extract(image0)
194
+ feats1 = extractor.extract(image1)
195
+ runtime = measure(
196
+ matcher,
197
+ {"image0": feats0, "image1": feats1},
198
+ device=device,
199
+ r=args.repeat,
200
+ )["mean"]
201
+ results[pair_name][name].append(
202
+ 1000 / runtime if args.measure == "throughput" else runtime
203
+ )
204
+ ax.plot(
205
+ args.num_keypoints, results[pair_name][name], label=name, marker="o"
206
+ )
207
+ del matcher, feats0, feats1
208
+
209
+ if args.add_superglue:
210
+ from hloc.matchers.superglue import SuperGlue
211
+
212
+ for name, conf in sg_configs.items():
213
+ print("Run benchmark for:", name)
214
+ matcher = SuperGlue(conf)
215
+ matcher = matcher.eval().to(device)
216
+ for pair_name, ax in zip(inputs.keys(), axes):
217
+ image0, image1 = [x.to(device) for x in inputs[pair_name]]
218
+ runtimes = []
219
+ for num_kpts in args.num_keypoints:
220
+ extractor.conf.max_num_keypoints = num_kpts
221
+ feats0 = extractor.extract(image0)
222
+ feats1 = extractor.extract(image1)
223
+ data = {
224
+ "image0": image0[None],
225
+ "image1": image1[None],
226
+ **{k + "0": v for k, v in feats0.items()},
227
+ **{k + "1": v for k, v in feats1.items()},
228
+ }
229
+ data["scores0"] = data["keypoint_scores0"]
230
+ data["scores1"] = data["keypoint_scores1"]
231
+ data["descriptors0"] = (
232
+ data["descriptors0"].transpose(-1, -2).contiguous()
233
+ )
234
+ data["descriptors1"] = (
235
+ data["descriptors1"].transpose(-1, -2).contiguous()
236
+ )
237
+ runtime = measure(matcher, data, device=device, r=args.repeat)[
238
+ "mean"
239
+ ]
240
+ results[pair_name][name].append(
241
+ 1000 / runtime if args.measure == "throughput" else runtime
242
+ )
243
+ ax.plot(
244
+ args.num_keypoints, results[pair_name][name], label=name, marker="o"
245
+ )
246
+ del matcher, data, image0, image1, feats0, feats1
247
+
248
+ for name, runtimes in results.items():
249
+ print_as_table(runtimes, name, args.num_keypoints)
250
+
251
+ axes[0].legend()
252
+ fig.tight_layout()
253
+ if args.save:
254
+ plt.savefig(args.save, dpi=fig.dpi)
255
+ plt.show()
third_party/LightGlue/demo.ipynb CHANGED
@@ -16,16 +16,19 @@
16
  "source": [
17
  "# If we are on colab: this clones the repo and installs the dependencies\n",
18
  "from pathlib import Path\n",
19
- "if Path.cwd().name != 'LightGlue':\n",
20
- " !git clone --quiet https://github.com/cvg/LightGlue/\n",
21
- " %cd LightGlue\n",
22
- " !pip install --progress-bar off --quiet -e .\n",
23
- " \n",
 
24
  "from lightglue import LightGlue, SuperPoint, DISK\n",
25
  "from lightglue.utils import load_image, rbd\n",
26
  "from lightglue import viz2d\n",
27
  "import torch\n",
28
- "images = Path('assets')"
 
 
29
  ]
30
  },
31
  {
@@ -51,10 +54,10 @@
51
  }
52
  ],
53
  "source": [
54
- "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 'mps', 'cpu'\n",
55
  "\n",
56
  "extractor = SuperPoint(max_num_keypoints=2048).eval().to(device) # load the extractor\n",
57
- "matcher = LightGlue(features='superpoint').eval().to(device)"
58
  ]
59
  },
60
  {
@@ -92,22 +95,24 @@
92
  }
93
  ],
94
  "source": [
95
- "image0 = load_image(images / 'DSC_0411.JPG')\n",
96
- "image1 = load_image(images / 'DSC_0410.JPG')\n",
97
  "\n",
98
  "feats0 = extractor.extract(image0.to(device))\n",
99
  "feats1 = extractor.extract(image1.to(device))\n",
100
- "matches01 = matcher({'image0': feats0, 'image1': feats1})\n",
101
- "feats0, feats1, matches01 = [rbd(x) for x in [feats0, feats1, matches01]] # remove batch dimension\n",
 
 
102
  "\n",
103
- "kpts0, kpts1, matches = feats0['keypoints'], feats1['keypoints'], matches01['matches']\n",
104
  "m_kpts0, m_kpts1 = kpts0[matches[..., 0]], kpts1[matches[..., 1]]\n",
105
  "\n",
106
  "axes = viz2d.plot_images([image0, image1])\n",
107
- "viz2d.plot_matches(m_kpts0, m_kpts1, color='lime', lw=0.2)\n",
108
  "viz2d.add_text(0, f'Stop after {matches01[\"stop\"]} layers', fs=20)\n",
109
  "\n",
110
- "kpc0, kpc1 = viz2d.cm_prune(matches01['prune0']), viz2d.cm_prune(matches01['prune1'])\n",
111
  "viz2d.plot_images([image0, image1])\n",
112
  "viz2d.plot_keypoints([kpts0, kpts1], colors=[kpc0, kpc1], ps=10)"
113
  ]
@@ -147,22 +152,24 @@
147
  }
148
  ],
149
  "source": [
150
- "image0 = load_image(images / 'sacre_coeur1.jpg')\n",
151
- "image1 = load_image(images / 'sacre_coeur2.jpg')\n",
152
  "\n",
153
  "feats0 = extractor.extract(image0.to(device))\n",
154
  "feats1 = extractor.extract(image1.to(device))\n",
155
- "matches01 = matcher({'image0': feats0, 'image1': feats1})\n",
156
- "feats0, feats1, matches01 = [rbd(x) for x in [feats0, feats1, matches01]] # remove batch dimension\n",
 
 
157
  "\n",
158
- "kpts0, kpts1, matches = feats0['keypoints'], feats1['keypoints'], matches01['matches']\n",
159
  "m_kpts0, m_kpts1 = kpts0[matches[..., 0]], kpts1[matches[..., 1]]\n",
160
  "\n",
161
  "axes = viz2d.plot_images([image0, image1])\n",
162
- "viz2d.plot_matches(m_kpts0, m_kpts1, color='lime', lw=0.2)\n",
163
  "viz2d.add_text(0, f'Stop after {matches01[\"stop\"]} layers')\n",
164
  "\n",
165
- "kpc0, kpc1 = viz2d.cm_prune(matches01['prune0']), viz2d.cm_prune(matches01['prune1'])\n",
166
  "viz2d.plot_images([image0, image1])\n",
167
  "viz2d.plot_keypoints([kpts0, kpts1], colors=[kpc0, kpc1], ps=6)"
168
  ]
 
16
  "source": [
17
  "# If we are on colab: this clones the repo and installs the dependencies\n",
18
  "from pathlib import Path\n",
19
+ "\n",
20
+ "if Path.cwd().name != \"LightGlue\":\n",
21
+ " !git clone --quiet https://github.com/cvg/LightGlue/\n",
22
+ " %cd LightGlue\n",
23
+ " !pip install --progress-bar off --quiet -e .\n",
24
+ "\n",
25
  "from lightglue import LightGlue, SuperPoint, DISK\n",
26
  "from lightglue.utils import load_image, rbd\n",
27
  "from lightglue import viz2d\n",
28
  "import torch\n",
29
+ "\n",
30
+ "torch.set_grad_enabled(False)\n",
31
+ "images = Path(\"assets\")"
32
  ]
33
  },
34
  {
 
54
  }
55
  ],
56
  "source": [
57
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\") # 'mps', 'cpu'\n",
58
  "\n",
59
  "extractor = SuperPoint(max_num_keypoints=2048).eval().to(device) # load the extractor\n",
60
+ "matcher = LightGlue(features=\"superpoint\").eval().to(device)"
61
  ]
62
  },
63
  {
 
95
  }
96
  ],
97
  "source": [
98
+ "image0 = load_image(images / \"DSC_0411.JPG\")\n",
99
+ "image1 = load_image(images / \"DSC_0410.JPG\")\n",
100
  "\n",
101
  "feats0 = extractor.extract(image0.to(device))\n",
102
  "feats1 = extractor.extract(image1.to(device))\n",
103
+ "matches01 = matcher({\"image0\": feats0, \"image1\": feats1})\n",
104
+ "feats0, feats1, matches01 = [\n",
105
+ " rbd(x) for x in [feats0, feats1, matches01]\n",
106
+ "] # remove batch dimension\n",
107
  "\n",
108
+ "kpts0, kpts1, matches = feats0[\"keypoints\"], feats1[\"keypoints\"], matches01[\"matches\"]\n",
109
  "m_kpts0, m_kpts1 = kpts0[matches[..., 0]], kpts1[matches[..., 1]]\n",
110
  "\n",
111
  "axes = viz2d.plot_images([image0, image1])\n",
112
+ "viz2d.plot_matches(m_kpts0, m_kpts1, color=\"lime\", lw=0.2)\n",
113
  "viz2d.add_text(0, f'Stop after {matches01[\"stop\"]} layers', fs=20)\n",
114
  "\n",
115
+ "kpc0, kpc1 = viz2d.cm_prune(matches01[\"prune0\"]), viz2d.cm_prune(matches01[\"prune1\"])\n",
116
  "viz2d.plot_images([image0, image1])\n",
117
  "viz2d.plot_keypoints([kpts0, kpts1], colors=[kpc0, kpc1], ps=10)"
118
  ]
 
152
  }
153
  ],
154
  "source": [
155
+ "image0 = load_image(images / \"sacre_coeur1.jpg\")\n",
156
+ "image1 = load_image(images / \"sacre_coeur2.jpg\")\n",
157
  "\n",
158
  "feats0 = extractor.extract(image0.to(device))\n",
159
  "feats1 = extractor.extract(image1.to(device))\n",
160
+ "matches01 = matcher({\"image0\": feats0, \"image1\": feats1})\n",
161
+ "feats0, feats1, matches01 = [\n",
162
+ " rbd(x) for x in [feats0, feats1, matches01]\n",
163
+ "] # remove batch dimension\n",
164
  "\n",
165
+ "kpts0, kpts1, matches = feats0[\"keypoints\"], feats1[\"keypoints\"], matches01[\"matches\"]\n",
166
  "m_kpts0, m_kpts1 = kpts0[matches[..., 0]], kpts1[matches[..., 1]]\n",
167
  "\n",
168
  "axes = viz2d.plot_images([image0, image1])\n",
169
+ "viz2d.plot_matches(m_kpts0, m_kpts1, color=\"lime\", lw=0.2)\n",
170
  "viz2d.add_text(0, f'Stop after {matches01[\"stop\"]} layers')\n",
171
  "\n",
172
+ "kpc0, kpc1 = viz2d.cm_prune(matches01[\"prune0\"]), viz2d.cm_prune(matches01[\"prune1\"])\n",
173
  "viz2d.plot_images([image0, image1])\n",
174
  "viz2d.plot_keypoints([kpts0, kpts1], colors=[kpc0, kpc1], ps=6)"
175
  ]
third_party/LightGlue/lightglue/__init__.py CHANGED
@@ -1,4 +1,7 @@
1
- from .lightglue import LightGlue
2
- from .superpoint import SuperPoint
3
- from .disk import DISK
4
- from .utils import match_pair
 
 
 
 
1
+ from .aliked import ALIKED # noqa
2
+ from .disk import DISK # noqa
3
+ from .dog_hardnet import DoGHardNet # noqa
4
+ from .lightglue import LightGlue # noqa
5
+ from .sift import SIFT # noqa
6
+ from .superpoint import SuperPoint # noqa
7
+ from .utils import match_pair # noqa
third_party/LightGlue/lightglue/aliked.py ADDED
@@ -0,0 +1,758 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BSD 3-Clause License
2
+
3
+ # Copyright (c) 2022, Zhao Xiaoming
4
+ # All rights reserved.
5
+
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+
31
+ # Authors:
32
+ # Xiaoming Zhao, Xingming Wu, Weihai Chen, Peter C.Y. Chen, Qingsong Xu, and Zhengguo Li
33
+ # Code from https://github.com/Shiaoming/ALIKED
34
+
35
+ from typing import Callable, Optional
36
+
37
+ import torch
38
+ import torch.nn.functional as F
39
+ import torchvision
40
+ from kornia.color import grayscale_to_rgb
41
+ from torch import nn
42
+ from torch.nn.modules.utils import _pair
43
+ from torchvision.models import resnet
44
+
45
+ from .utils import Extractor
46
+
47
+
48
+ def get_patches(
49
+ tensor: torch.Tensor, required_corners: torch.Tensor, ps: int
50
+ ) -> torch.Tensor:
51
+ c, h, w = tensor.shape
52
+ corner = (required_corners - ps / 2 + 1).long()
53
+ corner[:, 0] = corner[:, 0].clamp(min=0, max=w - 1 - ps)
54
+ corner[:, 1] = corner[:, 1].clamp(min=0, max=h - 1 - ps)
55
+ offset = torch.arange(0, ps)
56
+
57
+ kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {}
58
+ x, y = torch.meshgrid(offset, offset, **kw)
59
+ patches = torch.stack((x, y)).permute(2, 1, 0).unsqueeze(2)
60
+ patches = patches.to(corner) + corner[None, None]
61
+ pts = patches.reshape(-1, 2)
62
+ sampled = tensor.permute(1, 2, 0)[tuple(pts.T)[::-1]]
63
+ sampled = sampled.reshape(ps, ps, -1, c)
64
+ assert sampled.shape[:3] == patches.shape[:3]
65
+ return sampled.permute(2, 3, 0, 1)
66
+
67
+
68
+ def simple_nms(scores: torch.Tensor, nms_radius: int):
69
+ """Fast Non-maximum suppression to remove nearby points"""
70
+
71
+ zeros = torch.zeros_like(scores)
72
+ max_mask = scores == torch.nn.functional.max_pool2d(
73
+ scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
74
+ )
75
+
76
+ for _ in range(2):
77
+ supp_mask = (
78
+ torch.nn.functional.max_pool2d(
79
+ max_mask.float(),
80
+ kernel_size=nms_radius * 2 + 1,
81
+ stride=1,
82
+ padding=nms_radius,
83
+ )
84
+ > 0
85
+ )
86
+ supp_scores = torch.where(supp_mask, zeros, scores)
87
+ new_max_mask = supp_scores == torch.nn.functional.max_pool2d(
88
+ supp_scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
89
+ )
90
+ max_mask = max_mask | (new_max_mask & (~supp_mask))
91
+ return torch.where(max_mask, scores, zeros)
92
+
93
+
94
+ class DKD(nn.Module):
95
+ def __init__(
96
+ self,
97
+ radius: int = 2,
98
+ top_k: int = 0,
99
+ scores_th: float = 0.2,
100
+ n_limit: int = 20000,
101
+ ):
102
+ """
103
+ Args:
104
+ radius: soft detection radius, kernel size is (2 * radius + 1)
105
+ top_k: top_k > 0: return top k keypoints
106
+ scores_th: top_k <= 0 threshold mode:
107
+ scores_th > 0: return keypoints with scores>scores_th
108
+ else: return keypoints with scores > scores.mean()
109
+ n_limit: max number of keypoint in threshold mode
110
+ """
111
+ super().__init__()
112
+ self.radius = radius
113
+ self.top_k = top_k
114
+ self.scores_th = scores_th
115
+ self.n_limit = n_limit
116
+ self.kernel_size = 2 * self.radius + 1
117
+ self.temperature = 0.1 # tuned temperature
118
+ self.unfold = nn.Unfold(kernel_size=self.kernel_size, padding=self.radius)
119
+ # local xy grid
120
+ x = torch.linspace(-self.radius, self.radius, self.kernel_size)
121
+ # (kernel_size*kernel_size) x 2 : (w,h)
122
+ kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {}
123
+ self.hw_grid = (
124
+ torch.stack(torch.meshgrid([x, x], **kw)).view(2, -1).t()[:, [1, 0]]
125
+ )
126
+
127
+ def forward(
128
+ self,
129
+ scores_map: torch.Tensor,
130
+ sub_pixel: bool = True,
131
+ image_size: Optional[torch.Tensor] = None,
132
+ ):
133
+ """
134
+ :param scores_map: Bx1xHxW
135
+ :param descriptor_map: BxCxHxW
136
+ :param sub_pixel: whether to use sub-pixel keypoint detection
137
+ :return: kpts: list[Nx2,...]; kptscores: list[N,....] normalised position: -1~1
138
+ """
139
+ b, c, h, w = scores_map.shape
140
+ scores_nograd = scores_map.detach()
141
+ nms_scores = simple_nms(scores_nograd, self.radius)
142
+
143
+ # remove border
144
+ nms_scores[:, :, : self.radius, :] = 0
145
+ nms_scores[:, :, :, : self.radius] = 0
146
+ if image_size is not None:
147
+ for i in range(scores_map.shape[0]):
148
+ w, h = image_size[i].long()
149
+ nms_scores[i, :, h.item() - self.radius :, :] = 0
150
+ nms_scores[i, :, :, w.item() - self.radius :] = 0
151
+ else:
152
+ nms_scores[:, :, -self.radius :, :] = 0
153
+ nms_scores[:, :, :, -self.radius :] = 0
154
+
155
+ # detect keypoints without grad
156
+ if self.top_k > 0:
157
+ topk = torch.topk(nms_scores.view(b, -1), self.top_k)
158
+ indices_keypoints = [topk.indices[i] for i in range(b)] # B x top_k
159
+ else:
160
+ if self.scores_th > 0:
161
+ masks = nms_scores > self.scores_th
162
+ if masks.sum() == 0:
163
+ th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th
164
+ masks = nms_scores > th.reshape(b, 1, 1, 1)
165
+ else:
166
+ th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th
167
+ masks = nms_scores > th.reshape(b, 1, 1, 1)
168
+ masks = masks.reshape(b, -1)
169
+
170
+ indices_keypoints = [] # list, B x (any size)
171
+ scores_view = scores_nograd.reshape(b, -1)
172
+ for mask, scores in zip(masks, scores_view):
173
+ indices = mask.nonzero()[:, 0]
174
+ if len(indices) > self.n_limit:
175
+ kpts_sc = scores[indices]
176
+ sort_idx = kpts_sc.sort(descending=True)[1]
177
+ sel_idx = sort_idx[: self.n_limit]
178
+ indices = indices[sel_idx]
179
+ indices_keypoints.append(indices)
180
+
181
+ wh = torch.tensor([w - 1, h - 1], device=scores_nograd.device)
182
+
183
+ keypoints = []
184
+ scoredispersitys = []
185
+ kptscores = []
186
+ if sub_pixel:
187
+ # detect soft keypoints with grad backpropagation
188
+ patches = self.unfold(scores_map) # B x (kernel**2) x (H*W)
189
+ self.hw_grid = self.hw_grid.to(scores_map) # to device
190
+ for b_idx in range(b):
191
+ patch = patches[b_idx].t() # (H*W) x (kernel**2)
192
+ indices_kpt = indices_keypoints[
193
+ b_idx
194
+ ] # one dimension vector, say its size is M
195
+ patch_scores = patch[indices_kpt] # M x (kernel**2)
196
+ keypoints_xy_nms = torch.stack(
197
+ [indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")],
198
+ dim=1,
199
+ ) # Mx2
200
+
201
+ # max is detached to prevent undesired backprop loops in the graph
202
+ max_v = patch_scores.max(dim=1).values.detach()[:, None]
203
+ x_exp = (
204
+ (patch_scores - max_v) / self.temperature
205
+ ).exp() # M * (kernel**2), in [0, 1]
206
+
207
+ # \frac{ \sum{(i,j) \times \exp(x/T)} }{ \sum{\exp(x/T)} }
208
+ xy_residual = (
209
+ x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None]
210
+ ) # Soft-argmax, Mx2
211
+
212
+ hw_grid_dist2 = (
213
+ torch.norm(
214
+ (self.hw_grid[None, :, :] - xy_residual[:, None, :])
215
+ / self.radius,
216
+ dim=-1,
217
+ )
218
+ ** 2
219
+ )
220
+ scoredispersity = (x_exp * hw_grid_dist2).sum(dim=1) / x_exp.sum(dim=1)
221
+
222
+ # compute result keypoints
223
+ keypoints_xy = keypoints_xy_nms + xy_residual
224
+ keypoints_xy = keypoints_xy / wh * 2 - 1 # (w,h) -> (-1~1,-1~1)
225
+
226
+ kptscore = torch.nn.functional.grid_sample(
227
+ scores_map[b_idx].unsqueeze(0),
228
+ keypoints_xy.view(1, 1, -1, 2),
229
+ mode="bilinear",
230
+ align_corners=True,
231
+ )[
232
+ 0, 0, 0, :
233
+ ] # CxN
234
+
235
+ keypoints.append(keypoints_xy)
236
+ scoredispersitys.append(scoredispersity)
237
+ kptscores.append(kptscore)
238
+ else:
239
+ for b_idx in range(b):
240
+ indices_kpt = indices_keypoints[
241
+ b_idx
242
+ ] # one dimension vector, say its size is M
243
+ # To avoid warning: UserWarning: __floordiv__ is deprecated
244
+ keypoints_xy_nms = torch.stack(
245
+ [indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")],
246
+ dim=1,
247
+ ) # Mx2
248
+ keypoints_xy = keypoints_xy_nms / wh * 2 - 1 # (w,h) -> (-1~1,-1~1)
249
+ kptscore = torch.nn.functional.grid_sample(
250
+ scores_map[b_idx].unsqueeze(0),
251
+ keypoints_xy.view(1, 1, -1, 2),
252
+ mode="bilinear",
253
+ align_corners=True,
254
+ )[
255
+ 0, 0, 0, :
256
+ ] # CxN
257
+ keypoints.append(keypoints_xy)
258
+ scoredispersitys.append(kptscore) # for jit.script compatability
259
+ kptscores.append(kptscore)
260
+
261
+ return keypoints, scoredispersitys, kptscores
262
+
263
+
264
+ class InputPadder(object):
265
+ """Pads images such that dimensions are divisible by 8"""
266
+
267
+ def __init__(self, h: int, w: int, divis_by: int = 8):
268
+ self.ht = h
269
+ self.wd = w
270
+ pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by
271
+ pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by
272
+ self._pad = [
273
+ pad_wd // 2,
274
+ pad_wd - pad_wd // 2,
275
+ pad_ht // 2,
276
+ pad_ht - pad_ht // 2,
277
+ ]
278
+
279
+ def pad(self, x: torch.Tensor):
280
+ assert x.ndim == 4
281
+ return F.pad(x, self._pad, mode="replicate")
282
+
283
+ def unpad(self, x: torch.Tensor):
284
+ assert x.ndim == 4
285
+ ht = x.shape[-2]
286
+ wd = x.shape[-1]
287
+ c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
288
+ return x[..., c[0] : c[1], c[2] : c[3]]
289
+
290
+
291
+ class DeformableConv2d(nn.Module):
292
+ def __init__(
293
+ self,
294
+ in_channels,
295
+ out_channels,
296
+ kernel_size=3,
297
+ stride=1,
298
+ padding=1,
299
+ bias=False,
300
+ mask=False,
301
+ ):
302
+ super(DeformableConv2d, self).__init__()
303
+
304
+ self.padding = padding
305
+ self.mask = mask
306
+
307
+ self.channel_num = (
308
+ 3 * kernel_size * kernel_size if mask else 2 * kernel_size * kernel_size
309
+ )
310
+ self.offset_conv = nn.Conv2d(
311
+ in_channels,
312
+ self.channel_num,
313
+ kernel_size=kernel_size,
314
+ stride=stride,
315
+ padding=self.padding,
316
+ bias=True,
317
+ )
318
+
319
+ self.regular_conv = nn.Conv2d(
320
+ in_channels=in_channels,
321
+ out_channels=out_channels,
322
+ kernel_size=kernel_size,
323
+ stride=stride,
324
+ padding=self.padding,
325
+ bias=bias,
326
+ )
327
+
328
+ def forward(self, x):
329
+ h, w = x.shape[2:]
330
+ max_offset = max(h, w) / 4.0
331
+
332
+ out = self.offset_conv(x)
333
+ if self.mask:
334
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
335
+ offset = torch.cat((o1, o2), dim=1)
336
+ mask = torch.sigmoid(mask)
337
+ else:
338
+ offset = out
339
+ mask = None
340
+ offset = offset.clamp(-max_offset, max_offset)
341
+ x = torchvision.ops.deform_conv2d(
342
+ input=x,
343
+ offset=offset,
344
+ weight=self.regular_conv.weight,
345
+ bias=self.regular_conv.bias,
346
+ padding=self.padding,
347
+ mask=mask,
348
+ )
349
+ return x
350
+
351
+
352
+ def get_conv(
353
+ inplanes,
354
+ planes,
355
+ kernel_size=3,
356
+ stride=1,
357
+ padding=1,
358
+ bias=False,
359
+ conv_type="conv",
360
+ mask=False,
361
+ ):
362
+ if conv_type == "conv":
363
+ conv = nn.Conv2d(
364
+ inplanes,
365
+ planes,
366
+ kernel_size=kernel_size,
367
+ stride=stride,
368
+ padding=padding,
369
+ bias=bias,
370
+ )
371
+ elif conv_type == "dcn":
372
+ conv = DeformableConv2d(
373
+ inplanes,
374
+ planes,
375
+ kernel_size=kernel_size,
376
+ stride=stride,
377
+ padding=_pair(padding),
378
+ bias=bias,
379
+ mask=mask,
380
+ )
381
+ else:
382
+ raise TypeError
383
+ return conv
384
+
385
+
386
+ class ConvBlock(nn.Module):
387
+ def __init__(
388
+ self,
389
+ in_channels,
390
+ out_channels,
391
+ gate: Optional[Callable[..., nn.Module]] = None,
392
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
393
+ conv_type: str = "conv",
394
+ mask: bool = False,
395
+ ):
396
+ super().__init__()
397
+ if gate is None:
398
+ self.gate = nn.ReLU(inplace=True)
399
+ else:
400
+ self.gate = gate
401
+ if norm_layer is None:
402
+ norm_layer = nn.BatchNorm2d
403
+ self.conv1 = get_conv(
404
+ in_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask
405
+ )
406
+ self.bn1 = norm_layer(out_channels)
407
+ self.conv2 = get_conv(
408
+ out_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask
409
+ )
410
+ self.bn2 = norm_layer(out_channels)
411
+
412
+ def forward(self, x):
413
+ x = self.gate(self.bn1(self.conv1(x))) # B x in_channels x H x W
414
+ x = self.gate(self.bn2(self.conv2(x))) # B x out_channels x H x W
415
+ return x
416
+
417
+
418
+ # modified based on torchvision\models\resnet.py#27->BasicBlock
419
+ class ResBlock(nn.Module):
420
+ expansion: int = 1
421
+
422
+ def __init__(
423
+ self,
424
+ inplanes: int,
425
+ planes: int,
426
+ stride: int = 1,
427
+ downsample: Optional[nn.Module] = None,
428
+ groups: int = 1,
429
+ base_width: int = 64,
430
+ dilation: int = 1,
431
+ gate: Optional[Callable[..., nn.Module]] = None,
432
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
433
+ conv_type: str = "conv",
434
+ mask: bool = False,
435
+ ) -> None:
436
+ super(ResBlock, self).__init__()
437
+ if gate is None:
438
+ self.gate = nn.ReLU(inplace=True)
439
+ else:
440
+ self.gate = gate
441
+ if norm_layer is None:
442
+ norm_layer = nn.BatchNorm2d
443
+ if groups != 1 or base_width != 64:
444
+ raise ValueError("ResBlock only supports groups=1 and base_width=64")
445
+ if dilation > 1:
446
+ raise NotImplementedError("Dilation > 1 not supported in ResBlock")
447
+ # Both self.conv1 and self.downsample layers
448
+ # downsample the input when stride != 1
449
+ self.conv1 = get_conv(
450
+ inplanes, planes, kernel_size=3, conv_type=conv_type, mask=mask
451
+ )
452
+ self.bn1 = norm_layer(planes)
453
+ self.conv2 = get_conv(
454
+ planes, planes, kernel_size=3, conv_type=conv_type, mask=mask
455
+ )
456
+ self.bn2 = norm_layer(planes)
457
+ self.downsample = downsample
458
+ self.stride = stride
459
+
460
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
461
+ identity = x
462
+
463
+ out = self.conv1(x)
464
+ out = self.bn1(out)
465
+ out = self.gate(out)
466
+
467
+ out = self.conv2(out)
468
+ out = self.bn2(out)
469
+
470
+ if self.downsample is not None:
471
+ identity = self.downsample(x)
472
+
473
+ out += identity
474
+ out = self.gate(out)
475
+
476
+ return out
477
+
478
+
479
+ class SDDH(nn.Module):
480
+ def __init__(
481
+ self,
482
+ dims: int,
483
+ kernel_size: int = 3,
484
+ n_pos: int = 8,
485
+ gate=nn.ReLU(),
486
+ conv2D=False,
487
+ mask=False,
488
+ ):
489
+ super(SDDH, self).__init__()
490
+ self.kernel_size = kernel_size
491
+ self.n_pos = n_pos
492
+ self.conv2D = conv2D
493
+ self.mask = mask
494
+
495
+ self.get_patches_func = get_patches
496
+
497
+ # estimate offsets
498
+ self.channel_num = 3 * n_pos if mask else 2 * n_pos
499
+ self.offset_conv = nn.Sequential(
500
+ nn.Conv2d(
501
+ dims,
502
+ self.channel_num,
503
+ kernel_size=kernel_size,
504
+ stride=1,
505
+ padding=0,
506
+ bias=True,
507
+ ),
508
+ gate,
509
+ nn.Conv2d(
510
+ self.channel_num,
511
+ self.channel_num,
512
+ kernel_size=1,
513
+ stride=1,
514
+ padding=0,
515
+ bias=True,
516
+ ),
517
+ )
518
+
519
+ # sampled feature conv
520
+ self.sf_conv = nn.Conv2d(
521
+ dims, dims, kernel_size=1, stride=1, padding=0, bias=False
522
+ )
523
+
524
+ # convM
525
+ if not conv2D:
526
+ # deformable desc weights
527
+ agg_weights = torch.nn.Parameter(torch.rand(n_pos, dims, dims))
528
+ self.register_parameter("agg_weights", agg_weights)
529
+ else:
530
+ self.convM = nn.Conv2d(
531
+ dims * n_pos, dims, kernel_size=1, stride=1, padding=0, bias=False
532
+ )
533
+
534
+ def forward(self, x, keypoints):
535
+ # x: [B,C,H,W]
536
+ # keypoints: list, [[N_kpts,2], ...] (w,h)
537
+ b, c, h, w = x.shape
538
+ wh = torch.tensor([[w - 1, h - 1]], device=x.device)
539
+ max_offset = max(h, w) / 4.0
540
+
541
+ offsets = []
542
+ descriptors = []
543
+ # get offsets for each keypoint
544
+ for ib in range(b):
545
+ xi, kptsi = x[ib], keypoints[ib]
546
+ kptsi_wh = (kptsi / 2 + 0.5) * wh
547
+ N_kpts = len(kptsi)
548
+
549
+ if self.kernel_size > 1:
550
+ patch = self.get_patches_func(
551
+ xi, kptsi_wh.long(), self.kernel_size
552
+ ) # [N_kpts, C, K, K]
553
+ else:
554
+ kptsi_wh_long = kptsi_wh.long()
555
+ patch = (
556
+ xi[:, kptsi_wh_long[:, 1], kptsi_wh_long[:, 0]]
557
+ .permute(1, 0)
558
+ .reshape(N_kpts, c, 1, 1)
559
+ )
560
+
561
+ offset = self.offset_conv(patch).clamp(
562
+ -max_offset, max_offset
563
+ ) # [N_kpts, 2*n_pos, 1, 1]
564
+ if self.mask:
565
+ offset = (
566
+ offset[:, :, 0, 0].view(N_kpts, 3, self.n_pos).permute(0, 2, 1)
567
+ ) # [N_kpts, n_pos, 3]
568
+ offset = offset[:, :, :-1] # [N_kpts, n_pos, 2]
569
+ mask_weight = torch.sigmoid(offset[:, :, -1]) # [N_kpts, n_pos]
570
+ else:
571
+ offset = (
572
+ offset[:, :, 0, 0].view(N_kpts, 2, self.n_pos).permute(0, 2, 1)
573
+ ) # [N_kpts, n_pos, 2]
574
+ offsets.append(offset) # for visualization
575
+
576
+ # get sample positions
577
+ pos = kptsi_wh.unsqueeze(1) + offset # [N_kpts, n_pos, 2]
578
+ pos = 2.0 * pos / wh[None] - 1
579
+ pos = pos.reshape(1, N_kpts * self.n_pos, 1, 2)
580
+
581
+ # sample features
582
+ features = F.grid_sample(
583
+ xi.unsqueeze(0), pos, mode="bilinear", align_corners=True
584
+ ) # [1,C,(N_kpts*n_pos),1]
585
+ features = features.reshape(c, N_kpts, self.n_pos, 1).permute(
586
+ 1, 0, 2, 3
587
+ ) # [N_kpts, C, n_pos, 1]
588
+ if self.mask:
589
+ features = torch.einsum("ncpo,np->ncpo", features, mask_weight)
590
+
591
+ features = torch.selu_(self.sf_conv(features)).squeeze(
592
+ -1
593
+ ) # [N_kpts, C, n_pos]
594
+ # convM
595
+ if not self.conv2D:
596
+ descs = torch.einsum(
597
+ "ncp,pcd->nd", features, self.agg_weights
598
+ ) # [N_kpts, C]
599
+ else:
600
+ features = features.reshape(N_kpts, -1)[
601
+ :, :, None, None
602
+ ] # [N_kpts, C*n_pos, 1, 1]
603
+ descs = self.convM(features).squeeze() # [N_kpts, C]
604
+
605
+ # normalize
606
+ descs = F.normalize(descs, p=2.0, dim=1)
607
+ descriptors.append(descs)
608
+
609
+ return descriptors, offsets
610
+
611
+
612
+ class ALIKED(Extractor):
613
+ default_conf = {
614
+ "model_name": "aliked-n16",
615
+ "max_num_keypoints": -1,
616
+ "detection_threshold": 0.2,
617
+ "nms_radius": 2,
618
+ }
619
+
620
+ checkpoint_url = "https://github.com/Shiaoming/ALIKED/raw/main/models/{}.pth"
621
+
622
+ n_limit_max = 20000
623
+
624
+ # c1, c2, c3, c4, dim, K, M
625
+ cfgs = {
626
+ "aliked-t16": [8, 16, 32, 64, 64, 3, 16],
627
+ "aliked-n16": [16, 32, 64, 128, 128, 3, 16],
628
+ "aliked-n16rot": [16, 32, 64, 128, 128, 3, 16],
629
+ "aliked-n32": [16, 32, 64, 128, 128, 3, 32],
630
+ }
631
+ preprocess_conf = {
632
+ "resize": 1024,
633
+ }
634
+
635
+ required_data_keys = ["image"]
636
+
637
+ def __init__(self, **conf):
638
+ super().__init__(**conf) # Update with default configuration.
639
+ conf = self.conf
640
+ c1, c2, c3, c4, dim, K, M = self.cfgs[conf.model_name]
641
+ conv_types = ["conv", "conv", "dcn", "dcn"]
642
+ conv2D = False
643
+ mask = False
644
+
645
+ # build model
646
+ self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
647
+ self.pool4 = nn.AvgPool2d(kernel_size=4, stride=4)
648
+ self.norm = nn.BatchNorm2d
649
+ self.gate = nn.SELU(inplace=True)
650
+ self.block1 = ConvBlock(3, c1, self.gate, self.norm, conv_type=conv_types[0])
651
+ self.block2 = self.get_resblock(c1, c2, conv_types[1], mask)
652
+ self.block3 = self.get_resblock(c2, c3, conv_types[2], mask)
653
+ self.block4 = self.get_resblock(c3, c4, conv_types[3], mask)
654
+
655
+ self.conv1 = resnet.conv1x1(c1, dim // 4)
656
+ self.conv2 = resnet.conv1x1(c2, dim // 4)
657
+ self.conv3 = resnet.conv1x1(c3, dim // 4)
658
+ self.conv4 = resnet.conv1x1(dim, dim // 4)
659
+ self.upsample2 = nn.Upsample(
660
+ scale_factor=2, mode="bilinear", align_corners=True
661
+ )
662
+ self.upsample4 = nn.Upsample(
663
+ scale_factor=4, mode="bilinear", align_corners=True
664
+ )
665
+ self.upsample8 = nn.Upsample(
666
+ scale_factor=8, mode="bilinear", align_corners=True
667
+ )
668
+ self.upsample32 = nn.Upsample(
669
+ scale_factor=32, mode="bilinear", align_corners=True
670
+ )
671
+ self.score_head = nn.Sequential(
672
+ resnet.conv1x1(dim, 8),
673
+ self.gate,
674
+ resnet.conv3x3(8, 4),
675
+ self.gate,
676
+ resnet.conv3x3(4, 4),
677
+ self.gate,
678
+ resnet.conv3x3(4, 1),
679
+ )
680
+ self.desc_head = SDDH(dim, K, M, gate=self.gate, conv2D=conv2D, mask=mask)
681
+ self.dkd = DKD(
682
+ radius=conf.nms_radius,
683
+ top_k=-1 if conf.detection_threshold > 0 else conf.max_num_keypoints,
684
+ scores_th=conf.detection_threshold,
685
+ n_limit=conf.max_num_keypoints
686
+ if conf.max_num_keypoints > 0
687
+ else self.n_limit_max,
688
+ )
689
+
690
+ state_dict = torch.hub.load_state_dict_from_url(
691
+ self.checkpoint_url.format(conf.model_name), map_location="cpu"
692
+ )
693
+ self.load_state_dict(state_dict, strict=True)
694
+
695
+ def get_resblock(self, c_in, c_out, conv_type, mask):
696
+ return ResBlock(
697
+ c_in,
698
+ c_out,
699
+ 1,
700
+ nn.Conv2d(c_in, c_out, 1),
701
+ gate=self.gate,
702
+ norm_layer=self.norm,
703
+ conv_type=conv_type,
704
+ mask=mask,
705
+ )
706
+
707
+ def extract_dense_map(self, image):
708
+ # Pads images such that dimensions are divisible by
709
+ div_by = 2**5
710
+ padder = InputPadder(image.shape[-2], image.shape[-1], div_by)
711
+ image = padder.pad(image)
712
+
713
+ # ================================== feature encoder
714
+ x1 = self.block1(image) # B x c1 x H x W
715
+ x2 = self.pool2(x1)
716
+ x2 = self.block2(x2) # B x c2 x H/2 x W/2
717
+ x3 = self.pool4(x2)
718
+ x3 = self.block3(x3) # B x c3 x H/8 x W/8
719
+ x4 = self.pool4(x3)
720
+ x4 = self.block4(x4) # B x dim x H/32 x W/32
721
+ # ================================== feature aggregation
722
+ x1 = self.gate(self.conv1(x1)) # B x dim//4 x H x W
723
+ x2 = self.gate(self.conv2(x2)) # B x dim//4 x H//2 x W//2
724
+ x3 = self.gate(self.conv3(x3)) # B x dim//4 x H//8 x W//8
725
+ x4 = self.gate(self.conv4(x4)) # B x dim//4 x H//32 x W//32
726
+ x2_up = self.upsample2(x2) # B x dim//4 x H x W
727
+ x3_up = self.upsample8(x3) # B x dim//4 x H x W
728
+ x4_up = self.upsample32(x4) # B x dim//4 x H x W
729
+ x1234 = torch.cat([x1, x2_up, x3_up, x4_up], dim=1)
730
+ # ================================== score head
731
+ score_map = torch.sigmoid(self.score_head(x1234))
732
+ feature_map = torch.nn.functional.normalize(x1234, p=2, dim=1)
733
+
734
+ # Unpads images
735
+ feature_map = padder.unpad(feature_map)
736
+ score_map = padder.unpad(score_map)
737
+
738
+ return feature_map, score_map
739
+
740
+ def forward(self, data: dict) -> dict:
741
+ image = data["image"]
742
+ if image.shape[1] == 1:
743
+ image = grayscale_to_rgb(image)
744
+ feature_map, score_map = self.extract_dense_map(image)
745
+ keypoints, kptscores, scoredispersitys = self.dkd(
746
+ score_map, image_size=data.get("image_size")
747
+ )
748
+ descriptors, offsets = self.desc_head(feature_map, keypoints)
749
+
750
+ _, _, h, w = image.shape
751
+ wh = torch.tensor([w - 1, h - 1], device=image.device)
752
+ # no padding required
753
+ # we can set detection_threshold=-1 and conf.max_num_keypoints > 0
754
+ return {
755
+ "keypoints": wh * (torch.stack(keypoints) + 1) / 2.0, # B x N x 2
756
+ "descriptors": torch.stack(descriptors), # B x N x D
757
+ "keypoint_scores": torch.stack(kptscores), # B x N
758
+ }
third_party/LightGlue/lightglue/disk.py CHANGED
@@ -1,11 +1,10 @@
1
- import torch
2
- import torch.nn as nn
3
  import kornia
4
- from types import SimpleNamespace
5
- from .utils import ImagePreprocessor
6
 
 
7
 
8
- class DISK(nn.Module):
 
9
  default_conf = {
10
  "weights": "depth",
11
  "max_num_keypoints": None,
@@ -16,7 +15,6 @@ class DISK(nn.Module):
16
  }
17
 
18
  preprocess_conf = {
19
- **ImagePreprocessor.default_conf,
20
  "resize": 1024,
21
  "grayscale": False,
22
  }
@@ -24,9 +22,7 @@ class DISK(nn.Module):
24
  required_data_keys = ["image"]
25
 
26
  def __init__(self, **conf) -> None:
27
- super().__init__()
28
- self.conf = {**self.default_conf, **conf}
29
- self.conf = SimpleNamespace(**self.conf)
30
  self.model = kornia.feature.DISK.from_pretrained(self.conf.weights)
31
 
32
  def forward(self, data: dict) -> dict:
@@ -34,6 +30,8 @@ class DISK(nn.Module):
34
  for key in self.required_data_keys:
35
  assert key in data, f"Missing key {key} in data"
36
  image = data["image"]
 
 
37
  features = self.model(
38
  image,
39
  n=self.conf.max_num_keypoints,
@@ -51,19 +49,7 @@ class DISK(nn.Module):
51
  descriptors = torch.stack(descriptors, 0)
52
 
53
  return {
54
- "keypoints": keypoints.to(image),
55
- "keypoint_scores": scores.to(image),
56
- "descriptors": descriptors.to(image),
57
  }
58
-
59
- def extract(self, img: torch.Tensor, **conf) -> dict:
60
- """Perform extraction with online resizing"""
61
- if img.dim() == 3:
62
- img = img[None] # add batch dim
63
- assert img.dim() == 4 and img.shape[0] == 1
64
- shape = img.shape[-2:][::-1]
65
- img, scales = ImagePreprocessor(**{**self.preprocess_conf, **conf})(img)
66
- feats = self.forward({"image": img})
67
- feats["image_size"] = torch.tensor(shape)[None].to(img).float()
68
- feats["keypoints"] = (feats["keypoints"] + 0.5) / scales[None] - 0.5
69
- return feats
 
 
 
1
  import kornia
2
+ import torch
 
3
 
4
+ from .utils import Extractor
5
 
6
+
7
+ class DISK(Extractor):
8
  default_conf = {
9
  "weights": "depth",
10
  "max_num_keypoints": None,
 
15
  }
16
 
17
  preprocess_conf = {
 
18
  "resize": 1024,
19
  "grayscale": False,
20
  }
 
22
  required_data_keys = ["image"]
23
 
24
  def __init__(self, **conf) -> None:
25
+ super().__init__(**conf) # Update with default configuration.
 
 
26
  self.model = kornia.feature.DISK.from_pretrained(self.conf.weights)
27
 
28
  def forward(self, data: dict) -> dict:
 
30
  for key in self.required_data_keys:
31
  assert key in data, f"Missing key {key} in data"
32
  image = data["image"]
33
+ if image.shape[1] == 1:
34
+ image = kornia.color.grayscale_to_rgb(image)
35
  features = self.model(
36
  image,
37
  n=self.conf.max_num_keypoints,
 
49
  descriptors = torch.stack(descriptors, 0)
50
 
51
  return {
52
+ "keypoints": keypoints.to(image).contiguous(),
53
+ "keypoint_scores": scores.to(image).contiguous(),
54
+ "descriptors": descriptors.to(image).contiguous(),
55
  }
 
 
 
 
 
 
 
 
 
 
 
 
third_party/LightGlue/lightglue/dog_hardnet.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from kornia.color import rgb_to_grayscale
3
+ from kornia.feature import HardNet, LAFDescriptor, laf_from_center_scale_ori
4
+
5
+ from .sift import SIFT
6
+
7
+
8
+ class DoGHardNet(SIFT):
9
+ required_data_keys = ["image"]
10
+
11
+ def __init__(self, **conf):
12
+ super().__init__(**conf)
13
+ self.laf_desc = LAFDescriptor(HardNet(True)).eval()
14
+
15
+ def forward(self, data: dict) -> dict:
16
+ image = data["image"]
17
+ if image.shape[1] == 3:
18
+ image = rgb_to_grayscale(image)
19
+ device = image.device
20
+ self.laf_desc = self.laf_desc.to(device)
21
+ self.laf_desc.descriptor = self.laf_desc.descriptor.eval()
22
+ pred = []
23
+ if "image_size" in data.keys():
24
+ im_size = data.get("image_size").long()
25
+ else:
26
+ im_size = None
27
+ for k in range(len(image)):
28
+ img = image[k]
29
+ if im_size is not None:
30
+ w, h = data["image_size"][k]
31
+ img = img[:, : h.to(torch.int32), : w.to(torch.int32)]
32
+ p = self.extract_single_image(img)
33
+ lafs = laf_from_center_scale_ori(
34
+ p["keypoints"].reshape(1, -1, 2),
35
+ 6.0 * p["scales"].reshape(1, -1, 1, 1),
36
+ torch.rad2deg(p["oris"]).reshape(1, -1, 1),
37
+ ).to(device)
38
+ p["descriptors"] = self.laf_desc(img[None], lafs).reshape(-1, 128)
39
+ pred.append(p)
40
+ pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]}
41
+ return pred
third_party/LightGlue/lightglue/lightglue.py CHANGED
@@ -1,11 +1,12 @@
 
1
  from pathlib import Path
2
  from types import SimpleNamespace
3
- import warnings
 
4
  import numpy as np
5
  import torch
6
- from torch import nn
7
  import torch.nn.functional as F
8
- from typing import Optional, List, Callable
9
 
10
  try:
11
  from flash_attn.modules.mha import FlashCrossAttention
@@ -21,15 +22,32 @@ torch.backends.cudnn.deterministic = True
21
 
22
 
23
  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
24
- def normalize_keypoints(kpts: torch.Tensor, size: torch.Tensor) -> torch.Tensor:
25
- if isinstance(size, torch.Size):
26
- size = torch.tensor(size)[None]
27
- shift = size.float().to(kpts) / 2
28
- scale = size.max(1).values.float().to(kpts) / 2
29
- kpts = (kpts - shift[:, None]) / scale[:, None, None]
 
 
 
 
 
30
  return kpts
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def rotate_half(x: torch.Tensor) -> torch.Tensor:
34
  x = x.unflatten(-1, (-1, 2))
35
  x1, x2 = x.unbind(dim=-1)
@@ -64,8 +82,8 @@ class TokenConfidence(nn.Module):
64
  def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
65
  """get confidence tokens"""
66
  return (
67
- self.token(desc0.detach().float()).squeeze(-1),
68
- self.token(desc1.detach().float()).squeeze(-1),
69
  )
70
 
71
 
@@ -79,29 +97,40 @@ class Attention(nn.Module):
79
  stacklevel=2,
80
  )
81
  self.enable_flash = allow_flash and FLASH_AVAILABLE
 
82
  if allow_flash and FlashCrossAttention:
83
  self.flash_ = FlashCrossAttention()
 
 
84
 
85
- def forward(self, q, k, v) -> torch.Tensor:
 
 
86
  if self.enable_flash and q.device.type == "cuda":
87
- if FlashCrossAttention:
88
- q, k, v = [x.transpose(-2, -3) for x in [q, k, v]]
89
- m = self.flash_(q.half(), torch.stack([k, v], 2).half())
90
- return m.transpose(-2, -3).to(q.dtype)
91
- else: # use torch 2.0 scaled_dot_product_attention with flash
92
  args = [x.half().contiguous() for x in [q, k, v]]
93
- with torch.backends.cuda.sdp_kernel(enable_flash=True):
94
- return F.scaled_dot_product_attention(*args).to(q.dtype)
95
- elif hasattr(F, "scaled_dot_product_attention"):
 
 
 
 
 
96
  args = [x.contiguous() for x in [q, k, v]]
97
- return F.scaled_dot_product_attention(*args).to(q.dtype)
 
98
  else:
99
  s = q.shape[-1] ** -0.5
100
- attn = F.softmax(torch.einsum("...id,...jd->...ij", q, k) * s, -1)
 
 
 
101
  return torch.einsum("...ij,...jd->...id", attn, v)
102
 
103
 
104
- class Transformer(nn.Module):
105
  def __init__(
106
  self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
107
  ) -> None:
@@ -120,22 +149,23 @@ class Transformer(nn.Module):
120
  nn.Linear(2 * embed_dim, embed_dim),
121
  )
122
 
123
- def _forward(self, x: torch.Tensor, encoding: Optional[torch.Tensor] = None):
 
 
 
 
 
124
  qkv = self.Wqkv(x)
125
  qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2)
126
  q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
127
- if encoding is not None:
128
- q = apply_cached_rotary_emb(encoding, q)
129
- k = apply_cached_rotary_emb(encoding, k)
130
- context = self.inner_attn(q, k, v)
131
  message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2))
132
  return x + self.ffn(torch.cat([x, message], -1))
133
 
134
- def forward(self, x0, x1, encoding0=None, encoding1=None):
135
- return self._forward(x0, encoding0), self._forward(x1, encoding1)
136
 
137
-
138
- class CrossTransformer(nn.Module):
139
  def __init__(
140
  self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
141
  ) -> None:
@@ -153,7 +183,6 @@ class CrossTransformer(nn.Module):
153
  nn.GELU(),
154
  nn.Linear(2 * embed_dim, embed_dim),
155
  )
156
-
157
  if flash and FLASH_AVAILABLE:
158
  self.flash = Attention(True)
159
  else:
@@ -162,23 +191,31 @@ class CrossTransformer(nn.Module):
162
  def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor):
163
  return func(x0), func(x1)
164
 
165
- def forward(self, x0: torch.Tensor, x1: torch.Tensor) -> List[torch.Tensor]:
 
 
166
  qk0, qk1 = self.map_(self.to_qk, x0, x1)
167
  v0, v1 = self.map_(self.to_v, x0, x1)
168
  qk0, qk1, v0, v1 = map(
169
  lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2),
170
  (qk0, qk1, v0, v1),
171
  )
172
- if self.flash is not None:
173
- m0 = self.flash(qk0, qk1, v1)
174
- m1 = self.flash(qk1, qk0, v0)
 
 
175
  else:
176
  qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5
177
- sim = torch.einsum("b h i d, b h j d -> b h i j", qk0, qk1)
 
 
178
  attn01 = F.softmax(sim, dim=-1)
179
  attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
180
  m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1)
181
  m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0)
 
 
182
  m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1)
183
  m0, m1 = self.map_(self.to_out, m0, m1)
184
  x0 = x0 + self.ffn(torch.cat([x0, m0], -1))
@@ -186,6 +223,38 @@ class CrossTransformer(nn.Module):
186
  return x0, x1
187
 
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  def sigmoid_log_double_softmax(
190
  sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor
191
  ) -> torch.Tensor:
@@ -219,29 +288,26 @@ class MatchAssignment(nn.Module):
219
  scores = sigmoid_log_double_softmax(sim, z0, z1)
220
  return scores, sim
221
 
222
- def scores(self, desc0: torch.Tensor, desc1: torch.Tensor):
223
- m0 = torch.sigmoid(self.matchability(desc0)).squeeze(-1)
224
- m1 = torch.sigmoid(self.matchability(desc1)).squeeze(-1)
225
- return m0, m1
226
 
227
 
228
  def filter_matches(scores: torch.Tensor, th: float):
229
  """obtain matches from a log assignment matrix [Bx M+1 x N+1]"""
230
  max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
231
  m0, m1 = max0.indices, max1.indices
232
- mutual0 = torch.arange(m0.shape[1]).to(m0)[None] == m1.gather(1, m0)
233
- mutual1 = torch.arange(m1.shape[1]).to(m1)[None] == m0.gather(1, m1)
 
 
234
  max0_exp = max0.values.exp()
235
  zero = max0_exp.new_tensor(0)
236
  mscores0 = torch.where(mutual0, max0_exp, zero)
237
  mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero)
238
- if th is not None:
239
- valid0 = mutual0 & (mscores0 > th)
240
- else:
241
- valid0 = mutual0
242
  valid1 = mutual1 & valid0.gather(1, m1)
243
- m0 = torch.where(valid0, m0, m0.new_tensor(-1))
244
- m1 = torch.where(valid1, m1, m1.new_tensor(-1))
245
  return m0, m1, mscores0, mscores1
246
 
247
 
@@ -250,6 +316,7 @@ class LightGlue(nn.Module):
250
  "name": "lightglue", # just for interfacing
251
  "input_dim": 256, # input descriptor dimension (autoselected from weights)
252
  "descriptor_dim": 256,
 
253
  "n_layers": 9,
254
  "num_heads": 4,
255
  "flash": True, # enable FlashAttention if available.
@@ -260,23 +327,56 @@ class LightGlue(nn.Module):
260
  "weights": None,
261
  }
262
 
 
 
 
 
 
 
 
 
 
263
  required_data_keys = ["image0", "image1"]
264
 
265
  version = "v0.1_arxiv"
266
  url = "https://github.com/cvg/LightGlue/releases/download/{}/{}_lightglue.pth"
267
 
268
  features = {
269
- "superpoint": ("superpoint_lightglue", 256),
270
- "disk": ("disk_lightglue", 128),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  }
272
 
273
  def __init__(self, features="superpoint", **conf) -> None:
274
  super().__init__()
275
- self.conf = {**self.default_conf, **conf}
276
  if features is not None:
277
- assert features in list(self.features.keys())
278
- self.conf["weights"], self.conf["input_dim"] = self.features[features]
279
- self.conf = conf = SimpleNamespace(**self.conf)
 
 
 
 
280
 
281
  if conf.input_dim != conf.descriptor_dim:
282
  self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True)
@@ -284,22 +384,30 @@ class LightGlue(nn.Module):
284
  self.input_proj = nn.Identity()
285
 
286
  head_dim = conf.descriptor_dim // conf.num_heads
287
- self.posenc = LearnableFourierPositionalEncoding(2, head_dim, head_dim)
 
 
288
 
289
  h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim
290
- self.self_attn = nn.ModuleList(
291
- [Transformer(d, h, conf.flash) for _ in range(n)]
292
- )
293
- self.cross_attn = nn.ModuleList(
294
- [CrossTransformer(d, h, conf.flash) for _ in range(n)]
295
  )
 
296
  self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)])
297
  self.token_confidence = nn.ModuleList(
298
  [TokenConfidence(d) for _ in range(n - 1)]
299
  )
 
 
 
 
 
 
300
 
 
301
  if features is not None:
302
- fname = f"{conf.weights}_{self.version}.pth".replace(".", "-")
303
  state_dict = torch.hub.load_state_dict_from_url(
304
  self.url.format(self.version, features), file_name=fname
305
  )
@@ -308,9 +416,35 @@ class LightGlue(nn.Module):
308
  path = Path(__file__).parent
309
  path = path / "weights/{}.pth".format(self.conf.weights)
310
  state_dict = torch.load(str(path), map_location="cpu")
 
 
 
 
 
 
 
 
311
  self.load_state_dict(state_dict, strict=False)
312
 
313
- print("Loaded LightGlue model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
 
315
  def forward(self, data: dict) -> dict:
316
  """
@@ -326,12 +460,15 @@ class LightGlue(nn.Module):
326
  descriptors: [B x N x D]
327
  image: [B x C x H x W] or image_size: [B x 2]
328
  Output (dict):
329
- log_assignment: [B x M+1 x N+1]
330
  matches0: [B x M]
331
  matching_scores0: [B x M]
332
  matches1: [B x N]
333
  matching_scores1: [B x N]
334
- matches: List[[Si x 2]], scores: List[[Si]]
 
 
 
 
335
  """
336
  with torch.autocast(enabled=self.conf.mp, device_type="cuda"):
337
  return self._forward(data)
@@ -340,20 +477,23 @@ class LightGlue(nn.Module):
340
  for key in self.required_data_keys:
341
  assert key in data, f"Missing key {key} in data"
342
  data0, data1 = data["image0"], data["image1"]
343
- kpts0_, kpts1_ = data0["keypoints"], data1["keypoints"]
344
- b, m, _ = kpts0_.shape
345
- b, n, _ = kpts1_.shape
 
346
  size0, size1 = data0.get("image_size"), data1.get("image_size")
347
- size0 = size0 if size0 is not None else data0["image"].shape[-2:][::-1]
348
- size1 = size1 if size1 is not None else data1["image"].shape[-2:][::-1]
349
- kpts0 = normalize_keypoints(kpts0_, size=size0)
350
- kpts1 = normalize_keypoints(kpts1_, size=size1)
351
 
352
- assert torch.all(kpts0 >= -1) and torch.all(kpts0 <= 1)
353
- assert torch.all(kpts1 >= -1) and torch.all(kpts1 <= 1)
354
-
355
- desc0 = data0["descriptors"].detach()
356
- desc1 = data1["descriptors"].detach()
 
 
 
 
357
 
358
  assert desc0.shape[-1] == self.conf.input_dim
359
  assert desc1.shape[-1] == self.conf.input_dim
@@ -362,109 +502,154 @@ class LightGlue(nn.Module):
362
  desc0 = desc0.half()
363
  desc1 = desc1.half()
364
 
 
 
 
 
 
 
 
 
 
365
  desc0 = self.input_proj(desc0)
366
  desc1 = self.input_proj(desc1)
367
-
368
  # cache positional embeddings
369
  encoding0 = self.posenc(kpts0)
370
  encoding1 = self.posenc(kpts1)
371
 
372
  # GNN + final_proj + assignment
373
- ind0 = torch.arange(0, m).to(device=kpts0.device)[None]
374
- ind1 = torch.arange(0, n).to(device=kpts0.device)[None]
375
- prune0 = torch.ones_like(ind0) # store layer where pruning is detected
376
- prune1 = torch.ones_like(ind1)
377
- dec, wic = self.conf.depth_confidence, self.conf.width_confidence
 
 
 
 
378
  token0, token1 = None, None
379
  for i in range(self.conf.n_layers):
380
- # self+cross attention
381
- desc0, desc1 = self.self_attn[i](desc0, desc1, encoding0, encoding1)
382
- desc0, desc1 = self.cross_attn[i](desc0, desc1)
 
 
383
  if i == self.conf.n_layers - 1:
384
  continue # no early stopping or adaptive width at last layer
385
- if dec > 0: # early stopping
 
386
  token0, token1 = self.token_confidence[i](desc0, desc1)
387
- if self.stop(token0, token1, self.conf_th(i), dec, m + n):
388
- break
389
- if wic > 0: # point pruning
390
- match0, match1 = self.log_assignment[i].scores(desc0, desc1)
391
- mask0 = self.get_mask(token0, match0, self.conf_th(i), 1 - wic)
392
- mask1 = self.get_mask(token1, match1, self.conf_th(i), 1 - wic)
393
- ind0, ind1 = ind0[mask0][None], ind1[mask1][None]
394
- desc0, desc1 = desc0[mask0][None], desc1[mask1][None]
395
- if desc0.shape[-2] == 0 or desc1.shape[-2] == 0:
396
  break
397
- encoding0 = encoding0[:, :, mask0][:, None]
398
- encoding1 = encoding1[:, :, mask1][:, None]
399
- prune0[:, ind0] += 1
400
- prune1[:, ind1] += 1
401
-
402
- if wic > 0: # scatter with indices after pruning
403
- scores_, _ = self.log_assignment[i](desc0, desc1)
404
- dt, dev = scores_.dtype, scores_.device
405
- scores = torch.zeros(b, m + 1, n + 1, dtype=dt, device=dev)
406
- scores[:, :-1, :-1] = -torch.inf
407
- scores[:, ind0[0], -1] = scores_[:, :-1, -1]
408
- scores[:, -1, ind1[0]] = scores_[:, -1, :-1]
409
- x, y = torch.meshgrid(ind0[0], ind1[0], indexing="ij")
410
- scores[:, x, y] = scores_[:, :-1, :-1]
411
- else:
412
- scores, _ = self.log_assignment[i](desc0, desc1)
413
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
  m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold)
415
-
416
  matches, mscores = [], []
417
  for k in range(b):
418
  valid = m0[k] > -1
419
- matches.append(torch.stack([torch.where(valid)[0], m0[k][valid]], -1))
 
 
 
 
 
420
  mscores.append(mscores0[k][valid])
421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
  return {
423
- "log_assignment": scores,
424
  "matches0": m0,
425
  "matches1": m1,
426
  "matching_scores0": mscores0,
427
  "matching_scores1": mscores1,
428
  "stop": i + 1,
429
- "prune0": prune0,
430
- "prune1": prune1,
431
  "matches": matches,
432
  "scores": mscores,
 
 
433
  }
434
 
435
- def conf_th(self, i: int) -> float:
436
  """scaled confidence threshold"""
437
- return np.clip(0.8 + 0.1 * np.exp(-4.0 * i / self.conf.n_layers), 0, 1)
 
438
 
439
- def get_mask(
440
- self,
441
- confidence: torch.Tensor,
442
- match: torch.Tensor,
443
- conf_th: float,
444
- match_th: float,
445
  ) -> torch.Tensor:
446
  """mask points which should be removed"""
447
- if conf_th and confidence is not None:
448
- mask = (
449
- torch.where(confidence > conf_th, match, match.new_tensor(1.0))
450
- > match_th
451
- )
452
- else:
453
- mask = match > match_th
454
- return mask
455
 
456
- def stop(
457
  self,
458
- token0: torch.Tensor,
459
- token1: torch.Tensor,
460
- conf_th: float,
461
- inl_th: float,
462
- seql: int,
463
  ) -> torch.Tensor:
464
  """evaluate stopping condition"""
465
- tokens = torch.cat([token0, token1], -1)
466
- if conf_th:
467
- pos = 1.0 - (tokens < conf_th).float().sum() / seql
468
- return pos > inl_th
 
 
 
 
469
  else:
470
- return tokens.mean() > inl_th
 
1
+ import warnings
2
  from pathlib import Path
3
  from types import SimpleNamespace
4
+ from typing import Callable, List, Optional, Tuple
5
+
6
  import numpy as np
7
  import torch
 
8
  import torch.nn.functional as F
9
+ from torch import nn
10
 
11
  try:
12
  from flash_attn.modules.mha import FlashCrossAttention
 
22
 
23
 
24
  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
25
+ def normalize_keypoints(
26
+ kpts: torch.Tensor, size: Optional[torch.Tensor] = None
27
+ ) -> torch.Tensor:
28
+ if size is None:
29
+ size = 1 + kpts.max(-2).values - kpts.min(-2).values
30
+ elif not isinstance(size, torch.Tensor):
31
+ size = torch.tensor(size, device=kpts.device, dtype=kpts.dtype)
32
+ size = size.to(kpts)
33
+ shift = size / 2
34
+ scale = size.max(-1).values / 2
35
+ kpts = (kpts - shift[..., None, :]) / scale[..., None, None]
36
  return kpts
37
 
38
 
39
+ def pad_to_length(x: torch.Tensor, length: int) -> Tuple[torch.Tensor]:
40
+ if length <= x.shape[-2]:
41
+ return x, torch.ones_like(x[..., :1], dtype=torch.bool)
42
+ pad = torch.ones(
43
+ *x.shape[:-2], length - x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype
44
+ )
45
+ y = torch.cat([x, pad], dim=-2)
46
+ mask = torch.zeros(*y.shape[:-1], 1, dtype=torch.bool, device=x.device)
47
+ mask[..., : x.shape[-2], :] = True
48
+ return y, mask
49
+
50
+
51
  def rotate_half(x: torch.Tensor) -> torch.Tensor:
52
  x = x.unflatten(-1, (-1, 2))
53
  x1, x2 = x.unbind(dim=-1)
 
82
  def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
83
  """get confidence tokens"""
84
  return (
85
+ self.token(desc0.detach()).squeeze(-1),
86
+ self.token(desc1.detach()).squeeze(-1),
87
  )
88
 
89
 
 
97
  stacklevel=2,
98
  )
99
  self.enable_flash = allow_flash and FLASH_AVAILABLE
100
+ self.has_sdp = hasattr(F, "scaled_dot_product_attention")
101
  if allow_flash and FlashCrossAttention:
102
  self.flash_ = FlashCrossAttention()
103
+ if self.has_sdp:
104
+ torch.backends.cuda.enable_flash_sdp(allow_flash)
105
 
106
+ def forward(self, q, k, v, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
107
+ if q.shape[-2] == 0 or k.shape[-2] == 0:
108
+ return q.new_zeros((*q.shape[:-1], v.shape[-1]))
109
  if self.enable_flash and q.device.type == "cuda":
110
+ # use torch 2.0 scaled_dot_product_attention with flash
111
+ if self.has_sdp:
 
 
 
112
  args = [x.half().contiguous() for x in [q, k, v]]
113
+ v = F.scaled_dot_product_attention(*args, attn_mask=mask).to(q.dtype)
114
+ return v if mask is None else v.nan_to_num()
115
+ else:
116
+ assert mask is None
117
+ q, k, v = [x.transpose(-2, -3).contiguous() for x in [q, k, v]]
118
+ m = self.flash_(q.half(), torch.stack([k, v], 2).half())
119
+ return m.transpose(-2, -3).to(q.dtype).clone()
120
+ elif self.has_sdp:
121
  args = [x.contiguous() for x in [q, k, v]]
122
+ v = F.scaled_dot_product_attention(*args, attn_mask=mask)
123
+ return v if mask is None else v.nan_to_num()
124
  else:
125
  s = q.shape[-1] ** -0.5
126
+ sim = torch.einsum("...id,...jd->...ij", q, k) * s
127
+ if mask is not None:
128
+ sim.masked_fill(~mask, -float("inf"))
129
+ attn = F.softmax(sim, -1)
130
  return torch.einsum("...ij,...jd->...id", attn, v)
131
 
132
 
133
+ class SelfBlock(nn.Module):
134
  def __init__(
135
  self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
136
  ) -> None:
 
149
  nn.Linear(2 * embed_dim, embed_dim),
150
  )
151
 
152
+ def forward(
153
+ self,
154
+ x: torch.Tensor,
155
+ encoding: torch.Tensor,
156
+ mask: Optional[torch.Tensor] = None,
157
+ ) -> torch.Tensor:
158
  qkv = self.Wqkv(x)
159
  qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2)
160
  q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
161
+ q = apply_cached_rotary_emb(encoding, q)
162
+ k = apply_cached_rotary_emb(encoding, k)
163
+ context = self.inner_attn(q, k, v, mask=mask)
 
164
  message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2))
165
  return x + self.ffn(torch.cat([x, message], -1))
166
 
 
 
167
 
168
+ class CrossBlock(nn.Module):
 
169
  def __init__(
170
  self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
171
  ) -> None:
 
183
  nn.GELU(),
184
  nn.Linear(2 * embed_dim, embed_dim),
185
  )
 
186
  if flash and FLASH_AVAILABLE:
187
  self.flash = Attention(True)
188
  else:
 
191
  def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor):
192
  return func(x0), func(x1)
193
 
194
+ def forward(
195
+ self, x0: torch.Tensor, x1: torch.Tensor, mask: Optional[torch.Tensor] = None
196
+ ) -> List[torch.Tensor]:
197
  qk0, qk1 = self.map_(self.to_qk, x0, x1)
198
  v0, v1 = self.map_(self.to_v, x0, x1)
199
  qk0, qk1, v0, v1 = map(
200
  lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2),
201
  (qk0, qk1, v0, v1),
202
  )
203
+ if self.flash is not None and qk0.device.type == "cuda":
204
+ m0 = self.flash(qk0, qk1, v1, mask)
205
+ m1 = self.flash(
206
+ qk1, qk0, v0, mask.transpose(-1, -2) if mask is not None else None
207
+ )
208
  else:
209
  qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5
210
+ sim = torch.einsum("bhid, bhjd -> bhij", qk0, qk1)
211
+ if mask is not None:
212
+ sim = sim.masked_fill(~mask, -float("inf"))
213
  attn01 = F.softmax(sim, dim=-1)
214
  attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
215
  m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1)
216
  m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0)
217
+ if mask is not None:
218
+ m0, m1 = m0.nan_to_num(), m1.nan_to_num()
219
  m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1)
220
  m0, m1 = self.map_(self.to_out, m0, m1)
221
  x0 = x0 + self.ffn(torch.cat([x0, m0], -1))
 
223
  return x0, x1
224
 
225
 
226
+ class TransformerLayer(nn.Module):
227
+ def __init__(self, *args, **kwargs):
228
+ super().__init__()
229
+ self.self_attn = SelfBlock(*args, **kwargs)
230
+ self.cross_attn = CrossBlock(*args, **kwargs)
231
+
232
+ def forward(
233
+ self,
234
+ desc0,
235
+ desc1,
236
+ encoding0,
237
+ encoding1,
238
+ mask0: Optional[torch.Tensor] = None,
239
+ mask1: Optional[torch.Tensor] = None,
240
+ ):
241
+ if mask0 is not None and mask1 is not None:
242
+ return self.masked_forward(desc0, desc1, encoding0, encoding1, mask0, mask1)
243
+ else:
244
+ desc0 = self.self_attn(desc0, encoding0)
245
+ desc1 = self.self_attn(desc1, encoding1)
246
+ return self.cross_attn(desc0, desc1)
247
+
248
+ # This part is compiled and allows padding inputs
249
+ def masked_forward(self, desc0, desc1, encoding0, encoding1, mask0, mask1):
250
+ mask = mask0 & mask1.transpose(-1, -2)
251
+ mask0 = mask0 & mask0.transpose(-1, -2)
252
+ mask1 = mask1 & mask1.transpose(-1, -2)
253
+ desc0 = self.self_attn(desc0, encoding0, mask0)
254
+ desc1 = self.self_attn(desc1, encoding1, mask1)
255
+ return self.cross_attn(desc0, desc1, mask)
256
+
257
+
258
  def sigmoid_log_double_softmax(
259
  sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor
260
  ) -> torch.Tensor:
 
288
  scores = sigmoid_log_double_softmax(sim, z0, z1)
289
  return scores, sim
290
 
291
+ def get_matchability(self, desc: torch.Tensor):
292
+ return torch.sigmoid(self.matchability(desc)).squeeze(-1)
 
 
293
 
294
 
295
  def filter_matches(scores: torch.Tensor, th: float):
296
  """obtain matches from a log assignment matrix [Bx M+1 x N+1]"""
297
  max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
298
  m0, m1 = max0.indices, max1.indices
299
+ indices0 = torch.arange(m0.shape[1], device=m0.device)[None]
300
+ indices1 = torch.arange(m1.shape[1], device=m1.device)[None]
301
+ mutual0 = indices0 == m1.gather(1, m0)
302
+ mutual1 = indices1 == m0.gather(1, m1)
303
  max0_exp = max0.values.exp()
304
  zero = max0_exp.new_tensor(0)
305
  mscores0 = torch.where(mutual0, max0_exp, zero)
306
  mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero)
307
+ valid0 = mutual0 & (mscores0 > th)
 
 
 
308
  valid1 = mutual1 & valid0.gather(1, m1)
309
+ m0 = torch.where(valid0, m0, -1)
310
+ m1 = torch.where(valid1, m1, -1)
311
  return m0, m1, mscores0, mscores1
312
 
313
 
 
316
  "name": "lightglue", # just for interfacing
317
  "input_dim": 256, # input descriptor dimension (autoselected from weights)
318
  "descriptor_dim": 256,
319
+ "add_scale_ori": False,
320
  "n_layers": 9,
321
  "num_heads": 4,
322
  "flash": True, # enable FlashAttention if available.
 
327
  "weights": None,
328
  }
329
 
330
+ # Point pruning involves an overhead (gather).
331
+ # Therefore, we only activate it if there are enough keypoints.
332
+ pruning_keypoint_thresholds = {
333
+ "cpu": -1,
334
+ "mps": -1,
335
+ "cuda": 1024,
336
+ "flash": 1536,
337
+ }
338
+
339
  required_data_keys = ["image0", "image1"]
340
 
341
  version = "v0.1_arxiv"
342
  url = "https://github.com/cvg/LightGlue/releases/download/{}/{}_lightglue.pth"
343
 
344
  features = {
345
+ "superpoint": {
346
+ "weights": "superpoint_lightglue",
347
+ "input_dim": 256,
348
+ },
349
+ "disk": {
350
+ "weights": "disk_lightglue",
351
+ "input_dim": 128,
352
+ },
353
+ "aliked": {
354
+ "weights": "aliked_lightglue",
355
+ "input_dim": 128,
356
+ },
357
+ "sift": {
358
+ "weights": "sift_lightglue",
359
+ "input_dim": 128,
360
+ "add_scale_ori": True,
361
+ },
362
+ "doghardnet": {
363
+ "weights": "doghardnet_lightglue",
364
+ "input_dim": 128,
365
+ "add_scale_ori": True,
366
+ },
367
  }
368
 
369
  def __init__(self, features="superpoint", **conf) -> None:
370
  super().__init__()
371
+ self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf})
372
  if features is not None:
373
+ if features not in self.features:
374
+ raise ValueError(
375
+ f"Unsupported features: {features} not in "
376
+ f"{{{','.join(self.features)}}}"
377
+ )
378
+ for k, v in self.features[features].items():
379
+ setattr(conf, k, v)
380
 
381
  if conf.input_dim != conf.descriptor_dim:
382
  self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True)
 
384
  self.input_proj = nn.Identity()
385
 
386
  head_dim = conf.descriptor_dim // conf.num_heads
387
+ self.posenc = LearnableFourierPositionalEncoding(
388
+ 2 + 2 * self.conf.add_scale_ori, head_dim, head_dim
389
+ )
390
 
391
  h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim
392
+
393
+ self.transformers = nn.ModuleList(
394
+ [TransformerLayer(d, h, conf.flash) for _ in range(n)]
 
 
395
  )
396
+
397
  self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)])
398
  self.token_confidence = nn.ModuleList(
399
  [TokenConfidence(d) for _ in range(n - 1)]
400
  )
401
+ self.register_buffer(
402
+ "confidence_thresholds",
403
+ torch.Tensor(
404
+ [self.confidence_threshold(i) for i in range(self.conf.n_layers)]
405
+ ),
406
+ )
407
 
408
+ state_dict = None
409
  if features is not None:
410
+ fname = f"{conf.weights}_{self.version.replace('.', '-')}.pth"
411
  state_dict = torch.hub.load_state_dict_from_url(
412
  self.url.format(self.version, features), file_name=fname
413
  )
 
416
  path = Path(__file__).parent
417
  path = path / "weights/{}.pth".format(self.conf.weights)
418
  state_dict = torch.load(str(path), map_location="cpu")
419
+
420
+ if state_dict:
421
+ # rename old state dict entries
422
+ for i in range(self.conf.n_layers):
423
+ pattern = f"self_attn.{i}", f"transformers.{i}.self_attn"
424
+ state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
425
+ pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn"
426
+ state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
427
  self.load_state_dict(state_dict, strict=False)
428
 
429
+ # static lengths LightGlue is compiled for (only used with torch.compile)
430
+ self.static_lengths = None
431
+
432
+ def compile(
433
+ self, mode="reduce-overhead", static_lengths=[256, 512, 768, 1024, 1280, 1536]
434
+ ):
435
+ if self.conf.width_confidence != -1:
436
+ warnings.warn(
437
+ "Point pruning is partially disabled for compiled forward.",
438
+ stacklevel=2,
439
+ )
440
+
441
+ torch._inductor.cudagraph_mark_step_begin()
442
+ for i in range(self.conf.n_layers):
443
+ self.transformers[i].masked_forward = torch.compile(
444
+ self.transformers[i].masked_forward, mode=mode, fullgraph=True
445
+ )
446
+
447
+ self.static_lengths = static_lengths
448
 
449
  def forward(self, data: dict) -> dict:
450
  """
 
460
  descriptors: [B x N x D]
461
  image: [B x C x H x W] or image_size: [B x 2]
462
  Output (dict):
 
463
  matches0: [B x M]
464
  matching_scores0: [B x M]
465
  matches1: [B x N]
466
  matching_scores1: [B x N]
467
+ matches: List[[Si x 2]]
468
+ scores: List[[Si]]
469
+ stop: int
470
+ prune0: [B x M]
471
+ prune1: [B x N]
472
  """
473
  with torch.autocast(enabled=self.conf.mp, device_type="cuda"):
474
  return self._forward(data)
 
477
  for key in self.required_data_keys:
478
  assert key in data, f"Missing key {key} in data"
479
  data0, data1 = data["image0"], data["image1"]
480
+ kpts0, kpts1 = data0["keypoints"], data1["keypoints"]
481
+ b, m, _ = kpts0.shape
482
+ b, n, _ = kpts1.shape
483
+ device = kpts0.device
484
  size0, size1 = data0.get("image_size"), data1.get("image_size")
485
+ kpts0 = normalize_keypoints(kpts0, size0).clone()
486
+ kpts1 = normalize_keypoints(kpts1, size1).clone()
 
 
487
 
488
+ if self.conf.add_scale_ori:
489
+ kpts0 = torch.cat(
490
+ [kpts0] + [data0[k].unsqueeze(-1) for k in ("scales", "oris")], -1
491
+ )
492
+ kpts1 = torch.cat(
493
+ [kpts1] + [data1[k].unsqueeze(-1) for k in ("scales", "oris")], -1
494
+ )
495
+ desc0 = data0["descriptors"].detach().contiguous()
496
+ desc1 = data1["descriptors"].detach().contiguous()
497
 
498
  assert desc0.shape[-1] == self.conf.input_dim
499
  assert desc1.shape[-1] == self.conf.input_dim
 
502
  desc0 = desc0.half()
503
  desc1 = desc1.half()
504
 
505
+ mask0, mask1 = None, None
506
+ c = max(m, n)
507
+ do_compile = self.static_lengths and c <= max(self.static_lengths)
508
+ if do_compile:
509
+ kn = min([k for k in self.static_lengths if k >= c])
510
+ desc0, mask0 = pad_to_length(desc0, kn)
511
+ desc1, mask1 = pad_to_length(desc1, kn)
512
+ kpts0, _ = pad_to_length(kpts0, kn)
513
+ kpts1, _ = pad_to_length(kpts1, kn)
514
  desc0 = self.input_proj(desc0)
515
  desc1 = self.input_proj(desc1)
 
516
  # cache positional embeddings
517
  encoding0 = self.posenc(kpts0)
518
  encoding1 = self.posenc(kpts1)
519
 
520
  # GNN + final_proj + assignment
521
+ do_early_stop = self.conf.depth_confidence > 0
522
+ do_point_pruning = self.conf.width_confidence > 0 and not do_compile
523
+ pruning_th = self.pruning_min_kpts(device)
524
+ if do_point_pruning:
525
+ ind0 = torch.arange(0, m, device=device)[None]
526
+ ind1 = torch.arange(0, n, device=device)[None]
527
+ # We store the index of the layer at which pruning is detected.
528
+ prune0 = torch.ones_like(ind0)
529
+ prune1 = torch.ones_like(ind1)
530
  token0, token1 = None, None
531
  for i in range(self.conf.n_layers):
532
+ if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints
533
+ break
534
+ desc0, desc1 = self.transformers[i](
535
+ desc0, desc1, encoding0, encoding1, mask0=mask0, mask1=mask1
536
+ )
537
  if i == self.conf.n_layers - 1:
538
  continue # no early stopping or adaptive width at last layer
539
+
540
+ if do_early_stop:
541
  token0, token1 = self.token_confidence[i](desc0, desc1)
542
+ if self.check_if_stop(token0[..., :m], token1[..., :n], i, m + n):
 
 
 
 
 
 
 
 
543
  break
544
+ if do_point_pruning and desc0.shape[-2] > pruning_th:
545
+ scores0 = self.log_assignment[i].get_matchability(desc0)
546
+ prunemask0 = self.get_pruning_mask(token0, scores0, i)
547
+ keep0 = torch.where(prunemask0)[1]
548
+ ind0 = ind0.index_select(1, keep0)
549
+ desc0 = desc0.index_select(1, keep0)
550
+ encoding0 = encoding0.index_select(-2, keep0)
551
+ prune0[:, ind0] += 1
552
+ if do_point_pruning and desc1.shape[-2] > pruning_th:
553
+ scores1 = self.log_assignment[i].get_matchability(desc1)
554
+ prunemask1 = self.get_pruning_mask(token1, scores1, i)
555
+ keep1 = torch.where(prunemask1)[1]
556
+ ind1 = ind1.index_select(1, keep1)
557
+ desc1 = desc1.index_select(1, keep1)
558
+ encoding1 = encoding1.index_select(-2, keep1)
559
+ prune1[:, ind1] += 1
560
+
561
+ if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints
562
+ m0 = desc0.new_full((b, m), -1, dtype=torch.long)
563
+ m1 = desc1.new_full((b, n), -1, dtype=torch.long)
564
+ mscores0 = desc0.new_zeros((b, m))
565
+ mscores1 = desc1.new_zeros((b, n))
566
+ matches = desc0.new_empty((b, 0, 2), dtype=torch.long)
567
+ mscores = desc0.new_empty((b, 0))
568
+ if not do_point_pruning:
569
+ prune0 = torch.ones_like(mscores0) * self.conf.n_layers
570
+ prune1 = torch.ones_like(mscores1) * self.conf.n_layers
571
+ return {
572
+ "matches0": m0,
573
+ "matches1": m1,
574
+ "matching_scores0": mscores0,
575
+ "matching_scores1": mscores1,
576
+ "stop": i + 1,
577
+ "matches": matches,
578
+ "scores": mscores,
579
+ "prune0": prune0,
580
+ "prune1": prune1,
581
+ }
582
+
583
+ desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :] # remove padding
584
+ scores, _ = self.log_assignment[i](desc0, desc1)
585
  m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold)
 
586
  matches, mscores = [], []
587
  for k in range(b):
588
  valid = m0[k] > -1
589
+ m_indices_0 = torch.where(valid)[0]
590
+ m_indices_1 = m0[k][valid]
591
+ if do_point_pruning:
592
+ m_indices_0 = ind0[k, m_indices_0]
593
+ m_indices_1 = ind1[k, m_indices_1]
594
+ matches.append(torch.stack([m_indices_0, m_indices_1], -1))
595
  mscores.append(mscores0[k][valid])
596
 
597
+ # TODO: Remove when hloc switches to the compact format.
598
+ if do_point_pruning:
599
+ m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype)
600
+ m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype)
601
+ m0_[:, ind0] = torch.where(m0 == -1, -1, ind1.gather(1, m0.clamp(min=0)))
602
+ m1_[:, ind1] = torch.where(m1 == -1, -1, ind0.gather(1, m1.clamp(min=0)))
603
+ mscores0_ = torch.zeros((b, m), device=mscores0.device)
604
+ mscores1_ = torch.zeros((b, n), device=mscores1.device)
605
+ mscores0_[:, ind0] = mscores0
606
+ mscores1_[:, ind1] = mscores1
607
+ m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_
608
+ else:
609
+ prune0 = torch.ones_like(mscores0) * self.conf.n_layers
610
+ prune1 = torch.ones_like(mscores1) * self.conf.n_layers
611
+
612
  return {
 
613
  "matches0": m0,
614
  "matches1": m1,
615
  "matching_scores0": mscores0,
616
  "matching_scores1": mscores1,
617
  "stop": i + 1,
 
 
618
  "matches": matches,
619
  "scores": mscores,
620
+ "prune0": prune0,
621
+ "prune1": prune1,
622
  }
623
 
624
+ def confidence_threshold(self, layer_index: int) -> float:
625
  """scaled confidence threshold"""
626
+ threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.conf.n_layers)
627
+ return np.clip(threshold, 0, 1)
628
 
629
+ def get_pruning_mask(
630
+ self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int
 
 
 
 
631
  ) -> torch.Tensor:
632
  """mask points which should be removed"""
633
+ keep = scores > (1 - self.conf.width_confidence)
634
+ if confidences is not None: # Low-confidence points are never pruned.
635
+ keep |= confidences <= self.confidence_thresholds[layer_index]
636
+ return keep
 
 
 
 
637
 
638
+ def check_if_stop(
639
  self,
640
+ confidences0: torch.Tensor,
641
+ confidences1: torch.Tensor,
642
+ layer_index: int,
643
+ num_points: int,
 
644
  ) -> torch.Tensor:
645
  """evaluate stopping condition"""
646
+ confidences = torch.cat([confidences0, confidences1], -1)
647
+ threshold = self.confidence_thresholds[layer_index]
648
+ ratio_confident = 1.0 - (confidences < threshold).float().sum() / num_points
649
+ return ratio_confident > self.conf.depth_confidence
650
+
651
+ def pruning_min_kpts(self, device: torch.device):
652
+ if self.conf.flash and FLASH_AVAILABLE and device.type == "cuda":
653
+ return self.pruning_keypoint_thresholds["flash"]
654
  else:
655
+ return self.pruning_keypoint_thresholds[device.type]
third_party/LightGlue/lightglue/sift.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from kornia.color import rgb_to_grayscale
7
+ from packaging import version
8
+
9
+ try:
10
+ import pycolmap
11
+ except ImportError:
12
+ pycolmap = None
13
+
14
+ from .utils import Extractor
15
+
16
+
17
+ def filter_dog_point(points, scales, angles, image_shape, nms_radius, scores=None):
18
+ h, w = image_shape
19
+ ij = np.round(points - 0.5).astype(int).T[::-1]
20
+
21
+ # Remove duplicate points (identical coordinates).
22
+ # Pick highest scale or score
23
+ s = scales if scores is None else scores
24
+ buffer = np.zeros((h, w))
25
+ np.maximum.at(buffer, tuple(ij), s)
26
+ keep = np.where(buffer[tuple(ij)] == s)[0]
27
+
28
+ # Pick lowest angle (arbitrary).
29
+ ij = ij[:, keep]
30
+ buffer[:] = np.inf
31
+ o_abs = np.abs(angles[keep])
32
+ np.minimum.at(buffer, tuple(ij), o_abs)
33
+ mask = buffer[tuple(ij)] == o_abs
34
+ ij = ij[:, mask]
35
+ keep = keep[mask]
36
+
37
+ if nms_radius > 0:
38
+ # Apply NMS on the remaining points
39
+ buffer[:] = 0
40
+ buffer[tuple(ij)] = s[keep] # scores or scale
41
+
42
+ local_max = torch.nn.functional.max_pool2d(
43
+ torch.from_numpy(buffer).unsqueeze(0),
44
+ kernel_size=nms_radius * 2 + 1,
45
+ stride=1,
46
+ padding=nms_radius,
47
+ ).squeeze(0)
48
+ is_local_max = buffer == local_max.numpy()
49
+ keep = keep[is_local_max[tuple(ij)]]
50
+ return keep
51
+
52
+
53
+ def sift_to_rootsift(x: torch.Tensor, eps=1e-6) -> torch.Tensor:
54
+ x = torch.nn.functional.normalize(x, p=1, dim=-1, eps=eps)
55
+ x.clip_(min=eps).sqrt_()
56
+ return torch.nn.functional.normalize(x, p=2, dim=-1, eps=eps)
57
+
58
+
59
+ def run_opencv_sift(features: cv2.Feature2D, image: np.ndarray) -> np.ndarray:
60
+ """
61
+ Detect keypoints using OpenCV Detector.
62
+ Optionally, perform description.
63
+ Args:
64
+ features: OpenCV based keypoints detector and descriptor
65
+ image: Grayscale image of uint8 data type
66
+ Returns:
67
+ keypoints: 1D array of detected cv2.KeyPoint
68
+ scores: 1D array of responses
69
+ descriptors: 1D array of descriptors
70
+ """
71
+ detections, descriptors = features.detectAndCompute(image, None)
72
+ points = np.array([k.pt for k in detections], dtype=np.float32)
73
+ scores = np.array([k.response for k in detections], dtype=np.float32)
74
+ scales = np.array([k.size for k in detections], dtype=np.float32)
75
+ angles = np.deg2rad(np.array([k.angle for k in detections], dtype=np.float32))
76
+ return points, scores, scales, angles, descriptors
77
+
78
+
79
+ class SIFT(Extractor):
80
+ default_conf = {
81
+ "rootsift": True,
82
+ "nms_radius": 0, # None to disable filtering entirely.
83
+ "max_num_keypoints": 4096,
84
+ "backend": "opencv", # in {opencv, pycolmap, pycolmap_cpu, pycolmap_cuda}
85
+ "detection_threshold": 0.0066667, # from COLMAP
86
+ "edge_threshold": 10,
87
+ "first_octave": -1, # only used by pycolmap, the default of COLMAP
88
+ "num_octaves": 4,
89
+ }
90
+
91
+ preprocess_conf = {
92
+ "resize": 1024,
93
+ }
94
+
95
+ required_data_keys = ["image"]
96
+
97
+ def __init__(self, **conf):
98
+ super().__init__(**conf) # Update with default configuration.
99
+ backend = self.conf.backend
100
+ if backend.startswith("pycolmap"):
101
+ if pycolmap is None:
102
+ raise ImportError(
103
+ "Cannot find module pycolmap: install it with pip"
104
+ "or use backend=opencv."
105
+ )
106
+ options = {
107
+ "peak_threshold": self.conf.detection_threshold,
108
+ "edge_threshold": self.conf.edge_threshold,
109
+ "first_octave": self.conf.first_octave,
110
+ "num_octaves": self.conf.num_octaves,
111
+ "normalization": pycolmap.Normalization.L2, # L1_ROOT is buggy.
112
+ }
113
+ device = (
114
+ "auto" if backend == "pycolmap" else backend.replace("pycolmap_", "")
115
+ )
116
+ if (
117
+ backend == "pycolmap_cpu" or not pycolmap.has_cuda
118
+ ) and pycolmap.__version__ < "0.5.0":
119
+ warnings.warn(
120
+ "The pycolmap CPU SIFT is buggy in version < 0.5.0, "
121
+ "consider upgrading pycolmap or use the CUDA version.",
122
+ stacklevel=1,
123
+ )
124
+ else:
125
+ options["max_num_features"] = self.conf.max_num_keypoints
126
+ self.sift = pycolmap.Sift(options=options, device=device)
127
+ elif backend == "opencv":
128
+ self.sift = cv2.SIFT_create(
129
+ contrastThreshold=self.conf.detection_threshold,
130
+ nfeatures=self.conf.max_num_keypoints,
131
+ edgeThreshold=self.conf.edge_threshold,
132
+ nOctaveLayers=self.conf.num_octaves,
133
+ )
134
+ else:
135
+ backends = {"opencv", "pycolmap", "pycolmap_cpu", "pycolmap_cuda"}
136
+ raise ValueError(
137
+ f"Unknown backend: {backend} not in " f"{{{','.join(backends)}}}."
138
+ )
139
+
140
+ def extract_single_image(self, image: torch.Tensor):
141
+ image_np = image.cpu().numpy().squeeze(0)
142
+
143
+ if self.conf.backend.startswith("pycolmap"):
144
+ if version.parse(pycolmap.__version__) >= version.parse("0.5.0"):
145
+ detections, descriptors = self.sift.extract(image_np)
146
+ scores = None # Scores are not exposed by COLMAP anymore.
147
+ else:
148
+ detections, scores, descriptors = self.sift.extract(image_np)
149
+ keypoints = detections[:, :2] # Keep only (x, y).
150
+ scales, angles = detections[:, -2:].T
151
+ if scores is not None and (
152
+ self.conf.backend == "pycolmap_cpu" or not pycolmap.has_cuda
153
+ ):
154
+ # Set the scores as a combination of abs. response and scale.
155
+ scores = np.abs(scores) * scales
156
+ elif self.conf.backend == "opencv":
157
+ # TODO: Check if opencv keypoints are already in corner convention
158
+ keypoints, scores, scales, angles, descriptors = run_opencv_sift(
159
+ self.sift, (image_np * 255.0).astype(np.uint8)
160
+ )
161
+ pred = {
162
+ "keypoints": keypoints,
163
+ "scales": scales,
164
+ "oris": angles,
165
+ "descriptors": descriptors,
166
+ }
167
+ if scores is not None:
168
+ pred["keypoint_scores"] = scores
169
+
170
+ # sometimes pycolmap returns points outside the image. We remove them
171
+ if self.conf.backend.startswith("pycolmap"):
172
+ is_inside = (
173
+ pred["keypoints"] + 0.5 < np.array([image_np.shape[-2:][::-1]])
174
+ ).all(-1)
175
+ pred = {k: v[is_inside] for k, v in pred.items()}
176
+
177
+ if self.conf.nms_radius is not None:
178
+ keep = filter_dog_point(
179
+ pred["keypoints"],
180
+ pred["scales"],
181
+ pred["oris"],
182
+ image_np.shape,
183
+ self.conf.nms_radius,
184
+ scores=pred.get("keypoint_scores"),
185
+ )
186
+ pred = {k: v[keep] for k, v in pred.items()}
187
+
188
+ pred = {k: torch.from_numpy(v) for k, v in pred.items()}
189
+ if scores is not None:
190
+ # Keep the k keypoints with highest score
191
+ num_points = self.conf.max_num_keypoints
192
+ if num_points is not None and len(pred["keypoints"]) > num_points:
193
+ indices = torch.topk(pred["keypoint_scores"], num_points).indices
194
+ pred = {k: v[indices] for k, v in pred.items()}
195
+
196
+ return pred
197
+
198
+ def forward(self, data: dict) -> dict:
199
+ image = data["image"]
200
+ if image.shape[1] == 3:
201
+ image = rgb_to_grayscale(image)
202
+ device = image.device
203
+ image = image.cpu()
204
+ pred = []
205
+ for k in range(len(image)):
206
+ img = image[k]
207
+ if "image_size" in data.keys():
208
+ # avoid extracting points in padded areas
209
+ w, h = data["image_size"][k]
210
+ img = img[:, :h, :w]
211
+ p = self.extract_single_image(img)
212
+ pred.append(p)
213
+ pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]}
214
+ if self.conf.rootsift:
215
+ pred["descriptors"] = sift_to_rootsift(pred["descriptors"])
216
+ return pred
third_party/LightGlue/lightglue/superpoint.py CHANGED
@@ -43,8 +43,10 @@
43
  # Adapted by Remi Pautrat, Philipp Lindenberger
44
 
45
  import torch
 
46
  from torch import nn
47
- from .utils import ImagePreprocessor
 
48
 
49
 
50
  def simple_nms(scores, nms_radius: int):
@@ -77,7 +79,9 @@ def sample_descriptors(keypoints, descriptors, s: int = 8):
77
  """Interpolate descriptors at keypoint locations"""
78
  b, c, h, w = descriptors.shape
79
  keypoints = keypoints - s / 2 + 0.5
80
- keypoints /= torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],).to(
 
 
81
  keypoints
82
  )[None]
83
  keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
@@ -91,7 +95,7 @@ def sample_descriptors(keypoints, descriptors, s: int = 8):
91
  return descriptors
92
 
93
 
94
- class SuperPoint(nn.Module):
95
  """SuperPoint Convolutional Detector and Descriptor
96
 
97
  SuperPoint: Self-Supervised Interest Point Detection and
@@ -109,17 +113,13 @@ class SuperPoint(nn.Module):
109
  }
110
 
111
  preprocess_conf = {
112
- **ImagePreprocessor.default_conf,
113
  "resize": 1024,
114
- "grayscale": True,
115
  }
116
 
117
  required_data_keys = ["image"]
118
 
119
  def __init__(self, **conf):
120
- super().__init__()
121
- self.conf = {**self.default_conf, **conf}
122
-
123
  self.relu = nn.ReLU(inplace=True)
124
  self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
125
  c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256
@@ -138,26 +138,23 @@ class SuperPoint(nn.Module):
138
 
139
  self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
140
  self.convDb = nn.Conv2d(
141
- c5, self.conf["descriptor_dim"], kernel_size=1, stride=1, padding=0
142
  )
143
 
144
- url = "https://github.com/cvg/LightGlue/releases/download/v0.1_arxiv/superpoint_v1.pth"
145
  self.load_state_dict(torch.hub.load_state_dict_from_url(url))
146
 
147
- mk = self.conf["max_num_keypoints"]
148
- if mk is not None and mk <= 0:
149
  raise ValueError("max_num_keypoints must be positive or None")
150
 
151
- print("Loaded SuperPoint model")
152
-
153
  def forward(self, data: dict) -> dict:
154
  """Compute keypoints, scores, descriptors for image"""
155
  for key in self.required_data_keys:
156
  assert key in data, f"Missing key {key} in data"
157
  image = data["image"]
158
- if image.shape[1] == 3: # RGB
159
- scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
160
- image = (image * scale).sum(1, keepdim=True)
161
  # Shared Encoder
162
  x = self.relu(self.conv1a(image))
163
  x = self.relu(self.conv1b(x))
@@ -178,18 +175,18 @@ class SuperPoint(nn.Module):
178
  b, _, h, w = scores.shape
179
  scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
180
  scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8)
181
- scores = simple_nms(scores, self.conf["nms_radius"])
182
 
183
  # Discard keypoints near the image borders
184
- if self.conf["remove_borders"]:
185
- pad = self.conf["remove_borders"]
186
  scores[:, :pad] = -1
187
  scores[:, :, :pad] = -1
188
  scores[:, -pad:] = -1
189
  scores[:, :, -pad:] = -1
190
 
191
  # Extract keypoints
192
- best_kp = torch.where(scores > self.conf["detection_threshold"])
193
  scores = scores[best_kp]
194
 
195
  # Separate into batches
@@ -199,11 +196,11 @@ class SuperPoint(nn.Module):
199
  scores = [scores[best_kp[0] == i] for i in range(b)]
200
 
201
  # Keep the k keypoints with highest score
202
- if self.conf["max_num_keypoints"] is not None:
203
  keypoints, scores = list(
204
  zip(
205
  *[
206
- top_k_keypoints(k, s, self.conf["max_num_keypoints"])
207
  for k, s in zip(keypoints, scores)
208
  ]
209
  )
@@ -226,17 +223,5 @@ class SuperPoint(nn.Module):
226
  return {
227
  "keypoints": torch.stack(keypoints, 0),
228
  "keypoint_scores": torch.stack(scores, 0),
229
- "descriptors": torch.stack(descriptors, 0).transpose(-1, -2),
230
  }
231
-
232
- def extract(self, img: torch.Tensor, **conf) -> dict:
233
- """Perform extraction with online resizing"""
234
- if img.dim() == 3:
235
- img = img[None] # add batch dim
236
- assert img.dim() == 4 and img.shape[0] == 1
237
- shape = img.shape[-2:][::-1]
238
- img, scales = ImagePreprocessor(**{**self.preprocess_conf, **conf})(img)
239
- feats = self.forward({"image": img})
240
- feats["image_size"] = torch.tensor(shape)[None].to(img).float()
241
- feats["keypoints"] = (feats["keypoints"] + 0.5) / scales[None] - 0.5
242
- return feats
 
43
  # Adapted by Remi Pautrat, Philipp Lindenberger
44
 
45
  import torch
46
+ from kornia.color import rgb_to_grayscale
47
  from torch import nn
48
+
49
+ from .utils import Extractor
50
 
51
 
52
  def simple_nms(scores, nms_radius: int):
 
79
  """Interpolate descriptors at keypoint locations"""
80
  b, c, h, w = descriptors.shape
81
  keypoints = keypoints - s / 2 + 0.5
82
+ keypoints /= torch.tensor(
83
+ [(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],
84
+ ).to(
85
  keypoints
86
  )[None]
87
  keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
 
95
  return descriptors
96
 
97
 
98
+ class SuperPoint(Extractor):
99
  """SuperPoint Convolutional Detector and Descriptor
100
 
101
  SuperPoint: Self-Supervised Interest Point Detection and
 
113
  }
114
 
115
  preprocess_conf = {
 
116
  "resize": 1024,
 
117
  }
118
 
119
  required_data_keys = ["image"]
120
 
121
  def __init__(self, **conf):
122
+ super().__init__(**conf) # Update with default configuration.
 
 
123
  self.relu = nn.ReLU(inplace=True)
124
  self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
125
  c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256
 
138
 
139
  self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
140
  self.convDb = nn.Conv2d(
141
+ c5, self.conf.descriptor_dim, kernel_size=1, stride=1, padding=0
142
  )
143
 
144
+ url = "https://github.com/cvg/LightGlue/releases/download/v0.1_arxiv/superpoint_v1.pth" # noqa
145
  self.load_state_dict(torch.hub.load_state_dict_from_url(url))
146
 
147
+ if self.conf.max_num_keypoints is not None and self.conf.max_num_keypoints <= 0:
 
148
  raise ValueError("max_num_keypoints must be positive or None")
149
 
 
 
150
  def forward(self, data: dict) -> dict:
151
  """Compute keypoints, scores, descriptors for image"""
152
  for key in self.required_data_keys:
153
  assert key in data, f"Missing key {key} in data"
154
  image = data["image"]
155
+ if image.shape[1] == 3:
156
+ image = rgb_to_grayscale(image)
157
+
158
  # Shared Encoder
159
  x = self.relu(self.conv1a(image))
160
  x = self.relu(self.conv1b(x))
 
175
  b, _, h, w = scores.shape
176
  scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
177
  scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8)
178
+ scores = simple_nms(scores, self.conf.nms_radius)
179
 
180
  # Discard keypoints near the image borders
181
+ if self.conf.remove_borders:
182
+ pad = self.conf.remove_borders
183
  scores[:, :pad] = -1
184
  scores[:, :, :pad] = -1
185
  scores[:, -pad:] = -1
186
  scores[:, :, -pad:] = -1
187
 
188
  # Extract keypoints
189
+ best_kp = torch.where(scores > self.conf.detection_threshold)
190
  scores = scores[best_kp]
191
 
192
  # Separate into batches
 
196
  scores = [scores[best_kp[0] == i] for i in range(b)]
197
 
198
  # Keep the k keypoints with highest score
199
+ if self.conf.max_num_keypoints is not None:
200
  keypoints, scores = list(
201
  zip(
202
  *[
203
+ top_k_keypoints(k, s, self.conf.max_num_keypoints)
204
  for k, s in zip(keypoints, scores)
205
  ]
206
  )
 
223
  return {
224
  "keypoints": torch.stack(keypoints, 0),
225
  "keypoint_scores": torch.stack(scores, 0),
226
+ "descriptors": torch.stack(descriptors, 0).transpose(-1, -2).contiguous(),
227
  }
 
 
 
 
 
 
 
 
 
 
 
 
third_party/LightGlue/lightglue/utils.py CHANGED
@@ -1,11 +1,12 @@
 
1
  from pathlib import Path
2
- import torch
3
- import kornia
 
4
  import cv2
 
5
  import numpy as np
6
- from typing import Union, List, Optional, Callable, Tuple
7
- import collections.abc as collections
8
- from types import SimpleNamespace
9
 
10
 
11
  class ImagePreprocessor:
@@ -15,7 +16,6 @@ class ImagePreprocessor:
15
  "interpolation": "bilinear",
16
  "align_corners": None,
17
  "antialias": True,
18
- "grayscale": False, # convert rgb to grayscale
19
  }
20
 
21
  def __init__(self, **conf) -> None:
@@ -35,10 +35,6 @@ class ImagePreprocessor:
35
  align_corners=self.conf.align_corners,
36
  )
37
  scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img)
38
- if self.conf.grayscale and img.shape[-3] == 3:
39
- img = kornia.color.rgb_to_grayscale(img)
40
- elif not self.conf.grayscale and img.shape[-3] == 1:
41
- img = kornia.color.grayscale_to_rgb(img)
42
  return img, scale
43
 
44
 
@@ -132,6 +128,25 @@ def load_image(path: Path, resize: int = None, **kwargs) -> torch.Tensor:
132
  return numpy_image_to_torch(image)
133
 
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  def match_pair(
136
  extractor,
137
  matcher,
 
1
+ import collections.abc as collections
2
  from pathlib import Path
3
+ from types import SimpleNamespace
4
+ from typing import Callable, List, Optional, Tuple, Union
5
+
6
  import cv2
7
+ import kornia
8
  import numpy as np
9
+ import torch
 
 
10
 
11
 
12
  class ImagePreprocessor:
 
16
  "interpolation": "bilinear",
17
  "align_corners": None,
18
  "antialias": True,
 
19
  }
20
 
21
  def __init__(self, **conf) -> None:
 
35
  align_corners=self.conf.align_corners,
36
  )
37
  scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img)
 
 
 
 
38
  return img, scale
39
 
40
 
 
128
  return numpy_image_to_torch(image)
129
 
130
 
131
+ class Extractor(torch.nn.Module):
132
+ def __init__(self, **conf):
133
+ super().__init__()
134
+ self.conf = SimpleNamespace(**{**self.default_conf, **conf})
135
+
136
+ @torch.no_grad()
137
+ def extract(self, img: torch.Tensor, **conf) -> dict:
138
+ """Perform extraction with online resizing"""
139
+ if img.dim() == 3:
140
+ img = img[None] # add batch dim
141
+ assert img.dim() == 4 and img.shape[0] == 1
142
+ shape = img.shape[-2:][::-1]
143
+ img, scales = ImagePreprocessor(**{**self.preprocess_conf, **conf})(img)
144
+ feats = self.forward({"image": img})
145
+ feats["image_size"] = torch.tensor(shape)[None].to(img).float()
146
+ feats["keypoints"] = (feats["keypoints"] + 0.5) / scales[None] - 0.5
147
+ return feats
148
+
149
+
150
  def match_pair(
151
  extractor,
152
  matcher,
third_party/LightGlue/lightglue/viz2d.py CHANGED
@@ -6,8 +6,8 @@
6
  """
7
 
8
  import matplotlib
9
- import matplotlib.pyplot as plt
10
  import matplotlib.patheffects as path_effects
 
11
  import numpy as np
12
  import torch
13
 
 
6
  """
7
 
8
  import matplotlib
 
9
  import matplotlib.patheffects as path_effects
10
+ import matplotlib.pyplot as plt
11
  import numpy as np
12
  import torch
13
 
third_party/LightGlue/pyproject.toml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "lightglue"
3
+ description = "LightGlue: Local Feature Matching at Light Speed"
4
+ version = "0.0"
5
+ authors = [
6
+ {name = "Philipp Lindenberger"},
7
+ {name = "Paul-Edouard Sarlin"},
8
+ ]
9
+ readme = "README.md"
10
+ requires-python = ">=3.6"
11
+ license = {file = "LICENSE"}
12
+ classifiers = [
13
+ "Programming Language :: Python :: 3",
14
+ "License :: OSI Approved :: Apache Software License",
15
+ "Operating System :: OS Independent",
16
+ ]
17
+ urls = {Repository = "https://github.com/cvg/LightGlue/"}
18
+ dynamic = ["dependencies"]
19
+
20
+ [project.optional-dependencies]
21
+ dev = ["black==23.12.1", "flake8", "isort"]
22
+
23
+ [tool.setuptools]
24
+ packages = ["lightglue"]
25
+
26
+ [tool.setuptools.dynamic]
27
+ dependencies = {file = ["requirements.txt"]}
28
+
29
+ [tool.isort]
30
+ profile = "black"
third_party/LightGlue/setup.py DELETED
@@ -1,27 +0,0 @@
1
- from pathlib import Path
2
- from setuptools import setup
3
-
4
- description = ["LightGlue"]
5
-
6
- with open(str(Path(__file__).parent / "README.md"), "r", encoding="utf-8") as f:
7
- readme = f.read()
8
- with open(str(Path(__file__).parent / "requirements.txt"), "r") as f:
9
- dependencies = f.read().split("\n")
10
-
11
- setup(
12
- name="lightglue",
13
- version="0.0",
14
- packages=["lightglue"],
15
- python_requires=">=3.6",
16
- install_requires=dependencies,
17
- author="Philipp Lindenberger, Paul-Edouard Sarlin",
18
- description=description,
19
- long_description=readme,
20
- long_description_content_type="text/markdown",
21
- url="https://github.com/cvg/LightGlue/",
22
- classifiers=[
23
- "Programming Language :: Python :: 3",
24
- "License :: OSI Approved :: Apache Software License",
25
- "Operating System :: OS Independent",
26
- ],
27
- )