File size: 28,471 Bytes
fb98d2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# visloc script with support for coarse to fine
# --------------------------------------------------------
import os
import numpy as np
import random
import torch
import torchvision.transforms as tvf
import argparse
from tqdm import tqdm
from PIL import Image
import math

from mast3r.model import AsymmetricMASt3R
from mast3r.fast_nn import fast_reciprocal_NNs
from mast3r.utils.coarse_to_fine import select_pairs_of_crops, crop_slice
from mast3r.utils.collate import cat_collate, cat_collate_fn_map
from mast3r.utils.misc import mkdir_for
from mast3r.datasets.utils.cropping import crop_to_homography

import mast3r.utils.path_to_dust3r  # noqa
from dust3r.inference import inference, loss_of_one_batch
from dust3r.utils.geometry import geotrf, colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics
from dust3r.datasets.utils.transforms import ImgNorm
from dust3r_visloc.datasets import *
from dust3r_visloc.localization import run_pnp
from dust3r_visloc.evaluation import get_pose_error, aggregate_stats, export_results
from dust3r_visloc.datasets.utils import get_HW_resolution, rescale_points3d


def get_args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, required=True, help="visloc dataset to eval")
    parser_weights = parser.add_mutually_exclusive_group(required=True)
    parser_weights.add_argument("--weights", type=str, help="path to the model weights", default=None)
    parser_weights.add_argument("--model_name", type=str, help="name of the model weights",
                                choices=["MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"])

    parser.add_argument("--confidence_threshold", type=float, default=1.001,
                        help="confidence values higher than threshold are invalid")
    parser.add_argument('--pixel_tol', default=5, type=int)

    parser.add_argument("--coarse_to_fine", action='store_true', default=False,
                        help="do the matching from coarse to fine")
    parser.add_argument("--max_image_size", type=int, default=None,
                        help="max image size for the fine resolution")
    parser.add_argument("--c2f_crop_with_homography", action='store_true', default=False,
                        help="when using coarse to fine, crop with homographies to keep cx, cy centered")

    parser.add_argument("--device", type=str, default='cuda', help="pytorch device")
    parser.add_argument("--pnp_mode", type=str, default="cv2", choices=['cv2', 'poselib', 'pycolmap'],
                        help="pnp lib to use")
    parser_reproj = parser.add_mutually_exclusive_group()
    parser_reproj.add_argument("--reprojection_error", type=float, default=5.0, help="pnp reprojection error")
    parser_reproj.add_argument("--reprojection_error_diag_ratio", type=float, default=None,
                               help="pnp reprojection error as a ratio of the diagonal of the image")

    parser.add_argument("--max_batch_size", type=int, default=48,
                        help="max batch size for inference on crops when using coarse to fine")
    parser.add_argument("--pnp_max_points", type=int, default=100_000, help="pnp maximum number of points kept")
    parser.add_argument("--viz_matches", type=int, default=0, help="debug matches")

    parser.add_argument("--output_dir", type=str, default=None, help="output path")
    parser.add_argument("--output_label", type=str, default='', help="prefix for results files")
    return parser


