import os from torch.multiprocessing import Process,Manager,set_start_method,Pool import functools import argparse import yaml import numpy as np import sys import cv2 from tqdm import trange set_start_method('spawn',force=True) ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, ROOT_DIR) from components import load_component from utils import evaluation_utils,metrics parser = argparse.ArgumentParser(description='dump eval data.') parser.add_argument('--config_path', type=str, default='configs/eval/scannet_eval_sgm.yaml') parser.add_argument('--num_process_match', type=int, default=4) parser.add_argument('--num_process_eval', type=int, default=4) parser.add_argument('--vis_folder',type=str,default=None) args=parser.parse_args() def feed_match(info,matcher): x1,x2,desc1,desc2,size1,size2=info['x1'],info['x2'],info['desc1'],info['desc2'],info['img1'].shape[:2],info['img2'].shape[:2] test_data = {'x1': x1,'x2': x2,'desc1': desc1,'desc2': desc2,'size1':np.flip(np.asarray(size1)),'size2':np.flip(np.asarray(size2)) } corr1,corr2=matcher.run(test_data) return [corr1,corr2] def reader_handler(config,read_que): reader=load_component('reader',config['name'],config) for index in range(len(reader)): index+=0 info=reader.run(index) read_que.put(info) read_que.put('over') def match_handler(config,read_que,match_que): matcher=load_component('matcher',config['name'],config) match_func=functools.partial(feed_match,matcher=matcher) pool = Pool(args.num_process_match) cache=[] while True: item=read_que.get() #clear cache if item=='over': if len(cache)!=0: results=pool.map(match_func,cache) for cur_item,cur_result in zip(cache,results): cur_item['corr1'],cur_item['corr2']=cur_result[0],cur_result[1] match_que.put(cur_item) match_que.put('over') break cache.append(item) #print(len(cache)) if len(cache)==args.num_process_match: #matching in parallel results=pool.map(match_func,cache) for cur_item,cur_result in zip(cache,results): cur_item['corr1'],cur_item['corr2']=cur_result[0],cur_result[1] match_que.put(cur_item) cache=[] pool.close() pool.join() def evaluate_handler(config,match_que): evaluator=load_component('evaluator',config['name'],config) pool = Pool(args.num_process_eval) cache=[] for _ in trange(config['num_pair']): item=match_que.get() if item=='over': if len(cache)!=0: results=pool.map(evaluator.run,cache) for cur_res in results: evaluator.res_inqueue(cur_res) break cache.append(item) if len(cache)==args.num_process_eval: results=pool.map(evaluator.run,cache) for cur_res in results: evaluator.res_inqueue(cur_res) cache=[] if args.vis_folder is not None: #dump visualization corr1_norm,corr2_norm=evaluation_utils.normalize_intrinsic(item['corr1'],item['K1']),\ evaluation_utils.normalize_intrinsic(item['corr2'],item['K2']) inlier_mask=metrics.compute_epi_inlier(corr1_norm,corr2_norm,item['e'],config['inlier_th']) display=evaluation_utils.draw_match(item['img1'],item['img2'],item['corr1'],item['corr2'],inlier_mask) cv2.imwrite(os.path.join(args.vis_folder,str(item['index'])+'.png'),display) evaluator.parse() if __name__=='__main__': with open(args.config_path, 'r') as f: config = yaml.load(f) if args.vis_folder is not None and not os.path.exists(args.vis_folder): os.mkdir(args.vis_folder) read_que,match_que,estimate_que=Manager().Queue(maxsize=100),Manager().Queue(maxsize=100),Manager().Queue(maxsize=100) read_process=Process(target=reader_handler,args=(config['reader'],read_que)) match_process=Process(target=match_handler,args=(config['matcher'],read_que,match_que)) evaluate_process=Process(target=evaluate_handler,args=(config['evaluator'],match_que)) read_process.start() match_process.start() evaluate_process.start() read_process.join() match_process.join() evaluate_process.join()