Spaces:
Running
Running
import argparse | |
import torch | |
from pathlib import Path | |
import h5py | |
import logging | |
from tqdm import tqdm | |
import pprint | |
import localization.matchers as matchers | |
from localization.base_model import dynamic_load | |
from colmap_utils.parsers import names_to_pair | |
confs = { | |
'gm': { | |
'output': 'gm', | |
'model': { | |
'name': 'gm', | |
'weight_path': 'weights/imp_gm.900.pth', | |
'sinkhorn_iterations': 20, | |
}, | |
}, | |
'gml': { | |
'output': 'gml', | |
'model': { | |
'name': 'gml', | |
'weight_path': 'weights/imp_gml.920.pth', | |
'sinkhorn_iterations': 20, | |
}, | |
}, | |
'adagml': { | |
'output': 'adagml', | |
'model': { | |
'name': 'adagml', | |
'weight_path': 'weights/imp_adagml.80.pth', | |
'sinkhorn_iterations': 20, | |
}, | |
}, | |
'superglue': { | |
'output': 'superglue', | |
'model': { | |
'name': 'superglue', | |
'weights': 'outdoor', | |
'sinkhorn_iterations': 20, | |
'weight_path': 'weights/superglue_outdoor.pth', | |
}, | |
}, | |
'NNM': { | |
'output': 'NNM', | |
'model': { | |
'name': 'nearest_neighbor', | |
'do_mutual_check': True, | |
'distance_threshold': None, | |
}, | |
}, | |
} | |
def main(conf, pairs, features, export_dir, exhaustive=False): | |
logging.info('Matching local features with configuration:' | |
f'\n{pprint.pformat(conf)}') | |
feature_path = Path(export_dir, features + '.h5') | |
assert feature_path.exists(), feature_path | |
feature_file = h5py.File(str(feature_path), 'r') | |
pairs_name = pairs.stem | |
if not exhaustive: | |
assert pairs.exists(), pairs | |
with open(pairs, 'r') as f: | |
pair_list = f.read().rstrip('\n').split('\n') | |
elif exhaustive: | |
logging.info(f'Writing exhaustive match pairs to {pairs}.') | |
assert not pairs.exists(), pairs | |
# get the list of images from the feature file | |
images = [] | |
feature_file.visititems( | |
lambda name, obj: images.append(obj.parent.name.strip('/')) | |
if isinstance(obj, h5py.Dataset) else None) | |
images = list(set(images)) | |
pair_list = [' '.join((images[i], images[j])) | |
for i in range(len(images)) for j in range(i)] | |
with open(str(pairs), 'w') as f: | |
f.write('\n'.join(pair_list)) | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
Model = dynamic_load(matchers, conf['model']['name']) | |
model = Model(conf['model']).eval().to(device) | |
match_name = f'{features}-{conf["output"]}-{pairs_name}' | |
match_path = Path(export_dir, match_name + '.h5') | |
match_file = h5py.File(str(match_path), 'a') | |
matched = set() | |
for pair in tqdm(pair_list, smoothing=.1): | |
name0, name1 = pair.split(' ') | |
pair = names_to_pair(name0, name1) | |
# Avoid to recompute duplicates to save time | |
if len({(name0, name1), (name1, name0)} & matched) \ | |
or pair in match_file: | |
continue | |
data = {} | |
feats0, feats1 = feature_file[name0], feature_file[name1] | |
for k in feats1.keys(): | |
# data[k + '0'] = feats0[k].__array__() | |
if k == 'descriptors': | |
data[k + '0'] = feats0[k][()].transpose() # [N D] | |
else: | |
data[k + '0'] = feats0[k][()] | |
for k in feats1.keys(): | |
# data[k + '1'] = feats1[k].__array__() | |
# data[k + '1'] = feats1[k][()].transpose() # [N D] | |
if k == 'descriptors': | |
data[k + '1'] = feats1[k][()].transpose() # [N D] | |
else: | |
data[k + '1'] = feats1[k][()] | |
data = {k: torch.from_numpy(v)[None].float().to(device) | |
for k, v in data.items()} | |
# some matchers might expect an image but only use its size | |
data['image0'] = torch.empty((1, 1,) + tuple(feats0['image_size'])[::-1]) | |
data['image1'] = torch.empty((1, 1,) + tuple(feats1['image_size'])[::-1]) | |
pred = model(data) | |
grp = match_file.create_group(pair) | |
matches = pred['matches0'][0].cpu().short().numpy() | |
grp.create_dataset('matches0', data=matches) | |
if 'matching_scores0' in pred: | |
scores = pred['matching_scores0'][0].cpu().half().numpy() | |
grp.create_dataset('matching_scores0', data=scores) | |
matched |= {(name0, name1), (name1, name0)} | |
match_file.close() | |
logging.info('Finished exporting matches.') | |
return match_path | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--export_dir', type=Path, required=True) | |
parser.add_argument('--features', type=str, required=True) | |
parser.add_argument('--pairs', type=Path, required=True) | |
parser.add_argument('--conf', type=str, required=True, choices=list(confs.keys())) | |
parser.add_argument('--exhaustive', action='store_true') | |
args = parser.parse_args() | |
main(confs[args.conf], args.pairs, args.features, args.export_dir, | |
exhaustive=args.exhaustive) | |