Vincentqyw
update: features and matchers
437b5f6
raw
history blame
3.45 kB
import torch
import numpy as np
import cv2
import os
from loss import batch_episym
from tqdm import tqdm
import sys
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, ROOT_DIR)
from utils import evaluation_utils,train_utils
def valid(valid_loader, model,match_loss, config,model_config):
model.eval()
loader_iter = iter(valid_loader)
num_pair = 0
total_loss,total_acc_corr,total_acc_incorr=0,0,0
total_precision,total_recall=torch.zeros(model_config.layer_num ,device='cuda'),\
torch.zeros(model_config.layer_num ,device='cuda')
total_acc_mid=torch.zeros(len(model_config.seedlayer)-1,device='cuda')
with torch.no_grad():
if config.local_rank==0:
loader_iter=tqdm(loader_iter)
print('validating...')
for test_data in loader_iter:
num_pair+= 1
test_data = train_utils.tocuda(test_data)
res= model(test_data)
loss_res=match_loss.run(test_data,res)
total_acc_corr+=loss_res['acc_corr']
total_acc_incorr+=loss_res['acc_incorr']
total_loss+=loss_res['total_loss']
if config.model_name=='SGM':
total_acc_mid+=loss_res['mid_acc_corr']
total_precision,total_recall=total_precision+loss_res['pre_seed_conf'],total_recall+loss_res['recall_seed_conf']
total_acc_corr/=num_pair
total_acc_incorr /= num_pair
total_precision/=num_pair
total_recall/=num_pair
total_acc_mid/=num_pair
#apply tensor reduction
total_loss,total_acc_corr,total_acc_incorr,total_precision,total_recall,total_acc_mid=train_utils.reduce_tensor(total_loss,'sum'),\
train_utils.reduce_tensor(total_acc_corr,'mean'),train_utils.reduce_tensor(total_acc_incorr,'mean'),\
train_utils.reduce_tensor(total_precision,'mean'),train_utils.reduce_tensor(total_recall,'mean'),train_utils.reduce_tensor(total_acc_mid,'mean')
model.train()
return total_loss,total_acc_corr,total_acc_incorr,total_precision,total_recall,total_acc_mid
def dump_train_vis(res,data,step,config):
#batch matching
p=res['p'][:,:-1,:-1]
score,index1=torch.max(p,dim=-1)
_,index2=torch.max(p,dim=-2)
mask_th=score>0.2
mask_mc=index2.gather(index=index1,dim=1) == torch.arange(len(p[0])).cuda()[None]
mask_p=mask_th&mask_mc#B*N
corr1,corr2=data['x1'],data['x2'].gather(index=index1[:,:,None].expand(-1,-1,2),dim=1)
corr1_kpt,corr2_kpt=data['kpt1'],data['kpt2'].gather(index=index1[:,:,None].expand(-1,-1,2),dim=1)
epi_dis=batch_episym(corr1,corr2,data['e_gt'])
mask_inlier=epi_dis<config.inlier_th#B*N
#dump vis
for cur_mask_p,cur_mask_inlier,cur_corr1,cur_corr2,img_path1,img_path2 in zip(mask_p,mask_inlier,corr1_kpt,corr2_kpt,data['img_path1'],data['img_path2']):
img1,img2=cv2.imread(img_path1),cv2.imread(img_path2)
dis_play=evaluation_utils.draw_match(img1,img2,cur_corr1[cur_mask_p].cpu().numpy(),cur_corr2[cur_mask_p].cpu().numpy(),inlier=cur_mask_inlier)
base_name_seq=os.path.join(img_path1.split('/')[-1]+'_'+img_path2.split('/')[-1]+'_'+img_path1.split('/')[-2])
save_path=os.path.join(config.train_vis_folder,'train_vis',config.log_base,str(step),base_name_seq+'.png')
cv2.imwrite(save_path,dis_play)