Realcat commited on
Commit
8811cfe
·
1 Parent(s): b9d8129

update: moving lfs files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +1 -1
  2. hloc/__init__.py +3 -0
  3. hloc/extractors/alike.py +9 -1
  4. hloc/extractors/d2net.py +10 -14
  5. hloc/extractors/darkfeat.py +8 -25
  6. hloc/extractors/dedode.py +12 -27
  7. hloc/extractors/dir.py +1 -0
  8. hloc/extractors/lanet.py +12 -5
  9. hloc/extractors/r2d2.py +8 -3
  10. hloc/extractors/rekd.py +7 -3
  11. hloc/extractors/rord.py +7 -21
  12. hloc/extractors/sfd2.py +7 -2
  13. hloc/matchers/aspanformer.py +9 -48
  14. hloc/matchers/cotr.py +9 -7
  15. hloc/matchers/dkm.py +10 -22
  16. hloc/matchers/duster.py +11 -19
  17. hloc/matchers/eloftr.py +12 -3
  18. hloc/matchers/gim.py +8 -25
  19. hloc/matchers/gluestick.py +7 -18
  20. hloc/matchers/imp.py +9 -3
  21. hloc/matchers/lightglue.py +8 -3
  22. hloc/matchers/mast3r.py +13 -23
  23. hloc/matchers/mickey.py +8 -21
  24. hloc/matchers/omniglue.py +15 -13
  25. hloc/matchers/roma.py +13 -25
  26. hloc/matchers/sgmnet.py +9 -40
  27. hloc/matchers/sold2.py +10 -17
  28. hloc/matchers/topicfm.py +10 -2
  29. hloc/utils/base_model.py +9 -1
  30. requirements.txt +1 -1
  31. third_party/ALIKE/assets/ALIKE_code.zip +0 -3
  32. third_party/ALIKE/assets/alike.png +0 -3
  33. third_party/ALIKE/assets/kitti.gif +0 -3
  34. third_party/ALIKE/assets/kitti/000100.png +0 -3
  35. third_party/ALIKE/assets/kitti/000101.png +0 -3
  36. third_party/ALIKE/assets/kitti/000102.png +0 -3
  37. third_party/ALIKE/assets/kitti/000103.png +0 -3
  38. third_party/ALIKE/assets/kitti/000104.png +0 -3
  39. third_party/ALIKE/assets/kitti/000105.png +0 -3
  40. third_party/ALIKE/assets/kitti/000106.png +0 -3
  41. third_party/ALIKE/assets/kitti/000107.png +0 -3
  42. third_party/ALIKE/assets/kitti/000108.png +0 -3
  43. third_party/ALIKE/assets/kitti/000109.png +0 -3
  44. third_party/ALIKE/assets/kitti/000110.png +0 -3
  45. third_party/ALIKE/assets/kitti/000111.png +0 -3
  46. third_party/ALIKE/assets/kitti/000112.png +0 -3
  47. third_party/ALIKE/assets/kitti/000113.png +0 -3
  48. third_party/ALIKE/assets/kitti/000114.png +0 -3
  49. third_party/ALIKE/assets/kitti/000115.png +0 -3
  50. third_party/ALIKE/assets/kitti/000116.png +0 -3
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🤗
4
  colorFrom: red
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 4.44.1
8
  app_file: app.py
9
  pinned: true
10
  license: apache-2.0
 
4
  colorFrom: red
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 5.4.0
8
  app_file: app.py
9
  pinned: true
10
  license: apache-2.0
hloc/__init__.py CHANGED
@@ -61,3 +61,6 @@ else:
61
  )
62
 
63
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
61
  )
62
 
63
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
64
+
65
+ # model hub: https://huggingface.co/Realcat/imatchui_checkpoint
66
+ MODEL_REPO_ID = "Realcat/imatchui_checkpoints"
hloc/extractors/alike.py CHANGED
@@ -3,7 +3,7 @@ from pathlib import Path
3
 
4
  import torch
5
 
6
- from hloc import logger
7
 
8
  from ..utils.base_model import BaseModel
9
 
@@ -29,12 +29,20 @@ class Alike(BaseModel):
29
  required_inputs = ["image"]
30
 
31
  def _init(self, conf):
 
 
 
 
 
 
 
32
  self.net = Alike_(
33
  **configs[conf["model_name"]],
34
  device=device,
35
  top_k=conf["top_k"],
36
  scores_th=conf["detection_threshold"],
37
  n_limit=conf["max_keypoints"],
 
38
  )
39
  logger.info("Load Alike model done.")
40
 
 
3
 
4
  import torch
5
 
6
+ from hloc import MODEL_REPO_ID, logger
7
 
8
  from ..utils.base_model import BaseModel
9
 
 
29
  required_inputs = ["image"]
30
 
31
  def _init(self, conf):
32
+ model_path = self._download_model(
33
+ repo_id=MODEL_REPO_ID,
34
+ filename="{}/{}.pth".format(
35
+ Path(__file__).stem, self.conf["model_name"]
36
+ ),
37
+ )
38
+ logger.info("Loaded Alike model from {}".format(model_path))
39
  self.net = Alike_(
40
  **configs[conf["model_name"]],
41
  device=device,
42
  top_k=conf["top_k"],
43
  scores_th=conf["detection_threshold"],
44
  n_limit=conf["max_keypoints"],
45
+ model_path=model_path,
46
  )
47
  logger.info("Load Alike model done.")
48
 
hloc/extractors/d2net.py CHANGED
@@ -1,10 +1,9 @@
1
- import subprocess
2
  import sys
3
  from pathlib import Path
4
 
5
  import torch
6
 
7
- from hloc import logger
8
 
9
  from ..utils.base_model import BaseModel
10
 
@@ -25,20 +24,17 @@ class D2Net(BaseModel):
25
  required_inputs = ["image"]
26
 
27
  def _init(self, conf):
28
- model_file = conf["checkpoint_dir"] / conf["model_name"]
29
- if not model_file.exists():
30
- model_file.parent.mkdir(exist_ok=True)
31
- cmd = [
32
- "wget",
33
- "--quiet",
34
- "https://dusmanu.com/files/d2-net/" + conf["model_name"],
35
- "-O",
36
- str(model_file),
37
- ]
38
- subprocess.run(cmd, check=True)
39
 
 
 
 
 
 
 
 
 
40
  self.net = _D2Net(
41
- model_file=model_file, use_relu=conf["use_relu"], use_cuda=False
42
  )
43
  logger.info("Load D2Net model done.")
44
 
 
 
1
  import sys
2
  from pathlib import Path
3
 
4
  import torch
5
 
6
+ from hloc import MODEL_REPO_ID, logger
7
 
8
  from ..utils.base_model import BaseModel
9
 
 
24
  required_inputs = ["image"]
25
 
26
  def _init(self, conf):
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ logger.info("Loading D2Net model...")
29
+ model_path = self._download_model(
30
+ repo_id=MODEL_REPO_ID,
31
+ filename="{}/{}".format(
32
+ Path(__file__).stem, self.conf["model_name"]
33
+ ),
34
+ )
35
+ logger.info(f"Loading model from {model_path}...")
36
  self.net = _D2Net(
37
+ model_file=model_path, use_relu=conf["use_relu"], use_cuda=False
38
  )
39
  logger.info("Load D2Net model done.")
40
 
hloc/extractors/darkfeat.py CHANGED
@@ -2,7 +2,7 @@ import subprocess
2
  import sys
3
  from pathlib import Path
4
 
5
- from hloc import logger
6
 
7
  from ..utils.base_model import BaseModel
8
 
@@ -18,33 +18,16 @@ class DarkFeat(BaseModel):
18
  "detection_threshold": 0.5,
19
  "sub_pixel": False,
20
  }
21
- weight_urls = {
22
- "DarkFeat.pth": "https://drive.google.com/uc?id=1Thl6m8NcmQ7zSAF-1_xaFs3F4H8UU6HX&confirm=t",
23
- }
24
- proxy = "http://localhost:1080"
25
  required_inputs = ["image"]
26
 
27
  def _init(self, conf):
28
- model_path = darkfeat_path / "checkpoints" / conf["model_name"]
29
- link = self.weight_urls[conf["model_name"]]
30
- if not model_path.exists():
31
- model_path.parent.mkdir(exist_ok=True)
32
- cmd_wo_proxy = ["gdown", link, "-O", str(model_path)]
33
- cmd = ["gdown", link, "-O", str(model_path), "--proxy", self.proxy]
34
- logger.info(
35
- f"Downloading the DarkFeat model with `{cmd_wo_proxy}`."
36
- )
37
- try:
38
- subprocess.run(cmd_wo_proxy, check=True)
39
- except subprocess.CalledProcessError as e:
40
- logger.info(f"Downloading the model failed `{e}`.")
41
- logger.info(f"Downloading the DarkFeat model with `{cmd}`.")
42
- try:
43
- subprocess.run(cmd, check=True)
44
- except subprocess.CalledProcessError as e:
45
- logger.error("Failed to download the DarkFeat model.")
46
- raise e
47
-
48
  self.net = DarkFeat_(model_path)
49
  logger.info("Load DarkFeat model done.")
50
 
 
2
  import sys
3
  from pathlib import Path
4
 
5
+ from hloc import MODEL_REPO_ID, logger
6
 
7
  from ..utils.base_model import BaseModel
8
 
 
18
  "detection_threshold": 0.5,
19
  "sub_pixel": False,
20
  }
 
 
 
 
21
  required_inputs = ["image"]
22
 
23
  def _init(self, conf):
24
+ model_path = self._download_model(
25
+ repo_id=MODEL_REPO_ID,
26
+ filename="{}/{}.pth".format(
27
+ Path(__file__).stem, self.conf["model_name"]
28
+ ),
29
+ )
30
+ logger.info("Loaded DarkFeat model: {}".format(model_path))
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  self.net = DarkFeat_(model_path)
32
  logger.info("Load DarkFeat model done.")
33
 
