Spaces:
Running
Running
# -*- coding: UTF-8 -*- | |
'''================================================= | |
@Project -> File pram -> pose_estimation | |
@IDE PyCharm | |
@Author fx221@cam.ac.uk | |
@Date 08/02/2024 11:01 | |
==================================================''' | |
import torch | |
import numpy as np | |
import pycolmap | |
import cv2 | |
import os | |
import time | |
import os.path as osp | |
from collections import defaultdict | |
def get_covisibility_frames(frame_id, all_images, points3D, covisibility_frame=50): | |
observed = all_images[frame_id].point3D_ids | |
covis = defaultdict(int) | |
for pid in observed: | |
if pid == -1: | |
continue | |
for img_id in points3D[pid].image_ids: | |
if img_id != frame_id: | |
covis[img_id] += 1 | |
print('Find {:d} connected frames'.format(len(covis.keys()))) | |
covis_ids = np.array(list(covis.keys())) | |
covis_num = np.array([covis[i] for i in covis_ids]) | |
if len(covis_ids) <= covisibility_frame: | |
sel_covis_ids = covis_ids[np.argsort(-covis_num)] | |
else: | |
ind_top = np.argpartition(covis_num, -covisibility_frame) | |
ind_top = ind_top[-covisibility_frame:] # unsorted top k | |
ind_top = ind_top[np.argsort(-covis_num[ind_top])] | |
sel_covis_ids = [covis_ids[i] for i in ind_top] | |
print('Retain {:d} valid connected frames'.format(len(sel_covis_ids))) | |
return sel_covis_ids | |
def feature_matching(query_data, db_data, matcher): | |
db_3D_ids = db_data['db_3D_ids'] | |
if db_3D_ids is None: | |
with torch.no_grad(): | |
match_data = { | |
'keypoints0': torch.from_numpy(query_data['keypoints'])[None].float().cuda(), | |
'scores0': torch.from_numpy(query_data['scores'])[None].float().cuda(), | |
'descriptors0': torch.from_numpy(query_data['descriptors'])[None].float().cuda(), | |
'image0': torch.empty((1, 1,) + tuple(query_data['image_size'])[::-1]), | |
'keypoints1': torch.from_numpy(db_data['keypoints'])[None].float().cuda(), | |
'scores1': torch.from_numpy(db_data['scores'])[None].float().cuda(), | |
'descriptors1': torch.from_numpy(db_data['descriptors'])[None].float().cuda(), # [B, N, D] | |
'image1': torch.empty((1, 1,) + tuple(db_data['image_size'])[::-1]), | |
} | |
matches = matcher(match_data)['matches0'][0].cpu().numpy() | |
del match_data | |
else: | |
masks = (db_3D_ids != -1) | |
valid_ids = [i for i in range(masks.shape[0]) if masks[i]] | |
if len(valid_ids) == 0: | |
return np.zeros(shape=(query_data['keypoints'].shape[0],), dtype=int) - 1 | |
with torch.no_grad(): | |
match_data = { | |
'keypoints0': torch.from_numpy(query_data['keypoints'])[None].float().cuda(), | |
'scores0': torch.from_numpy(query_data['scores'])[None].float().cuda(), | |
'descriptors0': torch.from_numpy(query_data['descriptors'])[None].float().cuda(), | |
'image0': torch.empty((1, 1,) + tuple(query_data['image_size'])[::-1]), | |
'keypoints1': torch.from_numpy(db_data['keypoints'])[masks][None].float().cuda(), | |
'scores1': torch.from_numpy(db_data['scores'])[masks][None].float().cuda(), | |
'descriptors1': torch.from_numpy(db_data['descriptors'][masks])[None].float().cuda(), | |
'image1': torch.empty((1, 1,) + tuple(db_data['image_size'])[::-1]), | |
} | |
matches = matcher(match_data)['matches0'][0].cpu().numpy() | |
del match_data | |
for i in range(matches.shape[0]): | |
if matches[i] >= 0: | |
matches[i] = valid_ids[matches[i]] | |
return matches | |
def find_2D_3D_matches(query_data, db_id, points3D, feature_file, db_images, matcher, obs_th=0): | |
kpq = query_data['keypoints'] | |
db_name = db_images[db_id].name | |
kpdb = feature_file[db_name]['keypoints'][()] | |
desc_db = feature_file[db_name]["descriptors"][()] | |
desc_db = desc_db.transpose() | |
# print('db_desc: ', desc_db.shape, query_data['descriptors'].shape) | |
points3D_ids = db_images[db_id].point3D_ids | |
matches = feature_matching(query_data=query_data, | |
db_data={ | |
'keypoints': kpdb, | |
'scores': feature_file[db_name]['scores'][()], | |
'descriptors': desc_db, | |
'db_3D_ids': points3D_ids, | |
'image_size': feature_file[db_name]['image_size'][()] | |
}, | |
matcher=matcher) | |
mkpdb = [] | |
mp3d_ids = [] | |
q_ids = [] | |
mkpq = [] | |
mp3d = [] | |
valid_matches = [] | |
for idx in range(matches.shape[0]): | |
if matches[idx] == -1: | |
continue | |
if points3D_ids[matches[idx]] == -1: | |
continue | |
id_3D = points3D_ids[matches[idx]] | |
# reject 3d points without enough observations | |
if len(points3D[id_3D].image_ids) < obs_th: | |
continue | |
mp3d.append(points3D[id_3D].xyz) | |
mp3d_ids.append(id_3D) | |
mkpq.append(kpq[idx]) | |
mkpdb.append(kpdb[matches[idx]]) | |
q_ids.append(idx) | |
valid_matches.append(matches[idx]) | |
mp3d = np.array(mp3d, float).reshape(-1, 3) | |
mkpq = np.array(mkpq, float).reshape(-1, 2) + 0.5 | |
return mp3d, mkpq, mp3d_ids, q_ids | |
# hfnet, cvpr 2019 | |
def pose_estimator_hloc(qname, qinfo, db_ids, db_images, points3D, | |
feature_file, | |
thresh, | |
image_dir, | |
matcher, | |
log_info=None, | |
query_img_prefix='', | |
db_img_prefix=''): | |
kpq = feature_file[qname]['keypoints'][()] | |
score_q = feature_file[qname]['scores'][()] | |
desc_q = feature_file[qname]['descriptors'][()] | |
desc_q = desc_q.transpose() | |
imgsize_q = feature_file[qname]['image_size'][()] | |
query_data = { | |
'keypoints': kpq, | |
'scores': score_q, | |
'descriptors': desc_q, | |
'image_size': imgsize_q, | |
} | |
camera_model, width, height, params = qinfo | |
cam = pycolmap.Camera(model=camera_model, width=width, height=height, params=params) | |
cfg = { | |
'model': camera_model, | |
'width': width, | |
'height': height, | |
'params': params, | |
} | |
all_mkpts = [] | |
all_mp3ds = [] | |
all_points3D_ids = [] | |
best_db_id = db_ids[0] | |
best_db_name = db_images[best_db_id].name | |
t_start = time.time() | |
for cluster_idx, db_id in enumerate(db_ids): | |
mp3d, mkpq, mp3d_ids, q_ids = find_2D_3D_matches( | |
query_data=query_data, | |
db_id=db_id, | |
points3D=points3D, | |
feature_file=feature_file, | |
db_images=db_images, | |
matcher=matcher, | |
obs_th=3) | |
if mp3d.shape[0] > 0: | |
all_mkpts.append(mkpq) | |
all_mp3ds.append(mp3d) | |
all_points3D_ids = all_points3D_ids + mp3d_ids | |
if len(all_mkpts) == 0: | |
print_text = 'Localize {:s} failed, but use the pose of {:s} as approximation'.format(qname, best_db_name) | |
print(print_text) | |
if log_info is not None: | |
log_info = log_info + print_text + '\n' | |
qvec = db_images[best_db_id].qvec | |
tvec = db_images[best_db_id].tvec | |
return { | |
'qvec': qvec, | |
'tvec': tvec, | |
'log_info': log_info, | |
'qname': qname, | |
'dbname': best_db_name, | |
'num_inliers': 0, | |
'order': -1, | |
'keypoints_query': np.array([]), | |
'points3D_ids': [], | |
'time': time.time() - t_start, | |
} | |
all_mkpts = np.vstack(all_mkpts) | |
all_mp3ds = np.vstack(all_mp3ds) | |
ret = pycolmap.absolute_pose_estimation(all_mkpts, all_mp3ds, cam, | |
estimation_options={ | |
"ransac": {"max_error": thresh}}, | |
refinement_options={}, | |
) | |
if ret is None: | |
ret = {'success': False, } | |
else: | |
ret['success'] = True | |
ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]] | |
ret['tvec'] = ret['cam_from_world'].translation | |
success = ret['success'] | |
if success: | |
print_text = 'qname: {:s} localization success with {:d}/{:d} inliers'.format(qname, ret['num_inliers'], | |
all_mp3ds.shape[0]) | |
print(print_text) | |
if log_info is not None: | |
log_info = log_info + print_text + '\n' | |
qvec = ret['qvec'] | |
tvec = ret['tvec'] | |
ret['cfg'] = cfg | |
num_inliers = ret['num_inliers'] | |
inliers = ret['inliers'] | |
return { | |
'qvec': qvec, | |
'tvec': tvec, | |
'log_info': log_info, | |
'qname': qname, | |
'dbname': best_db_name, | |
'num_inliers': num_inliers, | |
'order': -1, | |
'keypoints_query': np.array([all_mkpts[i] for i in range(len(inliers)) if inliers[i]]), | |
'points3D_ids': [all_points3D_ids[i] for i in range(len(inliers)) if inliers[i]], | |
'time': time.time() - t_start, | |
} | |
else: | |
print_text = 'Localize {:s} failed, but use the pose of {:s} as approximation'.format(qname, best_db_name) | |
print(print_text) | |
if log_info is not None: | |
log_info = log_info + print_text + '\n' | |
qvec = db_images[best_db_id].qvec | |
tvec = db_images[best_db_id].tvec | |
return { | |
'qvec': qvec, | |
'tvec': tvec, | |
'log_info': log_info, | |
'qname': qname, | |
'dbname': best_db_name, | |
'num_inliers': 0, | |
'order': -1, | |
'keypoints_query': np.array([]), | |
'points3D_ids': [], | |
'time': time.time() - t_start, | |
} | |
def pose_refinement(query_data, | |
query_cam, feature_file, db_frame_id, db_images, points3D, matcher, | |
covisibility_frame=50, | |
obs_th=3, | |
opt_th=12, | |
qvec=None, | |
tvec=None, | |
log_info='', | |
**kwargs, | |
): | |
db_ids = get_covisibility_frames(frame_id=db_frame_id, all_images=db_images, points3D=points3D, | |
covisibility_frame=covisibility_frame) | |
mp3d = [] | |
mkpq = [] | |
mkpdb = [] | |
all_3D_ids = [] | |
all_score_q = [] | |
kpq = query_data['keypoints'] | |
for i, db_id in enumerate(db_ids): | |
db_name = db_images[db_id].name | |
kpdb = feature_file[db_name]['keypoints'][()] | |
scores_db = feature_file[db_name]['scores'][()] | |
imgsize_db = feature_file[db_name]['image_size'][()] | |
desc_db = feature_file[db_name]["descriptors"][()] | |
desc_db = desc_db.transpose() | |
points3D_ids = db_images[db_id].point3D_ids | |
if points3D_ids.size == 0: | |
print("No 3D points in this db image: ", db_name) | |
continue | |
matches = feature_matching(query_data=query_data, | |
db_data={'keypoints': kpdb, | |
'scores': scores_db, | |
'descriptors': desc_db, | |
'image_size': imgsize_db, | |
'db_3D_ids': points3D_ids, | |
}, | |
matcher=matcher, | |
) | |
valid = np.where(matches > -1)[0] | |
valid = valid[points3D_ids[matches[valid]] != -1] | |
inliers = [] | |
for idx in valid: | |
id_3D = points3D_ids[matches[idx]] | |
if len(points3D[id_3D].image_ids) < obs_th: | |
continue | |
inliers.append(True) | |
mp3d.append(points3D[id_3D].xyz) | |
mkpq.append(kpq[idx]) | |
mkpdb.append(kpdb[matches[idx]]) | |
all_3D_ids.append(id_3D) | |
mp3d = np.array(mp3d, float).reshape(-1, 3) | |
mkpq = np.array(mkpq, float).reshape(-1, 2) + 0.5 | |
print_text = 'Get {:d} covisible frames with {:d} matches from cluster optimization'.format(len(db_ids), | |
mp3d.shape[0]) | |
print(print_text) | |
if log_info is not None: | |
log_info += (print_text + '\n') | |
# cam = pycolmap.Camera(model=cfg['model'], params=cfg['params']) | |
ret = pycolmap.absolute_pose_estimation(mkpq, mp3d, | |
query_cam, | |
estimation_options={ | |
"ransac": {"max_error": opt_th}}, | |
refinement_options={}, | |
) | |
if ret is None: | |
ret = {'success': False, } | |
else: | |
ret['success'] = True | |
ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]] | |
ret['tvec'] = ret['cam_from_world'].translation | |
if not ret['success']: | |
ret['mkpq'] = mkpq | |
ret['3D_ids'] = all_3D_ids | |
ret['db_ids'] = db_ids | |
ret['score_q'] = all_score_q | |
ret['log_info'] = log_info | |
ret['qvec'] = qvec | |
ret['tvec'] = tvec | |
ret['inliers'] = [False for i in range(mkpq.shape[0])] | |
ret['num_inliers'] = 0 | |
ret['keypoints_query'] = np.array([]) | |
ret['points3D_ids'] = [] | |
return ret | |
ret_inliers = ret['inliers'] | |
loc_keypoints_query = np.array([mkpq[i] for i in range(len(ret_inliers)) if ret_inliers[i]]) | |
loc_points3D_ids = [all_3D_ids[i] for i in range(len(ret_inliers)) if ret_inliers[i]] | |
ret['mkpq'] = mkpq | |
ret['3D_ids'] = all_3D_ids | |
ret['db_ids'] = db_ids | |
ret['log_info'] = log_info | |
ret['keypoints_query'] = loc_keypoints_query | |
ret['points3D_ids'] = loc_points3D_ids | |
return ret | |
# proposed in efficient large-scale localization by global instance recognition, cvpr 2022 | |
def pose_estimator_iterative(qname, qinfo, db_ids, db_images, points3D, feature_file, thresh, image_dir, | |
matcher, | |
inlier_th=50, | |
log_info=None, | |
do_covisibility_opt=False, | |
covisibility_frame=50, | |
vis_dir=None, | |
obs_th=0, | |
opt_th=12, | |
gt_qvec=None, | |
gt_tvec=None, | |
query_img_prefix='', | |
db_img_prefix='', | |
): | |
print("qname: ", qname) | |
db_name_to_id = {image.name: i for i, image in db_images.items()} | |
# q_img = cv2.imread(osp.join(image_dir, query_img_prefix, qname)) | |
kpq = feature_file[qname]['keypoints'][()] | |
score_q = feature_file[qname]['scores'][()] | |
imgsize_q = feature_file[qname]['image_size'][()] | |
desc_q = feature_file[qname]['descriptors'][()] | |
desc_q = desc_q.transpose() # [N D] | |
query_data = { | |
'keypoints': kpq, | |
'scores': score_q, | |
'descriptors': desc_q, | |
'image_size': imgsize_q, | |
} | |
camera_model, width, height, params = qinfo | |
best_results = { | |
'tvec': None, | |
'qvec': None, | |
'num_inliers': 0, | |
'single_num_inliers': 0, | |
'db_id': -1, | |
'order': -1, | |
'qname': qname, | |
'optimize': False, | |
'dbname': db_images[db_ids[0]].name, | |
"ret_source": "", | |
"inliers": [], | |
'keypoints_query': np.array([]), | |
'points3D_ids': [], | |
} | |
cam = pycolmap.Camera(model=camera_model, width=width, height=height, params=params) | |
for cluster_idx, db_id in enumerate(db_ids): | |
db_name = db_images[db_id].name | |
mp3d, mkpq, mp3d_ids, q_ids = find_2D_3D_matches( | |
query_data=query_data, | |
db_id=db_id, | |
points3D=points3D, | |
feature_file=feature_file, | |
db_images=db_images, | |
matcher=matcher, | |
obs_th=obs_th) | |
if mp3d.shape[0] < 8: | |
print_text = "qname: {:s} dbname: {:s}({:d}/{:d}) failed because of insufficient 3d points {:d}".format( | |
qname, | |
db_name, | |
cluster_idx + 1, | |
len(db_ids), | |
mp3d.shape[0]) | |
print(print_text) | |
if log_info is not None: | |
log_info += (print_text + '\n') | |
continue | |
ret = pycolmap.absolute_pose_estimation(mkpq, mp3d, cam, | |
estimation_options={ | |
"ransac": {"max_error": thresh}}, | |
refinement_options={}, | |
) | |
if ret is None: | |
ret = {'success': False, } | |
else: | |
ret['success'] = True | |
ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]] | |
ret['tvec'] = ret['cam_from_world'].translation | |
if not ret["success"]: | |
print_text = "qname: {:s} dbname: {:s} ({:d}/{:d}) failed after matching".format(qname, db_name, | |
cluster_idx + 1, | |
len(db_ids)) | |
print(print_text) | |
if log_info is not None: | |
log_info += (print_text + '\n') | |
continue | |
inliers = ret['inliers'] | |
num_inliers = ret['num_inliers'] | |
inlier_p3d_ids = [mp3d_ids[i] for i in range(len(inliers)) if inliers[i]] | |
inlier_mkpq = [mkpq[i] for i in range(len(inliers)) if inliers[i]] | |
loc_keypoints_query = np.array(inlier_mkpq) | |
loc_points3D_ids = inlier_p3d_ids | |
if ret['num_inliers'] > best_results['num_inliers']: | |
best_results['qvec'] = ret['qvec'] | |
best_results['tvec'] = ret['tvec'] | |
best_results['inlier'] = ret['inliers'] | |
best_results['num_inliers'] = ret['num_inliers'] | |
best_results['dbname'] = db_name | |
best_results['order'] = cluster_idx + 1 | |
best_results['keypoints_query'] = loc_keypoints_query | |
best_results['points3D_ids'] = loc_points3D_ids | |
if ret['num_inliers'] < inlier_th: | |
print_text = "qname: {:s} dbname: {:s} ({:d}/{:d}) failed insufficient {:d} inliers".format(qname, | |
db_name, | |
cluster_idx + 1, | |
len(db_ids), | |
num_inliers, | |
) | |
print(print_text) | |
if log_info is not None: | |
log_info += (print_text + '\n') | |
continue | |
print_text = "qname: {:s} dbname: {:s} ({:d}/{:d}) initialization succeed with {:d} inliers".format( | |
qname, | |
db_name, | |
cluster_idx + 1, | |
len(db_ids), | |
ret["num_inliers"] | |
) | |
print(print_text) | |
if log_info is not None: | |
log_info += (print_text + '\n') | |
if do_covisibility_opt: | |
ret = pose_refinement(qname=qname, | |
query_cam=cam, | |
feature_file=feature_file, | |
db_frame_id=db_id, | |
db_images=db_images, | |
points3D=points3D, | |
thresh=thresh, | |
covisibility_frame=covisibility_frame, | |
matcher=matcher, | |
obs_th=obs_th, | |
opt_th=opt_th, | |
qvec=ret['qvec'], | |
tvec=ret['tvec'], | |
log_info='', | |
image_dir=image_dir, | |
vis_dir=vis_dir, | |
gt_qvec=gt_qvec, | |
gt_tvec=gt_tvec, | |
) | |
loc_keypoints_query = ret['keypoints_query'] | |
loc_points3D_ids = ret['points3D_ids'] | |
log_info = log_info + ret['log_info'] | |
print_text = 'Find {:d} inliers after optimization'.format(ret['num_inliers']) | |
print(print_text) | |
if log_info is not None: | |
log_info += (print_text + "\n") | |
# localization succeed | |
qvec = ret['qvec'] | |
tvec = ret['tvec'] | |
num_inliers = ret['num_inliers'] | |
best_results['keypoints_query'] = loc_keypoints_query | |
best_results['points3D_ids'] = loc_points3D_ids | |
best_results['qvec'] = qvec | |
best_results['tvec'] = tvec | |
best_results['num_inliers'] = num_inliers | |
best_results['log_info'] = log_info | |
return best_results | |
if best_results['num_inliers'] >= 10: # 20 for aachen | |
qvec = best_results['qvec'] | |
tvec = best_results['tvec'] | |
best_dbname = best_results['dbname'] | |
best_results['keypoints_query'] = loc_keypoints_query | |
best_results['points3D_ids'] = loc_points3D_ids | |
if do_covisibility_opt: | |
ret = pose_refinement(qname=qname, | |
query_cam=cam, | |
feature_file=feature_file, | |
db_frame_id=db_name_to_id[best_dbname], | |
db_images=db_images, | |
points3D=points3D, | |
thresh=thresh, | |
covisibility_frame=covisibility_frame, | |
matcher=matcher, | |
obs_th=obs_th, | |
opt_th=opt_th, | |
qvec=qvec, | |
tvec=tvec, | |
log_info='', | |
image_dir=image_dir, | |
vis_dir=vis_dir, | |
gt_qvec=gt_qvec, | |
gt_tvec=gt_tvec, | |
) | |
# localization succeed | |
qvec = ret['qvec'] | |
tvec = ret['tvec'] | |
num_inliers = ret['num_inliers'] | |
best_results['keypoints_query'] = loc_keypoints_query | |
best_results['points3D_ids'] = loc_points3D_ids | |
best_results['qvec'] = qvec | |
best_results['tvec'] = tvec | |
best_results['num_inliers'] = num_inliers | |
best_results['log_info'] = log_info | |
return best_results | |
closest = db_images[db_ids[0][0]] | |
print_text = 'Localize {:s} failed, but use the pose of {:s} as approximation'.format(qname, closest.name) | |
print(print_text) | |
if log_info is not None: | |
log_info += (print_text + '\n') | |
best_results['qvec'] = closest.qvec | |
best_results['tvec'] = closest.tvec | |
best_results['num_inliers'] = -1 | |
best_results['log_info'] = log_info | |
return best_results | |