File size: 1,919 Bytes
404d2af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import torch
import yaml
import time
from collections import OrderedDict,namedtuple
import os
import sys
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, ROOT_DIR)
from sgmnet import matcher as SGM_Model
from superglue import matcher as SG_Model
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--matcher_name', type=str, default='SGM',
help='number of processes.')
parser.add_argument('--config_path', type=str, default='configs/cost/sgm_cost.yaml',
help='number of processes.')
parser.add_argument('--num_kpt', type=int, default=4000,
help='keypoint number, default:100')
parser.add_argument('--iter_num', type=int, default=100,
help='keypoint number, default:100')
def test_cost(test_data,model):
with torch.no_grad():
#warm up call
_=model(test_data)
torch.cuda.synchronize()
a=time.time()
for _ in range(int(args.iter_num)):
_=model(test_data)
torch.cuda.synchronize()
b=time.time()
print('Average time per run(ms): ',(b-a)/args.iter_num*1e3)
print('Peak memory(MB): ',torch.cuda.max_memory_allocated()/1e6)
if __name__=='__main__':
torch.backends.cudnn.benchmark=False
args = parser.parse_args()
with open(args.config_path, 'r') as f:
model_config = yaml.load(f)
model_config=namedtuple('model_config',model_config.keys())(*model_config.values())
if args.matcher_name=='SGM':
model = SGM_Model(model_config)
elif args.matcher_name=='SG':
model = SG_Model(model_config)
model.cuda(),model.eval()
test_data = {
'x1':torch.rand(1,args.num_kpt,2).cuda()-0.5,
'x2':torch.rand(1,args.num_kpt,2).cuda()-0.5,
'desc1': torch.rand(1,args.num_kpt,128).cuda(),
'desc2': torch.rand(1,args.num_kpt,128).cuda()
}
test_cost(test_data,model)
|