|
import torch |
|
import yaml |
|
import time |
|
from collections import OrderedDict,namedtuple |
|
import os |
|
import sys |
|
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) |
|
sys.path.insert(0, ROOT_DIR) |
|
|
|
from sgmnet import matcher as SGM_Model |
|
from superglue import matcher as SG_Model |
|
|
|
|
|
import argparse |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--matcher_name', type=str, default='SGM', |
|
help='number of processes.') |
|
parser.add_argument('--config_path', type=str, default='configs/cost/sgm_cost.yaml', |
|
help='number of processes.') |
|
parser.add_argument('--num_kpt', type=int, default=4000, |
|
help='keypoint number, default:100') |
|
parser.add_argument('--iter_num', type=int, default=100, |
|
help='keypoint number, default:100') |
|
|
|
|
|
def test_cost(test_data,model): |
|
with torch.no_grad(): |
|
|
|
_=model(test_data) |
|
torch.cuda.synchronize() |
|
a=time.time() |
|
for _ in range(int(args.iter_num)): |
|
_=model(test_data) |
|
torch.cuda.synchronize() |
|
b=time.time() |
|
print('Average time per run(ms): ',(b-a)/args.iter_num*1e3) |
|
print('Peak memory(MB): ',torch.cuda.max_memory_allocated()/1e6) |
|
|
|
|
|
if __name__=='__main__': |
|
torch.backends.cudnn.benchmark=False |
|
args = parser.parse_args() |
|
with open(args.config_path, 'r') as f: |
|
model_config = yaml.load(f) |
|
model_config=namedtuple('model_config',model_config.keys())(*model_config.values()) |
|
|
|
if args.matcher_name=='SGM': |
|
model = SGM_Model(model_config) |
|
elif args.matcher_name=='SG': |
|
model = SG_Model(model_config) |
|
model.cuda(),model.eval() |
|
|
|
test_data = { |
|
'x1':torch.rand(1,args.num_kpt,2).cuda()-0.5, |
|
'x2':torch.rand(1,args.num_kpt,2).cuda()-0.5, |
|
'desc1': torch.rand(1,args.num_kpt,128).cuda(), |
|
'desc2': torch.rand(1,args.num_kpt,128).cuda() |
|
} |
|
|
|
test_cost(test_data,model) |
|
|