|
import argparse |
|
import pickle |
|
from collections import defaultdict |
|
from pathlib import Path |
|
from typing import Dict, List, Union |
|
|
|
import numpy as np |
|
import pycolmap |
|
from tqdm import tqdm |
|
|
|
from . import logger |
|
from .utils.io import get_keypoints, get_matches |
|
from .utils.parsers import parse_image_lists, parse_retrieval |
|
|
|
|
|
def do_covisibility_clustering( |
|
frame_ids: List[int], reconstruction: pycolmap.Reconstruction |
|
): |
|
clusters = [] |
|
visited = set() |
|
for frame_id in frame_ids: |
|
|
|
if frame_id in visited: |
|
continue |
|
|
|
|
|
clusters.append([]) |
|
queue = {frame_id} |
|
while len(queue): |
|
exploration_frame = queue.pop() |
|
|
|
|
|
if exploration_frame in visited: |
|
continue |
|
visited.add(exploration_frame) |
|
clusters[-1].append(exploration_frame) |
|
|
|
observed = reconstruction.images[exploration_frame].points2D |
|
connected_frames = { |
|
obs.image_id |
|
for p2D in observed |
|
if p2D.has_point3D() |
|
for obs in reconstruction.points3D[ |
|
p2D.point3D_id |
|
].track.elements |
|
} |
|
connected_frames &= set(frame_ids) |
|
connected_frames -= visited |
|
queue |= connected_frames |
|
|
|
clusters = sorted(clusters, key=len, reverse=True) |
|
return clusters |
|
|
|
|
|
class QueryLocalizer: |
|
def __init__(self, reconstruction, config=None): |
|
self.reconstruction = reconstruction |
|
self.config = config or {} |
|
|
|
def localize(self, points2D_all, points2D_idxs, points3D_id, query_camera): |
|
points2D = points2D_all[points2D_idxs] |
|
points3D = [self.reconstruction.points3D[j].xyz for j in points3D_id] |
|
ret = pycolmap.absolute_pose_estimation( |
|
points2D, |
|
points3D, |
|
query_camera, |
|
estimation_options=self.config.get("estimation", {}), |
|
refinement_options=self.config.get("refinement", {}), |
|
) |
|
return ret |
|
|
|
|
|
def pose_from_cluster( |
|
localizer: QueryLocalizer, |
|
qname: str, |
|
query_camera: pycolmap.Camera, |
|
db_ids: List[int], |
|
features_path: Path, |
|
matches_path: Path, |
|
**kwargs, |
|
): |
|
kpq = get_keypoints(features_path, qname) |
|
kpq += 0.5 |
|
|
|
kp_idx_to_3D = defaultdict(list) |
|
kp_idx_to_3D_to_db = defaultdict(lambda: defaultdict(list)) |
|
num_matches = 0 |
|
for i, db_id in enumerate(db_ids): |
|
image = localizer.reconstruction.images[db_id] |
|
if image.num_points3D == 0: |
|
logger.debug(f"No 3D points found for {image.name}.") |
|
continue |
|
points3D_ids = np.array( |
|
[p.point3D_id if p.has_point3D() else -1 for p in image.points2D] |
|
) |
|
|
|
matches, _ = get_matches(matches_path, qname, image.name) |
|
matches = matches[points3D_ids[matches[:, 1]] != -1] |
|
num_matches += len(matches) |
|
for idx, m in matches: |
|
id_3D = points3D_ids[m] |
|
kp_idx_to_3D_to_db[idx][id_3D].append(i) |
|
|
|
if id_3D not in kp_idx_to_3D[idx]: |
|
kp_idx_to_3D[idx].append(id_3D) |
|
|
|
idxs = list(kp_idx_to_3D.keys()) |
|
mkp_idxs = [i for i in idxs for _ in kp_idx_to_3D[i]] |
|
mp3d_ids = [j for i in idxs for j in kp_idx_to_3D[i]] |
|
ret = localizer.localize(kpq, mkp_idxs, mp3d_ids, query_camera, **kwargs) |
|
if ret is not None: |
|
ret["camera"] = query_camera |
|
|
|
|
|
mkp_to_3D_to_db = [ |
|
(j, kp_idx_to_3D_to_db[i][j]) for i in idxs for j in kp_idx_to_3D[i] |
|
] |
|
log = { |
|
"db": db_ids, |
|
"PnP_ret": ret, |
|
"keypoints_query": kpq[mkp_idxs], |
|
"points3D_ids": mp3d_ids, |
|
"points3D_xyz": None, |
|
"num_matches": num_matches, |
|
"keypoint_index_to_db": (mkp_idxs, mkp_to_3D_to_db), |
|
} |
|
return ret, log |
|
|
|
|
|
def main( |
|
reference_sfm: Union[Path, pycolmap.Reconstruction], |
|
queries: Path, |
|
retrieval: Path, |
|
features: Path, |
|
matches: Path, |
|
results: Path, |
|
ransac_thresh: int = 12, |
|
covisibility_clustering: bool = False, |
|
prepend_camera_name: bool = False, |
|
config: Dict = None, |
|
): |
|
assert retrieval.exists(), retrieval |
|
assert features.exists(), features |
|
assert matches.exists(), matches |
|
|
|
queries = parse_image_lists(queries, with_intrinsics=True) |
|
retrieval_dict = parse_retrieval(retrieval) |
|
|
|
logger.info("Reading the 3D model...") |
|
if not isinstance(reference_sfm, pycolmap.Reconstruction): |
|
reference_sfm = pycolmap.Reconstruction(reference_sfm) |
|
db_name_to_id = {img.name: i for i, img in reference_sfm.images.items()} |
|
|
|
config = { |
|
"estimation": {"ransac": {"max_error": ransac_thresh}}, |
|
**(config or {}), |
|
} |
|
localizer = QueryLocalizer(reference_sfm, config) |
|
|
|
cam_from_world = {} |
|
logs = { |
|
"features": features, |
|
"matches": matches, |
|
"retrieval": retrieval, |
|
"loc": {}, |
|
} |
|
logger.info("Starting localization...") |
|
for qname, qcam in tqdm(queries): |
|
if qname not in retrieval_dict: |
|
logger.warning( |
|
f"No images retrieved for query image {qname}. Skipping..." |
|
) |
|
continue |
|
db_names = retrieval_dict[qname] |
|
db_ids = [] |
|
for n in db_names: |
|
if n not in db_name_to_id: |
|
logger.warning(f"Image {n} was retrieved but not in database") |
|
continue |
|
db_ids.append(db_name_to_id[n]) |
|
|
|
if covisibility_clustering: |
|
clusters = do_covisibility_clustering(db_ids, reference_sfm) |
|
best_inliers = 0 |
|
best_cluster = None |
|
logs_clusters = [] |
|
for i, cluster_ids in enumerate(clusters): |
|
ret, log = pose_from_cluster( |
|
localizer, qname, qcam, cluster_ids, features, matches |
|
) |
|
if ret is not None and ret["num_inliers"] > best_inliers: |
|
best_cluster = i |
|
best_inliers = ret["num_inliers"] |
|
logs_clusters.append(log) |
|
if best_cluster is not None: |
|
ret = logs_clusters[best_cluster]["PnP_ret"] |
|
cam_from_world[qname] = ret["cam_from_world"] |
|
logs["loc"][qname] = { |
|
"db": db_ids, |
|
"best_cluster": best_cluster, |
|
"log_clusters": logs_clusters, |
|
"covisibility_clustering": covisibility_clustering, |
|
} |
|
else: |
|
ret, log = pose_from_cluster( |
|
localizer, qname, qcam, db_ids, features, matches |
|
) |
|
if ret is not None: |
|
cam_from_world[qname] = ret["cam_from_world"] |
|
else: |
|
closest = reference_sfm.images[db_ids[0]] |
|
cam_from_world[qname] = closest.cam_from_world |
|
log["covisibility_clustering"] = covisibility_clustering |
|
logs["loc"][qname] = log |
|
|
|
logger.info(f"Localized {len(cam_from_world)} / {len(queries)} images.") |
|
logger.info(f"Writing poses to {results}...") |
|
with open(results, "w") as f: |
|
for query, t in cam_from_world.items(): |
|
qvec = " ".join(map(str, t.rotation.quat[[3, 0, 1, 2]])) |
|
tvec = " ".join(map(str, t.translation)) |
|
name = query.split("/")[-1] |
|
if prepend_camera_name: |
|
name = query.split("/")[-2] + "/" + name |
|
f.write(f"{name} {qvec} {tvec}\n") |
|
|
|
logs_path = f"{results}_logs.pkl" |
|
logger.info(f"Writing logs to {logs_path}...") |
|
|
|
with open(logs_path, "wb") as f: |
|
pickle.dump(logs, f) |
|
logger.info("Done!") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--reference_sfm", type=Path, required=True) |
|
parser.add_argument("--queries", type=Path, required=True) |
|
parser.add_argument("--features", type=Path, required=True) |
|
parser.add_argument("--matches", type=Path, required=True) |
|
parser.add_argument("--retrieval", type=Path, required=True) |
|
parser.add_argument("--results", type=Path, required=True) |
|
parser.add_argument("--ransac_thresh", type=float, default=12.0) |
|
parser.add_argument("--covisibility_clustering", action="store_true") |
|
parser.add_argument("--prepend_camera_name", action="store_true") |
|
args = parser.parse_args() |
|
main(**args.__dict__) |
|
|