import torch import yaml import time from collections import OrderedDict,namedtuple import os import sys ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, ROOT_DIR) from sgmnet import matcher as SGM_Model from superglue import matcher as SG_Model import argparse parser = argparse.ArgumentParser() parser.add_argument('--matcher_name', type=str, default='SGM', help='number of processes.') parser.add_argument('--config_path', type=str, default='configs/cost/sgm_cost.yaml', help='number of processes.') parser.add_argument('--num_kpt', type=int, default=4000, help='keypoint number, default:100') parser.add_argument('--iter_num', type=int, default=100, help='keypoint number, default:100') def test_cost(test_data,model): with torch.no_grad(): #warm up call _=model(test_data) torch.cuda.synchronize() a=time.time() for _ in range(int(args.iter_num)): _=model(test_data) torch.cuda.synchronize() b=time.time() print('Average time per run(ms): ',(b-a)/args.iter_num*1e3) print('Peak memory(MB): ',torch.cuda.max_memory_allocated()/1e6) if __name__=='__main__': torch.backends.cudnn.benchmark=False args = parser.parse_args() with open(args.config_path, 'r') as f: model_config = yaml.load(f) model_config=namedtuple('model_config',model_config.keys())(*model_config.values()) if args.matcher_name=='SGM': model = SGM_Model(model_config) elif args.matcher_name=='SG': model = SG_Model(model_config) model.cuda(),model.eval() test_data = { 'x1':torch.rand(1,args.num_kpt,2).cuda()-0.5, 'x2':torch.rand(1,args.num_kpt,2).cuda()-0.5, 'desc1': torch.rand(1,args.num_kpt,128).cuda(), 'desc2': torch.rand(1,args.num_kpt,128).cuda() } test_cost(test_data,model)