Realcat
fix: eloftr
63f3cf2
raw
history blame
24 kB
# -*- coding: UTF-8 -*-
'''=================================================
@Project -> File pram -> pose_estimation
@IDE PyCharm
@Author fx221@cam.ac.uk
@Date 08/02/2024 11:01
=================================================='''
import torch
import numpy as np
import pycolmap
import cv2
import os
import time
import os.path as osp
from collections import defaultdict
def get_covisibility_frames(frame_id, all_images, points3D, covisibility_frame=50):
observed = all_images[frame_id].point3D_ids
covis = defaultdict(int)
for pid in observed:
if pid == -1:
continue
for img_id in points3D[pid].image_ids:
if img_id != frame_id:
covis[img_id] += 1
print('Find {:d} connected frames'.format(len(covis.keys())))
covis_ids = np.array(list(covis.keys()))
covis_num = np.array([covis[i] for i in covis_ids])
if len(covis_ids) <= covisibility_frame:
sel_covis_ids = covis_ids[np.argsort(-covis_num)]
else:
ind_top = np.argpartition(covis_num, -covisibility_frame)
ind_top = ind_top[-covisibility_frame:] # unsorted top k
ind_top = ind_top[np.argsort(-covis_num[ind_top])]
sel_covis_ids = [covis_ids[i] for i in ind_top]
print('Retain {:d} valid connected frames'.format(len(sel_covis_ids)))
return sel_covis_ids
def feature_matching(query_data, db_data, matcher):
db_3D_ids = db_data['db_3D_ids']
if db_3D_ids is None:
with torch.no_grad():
match_data = {
'keypoints0': torch.from_numpy(query_data['keypoints'])[None].float().cuda(),
'scores0': torch.from_numpy(query_data['scores'])[None].float().cuda(),
'descriptors0': torch.from_numpy(query_data['descriptors'])[None].float().cuda(),
'image0': torch.empty((1, 1,) + tuple(query_data['image_size'])[::-1]),
'keypoints1': torch.from_numpy(db_data['keypoints'])[None].float().cuda(),
'scores1': torch.from_numpy(db_data['scores'])[None].float().cuda(),
'descriptors1': torch.from_numpy(db_data['descriptors'])[None].float().cuda(), # [B, N, D]
'image1': torch.empty((1, 1,) + tuple(db_data['image_size'])[::-1]),
}
matches = matcher(match_data)['matches0'][0].cpu().numpy()
del match_data
else:
masks = (db_3D_ids != -1)
valid_ids = [i for i in range(masks.shape[0]) if masks[i]]
if len(valid_ids) == 0:
return np.zeros(shape=(query_data['keypoints'].shape[0],), dtype=int) - 1
with torch.no_grad():
match_data = {
'keypoints0': torch.from_numpy(query_data['keypoints'])[None].float().cuda(),
'scores0': torch.from_numpy(query_data['scores'])[None].float().cuda(),
'descriptors0': torch.from_numpy(query_data['descriptors'])[None].float().cuda(),
'image0': torch.empty((1, 1,) + tuple(query_data['image_size'])[::-1]),
'keypoints1': torch.from_numpy(db_data['keypoints'])[masks][None].float().cuda(),
'scores1': torch.from_numpy(db_data['scores'])[masks][None].float().cuda(),
'descriptors1': torch.from_numpy(db_data['descriptors'][masks])[None].float().cuda(),
'image1': torch.empty((1, 1,) + tuple(db_data['image_size'])[::-1]),
}
matches = matcher(match_data)['matches0'][0].cpu().numpy()
del match_data
for i in range(matches.shape[0]):
if matches[i] >= 0:
matches[i] = valid_ids[matches[i]]
return matches
def find_2D_3D_matches(query_data, db_id, points3D, feature_file, db_images, matcher, obs_th=0):
kpq = query_data['keypoints']
db_name = db_images[db_id].name
kpdb = feature_file[db_name]['keypoints'][()]
desc_db = feature_file[db_name]["descriptors"][()]
desc_db = desc_db.transpose()
# print('db_desc: ', desc_db.shape, query_data['descriptors'].shape)
points3D_ids = db_images[db_id].point3D_ids
matches = feature_matching(query_data=query_data,
db_data={
'keypoints': kpdb,
'scores': feature_file[db_name]['scores'][()],
'descriptors': desc_db,
'db_3D_ids': points3D_ids,
'image_size': feature_file[db_name]['image_size'][()]
},
matcher=matcher)
mkpdb = []
mp3d_ids = []
q_ids = []
mkpq = []
mp3d = []
valid_matches = []
for idx in range(matches.shape[0]):
if matches[idx] == -1:
continue
if points3D_ids[matches[idx]] == -1:
continue
id_3D = points3D_ids[matches[idx]]
# reject 3d points without enough observations
if len(points3D[id_3D].image_ids) < obs_th:
continue
mp3d.append(points3D[id_3D].xyz)
mp3d_ids.append(id_3D)
mkpq.append(kpq[idx])
mkpdb.append(kpdb[matches[idx]])
q_ids.append(idx)
valid_matches.append(matches[idx])
mp3d = np.array(mp3d, float).reshape(-1, 3)
mkpq = np.array(mkpq, float).reshape(-1, 2) + 0.5
return mp3d, mkpq, mp3d_ids, q_ids
# hfnet, cvpr 2019
def pose_estimator_hloc(qname, qinfo, db_ids, db_images, points3D,
feature_file,
thresh,
image_dir,
matcher,
log_info=None,
query_img_prefix='',
db_img_prefix=''):
kpq = feature_file[qname]['keypoints'][()]
score_q = feature_file[qname]['scores'][()]
desc_q = feature_file[qname]['descriptors'][()]
desc_q = desc_q.transpose()
imgsize_q = feature_file[qname]['image_size'][()]
query_data = {
'keypoints': kpq,
'scores': score_q,
'descriptors': desc_q,
'image_size': imgsize_q,
}
camera_model, width, height, params = qinfo
cam = pycolmap.Camera(model=camera_model, width=width, height=height, params=params)
cfg = {
'model': camera_model,
'width': width,
'height': height,
'params': params,
}
all_mkpts = []
all_mp3ds = []
all_points3D_ids = []
best_db_id = db_ids[0]
best_db_name = db_images[best_db_id].name
t_start = time.time()
for cluster_idx, db_id in enumerate(db_ids):
mp3d, mkpq, mp3d_ids, q_ids = find_2D_3D_matches(
query_data=query_data,
db_id=db_id,
points3D=points3D,
feature_file=feature_file,
db_images=db_images,
matcher=matcher,
obs_th=3)
if mp3d.shape[0] > 0:
all_mkpts.append(mkpq)
all_mp3ds.append(mp3d)
all_points3D_ids = all_points3D_ids + mp3d_ids
if len(all_mkpts) == 0:
print_text = 'Localize {:s} failed, but use the pose of {:s} as approximation'.format(qname, best_db_name)
print(print_text)
if log_info is not None:
log_info = log_info + print_text + '\n'
qvec = db_images[best_db_id].qvec
tvec = db_images[best_db_id].tvec
return {
'qvec': qvec,
'tvec': tvec,
'log_info': log_info,
'qname': qname,
'dbname': best_db_name,
'num_inliers': 0,
'order': -1,
'keypoints_query': np.array([]),
'points3D_ids': [],
'time': time.time() - t_start,
}
all_mkpts = np.vstack(all_mkpts)
all_mp3ds = np.vstack(all_mp3ds)
ret = pycolmap.absolute_pose_estimation(all_mkpts, all_mp3ds, cam,
estimation_options={
"ransac": {"max_error": thresh}},
refinement_options={},
)
if ret is None:
ret = {'success': False, }
else:
ret['success'] = True
ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]]
ret['tvec'] = ret['cam_from_world'].translation
success = ret['success']
if success:
print_text = 'qname: {:s} localization success with {:d}/{:d} inliers'.format(qname, ret['num_inliers'],
all_mp3ds.shape[0])
print(print_text)
if log_info is not None:
log_info = log_info + print_text + '\n'
qvec = ret['qvec']
tvec = ret['tvec']
ret['cfg'] = cfg
num_inliers = ret['num_inliers']
inliers = ret['inliers']
return {
'qvec': qvec,
'tvec': tvec,
'log_info': log_info,
'qname': qname,
'dbname': best_db_name,
'num_inliers': num_inliers,
'order': -1,
'keypoints_query': np.array([all_mkpts[i] for i in range(len(inliers)) if inliers[i]]),
'points3D_ids': [all_points3D_ids[i] for i in range(len(inliers)) if inliers[i]],
'time': time.time() - t_start,
}
else:
print_text = 'Localize {:s} failed, but use the pose of {:s} as approximation'.format(qname, best_db_name)
print(print_text)
if log_info is not None:
log_info = log_info + print_text + '\n'
qvec = db_images[best_db_id].qvec
tvec = db_images[best_db_id].tvec
return {
'qvec': qvec,
'tvec': tvec,
'log_info': log_info,
'qname': qname,
'dbname': best_db_name,
'num_inliers': 0,
'order': -1,
'keypoints_query': np.array([]),
'points3D_ids': [],
'time': time.time() - t_start,
}
def pose_refinement(query_data,
query_cam, feature_file, db_frame_id, db_images, points3D, matcher,
covisibility_frame=50,
obs_th=3,
opt_th=12,
qvec=None,
tvec=None,
log_info='',
**kwargs,
):
db_ids = get_covisibility_frames(frame_id=db_frame_id, all_images=db_images, points3D=points3D,
covisibility_frame=covisibility_frame)
mp3d = []
mkpq = []
mkpdb = []
all_3D_ids = []
all_score_q = []
kpq = query_data['keypoints']
for i, db_id in enumerate(db_ids):
db_name = db_images[db_id].name
kpdb = feature_file[db_name]['keypoints'][()]
scores_db = feature_file[db_name]['scores'][()]
imgsize_db = feature_file[db_name]['image_size'][()]
desc_db = feature_file[db_name]["descriptors"][()]
desc_db = desc_db.transpose()
points3D_ids = db_images[db_id].point3D_ids
if points3D_ids.size == 0:
print("No 3D points in this db image: ", db_name)
continue
matches = feature_matching(query_data=query_data,
db_data={'keypoints': kpdb,
'scores': scores_db,
'descriptors': desc_db,
'image_size': imgsize_db,
'db_3D_ids': points3D_ids,
},
matcher=matcher,
)
valid = np.where(matches > -1)[0]
valid = valid[points3D_ids[matches[valid]] != -1]
inliers = []
for idx in valid:
id_3D = points3D_ids[matches[idx]]
if len(points3D[id_3D].image_ids) < obs_th:
continue
inliers.append(True)
mp3d.append(points3D[id_3D].xyz)
mkpq.append(kpq[idx])
mkpdb.append(kpdb[matches[idx]])
all_3D_ids.append(id_3D)
mp3d = np.array(mp3d, float).reshape(-1, 3)
mkpq = np.array(mkpq, float).reshape(-1, 2) + 0.5
print_text = 'Get {:d} covisible frames with {:d} matches from cluster optimization'.format(len(db_ids),
mp3d.shape[0])
print(print_text)
if log_info is not None:
log_info += (print_text + '\n')
# cam = pycolmap.Camera(model=cfg['model'], params=cfg['params'])
ret = pycolmap.absolute_pose_estimation(mkpq, mp3d,
query_cam,
estimation_options={
"ransac": {"max_error": opt_th}},
refinement_options={},
)
if ret is None:
ret = {'success': False, }
else:
ret['success'] = True
ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]]
ret['tvec'] = ret['cam_from_world'].translation
if not ret['success']:
ret['mkpq'] = mkpq
ret['3D_ids'] = all_3D_ids
ret['db_ids'] = db_ids
ret['score_q'] = all_score_q
ret['log_info'] = log_info
ret['qvec'] = qvec
ret['tvec'] = tvec
ret['inliers'] = [False for i in range(mkpq.shape[0])]
ret['num_inliers'] = 0
ret['keypoints_query'] = np.array([])
ret['points3D_ids'] = []
return ret
ret_inliers = ret['inliers']
loc_keypoints_query = np.array([mkpq[i] for i in range(len(ret_inliers)) if ret_inliers[i]])
loc_points3D_ids = [all_3D_ids[i] for i in range(len(ret_inliers)) if ret_inliers[i]]
ret['mkpq'] = mkpq
ret['3D_ids'] = all_3D_ids
ret['db_ids'] = db_ids
ret['log_info'] = log_info
ret['keypoints_query'] = loc_keypoints_query
ret['points3D_ids'] = loc_points3D_ids
return ret
# proposed in efficient large-scale localization by global instance recognition, cvpr 2022
def pose_estimator_iterative(qname, qinfo, db_ids, db_images, points3D, feature_file, thresh, image_dir,
matcher,
inlier_th=50,
log_info=None,
do_covisibility_opt=False,
covisibility_frame=50,
vis_dir=None,
obs_th=0,
opt_th=12,
gt_qvec=None,
gt_tvec=None,
query_img_prefix='',
db_img_prefix='',
):
print("qname: ", qname)
db_name_to_id = {image.name: i for i, image in db_images.items()}
# q_img = cv2.imread(osp.join(image_dir, query_img_prefix, qname))
kpq = feature_file[qname]['keypoints'][()]
score_q = feature_file[qname]['scores'][()]
imgsize_q = feature_file[qname]['image_size'][()]
desc_q = feature_file[qname]['descriptors'][()]
desc_q = desc_q.transpose() # [N D]
query_data = {
'keypoints': kpq,
'scores': score_q,
'descriptors': desc_q,
'image_size': imgsize_q,
}
camera_model, width, height, params = qinfo
best_results = {
'tvec': None,
'qvec': None,
'num_inliers': 0,
'single_num_inliers': 0,
'db_id': -1,
'order': -1,
'qname': qname,
'optimize': False,
'dbname': db_images[db_ids[0]].name,
"ret_source": "",
"inliers": [],
'keypoints_query': np.array([]),
'points3D_ids': [],
}
cam = pycolmap.Camera(model=camera_model, width=width, height=height, params=params)
for cluster_idx, db_id in enumerate(db_ids):
db_name = db_images[db_id].name
mp3d, mkpq, mp3d_ids, q_ids = find_2D_3D_matches(
query_data=query_data,
db_id=db_id,
points3D=points3D,
feature_file=feature_file,
db_images=db_images,
matcher=matcher,
obs_th=obs_th)
if mp3d.shape[0] < 8:
print_text = "qname: {:s} dbname: {:s}({:d}/{:d}) failed because of insufficient 3d points {:d}".format(
qname,
db_name,
cluster_idx + 1,
len(db_ids),
mp3d.shape[0])
print(print_text)
if log_info is not None:
log_info += (print_text + '\n')
continue
ret = pycolmap.absolute_pose_estimation(mkpq, mp3d, cam,
estimation_options={
"ransac": {"max_error": thresh}},
refinement_options={},
)
if ret is None:
ret = {'success': False, }
else:
ret['success'] = True
ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]]
ret['tvec'] = ret['cam_from_world'].translation
if not ret["success"]:
print_text = "qname: {:s} dbname: {:s} ({:d}/{:d}) failed after matching".format(qname, db_name,
cluster_idx + 1,
len(db_ids))
print(print_text)
if log_info is not None:
log_info += (print_text + '\n')
continue
inliers = ret['inliers']
num_inliers = ret['num_inliers']
inlier_p3d_ids = [mp3d_ids[i] for i in range(len(inliers)) if inliers[i]]
inlier_mkpq = [mkpq[i] for i in range(len(inliers)) if inliers[i]]
loc_keypoints_query = np.array(inlier_mkpq)
loc_points3D_ids = inlier_p3d_ids
if ret['num_inliers'] > best_results['num_inliers']:
best_results['qvec'] = ret['qvec']
best_results['tvec'] = ret['tvec']
best_results['inlier'] = ret['inliers']
best_results['num_inliers'] = ret['num_inliers']
best_results['dbname'] = db_name
best_results['order'] = cluster_idx + 1
best_results['keypoints_query'] = loc_keypoints_query
best_results['points3D_ids'] = loc_points3D_ids
if ret['num_inliers'] < inlier_th:
print_text = "qname: {:s} dbname: {:s} ({:d}/{:d}) failed insufficient {:d} inliers".format(qname,
db_name,
cluster_idx + 1,
len(db_ids),
num_inliers,
)
print(print_text)
if log_info is not None:
log_info += (print_text + '\n')
continue
print_text = "qname: {:s} dbname: {:s} ({:d}/{:d}) initialization succeed with {:d} inliers".format(
qname,
db_name,
cluster_idx + 1,
len(db_ids),
ret["num_inliers"]
)
print(print_text)
if log_info is not None:
log_info += (print_text + '\n')
if do_covisibility_opt:
ret = pose_refinement(qname=qname,
query_cam=cam,
feature_file=feature_file,
db_frame_id=db_id,
db_images=db_images,
points3D=points3D,
thresh=thresh,
covisibility_frame=covisibility_frame,
matcher=matcher,
obs_th=obs_th,
opt_th=opt_th,
qvec=ret['qvec'],
tvec=ret['tvec'],
log_info='',
image_dir=image_dir,
vis_dir=vis_dir,
gt_qvec=gt_qvec,
gt_tvec=gt_tvec,
)
loc_keypoints_query = ret['keypoints_query']
loc_points3D_ids = ret['points3D_ids']
log_info = log_info + ret['log_info']
print_text = 'Find {:d} inliers after optimization'.format(ret['num_inliers'])
print(print_text)
if log_info is not None:
log_info += (print_text + "\n")
# localization succeed
qvec = ret['qvec']
tvec = ret['tvec']
num_inliers = ret['num_inliers']
best_results['keypoints_query'] = loc_keypoints_query
best_results['points3D_ids'] = loc_points3D_ids
best_results['qvec'] = qvec
best_results['tvec'] = tvec
best_results['num_inliers'] = num_inliers
best_results['log_info'] = log_info
return best_results
if best_results['num_inliers'] >= 10: # 20 for aachen
qvec = best_results['qvec']
tvec = best_results['tvec']
best_dbname = best_results['dbname']
best_results['keypoints_query'] = loc_keypoints_query
best_results['points3D_ids'] = loc_points3D_ids
if do_covisibility_opt:
ret = pose_refinement(qname=qname,
query_cam=cam,
feature_file=feature_file,
db_frame_id=db_name_to_id[best_dbname],
db_images=db_images,
points3D=points3D,
thresh=thresh,
covisibility_frame=covisibility_frame,
matcher=matcher,
obs_th=obs_th,
opt_th=opt_th,
qvec=qvec,
tvec=tvec,
log_info='',
image_dir=image_dir,
vis_dir=vis_dir,
gt_qvec=gt_qvec,
gt_tvec=gt_tvec,
)
# localization succeed
qvec = ret['qvec']
tvec = ret['tvec']
num_inliers = ret['num_inliers']
best_results['keypoints_query'] = loc_keypoints_query
best_results['points3D_ids'] = loc_points3D_ids
best_results['qvec'] = qvec
best_results['tvec'] = tvec
best_results['num_inliers'] = num_inliers
best_results['log_info'] = log_info
return best_results
closest = db_images[db_ids[0][0]]
print_text = 'Localize {:s} failed, but use the pose of {:s} as approximation'.format(qname, closest.name)
print(print_text)
if log_info is not None:
log_info += (print_text + '\n')
best_results['qvec'] = closest.qvec
best_results['tvec'] = closest.tvec
best_results['num_inliers'] = -1
best_results['log_info'] = log_info
return best_results