# -*- coding: UTF-8 -*- '''================================================= @Project -> File pram -> hloc @IDE PyCharm @Author fx221@cam.ac.uk @Date 07/02/2024 16:45 ==================================================''' import os import os.path as osp from tqdm import tqdm import argparse import time import logging import h5py import numpy as np from pathlib import Path from colmap_utils.read_write_model import read_model from colmap_utils.parsers import parse_image_lists_with_intrinsics # localization from localization.match_features_batch import confs from localization.base_model import dynamic_load from localization import matchers from localization.utils import compute_pose_error, read_gt_pose, read_retrieval_results from localization.pose_estimator import pose_estimator_hloc, pose_estimator_iterative def run(args): if args.gt_pose_fn is not None: gt_poses = read_gt_pose(path=args.gt_pose_fn) else: gt_poses = {} retrievals = read_retrieval_results(args.retrieval) save_root = args.save_root # path to save os.makedirs(save_root, exist_ok=True) matcher_name = args.matcher_method # matching method print('matcher: ', confs[args.matcher_method]['model']['name']) Model = dynamic_load(matchers, confs[args.matcher_method]['model']['name']) matcher = Model(confs[args.matcher_method]['model']).eval().cuda() local_feat_name = args.features.as_posix().split("/")[-1].split(".")[0] # name of local features save_fn = '{:s}_{:s}'.format(local_feat_name, matcher_name) if args.use_hloc: save_fn = 'hloc_' + save_fn save_fn = osp.join(save_root, save_fn) queries = parse_image_lists_with_intrinsics(args.queries) _, db_images, points3D = read_model(str(args.reference_sfm), '.bin') db_name_to_id = {image.name: i for i, image in db_images.items()} feature_file = h5py.File(args.features, 'r') tag = '' if args.do_covisible_opt: tag = tag + "_o" + str(int(args.obs_thresh)) + 'op' + str(int(args.covisibility_frame)) tag = tag + "th" + str(int(args.opt_thresh)) if args.iters > 0: tag = tag + "i" + str(int(args.iters)) log_fn = save_fn + tag vis_dir = save_fn + tag results = save_fn + tag full_log_fn = log_fn + '_full.log' loc_log_fn = log_fn + '_loc.npy' results = Path(results + '.txt') vis_dir = Path(vis_dir) if vis_dir is not None: Path(vis_dir).mkdir(exist_ok=True) print("save_fn: ", log_fn) logging.info('Starting localization...') poses = {} failed_cases = [] n_total = 0 n_failed = 0 full_log_info = '' loc_results = {} error_ths = ((0.25, 2), (0.5, 5), (5, 10)) success = [0, 0, 0] total_loc_time = [] for qname, qinfo in tqdm(queries): kpq = feature_file[qname]['keypoints'].__array__() n_total += 1 time_start = time.time() if qname in retrievals.keys(): cans = retrievals[qname] db_ids = [db_name_to_id[v] for v in cans] else: cans = [] db_ids = [] time_coarse = time.time() if args.use_hloc: output = pose_estimator_hloc(qname=qname, qinfo=qinfo, db_ids=db_ids, db_images=db_images, points3D=points3D, feature_file=feature_file, thresh=args.ransac_thresh, image_dir=args.image_dir, matcher=matcher, log_info='', query_img_prefix='', db_img_prefix='') else: # should be faster and more accurate than hloc t_start = time.time() output = pose_estimator_iterative(qname=qname, qinfo=qinfo, matcher=matcher, db_ids=db_ids, db_images=db_images, points3D=points3D, feature_file=feature_file, thresh=args.ransac_thresh, image_dir=args.image_dir, do_covisibility_opt=args.do_covisible_opt, covisibility_frame=args.covisibility_frame, log_info='', inlier_th=args.inlier_thresh, obs_th=args.obs_thresh, opt_th=args.opt_thresh, gt_qvec=gt_poses[qname]['qvec'] if qname in gt_poses.keys() else None, gt_tvec=gt_poses[qname]['tvec'] if qname in gt_poses.keys() else None, query_img_prefix='', db_img_prefix='database', ) time_full = time.time() qvec = output['qvec'] tvec = output['tvec'] loc_time = time_full - time_start total_loc_time.append(loc_time) poses[qname] = (qvec, tvec) print_text = "All {:d}/{:d} failed cases, time[cs/fn]: {:.2f}/{:.2f}".format( n_failed, n_total, time_coarse - time_start, time_full - time_coarse, ) if qname in gt_poses.keys(): gt_qvec = gt_poses[qname]['qvec'] gt_tvec = gt_poses[qname]['tvec'] q_error, t_error = compute_pose_error(pred_qcw=qvec, pred_tcw=tvec, gt_qcw=gt_qvec, gt_tcw=gt_tvec) for error_idx, th in enumerate(error_ths): if t_error <= th[0] and q_error <= th[1]: success[error_idx] += 1 print_text += ( ', q_error:{:.2f} t_error:{:.2f} {:d}/{:d}/{:d}/{:d}, time: {:.2f}, {:d}pts'.format(q_error, t_error, success[0], success[1], success[2], n_total, loc_time, kpq.shape[0])) if output['num_inliers'] == 0: failed_cases.append(qname) loc_results[qname] = { 'keypoints_query': output['keypoints_query'], 'points3D_ids': output['points3D_ids'], } full_log_info = full_log_info + output['log_info'] full_log_info += (print_text + "\n") print(print_text) logs_path = f'{results}.failed' with open(logs_path, 'w') as f: for v in failed_cases: print(v) f.write(v + "\n") logging.info(f'Localized {len(poses)} / {len(queries)} images.') logging.info(f'Writing poses to {results}...') # logging.info(f'Mean loc time: {np.mean(total_loc_time)}...') print('Mean loc time: {:.2f}...'.format(np.mean(total_loc_time))) with open(results, 'w') as f: for q in poses: qvec, tvec = poses[q] qvec = ' '.join(map(str, qvec)) tvec = ' '.join(map(str, tvec)) name = q f.write(f'{name} {qvec} {tvec}\n') with open(full_log_fn, 'w') as f: f.write(full_log_info) np.save(loc_log_fn, loc_results) print('Save logs to ', loc_log_fn) logging.info('Done!') if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--image_dir', type=str, required=True) parser.add_argument('--dataset', type=str, required=True) parser.add_argument('--reference_sfm', type=Path, required=True) parser.add_argument('--queries', type=Path, required=True) parser.add_argument('--features', type=Path, required=True) parser.add_argument('--ransac_thresh', type=float, default=12) parser.add_argument('--covisibility_frame', type=int, default=50) parser.add_argument('--do_covisible_opt', action='store_true') parser.add_argument('--use_hloc', action='store_true') parser.add_argument('--matcher_method', type=str, default="NNM") parser.add_argument('--inlier_thresh', type=int, default=50) parser.add_argument('--obs_thresh', type=float, default=3) parser.add_argument('--opt_thresh', type=float, default=12) parser.add_argument('--save_root', type=str, required=True) parser.add_argument('--retrieval', type=Path, default=None) parser.add_argument('--gt_pose_fn', type=str, default=None) args = parser.parse_args() os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' run(args=args)