|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from pathlib import Path |
|
import argparse |
|
import random |
|
import numpy as np |
|
import matplotlib.cm as cm |
|
import torch |
|
|
|
|
|
from models.matching import Matching |
|
from models.utils import ( |
|
compute_pose_error, |
|
compute_epipolar_error, |
|
estimate_pose, |
|
make_matching_plot, |
|
error_colormap, |
|
AverageTimer, |
|
pose_auc, |
|
read_image, |
|
rotate_intrinsics, |
|
rotate_pose_inplane, |
|
scale_intrinsics, |
|
) |
|
|
|
torch.set_grad_enabled(False) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser( |
|
description="Image pair matching and pose evaluation with SuperGlue", |
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
|
) |
|
|
|
parser.add_argument( |
|
"--input_pairs", |
|
type=str, |
|
default="assets/scannet_sample_pairs_with_gt.txt", |
|
help="Path to the list of image pairs", |
|
) |
|
parser.add_argument( |
|
"--input_dir", |
|
type=str, |
|
default="assets/scannet_sample_images/", |
|
help="Path to the directory that contains the images", |
|
) |
|
parser.add_argument( |
|
"--output_dir", |
|
type=str, |
|
default="dump_match_pairs/", |
|
help="Path to the directory in which the .npz results and optionally," |
|
"the visualization images are written", |
|
) |
|
|
|
parser.add_argument( |
|
"--max_length", type=int, default=-1, help="Maximum number of pairs to evaluate" |
|
) |
|
parser.add_argument( |
|
"--resize", |
|
type=int, |
|
nargs="+", |
|
default=[640, 480], |
|
help="Resize the input image before running inference. If two numbers, " |
|
"resize to the exact dimensions, if one number, resize the max " |
|
"dimension, if -1, do not resize", |
|
) |
|
parser.add_argument( |
|
"--resize_float", |
|
action="store_true", |
|
help="Resize the image after casting uint8 to float", |
|
) |
|
|
|
parser.add_argument( |
|
"--superglue", |
|
choices={"indoor", "outdoor"}, |
|
default="indoor", |
|
help="SuperGlue weights", |
|
) |
|
parser.add_argument( |
|
"--max_keypoints", |
|
type=int, |
|
default=1024, |
|
help="Maximum number of keypoints detected by Superpoint" |
|
" ('-1' keeps all keypoints)", |
|
) |
|
parser.add_argument( |
|
"--keypoint_threshold", |
|
type=float, |
|
default=0.005, |
|
help="SuperPoint keypoint detector confidence threshold", |
|
) |
|
parser.add_argument( |
|
"--nms_radius", |
|
type=int, |
|
default=4, |
|
help="SuperPoint Non Maximum Suppression (NMS) radius" " (Must be positive)", |
|
) |
|
parser.add_argument( |
|
"--sinkhorn_iterations", |
|
type=int, |
|
default=20, |
|
help="Number of Sinkhorn iterations performed by SuperGlue", |
|
) |
|
parser.add_argument( |
|
"--match_threshold", type=float, default=0.2, help="SuperGlue match threshold" |
|
) |
|
|
|
parser.add_argument( |
|
"--viz", action="store_true", help="Visualize the matches and dump the plots" |
|
) |
|
parser.add_argument( |
|
"--eval", |
|
action="store_true", |
|
help="Perform the evaluation" " (requires ground truth pose and intrinsics)", |
|
) |
|
parser.add_argument( |
|
"--fast_viz", |
|
action="store_true", |
|
help="Use faster image visualization with OpenCV instead of Matplotlib", |
|
) |
|
parser.add_argument( |
|
"--cache", |
|
action="store_true", |
|
help="Skip the pair if output .npz files are already found", |
|
) |
|
parser.add_argument( |
|
"--show_keypoints", |
|
action="store_true", |
|
help="Plot the keypoints in addition to the matches", |
|
) |
|
parser.add_argument( |
|
"--viz_extension", |
|
type=str, |
|
default="png", |
|
choices=["png", "pdf"], |
|
help="Visualization file extension. Use pdf for highest-quality.", |
|
) |
|
parser.add_argument( |
|
"--opencv_display", |
|
action="store_true", |
|
help="Visualize via OpenCV before saving output images", |
|
) |
|
parser.add_argument( |
|
"--shuffle", |
|
action="store_true", |
|
help="Shuffle ordering of pairs before processing", |
|
) |
|
parser.add_argument( |
|
"--force_cpu", action="store_true", help="Force pytorch to run in CPU mode." |
|
) |
|
|
|
opt = parser.parse_args() |
|
print(opt) |
|
|
|
assert not ( |
|
opt.opencv_display and not opt.viz |
|
), "Must use --viz with --opencv_display" |
|
assert not ( |
|
opt.opencv_display and not opt.fast_viz |
|
), "Cannot use --opencv_display without --fast_viz" |
|
assert not (opt.fast_viz and not opt.viz), "Must use --viz with --fast_viz" |
|
assert not ( |
|
opt.fast_viz and opt.viz_extension == "pdf" |
|
), "Cannot use pdf extension with --fast_viz" |
|
|
|
if len(opt.resize) == 2 and opt.resize[1] == -1: |
|
opt.resize = opt.resize[0:1] |
|
if len(opt.resize) == 2: |
|
print("Will resize to {}x{} (WxH)".format(opt.resize[0], opt.resize[1])) |
|
elif len(opt.resize) == 1 and opt.resize[0] > 0: |
|
print("Will resize max dimension to {}".format(opt.resize[0])) |
|
elif len(opt.resize) == 1: |
|
print("Will not resize images") |
|
else: |
|
raise ValueError("Cannot specify more than two integers for --resize") |
|
|
|
with open(opt.input_pairs, "r") as f: |
|
pairs = [l.split() for l in f.readlines()] |
|
|
|
if opt.max_length > -1: |
|
pairs = pairs[0 : np.min([len(pairs), opt.max_length])] |
|
|
|
if opt.shuffle: |
|
random.Random(0).shuffle(pairs) |
|
|
|
if opt.eval: |
|
if not all([len(p) == 38 for p in pairs]): |
|
raise ValueError( |
|
"All pairs should have ground truth info for evaluation." |
|
'File "{}" needs 38 valid entries per row'.format(opt.input_pairs) |
|
) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() and not opt.force_cpu else "cpu" |
|
print('Running inference on device "{}"'.format(device)) |
|
config = { |
|
"superpoint": { |
|
"nms_radius": opt.nms_radius, |
|
"keypoint_threshold": opt.keypoint_threshold, |
|
"max_keypoints": opt.max_keypoints, |
|
}, |
|
"superglue": { |
|
"weights": opt.superglue, |
|
"sinkhorn_iterations": opt.sinkhorn_iterations, |
|
"match_threshold": opt.match_threshold, |
|
}, |
|
} |
|
matching = Matching(config).eval().to(device) |
|
|
|
|
|
input_dir = Path(opt.input_dir) |
|
print('Looking for data in directory "{}"'.format(input_dir)) |
|
output_dir = Path(opt.output_dir) |
|
output_dir.mkdir(exist_ok=True, parents=True) |
|
print('Will write matches to directory "{}"'.format(output_dir)) |
|
if opt.eval: |
|
print("Will write evaluation results", 'to directory "{}"'.format(output_dir)) |
|
if opt.viz: |
|
print("Will write visualization images to", 'directory "{}"'.format(output_dir)) |
|
|
|
timer = AverageTimer(newline=True) |
|
for i, pair in enumerate(pairs): |
|
name0, name1 = pair[:2] |
|
stem0, stem1 = Path(name0).stem, Path(name1).stem |
|
matches_path = output_dir / "{}_{}_matches.npz".format(stem0, stem1) |
|
eval_path = output_dir / "{}_{}_evaluation.npz".format(stem0, stem1) |
|
viz_path = output_dir / "{}_{}_matches.{}".format( |
|
stem0, stem1, opt.viz_extension |
|
) |
|
viz_eval_path = output_dir / "{}_{}_evaluation.{}".format( |
|
stem0, stem1, opt.viz_extension |
|
) |
|
|
|
|
|
do_match = True |
|
do_eval = opt.eval |
|
do_viz = opt.viz |
|
do_viz_eval = opt.eval and opt.viz |
|
if opt.cache: |
|
if matches_path.exists(): |
|
try: |
|
results = np.load(matches_path) |
|
except: |
|
raise IOError("Cannot load matches .npz file: %s" % matches_path) |
|
|
|
kpts0, kpts1 = results["keypoints0"], results["keypoints1"] |
|
matches, conf = results["matches"], results["match_confidence"] |
|
do_match = False |
|
if opt.eval and eval_path.exists(): |
|
try: |
|
results = np.load(eval_path) |
|
except: |
|
raise IOError("Cannot load eval .npz file: %s" % eval_path) |
|
err_R, err_t = results["error_R"], results["error_t"] |
|
precision = results["precision"] |
|
matching_score = results["matching_score"] |
|
num_correct = results["num_correct"] |
|
epi_errs = results["epipolar_errors"] |
|
do_eval = False |
|
if opt.viz and viz_path.exists(): |
|
do_viz = False |
|
if opt.viz and opt.eval and viz_eval_path.exists(): |
|
do_viz_eval = False |
|
timer.update("load_cache") |
|
|
|
if not (do_match or do_eval or do_viz or do_viz_eval): |
|
timer.print("Finished pair {:5} of {:5}".format(i, len(pairs))) |
|
continue |
|
|
|
|
|
if len(pair) >= 5: |
|
rot0, rot1 = int(pair[2]), int(pair[3]) |
|
else: |
|
rot0, rot1 = 0, 0 |
|
|
|
|
|
image0, inp0, scales0 = read_image( |
|
input_dir / name0, device, opt.resize, rot0, opt.resize_float |
|
) |
|
image1, inp1, scales1 = read_image( |
|
input_dir / name1, device, opt.resize, rot1, opt.resize_float |
|
) |
|
if image0 is None or image1 is None: |
|
print( |
|
"Problem reading image pair: {} {}".format( |
|
input_dir / name0, input_dir / name1 |
|
) |
|
) |
|
exit(1) |
|
timer.update("load_image") |
|
|
|
if do_match: |
|
|
|
pred = matching({"image0": inp0, "image1": inp1}) |
|
pred = {k: v[0].cpu().numpy() for k, v in pred.items()} |
|
kpts0, kpts1 = pred["keypoints0"], pred["keypoints1"] |
|
matches, conf = pred["matches0"], pred["matching_scores0"] |
|
timer.update("matcher") |
|
|
|
|
|
out_matches = { |
|
"keypoints0": kpts0, |
|
"keypoints1": kpts1, |
|
"matches": matches, |
|
"match_confidence": conf, |
|
} |
|
np.savez(str(matches_path), **out_matches) |
|
|
|
|
|
valid = matches > -1 |
|
mkpts0 = kpts0[valid] |
|
mkpts1 = kpts1[matches[valid]] |
|
mconf = conf[valid] |
|
|
|
if do_eval: |
|
|
|
assert len(pair) == 38, "Pair does not have ground truth info" |
|
K0 = np.array(pair[4:13]).astype(float).reshape(3, 3) |
|
K1 = np.array(pair[13:22]).astype(float).reshape(3, 3) |
|
T_0to1 = np.array(pair[22:]).astype(float).reshape(4, 4) |
|
|
|
|
|
K0 = scale_intrinsics(K0, scales0) |
|
K1 = scale_intrinsics(K1, scales1) |
|
|
|
|
|
if rot0 != 0 or rot1 != 0: |
|
cam0_T_w = np.eye(4) |
|
cam1_T_w = T_0to1 |
|
if rot0 != 0: |
|
K0 = rotate_intrinsics(K0, image0.shape, rot0) |
|
cam0_T_w = rotate_pose_inplane(cam0_T_w, rot0) |
|
if rot1 != 0: |
|
K1 = rotate_intrinsics(K1, image1.shape, rot1) |
|
cam1_T_w = rotate_pose_inplane(cam1_T_w, rot1) |
|
cam1_T_cam0 = cam1_T_w @ np.linalg.inv(cam0_T_w) |
|
T_0to1 = cam1_T_cam0 |
|
|
|
epi_errs = compute_epipolar_error(mkpts0, mkpts1, T_0to1, K0, K1) |
|
correct = epi_errs < 5e-4 |
|
num_correct = np.sum(correct) |
|
precision = np.mean(correct) if len(correct) > 0 else 0 |
|
matching_score = num_correct / len(kpts0) if len(kpts0) > 0 else 0 |
|
|
|
thresh = 1.0 |
|
ret = estimate_pose(mkpts0, mkpts1, K0, K1, thresh) |
|
if ret is None: |
|
err_t, err_R = np.inf, np.inf |
|
else: |
|
R, t, inliers = ret |
|
err_t, err_R = compute_pose_error(T_0to1, R, t) |
|
|
|
|
|
out_eval = { |
|
"error_t": err_t, |
|
"error_R": err_R, |
|
"precision": precision, |
|
"matching_score": matching_score, |
|
"num_correct": num_correct, |
|
"epipolar_errors": epi_errs, |
|
} |
|
np.savez(str(eval_path), **out_eval) |
|
timer.update("eval") |
|
|
|
if do_viz: |
|
|
|
color = cm.jet(mconf) |
|
text = [ |
|
"SuperGlue", |
|
"Keypoints: {}:{}".format(len(kpts0), len(kpts1)), |
|
"Matches: {}".format(len(mkpts0)), |
|
] |
|
if rot0 != 0 or rot1 != 0: |
|
text.append("Rotation: {}:{}".format(rot0, rot1)) |
|
|
|
|
|
k_thresh = matching.superpoint.config["keypoint_threshold"] |
|
m_thresh = matching.superglue.config["match_threshold"] |
|
small_text = [ |
|
"Keypoint Threshold: {:.4f}".format(k_thresh), |
|
"Match Threshold: {:.2f}".format(m_thresh), |
|
"Image Pair: {}:{}".format(stem0, stem1), |
|
] |
|
|
|
make_matching_plot( |
|
image0, |
|
image1, |
|
kpts0, |
|
kpts1, |
|
mkpts0, |
|
mkpts1, |
|
color, |
|
text, |
|
viz_path, |
|
opt.show_keypoints, |
|
opt.fast_viz, |
|
opt.opencv_display, |
|
"Matches", |
|
small_text, |
|
) |
|
|
|
timer.update("viz_match") |
|
|
|
if do_viz_eval: |
|
|
|
color = np.clip((epi_errs - 0) / (1e-3 - 0), 0, 1) |
|
color = error_colormap(1 - color) |
|
deg, delta = " deg", "Delta " |
|
if not opt.fast_viz: |
|
deg, delta = "°", "$\\Delta$" |
|
e_t = "FAIL" if np.isinf(err_t) else "{:.1f}{}".format(err_t, deg) |
|
e_R = "FAIL" if np.isinf(err_R) else "{:.1f}{}".format(err_R, deg) |
|
text = [ |
|
"SuperGlue", |
|
"{}R: {}".format(delta, e_R), |
|
"{}t: {}".format(delta, e_t), |
|
"inliers: {}/{}".format(num_correct, (matches > -1).sum()), |
|
] |
|
if rot0 != 0 or rot1 != 0: |
|
text.append("Rotation: {}:{}".format(rot0, rot1)) |
|
|
|
|
|
k_thresh = matching.superpoint.config["keypoint_threshold"] |
|
m_thresh = matching.superglue.config["match_threshold"] |
|
small_text = [ |
|
"Keypoint Threshold: {:.4f}".format(k_thresh), |
|
"Match Threshold: {:.2f}".format(m_thresh), |
|
"Image Pair: {}:{}".format(stem0, stem1), |
|
] |
|
|
|
make_matching_plot( |
|
image0, |
|
image1, |
|
kpts0, |
|
kpts1, |
|
mkpts0, |
|
mkpts1, |
|
color, |
|
text, |
|
viz_eval_path, |
|
opt.show_keypoints, |
|
opt.fast_viz, |
|
opt.opencv_display, |
|
"Relative Pose", |
|
small_text, |
|
) |
|
|
|
timer.update("viz_eval") |
|
|
|
timer.print("Finished pair {:5} of {:5}".format(i, len(pairs))) |
|
|
|
if opt.eval: |
|
|
|
pose_errors = [] |
|
precisions = [] |
|
matching_scores = [] |
|
for pair in pairs: |
|
name0, name1 = pair[:2] |
|
stem0, stem1 = Path(name0).stem, Path(name1).stem |
|
eval_path = output_dir / "{}_{}_evaluation.npz".format(stem0, stem1) |
|
results = np.load(eval_path) |
|
pose_error = np.maximum(results["error_t"], results["error_R"]) |
|
pose_errors.append(pose_error) |
|
precisions.append(results["precision"]) |
|
matching_scores.append(results["matching_score"]) |
|
thresholds = [5, 10, 20] |
|
aucs = pose_auc(pose_errors, thresholds) |
|
aucs = [100.0 * yy for yy in aucs] |
|
prec = 100.0 * np.mean(precisions) |
|
ms = 100.0 * np.mean(matching_scores) |
|
print("Evaluation Results (mean over {} pairs):".format(len(pairs))) |
|
print("AUC@5\t AUC@10\t AUC@20\t Prec\t MScore\t") |
|
print( |
|
"{:.2f}\t {:.2f}\t {:.2f}\t {:.2f}\t {:.2f}\t".format( |
|
aucs[0], aucs[1], aucs[2], prec, ms |
|
) |
|
) |
|
|