Vincentqyw commited on
Commit
4ede021
Β·
1 Parent(s): e8fe67e

fix: model path

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. hloc/extractors/alike.py +1 -1
  2. hloc/extractors/darkfeat.py +1 -1
  3. hloc/extractors/lanet.py +0 -4
  4. hloc/extractors/sfd2.py +1 -1
  5. hloc/match_dense.py +2 -1
  6. hloc/matchers/gim.py +89 -18
  7. hloc/matchers/imp.py +1 -1
  8. hloc/matchers/lightglue.py +1 -0
  9. third_party/gim/gim/__init__.py +2 -0
  10. third_party/gim/{dkm β†’ gim/dkm}/__init__.py +0 -0
  11. third_party/gim/{dkm β†’ gim/dkm}/benchmarks/__init__.py +0 -0
  12. third_party/gim/{dkm β†’ gim/dkm}/benchmarks/hpatches_sequences_homog_benchmark.py +0 -0
  13. third_party/gim/{dkm β†’ gim/dkm}/benchmarks/megadepth1500_benchmark.py +0 -0
  14. third_party/gim/{dkm β†’ gim/dkm}/benchmarks/megadepth_dense_benchmark.py +0 -0
  15. third_party/gim/{dkm β†’ gim/dkm}/benchmarks/scannet_benchmark.py +0 -0
  16. third_party/gim/{dkm β†’ gim/dkm}/checkpointing/__init__.py +0 -0
  17. third_party/gim/{dkm β†’ gim/dkm}/checkpointing/checkpoint.py +0 -0
  18. third_party/gim/{dkm β†’ gim/dkm}/datasets/__init__.py +0 -0
  19. third_party/gim/{dkm β†’ gim/dkm}/datasets/megadepth.py +0 -0
  20. third_party/gim/{dkm β†’ gim/dkm}/datasets/scannet.py +0 -0
  21. third_party/gim/{dkm β†’ gim/dkm}/losses/__init__.py +0 -0
  22. third_party/gim/{dkm β†’ gim/dkm}/losses/depth_match_regression_loss.py +0 -0
  23. third_party/gim/{dkm β†’ gim/dkm}/models/__init__.py +0 -0
  24. third_party/gim/{dkm β†’ gim/dkm}/models/dkm.py +3 -3
  25. third_party/gim/{dkm β†’ gim/dkm}/models/encoders.py +0 -0
  26. third_party/gim/{dkm β†’ gim/dkm}/models/model_zoo/DKMv3.py +2 -2
  27. third_party/gim/{dkm β†’ gim/dkm}/models/model_zoo/__init__.py +0 -0
  28. third_party/gim/{dkm β†’ gim/dkm}/train/__init__.py +0 -0
  29. third_party/gim/{dkm β†’ gim/dkm}/train/train.py +0 -0
  30. third_party/gim/{dkm β†’ gim/dkm}/utils/__init__.py +0 -0
  31. third_party/gim/{dkm β†’ gim/dkm}/utils/kde.py +0 -0
  32. third_party/gim/{dkm β†’ gim/dkm}/utils/local_correlation.py +0 -0
  33. third_party/gim/{dkm β†’ gim/dkm}/utils/transforms.py +0 -0
  34. third_party/gim/{dkm β†’ gim/dkm}/utils/utils.py +0 -0
  35. third_party/gim/{gluefactory β†’ gim/gluefactory}/__init__.py +0 -0
  36. third_party/gim/{gluefactory β†’ gim/gluefactory}/configs/aliked+NN.yaml +0 -0
  37. third_party/gim/{gluefactory β†’ gim/gluefactory}/configs/aliked+lightglue-official.yaml +0 -0
  38. third_party/gim/{gluefactory β†’ gim/gluefactory}/configs/aliked+lightglue_homography.yaml +0 -0
  39. third_party/gim/{gluefactory β†’ gim/gluefactory}/configs/aliked+lightglue_megadepth.yaml +0 -0
  40. third_party/gim/{gluefactory β†’ gim/gluefactory}/configs/disk+NN.yaml +0 -0
  41. third_party/gim/{gluefactory β†’ gim/gluefactory}/configs/disk+lightglue-official.yaml +0 -0
  42. third_party/gim/{gluefactory β†’ gim/gluefactory}/configs/disk+lightglue_homography.yaml +0 -0
  43. third_party/gim/{gluefactory β†’ gim/gluefactory}/configs/disk+lightglue_megadepth.yaml +0 -0
  44. third_party/gim/{gluefactory β†’ gim/gluefactory}/configs/sift+NN.yaml +0 -0
  45. third_party/gim/{gluefactory β†’ gim/gluefactory}/configs/sift+lightglue-official.yaml +0 -0
  46. third_party/gim/{gluefactory β†’ gim/gluefactory}/configs/sift+lightglue_homography.yaml +0 -0
  47. third_party/gim/{gluefactory β†’ gim/gluefactory}/configs/sift+lightglue_megadepth.yaml +0 -0
  48. third_party/gim/{gluefactory β†’ gim/gluefactory}/configs/superpoint+NN.yaml +0 -0
  49. third_party/gim/{gluefactory β†’ gim/gluefactory}/configs/superpoint+lightglue-official.yaml +0 -0
  50. third_party/gim/{gluefactory β†’ gim/gluefactory}/configs/superpoint+lightglue_homography.yaml +0 -0
