Spaces:
Running
Running
Vincentqyw
commited on
Commit
·
9fb6531
1
Parent(s):
7e4f5b4
add: rord libs
Browse files- .gitignore +1 -3
- third_party/RoRD/lib/__init__.py +0 -0
- third_party/RoRD/lib/dataloaders/datasetPhotoTourism_combined.py +77 -0
- third_party/RoRD/lib/dataloaders/datasetPhotoTourism_ipr.py +170 -0
- third_party/RoRD/lib/dataloaders/datasetPhotoTourism_real.py +258 -0
- third_party/RoRD/lib/exceptions.py +6 -0
- third_party/RoRD/lib/extractMatchTop.py +361 -0
- third_party/RoRD/lib/loss.py +342 -0
- third_party/RoRD/lib/losses/lossPhotoTourism.py +232 -0
- third_party/RoRD/lib/model.py +121 -0
- third_party/RoRD/lib/model_test.py +187 -0
- third_party/RoRD/lib/pyramid.py +129 -0
- third_party/RoRD/lib/utils.py +167 -0
.gitignore
CHANGED
@@ -1,8 +1,6 @@
|
|
1 |
build/
|
2 |
-
|
3 |
-
lib/
|
4 |
bin/
|
5 |
-
|
6 |
cmake_modules/
|
7 |
cmake-build-debug/
|
8 |
.idea/
|
|
|
1 |
build/
|
2 |
+
# lib
|
|
|
3 |
bin/
|
|
|
4 |
cmake_modules/
|
5 |
cmake-build-debug/
|
6 |
.idea/
|
third_party/RoRD/lib/__init__.py
ADDED
File without changes
|
third_party/RoRD/lib/dataloaders/datasetPhotoTourism_combined.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import random
|
5 |
+
|
6 |
+
import h5py
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
+
from tqdm import tqdm
|
10 |
+
import joblib
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch.utils.data import Dataset
|
14 |
+
from torch.utils.data import DataLoader
|
15 |
+
|
16 |
+
from lib.utils import preprocess_image
|
17 |
+
from lib.utils import preprocess_image, grid_positions, upscale_positions
|
18 |
+
from lib.dataloaders.datasetPhotoTourism_ipr import PhotoTourismIPR
|
19 |
+
from lib.dataloaders.datasetPhotoTourism_real import PhotoTourism
|
20 |
+
|
21 |
+
from sys import exit, argv
|
22 |
+
import cv2
|
23 |
+
import csv
|
24 |
+
|
25 |
+
np.random.seed(0)
|
26 |
+
|
27 |
+
|
28 |
+
class PhotoTourismCombined(Dataset):
|
29 |
+
def __init__(self, base_path, preprocessing, ipr_pref=0.5, train=True, cropSize=256):
|
30 |
+
self.base_path = base_path
|
31 |
+
self.preprocessing = preprocessing
|
32 |
+
self.cropSize=cropSize
|
33 |
+
|
34 |
+
self.ipr_pref = ipr_pref
|
35 |
+
|
36 |
+
# self.dataset_len = 0
|
37 |
+
# self.dataset_len2 = 0
|
38 |
+
|
39 |
+
print("[INFO] Building Original Dataset")
|
40 |
+
self.PTReal = PhotoTourism(base_path, preprocessing=preprocessing, train=train, image_size=cropSize)
|
41 |
+
self.PTReal.build_dataset()
|
42 |
+
|
43 |
+
# self.dataset_len1 = len(self.PTReal)
|
44 |
+
# print("size 1:",len(self.PTReal))
|
45 |
+
# for _ in self.PTReal:
|
46 |
+
# pass
|
47 |
+
# print("size 2:",len(self.PTReal))
|
48 |
+
self.dataset_len1 = len(self.PTReal)
|
49 |
+
# joblib.dump(self.PTReal.dataset, os.path.join(self.base_path, "orig_PT_2.gz"), 3)
|
50 |
+
|
51 |
+
print("[INFO] Building IPR Dataset")
|
52 |
+
self.PTipr = PhotoTourismIPR(base_path, preprocessing=preprocessing, train=train, cropSize=cropSize)
|
53 |
+
self.PTipr.build_dataset()
|
54 |
+
|
55 |
+
# self.dataset_len2 = len(self.PTipr)
|
56 |
+
# print("size 1:",len(self.PTipr))
|
57 |
+
# for _ in self.PTipr:
|
58 |
+
# pass
|
59 |
+
# print("size 2:",len(self.PTipr))
|
60 |
+
self.dataset_len2 = len(self.PTipr)
|
61 |
+
|
62 |
+
# joblib.dump((self.PTipr.dataset_H, self.PTipr.valid_images), os.path.join(self.base_path, "ipr_PT_2.gz"), 3)
|
63 |
+
|
64 |
+
def __getitem__(self, idx):
|
65 |
+
if random.random()<self.ipr_pref:
|
66 |
+
return (self.PTipr[idx%self.dataset_len1], 1)
|
67 |
+
return (self.PTReal[idx%self.dataset_len2], 0)
|
68 |
+
|
69 |
+
def __len__(self):
|
70 |
+
return self.dataset_len2+self.dataset_len1
|
71 |
+
|
72 |
+
|
73 |
+
if __name__=="__main__":
|
74 |
+
pt = PhotoTourismCombined("/scratch/udit/phototourism/", 'caffe', 256)
|
75 |
+
dl = DataLoader(pt, batch_size=1, num_workers=2)
|
76 |
+
for _ in dl:
|
77 |
+
pass
|
third_party/RoRD/lib/dataloaders/datasetPhotoTourism_ipr.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from sys import exit, argv
|
3 |
+
import csv
|
4 |
+
import random
|
5 |
+
|
6 |
+
import joblib
|
7 |
+
import numpy as np
|
8 |
+
import cv2
|
9 |
+
from PIL import Image
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch.utils.data import Dataset
|
14 |
+
|
15 |
+
from lib.utils import preprocess_image, grid_positions, upscale_positions
|
16 |
+
|
17 |
+
np.random.seed(0)
|
18 |
+
|
19 |
+
|
20 |
+
class PhotoTourismIPR(Dataset):
|
21 |
+
def __init__(self, base_path, preprocessing, train=True, cropSize=256):
|
22 |
+
self.base_path = base_path
|
23 |
+
self.train = train
|
24 |
+
self.preprocessing = preprocessing
|
25 |
+
self.valid_images = []
|
26 |
+
self.cropSize=cropSize
|
27 |
+
|
28 |
+
def getImageFiles(self):
|
29 |
+
img_files = []
|
30 |
+
img_path = "dense/images"
|
31 |
+
if self.train:
|
32 |
+
print("Inside training!!")
|
33 |
+
|
34 |
+
with open(os.path.join("configs", "train_scenes_small.txt")) as f:
|
35 |
+
scenes = f.read().strip("\n").split("\n")
|
36 |
+
|
37 |
+
print("[INFO]",scenes)
|
38 |
+
for scene in scenes:
|
39 |
+
image_dir = os.path.join(self.base_path, scene, img_path)
|
40 |
+
img_names = os.listdir(image_dir)
|
41 |
+
img_files += [os.path.join(image_dir, img) for img in img_names]
|
42 |
+
return img_files
|
43 |
+
|
44 |
+
def imgCrop(self, img1):
|
45 |
+
w, h = img1.size
|
46 |
+
left = np.random.randint(low = 0, high = w - (self.cropSize))
|
47 |
+
upper = np.random.randint(low = 0, high = h - (self.cropSize))
|
48 |
+
|
49 |
+
cropImg = img1.crop((left, upper, left+self.cropSize, upper+self.cropSize))
|
50 |
+
|
51 |
+
return cropImg
|
52 |
+
|
53 |
+
def getGrid(self, im1, im2, H, scaling_steps=3):
|
54 |
+
h1, w1 = int(im1.shape[0]/(2**scaling_steps)), int(im1.shape[1]/(2**scaling_steps))
|
55 |
+
device = torch.device("cpu")
|
56 |
+
|
57 |
+
fmap_pos1 = grid_positions(h1, w1, device)
|
58 |
+
pos1 = upscale_positions(fmap_pos1, scaling_steps=scaling_steps).data.cpu().numpy()
|
59 |
+
|
60 |
+
pos1[[0, 1]] = pos1[[1, 0]]
|
61 |
+
|
62 |
+
ones = np.ones((1, pos1.shape[1]))
|
63 |
+
pos1Homo = np.vstack((pos1, ones))
|
64 |
+
pos2Homo = np.dot(H, pos1Homo)
|
65 |
+
pos2Homo = pos2Homo/pos2Homo[2, :]
|
66 |
+
pos2 = pos2Homo[0:2, :]
|
67 |
+
|
68 |
+
pos1[[0, 1]] = pos1[[1, 0]]
|
69 |
+
pos2[[0, 1]] = pos2[[1, 0]]
|
70 |
+
pos1 = pos1.astype(np.float32)
|
71 |
+
pos2 = pos2.astype(np.float32)
|
72 |
+
|
73 |
+
ids = []
|
74 |
+
for i in range(pos2.shape[1]):
|
75 |
+
x, y = pos2[:, i]
|
76 |
+
|
77 |
+
if(2 < x < (im1.shape[0]-2) and 2 < y < (im1.shape[1]-2)):
|
78 |
+
ids.append(i)
|
79 |
+
pos1 = pos1[:, ids]
|
80 |
+
pos2 = pos2[:, ids]
|
81 |
+
|
82 |
+
return pos1, pos2
|
83 |
+
|
84 |
+
def imgRotH(self, img1, min=0, max=360):
|
85 |
+
width, height = img1.size
|
86 |
+
theta = np.random.randint(low=min, high=max) * (np.pi / 180)
|
87 |
+
Tx = width / 2
|
88 |
+
Ty = height / 2
|
89 |
+
sx = random.uniform(-1e-2, 1e-2)
|
90 |
+
sy = random.uniform(-1e-2, 1e-2)
|
91 |
+
p1 = random.uniform(-1e-4, 1e-4)
|
92 |
+
p2 = random.uniform(-1e-4, 1e-4)
|
93 |
+
|
94 |
+
alpha = np.cos(theta)
|
95 |
+
beta = np.sin(theta)
|
96 |
+
|
97 |
+
He = np.matrix([[alpha, beta, Tx * (1 - alpha) - Ty * beta], [-beta, alpha, beta * Tx + (1 - alpha) * Ty], [0, 0, 1]])
|
98 |
+
Ha = np.matrix([[1, sy, 0], [sx, 1, 0], [0, 0, 1]])
|
99 |
+
Hp = np.matrix([[1, 0, 0], [0, 1, 0], [p1, p2, 1]])
|
100 |
+
|
101 |
+
H = He @ Ha @ Hp
|
102 |
+
|
103 |
+
return H, theta
|
104 |
+
|
105 |
+
def build_dataset(self):
|
106 |
+
print("Building Dataset.")
|
107 |
+
|
108 |
+
imgFiles = self.getImageFiles()
|
109 |
+
|
110 |
+
for idx in tqdm(range(len(imgFiles))):
|
111 |
+
|
112 |
+
img = imgFiles[idx]
|
113 |
+
img1 = Image.open(img)
|
114 |
+
|
115 |
+
if(img1.mode != 'RGB'):
|
116 |
+
img1 = img1.convert('RGB')
|
117 |
+
if(img1.size[0] < self.cropSize or img1.size[1] < self.cropSize):
|
118 |
+
continue
|
119 |
+
|
120 |
+
self.valid_images.append(img)
|
121 |
+
|
122 |
+
def __len__(self):
|
123 |
+
return len(self.valid_images)
|
124 |
+
|
125 |
+
def __getitem__(self, idx):
|
126 |
+
while 1:
|
127 |
+
try:
|
128 |
+
img = self.valid_images[idx]
|
129 |
+
|
130 |
+
img1 = Image.open(img)
|
131 |
+
img1 = self.imgCrop(img1)
|
132 |
+
width, height = img1.size
|
133 |
+
|
134 |
+
H, theta = self.imgRotH(img1, min=0, max=360)
|
135 |
+
|
136 |
+
img1 = np.array(img1)
|
137 |
+
img2 = cv2.warpPerspective(img1, H, dsize=(width,height))
|
138 |
+
img2 = np.array(img2)
|
139 |
+
|
140 |
+
pos1, pos2 = self.getGrid(img1, img2, H)
|
141 |
+
|
142 |
+
assert (len(pos1) != 0 and len(pos2) != 0)
|
143 |
+
break
|
144 |
+
except IndexError:
|
145 |
+
print("IndexError")
|
146 |
+
exit(1)
|
147 |
+
except:
|
148 |
+
del self.valid_images[idx]
|
149 |
+
|
150 |
+
img1 = preprocess_image(img1, preprocessing=self.preprocessing)
|
151 |
+
img2 = preprocess_image(img2, preprocessing=self.preprocessing)
|
152 |
+
|
153 |
+
return {
|
154 |
+
'image1': torch.from_numpy(img1.astype(np.float32)),
|
155 |
+
'image2': torch.from_numpy(img2.astype(np.float32)),
|
156 |
+
'pos1': torch.from_numpy(pos1.astype(np.float32)),
|
157 |
+
'pos2': torch.from_numpy(pos2.astype(np.float32)),
|
158 |
+
'H': np.array(H),
|
159 |
+
'theta': np.array([theta])
|
160 |
+
}
|
161 |
+
|
162 |
+
|
163 |
+
if __name__ == '__main__':
|
164 |
+
rootDir = argv[1]
|
165 |
+
|
166 |
+
training_dataset = PhotoTourismIPR(rootDir, 'caffe')
|
167 |
+
training_dataset.build_dataset()
|
168 |
+
|
169 |
+
data = training_dataset[0]
|
170 |
+
print(data['image1'].shape, data['image2'].shape, data['pos1'].shape, data['pos2'].shape, len(training_dataset))
|
third_party/RoRD/lib/dataloaders/datasetPhotoTourism_real.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
import h5py
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch.utils.data import Dataset
|
12 |
+
from lib.utils import preprocess_image
|
13 |
+
|
14 |
+
import joblib
|
15 |
+
|
16 |
+
|
17 |
+
class PhotoTourism(Dataset):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
#scene_list_path='megadepth_utils/train_scenes.txt',
|
21 |
+
# scene_info_path='/local/dataset/megadepth/scene_info',
|
22 |
+
base_path='/scratch/udit/phototourism',
|
23 |
+
train=True,
|
24 |
+
preprocessing=None,
|
25 |
+
min_overlap_ratio=.5,
|
26 |
+
max_overlap_ratio=1,
|
27 |
+
max_scale_ratio=np.inf,
|
28 |
+
pairs_per_scene=500,
|
29 |
+
image_size=256
|
30 |
+
):
|
31 |
+
if train:
|
32 |
+
scene_list_path = os.path.join(base_path, "train_scenes.txt.bkp")
|
33 |
+
else:
|
34 |
+
scene_list_path = os.path.join(base_path, "valid_scenes.txt")
|
35 |
+
self.scenes = []
|
36 |
+
with open(scene_list_path, 'r') as f:
|
37 |
+
lines = f.readlines()
|
38 |
+
for line in lines:
|
39 |
+
self.scenes.append(line.strip('\n'))
|
40 |
+
|
41 |
+
# self.scene_info_path = scene_info_path
|
42 |
+
self.base_path = base_path
|
43 |
+
|
44 |
+
self.train = train
|
45 |
+
|
46 |
+
self.preprocessing = preprocessing
|
47 |
+
|
48 |
+
self.min_overlap_ratio = min_overlap_ratio
|
49 |
+
self.max_overlap_ratio = max_overlap_ratio
|
50 |
+
self.max_scale_ratio = max_scale_ratio
|
51 |
+
|
52 |
+
self.pairs_per_scene = pairs_per_scene
|
53 |
+
|
54 |
+
self.image_size = image_size
|
55 |
+
|
56 |
+
self.dataset = []
|
57 |
+
|
58 |
+
def build_dataset(self):
|
59 |
+
cache_path = os.path.join(self.base_path, "orig_PT_2.gz")
|
60 |
+
if os.path.exists(cache_path):
|
61 |
+
self.dataset = joblib.load(cache_path)
|
62 |
+
return
|
63 |
+
|
64 |
+
self.dataset = []
|
65 |
+
if not self.train:
|
66 |
+
np_random_state = np.random.get_state()
|
67 |
+
np.random.seed(42)
|
68 |
+
print('Building the validation dataset...')
|
69 |
+
else:
|
70 |
+
print('Building a new training dataset...')
|
71 |
+
|
72 |
+
for scene in tqdm(self.scenes, total=len(self.scenes)):
|
73 |
+
|
74 |
+
scene_info_path = os.path.join(
|
75 |
+
self.base_path, scene, '%s.npz' % scene
|
76 |
+
)
|
77 |
+
|
78 |
+
if not os.path.exists(scene_info_path):
|
79 |
+
continue
|
80 |
+
|
81 |
+
scene_info = np.load(scene_info_path, allow_pickle=True)
|
82 |
+
overlap_matrix = scene_info['overlap_matrix']
|
83 |
+
scale_ratio_matrix = scene_info['scale_ratio_matrix']
|
84 |
+
|
85 |
+
valid = np.logical_and(
|
86 |
+
np.logical_and(
|
87 |
+
overlap_matrix >= self.min_overlap_ratio,
|
88 |
+
overlap_matrix <= self.max_overlap_ratio
|
89 |
+
),
|
90 |
+
scale_ratio_matrix <= self.max_scale_ratio
|
91 |
+
)
|
92 |
+
|
93 |
+
pairs = np.vstack(np.where(valid))
|
94 |
+
try:
|
95 |
+
selected_ids = np.random.choice(
|
96 |
+
pairs.shape[1], self.pairs_per_scene
|
97 |
+
)
|
98 |
+
except:
|
99 |
+
return
|
100 |
+
|
101 |
+
image_paths = scene_info['image_paths']
|
102 |
+
depth_paths = scene_info['depth_paths']
|
103 |
+
points3D_id_to_2D = scene_info['points3D_id_to_2D']
|
104 |
+
points3D_id_to_ndepth = scene_info['points3D_id_to_ndepth']
|
105 |
+
intrinsics = scene_info['intrinsics']
|
106 |
+
poses = scene_info['poses']
|
107 |
+
|
108 |
+
for pair_idx in selected_ids:
|
109 |
+
idx1 = pairs[0, pair_idx]
|
110 |
+
idx2 = pairs[1, pair_idx]
|
111 |
+
matches = np.array(list(
|
112 |
+
points3D_id_to_2D[idx1].keys() &
|
113 |
+
points3D_id_to_2D[idx2].keys()
|
114 |
+
))
|
115 |
+
|
116 |
+
# Scale filtering
|
117 |
+
matches_nd1 = np.array([points3D_id_to_ndepth[idx1][match] for match in matches])
|
118 |
+
matches_nd2 = np.array([points3D_id_to_ndepth[idx2][match] for match in matches])
|
119 |
+
scale_ratio = np.maximum(matches_nd1 / matches_nd2, matches_nd2 / matches_nd1)
|
120 |
+
matches = matches[np.where(scale_ratio <= self.max_scale_ratio)[0]]
|
121 |
+
|
122 |
+
point3D_id = np.random.choice(matches)
|
123 |
+
point2D1 = points3D_id_to_2D[idx1][point3D_id]
|
124 |
+
point2D2 = points3D_id_to_2D[idx2][point3D_id]
|
125 |
+
nd1 = points3D_id_to_ndepth[idx1][point3D_id]
|
126 |
+
nd2 = points3D_id_to_ndepth[idx2][point3D_id]
|
127 |
+
central_match = np.array([
|
128 |
+
point2D1[1], point2D1[0],
|
129 |
+
point2D2[1], point2D2[0]
|
130 |
+
])
|
131 |
+
self.dataset.append({
|
132 |
+
'image_path1': image_paths[idx1],
|
133 |
+
'depth_path1': depth_paths[idx1],
|
134 |
+
'intrinsics1': intrinsics[idx1],
|
135 |
+
'pose1': poses[idx1],
|
136 |
+
'image_path2': image_paths[idx2],
|
137 |
+
'depth_path2': depth_paths[idx2],
|
138 |
+
'intrinsics2': intrinsics[idx2],
|
139 |
+
'pose2': poses[idx2],
|
140 |
+
'central_match': central_match,
|
141 |
+
'scale_ratio': max(nd1 / nd2, nd2 / nd1)
|
142 |
+
})
|
143 |
+
np.random.shuffle(self.dataset)
|
144 |
+
joblib.dump(self.dataset, cache_path, 3)
|
145 |
+
if not self.train:
|
146 |
+
np.random.set_state(np_random_state)
|
147 |
+
|
148 |
+
def __len__(self):
|
149 |
+
return len(self.dataset)
|
150 |
+
|
151 |
+
def recover_pair(self, pair_metadata):
|
152 |
+
depth_path1 = os.path.join(
|
153 |
+
self.base_path, pair_metadata['depth_path1']
|
154 |
+
)
|
155 |
+
with h5py.File(depth_path1, 'r') as hdf5_file:
|
156 |
+
depth1 = np.array(hdf5_file['/depth'])
|
157 |
+
assert(np.min(depth1) >= 0)
|
158 |
+
image_path1 = os.path.join(
|
159 |
+
self.base_path, pair_metadata['image_path1']
|
160 |
+
)
|
161 |
+
image1 = Image.open(image_path1)
|
162 |
+
if image1.mode != 'RGB':
|
163 |
+
image1 = image1.convert('RGB')
|
164 |
+
image1 = np.array(image1)
|
165 |
+
assert(image1.shape[0] == depth1.shape[0] and image1.shape[1] == depth1.shape[1])
|
166 |
+
intrinsics1 = pair_metadata['intrinsics1']
|
167 |
+
pose1 = pair_metadata['pose1']
|
168 |
+
|
169 |
+
depth_path2 = os.path.join(
|
170 |
+
self.base_path, pair_metadata['depth_path2']
|
171 |
+
)
|
172 |
+
with h5py.File(depth_path2, 'r') as hdf5_file:
|
173 |
+
depth2 = np.array(hdf5_file['/depth'])
|
174 |
+
assert(np.min(depth2) >= 0)
|
175 |
+
image_path2 = os.path.join(
|
176 |
+
self.base_path, pair_metadata['image_path2']
|
177 |
+
)
|
178 |
+
image2 = Image.open(image_path2)
|
179 |
+
if image2.mode != 'RGB':
|
180 |
+
image2 = image2.convert('RGB')
|
181 |
+
image2 = np.array(image2)
|
182 |
+
assert(image2.shape[0] == depth2.shape[0] and image2.shape[1] == depth2.shape[1])
|
183 |
+
intrinsics2 = pair_metadata['intrinsics2']
|
184 |
+
pose2 = pair_metadata['pose2']
|
185 |
+
|
186 |
+
central_match = pair_metadata['central_match']
|
187 |
+
image1, bbox1, image2, bbox2 = self.crop(image1, image2, central_match)
|
188 |
+
|
189 |
+
depth1 = depth1[
|
190 |
+
bbox1[0] : bbox1[0] + self.image_size,
|
191 |
+
bbox1[1] : bbox1[1] + self.image_size
|
192 |
+
]
|
193 |
+
depth2 = depth2[
|
194 |
+
bbox2[0] : bbox2[0] + self.image_size,
|
195 |
+
bbox2[1] : bbox2[1] + self.image_size
|
196 |
+
]
|
197 |
+
|
198 |
+
return (
|
199 |
+
image1, depth1, intrinsics1, pose1, bbox1,
|
200 |
+
image2, depth2, intrinsics2, pose2, bbox2
|
201 |
+
)
|
202 |
+
|
203 |
+
def crop(self, image1, image2, central_match):
|
204 |
+
bbox1_i = max(int(central_match[0]) - self.image_size // 2, 0)
|
205 |
+
if bbox1_i + self.image_size >= image1.shape[0]:
|
206 |
+
bbox1_i = image1.shape[0] - self.image_size
|
207 |
+
bbox1_j = max(int(central_match[1]) - self.image_size // 2, 0)
|
208 |
+
if bbox1_j + self.image_size >= image1.shape[1]:
|
209 |
+
bbox1_j = image1.shape[1] - self.image_size
|
210 |
+
|
211 |
+
bbox2_i = max(int(central_match[2]) - self.image_size // 2, 0)
|
212 |
+
if bbox2_i + self.image_size >= image2.shape[0]:
|
213 |
+
bbox2_i = image2.shape[0] - self.image_size
|
214 |
+
bbox2_j = max(int(central_match[3]) - self.image_size // 2, 0)
|
215 |
+
if bbox2_j + self.image_size >= image2.shape[1]:
|
216 |
+
bbox2_j = image2.shape[1] - self.image_size
|
217 |
+
|
218 |
+
return (
|
219 |
+
image1[
|
220 |
+
bbox1_i : bbox1_i + self.image_size,
|
221 |
+
bbox1_j : bbox1_j + self.image_size
|
222 |
+
],
|
223 |
+
np.array([bbox1_i, bbox1_j]),
|
224 |
+
image2[
|
225 |
+
bbox2_i : bbox2_i + self.image_size,
|
226 |
+
bbox2_j : bbox2_j + self.image_size
|
227 |
+
],
|
228 |
+
np.array([bbox2_i, bbox2_j])
|
229 |
+
)
|
230 |
+
|
231 |
+
def __getitem__(self, idx):
|
232 |
+
while 1:
|
233 |
+
try:
|
234 |
+
(
|
235 |
+
image1, depth1, intrinsics1, pose1, bbox1,
|
236 |
+
image2, depth2, intrinsics2, pose2, bbox2
|
237 |
+
) = self.recover_pair(self.dataset[idx])
|
238 |
+
image1 = preprocess_image(image1, preprocessing=self.preprocessing)
|
239 |
+
image2 = preprocess_image(image2, preprocessing=self.preprocessing)
|
240 |
+
assert np.all(image1.shape==image2.shape)
|
241 |
+
break
|
242 |
+
except IndexError:
|
243 |
+
idx-=1
|
244 |
+
except:
|
245 |
+
del self.dataset[idx]
|
246 |
+
|
247 |
+
return {
|
248 |
+
'image1': torch.from_numpy(image1.astype(np.float32)),
|
249 |
+
'depth1': torch.from_numpy(depth1.astype(np.float32)),
|
250 |
+
'intrinsics1': torch.from_numpy(intrinsics1.astype(np.float32)),
|
251 |
+
'pose1': torch.from_numpy(pose1.astype(np.float32)),
|
252 |
+
'bbox1': torch.from_numpy(bbox1.astype(np.float32)),
|
253 |
+
'image2': torch.from_numpy(image2.astype(np.float32)),
|
254 |
+
'depth2': torch.from_numpy(depth2.astype(np.float32)),
|
255 |
+
'intrinsics2': torch.from_numpy(intrinsics2.astype(np.float32)),
|
256 |
+
'pose2': torch.from_numpy(pose2.astype(np.float32)),
|
257 |
+
'bbox2': torch.from_numpy(bbox2.astype(np.float32))
|
258 |
+
}
|
third_party/RoRD/lib/exceptions.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class EmptyTensorError(Exception):
|
2 |
+
pass
|
3 |
+
|
4 |
+
|
5 |
+
class NoGradientError(Exception):
|
6 |
+
pass
|
third_party/RoRD/lib/extractMatchTop.py
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import numpy as np
|
3 |
+
import imageio
|
4 |
+
import torch
|
5 |
+
from tqdm import tqdm
|
6 |
+
import time
|
7 |
+
import scipy
|
8 |
+
import scipy.io
|
9 |
+
import scipy.misc
|
10 |
+
|
11 |
+
from lib.model_test import D2Net
|
12 |
+
from lib.utils import preprocess_image
|
13 |
+
from lib.pyramid import process_multiscale
|
14 |
+
|
15 |
+
import cv2
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
import os
|
18 |
+
from sys import exit, argv
|
19 |
+
from PIL import Image
|
20 |
+
from skimage.feature import match_descriptors
|
21 |
+
from skimage.measure import ransac
|
22 |
+
from skimage.transform import ProjectiveTransform, AffineTransform
|
23 |
+
import pydegensac
|
24 |
+
|
25 |
+
|
26 |
+
def extractSingle(image, model, device):
|
27 |
+
|
28 |
+
with torch.no_grad():
|
29 |
+
keypoints, scores, descriptors = process_multiscale(
|
30 |
+
image.to(device).unsqueeze(0),
|
31 |
+
model,
|
32 |
+
scales=[1]
|
33 |
+
)
|
34 |
+
|
35 |
+
keypoints = keypoints[:, [1, 0, 2]]
|
36 |
+
|
37 |
+
feat = {}
|
38 |
+
feat['keypoints'] = keypoints
|
39 |
+
feat['scores'] = scores
|
40 |
+
feat['descriptors'] = descriptors
|
41 |
+
|
42 |
+
return feat
|
43 |
+
|
44 |
+
|
45 |
+
def siftMatching(img1, img2, HFile1, HFile2, device):
|
46 |
+
if HFile1 is not None:
|
47 |
+
H1 = np.load(HFile1)
|
48 |
+
H2 = np.load(HFile2)
|
49 |
+
|
50 |
+
rgbFile1 = img1
|
51 |
+
img1 = Image.open(img1)
|
52 |
+
|
53 |
+
if(img1.mode != 'RGB'):
|
54 |
+
img1 = img1.convert('RGB')
|
55 |
+
img1 = np.array(img1)
|
56 |
+
|
57 |
+
if HFile1 is not None:
|
58 |
+
img1 = cv2.warpPerspective(img1, H1, dsize=(400,400))
|
59 |
+
|
60 |
+
#### Visualization ####
|
61 |
+
# cv2.imshow("Image", cv2.cvtColor(img1, cv2.COLOR_BGR2RGB))
|
62 |
+
# cv2.waitKey(0)
|
63 |
+
|
64 |
+
rgbFile2 = img2
|
65 |
+
img2 = Image.open(img2)
|
66 |
+
|
67 |
+
if(img2.mode != 'RGB'):
|
68 |
+
img2 = img2.convert('RGB')
|
69 |
+
img2 = np.array(img2)
|
70 |
+
|
71 |
+
if HFile2 is not None:
|
72 |
+
img2 = cv2.warpPerspective(img2, H2, dsize=(400,400))
|
73 |
+
|
74 |
+
#### Visualization ####
|
75 |
+
# cv2.imshow("Image", cv2.cvtColor(img2, cv2.COLOR_BGR2RGB))
|
76 |
+
# cv2.waitKey(0)
|
77 |
+
|
78 |
+
# surf = cv2.xfeatures2d.SURF_create(100) # SURF
|
79 |
+
surf = cv2.xfeatures2d.SIFT_create()
|
80 |
+
|
81 |
+
kp1, des1 = surf.detectAndCompute(img1, None)
|
82 |
+
kp2, des2 = surf.detectAndCompute(img2, None)
|
83 |
+
|
84 |
+
matches = mnn_matcher(
|
85 |
+
torch.from_numpy(des1).float().to(device=device),
|
86 |
+
torch.from_numpy(des2).float().to(device=device)
|
87 |
+
)
|
88 |
+
|
89 |
+
src_pts = np.float32([ kp1[m[0]].pt for m in matches ]).reshape(-1, 2)
|
90 |
+
dst_pts = np.float32([ kp2[m[1]].pt for m in matches ]).reshape(-1, 2)
|
91 |
+
|
92 |
+
if(src_pts.shape[0] < 5 or dst_pts.shape[0] < 5):
|
93 |
+
return [], []
|
94 |
+
|
95 |
+
H, inliers = pydegensac.findHomography(src_pts, dst_pts, 8.0, 0.99, 10000)
|
96 |
+
|
97 |
+
n_inliers = np.sum(inliers)
|
98 |
+
|
99 |
+
inlier_keypoints_left = [cv2.KeyPoint(point[0], point[1], 1) for point in src_pts[inliers]]
|
100 |
+
inlier_keypoints_right = [cv2.KeyPoint(point[0], point[1], 1) for point in dst_pts[inliers]]
|
101 |
+
placeholder_matches = [cv2.DMatch(idx, idx, 1) for idx in range(n_inliers)]
|
102 |
+
|
103 |
+
#### Visualization ####
|
104 |
+
image3 = cv2.drawMatches(img1, inlier_keypoints_left, img2, inlier_keypoints_right, placeholder_matches, None)
|
105 |
+
image3 = cv2.cvtColor(image3, cv2.COLOR_BGR2RGB)
|
106 |
+
# cv2.imshow('Matches', image3)
|
107 |
+
# cv2.waitKey()
|
108 |
+
|
109 |
+
src_pts = np.float32([ inlier_keypoints_left[m.queryIdx].pt for m in placeholder_matches ]).reshape(-1, 2)
|
110 |
+
dst_pts = np.float32([ inlier_keypoints_right[m.trainIdx].pt for m in placeholder_matches ]).reshape(-1, 2)
|
111 |
+
|
112 |
+
if HFile1 is None:
|
113 |
+
return src_pts, dst_pts, image3, image3
|
114 |
+
|
115 |
+
orgSrc, orgDst = orgKeypoints(src_pts, dst_pts, H1, H2)
|
116 |
+
matchImg = drawOrg(cv2.imread(rgbFile1), cv2.imread(rgbFile2), orgSrc, orgDst)
|
117 |
+
|
118 |
+
return orgSrc, orgDst, matchImg, image3
|
119 |
+
|
120 |
+
|
121 |
+
def orgKeypoints(src_pts, dst_pts, H1, H2):
|
122 |
+
ones = np.ones((src_pts.shape[0], 1))
|
123 |
+
|
124 |
+
src_pts = np.hstack((src_pts, ones))
|
125 |
+
dst_pts = np.hstack((dst_pts, ones))
|
126 |
+
|
127 |
+
orgSrc = np.linalg.inv(H1) @ src_pts.T
|
128 |
+
orgDst = np.linalg.inv(H2) @ dst_pts.T
|
129 |
+
|
130 |
+
orgSrc = orgSrc/orgSrc[2, :]
|
131 |
+
orgDst = orgDst/orgDst[2, :]
|
132 |
+
|
133 |
+
orgSrc = np.asarray(orgSrc)[0:2, :]
|
134 |
+
orgDst = np.asarray(orgDst)[0:2, :]
|
135 |
+
|
136 |
+
return orgSrc, orgDst
|
137 |
+
|
138 |
+
|
139 |
+
def drawOrg(image1, image2, orgSrc, orgDst):
|
140 |
+
img1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)
|
141 |
+
img2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)
|
142 |
+
|
143 |
+
for i in range(orgSrc.shape[1]):
|
144 |
+
im1 = cv2.circle(img1, (int(orgSrc[0, i]), int(orgSrc[1, i])), 3, (0, 0, 255), 1)
|
145 |
+
for i in range(orgDst.shape[1]):
|
146 |
+
im2 = cv2.circle(img2, (int(orgDst[0, i]), int(orgDst[1, i])), 3, (0, 0, 255), 1)
|
147 |
+
|
148 |
+
im4 = cv2.hconcat([im1, im2])
|
149 |
+
for i in range(orgSrc.shape[1]):
|
150 |
+
im4 = cv2.line(im4, (int(orgSrc[0, i]), int(orgSrc[1, i])), (int(orgDst[0, i]) + im1.shape[1], int(orgDst[1, i])), (0, 255, 0), 1)
|
151 |
+
im4 = cv2.cvtColor(im4, cv2.COLOR_BGR2RGB)
|
152 |
+
# cv2.imshow("Image", im4)
|
153 |
+
# cv2.waitKey(0)
|
154 |
+
|
155 |
+
return im4
|
156 |
+
|
157 |
+
|
158 |
+
|
159 |
+
def getPerspKeypoints(rgbFile1, rgbFile2, HFile1, HFile2, model, device):
|
160 |
+
if HFile1 is None:
|
161 |
+
igp1, img1 = read_and_process_image(rgbFile1, H=None)
|
162 |
+
else:
|
163 |
+
H1 = np.load(HFile1)
|
164 |
+
igp1, img1 = read_and_process_image(rgbFile1, H=H1)
|
165 |
+
|
166 |
+
c,h,w = igp1.shape
|
167 |
+
|
168 |
+
if HFile2 is None:
|
169 |
+
igp2, img2 = read_and_process_image(rgbFile2, H=None)
|
170 |
+
else:
|
171 |
+
H2 = np.load(HFile2)
|
172 |
+
igp2, img2 = read_and_process_image(rgbFile2, H=H2)
|
173 |
+
|
174 |
+
feat1 = extractSingle(igp1, model, device)
|
175 |
+
feat2 = extractSingle(igp2, model, device)
|
176 |
+
|
177 |
+
matches = mnn_matcher(
|
178 |
+
torch.from_numpy(feat1['descriptors']).to(device=device),
|
179 |
+
torch.from_numpy(feat2['descriptors']).to(device=device),
|
180 |
+
)
|
181 |
+
pos_a = feat1["keypoints"][matches[:, 0], : 2]
|
182 |
+
pos_b = feat2["keypoints"][matches[:, 1], : 2]
|
183 |
+
|
184 |
+
H, inliers = pydegensac.findHomography(pos_a, pos_b, 8.0, 0.99, 10000)
|
185 |
+
pos_a = pos_a[inliers]
|
186 |
+
pos_b = pos_b[inliers]
|
187 |
+
|
188 |
+
inlier_keypoints_left = [cv2.KeyPoint(point[0], point[1], 1) for point in pos_a]
|
189 |
+
inlier_keypoints_right = [cv2.KeyPoint(point[0], point[1], 1) for point in pos_b]
|
190 |
+
placeholder_matches = [cv2.DMatch(idx, idx, 1) for idx in range(len(pos_a))]
|
191 |
+
|
192 |
+
image3 = cv2.drawMatches(img1, inlier_keypoints_left, img2, inlier_keypoints_right, placeholder_matches, None, matchColor=[0, 255, 0])
|
193 |
+
image3 = cv2.cvtColor(image3, cv2.COLOR_BGR2RGB)
|
194 |
+
|
195 |
+
#### Visualization ####
|
196 |
+
# cv2.imshow('Matches', image3)
|
197 |
+
# cv2.waitKey()
|
198 |
+
|
199 |
+
if HFile1 is None:
|
200 |
+
return pos_a, pos_b, image3, image3
|
201 |
+
|
202 |
+
orgSrc, orgDst = orgKeypoints(pos_a, pos_b, H1, H2)
|
203 |
+
matchImg = drawOrg(cv2.imread(rgbFile1), cv2.imread(rgbFile2), orgSrc, orgDst) # Reproject matches to perspective View
|
204 |
+
|
205 |
+
return orgSrc, orgDst, matchImg, image3
|
206 |
+
|
207 |
+
|
208 |
+
|
209 |
+
###### Ensemble
|
210 |
+
def read_and_process_image(img_path, resize=None, H=None, h=None, w=None, preprocessing='caffe'):
|
211 |
+
img1 = Image.open(img_path)
|
212 |
+
if resize:
|
213 |
+
img1 = img1.resize(resize)
|
214 |
+
if(img1.mode != 'RGB'):
|
215 |
+
img1 = img1.convert('RGB')
|
216 |
+
img1 = np.array(img1)
|
217 |
+
if H is not None:
|
218 |
+
img1 = cv2.warpPerspective(img1, H, dsize=(400, 400))
|
219 |
+
# cv2.imshow("Image", cv2.cvtColor(img1, cv2.COLOR_BGR2RGB))
|
220 |
+
# cv2.waitKey(0)
|
221 |
+
igp1 = torch.from_numpy(preprocess_image(img1, preprocessing=preprocessing).astype(np.float32))
|
222 |
+
return igp1, img1
|
223 |
+
|
224 |
+
def mnn_matcher_scorer(descriptors_a, descriptors_b, k=np.inf):
|
225 |
+
device = descriptors_a.device
|
226 |
+
sim = descriptors_a @ descriptors_b.t()
|
227 |
+
val1, nn12 = torch.max(sim, dim=1)
|
228 |
+
val2, nn21 = torch.max(sim, dim=0)
|
229 |
+
ids1 = torch.arange(0, sim.shape[0], device=device)
|
230 |
+
mask = (ids1 == nn21[nn12])
|
231 |
+
matches = torch.stack([ids1[mask], nn12[mask]]).t()
|
232 |
+
remaining_matches_dist = val1[mask]
|
233 |
+
return matches, remaining_matches_dist
|
234 |
+
|
235 |
+
def mnn_matcher(descriptors_a, descriptors_b):
|
236 |
+
device = descriptors_a.device
|
237 |
+
sim = descriptors_a @ descriptors_b.t()
|
238 |
+
nn12 = torch.max(sim, dim=1)[1]
|
239 |
+
nn21 = torch.max(sim, dim=0)[1]
|
240 |
+
ids1 = torch.arange(0, sim.shape[0], device=device)
|
241 |
+
mask = (ids1 == nn21[nn12])
|
242 |
+
matches = torch.stack([ids1[mask], nn12[mask]])
|
243 |
+
return matches.t().data.cpu().numpy()
|
244 |
+
|
245 |
+
|
246 |
+
def getPerspKeypointsEnsemble(model1, model2, rgbFile1, rgbFile2, HFile1, HFile2, device):
|
247 |
+
if HFile1 is None:
|
248 |
+
igp1, img1 = read_and_process_image(rgbFile1, H=None)
|
249 |
+
else:
|
250 |
+
H1 = np.load(HFile1)
|
251 |
+
igp1, img1 = read_and_process_image(rgbFile1, H=H1)
|
252 |
+
|
253 |
+
c,h,w = igp1.shape
|
254 |
+
|
255 |
+
if HFile2 is None:
|
256 |
+
igp2, img2 = read_and_process_image(rgbFile2, H=None)
|
257 |
+
else:
|
258 |
+
H2 = np.load(HFile2)
|
259 |
+
igp2, img2 = read_and_process_image(rgbFile2, H=H2)
|
260 |
+
|
261 |
+
with torch.no_grad():
|
262 |
+
keypoints_a1, scores_a1, descriptors_a1 = process_multiscale(
|
263 |
+
igp1.to(device).unsqueeze(0),
|
264 |
+
model1,
|
265 |
+
scales=[1]
|
266 |
+
)
|
267 |
+
keypoints_a1 = keypoints_a1[:, [1, 0, 2]]
|
268 |
+
|
269 |
+
keypoints_a2, scores_a2, descriptors_a2 = process_multiscale(
|
270 |
+
igp1.to(device).unsqueeze(0),
|
271 |
+
model2,
|
272 |
+
scales=[1]
|
273 |
+
)
|
274 |
+
keypoints_a2 = keypoints_a2[:, [1, 0, 2]]
|
275 |
+
|
276 |
+
keypoints_b1, scores_b1, descriptors_b1 = process_multiscale(
|
277 |
+
igp2.to(device).unsqueeze(0),
|
278 |
+
model1,
|
279 |
+
scales=[1]
|
280 |
+
)
|
281 |
+
keypoints_b1 = keypoints_b1[:, [1, 0, 2]]
|
282 |
+
|
283 |
+
keypoints_b2, scores_b2, descriptors_b2 = process_multiscale(
|
284 |
+
igp2.to(device).unsqueeze(0),
|
285 |
+
model2,
|
286 |
+
scales=[1]
|
287 |
+
)
|
288 |
+
keypoints_b2 = keypoints_b2[:, [1, 0, 2]]
|
289 |
+
|
290 |
+
# calculating matches for both models
|
291 |
+
matches1, dist_1 = mnn_matcher_scorer(
|
292 |
+
torch.from_numpy(descriptors_a1).to(device=device),
|
293 |
+
torch.from_numpy(descriptors_b1).to(device=device),
|
294 |
+
# len(matches1)
|
295 |
+
)
|
296 |
+
matches2, dist_2 = mnn_matcher_scorer(
|
297 |
+
torch.from_numpy(descriptors_a2).to(device=device),
|
298 |
+
torch.from_numpy(descriptors_b2).to(device=device),
|
299 |
+
# len(matches1)
|
300 |
+
)
|
301 |
+
|
302 |
+
full_matches = torch.cat([matches1, matches2])
|
303 |
+
full_dist = torch.cat([dist_1, dist_2])
|
304 |
+
assert len(full_dist)==(len(dist_1)+len(dist_2)), "something wrong"
|
305 |
+
|
306 |
+
k_final = len(full_dist)//2
|
307 |
+
# k_final = len(full_dist)
|
308 |
+
# k_final = max(len(dist_1), len(dist_2))
|
309 |
+
top_k_mask = torch.topk(full_dist, k=k_final)[1]
|
310 |
+
first = []
|
311 |
+
second = []
|
312 |
+
|
313 |
+
for valid_id in top_k_mask:
|
314 |
+
if valid_id<len(dist_1):
|
315 |
+
first.append(valid_id)
|
316 |
+
else:
|
317 |
+
second.append(valid_id-len(dist_1))
|
318 |
+
# final_matches = full_matches[top_k_mask]
|
319 |
+
|
320 |
+
matches1 = matches1[torch.tensor(first, device=device).long()].data.cpu().numpy()
|
321 |
+
matches2 = matches2[torch.tensor(second, device=device).long()].data.cpu().numpy()
|
322 |
+
|
323 |
+
pos_a1 = keypoints_a1[matches1[:, 0], : 2]
|
324 |
+
pos_b1 = keypoints_b1[matches1[:, 1], : 2]
|
325 |
+
|
326 |
+
pos_a2 = keypoints_a2[matches2[:, 0], : 2]
|
327 |
+
pos_b2 = keypoints_b2[matches2[:, 1], : 2]
|
328 |
+
|
329 |
+
pos_a = np.concatenate([pos_a1, pos_a2], 0)
|
330 |
+
pos_b = np.concatenate([pos_b1, pos_b2], 0)
|
331 |
+
|
332 |
+
# pos_a, pos_b, inliers = apply_ransac(pos_a, pos_b)
|
333 |
+
H, inliers = pydegensac.findHomography(pos_a, pos_b, 8.0, 0.99, 10000)
|
334 |
+
pos_a = pos_a[inliers]
|
335 |
+
pos_b = pos_b[inliers]
|
336 |
+
|
337 |
+
inlier_keypoints_left = [cv2.KeyPoint(point[0], point[1], 1) for point in pos_a]
|
338 |
+
inlier_keypoints_right = [cv2.KeyPoint(point[0], point[1], 1) for point in pos_b]
|
339 |
+
placeholder_matches = [cv2.DMatch(idx, idx, 1) for idx in range(len(pos_a))]
|
340 |
+
|
341 |
+
image3 = cv2.drawMatches(img1, inlier_keypoints_left, img2, inlier_keypoints_right, placeholder_matches, None, matchColor=[0, 255, 0])
|
342 |
+
image3 = cv2.cvtColor(image3, cv2.COLOR_BGR2RGB)
|
343 |
+
# cv2.imshow('Matches', image3)
|
344 |
+
# cv2.waitKey()
|
345 |
+
|
346 |
+
|
347 |
+
orgSrc, orgDst = orgKeypoints(pos_a, pos_b, H1, H2)
|
348 |
+
matchImg = drawOrg(cv2.imread(rgbFile1), cv2.imread(rgbFile2), orgSrc, orgDst)
|
349 |
+
|
350 |
+
return orgSrc, orgDst, matchImg, image3
|
351 |
+
|
352 |
+
|
353 |
+
if __name__ == '__main__':
|
354 |
+
WEIGHTS = '../models/rord.pth'
|
355 |
+
|
356 |
+
srcR = argv[1]
|
357 |
+
trgR = argv[2]
|
358 |
+
srcH = argv[3]
|
359 |
+
trgH = argv[4]
|
360 |
+
|
361 |
+
orgSrc, orgDst = getPerspKeypoints(srcR, trgR, srcH, trgH, WEIGHTS, ('gpu'))
|
third_party/RoRD/lib/loss.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from lib.utils import (
|
10 |
+
grid_positions,
|
11 |
+
upscale_positions,
|
12 |
+
downscale_positions,
|
13 |
+
savefig,
|
14 |
+
imshow_image
|
15 |
+
)
|
16 |
+
from lib.exceptions import NoGradientError, EmptyTensorError
|
17 |
+
|
18 |
+
matplotlib.use('Agg')
|
19 |
+
|
20 |
+
|
21 |
+
def loss_function(
|
22 |
+
model, batch, device, margin=1, safe_radius=4, scaling_steps=3, plot=False, plot_path=None
|
23 |
+
):
|
24 |
+
output = model({
|
25 |
+
'image1': batch['image1'].to(device),
|
26 |
+
'image2': batch['image2'].to(device)
|
27 |
+
})
|
28 |
+
|
29 |
+
loss = torch.tensor(np.array([0], dtype=np.float32), device=device)
|
30 |
+
has_grad = False
|
31 |
+
|
32 |
+
n_valid_samples = 0
|
33 |
+
for idx_in_batch in range(batch['image1'].size(0)):
|
34 |
+
# Annotations
|
35 |
+
depth1 = batch['depth1'][idx_in_batch].to(device) # [h1, w1]
|
36 |
+
intrinsics1 = batch['intrinsics1'][idx_in_batch].to(device) # [3, 3]
|
37 |
+
pose1 = batch['pose1'][idx_in_batch].view(4, 4).to(device) # [4, 4]
|
38 |
+
bbox1 = batch['bbox1'][idx_in_batch].to(device) # [2]
|
39 |
+
|
40 |
+
depth2 = batch['depth2'][idx_in_batch].to(device)
|
41 |
+
intrinsics2 = batch['intrinsics2'][idx_in_batch].to(device)
|
42 |
+
pose2 = batch['pose2'][idx_in_batch].view(4, 4).to(device)
|
43 |
+
bbox2 = batch['bbox2'][idx_in_batch].to(device)
|
44 |
+
|
45 |
+
# Network output
|
46 |
+
dense_features1 = output['dense_features1'][idx_in_batch]
|
47 |
+
c, h1, w1 = dense_features1.size()
|
48 |
+
scores1 = output['scores1'][idx_in_batch].view(-1)
|
49 |
+
|
50 |
+
dense_features2 = output['dense_features2'][idx_in_batch]
|
51 |
+
_, h2, w2 = dense_features2.size()
|
52 |
+
scores2 = output['scores2'][idx_in_batch]
|
53 |
+
|
54 |
+
all_descriptors1 = F.normalize(dense_features1.view(c, -1), dim=0)
|
55 |
+
descriptors1 = all_descriptors1
|
56 |
+
|
57 |
+
all_descriptors2 = F.normalize(dense_features2.view(c, -1), dim=0)
|
58 |
+
|
59 |
+
# Warp the positions from image 1 to image 2
|
60 |
+
fmap_pos1 = grid_positions(h1, w1, device)
|
61 |
+
pos1 = upscale_positions(fmap_pos1, scaling_steps=scaling_steps)
|
62 |
+
try:
|
63 |
+
pos1, pos2, ids = warp(
|
64 |
+
pos1,
|
65 |
+
depth1, intrinsics1, pose1, bbox1,
|
66 |
+
depth2, intrinsics2, pose2, bbox2
|
67 |
+
)
|
68 |
+
except EmptyTensorError:
|
69 |
+
continue
|
70 |
+
fmap_pos1 = fmap_pos1[:, ids]
|
71 |
+
descriptors1 = descriptors1[:, ids]
|
72 |
+
scores1 = scores1[ids]
|
73 |
+
|
74 |
+
# Skip the pair if not enough GT correspondences are available
|
75 |
+
if ids.size(0) < 128:
|
76 |
+
continue
|
77 |
+
|
78 |
+
# Descriptors at the corresponding positions
|
79 |
+
fmap_pos2 = torch.round(
|
80 |
+
downscale_positions(pos2, scaling_steps=scaling_steps)
|
81 |
+
).long()
|
82 |
+
descriptors2 = F.normalize(
|
83 |
+
dense_features2[:, fmap_pos2[0, :], fmap_pos2[1, :]],
|
84 |
+
dim=0
|
85 |
+
)
|
86 |
+
positive_distance = 2 - 2 * (
|
87 |
+
descriptors1.t().unsqueeze(1) @ descriptors2.t().unsqueeze(2)
|
88 |
+
).squeeze()
|
89 |
+
|
90 |
+
all_fmap_pos2 = grid_positions(h2, w2, device)
|
91 |
+
position_distance = torch.max(
|
92 |
+
torch.abs(
|
93 |
+
fmap_pos2.unsqueeze(2).float() -
|
94 |
+
all_fmap_pos2.unsqueeze(1)
|
95 |
+
),
|
96 |
+
dim=0
|
97 |
+
)[0]
|
98 |
+
is_out_of_safe_radius = position_distance > safe_radius
|
99 |
+
distance_matrix = 2 - 2 * (descriptors1.t() @ all_descriptors2)
|
100 |
+
negative_distance2 = torch.min(
|
101 |
+
distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
|
102 |
+
dim=1
|
103 |
+
)[0]
|
104 |
+
|
105 |
+
all_fmap_pos1 = grid_positions(h1, w1, device)
|
106 |
+
position_distance = torch.max(
|
107 |
+
torch.abs(
|
108 |
+
fmap_pos1.unsqueeze(2).float() -
|
109 |
+
all_fmap_pos1.unsqueeze(1)
|
110 |
+
),
|
111 |
+
dim=0
|
112 |
+
)[0]
|
113 |
+
is_out_of_safe_radius = position_distance > safe_radius
|
114 |
+
distance_matrix = 2 - 2 * (descriptors2.t() @ all_descriptors1)
|
115 |
+
negative_distance1 = torch.min(
|
116 |
+
distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
|
117 |
+
dim=1
|
118 |
+
)[0]
|
119 |
+
|
120 |
+
diff = positive_distance - torch.min(
|
121 |
+
negative_distance1, negative_distance2
|
122 |
+
)
|
123 |
+
|
124 |
+
scores2 = scores2[fmap_pos2[0, :], fmap_pos2[1, :]]
|
125 |
+
|
126 |
+
loss = loss + (
|
127 |
+
torch.sum(scores1 * scores2 * F.relu(margin + diff)) /
|
128 |
+
torch.sum(scores1 * scores2)
|
129 |
+
)
|
130 |
+
|
131 |
+
has_grad = True
|
132 |
+
n_valid_samples += 1
|
133 |
+
|
134 |
+
# print(plot, batch['batch_idx'],batch['log_interval'])
|
135 |
+
if plot and batch['batch_idx'] % batch['log_interval'] == 0:
|
136 |
+
# print("should plot")
|
137 |
+
pos1_aux = pos1.cpu().numpy()
|
138 |
+
pos2_aux = pos2.cpu().numpy()
|
139 |
+
k = pos1_aux.shape[1]
|
140 |
+
col = np.random.rand(k, 3)
|
141 |
+
n_sp = 4
|
142 |
+
plt.figure()
|
143 |
+
plt.subplot(1, n_sp, 1)
|
144 |
+
im1 = imshow_image(
|
145 |
+
batch['image1'][idx_in_batch].cpu().numpy(),
|
146 |
+
preprocessing=batch['preprocessing']
|
147 |
+
)
|
148 |
+
plt.imshow(im1)
|
149 |
+
plt.scatter(
|
150 |
+
pos1_aux[1, :], pos1_aux[0, :],
|
151 |
+
s=0.25**2, c=col, marker=',', alpha=0.5
|
152 |
+
)
|
153 |
+
plt.axis('off')
|
154 |
+
plt.subplot(1, n_sp, 2)
|
155 |
+
plt.imshow(
|
156 |
+
output['scores1'][idx_in_batch].data.cpu().numpy(),
|
157 |
+
cmap='Reds'
|
158 |
+
)
|
159 |
+
plt.axis('off')
|
160 |
+
plt.subplot(1, n_sp, 3)
|
161 |
+
im2 = imshow_image(
|
162 |
+
batch['image2'][idx_in_batch].cpu().numpy(),
|
163 |
+
preprocessing=batch['preprocessing']
|
164 |
+
)
|
165 |
+
plt.imshow(im2)
|
166 |
+
plt.scatter(
|
167 |
+
pos2_aux[1, :], pos2_aux[0, :],
|
168 |
+
s=0.25**2, c=col, marker=',', alpha=0.5
|
169 |
+
)
|
170 |
+
plt.axis('off')
|
171 |
+
plt.subplot(1, n_sp, 4)
|
172 |
+
plt.imshow(
|
173 |
+
output['scores2'][idx_in_batch].data.cpu().numpy(),
|
174 |
+
cmap='Reds'
|
175 |
+
)
|
176 |
+
plt.axis('off')
|
177 |
+
savefig(os.path.join(plot_path, '%s.%02d.%02d.%d.png' % (
|
178 |
+
'train' if batch['train'] else 'valid',
|
179 |
+
batch['epoch_idx'],
|
180 |
+
batch['batch_idx'] // batch['log_interval'],
|
181 |
+
idx_in_batch
|
182 |
+
)), dpi=300)
|
183 |
+
plt.close()
|
184 |
+
|
185 |
+
if not has_grad:
|
186 |
+
raise NoGradientError
|
187 |
+
|
188 |
+
loss = loss / n_valid_samples
|
189 |
+
|
190 |
+
return loss
|
191 |
+
|
192 |
+
|
193 |
+
def interpolate_depth(pos, depth):
|
194 |
+
device = pos.device
|
195 |
+
|
196 |
+
ids = torch.arange(0, pos.size(1), device=device)
|
197 |
+
|
198 |
+
h, w = depth.size()
|
199 |
+
|
200 |
+
i = pos[0, :]
|
201 |
+
j = pos[1, :]
|
202 |
+
|
203 |
+
# Valid corners
|
204 |
+
i_top_left = torch.floor(i).long()
|
205 |
+
j_top_left = torch.floor(j).long()
|
206 |
+
valid_top_left = torch.min(i_top_left >= 0, j_top_left >= 0)
|
207 |
+
|
208 |
+
i_top_right = torch.floor(i).long()
|
209 |
+
j_top_right = torch.ceil(j).long()
|
210 |
+
valid_top_right = torch.min(i_top_right >= 0, j_top_right < w)
|
211 |
+
|
212 |
+
i_bottom_left = torch.ceil(i).long()
|
213 |
+
j_bottom_left = torch.floor(j).long()
|
214 |
+
valid_bottom_left = torch.min(i_bottom_left < h, j_bottom_left >= 0)
|
215 |
+
|
216 |
+
i_bottom_right = torch.ceil(i).long()
|
217 |
+
j_bottom_right = torch.ceil(j).long()
|
218 |
+
valid_bottom_right = torch.min(i_bottom_right < h, j_bottom_right < w)
|
219 |
+
|
220 |
+
valid_corners = torch.min(
|
221 |
+
torch.min(valid_top_left, valid_top_right),
|
222 |
+
torch.min(valid_bottom_left, valid_bottom_right)
|
223 |
+
)
|
224 |
+
|
225 |
+
i_top_left = i_top_left[valid_corners]
|
226 |
+
j_top_left = j_top_left[valid_corners]
|
227 |
+
|
228 |
+
i_top_right = i_top_right[valid_corners]
|
229 |
+
j_top_right = j_top_right[valid_corners]
|
230 |
+
|
231 |
+
i_bottom_left = i_bottom_left[valid_corners]
|
232 |
+
j_bottom_left = j_bottom_left[valid_corners]
|
233 |
+
|
234 |
+
i_bottom_right = i_bottom_right[valid_corners]
|
235 |
+
j_bottom_right = j_bottom_right[valid_corners]
|
236 |
+
|
237 |
+
ids = ids[valid_corners]
|
238 |
+
if ids.size(0) == 0:
|
239 |
+
raise EmptyTensorError
|
240 |
+
|
241 |
+
# Valid depth
|
242 |
+
valid_depth = torch.min(
|
243 |
+
torch.min(
|
244 |
+
depth[i_top_left, j_top_left] > 0,
|
245 |
+
depth[i_top_right, j_top_right] > 0
|
246 |
+
),
|
247 |
+
torch.min(
|
248 |
+
depth[i_bottom_left, j_bottom_left] > 0,
|
249 |
+
depth[i_bottom_right, j_bottom_right] > 0
|
250 |
+
)
|
251 |
+
)
|
252 |
+
|
253 |
+
i_top_left = i_top_left[valid_depth]
|
254 |
+
j_top_left = j_top_left[valid_depth]
|
255 |
+
|
256 |
+
i_top_right = i_top_right[valid_depth]
|
257 |
+
j_top_right = j_top_right[valid_depth]
|
258 |
+
|
259 |
+
i_bottom_left = i_bottom_left[valid_depth]
|
260 |
+
j_bottom_left = j_bottom_left[valid_depth]
|
261 |
+
|
262 |
+
i_bottom_right = i_bottom_right[valid_depth]
|
263 |
+
j_bottom_right = j_bottom_right[valid_depth]
|
264 |
+
|
265 |
+
ids = ids[valid_depth]
|
266 |
+
if ids.size(0) == 0:
|
267 |
+
raise EmptyTensorError
|
268 |
+
|
269 |
+
# Interpolation
|
270 |
+
i = i[ids]
|
271 |
+
j = j[ids]
|
272 |
+
dist_i_top_left = i - i_top_left.float()
|
273 |
+
dist_j_top_left = j - j_top_left.float()
|
274 |
+
w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left)
|
275 |
+
w_top_right = (1 - dist_i_top_left) * dist_j_top_left
|
276 |
+
w_bottom_left = dist_i_top_left * (1 - dist_j_top_left)
|
277 |
+
w_bottom_right = dist_i_top_left * dist_j_top_left
|
278 |
+
|
279 |
+
interpolated_depth = (
|
280 |
+
w_top_left * depth[i_top_left, j_top_left] +
|
281 |
+
w_top_right * depth[i_top_right, j_top_right] +
|
282 |
+
w_bottom_left * depth[i_bottom_left, j_bottom_left] +
|
283 |
+
w_bottom_right * depth[i_bottom_right, j_bottom_right]
|
284 |
+
)
|
285 |
+
|
286 |
+
pos = torch.cat([i.view(1, -1), j.view(1, -1)], dim=0)
|
287 |
+
|
288 |
+
return [interpolated_depth, pos, ids]
|
289 |
+
|
290 |
+
|
291 |
+
def uv_to_pos(uv):
|
292 |
+
return torch.cat([uv[1, :].view(1, -1), uv[0, :].view(1, -1)], dim=0)
|
293 |
+
|
294 |
+
|
295 |
+
def warp(
|
296 |
+
pos1,
|
297 |
+
depth1, intrinsics1, pose1, bbox1,
|
298 |
+
depth2, intrinsics2, pose2, bbox2
|
299 |
+
):
|
300 |
+
device = pos1.device
|
301 |
+
|
302 |
+
Z1, pos1, ids = interpolate_depth(pos1, depth1)
|
303 |
+
|
304 |
+
# COLMAP convention
|
305 |
+
u1 = pos1[1, :] + bbox1[1] + .5
|
306 |
+
v1 = pos1[0, :] + bbox1[0] + .5
|
307 |
+
|
308 |
+
X1 = (u1 - intrinsics1[0, 2]) * (Z1 / intrinsics1[0, 0])
|
309 |
+
Y1 = (v1 - intrinsics1[1, 2]) * (Z1 / intrinsics1[1, 1])
|
310 |
+
|
311 |
+
XYZ1_hom = torch.cat([
|
312 |
+
X1.view(1, -1),
|
313 |
+
Y1.view(1, -1),
|
314 |
+
Z1.view(1, -1),
|
315 |
+
torch.ones(1, Z1.size(0), device=device)
|
316 |
+
], dim=0)
|
317 |
+
XYZ2_hom = torch.chain_matmul(pose2, torch.inverse(pose1), XYZ1_hom)
|
318 |
+
XYZ2 = XYZ2_hom[: -1, :] / XYZ2_hom[-1, :].view(1, -1)
|
319 |
+
|
320 |
+
uv2_hom = torch.matmul(intrinsics2, XYZ2)
|
321 |
+
uv2 = uv2_hom[: -1, :] / uv2_hom[-1, :].view(1, -1)
|
322 |
+
|
323 |
+
u2 = uv2[0, :] - bbox2[1] - .5
|
324 |
+
v2 = uv2[1, :] - bbox2[0] - .5
|
325 |
+
uv2 = torch.cat([u2.view(1, -1), v2.view(1, -1)], dim=0)
|
326 |
+
|
327 |
+
annotated_depth, pos2, new_ids = interpolate_depth(uv_to_pos(uv2), depth2)
|
328 |
+
|
329 |
+
ids = ids[new_ids]
|
330 |
+
pos1 = pos1[:, new_ids]
|
331 |
+
estimated_depth = XYZ2[2, new_ids]
|
332 |
+
|
333 |
+
inlier_mask = torch.abs(estimated_depth - annotated_depth) < 0.05
|
334 |
+
|
335 |
+
ids = ids[inlier_mask]
|
336 |
+
if ids.size(0) == 0:
|
337 |
+
raise EmptyTensorError
|
338 |
+
|
339 |
+
pos2 = pos2[:, inlier_mask]
|
340 |
+
pos1 = pos1[:, inlier_mask]
|
341 |
+
|
342 |
+
return pos1, pos2, ids
|
third_party/RoRD/lib/losses/lossPhotoTourism.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
from sys import exit
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
from lib.utils import (
|
12 |
+
grid_positions,
|
13 |
+
upscale_positions,
|
14 |
+
downscale_positions,
|
15 |
+
savefig,
|
16 |
+
imshow_image
|
17 |
+
)
|
18 |
+
from lib.exceptions import NoGradientError, EmptyTensorError
|
19 |
+
|
20 |
+
matplotlib.use('Agg')
|
21 |
+
|
22 |
+
|
23 |
+
def loss_function(
|
24 |
+
model, batch, device, margin=1, safe_radius=4, scaling_steps=3, plot=False, plot_path=None
|
25 |
+
):
|
26 |
+
output = model({
|
27 |
+
'image1': batch['image1'].to(device),
|
28 |
+
'image2': batch['image2'].to(device)
|
29 |
+
})
|
30 |
+
|
31 |
+
|
32 |
+
loss = torch.tensor(np.array([0], dtype=np.float32), device=device)
|
33 |
+
has_grad = False
|
34 |
+
|
35 |
+
n_valid_samples = 0
|
36 |
+
for idx_in_batch in range(batch['image1'].size(0)):
|
37 |
+
# Network output
|
38 |
+
dense_features1 = output['dense_features1'][idx_in_batch]
|
39 |
+
c, h1, w1 = dense_features1.size()
|
40 |
+
scores1 = output['scores1'][idx_in_batch].view(-1)
|
41 |
+
|
42 |
+
dense_features2 = output['dense_features2'][idx_in_batch]
|
43 |
+
_, h2, w2 = dense_features2.size()
|
44 |
+
scores2 = output['scores2'][idx_in_batch]
|
45 |
+
|
46 |
+
all_descriptors1 = F.normalize(dense_features1.view(c, -1), dim=0)
|
47 |
+
descriptors1 = all_descriptors1
|
48 |
+
|
49 |
+
all_descriptors2 = F.normalize(dense_features2.view(c, -1), dim=0)
|
50 |
+
|
51 |
+
fmap_pos1 = grid_positions(h1, w1, device)
|
52 |
+
|
53 |
+
pos1 = batch['pos1'][idx_in_batch].to(device)
|
54 |
+
pos2 = batch['pos2'][idx_in_batch].to(device)
|
55 |
+
|
56 |
+
ids = idsAlign(pos1, device, h1, w1)
|
57 |
+
|
58 |
+
fmap_pos1 = fmap_pos1[:, ids]
|
59 |
+
descriptors1 = descriptors1[:, ids]
|
60 |
+
scores1 = scores1[ids]
|
61 |
+
|
62 |
+
# Skip the pair if not enough GT correspondences are available
|
63 |
+
if ids.size(0) < 128:
|
64 |
+
continue
|
65 |
+
|
66 |
+
# Descriptors at the corresponding positions
|
67 |
+
fmap_pos2 = torch.round(
|
68 |
+
downscale_positions(pos2, scaling_steps=scaling_steps)
|
69 |
+
).long()
|
70 |
+
|
71 |
+
descriptors2 = F.normalize(
|
72 |
+
dense_features2[:, fmap_pos2[0, :], fmap_pos2[1, :]],
|
73 |
+
dim=0
|
74 |
+
)
|
75 |
+
positive_distance = 2 - 2 * (
|
76 |
+
descriptors1.t().unsqueeze(1) @ descriptors2.t().unsqueeze(2)
|
77 |
+
).squeeze()
|
78 |
+
|
79 |
+
all_fmap_pos2 = grid_positions(h2, w2, device)
|
80 |
+
position_distance = torch.max(
|
81 |
+
torch.abs(
|
82 |
+
fmap_pos2.unsqueeze(2).float() -
|
83 |
+
all_fmap_pos2.unsqueeze(1)
|
84 |
+
),
|
85 |
+
dim=0
|
86 |
+
)[0]
|
87 |
+
is_out_of_safe_radius = position_distance > safe_radius
|
88 |
+
|
89 |
+
distance_matrix = 2 - 2 * (descriptors1.t() @ all_descriptors2)
|
90 |
+
|
91 |
+
negative_distance2 = torch.min(
|
92 |
+
distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
|
93 |
+
dim=1
|
94 |
+
)[0]
|
95 |
+
|
96 |
+
all_fmap_pos1 = grid_positions(h1, w1, device)
|
97 |
+
position_distance = torch.max(
|
98 |
+
torch.abs(
|
99 |
+
fmap_pos1.unsqueeze(2).float() -
|
100 |
+
all_fmap_pos1.unsqueeze(1)
|
101 |
+
),
|
102 |
+
dim=0
|
103 |
+
)[0]
|
104 |
+
is_out_of_safe_radius = position_distance > safe_radius
|
105 |
+
|
106 |
+
distance_matrix = 2 - 2 * (descriptors2.t() @ all_descriptors1)
|
107 |
+
|
108 |
+
negative_distance1 = torch.min(
|
109 |
+
distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
|
110 |
+
dim=1
|
111 |
+
)[0]
|
112 |
+
|
113 |
+
diff = positive_distance - torch.min(
|
114 |
+
negative_distance1, negative_distance2
|
115 |
+
)
|
116 |
+
|
117 |
+
scores2 = scores2[fmap_pos2[0, :], fmap_pos2[1, :]]
|
118 |
+
|
119 |
+
loss = loss + (
|
120 |
+
torch.sum(scores1 * scores2 * F.relu(margin + diff)) /
|
121 |
+
(torch.sum(scores1 * scores2) )
|
122 |
+
)
|
123 |
+
|
124 |
+
has_grad = True
|
125 |
+
n_valid_samples += 1
|
126 |
+
|
127 |
+
if plot and batch['batch_idx'] % batch['log_interval'] == 0:
|
128 |
+
drawTraining(batch['image1'], batch['image2'], pos1, pos2, batch, idx_in_batch, output, save=True, plot_path=plot_path)
|
129 |
+
|
130 |
+
if not has_grad:
|
131 |
+
raise NoGradientError
|
132 |
+
|
133 |
+
loss = loss / (n_valid_samples )
|
134 |
+
|
135 |
+
return loss
|
136 |
+
|
137 |
+
|
138 |
+
def idsAlign(pos1, device, h1, w1):
|
139 |
+
pos1D = downscale_positions(pos1, scaling_steps=3)
|
140 |
+
row = pos1D[0, :]
|
141 |
+
col = pos1D[1, :]
|
142 |
+
|
143 |
+
ids = []
|
144 |
+
|
145 |
+
for i in range(row.shape[0]):
|
146 |
+
|
147 |
+
index = ((w1) * (row[i])) + (col[i])
|
148 |
+
ids.append(index)
|
149 |
+
|
150 |
+
ids = torch.round(torch.Tensor(ids)).long().to(device)
|
151 |
+
|
152 |
+
return ids
|
153 |
+
|
154 |
+
|
155 |
+
def drawTraining(image1, image2, pos1, pos2, batch, idx_in_batch, output, save=False, plot_path="train_viz"):
|
156 |
+
pos1_aux = pos1.cpu().numpy()
|
157 |
+
pos2_aux = pos2.cpu().numpy()
|
158 |
+
|
159 |
+
k = pos1_aux.shape[1]
|
160 |
+
col = np.random.rand(k, 3)
|
161 |
+
n_sp = 4
|
162 |
+
plt.figure()
|
163 |
+
plt.subplot(1, n_sp, 1)
|
164 |
+
im1 = imshow_image(
|
165 |
+
image1[0].cpu().numpy(),
|
166 |
+
preprocessing=batch['preprocessing']
|
167 |
+
)
|
168 |
+
plt.imshow(im1)
|
169 |
+
plt.scatter(
|
170 |
+
pos1_aux[1, :], pos1_aux[0, :],
|
171 |
+
s=0.25**2, c=col, marker=',', alpha=0.5
|
172 |
+
)
|
173 |
+
plt.axis('off')
|
174 |
+
plt.subplot(1, n_sp, 2)
|
175 |
+
plt.imshow(
|
176 |
+
output['scores1'][idx_in_batch].data.cpu().numpy(),
|
177 |
+
cmap='Reds'
|
178 |
+
)
|
179 |
+
plt.axis('off')
|
180 |
+
plt.subplot(1, n_sp, 3)
|
181 |
+
im2 = imshow_image(
|
182 |
+
image2[0].cpu().numpy(),
|
183 |
+
preprocessing=batch['preprocessing']
|
184 |
+
)
|
185 |
+
plt.imshow(im2)
|
186 |
+
plt.scatter(
|
187 |
+
pos2_aux[1, :], pos2_aux[0, :],
|
188 |
+
s=0.25**2, c=col, marker=',', alpha=0.5
|
189 |
+
)
|
190 |
+
plt.axis('off')
|
191 |
+
plt.subplot(1, n_sp, 4)
|
192 |
+
plt.imshow(
|
193 |
+
output['scores2'][idx_in_batch].data.cpu().numpy(),
|
194 |
+
cmap='Reds'
|
195 |
+
)
|
196 |
+
plt.axis('off')
|
197 |
+
|
198 |
+
if(save == True):
|
199 |
+
savefig(plot_path+'/%s.%02d.%02d.%d.png' % (
|
200 |
+
'train' if batch['train'] else 'valid',
|
201 |
+
batch['epoch_idx'],
|
202 |
+
batch['batch_idx'] // batch['log_interval'],
|
203 |
+
idx_in_batch
|
204 |
+
), dpi=300)
|
205 |
+
else:
|
206 |
+
plt.show()
|
207 |
+
|
208 |
+
plt.close()
|
209 |
+
|
210 |
+
im1 = cv2.cvtColor(im1, cv2.COLOR_BGR2RGB)
|
211 |
+
im2 = cv2.cvtColor(im2, cv2.COLOR_BGR2RGB)
|
212 |
+
|
213 |
+
for i in range(0, pos1_aux.shape[1], 5):
|
214 |
+
im1 = cv2.circle(im1, (pos1_aux[1, i], pos1_aux[0, i]), 1, (0, 0, 255), 2)
|
215 |
+
for i in range(0, pos2_aux.shape[1], 5):
|
216 |
+
im2 = cv2.circle(im2, (pos2_aux[1, i], pos2_aux[0, i]), 1, (0, 0, 255), 2)
|
217 |
+
|
218 |
+
im3 = cv2.hconcat([im1, im2])
|
219 |
+
|
220 |
+
for i in range(0, pos1_aux.shape[1], 5):
|
221 |
+
im3 = cv2.line(im3, (int(pos1_aux[1, i]), int(pos1_aux[0, i])), (int(pos2_aux[1, i]) + im1.shape[1], int(pos2_aux[0, i])), (0, 255, 0), 1)
|
222 |
+
|
223 |
+
if(save == True):
|
224 |
+
cv2.imwrite(plot_path+'/%s.%02d.%02d.%d.png' % (
|
225 |
+
'train_corr' if batch['train'] else 'valid',
|
226 |
+
batch['epoch_idx'],
|
227 |
+
batch['batch_idx'] // batch['log_interval'],
|
228 |
+
idx_in_batch
|
229 |
+
), im3)
|
230 |
+
else:
|
231 |
+
cv2.imshow('Image', im3)
|
232 |
+
cv2.waitKey(0)
|
third_party/RoRD/lib/model.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
import torchvision.models as models
|
6 |
+
|
7 |
+
|
8 |
+
class DenseFeatureExtractionModule(nn.Module):
|
9 |
+
def __init__(self, finetune_feature_extraction=False, use_cuda=True):
|
10 |
+
super(DenseFeatureExtractionModule, self).__init__()
|
11 |
+
|
12 |
+
model = models.vgg16()
|
13 |
+
vgg16_layers = [
|
14 |
+
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2',
|
15 |
+
'pool1',
|
16 |
+
'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2',
|
17 |
+
'pool2',
|
18 |
+
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3',
|
19 |
+
'pool3',
|
20 |
+
'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3',
|
21 |
+
'pool4',
|
22 |
+
'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
|
23 |
+
'pool5'
|
24 |
+
]
|
25 |
+
conv4_3_idx = vgg16_layers.index('conv4_3')
|
26 |
+
|
27 |
+
self.model = nn.Sequential(
|
28 |
+
*list(model.features.children())[: conv4_3_idx + 1]
|
29 |
+
)
|
30 |
+
|
31 |
+
self.num_channels = 512
|
32 |
+
|
33 |
+
# Fix forward parameters
|
34 |
+
for param in self.model.parameters():
|
35 |
+
param.requires_grad = False
|
36 |
+
if finetune_feature_extraction:
|
37 |
+
# Unlock conv4_3
|
38 |
+
for param in list(self.model.parameters())[-2 :]:
|
39 |
+
param.requires_grad = True
|
40 |
+
|
41 |
+
if use_cuda:
|
42 |
+
self.model = self.model.cuda()
|
43 |
+
|
44 |
+
def forward(self, batch):
|
45 |
+
output = self.model(batch)
|
46 |
+
return output
|
47 |
+
|
48 |
+
|
49 |
+
class SoftDetectionModule(nn.Module):
|
50 |
+
def __init__(self, soft_local_max_size=3):
|
51 |
+
super(SoftDetectionModule, self).__init__()
|
52 |
+
|
53 |
+
self.soft_local_max_size = soft_local_max_size
|
54 |
+
|
55 |
+
self.pad = self.soft_local_max_size // 2
|
56 |
+
|
57 |
+
def forward(self, batch):
|
58 |
+
b = batch.size(0)
|
59 |
+
|
60 |
+
batch = F.relu(batch)
|
61 |
+
|
62 |
+
max_per_sample = torch.max(batch.view(b, -1), dim=1)[0]
|
63 |
+
exp = torch.exp(batch / max_per_sample.view(b, 1, 1, 1))
|
64 |
+
sum_exp = (
|
65 |
+
self.soft_local_max_size ** 2 *
|
66 |
+
F.avg_pool2d(
|
67 |
+
F.pad(exp, [self.pad] * 4, mode='constant', value=1.),
|
68 |
+
self.soft_local_max_size, stride=1
|
69 |
+
)
|
70 |
+
)
|
71 |
+
local_max_score = exp / sum_exp
|
72 |
+
|
73 |
+
depth_wise_max = torch.max(batch, dim=1)[0]
|
74 |
+
depth_wise_max_score = batch / depth_wise_max.unsqueeze(1)
|
75 |
+
|
76 |
+
all_scores = local_max_score * depth_wise_max_score
|
77 |
+
score = torch.max(all_scores, dim=1)[0]
|
78 |
+
|
79 |
+
score = score / torch.sum(score.view(b, -1), dim=1).view(b, 1, 1)
|
80 |
+
|
81 |
+
return score
|
82 |
+
|
83 |
+
|
84 |
+
class D2Net(nn.Module):
|
85 |
+
def __init__(self, model_file=None, use_cuda=True):
|
86 |
+
super(D2Net, self).__init__()
|
87 |
+
|
88 |
+
self.dense_feature_extraction = DenseFeatureExtractionModule(
|
89 |
+
finetune_feature_extraction=True,
|
90 |
+
use_cuda=use_cuda
|
91 |
+
)
|
92 |
+
|
93 |
+
self.detection = SoftDetectionModule()
|
94 |
+
|
95 |
+
if model_file is not None:
|
96 |
+
if use_cuda:
|
97 |
+
self.load_state_dict(torch.load(model_file)['model'])
|
98 |
+
else:
|
99 |
+
self.load_state_dict(torch.load(model_file, map_location='cpu')['model'])
|
100 |
+
|
101 |
+
def forward(self, batch):
|
102 |
+
b = batch['image1'].size(0)
|
103 |
+
|
104 |
+
dense_features = self.dense_feature_extraction(
|
105 |
+
torch.cat([batch['image1'], batch['image2']], dim=0)
|
106 |
+
)
|
107 |
+
|
108 |
+
scores = self.detection(dense_features)
|
109 |
+
|
110 |
+
dense_features1 = dense_features[: b, :, :, :]
|
111 |
+
dense_features2 = dense_features[b :, :, :, :]
|
112 |
+
|
113 |
+
scores1 = scores[: b, :, :]
|
114 |
+
scores2 = scores[b :, :, :]
|
115 |
+
|
116 |
+
return {
|
117 |
+
'dense_features1': dense_features1,
|
118 |
+
'scores1': scores1,
|
119 |
+
'dense_features2': dense_features2,
|
120 |
+
'scores2': scores2
|
121 |
+
}
|
third_party/RoRD/lib/model_test.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class DenseFeatureExtractionModule(nn.Module):
|
7 |
+
def __init__(self, use_relu=True, use_cuda=True):
|
8 |
+
super(DenseFeatureExtractionModule, self).__init__()
|
9 |
+
|
10 |
+
self.model = nn.Sequential(
|
11 |
+
nn.Conv2d(3, 64, 3, padding=1),
|
12 |
+
nn.ReLU(inplace=True),
|
13 |
+
nn.Conv2d(64, 64, 3, padding=1),
|
14 |
+
nn.ReLU(inplace=True),
|
15 |
+
nn.MaxPool2d(2, stride=2),
|
16 |
+
nn.Conv2d(64, 128, 3, padding=1),
|
17 |
+
nn.ReLU(inplace=True),
|
18 |
+
nn.Conv2d(128, 128, 3, padding=1),
|
19 |
+
nn.ReLU(inplace=True),
|
20 |
+
nn.MaxPool2d(2, stride=2),
|
21 |
+
nn.Conv2d(128, 256, 3, padding=1),
|
22 |
+
nn.ReLU(inplace=True),
|
23 |
+
nn.Conv2d(256, 256, 3, padding=1),
|
24 |
+
nn.ReLU(inplace=True),
|
25 |
+
nn.Conv2d(256, 256, 3, padding=1),
|
26 |
+
nn.ReLU(inplace=True),
|
27 |
+
nn.AvgPool2d(2, stride=1),
|
28 |
+
nn.Conv2d(256, 512, 3, padding=2, dilation=2),
|
29 |
+
nn.ReLU(inplace=True),
|
30 |
+
nn.Conv2d(512, 512, 3, padding=2, dilation=2),
|
31 |
+
nn.ReLU(inplace=True),
|
32 |
+
nn.Conv2d(512, 512, 3, padding=2, dilation=2),
|
33 |
+
)
|
34 |
+
self.num_channels = 512
|
35 |
+
|
36 |
+
self.use_relu = use_relu
|
37 |
+
|
38 |
+
if use_cuda:
|
39 |
+
self.model = self.model.cuda()
|
40 |
+
|
41 |
+
def forward(self, batch):
|
42 |
+
output = self.model(batch)
|
43 |
+
if self.use_relu:
|
44 |
+
output = F.relu(output)
|
45 |
+
return output
|
46 |
+
|
47 |
+
|
48 |
+
class D2Net(nn.Module):
|
49 |
+
def __init__(self, model_file=None, use_relu=True, use_cuda=False):
|
50 |
+
super(D2Net, self).__init__()
|
51 |
+
|
52 |
+
self.dense_feature_extraction = DenseFeatureExtractionModule(
|
53 |
+
use_relu=use_relu, use_cuda=use_cuda
|
54 |
+
)
|
55 |
+
|
56 |
+
self.detection = HardDetectionModule()
|
57 |
+
|
58 |
+
self.localization = HandcraftedLocalizationModule()
|
59 |
+
|
60 |
+
if model_file is not None:
|
61 |
+
if use_cuda:
|
62 |
+
self.load_state_dict(torch.load(model_file)['model'])
|
63 |
+
else:
|
64 |
+
self.load_state_dict(torch.load(model_file, map_location='cpu')['model'])
|
65 |
+
|
66 |
+
def forward(self, batch):
|
67 |
+
_, _, h, w = batch.size()
|
68 |
+
dense_features = self.dense_feature_extraction(batch)
|
69 |
+
|
70 |
+
detections = self.detection(dense_features)
|
71 |
+
|
72 |
+
displacements = self.localization(dense_features)
|
73 |
+
|
74 |
+
return {
|
75 |
+
'dense_features': dense_features,
|
76 |
+
'detections': detections,
|
77 |
+
'displacements': displacements
|
78 |
+
}
|
79 |
+
|
80 |
+
|
81 |
+
class HardDetectionModule(nn.Module):
|
82 |
+
def __init__(self, edge_threshold=5):
|
83 |
+
super(HardDetectionModule, self).__init__()
|
84 |
+
|
85 |
+
self.edge_threshold = edge_threshold
|
86 |
+
|
87 |
+
self.dii_filter = torch.tensor(
|
88 |
+
[[0, 1., 0], [0, -2., 0], [0, 1., 0]]
|
89 |
+
).view(1, 1, 3, 3)
|
90 |
+
self.dij_filter = 0.25 * torch.tensor(
|
91 |
+
[[1., 0, -1.], [0, 0., 0], [-1., 0, 1.]]
|
92 |
+
).view(1, 1, 3, 3)
|
93 |
+
self.djj_filter = torch.tensor(
|
94 |
+
[[0, 0, 0], [1., -2., 1.], [0, 0, 0]]
|
95 |
+
).view(1, 1, 3, 3)
|
96 |
+
|
97 |
+
def forward(self, batch):
|
98 |
+
b, c, h, w = batch.size()
|
99 |
+
device = batch.device
|
100 |
+
|
101 |
+
depth_wise_max = torch.max(batch, dim=1)[0]
|
102 |
+
is_depth_wise_max = (batch == depth_wise_max)
|
103 |
+
del depth_wise_max
|
104 |
+
|
105 |
+
local_max = F.max_pool2d(batch, 3, stride=1, padding=1)
|
106 |
+
is_local_max = (batch == local_max)
|
107 |
+
del local_max
|
108 |
+
|
109 |
+
dii = F.conv2d(
|
110 |
+
batch.view(-1, 1, h, w), self.dii_filter.to(device), padding=1
|
111 |
+
).view(b, c, h, w)
|
112 |
+
dij = F.conv2d(
|
113 |
+
batch.view(-1, 1, h, w), self.dij_filter.to(device), padding=1
|
114 |
+
).view(b, c, h, w)
|
115 |
+
djj = F.conv2d(
|
116 |
+
batch.view(-1, 1, h, w), self.djj_filter.to(device), padding=1
|
117 |
+
).view(b, c, h, w)
|
118 |
+
|
119 |
+
det = dii * djj - dij * dij
|
120 |
+
tr = dii + djj
|
121 |
+
del dii, dij, djj
|
122 |
+
|
123 |
+
threshold = (self.edge_threshold + 1) ** 2 / self.edge_threshold
|
124 |
+
is_not_edge = torch.min(tr * tr / det <= threshold, det > 0)
|
125 |
+
|
126 |
+
detected = torch.min(
|
127 |
+
is_depth_wise_max,
|
128 |
+
torch.min(is_local_max, is_not_edge)
|
129 |
+
)
|
130 |
+
del is_depth_wise_max, is_local_max, is_not_edge
|
131 |
+
|
132 |
+
return detected
|
133 |
+
|
134 |
+
|
135 |
+
class HandcraftedLocalizationModule(nn.Module):
|
136 |
+
def __init__(self):
|
137 |
+
super(HandcraftedLocalizationModule, self).__init__()
|
138 |
+
|
139 |
+
self.di_filter = torch.tensor(
|
140 |
+
[[0, -0.5, 0], [0, 0, 0], [0, 0.5, 0]]
|
141 |
+
).view(1, 1, 3, 3)
|
142 |
+
self.dj_filter = torch.tensor(
|
143 |
+
[[0, 0, 0], [-0.5, 0, 0.5], [0, 0, 0]]
|
144 |
+
).view(1, 1, 3, 3)
|
145 |
+
|
146 |
+
self.dii_filter = torch.tensor(
|
147 |
+
[[0, 1., 0], [0, -2., 0], [0, 1., 0]]
|
148 |
+
).view(1, 1, 3, 3)
|
149 |
+
self.dij_filter = 0.25 * torch.tensor(
|
150 |
+
[[1., 0, -1.], [0, 0., 0], [-1., 0, 1.]]
|
151 |
+
).view(1, 1, 3, 3)
|
152 |
+
self.djj_filter = torch.tensor(
|
153 |
+
[[0, 0, 0], [1., -2., 1.], [0, 0, 0]]
|
154 |
+
).view(1, 1, 3, 3)
|
155 |
+
|
156 |
+
def forward(self, batch):
|
157 |
+
b, c, h, w = batch.size()
|
158 |
+
device = batch.device
|
159 |
+
|
160 |
+
dii = F.conv2d(
|
161 |
+
batch.view(-1, 1, h, w), self.dii_filter.to(device), padding=1
|
162 |
+
).view(b, c, h, w)
|
163 |
+
dij = F.conv2d(
|
164 |
+
batch.view(-1, 1, h, w), self.dij_filter.to(device), padding=1
|
165 |
+
).view(b, c, h, w)
|
166 |
+
djj = F.conv2d(
|
167 |
+
batch.view(-1, 1, h, w), self.djj_filter.to(device), padding=1
|
168 |
+
).view(b, c, h, w)
|
169 |
+
det = dii * djj - dij * dij
|
170 |
+
|
171 |
+
inv_hess_00 = djj / det
|
172 |
+
inv_hess_01 = -dij / det
|
173 |
+
inv_hess_11 = dii / det
|
174 |
+
del dii, dij, djj, det
|
175 |
+
|
176 |
+
di = F.conv2d(
|
177 |
+
batch.view(-1, 1, h, w), self.di_filter.to(device), padding=1
|
178 |
+
).view(b, c, h, w)
|
179 |
+
dj = F.conv2d(
|
180 |
+
batch.view(-1, 1, h, w), self.dj_filter.to(device), padding=1
|
181 |
+
).view(b, c, h, w)
|
182 |
+
|
183 |
+
step_i = -(inv_hess_00 * di + inv_hess_01 * dj)
|
184 |
+
step_j = -(inv_hess_01 * di + inv_hess_11 * dj)
|
185 |
+
del inv_hess_00, inv_hess_01, inv_hess_11, di, dj
|
186 |
+
|
187 |
+
return torch.stack([step_i, step_j], dim=1)
|
third_party/RoRD/lib/pyramid.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from lib.exceptions import EmptyTensorError
|
6 |
+
from lib.utils import interpolate_dense_features, upscale_positions
|
7 |
+
|
8 |
+
|
9 |
+
def process_multiscale(image, model, scales=[.5, 1, 2]):
|
10 |
+
b, _, h_init, w_init = image.size()
|
11 |
+
device = image.device
|
12 |
+
assert(b == 1)
|
13 |
+
|
14 |
+
all_keypoints = torch.zeros([3, 0])
|
15 |
+
all_descriptors = torch.zeros([
|
16 |
+
model.dense_feature_extraction.num_channels, 0
|
17 |
+
])
|
18 |
+
all_scores = torch.zeros(0)
|
19 |
+
|
20 |
+
previous_dense_features = None
|
21 |
+
banned = None
|
22 |
+
for idx, scale in enumerate(scales):
|
23 |
+
current_image = F.interpolate(
|
24 |
+
image, scale_factor=scale,
|
25 |
+
mode='bilinear', align_corners=True
|
26 |
+
)
|
27 |
+
_, _, h_level, w_level = current_image.size()
|
28 |
+
|
29 |
+
dense_features = model.dense_feature_extraction(current_image)
|
30 |
+
del current_image
|
31 |
+
|
32 |
+
_, _, h, w = dense_features.size()
|
33 |
+
|
34 |
+
# Sum the feature maps.
|
35 |
+
if previous_dense_features is not None:
|
36 |
+
dense_features += F.interpolate(
|
37 |
+
previous_dense_features, size=[h, w],
|
38 |
+
mode='bilinear', align_corners=True
|
39 |
+
)
|
40 |
+
del previous_dense_features
|
41 |
+
|
42 |
+
# Recover detections.
|
43 |
+
detections = model.detection(dense_features)
|
44 |
+
if banned is not None:
|
45 |
+
banned = F.interpolate(banned.float(), size=[h, w]).bool()
|
46 |
+
detections = torch.min(detections, ~banned)
|
47 |
+
banned = torch.max(
|
48 |
+
torch.max(detections, dim=1)[0].unsqueeze(1), banned
|
49 |
+
)
|
50 |
+
else:
|
51 |
+
banned = torch.max(detections, dim=1)[0].unsqueeze(1)
|
52 |
+
fmap_pos = torch.nonzero(detections[0].cpu()).t()
|
53 |
+
del detections
|
54 |
+
|
55 |
+
# Recover displacements.
|
56 |
+
displacements = model.localization(dense_features)[0].cpu()
|
57 |
+
displacements_i = displacements[
|
58 |
+
0, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :]
|
59 |
+
]
|
60 |
+
displacements_j = displacements[
|
61 |
+
1, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :]
|
62 |
+
]
|
63 |
+
del displacements
|
64 |
+
|
65 |
+
mask = torch.min(
|
66 |
+
torch.abs(displacements_i) < 0.5,
|
67 |
+
torch.abs(displacements_j) < 0.5
|
68 |
+
)
|
69 |
+
fmap_pos = fmap_pos[:, mask]
|
70 |
+
valid_displacements = torch.stack([
|
71 |
+
displacements_i[mask],
|
72 |
+
displacements_j[mask]
|
73 |
+
], dim=0)
|
74 |
+
del mask, displacements_i, displacements_j
|
75 |
+
|
76 |
+
fmap_keypoints = fmap_pos[1 :, :].float() + valid_displacements
|
77 |
+
del valid_displacements
|
78 |
+
|
79 |
+
try:
|
80 |
+
raw_descriptors, _, ids = interpolate_dense_features(
|
81 |
+
fmap_keypoints.to(device),
|
82 |
+
dense_features[0]
|
83 |
+
)
|
84 |
+
except EmptyTensorError:
|
85 |
+
continue
|
86 |
+
fmap_pos = fmap_pos.to(device)
|
87 |
+
fmap_keypoints = fmap_keypoints.to(device)
|
88 |
+
fmap_pos = fmap_pos[:, ids]
|
89 |
+
fmap_keypoints = fmap_keypoints[:, ids]
|
90 |
+
del ids
|
91 |
+
|
92 |
+
keypoints = upscale_positions(fmap_keypoints, scaling_steps=2)
|
93 |
+
del fmap_keypoints
|
94 |
+
|
95 |
+
descriptors = F.normalize(raw_descriptors, dim=0).cpu()
|
96 |
+
del raw_descriptors
|
97 |
+
|
98 |
+
keypoints[0, :] *= h_init / h_level
|
99 |
+
keypoints[1, :] *= w_init / w_level
|
100 |
+
|
101 |
+
fmap_pos = fmap_pos.cpu()
|
102 |
+
keypoints = keypoints.cpu()
|
103 |
+
|
104 |
+
keypoints = torch.cat([
|
105 |
+
keypoints,
|
106 |
+
torch.ones([1, keypoints.size(1)]) * 1 / scale,
|
107 |
+
], dim=0)
|
108 |
+
|
109 |
+
scores = dense_features[
|
110 |
+
0, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :]
|
111 |
+
].cpu() / (idx + 1)
|
112 |
+
del fmap_pos
|
113 |
+
|
114 |
+
all_keypoints = torch.cat([all_keypoints, keypoints], dim=1)
|
115 |
+
all_descriptors = torch.cat([all_descriptors, descriptors], dim=1)
|
116 |
+
all_scores = torch.cat([all_scores, scores], dim=0)
|
117 |
+
del keypoints, descriptors
|
118 |
+
|
119 |
+
previous_dense_features = dense_features
|
120 |
+
del dense_features
|
121 |
+
del previous_dense_features, banned
|
122 |
+
|
123 |
+
keypoints = all_keypoints.t().detach().numpy()
|
124 |
+
del all_keypoints
|
125 |
+
scores = all_scores.detach().numpy()
|
126 |
+
del all_scores
|
127 |
+
descriptors = all_descriptors.t().detach().numpy()
|
128 |
+
del all_descriptors
|
129 |
+
return keypoints, scores, descriptors
|
third_party/RoRD/lib/utils.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from lib.exceptions import EmptyTensorError
|
8 |
+
|
9 |
+
|
10 |
+
def preprocess_image(image, preprocessing=None):
|
11 |
+
image = image.astype(np.float32)
|
12 |
+
image = np.transpose(image, [2, 0, 1])
|
13 |
+
if preprocessing is None:
|
14 |
+
pass
|
15 |
+
elif preprocessing == 'caffe':
|
16 |
+
# RGB -> BGR
|
17 |
+
image = image[:: -1, :, :]
|
18 |
+
# Zero-center by mean pixel
|
19 |
+
mean = np.array([103.939, 116.779, 123.68])
|
20 |
+
image = image - mean.reshape([3, 1, 1])
|
21 |
+
elif preprocessing == 'torch':
|
22 |
+
image /= 255.0
|
23 |
+
mean = np.array([0.485, 0.456, 0.406])
|
24 |
+
std = np.array([0.229, 0.224, 0.225])
|
25 |
+
image = (image - mean.reshape([3, 1, 1])) / std.reshape([3, 1, 1])
|
26 |
+
else:
|
27 |
+
raise ValueError('Unknown preprocessing parameter.')
|
28 |
+
return image
|
29 |
+
|
30 |
+
|
31 |
+
def imshow_image(image, preprocessing=None):
|
32 |
+
if preprocessing is None:
|
33 |
+
pass
|
34 |
+
elif preprocessing == 'caffe':
|
35 |
+
mean = np.array([103.939, 116.779, 123.68])
|
36 |
+
image = image + mean.reshape([3, 1, 1])
|
37 |
+
# RGB -> BGR
|
38 |
+
image = image[:: -1, :, :]
|
39 |
+
elif preprocessing == 'torch':
|
40 |
+
mean = np.array([0.485, 0.456, 0.406])
|
41 |
+
std = np.array([0.229, 0.224, 0.225])
|
42 |
+
image = image * std.reshape([3, 1, 1]) + mean.reshape([3, 1, 1])
|
43 |
+
image *= 255.0
|
44 |
+
else:
|
45 |
+
raise ValueError('Unknown preprocessing parameter.')
|
46 |
+
image = np.transpose(image, [1, 2, 0])
|
47 |
+
image = np.round(image).astype(np.uint8)
|
48 |
+
return image
|
49 |
+
|
50 |
+
|
51 |
+
def grid_positions(h, w, device, matrix=False):
|
52 |
+
lines = torch.arange(
|
53 |
+
0, h, device=device
|
54 |
+
).view(-1, 1).float().repeat(1, w)
|
55 |
+
columns = torch.arange(
|
56 |
+
0, w, device=device
|
57 |
+
).view(1, -1).float().repeat(h, 1)
|
58 |
+
if matrix:
|
59 |
+
return torch.stack([lines, columns], dim=0)
|
60 |
+
else:
|
61 |
+
return torch.cat([lines.view(1, -1), columns.view(1, -1)], dim=0)
|
62 |
+
|
63 |
+
|
64 |
+
def upscale_positions(pos, scaling_steps=0):
|
65 |
+
for _ in range(scaling_steps):
|
66 |
+
pos = pos * 2 + 0.5
|
67 |
+
return pos
|
68 |
+
|
69 |
+
|
70 |
+
def downscale_positions(pos, scaling_steps=0):
|
71 |
+
for _ in range(scaling_steps):
|
72 |
+
pos = (pos - 0.5) / 2
|
73 |
+
return pos
|
74 |
+
|
75 |
+
|
76 |
+
def interpolate_dense_features(pos, dense_features, return_corners=False):
|
77 |
+
device = pos.device
|
78 |
+
|
79 |
+
ids = torch.arange(0, pos.size(1), device=device)
|
80 |
+
|
81 |
+
_, h, w = dense_features.size()
|
82 |
+
|
83 |
+
i = pos[0, :]
|
84 |
+
j = pos[1, :]
|
85 |
+
|
86 |
+
# Valid corners
|
87 |
+
i_top_left = torch.floor(i).long()
|
88 |
+
j_top_left = torch.floor(j).long()
|
89 |
+
valid_top_left = torch.min(i_top_left >= 0, j_top_left >= 0)
|
90 |
+
|
91 |
+
i_top_right = torch.floor(i).long()
|
92 |
+
j_top_right = torch.ceil(j).long()
|
93 |
+
valid_top_right = torch.min(i_top_right >= 0, j_top_right < w)
|
94 |
+
|
95 |
+
i_bottom_left = torch.ceil(i).long()
|
96 |
+
j_bottom_left = torch.floor(j).long()
|
97 |
+
valid_bottom_left = torch.min(i_bottom_left < h, j_bottom_left >= 0)
|
98 |
+
|
99 |
+
i_bottom_right = torch.ceil(i).long()
|
100 |
+
j_bottom_right = torch.ceil(j).long()
|
101 |
+
valid_bottom_right = torch.min(i_bottom_right < h, j_bottom_right < w)
|
102 |
+
|
103 |
+
valid_corners = torch.min(
|
104 |
+
torch.min(valid_top_left, valid_top_right),
|
105 |
+
torch.min(valid_bottom_left, valid_bottom_right)
|
106 |
+
)
|
107 |
+
|
108 |
+
i_top_left = i_top_left[valid_corners]
|
109 |
+
j_top_left = j_top_left[valid_corners]
|
110 |
+
|
111 |
+
i_top_right = i_top_right[valid_corners]
|
112 |
+
j_top_right = j_top_right[valid_corners]
|
113 |
+
|
114 |
+
i_bottom_left = i_bottom_left[valid_corners]
|
115 |
+
j_bottom_left = j_bottom_left[valid_corners]
|
116 |
+
|
117 |
+
i_bottom_right = i_bottom_right[valid_corners]
|
118 |
+
j_bottom_right = j_bottom_right[valid_corners]
|
119 |
+
|
120 |
+
ids = ids[valid_corners]
|
121 |
+
if ids.size(0) == 0:
|
122 |
+
raise EmptyTensorError
|
123 |
+
|
124 |
+
# Interpolation
|
125 |
+
i = i[ids]
|
126 |
+
j = j[ids]
|
127 |
+
dist_i_top_left = i - i_top_left.float()
|
128 |
+
dist_j_top_left = j - j_top_left.float()
|
129 |
+
w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left)
|
130 |
+
w_top_right = (1 - dist_i_top_left) * dist_j_top_left
|
131 |
+
w_bottom_left = dist_i_top_left * (1 - dist_j_top_left)
|
132 |
+
w_bottom_right = dist_i_top_left * dist_j_top_left
|
133 |
+
|
134 |
+
descriptors = (
|
135 |
+
w_top_left * dense_features[:, i_top_left, j_top_left] +
|
136 |
+
w_top_right * dense_features[:, i_top_right, j_top_right] +
|
137 |
+
w_bottom_left * dense_features[:, i_bottom_left, j_bottom_left] +
|
138 |
+
w_bottom_right * dense_features[:, i_bottom_right, j_bottom_right]
|
139 |
+
)
|
140 |
+
|
141 |
+
pos = torch.cat([i.view(1, -1), j.view(1, -1)], dim=0)
|
142 |
+
|
143 |
+
if not return_corners:
|
144 |
+
return [descriptors, pos, ids]
|
145 |
+
else:
|
146 |
+
corners = torch.stack([
|
147 |
+
torch.stack([i_top_left, j_top_left], dim=0),
|
148 |
+
torch.stack([i_top_right, j_top_right], dim=0),
|
149 |
+
torch.stack([i_bottom_left, j_bottom_left], dim=0),
|
150 |
+
torch.stack([i_bottom_right, j_bottom_right], dim=0)
|
151 |
+
], dim=0)
|
152 |
+
return [descriptors, pos, ids, corners]
|
153 |
+
|
154 |
+
|
155 |
+
def savefig(filepath, fig=None, dpi=None):
|
156 |
+
# TomNorway - https://stackoverflow.com/a/53516034
|
157 |
+
if not fig:
|
158 |
+
fig = plt.gcf()
|
159 |
+
|
160 |
+
plt.subplots_adjust(0, 0, 1, 1, 0, 0)
|
161 |
+
for ax in fig.axes:
|
162 |
+
ax.axis('off')
|
163 |
+
ax.margins(0, 0)
|
164 |
+
ax.xaxis.set_major_locator(plt.NullLocator())
|
165 |
+
ax.yaxis.set_major_locator(plt.NullLocator())
|
166 |
+
|
167 |
+
fig.savefig(filepath, pad_inches=0, bbox_inches='tight', dpi=dpi)
|