Spaces:
Running
Running
# -*- coding: UTF-8 -*- | |
'''================================================= | |
@Project -> File pram -> frame | |
@IDE PyCharm | |
@Author fx221@cam.ac.uk | |
@Date 01/03/2024 10:08 | |
==================================================''' | |
from collections import defaultdict | |
import numpy as np | |
import torch | |
import pycolmap | |
from localization.camera import Camera | |
from localization.utils import compute_pose_error | |
class Frame: | |
def __init__(self, image: np.ndarray, camera: pycolmap.Camera, id: int, name: str = None, qvec=None, tvec=None, | |
scene_name=None, | |
reference_frame_id=None): | |
self.image = image | |
self.camera = camera | |
self.id = id | |
self.name = name | |
self.image_size = np.array([camera.height, camera.width]) | |
self.qvec = qvec | |
self.tvec = tvec | |
self.scene_name = scene_name | |
self.reference_frame_id = reference_frame_id | |
self.keypoints = None # [N, 3] | |
self.descriptors = None # [N, D] | |
self.segmentations = None # [N C] | |
self.seg_scores = None # [N C] | |
self.seg_ids = None # [N, 1] | |
self.point3D_ids = None # [N, 1] | |
self.xyzs = None | |
self.gt_qvec = None | |
self.gt_tvec = None | |
self.matched_scene_name = None | |
self.matched_keypoints = None | |
self.matched_keypoint_ids = None | |
self.matched_xyzs = None | |
self.matched_point3D_ids = None | |
self.matched_inliers = None | |
self.matched_sids = None | |
self.matched_order = None | |
self.refinement_reference_frame_ids = None | |
self.image_rec = None | |
self.image_matching = None | |
self.image_inlier = None | |
self.reference_frame_name = None | |
self.image_matching_tmp = None | |
self.image_inlier_tmp = None | |
self.reference_frame_name_tmp = None | |
self.tracking_status = None | |
self.time_feat = 0 | |
self.time_rec = 0 | |
self.time_loc = 0 | |
self.time_ref = 0 | |
def update_point3ds_old(self): | |
pt = torch.from_numpy(self.keypoints[:, :2]).unsqueeze(-1) # [M 2 1] | |
mpt = torch.from_numpy(self.matched_keypoints[:, :2].transpose()).unsqueeze(0) # [1 2 N] | |
dist = torch.sqrt(torch.sum((pt - mpt) ** 2, dim=1)) | |
values, ids = torch.topk(dist, dim=1, k=1, largest=False) | |
values = values[:, 0].numpy() | |
ids = ids[:, 0].numpy() | |
mask = (values < 1) # 1 pixel error | |
self.point3D_ids = np.zeros(shape=(self.keypoints.shape[0],), dtype=int) - 1 | |
self.point3D_ids[mask] = self.matched_point3D_ids[ids[mask]] | |
# self.xyzs = np.zeros(shape=(self.keypoints.shape[0], 3), dtype=float) | |
inlier_mask = self.matched_inliers | |
self.xyzs[mask] = self.matched_xyzs[ids[mask]] | |
self.seg_ids[mask] = self.matched_sids[ids[mask]] | |
def update_point3ds(self): | |
# print('Frame: update_point3ds: ', self.matched_keypoint_ids.shape, self.matched_xyzs.shape, | |
# self.matched_sids.shape, self.matched_point3D_ids.shape) | |
self.xyzs[self.matched_keypoint_ids] = self.matched_xyzs | |
self.seg_ids[self.matched_keypoint_ids] = self.matched_sids | |
self.point3D_ids[self.matched_keypoint_ids] = self.matched_point3D_ids | |
def add_keypoints(self, keypoints: np.ndarray, descriptors: np.ndarray): | |
self.keypoints = keypoints | |
self.descriptors = descriptors | |
self.initialize_localization_variables() | |
def add_segmentations(self, segmentations: torch.Tensor, filtering_threshold: float): | |
''' | |
:param segmentations: [number_points number_labels] | |
:return: | |
''' | |
seg_scores = torch.softmax(segmentations, dim=-1) | |
if filtering_threshold > 0: | |
scores_background = seg_scores[:, 0] | |
non_bg_mask = (scores_background < filtering_threshold) | |
print('pre filtering before: ', self.keypoints.shape) | |
if torch.sum(non_bg_mask) >= 0.4 * seg_scores.shape[0]: | |
self.keypoints = self.keypoints[non_bg_mask.cpu().numpy()] | |
self.descriptors = self.descriptors[non_bg_mask.cpu().numpy()] | |
# print('pre filtering after: ', self.keypoints.shape) | |
# update localization variables | |
self.initialize_localization_variables() | |
segmentations = segmentations[non_bg_mask] | |
seg_scores = seg_scores[non_bg_mask] | |
print('pre filtering after: ', self.keypoints.shape) | |
# extract initial segmentation info | |
self.segmentations = segmentations.cpu().numpy() | |
self.seg_scores = seg_scores.cpu().numpy() | |
self.seg_ids = segmentations.max(dim=-1)[1].cpu().numpy() - 1 # should start from 0 | |
def filter_keypoints(self, seg_scores: np.ndarray, filtering_threshold: float): | |
scores_background = seg_scores[:, 0] | |
non_bg_mask = (scores_background < filtering_threshold) | |
print('pre filtering before: ', self.keypoints.shape) | |
if np.sum(non_bg_mask) >= 0.4 * seg_scores.shape[0]: | |
self.keypoints = self.keypoints[non_bg_mask] | |
self.descriptors = self.descriptors[non_bg_mask] | |
print('pre filtering after: ', self.keypoints.shape) | |
# update localization variables | |
self.initialize_localization_variables() | |
return non_bg_mask | |
else: | |
print('pre filtering after: ', self.keypoints.shape) | |
return None | |
def compute_pose_error(self, pred_qvec=None, pred_tvec=None): | |
if pred_qvec is not None and pred_tvec is not None: | |
if self.gt_qvec is not None and self.gt_tvec is not None: | |
return compute_pose_error(pred_qcw=pred_qvec, pred_tcw=pred_tvec, | |
gt_qcw=self.gt_qvec, gt_tcw=self.gt_tvec) | |
else: | |
return 100, 100 | |
if self.qvec is None or self.tvec is None or self.gt_qvec is None or self.gt_tvec is None: | |
return 100, 100 | |
else: | |
err_q, err_t = compute_pose_error(pred_qcw=self.qvec, pred_tcw=self.tvec, | |
gt_qcw=self.gt_qvec, gt_tcw=self.gt_tvec) | |
return err_q, err_t | |
def get_intrinsics(self) -> np.ndarray: | |
camera_model = self.camera.model.name | |
params = self.camera.params | |
if camera_model in ("SIMPLE_PINHOLE", "SIMPLE_RADIAL", "RADIAL"): | |
fx = fy = params[0] | |
cx = params[1] | |
cy = params[2] | |
elif camera_model in ("PINHOLE", "OPENCV", "OPENCV_FISHEYE", "FULL_OPENCV"): | |
fx = params[0] | |
fy = params[1] | |
cx = params[2] | |
cy = params[3] | |
else: | |
raise Exception("Camera model not supported") | |
# intrinsics | |
K = np.identity(3) | |
K[0, 0] = fx | |
K[1, 1] = fy | |
K[0, 2] = cx | |
K[1, 2] = cy | |
return K | |
def get_dominate_seg_id(self): | |
counts = np.bincount(self.seg_ids[self.seg_ids > 0]) | |
return np.argmax(counts) | |
def clear_localization_track(self): | |
self.matched_scene_name = None | |
self.matched_keypoints = None | |
self.matched_xyzs = None | |
self.matched_point3D_ids = None | |
self.matched_inliers = None | |
self.matched_sids = None | |
self.refinement_reference_frame_ids = None | |
def initialize_localization_variables(self): | |
nkpt = self.keypoints.shape[0] | |
self.seg_ids = np.zeros(shape=(nkpt,), dtype=int) - 1 | |
self.point3D_ids = np.zeros(shape=(nkpt,), dtype=int) - 1 | |
self.xyzs = np.zeros(shape=(nkpt, 3), dtype=float) | |