Spaces:
Running
Running
# Copyright 2019-present NAVER Corp. | |
# CC BY-NC-SA 3.0 | |
# Available only for non-commercial use | |
import os, pdb | |
from PIL import Image | |
import numpy as np | |
import torch | |
from tools import common | |
from tools.dataloader import norm_RGB | |
from nets.patchnet import * | |
def load_network(model_fn): | |
checkpoint = torch.load(model_fn) | |
print("\n>> Creating net = " + checkpoint['net']) | |
net = eval(checkpoint['net']) | |
nb_of_weights = common.model_size(net) | |
print(f" ( Model size: {nb_of_weights/1000:.0f}K parameters )") | |
# initialization | |
weights = checkpoint['state_dict'] | |
net.load_state_dict({k.replace('module.',''):v for k,v in weights.items()}) | |
return net.eval() | |
class NonMaxSuppression (torch.nn.Module): | |
def __init__(self, rel_thr=0.7, rep_thr=0.7): | |
nn.Module.__init__(self) | |
self.max_filter = torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1) | |
self.rel_thr = rel_thr | |
self.rep_thr = rep_thr | |
def forward(self, reliability, repeatability, **kw): | |
assert len(reliability) == len(repeatability) == 1 | |
reliability, repeatability = reliability[0], repeatability[0] | |
# local maxima | |
maxima = (repeatability == self.max_filter(repeatability)) | |
# remove low peaks | |
maxima *= (repeatability >= self.rep_thr) | |
maxima *= (reliability >= self.rel_thr) | |
return maxima.nonzero().t()[2:4] | |
def extract_multiscale( net, img, detector, scale_f=2**0.25, | |
min_scale=0.0, max_scale=1, | |
min_size=256, max_size=1024, | |
verbose=False): | |
old_bm = torch.backends.cudnn.benchmark | |
torch.backends.cudnn.benchmark = False # speedup | |
# extract keypoints at multiple scales | |
B, three, H, W = img.shape | |
assert B == 1 and three == 3, "should be a batch with a single RGB image" | |
assert max_scale <= 1 | |
s = 1.0 # current scale factor | |
X,Y,S,C,Q,D = [],[],[],[],[],[] | |
while s+0.001 >= max(min_scale, min_size / max(H,W)): | |
if s-0.001 <= min(max_scale, max_size / max(H,W)): | |
nh, nw = img.shape[2:] | |
if verbose: print(f"extracting at scale x{s:.02f} = {nw:4d}x{nh:3d}") | |
# extract descriptors | |
with torch.no_grad(): | |
res = net(imgs=[img]) | |
# get output and reliability map | |
descriptors = res['descriptors'][0] | |
reliability = res['reliability'][0] | |
repeatability = res['repeatability'][0] | |
# normalize the reliability for nms | |
# extract maxima and descs | |
y,x = detector(**res) # nms | |
c = reliability[0,0,y,x] | |
q = repeatability[0,0,y,x] | |
d = descriptors[0,:,y,x].t() | |
n = d.shape[0] | |
# accumulate multiple scales | |
X.append(x.float() * W/nw) | |
Y.append(y.float() * H/nh) | |
S.append((32/s) * torch.ones(n, dtype=torch.float32, device=d.device)) | |
C.append(c) | |
Q.append(q) | |
D.append(d) | |
s /= scale_f | |
# down-scale the image for next iteration | |
nh, nw = round(H*s), round(W*s) | |
img = F.interpolate(img, (nh,nw), mode='bilinear', align_corners=False) | |
# restore value | |
torch.backends.cudnn.benchmark = old_bm | |
Y = torch.cat(Y) | |
X = torch.cat(X) | |
S = torch.cat(S) # scale | |
scores = torch.cat(C) * torch.cat(Q) # scores = reliability * repeatability | |
XYS = torch.stack([X,Y,S], dim=-1) | |
D = torch.cat(D) | |
return XYS, D, scores | |
def extract_keypoints(args): | |
iscuda = common.torch_set_gpu(args.gpu) | |
# load the network... | |
net = load_network(args.model) | |
if iscuda: net = net.cuda() | |
# create the non-maxima detector | |
detector = NonMaxSuppression( | |
rel_thr = args.reliability_thr, | |
rep_thr = args.repeatability_thr) | |
while args.images: | |
img_path = args.images.pop(0) | |
if img_path.endswith('.txt'): | |
args.images = open(img_path).read().splitlines() + args.images | |
continue | |
print(f"\nExtracting features for {img_path}") | |
img = Image.open(img_path).convert('RGB') | |
W, H = img.size | |
img = norm_RGB(img)[None] | |
if iscuda: img = img.cuda() | |
# extract keypoints/descriptors for a single image | |
xys, desc, scores = extract_multiscale(net, img, detector, | |
scale_f = args.scale_f, | |
min_scale = args.min_scale, | |
max_scale = args.max_scale, | |
min_size = args.min_size, | |
max_size = args.max_size, | |
verbose = True) | |
xys = xys.cpu().numpy() | |
desc = desc.cpu().numpy() | |
scores = scores.cpu().numpy() | |
idxs = scores.argsort()[-args.top_k or None:] | |
outpath = img_path + '.' + args.tag | |
print(f"Saving {len(idxs)} keypoints to {outpath}") | |
np.savez(open(outpath,'wb'), | |
imsize = (W,H), | |
keypoints = xys[idxs], | |
descriptors = desc[idxs], | |
scores = scores[idxs]) | |
if __name__ == '__main__': | |
import argparse | |
parser = argparse.ArgumentParser("Extract keypoints for a given image") | |
parser.add_argument("--model", type=str, required=True, help='model path') | |
parser.add_argument("--images", type=str, required=True, nargs='+', help='images / list') | |
parser.add_argument("--tag", type=str, default='r2d2', help='output file tag') | |
parser.add_argument("--top-k", type=int, default=5000, help='number of keypoints') | |
parser.add_argument("--scale-f", type=float, default=2**0.25) | |
parser.add_argument("--min-size", type=int, default=256) | |
parser.add_argument("--max-size", type=int, default=1024) | |
parser.add_argument("--min-scale", type=float, default=0) | |
parser.add_argument("--max-scale", type=float, default=1) | |
parser.add_argument("--reliability-thr", type=float, default=0.7) | |
parser.add_argument("--repeatability-thr", type=float, default=0.7) | |
parser.add_argument("--gpu", type=int, nargs='+', default=[0], help='use -1 for CPU') | |
args = parser.parse_args() | |
extract_keypoints(args) | |