hloc/extractors/dedode.py CHANGED
@@ -1,11 +1,10 @@
1
- import subprocess
2
  import sys
3
  from pathlib import Path
4
 
5
  import torch
6
  import torchvision.transforms as transforms
7
 
8
- from hloc import logger
9
 
10
  from ..utils.base_model import BaseModel
11
 
@@ -30,39 +29,25 @@ class DeDoDe(BaseModel):
30
  required_inputs = [
31
  "image",
32
  ]
33
- weight_urls = {
34
- "dedode_detector_L.pth": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_detector_L.pth",
35
- "dedode_descriptor_B.pth": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_descriptor_B.pth",
36
- }
37
 
38
  # Initialize the line matcher
39
  def _init(self, conf):
40
- model_detector_path = (
41
- dedode_path / "pretrained" / conf["model_detector_name"]
 
 
 
42
  )
43
- model_descriptor_path = (
44
- dedode_path / "pretrained" / conf["model_descriptor_name"]
 
 
 
45
  )
46
-
47
  self.normalizer = transforms.Normalize(
48
  mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
49
  )
50
- # Download the model.
51
- if not model_detector_path.exists():
52
- model_detector_path.parent.mkdir(exist_ok=True)
53
- link = self.weight_urls[conf["model_detector_name"]]
54
- cmd = ["wget", "--quiet", link, "-O", str(model_detector_path)]
55
- logger.info(f"Downloading the DeDoDe detector model with `{cmd}`.")
56
- subprocess.run(cmd, check=True)
57
-
58
- if not model_descriptor_path.exists():
59
- model_descriptor_path.parent.mkdir(exist_ok=True)
60
- link = self.weight_urls[conf["model_descriptor_name"]]
61
- cmd = ["wget", "--quiet", link, "-O", str(model_descriptor_path)]
62
- logger.info(
63
- f"Downloading the DeDoDe descriptor model with `{cmd}`."
64
- )
65
- subprocess.run(cmd, check=True)
66
 
67
  # load the model
68
  weights_detector = torch.load(model_detector_path, map_location="cpu")
 
 
1
  import sys
2
  from pathlib import Path
3
 
4
  import torch
5
  import torchvision.transforms as transforms
6
 
7
+ from hloc import MODEL_REPO_ID, logger
8
 
9
  from ..utils.base_model import BaseModel
10
 
 
29
  required_inputs = [
30
  "image",
31
  ]
 
 
 
 
32
 
33
  # Initialize the line matcher
34
  def _init(self, conf):
35
+ model_detector_path = self._download_model(
36
+ repo_id=MODEL_REPO_ID,
37
+ filename="{}/{}".format(
38
+ Path(__file__).stem, conf["model_detector_name"]
39
+ ),
40
  )
41
+ model_descriptor_path = self._download_model(
42
+ repo_id=MODEL_REPO_ID,
43
+ filename="{}/{}".format(
44
+ Path(__file__).stem, conf["model_descriptor_name"]
45
+ ),
46
  )
47
+ logger.info("Loaded DarkFeat model: {}".format(model_detector_path))
48
  self.normalizer = transforms.Normalize(
49
  mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
50
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  # load the model
53
  weights_detector = torch.load(model_detector_path, map_location="cpu")
hloc/extractors/dir.py CHANGED
@@ -43,6 +43,7 @@ class DIR(BaseModel):
43
  }
44
 
45
  def _init(self, conf):
 
46
  checkpoint = Path(
47
  torch.hub.get_dir(), "dirtorch", conf["model_name"] + ".pt"
48
  )
 
43
  }
44
 
45
  def _init(self, conf):
46
+ # todo: download from google drive -> huggingface models
47
  checkpoint = Path(
48
  torch.hub.get_dir(), "dirtorch", conf["model_name"] + ".pt"
49
  )
hloc/extractors/lanet.py CHANGED
@@ -3,7 +3,7 @@ from pathlib import Path
3
 
4
  import torch
5
 
6
- from hloc import logger
7
 
8
  from ..utils.base_model import BaseModel
9
 
@@ -18,18 +18,25 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
 
19
  class LANet(BaseModel):
20
  default_conf = {
21
- "model_name": "v0",
22
  "keypoint_threshold": 0.1,
23
  "max_keypoints": 1024,
24
  }
25
  required_inputs = ["image"]
26
 
27
  def _init(self, conf):
28
- model_path = (
29
- lanet_path / "checkpoints" / f'PointModel_{conf["model_name"]}.pth'
 
 
 
 
 
30
  )
31
  if not model_path.exists():
32
- logger.warning(f"No model found at {model_path}, start downloading")
 
 
33
  self.net = PointModel(is_test=True)
34
  state_dict = torch.load(model_path, map_location="cpu")
35
  self.net.load_state_dict(state_dict["model_state"])
 
3
 
4
  import torch
5
 
6
+ from hloc import MODEL_REPO_ID, logger
7
 
8
  from ..utils.base_model import BaseModel
9
 
 
18
 
19
  class LANet(BaseModel):
20
  default_conf = {
21
+ "model_name": "PointModel_v0.pth",
22
  "keypoint_threshold": 0.1,
23
  "max_keypoints": 1024,
24
  }
25
  required_inputs = ["image"]
26
 
27
  def _init(self, conf):
28
+ logger.info("Loading LANet model...")
29
+
30
+ model_path = self._download_model(
31
+ repo_id=MODEL_REPO_ID,
32
+ filename="{}/{}".format(
33
+ Path(__file__).stem, self.conf["model_name"]
34
+ ),
35
  )
36
  if not model_path.exists():
37
+ logger.warning(
38
+ f"No model found at {model_path}, please download it first."
39
+ )
40
  self.net = PointModel(is_test=True)
41
  state_dict = torch.load(model_path, map_location="cpu")
42
  self.net.load_state_dict(state_dict["model_state"])
hloc/extractors/r2d2.py CHANGED
@@ -3,7 +3,7 @@ from pathlib import Path
3
 
4
  import torchvision.transforms as tvf
5
 
6
- from hloc import logger
7
 
8
  from ..utils.base_model import BaseModel
9
 
@@ -27,11 +27,16 @@ class R2D2(BaseModel):
27
  required_inputs = ["image"]
28
 
29
  def _init(self, conf):
30
- model_fn = r2d2_path / "models" / conf["model_name"]
 
 
 
 
 
31
  self.norm_rgb = tvf.Normalize(
32
  mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
33
  )
34
- self.net = load_network(model_fn)
35
  self.detector = NonMaxSuppression(
36
  rel_thr=conf["reliability_threshold"],
37
  rep_thr=conf["repetability_threshold"],
 
3
 
4
  import torchvision.transforms as tvf
5
 
6
+ from hloc import MODEL_REPO_ID, logger
7
 
8
  from ..utils.base_model import BaseModel
9
 
 
27
  required_inputs = ["image"]
28
 
29
  def _init(self, conf):
30
+ model_path = self._download_model(
31
+ repo_id=MODEL_REPO_ID,
32
+ filename="{}/{}".format(
33
+ Path(__file__).stem, self.conf["model_name"]
34
+ ),
35
+ )
36
  self.norm_rgb = tvf.Normalize(
37
  mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
38
  )
39
+ self.net = load_network(model_path)
40
  self.detector = NonMaxSuppression(
41
  rel_thr=conf["reliability_threshold"],
42
  rep_thr=conf["repetability_threshold"],
hloc/extractors/rekd.py CHANGED
@@ -3,7 +3,7 @@ from pathlib import Path
3
 
4
  import torch
5
 
6
- from hloc import logger
7
 
8
  from ..utils.base_model import BaseModel
9
 
@@ -22,8 +22,12 @@ class REKD(BaseModel):
22
  required_inputs = ["image"]
23
 
24
  def _init(self, conf):
25
- model_path = (
26
- rekd_path / "checkpoints" / f'PointModel_{conf["model_name"]}.pth'
 
 
 
 
27
  )
28
  if not model_path.exists():
29
  print(f"No model found at {model_path}")
 
3
 
4
  import torch
5
 
6
+ from hloc import MODEL_REPO_ID, logger
7
 
8
  from ..utils.base_model import BaseModel
9
 
 
22
  required_inputs = ["image"]
23
 
24
  def _init(self, conf):
25
+ # TODO: download model
26
+ model_path = self._download_model(
27
+ repo_id=MODEL_REPO_ID,
28
+ filename="{}/{}".format(
29
+ Path(__file__).stem, self.conf["model_name"]
30
+ ),
31
  )
32
  if not model_path.exists():
33
  print(f"No model found at {model_path}")
hloc/extractors/rord.py CHANGED
@@ -4,7 +4,7 @@ from pathlib import Path
4
 
5
  import torch
6
 
7
- from hloc import logger
8
 
9
  from ..utils.base_model import BaseModel
10
 
@@ -23,28 +23,14 @@ class RoRD(BaseModel):
23
  "max_keypoints": 1024,
24
  }
25
  required_inputs = ["image"]
26
- weight_urls = {
27
- "rord.pth": "https://drive.google.com/uc?id=12414ZGKwgPAjNTGtNrlB4VV9l7W76B2o&confirm=t",
28
- }
29
- proxy = "http://localhost:1080"
30
 
31
  def _init(self, conf):
32
- model_path = conf["checkpoint_dir"] / conf["model_name"]
33
- link = self.weight_urls[conf["model_name"]]
34
- if not model_path.exists():
35
- model_path.parent.mkdir(exist_ok=True)
36
- cmd_wo_proxy = ["gdown", link, "-O", str(model_path)]
37
- cmd = ["gdown", link, "-O", str(model_path), "--proxy", self.proxy]
38
- logger.info(f"Downloading the RoRD model with `{cmd_wo_proxy}`.")
39
- try:
40
- subprocess.run(cmd_wo_proxy, check=True)
41
- except subprocess.CalledProcessError as e:
42
- logger.info(f"Downloading failed {e}.")
43
- logger.info(f"Downloading the RoRD model with {cmd}.")
44
- try:
45
- subprocess.run(cmd, check=True)
46
- except subprocess.CalledProcessError as e:
47
- logger.error(f"Failed to download the RoRD model: {e}")
48
  self.net = _RoRD(
49
  model_file=model_path, use_relu=conf["use_relu"], use_cuda=False
50
  )
 
4
 
5
  import torch
6
 
7
+ from hloc import MODEL_REPO_ID, logger
8
 
9
  from ..utils.base_model import BaseModel
10
 
 
23
  "max_keypoints": 1024,
24
  }
