File size: 4,228 Bytes
404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import argparse
import os
from os.path import join
import cv2
import torch
from matplotlib import pyplot as plt
from gluestick import batch_to_np, numpy_image_to_torch, GLUESTICK_ROOT
from .drawing import (
plot_images,
plot_lines,
plot_color_line_matches,
plot_keypoints,
plot_matches,
)
from .models.two_view_pipeline import TwoViewPipeline
def main():
# Parse input parameters
parser = argparse.ArgumentParser(
prog="GlueStick Demo",
description="Demo app to show the point and line matches obtained by GlueStick",
)
parser.add_argument("-img1", default=join("resources" + os.path.sep + "img1.jpg"))
parser.add_argument("-img2", default=join("resources" + os.path.sep + "img2.jpg"))
parser.add_argument("--max_pts", type=int, default=1000)
parser.add_argument("--max_lines", type=int, default=300)
parser.add_argument("--skip-imshow", default=False, action="store_true")
args = parser.parse_args()
# Evaluation config
conf = {
"name": "two_view_pipeline",
"use_lines": True,
"extractor": {
"name": "wireframe",
"sp_params": {
"force_num_keypoints": False,
"max_num_keypoints": args.max_pts,
},
"wireframe_params": {
"merge_points": True,
"merge_line_endpoints": True,
},
"max_n_lines": args.max_lines,
},
"matcher": {
"name": "gluestick",
"weights": str(
GLUESTICK_ROOT / "resources" / "weights" / "checkpoint_GlueStick_MD.tar"
),
"trainable": False,
},
"ground_truth": {
"from_pose_depth": False,
},
}
device = "cuda" if torch.cuda.is_available() else "cpu"
pipeline_model = TwoViewPipeline(conf).to(device).eval()
gray0 = cv2.imread(args.img1, 0)
gray1 = cv2.imread(args.img2, 0)
torch_gray0, torch_gray1 = numpy_image_to_torch(gray0), numpy_image_to_torch(gray1)
torch_gray0, torch_gray1 = (
torch_gray0.to(device)[None],
torch_gray1.to(device)[None],
)
x = {"image0": torch_gray0, "image1": torch_gray1}
pred = pipeline_model(x)
pred = batch_to_np(pred)
kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
m0 = pred["matches0"]
line_seg0, line_seg1 = pred["lines0"], pred["lines1"]
line_matches = pred["line_matches0"]
valid_matches = m0 != -1
match_indices = m0[valid_matches]
matched_kps0 = kp0[valid_matches]
matched_kps1 = kp1[match_indices]
valid_matches = line_matches != -1
match_indices = line_matches[valid_matches]
matched_lines0 = line_seg0[valid_matches]
matched_lines1 = line_seg1[match_indices]
# Plot the matches
img0, img1 = cv2.cvtColor(gray0, cv2.COLOR_GRAY2BGR), cv2.cvtColor(
gray1, cv2.COLOR_GRAY2BGR
)
plot_images(
[img0, img1],
["Image 1 - detected lines", "Image 2 - detected lines"],
dpi=200,
pad=2.0,
)
plot_lines([line_seg0, line_seg1], ps=4, lw=2)
plt.gcf().canvas.manager.set_window_title("Detected Lines")
plt.savefig("detected_lines.png")
plot_images(
[img0, img1],
["Image 1 - detected points", "Image 2 - detected points"],
dpi=200,
pad=2.0,
)
plot_keypoints([kp0, kp1], colors="c")
plt.gcf().canvas.manager.set_window_title("Detected Points")
plt.savefig("detected_points.png")
plot_images(
[img0, img1],
["Image 1 - line matches", "Image 2 - line matches"],
dpi=200,
pad=2.0,
)
plot_color_line_matches([matched_lines0, matched_lines1], lw=2)
plt.gcf().canvas.manager.set_window_title("Line Matches")
plt.savefig("line_matches.png")
plot_images(
[img0, img1],
["Image 1 - point matches", "Image 2 - point matches"],
dpi=200,
pad=2.0,
)
plot_matches(matched_kps0, matched_kps1, "green", lw=1, ps=0)
plt.gcf().canvas.manager.set_window_title("Point Matches")
plt.savefig("detected_points.png")
if not args.skip_imshow:
plt.show()
if __name__ == "__main__":
main()
|