Vincentqyw
fix: roma
c74a070
import os
import glob
import math
import re
import numpy as np
import h5py
from tqdm import trange
from torch.multiprocessing import Pool
import pyxis as px
from .base_dumper import BaseDumper
import sys
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))
sys.path.insert(0, ROOT_DIR)
from utils import transformations, data_utils
class gl3d_train(BaseDumper):
def get_seqs(self):
data_dir = os.path.join(self.config["rawdata_dir"], "data")
seq_train = np.loadtxt(
os.path.join(
self.config["rawdata_dir"], "list", "comb", "imageset_train.txt"
),
dtype=str,
)
seq_valid = np.loadtxt(
os.path.join(
self.config["rawdata_dir"], "list", "comb", "imageset_test.txt"
),
dtype=str,
)
# filtering seq list
self.seq_list, self.train_list, self.valid_list = [], [], []
for seq in seq_train:
if seq not in self.config["exclude_seq"]:
self.train_list.append(seq)
for seq in seq_valid:
if seq not in self.config["exclude_seq"]:
self.valid_list.append(seq)
seq_list = []
if self.config["dump_train"]:
seq_list.append(self.train_list)
if self.config["dump_valid"]:
seq_list.append(self.valid_list)
self.seq_list = np.concatenate(seq_list, axis=0)
# self.seq_list=self.seq_list[:2]
# self.valid_list=self.valid_list[:2]
for seq in self.seq_list:
dump_dir = os.path.join(self.config["feature_dump_dir"], seq)
cur_img_seq = glob.glob(
os.path.join(data_dir, seq, "undist_images", "*.jpg")
)
cur_dump_seq = [
os.path.join(dump_dir, path.split("/")[-1])
+ "_"
+ self.config["extractor"]["name"]
+ "_"
+ str(self.config["extractor"]["num_kpt"])
+ ".hdf5"
for path in cur_img_seq
]
self.img_seq += cur_img_seq
self.dump_seq += cur_dump_seq
def format_dump_folder(self):
if not os.path.exists(self.config["feature_dump_dir"]):
os.mkdir(self.config["feature_dump_dir"])
for seq in self.seq_list:
seq_dir = os.path.join(self.config["feature_dump_dir"], seq)
if not os.path.exists(seq_dir):
os.mkdir(seq_dir)
if not os.path.exists(self.config["dataset_dump_dir"]):
os.mkdir(self.config["dataset_dump_dir"])
def load_geom(self, seq):
# load geometry file
geom_file = os.path.join(
self.config["rawdata_dir"], "data", seq, "geolabel", "cameras.txt"
)
basename_list = np.loadtxt(
os.path.join(self.config["rawdata_dir"], "data", seq, "basenames.txt"),
dtype=str,
)
geom_dict = []
cameras = np.loadtxt(geom_file)
camera_index = 0
for base_index in range(len(basename_list)):
if base_index < cameras[camera_index][0]:
geom_dict.append(None)
continue
cur_geom = {}
ori_img_size = [cameras[camera_index][-2], cameras[camera_index][-1]]
scale_factor = [1000.0 / ori_img_size[0], 1000.0 / ori_img_size[1]]
K = np.asarray(
[
[
cameras[camera_index][1],
cameras[camera_index][5],
cameras[camera_index][3],
],
[0, cameras[camera_index][2], cameras[camera_index][4]],
[0, 0, 1],
]
)
# Rescale calbration according to previous resizing
S = np.asarray(
[[scale_factor[0], 0, 0], [0, scale_factor[1], 0], [0, 0, 1]]
)
K = np.dot(S, K)
cur_geom["K"] = K
cur_geom["R"] = cameras[camera_index][9:18].reshape([3, 3])
cur_geom["T"] = cameras[camera_index][6:9]
cur_geom["size"] = np.asarray([1000, 1000])
geom_dict.append(cur_geom)
camera_index += 1
return geom_dict
def load_depth(self, file_path):
with open(os.path.join(file_path), "rb") as fin:
color = None
width = None
height = None
scale = None
data_type = None
header = str(fin.readline().decode("UTF-8")).rstrip()
if header == "PF":
color = True
elif header == "Pf":
color = False
else:
raise Exception("Not a PFM file.")
dim_match = re.match(r"^(\d+)\s(\d+)\s$", fin.readline().decode("UTF-8"))
if dim_match:
width, height = map(int, dim_match.groups())
else:
raise Exception("Malformed PFM header.")
scale = float((fin.readline().decode("UTF-8")).rstrip())
if scale < 0: # little-endian
data_type = "<f"
else:
data_type = ">f" # big-endian
data_string = fin.read()
data = np.fromstring(data_string, data_type)
shape = (height, width, 3) if color else (height, width)
data = np.reshape(data, shape)
data = np.flip(data, 0)
return data
def dump_info(self, seq, info):
pair_type = [
"dR",
"dt",
"K1",
"K2",
"size1",
"size2",
"corr",
"incorr1",
"incorr2",
]
num_pairs = len(info["dR"])
os.mkdir(os.path.join(self.config["dataset_dump_dir"], seq))
with h5py.File(
os.path.join(self.config["dataset_dump_dir"], seq, "info.h5py"), "w"
) as f:
for type in pair_type:
dg = f.create_group(type)
for idx in range(num_pairs):
data_item = np.asarray(info[type][idx])
dg.create_dataset(
str(idx), data_item.shape, data_item.dtype, data=data_item
)
for type in ["img_path1", "img_path2"]:
dg = f.create_group(type)
for idx in range(num_pairs):
dg.create_dataset(
str(idx),
[1],
h5py.string_dtype(encoding="ascii"),
data=info[type][idx].encode("ascii"),
)
with open(
os.path.join(self.config["dataset_dump_dir"], seq, "pair_num.txt"), "w"
) as f:
f.write(str(info["pair_num"]))
def format_seq(self, index):
seq = self.seq_list[index]
seq_dir = os.path.join(os.path.join(self.config["rawdata_dir"], "data", seq))
basename_list = np.loadtxt(os.path.join(seq_dir, "basenames.txt"), dtype=str)
pair_list = np.loadtxt(
os.path.join(seq_dir, "geolabel", "common_track.txt"), dtype=float
)[:, :2].astype(int)
overlap_score = np.loadtxt(
os.path.join(seq_dir, "geolabel", "common_track.txt"), dtype=float
)[:, 2]
geom_dict = self.load_geom(seq)
# check info existance
if os.path.exists(
os.path.join(self.config["dataset_dump_dir"], seq, "pair_num.txt")
):
return
angle_list = []
# filtering pairs
for cur_pair in pair_list:
pair_index1, pair_index2 = cur_pair[0], cur_pair[1]
geo1, geo2 = geom_dict[pair_index1], geom_dict[pair_index2]
dR = np.dot(geo2["R"], geo1["R"].T)
q = transformations.quaternion_from_matrix(dR)
angle_list.append(math.acos(q[0]) * 2 * 180 / math.pi)
angle_list = np.asarray(angle_list)
mask_survive = np.logical_and(
np.logical_and(
angle_list > self.config["angle_th"][0],
angle_list < self.config["angle_th"][1],
),
np.logical_and(
overlap_score > self.config["overlap_th"][0],
overlap_score < self.config["overlap_th"][1],
),
)
pair_list = pair_list[mask_survive]
if len(pair_list) < 100:
print(seq, len(pair_list))
# sample pairs
shuffled_pair_list = np.random.permutation(pair_list)
sample_target = min(self.config["pairs_per_seq"], len(shuffled_pair_list))
sample_number = 0
info = {
"dR": [],
"dt": [],
"K1": [],
"K2": [],
"img_path1": [],
"img_path2": [],
"fea_path1": [],
"fea_path2": [],
"size1": [],
"size2": [],
"corr": [],
"incorr1": [],
"incorr2": [],
"pair_num": [],
}
for cur_pair in shuffled_pair_list:
pair_index1, pair_index2 = cur_pair[0], cur_pair[1]
geo1, geo2 = geom_dict[pair_index1], geom_dict[pair_index2]
dR = np.dot(geo2["R"], geo1["R"].T)
t1, t2 = geo1["T"].reshape([3, 1]), geo2["T"].reshape([3, 1])
dt = t2 - np.dot(dR, t1)
K1, K2 = geo1["K"], geo2["K"]
size1, size2 = geo1["size"], geo2["size"]
basename1, basename2 = (
basename_list[pair_index1],
basename_list[pair_index2],
)
img_path1, img_path2 = os.path.join(
seq, "undist_images", basename1 + ".jpg"
), os.path.join(seq, "undist_images", basename2 + ".jpg")
fea_path1, fea_path2 = os.path.join(
seq,
basename1
+ ".jpg"
+ "_"
+ self.config["extractor"]["name"]
+ "_"
+ str(self.config["extractor"]["num_kpt"])
+ ".hdf5",
), os.path.join(
seq,
basename2
+ ".jpg"
+ "_"
+ self.config["extractor"]["name"]
+ "_"
+ str(self.config["extractor"]["num_kpt"])
+ ".hdf5",
)
with h5py.File(
os.path.join(self.config["feature_dump_dir"], fea_path1), "r"
) as fea1, h5py.File(
os.path.join(self.config["feature_dump_dir"], fea_path2), "r"
) as fea2:
desc1, desc2 = fea1["descriptors"][()], fea2["descriptors"][()]
kpt1, kpt2 = fea1["keypoints"][()], fea2["keypoints"][()]
depth_path1, depth_path2 = os.path.join(
self.config["rawdata_dir"],
"data",
seq,
"depths",
basename1 + ".pfm",
), os.path.join(
self.config["rawdata_dir"],
"data",
seq,
"depths",
basename2 + ".pfm",
)
depth1, depth2 = self.load_depth(depth_path1), self.load_depth(
depth_path2
)
corr_index, incorr_index1, incorr_index2 = data_utils.make_corr(
kpt1[:, :2],
kpt2[:, :2],
desc1,
desc2,
depth1,
depth2,
K1,
K2,
dR,
dt,
size1,
size2,
self.config["corr_th"],
self.config["incorr_th"],
self.config["check_desc"],
)
if (
len(corr_index) > self.config["min_corr"]
and len(incorr_index1) > self.config["min_incorr"]
and len(incorr_index2) > self.config["min_incorr"]
):
info["corr"].append(corr_index), info["incorr1"].append(
incorr_index1
), info["incorr2"].append(incorr_index2)
info["dR"].append(dR), info["dt"].append(dt), info["K1"].append(
K1
), info["K2"].append(K2), info["img_path1"].append(img_path1), info[
"img_path2"
].append(
img_path2
)
info["fea_path1"].append(fea_path1), info["fea_path2"].append(
fea_path2
), info["size1"].append(size1), info["size2"].append(size2)
sample_number += 1
if sample_number == sample_target:
break
info["pair_num"] = sample_number
# dump info
self.dump_info(seq, info)
def collect_meta(self):
print("collecting meta info...")
dump_path, seq_list = [], []
if self.config["dump_train"]:
dump_path.append(os.path.join(self.config["dataset_dump_dir"], "train"))
seq_list.append(self.train_list)
if self.config["dump_valid"]:
dump_path.append(os.path.join(self.config["dataset_dump_dir"], "valid"))
seq_list.append(self.valid_list)
for pth, seqs in zip(dump_path, seq_list):
if not os.path.exists(pth):
os.mkdir(pth)
pair_num_list, total_pair = [], 0
for seq_index in range(len(seqs)):
seq = seqs[seq_index]
pair_num = np.loadtxt(
os.path.join(self.config["dataset_dump_dir"], seq, "pair_num.txt"),
dtype=int,
)
pair_num_list.append(str(pair_num))
total_pair += pair_num
pair_num_list = np.stack(
[np.asarray(seqs, dtype=str), np.asarray(pair_num_list, dtype=str)],
axis=1,
)
pair_num_list = np.concatenate(
[np.asarray([["total", str(total_pair)]]), pair_num_list], axis=0
)
np.savetxt(os.path.join(pth, "pair_num.txt"), pair_num_list, fmt="%s")
def format_dump_data(self):
print("Formatting data...")
iteration_num = len(self.seq_list) // self.config["num_process"]
if len(self.seq_list) % self.config["num_process"] != 0:
iteration_num += 1
pool = Pool(self.config["num_process"])
for index in trange(iteration_num):
indices = range(
index * self.config["num_process"],
min((index + 1) * self.config["num_process"], len(self.seq_list)),
)
pool.map(self.format_seq, indices)
pool.close()
pool.join()
self.collect_meta()