Spaces:
Running
Running
# -*- coding: UTF-8 -*- | |
'''================================================= | |
@Project -> File pram -> tracker | |
@IDE PyCharm | |
@Author fx221@cam.ac.uk | |
@Date 29/02/2024 16:58 | |
==================================================''' | |
import time | |
import cv2 | |
import numpy as np | |
import torch | |
import pycolmap | |
from localization.frame import Frame | |
from localization.base_model import dynamic_load | |
import localization.matchers as matchers | |
from localization.match_features_batch import confs as matcher_confs | |
from recognition.vis_seg import vis_seg_point, generate_color_dic, vis_inlier, plot_matches | |
from tools.common import resize_img | |
class Tracker: | |
def __init__(self, locMap, matcher, config): | |
self.locMap = locMap | |
self.matcher = matcher | |
self.config = config | |
self.loc_config = config['localization'] | |
self.lost = True | |
self.curr_frame = None | |
self.last_frame = None | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
Model = dynamic_load(matchers, 'nearest_neighbor') | |
self.nn_matcher = Model(matcher_confs['NNM']['model']).eval().to(device) | |
def run(self, frame: Frame): | |
print('Start tracking...') | |
show = self.config['localization']['show'] | |
self.curr_frame = frame | |
ref_img = self.last_frame.image | |
curr_img = self.curr_frame.image | |
q_kpts = frame.keypoints | |
t_start = time.time() | |
ret = self.track_last_frame(curr_frame=self.curr_frame, last_frame=self.last_frame) | |
self.curr_frame.time_loc = self.curr_frame.time_loc + time.time() - t_start | |
if show: | |
curr_matched_kpts = ret['matched_keypoints'] | |
ref_matched_kpts = ret['matched_ref_keypoints'] | |
img_loc_matching = plot_matches(img1=curr_img, img2=ref_img, | |
pts1=curr_matched_kpts, | |
pts2=ref_matched_kpts, | |
inliers=np.array([True for i in range(curr_matched_kpts.shape[0])]), | |
radius=9, line_thickness=3) | |
self.curr_frame.image_matching = img_loc_matching | |
q_ref_img_matching = resize_img(img_loc_matching, nh=512) | |
if not ret['success']: | |
show_text = 'Tracking FAILED!' | |
img_inlier = vis_inlier(img=curr_img, kpts=curr_matched_kpts, | |
inliers=[False for i in range(curr_matched_kpts.shape[0])], radius=9 + 2, | |
thickness=2) | |
q_img_inlier = cv2.putText(img=img_inlier, text=show_text, org=(30, 30), | |
fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), | |
thickness=2, lineType=cv2.LINE_AA) | |
q_img_loc = np.hstack([resize_img(q_ref_img_matching, nh=512), resize_img(q_img_inlier, nh=512)]) | |
cv2.imshow('loc', q_img_loc) | |
key = cv2.waitKey(self.loc_config['show_time']) | |
if key == ord('q'): | |
cv2.destroyAllWindows() | |
exit(0) | |
return False | |
ret['matched_scene_name'] = self.last_frame.scene_name | |
success = self.verify_and_update(q_frame=self.curr_frame, ret=ret) | |
if not success: | |
return False | |
if ret['num_inliers'] < 256: | |
# refinement is necessary for tracking last frame | |
t_start = time.time() | |
ret = self.locMap.sub_maps[self.last_frame.matched_scene_name].refine_pose(self.curr_frame, | |
refinement_method= | |
self.loc_config[ | |
'refinement_method']) | |
self.curr_frame.time_ref = self.curr_frame.time_ref + time.time() - t_start | |
ret['matched_scene_name'] = self.last_frame.scene_name | |
success = self.verify_and_update(q_frame=self.curr_frame, ret=ret) | |
if show: | |
q_err, t_err = self.curr_frame.compute_pose_error() | |
num_matches = ret['matched_keypoints'].shape[0] | |
num_inliers = ret['num_inliers'] | |
show_text = 'Tracking, k/m/i: {:d}/{:d}/{:d}'.format(q_kpts.shape[0], num_matches, num_inliers) | |
q_img_inlier = vis_inlier(img=curr_img, kpts=ret['matched_keypoints'], inliers=ret['inliers'], | |
radius=9 + 2, thickness=2) | |
q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 30), | |
fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), | |
thickness=2, lineType=cv2.LINE_AA) | |
show_text = 'r_err:{:.2f}, t_err:{:.2f}'.format(q_err, t_err) | |
q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 80), | |
fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), | |
thickness=2, lineType=cv2.LINE_AA) | |
self.curr_frame.image_inlier = q_img_inlier | |
q_img_loc = np.hstack([resize_img(q_ref_img_matching, nh=512), resize_img(q_img_inlier, nh=512)]) | |
cv2.imshow('loc', q_img_loc) | |
key = cv2.waitKey(self.loc_config['show_time']) | |
if key == ord('q'): | |
cv2.destroyAllWindows() | |
exit(0) | |
self.lost = success | |
return success | |
def verify_and_update(self, q_frame: Frame, ret: dict): | |
num_matches = ret['matched_keypoints'].shape[0] | |
num_inliers = ret['num_inliers'] | |
q_frame.qvec = ret['qvec'] | |
q_frame.tvec = ret['tvec'] | |
q_err, t_err = q_frame.compute_pose_error() | |
if num_inliers < self.loc_config['min_inliers']: | |
print_text = 'Failed due to insufficient {:d} inliers, q_err: {:.2f}, t_err: {:.2f}'.format( | |
ret['num_inliers'], q_err, t_err) | |
print(print_text) | |
q_frame.tracking_status = False | |
q_frame.clear_localization_track() | |
return False | |
else: | |
print_text = 'Succeed! Find {}/{} 2D-3D inliers,q_err: {:.2f}, t_err: {:.2f}'.format( | |
num_inliers, num_matches, q_err, t_err) | |
print(print_text) | |
q_frame.tracking_status = True | |
self.update_current_frame(curr_frame=q_frame, ret=ret) | |
return True | |
def update_current_frame(self, curr_frame: Frame, ret: dict): | |
curr_frame.qvec = ret['qvec'] | |
curr_frame.tvec = ret['tvec'] | |
curr_frame.matched_scene_name = ret['matched_scene_name'] | |
curr_frame.reference_frame_id = ret['reference_frame_id'] | |
inliers = np.array(ret['inliers']) | |
curr_frame.matched_keypoints = ret['matched_keypoints'][inliers] | |
curr_frame.matched_xyzs = ret['matched_xyzs'][inliers] | |
curr_frame.matched_point3D_ids = ret['matched_point3D_ids'][inliers] | |
curr_frame.matched_keypoint_ids = ret['matched_keypoint_ids'][inliers] | |
curr_frame.matched_sids = ret['matched_sids'][inliers] | |
def track_last_frame(self, curr_frame: Frame, last_frame: Frame): | |
curr_kpts = curr_frame.keypoints[:, :2] | |
curr_scores = curr_frame.keypoints[:, 2] | |
curr_descs = curr_frame.descriptors | |
curr_kpt_ids = np.arange(curr_kpts.shape[0]) | |
last_kpts = last_frame.keypoints[:, :2] | |
last_scores = last_frame.keypoints[:, 2] | |
last_descs = last_frame.descriptors | |
last_xyzs = last_frame.xyzs | |
last_point3D_ids = last_frame.point3D_ids | |
last_sids = last_frame.seg_ids | |
# ''' | |
indices = self.matcher({ | |
'descriptors0': torch.from_numpy(curr_descs)[None].cuda().float(), | |
'keypoints0': torch.from_numpy(curr_kpts)[None].cuda().float(), | |
'scores0': torch.from_numpy(curr_scores)[None].cuda().float(), | |
'image_shape0': (1, 3, curr_frame.camera.width, curr_frame.camera.height), | |
'descriptors1': torch.from_numpy(last_descs)[None].cuda().float(), | |
'keypoints1': torch.from_numpy(last_kpts)[None].cuda().float(), | |
'scores1': torch.from_numpy(last_scores)[None].cuda().float(), | |
'image_shape1': (1, 3, last_frame.camera.width, last_frame.camera.height), | |
})['matches0'][0].cpu().numpy() | |
''' | |
indices = self.nn_matcher({ | |
'descriptors0': torch.from_numpy(curr_descs.transpose()).float().cuda()[None], | |
'descriptors1': torch.from_numpy(last_descs.transpose()).float().cuda()[None], | |
})['matches0'][0].cpu().numpy() | |
''' | |
valid = (indices >= 0) | |
matched_point3D_ids = last_point3D_ids[indices[valid]] | |
point3D_mask = (matched_point3D_ids >= 0) | |
matched_point3D_ids = matched_point3D_ids[point3D_mask] | |
matched_sids = last_sids[indices[valid]][point3D_mask] | |
matched_kpts = curr_kpts[valid][point3D_mask] | |
matched_kpt_ids = curr_kpt_ids[valid][point3D_mask] | |
matched_xyzs = last_xyzs[indices[valid]][point3D_mask] | |
matched_last_kpts = last_kpts[indices[valid]][point3D_mask] | |
print('Tracking: {:d} matches from {:d}-{:d} kpts'.format(matched_kpts.shape[0], curr_kpts.shape[0], | |
last_kpts.shape[0])) | |
# print('tracking: ', matched_kpts.shape, matched_xyzs.shape) | |
ret = pycolmap.absolute_pose_estimation(matched_kpts + 0.5, matched_xyzs, | |
curr_frame.camera, | |
estimation_options={ | |
"ransac": {"max_error": self.config['localization']['threshold']}}, | |
refinement_options={}, | |
# max_error_px=self.config['localization']['threshold'] | |
) | |
if ret is None: | |
ret = {'success': False, } | |
else: | |
ret['success'] = True | |
ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]] | |
ret['tvec'] = ret['cam_from_world'].translation | |
ret['matched_keypoints'] = matched_kpts | |
ret['matched_keypoint_ids'] = matched_kpt_ids | |
ret['matched_ref_keypoints'] = matched_last_kpts | |
ret['matched_xyzs'] = matched_xyzs | |
ret['matched_point3D_ids'] = matched_point3D_ids | |
ret['matched_sids'] = matched_sids | |
ret['reference_frame_id'] = last_frame.reference_frame_id | |
ret['matched_scene_name'] = last_frame.matched_scene_name | |
return ret | |
def track_last_frame_fast(self, curr_frame: Frame, last_frame: Frame): | |
curr_kpts = curr_frame.keypoints[:, :2] | |
curr_scores = curr_frame.keypoints[:, 2] | |
curr_descs = curr_frame.descriptors | |
curr_kpt_ids = np.arange(curr_kpts.shape[0]) | |
last_point3D_ids = last_frame.point3D_ids | |
point3D_mask = (last_point3D_ids >= 0) | |
last_kpts = last_frame.keypoints[:, :2][point3D_mask] | |
last_scores = last_frame.keypoints[:, 2][point3D_mask] | |
last_descs = last_frame.descriptors[point3D_mask] | |
last_xyzs = last_frame.xyzs[point3D_mask] | |
last_sids = last_frame.seg_ids[point3D_mask] | |
minx = np.min(last_kpts[:, 0]) | |
maxx = np.max(last_kpts[:, 0]) | |
miny = np.min(last_kpts[:, 1]) | |
maxy = np.max(last_kpts[:, 1]) | |
curr_mask = (curr_kpts[:, 0] >= minx) * (curr_kpts[:, 0] <= maxx) * (curr_kpts[:, 1] >= miny) * ( | |
curr_kpts[:, 1] <= maxy) | |
curr_kpts = curr_kpts[curr_mask] | |
curr_scores = curr_scores[curr_mask] | |
curr_descs = curr_descs[curr_mask] | |
curr_kpt_ids = curr_kpt_ids[curr_mask] | |
# ''' | |
indices = self.matcher({ | |
'descriptors0': torch.from_numpy(curr_descs)[None].cuda().float(), | |
'keypoints0': torch.from_numpy(curr_kpts)[None].cuda().float(), | |
'scores0': torch.from_numpy(curr_scores)[None].cuda().float(), | |
'image_shape0': (1, 3, curr_frame.camera.width, curr_frame.camera.height), | |
'descriptors1': torch.from_numpy(last_descs)[None].cuda().float(), | |
'keypoints1': torch.from_numpy(last_kpts)[None].cuda().float(), | |
'scores1': torch.from_numpy(last_scores)[None].cuda().float(), | |
'image_shape1': (1, 3, last_frame.camera.width, last_frame.camera.height), | |
})['matches0'][0].cpu().numpy() | |
''' | |
indices = self.nn_matcher({ | |
'descriptors0': torch.from_numpy(curr_descs.transpose()).float().cuda()[None], | |
'descriptors1': torch.from_numpy(last_descs.transpose()).float().cuda()[None], | |
})['matches0'][0].cpu().numpy() | |
''' | |
valid = (indices >= 0) | |
matched_point3D_ids = last_point3D_ids[indices[valid]] | |
matched_sids = last_sids[indices[valid]] | |
matched_kpts = curr_kpts[valid] | |
matched_kpt_ids = curr_kpt_ids[valid] | |
matched_xyzs = last_xyzs[indices[valid]] | |
matched_last_kpts = last_kpts[indices[valid]] | |
print('Tracking: {:d} matches from {:d}-{:d} kpts'.format(matched_kpts.shape[0], curr_kpts.shape[0], | |
last_kpts.shape[0])) | |
# print('tracking: ', matched_kpts.shape, matched_xyzs.shape) | |
ret = pycolmap.absolute_pose_estimation(matched_kpts + 0.5, matched_xyzs, | |
curr_frame.camera._asdict(), | |
max_error_px=self.config['localization']['threshold']) | |
ret['matched_keypoints'] = matched_kpts | |
ret['matched_keypoint_ids'] = matched_kpt_ids | |
ret['matched_ref_keypoints'] = matched_last_kpts | |
ret['matched_xyzs'] = matched_xyzs | |
ret['matched_point3D_ids'] = matched_point3D_ids | |
ret['matched_sids'] = matched_sids | |
ret['reference_frame_id'] = last_frame.reference_frame_id | |
ret['matched_scene_name'] = last_frame.matched_scene_name | |
return ret | |
def match_frame(self, frame: Frame, reference_frame: Frame): | |
print('match: ', frame.keypoints.shape, reference_frame.keypoints.shape) | |
matches = self.matcher({ | |
'descriptors0': torch.from_numpy(frame.descriptors)[None].cuda().float(), | |
'keypoints0': torch.from_numpy(frame.keypoints[:, :2])[None].cuda().float(), | |
'scores0': torch.from_numpy(frame.keypoints[:, 2])[None].cuda().float(), | |
'image_shape0': (1, 3, frame.image_size[0], frame.image_size[1]), | |
# 'descriptors0': torch.from_numpy(reference_frame.descriptors)[None].cuda().float(), | |
# 'keypoints0': torch.from_numpy(reference_frame.keypoints[:, :2])[None].cuda().float(), | |
# 'scores0': torch.from_numpy(reference_frame.keypoints[:, 2])[None].cuda().float(), | |
# 'image_shape0': (1, 3, reference_frame.image_size[0], reference_frame.image_size[1]), | |
'descriptors1': torch.from_numpy(reference_frame.descriptors)[None].cuda().float(), | |
'keypoints1': torch.from_numpy(reference_frame.keypoints[:, :2])[None].cuda().float(), | |
'scores1': torch.from_numpy(reference_frame.keypoints[:, 2])[None].cuda().float(), | |
'image_shape1': (1, 3, reference_frame.image_size[0], reference_frame.image_size[1]), | |
})['matches0'][0].cpu().numpy() | |
ids1 = np.arange(matches.shape[0]) | |
ids2 = matches | |
ids1 = ids1[matches >= 0] | |
ids2 = ids2[matches >= 0] | |
mask_p3ds = reference_frame.points3d_mask[ids2] | |
ids1 = ids1[mask_p3ds] | |
ids2 = ids2[mask_p3ds] | |
return ids1, ids2 | |