@torch.no_grad()
def coarse_matching(query_view, map_view, model, device, pixel_tol, fast_nn_params):
    # prepare batch
    imgs = []
    for idx, img in enumerate([query_view['rgb_rescaled'], map_view['rgb_rescaled']]):
        imgs.append(dict(img=img.unsqueeze(0), true_shape=np.int32([img.shape[1:]]),
                         idx=idx, instance=str(idx)))
    output = inference([tuple(imgs)], model, device, batch_size=1, verbose=False)
    pred1, pred2 = output['pred1'], output['pred2']
    conf_list = [pred1['desc_conf'].squeeze(0).cpu().numpy(), pred2['desc_conf'].squeeze(0).cpu().numpy()]
    desc_list = [pred1['desc'].squeeze(0).detach(), pred2['desc'].squeeze(0).detach()]

    # find 2D-2D matches between the two images
    PQ, PM = desc_list[0], desc_list[1]
    if len(PQ) == 0 or len(PM) == 0:
        return [], [], [], []

    if pixel_tol == 0:
        matches_im_map, matches_im_query = fast_reciprocal_NNs(PM, PQ, subsample_or_initxy1=8, **fast_nn_params)
        HM, WM = map_view['rgb_rescaled'].shape[1:]
        HQ, WQ = query_view['rgb_rescaled'].shape[1:]
        # ignore small border around the edge
        valid_matches_map = (matches_im_map[:, 0] >= 3) & (matches_im_map[:, 0] < WM - 3) & (
            matches_im_map[:, 1] >= 3) & (matches_im_map[:, 1] < HM - 3)
        valid_matches_query = (matches_im_query[:, 0] >= 3) & (matches_im_query[:, 0] < WQ - 3) & (
            matches_im_query[:, 1] >= 3) & (matches_im_query[:, 1] < HQ - 3)
        valid_matches = valid_matches_map & valid_matches_query
        matches_im_map = matches_im_map[valid_matches]
        matches_im_query = matches_im_query[valid_matches]
        valid_pts3d = []
        matches_confs = []
    else:
        yM, xM = torch.where(map_view['valid_rescaled'])
        matches_im_map, matches_im_query = fast_reciprocal_NNs(PM, PQ, (xM, yM), pixel_tol=pixel_tol, **fast_nn_params)
        valid_pts3d = map_view['pts3d_rescaled'].cpu().numpy()[matches_im_map[:, 1], matches_im_map[:, 0]]
        matches_confs = np.minimum(
            conf_list[1][matches_im_map[:, 1], matches_im_map[:, 0]],
            conf_list[0][matches_im_query[:, 1], matches_im_query[:, 0]]
        )
    # from cv2 to colmap
    matches_im_query = matches_im_query.astype(np.float64)
    matches_im_map = matches_im_map.astype(np.float64)
    matches_im_query[:, 0] += 0.5
    matches_im_query[:, 1] += 0.5
    matches_im_map[:, 0] += 0.5
    matches_im_map[:, 1] += 0.5
    # rescale coordinates
    matches_im_query = geotrf(query_view['to_orig'], matches_im_query, norm=True)
    matches_im_map = geotrf(map_view['to_orig'], matches_im_map, norm=True)
    # from colmap back to cv2
    matches_im_query[:, 0] -= 0.5
    matches_im_query[:, 1] -= 0.5
    matches_im_map[:, 0] -= 0.5
    matches_im_map[:, 1] -= 0.5
    return valid_pts3d, matches_im_query, matches_im_map, matches_confs


@torch.no_grad()
def crops_inference(pairs, model, device, batch_size=48, verbose=True):
    assert len(pairs) == 2, "Error, data should be a tuple of dicts containing the batch of image pairs"
    # Forward a possibly big bunch of data, by blocks of batch_size
    B = pairs[0]['img'].shape[0]
    if B < batch_size:
        return loss_of_one_batch(pairs, model, None, device=device, symmetrize_batch=False)
    preds = []
    for ii in range(0, B, batch_size):
        sel = slice(ii, ii + min(B - ii, batch_size))
        temp_data = [{}, {}]
        for di in [0, 1]:
            temp_data[di] = {kk: pairs[di][kk][sel]
                             for kk in pairs[di].keys() if pairs[di][kk] is not None}  # copy chunk for forward
        preds.append(loss_of_one_batch(temp_data, model,
                                       None, device=device, symmetrize_batch=False))  # sequential forward
    # Merge all preds
    return cat_collate(preds, collate_fn_map=cat_collate_fn_map)


