Spaces:
Running
Running
File size: 5,145 Bytes
63f3cf2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import argparse
import torch
from pathlib import Path
import h5py
import logging
from tqdm import tqdm
import pprint
import localization.matchers as matchers
from localization.base_model import dynamic_load
from colmap_utils.parsers import names_to_pair
confs = {
'gm': {
'output': 'gm',
'model': {
'name': 'gm',
'weight_path': 'weights/imp_gm.900.pth',
'sinkhorn_iterations': 20,
},
},
'gml': {
'output': 'gml',
'model': {
'name': 'gml',
'weight_path': 'weights/imp_gml.920.pth',
'sinkhorn_iterations': 20,
},
},
'adagml': {
'output': 'adagml',
'model': {
'name': 'adagml',
'weight_path': 'weights/imp_adagml.80.pth',
'sinkhorn_iterations': 20,
},
},
'superglue': {
'output': 'superglue',
'model': {
'name': 'superglue',
'weights': 'outdoor',
'sinkhorn_iterations': 20,
'weight_path': 'weights/superglue_outdoor.pth',
},
},
'NNM': {
'output': 'NNM',
'model': {
'name': 'nearest_neighbor',
'do_mutual_check': True,
'distance_threshold': None,
},
},
}
@torch.no_grad()
def main(conf, pairs, features, export_dir, exhaustive=False):
logging.info('Matching local features with configuration:'
f'\n{pprint.pformat(conf)}')
feature_path = Path(export_dir, features + '.h5')
assert feature_path.exists(), feature_path
feature_file = h5py.File(str(feature_path), 'r')
pairs_name = pairs.stem
if not exhaustive:
assert pairs.exists(), pairs
with open(pairs, 'r') as f:
pair_list = f.read().rstrip('\n').split('\n')
elif exhaustive:
logging.info(f'Writing exhaustive match pairs to {pairs}.')
assert not pairs.exists(), pairs
# get the list of images from the feature file
images = []
feature_file.visititems(
lambda name, obj: images.append(obj.parent.name.strip('/'))
if isinstance(obj, h5py.Dataset) else None)
images = list(set(images))
pair_list = [' '.join((images[i], images[j]))
for i in range(len(images)) for j in range(i)]
with open(str(pairs), 'w') as f:
f.write('\n'.join(pair_list))
device = 'cuda' if torch.cuda.is_available() else 'cpu'
Model = dynamic_load(matchers, conf['model']['name'])
model = Model(conf['model']).eval().to(device)
match_name = f'{features}-{conf["output"]}-{pairs_name}'
match_path = Path(export_dir, match_name + '.h5')
match_file = h5py.File(str(match_path), 'a')
matched = set()
for pair in tqdm(pair_list, smoothing=.1):
name0, name1 = pair.split(' ')
pair = names_to_pair(name0, name1)
# Avoid to recompute duplicates to save time
if len({(name0, name1), (name1, name0)} & matched) \
or pair in match_file:
continue
data = {}
feats0, feats1 = feature_file[name0], feature_file[name1]
for k in feats1.keys():
# data[k + '0'] = feats0[k].__array__()
if k == 'descriptors':
data[k + '0'] = feats0[k][()].transpose() # [N D]
else:
data[k + '0'] = feats0[k][()]
for k in feats1.keys():
# data[k + '1'] = feats1[k].__array__()
# data[k + '1'] = feats1[k][()].transpose() # [N D]
if k == 'descriptors':
data[k + '1'] = feats1[k][()].transpose() # [N D]
else:
data[k + '1'] = feats1[k][()]
data = {k: torch.from_numpy(v)[None].float().to(device)
for k, v in data.items()}
# some matchers might expect an image but only use its size
data['image0'] = torch.empty((1, 1,) + tuple(feats0['image_size'])[::-1])
data['image1'] = torch.empty((1, 1,) + tuple(feats1['image_size'])[::-1])
pred = model(data)
grp = match_file.create_group(pair)
matches = pred['matches0'][0].cpu().short().numpy()
grp.create_dataset('matches0', data=matches)
if 'matching_scores0' in pred:
scores = pred['matching_scores0'][0].cpu().half().numpy()
grp.create_dataset('matching_scores0', data=scores)
matched |= {(name0, name1), (name1, name0)}
match_file.close()
logging.info('Finished exporting matches.')
return match_path
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--export_dir', type=Path, required=True)
parser.add_argument('--features', type=str, required=True)
parser.add_argument('--pairs', type=Path, required=True)
parser.add_argument('--conf', type=str, required=True, choices=list(confs.keys()))
parser.add_argument('--exhaustive', action='store_true')
args = parser.parse_args()
main(confs[args.conf], args.pairs, args.features, args.export_dir,
exhaustive=args.exhaustive)
|