Spaces:
Running
Running
Realcat
commited on
Commit
·
10624da
1
Parent(s):
4f5bf21
fix: lanet and r2d2
Browse files- hloc/matchers/imp.py +2 -2
- third_party/lanet/augmentations.py +1 -1
- third_party/lanet/evaluation/descriptor_evaluation.py +1 -1
- third_party/lanet/evaluation/detector_evaluation.py +1 -1
- third_party/lanet/network_v0/modules.py +1 -1
- third_party/lanet/network_v1/modules.py +1 -1
- third_party/r2d2/extract.py +3 -3
hloc/matchers/imp.py
CHANGED
@@ -4,7 +4,7 @@ from pathlib import Path
|
|
4 |
|
5 |
import torch
|
6 |
|
7 |
-
from .. import
|
8 |
from ..utils.base_model import BaseModel
|
9 |
|
10 |
pram_path = Path(__file__).parent / "../../third_party/pram"
|
@@ -34,7 +34,7 @@ class IMP(BaseModel):
|
|
34 |
def _init(self, conf):
|
35 |
self.conf = {**self.default_conf, **conf}
|
36 |
weight_path = pram_path / "weights" / self.conf["model_name"]
|
37 |
-
self.net = GML(self.conf).eval().to(
|
38 |
self.net.load_state_dict(
|
39 |
torch.load(weight_path, map_location="cpu")["model"], strict=True
|
40 |
)
|
|
|
4 |
|
5 |
import torch
|
6 |
|
7 |
+
from .. import DEVICE, logger
|
8 |
from ..utils.base_model import BaseModel
|
9 |
|
10 |
pram_path = Path(__file__).parent / "../../third_party/pram"
|
|
|
34 |
def _init(self, conf):
|
35 |
self.conf = {**self.default_conf, **conf}
|
36 |
weight_path = pram_path / "weights" / self.conf["model_name"]
|
37 |
+
self.net = GML(self.conf).eval().to(DEVICE)
|
38 |
self.net.load_state_dict(
|
39 |
torch.load(weight_path, map_location="cpu")["model"], strict=True
|
40 |
)
|
third_party/lanet/augmentations.py
CHANGED
@@ -12,7 +12,7 @@ import torchvision
|
|
12 |
import torchvision.transforms as transforms
|
13 |
from PIL import Image
|
14 |
|
15 |
-
from lanet_utils import image_grid
|
16 |
|
17 |
|
18 |
def filter_dict(dict, keywords):
|
|
|
12 |
import torchvision.transforms as transforms
|
13 |
from PIL import Image
|
14 |
|
15 |
+
from ..lanet_utils import image_grid
|
16 |
|
17 |
|
18 |
def filter_dict(dict, keywords):
|
third_party/lanet/evaluation/descriptor_evaluation.py
CHANGED
@@ -8,7 +8,7 @@ from os import path as osp
|
|
8 |
import cv2
|
9 |
import numpy as np
|
10 |
|
11 |
-
from lanet_utils import warp_keypoints
|
12 |
|
13 |
|
14 |
def select_k_best(points, descriptors, k):
|
|
|
8 |
import cv2
|
9 |
import numpy as np
|
10 |
|
11 |
+
from ..lanet_utils import warp_keypoints
|
12 |
|
13 |
|
14 |
def select_k_best(points, descriptors, k):
|
third_party/lanet/evaluation/detector_evaluation.py
CHANGED
@@ -8,7 +8,7 @@ from os import path as osp
|
|
8 |
import cv2
|
9 |
import numpy as np
|
10 |
|
11 |
-
from lanet_utils import warp_keypoints
|
12 |
|
13 |
|
14 |
def compute_repeatability(data, keep_k_points=300, distance_thresh=3):
|
|
|
8 |
import cv2
|
9 |
import numpy as np
|
10 |
|
11 |
+
from ..lanet_utils import warp_keypoints
|
12 |
|
13 |
|
14 |
def compute_repeatability(data, keep_k_points=300, distance_thresh=3):
|
third_party/lanet/network_v0/modules.py
CHANGED
@@ -2,7 +2,7 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
4 |
|
5 |
-
from lanet_utils import image_grid
|
6 |
|
7 |
|
8 |
class ConvBlock(nn.Module):
|
|
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
4 |
|
5 |
+
from ..lanet_utils import image_grid
|
6 |
|
7 |
|
8 |
class ConvBlock(nn.Module):
|
third_party/lanet/network_v1/modules.py
CHANGED
@@ -4,7 +4,7 @@ import torch.nn as nn
|
|
4 |
import torch.nn.functional as F
|
5 |
|
6 |
from torchvision import models
|
7 |
-
from lanet_utils import image_grid
|
8 |
|
9 |
|
10 |
class ConvBlock(nn.Module):
|
|
|
4 |
import torch.nn.functional as F
|
5 |
|
6 |
from torchvision import models
|
7 |
+
from ..lanet_utils import image_grid
|
8 |
|
9 |
|
10 |
class ConvBlock(nn.Module):
|
third_party/r2d2/extract.py
CHANGED
@@ -8,9 +8,9 @@ from PIL import Image
|
|
8 |
import numpy as np
|
9 |
import torch
|
10 |
|
11 |
-
from
|
12 |
-
from
|
13 |
-
from
|
14 |
|
15 |
|
16 |
def load_network(model_fn):
|
|
|
8 |
import numpy as np
|
9 |
import torch
|
10 |
|
11 |
+
from tools import common
|
12 |
+
from tools.dataloader import norm_RGB
|
13 |
+
from nets.patchnet import *
|
14 |
|
15 |
|
16 |
def load_network(model_fn):
|