25
  required_inputs = ["image"]
 
 
 
 
26
 
27
  def _init(self, conf):
28
+ model_path = self._download_model(
29
+ repo_id=MODEL_REPO_ID,
30
+ filename="{}/{}".format(
31
+ Path(__file__).stem, self.conf["model_name"]
32
+ ),
33
+ )
 
 
 
 
 
 
 
 
 
 
34
  self.net = _RoRD(
35
  model_file=model_path, use_relu=conf["use_relu"], use_cuda=False
36
  )
hloc/extractors/sfd2.py CHANGED
@@ -3,7 +3,7 @@ from pathlib import Path
3
 
4
  import torchvision.transforms as tvf
5
 
6
- from .. import logger
7
  from ..utils.base_model import BaseModel
8
 
9
  tp_path = Path(__file__).parent / "../../third_party"
@@ -24,7 +24,12 @@ class SFD2(BaseModel):
24
  self.norm_rgb = tvf.Normalize(
25
  mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
26
  )
27
- model_path = tp_path / "pram" / "weights" / self.conf["model_name"]
 
 
 
 
 
28
  self.net = load_sfd2(weight_path=model_path).eval()
29
 
30
  logger.info("Load SFD2 model done.")
 
3
 
4
  import torchvision.transforms as tvf
5
 
6
+ from .. import MODEL_REPO_ID, logger
7
  from ..utils.base_model import BaseModel
8
 
9
  tp_path = Path(__file__).parent / "../../third_party"
 
24
  self.norm_rgb = tvf.Normalize(
25
  mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
26
  )
27
+ model_path = self._download_model(
28
+ repo_id=MODEL_REPO_ID,
29
+ filename="{}/{}".format(
30
+ Path(__file__).stem, self.conf["model_name"]
31
+ ),
32
+ )
33
  self.net = load_sfd2(weight_path=model_path).eval()
34
 
35
  logger.info("Load SFD2 model done.")
hloc/matchers/aspanformer.py CHANGED
@@ -4,7 +4,7 @@ from pathlib import Path
4
 
5
  import torch
6
 
7
- from hloc import logger
8
  from hloc.utils.base_model import BaseModel
9
 
10
  sys.path.append(str(Path(__file__).parent / "../../third_party"))
@@ -17,59 +17,15 @@ aspanformer_path = Path(__file__).parent / "../../third_party/ASpanFormer"
17
 
18
  class ASpanFormer(BaseModel):
19
  default_conf = {
20
- "weights": "outdoor",
21
  "match_threshold": 0.2,
22
  "sinkhorn_iterations": 20,
23
  "max_keypoints": 2048,
24
  "config_path": aspanformer_path / "configs/aspan/outdoor/aspan_test.py",
25
- "model_name": "weights_aspanformer.tar",
26
  }
27
  required_inputs = ["image0", "image1"]
28
- proxy = "http://localhost:1080"
29
- aspanformer_models = {
30
- "weights_aspanformer.tar": "https://drive.google.com/uc?id=1eavM9dTkw9nbc-JqlVVfGPU5UvTTfc6k&confirm=t"
31
- }
32
 
33
  def _init(self, conf):
34
- model_path = (
35
- aspanformer_path / "weights" / Path(conf["weights"] + ".ckpt")
36
- )
37
- # Download the model.
38
- if not model_path.exists():
39
- # model_path.parent.mkdir(exist_ok=True)
40
- tar_path = aspanformer_path / conf["model_name"]
41
- if not tar_path.exists():
42
- link = self.aspanformer_models[conf["model_name"]]
43
- cmd = [
44
- "gdown",
45
- link,
46
- "-O",
47
- str(tar_path),
48
- "--proxy",
49
- self.proxy,
50
- ]
51
- cmd_wo_proxy = ["gdown", link, "-O", str(tar_path)]
52
- logger.info(
53
- f"Downloading the Aspanformer model with `{cmd_wo_proxy}`."
54
- )
55
- try:
56
- subprocess.run(cmd_wo_proxy, check=True)
57
- except subprocess.CalledProcessError as e:
58
- logger.info(f"Downloading failed {e}.")
59
- logger.info(
60
- f"Downloading the Aspanformer model with `{cmd}`."
61
- )
62
- try:
63
- subprocess.run(cmd, check=True)
64
- except subprocess.CalledProcessError as e:
65
- logger.error(
66
- f"Failed to download the Aspanformer model: {e}"
67
- )
68
-
69
- cmd = ["tar", "-xvf", str(tar_path), "-C", str(aspanformer_path)]
70
- logger.info(f"Unzip model file `{cmd}`.")
71
- subprocess.run(cmd, check=True)
72
-
73
  config = get_cfg_defaults()
74
  config.merge_from_file(conf["config_path"])
75
  _config = lower_config(config)
@@ -81,8 +37,13 @@ class ASpanFormer(BaseModel):
81
  ]
82
 
83
  self.net = _ASpanFormer(config=_config["aspan"])
84
- weight_path = model_path
85
- state_dict = torch.load(str(weight_path), map_location="cpu")[
 
 
 
 
 
86
  "state_dict"
87
  ]
88
  self.net.load_state_dict(state_dict, strict=False)
 
4
 
5
  import torch
6
 
7
+ from hloc import MODEL_REPO_ID, logger
8
  from hloc.utils.base_model import BaseModel
9
 
10
  sys.path.append(str(Path(__file__).parent / "../../third_party"))
 
17
 
18
  class ASpanFormer(BaseModel):
19
  default_conf = {
20
+ "model_name": "outdoor.ckpt",
21
  "match_threshold": 0.2,
22
  "sinkhorn_iterations": 20,
23
  "max_keypoints": 2048,
24
  "config_path": aspanformer_path / "configs/aspan/outdoor/aspan_test.py",
 
25
  }
26
  required_inputs = ["image0", "image1"]
 
 
 
 
27
 
28
  def _init(self, conf):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  config = get_cfg_defaults()
30
  config.merge_from_file(conf["config_path"])
31
  _config = lower_config(config)
 
37
  ]
38
 
39
  self.net = _ASpanFormer(config=_config["aspan"])
40
+ model_path = self._download_model(
41
+ repo_id=MODEL_REPO_ID,
42
+ filename="{}/{}".format(
43
+ Path(__file__).stem, self.conf["model_name"]
44
+ ),
45
+ )
46
+ state_dict = torch.load(str(model_path), map_location="cpu")[
47
  "state_dict"
48
  ]
49
  self.net.load_state_dict(state_dict, strict=False)
hloc/matchers/cotr.py CHANGED
@@ -6,6 +6,8 @@ import numpy as np
6
  import torch
7
  from torchvision.transforms import ToPILImage
8
 
 
 
9
  from ..utils.base_model import BaseModel
10
 
11
  sys.path.append(str(Path(__file__).parent / "../../third_party/COTR"))
@@ -18,16 +20,13 @@ from COTR.utils import utils as utils_cotr
18
  utils_cotr.fix_randomness(0)
19
  torch.set_grad_enabled(False)
20
 
21
- cotr_path = Path(__file__).parent / "../../third_party/COTR"
22
-
23
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
-
25
 
26
  class COTR(BaseModel):
27
  default_conf = {
28
  "weights": "out/default",
29
  "match_threshold": 0.2,
30
  "max_keypoints": -1,
 
31
  }
32
  required_inputs = ["image0", "image1"]
33
 
@@ -36,8 +35,11 @@ class COTR(BaseModel):
36
  set_COTR_arguments(parser) # noqa: F405
37
  opt = parser.parse_args()
38
  opt.command = " ".join(sys.argv)
39
- opt.load_weights_path = str(
40
- cotr_path / conf["weights"] / "checkpoint.pth.tar"
 
 
 
41
  )
42
 
