File size: 1,919 Bytes
a80d6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import yaml
import time
from collections import OrderedDict,namedtuple
import os
import sys
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, ROOT_DIR)

from sgmnet import matcher as SGM_Model
from superglue import matcher as SG_Model


import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--matcher_name', type=str, default='SGM',
  help='number of processes.')
parser.add_argument('--config_path', type=str, default='configs/cost/sgm_cost.yaml',
  help='number of processes.')
parser.add_argument('--num_kpt', type=int, default=4000,
  help='keypoint number, default:100')
parser.add_argument('--iter_num', type=int, default=100,
  help='keypoint number, default:100')


def test_cost(test_data,model):
    with torch.no_grad():
        #warm up call
        _=model(test_data)
        torch.cuda.synchronize()
        a=time.time()
        for _ in range(int(args.iter_num)):
            _=model(test_data)
        torch.cuda.synchronize()
        b=time.time()
    print('Average time per run(ms): ',(b-a)/args.iter_num*1e3)
    print('Peak memory(MB): ',torch.cuda.max_memory_allocated()/1e6)


if __name__=='__main__':
    torch.backends.cudnn.benchmark=False
    args = parser.parse_args()
    with open(args.config_path, 'r') as f:
      model_config = yaml.load(f)
    model_config=namedtuple('model_config',model_config.keys())(*model_config.values())
    
    if args.matcher_name=='SGM':
      model = SGM_Model(model_config) 
    elif args.matcher_name=='SG':
      model = SG_Model(model_config)
    model.cuda(),model.eval()
    
    test_data = {
            'x1':torch.rand(1,args.num_kpt,2).cuda()-0.5,
            'x2':torch.rand(1,args.num_kpt,2).cuda()-0.5,
            'desc1': torch.rand(1,args.num_kpt,128).cuda(),
            'desc2': torch.rand(1,args.num_kpt,128).cuda()
            }

    test_cost(test_data,model)