Spaces:
Running
Running
File size: 4,668 Bytes
10b4a5f 358ab8f 10b4a5f 358ab8f 10b4a5f 358ab8f 10b4a5f 358ab8f 10b4a5f 358ab8f 10b4a5f 358ab8f 10b4a5f 358ab8f 10b4a5f 358ab8f 10b4a5f 358ab8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import argparse
import os
import numpy as np
import h5py
import cv2
from numpy.core.numeric import indices
import pyxis as px
from tqdm import trange
import sys
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, ROOT_DIR)
from utils import evaluation_utils, train_utils
parser = argparse.ArgumentParser(description="checking training data.")
parser.add_argument("--meta_dir", type=str, default="dataset/valid")
parser.add_argument("--dataset_dir", type=str, default="dataset")
parser.add_argument("--desc_dir", type=str, default="desc")
parser.add_argument("--raw_dir", type=str, default="raw_data")
parser.add_argument("--desc_suffix", type=str, default="_root_1000.hdf5")
parser.add_argument("--vis_folder", type=str, default=None)
args = parser.parse_args()
if __name__ == "__main__":
if args.vis_folder is not None and not os.path.exists(args.vis_folder):
os.mkdir(args.vis_folder)
pair_num_list = np.loadtxt(os.path.join(args.meta_dir, "pair_num.txt"), dtype=str)
pair_seq_list, accu_pair_list = train_utils.parse_pair_seq(pair_num_list)
total_pair = int(pair_num_list[0, 1])
total_inlier_rate, total_corr_num, total_incorr_num = [], [], []
pair_num_list = pair_num_list[1:]
for index in trange(total_pair):
seq = pair_seq_list[index]
index_within_seq = index - accu_pair_list[seq]
with h5py.File(os.path.join(args.dataset_dir, seq, "info.h5py"), "r") as data:
corr = data["corr"][str(index_within_seq)][()]
corr1, corr2 = corr[:, 0], corr[:, 1]
incorr1, incorr2 = (
data["incorr1"][str(index_within_seq)][()],
data["incorr2"][str(index_within_seq)][()],
)
img_path1, img_path2 = (
data["img_path1"][str(index_within_seq)][()][0].decode(),
data["img_path2"][str(index_within_seq)][()][0].decode(),
)
img_name1, img_name2 = img_path1.split("/")[-1], img_path2.split("/")[-1]
fea_path1, fea_path2 = os.path.join(
args.desc_dir, seq, img_name1 + args.desc_suffix
), os.path.join(args.desc_dir, seq, img_name2 + args.desc_suffix)
with h5py.File(fea_path1, "r") as fea1, h5py.File(fea_path2, "r") as fea2:
desc1, kpt1 = fea1["descriptors"][()], fea1["keypoints"][()][:, :2]
desc2, kpt2 = fea2["descriptors"][()], fea2["keypoints"][()][:, :2]
sim_mat = desc1 @ desc2.T
nn_index1, nn_index2 = np.argmax(sim_mat, axis=1), np.argmax(
sim_mat, axis=0
)
mask_mutual = (nn_index2[nn_index1] == np.arange(len(nn_index1)))[corr1]
mask_inlier = nn_index1[corr1] == corr2
mask_nn_correct = np.logical_and(mask_mutual, mask_inlier)
# statistics
total_inlier_rate.append(mask_nn_correct.mean())
total_corr_num.append(len(corr1))
total_incorr_num.append((len(incorr1) + len(incorr2)) / 2)
# dump visualization
if args.vis_folder is not None:
# draw corr
img1, img2 = cv2.imread(
os.path.join(args.raw_dir, img_path1)
), cv2.imread(os.path.join(args.raw_dir, img_path2))
corr1_pos, corr2_pos = np.take_along_axis(
kpt1, corr1[:, np.newaxis], axis=0
), np.take_along_axis(kpt2, corr2[:, np.newaxis], axis=0)
dis_corr = evaluation_utils.draw_match(img1, img2, corr1_pos, corr2_pos)
cv2.imwrite(
os.path.join(args.vis_folder, str(index) + ".png"), dis_corr
)
# draw incorr
incorr1_pos, incorr2_pos = np.take_along_axis(
kpt1, incorr1[:, np.newaxis], axis=0
), np.take_along_axis(kpt2, incorr2[:, np.newaxis], axis=0)
dis_incorr1, dis_incorr2 = evaluation_utils.draw_points(
img1, incorr1_pos
), evaluation_utils.draw_points(img2, incorr2_pos)
cv2.imwrite(
os.path.join(args.vis_folder, str(index) + "_incorr1.png"),
dis_incorr1,
)
cv2.imwrite(
os.path.join(args.vis_folder, str(index) + "_incorr2.png"),
dis_incorr2,
)
print("NN matching accuracy: ", np.asarray(total_inlier_rate).mean())
print("mean corr number: ", np.asarray(total_corr_num).mean())
print("mean incorr number: ", np.asarray(total_incorr_num).mean())
|