Realcat commited on
Commit
10624da
·
1 Parent(s): 4f5bf21

fix: lanet and r2d2

Browse files
hloc/matchers/imp.py CHANGED
@@ -4,7 +4,7 @@ from pathlib import Path
4
 
5
  import torch
6
 
7
- from .. import device, logger
8
  from ..utils.base_model import BaseModel
9
 
10
  pram_path = Path(__file__).parent / "../../third_party/pram"
@@ -34,7 +34,7 @@ class IMP(BaseModel):
34
  def _init(self, conf):
35
  self.conf = {**self.default_conf, **conf}
36
  weight_path = pram_path / "weights" / self.conf["model_name"]
37
- self.net = GML(self.conf).eval().to(device)
38
  self.net.load_state_dict(
39
  torch.load(weight_path, map_location="cpu")["model"], strict=True
40
  )
 
4
 
5
  import torch
6
 
7
+ from .. import DEVICE, logger
8
  from ..utils.base_model import BaseModel
9
 
10
  pram_path = Path(__file__).parent / "../../third_party/pram"
 
34
  def _init(self, conf):
35
  self.conf = {**self.default_conf, **conf}
36
  weight_path = pram_path / "weights" / self.conf["model_name"]
37
+ self.net = GML(self.conf).eval().to(DEVICE)
38
  self.net.load_state_dict(
39
  torch.load(weight_path, map_location="cpu")["model"], strict=True
40
  )
third_party/lanet/augmentations.py CHANGED
@@ -12,7 +12,7 @@ import torchvision
12
  import torchvision.transforms as transforms
13
  from PIL import Image
14
 
15
- from lanet_utils import image_grid
16
 
17
 
18
  def filter_dict(dict, keywords):
 
12
  import torchvision.transforms as transforms
13
  from PIL import Image
14
 
15
+ from ..lanet_utils import image_grid
16
 
17
 
18
  def filter_dict(dict, keywords):
third_party/lanet/evaluation/descriptor_evaluation.py CHANGED
@@ -8,7 +8,7 @@ from os import path as osp
8
  import cv2
9
  import numpy as np
10
 
11
- from lanet_utils import warp_keypoints
12
 
13
 
14
  def select_k_best(points, descriptors, k):
 
8
  import cv2
9
  import numpy as np
10
 
11
+ from ..lanet_utils import warp_keypoints
12
 
13
 
14
  def select_k_best(points, descriptors, k):
third_party/lanet/evaluation/detector_evaluation.py CHANGED
@@ -8,7 +8,7 @@ from os import path as osp
8
  import cv2
9
  import numpy as np
10
 
11
- from lanet_utils import warp_keypoints
12
 
13
 
14
  def compute_repeatability(data, keep_k_points=300, distance_thresh=3):
 
8
  import cv2
9
  import numpy as np
10
 
11
+ from ..lanet_utils import warp_keypoints
12
 
13
 
14
  def compute_repeatability(data, keep_k_points=300, distance_thresh=3):
third_party/lanet/network_v0/modules.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
 
5
- from lanet_utils import image_grid
6
 
7
 
8
  class ConvBlock(nn.Module):
 
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
 
5
+ from ..lanet_utils import image_grid
6
 
7
 
8
  class ConvBlock(nn.Module):
third_party/lanet/network_v1/modules.py CHANGED
@@ -4,7 +4,7 @@ import torch.nn as nn
4
  import torch.nn.functional as F
5
 
6
  from torchvision import models
7
- from lanet_utils import image_grid
8
 
9
 
10
  class ConvBlock(nn.Module):
 
4
  import torch.nn.functional as F
5
 
6
  from torchvision import models
7
+ from ..lanet_utils import image_grid
8
 
9
 
10
  class ConvBlock(nn.Module):
third_party/r2d2/extract.py CHANGED
@@ -8,9 +8,9 @@ from PIL import Image
8
  import numpy as np
9
  import torch
10
 
11
- from .tools import common
12
- from .tools.dataloader import norm_RGB
13
- from .nets.patchnet import *
14
 
15
 
16
  def load_network(model_fn):
 
8
  import numpy as np
9
  import torch
10
 
11
+ from tools import common
12
+ from tools.dataloader import norm_RGB
13
+ from nets.patchnet import *
14
 
15
 
16
  def load_network(model_fn):