Realcat commited on
Commit
8c0ddef
·
1 Parent(s): b864970

update: d2net lib

Browse files
third_party/d2net/lib/dataset.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import h5py
2
+
3
+ import numpy as np
4
+
5
+ from PIL import Image
6
+
7
+ import os
8
+
9
+ import torch
10
+ from torch.utils.data import Dataset
11
+
12
+ import time
13
+
14
+ from tqdm import tqdm
15
+
16
+ from lib.utils import preprocess_image
17
+
18
+
19
+ class MegaDepthDataset(Dataset):
20
+ def __init__(
21
+ self,
22
+ scene_list_path='megadepth_utils/train_scenes.txt',
23
+ scene_info_path='/local/dataset/megadepth/scene_info',
24
+ base_path='/local/dataset/megadepth',
25
+ train=True,
26
+ preprocessing=None,
27
+ min_overlap_ratio=.5,
28
+ max_overlap_ratio=1,
29
+ max_scale_ratio=np.inf,
30
+ pairs_per_scene=100,
31
+ image_size=256
32
+ ):
33
+ self.scenes = []
34
+ with open(scene_list_path, 'r') as f:
35
+ lines = f.readlines()
36
+ for line in lines:
37
+ self.scenes.append(line.strip('\n'))
38
+
39
+ self.scene_info_path = scene_info_path
40
+ self.base_path = base_path
41
+
42
+ self.train = train
43
+
44
+ self.preprocessing = preprocessing
45
+
46
+ self.min_overlap_ratio = min_overlap_ratio
47
+ self.max_overlap_ratio = max_overlap_ratio
48
+ self.max_scale_ratio = max_scale_ratio
49
+
50
+ self.pairs_per_scene = pairs_per_scene
51
+
52
+ self.image_size = image_size
53
+
54
+ self.dataset = []
55
+
56
+ def build_dataset(self):
57
+ self.dataset = []
58
+ if not self.train:
59
+ np_random_state = np.random.get_state()
60
+ np.random.seed(42)
61
+ print('Building the validation dataset...')
62
+ else:
63
+ print('Building a new training dataset...')
64
+ for scene in tqdm(self.scenes, total=len(self.scenes)):
65
+ scene_info_path = os.path.join(
66
+ self.scene_info_path, '%s.npz' % scene
67
+ )
68
+ if not os.path.exists(scene_info_path):
69
+ continue
70
+ scene_info = np.load(scene_info_path, allow_pickle=True)
71
+ overlap_matrix = scene_info['overlap_matrix']
72
+ scale_ratio_matrix = scene_info['scale_ratio_matrix']
73
+
74
+ valid = np.logical_and(
75
+ np.logical_and(
76
+ overlap_matrix >= self.min_overlap_ratio,
77
+ overlap_matrix <= self.max_overlap_ratio
78
+ ),
79
+ scale_ratio_matrix <= self.max_scale_ratio
80
+ )
81
+
82
+ pairs = np.vstack(np.where(valid))
83
+ try:
84
+ selected_ids = np.random.choice(
85
+ pairs.shape[1], self.pairs_per_scene
86
+ )
87
+ except:
88
+ continue
89
+
90
+ image_paths = scene_info['image_paths']
91
+ depth_paths = scene_info['depth_paths']
92
+ points3D_id_to_2D = scene_info['points3D_id_to_2D']
93
+ points3D_id_to_ndepth = scene_info['points3D_id_to_ndepth']
94
+ intrinsics = scene_info['intrinsics']
95
+ poses = scene_info['poses']
96
+
97
+ for pair_idx in selected_ids:
98
+ idx1 = pairs[0, pair_idx]
99
+ idx2 = pairs[1, pair_idx]
100
+ matches = np.array(list(
101
+ points3D_id_to_2D[idx1].keys() &
102
+ points3D_id_to_2D[idx2].keys()
103
+ ))
104
+
105
+ # Scale filtering
106
+ matches_nd1 = np.array([points3D_id_to_ndepth[idx1][match] for match in matches])
107
+ matches_nd2 = np.array([points3D_id_to_ndepth[idx2][match] for match in matches])
108
+ scale_ratio = np.maximum(matches_nd1 / matches_nd2, matches_nd2 / matches_nd1)
109
+ matches = matches[np.where(scale_ratio <= self.max_scale_ratio)[0]]
110
+
111
+ point3D_id = np.random.choice(matches)
112
+ point2D1 = points3D_id_to_2D[idx1][point3D_id]
113
+ point2D2 = points3D_id_to_2D[idx2][point3D_id]
114
+ nd1 = points3D_id_to_ndepth[idx1][point3D_id]
115
+ nd2 = points3D_id_to_ndepth[idx2][point3D_id]
116
+ central_match = np.array([
117
+ point2D1[1], point2D1[0],
118
+ point2D2[1], point2D2[0]
119
+ ])
120
+ self.dataset.append({
121
+ 'image_path1': image_paths[idx1],
122
+ 'depth_path1': depth_paths[idx1],
123
+ 'intrinsics1': intrinsics[idx1],
124
+ 'pose1': poses[idx1],
125
+ 'image_path2': image_paths[idx2],
126
+ 'depth_path2': depth_paths[idx2],
127
+ 'intrinsics2': intrinsics[idx2],
128
+ 'pose2': poses[idx2],
129
+ 'central_match': central_match,
130
+ 'scale_ratio': max(nd1 / nd2, nd2 / nd1)
131
+ })
132
+ np.random.shuffle(self.dataset)
133
+ if not self.train:
134
+ np.random.set_state(np_random_state)
135
+
136
+ def __len__(self):
137
+ return len(self.dataset)
138
+
139
+ def recover_pair(self, pair_metadata):
140
+ depth_path1 = os.path.join(
141
+ self.base_path, pair_metadata['depth_path1']
142
+ )
143
+ with h5py.File(depth_path1, 'r') as hdf5_file:
144
+ depth1 = np.array(hdf5_file['/depth'])
145
+ assert(np.min(depth1) >= 0)
146
+ image_path1 = os.path.join(
147
+ self.base_path, pair_metadata['image_path1']
148
+ )
149
+ image1 = Image.open(image_path1)
150
+ if image1.mode != 'RGB':
151
+ image1 = image1.convert('RGB')
152
+ image1 = np.array(image1)
153
+ assert(image1.shape[0] == depth1.shape[0] and image1.shape[1] == depth1.shape[1])
154
+ intrinsics1 = pair_metadata['intrinsics1']
155
+ pose1 = pair_metadata['pose1']
156
+
157
+ depth_path2 = os.path.join(
158
+ self.base_path, pair_metadata['depth_path2']
159
+ )
160
+ with h5py.File(depth_path2, 'r') as hdf5_file:
161
+ depth2 = np.array(hdf5_file['/depth'])
162
+ assert(np.min(depth2) >= 0)
163
+ image_path2 = os.path.join(
164
+ self.base_path, pair_metadata['image_path2']
165
+ )
166
+ image2 = Image.open(image_path2)
167
+ if image2.mode != 'RGB':
168
+ image2 = image2.convert('RGB')
169
+ image2 = np.array(image2)
170
+ assert(image2.shape[0] == depth2.shape[0] and image2.shape[1] == depth2.shape[1])
171
+ intrinsics2 = pair_metadata['intrinsics2']
172
+ pose2 = pair_metadata['pose2']
173
+
174
+ central_match = pair_metadata['central_match']
175
+ image1, bbox1, image2, bbox2 = self.crop(image1, image2, central_match)
176
+
177
+ depth1 = depth1[
178
+ bbox1[0] : bbox1[0] + self.image_size,
179
+ bbox1[1] : bbox1[1] + self.image_size
180
+ ]
181
+ depth2 = depth2[
182
+ bbox2[0] : bbox2[0] + self.image_size,
183
+ bbox2[1] : bbox2[1] + self.image_size
184
+ ]
185
+
186
+ return (
187
+ image1, depth1, intrinsics1, pose1, bbox1,
188
+ image2, depth2, intrinsics2, pose2, bbox2
189
+ )
190
+
191
+ def crop(self, image1, image2, central_match):
192
+ bbox1_i = max(int(central_match[0]) - self.image_size // 2, 0)
193
+ if bbox1_i + self.image_size >= image1.shape[0]:
194
+ bbox1_i = image1.shape[0] - self.image_size
195
+ bbox1_j = max(int(central_match[1]) - self.image_size // 2, 0)
196
+ if bbox1_j + self.image_size >= image1.shape[1]:
197
+ bbox1_j = image1.shape[1] - self.image_size
198
+
199
+ bbox2_i = max(int(central_match[2]) - self.image_size // 2, 0)
200
+ if bbox2_i + self.image_size >= image2.shape[0]:
201
+ bbox2_i = image2.shape[0] - self.image_size
202
+ bbox2_j = max(int(central_match[3]) - self.image_size // 2, 0)
203
+ if bbox2_j + self.image_size >= image2.shape[1]:
204
+ bbox2_j = image2.shape[1] - self.image_size
205
+
206
+ return (
207
+ image1[
208
+ bbox1_i : bbox1_i + self.image_size,
209
+ bbox1_j : bbox1_j + self.image_size
210
+ ],
211
+ np.array([bbox1_i, bbox1_j]),
212
+ image2[
213
+ bbox2_i : bbox2_i + self.image_size,
214
+ bbox2_j : bbox2_j + self.image_size
215
+ ],
216
+ np.array([bbox2_i, bbox2_j])
217
+ )
218
+
219
+ def __getitem__(self, idx):
220
+ (
221
+ image1, depth1, intrinsics1, pose1, bbox1,
222
+ image2, depth2, intrinsics2, pose2, bbox2
223
+ ) = self.recover_pair(self.dataset[idx])
224
+
225
+ image1 = preprocess_image(image1, preprocessing=self.preprocessing)
226
+ image2 = preprocess_image(image2, preprocessing=self.preprocessing)
227
+
228
+ return {
229
+ 'image1': torch.from_numpy(image1.astype(np.float32)),
230
+ 'depth1': torch.from_numpy(depth1.astype(np.float32)),
231
+ 'intrinsics1': torch.from_numpy(intrinsics1.astype(np.float32)),
232
+ 'pose1': torch.from_numpy(pose1.astype(np.float32)),
233
+ 'bbox1': torch.from_numpy(bbox1.astype(np.float32)),
234
+ 'image2': torch.from_numpy(image2.astype(np.float32)),
235
+ 'depth2': torch.from_numpy(depth2.astype(np.float32)),
236
+ 'intrinsics2': torch.from_numpy(intrinsics2.astype(np.float32)),
237
+ 'pose2': torch.from_numpy(pose2.astype(np.float32)),
238
+ 'bbox2': torch.from_numpy(bbox2.astype(np.float32))
239
+ }
third_party/d2net/lib/exceptions.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ class EmptyTensorError(Exception):
2
+ pass
3
+
4
+
5
+ class NoGradientError(Exception):
6
+ pass
third_party/d2net/lib/loss.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ import matplotlib.pyplot as plt
3
+
4
+ import numpy as np
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ from lib.utils import (
10
+ grid_positions,
11
+ upscale_positions,
12
+ downscale_positions,
13
+ savefig,
14
+ imshow_image
15
+ )
16
+ from lib.exceptions import NoGradientError, EmptyTensorError
17
+
18
+ matplotlib.use('Agg')
19
+
20
+
21
+ def loss_function(
22
+ model, batch, device, margin=1, safe_radius=4, scaling_steps=3, plot=False
23
+ ):
24
+ output = model({
25
+ 'image1': batch['image1'].to(device),
26
+ 'image2': batch['image2'].to(device)
27
+ })
28
+
29
+ loss = torch.tensor(np.array([0], dtype=np.float32), device=device)
30
+ has_grad = False
31
+
32
+ n_valid_samples = 0
33
+ for idx_in_batch in range(batch['image1'].size(0)):
34
+ # Annotations
35
+ depth1 = batch['depth1'][idx_in_batch].to(device) # [h1, w1]
36
+ intrinsics1 = batch['intrinsics1'][idx_in_batch].to(device) # [3, 3]
37
+ pose1 = batch['pose1'][idx_in_batch].view(4, 4).to(device) # [4, 4]
38
+ bbox1 = batch['bbox1'][idx_in_batch].to(device) # [2]
39
+
40
+ depth2 = batch['depth2'][idx_in_batch].to(device)
41
+ intrinsics2 = batch['intrinsics2'][idx_in_batch].to(device)
42
+ pose2 = batch['pose2'][idx_in_batch].view(4, 4).to(device)
43
+ bbox2 = batch['bbox2'][idx_in_batch].to(device)
44
+
45
+ # Network output
46
+ dense_features1 = output['dense_features1'][idx_in_batch]
47
+ c, h1, w1 = dense_features1.size()
48
+ scores1 = output['scores1'][idx_in_batch].view(-1)
49
+
50
+ dense_features2 = output['dense_features2'][idx_in_batch]
51
+ _, h2, w2 = dense_features2.size()
52
+ scores2 = output['scores2'][idx_in_batch]
53
+
54
+ all_descriptors1 = F.normalize(dense_features1.view(c, -1), dim=0)
55
+ descriptors1 = all_descriptors1
56
+
57
+ all_descriptors2 = F.normalize(dense_features2.view(c, -1), dim=0)
58
+
59
+ # Warp the positions from image 1 to image 2
60
+ fmap_pos1 = grid_positions(h1, w1, device)
61
+ pos1 = upscale_positions(fmap_pos1, scaling_steps=scaling_steps)
62
+ try:
63
+ pos1, pos2, ids = warp(
64
+ pos1,
65
+ depth1, intrinsics1, pose1, bbox1,
66
+ depth2, intrinsics2, pose2, bbox2
67
+ )
68
+ except EmptyTensorError:
69
+ continue
70
+ fmap_pos1 = fmap_pos1[:, ids]
71
+ descriptors1 = descriptors1[:, ids]
72
+ scores1 = scores1[ids]
73
+
74
+ # Skip the pair if not enough GT correspondences are available
75
+ if ids.size(0) < 128:
76
+ continue
77
+
78
+ # Descriptors at the corresponding positions
79
+ fmap_pos2 = torch.round(
80
+ downscale_positions(pos2, scaling_steps=scaling_steps)
81
+ ).long()
82
+ descriptors2 = F.normalize(
83
+ dense_features2[:, fmap_pos2[0, :], fmap_pos2[1, :]],
84
+ dim=0
85
+ )
86
+ positive_distance = 2 - 2 * (
87
+ descriptors1.t().unsqueeze(1) @ descriptors2.t().unsqueeze(2)
88
+ ).squeeze()
89
+
90
+ all_fmap_pos2 = grid_positions(h2, w2, device)
91
+ position_distance = torch.max(
92
+ torch.abs(
93
+ fmap_pos2.unsqueeze(2).float() -
94
+ all_fmap_pos2.unsqueeze(1)
95
+ ),
96
+ dim=0
97
+ )[0]
98
+ is_out_of_safe_radius = position_distance > safe_radius
99
+ distance_matrix = 2 - 2 * (descriptors1.t() @ all_descriptors2)
100
+ negative_distance2 = torch.min(
101
+ distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
102
+ dim=1
103
+ )[0]
104
+
105
+ all_fmap_pos1 = grid_positions(h1, w1, device)
106
+ position_distance = torch.max(
107
+ torch.abs(
108
+ fmap_pos1.unsqueeze(2).float() -
109
+ all_fmap_pos1.unsqueeze(1)
110
+ ),
111
+ dim=0
112
+ )[0]
113
+ is_out_of_safe_radius = position_distance > safe_radius
114
+ distance_matrix = 2 - 2 * (descriptors2.t() @ all_descriptors1)
115
+ negative_distance1 = torch.min(
116
+ distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
117
+ dim=1
118
+ )[0]
119
+
120
+ diff = positive_distance - torch.min(
121
+ negative_distance1, negative_distance2
122
+ )
123
+
124
+ scores2 = scores2[fmap_pos2[0, :], fmap_pos2[1, :]]
125
+
126
+ loss = loss + (
127
+ torch.sum(scores1 * scores2 * F.relu(margin + diff)) /
128
+ torch.sum(scores1 * scores2)
129
+ )
130
+
131
+ has_grad = True
132
+ n_valid_samples += 1
133
+
134
+ if plot and batch['batch_idx'] % batch['log_interval'] == 0:
135
+ pos1_aux = pos1.cpu().numpy()
136
+ pos2_aux = pos2.cpu().numpy()
137
+ k = pos1_aux.shape[1]
138
+ col = np.random.rand(k, 3)
139
+ n_sp = 4
140
+ plt.figure()
141
+ plt.subplot(1, n_sp, 1)
142
+ im1 = imshow_image(
143
+ batch['image1'][idx_in_batch].cpu().numpy(),
144
+ preprocessing=batch['preprocessing']
145
+ )
146
+ plt.imshow(im1)
147
+ plt.scatter(
148
+ pos1_aux[1, :], pos1_aux[0, :],
149
+ s=0.25**2, c=col, marker=',', alpha=0.5
150
+ )
151
+ plt.axis('off')
152
+ plt.subplot(1, n_sp, 2)
153
+ plt.imshow(
154
+ output['scores1'][idx_in_batch].data.cpu().numpy(),
155
+ cmap='Reds'
156
+ )
157
+ plt.axis('off')
158
+ plt.subplot(1, n_sp, 3)
159
+ im2 = imshow_image(
160
+ batch['image2'][idx_in_batch].cpu().numpy(),
161
+ preprocessing=batch['preprocessing']
162
+ )
163
+ plt.imshow(im2)
164
+ plt.scatter(
165
+ pos2_aux[1, :], pos2_aux[0, :],
166
+ s=0.25**2, c=col, marker=',', alpha=0.5
167
+ )
168
+ plt.axis('off')
169
+ plt.subplot(1, n_sp, 4)
170
+ plt.imshow(
171
+ output['scores2'][idx_in_batch].data.cpu().numpy(),
172
+ cmap='Reds'
173
+ )
174
+ plt.axis('off')
175
+ savefig('train_vis/%s.%02d.%02d.%d.png' % (
176
+ 'train' if batch['train'] else 'valid',
177
+ batch['epoch_idx'],
178
+ batch['batch_idx'] // batch['log_interval'],
179
+ idx_in_batch
180
+ ), dpi=300)
181
+ plt.close()
182
+
183
+ if not has_grad:
184
+ raise NoGradientError
185
+
186
+ loss = loss / n_valid_samples
187
+
188
+ return loss
189
+
190
+
191
+ def interpolate_depth(pos, depth):
192
+ device = pos.device
193
+
194
+ ids = torch.arange(0, pos.size(1), device=device)
195
+
196
+ h, w = depth.size()
197
+
198
+ i = pos[0, :]
199
+ j = pos[1, :]
200
+
201
+ # Valid corners
202
+ i_top_left = torch.floor(i).long()
203
+ j_top_left = torch.floor(j).long()
204
+ valid_top_left = torch.min(i_top_left >= 0, j_top_left >= 0)
205
+
206
+ i_top_right = torch.floor(i).long()
207
+ j_top_right = torch.ceil(j).long()
208
+ valid_top_right = torch.min(i_top_right >= 0, j_top_right < w)
209
+
210
+ i_bottom_left = torch.ceil(i).long()
211
+ j_bottom_left = torch.floor(j).long()
212
+ valid_bottom_left = torch.min(i_bottom_left < h, j_bottom_left >= 0)
213
+
214
+ i_bottom_right = torch.ceil(i).long()
215
+ j_bottom_right = torch.ceil(j).long()
216
+ valid_bottom_right = torch.min(i_bottom_right < h, j_bottom_right < w)
217
+
218
+ valid_corners = torch.min(
219
+ torch.min(valid_top_left, valid_top_right),
220
+ torch.min(valid_bottom_left, valid_bottom_right)
221
+ )
222
+
223
+ i_top_left = i_top_left[valid_corners]
224
+ j_top_left = j_top_left[valid_corners]
225
+
226
+ i_top_right = i_top_right[valid_corners]
227
+ j_top_right = j_top_right[valid_corners]
228
+
229
+ i_bottom_left = i_bottom_left[valid_corners]
230
+ j_bottom_left = j_bottom_left[valid_corners]
231
+
232
+ i_bottom_right = i_bottom_right[valid_corners]
233
+ j_bottom_right = j_bottom_right[valid_corners]
234
+
235
+ ids = ids[valid_corners]
236
+ if ids.size(0) == 0:
237
+ raise EmptyTensorError
238
+
239
+ # Valid depth
240
+ valid_depth = torch.min(
241
+ torch.min(
242
+ depth[i_top_left, j_top_left] > 0,
243
+ depth[i_top_right, j_top_right] > 0
244
+ ),
245
+ torch.min(
246
+ depth[i_bottom_left, j_bottom_left] > 0,
247
+ depth[i_bottom_right, j_bottom_right] > 0
248
+ )
249
+ )
250
+
251
+ i_top_left = i_top_left[valid_depth]
252
+ j_top_left = j_top_left[valid_depth]
253
+
254
+ i_top_right = i_top_right[valid_depth]
255
+ j_top_right = j_top_right[valid_depth]
256
+
257
+ i_bottom_left = i_bottom_left[valid_depth]
258
+ j_bottom_left = j_bottom_left[valid_depth]
259
+
260
+ i_bottom_right = i_bottom_right[valid_depth]
261
+ j_bottom_right = j_bottom_right[valid_depth]
262
+
263
+ ids = ids[valid_depth]
264
+ if ids.size(0) == 0:
265
+ raise EmptyTensorError
266
+
267
+ # Interpolation
268
+ i = i[ids]
269
+ j = j[ids]
270
+ dist_i_top_left = i - i_top_left.float()
271
+ dist_j_top_left = j - j_top_left.float()
272
+ w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left)
273
+ w_top_right = (1 - dist_i_top_left) * dist_j_top_left
274
+ w_bottom_left = dist_i_top_left * (1 - dist_j_top_left)
275
+ w_bottom_right = dist_i_top_left * dist_j_top_left
276
+
277
+ interpolated_depth = (
278
+ w_top_left * depth[i_top_left, j_top_left] +
279
+ w_top_right * depth[i_top_right, j_top_right] +
280
+ w_bottom_left * depth[i_bottom_left, j_bottom_left] +
281
+ w_bottom_right * depth[i_bottom_right, j_bottom_right]
282
+ )
283
+
284
+ pos = torch.cat([i.view(1, -1), j.view(1, -1)], dim=0)
285
+
286
+ return [interpolated_depth, pos, ids]
287
+
288
+
289
+ def uv_to_pos(uv):
290
+ return torch.cat([uv[1, :].view(1, -1), uv[0, :].view(1, -1)], dim=0)
291
+
292
+
293
+ def warp(
294
+ pos1,
295
+ depth1, intrinsics1, pose1, bbox1,
296
+ depth2, intrinsics2, pose2, bbox2
297
+ ):
298
+ device = pos1.device
299
+
300
+ Z1, pos1, ids = interpolate_depth(pos1, depth1)
301
+
302
+ # COLMAP convention
303
+ u1 = pos1[1, :] + bbox1[1] + .5
304
+ v1 = pos1[0, :] + bbox1[0] + .5
305
+
306
+ X1 = (u1 - intrinsics1[0, 2]) * (Z1 / intrinsics1[0, 0])
307
+ Y1 = (v1 - intrinsics1[1, 2]) * (Z1 / intrinsics1[1, 1])
308
+
309
+ XYZ1_hom = torch.cat([
310
+ X1.view(1, -1),
311
+ Y1.view(1, -1),
312
+ Z1.view(1, -1),
313
+ torch.ones(1, Z1.size(0), device=device)
314
+ ], dim=0)
315
+ XYZ2_hom = torch.chain_matmul(pose2, torch.inverse(pose1), XYZ1_hom)
316
+ XYZ2 = XYZ2_hom[: -1, :] / XYZ2_hom[-1, :].view(1, -1)
317
+
318
+ uv2_hom = torch.matmul(intrinsics2, XYZ2)
319
+ uv2 = uv2_hom[: -1, :] / uv2_hom[-1, :].view(1, -1)
320
+
321
+ u2 = uv2[0, :] - bbox2[1] - .5
322
+ v2 = uv2[1, :] - bbox2[0] - .5
323
+ uv2 = torch.cat([u2.view(1, -1), v2.view(1, -1)], dim=0)
324
+
325
+ annotated_depth, pos2, new_ids = interpolate_depth(uv_to_pos(uv2), depth2)
326
+
327
+ ids = ids[new_ids]
328
+ pos1 = pos1[:, new_ids]
329
+ estimated_depth = XYZ2[2, new_ids]
330
+
331
+ inlier_mask = torch.abs(estimated_depth - annotated_depth) < 0.05
332
+
333
+ ids = ids[inlier_mask]
334
+ if ids.size(0) == 0:
335
+ raise EmptyTensorError
336
+
337
+ pos2 = pos2[:, inlier_mask]
338
+ pos1 = pos1[:, inlier_mask]
339
+
340
+ return pos1, pos2, ids
third_party/d2net/lib/model.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ import torchvision.models as models
6
+
7
+
8
+ class DenseFeatureExtractionModule(nn.Module):
9
+ def __init__(self, finetune_feature_extraction=False, use_cuda=True):
10
+ super(DenseFeatureExtractionModule, self).__init__()
11
+
12
+ model = models.vgg16()
13
+ vgg16_layers = [
14
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2',
15
+ 'pool1',
16
+ 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2',
17
+ 'pool2',
18
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3',
19
+ 'pool3',
20
+ 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3',
21
+ 'pool4',
22
+ 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
23
+ 'pool5'
24
+ ]
25
+ conv4_3_idx = vgg16_layers.index('conv4_3')
26
+
27
+ self.model = nn.Sequential(
28
+ *list(model.features.children())[: conv4_3_idx + 1]
29
+ )
30
+
31
+ self.num_channels = 512
32
+
33
+ # Fix forward parameters
34
+ for param in self.model.parameters():
35
+ param.requires_grad = False
36
+ if finetune_feature_extraction:
37
+ # Unlock conv4_3
38
+ for param in list(self.model.parameters())[-2 :]:
39
+ param.requires_grad = True
40
+
41
+ if use_cuda:
42
+ self.model = self.model.cuda()
43
+
44
+ def forward(self, batch):
45
+ output = self.model(batch)
46
+ return output
47
+
48
+
49
+ class SoftDetectionModule(nn.Module):
50
+ def __init__(self, soft_local_max_size=3):
51
+ super(SoftDetectionModule, self).__init__()
52
+
53
+ self.soft_local_max_size = soft_local_max_size
54
+
55
+ self.pad = self.soft_local_max_size // 2
56
+
57
+ def forward(self, batch):
58
+ b = batch.size(0)
59
+
60
+ batch = F.relu(batch)
61
+
62
+ max_per_sample = torch.max(batch.view(b, -1), dim=1)[0]
63
+ exp = torch.exp(batch / max_per_sample.view(b, 1, 1, 1))
64
+ sum_exp = (
65
+ self.soft_local_max_size ** 2 *
66
+ F.avg_pool2d(
67
+ F.pad(exp, [self.pad] * 4, mode='constant', value=1.),
68
+ self.soft_local_max_size, stride=1
69
+ )
70
+ )
71
+ local_max_score = exp / sum_exp
72
+
73
+ depth_wise_max = torch.max(batch, dim=1)[0]
74
+ depth_wise_max_score = batch / depth_wise_max.unsqueeze(1)
75
+
76
+ all_scores = local_max_score * depth_wise_max_score
77
+ score = torch.max(all_scores, dim=1)[0]
78
+
79
+ score = score / torch.sum(score.view(b, -1), dim=1).view(b, 1, 1)
80
+
81
+ return score
82
+
83
+
84
+ class D2Net(nn.Module):
85
+ def __init__(self, model_file=None, use_cuda=True):
86
+ super(D2Net, self).__init__()
87
+
88
+ self.dense_feature_extraction = DenseFeatureExtractionModule(
89
+ finetune_feature_extraction=True,
90
+ use_cuda=use_cuda
91
+ )
92
+
93
+ self.detection = SoftDetectionModule()
94
+
95
+ if model_file is not None:
96
+ if use_cuda:
97
+ self.load_state_dict(torch.load(model_file)['model'])
98
+ else:
99
+ self.load_state_dict(torch.load(model_file, map_location='cpu')['model'])
100
+
101
+ def forward(self, batch):
102
+ b = batch['image1'].size(0)
103
+
104
+ dense_features = self.dense_feature_extraction(
105
+ torch.cat([batch['image1'], batch['image2']], dim=0)
106
+ )
107
+
108
+ scores = self.detection(dense_features)
109
+
110
+ dense_features1 = dense_features[: b, :, :, :]
111
+ dense_features2 = dense_features[b :, :, :, :]
112
+
113
+ scores1 = scores[: b, :, :]
114
+ scores2 = scores[b :, :, :]
115
+
116
+ return {
117
+ 'dense_features1': dense_features1,
118
+ 'scores1': scores1,
119
+ 'dense_features2': dense_features2,
120
+ 'scores2': scores2
121
+ }
third_party/d2net/lib/model_test.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class DenseFeatureExtractionModule(nn.Module):
7
+ def __init__(self, use_relu=True, use_cuda=True):
8
+ super(DenseFeatureExtractionModule, self).__init__()
9
+
10
+ self.model = nn.Sequential(
11
+ nn.Conv2d(3, 64, 3, padding=1),
12
+ nn.ReLU(inplace=True),
13
+ nn.Conv2d(64, 64, 3, padding=1),
14
+ nn.ReLU(inplace=True),
15
+ nn.MaxPool2d(2, stride=2),
16
+ nn.Conv2d(64, 128, 3, padding=1),
17
+ nn.ReLU(inplace=True),
18
+ nn.Conv2d(128, 128, 3, padding=1),
19
+ nn.ReLU(inplace=True),
20
+ nn.MaxPool2d(2, stride=2),
21
+ nn.Conv2d(128, 256, 3, padding=1),
22
+ nn.ReLU(inplace=True),
23
+ nn.Conv2d(256, 256, 3, padding=1),
24
+ nn.ReLU(inplace=True),
25
+ nn.Conv2d(256, 256, 3, padding=1),
26
+ nn.ReLU(inplace=True),
27
+ nn.AvgPool2d(2, stride=1),
28
+ nn.Conv2d(256, 512, 3, padding=2, dilation=2),
29
+ nn.ReLU(inplace=True),
30
+ nn.Conv2d(512, 512, 3, padding=2, dilation=2),
31
+ nn.ReLU(inplace=True),
32
+ nn.Conv2d(512, 512, 3, padding=2, dilation=2),
33
+ )
34
+ self.num_channels = 512
35
+
36
+ self.use_relu = use_relu
37
+
38
+ if use_cuda:
39
+ self.model = self.model.cuda()
40
+
41
+ def forward(self, batch):
42
+ output = self.model(batch)
43
+ if self.use_relu:
44
+ output = F.relu(output)
45
+ return output
46
+
47
+
48
+ class D2Net(nn.Module):
49
+ def __init__(self, model_file=None, use_relu=True, use_cuda=True):
50
+ super(D2Net, self).__init__()
51
+
52
+ self.dense_feature_extraction = DenseFeatureExtractionModule(
53
+ use_relu=use_relu, use_cuda=use_cuda
54
+ )
55
+
56
+ self.detection = HardDetectionModule()
57
+
58
+ self.localization = HandcraftedLocalizationModule()
59
+
60
+ if model_file is not None:
61
+ if use_cuda:
62
+ self.load_state_dict(torch.load(model_file)['model'])
63
+ else:
64
+ self.load_state_dict(torch.load(model_file, map_location='cpu')['model'])
65
+
66
+ def forward(self, batch):
67
+ _, _, h, w = batch.size()
68
+ dense_features = self.dense_feature_extraction(batch)
69
+
70
+ detections = self.detection(dense_features)
71
+
72
+ displacements = self.localization(dense_features)
73
+
74
+ return {
75
+ 'dense_features': dense_features,
76
+ 'detections': detections,
77
+ 'displacements': displacements
78
+ }
79
+
80
+
81
+ class HardDetectionModule(nn.Module):
82
+ def __init__(self, edge_threshold=5):
83
+ super(HardDetectionModule, self).__init__()
84
+
85
+ self.edge_threshold = edge_threshold
86
+
87
+ self.dii_filter = torch.tensor(
88
+ [[0, 1., 0], [0, -2., 0], [0, 1., 0]]
89
+ ).view(1, 1, 3, 3)
90
+ self.dij_filter = 0.25 * torch.tensor(
91
+ [[1., 0, -1.], [0, 0., 0], [-1., 0, 1.]]
92
+ ).view(1, 1, 3, 3)
93
+ self.djj_filter = torch.tensor(
94
+ [[0, 0, 0], [1., -2., 1.], [0, 0, 0]]
95
+ ).view(1, 1, 3, 3)
96
+
97
+ def forward(self, batch):
98
+ b, c, h, w = batch.size()
99
+ device = batch.device
100
+
101
+ depth_wise_max = torch.max(batch, dim=1)[0]
102
+ is_depth_wise_max = (batch == depth_wise_max)
103
+ del depth_wise_max
104
+
105
+ local_max = F.max_pool2d(batch, 3, stride=1, padding=1)
106
+ is_local_max = (batch == local_max)
107
+ del local_max
108
+
109
+ dii = F.conv2d(
110
+ batch.view(-1, 1, h, w), self.dii_filter.to(device), padding=1
111
+ ).view(b, c, h, w)
112
+ dij = F.conv2d(
113
+ batch.view(-1, 1, h, w), self.dij_filter.to(device), padding=1
114
+ ).view(b, c, h, w)
115
+ djj = F.conv2d(
116
+ batch.view(-1, 1, h, w), self.djj_filter.to(device), padding=1
117
+ ).view(b, c, h, w)
118
+
119
+ det = dii * djj - dij * dij
120
+ tr = dii + djj
121
+ del dii, dij, djj
122
+
123
+ threshold = (self.edge_threshold + 1) ** 2 / self.edge_threshold
124
+ is_not_edge = torch.min(tr * tr / det <= threshold, det > 0)
125
+
126
+ detected = torch.min(
127
+ is_depth_wise_max,
128
+ torch.min(is_local_max, is_not_edge)
129
+ )
130
+ del is_depth_wise_max, is_local_max, is_not_edge
131
+
132
+ return detected
133
+
134
+
135
+ class HandcraftedLocalizationModule(nn.Module):
136
+ def __init__(self):
137
+ super(HandcraftedLocalizationModule, self).__init__()
138
+
139
+ self.di_filter = torch.tensor(
140
+ [[0, -0.5, 0], [0, 0, 0], [0, 0.5, 0]]
141
+ ).view(1, 1, 3, 3)
142
+ self.dj_filter = torch.tensor(
143
+ [[0, 0, 0], [-0.5, 0, 0.5], [0, 0, 0]]
144
+ ).view(1, 1, 3, 3)
145
+
146
+ self.dii_filter = torch.tensor(
147
+ [[0, 1., 0], [0, -2., 0], [0, 1., 0]]
148
+ ).view(1, 1, 3, 3)
149
+ self.dij_filter = 0.25 * torch.tensor(
150
+ [[1., 0, -1.], [0, 0., 0], [-1., 0, 1.]]
151
+ ).view(1, 1, 3, 3)
152
+ self.djj_filter = torch.tensor(
153
+ [[0, 0, 0], [1., -2., 1.], [0, 0, 0]]
154
+ ).view(1, 1, 3, 3)
155
+
156
+ def forward(self, batch):
157
+ b, c, h, w = batch.size()
158
+ device = batch.device
159
+
160
+ dii = F.conv2d(
161
+ batch.view(-1, 1, h, w), self.dii_filter.to(device), padding=1
162
+ ).view(b, c, h, w)
163
+ dij = F.conv2d(
164
+ batch.view(-1, 1, h, w), self.dij_filter.to(device), padding=1
165
+ ).view(b, c, h, w)
166
+ djj = F.conv2d(
167
+ batch.view(-1, 1, h, w), self.djj_filter.to(device), padding=1
168
+ ).view(b, c, h, w)
169
+ det = dii * djj - dij * dij
170
+
171
+ inv_hess_00 = djj / det
172
+ inv_hess_01 = -dij / det
173
+ inv_hess_11 = dii / det
174
+ del dii, dij, djj, det
175
+
176
+ di = F.conv2d(
177
+ batch.view(-1, 1, h, w), self.di_filter.to(device), padding=1
178
+ ).view(b, c, h, w)
179
+ dj = F.conv2d(
180
+ batch.view(-1, 1, h, w), self.dj_filter.to(device), padding=1
181
+ ).view(b, c, h, w)
182
+
183
+ step_i = -(inv_hess_00 * di + inv_hess_01 * dj)
184
+ step_j = -(inv_hess_01 * di + inv_hess_11 * dj)
185
+ del inv_hess_00, inv_hess_01, inv_hess_11, di, dj
186
+
187
+ return torch.stack([step_i, step_j], dim=1)
third_party/d2net/lib/pyramid.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from lib.exceptions import EmptyTensorError
6
+ from lib.utils import interpolate_dense_features, upscale_positions
7
+
8
+
9
+ def process_multiscale(image, model, scales=[.5, 1, 2]):
10
+ b, _, h_init, w_init = image.size()
11
+ device = image.device
12
+ assert(b == 1)
13
+
14
+ all_keypoints = torch.zeros([3, 0])
15
+ all_descriptors = torch.zeros([
16
+ model.dense_feature_extraction.num_channels, 0
17
+ ])
18
+ all_scores = torch.zeros(0)
19
+
20
+ previous_dense_features = None
21
+ banned = None
22
+ for idx, scale in enumerate(scales):
23
+ current_image = F.interpolate(
24
+ image, scale_factor=scale,
25
+ mode='bilinear', align_corners=True
26
+ )
27
+ _, _, h_level, w_level = current_image.size()
28
+
29
+ dense_features = model.dense_feature_extraction(current_image)
30
+ del current_image
31
+
32
+ _, _, h, w = dense_features.size()
33
+
34
+ # Sum the feature maps.
35
+ if previous_dense_features is not None:
36
+ dense_features += F.interpolate(
37
+ previous_dense_features, size=[h, w],
38
+ mode='bilinear', align_corners=True
39
+ )
40
+ del previous_dense_features
41
+
42
+ # Recover detections.
43
+ detections = model.detection(dense_features)
44
+ if banned is not None:
45
+ banned = F.interpolate(banned.float(), size=[h, w]).bool()
46
+ detections = torch.min(detections, ~banned)
47
+ banned = torch.max(
48
+ torch.max(detections, dim=1)[0].unsqueeze(1), banned
49
+ )
50
+ else:
51
+ banned = torch.max(detections, dim=1)[0].unsqueeze(1)
52
+ fmap_pos = torch.nonzero(detections[0].cpu()).t()
53
+ del detections
54
+
55
+ # Recover displacements.
56
+ displacements = model.localization(dense_features)[0].cpu()
57
+ displacements_i = displacements[
58
+ 0, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :]
59
+ ]
60
+ displacements_j = displacements[
61
+ 1, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :]
62
+ ]
63
+ del displacements
64
+
65
+ mask = torch.min(
66
+ torch.abs(displacements_i) < 0.5,
67
+ torch.abs(displacements_j) < 0.5
68
+ )
69
+ fmap_pos = fmap_pos[:, mask]
70
+ valid_displacements = torch.stack([
71
+ displacements_i[mask],
72
+ displacements_j[mask]
73
+ ], dim=0)
74
+ del mask, displacements_i, displacements_j
75
+
76
+ fmap_keypoints = fmap_pos[1 :, :].float() + valid_displacements
77
+ del valid_displacements
78
+
79
+ try:
80
+ raw_descriptors, _, ids = interpolate_dense_features(
81
+ fmap_keypoints.to(device),
82
+ dense_features[0]
83
+ )
84
+ except EmptyTensorError:
85
+ continue
86
+ fmap_pos = fmap_pos.to(device)
87
+ fmap_keypoints = fmap_keypoints.to(device)
88
+ fmap_pos = fmap_pos[:, ids]
89
+ fmap_keypoints = fmap_keypoints[:, ids]
90
+ del ids
91
+
92
+ keypoints = upscale_positions(fmap_keypoints, scaling_steps=2)
93
+ del fmap_keypoints
94
+
95
+ descriptors = F.normalize(raw_descriptors, dim=0).cpu()
96
+ del raw_descriptors
97
+
98
+ keypoints[0, :] *= h_init / h_level
99
+ keypoints[1, :] *= w_init / w_level
100
+
101
+ fmap_pos = fmap_pos.cpu()
102
+ keypoints = keypoints.cpu()
103
+
104
+ keypoints = torch.cat([
105
+ keypoints,
106
+ torch.ones([1, keypoints.size(1)]) * 1 / scale,
107
+ ], dim=0)
108
+
109
+ scores = dense_features[
110
+ 0, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :]
111
+ ].cpu() / (idx + 1)
112
+ del fmap_pos
113
+
114
+ all_keypoints = torch.cat([all_keypoints, keypoints], dim=1)
115
+ all_descriptors = torch.cat([all_descriptors, descriptors], dim=1)
116
+ all_scores = torch.cat([all_scores, scores], dim=0)
117
+ del keypoints, descriptors
118
+
119
+ previous_dense_features = dense_features
120
+ del dense_features
121
+ del previous_dense_features, banned
122
+
123
+ keypoints = all_keypoints.t().detach().numpy()
124
+ del all_keypoints
125
+ scores = all_scores.detach().numpy()
126
+ del all_scores
127
+ descriptors = all_descriptors.t().detach().numpy()
128
+ del all_descriptors
129
+ return keypoints, scores, descriptors
third_party/d2net/lib/utils.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+
3
+ import numpy as np
4
+
5
+ import torch
6
+
7
+ from lib.exceptions import EmptyTensorError
8
+
9
+
10
+ def preprocess_image(image, preprocessing=None):
11
+ image = image.astype(np.float32)
12
+ image = np.transpose(image, [2, 0, 1])
13
+ if preprocessing is None:
14
+ pass
15
+ elif preprocessing == 'caffe':
16
+ # RGB -> BGR
17
+ image = image[:: -1, :, :]
18
+ # Zero-center by mean pixel
19
+ mean = np.array([103.939, 116.779, 123.68])
20
+ image = image - mean.reshape([3, 1, 1])
21
+ elif preprocessing == 'torch':
22
+ image /= 255.0
23
+ mean = np.array([0.485, 0.456, 0.406])
24
+ std = np.array([0.229, 0.224, 0.225])
25
+ image = (image - mean.reshape([3, 1, 1])) / std.reshape([3, 1, 1])
26
+ else:
27
+ raise ValueError('Unknown preprocessing parameter.')
28
+ return image
29
+
30
+
31
+ def imshow_image(image, preprocessing=None):
32
+ if preprocessing is None:
33
+ pass
34
+ elif preprocessing == 'caffe':
35
+ mean = np.array([103.939, 116.779, 123.68])
36
+ image = image + mean.reshape([3, 1, 1])
37
+ # RGB -> BGR
38
+ image = image[:: -1, :, :]
39
+ elif preprocessing == 'torch':
40
+ mean = np.array([0.485, 0.456, 0.406])
41
+ std = np.array([0.229, 0.224, 0.225])
42
+ image = image * std.reshape([3, 1, 1]) + mean.reshape([3, 1, 1])
43
+ image *= 255.0
44
+ else:
45
+ raise ValueError('Unknown preprocessing parameter.')
46
+ image = np.transpose(image, [1, 2, 0])
47
+ image = np.round(image).astype(np.uint8)
48
+ return image
49
+
50
+
51
+ def grid_positions(h, w, device, matrix=False):
52
+ lines = torch.arange(
53
+ 0, h, device=device
54
+ ).view(-1, 1).float().repeat(1, w)
55
+ columns = torch.arange(
56
+ 0, w, device=device
57
+ ).view(1, -1).float().repeat(h, 1)
58
+ if matrix:
59
+ return torch.stack([lines, columns], dim=0)
60
+ else:
61
+ return torch.cat([lines.view(1, -1), columns.view(1, -1)], dim=0)
62
+
63
+
64
+ def upscale_positions(pos, scaling_steps=0):
65
+ for _ in range(scaling_steps):
66
+ pos = pos * 2 + 0.5
67
+ return pos
68
+
69
+
70
+ def downscale_positions(pos, scaling_steps=0):
71
+ for _ in range(scaling_steps):
72
+ pos = (pos - 0.5) / 2
73
+ return pos
74
+
75
+
76
+ def interpolate_dense_features(pos, dense_features, return_corners=False):
77
+ device = pos.device
78
+
79
+ ids = torch.arange(0, pos.size(1), device=device)
80
+
81
+ _, h, w = dense_features.size()
82
+
83
+ i = pos[0, :]
84
+ j = pos[1, :]
85
+
86
+ # Valid corners
87
+ i_top_left = torch.floor(i).long()
88
+ j_top_left = torch.floor(j).long()
89
+ valid_top_left = torch.min(i_top_left >= 0, j_top_left >= 0)
90
+
91
+ i_top_right = torch.floor(i).long()
92
+ j_top_right = torch.ceil(j).long()
93
+ valid_top_right = torch.min(i_top_right >= 0, j_top_right < w)
94
+
95
+ i_bottom_left = torch.ceil(i).long()
96
+ j_bottom_left = torch.floor(j).long()
97
+ valid_bottom_left = torch.min(i_bottom_left < h, j_bottom_left >= 0)
98
+
99
+ i_bottom_right = torch.ceil(i).long()
100
+ j_bottom_right = torch.ceil(j).long()
101
+ valid_bottom_right = torch.min(i_bottom_right < h, j_bottom_right < w)
102
+
103
+ valid_corners = torch.min(
104
+ torch.min(valid_top_left, valid_top_right),
105
+ torch.min(valid_bottom_left, valid_bottom_right)
106
+ )
107
+
108
+ i_top_left = i_top_left[valid_corners]
109
+ j_top_left = j_top_left[valid_corners]
110
+
111
+ i_top_right = i_top_right[valid_corners]
112
+ j_top_right = j_top_right[valid_corners]
113
+
114
+ i_bottom_left = i_bottom_left[valid_corners]
115
+ j_bottom_left = j_bottom_left[valid_corners]
116
+
117
+ i_bottom_right = i_bottom_right[valid_corners]
118
+ j_bottom_right = j_bottom_right[valid_corners]
119
+
120
+ ids = ids[valid_corners]
121
+ if ids.size(0) == 0:
122
+ raise EmptyTensorError
123
+
124
+ # Interpolation
125
+ i = i[ids]
126
+ j = j[ids]
127
+ dist_i_top_left = i - i_top_left.float()
128
+ dist_j_top_left = j - j_top_left.float()
129
+ w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left)
130
+ w_top_right = (1 - dist_i_top_left) * dist_j_top_left
131
+ w_bottom_left = dist_i_top_left * (1 - dist_j_top_left)
132
+ w_bottom_right = dist_i_top_left * dist_j_top_left
133
+
134
+ descriptors = (
135
+ w_top_left * dense_features[:, i_top_left, j_top_left] +
136
+ w_top_right * dense_features[:, i_top_right, j_top_right] +
137
+ w_bottom_left * dense_features[:, i_bottom_left, j_bottom_left] +
138
+ w_bottom_right * dense_features[:, i_bottom_right, j_bottom_right]
139
+ )
140
+
141
+ pos = torch.cat([i.view(1, -1), j.view(1, -1)], dim=0)
142
+
143
+ if not return_corners:
144
+ return [descriptors, pos, ids]
145
+ else:
146
+ corners = torch.stack([
147
+ torch.stack([i_top_left, j_top_left], dim=0),
148
+ torch.stack([i_top_right, j_top_right], dim=0),
149
+ torch.stack([i_bottom_left, j_bottom_left], dim=0),
150
+ torch.stack([i_bottom_right, j_bottom_right], dim=0)
151
+ ], dim=0)
152
+ return [descriptors, pos, ids, corners]
153
+
154
+
155
+ def savefig(filepath, fig=None, dpi=None):
156
+ # TomNorway - https://stackoverflow.com/a/53516034
157
+ if not fig:
158
+ fig = plt.gcf()
159
+
160
+ plt.subplots_adjust(0, 0, 1, 1, 0, 0)
161
+ for ax in fig.axes:
162
+ ax.axis('off')
163
+ ax.margins(0, 0)
164
+ ax.xaxis.set_major_locator(plt.NullLocator())
165
+ ax.yaxis.set_major_locator(plt.NullLocator())
166
+
167
+ fig.savefig(filepath, pad_inches=0, bbox_inches='tight', dpi=dpi)