import argparse import os from os.path import join import cv2 import torch from matplotlib import pyplot as plt from gluestick import batch_to_np, numpy_image_to_torch, GLUESTICK_ROOT from .drawing import plot_images, plot_lines, plot_color_line_matches, plot_keypoints, plot_matches from .models.two_view_pipeline import TwoViewPipeline def main(): # Parse input parameters parser = argparse.ArgumentParser( prog='GlueStick Demo', description='Demo app to show the point and line matches obtained by GlueStick') parser.add_argument('-img1', default=join('resources' + os.path.sep + 'img1.jpg')) parser.add_argument('-img2', default=join('resources' + os.path.sep + 'img2.jpg')) parser.add_argument('--max_pts', type=int, default=1000) parser.add_argument('--max_lines', type=int, default=300) parser.add_argument('--skip-imshow', default=False, action='store_true') args = parser.parse_args() # Evaluation config conf = { 'name': 'two_view_pipeline', 'use_lines': True, 'extractor': { 'name': 'wireframe', 'sp_params': { 'force_num_keypoints': False, 'max_num_keypoints': args.max_pts, }, 'wireframe_params': { 'merge_points': True, 'merge_line_endpoints': True, }, 'max_n_lines': args.max_lines, }, 'matcher': { 'name': 'gluestick', 'weights': str(GLUESTICK_ROOT / 'resources' / 'weights' / 'checkpoint_GlueStick_MD.tar'), 'trainable': False, }, 'ground_truth': { 'from_pose_depth': False, } } device = 'cuda' if torch.cuda.is_available() else 'cpu' pipeline_model = TwoViewPipeline(conf).to(device).eval() gray0 = cv2.imread(args.img1, 0) gray1 = cv2.imread(args.img2, 0) torch_gray0, torch_gray1 = numpy_image_to_torch(gray0), numpy_image_to_torch(gray1) torch_gray0, torch_gray1 = torch_gray0.to(device)[None], torch_gray1.to(device)[None] x = {'image0': torch_gray0, 'image1': torch_gray1} pred = pipeline_model(x) pred = batch_to_np(pred) kp0, kp1 = pred["keypoints0"], pred["keypoints1"] m0 = pred["matches0"] line_seg0, line_seg1 = pred["lines0"], pred["lines1"] line_matches = pred["line_matches0"] valid_matches = m0 != -1 match_indices = m0[valid_matches] matched_kps0 = kp0[valid_matches] matched_kps1 = kp1[match_indices] valid_matches = line_matches != -1 match_indices = line_matches[valid_matches] matched_lines0 = line_seg0[valid_matches] matched_lines1 = line_seg1[match_indices] # Plot the matches img0, img1 = cv2.cvtColor(gray0, cv2.COLOR_GRAY2BGR), cv2.cvtColor(gray1, cv2.COLOR_GRAY2BGR) plot_images([img0, img1], ['Image 1 - detected lines', 'Image 2 - detected lines'], dpi=200, pad=2.0) plot_lines([line_seg0, line_seg1], ps=4, lw=2) plt.gcf().canvas.manager.set_window_title('Detected Lines') plt.savefig('detected_lines.png') plot_images([img0, img1], ['Image 1 - detected points', 'Image 2 - detected points'], dpi=200, pad=2.0) plot_keypoints([kp0, kp1], colors='c') plt.gcf().canvas.manager.set_window_title('Detected Points') plt.savefig('detected_points.png') plot_images([img0, img1], ['Image 1 - line matches', 'Image 2 - line matches'], dpi=200, pad=2.0) plot_color_line_matches([matched_lines0, matched_lines1], lw=2) plt.gcf().canvas.manager.set_window_title('Line Matches') plt.savefig('line_matches.png') plot_images([img0, img1], ['Image 1 - point matches', 'Image 2 - point matches'], dpi=200, pad=2.0) plot_matches(matched_kps0, matched_kps1, 'green', lw=1, ps=0) plt.gcf().canvas.manager.set_window_title('Point Matches') plt.savefig('detected_points.png') if not args.skip_imshow: plt.show() if __name__ == '__main__': main()