43
  layer_2_channels = {
@@ -49,7 +51,7 @@ class COTR(BaseModel):
49
  opt.dim_feedforward = layer_2_channels[opt.layer]
50
 
51
  model = build_model(opt)
52
- model = model.to(device)
53
  weights = torch.load(opt.load_weights_path, map_location="cpu")[
54
  "model_state_dict"
55
  ]
 
6
  import torch
7
  from torchvision.transforms import ToPILImage
8
 
9
+ from hloc import DEVICE, MODEL_REPO_ID
10
+
11
  from ..utils.base_model import BaseModel
12
 
13
  sys.path.append(str(Path(__file__).parent / "../../third_party/COTR"))
 
20
  utils_cotr.fix_randomness(0)
21
  torch.set_grad_enabled(False)
22
 
 
 
 
 
23
 
24
  class COTR(BaseModel):
25
  default_conf = {
26
  "weights": "out/default",
27
  "match_threshold": 0.2,
28
  "max_keypoints": -1,
29
+ "model_name": "checkpoint.pth.tar",
30
  }
31
  required_inputs = ["image0", "image1"]
32
 
 
35
  set_COTR_arguments(parser) # noqa: F405
36
  opt = parser.parse_args()
37
  opt.command = " ".join(sys.argv)
38
+ opt.load_weights_path = self._download_model(
39
+ repo_id=MODEL_REPO_ID,
40
+ filename="{}/{}".format(
41
+ Path(__file__).stem, self.conf["model_name"]
42
+ ),
43
  )
44
 
45
  layer_2_channels = {
 
51
  opt.dim_feedforward = layer_2_channels[opt.layer]
52
 
53
  model = build_model(opt)
54
+ model = model.to(DEVICE)
55
  weights = torch.load(opt.load_weights_path, map_location="cpu")[
56
  "model_state_dict"
57
  ]
hloc/matchers/dkm.py CHANGED
@@ -1,48 +1,36 @@
1
- import subprocess
2
  import sys
3
  from pathlib import Path
4
 
5
  import torch
6
  from PIL import Image
7
 
8
- from .. import logger
9
- from ..utils.base_model import BaseModel
10
 
11
  sys.path.append(str(Path(__file__).parent / "../../third_party"))
12
  from DKM.dkm import DKMv3_outdoor
13
 
14
- dkm_path = Path(__file__).parent / "../../third_party/DKM"
15
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
-
17
 
18
  class DKMv3(BaseModel):
19
  default_conf = {
20
  "model_name": "DKMv3_outdoor.pth",
21
  "match_threshold": 0.2,
22
- "checkpoint_dir": dkm_path / "pretrained",
23
  "max_keypoints": -1,
24
  }
25
  required_inputs = [
26
  "image0",
27
  "image1",
28
  ]
29
- # Models exported using
30
- dkm_models = {
31
- "DKMv3_outdoor.pth": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_outdoor.pth",
32
- "DKMv3_indoor.pth": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_indoor.pth",
33
- }
34
 
35
  def _init(self, conf):
36
- model_path = dkm_path / "pretrained" / conf["model_name"]
 
 
 
 
 
37
 
38
- # Download the model.
39
- if not model_path.exists():
40
- model_path.parent.mkdir(exist_ok=True)
41
- link = self.dkm_models[conf["model_name"]]
42
- cmd = ["wget", "--quiet", link, "-O", str(model_path)]
43
- logger.info(f"Downloading the DKMv3 model with `{cmd}`.")
44
- subprocess.run(cmd, check=True)
45
- self.net = DKMv3_outdoor(path_to_weights=str(model_path), device=device)
46
  logger.info("Loading DKMv3 model done")
47
 
48
  def _forward(self, data):
@@ -55,7 +43,7 @@ class DKMv3(BaseModel):
55
  W_A, H_A = img0.size
56
  W_B, H_B = img1.size
57
 
58
- warp, certainty = self.net.match(img0, img1, device=device)
59
  matches, certainty = self.net.sample(
60
  warp, certainty, num=self.conf["max_keypoints"]
61
  )
 
 
1
  import sys
2
  from pathlib import Path
3
 
4
  import torch
5
  from PIL import Image
6
 
7
+ from hloc import DEVICE, MODEL_REPO_ID, logger
8
+ from hloc.utils.base_model import BaseModel
9
 
10
  sys.path.append(str(Path(__file__).parent / "../../third_party"))
11
  from DKM.dkm import DKMv3_outdoor
12
 
 
 
 
13
 
14
  class DKMv3(BaseModel):
15
  default_conf = {
16
  "model_name": "DKMv3_outdoor.pth",
17
  "match_threshold": 0.2,
 
18
  "max_keypoints": -1,
19
  }
20
  required_inputs = [
21
  "image0",
22
  "image1",
23
  ]
 
 
 
 
 
24
 
25
  def _init(self, conf):
26
+ model_path = self._download_model(
27
+ repo_id=MODEL_REPO_ID,
28
+ filename="{}/{}".format(
29
+ Path(__file__).stem, self.conf["model_name"]
30
+ ),
31
+ )
32
 
33
+ self.net = DKMv3_outdoor(path_to_weights=str(model_path), device=DEVICE)
 
 
 
 
 
 
 
34
  logger.info("Loading DKMv3 model done")
35
 
36
  def _forward(self, data):
 
43
  W_A, H_A = img0.size
44
  W_B, H_B = img1.size
45
 
46
+ warp, certainty = self.net.match(img0, img1, device=DEVICE)
47
  matches, certainty = self.net.sample(
48
  warp, certainty, num=self.conf["max_keypoints"]
49
  )
hloc/matchers/duster.py CHANGED
@@ -1,13 +1,11 @@
1
- import os
2
  import sys
3
- import urllib.request
4
  from pathlib import Path
5
 
6
  import numpy as np
7
  import torch
8
  import torchvision.transforms as tfm
9
 
10
- from .. import logger
11
  from ..utils.base_model import BaseModel
12
 
13
  duster_path = Path(__file__).parent / "../../third_party/dust3r"
@@ -25,30 +23,24 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
  class Duster(BaseModel):
26
  default_conf = {
27
  "name": "Duster3r",
28
- "model_path": duster_path / "model_weights/duster_vit_large.pth",
29
  "max_keypoints": 3000,
30
  "vit_patch_size": 16,
31
  }
32
 
33
  def _init(self, conf):
34
  self.normalize = tfm.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
35
- self.model_path = self.conf["model_path"]
36
- self.download_weights()
37
- # self.net = load_model(self.model_path, device)
38
- self.net = AsymmetricCroCo3DStereo.from_pretrained(
39
- self.model_path
40
- # "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt"
41
- ).to(device)
 
 
42
  logger.info("Loaded Dust3r model")
43
 
44
- def download_weights(self):
45
- url = "https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth"
46
-
47
- self.model_path.parent.mkdir(parents=True, exist_ok=True)
48
- if not os.path.isfile(self.model_path):
49
- logger.info("Downloading Duster(ViT large)... (takes a while)")
50
- urllib.request.urlretrieve(url, self.model_path)
51
-
52
  def preprocess(self, img):
53
  # the super-class already makes sure that img0,img1 have
54
  # same resolution and that h == w
 
 
1
  import sys
 
2
  from pathlib import Path
3
 
4
  import numpy as np
5
  import torch
6
  import torchvision.transforms as tfm
7
 
8
+ from .. import MODEL_REPO_ID, logger
9
  from ..utils.base_model import BaseModel
10
 
11
  duster_path = Path(__file__).parent / "../../third_party/dust3r"
 
23
  class Duster(BaseModel):
24
  default_conf = {
25
  "name": "Duster3r",
26
+ "model_name": "duster_vit_large.pth",
27
  "max_keypoints": 3000,
28
  "vit_patch_size": 16,
29
  }
30
 
31
  def _init(self, conf):
32
  self.normalize = tfm.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
33
+ model_path = self._download_model(
34
+ repo_id=MODEL_REPO_ID,
35
+ filename="{}/{}".format(
36
+ Path(__file__).stem, self.conf["model_name"]
37
+ ),
38
+ )
39
+ self.net = AsymmetricCroCo3DStereo.from_pretrained(model_path).to(
40
+ device
41
+ )
42
  logger.info("Loaded Dust3r model")
43
 
 
 
 
 
 
 
 
 
44
  def preprocess(self, img):
45
  # the super-class already makes sure that img0,img1 have
46
  # same resolution and that h == w
hloc/matchers/eloftr.py CHANGED
@@ -5,6 +5,8 @@ from pathlib import Path
5
 
6
  import torch
7
 
 
 
8
  tp_path = Path(__file__).parent / "../../third_party"
9
  sys.path.append(str(tp_path))
10
 
@@ -22,7 +24,7 @@ from ..utils.base_model import BaseModel
22
 
23
  class ELoFTR(BaseModel):
24
  default_conf = {
25
- "weights": "weights/eloftr_outdoor.ckpt",
26
  "match_threshold": 0.2,
27
  # "sinkhorn_iterations": 20,
28
  "max_keypoints": -1,
@@ -44,7 +46,14 @@ class ELoFTR(BaseModel):
44
  _default_cfg["mp"] = True
45
  elif self.conf["precision"] == "fp16":
46
  _default_cfg["half"] = True
47
- model_path = tp_path / "EfficientLoFTR" / self.conf["weights"]
 
 
 
 
 
 
 
48
  cfg = _default_cfg
49
  cfg["match_coarse"]["thr"] = conf["match_threshold"]
50
  # cfg["match_coarse"]["skh_iters"] = conf["sinkhorn_iterations"]
@@ -55,7 +64,7 @@ class ELoFTR(BaseModel):
55
 
56
  if self.conf["precision"] == "fp16":
57
  self.net = self.net.half()
58
- logger.info(f"Loaded Efficient LoFTR with weights {conf['weights']}")
59
 
60
  def _forward(self, data):
61
  # For consistency with hloc pairs, we refine kpts in image0!
 
5
 
6
  import torch
7
 
8
+ from hloc import MODEL_REPO_ID
9
+
10
  tp_path = Path(__file__).parent / "../../third_party"
11
  sys.path.append(str(tp_path))
12
 
 
24
 
25
  class ELoFTR(BaseModel):
26
  default_conf = {
27
+ "model_name": "eloftr_outdoor.ckpt",
28
  "match_threshold": 0.2,
29
  # "sinkhorn_iterations": 20,
30
  "max_keypoints": -1,
 
46
  _default_cfg["mp"] = True
47
  elif self.conf["precision"] == "fp16":
48
  _default_cfg["half"] = True
49
+
50
+ model_path = self._download_model(
51
+ repo_id=MODEL_REPO_ID,
52
+ filename="{}/{}".format(
53
+ Path(__file__).stem, self.conf["model_name"]
54
+ ),
55
+ )
56
+
57
  cfg = _default_cfg
58
  cfg["match_coarse"]["thr"] = conf["match_threshold"]
59
  # cfg["match_coarse"]["skh_iters"] = conf["sinkhorn_iterations"]
 
64
 
65
  if self.conf["precision"] == "fp16":
66
  self.net = self.net.half()
67
+ logger.info(f"Loaded Efficient LoFTR with weights {conf['model_name']}")
68
 
69
  def _forward(self, data):
70
  # For consistency with hloc pairs, we refine kpts in image0!
hloc/matchers/gim.py CHANGED
@@ -1,11 +1,9 @@
1
- import subprocess
2
  import sys
3
  from pathlib import Path
4
 
5
- import gdown
6
  import torch
7
 
8
- from .. import logger
9
  from ..utils.base_model import BaseModel
10
 
11
  gim_path = Path(__file__).parent / "../../third_party/gim"
@@ -13,12 +11,10 @@ sys.path.append(str(gim_path))
13
 
14
  from dkm.models.model_zoo.DKMv3 import DKMv3
15
 
16
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
-
18
 
19
  class GIM(BaseModel):
20
  default_conf = {
21
- "model_name": "gim_dkm_100h.ckpt",
22
  "match_threshold": 0.2,
23
  "checkpoint_dir": gim_path / "weights",
24
  }
@@ -26,27 +22,14 @@ class GIM(BaseModel):
26
  "image0",
27
  "image1",
28
  ]
29
- model_dict = {
30
- "gim_lightglue_100h.ckpt": "https://github.com/xuelunshen/gim/blob/main/weights/gim_lightglue_100h.ckpt",
31
- "gim_dkm_100h.ckpt": "https://drive.google.com/file/d/1gk97V4IROnR1Nprq10W9NCFUv2mxXR_-/view",
32
- }
33
 
34
  def _init(self, conf):
35
- conf["model_name"] = str(conf["weights"])
36
- if conf["model_name"] not in self.model_dict:
37
- raise ValueError(f"Unknown GIM model {conf['model_name']}.")
38
- model_path = conf["checkpoint_dir"] / conf["model_name"]
39
-
40
- # Download the model.
41
- if not model_path.exists():
42
- model_path.parent.mkdir(exist_ok=True)
43
- model_link = self.model_dict[conf["model_name"]]
44
- if "drive.google.com" in model_link:
45
- gdown.download(model_link, output=str(model_path), fuzzy=True)
46
- else:
47
- cmd = ["wget", "--quiet", model_link, "-O", str(model_path)]
48
- subprocess.run(cmd, check=True)
49
- logger.info("Downloaded GIM model succeeed!")
50
 
51
  self.aspect_ratio = 896 / 672
52
  model = DKMv3(None, 672, 896, upsample_preds=True)
 
 
1
  import sys
2
  from pathlib import Path
3
 
 
4
  import torch
5
 
6
+ from .. import MODEL_REPO_ID, logger
7
  from ..utils.base_model import BaseModel
8
 
9
  gim_path = Path(__file__).parent / "../../third_party/gim"
 
11
 
12
  from dkm.models.model_zoo.DKMv3 import DKMv3
13
 
 
 
14
 
15
  class GIM(BaseModel):
16
  default_conf = {
17
+ "model_name": "gim_lightglue_100h.ckpt",
18
  "match_threshold": 0.2,
19
  "checkpoint_dir": gim_path / "weights",
20
  }
 
22
  "image0",
23
  "image1",
24
  ]
 
 
 
 
25
 
26
  def _init(self, conf):
27
+ model_path = self._download_model(
28
+ repo_id=MODEL_REPO_ID,
29
+ filename="{}/{}".format(
30
+ Path(__file__).stem, self.conf["model_name"]
31
+ ),
32
+ )
 
 
 
 
 
 
 
 
 
33
 
34
  self.aspect_ratio = 896 / 672
35
  model = DKMv3(None, 672, 896, upsample_preds=True)
hloc/matchers/gluestick.py CHANGED
@@ -1,10 +1,9 @@
1
- import subprocess
2
  import sys
3
  from pathlib import Path
4
 
5
  import torch
6
 
7
- from .. import logger
8
  from ..utils.base_model import BaseModel
9
 
10
  gluestick_path = Path(__file__).parent / "../../third_party/GlueStick"
@@ -13,8 +12,6 @@ sys.path.append(str(gluestick_path))
13
  from gluestick import batch_to_np
14
  from gluestick.models.two_view_pipeline import TwoViewPipeline
15
 
16
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
-
18
 
19
  class GlueStick(BaseModel):
20
  default_conf = {
@@ -30,23 +27,15 @@ class GlueStick(BaseModel):
30
  "image1",
31
  ]
32
 
33
- gluestick_models = {
34
- "checkpoint_GlueStick_MD.tar": "https://github.com/cvg/GlueStick/releases/download/v0.1_arxiv/checkpoint_GlueStick_MD.tar",
35
- }
36
-
37
  # Initialize the line matcher
38
  def _init(self, conf):
39
- model_path = (
40
- gluestick_path / "resources" / "weights" / conf["model_name"]
41
- )
42
-
43
  # Download the model.
44
- if not model_path.exists():
45
- model_path.parent.mkdir(exist_ok=True)
46
- link = self.gluestick_models[conf["model_name"]]
47
- cmd = ["wget", "--quiet", link, "-O", str(model_path)]
48
- logger.info(f"Downloading the Gluestick model with `{cmd}`.")
49
- subprocess.run(cmd, check=True)
50
  logger.info("Loading GlueStick model...")
51
 
52
  gluestick_conf = {
 
 
1
  import sys
2
  from pathlib import Path
3
 
4
  import torch
5
 
6
+ from .. import MODEL_REPO_ID, logger
7
  from ..utils.base_model import BaseModel
8
 
9
  gluestick_path = Path(__file__).parent / "../../third_party/GlueStick"
 
12
  from gluestick import batch_to_np
13
  from gluestick.models.two_view_pipeline import TwoViewPipeline
14
 
 
 
15
 
16
  class GlueStick(BaseModel):
17
  default_conf = {
 
27
  "image1",
28
  ]
29
 
 
 
 
 
30
  # Initialize the line matcher
31
  def _init(self, conf):
 
 
 
 
32
  # Download the model.
33
+ model_path = self._download_model(
34
+ repo_id=MODEL_REPO_ID,
35
+ filename="{}/{}".format(
36
+ Path(__file__).stem, self.conf["model_name"]
37
+ ),
38
+ )
39
  logger.info("Loading GlueStick model...")
40
 
41
  gluestick_conf = {
hloc/matchers/imp.py CHANGED
@@ -3,7 +3,7 @@ from pathlib import Path
3
 
4
  import torch
5
 
6
- from .. import DEVICE, logger
7
  from ..utils.base_model import BaseModel
8
 
9
  tp_path = Path(__file__).parent / "../../third_party"
@@ -31,11 +31,17 @@ class IMP(BaseModel):
31
 
32
  def _init(self, conf):
33
  self.conf = {**self.default_conf, **conf}
34
- weight_path = tp_path / "pram" / "weights" / self.conf["model_name"]
 
 
 
 
 
 
35
  # self.net = nets.gml(self.conf).eval().to(DEVICE)
36
  self.net = GML(self.conf).eval().to(DEVICE)
37
  self.net.load_state_dict(
38
- torch.load(weight_path, map_location="cpu")["model"], strict=True
39
  )
40
  logger.info("Load IMP model done.")
41
 
 
3
 
4
  import torch
5
 
6
+ from .. import DEVICE, MODEL_REPO_ID, logger
7
  from ..utils.base_model import BaseModel
8
 
9
  tp_path = Path(__file__).parent / "../../third_party"
 
31
 
32
  def _init(self, conf):
33
  self.conf = {**self.default_conf, **conf}
34
+ model_path = self._download_model(
35
+ repo_id=MODEL_REPO_ID,
36
+ filename="{}/{}".format(
37
+ Path(__file__).stem, self.conf["model_name"]
38
+ ),
39
+ )
40
+
41
  # self.net = nets.gml(self.conf).eval().to(DEVICE)
42
  self.net = GML(self.conf).eval().to(DEVICE)
43
  self.net.load_state_dict(
44
+ torch.load(model_path, map_location="cpu")["model"], strict=True
45
  )
46
  logger.info("Load IMP model done.")
47
 
hloc/matchers/lightglue.py CHANGED
@@ -1,7 +1,7 @@
1
  import sys
2
  from pathlib import Path
3
 
4
- from .. import logger
5
  from ..utils.base_model import BaseModel
6
 
7
  lightglue_path = Path(__file__).parent / "../../third_party/LightGlue"
@@ -33,8 +33,13 @@ class LightGlue(BaseModel):
33
  ]
34
 
35
  def _init(self, conf):
36
- weight_path = lightglue_path / "weights" / conf["model_name"]
37
- conf["weights"] = str(weight_path)
 
 
 
 
 
38
  conf["filter_threshold"] = conf["match_threshold"]
39
  self.net = LG(**conf)
40
  logger.info("Load lightglue model done.")
 
1
  import sys
2
  from pathlib import Path
3
 
4
+ from .. import MODEL_REPO_ID, logger
5
  from ..utils.base_model import BaseModel
6
 
7
  lightglue_path = Path(__file__).parent / "../../third_party/LightGlue"
 
33
  ]
34
 
35
  def _init(self, conf):
36
+ model_path = self._download_model(
37
+ repo_id=MODEL_REPO_ID,
38
+ filename="{}/{}".format(
39
+ Path(__file__).stem, self.conf["model_name"]
40
+ ),
41
+ )
42
+ conf["weights"] = str(model_path)
43
  conf["filter_threshold"] = conf["match_threshold"]
44
  self.net = LG(**conf)
45
  logger.info("Load lightglue model done.")
hloc/matchers/mast3r.py CHANGED
@@ -1,13 +1,11 @@
1
- import os
2
  import sys
3
- import urllib.request
4
  from pathlib import Path
5
 
6
  import numpy as np
7
  import torch
8
  import torchvision.transforms as tfm
9
 
10
- from .. import logger
11
 
12
  mast3r_path = Path(__file__).parent / "../../third_party/mast3r"
13
  sys.path.append(str(mast3r_path))
@@ -22,38 +20,30 @@ from mast3r.model import AsymmetricMASt3R
22
 
23
  from hloc.matchers.duster import Duster
24
 
25
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
-
27
 
28
  class Mast3r(Duster):
29
  default_conf = {
30
  "name": "Mast3r",
31
- "model_path": mast3r_path
32
- / "model_weights/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth",
33
  "max_keypoints": 2000,
34
  "vit_patch_size": 16,
35
  }
36
 
37
  def _init(self, conf):
38
  self.normalize = tfm.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
39
- self.model_path = self.conf["model_path"]
40
- self.download_weights()
41
- self.net = AsymmetricMASt3R.from_pretrained(self.model_path).to(device)
 
 
 
 
42
  logger.info("Loaded Mast3r model")
43
 
44
- def download_weights(self):
45
- url = "https://download.europe.naverlabs.com/ComputerVision/MASt3R/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth"
46
-
47
- self.model_path.parent.mkdir(parents=True, exist_ok=True)
48
- if not os.path.isfile(self.model_path):
49
- logger.info("Downloading Mast3r(ViT large)... (takes a while)")
50
- urllib.request.urlretrieve(url, self.model_path)
51
- logger.info("Downloading Mast3r(ViT large)... done!")
52
-
53
  def _forward(self, data):
54
  img0, img1 = data["image0"], data["image1"]
55
- mean = torch.tensor([0.5, 0.5, 0.5]).to(device)
56
- std = torch.tensor([0.5, 0.5, 0.5]).to(device)
57
 
58
  img0 = (img0 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1)
59
  img1 = (img1 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1)
@@ -65,7 +55,7 @@ class Mast3r(Duster):
65
  pairs = make_pairs(
66
  images, scene_graph="complete", prefilter=None, symmetrize=True
67
  )
68
- output = inference(pairs, self.net, device, batch_size=1)
69
 
70
  # at this stage, you have the raw dust3r predictions
71
  _, pred1 = output["view1"], output["pred1"]
@@ -81,7 +71,7 @@ class Mast3r(Duster):
81
  desc1,
82
  desc2,
83
  subsample_or_initxy1=2,
84
- device=device,
85
  dist="dot",
86
  block_size=2**13,
87
  )
 
 
1
  import sys
 
2
  from pathlib import Path
3
 
4
  import numpy as np
5
  import torch
6
  import torchvision.transforms as tfm
7
 
8
+ from .. import DEVICE, MODEL_REPO_ID, logger
9
 
10
  mast3r_path = Path(__file__).parent / "../../third_party/mast3r"
11
  sys.path.append(str(mast3r_path))
 
20
 
21
  from hloc.matchers.duster import Duster
22
 
 
 
23
 
24
  class Mast3r(Duster):
25
  default_conf = {
26
  "name": "Mast3r",
27
+ "model_name": "MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth",
 
28
  "max_keypoints": 2000,
29
  "vit_patch_size": 16,
30
  }
31
 
32
  def _init(self, conf):
33
  self.normalize = tfm.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
34
+ model_path = self._download_model(
35
+ repo_id=MODEL_REPO_ID,
36
+ filename="{}/{}".format(
37
+ Path(__file__).stem, self.conf["model_name"]
38
+ ),
39
+ )
40
+ self.net = AsymmetricMASt3R.from_pretrained(model_path).to(DEVICE)
41
  logger.info("Loaded Mast3r model")
42
 
 
 
 
 
 
 
 
 
 
43
  def _forward(self, data):
44
  img0, img1 = data["image0"], data["image1"]
45
+ mean = torch.tensor([0.5, 0.5, 0.5]).to(DEVICE)
46
+ std = torch.tensor([0.5, 0.5, 0.5]).to(DEVICE)
47
 
48
  img0 = (img0 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1)
49
  img1 = (img1 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1)
 
55
  pairs = make_pairs(
56
  images, scene_graph="complete", prefilter=None, symmetrize=True
57
  )
58
+ output = inference(pairs, self.net, DEVICE, batch_size=1)
59
 
60
  # at this stage, you have the raw dust3r predictions
61
  _, pred1 = output["view1"], output["pred1"]
 
71
  desc1,
72
  desc2,
73
  subsample_or_initxy1=2,
74
+ device=DEVICE,
75
  dist="dot",
76
  block_size=2**13,
77
  )
