Vincentqyw
update: features and matchers
a80d6bb
raw
history blame
1.92 kB
import torch
import yaml
import time
from collections import OrderedDict,namedtuple
import os
import sys
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, ROOT_DIR)
from sgmnet import matcher as SGM_Model
from superglue import matcher as SG_Model
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--matcher_name', type=str, default='SGM',
help='number of processes.')
parser.add_argument('--config_path', type=str, default='configs/cost/sgm_cost.yaml',
help='number of processes.')
parser.add_argument('--num_kpt', type=int, default=4000,
help='keypoint number, default:100')
parser.add_argument('--iter_num', type=int, default=100,
help='keypoint number, default:100')
def test_cost(test_data,model):
with torch.no_grad():
#warm up call
_=model(test_data)
torch.cuda.synchronize()
a=time.time()
for _ in range(int(args.iter_num)):
_=model(test_data)
torch.cuda.synchronize()
b=time.time()
print('Average time per run(ms): ',(b-a)/args.iter_num*1e3)
print('Peak memory(MB): ',torch.cuda.max_memory_allocated()/1e6)
if __name__=='__main__':
torch.backends.cudnn.benchmark=False
args = parser.parse_args()
with open(args.config_path, 'r') as f:
model_config = yaml.load(f)
model_config=namedtuple('model_config',model_config.keys())(*model_config.values())
if args.matcher_name=='SGM':
model = SGM_Model(model_config)
elif args.matcher_name=='SG':
model = SG_Model(model_config)
model.cuda(),model.eval()
test_data = {
'x1':torch.rand(1,args.num_kpt,2).cuda()-0.5,
'x2':torch.rand(1,args.num_kpt,2).cuda()-0.5,
'desc1': torch.rand(1,args.num_kpt,128).cuda(),
'desc2': torch.rand(1,args.num_kpt,128).cuda()
}
test_cost(test_data,model)