File size: 4,045 Bytes
404d2af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import argparse
import os
import numpy as np
import h5py
import cv2
from numpy.core.numeric import indices
import pyxis as px
from tqdm import trange

import sys
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, ROOT_DIR)

from utils import evaluation_utils,train_utils

parser = argparse.ArgumentParser(description='checking training data.')
parser.add_argument('--meta_dir', type=str, default='dataset/valid')
parser.add_argument('--dataset_dir', type=str, default='dataset')
parser.add_argument('--desc_dir', type=str, default='desc')
parser.add_argument('--raw_dir', type=str, default='raw_data')
parser.add_argument('--desc_suffix', type=str, default='_root_1000.hdf5')
parser.add_argument('--vis_folder',type=str,default=None)
args=parser.parse_args()  



if __name__=='__main__':
    if args.vis_folder is not None and not os.path.exists(args.vis_folder):
        os.mkdir(args.vis_folder)

    pair_num_list=np.loadtxt(os.path.join(args.meta_dir,'pair_num.txt'),dtype=str)
    pair_seq_list,accu_pair_list=train_utils.parse_pair_seq(pair_num_list)
    total_pair=int(pair_num_list[0,1])
    total_inlier_rate,total_corr_num,total_incorr_num=[],[],[]
    pair_num_list=pair_num_list[1:]

    for index in trange(total_pair):
        seq=pair_seq_list[index]
        index_within_seq=index-accu_pair_list[seq]
        with h5py.File(os.path.join(args.dataset_dir,seq,'info.h5py'),'r') as data:
            corr=data['corr'][str(index_within_seq)][()]
            corr1,corr2=corr[:,0],corr[:,1]
            incorr1,incorr2=data['incorr1'][str(index_within_seq)][()],data['incorr2'][str(index_within_seq)][()]
            img_path1,img_path2=data['img_path1'][str(index_within_seq)][()][0].decode(),data['img_path2'][str(index_within_seq)][()][0].decode()
            img_name1,img_name2=img_path1.split('/')[-1],img_path2.split('/')[-1]
            fea_path1,fea_path2=os.path.join(args.desc_dir,seq,img_name1+args.desc_suffix),os.path.join(args.desc_dir,seq,img_name2+args.desc_suffix)
            with h5py.File(fea_path1,'r') as fea1, h5py.File(fea_path2,'r') as fea2:
                desc1,kpt1=fea1['descriptors'][()],fea1['keypoints'][()][:,:2]
                desc2,kpt2=fea2['descriptors'][()],fea2['keypoints'][()][:,:2]
            sim_mat=desc1@desc2.T
            nn_index1,nn_index2=np.argmax(sim_mat,axis=1),np.argmax(sim_mat,axis=0)
            mask_mutual=(nn_index2[nn_index1]==np.arange(len(nn_index1)))[corr1]
            mask_inlier=nn_index1[corr1]==corr2
            mask_nn_correct=np.logical_and(mask_mutual,mask_inlier)
            #statistics
            total_inlier_rate.append(mask_nn_correct.mean())
            total_corr_num.append(len(corr1))
            total_incorr_num.append((len(incorr1)+len(incorr2))/2)
            #dump visualization
            if args.vis_folder is not None:
                #draw corr
                img1,img2=cv2.imread(os.path.join(args.raw_dir,img_path1)),cv2.imread(os.path.join(args.raw_dir,img_path2))
                corr1_pos,corr2_pos=np.take_along_axis(kpt1,corr1[:,np.newaxis],axis=0),np.take_along_axis(kpt2,corr2[:,np.newaxis],axis=0)
                dis_corr=evaluation_utils.draw_match(img1,img2,corr1_pos,corr2_pos)
                cv2.imwrite(os.path.join(args.vis_folder,str(index)+'.png'),dis_corr)
                #draw incorr
                incorr1_pos,incorr2_pos=np.take_along_axis(kpt1,incorr1[:,np.newaxis],axis=0),np.take_along_axis(kpt2,incorr2[:,np.newaxis],axis=0)
                dis_incorr1,dis_incorr2=evaluation_utils.draw_points(img1,incorr1_pos),evaluation_utils.draw_points(img2,incorr2_pos)
                cv2.imwrite(os.path.join(args.vis_folder,str(index)+'_incorr1.png'),dis_incorr1)
                cv2.imwrite(os.path.join(args.vis_folder,str(index)+'_incorr2.png'),dis_incorr2)

    print('NN matching accuracy: ',np.asarray(total_inlier_rate).mean())
    print('mean corr number: ',np.asarray(total_corr_num).mean())
    print('mean incorr number: ',np.asarray(total_incorr_num).mean())