# -*- coding: UTF-8 -*- '''================================================= @Project -> File pram -> loc_by_rec @IDE PyCharm @Author fx221@cam.ac.uk @Date 08/02/2024 15:26 ==================================================''' import torch import pycolmap from localization.multimap3d import MultiMap3D from localization.frame import Frame import yaml, cv2, time import numpy as np import os.path as osp import threading from recognition.vis_seg import vis_seg_point, generate_color_dic from tools.common import resize_img from localization.viewer import Viewer from localization.tracker import Tracker from localization.utils import read_query_info from tools.common import puttext_with_background def loc_by_rec_online(rec_model, config, local_feat, img_transforms=None): seg_color = generate_color_dic(n_seg=2000) dataset_path = config['dataset_path'] show = config['localization']['show'] if show: cv2.namedWindow('img', cv2.WINDOW_NORMAL) locMap = MultiMap3D(config=config, save_dir=None) if config['dataset'][0] in ['Aachen']: viewer_config = {'scene': 'outdoor', 'image_size_indoor': 4, 'image_line_width_indoor': 8, } elif config['dataset'][0] in ['C']: viewer_config = {'scene': 'outdoor'} elif config['dataset'][0] in ['12Scenes', '7Scenes']: viewer_config = {'scene': 'indoor', } else: viewer_config = {'scene': 'outdoor', 'image_size_indoor': 0.4, 'image_line_width_indoor': 2, } # start viewer mViewer = Viewer(locMap=locMap, seg_color=seg_color, config=viewer_config) mViewer.refinement = locMap.do_refinement # locMap.viewer = mViewer viewer_thread = threading.Thread(target=mViewer.run) viewer_thread.start() # start tracker mTracker = Tracker(locMap=locMap, matcher=locMap.matcher, config=config) dataset_name = config['dataset'][0] all_scene_query_info = {} with open(osp.join(config['config_path'], '{:s}.yaml'.format(dataset_name)), 'r') as f: scene_config = yaml.load(f, Loader=yaml.Loader) # multiple scenes in a single dataset err_ths_cnt = [0, 0, 0, 0] show_time = -1 scenes = scene_config['scenes'] n_total = 0 for scene in scenes: if len(config['localization']['loc_scene_name']) > 0: if scene not in config['localization']['loc_scene_name']: continue query_path = osp.join(config['dataset_path'], dataset_name, scene, scene_config[scene]['query_path']) query_info = read_query_info(query_fn=query_path) all_scene_query_info[dataset_name + '/' + scene] = query_info image_path = osp.join(dataset_path, dataset_name, scene) for fn in sorted(query_info.keys()): # for fn in sorted(query_info.keys())[880:][::5]: # darwinRGB-loc-outdoor-aligned # for fn in sorted(query_info.keys())[3161:][::5]: # darwinRGB-loc-indoor-aligned # for fn in sorted(query_info.keys())[2840:][::5]: # darwinRGB-loc-indoor-aligned # for fn in sorted(query_info.keys())[2100:][::5]: # darwinRGB-loc-outdoor # for fn in sorted(query_info.keys())[4360:][::5]: # darwinRGB-loc-indoor # for fn in sorted(query_info.keys())[1380:]: # Cam-Church # for fn in sorted(query_info.keys())[::5]: #ACUED-test2 # for fn in sorted(query_info.keys())[1260:]: # jesus aligned # for fn in sorted(query_info.keys())[1260:]: # jesus aligned # for fn in sorted(query_info.keys())[4850:]: img = cv2.imread(osp.join(image_path, fn)) # BGR camera_model, width, height, params = all_scene_query_info[dataset_name + '/' + scene][fn] # camera = Camera(id=-1, model=camera_model, width=width, height=height, params=params) camera = pycolmap.Camera(model=camera_model, width=int(width), height=int(height), params=params) curr_frame = Frame(image=img, camera=camera, id=0, name=fn, scene_name=dataset_name + '/' + scene) gt_sub_map = locMap.sub_maps[curr_frame.scene_name] if gt_sub_map.gt_poses is not None and curr_frame.name in gt_sub_map.gt_poses.keys(): curr_frame.gt_qvec = gt_sub_map.gt_poses[curr_frame.name]['qvec'] curr_frame.gt_tvec = gt_sub_map.gt_poses[curr_frame.name]['tvec'] with torch.no_grad(): if config['image_dim'] == 1: img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) img_cuda = torch.from_numpy(img_gray / 255)[None].cuda().float() else: img_cuda = torch.from_numpy(img / 255).permute(2, 0, 1).cuda().float() if img_transforms is not None: img_cuda = img_transforms(img_cuda)[None] else: img_cuda = img_cuda[None] t_start = time.time() encoder_out = local_feat.extract_local_global(data={'image': img_cuda}, config={'min_keypoints': 128, 'max_keypoints': config['eval_max_keypoints'], } ) t_feat = time.time() - t_start # global_descriptors_cuda = encoder_out['global_descriptors'] scores_cuda = encoder_out['scores'][0][None] kpts_cuda = encoder_out['keypoints'][0][None] descriptors_cuda = encoder_out['descriptors'][0][None].permute(0, 2, 1) curr_frame.add_keypoints(keypoints=np.hstack([kpts_cuda[0].cpu().numpy(), scores_cuda[0].cpu().numpy().reshape(-1, 1)]), descriptors=descriptors_cuda[0].cpu().numpy()) curr_frame.time_feat = t_feat t_start = time.time() _, seg_descriptors = local_feat.sample(score_map=encoder_out['score_map'], semi_descs=encoder_out['mid_features'], kpts=kpts_cuda[0], norm_desc=config['norm_desc']) rec_out = rec_model({'scores': scores_cuda, 'seg_descriptors': seg_descriptors[None].permute(0, 2, 1), 'keypoints': kpts_cuda, 'image': img_cuda}) t_rec = time.time() - t_start curr_frame.time_rec = t_rec pred = { 'scores': scores_cuda, 'keypoints': kpts_cuda, 'descriptors': descriptors_cuda, # 'global_descriptors': global_descriptors_cuda, 'image_size': np.array([img.shape[1], img.shape[0]])[None], } pred = {**pred, **rec_out} pred_seg = torch.max(pred['prediction'], dim=2)[1] # [B, N, C] pred_seg = pred_seg[0].cpu().numpy() kpts = kpts_cuda[0].cpu().numpy() segmentations = pred['prediction'][0] # .cpu().numpy() # [N, C] curr_frame.add_segmentations(segmentations=segmentations, filtering_threshold=config['localization']['pre_filtering_th']) img_pred_seg = vis_seg_point(img=img, kpts=curr_frame.keypoints, segs=curr_frame.seg_ids + 1, seg_color=seg_color, radius=9) show_text = 'kpts: {:d}'.format(kpts.shape[0]) img_pred_seg = cv2.putText(img=img_pred_seg, text=show_text, org=(50, 30), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), thickness=2, lineType=cv2.LINE_AA) curr_frame.image_rec = img_pred_seg if show: img_text = puttext_with_background(image=img, text='Press C - continue | S - pause | Q - exit', org=(30, 50), bg_color=(255, 255, 255), text_color=(0, 0, 255), fontScale=1, thickness=2) cv2.imshow('img', img_text) key = cv2.waitKey(show_time) if key == ord('q'): exit(0) elif key == ord('s'): show_time = -1 elif key == ord('c'): show_time = 1 # Step1: do tracker first success = not mTracker.lost and mViewer.tracking if success: success = mTracker.run(frame=curr_frame) if success: mViewer.update(curr_frame=curr_frame) if not success: # success = locMap.run(q_frame=curr_frame, q_segs=segmentations) success = locMap.run(q_frame=curr_frame) if success: mViewer.update(curr_frame=curr_frame) if success: curr_frame.update_point3ds() if mViewer.tracking: mTracker.lost = False mTracker.last_frame = curr_frame time.sleep(50 / 1000) locMap.do_refinement = mViewer.refinement n_total = n_total + 1 q_err, t_err = curr_frame.compute_pose_error() if q_err <= 5 and t_err <= 0.05: err_ths_cnt[0] = err_ths_cnt[0] + 1 if q_err <= 2 and t_err <= 0.25: err_ths_cnt[1] = err_ths_cnt[1] + 1 if q_err <= 5 and t_err <= 0.5: err_ths_cnt[2] = err_ths_cnt[2] + 1 if q_err <= 10 and t_err <= 5: err_ths_cnt[3] = err_ths_cnt[3] + 1 time_total = curr_frame.time_feat + curr_frame.time_rec + curr_frame.time_loc + curr_frame.time_ref print_text = 'qname: {:s} localization {:b}, q_err: {:.2f}, t_err: {:.2f}, {:d}/{:d}/{:d}/{:d}/{:d}, time: {:.2f}/{:.2f}/{:.2f}/{:.2f}/{:.2f}'.format( scene + '/' + fn, success, q_err, t_err, err_ths_cnt[0], err_ths_cnt[1], err_ths_cnt[2], err_ths_cnt[3], n_total, curr_frame.time_feat, curr_frame.time_rec, curr_frame.time_loc, curr_frame.time_ref, time_total ) print(print_text) mViewer.terminate() viewer_thread.join()