Realcat
fix: eloftr
63f3cf2
raw
history blame
5.15 kB
import argparse
import torch
from pathlib import Path
import h5py
import logging
from tqdm import tqdm
import pprint
import localization.matchers as matchers
from localization.base_model import dynamic_load
from colmap_utils.parsers import names_to_pair
confs = {
'gm': {
'output': 'gm',
'model': {
'name': 'gm',
'weight_path': 'weights/imp_gm.900.pth',
'sinkhorn_iterations': 20,
},
},
'gml': {
'output': 'gml',
'model': {
'name': 'gml',
'weight_path': 'weights/imp_gml.920.pth',
'sinkhorn_iterations': 20,
},
},
'adagml': {
'output': 'adagml',
'model': {
'name': 'adagml',
'weight_path': 'weights/imp_adagml.80.pth',
'sinkhorn_iterations': 20,
},
},
'superglue': {
'output': 'superglue',
'model': {
'name': 'superglue',
'weights': 'outdoor',
'sinkhorn_iterations': 20,
'weight_path': 'weights/superglue_outdoor.pth',
},
},
'NNM': {
'output': 'NNM',
'model': {
'name': 'nearest_neighbor',
'do_mutual_check': True,
'distance_threshold': None,
},
},
}
@torch.no_grad()
def main(conf, pairs, features, export_dir, exhaustive=False):
logging.info('Matching local features with configuration:'
f'\n{pprint.pformat(conf)}')
feature_path = Path(export_dir, features + '.h5')
assert feature_path.exists(), feature_path
feature_file = h5py.File(str(feature_path), 'r')
pairs_name = pairs.stem
if not exhaustive:
assert pairs.exists(), pairs
with open(pairs, 'r') as f:
pair_list = f.read().rstrip('\n').split('\n')
elif exhaustive:
logging.info(f'Writing exhaustive match pairs to {pairs}.')
assert not pairs.exists(), pairs
# get the list of images from the feature file
images = []
feature_file.visititems(
lambda name, obj: images.append(obj.parent.name.strip('/'))
if isinstance(obj, h5py.Dataset) else None)
images = list(set(images))
pair_list = [' '.join((images[i], images[j]))
for i in range(len(images)) for j in range(i)]
with open(str(pairs), 'w') as f:
f.write('\n'.join(pair_list))
device = 'cuda' if torch.cuda.is_available() else 'cpu'
Model = dynamic_load(matchers, conf['model']['name'])
model = Model(conf['model']).eval().to(device)
match_name = f'{features}-{conf["output"]}-{pairs_name}'
match_path = Path(export_dir, match_name + '.h5')
match_file = h5py.File(str(match_path), 'a')
matched = set()
for pair in tqdm(pair_list, smoothing=.1):
name0, name1 = pair.split(' ')
pair = names_to_pair(name0, name1)
# Avoid to recompute duplicates to save time
if len({(name0, name1), (name1, name0)} & matched) \
or pair in match_file:
continue
data = {}
feats0, feats1 = feature_file[name0], feature_file[name1]
for k in feats1.keys():
# data[k + '0'] = feats0[k].__array__()
if k == 'descriptors':
data[k + '0'] = feats0[k][()].transpose() # [N D]
else:
data[k + '0'] = feats0[k][()]
for k in feats1.keys():
# data[k + '1'] = feats1[k].__array__()
# data[k + '1'] = feats1[k][()].transpose() # [N D]
if k == 'descriptors':
data[k + '1'] = feats1[k][()].transpose() # [N D]
else:
data[k + '1'] = feats1[k][()]
data = {k: torch.from_numpy(v)[None].float().to(device)
for k, v in data.items()}
# some matchers might expect an image but only use its size
data['image0'] = torch.empty((1, 1,) + tuple(feats0['image_size'])[::-1])
data['image1'] = torch.empty((1, 1,) + tuple(feats1['image_size'])[::-1])
pred = model(data)
grp = match_file.create_group(pair)
matches = pred['matches0'][0].cpu().short().numpy()
grp.create_dataset('matches0', data=matches)
if 'matching_scores0' in pred:
scores = pred['matching_scores0'][0].cpu().half().numpy()
grp.create_dataset('matching_scores0', data=scores)
matched |= {(name0, name1), (name1, name0)}
match_file.close()
logging.info('Finished exporting matches.')
return match_path
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--export_dir', type=Path, required=True)
parser.add_argument('--features', type=str, required=True)
parser.add_argument('--pairs', type=Path, required=True)
parser.add_argument('--conf', type=str, required=True, choices=list(confs.keys()))
parser.add_argument('--exhaustive', action='store_true')
args = parser.parse_args()
main(confs[args.conf], args.pairs, args.features, args.export_dir,
exhaustive=args.exhaustive)