@torch.no_grad()
def fine_matching(query_views, map_views, model, device, max_batch_size, pixel_tol, fast_nn_params):
    assert pixel_tol > 0
    output = crops_inference([query_views, map_views],
                             model, device, batch_size=max_batch_size, verbose=False)
    pred1, pred2 = output['pred1'], output['pred2']
    descs1 = pred1['desc'].clone()
    descs2 = pred2['desc'].clone()
    confs1 = pred1['desc_conf'].clone()
    confs2 = pred2['desc_conf'].clone()

    # Compute matches
    valid_pts3d, matches_im_map, matches_im_query, matches_confs = [], [], [], []
    for ppi, (pp1, pp2, cc11, cc21) in enumerate(zip(descs1, descs2, confs1, confs2)):
        valid_ppi = map_views['valid'][ppi]
        pts3d_ppi = map_views['pts3d'][ppi].cpu().numpy()
        conf_list_ppi = [cc11.cpu().numpy(), cc21.cpu().numpy()]

        y_ppi, x_ppi = torch.where(valid_ppi)
        matches_im_map_ppi, matches_im_query_ppi = fast_reciprocal_NNs(pp2, pp1, (x_ppi, y_ppi),
                                                                       pixel_tol=pixel_tol, **fast_nn_params)

        valid_pts3d_ppi = pts3d_ppi[matches_im_map_ppi[:, 1], matches_im_map_ppi[:, 0]]
        matches_confs_ppi = np.minimum(
            conf_list_ppi[1][matches_im_map_ppi[:, 1], matches_im_map_ppi[:, 0]],
            conf_list_ppi[0][matches_im_query_ppi[:, 1], matches_im_query_ppi[:, 0]]
        )
        # inverse operation where we uncrop pixel coordinates
        matches_im_map_ppi = geotrf(map_views['to_orig'][ppi].cpu().numpy(), matches_im_map_ppi.copy(), norm=True)
        matches_im_query_ppi = geotrf(query_views['to_orig'][ppi].cpu().numpy(), matches_im_query_ppi.copy(), norm=True)

        matches_im_map.append(matches_im_map_ppi)
        matches_im_query.append(matches_im_query_ppi)
        valid_pts3d.append(valid_pts3d_ppi)
        matches_confs.append(matches_confs_ppi)

    if len(valid_pts3d) == 0:
        return [], [], [], []

    matches_im_map = np.concatenate(matches_im_map, axis=0)
    matches_im_query = np.concatenate(matches_im_query, axis=0)
    valid_pts3d = np.concatenate(valid_pts3d, axis=0)
    matches_confs = np.concatenate(matches_confs, axis=0)
    return valid_pts3d, matches_im_query, matches_im_map, matches_confs


def crop(img, mask, pts3d, crop, intrinsics=None):
    out_cropped_img = img.clone()
    if mask is not None:
        out_cropped_mask = mask.clone()
    else:
        out_cropped_mask = None
    if pts3d is not None:
        out_cropped_pts3d = pts3d.clone()
    else:
        out_cropped_pts3d = None
    to_orig = torch.eye(3, device=img.device)

    # If intrinsics available, crop and apply rectifying homography. Otherwise, just crop
    if intrinsics is not None:
        K_old = intrinsics
        imsize, K_new, R, H = crop_to_homography(K_old, crop)
        # apply homography to image
        H /= H[2, 2]
        homo8 = H.ravel().tolist()[:8]
        # From float tensor to uint8 PIL Image
        pilim = Image.fromarray((255 * (img + 1.) / 2).to(torch.uint8).numpy())
        pilout_cropped_img = pilim.transform(imsize, Image.Transform.PERSPECTIVE,
                                             homo8, resample=Image.Resampling.BICUBIC)

        # From uint8 PIL Image to float tensor
        out_cropped_img = 2. * torch.tensor(np.array(pilout_cropped_img)).to(img) / 255. - 1.
        if out_cropped_mask is not None:
            pilmask = Image.fromarray((255 * out_cropped_mask).to(torch.uint8).numpy())
            pilout_cropped_mask = pilmask.transform(
                imsize, Image.Transform.PERSPECTIVE, homo8, resample=Image.Resampling.NEAREST)
            out_cropped_mask = torch.from_numpy(np.array(pilout_cropped_mask) > 0).to(out_cropped_mask.dtype)
        if out_cropped_pts3d is not None:
            out_cropped_pts3d = out_cropped_pts3d.numpy()
            out_cropped_X = np.array(Image.fromarray(out_cropped_pts3d[:, :, 0]).transform(imsize,
                                                                                           Image.Transform.PERSPECTIVE,
                                                                                           homo8,
                                                                                           resample=Image.Resampling.NEAREST))
            out_cropped_Y = np.array(Image.fromarray(out_cropped_pts3d[:, :, 1]).transform(imsize,
                                                                                           Image.Transform.PERSPECTIVE,
                                                                                           homo8,
                                                                                           resample=Image.Resampling.NEAREST))
            out_cropped_Z = np.array(Image.fromarray(out_cropped_pts3d[:, :, 2]).transform(imsize,
                                                                                           Image.Transform.PERSPECTIVE,
                                                                                           homo8,
                                                                                           resample=Image.Resampling.NEAREST))

            out_cropped_pts3d = torch.from_numpy(np.stack([out_cropped_X, out_cropped_Y, out_cropped_Z], axis=-1))

        to_orig = torch.tensor(H, device=img.device)
    else:
        out_cropped_img = img[crop_slice(crop)]
        if out_cropped_mask is not None:
            out_cropped_mask = out_cropped_mask[crop_slice(crop)]
        if out_cropped_pts3d is not None:
            out_cropped_pts3d = out_cropped_pts3d[crop_slice(crop)]
        to_orig[:2, -1] = torch.tensor(crop[:2])

    return out_cropped_img, out_cropped_mask, out_cropped_pts3d, to_orig


