Vincentqyw
update: features and matchers
a80d6bb
raw
history blame
6.22 kB
# Copyright 2019-present NAVER Corp.
# CC BY-NC-SA 3.0
# Available only for non-commercial use
import os, pdb
from PIL import Image
import numpy as np
import torch
from tools import common
from tools.dataloader import norm_RGB
from nets.patchnet import *
def load_network(model_fn):
checkpoint = torch.load(model_fn)
print("\n>> Creating net = " + checkpoint['net'])
net = eval(checkpoint['net'])
nb_of_weights = common.model_size(net)
print(f" ( Model size: {nb_of_weights/1000:.0f}K parameters )")
# initialization
weights = checkpoint['state_dict']
net.load_state_dict({k.replace('module.',''):v for k,v in weights.items()})
return net.eval()
class NonMaxSuppression (torch.nn.Module):
def __init__(self, rel_thr=0.7, rep_thr=0.7):
nn.Module.__init__(self)
self.max_filter = torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
self.rel_thr = rel_thr
self.rep_thr = rep_thr
def forward(self, reliability, repeatability, **kw):
assert len(reliability) == len(repeatability) == 1
reliability, repeatability = reliability[0], repeatability[0]
# local maxima
maxima = (repeatability == self.max_filter(repeatability))
# remove low peaks
maxima *= (repeatability >= self.rep_thr)
maxima *= (reliability >= self.rel_thr)
return maxima.nonzero().t()[2:4]
def extract_multiscale( net, img, detector, scale_f=2**0.25,
min_scale=0.0, max_scale=1,
min_size=256, max_size=1024,
verbose=False):
old_bm = torch.backends.cudnn.benchmark
torch.backends.cudnn.benchmark = False # speedup
# extract keypoints at multiple scales
B, three, H, W = img.shape
assert B == 1 and three == 3, "should be a batch with a single RGB image"
assert max_scale <= 1
s = 1.0 # current scale factor
X,Y,S,C,Q,D = [],[],[],[],[],[]
while s+0.001 >= max(min_scale, min_size / max(H,W)):
if s-0.001 <= min(max_scale, max_size / max(H,W)):
nh, nw = img.shape[2:]
if verbose: print(f"extracting at scale x{s:.02f} = {nw:4d}x{nh:3d}")
# extract descriptors
with torch.no_grad():
res = net(imgs=[img])
# get output and reliability map
descriptors = res['descriptors'][0]
reliability = res['reliability'][0]
repeatability = res['repeatability'][0]
# normalize the reliability for nms
# extract maxima and descs
y,x = detector(**res) # nms
c = reliability[0,0,y,x]
q = repeatability[0,0,y,x]
d = descriptors[0,:,y,x].t()
n = d.shape[0]
# accumulate multiple scales
X.append(x.float() * W/nw)
Y.append(y.float() * H/nh)
S.append((32/s) * torch.ones(n, dtype=torch.float32, device=d.device))
C.append(c)
Q.append(q)
D.append(d)
s /= scale_f
# down-scale the image for next iteration
nh, nw = round(H*s), round(W*s)
img = F.interpolate(img, (nh,nw), mode='bilinear', align_corners=False)
# restore value
torch.backends.cudnn.benchmark = old_bm
Y = torch.cat(Y)
X = torch.cat(X)
S = torch.cat(S) # scale
scores = torch.cat(C) * torch.cat(Q) # scores = reliability * repeatability
XYS = torch.stack([X,Y,S], dim=-1)
D = torch.cat(D)
return XYS, D, scores
def extract_keypoints(args):
iscuda = common.torch_set_gpu(args.gpu)
# load the network...
net = load_network(args.model)
if iscuda: net = net.cuda()
# create the non-maxima detector
detector = NonMaxSuppression(
rel_thr = args.reliability_thr,
rep_thr = args.repeatability_thr)
while args.images:
img_path = args.images.pop(0)
if img_path.endswith('.txt'):
args.images = open(img_path).read().splitlines() + args.images
continue
print(f"\nExtracting features for {img_path}")
img = Image.open(img_path).convert('RGB')
W, H = img.size
img = norm_RGB(img)[None]
if iscuda: img = img.cuda()
# extract keypoints/descriptors for a single image
xys, desc, scores = extract_multiscale(net, img, detector,
scale_f = args.scale_f,
min_scale = args.min_scale,
max_scale = args.max_scale,
min_size = args.min_size,
max_size = args.max_size,
verbose = True)
xys = xys.cpu().numpy()
desc = desc.cpu().numpy()
scores = scores.cpu().numpy()
idxs = scores.argsort()[-args.top_k or None:]
outpath = img_path + '.' + args.tag
print(f"Saving {len(idxs)} keypoints to {outpath}")
np.savez(open(outpath,'wb'),
imsize = (W,H),
keypoints = xys[idxs],
descriptors = desc[idxs],
scores = scores[idxs])
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser("Extract keypoints for a given image")
parser.add_argument("--model", type=str, required=True, help='model path')
parser.add_argument("--images", type=str, required=True, nargs='+', help='images / list')
parser.add_argument("--tag", type=str, default='r2d2', help='output file tag')
parser.add_argument("--top-k", type=int, default=5000, help='number of keypoints')
parser.add_argument("--scale-f", type=float, default=2**0.25)
parser.add_argument("--min-size", type=int, default=256)
parser.add_argument("--max-size", type=int, default=1024)
parser.add_argument("--min-scale", type=float, default=0)
parser.add_argument("--max-scale", type=float, default=1)
parser.add_argument("--reliability-thr", type=float, default=0.7)
parser.add_argument("--repeatability-thr", type=float, default=0.7)
parser.add_argument("--gpu", type=int, nargs='+', default=[0], help='use -1 for CPU')
args = parser.parse_args()
extract_keypoints(args)