|
import numpy as np |
|
import torch |
|
import torch.utils.data as data |
|
import cv2 |
|
import os |
|
import h5py |
|
import random |
|
|
|
import sys |
|
|
|
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../")) |
|
sys.path.insert(0, ROOT_DIR) |
|
|
|
from utils import train_utils, evaluation_utils |
|
|
|
torch.multiprocessing.set_sharing_strategy("file_system") |
|
|
|
|
|
class Offline_Dataset(data.Dataset): |
|
def __init__(self, config, mode): |
|
assert mode == "train" or mode == "valid" |
|
|
|
self.config = config |
|
self.mode = mode |
|
metadir = ( |
|
os.path.join(config.dataset_path, "valid") |
|
if mode == "valid" |
|
else os.path.join(config.dataset_path, "train") |
|
) |
|
|
|
pair_num_list = np.loadtxt(os.path.join(metadir, "pair_num.txt"), dtype=str) |
|
self.total_pairs = int(pair_num_list[0, 1]) |
|
self.pair_seq_list, self.accu_pair_num = train_utils.parse_pair_seq( |
|
pair_num_list |
|
) |
|
|
|
def collate_fn(self, batch): |
|
batch_size, num_pts = len(batch), batch[0]["x1"].shape[0] |
|
|
|
data = {} |
|
dtype = [ |
|
"x1", |
|
"x2", |
|
"kpt1", |
|
"kpt2", |
|
"desc1", |
|
"desc2", |
|
"num_corr", |
|
"num_incorr1", |
|
"num_incorr2", |
|
"e_gt", |
|
"pscore1", |
|
"pscore2", |
|
"img_path1", |
|
"img_path2", |
|
] |
|
for key in dtype: |
|
data[key] = [] |
|
for sample in batch: |
|
for key in dtype: |
|
data[key].append(sample[key]) |
|
|
|
for key in [ |
|
"x1", |
|
"x2", |
|
"kpt1", |
|
"kpt2", |
|
"desc1", |
|
"desc2", |
|
"e_gt", |
|
"pscore1", |
|
"pscore2", |
|
]: |
|
data[key] = torch.from_numpy(np.stack(data[key])).float() |
|
for key in ["num_corr", "num_incorr1", "num_incorr2"]: |
|
data[key] = torch.from_numpy(np.stack(data[key])).int() |
|
|
|
|
|
if self.mode == "train" and self.config.data_aug: |
|
homo_mat = torch.from_numpy( |
|
train_utils.get_rnd_homography(batch_size) |
|
).unsqueeze(1) |
|
aug_seed = random.random() |
|
if aug_seed < 0.5: |
|
x1_homo = torch.cat( |
|
[data["x1"], torch.ones([batch_size, num_pts, 1])], dim=-1 |
|
).unsqueeze(-1) |
|
x1_homo = torch.matmul(homo_mat.float(), x1_homo.float()).squeeze(-1) |
|
data["aug_x1"] = x1_homo[:, :, :2] / x1_homo[:, :, 2].unsqueeze(-1) |
|
data["aug_x2"] = data["x2"] |
|
else: |
|
x2_homo = torch.cat( |
|
[data["x2"], torch.ones([batch_size, num_pts, 1])], dim=-1 |
|
).unsqueeze(-1) |
|
x2_homo = torch.matmul(homo_mat.float(), x2_homo.float()).squeeze(-1) |
|
data["aug_x2"] = x2_homo[:, :, :2] / x2_homo[:, :, 2].unsqueeze(-1) |
|
data["aug_x1"] = data["x1"] |
|
else: |
|
data["aug_x1"], data["aug_x2"] = data["x1"], data["x2"] |
|
return data |
|
|
|
def __getitem__(self, index): |
|
seq = self.pair_seq_list[index] |
|
index_within_seq = index - self.accu_pair_num[seq] |
|
|
|
with h5py.File( |
|
os.path.join(self.config.dataset_path, seq, "info.h5py"), "r" |
|
) as data: |
|
R, t = ( |
|
data["dR"][str(index_within_seq)][()], |
|
data["dt"][str(index_within_seq)][()], |
|
) |
|
egt = np.reshape( |
|
np.matmul( |
|
np.reshape( |
|
evaluation_utils.np_skew_symmetric( |
|
t.astype("float64").reshape(1, 3) |
|
), |
|
(3, 3), |
|
), |
|
np.reshape(R.astype("float64"), (3, 3)), |
|
), |
|
(3, 3), |
|
) |
|
egt = egt / np.linalg.norm(egt) |
|
K1, K2 = ( |
|
data["K1"][str(index_within_seq)][()], |
|
data["K2"][str(index_within_seq)][()], |
|
) |
|
size1, size2 = ( |
|
data["size1"][str(index_within_seq)][()], |
|
data["size2"][str(index_within_seq)][()], |
|
) |
|
|
|
img_path1, img_path2 = ( |
|
data["img_path1"][str(index_within_seq)][()][0].decode(), |
|
data["img_path2"][str(index_within_seq)][()][0].decode(), |
|
) |
|
img_name1, img_name2 = img_path1.split("/")[-1], img_path2.split("/")[-1] |
|
img_path1, img_path2 = os.path.join( |
|
self.config.rawdata_path, img_path1 |
|
), os.path.join(self.config.rawdata_path, img_path2) |
|
fea_path1, fea_path2 = os.path.join( |
|
self.config.desc_path, seq, img_name1 + self.config.desc_suffix |
|
), os.path.join( |
|
self.config.desc_path, seq, img_name2 + self.config.desc_suffix |
|
) |
|
with h5py.File(fea_path1, "r") as fea1, h5py.File(fea_path2, "r") as fea2: |
|
desc1, kpt1, pscore1 = ( |
|
fea1["descriptors"][()], |
|
fea1["keypoints"][()][:, :2], |
|
fea1["keypoints"][()][:, 2], |
|
) |
|
desc2, kpt2, pscore2 = ( |
|
fea2["descriptors"][()], |
|
fea2["keypoints"][()][:, :2], |
|
fea2["keypoints"][()][:, 2], |
|
) |
|
kpt1, kpt2, desc1, desc2 = ( |
|
kpt1[: self.config.num_kpt], |
|
kpt2[: self.config.num_kpt], |
|
desc1[: self.config.num_kpt], |
|
desc2[: self.config.num_kpt], |
|
) |
|
|
|
|
|
if self.config.input_normalize == "intrinsic": |
|
x1, x2 = np.concatenate( |
|
[kpt1, np.ones([kpt1.shape[0], 1])], axis=-1 |
|
), np.concatenate([kpt2, np.ones([kpt2.shape[0], 1])], axis=-1) |
|
x1, x2 = ( |
|
np.matmul(np.linalg.inv(K1), x1.T).T[:, :2], |
|
np.matmul(np.linalg.inv(K2), x2.T).T[:, :2], |
|
) |
|
elif self.config.input_normalize == "img": |
|
x1, x2 = (kpt1 - size1 / 2) / size1, (kpt2 - size2 / 2) / size2 |
|
S1_inv, S2_inv = np.asarray( |
|
[ |
|
[size1[0], 0, 0.5 * size1[0]], |
|
[0, size1[1], 0.5 * size1[1]], |
|
[0, 0, 1], |
|
] |
|
), np.asarray( |
|
[ |
|
[size2[0], 0, 0.5 * size2[0]], |
|
[0, size2[1], 0.5 * size2[1]], |
|
[0, 0, 1], |
|
] |
|
) |
|
M1, M2 = np.matmul(np.linalg.inv(K1), S1_inv), np.matmul( |
|
np.linalg.inv(K2), S2_inv |
|
) |
|
egt = np.matmul(np.matmul(M2.transpose(), egt), M1) |
|
egt = egt / np.linalg.norm(egt) |
|
else: |
|
raise NotImplementedError |
|
|
|
corr = data["corr"][str(index_within_seq)][()] |
|
incorr1, incorr2 = ( |
|
data["incorr1"][str(index_within_seq)][()], |
|
data["incorr2"][str(index_within_seq)][()], |
|
) |
|
|
|
|
|
valid_corr = corr[corr.max(axis=-1) < self.config.num_kpt] |
|
valid_incorr1, valid_incorr2 = ( |
|
incorr1[incorr1 < self.config.num_kpt], |
|
incorr2[incorr2 < self.config.num_kpt], |
|
) |
|
num_corr, num_incorr1, num_incorr2 = ( |
|
len(valid_corr), |
|
len(valid_incorr1), |
|
len(valid_incorr2), |
|
) |
|
mask1_invlaid, mask2_invalid = np.ones(x1.shape[0]).astype(bool), np.ones( |
|
x2.shape[0] |
|
).astype(bool) |
|
mask1_invlaid[valid_corr[:, 0]] = False |
|
mask2_invalid[valid_corr[:, 1]] = False |
|
mask1_invlaid[valid_incorr1] = False |
|
mask2_invalid[valid_incorr2] = False |
|
invalid_index1, invalid_index2 = ( |
|
np.nonzero(mask1_invlaid)[0], |
|
np.nonzero(mask2_invalid)[0], |
|
) |
|
|
|
|
|
cur_kpt1 = self.config.num_kpt - num_corr - num_incorr1 |
|
cur_kpt2 = self.config.num_kpt - num_corr - num_incorr2 |
|
|
|
if invalid_index1.shape[0] < cur_kpt1: |
|
sub_idx1 = np.concatenate( |
|
[ |
|
np.arange(len(invalid_index1)), |
|
np.random.randint( |
|
len(invalid_index1), size=cur_kpt1 - len(invalid_index1) |
|
), |
|
] |
|
) |
|
if invalid_index1.shape[0] >= cur_kpt1: |
|
sub_idx1 = np.random.choice(len(invalid_index1), cur_kpt1, replace=False) |
|
if invalid_index2.shape[0] < cur_kpt2: |
|
sub_idx2 = np.concatenate( |
|
[ |
|
np.arange(len(invalid_index2)), |
|
np.random.randint( |
|
len(invalid_index2), size=cur_kpt2 - len(invalid_index2) |
|
), |
|
] |
|
) |
|
if invalid_index2.shape[0] >= cur_kpt2: |
|
sub_idx2 = np.random.choice(len(invalid_index2), cur_kpt2, replace=False) |
|
|
|
per_idx1, per_idx2 = np.concatenate( |
|
[valid_corr[:, 0], valid_incorr1, invalid_index1[sub_idx1]] |
|
), np.concatenate([valid_corr[:, 1], valid_incorr2, invalid_index2[sub_idx2]]) |
|
|
|
pscore1, pscore2 = ( |
|
pscore1[per_idx1][:, np.newaxis], |
|
pscore2[per_idx2][:, np.newaxis], |
|
) |
|
x1, x2 = x1[per_idx1][:, :2], x2[per_idx2][:, :2] |
|
desc1, desc2 = desc1[per_idx1], desc2[per_idx2] |
|
kpt1, kpt2 = kpt1[per_idx1], kpt2[per_idx2] |
|
|
|
return { |
|
"x1": x1, |
|
"x2": x2, |
|
"kpt1": kpt1, |
|
"kpt2": kpt2, |
|
"desc1": desc1, |
|
"desc2": desc2, |
|
"num_corr": num_corr, |
|
"num_incorr1": num_incorr1, |
|
"num_incorr2": num_incorr2, |
|
"e_gt": egt, |
|
"pscore1": pscore1, |
|
"pscore2": pscore2, |
|
"img_path1": img_path1, |
|
"img_path2": img_path2, |
|
} |
|
|
|
def __len__(self): |
|
return self.total_pairs |
|
|