hloc/matchers/mickey.py CHANGED
@@ -1,10 +1,9 @@
1
- import subprocess
2
  import sys
3
  from pathlib import Path
4
 
5
  import torch
6
 
7
- from .. import logger
8
  from ..utils.base_model import BaseModel
9
 
10
  mickey_path = Path(__file__).parent / "../../third_party"
@@ -13,8 +12,6 @@ sys.path.append(str(mickey_path))
13
  from mickey.config.default import cfg
14
  from mickey.lib.models.builder import build_model
15
 
16
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
-
18
 
19
  class Mickey(BaseModel):
20
  default_conf = {
@@ -26,33 +23,23 @@ class Mickey(BaseModel):
26
  "image0",
27
  "image1",
28
  ]
29
- weight_urls = "https://storage.googleapis.com/niantic-lon-static/research/mickey/assets/mickey_weights.zip"
30
 
31
  # Initialize the line matcher
32
  def _init(self, conf):
33
- model_path = mickey_path / "mickey/mickey_weights" / conf["model_name"]
34
- zip_path = mickey_path / "mickey/mickey_weights.zip"
 
 
 
 
 
35
  config_path = model_path.parent / self.conf["config_path"]
36
- # Download the model.
37
- if not model_path.exists():
38
- model_path.parent.mkdir(exist_ok=True, parents=True)
39
- link = self.weight_urls
40
- if not zip_path.exists():
41
- cmd = ["wget", "--quiet", link, "-O", str(zip_path)]
42
- logger.info(f"Downloading the Mickey model with {cmd}.")
43
- subprocess.run(cmd, check=True)
44
- cmd = ["unzip", "-d", str(model_path.parent.parent), str(zip_path)]
45
- logger.info(f"Running {cmd}.")
46
- subprocess.run(cmd, check=True)
47
-
48
  logger.info("Loading mickey model...")