hloc/extractors/alike.py CHANGED
@@ -36,13 +36,13 @@ class Alike(BaseModel):
36
  ),
37
  )
38
  logger.info("Loaded Alike model from {}".format(model_path))
 
39
  self.net = Alike_(
40
  **configs[conf["model_name"]],
41
  device=device,
42
  top_k=conf["top_k"],
43
  scores_th=conf["detection_threshold"],
44
  n_limit=conf["max_keypoints"],
45
- model_path=model_path,
46
  )
47
  logger.info("Load Alike model done.")
48
 
 
36
  ),
37
  )
38
  logger.info("Loaded Alike model from {}".format(model_path))
39
+ configs[conf["model_name"]]["model_path"] = model_path
40
  self.net = Alike_(
41
  **configs[conf["model_name"]],
42
  device=device,
43
  top_k=conf["top_k"],
44
  scores_th=conf["detection_threshold"],
45
  n_limit=conf["max_keypoints"],
 
46
  )
47
  logger.info("Load Alike model done.")
48
 
hloc/extractors/darkfeat.py CHANGED
@@ -23,7 +23,7 @@ class DarkFeat(BaseModel):
23
  def _init(self, conf):
24
  model_path = self._download_model(
25
  repo_id=MODEL_REPO_ID,
26
- filename="{}/{}.pth".format(
27
  Path(__file__).stem, self.conf["model_name"]
28
  ),
29
  )
 
23
  def _init(self, conf):
24
  model_path = self._download_model(
25
  repo_id=MODEL_REPO_ID,
26
+ filename="{}/{}".format(
27
  Path(__file__).stem, self.conf["model_name"]
28
  ),
29
  )
hloc/extractors/lanet.py CHANGED
@@ -33,10 +33,6 @@ class LANet(BaseModel):
33
  Path(__file__).stem, self.conf["model_name"]
34
  ),
35
  )
36
- if not model_path.exists():
37
- logger.warning(
38
- f"No model found at {model_path}, please download it first."
39
- )
40
  self.net = PointModel(is_test=True)
41
  state_dict = torch.load(model_path, map_location="cpu")
42
  self.net.load_state_dict(state_dict["model_state"])
 
33
  Path(__file__).stem, self.conf["model_name"]
34
  ),
35
  )
 
 
 
 
36
  self.net = PointModel(is_test=True)
37
  state_dict = torch.load(model_path, map_location="cpu")
38
  self.net.load_state_dict(state_dict["model_state"])
hloc/extractors/sfd2.py CHANGED
@@ -27,7 +27,7 @@ class SFD2(BaseModel):
27
  model_path = self._download_model(
28
  repo_id=MODEL_REPO_ID,
29
  filename="{}/{}".format(
30
- Path(__file__).stem, self.conf["model_name"]
31
  ),
32
  )
33
  self.net = load_sfd2(weight_path=model_path).eval()
 