def resize_image_to_max(max_image_size, rgb, K):
    W, H = rgb.size
    if max_image_size and max(W, H) > max_image_size:
        islandscape = (W >= H)
        if islandscape:
            WMax = max_image_size
            HMax = int(H * (WMax / W))
        else:
            HMax = max_image_size
            WMax = int(W * (HMax / H))
        resize_op = tvf.Compose([ImgNorm, tvf.Resize(size=[HMax, WMax])])
        rgb_tensor = resize_op(rgb).permute(1, 2, 0)
        to_orig_max = np.array([[W / WMax, 0, 0],
                                [0, H / HMax, 0],
                                [0, 0, 1]])
        to_resize_max = np.array([[WMax / W, 0, 0],
                                  [0, HMax / H, 0],
                                  [0, 0, 1]])

        # Generate new camera parameters
        new_K = opencv_to_colmap_intrinsics(K)
        new_K[0, :] *= WMax / W
        new_K[1, :] *= HMax / H
        new_K = colmap_to_opencv_intrinsics(new_K)
    else:
        rgb_tensor = ImgNorm(rgb).permute(1, 2, 0)
        to_orig_max = np.eye(3)
        to_resize_max = np.eye(3)
        HMax, WMax = H, W
        new_K = K
    return rgb_tensor, new_K, to_orig_max, to_resize_max, (HMax, WMax)


