Spaces:
Running
Running
import argparse | |
import pickle | |
from collections import defaultdict | |
from pathlib import Path | |
from typing import Dict, List, Union | |
import numpy as np | |
import pycolmap | |
from tqdm import tqdm | |
from . import logger | |
from .utils.io import get_keypoints, get_matches | |
from .utils.parsers import parse_image_lists, parse_retrieval | |
def do_covisibility_clustering( | |
frame_ids: List[int], reconstruction: pycolmap.Reconstruction | |
): | |
clusters = [] | |
visited = set() | |
for frame_id in frame_ids: | |
# Check if already labeled | |
if frame_id in visited: | |
continue | |
# New component | |
clusters.append([]) | |
queue = {frame_id} | |
while len(queue): | |
exploration_frame = queue.pop() | |
# Already part of the component | |
if exploration_frame in visited: | |
continue | |
visited.add(exploration_frame) | |
clusters[-1].append(exploration_frame) | |
observed = reconstruction.images[exploration_frame].points2D | |
connected_frames = { | |
obs.image_id | |
for p2D in observed | |
if p2D.has_point3D() | |
for obs in reconstruction.points3D[p2D.point3D_id].track.elements | |
} | |
connected_frames &= set(frame_ids) | |
connected_frames -= visited | |
queue |= connected_frames | |
clusters = sorted(clusters, key=len, reverse=True) | |
return clusters | |
class QueryLocalizer: | |
def __init__(self, reconstruction, config=None): | |
self.reconstruction = reconstruction | |
self.config = config or {} | |
def localize(self, points2D_all, points2D_idxs, points3D_id, query_camera): | |
points2D = points2D_all[points2D_idxs] | |
points3D = [self.reconstruction.points3D[j].xyz for j in points3D_id] | |
ret = pycolmap.absolute_pose_estimation( | |
points2D, | |
points3D, | |
query_camera, | |
estimation_options=self.config.get("estimation", {}), | |
refinement_options=self.config.get("refinement", {}), | |
) | |
return ret | |
def pose_from_cluster( | |
localizer: QueryLocalizer, | |
qname: str, | |
query_camera: pycolmap.Camera, | |
db_ids: List[int], | |
features_path: Path, | |
matches_path: Path, | |
**kwargs, | |
): | |
kpq = get_keypoints(features_path, qname) | |
kpq += 0.5 # COLMAP coordinates | |
kp_idx_to_3D = defaultdict(list) | |
kp_idx_to_3D_to_db = defaultdict(lambda: defaultdict(list)) | |
num_matches = 0 | |
for i, db_id in enumerate(db_ids): | |
image = localizer.reconstruction.images[db_id] | |
if image.num_points3D == 0: | |
logger.debug(f"No 3D points found for {image.name}.") | |
continue | |
points3D_ids = np.array( | |
[p.point3D_id if p.has_point3D() else -1 for p in image.points2D] | |
) | |
matches, _ = get_matches(matches_path, qname, image.name) | |
matches = matches[points3D_ids[matches[:, 1]] != -1] | |
num_matches += len(matches) | |
for idx, m in matches: | |
id_3D = points3D_ids[m] | |
kp_idx_to_3D_to_db[idx][id_3D].append(i) | |
# avoid duplicate observations | |
if id_3D not in kp_idx_to_3D[idx]: | |
kp_idx_to_3D[idx].append(id_3D) | |
idxs = list(kp_idx_to_3D.keys()) | |
mkp_idxs = [i for i in idxs for _ in kp_idx_to_3D[i]] | |
mp3d_ids = [j for i in idxs for j in kp_idx_to_3D[i]] | |
ret = localizer.localize(kpq, mkp_idxs, mp3d_ids, query_camera, **kwargs) | |
if ret is not None: | |
ret["camera"] = query_camera | |
# mostly for logging and post-processing | |
mkp_to_3D_to_db = [ | |
(j, kp_idx_to_3D_to_db[i][j]) for i in idxs for j in kp_idx_to_3D[i] | |
] | |
log = { | |
"db": db_ids, | |
"PnP_ret": ret, | |
"keypoints_query": kpq[mkp_idxs], | |
"points3D_ids": mp3d_ids, | |
"points3D_xyz": None, # we don't log xyz anymore because of file size | |
"num_matches": num_matches, | |
"keypoint_index_to_db": (mkp_idxs, mkp_to_3D_to_db), | |
} | |
return ret, log | |
def main( | |
reference_sfm: Union[Path, pycolmap.Reconstruction], | |
queries: Path, | |
retrieval: Path, | |
features: Path, | |
matches: Path, | |
results: Path, | |
ransac_thresh: int = 12, | |
covisibility_clustering: bool = False, | |
prepend_camera_name: bool = False, | |
config: Dict = None, | |
): | |
assert retrieval.exists(), retrieval | |
assert features.exists(), features | |
assert matches.exists(), matches | |
queries = parse_image_lists(queries, with_intrinsics=True) | |
retrieval_dict = parse_retrieval(retrieval) | |
logger.info("Reading the 3D model...") | |
if not isinstance(reference_sfm, pycolmap.Reconstruction): | |
reference_sfm = pycolmap.Reconstruction(reference_sfm) | |
db_name_to_id = {img.name: i for i, img in reference_sfm.images.items()} | |
config = {"estimation": {"ransac": {"max_error": ransac_thresh}}, **(config or {})} | |
localizer = QueryLocalizer(reference_sfm, config) | |
cam_from_world = {} | |
logs = { | |
"features": features, | |
"matches": matches, | |
"retrieval": retrieval, | |
"loc": {}, | |
} | |
logger.info("Starting localization...") | |
for qname, qcam in tqdm(queries): | |
if qname not in retrieval_dict: | |
logger.warning(f"No images retrieved for query image {qname}. Skipping...") | |
continue | |
db_names = retrieval_dict[qname] | |
db_ids = [] | |
for n in db_names: | |
if n not in db_name_to_id: | |
logger.warning(f"Image {n} was retrieved but not in database") | |
continue | |
db_ids.append(db_name_to_id[n]) | |
if covisibility_clustering: | |
clusters = do_covisibility_clustering(db_ids, reference_sfm) | |
best_inliers = 0 | |
best_cluster = None | |
logs_clusters = [] | |
for i, cluster_ids in enumerate(clusters): | |
ret, log = pose_from_cluster( | |
localizer, qname, qcam, cluster_ids, features, matches | |
) | |
if ret is not None and ret["num_inliers"] > best_inliers: | |
best_cluster = i | |
best_inliers = ret["num_inliers"] | |
logs_clusters.append(log) | |
if best_cluster is not None: | |
ret = logs_clusters[best_cluster]["PnP_ret"] | |
cam_from_world[qname] = ret["cam_from_world"] | |
logs["loc"][qname] = { | |
"db": db_ids, | |
"best_cluster": best_cluster, | |
"log_clusters": logs_clusters, | |
"covisibility_clustering": covisibility_clustering, | |
} | |
else: | |
ret, log = pose_from_cluster( | |
localizer, qname, qcam, db_ids, features, matches | |
) | |
if ret is not None: | |
cam_from_world[qname] = ret["cam_from_world"] | |
else: | |
closest = reference_sfm.images[db_ids[0]] | |
cam_from_world[qname] = closest.cam_from_world | |
log["covisibility_clustering"] = covisibility_clustering | |
logs["loc"][qname] = log | |
logger.info(f"Localized {len(cam_from_world)} / {len(queries)} images.") | |
logger.info(f"Writing poses to {results}...") | |
with open(results, "w") as f: | |
for query, t in cam_from_world.items(): | |
qvec = " ".join(map(str, t.rotation.quat[[3, 0, 1, 2]])) | |
tvec = " ".join(map(str, t.translation)) | |
name = query.split("/")[-1] | |
if prepend_camera_name: | |
name = query.split("/")[-2] + "/" + name | |
f.write(f"{name} {qvec} {tvec}\n") | |
logs_path = f"{results}_logs.pkl" | |
logger.info(f"Writing logs to {logs_path}...") | |
# TODO: Resolve pickling issue with pycolmap objects. | |
with open(logs_path, "wb") as f: | |
pickle.dump(logs, f) | |
logger.info("Done!") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--reference_sfm", type=Path, required=True) | |
parser.add_argument("--queries", type=Path, required=True) | |
parser.add_argument("--features", type=Path, required=True) | |
parser.add_argument("--matches", type=Path, required=True) | |
parser.add_argument("--retrieval", type=Path, required=True) | |
parser.add_argument("--results", type=Path, required=True) | |
parser.add_argument("--ransac_thresh", type=float, default=12.0) | |
parser.add_argument("--covisibility_clustering", action="store_true") | |
parser.add_argument("--prepend_camera_name", action="store_true") | |
args = parser.parse_args() | |
main(**args.__dict__) | |