Realcat
add: GIM (https://github.com/xuelunshen/gim)
c0283b3
raw
history blame
16.6 kB
# -*- coding: utf-8 -*-
# @Author : xuelun
import cv2
import torch
import argparse
import warnings
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F
from os.path import join
from dkm.models.model_zoo.DKMv3 import DKMv3
from gluefactory.superpoint import SuperPoint
from gluefactory.models.matchers.lightglue import LightGlue
DEFAULT_MIN_NUM_MATCHES = 4
DEFAULT_RANSAC_MAX_ITER = 10000
DEFAULT_RANSAC_CONFIDENCE = 0.999
DEFAULT_RANSAC_REPROJ_THRESHOLD = 8
DEFAULT_RANSAC_METHOD = "USAC_MAGSAC"
RANSAC_ZOO = {
"RANSAC": cv2.RANSAC,
"USAC_FAST": cv2.USAC_FAST,
"USAC_MAGSAC": cv2.USAC_MAGSAC,
"USAC_PROSAC": cv2.USAC_PROSAC,
"USAC_DEFAULT": cv2.USAC_DEFAULT,
"USAC_FM_8PTS": cv2.USAC_FM_8PTS,
"USAC_ACCURATE": cv2.USAC_ACCURATE,
"USAC_PARALLEL": cv2.USAC_PARALLEL,
}
def read_image(path, grayscale=False):
if grayscale:
mode = cv2.IMREAD_GRAYSCALE
else:
mode = cv2.IMREAD_COLOR
image = cv2.imread(str(path), mode)
if image is None:
raise ValueError(f'Cannot read image {path}.')
if not grayscale and len(image.shape) == 3:
image = image[:, :, ::-1] # BGR to RGB
return image
def resize_image(image, size, interp):
assert interp.startswith('cv2_')
if interp.startswith('cv2_'):
interp = getattr(cv2, 'INTER_'+interp[len('cv2_'):].upper())
h, w = image.shape[:2]
if interp == cv2.INTER_AREA and (w < size[0] or h < size[1]):
interp = cv2.INTER_LINEAR
resized = cv2.resize(image, size, interpolation=interp)
# elif interp.startswith('pil_'):
# interp = getattr(PIL.Image, interp[len('pil_'):].upper())
# resized = PIL.Image.fromarray(image.astype(np.uint8))
# resized = resized.resize(size, resample=interp)
# resized = np.asarray(resized, dtype=image.dtype)
else:
raise ValueError(
f'Unknown interpolation {interp}.')
return resized
def fast_make_matching_figure(data, b_id):
color0 = (data['color0'][b_id].permute(1, 2, 0).cpu().detach().numpy() * 255).round().astype(np.uint8) # (rH, rW, 3)
color1 = (data['color1'][b_id].permute(1, 2, 0).cpu().detach().numpy() * 255).round().astype(np.uint8) # (rH, rW, 3)
gray0 = cv2.cvtColor(color0, cv2.COLOR_RGB2GRAY)
gray1 = cv2.cvtColor(color1, cv2.COLOR_RGB2GRAY)
kpts0 = data['mkpts0_f'].cpu().detach().numpy()
kpts1 = data['mkpts1_f'].cpu().detach().numpy()
mconf = data['mconf'].cpu().detach().numpy()
inliers = data['inliers']
rows = 2
margin = 2
(h0, w0), (h1, w1) = data['hw0_i'], data['hw1_i']
h = max(h0, h1)
H, W = margin * (rows + 1) + h * rows, margin * 3 + w0 + w1
# canvas
out = 255 * np.ones((H, W), np.uint8)
wx = [margin, margin + w0, margin + w0 + margin, margin + w0 + margin + w1]
hx = lambda row: margin * row + h * (row-1)
out = np.stack([out] * 3, -1)
sh = hx(row=1)
out[sh: sh + h0, wx[0]: wx[1]] = color0
out[sh: sh + h1, wx[2]: wx[3]] = color1
sh = hx(row=2)
out[sh: sh + h0, wx[0]: wx[1]] = color0
out[sh: sh + h1, wx[2]: wx[3]] = color1
mkpts0, mkpts1 = np.round(kpts0).astype(int), np.round(kpts1).astype(int)
for (x0, y0), (x1, y1) in zip(mkpts0[inliers], mkpts1[inliers]):
c = (0, 255, 0)
cv2.circle(out, (x0, y0 + sh), 3, c, -1, lineType=cv2.LINE_AA)
cv2.circle(out, (x1 + margin + w0, y1 + sh), 3, c, -1, lineType=cv2.LINE_AA)
return out
def fast_make_matching_overlay(data, b_id):
color0 = (data['color0'][b_id].permute(1, 2, 0).cpu().detach().numpy() * 255).round().astype(np.uint8) # (rH, rW, 3)
color1 = (data['color1'][b_id].permute(1, 2, 0).cpu().detach().numpy() * 255).round().astype(np.uint8) # (rH, rW, 3)
gray0 = cv2.cvtColor(color0, cv2.COLOR_RGB2GRAY)
gray1 = cv2.cvtColor(color1, cv2.COLOR_RGB2GRAY)
kpts0 = data['mkpts0_f'].cpu().detach().numpy()
kpts1 = data['mkpts1_f'].cpu().detach().numpy()
mconf = data['mconf'].cpu().detach().numpy()
inliers = data['inliers']
rows = 2
margin = 2
(h0, w0), (h1, w1) = data['hw0_i'], data['hw1_i']
h = max(h0, h1)
H, W = margin * (rows + 1) + h * rows, margin * 3 + w0 + w1
# canvas
out = 255 * np.ones((H, W), np.uint8)
wx = [margin, margin + w0, margin + w0 + margin, margin + w0 + margin + w1]
hx = lambda row: margin * row + h * (row-1)
out = np.stack([out] * 3, -1)
sh = hx(row=1)
out[sh: sh + h0, wx[0]: wx[1]] = color0
out[sh: sh + h1, wx[2]: wx[3]] = color1
sh = hx(row=2)
out[sh: sh + h0, wx[0]: wx[1]] = color0
out[sh: sh + h1, wx[2]: wx[3]] = color1
mkpts0, mkpts1 = np.round(kpts0).astype(int), np.round(kpts1).astype(int)
for (x0, y0), (x1, y1) in zip(mkpts0[inliers], mkpts1[inliers]):
c = (0, 255, 0)
cv2.line(out, (x0, y0 + sh), (x1 + margin + w0, y1 + sh), color=c, thickness=1, lineType=cv2.LINE_AA)
cv2.circle(out, (x0, y0 + sh), 3, c, -1, lineType=cv2.LINE_AA)
cv2.circle(out, (x1 + margin + w0, y1 + sh), 3, c, -1, lineType=cv2.LINE_AA)
return out
def preprocess(image: np.ndarray, grayscale: bool = False, resize_max: int = None,
dfactor: int = 8):
image = image.astype(np.float32, copy=False)
size = image.shape[:2][::-1]
scale = np.array([1.0, 1.0])
if resize_max:
scale = resize_max / max(size)
if scale < 1.0:
size_new = tuple(int(round(x*scale)) for x in size)
image = resize_image(image, size_new, 'cv2_area')
scale = np.array(size) / np.array(size_new)
if grayscale:
assert image.ndim == 2, image.shape
image = image[None]
else:
image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
image = torch.from_numpy(image / 255.0).float()
# assure that the size is divisible by dfactor
size_new = tuple(map(
lambda x: int(x // dfactor * dfactor),
image.shape[-2:]))
image = F.resize(image, size=size_new)
scale = np.array(size) / np.array(size_new)[::-1]
return image, scale
def compute_geom(data,
ransac_method=DEFAULT_RANSAC_METHOD,
ransac_reproj_threshold=DEFAULT_RANSAC_REPROJ_THRESHOLD,
ransac_confidence=DEFAULT_RANSAC_CONFIDENCE,
ransac_max_iter=DEFAULT_RANSAC_MAX_ITER,
) -> dict:
mkpts0 = data["mkpts0_f"].cpu().detach().numpy()
mkpts1 = data["mkpts1_f"].cpu().detach().numpy()
if len(mkpts0) < 2 * DEFAULT_MIN_NUM_MATCHES:
return {}
h1, w1 = data["hw0_i"]
geo_info = {}
F, inliers = cv2.findFundamentalMat(
mkpts0,
mkpts1,
method=RANSAC_ZOO[ransac_method],
ransacReprojThreshold=ransac_reproj_threshold,
confidence=ransac_confidence,
maxIters=ransac_max_iter,
)
if F is not None:
geo_info["Fundamental"] = F.tolist()
H, _ = cv2.findHomography(
mkpts1,
mkpts0,
method=RANSAC_ZOO[ransac_method],
ransacReprojThreshold=ransac_reproj_threshold,
confidence=ransac_confidence,
maxIters=ransac_max_iter,
)
if H is not None:
geo_info["Homography"] = H.tolist()
_, H1, H2 = cv2.stereoRectifyUncalibrated(
mkpts0.reshape(-1, 2),
mkpts1.reshape(-1, 2),
F,
imgSize=(w1, h1),
)
geo_info["H1"] = H1.tolist()
geo_info["H2"] = H2.tolist()
return geo_info
def wrap_images(img0, img1, geo_info, geom_type):
img0 = img0[0].permute((1, 2, 0)).cpu().detach().numpy()[..., ::-1]
img1 = img1[0].permute((1, 2, 0)).cpu().detach().numpy()[..., ::-1]
h1, w1, _ = img0.shape
h2, w2, _ = img1.shape
rectified_image0 = img0
rectified_image1 = None
H = np.array(geo_info["Homography"])
F = np.array(geo_info["Fundamental"])
title = []
if geom_type == "Homography":
rectified_image1 = cv2.warpPerspective(
img1, H, (img0.shape[1], img0.shape[0])
)
title = ["Image 0", "Image 1 - warped"]
elif geom_type == "Fundamental":
H1, H2 = np.array(geo_info["H1"]), np.array(geo_info["H2"])
rectified_image0 = cv2.warpPerspective(img0, H1, (w1, h1))
rectified_image1 = cv2.warpPerspective(img1, H2, (w2, h2))
title = ["Image 0 - warped", "Image 1 - warped"]
else:
print("Error: Unknown geometry type")
fig = plot_images(
[rectified_image0.squeeze(), rectified_image1.squeeze()],
title,
dpi=300,
)
img = fig2im(fig)
plt.close(fig)
return img
def plot_images(imgs, titles=None, cmaps="gray", dpi=100, size=5, pad=0.5):
"""Plot a set of images horizontally.
Args:
imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
titles: a list of strings, as titles for each image.
cmaps: colormaps for monochrome images.
dpi:
size:
pad:
"""
n = len(imgs)
if not isinstance(cmaps, (list, tuple)):
cmaps = [cmaps] * n
figsize = (size * n, size * 6 / 5) if size is not None else None
fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi)
if n == 1:
ax = [ax]
for i in range(n):
ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
ax[i].get_yaxis().set_ticks([])
ax[i].get_xaxis().set_ticks([])
ax[i].set_axis_off()
for spine in ax[i].spines.values(): # remove frame
spine.set_visible(False)
if titles:
ax[i].set_title(titles[i])
fig.tight_layout(pad=pad)
return fig
def fig2im(fig):
fig.canvas.draw()
w, h = fig.canvas.get_width_height()
buf_ndarray = np.frombuffer(fig.canvas.tostring_rgb(), dtype="u1")
im = buf_ndarray.reshape(h, w, 3)
return im
if __name__ == '__main__':
model_zoo = ['gim_dkm', 'gim_lightglue']
# model
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='gim_dkm', choices=model_zoo)
args = parser.parse_args()
# device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# load model
ckpt = None
model = None
detector = None
if args.model == 'gim_dkm':
ckpt = 'gim_dkm_100h.ckpt'
model = DKMv3(weights=None, h=672, w=896)
elif args.model == 'gim_lightglue':
ckpt = 'gim_lightglue_100h.ckpt'
detector = SuperPoint({
'max_num_keypoints': 2048,
'force_num_keypoints': True,
'detection_threshold': 0.0,
'nms_radius': 3,
'trainable': False,
})
model = LightGlue({
'filter_threshold': 0.1,
'flash': False,
'checkpointed': True,
})
# weights path
checkpoints_path = join('weights', ckpt)
# load state dict
if args.model == 'gim_dkm':
state_dict = torch.load(checkpoints_path, map_location='cpu')
if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict']
for k in list(state_dict.keys()):
if k.startswith('model.'):
state_dict[k.replace('model.', '', 1)] = state_dict.pop(k)
if 'encoder.net.fc' in k:
state_dict.pop(k)
model.load_state_dict(state_dict)
elif args.model == 'gim_lightglue':
state_dict = torch.load(checkpoints_path, map_location='cpu')
if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict']
for k in list(state_dict.keys()):
if k.startswith('model.'):
state_dict.pop(k)
if k.startswith('superpoint.'):
state_dict[k.replace('superpoint.', '', 1)] = state_dict.pop(k)
detector.load_state_dict(state_dict)
state_dict = torch.load(checkpoints_path, map_location='cpu')
if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict']
for k in list(state_dict.keys()):
if k.startswith('superpoint.'):
state_dict.pop(k)
if k.startswith('model.'):
state_dict[k.replace('model.', '', 1)] = state_dict.pop(k)
model.load_state_dict(state_dict)
# eval mode
if detector is not None:
detector = detector.eval().to(device)
model = model.eval().to(device)
name0 = 'a1'
name1 = 'a2'
postfix = '.png'
image_dir = join('assets', 'demo')
img_path0 = join(image_dir, name0 + postfix)
img_path1 = join(image_dir, name1 + postfix)
image0 = read_image(img_path0)
image1 = read_image(img_path1)
image0, scale0 = preprocess(image0)
image1, scale1 = preprocess(image1)
image0 = image0.to(device)[None]
image1 = image1.to(device)[None]
data = dict(color0=image0, color1=image1, image0=image0, image1=image1)
if args.model == 'gim_dkm':
with warnings.catch_warnings():
warnings.simplefilter("ignore")
dense_matches, dense_certainty = model.match(image0, image1)
sparse_matches, mconf = model.sample(dense_matches, dense_certainty, 5000)
height0, width0 = image0.shape[-2:]
height1, width1 = image1.shape[-2:]
kpts0 = sparse_matches[:, :2]
kpts0 = torch.stack((
width0 * (kpts0[:, 0] + 1) / 2, height0 * (kpts0[:, 1] + 1) / 2), dim=-1,)
kpts1 = sparse_matches[:, 2:]
kpts1 = torch.stack((
width1 * (kpts1[:, 0] + 1) / 2, height1 * (kpts1[:, 1] + 1) / 2), dim=-1,)
b_ids = torch.where(mconf[None])[0]
elif args.model == 'gim_lightglue':
gray0 = read_image(img_path0, grayscale=True)
gray1 = read_image(img_path1, grayscale=True)
gray0 = preprocess(gray0, grayscale=True)[0]
gray1 = preprocess(gray1, grayscale=True)[0]
gray0 = gray0.to(device)[None]
gray1 = gray1.to(device)[None]
scale0 = torch.tensor(scale0).to(device)[None]
scale1 = torch.tensor(scale1).to(device)[None]
data.update(dict(gray0=gray0, gray1=gray1))
size0 = torch.tensor(data["gray0"].shape[-2:][::-1])[None]
size1 = torch.tensor(data["gray1"].shape[-2:][::-1])[None]
data.update(dict(size0=size0, size1=size1))
data.update(dict(scale0=scale0, scale1=scale1))
pred = {}
pred.update({k + '0': v for k, v in detector({
"image": data["gray0"],
"image_size": data["size0"],
}).items()})
pred.update({k + '1': v for k, v in detector({
"image": data["gray1"],
"image_size": data["size1"],
}).items()})
pred.update(model({**pred, **data,
**{'resize0': data['size0'], 'resize1': data['size1']}}))
kpts0 = torch.cat([kp * s for kp, s in zip(pred['keypoints0'], data['scale0'][:, None])])
kpts1 = torch.cat([kp * s for kp, s in zip(pred['keypoints1'], data['scale1'][:, None])])
m_bids = torch.nonzero(pred['keypoints0'].sum(dim=2) > -1)[:, 0]
matches = pred['matches']
bs = data['image0'].size(0)
kpts0 = torch.cat([kpts0[m_bids == b_id][matches[b_id][..., 0]] for b_id in range(bs)])
kpts1 = torch.cat([kpts1[m_bids == b_id][matches[b_id][..., 1]] for b_id in range(bs)])
b_ids = torch.cat([m_bids[m_bids == b_id][matches[b_id][..., 0]] for b_id in range(bs)])
mconf = torch.cat(pred['scores'])
# robust fitting
_, mask = cv2.findFundamentalMat(kpts0.cpu().detach().numpy(),
kpts1.cpu().detach().numpy(),
cv2.USAC_MAGSAC, ransacReprojThreshold=1.0,
confidence=0.999999, maxIters=10000)
mask = mask.ravel() > 0
data.update({
'hw0_i': image0.shape[-2:],
'hw1_i': image1.shape[-2:],
'mkpts0_f': kpts0,
'mkpts1_f': kpts1,
'm_bids': b_ids,
'mconf': mconf,
'inliers': mask,
})
# save visualization
alpha = 0.5
out = fast_make_matching_figure(data, b_id=0)
overlay = fast_make_matching_overlay(data, b_id=0)
out = cv2.addWeighted(out, 1 - alpha, overlay, alpha, 0)
cv2.imwrite(join(image_dir, f'{name0}_{name1}_{args.model}_match.png'), out[..., ::-1])
geom_info = compute_geom(data)
wrapped_images = wrap_images(image0, image1, geom_info,
"Homography")
cv2.imwrite(join(image_dir, f'{name0}_{name1}_{args.model}_warp.png'), wrapped_images)