Realcat
fix: eloftr
63f3cf2
raw
history blame
16 kB
# -*- coding: UTF-8 -*-
'''=================================================
@Project -> File pram -> tracker
@IDE PyCharm
@Author fx221@cam.ac.uk
@Date 29/02/2024 16:58
=================================================='''
import time
import cv2
import numpy as np
import torch
import pycolmap
from localization.frame import Frame
from localization.base_model import dynamic_load
import localization.matchers as matchers
from localization.match_features_batch import confs as matcher_confs
from recognition.vis_seg import vis_seg_point, generate_color_dic, vis_inlier, plot_matches
from tools.common import resize_img
class Tracker:
def __init__(self, locMap, matcher, config):
self.locMap = locMap
self.matcher = matcher
self.config = config
self.loc_config = config['localization']
self.lost = True
self.curr_frame = None
self.last_frame = None
device = 'cuda' if torch.cuda.is_available() else 'cpu'
Model = dynamic_load(matchers, 'nearest_neighbor')
self.nn_matcher = Model(matcher_confs['NNM']['model']).eval().to(device)
def run(self, frame: Frame):
print('Start tracking...')
show = self.config['localization']['show']
self.curr_frame = frame
ref_img = self.last_frame.image
curr_img = self.curr_frame.image
q_kpts = frame.keypoints
t_start = time.time()
ret = self.track_last_frame(curr_frame=self.curr_frame, last_frame=self.last_frame)
self.curr_frame.time_loc = self.curr_frame.time_loc + time.time() - t_start
if show:
curr_matched_kpts = ret['matched_keypoints']
ref_matched_kpts = ret['matched_ref_keypoints']
img_loc_matching = plot_matches(img1=curr_img, img2=ref_img,
pts1=curr_matched_kpts,
pts2=ref_matched_kpts,
inliers=np.array([True for i in range(curr_matched_kpts.shape[0])]),
radius=9, line_thickness=3)
self.curr_frame.image_matching = img_loc_matching
q_ref_img_matching = resize_img(img_loc_matching, nh=512)
if not ret['success']:
show_text = 'Tracking FAILED!'
img_inlier = vis_inlier(img=curr_img, kpts=curr_matched_kpts,
inliers=[False for i in range(curr_matched_kpts.shape[0])], radius=9 + 2,
thickness=2)
q_img_inlier = cv2.putText(img=img_inlier, text=show_text, org=(30, 30),
fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255),
thickness=2, lineType=cv2.LINE_AA)
q_img_loc = np.hstack([resize_img(q_ref_img_matching, nh=512), resize_img(q_img_inlier, nh=512)])
cv2.imshow('loc', q_img_loc)
key = cv2.waitKey(self.loc_config['show_time'])
if key == ord('q'):
cv2.destroyAllWindows()
exit(0)
return False
ret['matched_scene_name'] = self.last_frame.scene_name
success = self.verify_and_update(q_frame=self.curr_frame, ret=ret)
if not success:
return False
if ret['num_inliers'] < 256:
# refinement is necessary for tracking last frame
t_start = time.time()
ret = self.locMap.sub_maps[self.last_frame.matched_scene_name].refine_pose(self.curr_frame,
refinement_method=
self.loc_config[
'refinement_method'])
self.curr_frame.time_ref = self.curr_frame.time_ref + time.time() - t_start
ret['matched_scene_name'] = self.last_frame.scene_name
success = self.verify_and_update(q_frame=self.curr_frame, ret=ret)
if show:
q_err, t_err = self.curr_frame.compute_pose_error()
num_matches = ret['matched_keypoints'].shape[0]
num_inliers = ret['num_inliers']
show_text = 'Tracking, k/m/i: {:d}/{:d}/{:d}'.format(q_kpts.shape[0], num_matches, num_inliers)
q_img_inlier = vis_inlier(img=curr_img, kpts=ret['matched_keypoints'], inliers=ret['inliers'],
radius=9 + 2, thickness=2)
q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 30),
fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255),
thickness=2, lineType=cv2.LINE_AA)
show_text = 'r_err:{:.2f}, t_err:{:.2f}'.format(q_err, t_err)
q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 80),
fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255),
thickness=2, lineType=cv2.LINE_AA)
self.curr_frame.image_inlier = q_img_inlier
q_img_loc = np.hstack([resize_img(q_ref_img_matching, nh=512), resize_img(q_img_inlier, nh=512)])
cv2.imshow('loc', q_img_loc)
key = cv2.waitKey(self.loc_config['show_time'])
if key == ord('q'):
cv2.destroyAllWindows()
exit(0)
self.lost = success
return success
def verify_and_update(self, q_frame: Frame, ret: dict):
num_matches = ret['matched_keypoints'].shape[0]
num_inliers = ret['num_inliers']
q_frame.qvec = ret['qvec']
q_frame.tvec = ret['tvec']
q_err, t_err = q_frame.compute_pose_error()
if num_inliers < self.loc_config['min_inliers']:
print_text = 'Failed due to insufficient {:d} inliers, q_err: {:.2f}, t_err: {:.2f}'.format(
ret['num_inliers'], q_err, t_err)
print(print_text)
q_frame.tracking_status = False
q_frame.clear_localization_track()
return False
else:
print_text = 'Succeed! Find {}/{} 2D-3D inliers,q_err: {:.2f}, t_err: {:.2f}'.format(
num_inliers, num_matches, q_err, t_err)
print(print_text)
q_frame.tracking_status = True
self.update_current_frame(curr_frame=q_frame, ret=ret)
return True
def update_current_frame(self, curr_frame: Frame, ret: dict):
curr_frame.qvec = ret['qvec']
curr_frame.tvec = ret['tvec']
curr_frame.matched_scene_name = ret['matched_scene_name']
curr_frame.reference_frame_id = ret['reference_frame_id']
inliers = np.array(ret['inliers'])
curr_frame.matched_keypoints = ret['matched_keypoints'][inliers]
curr_frame.matched_xyzs = ret['matched_xyzs'][inliers]
curr_frame.matched_point3D_ids = ret['matched_point3D_ids'][inliers]
curr_frame.matched_keypoint_ids = ret['matched_keypoint_ids'][inliers]
curr_frame.matched_sids = ret['matched_sids'][inliers]
def track_last_frame(self, curr_frame: Frame, last_frame: Frame):
curr_kpts = curr_frame.keypoints[:, :2]
curr_scores = curr_frame.keypoints[:, 2]
curr_descs = curr_frame.descriptors
curr_kpt_ids = np.arange(curr_kpts.shape[0])
last_kpts = last_frame.keypoints[:, :2]
last_scores = last_frame.keypoints[:, 2]
last_descs = last_frame.descriptors
last_xyzs = last_frame.xyzs
last_point3D_ids = last_frame.point3D_ids
last_sids = last_frame.seg_ids
# '''
indices = self.matcher({
'descriptors0': torch.from_numpy(curr_descs)[None].cuda().float(),
'keypoints0': torch.from_numpy(curr_kpts)[None].cuda().float(),
'scores0': torch.from_numpy(curr_scores)[None].cuda().float(),
'image_shape0': (1, 3, curr_frame.camera.width, curr_frame.camera.height),
'descriptors1': torch.from_numpy(last_descs)[None].cuda().float(),
'keypoints1': torch.from_numpy(last_kpts)[None].cuda().float(),
'scores1': torch.from_numpy(last_scores)[None].cuda().float(),
'image_shape1': (1, 3, last_frame.camera.width, last_frame.camera.height),
})['matches0'][0].cpu().numpy()
'''
indices = self.nn_matcher({
'descriptors0': torch.from_numpy(curr_descs.transpose()).float().cuda()[None],
'descriptors1': torch.from_numpy(last_descs.transpose()).float().cuda()[None],
})['matches0'][0].cpu().numpy()
'''
valid = (indices >= 0)
matched_point3D_ids = last_point3D_ids[indices[valid]]
point3D_mask = (matched_point3D_ids >= 0)
matched_point3D_ids = matched_point3D_ids[point3D_mask]
matched_sids = last_sids[indices[valid]][point3D_mask]
matched_kpts = curr_kpts[valid][point3D_mask]
matched_kpt_ids = curr_kpt_ids[valid][point3D_mask]
matched_xyzs = last_xyzs[indices[valid]][point3D_mask]
matched_last_kpts = last_kpts[indices[valid]][point3D_mask]
print('Tracking: {:d} matches from {:d}-{:d} kpts'.format(matched_kpts.shape[0], curr_kpts.shape[0],
last_kpts.shape[0]))
# print('tracking: ', matched_kpts.shape, matched_xyzs.shape)
ret = pycolmap.absolute_pose_estimation(matched_kpts + 0.5, matched_xyzs,
curr_frame.camera,
estimation_options={
"ransac": {"max_error": self.config['localization']['threshold']}},
refinement_options={},
# max_error_px=self.config['localization']['threshold']
)
if ret is None:
ret = {'success': False, }
else:
ret['success'] = True
ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]]
ret['tvec'] = ret['cam_from_world'].translation
ret['matched_keypoints'] = matched_kpts
ret['matched_keypoint_ids'] = matched_kpt_ids
ret['matched_ref_keypoints'] = matched_last_kpts
ret['matched_xyzs'] = matched_xyzs
ret['matched_point3D_ids'] = matched_point3D_ids
ret['matched_sids'] = matched_sids
ret['reference_frame_id'] = last_frame.reference_frame_id
ret['matched_scene_name'] = last_frame.matched_scene_name
return ret
def track_last_frame_fast(self, curr_frame: Frame, last_frame: Frame):
curr_kpts = curr_frame.keypoints[:, :2]
curr_scores = curr_frame.keypoints[:, 2]
curr_descs = curr_frame.descriptors
curr_kpt_ids = np.arange(curr_kpts.shape[0])
last_point3D_ids = last_frame.point3D_ids
point3D_mask = (last_point3D_ids >= 0)
last_kpts = last_frame.keypoints[:, :2][point3D_mask]
last_scores = last_frame.keypoints[:, 2][point3D_mask]
last_descs = last_frame.descriptors[point3D_mask]
last_xyzs = last_frame.xyzs[point3D_mask]
last_sids = last_frame.seg_ids[point3D_mask]
minx = np.min(last_kpts[:, 0])
maxx = np.max(last_kpts[:, 0])
miny = np.min(last_kpts[:, 1])
maxy = np.max(last_kpts[:, 1])
curr_mask = (curr_kpts[:, 0] >= minx) * (curr_kpts[:, 0] <= maxx) * (curr_kpts[:, 1] >= miny) * (
curr_kpts[:, 1] <= maxy)
curr_kpts = curr_kpts[curr_mask]
curr_scores = curr_scores[curr_mask]
curr_descs = curr_descs[curr_mask]
curr_kpt_ids = curr_kpt_ids[curr_mask]
# '''
indices = self.matcher({
'descriptors0': torch.from_numpy(curr_descs)[None].cuda().float(),
'keypoints0': torch.from_numpy(curr_kpts)[None].cuda().float(),
'scores0': torch.from_numpy(curr_scores)[None].cuda().float(),
'image_shape0': (1, 3, curr_frame.camera.width, curr_frame.camera.height),
'descriptors1': torch.from_numpy(last_descs)[None].cuda().float(),
'keypoints1': torch.from_numpy(last_kpts)[None].cuda().float(),
'scores1': torch.from_numpy(last_scores)[None].cuda().float(),
'image_shape1': (1, 3, last_frame.camera.width, last_frame.camera.height),
})['matches0'][0].cpu().numpy()
'''
indices = self.nn_matcher({
'descriptors0': torch.from_numpy(curr_descs.transpose()).float().cuda()[None],
'descriptors1': torch.from_numpy(last_descs.transpose()).float().cuda()[None],
})['matches0'][0].cpu().numpy()
'''
valid = (indices >= 0)
matched_point3D_ids = last_point3D_ids[indices[valid]]
matched_sids = last_sids[indices[valid]]
matched_kpts = curr_kpts[valid]
matched_kpt_ids = curr_kpt_ids[valid]
matched_xyzs = last_xyzs[indices[valid]]
matched_last_kpts = last_kpts[indices[valid]]
print('Tracking: {:d} matches from {:d}-{:d} kpts'.format(matched_kpts.shape[0], curr_kpts.shape[0],
last_kpts.shape[0]))
# print('tracking: ', matched_kpts.shape, matched_xyzs.shape)
ret = pycolmap.absolute_pose_estimation(matched_kpts + 0.5, matched_xyzs,
curr_frame.camera._asdict(),
max_error_px=self.config['localization']['threshold'])
ret['matched_keypoints'] = matched_kpts
ret['matched_keypoint_ids'] = matched_kpt_ids
ret['matched_ref_keypoints'] = matched_last_kpts
ret['matched_xyzs'] = matched_xyzs
ret['matched_point3D_ids'] = matched_point3D_ids
ret['matched_sids'] = matched_sids
ret['reference_frame_id'] = last_frame.reference_frame_id
ret['matched_scene_name'] = last_frame.matched_scene_name
return ret
@torch.no_grad()
def match_frame(self, frame: Frame, reference_frame: Frame):
print('match: ', frame.keypoints.shape, reference_frame.keypoints.shape)
matches = self.matcher({
'descriptors0': torch.from_numpy(frame.descriptors)[None].cuda().float(),
'keypoints0': torch.from_numpy(frame.keypoints[:, :2])[None].cuda().float(),
'scores0': torch.from_numpy(frame.keypoints[:, 2])[None].cuda().float(),
'image_shape0': (1, 3, frame.image_size[0], frame.image_size[1]),
# 'descriptors0': torch.from_numpy(reference_frame.descriptors)[None].cuda().float(),
# 'keypoints0': torch.from_numpy(reference_frame.keypoints[:, :2])[None].cuda().float(),
# 'scores0': torch.from_numpy(reference_frame.keypoints[:, 2])[None].cuda().float(),
# 'image_shape0': (1, 3, reference_frame.image_size[0], reference_frame.image_size[1]),
'descriptors1': torch.from_numpy(reference_frame.descriptors)[None].cuda().float(),
'keypoints1': torch.from_numpy(reference_frame.keypoints[:, :2])[None].cuda().float(),
'scores1': torch.from_numpy(reference_frame.keypoints[:, 2])[None].cuda().float(),
'image_shape1': (1, 3, reference_frame.image_size[0], reference_frame.image_size[1]),
})['matches0'][0].cpu().numpy()
ids1 = np.arange(matches.shape[0])
ids2 = matches
ids1 = ids1[matches >= 0]
ids2 = ids2[matches >= 0]
mask_p3ds = reference_frame.points3d_mask[ids2]
ids1 = ids1[mask_p3ds]
ids2 = ids2[mask_p3ds]
return ids1, ids2