|
import argparse |
|
import pickle |
|
from pathlib import Path |
|
|
|
import cv2 |
|
import h5py |
|
import numpy as np |
|
import pycolmap |
|
import torch |
|
from scipy.io import loadmat |
|
from tqdm import tqdm |
|
|
|
from . import logger |
|
from .utils.parsers import names_to_pair, parse_retrieval |
|
|
|
|
|
def interpolate_scan(scan, kp): |
|
h, w, c = scan.shape |
|
kp = kp / np.array([[w - 1, h - 1]]) * 2 - 1 |
|
assert np.all(kp > -1) and np.all(kp < 1) |
|
scan = torch.from_numpy(scan).permute(2, 0, 1)[None] |
|
kp = torch.from_numpy(kp)[None, None] |
|
grid_sample = torch.nn.functional.grid_sample |
|
|
|
|
|
|
|
interp_lin = grid_sample(scan, kp, align_corners=True, mode="bilinear")[ |
|
0, :, 0 |
|
] |
|
interp_nn = torch.nn.functional.grid_sample( |
|
scan, kp, align_corners=True, mode="nearest" |
|
)[0, :, 0] |
|
interp = torch.where(torch.isnan(interp_lin), interp_nn, interp_lin) |
|
valid = ~torch.any(torch.isnan(interp), 0) |
|
|
|
kp3d = interp.T.numpy() |
|
valid = valid.numpy() |
|
return kp3d, valid |
|
|
|
|
|
def get_scan_pose(dataset_dir, rpath): |
|
split_image_rpath = rpath.split("/") |
|
floor_name = split_image_rpath[-3] |
|
scan_id = split_image_rpath[-2] |
|
image_name = split_image_rpath[-1] |
|
building_name = image_name[:3] |
|
|
|
path = Path( |
|
dataset_dir, |
|
"database/alignments", |
|
floor_name, |
|
f"transformations/{building_name}_trans_{scan_id}.txt", |
|
) |
|
with open(path) as f: |
|
raw_lines = f.readlines() |
|
|
|
P_after_GICP = np.array( |
|
[ |
|
np.fromstring(raw_lines[7], sep=" "), |
|
np.fromstring(raw_lines[8], sep=" "), |
|
np.fromstring(raw_lines[9], sep=" "), |
|
np.fromstring(raw_lines[10], sep=" "), |
|
] |
|
) |
|
|
|
return P_after_GICP |
|
|
|
|
|
def pose_from_cluster( |
|
dataset_dir, q, retrieved, feature_file, match_file, skip=None |
|
): |
|
height, width = cv2.imread(str(dataset_dir / q)).shape[:2] |
|
cx = 0.5 * width |
|
cy = 0.5 * height |
|
focal_length = 4032.0 * 28.0 / 36.0 |
|
|
|
all_mkpq = [] |
|
all_mkpr = [] |
|
all_mkp3d = [] |
|
all_indices = [] |
|
kpq = feature_file[q]["keypoints"].__array__() |
|
num_matches = 0 |
|
|
|
for i, r in enumerate(retrieved): |
|
kpr = feature_file[r]["keypoints"].__array__() |
|
pair = names_to_pair(q, r) |
|
m = match_file[pair]["matches0"].__array__() |
|
v = m > -1 |
|
|
|
if skip and (np.count_nonzero(v) < skip): |
|
continue |
|
|
|
mkpq, mkpr = kpq[v], kpr[m[v]] |
|
num_matches += len(mkpq) |
|
|
|
scan_r = loadmat(Path(dataset_dir, r + ".mat"))["XYZcut"] |
|
mkp3d, valid = interpolate_scan(scan_r, mkpr) |
|
Tr = get_scan_pose(dataset_dir, r) |
|
mkp3d = (Tr[:3, :3] @ mkp3d.T + Tr[:3, -1:]).T |
|
|
|
all_mkpq.append(mkpq[valid]) |
|
all_mkpr.append(mkpr[valid]) |
|
all_mkp3d.append(mkp3d[valid]) |
|
all_indices.append(np.full(np.count_nonzero(valid), i)) |
|
|
|
all_mkpq = np.concatenate(all_mkpq, 0) |
|
all_mkpr = np.concatenate(all_mkpr, 0) |
|
all_mkp3d = np.concatenate(all_mkp3d, 0) |
|
all_indices = np.concatenate(all_indices, 0) |
|
|
|
cfg = { |
|
"model": "SIMPLE_PINHOLE", |
|
"width": width, |
|
"height": height, |
|
"params": [focal_length, cx, cy], |
|
} |
|
ret = pycolmap.absolute_pose_estimation(all_mkpq, all_mkp3d, cfg, 48.00) |
|
ret["cfg"] = cfg |
|
return ret, all_mkpq, all_mkpr, all_mkp3d, all_indices, num_matches |
|
|
|
|
|
def main(dataset_dir, retrieval, features, matches, results, skip_matches=None): |
|
assert retrieval.exists(), retrieval |
|
assert features.exists(), features |
|
assert matches.exists(), matches |
|
|
|
retrieval_dict = parse_retrieval(retrieval) |
|
queries = list(retrieval_dict.keys()) |
|
|
|
feature_file = h5py.File(features, "r", libver="latest") |
|
match_file = h5py.File(matches, "r", libver="latest") |
|
|
|
poses = {} |
|
logs = { |
|
"features": features, |
|
"matches": matches, |
|
"retrieval": retrieval, |
|
"loc": {}, |
|
} |
|
logger.info("Starting localization...") |
|
for q in tqdm(queries): |
|
db = retrieval_dict[q] |
|
ret, mkpq, mkpr, mkp3d, indices, num_matches = pose_from_cluster( |
|
dataset_dir, q, db, feature_file, match_file, skip_matches |
|
) |
|
|
|
poses[q] = (ret["qvec"], ret["tvec"]) |
|
logs["loc"][q] = { |
|
"db": db, |
|
"PnP_ret": ret, |
|
"keypoints_query": mkpq, |
|
"keypoints_db": mkpr, |
|
"3d_points": mkp3d, |
|
"indices_db": indices, |
|
"num_matches": num_matches, |
|
} |
|
|
|
logger.info(f"Writing poses to {results}...") |
|
with open(results, "w") as f: |
|
for q in queries: |
|
qvec, tvec = poses[q] |
|
qvec = " ".join(map(str, qvec)) |
|
tvec = " ".join(map(str, tvec)) |
|
name = q.split("/")[-1] |
|
f.write(f"{name} {qvec} {tvec}\n") |
|
|
|
logs_path = f"{results}_logs.pkl" |
|
logger.info(f"Writing logs to {logs_path}...") |
|
with open(logs_path, "wb") as f: |
|
pickle.dump(logs, f) |
|
logger.info("Done!") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--dataset_dir", type=Path, required=True) |
|
parser.add_argument("--retrieval", type=Path, required=True) |
|
parser.add_argument("--features", type=Path, required=True) |
|
parser.add_argument("--matches", type=Path, required=True) |
|
parser.add_argument("--results", type=Path, required=True) |
|
parser.add_argument("--skip_matches", type=int) |
|
args = parser.parse_args() |
|
main(**args.__dict__) |
|
|