Realcat
fix: eloftr
63f3cf2
raw
history blame
11.3 kB
# -*- coding: UTF-8 -*-
'''=================================================
@Project -> File pram -> loc_by_rec
@IDE PyCharm
@Author fx221@cam.ac.uk
@Date 08/02/2024 15:26
=================================================='''
import torch
import pycolmap
from localization.multimap3d import MultiMap3D
from localization.frame import Frame
import yaml, cv2, time
import numpy as np
import os.path as osp
import threading
from recognition.vis_seg import vis_seg_point, generate_color_dic
from tools.common import resize_img
from localization.viewer import Viewer
from localization.tracker import Tracker
from localization.utils import read_query_info
from tools.common import puttext_with_background
def loc_by_rec_online(rec_model, config, local_feat, img_transforms=None):
seg_color = generate_color_dic(n_seg=2000)
dataset_path = config['dataset_path']
show = config['localization']['show']
if show:
cv2.namedWindow('img', cv2.WINDOW_NORMAL)
locMap = MultiMap3D(config=config, save_dir=None)
if config['dataset'][0] in ['Aachen']:
viewer_config = {'scene': 'outdoor',
'image_size_indoor': 4,
'image_line_width_indoor': 8, }
elif config['dataset'][0] in ['C']:
viewer_config = {'scene': 'outdoor'}
elif config['dataset'][0] in ['12Scenes', '7Scenes']:
viewer_config = {'scene': 'indoor', }
else:
viewer_config = {'scene': 'outdoor',
'image_size_indoor': 0.4,
'image_line_width_indoor': 2, }
# start viewer
mViewer = Viewer(locMap=locMap, seg_color=seg_color, config=viewer_config)
mViewer.refinement = locMap.do_refinement
# locMap.viewer = mViewer
viewer_thread = threading.Thread(target=mViewer.run)
viewer_thread.start()
# start tracker
mTracker = Tracker(locMap=locMap, matcher=locMap.matcher, config=config)
dataset_name = config['dataset'][0]
all_scene_query_info = {}
with open(osp.join(config['config_path'], '{:s}.yaml'.format(dataset_name)), 'r') as f:
scene_config = yaml.load(f, Loader=yaml.Loader)
# multiple scenes in a single dataset
err_ths_cnt = [0, 0, 0, 0]
show_time = -1
scenes = scene_config['scenes']
n_total = 0
for scene in scenes:
if len(config['localization']['loc_scene_name']) > 0:
if scene not in config['localization']['loc_scene_name']:
continue
query_path = osp.join(config['dataset_path'], dataset_name, scene, scene_config[scene]['query_path'])
query_info = read_query_info(query_fn=query_path)
all_scene_query_info[dataset_name + '/' + scene] = query_info
image_path = osp.join(dataset_path, dataset_name, scene)
for fn in sorted(query_info.keys()):
# for fn in sorted(query_info.keys())[880:][::5]: # darwinRGB-loc-outdoor-aligned
# for fn in sorted(query_info.keys())[3161:][::5]: # darwinRGB-loc-indoor-aligned
# for fn in sorted(query_info.keys())[2840:][::5]: # darwinRGB-loc-indoor-aligned
# for fn in sorted(query_info.keys())[2100:][::5]: # darwinRGB-loc-outdoor
# for fn in sorted(query_info.keys())[4360:][::5]: # darwinRGB-loc-indoor
# for fn in sorted(query_info.keys())[1380:]: # Cam-Church
# for fn in sorted(query_info.keys())[::5]: #ACUED-test2
# for fn in sorted(query_info.keys())[1260:]: # jesus aligned
# for fn in sorted(query_info.keys())[1260:]: # jesus aligned
# for fn in sorted(query_info.keys())[4850:]:
img = cv2.imread(osp.join(image_path, fn)) # BGR
camera_model, width, height, params = all_scene_query_info[dataset_name + '/' + scene][fn]
# camera = Camera(id=-1, model=camera_model, width=width, height=height, params=params)
camera = pycolmap.Camera(model=camera_model, width=int(width), height=int(height), params=params)
curr_frame = Frame(image=img, camera=camera, id=0, name=fn, scene_name=dataset_name + '/' + scene)
gt_sub_map = locMap.sub_maps[curr_frame.scene_name]
if gt_sub_map.gt_poses is not None and curr_frame.name in gt_sub_map.gt_poses.keys():
curr_frame.gt_qvec = gt_sub_map.gt_poses[curr_frame.name]['qvec']
curr_frame.gt_tvec = gt_sub_map.gt_poses[curr_frame.name]['tvec']
with torch.no_grad():
if config['image_dim'] == 1:
img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img_cuda = torch.from_numpy(img_gray / 255)[None].cuda().float()
else:
img_cuda = torch.from_numpy(img / 255).permute(2, 0, 1).cuda().float()
if img_transforms is not None:
img_cuda = img_transforms(img_cuda)[None]
else:
img_cuda = img_cuda[None]
t_start = time.time()
encoder_out = local_feat.extract_local_global(data={'image': img_cuda},
config={'min_keypoints': 128,
'max_keypoints': config['eval_max_keypoints'],
}
)
t_feat = time.time() - t_start
# global_descriptors_cuda = encoder_out['global_descriptors']
scores_cuda = encoder_out['scores'][0][None]
kpts_cuda = encoder_out['keypoints'][0][None]
descriptors_cuda = encoder_out['descriptors'][0][None].permute(0, 2, 1)
curr_frame.add_keypoints(keypoints=np.hstack([kpts_cuda[0].cpu().numpy(),
scores_cuda[0].cpu().numpy().reshape(-1, 1)]),
descriptors=descriptors_cuda[0].cpu().numpy())
curr_frame.time_feat = t_feat
t_start = time.time()
_, seg_descriptors = local_feat.sample(score_map=encoder_out['score_map'],
semi_descs=encoder_out['mid_features'],
kpts=kpts_cuda[0],
norm_desc=config['norm_desc'])
rec_out = rec_model({'scores': scores_cuda,
'seg_descriptors': seg_descriptors[None].permute(0, 2, 1),
'keypoints': kpts_cuda,
'image': img_cuda})
t_rec = time.time() - t_start
curr_frame.time_rec = t_rec
pred = {
'scores': scores_cuda,
'keypoints': kpts_cuda,
'descriptors': descriptors_cuda,
# 'global_descriptors': global_descriptors_cuda,
'image_size': np.array([img.shape[1], img.shape[0]])[None],
}
pred = {**pred, **rec_out}
pred_seg = torch.max(pred['prediction'], dim=2)[1] # [B, N, C]
pred_seg = pred_seg[0].cpu().numpy()
kpts = kpts_cuda[0].cpu().numpy()
segmentations = pred['prediction'][0] # .cpu().numpy() # [N, C]
curr_frame.add_segmentations(segmentations=segmentations,
filtering_threshold=config['localization']['pre_filtering_th'])
img_pred_seg = vis_seg_point(img=img, kpts=curr_frame.keypoints,
segs=curr_frame.seg_ids + 1, seg_color=seg_color, radius=9)
show_text = 'kpts: {:d}'.format(kpts.shape[0])
img_pred_seg = cv2.putText(img=img_pred_seg,
text=show_text,
org=(50, 30),
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
fontScale=1, color=(0, 0, 255),
thickness=2, lineType=cv2.LINE_AA)
curr_frame.image_rec = img_pred_seg
if show:
img_text = puttext_with_background(image=img, text='Press C - continue | S - pause | Q - exit',
org=(30, 50),
bg_color=(255, 255, 255),
text_color=(0, 0, 255),
fontScale=1, thickness=2)
cv2.imshow('img', img_text)
key = cv2.waitKey(show_time)
if key == ord('q'):
exit(0)
elif key == ord('s'):
show_time = -1
elif key == ord('c'):
show_time = 1
# Step1: do tracker first
success = not mTracker.lost and mViewer.tracking
if success:
success = mTracker.run(frame=curr_frame)
if success:
mViewer.update(curr_frame=curr_frame)
if not success:
# success = locMap.run(q_frame=curr_frame, q_segs=segmentations)
success = locMap.run(q_frame=curr_frame)
if success:
mViewer.update(curr_frame=curr_frame)
if success:
curr_frame.update_point3ds()
if mViewer.tracking:
mTracker.lost = False
mTracker.last_frame = curr_frame
time.sleep(50 / 1000)
locMap.do_refinement = mViewer.refinement
n_total = n_total + 1
q_err, t_err = curr_frame.compute_pose_error()
if q_err <= 5 and t_err <= 0.05:
err_ths_cnt[0] = err_ths_cnt[0] + 1
if q_err <= 2 and t_err <= 0.25:
err_ths_cnt[1] = err_ths_cnt[1] + 1
if q_err <= 5 and t_err <= 0.5:
err_ths_cnt[2] = err_ths_cnt[2] + 1
if q_err <= 10 and t_err <= 5:
err_ths_cnt[3] = err_ths_cnt[3] + 1
time_total = curr_frame.time_feat + curr_frame.time_rec + curr_frame.time_loc + curr_frame.time_ref
print_text = 'qname: {:s} localization {:b}, q_err: {:.2f}, t_err: {:.2f}, {:d}/{:d}/{:d}/{:d}/{:d}, time: {:.2f}/{:.2f}/{:.2f}/{:.2f}/{:.2f}'.format(
scene + '/' + fn, success, q_err, t_err,
err_ths_cnt[0],
err_ths_cnt[1],
err_ths_cnt[2],
err_ths_cnt[3],
n_total,
curr_frame.time_feat, curr_frame.time_rec, curr_frame.time_loc, curr_frame.time_ref, time_total
)
print(print_text)
mViewer.terminate()
viewer_thread.join()