Spaces:
Running
Running
# -*- coding: UTF-8 -*- | |
'''================================================= | |
@Project -> File pram -> map3d | |
@IDE PyCharm | |
@Author fx221@cam.ac.uk | |
@Date 04/03/2024 10:25 | |
==================================================''' | |
import numpy as np | |
from collections import defaultdict | |
import os.path as osp | |
import pycolmap | |
import logging | |
import time | |
import torch | |
from localization.refframe import RefFrame | |
from localization.frame import Frame | |
from localization.point3d import Point3D | |
from colmap_utils.read_write_model import qvec2rotmat, read_model, read_compressed_model | |
from localization.utils import read_gt_pose | |
class SingleMap3D: | |
def __init__(self, config, matcher, with_compress=False, start_sid: int = 0): | |
self.config = config | |
self.matcher = matcher | |
self.image_path_prefix = self.config['image_path_prefix'] | |
self.start_sid = start_sid # for a dataset with multiple scenes | |
if not with_compress: | |
cameras, images, p3ds = read_model( | |
path=osp.join(config['landmark_path'], 'model'), ext='.bin') | |
p3d_descs = np.load(osp.join(config['landmark_path'], 'point3D_desc.npy'), | |
allow_pickle=True)[()] | |
else: | |
cameras, images, p3ds = read_compressed_model( | |
path=osp.join(config['landmark_path'], 'compress_model_{:s}'.format(config['cluster_method'])), | |
ext='.bin') | |
p3d_descs = np.load(osp.join(config['landmark_path'], 'compress_model_{:s}/point3D_desc.npy'.format( | |
config['cluster_method'])), allow_pickle=True)[()] | |
print('Load {} cameras {} images {} 3D points'.format(len(cameras), len(images), len(p3d_descs))) | |
seg_data = np.load( | |
osp.join(config['landmark_path'], 'point3D_cluster_n{:d}_{:s}_{:s}.npy'.format(config['n_cluster'], | |
config['cluster_mode'], | |
config['cluster_method'])), | |
allow_pickle=True)[()] | |
p3d_id = seg_data['id'] | |
seg_id = seg_data['label'] | |
p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])} | |
seg_p3d = {} | |
for k in p3d_seg.keys(): | |
sid = p3d_seg[k] | |
if sid in seg_p3d.keys(): | |
seg_p3d[sid].append(k) | |
else: | |
seg_p3d[sid] = [k] | |
print('Load {} segments and {} 3d points'.format(len(seg_p3d.keys()), len(p3d_seg.keys()))) | |
seg_vrf = np.load( | |
osp.join(config['landmark_path'], 'point3D_vrf_n{:d}_{:s}_{:s}.npy'.format(config['n_cluster'], | |
config['cluster_mode'], | |
config['cluster_method'])), | |
allow_pickle=True)[()] | |
# construct 3D map | |
self.initialize_point3Ds(p3ds=p3ds, p3d_descs=p3d_descs, p3d_seg=p3d_seg) | |
self.initialize_ref_frames(cameras=cameras, images=images) | |
all_vrf_frame_ids = [] | |
self.seg_ref_frame_ids = {} | |
for sid in seg_vrf.keys(): | |
self.seg_ref_frame_ids[sid] = [] | |
for vi in seg_vrf[sid].keys(): | |
vrf_frame_id = seg_vrf[sid][vi]['image_id'] | |
self.seg_ref_frame_ids[sid].append(vrf_frame_id) | |
if with_compress and vrf_frame_id in self.reference_frames.keys(): | |
self.reference_frames[vrf_frame_id].point3D_ids = seg_vrf[sid][vi]['original_points3d'] | |
all_vrf_frame_ids.extend(self.seg_ref_frame_ids[sid]) | |
if with_compress: | |
all_ref_ids = list(self.reference_frames.keys()) | |
for fid in all_ref_ids: | |
valid = self.reference_frames[fid].associate_keypoints_with_point3Ds(point3Ds=self.point3Ds) | |
if not valid: | |
del self.reference_frames[fid] | |
all_vrf_frame_ids = np.unique(all_vrf_frame_ids) | |
all_vrf_frame_ids = [v for v in all_vrf_frame_ids if v in self.reference_frames.keys()] | |
self.build_covisibility_graph(frame_ids=all_vrf_frame_ids, n_frame=config['localization'][ | |
'covisibility_frame']) # build covisible frames for vrf frames only | |
logging.info( | |
f'Construct {len(self.reference_frames.keys())} ref frames and {len(self.point3Ds.keys())} 3d points') | |
self.gt_poses = {} | |
if config['gt_pose_path'] is not None: | |
gt_pose_path = osp.join(config['dataset_path'], config['gt_pose_path']) | |
self.read_gt_pose(path=gt_pose_path) | |
def read_gt_pose(self, path, prefix=''): | |
self.gt_poses = read_gt_pose(path=path) | |
print('Load {} gt poses'.format(len(self.gt_poses.keys()))) | |
def initialize_point3Ds(self, p3ds, p3d_descs, p3d_seg): | |
self.point3Ds = {} | |
for id in p3ds.keys(): | |
if id not in p3d_seg.keys(): | |
continue | |
self.point3Ds[id] = Point3D(id=id, xyz=p3ds[id].xyz, error=p3ds[id].error, | |
refframe_id=-1, rgb=p3ds[id].rgb, | |
descriptor=p3d_descs[id], seg_id=p3d_seg[id], | |
frame_ids=p3ds[id].image_ids) | |
def initialize_ref_frames(self, cameras, images): | |
self.reference_frames = {} | |
for id in images.keys(): | |
im = images[id] | |
cam = cameras[im.camera_id] | |
self.reference_frames[id] = RefFrame(camera=cam, id=id, qvec=im.qvec, tvec=im.tvec, | |
point3D_ids=im.point3D_ids, | |
keypoints=im.xys, name=im.name) | |
def localize_with_ref_frame(self, q_frame: Frame, q_kpt_ids: np.ndarray, sid, semantic_matching=False): | |
ref_frame_id = self.seg_ref_frame_ids[sid][0] | |
ref_frame = self.reference_frames[ref_frame_id] | |
if semantic_matching and sid > 0: | |
ref_data = ref_frame.get_keypoints_by_sid(sid=sid) | |
else: | |
ref_data = ref_frame.get_keypoints() | |
q_descs = q_frame.descriptors[q_kpt_ids] | |
q_kpts = q_frame.keypoints[q_kpt_ids, :2] | |
q_scores = q_frame.keypoints[q_kpt_ids, 2] | |
xyzs = ref_data['xyzs'] | |
point3D_ids = ref_data['point3D_ids'] | |
ref_sids = np.array([self.point3Ds[v].seg_id for v in point3D_ids]) | |
with torch.no_grad(): | |
indices0 = self.matcher({ | |
'descriptors0': torch.from_numpy(q_descs)[None].cuda().float(), | |
'keypoints0': torch.from_numpy(q_kpts)[None].cuda().float(), | |
'scores0': torch.from_numpy(q_scores)[None].cuda().float(), | |
'image_shape0': (1, 3, q_frame.camera.width, q_frame.camera.height), | |
'descriptors1': torch.from_numpy(ref_data['descriptors'])[None].cuda().float(), | |
'keypoints1': torch.from_numpy(ref_data['keypoints'])[None].cuda().float(), | |
'scores1': torch.from_numpy(ref_data['scores'])[None].cuda().float(), | |
'image_shape1': (1, 3, ref_frame.camera.width, ref_frame.camera.height), | |
} | |
)['matches0'][0].cpu().numpy() | |
valid = indices0 >= 0 | |
mkpts = q_kpts[valid] | |
mkpt_ids = q_kpt_ids[valid] | |
mxyzs = xyzs[indices0[valid]] | |
mpoint3D_ids = point3D_ids[indices0[valid]] | |
matched_sids = ref_sids[indices0[valid]] | |
matched_ref_keypoints = ref_data['keypoints'][indices0[valid]] | |
# print('mkpts: ', mkpts.shape, mxyzs.shape, np.sum(indices0 >= 0)) | |
# cfg = q_frame.camera._asdict() | |
# q_cam = pycolmap.Camera(model=q_frame.camera.model, ) | |
# config = {"estimation": {"ransac": {"max_error": ransac_thresh}}, **(config or {})} | |
ret = pycolmap.absolute_pose_estimation(mkpts + 0.5, | |
mxyzs, | |
q_frame.camera, | |
estimation_options={ | |
"ransac": {"max_error": self.config['localization']['threshold']}}, | |
refinement_options={}, | |
# max_error_px=self.config['localization']['threshold'] | |
) | |
if ret is None: | |
ret = {'success': False, } | |
else: | |
ret['success'] = True | |
ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]] | |
ret['tvec'] = ret['cam_from_world'].translation | |
ret['matched_keypoints'] = mkpts | |
ret['matched_keypoint_ids'] = mkpt_ids | |
ret['matched_xyzs'] = mxyzs | |
ret['reference_frame_id'] = ref_frame_id | |
ret['matched_point3D_ids'] = mpoint3D_ids | |
ret['matched_sids'] = matched_sids | |
ret['matched_ref_keypoints'] = matched_ref_keypoints | |
if not ret['success']: | |
ret['num_inliers'] = 0 | |
ret['inliers'] = np.zeros(shape=(mkpts.shape[0],), dtype=bool) | |
return ret | |
def match(self, query_data, ref_data): | |
q_descs = query_data['descriptors'] | |
q_kpts = query_data['keypoints'] | |
q_scores = query_data['scores'] | |
xyzs = ref_data['xyzs'] | |
points3D_ids = ref_data['point3D_ids'] | |
with torch.no_grad(): | |
indices0 = self.matcher({ | |
'descriptors0': torch.from_numpy(q_descs)[None].cuda().float(), | |
'keypoints0': torch.from_numpy(q_kpts)[None].cuda().float(), | |
'scores0': torch.from_numpy(q_scores)[None].cuda().float(), | |
'image_shape0': (1, 3, query_data['camera'].width, query_data['camera'].height), | |
'descriptors1': torch.from_numpy(ref_data['descriptors'])[None].cuda().float(), | |
'keypoints1': torch.from_numpy(ref_data['keypoints'])[None].cuda().float(), | |
'scores1': torch.from_numpy(ref_data['scores'])[None].cuda().float(), | |
'image_shape1': (1, 3, ref_data['camera'].width, ref_data['camera'].height), | |
} | |
)['matches0'][0].cpu().numpy() | |
valid = indices0 >= 0 | |
mkpts = q_kpts[valid] | |
mkpt_ids = np.where(valid)[0] | |
mxyzs = xyzs[indices0[valid]] | |
mpoints3D_ids = points3D_ids[indices0[valid]] | |
return { | |
'matched_keypoints': mkpts, | |
'matched_xyzs': mxyzs, | |
'matched_point3D_ids': mpoints3D_ids, | |
'matched_keypoint_ids': mkpt_ids, | |
} | |
def build_covisibility_graph(self, frame_ids: list = None, n_frame: int = 20): | |
def find_covisible_frames(frame_id): | |
observed = self.reference_frames[frame_id].point3D_ids | |
covis = defaultdict(int) | |
for pid in observed: | |
if pid == -1: | |
continue | |
if pid not in self.point3Ds.keys(): | |
continue | |
for img_id in self.point3Ds[pid].frame_ids: | |
covis[img_id] += 1 | |
covis_ids = np.array(list(covis.keys())) | |
covis_num = np.array([covis[i] for i in covis_ids]) | |
if len(covis_ids) <= n_frame: | |
sel_covis_ids = covis_ids[np.argsort(-covis_num)] | |
else: | |
ind_top = np.argpartition(covis_num, -n_frame) | |
ind_top = ind_top[-n_frame:] # unsorted top k | |
ind_top = ind_top[np.argsort(-covis_num[ind_top])] | |
sel_covis_ids = [covis_ids[i] for i in ind_top] | |
return sel_covis_ids | |
if frame_ids is None: | |
frame_ids = list(self.referece_frames.keys()) | |
self.covisible_graph = defaultdict() | |
for frame_id in frame_ids: | |
self.covisible_graph[frame_id] = find_covisible_frames(frame_id=frame_id) | |
def refine_pose(self, q_frame: Frame, refinement_method='matching'): | |
if refinement_method == 'matching': | |
return self.refine_pose_by_matching(q_frame=q_frame) | |
elif refinement_method == 'projection': | |
return self.refine_pose_by_projection(q_frame=q_frame) | |
else: | |
raise NotImplementedError | |
def refine_pose_by_matching(self, q_frame): | |
ref_frame_id = q_frame.reference_frame_id | |
db_ids = self.covisible_graph[ref_frame_id] | |
print('Find {} covisible frames'.format(len(db_ids))) | |
loc_success = q_frame.tracking_status | |
if loc_success and ref_frame_id in db_ids: | |
init_kpts = q_frame.matched_keypoints | |
init_kpt_ids = q_frame.matched_keypoint_ids | |
init_point3D_ids = q_frame.matched_point3D_ids | |
init_xyzs = np.array([self.point3Ds[v].xyz for v in init_point3D_ids]).reshape(-1, 3) | |
list(db_ids).remove(ref_frame_id) | |
else: | |
init_kpts = None | |
init_xyzs = None | |
init_point3D_ids = None | |
matched_xyzs = [] | |
matched_kpts = [] | |
matched_point3D_ids = [] | |
matched_kpt_ids = [] | |
for idx, frame_id in enumerate(db_ids): | |
ref_data = self.reference_frames[frame_id].get_keypoints() | |
match_out = self.match(query_data={ | |
'keypoints': q_frame.keypoints[:, :2], | |
'scores': q_frame.keypoints[:, 2], | |
'descriptors': q_frame.descriptors, | |
'camera': q_frame.camera, }, | |
ref_data=ref_data) | |
if match_out['matched_keypoints'].shape[0] > 0: | |
matched_kpts.append(match_out['matched_keypoints']) | |
matched_xyzs.append(match_out['matched_xyzs']) | |
matched_point3D_ids.append(match_out['matched_point3D_ids']) | |
matched_kpt_ids.append(match_out['matched_keypoint_ids']) | |
if len(matched_kpts) > 1: | |
matched_kpts = np.vstack(matched_kpts) | |
matched_xyzs = np.vstack(matched_xyzs).reshape(-1, 3) | |
matched_point3D_ids = np.hstack(matched_point3D_ids) | |
matched_kpt_ids = np.hstack(matched_kpt_ids) | |
else: | |
matched_kpts = matched_kpts[0] | |
matched_xyzs = matched_xyzs[0] | |
matched_point3D_ids = matched_point3D_ids[0] | |
matched_kpt_ids = matched_kpt_ids[0] | |
if init_kpts is not None and init_kpts.shape[0] > 0: | |
matched_kpts = np.vstack([matched_kpts, init_kpts]) | |
matched_xyzs = np.vstack([matched_xyzs, init_xyzs]) | |
matched_point3D_ids = np.hstack([matched_point3D_ids, init_point3D_ids]) | |
matched_kpt_ids = np.hstack([matched_kpt_ids, init_kpt_ids]) | |
matched_sids = np.array([self.point3Ds[v].seg_id for v in matched_point3D_ids]) | |
print_text = 'Refinement by matching. Get {:d} covisible frames with {:d} matches for optimization'.format( | |
len(db_ids), matched_xyzs.shape[0]) | |
print(print_text) | |
t_start = time.time() | |
ret = pycolmap.absolute_pose_estimation(matched_kpts + 0.5, | |
matched_xyzs, | |
q_frame.camera, | |
estimation_options={ | |
'ransac': { | |
'max_error': self.config['localization']['threshold'], | |
'min_num_trials': 1000, | |
'max_num_trials': 10000, | |
'confidence': 0.995, | |
}}, | |
refinement_options={}, | |
# max_error_px=self.config['localization']['threshold'], | |
# min_num_trials=1000, max_num_trials=10000, confidence=0.995) | |
) | |
print('Time of RANSAC: {:.2f}s'.format(time.time() - t_start)) | |
if ret is None: | |
ret = {'success': False, } | |
else: | |
ret['success'] = True | |
ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]] | |
ret['tvec'] = ret['cam_from_world'].translation | |
ret['matched_keypoints'] = matched_kpts | |
ret['matched_keypoint_ids'] = matched_kpt_ids | |
ret['matched_xyzs'] = matched_xyzs | |
ret['matched_point3D_ids'] = matched_point3D_ids | |
ret['matched_sids'] = matched_sids | |
if ret['success']: | |
inlier_mask = np.array(ret['inliers']) | |
best_reference_frame_ids = self.find_reference_frames(matched_point3D_ids=matched_point3D_ids[inlier_mask], | |
candidate_frame_ids=self.covisible_graph.keys()) | |
else: | |
best_reference_frame_ids = self.find_reference_frames(matched_point3D_ids=matched_point3D_ids, | |
candidate_frame_ids=self.covisible_graph.keys()) | |
ret['refinement_reference_frame_ids'] = best_reference_frame_ids[:self.config['localization'][ | |
'covisibility_frame']] | |
ret['reference_frame_id'] = best_reference_frame_ids[0] | |
return ret | |
def refine_pose_by_projection(self, q_frame): | |
q_Rcw = qvec2rotmat(q_frame.qvec) | |
q_tcw = q_frame.tvec | |
q_Tcw = np.eye(4, dtype=float) # [4 4] | |
q_Tcw[:3, :3] = q_Rcw | |
q_Tcw[:3, 3] = q_tcw | |
cam = q_frame.camera | |
imw = cam.width | |
imh = cam.height | |
K = q_frame.get_intrinsics() # [3, 3] | |
reference_frame_id = q_frame.reference_frame_id | |
covis_frame_ids = self.covisible_graph[reference_frame_id] | |
if reference_frame_id not in covis_frame_ids: | |
covis_frame_ids.append(reference_frame_id) | |
all_point3D_ids = [] | |
for frame_id in covis_frame_ids: | |
all_point3D_ids.extend(list(self.reference_frames[frame_id].point3D_ids)) | |
all_point3D_ids = np.unique(all_point3D_ids) | |
all_xyzs = [] | |
all_descs = [] | |
all_sids = [] | |
for pid in all_point3D_ids: | |
all_xyzs.append(self.point3Ds[pid].xyz) | |
all_descs.append(self.point3Ds[pid].descriptor) | |
all_sids.append(self.point3Ds[pid].seg_id) | |
all_xyzs = np.array(all_xyzs) # [N 3] | |
all_descs = np.array(all_descs) # [N 3] | |
all_point3D_ids = np.array(all_point3D_ids) | |
all_sids = np.array(all_sids) | |
# move to gpu (distortion is not included) | |
# proj_uv = pycolmap.camera.img_from_cam( | |
# np.array([1, 1, 1]).reshape(1, 3), | |
# ) | |
all_xyzs_cuda = torch.from_numpy(all_xyzs).cuda() | |
ones = torch.ones(size=(all_xyzs_cuda.shape[0], 1), dtype=all_xyzs_cuda.dtype).cuda() | |
all_xyzs_cuda_homo = torch.cat([all_xyzs_cuda, ones], dim=1) # [N 4] | |
K_cuda = torch.from_numpy(K).cuda() | |
proj_uvs = K_cuda @ (torch.from_numpy(q_Tcw).cuda() @ all_xyzs_cuda_homo.t())[:3, :] # [3, N] | |
proj_uvs[0] /= proj_uvs[2] | |
proj_uvs[1] /= proj_uvs[2] | |
mask = (proj_uvs[2] > 0) * (proj_uvs[2] < 100) * (proj_uvs[0] >= 0) * (proj_uvs[0] < imw) * ( | |
proj_uvs[1] >= 0) * (proj_uvs[1] < imh) | |
proj_uvs = proj_uvs[:, mask] | |
print('Projection: out of range {:d}/{:d}'.format(all_xyzs_cuda.shape[0], proj_uvs.shape[1])) | |
mxyzs = all_xyzs[mask.cpu().numpy()] | |
mpoint3D_ids = all_point3D_ids[mask.cpu().numpy()] | |
msids = all_sids[mask.cpu().numpy()] | |
q_kpts_cuda = torch.from_numpy(q_frame.keypoints[:, :2]).cuda() | |
proj_error = q_kpts_cuda[..., None] - proj_uvs[:2][None] | |
proj_error = torch.sqrt(torch.sum(proj_error ** 2, dim=1)) # [M N] | |
out_of_range_mask = (proj_error >= 2 * self.config['localization']['threshold']) | |
q_descs_cuda = torch.from_numpy(q_frame.descriptors).cuda().float() # [M D] | |
all_descs_cuda = torch.from_numpy(all_descs).cuda().float()[mask] # [N D] | |
desc_dist = torch.sqrt(2 - 2 * q_descs_cuda @ all_descs_cuda.t() + 1e-6) | |
desc_dist[out_of_range_mask] = desc_dist[out_of_range_mask] + 100 | |
dists, ids = torch.topk(desc_dist, k=2, largest=False, dim=1) | |
# apply nn ratio | |
ratios = dists[:, 0] / dists[:, 1] # smaller, better | |
ratio_mask = (ratios <= 0.995) * (dists[:, 0] < 100) | |
ratio_mask = ratio_mask.cpu().numpy() | |
ids = ids.cpu().numpy()[ratio_mask, 0] | |
ratio_num = torch.sum(ratios <= 0.995) | |
proj_num = torch.sum(dists[:, 0] < 100) | |
print('Projection: after ratio {:d}/{:d}, ratio {:d}, proj {:d}'.format(q_kpts_cuda.shape[0], | |
np.sum(ratio_mask), | |
ratio_num, proj_num)) | |
mkpts = q_frame.keypoints[ratio_mask] | |
mkpt_ids = np.where(ratio_mask)[0] | |
mxyzs = mxyzs[ids] | |
mpoint3D_ids = mpoint3D_ids[ids] | |
msids = msids[ids] | |
print('projection: ', mkpts.shape, mkpt_ids.shape, mxyzs.shape, mpoint3D_ids.shape, msids.shape) | |
t_start = time.time() | |
ret = pycolmap.absolute_pose_estimation(mkpts[:, :2] + 0.5, mxyzs, q_frame.camera, | |
estimation_options={ | |
"ransac": {"max_error": self.config['localization']['threshold']}}, | |
refinement_options={}, | |
# max_error_px=self.config['localization']['threshold'] | |
) | |
if ret is None: | |
ret = {'success': False, } | |
else: | |
ret['success'] = True | |
ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]] | |
ret['tvec'] = ret['cam_from_world'].translation | |
# inlier_mask = np.ones(shape=(mkpts.shape[0],), dtype=bool).tolist() | |
# ret = pycolmap.pose_refinement(q_frame.tvec, q_frame.qvec, mkpts[:, :2] + 0.5, mxyzs, inlier_mask, cfg) | |
# ret['num_inliers'] = np.sum(inlier_mask).astype(int) | |
# ret['inliers'] = np.array(inlier_mask) | |
print_text = 'Refinement by projection. Get {:d} inliers of {:d} matches for optimization'.format( | |
ret['num_inliers'], mxyzs.shape[0]) | |
print(print_text) | |
print('Time of RANSAC: {:.2f}s'.format(time.time() - t_start)) | |
ret['matched_keypoints'] = mkpts | |
ret['matched_xyzs'] = mxyzs | |
ret['matched_point3D_ids'] = mpoint3D_ids | |
ret['matched_sids'] = msids | |
ret['matched_keypoint_ids'] = mkpt_ids | |
if ret['success']: | |
inlier_mask = np.array(ret['inliers']) | |
best_reference_frame_ids = self.find_reference_frames(matched_point3D_ids=mpoint3D_ids[inlier_mask], | |
candidate_frame_ids=self.covisible_graph.keys()) | |
else: | |
best_reference_frame_ids = self.find_reference_frames(matched_point3D_ids=mpoint3D_ids, | |
candidate_frame_ids=self.covisible_graph.keys()) | |
ret['refinement_reference_frame_ids'] = best_reference_frame_ids[:self.config['localization'][ | |
'covisibility_frame']] | |
ret['reference_frame_id'] = best_reference_frame_ids[0] | |
if not ret['success']: | |
ret['num_inliers'] = 0 | |
ret['inliers'] = np.zeros(shape=(mkpts.shape[0],), dtype=bool) | |
return ret | |
def find_reference_frames(self, matched_point3D_ids, candidate_frame_ids=None): | |
covis_frames = defaultdict(int) | |
for pid in matched_point3D_ids: | |
for im_id in self.point3Ds[pid].frame_ids: | |
if candidate_frame_ids is not None and im_id in candidate_frame_ids: | |
covis_frames[im_id] += 1 | |
covis_ids = np.array(list(covis_frames.keys())) | |
covis_num = np.array([covis_frames[i] for i in covis_ids]) | |
sorted_idxes = np.argsort(covis_num)[::-1] # larger to small | |
sorted_frame_ids = covis_ids[sorted_idxes] | |
return sorted_frame_ids | |
def check_semantic_consistency(self, q_frame: Frame, sid, overlap_ratio=0.5): | |
ref_frame_id = self.seg_ref_frame_ids[sid][0] | |
ref_frame = self.reference_frames[ref_frame_id] | |
q_sids = q_frame.seg_ids | |
ref_sids = np.array([self.point3Ds[v].seg_id for v in ref_frame.point3D_ids]) + self.start_sid | |
overlap_sids = np.intersect1d(q_sids, ref_sids) | |
overlap_num1 = 0 | |
overlap_num2 = 0 | |
for sid in overlap_sids: | |
overlap_num1 += np.sum(q_sids == sid) | |
overlap_num2 += np.sum(ref_sids == sid) | |
ratio1 = overlap_num1 / q_sids.shape[0] | |
ratio2 = overlap_num2 / ref_sids.shape[0] | |
# print('semantic_check: ', overlap_sids, overlap_num1, ratio1, overlap_num2, ratio2) | |
return min(ratio1, ratio2) >= overlap_ratio | |