|
|
|
|
|
|
|
import cv2 |
|
import torch |
|
import argparse |
|
import warnings |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import torchvision.transforms.functional as F |
|
|
|
from os.path import join |
|
|
|
from dkm.models.model_zoo.DKMv3 import DKMv3 |
|
from gluefactory.superpoint import SuperPoint |
|
from gluefactory.models.matchers.lightglue import LightGlue |
|
|
|
DEFAULT_MIN_NUM_MATCHES = 4 |
|
DEFAULT_RANSAC_MAX_ITER = 10000 |
|
DEFAULT_RANSAC_CONFIDENCE = 0.999 |
|
DEFAULT_RANSAC_REPROJ_THRESHOLD = 8 |
|
DEFAULT_RANSAC_METHOD = "USAC_MAGSAC" |
|
|
|
RANSAC_ZOO = { |
|
"RANSAC": cv2.RANSAC, |
|
"USAC_FAST": cv2.USAC_FAST, |
|
"USAC_MAGSAC": cv2.USAC_MAGSAC, |
|
"USAC_PROSAC": cv2.USAC_PROSAC, |
|
"USAC_DEFAULT": cv2.USAC_DEFAULT, |
|
"USAC_FM_8PTS": cv2.USAC_FM_8PTS, |
|
"USAC_ACCURATE": cv2.USAC_ACCURATE, |
|
"USAC_PARALLEL": cv2.USAC_PARALLEL, |
|
} |
|
|
|
|
|
def read_image(path, grayscale=False): |
|
if grayscale: |
|
mode = cv2.IMREAD_GRAYSCALE |
|
else: |
|
mode = cv2.IMREAD_COLOR |
|
image = cv2.imread(str(path), mode) |
|
if image is None: |
|
raise ValueError(f'Cannot read image {path}.') |
|
if not grayscale and len(image.shape) == 3: |
|
image = image[:, :, ::-1] |
|
return image |
|
|
|
|
|
def resize_image(image, size, interp): |
|
assert interp.startswith('cv2_') |
|
if interp.startswith('cv2_'): |
|
interp = getattr(cv2, 'INTER_'+interp[len('cv2_'):].upper()) |
|
h, w = image.shape[:2] |
|
if interp == cv2.INTER_AREA and (w < size[0] or h < size[1]): |
|
interp = cv2.INTER_LINEAR |
|
resized = cv2.resize(image, size, interpolation=interp) |
|
|
|
|
|
|
|
|
|
|
|
else: |
|
raise ValueError( |
|
f'Unknown interpolation {interp}.') |
|
return resized |
|
|
|
|
|
def fast_make_matching_figure(data, b_id): |
|
color0 = (data['color0'][b_id].permute(1, 2, 0).cpu().detach().numpy() * 255).round().astype(np.uint8) |
|
color1 = (data['color1'][b_id].permute(1, 2, 0).cpu().detach().numpy() * 255).round().astype(np.uint8) |
|
gray0 = cv2.cvtColor(color0, cv2.COLOR_RGB2GRAY) |
|
gray1 = cv2.cvtColor(color1, cv2.COLOR_RGB2GRAY) |
|
kpts0 = data['mkpts0_f'].cpu().detach().numpy() |
|
kpts1 = data['mkpts1_f'].cpu().detach().numpy() |
|
mconf = data['mconf'].cpu().detach().numpy() |
|
inliers = data['inliers'] |
|
|
|
rows = 2 |
|
margin = 2 |
|
(h0, w0), (h1, w1) = data['hw0_i'], data['hw1_i'] |
|
h = max(h0, h1) |
|
H, W = margin * (rows + 1) + h * rows, margin * 3 + w0 + w1 |
|
|
|
|
|
out = 255 * np.ones((H, W), np.uint8) |
|
|
|
wx = [margin, margin + w0, margin + w0 + margin, margin + w0 + margin + w1] |
|
hx = lambda row: margin * row + h * (row-1) |
|
out = np.stack([out] * 3, -1) |
|
|
|
sh = hx(row=1) |
|
out[sh: sh + h0, wx[0]: wx[1]] = color0 |
|
out[sh: sh + h1, wx[2]: wx[3]] = color1 |
|
|
|
sh = hx(row=2) |
|
out[sh: sh + h0, wx[0]: wx[1]] = color0 |
|
out[sh: sh + h1, wx[2]: wx[3]] = color1 |
|
mkpts0, mkpts1 = np.round(kpts0).astype(int), np.round(kpts1).astype(int) |
|
for (x0, y0), (x1, y1) in zip(mkpts0[inliers], mkpts1[inliers]): |
|
c = (0, 255, 0) |
|
cv2.circle(out, (x0, y0 + sh), 3, c, -1, lineType=cv2.LINE_AA) |
|
cv2.circle(out, (x1 + margin + w0, y1 + sh), 3, c, -1, lineType=cv2.LINE_AA) |
|
|
|
return out |
|
|
|
|
|
def fast_make_matching_overlay(data, b_id): |
|
color0 = (data['color0'][b_id].permute(1, 2, 0).cpu().detach().numpy() * 255).round().astype(np.uint8) |
|
color1 = (data['color1'][b_id].permute(1, 2, 0).cpu().detach().numpy() * 255).round().astype(np.uint8) |
|
gray0 = cv2.cvtColor(color0, cv2.COLOR_RGB2GRAY) |
|
gray1 = cv2.cvtColor(color1, cv2.COLOR_RGB2GRAY) |
|
kpts0 = data['mkpts0_f'].cpu().detach().numpy() |
|
kpts1 = data['mkpts1_f'].cpu().detach().numpy() |
|
mconf = data['mconf'].cpu().detach().numpy() |
|
inliers = data['inliers'] |
|
|
|
rows = 2 |
|
margin = 2 |
|
(h0, w0), (h1, w1) = data['hw0_i'], data['hw1_i'] |
|
h = max(h0, h1) |
|
H, W = margin * (rows + 1) + h * rows, margin * 3 + w0 + w1 |
|
|
|
|
|
out = 255 * np.ones((H, W), np.uint8) |
|
|
|
wx = [margin, margin + w0, margin + w0 + margin, margin + w0 + margin + w1] |
|
hx = lambda row: margin * row + h * (row-1) |
|
out = np.stack([out] * 3, -1) |
|
|
|
sh = hx(row=1) |
|
out[sh: sh + h0, wx[0]: wx[1]] = color0 |
|
out[sh: sh + h1, wx[2]: wx[3]] = color1 |
|
|
|
sh = hx(row=2) |
|
out[sh: sh + h0, wx[0]: wx[1]] = color0 |
|
out[sh: sh + h1, wx[2]: wx[3]] = color1 |
|
mkpts0, mkpts1 = np.round(kpts0).astype(int), np.round(kpts1).astype(int) |
|
for (x0, y0), (x1, y1) in zip(mkpts0[inliers], mkpts1[inliers]): |
|
c = (0, 255, 0) |
|
cv2.line(out, (x0, y0 + sh), (x1 + margin + w0, y1 + sh), color=c, thickness=1, lineType=cv2.LINE_AA) |
|
cv2.circle(out, (x0, y0 + sh), 3, c, -1, lineType=cv2.LINE_AA) |
|
cv2.circle(out, (x1 + margin + w0, y1 + sh), 3, c, -1, lineType=cv2.LINE_AA) |
|
|
|
return out |
|
|
|
|
|
def preprocess(image: np.ndarray, grayscale: bool = False, resize_max: int = None, |
|
dfactor: int = 8): |
|
image = image.astype(np.float32, copy=False) |
|
size = image.shape[:2][::-1] |
|
scale = np.array([1.0, 1.0]) |
|
|
|
if resize_max: |
|
scale = resize_max / max(size) |
|
if scale < 1.0: |
|
size_new = tuple(int(round(x*scale)) for x in size) |
|
image = resize_image(image, size_new, 'cv2_area') |
|
scale = np.array(size) / np.array(size_new) |
|
|
|
if grayscale: |
|
assert image.ndim == 2, image.shape |
|
image = image[None] |
|
else: |
|
image = image.transpose((2, 0, 1)) |
|
image = torch.from_numpy(image / 255.0).float() |
|
|
|
|
|
size_new = tuple(map( |
|
lambda x: int(x // dfactor * dfactor), |
|
image.shape[-2:])) |
|
image = F.resize(image, size=size_new) |
|
scale = np.array(size) / np.array(size_new)[::-1] |
|
return image, scale |
|
|
|
|
|
def compute_geom(data, |
|
ransac_method=DEFAULT_RANSAC_METHOD, |
|
ransac_reproj_threshold=DEFAULT_RANSAC_REPROJ_THRESHOLD, |
|
ransac_confidence=DEFAULT_RANSAC_CONFIDENCE, |
|
ransac_max_iter=DEFAULT_RANSAC_MAX_ITER, |
|
) -> dict: |
|
|
|
mkpts0 = data["mkpts0_f"].cpu().detach().numpy() |
|
mkpts1 = data["mkpts1_f"].cpu().detach().numpy() |
|
|
|
if len(mkpts0) < 2 * DEFAULT_MIN_NUM_MATCHES: |
|
return {} |
|
|
|
h1, w1 = data["hw0_i"] |
|
|
|
geo_info = {} |
|
|
|
F, inliers = cv2.findFundamentalMat( |
|
mkpts0, |
|
mkpts1, |
|
method=RANSAC_ZOO[ransac_method], |
|
ransacReprojThreshold=ransac_reproj_threshold, |
|
confidence=ransac_confidence, |
|
maxIters=ransac_max_iter, |
|
) |
|
if F is not None: |
|
geo_info["Fundamental"] = F.tolist() |
|
|
|
H, _ = cv2.findHomography( |
|
mkpts1, |
|
mkpts0, |
|
method=RANSAC_ZOO[ransac_method], |
|
ransacReprojThreshold=ransac_reproj_threshold, |
|
confidence=ransac_confidence, |
|
maxIters=ransac_max_iter, |
|
) |
|
if H is not None: |
|
geo_info["Homography"] = H.tolist() |
|
_, H1, H2 = cv2.stereoRectifyUncalibrated( |
|
mkpts0.reshape(-1, 2), |
|
mkpts1.reshape(-1, 2), |
|
F, |
|
imgSize=(w1, h1), |
|
) |
|
geo_info["H1"] = H1.tolist() |
|
geo_info["H2"] = H2.tolist() |
|
|
|
return geo_info |
|
|
|
|
|
def wrap_images(img0, img1, geo_info, geom_type): |
|
img0 = img0[0].permute((1, 2, 0)).cpu().detach().numpy()[..., ::-1] |
|
img1 = img1[0].permute((1, 2, 0)).cpu().detach().numpy()[..., ::-1] |
|
|
|
h1, w1, _ = img0.shape |
|
h2, w2, _ = img1.shape |
|
|
|
rectified_image0 = img0 |
|
rectified_image1 = None |
|
H = np.array(geo_info["Homography"]) |
|
F = np.array(geo_info["Fundamental"]) |
|
|
|
title = [] |
|
if geom_type == "Homography": |
|
rectified_image1 = cv2.warpPerspective( |
|
img1, H, (img0.shape[1], img0.shape[0]) |
|
) |
|
title = ["Image 0", "Image 1 - warped"] |
|
elif geom_type == "Fundamental": |
|
H1, H2 = np.array(geo_info["H1"]), np.array(geo_info["H2"]) |
|
rectified_image0 = cv2.warpPerspective(img0, H1, (w1, h1)) |
|
rectified_image1 = cv2.warpPerspective(img1, H2, (w2, h2)) |
|
title = ["Image 0 - warped", "Image 1 - warped"] |
|
else: |
|
print("Error: Unknown geometry type") |
|
|
|
fig = plot_images( |
|
[rectified_image0.squeeze(), rectified_image1.squeeze()], |
|
title, |
|
dpi=300, |
|
) |
|
|
|
img = fig2im(fig) |
|
|
|
plt.close(fig) |
|
|
|
return img |
|
|
|
|
|
def plot_images(imgs, titles=None, cmaps="gray", dpi=100, size=5, pad=0.5): |
|
"""Plot a set of images horizontally. |
|
Args: |
|
imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W). |
|
titles: a list of strings, as titles for each image. |
|
cmaps: colormaps for monochrome images. |
|
dpi: |
|
size: |
|
pad: |
|
""" |
|
n = len(imgs) |
|
if not isinstance(cmaps, (list, tuple)): |
|
cmaps = [cmaps] * n |
|
|
|
figsize = (size * n, size * 6 / 5) if size is not None else None |
|
fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi) |
|
|
|
if n == 1: |
|
ax = [ax] |
|
for i in range(n): |
|
ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i])) |
|
ax[i].get_yaxis().set_ticks([]) |
|
ax[i].get_xaxis().set_ticks([]) |
|
ax[i].set_axis_off() |
|
for spine in ax[i].spines.values(): |
|
spine.set_visible(False) |
|
if titles: |
|
ax[i].set_title(titles[i]) |
|
|
|
fig.tight_layout(pad=pad) |
|
|
|
return fig |
|
|
|
|
|
def fig2im(fig): |
|
fig.canvas.draw() |
|
w, h = fig.canvas.get_width_height() |
|
buf_ndarray = np.frombuffer(fig.canvas.tostring_rgb(), dtype="u1") |
|
im = buf_ndarray.reshape(h, w, 3) |
|
return im |
|
|
|
|
|
if __name__ == '__main__': |
|
model_zoo = ['gim_dkm', 'gim_lightglue'] |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--model', type=str, default='gim_dkm', choices=model_zoo) |
|
args = parser.parse_args() |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
ckpt = None |
|
model = None |
|
detector = None |
|
if args.model == 'gim_dkm': |
|
ckpt = 'gim_dkm_100h.ckpt' |
|
model = DKMv3(weights=None, h=672, w=896) |
|
elif args.model == 'gim_lightglue': |
|
ckpt = 'gim_lightglue_100h.ckpt' |
|
detector = SuperPoint({ |
|
'max_num_keypoints': 2048, |
|
'force_num_keypoints': True, |
|
'detection_threshold': 0.0, |
|
'nms_radius': 3, |
|
'trainable': False, |
|
}) |
|
model = LightGlue({ |
|
'filter_threshold': 0.1, |
|
'flash': False, |
|
'checkpointed': True, |
|
}) |
|
|
|
|
|
checkpoints_path = join('weights', ckpt) |
|
|
|
|
|
if args.model == 'gim_dkm': |
|
state_dict = torch.load(checkpoints_path, map_location='cpu') |
|
if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] |
|
for k in list(state_dict.keys()): |
|
if k.startswith('model.'): |
|
state_dict[k.replace('model.', '', 1)] = state_dict.pop(k) |
|
if 'encoder.net.fc' in k: |
|
state_dict.pop(k) |
|
model.load_state_dict(state_dict) |
|
|
|
elif args.model == 'gim_lightglue': |
|
state_dict = torch.load(checkpoints_path, map_location='cpu') |
|
if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] |
|
for k in list(state_dict.keys()): |
|
if k.startswith('model.'): |
|
state_dict.pop(k) |
|
if k.startswith('superpoint.'): |
|
state_dict[k.replace('superpoint.', '', 1)] = state_dict.pop(k) |
|
detector.load_state_dict(state_dict) |
|
|
|
state_dict = torch.load(checkpoints_path, map_location='cpu') |
|
if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] |
|
for k in list(state_dict.keys()): |
|
if k.startswith('superpoint.'): |
|
state_dict.pop(k) |
|
if k.startswith('model.'): |
|
state_dict[k.replace('model.', '', 1)] = state_dict.pop(k) |
|
model.load_state_dict(state_dict) |
|
|
|
|
|
if detector is not None: |
|
detector = detector.eval().to(device) |
|
model = model.eval().to(device) |
|
|
|
name0 = 'a1' |
|
name1 = 'a2' |
|
postfix = '.png' |
|
image_dir = join('assets', 'demo') |
|
img_path0 = join(image_dir, name0 + postfix) |
|
img_path1 = join(image_dir, name1 + postfix) |
|
|
|
image0 = read_image(img_path0) |
|
image1 = read_image(img_path1) |
|
image0, scale0 = preprocess(image0) |
|
image1, scale1 = preprocess(image1) |
|
|
|
image0 = image0.to(device)[None] |
|
image1 = image1.to(device)[None] |
|
|
|
data = dict(color0=image0, color1=image1, image0=image0, image1=image1) |
|
|
|
if args.model == 'gim_dkm': |
|
with warnings.catch_warnings(): |
|
warnings.simplefilter("ignore") |
|
dense_matches, dense_certainty = model.match(image0, image1) |
|
sparse_matches, mconf = model.sample(dense_matches, dense_certainty, 5000) |
|
|
|
height0, width0 = image0.shape[-2:] |
|
height1, width1 = image1.shape[-2:] |
|
|
|
kpts0 = sparse_matches[:, :2] |
|
kpts0 = torch.stack(( |
|
width0 * (kpts0[:, 0] + 1) / 2, height0 * (kpts0[:, 1] + 1) / 2), dim=-1,) |
|
kpts1 = sparse_matches[:, 2:] |
|
kpts1 = torch.stack(( |
|
width1 * (kpts1[:, 0] + 1) / 2, height1 * (kpts1[:, 1] + 1) / 2), dim=-1,) |
|
b_ids = torch.where(mconf[None])[0] |
|
elif args.model == 'gim_lightglue': |
|
gray0 = read_image(img_path0, grayscale=True) |
|
gray1 = read_image(img_path1, grayscale=True) |
|
gray0 = preprocess(gray0, grayscale=True)[0] |
|
gray1 = preprocess(gray1, grayscale=True)[0] |
|
|
|
gray0 = gray0.to(device)[None] |
|
gray1 = gray1.to(device)[None] |
|
scale0 = torch.tensor(scale0).to(device)[None] |
|
scale1 = torch.tensor(scale1).to(device)[None] |
|
|
|
data.update(dict(gray0=gray0, gray1=gray1)) |
|
|
|
size0 = torch.tensor(data["gray0"].shape[-2:][::-1])[None] |
|
size1 = torch.tensor(data["gray1"].shape[-2:][::-1])[None] |
|
|
|
data.update(dict(size0=size0, size1=size1)) |
|
data.update(dict(scale0=scale0, scale1=scale1)) |
|
|
|
pred = {} |
|
pred.update({k + '0': v for k, v in detector({ |
|
"image": data["gray0"], |
|
"image_size": data["size0"], |
|
}).items()}) |
|
pred.update({k + '1': v for k, v in detector({ |
|
"image": data["gray1"], |
|
"image_size": data["size1"], |
|
}).items()}) |
|
pred.update(model({**pred, **data, |
|
**{'resize0': data['size0'], 'resize1': data['size1']}})) |
|
|
|
kpts0 = torch.cat([kp * s for kp, s in zip(pred['keypoints0'], data['scale0'][:, None])]) |
|
kpts1 = torch.cat([kp * s for kp, s in zip(pred['keypoints1'], data['scale1'][:, None])]) |
|
m_bids = torch.nonzero(pred['keypoints0'].sum(dim=2) > -1)[:, 0] |
|
matches = pred['matches'] |
|
bs = data['image0'].size(0) |
|
kpts0 = torch.cat([kpts0[m_bids == b_id][matches[b_id][..., 0]] for b_id in range(bs)]) |
|
kpts1 = torch.cat([kpts1[m_bids == b_id][matches[b_id][..., 1]] for b_id in range(bs)]) |
|
b_ids = torch.cat([m_bids[m_bids == b_id][matches[b_id][..., 0]] for b_id in range(bs)]) |
|
mconf = torch.cat(pred['scores']) |
|
|
|
|
|
_, mask = cv2.findFundamentalMat(kpts0.cpu().detach().numpy(), |
|
kpts1.cpu().detach().numpy(), |
|
cv2.USAC_MAGSAC, ransacReprojThreshold=1.0, |
|
confidence=0.999999, maxIters=10000) |
|
mask = mask.ravel() > 0 |
|
|
|
data.update({ |
|
'hw0_i': image0.shape[-2:], |
|
'hw1_i': image1.shape[-2:], |
|
'mkpts0_f': kpts0, |
|
'mkpts1_f': kpts1, |
|
'm_bids': b_ids, |
|
'mconf': mconf, |
|
'inliers': mask, |
|
}) |
|
|
|
|
|
alpha = 0.5 |
|
out = fast_make_matching_figure(data, b_id=0) |
|
overlay = fast_make_matching_overlay(data, b_id=0) |
|
out = cv2.addWeighted(out, 1 - alpha, overlay, alpha, 0) |
|
cv2.imwrite(join(image_dir, f'{name0}_{name1}_{args.model}_match.png'), out[..., ::-1]) |
|
|
|
geom_info = compute_geom(data) |
|
wrapped_images = wrap_images(image0, image1, geom_info, |
|
"Homography") |
|
cv2.imwrite(join(image_dir, f'{name0}_{name1}_{args.model}_warp.png'), wrapped_images) |
|
|