49
  cfg.merge_from_file(config_path)
50
  self.net = build_model(cfg, checkpoint=model_path)
51
  logger.info("Load Mickey model done.")
52
 
53
  def _forward(self, data):
54
- # data['K_color0'] = torch.from_numpy(K['im0.jpg']).unsqueeze(0).to(device)
55
- # data['K_color1'] = torch.from_numpy(K['im1.jpg']).unsqueeze(0).to(device)
56
  pred = self.net(data)
57
  pred = {
58
  **pred,
 
 
1
  import sys
2
  from pathlib import Path
3
 
4
  import torch
5
 
6
+ from .. import MODEL_REPO_ID, logger
7
  from ..utils.base_model import BaseModel
8
 
9
  mickey_path = Path(__file__).parent / "../../third_party"
 
12
  from mickey.config.default import cfg
13
  from mickey.lib.models.builder import build_model
14
 
 
 
15
 
16
  class Mickey(BaseModel):
17
  default_conf = {
 
23
  "image0",
24
  "image1",
25
  ]
 
26
 
27
  # Initialize the line matcher
28
  def _init(self, conf):
29
+ model_path = self._download_model(
30
+ repo_id=MODEL_REPO_ID,
31
+ filename="{}/{}".format(
32
+ Path(__file__).stem, self.conf["model_name"]
33
+ ),
34
+ )
35
+ # TODO: config path of mickey
36
  config_path = model_path.parent / self.conf["config_path"]
 
 
 
 
 
 
 
 
 
 
 
 
37
  logger.info("Loading mickey model...")
38
  cfg.merge_from_file(config_path)
39
  self.net = build_model(cfg, checkpoint=model_path)
40
  logger.info("Load Mickey model done.")
41
 
42
  def _forward(self, data):
 
 
43
  pred = self.net(data)
44
  pred = {
45
  **pred,
hloc/matchers/omniglue.py CHANGED
@@ -5,7 +5,7 @@ from pathlib import Path
5
  import numpy as np
6
  import torch
7
 
8
- from .. import logger
9
  from ..utils.base_model import BaseModel
10
 
11
  thirdparty_path = Path(__file__).parent / "../../third_party"
@@ -27,19 +27,21 @@ class OmniGlue(BaseModel):
27
 
28
  def _init(self, conf):
29
  logger.info("Loading OmniGlue model")
30
- og_model_path = omniglue_path / "models" / "omniglue.onnx"
31
- sp_model_path = omniglue_path / "models" / "sp_v6.onnx"
32
- dino_model_path = (
33
- omniglue_path / "models" / "dinov2_vitb14_pretrain.pth" # ~330MB
34
  )
35
- if not dino_model_path.exists():
36
- link = self.dino_v2_link_dict.get(dino_model_path.name, None)
37
- if link is not None:
38
- cmd = ["wget", "--quiet", link, "-O", str(dino_model_path)]
39
- logger.info(f"Downloading the dinov2 model with `{cmd}`.")
40
- subprocess.run(cmd, check=True)
41
- else:
42
- logger.error(f"Invalid dinov2 model: {dino_model_path.name}")
 
 
 
43
  self.net = omniglue.OmniGlue(
44
  og_export=str(og_model_path),
45
  sp_export=str(sp_model_path),
 
5
  import numpy as np
6
  import torch
7
 
8
+ from .. import MODEL_REPO_ID, logger
9
  from ..utils.base_model import BaseModel
10
 
11
  thirdparty_path = Path(__file__).parent / "../../third_party"
 
27
 
28
  def _init(self, conf):
29
  logger.info("Loading OmniGlue model")
30
+ og_model_path = self._download_model(
31
+ repo_id=MODEL_REPO_ID,
32
+ filename="{}/{}".format(Path(__file__).stem, "omniglue.onnx"),
 
33
  )
34
+ sp_model_path = self._download_model(
35
+ repo_id=MODEL_REPO_ID,
36
+ filename="{}/{}".format(Path(__file__).stem, "sp_v6.onnx"),
37
+ )
38
+ dino_model_path = self._download_model(
39
+ repo_id=MODEL_REPO_ID,
40
+ filename="{}/{}".format(
41
+ Path(__file__).stem, "dinov2_vitb14_pretrain.pth"
42
+ ),
43
+ )
44
+
45
  self.net = omniglue.OmniGlue(
46
  og_export=str(og_model_path),
47
  sp_export=str(sp_model_path),
hloc/matchers/roma.py CHANGED
@@ -1,11 +1,10 @@
1
- import subprocess
2
  import sys
3
  from pathlib import Path
4
 
5
  import torch
6
  from PIL import Image
7
 
8
- from .. import logger
9
  from ..utils.base_model import BaseModel
10
 
11
  roma_path = Path(__file__).parent / "../../third_party/RoMa"
@@ -26,33 +25,22 @@ class Roma(BaseModel):
26
  "image0",
27
  "image1",
28
  ]
29
- weight_urls = {
30
- "roma": {
31
- "roma_outdoor.pth": "https://github.com/Parskatt/storage/releases/download/roma/roma_outdoor.pth",
32
- "roma_indoor.pth": "https://github.com/Parskatt/storage/releases/download/roma/roma_indoor.pth",
33
- },
34
- "dinov2_vitl14_pretrain.pth": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth",
35
- }
36
 
37
  # Initialize the line matcher
38
  def _init(self, conf):
39
- model_path = roma_path / "pretrained" / conf["model_name"]
40
- dinov2_weights = roma_path / "pretrained" / conf["model_utils_name"]
41
-
42
- # Download the model.
43
- if not model_path.exists():
44
- model_path.parent.mkdir(exist_ok=True)
45
- link = self.weight_urls["roma"][conf["model_name"]]
46
- cmd = ["wget", "--quiet", link, "-O", str(model_path)]
47
- logger.info(f"Downloading the Roma model with `{cmd}`.")
48
- subprocess.run(cmd, check=True)
49
 
50
- if not dinov2_weights.exists():
51
- dinov2_weights.parent.mkdir(exist_ok=True)
52
- link = self.weight_urls[conf["model_utils_name"]]
53
- cmd = ["wget", "--quiet", link, "-O", str(dinov2_weights)]
54
- logger.info(f"Downloading the dinov2 model with `{cmd}`.")
55
- subprocess.run(cmd, check=True)
56
 
57
  logger.info("Loading Roma model")
58
  # load the model
 
 
1
  import sys
2
  from pathlib import Path
3
 
4
  import torch
5
  from PIL import Image
6
 
7
+ from .. import MODEL_REPO_ID, logger
8
  from ..utils.base_model import BaseModel
9
 
10
  roma_path = Path(__file__).parent / "../../third_party/RoMa"
 
25
  "image0",
26
  "image1",
27
  ]
 
 
 
 
 
 
 
28
 
29
  # Initialize the line matcher
30
  def _init(self, conf):
31
+ model_path = self._download_model(
32
+ repo_id=MODEL_REPO_ID,
33
+ filename="{}/{}".format(
34
+ Path(__file__).stem, self.conf["model_name"]
35
+ ),
36
+ )
 
 
 
 
37
 
38
+ dinov2_weights = self._download_model(
39
+ repo_id=MODEL_REPO_ID,
40
+ filename="{}/{}".format(
41
+ Path(__file__).stem, self.conf["model_utils_name"]
42
+ ),
43
+ )
44
 
45
  logger.info("Loading Roma model")
46
  # load the model
hloc/matchers/sgmnet.py CHANGED
@@ -1,11 +1,10 @@
1
- import subprocess
2
  import sys
3
  from collections import OrderedDict, namedtuple
4
  from pathlib import Path
5
 
6
  import torch
7
 
8
- from .. import logger
9
  from ..utils.base_model import BaseModel
10
 
11
  sgmnet_path = Path(__file__).parent / "../../third_party/SGMNet"
@@ -19,7 +18,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
  class SGMNet(BaseModel):
20
  default_conf = {
21
  "name": "SGM",
22
- "model_name": "model_best.pth",
23
  "seed_top_k": [256, 256],
24
  "seed_radius_coe": 0.01,
25
  "net_channels": 128,
@@ -37,50 +36,20 @@ class SGMNet(BaseModel):
37
  "image0",
38
  "image1",
39
  ]
40
- weight_urls = {
41
- "model_best.pth": "https://drive.google.com/uc?id=1Ca0WmKSSt2G6P7m8YAOlSAHEFar_TAWb&confirm=t",
42
- }
43
- proxy = "http://localhost:1080"
44
 
45
  # Initialize the line matcher
46
  def _init(self, conf):
47
- sgmnet_weights = sgmnet_path / "weights/sgm/root" / conf["model_name"]
48
-
49
- link = self.weight_urls[conf["model_name"]]
50
- tar_path = sgmnet_path / "weights.tar.gz"
51
- # Download the model.
52
- if not sgmnet_weights.exists():
53
- if not tar_path.exists():
54
- cmd = [
55
- "gdown",
56
- link,
57
- "-O",
58
- str(tar_path),
59
- "--proxy",
60
- self.proxy,
61
- ]
62
- cmd_wo_proxy = ["gdown", link, "-O", str(tar_path)]
63
- logger.info(
64
- f"Downloading the SGMNet model with `{cmd_wo_proxy}`."
65
- )
66
- try:
67
- subprocess.run(cmd_wo_proxy, check=True)
68
- except subprocess.CalledProcessError as e:
69
- logger.info(f"Downloading failed {e}.")
70
- logger.info(f"Downloading the SGMNet model with `{cmd}`.")
71
- try:
72
- subprocess.run(cmd, check=True)
73
- except subprocess.CalledProcessError as e:
74
- logger.error("Failed to download the SGMNet model.")
75
- raise e
76
- cmd = ["tar", "-xvf", str(tar_path), "-C", str(sgmnet_path)]
77
- logger.info(f"Unzip model file `{cmd}`.")
78
- subprocess.run(cmd, check=True)
79
 
80
  # config
81
  config = namedtuple("config", conf.keys())(*conf.values())
82
  self.net = SGM_Model(config)
83
- checkpoint = torch.load(sgmnet_weights, map_location="cpu")
84
  # for ddp model
85
  if (
86
  list(checkpoint["state_dict"].items())[0][0].split(".")[0]
 
 
1
  import sys
2
  from collections import OrderedDict, namedtuple
3
  from pathlib import Path
4
 
5
  import torch
6
 
7
+ from .. import MODEL_REPO_ID, logger
8
  from ..utils.base_model import BaseModel
9
 
10
  sgmnet_path = Path(__file__).parent / "../../third_party/SGMNet"
 
18
  class SGMNet(BaseModel):
19
  default_conf = {
20
  "name": "SGM",
21
+ "model_name": "weights/sgm/root/model_best.pth",
22
  "seed_top_k": [256, 256],
23
  "seed_radius_coe": 0.01,
24
  "net_channels": 128,
 
36
  "image0",
37
  "image1",
38
  ]
 
 
 
 
39
 
40
  # Initialize the line matcher
41
  def _init(self, conf):
42
+ model_path = self._download_model(
43
+ repo_id=MODEL_REPO_ID,
44
+ filename="{}/{}".format(
45
+ Path(__file__).stem, self.conf["model_name"]
46
+ ),
47
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  # config
50
  config = namedtuple("config", conf.keys())(*conf.values())
51
  self.net = SGM_Model(config)
52
+ checkpoint = torch.load(model_path, map_location="cpu")
53
  # for ddp model
54
  if (
55
  list(checkpoint["state_dict"].items())[0][0].split(".")[0]
hloc/matchers/sold2.py CHANGED
@@ -1,10 +1,9 @@
1
- import subprocess
2
  import sys
3
  from pathlib import Path
4
 
5
  import torch
6
 
7
- from .. import logger
8
  from ..utils.base_model import BaseModel
9
 
10
  sold2_path = Path(__file__).parent / "../../third_party/SOLD2"
@@ -17,7 +16,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
  class SOLD2(BaseModel):
19
  default_conf = {
20
- "weights": "sold2_wireframe.tar",
21
  "match_threshold": 0.2,
22
  "checkpoint_dir": sold2_path / "pretrained",
23
  "detect_thresh": 0.25,
@@ -31,21 +30,15 @@ class SOLD2(BaseModel):
31
  "image1",
32
  ]
33
 
34
- weight_urls = {
35
- "sold2_wireframe.tar": "https://www.polybox.ethz.ch/index.php/s/blOrW89gqSLoHOk/download",
36
- }
37
-
38
  # Initialize the line matcher
39
  def _init(self, conf):
40
- checkpoint_path = conf["checkpoint_dir"] / conf["weights"]
41
-
42
- # Download the model.
43
- if not checkpoint_path.exists():
44
- checkpoint_path.parent.mkdir(exist_ok=True)
45
- link = self.weight_urls[conf["weights"]]
46
- cmd = ["wget", "--quiet", link, "-O", str(checkpoint_path)]
47
- logger.info(f"Downloading the SOLD2 model with `{cmd}`.")
48
- subprocess.run(cmd, check=True)
49
 
50
  mode = "dynamic" # 'dynamic' or 'static'
51
  match_config = {
@@ -127,7 +120,7 @@ class SOLD2(BaseModel):
127
  }
128
  self.net = LineMatcher(
129
  match_config["model_cfg"],
130
- checkpoint_path,
131
  device,
132
  match_config["line_detector_cfg"],
133
  match_config["line_matcher_cfg"],
 
 
1
  import sys
2
  from pathlib import Path
3
 
4
  import torch
5
 
6
+ from .. import MODEL_REPO_ID, logger
7
  from ..utils.base_model import BaseModel
8
 
9
  sold2_path = Path(__file__).parent / "../../third_party/SOLD2"
 
16
 
17
  class SOLD2(BaseModel):
18
  default_conf = {
19
+ "model_name": "sold2_wireframe.tar",
20
  "match_threshold": 0.2,
21
  "checkpoint_dir": sold2_path / "pretrained",
22
  "detect_thresh": 0.25,
 
30
  "image1",
31
  ]
32
 
 
 
 
 
33
  # Initialize the line matcher
34
  def _init(self, conf):
35
+ model_path = self._download_model(
36
+ repo_id=MODEL_REPO_ID,
37
+ filename="{}/{}".format(
38
+ Path(__file__).stem, self.conf["model_name"]
39
+ ),
40
+ )
41
+ logger.info("Loading SOLD2 model: {}".format(model_path))
 
 
42
 
43
  mode = "dynamic" # 'dynamic' or 'static'
44
  match_config = {
 
120
  }
121
  self.net = LineMatcher(
122
  match_config["model_cfg"],
123
+ model_path,
124
  device,
125
  match_config["line_detector_cfg"],
126
  match_config["line_matcher_cfg"],
hloc/matchers/topicfm.py CHANGED
@@ -3,6 +3,8 @@ from pathlib import Path
3
 
4
  import torch
5
 
 
 
6
  from ..utils.base_model import BaseModel
7
 
8
  sys.path.append(str(Path(__file__).parent / "../../third_party"))
@@ -15,6 +17,7 @@ topicfm_path = Path(__file__).parent / "../../third_party/TopicFM"
15
  class TopicFM(BaseModel):
16
  default_conf = {
17
  "weights": "outdoor",
 
18
  "match_threshold": 0.2,
19
  "n_sampling_topics": 4,
20
  "max_keypoints": -1,
@@ -25,9 +28,14 @@ class TopicFM(BaseModel):
25
  _conf = dict(get_model_cfg())
26
  _conf["match_coarse"]["thr"] = conf["match_threshold"]
27
  _conf["coarse"]["n_samples"] = conf["n_sampling_topics"]
28
- weight_path = topicfm_path / "pretrained/model_best.ckpt"
 
 
 
 
 
29
  self.net = _TopicFM(config=_conf)
30
- ckpt_dict = torch.load(weight_path, map_location="cpu")
31
  self.net.load_state_dict(ckpt_dict["state_dict"])
32
 
33
  def _forward(self, data):
 
3
 
4
  import torch
5
 
6
+ from hloc import MODEL_REPO_ID
7
+
8
  from ..utils.base_model import BaseModel
9
 
10
  sys.path.append(str(Path(__file__).parent / "../../third_party"))
 
17
  class TopicFM(BaseModel):
18
  default_conf = {
19
  "weights": "outdoor",
20
+ "model_name": "model_best.ckpt",
21
  "match_threshold": 0.2,
22
  "n_sampling_topics": 4,
23
  "max_keypoints": -1,
 
28
  _conf = dict(get_model_cfg())
29
  _conf["match_coarse"]["thr"] = conf["match_threshold"]
30
  _conf["coarse"]["n_samples"] = conf["n_sampling_topics"]
31
+ model_path = self._download_model(
32
+ repo_id=MODEL_REPO_ID,
33
+ filename="{}/{}".format(
34
+ Path(__file__).stem, self.conf["model_name"]
35
+ ),
36
+ )
37
  self.net = _TopicFM(config=_conf)
38
+ ckpt_dict = torch.load(model_path, map_location="cpu")
39
  self.net.load_state_dict(ckpt_dict["state_dict"])
40
 
41
  def _forward(self, data):
hloc/utils/base_model.py CHANGED
@@ -3,6 +3,7 @@ from abc import ABCMeta, abstractmethod
3
  from torch import nn
4
  from copy import copy
5
  import inspect
 
6
 
7
 
8
  class BaseModel(nn.Module, metaclass=ABCMeta):
@@ -32,7 +33,14 @@ class BaseModel(nn.Module, metaclass=ABCMeta):
32
  def _forward(self, data):
33
  """To be implemented by the child class."""
34
  raise NotImplementedError
35
-
 
 
 
 
 
 
 
36
 
37
  def dynamic_load(root, model):
38
  module_path = f"{root.__name__}.{model}"
 
3
  from torch import nn
4
  from copy import copy
5
  import inspect
6
+ from huggingface_hub import hf_hub_download
7
 
8
 
9
  class BaseModel(nn.Module, metaclass=ABCMeta):
 
33
  def _forward(self, data):
34
  """To be implemented by the child class."""
35
  raise NotImplementedError
36
+
37
+ def _download_model(self, repo_id=None, filename=None, **kwargs):
38
+ """Download model from hf hub and return the path."""
39
+ return hf_hub_download(
40
+ repo_type="model",
41
+ repo_id=repo_id,
42
+ filename=filename,
43
+ )
44
 
45
  def dynamic_load(root, model):
46
  module_path = f"{root.__name__}.{model}"
requirements.txt CHANGED
@@ -2,7 +2,7 @@ e2cnn
2
  einops
3
  easydict
4
  gdown
5
- gradio==5.0.1
6
  h5py
7
  huggingface_hub
8
  imageio
 
2
  einops
3
  easydict
4
  gdown
5
+ # gradio==5.4.0
6
  h5py
7
  huggingface_hub
8
  imageio
third_party/ALIKE/assets/ALIKE_code.zip DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:891e8431c047e7aeed77c9e5f64ffeed262d92389d8ae6235dde0964a9048a08
3
- size 62774
 
 
 
 
third_party/ALIKE/assets/alike.png DELETED

Git LFS Details

  • SHA256: d35e59f8e4d9c34b0e2686ecd5ca5414fe975b81553e4968eccc4bff1535c2d4
  • Pointer size: 131 Bytes
  • Size of remote file: 162 kB
third_party/ALIKE/assets/kitti.gif DELETED

Git LFS Details

  • SHA256: 0b05e4dc0000b9abf53183a3ebdfc0b95a92513952e235ea24f27f2945389ea1
  • Pointer size: 132 Bytes
  • Size of remote file: 7.03 MB
third_party/ALIKE/assets/kitti/000100.png DELETED

Git LFS Details

  • SHA256: c8d4a81ad91c7945cabd15de286aacf27ab661163b5eee0177128721782d5405
  • Pointer size: 131 Bytes
  • Size of remote file: 273 kB
third_party/ALIKE/assets/kitti/000101.png DELETED

Git LFS Details

  • SHA256: 539c684432726e903191a2471c8dae8c4b0012b88e1b3af7590de08c24890327
  • Pointer size: 131 Bytes
  • Size of remote file: 272 kB
third_party/ALIKE/assets/kitti/000102.png DELETED

Git LFS Details

  • SHA256: 5bbc9a5b04bd425a5e146f3ba114027041086477a5fa123a50463932ab62617e
  • Pointer size: 131 Bytes
  • Size of remote file: 270 kB
third_party/ALIKE/assets/kitti/000103.png DELETED

Git LFS Details

  • SHA256: 2041e633aeb85022b1222277cace17132bed09ca19856d1e6787984b05d61339
  • Pointer size: 131 Bytes
  • Size of remote file: 271 kB
third_party/ALIKE/assets/kitti/000104.png DELETED

Git LFS Details

  • SHA256: 6ca8a30c0edb7d2c6d6e5c2f5317bdffdae2269157d69e71f9602e0bbf2090ab
  • Pointer size: 131 Bytes
  • Size of remote file: 271 kB
third_party/ALIKE/assets/kitti/000105.png DELETED

Git LFS Details

  • SHA256: b8bca67672e8b2181b193f0577a9a3b42b64df9bb57d98608dbdbb54e79925bd
  • Pointer size: 131 Bytes
  • Size of remote file: 270 kB
third_party/ALIKE/assets/kitti/000106.png DELETED

Git LFS Details

  • SHA256: 2ccc83d57703afdcda4afd746dd99458b425fbc11ce3155583abde25e988e389
  • Pointer size: 131 Bytes
  • Size of remote file: 269 kB
third_party/ALIKE/assets/kitti/000107.png DELETED

Git LFS Details

  • SHA256: 980f4c74ac9117020f954cc75718cf0a09baeb30894aea123db59f9e4555ecef
  • Pointer size: 131 Bytes
  • Size of remote file: 269 kB
third_party/ALIKE/assets/kitti/000108.png DELETED

Git LFS Details

  • SHA256: c7c2234c8ba8c056c452a0d625db6eac09c8963b0c5e8a5d0b1c3af15a4b7516
  • Pointer size: 131 Bytes
  • Size of remote file: 271 kB
third_party/ALIKE/assets/kitti/000109.png DELETED

Git LFS Details

  • SHA256: 6a34b9639806e7deefe1cb24ae7b376343d394d2d032f95e763e4b6921cd61c7
  • Pointer size: 131 Bytes
  • Size of remote file: 276 kB
third_party/ALIKE/assets/kitti/000110.png DELETED

Git LFS Details

  • SHA256: 6af1b3e55b9c1eac208c887c44592f93e8ae7cc0196acaa2639c265f8bf959e3
  • Pointer size: 131 Bytes
  • Size of remote file: 275 kB
third_party/ALIKE/assets/kitti/000111.png DELETED

Git LFS Details

  • SHA256: 215ed5306f4976458110836a620dcf55030d8dd20618e6365d60176988c1cfa6
  • Pointer size: 131 Bytes
  • Size of remote file: 276 kB
third_party/ALIKE/assets/kitti/000112.png DELETED

Git LFS Details

  • SHA256: 8a265252457871d4dd2f17c42eafa1c0da99df90d103c653c8097aad26073d22
  • Pointer size: 131 Bytes
  • Size of remote file: 276 kB
third_party/ALIKE/assets/kitti/000113.png DELETED

Git LFS Details

  • SHA256: c83f220b29b5d04ead44c9304f9eccde3a4ff4e60627d7014f8fe424afb873f4
  • Pointer size: 131 Bytes
  • Size of remote file: 276 kB
third_party/ALIKE/assets/kitti/000114.png DELETED

Git LFS Details

  • SHA256: 1abad021db35c21f2e9ac0ce7e54a5721eec3ff32bc4ce820f5b7091af4d6fac
  • Pointer size: 131 Bytes
  • Size of remote file: 276 kB
third_party/ALIKE/assets/kitti/000115.png DELETED

Git LFS Details

  • SHA256: 6be815b2b0aa8aa3dc47e314ed6645eeb474996e9a920fab2abe8a35fb3ea089
  • Pointer size: 131 Bytes
  • Size of remote file: 274 kB
third_party/ALIKE/assets/kitti/000116.png DELETED

Git LFS Details

  • SHA256: 96b8df04ee570d877a04e43f1f4c30abc7e7383b24ce70a1a83a82dcbd863293
  • Pointer size: 131 Bytes
  • Size of remote file: 271 kB