import argparse import numpy as np import imageio import torch from tqdm import tqdm import time import scipy import scipy.io import scipy.misc from lib.model_test import D2Net from lib.utils import preprocess_image from lib.pyramid import process_multiscale import cv2 import matplotlib.pyplot as plt import os from sys import exit, argv from PIL import Image from skimage.feature import match_descriptors from skimage.measure import ransac from skimage.transform import ProjectiveTransform, AffineTransform import pydegensac def extractSingle(image, model, device): with torch.no_grad(): keypoints, scores, descriptors = process_multiscale( image.to(device).unsqueeze(0), model, scales=[1] ) keypoints = keypoints[:, [1, 0, 2]] feat = {} feat['keypoints'] = keypoints feat['scores'] = scores feat['descriptors'] = descriptors return feat def siftMatching(img1, img2, HFile1, HFile2, device): if HFile1 is not None: H1 = np.load(HFile1) H2 = np.load(HFile2) rgbFile1 = img1 img1 = Image.open(img1) if(img1.mode != 'RGB'): img1 = img1.convert('RGB') img1 = np.array(img1) if HFile1 is not None: img1 = cv2.warpPerspective(img1, H1, dsize=(400,400)) #### Visualization #### # cv2.imshow("Image", cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)) # cv2.waitKey(0) rgbFile2 = img2 img2 = Image.open(img2) if(img2.mode != 'RGB'): img2 = img2.convert('RGB') img2 = np.array(img2) if HFile2 is not None: img2 = cv2.warpPerspective(img2, H2, dsize=(400,400)) #### Visualization #### # cv2.imshow("Image", cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)) # cv2.waitKey(0) # surf = cv2.xfeatures2d.SURF_create(100) # SURF surf = cv2.xfeatures2d.SIFT_create() kp1, des1 = surf.detectAndCompute(img1, None) kp2, des2 = surf.detectAndCompute(img2, None) matches = mnn_matcher( torch.from_numpy(des1).float().to(device=device), torch.from_numpy(des2).float().to(device=device) ) src_pts = np.float32([ kp1[m[0]].pt for m in matches ]).reshape(-1, 2) dst_pts = np.float32([ kp2[m[1]].pt for m in matches ]).reshape(-1, 2) if(src_pts.shape[0] < 5 or dst_pts.shape[0] < 5): return [], [] H, inliers = pydegensac.findHomography(src_pts, dst_pts, 8.0, 0.99, 10000) n_inliers = np.sum(inliers) inlier_keypoints_left = [cv2.KeyPoint(point[0], point[1], 1) for point in src_pts[inliers]] inlier_keypoints_right = [cv2.KeyPoint(point[0], point[1], 1) for point in dst_pts[inliers]] placeholder_matches = [cv2.DMatch(idx, idx, 1) for idx in range(n_inliers)] #### Visualization #### image3 = cv2.drawMatches(img1, inlier_keypoints_left, img2, inlier_keypoints_right, placeholder_matches, None) image3 = cv2.cvtColor(image3, cv2.COLOR_BGR2RGB) # cv2.imshow('Matches', image3) # cv2.waitKey() src_pts = np.float32([ inlier_keypoints_left[m.queryIdx].pt for m in placeholder_matches ]).reshape(-1, 2) dst_pts = np.float32([ inlier_keypoints_right[m.trainIdx].pt for m in placeholder_matches ]).reshape(-1, 2) if HFile1 is None: return src_pts, dst_pts, image3, image3 orgSrc, orgDst = orgKeypoints(src_pts, dst_pts, H1, H2) matchImg = drawOrg(cv2.imread(rgbFile1), cv2.imread(rgbFile2), orgSrc, orgDst) return orgSrc, orgDst, matchImg, image3 def orgKeypoints(src_pts, dst_pts, H1, H2): ones = np.ones((src_pts.shape[0], 1)) src_pts = np.hstack((src_pts, ones)) dst_pts = np.hstack((dst_pts, ones)) orgSrc = np.linalg.inv(H1) @ src_pts.T orgDst = np.linalg.inv(H2) @ dst_pts.T orgSrc = orgSrc/orgSrc[2, :] orgDst = orgDst/orgDst[2, :] orgSrc = np.asarray(orgSrc)[0:2, :] orgDst = np.asarray(orgDst)[0:2, :] return orgSrc, orgDst def drawOrg(image1, image2, orgSrc, orgDst): img1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB) img2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB) for i in range(orgSrc.shape[1]): im1 = cv2.circle(img1, (int(orgSrc[0, i]), int(orgSrc[1, i])), 3, (0, 0, 255), 1) for i in range(orgDst.shape[1]): im2 = cv2.circle(img2, (int(orgDst[0, i]), int(orgDst[1, i])), 3, (0, 0, 255), 1) im4 = cv2.hconcat([im1, im2]) for i in range(orgSrc.shape[1]): im4 = cv2.line(im4, (int(orgSrc[0, i]), int(orgSrc[1, i])), (int(orgDst[0, i]) + im1.shape[1], int(orgDst[1, i])), (0, 255, 0), 1) im4 = cv2.cvtColor(im4, cv2.COLOR_BGR2RGB) # cv2.imshow("Image", im4) # cv2.waitKey(0) return im4 def getPerspKeypoints(rgbFile1, rgbFile2, HFile1, HFile2, model, device): if HFile1 is None: igp1, img1 = read_and_process_image(rgbFile1, H=None) else: H1 = np.load(HFile1) igp1, img1 = read_and_process_image(rgbFile1, H=H1) c,h,w = igp1.shape if HFile2 is None: igp2, img2 = read_and_process_image(rgbFile2, H=None) else: H2 = np.load(HFile2) igp2, img2 = read_and_process_image(rgbFile2, H=H2) feat1 = extractSingle(igp1, model, device) feat2 = extractSingle(igp2, model, device) matches = mnn_matcher( torch.from_numpy(feat1['descriptors']).to(device=device), torch.from_numpy(feat2['descriptors']).to(device=device), ) pos_a = feat1["keypoints"][matches[:, 0], : 2] pos_b = feat2["keypoints"][matches[:, 1], : 2] H, inliers = pydegensac.findHomography(pos_a, pos_b, 8.0, 0.99, 10000) pos_a = pos_a[inliers] pos_b = pos_b[inliers] inlier_keypoints_left = [cv2.KeyPoint(point[0], point[1], 1) for point in pos_a] inlier_keypoints_right = [cv2.KeyPoint(point[0], point[1], 1) for point in pos_b] placeholder_matches = [cv2.DMatch(idx, idx, 1) for idx in range(len(pos_a))] image3 = cv2.drawMatches(img1, inlier_keypoints_left, img2, inlier_keypoints_right, placeholder_matches, None, matchColor=[0, 255, 0]) image3 = cv2.cvtColor(image3, cv2.COLOR_BGR2RGB) #### Visualization #### # cv2.imshow('Matches', image3) # cv2.waitKey() if HFile1 is None: return pos_a, pos_b, image3, image3 orgSrc, orgDst = orgKeypoints(pos_a, pos_b, H1, H2) matchImg = drawOrg(cv2.imread(rgbFile1), cv2.imread(rgbFile2), orgSrc, orgDst) # Reproject matches to perspective View return orgSrc, orgDst, matchImg, image3 ###### Ensemble def read_and_process_image(img_path, resize=None, H=None, h=None, w=None, preprocessing='caffe'): img1 = Image.open(img_path) if resize: img1 = img1.resize(resize) if(img1.mode != 'RGB'): img1 = img1.convert('RGB') img1 = np.array(img1) if H is not None: img1 = cv2.warpPerspective(img1, H, dsize=(400, 400)) # cv2.imshow("Image", cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)) # cv2.waitKey(0) igp1 = torch.from_numpy(preprocess_image(img1, preprocessing=preprocessing).astype(np.float32)) return igp1, img1 def mnn_matcher_scorer(descriptors_a, descriptors_b, k=np.inf): device = descriptors_a.device sim = descriptors_a @ descriptors_b.t() val1, nn12 = torch.max(sim, dim=1) val2, nn21 = torch.max(sim, dim=0) ids1 = torch.arange(0, sim.shape[0], device=device) mask = (ids1 == nn21[nn12]) matches = torch.stack([ids1[mask], nn12[mask]]).t() remaining_matches_dist = val1[mask] return matches, remaining_matches_dist def mnn_matcher(descriptors_a, descriptors_b): device = descriptors_a.device sim = descriptors_a @ descriptors_b.t() nn12 = torch.max(sim, dim=1)[1] nn21 = torch.max(sim, dim=0)[1] ids1 = torch.arange(0, sim.shape[0], device=device) mask = (ids1 == nn21[nn12]) matches = torch.stack([ids1[mask], nn12[mask]]) return matches.t().data.cpu().numpy() def getPerspKeypointsEnsemble(model1, model2, rgbFile1, rgbFile2, HFile1, HFile2, device): if HFile1 is None: igp1, img1 = read_and_process_image(rgbFile1, H=None) else: H1 = np.load(HFile1) igp1, img1 = read_and_process_image(rgbFile1, H=H1) c,h,w = igp1.shape if HFile2 is None: igp2, img2 = read_and_process_image(rgbFile2, H=None) else: H2 = np.load(HFile2) igp2, img2 = read_and_process_image(rgbFile2, H=H2) with torch.no_grad(): keypoints_a1, scores_a1, descriptors_a1 = process_multiscale( igp1.to(device).unsqueeze(0), model1, scales=[1] ) keypoints_a1 = keypoints_a1[:, [1, 0, 2]] keypoints_a2, scores_a2, descriptors_a2 = process_multiscale( igp1.to(device).unsqueeze(0), model2, scales=[1] ) keypoints_a2 = keypoints_a2[:, [1, 0, 2]] keypoints_b1, scores_b1, descriptors_b1 = process_multiscale( igp2.to(device).unsqueeze(0), model1, scales=[1] ) keypoints_b1 = keypoints_b1[:, [1, 0, 2]] keypoints_b2, scores_b2, descriptors_b2 = process_multiscale( igp2.to(device).unsqueeze(0), model2, scales=[1] ) keypoints_b2 = keypoints_b2[:, [1, 0, 2]] # calculating matches for both models matches1, dist_1 = mnn_matcher_scorer( torch.from_numpy(descriptors_a1).to(device=device), torch.from_numpy(descriptors_b1).to(device=device), # len(matches1) ) matches2, dist_2 = mnn_matcher_scorer( torch.from_numpy(descriptors_a2).to(device=device), torch.from_numpy(descriptors_b2).to(device=device), # len(matches1) ) full_matches = torch.cat([matches1, matches2]) full_dist = torch.cat([dist_1, dist_2]) assert len(full_dist)==(len(dist_1)+len(dist_2)), "something wrong" k_final = len(full_dist)//2 # k_final = len(full_dist) # k_final = max(len(dist_1), len(dist_2)) top_k_mask = torch.topk(full_dist, k=k_final)[1] first = [] second = [] for valid_id in top_k_mask: if valid_id<len(dist_1): first.append(valid_id) else: second.append(valid_id-len(dist_1)) # final_matches = full_matches[top_k_mask] matches1 = matches1[torch.tensor(first, device=device).long()].data.cpu().numpy() matches2 = matches2[torch.tensor(second, device=device).long()].data.cpu().numpy() pos_a1 = keypoints_a1[matches1[:, 0], : 2] pos_b1 = keypoints_b1[matches1[:, 1], : 2] pos_a2 = keypoints_a2[matches2[:, 0], : 2] pos_b2 = keypoints_b2[matches2[:, 1], : 2] pos_a = np.concatenate([pos_a1, pos_a2], 0) pos_b = np.concatenate([pos_b1, pos_b2], 0) # pos_a, pos_b, inliers = apply_ransac(pos_a, pos_b) H, inliers = pydegensac.findHomography(pos_a, pos_b, 8.0, 0.99, 10000) pos_a = pos_a[inliers] pos_b = pos_b[inliers] inlier_keypoints_left = [cv2.KeyPoint(point[0], point[1], 1) for point in pos_a] inlier_keypoints_right = [cv2.KeyPoint(point[0], point[1], 1) for point in pos_b] placeholder_matches = [cv2.DMatch(idx, idx, 1) for idx in range(len(pos_a))] image3 = cv2.drawMatches(img1, inlier_keypoints_left, img2, inlier_keypoints_right, placeholder_matches, None, matchColor=[0, 255, 0]) image3 = cv2.cvtColor(image3, cv2.COLOR_BGR2RGB) # cv2.imshow('Matches', image3) # cv2.waitKey() orgSrc, orgDst = orgKeypoints(pos_a, pos_b, H1, H2) matchImg = drawOrg(cv2.imread(rgbFile1), cv2.imread(rgbFile2), orgSrc, orgDst) return orgSrc, orgDst, matchImg, image3 if __name__ == '__main__': WEIGHTS = '../models/rord.pth' srcR = argv[1] trgR = argv[2] srcH = argv[3] trgH = argv[4] orgSrc, orgDst = getPerspKeypoints(srcR, trgR, srcH, trgH, WEIGHTS, ('gpu'))