Realcat
fix: eloftr
63f3cf2
raw
history blame
7.67 kB
# -*- coding: UTF-8 -*-
'''=================================================
@Project -> File pram -> frame
@IDE PyCharm
@Author fx221@cam.ac.uk
@Date 01/03/2024 10:08
=================================================='''
from collections import defaultdict
import numpy as np
import torch
import pycolmap
from localization.camera import Camera
from localization.utils import compute_pose_error
class Frame:
def __init__(self, image: np.ndarray, camera: pycolmap.Camera, id: int, name: str = None, qvec=None, tvec=None,
scene_name=None,
reference_frame_id=None):
self.image = image
self.camera = camera
self.id = id
self.name = name
self.image_size = np.array([camera.height, camera.width])
self.qvec = qvec
self.tvec = tvec
self.scene_name = scene_name
self.reference_frame_id = reference_frame_id
self.keypoints = None # [N, 3]
self.descriptors = None # [N, D]
self.segmentations = None # [N C]
self.seg_scores = None # [N C]
self.seg_ids = None # [N, 1]
self.point3D_ids = None # [N, 1]
self.xyzs = None
self.gt_qvec = None
self.gt_tvec = None
self.matched_scene_name = None
self.matched_keypoints = None
self.matched_keypoint_ids = None
self.matched_xyzs = None
self.matched_point3D_ids = None
self.matched_inliers = None
self.matched_sids = None
self.matched_order = None
self.refinement_reference_frame_ids = None
self.image_rec = None
self.image_matching = None
self.image_inlier = None
self.reference_frame_name = None
self.image_matching_tmp = None
self.image_inlier_tmp = None
self.reference_frame_name_tmp = None
self.tracking_status = None
self.time_feat = 0
self.time_rec = 0
self.time_loc = 0
self.time_ref = 0
def update_point3ds_old(self):
pt = torch.from_numpy(self.keypoints[:, :2]).unsqueeze(-1) # [M 2 1]
mpt = torch.from_numpy(self.matched_keypoints[:, :2].transpose()).unsqueeze(0) # [1 2 N]
dist = torch.sqrt(torch.sum((pt - mpt) ** 2, dim=1))
values, ids = torch.topk(dist, dim=1, k=1, largest=False)
values = values[:, 0].numpy()
ids = ids[:, 0].numpy()
mask = (values < 1) # 1 pixel error
self.point3D_ids = np.zeros(shape=(self.keypoints.shape[0],), dtype=int) - 1
self.point3D_ids[mask] = self.matched_point3D_ids[ids[mask]]
# self.xyzs = np.zeros(shape=(self.keypoints.shape[0], 3), dtype=float)
inlier_mask = self.matched_inliers
self.xyzs[mask] = self.matched_xyzs[ids[mask]]
self.seg_ids[mask] = self.matched_sids[ids[mask]]
def update_point3ds(self):
# print('Frame: update_point3ds: ', self.matched_keypoint_ids.shape, self.matched_xyzs.shape,
# self.matched_sids.shape, self.matched_point3D_ids.shape)
self.xyzs[self.matched_keypoint_ids] = self.matched_xyzs
self.seg_ids[self.matched_keypoint_ids] = self.matched_sids
self.point3D_ids[self.matched_keypoint_ids] = self.matched_point3D_ids
def add_keypoints(self, keypoints: np.ndarray, descriptors: np.ndarray):
self.keypoints = keypoints
self.descriptors = descriptors
self.initialize_localization_variables()
def add_segmentations(self, segmentations: torch.Tensor, filtering_threshold: float):
'''
:param segmentations: [number_points number_labels]
:return:
'''
seg_scores = torch.softmax(segmentations, dim=-1)
if filtering_threshold > 0:
scores_background = seg_scores[:, 0]
non_bg_mask = (scores_background < filtering_threshold)
print('pre filtering before: ', self.keypoints.shape)
if torch.sum(non_bg_mask) >= 0.4 * seg_scores.shape[0]:
self.keypoints = self.keypoints[non_bg_mask.cpu().numpy()]
self.descriptors = self.descriptors[non_bg_mask.cpu().numpy()]
# print('pre filtering after: ', self.keypoints.shape)
# update localization variables
self.initialize_localization_variables()
segmentations = segmentations[non_bg_mask]
seg_scores = seg_scores[non_bg_mask]
print('pre filtering after: ', self.keypoints.shape)
# extract initial segmentation info
self.segmentations = segmentations.cpu().numpy()
self.seg_scores = seg_scores.cpu().numpy()
self.seg_ids = segmentations.max(dim=-1)[1].cpu().numpy() - 1 # should start from 0
def filter_keypoints(self, seg_scores: np.ndarray, filtering_threshold: float):
scores_background = seg_scores[:, 0]
non_bg_mask = (scores_background < filtering_threshold)
print('pre filtering before: ', self.keypoints.shape)
if np.sum(non_bg_mask) >= 0.4 * seg_scores.shape[0]:
self.keypoints = self.keypoints[non_bg_mask]
self.descriptors = self.descriptors[non_bg_mask]
print('pre filtering after: ', self.keypoints.shape)
# update localization variables
self.initialize_localization_variables()
return non_bg_mask
else:
print('pre filtering after: ', self.keypoints.shape)
return None
def compute_pose_error(self, pred_qvec=None, pred_tvec=None):
if pred_qvec is not None and pred_tvec is not None:
if self.gt_qvec is not None and self.gt_tvec is not None:
return compute_pose_error(pred_qcw=pred_qvec, pred_tcw=pred_tvec,
gt_qcw=self.gt_qvec, gt_tcw=self.gt_tvec)
else:
return 100, 100
if self.qvec is None or self.tvec is None or self.gt_qvec is None or self.gt_tvec is None:
return 100, 100
else:
err_q, err_t = compute_pose_error(pred_qcw=self.qvec, pred_tcw=self.tvec,
gt_qcw=self.gt_qvec, gt_tcw=self.gt_tvec)
return err_q, err_t
def get_intrinsics(self) -> np.ndarray:
camera_model = self.camera.model.name
params = self.camera.params
if camera_model in ("SIMPLE_PINHOLE", "SIMPLE_RADIAL", "RADIAL"):
fx = fy = params[0]
cx = params[1]
cy = params[2]
elif camera_model in ("PINHOLE", "OPENCV", "OPENCV_FISHEYE", "FULL_OPENCV"):
fx = params[0]
fy = params[1]
cx = params[2]
cy = params[3]
else:
raise Exception("Camera model not supported")
# intrinsics
K = np.identity(3)
K[0, 0] = fx
K[1, 1] = fy
K[0, 2] = cx
K[1, 2] = cy
return K
def get_dominate_seg_id(self):
counts = np.bincount(self.seg_ids[self.seg_ids > 0])
return np.argmax(counts)
def clear_localization_track(self):
self.matched_scene_name = None
self.matched_keypoints = None
self.matched_xyzs = None
self.matched_point3D_ids = None
self.matched_inliers = None
self.matched_sids = None
self.refinement_reference_frame_ids = None
def initialize_localization_variables(self):
nkpt = self.keypoints.shape[0]
self.seg_ids = np.zeros(shape=(nkpt,), dtype=int) - 1
self.point3D_ids = np.zeros(shape=(nkpt,), dtype=int) - 1
self.xyzs = np.zeros(shape=(nkpt, 3), dtype=float)