if __name__ == '__main__':
    parser = get_args_parser()
    args = parser.parse_args()
    conf_thr = args.confidence_threshold
    device = args.device
    pnp_mode = args.pnp_mode
    assert args.pixel_tol > 0
    reprojection_error = args.reprojection_error
    reprojection_error_diag_ratio = args.reprojection_error_diag_ratio
    pnp_max_points = args.pnp_max_points
    viz_matches = args.viz_matches

    if args.weights is not None:
        weights_path = args.weights
    else:
        weights_path = "naver/" + args.model_name
    model = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device)
    fast_nn_params = dict(device=device, dist='dot', block_size=2**13)
    dataset = eval(args.dataset)
    dataset.set_resolution(model)

    query_names = []
    poses_pred = []
    pose_errors = []
    angular_errors = []
    params_str = f'tol_{args.pixel_tol}' + ("_c2f" if args.coarse_to_fine else '')
    if args.max_image_size is not None:
        params_str = params_str + f'_{args.max_image_size}'
    if args.coarse_to_fine and args.c2f_crop_with_homography:
        params_str = params_str + '_with_homography'
    for idx in tqdm(range(len(dataset))):
        views = dataset[(idx)]  # 0 is the query
        query_view = views[0]
        map_views = views[1:]
        query_names.append(query_view['image_name'])

        query_pts2d = []
        query_pts3d = []
        maxdim = max(model.patch_embed.img_size)
        query_rgb_tensor, query_K, query_to_orig_max, query_to_resize_max, (HQ, WQ) = resize_image_to_max(
            args.max_image_size, query_view['rgb'], query_view['intrinsics'])

        # pairs of crops have the same resolution
        query_resolution = get_HW_resolution(HQ, WQ, maxdim=maxdim, patchsize=model.patch_embed.patch_size)
        for map_view in map_views:
            if args.output_dir is not None:
                cache_file = os.path.join(args.output_dir, 'matches', params_str,
                                          query_view['image_name'], map_view['image_name'] + '.npz')
            else:
                cache_file = None

            if cache_file is not None and os.path.isfile(cache_file):
                matches = np.load(cache_file)
                valid_pts3d = matches['valid_pts3d']
                matches_im_query = matches['matches_im_query']
                matches_im_map = matches['matches_im_map']
                matches_conf = matches['matches_conf']
            else:
                # coarse matching
                if args.coarse_to_fine and (maxdim < max(WQ, HQ)):
                    # use all points
                    _, coarse_matches_im0, coarse_matches_im1, _ = coarse_matching(query_view, map_view, model, device,
                                                                                   0, fast_nn_params)

                    # visualize a few matches
                    if viz_matches > 0:
                        num_matches = coarse_matches_im1.shape[0]
                        print(f'found {num_matches} matches')

                        viz_imgs = [np.array(query_view['rgb']), np.array(map_view['rgb'])]
                        from matplotlib import pyplot as pl
                        n_viz = viz_matches
                        match_idx_to_viz = np.round(np.linspace(0, num_matches - 1, n_viz)).astype(int)
                        viz_matches_im_query = coarse_matches_im0[match_idx_to_viz]
                        viz_matches_im_map = coarse_matches_im1[match_idx_to_viz]

                        H0, W0, H1, W1 = *viz_imgs[0].shape[:2], *viz_imgs[1].shape[:2]
                        img0 = np.pad(viz_imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)),
                                      'constant', constant_values=0)
                        img1 = np.pad(viz_imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)),
                                      'constant', constant_values=0)
                        img = np.concatenate((img0, img1), axis=1)
                        pl.figure()
                        pl.imshow(img)
                        cmap = pl.get_cmap('jet')
                        for i in range(n_viz):
                            (x0, y0), (x1, y1) = viz_matches_im_query[i].T, viz_matches_im_map[i].T
                            pl.plot([x0, x1 + W0], [y0, y1], '-+',
                                    color=cmap(i / (n_viz - 1)), scalex=False, scaley=False)
                        pl.show(block=True)

                    valid_all = map_view['valid']
                    pts3d = map_view['pts3d']

                    WM_full, HM_full = map_view['rgb'].size
                    map_rgb_tensor, map_K, map_to_orig_max, map_to_resize_max, (HM, WM) = resize_image_to_max(
                        args.max_image_size, map_view['rgb'], map_view['intrinsics'])
                    if WM_full != WM or HM_full != HM:
                        y_full, x_full = torch.where(valid_all)
                        pos2d_cv2 = torch.stack([x_full, y_full], dim=-1).cpu().numpy().astype(np.float64)
                        sparse_pts3d = pts3d[y_full, x_full].cpu().numpy()
                        _, _, pts3d_max, valid_max = rescale_points3d(
                            pos2d_cv2, sparse_pts3d, map_to_resize_max, HM, WM)
                        pts3d = torch.from_numpy(pts3d_max)
                        valid_all = torch.from_numpy(valid_max)

                    coarse_matches_im0 = geotrf(query_to_resize_max, coarse_matches_im0, norm=True)
                    coarse_matches_im1 = geotrf(map_to_resize_max, coarse_matches_im1, norm=True)

                    crops1, crops2 = [], []
                    crops_v1, crops_p1 = [], []
                    to_orig1, to_orig2 = [], []
                    map_resolution = get_HW_resolution(HM, WM, maxdim=maxdim, patchsize=model.patch_embed.patch_size)

                    for crop_q, crop_b, pair_tag in select_pairs_of_crops(map_rgb_tensor,
                                                                          query_rgb_tensor,
                                                                          coarse_matches_im1,
                                                                          coarse_matches_im0,
                                                                          maxdim=maxdim,
                                                                          overlap=.5,
                                                                          forced_resolution=[map_resolution,
                                                                                             query_resolution]):
                        # Per crop processing
                        if not args.c2f_crop_with_homography:
                            map_K = None
                            query_K = None

                        c1, v1, p1, trf1 = crop(map_rgb_tensor, valid_all, pts3d, crop_q, map_K)
                        c2, _, _, trf2 = crop(query_rgb_tensor, None, None, crop_b, query_K)
                        crops1.append(c1)
                        crops2.append(c2)
                        crops_v1.append(v1)
                        crops_p1.append(p1)
                        to_orig1.append(trf1)
                        to_orig2.append(trf2)

                    if len(crops1) == 0 or len(crops2) == 0:
                        valid_pts3d, matches_im_query, matches_im_map, matches_conf = [], [], [], []
                    else:
                        crops1, crops2 = torch.stack(crops1), torch.stack(crops2)
                        if len(crops1.shape) == 3:
                            crops1, crops2 = crops1[None], crops2[None]
                        crops_v1 = torch.stack(crops_v1)
                        crops_p1 = torch.stack(crops_p1)
                        to_orig1, to_orig2 = torch.stack(to_orig1), torch.stack(to_orig2)
                        map_crop_view = dict(img=crops1.permute(0, 3, 1, 2),
                                             instance=['1' for _ in range(crops1.shape[0])],
                                             valid=crops_v1, pts3d=crops_p1,
                                             to_orig=to_orig1)
                        query_crop_view = dict(img=crops2.permute(0, 3, 1, 2),
                                               instance=['2' for _ in range(crops2.shape[0])],
                                               to_orig=to_orig2)

                        # Inference and Matching
                        valid_pts3d, matches_im_query, matches_im_map, matches_conf = fine_matching(query_crop_view,
                                                                                                    map_crop_view,
                                                                                                    model, device,
                                                                                                    args.max_batch_size,
                                                                                                    args.pixel_tol,
                                                                                                    fast_nn_params)
                        matches_im_query = geotrf(query_to_orig_max, matches_im_query, norm=True)
                        matches_im_map = geotrf(map_to_orig_max, matches_im_map, norm=True)
                else:
                    # use only valid 2d points
                    valid_pts3d, matches_im_query, matches_im_map, matches_conf = coarse_matching(query_view, map_view,
                                                                                                  model, device,
                                                                                                  args.pixel_tol,
                                                                                                  fast_nn_params)
                if cache_file is not None:
                    mkdir_for(cache_file)
                    np.savez(cache_file, valid_pts3d=valid_pts3d, matches_im_query=matches_im_query,
                             matches_im_map=matches_im_map, matches_conf=matches_conf)

            # apply conf
            if len(matches_conf) > 0:
                mask = matches_conf >= conf_thr
                valid_pts3d = valid_pts3d[mask]
                matches_im_query = matches_im_query[mask]
                matches_im_map = matches_im_map[mask]
                matches_conf = matches_conf[mask]

            # visualize a few matches
            if viz_matches > 0:
                num_matches = matches_im_map.shape[0]
                print(f'found {num_matches} matches')

                viz_imgs = [np.array(query_view['rgb']), np.array(map_view['rgb'])]
                from matplotlib import pyplot as pl
                n_viz = viz_matches
                match_idx_to_viz = np.round(np.linspace(0, num_matches - 1, n_viz)).astype(int)
                viz_matches_im_query = matches_im_query[match_idx_to_viz]
                viz_matches_im_map = matches_im_map[match_idx_to_viz]

                H0, W0, H1, W1 = *viz_imgs[0].shape[:2], *viz_imgs[1].shape[:2]
                img0 = np.pad(viz_imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), 'constant', constant_values=0)
                img1 = np.pad(viz_imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), 'constant', constant_values=0)
                img = np.concatenate((img0, img1), axis=1)
                pl.figure()
                pl.imshow(img)
                cmap = pl.get_cmap('jet')
                for i in range(n_viz):
                    (x0, y0), (x1, y1) = viz_matches_im_query[i].T, viz_matches_im_map[i].T
                    pl.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(i / (n_viz - 1)), scalex=False, scaley=False)
                pl.show(block=True)

            if len(valid_pts3d) == 0:
                pass
            else:
                query_pts3d.append(valid_pts3d)
                query_pts2d.append(matches_im_query)

        if len(query_pts2d) == 0:
            success = False
            pr_querycam_to_world = None
        else:
            query_pts2d = np.concatenate(query_pts2d, axis=0).astype(np.float32)
            query_pts3d = np.concatenate(query_pts3d, axis=0)
            if len(query_pts2d) > pnp_max_points:
                idxs = random.sample(range(len(query_pts2d)), pnp_max_points)
                query_pts3d = query_pts3d[idxs]
                query_pts2d = query_pts2d[idxs]

            W, H = query_view['rgb'].size
            if reprojection_error_diag_ratio is not None:
                reprojection_error_img = reprojection_error_diag_ratio * math.sqrt(W**2 + H**2)
            else:
                reprojection_error_img = reprojection_error
            success, pr_querycam_to_world = run_pnp(query_pts2d, query_pts3d,
                                                    query_view['intrinsics'], query_view['distortion'],
                                                    pnp_mode, reprojection_error_img, img_size=[W, H])

        if not success:
            abs_transl_error = float('inf')
            abs_angular_error = float('inf')
        else:
            abs_transl_error, abs_angular_error = get_pose_error(pr_querycam_to_world, query_view['cam_to_world'])

        pose_errors.append(abs_transl_error)
        angular_errors.append(abs_angular_error)
        poses_pred.append(pr_querycam_to_world)

    xp_label = params_str + f'_conf_{conf_thr}'
    if args.output_label:
        xp_label = args.output_label + "_" + xp_label
    if reprojection_error_diag_ratio is not None:
        xp_label = xp_label + f'_reproj_diag_{reprojection_error_diag_ratio}'
    else:
        xp_label = xp_label + f'_reproj_err_{reprojection_error}'
    export_results(args.output_dir, xp_label, query_names, poses_pred)
    out_string = aggregate_stats(f'{args.dataset}', pose_errors, angular_errors)
    print(out_string)