Spaces:
Running
Running
# Copyright 2019-present NAVER Corp. | |
# CC BY-NC-SA 3.0 | |
# Available only for non-commercial use | |
from PIL import Image | |
from tools import common | |
from tools.dataloader import norm_RGB | |
from nets.patchnet import * | |
from os import path | |
from extract import load_network, NonMaxSuppression, extract_multiscale | |
# Kapture is a pivot file format, based on text and binary files, used to describe SfM (Structure From Motion) | |
# and more generally sensor-acquired data | |
# it can be installed with | |
# pip install kapture | |
# for more information check out https://github.com/naver/kapture | |
import kapture | |
from kapture.io.records import get_image_fullpath | |
from kapture.io.csv import kapture_from_dir | |
from kapture.io.csv import get_feature_csv_fullpath, keypoints_to_file, descriptors_to_file | |
from kapture.io.features import get_keypoints_fullpath, keypoints_check_dir, image_keypoints_to_file | |
from kapture.io.features import get_descriptors_fullpath, descriptors_check_dir, image_descriptors_to_file | |
from kapture.io.csv import get_all_tar_handlers | |
def extract_kapture_keypoints(args): | |
""" | |
Extract r2d2 keypoints and descritors to the kapture format directly | |
""" | |
print('extract_kapture_keypoints...') | |
with get_all_tar_handlers(args.kapture_root, | |
mode={kapture.Keypoints: 'a', | |
kapture.Descriptors: 'a', | |
kapture.GlobalFeatures: 'r', | |
kapture.Matches: 'r'}) as tar_handlers: | |
kdata = kapture_from_dir(args.kapture_root, None, | |
skip_list=[kapture.GlobalFeatures, | |
kapture.Matches, | |
kapture.Points3d, | |
kapture.Observations], | |
tar_handlers=tar_handlers) | |
assert kdata.records_camera is not None | |
image_list = [filename for _, _, filename in kapture.flatten(kdata.records_camera)] | |
if args.keypoints_type is None: | |
args.keypoints_type = path.splitext(path.basename(args.model))[0] | |
print(f'keypoints_type set to {args.keypoints_type}') | |
if args.descriptors_type is None: | |
args.descriptors_type = path.splitext(path.basename(args.model))[0] | |
print(f'descriptors_type set to {args.descriptors_type}') | |
if kdata.keypoints is not None and args.keypoints_type in kdata.keypoints \ | |
and kdata.descriptors is not None and args.descriptors_type in kdata.descriptors: | |
print('detected already computed features of same keypoints_type/descriptors_type, resuming extraction...') | |
image_list = [name | |
for name in image_list | |
if name not in kdata.keypoints[args.keypoints_type] or | |
name not in kdata.descriptors[args.descriptors_type]] | |
if len(image_list) == 0: | |
print('All features were already extracted') | |
return | |
else: | |
print(f'Extracting r2d2 features for {len(image_list)} images') | |
iscuda = common.torch_set_gpu(args.gpu) | |
# load the network... | |
net = load_network(args.model) | |
if iscuda: | |
net = net.cuda() | |
# create the non-maxima detector | |
detector = NonMaxSuppression( | |
rel_thr=args.reliability_thr, | |
rep_thr=args.repeatability_thr) | |
if kdata.keypoints is None: | |
kdata.keypoints = {} | |
if kdata.descriptors is None: | |
kdata.descriptors = {} | |
if args.keypoints_type not in kdata.keypoints: | |
keypoints_dtype = None | |
keypoints_dsize = None | |
else: | |
keypoints_dtype = kdata.keypoints[args.keypoints_type].dtype | |
keypoints_dsize = kdata.keypoints[args.keypoints_type].dsize | |
if args.descriptors_type not in kdata.descriptors: | |
descriptors_dtype = None | |
descriptors_dsize = None | |
else: | |
descriptors_dtype = kdata.descriptors[args.descriptors_type].dtype | |
descriptors_dsize = kdata.descriptors[args.descriptors_type].dsize | |
for image_name in image_list: | |
img_path = get_image_fullpath(args.kapture_root, image_name) | |
print(f"\nExtracting features for {img_path}") | |
img = Image.open(img_path).convert('RGB') | |
W, H = img.size | |
img = norm_RGB(img)[None] | |
if iscuda: | |
img = img.cuda() | |
# extract keypoints/descriptors for a single image | |
xys, desc, scores = extract_multiscale(net, img, detector, | |
scale_f=args.scale_f, | |
min_scale=args.min_scale, | |
max_scale=args.max_scale, | |
min_size=args.min_size, | |
max_size=args.max_size, | |
verbose=True) | |
xys = xys.cpu().numpy() | |
desc = desc.cpu().numpy() | |
scores = scores.cpu().numpy() | |
idxs = scores.argsort()[-args.top_k or None:] | |
xys = xys[idxs] | |
desc = desc[idxs] | |
if keypoints_dtype is None or descriptors_dtype is None: | |
keypoints_dtype = xys.dtype | |
descriptors_dtype = desc.dtype | |
keypoints_dsize = xys.shape[1] | |
descriptors_dsize = desc.shape[1] | |
kdata.keypoints[args.keypoints_type] = kapture.Keypoints('r2d2', keypoints_dtype, keypoints_dsize) | |
kdata.descriptors[args.descriptors_type] = kapture.Descriptors('r2d2', descriptors_dtype, | |
descriptors_dsize, | |
args.keypoints_type, 'L2') | |
keypoints_config_absolute_path = get_feature_csv_fullpath(kapture.Keypoints, | |
args.keypoints_type, | |
args.kapture_root) | |
descriptors_config_absolute_path = get_feature_csv_fullpath(kapture.Descriptors, | |
args.descriptors_type, | |
args.kapture_root) | |
keypoints_to_file(keypoints_config_absolute_path, kdata.keypoints[args.keypoints_type]) | |
descriptors_to_file(descriptors_config_absolute_path, kdata.descriptors[args.descriptors_type]) | |
else: | |
assert kdata.keypoints[args.keypoints_type].dtype == xys.dtype | |
assert kdata.descriptors[args.descriptors_type].dtype == desc.dtype | |
assert kdata.keypoints[args.keypoints_type].dsize == xys.shape[1] | |
assert kdata.descriptors[args.descriptors_type].dsize == desc.shape[1] | |
assert kdata.descriptors[args.descriptors_type].keypoints_type == args.keypoints_type | |
assert kdata.descriptors[args.descriptors_type].metric_type == 'L2' | |
keypoints_fullpath = get_keypoints_fullpath(args.keypoints_type, args.kapture_root, | |
image_name, tar_handlers) | |
print(f"Saving {xys.shape[0]} keypoints to {keypoints_fullpath}") | |
image_keypoints_to_file(keypoints_fullpath, xys) | |
kdata.keypoints[args.keypoints_type].add(image_name) | |
descriptors_fullpath = get_descriptors_fullpath(args.descriptors_type, args.kapture_root, | |
image_name, tar_handlers) | |
print(f"Saving {desc.shape[0]} descriptors to {descriptors_fullpath}") | |
image_descriptors_to_file(descriptors_fullpath, desc) | |
kdata.descriptors[args.descriptors_type].add(image_name) | |
if not keypoints_check_dir(kdata.keypoints[args.keypoints_type], args.keypoints_type, | |
args.kapture_root, tar_handlers) or \ | |
not descriptors_check_dir(kdata.descriptors[args.descriptors_type], args.descriptors_type, | |
args.kapture_root, tar_handlers): | |
print('local feature extraction ended successfully but not all files were saved') | |
if __name__ == '__main__': | |
import argparse | |
parser = argparse.ArgumentParser( | |
"Extract r2d2 local features for all images in a dataset stored in the kapture format") | |
parser.add_argument("--model", type=str, required=True, help='model path') | |
parser.add_argument('--keypoints-type', default=None, help='keypoint type_name, default is filename of model') | |
parser.add_argument('--descriptors-type', default=None, help='descriptors type_name, default is filename of model') | |
parser.add_argument("--kapture-root", type=str, required=True, help='path to kapture root directory') | |
parser.add_argument("--top-k", type=int, default=5000, help='number of keypoints') | |
parser.add_argument("--scale-f", type=float, default=2**0.25) | |
parser.add_argument("--min-size", type=int, default=256) | |
parser.add_argument("--max-size", type=int, default=1024) | |
parser.add_argument("--min-scale", type=float, default=0) | |
parser.add_argument("--max-scale", type=float, default=1) | |
parser.add_argument("--reliability-thr", type=float, default=0.7) | |
parser.add_argument("--repeatability-thr", type=float, default=0.7) | |
parser.add_argument("--gpu", type=int, nargs='+', default=[0], help='use -1 for CPU') | |
args = parser.parse_args() | |
extract_kapture_keypoints(args) | |