|
import torch |
|
import numpy as np |
|
import os |
|
from collections import OrderedDict,namedtuple |
|
import sys |
|
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) |
|
sys.path.insert(0, ROOT_DIR) |
|
|
|
from sgmnet import matcher as SGM_Model |
|
from superglue import matcher as SG_Model |
|
from utils import evaluation_utils |
|
|
|
class GNN_Matcher(object): |
|
|
|
def __init__(self,config,model_name): |
|
assert model_name=='SGM' or model_name=='SG' |
|
|
|
config=namedtuple('config',config.keys())(*config.values()) |
|
self.p_th=config.p_th |
|
self.model = SGM_Model(config) if model_name=='SGM' else SG_Model(config) |
|
self.model.cuda(),self.model.eval() |
|
checkpoint = torch.load(os.path.join(config.model_dir, 'model_best.pth')) |
|
|
|
if list(checkpoint['state_dict'].items())[0][0].split('.')[0]=='module': |
|
new_stat_dict=OrderedDict() |
|
for key,value in checkpoint['state_dict'].items(): |
|
new_stat_dict[key[7:]]=value |
|
checkpoint['state_dict']=new_stat_dict |
|
self.model.load_state_dict(checkpoint['state_dict']) |
|
|
|
def run(self,test_data): |
|
norm_x1,norm_x2=evaluation_utils.normalize_size(test_data['x1'][:,:2],test_data['size1']),\ |
|
evaluation_utils.normalize_size(test_data['x2'][:,:2],test_data['size2']) |
|
x1,x2=np.concatenate([norm_x1,test_data['x1'][:,2,np.newaxis]],axis=-1),np.concatenate([norm_x2,test_data['x2'][:,2,np.newaxis]],axis=-1) |
|
feed_data={'x1':torch.from_numpy(x1[np.newaxis]).cuda().float(), |
|
'x2':torch.from_numpy(x2[np.newaxis]).cuda().float(), |
|
'desc1':torch.from_numpy(test_data['desc1'][np.newaxis]).cuda().float(), |
|
'desc2':torch.from_numpy(test_data['desc2'][np.newaxis]).cuda().float()} |
|
with torch.no_grad(): |
|
res=self.model(feed_data,test_mode=True) |
|
p=res['p'] |
|
index1,index2=self.match_p(p[0,:-1,:-1]) |
|
corr1,corr2=test_data['x1'][:,:2][index1.cpu()],test_data['x2'][:,:2][index2.cpu()] |
|
if len(corr1.shape)==1: |
|
corr1,corr2=corr1[np.newaxis],corr2[np.newaxis] |
|
return corr1,corr2 |
|
|
|
def match_p(self,p): |
|
score,index=torch.topk(p,k=1,dim=-1) |
|
_,index2=torch.topk(p,k=1,dim=-2) |
|
mask_th,index,index2=score[:,0]>self.p_th,index[:,0],index2.squeeze(0) |
|
mask_mc=index2[index] == torch.arange(len(p)).cuda() |
|
mask=mask_th&mask_mc |
|
index1,index2=torch.nonzero(mask).squeeze(1),index[mask] |
|
return index1,index2 |
|
|
|
|
|
class NN_Matcher(object): |
|
|
|
def __init__(self,config): |
|
config=namedtuple('config',config.keys())(*config.values()) |
|
self.mutual_check=config.mutual_check |
|
self.ratio_th=config.ratio_th |
|
|
|
def run(self,test_data): |
|
desc1,desc2,x1,x2=test_data['desc1'],test_data['desc2'],test_data['x1'],test_data['x2'] |
|
desc_mat=np.sqrt(abs((desc1**2).sum(-1)[:,np.newaxis]+(desc2**2).sum(-1)[np.newaxis]-2*desc1@desc2.T)) |
|
nn_index=np.argpartition(desc_mat,kth=(1,2),axis=-1) |
|
dis_value12=np.take_along_axis(desc_mat,nn_index, axis=-1) |
|
ratio_score=dis_value12[:,0]/dis_value12[:,1] |
|
nn_index1=nn_index[:,0] |
|
nn_index2=np.argmin(desc_mat,axis=0) |
|
mask_ratio,mask_mutual=ratio_score<self.ratio_th,np.arange(len(x1))==nn_index2[nn_index1] |
|
corr1,corr2=x1[:,:2],x2[:,:2][nn_index1] |
|
if self.mutual_check: |
|
mask=mask_ratio&mask_mutual |
|
else: |
|
mask=mask_ratio |
|
corr1,corr2=corr1[mask],corr2[mask] |
|
return corr1,corr2 |
|
|
|
|
|
|
|
|
|
|