27
  model_path = self._download_model(
28
  repo_id=MODEL_REPO_ID,
29
  filename="{}/{}".format(
30
+ "pram", self.conf["model_name"]
31
  ),
32
  )
33
  self.net = load_sfd2(weight_path=model_path).eval()
hloc/match_dense.py CHANGED
@@ -257,6 +257,7 @@ confs = {
257
  "model": {
258
  "name": "roma",
259
  "weights": "outdoor",
 
260
  "max_keypoints": 2000,
261
  "match_threshold": 0.2,
262
  },
@@ -273,7 +274,7 @@ confs = {
273
  "output": "matches-gim",
274
  "model": {
275
  "name": "gim",
276
- "weights": "gim_dkm_100h.ckpt",
277
  "max_keypoints": 2000,
278
  "match_threshold": 0.2,
279
  },
 
257
  "model": {
258
  "name": "roma",
259
  "weights": "outdoor",
260
+ "model_name": "roma_outdoor.pth",
261
  "max_keypoints": 2000,
262
  "match_threshold": 0.2,
263
  },
 
274
  "output": "matches-gim",
275
  "model": {
276
  "name": "gim",
277
+ "model_name": "gim_dkm_100h.ckpt",
278
  "max_keypoints": 2000,
279
  "match_threshold": 0.2,
280
  },
hloc/matchers/gim.py CHANGED
@@ -3,46 +3,116 @@ from pathlib import Path
3
 
4
  import torch
5
 
6
- from .. import MODEL_REPO_ID, logger
7
  from ..utils.base_model import BaseModel
8
 
9
  gim_path = Path(__file__).parent / "../../third_party/gim"
10
  sys.path.append(str(gim_path))
11
 
12
- from dkm.models.model_zoo.DKMv3 import DKMv3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
 
15
  class GIM(BaseModel):
16
  default_conf = {
17
- "model_name": "gim_lightglue_100h.ckpt",
18
  "match_threshold": 0.2,
19
  "checkpoint_dir": gim_path / "weights",
 
20
  }
21
  required_inputs = [
22
  "image0",
23
  "image1",
24
  ]
 
 
 
 
 
25
 
26
  def _init(self, conf):
 
27
  model_path = self._download_model(
28
  repo_id=MODEL_REPO_ID,
29
- filename="{}/{}".format(
30
- Path(__file__).stem, self.conf["model_name"]
31
- ),
32
  )
33
-
34
  self.aspect_ratio = 896 / 672
35
- model = DKMv3(None, 672, 896, upsample_preds=True)
36
- state_dict = torch.load(str(model_path), map_location="cpu")
37
- if "state_dict" in state_dict.keys():
38
- state_dict = state_dict["state_dict"]
39
- for k in list(state_dict.keys()):
40
- if k.startswith("model."):
41
- state_dict[k.replace("model.", "", 1)] = state_dict.pop(k)
42
- if "encoder.net.fc" in k:
43
- state_dict.pop(k)
44
- model.load_state_dict(state_dict)
45
-
46
  self.net = model
47
  logger.info("Loaded GIM model")
48
 
@@ -94,6 +164,7 @@ class GIM(BaseModel):
94
  return mask
95
 
96
  def _forward(self, data):
 
97
  image0, image1 = self.pad_image(
98
  data["image0"], self.aspect_ratio
99
  ), self.pad_image(data["image1"], self.aspect_ratio)
 
3
 
4
  import torch
5
 
6
+ from .. import MODEL_REPO_ID, logger, DEVICE
7
  from ..utils.base_model import BaseModel
8
 
9
  gim_path = Path(__file__).parent / "../../third_party/gim"
10
  sys.path.append(str(gim_path))
11
 
12
+ def load_model(weight_name, checkpoints_path):
13
+ # load model
14
+ model = None
15
+ detector = None
16
+ if weight_name == "gim_dkm":
17
+ from gim.dkm.models.model_zoo.DKMv3 import DKMv3
18
+ model = DKMv3(weights=None, h=672, w=896)
19
+ elif weight_name == "gim_loftr":
20
+ from gim.loftr.loftr import LoFTR
21
+ from gim.loftr.misc import lower_config
22
+ from gim.loftr.config import get_cfg_defaults
23
+
24
+ model = LoFTR(lower_config(get_cfg_defaults())["loftr"])
25
+ elif weight_name == "gim_lightglue":
26
+ from gim.lightglue.superpoint import SuperPoint
27
+ from gim.lightglue.models.matchers.lightglue import LightGlue
28
+
29
+ detector = SuperPoint(
30
+ {
31
+ "max_num_keypoints": 2048,
32
+ "force_num_keypoints": True,
33
+ "detection_threshold": 0.0,
34
+ "nms_radius": 3,
35
+ "trainable": False,
36
+ }
37
+ )
38
+ model = LightGlue(
39
+ {
40
+ "filter_threshold": 0.1,
41
+ "flash": False,
42
+ "checkpointed": True,
43
+ }
44
+ )
45
+
46
+ # load state dict
47
+ if weight_name == "gim_dkm":
48
+ state_dict = torch.load(checkpoints_path, map_location="cpu")
49
+ if "state_dict" in state_dict.keys():
50
+ state_dict = state_dict["state_dict"]
51
+ for k in list(state_dict.keys()):
52
+ if k.startswith("model."):
53
+ state_dict[k.replace("model.", "", 1)] = state_dict.pop(k)
54
+ if "encoder.net.fc" in k:
55
+ state_dict.pop(k)
56
+ model.load_state_dict(state_dict)
57
+
58
+ elif weight_name == "gim_loftr":
59
+ state_dict = torch.load(checkpoints_path, map_location="cpu")
60
+ if "state_dict" in state_dict.keys():
61
+ state_dict = state_dict["state_dict"]
62
+ model.load_state_dict(state_dict)
63
+
64
+ elif weight_name == "gim_lightglue":
65
+ state_dict = torch.load(checkpoints_path, map_location="cpu")
66
+ if "state_dict" in state_dict.keys():
67
+ state_dict = state_dict["state_dict"]
68
+ for k in list(state_dict.keys()):
69
+ if k.startswith("model."):
70
+ state_dict.pop(k)
71
+ if k.startswith("superpoint."):
72
+ state_dict[k.replace("superpoint.", "", 1)] = state_dict.pop(k)
73
+ detector.load_state_dict(state_dict)
74
+
75
+ state_dict = torch.load(checkpoints_path, map_location="cpu")
76
+ if "state_dict" in state_dict.keys():
77
+ state_dict = state_dict["state_dict"]
78
+ for k in list(state_dict.keys()):
79
+ if k.startswith("superpoint."):
80
+ state_dict.pop(k)
81
+ if k.startswith("model."):
82
+ state_dict[k.replace("model.", "", 1)] = state_dict.pop(k)
83
+ model.load_state_dict(state_dict)
84
+
85
+ # eval mode
86
+ if detector is not None:
87
+ detector = detector.eval().to(DEVICE)
88
+ model = model.eval().to(DEVICE)
89
+ return model
90
 
91
 
92
  class GIM(BaseModel):
93
  default_conf = {
 
94
  "match_threshold": 0.2,
95
  "checkpoint_dir": gim_path / "weights",
96
+ "weights": "gim_dkm",
97
  }
98
  required_inputs = [
99
  "image0",
100
  "image1",
101
  ]
102
+ ckpt_name_dict = {
103
+ "gim_dkm": "gim_dkm_100h.ckpt",
104
+ "gim_loftr": "gim_loftr_50h.ckpt",
105
+ "gim_lightglue": "gim_lightglue_100h.ckpt",
106
+ }
107
 
108
  def _init(self, conf):
109
+ ckpt_name = self.ckpt_name_dict[conf["weights"]]
110
  model_path = self._download_model(
111
  repo_id=MODEL_REPO_ID,
112
+ filename="{}/{}".format(Path(__file__).stem, ckpt_name),
 
 
113
  )
 
114
  self.aspect_ratio = 896 / 672
115
+ model = load_model(conf["weights"], model_path)
 
 
 
 
 
 
 
 
 
 
116
  self.net = model
117
  logger.info("Loaded GIM model")
118
 
 
164
  return mask
165
 
166
  def _forward(self, data):
167
+ # TODO: only support dkm+gim
168
  image0, image1 = self.pad_image(
169
  data["image0"], self.aspect_ratio
170
  ), self.pad_image(data["image1"], self.aspect_ratio)
hloc/matchers/imp.py CHANGED
@@ -34,7 +34,7 @@ class IMP(BaseModel):
34
  model_path = self._download_model(
35
  repo_id=MODEL_REPO_ID,
36
  filename="{}/{}".format(
37
- Path(__file__).stem, self.conf["model_name"]
38
  ),
39
  )
40
 
 
34
  model_path = self._download_model(
35
  repo_id=MODEL_REPO_ID,
36
  filename="{}/{}".format(
37
+ 'pram', self.conf["model_name"]
38
  ),
39
  )
40
 
hloc/matchers/lightglue.py CHANGED
@@ -33,6 +33,7 @@ class LightGlue(BaseModel):
33
  ]
34
 
35
  def _init(self, conf):
 
36
  model_path = self._download_model(
37
  repo_id=MODEL_REPO_ID,
38
  filename="{}/{}".format(
 
33
  ]
34
 
35
  def _init(self, conf):
36
+ logger.info("Loading lightglue model, {}".format(conf["model_name"]))
37
  model_path = self._download_model(
38
  repo_id=MODEL_REPO_ID,
39
  filename="{}/{}".format(
third_party/gim/gim/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Author : xuelun
third_party/gim/{dkm β†’ gim/dkm}/__init__.py RENAMED
File without changes
third_party/gim/{dkm β†’ gim/dkm}/benchmarks/__init__.py RENAMED
File without changes
third_party/gim/{dkm β†’ gim/dkm}/benchmarks/hpatches_sequences_homog_benchmark.py RENAMED
File without changes
third_party/gim/{dkm β†’ gim/dkm}/benchmarks/megadepth1500_benchmark.py RENAMED
File without changes
third_party/gim/{dkm β†’ gim/dkm}/benchmarks/megadepth_dense_benchmark.py RENAMED
File without changes
third_party/gim/{dkm β†’ gim/dkm}/benchmarks/scannet_benchmark.py RENAMED
File without changes
third_party/gim/{dkm β†’ gim/dkm}/checkpointing/__init__.py RENAMED
File without changes
third_party/gim/{dkm β†’ gim/dkm}/checkpointing/checkpoint.py RENAMED
File without changes
third_party/gim/{dkm β†’ gim/dkm}/datasets/__init__.py RENAMED
File without changes
third_party/gim/{dkm β†’ gim/dkm}/datasets/megadepth.py RENAMED
File without changes
third_party/gim/{dkm β†’ gim/dkm}/datasets/scannet.py RENAMED
File without changes
third_party/gim/{dkm β†’ gim/dkm}/losses/__init__.py RENAMED
File without changes
third_party/gim/{dkm β†’ gim/dkm}/losses/depth_match_regression_loss.py RENAMED
File without changes
third_party/gim/{dkm β†’ gim/dkm}/models/__init__.py RENAMED
File without changes
third_party/gim/{dkm β†’ gim/dkm}/models/dkm.py RENAMED
@@ -5,9 +5,9 @@ from PIL import Image
5
  import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
- from dkm.utils import get_tuple_transform_ops
9
  from einops import rearrange
10
- from dkm.utils.local_correlation import local_correlation
11
 
12
 
13
  class ConvRefiner(nn.Module):
@@ -609,7 +609,7 @@ class RegressionMatcher(nn.Module):
609
  if "balanced" not in self.sample_mode:
610
  return good_matches, good_certainty
611
 
612
- from dkm.utils.kde import kde
613
  density = kde(good_matches, std=0.1)
614
  p = 1 / (density+1)
615
  p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
 
5
  import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
+ from gim.dkm.utils import get_tuple_transform_ops
9
  from einops import rearrange
10
+ from gim.dkm.utils.local_correlation import local_correlation
11
 
12
 
13
  class ConvRefiner(nn.Module):
 
609
  if "balanced" not in self.sample_mode:
610
  return good_matches, good_certainty
611
 
612
+ from gim.dkm.utils.kde import kde
613
  density = kde(good_matches, std=0.1)
614
  p = 1 / (density+1)
615
  p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
third_party/gim/{dkm β†’ gim/dkm}/models/encoders.py RENAMED
File without changes
third_party/gim/{dkm β†’ gim/dkm}/models/model_zoo/DKMv3.py RENAMED
@@ -1,8 +1,8 @@
1
  import torch
2
 
3
  from torch import nn
4
- from dkm.models.dkm import *
5
- from dkm.models.encoders import *
6
 
7
 
8
  def DKMv3(weights, h, w, symmetric = True, sample_mode= "threshold_balanced", **kwargs):
 
1
  import torch
2
 
3
  from torch import nn
4
+ from gim.dkm.models.dkm import *
5
+ from gim.dkm.models.encoders import *
6
 
7
 
8
  def DKMv3(weights, h, w, symmetric = True, sample_mode= "threshold_balanced", **kwargs):
third_party/gim/{dkm β†’ gim/dkm}/models/model_zoo/__init__.py RENAMED
File without changes
third_party/gim/{dkm β†’ gim/dkm}/train/__init__.py RENAMED
File without changes
third_party/gim/{dkm β†’ gim/dkm}/train/train.py RENAMED
File without changes
third_party/gim/{dkm β†’ gim/dkm}/utils/__init__.py RENAMED
File without changes
third_party/gim/{dkm β†’ gim/dkm}/utils/kde.py RENAMED
File without changes
third_party/gim/{dkm β†’ gim/dkm}/utils/local_correlation.py RENAMED
File without changes
third_party/gim/{dkm β†’ gim/dkm}/utils/transforms.py RENAMED
File without changes
third_party/gim/{dkm β†’ gim/dkm}/utils/utils.py RENAMED
File without changes
third_party/gim/{gluefactory β†’ gim/gluefactory}/__init__.py RENAMED
File without changes
third_party/gim/{gluefactory β†’ gim/gluefactory}/configs/aliked+NN.yaml RENAMED
File without changes
third_party/gim/{gluefactory β†’ gim/gluefactory}/configs/aliked+lightglue-official.yaml RENAMED
File without changes
third_party/gim/{gluefactory β†’ gim/gluefactory}/configs/aliked+lightglue_homography.yaml RENAMED
File without changes
third_party/gim/{gluefactory β†’ gim/gluefactory}/configs/aliked+lightglue_megadepth.yaml RENAMED
File without changes
third_party/gim/{gluefactory β†’ gim/gluefactory}/configs/disk+NN.yaml RENAMED
File without changes
third_party/gim/{gluefactory β†’ gim/gluefactory}/configs/disk+lightglue-official.yaml RENAMED
File without changes
third_party/gim/{gluefactory β†’ gim/gluefactory}/configs/disk+lightglue_homography.yaml RENAMED
File without changes
third_party/gim/{gluefactory β†’ gim/gluefactory}/configs/disk+lightglue_megadepth.yaml RENAMED
File without changes
third_party/gim/{gluefactory β†’ gim/gluefactory}/configs/sift+NN.yaml RENAMED
File without changes
third_party/gim/{gluefactory β†’ gim/gluefactory}/configs/sift+lightglue-official.yaml RENAMED
File without changes
third_party/gim/{gluefactory β†’ gim/gluefactory}/configs/sift+lightglue_homography.yaml RENAMED
File without changes
third_party/gim/{gluefactory β†’ gim/gluefactory}/configs/sift+lightglue_megadepth.yaml RENAMED
File without changes
third_party/gim/{gluefactory β†’ gim/gluefactory}/configs/superpoint+NN.yaml RENAMED
File without changes
third_party/gim/{gluefactory β†’ gim/gluefactory}/configs/superpoint+lightglue-official.yaml RENAMED
File without changes
third_party/gim/{gluefactory β†’ gim/gluefactory}/configs/superpoint+lightglue_homography.yaml RENAMED
File without changes