Vincentqyw
fix: roma
358ab8f
raw
history blame
10.3 kB
import numpy as np
import torch
import torch.utils.data as data
import cv2
import os
import h5py
import random
import sys
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))
sys.path.insert(0, ROOT_DIR)
from utils import train_utils, evaluation_utils
torch.multiprocessing.set_sharing_strategy("file_system")
class Offline_Dataset(data.Dataset):
def __init__(self, config, mode):
assert mode == "train" or mode == "valid"
self.config = config
self.mode = mode
metadir = (
os.path.join(config.dataset_path, "valid")
if mode == "valid"
else os.path.join(config.dataset_path, "train")
)
pair_num_list = np.loadtxt(os.path.join(metadir, "pair_num.txt"), dtype=str)
self.total_pairs = int(pair_num_list[0, 1])
self.pair_seq_list, self.accu_pair_num = train_utils.parse_pair_seq(
pair_num_list
)
def collate_fn(self, batch):
batch_size, num_pts = len(batch), batch[0]["x1"].shape[0]
data = {}
dtype = [
"x1",
"x2",
"kpt1",
"kpt2",
"desc1",
"desc2",
"num_corr",
"num_incorr1",
"num_incorr2",
"e_gt",
"pscore1",
"pscore2",
"img_path1",
"img_path2",
]
for key in dtype:
data[key] = []
for sample in batch:
for key in dtype:
data[key].append(sample[key])
for key in [
"x1",
"x2",
"kpt1",
"kpt2",
"desc1",
"desc2",
"e_gt",
"pscore1",
"pscore2",
]:
data[key] = torch.from_numpy(np.stack(data[key])).float()
for key in ["num_corr", "num_incorr1", "num_incorr2"]:
data[key] = torch.from_numpy(np.stack(data[key])).int()
# kpt augmentation with random homography
if self.mode == "train" and self.config.data_aug:
homo_mat = torch.from_numpy(
train_utils.get_rnd_homography(batch_size)
).unsqueeze(1)
aug_seed = random.random()
if aug_seed < 0.5:
x1_homo = torch.cat(
[data["x1"], torch.ones([batch_size, num_pts, 1])], dim=-1
).unsqueeze(-1)
x1_homo = torch.matmul(homo_mat.float(), x1_homo.float()).squeeze(-1)
data["aug_x1"] = x1_homo[:, :, :2] / x1_homo[:, :, 2].unsqueeze(-1)
data["aug_x2"] = data["x2"]
else:
x2_homo = torch.cat(
[data["x2"], torch.ones([batch_size, num_pts, 1])], dim=-1
).unsqueeze(-1)
x2_homo = torch.matmul(homo_mat.float(), x2_homo.float()).squeeze(-1)
data["aug_x2"] = x2_homo[:, :, :2] / x2_homo[:, :, 2].unsqueeze(-1)
data["aug_x1"] = data["x1"]
else:
data["aug_x1"], data["aug_x2"] = data["x1"], data["x2"]
return data
def __getitem__(self, index):
seq = self.pair_seq_list[index]
index_within_seq = index - self.accu_pair_num[seq]
with h5py.File(
os.path.join(self.config.dataset_path, seq, "info.h5py"), "r"
) as data:
R, t = (
data["dR"][str(index_within_seq)][()],
data["dt"][str(index_within_seq)][()],
)
egt = np.reshape(
np.matmul(
np.reshape(
evaluation_utils.np_skew_symmetric(
t.astype("float64").reshape(1, 3)
),
(3, 3),
),
np.reshape(R.astype("float64"), (3, 3)),
),
(3, 3),
)
egt = egt / np.linalg.norm(egt)
K1, K2 = (
data["K1"][str(index_within_seq)][()],
data["K2"][str(index_within_seq)][()],
)
size1, size2 = (
data["size1"][str(index_within_seq)][()],
data["size2"][str(index_within_seq)][()],
)
img_path1, img_path2 = (
data["img_path1"][str(index_within_seq)][()][0].decode(),
data["img_path2"][str(index_within_seq)][()][0].decode(),
)
img_name1, img_name2 = img_path1.split("/")[-1], img_path2.split("/")[-1]
img_path1, img_path2 = os.path.join(
self.config.rawdata_path, img_path1
), os.path.join(self.config.rawdata_path, img_path2)
fea_path1, fea_path2 = os.path.join(
self.config.desc_path, seq, img_name1 + self.config.desc_suffix
), os.path.join(
self.config.desc_path, seq, img_name2 + self.config.desc_suffix
)
with h5py.File(fea_path1, "r") as fea1, h5py.File(fea_path2, "r") as fea2:
desc1, kpt1, pscore1 = (
fea1["descriptors"][()],
fea1["keypoints"][()][:, :2],
fea1["keypoints"][()][:, 2],
)
desc2, kpt2, pscore2 = (
fea2["descriptors"][()],
fea2["keypoints"][()][:, :2],
fea2["keypoints"][()][:, 2],
)
kpt1, kpt2, desc1, desc2 = (
kpt1[: self.config.num_kpt],
kpt2[: self.config.num_kpt],
desc1[: self.config.num_kpt],
desc2[: self.config.num_kpt],
)
# normalize kpt
if self.config.input_normalize == "intrinsic":
x1, x2 = np.concatenate(
[kpt1, np.ones([kpt1.shape[0], 1])], axis=-1
), np.concatenate([kpt2, np.ones([kpt2.shape[0], 1])], axis=-1)
x1, x2 = (
np.matmul(np.linalg.inv(K1), x1.T).T[:, :2],
np.matmul(np.linalg.inv(K2), x2.T).T[:, :2],
)
elif self.config.input_normalize == "img":
x1, x2 = (kpt1 - size1 / 2) / size1, (kpt2 - size2 / 2) / size2
S1_inv, S2_inv = np.asarray(
[
[size1[0], 0, 0.5 * size1[0]],
[0, size1[1], 0.5 * size1[1]],
[0, 0, 1],
]
), np.asarray(
[
[size2[0], 0, 0.5 * size2[0]],
[0, size2[1], 0.5 * size2[1]],
[0, 0, 1],
]
)
M1, M2 = np.matmul(np.linalg.inv(K1), S1_inv), np.matmul(
np.linalg.inv(K2), S2_inv
)
egt = np.matmul(np.matmul(M2.transpose(), egt), M1)
egt = egt / np.linalg.norm(egt)
else:
raise NotImplementedError
corr = data["corr"][str(index_within_seq)][()]
incorr1, incorr2 = (
data["incorr1"][str(index_within_seq)][()],
data["incorr2"][str(index_within_seq)][()],
)
# permute kpt
valid_corr = corr[corr.max(axis=-1) < self.config.num_kpt]
valid_incorr1, valid_incorr2 = (
incorr1[incorr1 < self.config.num_kpt],
incorr2[incorr2 < self.config.num_kpt],
)
num_corr, num_incorr1, num_incorr2 = (
len(valid_corr),
len(valid_incorr1),
len(valid_incorr2),
)
mask1_invlaid, mask2_invalid = np.ones(x1.shape[0]).astype(bool), np.ones(
x2.shape[0]
).astype(bool)
mask1_invlaid[valid_corr[:, 0]] = False
mask2_invalid[valid_corr[:, 1]] = False
mask1_invlaid[valid_incorr1] = False
mask2_invalid[valid_incorr2] = False
invalid_index1, invalid_index2 = (
np.nonzero(mask1_invlaid)[0],
np.nonzero(mask2_invalid)[0],
)
# random sample from point w/o valid annotation
cur_kpt1 = self.config.num_kpt - num_corr - num_incorr1
cur_kpt2 = self.config.num_kpt - num_corr - num_incorr2
if invalid_index1.shape[0] < cur_kpt1:
sub_idx1 = np.concatenate(
[
np.arange(len(invalid_index1)),
np.random.randint(
len(invalid_index1), size=cur_kpt1 - len(invalid_index1)
),
]
)
if invalid_index1.shape[0] >= cur_kpt1:
sub_idx1 = np.random.choice(len(invalid_index1), cur_kpt1, replace=False)
if invalid_index2.shape[0] < cur_kpt2:
sub_idx2 = np.concatenate(
[
np.arange(len(invalid_index2)),
np.random.randint(
len(invalid_index2), size=cur_kpt2 - len(invalid_index2)
),
]
)
if invalid_index2.shape[0] >= cur_kpt2:
sub_idx2 = np.random.choice(len(invalid_index2), cur_kpt2, replace=False)
per_idx1, per_idx2 = np.concatenate(
[valid_corr[:, 0], valid_incorr1, invalid_index1[sub_idx1]]
), np.concatenate([valid_corr[:, 1], valid_incorr2, invalid_index2[sub_idx2]])
pscore1, pscore2 = (
pscore1[per_idx1][:, np.newaxis],
pscore2[per_idx2][:, np.newaxis],
)
x1, x2 = x1[per_idx1][:, :2], x2[per_idx2][:, :2]
desc1, desc2 = desc1[per_idx1], desc2[per_idx2]
kpt1, kpt2 = kpt1[per_idx1], kpt2[per_idx2]
return {
"x1": x1,
"x2": x2,
"kpt1": kpt1,
"kpt2": kpt2,
"desc1": desc1,
"desc2": desc2,
"num_corr": num_corr,
"num_incorr1": num_incorr1,
"num_incorr2": num_incorr2,
"e_gt": egt,
"pscore1": pscore1,
"pscore2": pscore2,
"img_path1": img_path1,
"img_path2": img_path2,
}
def __len__(self):
return self.total_pairs