diff --git a/hloc/extractors/sfd2.py b/hloc/extractors/sfd2.py index 9fb76eddd29347be56be162afc346b0ab9bb934a..1bd6188faa8ac8bfa647e6d5bcb3a9dfc07a2f30 100644 --- a/hloc/extractors/sfd2.py +++ b/hloc/extractors/sfd2.py @@ -1,4 +1,3 @@ -# -*- coding: UTF-8 -*- import sys from pathlib import Path @@ -7,10 +6,9 @@ import torchvision.transforms as tvf from .. import logger from ..utils.base_model import BaseModel -pram_path = Path(__file__).parent / "../../third_party/pram" -sys.path.append(str(pram_path)) - -from nets.sfd2 import load_sfd2 +tp_path = Path(__file__).parent / "../../third_party" +sys.path.append(str(tp_path)) +from pram.nets.sfd2 import load_sfd2 class SFD2(BaseModel): @@ -26,8 +24,8 @@ class SFD2(BaseModel): self.norm_rgb = tvf.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) - model_fn = pram_path / "weights" / self.conf["model_name"] - self.net = load_sfd2(weight_path=model_fn).eval() + model_path = tp_path / "pram" / "weights" / self.conf["model_name"] + self.net = load_sfd2(weight_path=model_path).eval() logger.info("Load SFD2 model done.") diff --git a/hloc/matchers/eloftr.py b/hloc/matchers/eloftr.py index 2c1e6245eb720c5b3545f9e2f5d2a6a5a93cb95b..d22906de8bf7cc912745c21b950458829dee5d19 100644 --- a/hloc/matchers/eloftr.py +++ b/hloc/matchers/eloftr.py @@ -5,18 +5,22 @@ from pathlib import Path import torch -eloftr_path = Path(__file__).parent / "../../third_party/EfficientLoFTR" -sys.path.append(str(eloftr_path)) +tp_path = Path(__file__).parent / "../../third_party" +sys.path.append(str(tp_path)) -from src.loftr import LoFTR as ELoFTR_ -from src.loftr import full_default_cfg, opt_default_cfg, reparameter +from EfficientLoFTR.src.loftr import LoFTR as ELoFTR_ +from EfficientLoFTR.src.loftr import ( + full_default_cfg, + opt_default_cfg, + reparameter, +) from hloc import logger from ..utils.base_model import BaseModel -class LoFTR(BaseModel): +class ELoFTR(BaseModel): default_conf = { "weights": "weights/eloftr_outdoor.ckpt", "match_threshold": 0.2, @@ -40,7 +44,7 @@ class LoFTR(BaseModel): _default_cfg["mp"] = True elif self.conf["precision"] == "fp16": _default_cfg["half"] = True - model_path = eloftr_path / self.conf["weights"] + model_path = tp_path / "EfficientLoFTR" / self.conf["weights"] cfg = _default_cfg cfg["match_coarse"]["thr"] = conf["match_threshold"] # cfg["match_coarse"]["skh_iters"] = conf["sinkhorn_iterations"] diff --git a/hloc/matchers/imp.py b/hloc/matchers/imp.py index ca64980ef70c52672806476fdc65bb4d39479f10..05c3cb96b05410985ca97f89d8fe55a4d71be501 100644 --- a/hloc/matchers/imp.py +++ b/hloc/matchers/imp.py @@ -1,4 +1,3 @@ -# -*- coding: UTF-8 -*- import sys from pathlib import Path @@ -7,10 +6,9 @@ import torch from .. import DEVICE, logger from ..utils.base_model import BaseModel -pram_path = Path(__file__).parent / "../../third_party/pram" -sys.path.append(str(pram_path)) - -from nets.gml import GML +tp_path = Path(__file__).parent / "../../third_party" +sys.path.append(str(tp_path)) +from pram.nets.gml import GML class IMP(BaseModel): @@ -33,7 +31,8 @@ class IMP(BaseModel): def _init(self, conf): self.conf = {**self.default_conf, **conf} - weight_path = pram_path / "weights" / self.conf["model_name"] + weight_path = tp_path / "pram" / "weights" / self.conf["model_name"] + # self.net = nets.gml(self.conf).eval().to(DEVICE) self.net = GML(self.conf).eval().to(DEVICE) self.net.load_state_dict( torch.load(weight_path, map_location="cpu")["model"], strict=True diff --git a/third_party/pram/.gitignore b/third_party/pram/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..e76db3ee25df1858b0cec129d3e7c0eb84637c09 --- /dev/null +++ b/third_party/pram/.gitignore @@ -0,0 +1,13 @@ +.idea +__pycache__ +weights/12scenes* +weights/7scenes* +weights/aachen* +weights/cambridgelandmarks* +weights/imp_adagml.80.pth +landmarks +3D-models +log_* +*.log +.nfs* +Pangolin diff --git a/third_party/pram/LICENSE b/third_party/pram/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..0bde2a83689b0ae97269181bc848fd581d23e828 --- /dev/null +++ b/third_party/pram/LICENSE @@ -0,0 +1,2 @@ +This work is licensed under the Creative Commons Attribution-NonCommercial 4.0 International License. +To view a copy of this license, visit http://creativecommons.org/licenses/by-nc/4.0/. diff --git a/third_party/pram/README.md b/third_party/pram/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b8ceb745c82fd44f1ef2c1808ab3993fb4d3890c --- /dev/null +++ b/third_party/pram/README.md @@ -0,0 +1,207 @@ +## PRAM: Place Recognition Anywhere Model for Efficient Visual Localization + +
+ +
+ +Humans localize themselves efficiently in known environments by first recognizing landmarks defined on certain objects +and their spatial relationships, and then verifying the location by aligning detailed structures of recognized objects +with those in the memory. Inspired by this, we propose the place recognition anywhere model (PRAM) to perform visual +localization as efficiently as humans do. PRAM consists of two main components - recognition and registration. In +detail, first of all, a self-supervised map-centric landmark definition strategy is adopted, making places in either +indoor or outdoor scenes act as unique landmarks. Then, sparse keypoints extracted from images, are utilized as the +input to a transformer-based deep neural network for landmark recognition; these keypoints enable PRAM to recognize +hundreds of landmarks with high time and memory efficiency. Keypoints along with recognized landmark labels are further +used for registration between query images and the 3D landmark map. Different from previous hierarchical methods, PRAM +discards global and local descriptors, and reduces over 90% storage. Since PRAM utilizes recognition and landmark-wise +verification to replace global reference search and exhaustive matching respectively, it runs 2.4 times faster than +prior state-of-the-art approaches. Moreover, PRAM opens new directions for visual localization including multi-modality +localization, map-centric feature learning, and hierarchical scene coordinate regression. + +* Full paper + PDF: [Place Recognition Anywhere Model for Efficient Visual Localization](https://arxiv.org/pdf/2404.07785.pdf). + +* Authors: *Fei Xue, Ignas Budvytis, Roberto Cipolla* + +* Website: [PRAM](https://feixue94.github.io/pram-project) for videos, slides, recent updates, and datasets. + +## Key Features + +### 1. Self-supervised landmark definition on 3D space + +- No need of segmentations on images +- No inconsistent semantic results from multi-view images +- No limitation to labels of only known objects +- Work in any places with known or unknown objects +- Landmark-wise 3D map sparsification + ++ +
+ +### 2. Efficient landmark-wise coarse and fine localization + +- Recognize landmarks as opposed to do global retrieval +- Local landmark-wise matching as opposed to exhaustive matching +- No global descriptors (e.g. NetVLAD) +- No reference images and their heavy repetative 2D keypoints and descriptors +- Automatic inlier/outlier idetification + ++ +
+ +### 4. Sparse recognition + +- Sparse SFD2 keypoints as tokens +- No uncertainties of points at boundaries +- Flexible to accept multi-modality inputs + +### 5. Relocalization and temporal localization + +- Per frame reclocalization from scratch +- Tracking previous frames for higher efficiency + +### 6. One model one dataset + +- All 7 subscenes in 7Scenes dataset share a model +- All 12 subscenes in 12Scenes dataset share a model +- All 5 subscenes in CambridgeLandmarks share a model + +### 7. Robust to long-term changes + ++ +
+ +## Open problems + +- Adaptive number landmarks determination +- Using SAM + open vocabulary to generate semantic map +- Multi-modality localization with other tokenized signals (e.g. text, language, GPS, Magonemeter) +- More effective solutions to 3D sparsification + +## Preparation + +1. Download the 7Scenes, 12Scenes, CambridgeLandmarks, and Aachen datasets (remove redundant depth images otherwise they + will be found in the sfm process) +2. Environments + +2.1 Create a virtual environment + +``` +conda env create -f environment.yml +(do not activate pram before pangolin is installed) +``` + +2.2 Compile Pangolin for the installed python + +``` +git clone --recursive https://github.com/stevenlovegrove/Pangolin.git +cd Pangolin +git checkout v0.8 + +# Install dependencies +./scripts/install_prerequisites.sh recommended + +# Compile with your python +cmake -DPython_EXECUTABLE=/your path to/anaconda3/envs/pram/bin/python3 -B build +cmake --build build -t pypangolin_pip_install + +conda activate pram +``` + +## Run the localization with online visualization + +1. Download the [3D-models](https://drive.google.com/drive/folders/1DUB073KxAjsc8lxhMpFuxPRf0ZBQS6NS?usp=drive_link), + pretrained [models](https://drive.google.com/drive/folders/1E2QvujCevqnyg_CM9FGAa0AxKkt4KbLD?usp=drive_link) , + and [landmarks](https://drive.google.com/drive/folders/1r9src9bz7k3WYGfaPmKJ9gqxuvdfxZU0?usp=sharing) +2. Put pretrained models in ```weights``` directory +3. Run the demo (e.g. 7Scenes) + +``` +python3 inference.py --config configs/config_train_7scenes_sfd2.yaml --rec_weight_path weights/7scenes_nc113_birch_segnetvit.199.pth --landmark_path /your path to/landmarks --online +``` + +## Train the recognition model (e.g. for 7Scenes) + +### 1. Do SfM with SFD2 including feature extraction (modify the dataset_dir, ref_sfm_dir, output_dir) + +``` +./sfm_scripts/reconstruct_7scenes.sh +``` + +This step will produce the SfM results together with the extracted keypoints + +### 2. Generate 3D landmarks + +``` +python3 -m recognition.recmap --dataset 7Scenes --dataset_dir /your path to/7Scenes --sfm_dir /sfm_path/7Scenes --save_dir /save_path/landmakrs +``` + +This step will generate 3D landmarks, create virtual reference frame, and sparsify the 3D points for each landmark for +all scenes in 7Scenes + +### 3. Train the sparse recognition model (one model one dataset) + +``` +python3 train.py --config configs/config_train_7scenes_sfd2.yaml +``` + +Remember to modify the paths in 'config_train_7scenes_sfd2.yaml' + +## Your own dataset + +1. Run colmap or hloc to obtain the SfM results +2. Do reconstruction with SFD2 keypoints with the sfm from step as refernece sfm +3. Do 3D landmark generation, VRF, map sparsification etc (Add DatasetName.yaml to configs/datasets) +4. Train the recognition model +5. Do evaluation + +## Previous works can be found here + +1. [Efficient large-scale localization by landmark recognition, CVPR 2022](https://github.com/feixue94/lbr) +2. [IMP: Iterative Matching and Pose Estimation with Adaptive Pooling, CVPR 2023](https://github.com/feixue94/imp-release) +3. [SFD2: Semantic-guided Feature Detection and Description, CVPR 2023](https://github.com/feixue94/sfd2) +4. [VRS-NeRF: Visual Relocalization with Sparse Neural Radiance Field, under review](https://github.com/feixue94/vrs-nerf) + +## BibTeX Citation + +If you use any ideas from the paper or code in this repo, please consider citing: + +``` + @article{xue2024pram, + author = {Fei Xue and Ignas Budvytis and Roberto Cipolla}, + title = {PRAM: Place Recognition Anywhere Model for Efficient Visual Localization}, + journal = {arXiv preprint arXiv:2404.07785}, + year = {2024} + } + +@inproceedings{xue2023sfd2, + author = {Fei Xue and Ignas Budvytis and Roberto Cipolla}, + title = {SFD2: Semantic-guided Feature Detection and Description}, + booktitle = {CVPR}, + year = {2023} +} + +@inproceedings{xue2022imp, + author = {Fei Xue and Ignas Budvytis and Roberto Cipolla}, + title = {IMP: Iterative Matching and Pose Estimation with Adaptive Pooling}, + booktitle = {CVPR}, + year = {2023} +} + +@inproceedings{xue2022efficient, + author = {Fei Xue and Ignas Budvytis and Daniel Olmeda Reino and Roberto Cipolla}, + title = {Efficient Large-scale Localization by Global Instance Recognition}, + booktitle = {CVPR}, + year = {2022} +} +``` + +## Acknowledgements + +Part of the code is from previous excellent works +including , [SuperGlue](https://github.com/magicleap/SuperGluePretrainedNetwork) +and [hloc](https://github.com/cvg/Hierarchical-Localization). You can find more details from their released +repositories if you are interested in their works. \ No newline at end of file diff --git a/third_party/pram/assets/map_sparsification.gif b/third_party/pram/assets/map_sparsification.gif new file mode 100644 index 0000000000000000000000000000000000000000..63133a4b49805d0311aec8572fc10482f21d97f1 --- /dev/null +++ b/third_party/pram/assets/map_sparsification.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd7bbe3b0bad7c6ae330eaa702b2839533a6f27ad5a0b104c4a37597c0c37aad +size 493481 diff --git a/third_party/pram/assets/multi_recognition.png b/third_party/pram/assets/multi_recognition.png new file mode 100644 index 0000000000000000000000000000000000000000..7b12f484fb23daccd0bc83509db99fdf200fe79b --- /dev/null +++ b/third_party/pram/assets/multi_recognition.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c84e81cb990adedc25ef612b31d1ec53f7cb9f2168ef2246f2f03ca479cca9cf +size 2460085 diff --git a/third_party/pram/assets/overview.png b/third_party/pram/assets/overview.png new file mode 100644 index 0000000000000000000000000000000000000000..e5cc9c60f72a7590dace5db4e29eb848f0676b40 --- /dev/null +++ b/third_party/pram/assets/overview.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:466b1f2b6a38cb956a389c1fc69c213c1655579c0c944174b6e95e247209eedc +size 662283 diff --git a/third_party/pram/assets/pipeline1.png b/third_party/pram/assets/pipeline1.png new file mode 100644 index 0000000000000000000000000000000000000000..780d9639033cb33aa765b571b486be9b96a44b9b --- /dev/null +++ b/third_party/pram/assets/pipeline1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0bd0545bc3f4814d4b9f18893965529a08a73263e80a3978755162935e05d2b3 +size 3990973 diff --git a/third_party/pram/assets/sam_openvoc.png b/third_party/pram/assets/sam_openvoc.png new file mode 100644 index 0000000000000000000000000000000000000000..aabb6e166dce60f09acbb2578e526eb573f7a1e4 --- /dev/null +++ b/third_party/pram/assets/sam_openvoc.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b3e0b06b6917402ed010cd4054e2efcf75c04ede84be53f17d147e2dd388d15a +size 1148808 diff --git a/third_party/pram/colmap_utils/camera_intrinsics.py b/third_party/pram/colmap_utils/camera_intrinsics.py new file mode 100644 index 0000000000000000000000000000000000000000..41bdc5055dfb451fa1f4dac3f27931675b68333f --- /dev/null +++ b/third_party/pram/colmap_utils/camera_intrinsics.py @@ -0,0 +1,30 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File localizer -> camera_intrinsics +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 15/08/2023 12:33 +==================================================''' +import numpy as np + + +def intrinsics_from_camera(camera_model, params): + if camera_model in ("SIMPLE_PINHOLE", "SIMPLE_RADIAL", "RADIAL"): + fx = fy = params[0] + cx = params[1] + cy = params[2] + elif camera_model in ("PINHOLE", "OPENCV", "OPENCV_FISHEYE", "FULL_OPENCV"): + fx = params[0] + fy = params[1] + cx = params[2] + cy = params[3] + else: + raise Exception("Camera model not supported") + + # intrinsics + K = np.identity(3) + K[0, 0] = fx + K[1, 1] = fy + K[0, 2] = cx + K[1, 2] = cy + return K diff --git a/third_party/pram/colmap_utils/database.py b/third_party/pram/colmap_utils/database.py new file mode 100644 index 0000000000000000000000000000000000000000..37638347834f4b0b1432846adf9a83693b509a7f --- /dev/null +++ b/third_party/pram/colmap_utils/database.py @@ -0,0 +1,352 @@ +# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of +# its contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) + +# This script is based on an original implementation by True Price. + +import sys +import sqlite3 +import numpy as np + + +IS_PYTHON3 = sys.version_info[0] >= 3 + +MAX_IMAGE_ID = 2**31 - 1 + +CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras ( + camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + model INTEGER NOT NULL, + width INTEGER NOT NULL, + height INTEGER NOT NULL, + params BLOB, + prior_focal_length INTEGER NOT NULL)""" + +CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors ( + image_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB, + FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)""" + +CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images ( + image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + name TEXT NOT NULL UNIQUE, + camera_id INTEGER NOT NULL, + prior_qw REAL, + prior_qx REAL, + prior_qy REAL, + prior_qz REAL, + prior_tx REAL, + prior_ty REAL, + prior_tz REAL, + CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < {}), + FOREIGN KEY(camera_id) REFERENCES cameras(camera_id)) +""".format(MAX_IMAGE_ID) + +CREATE_TWO_VIEW_GEOMETRIES_TABLE = """ +CREATE TABLE IF NOT EXISTS two_view_geometries ( + pair_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB, + config INTEGER NOT NULL, + F BLOB, + E BLOB, + H BLOB) +""" + +CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints ( + image_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB, + FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE) +""" + +CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches ( + pair_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB)""" + +CREATE_NAME_INDEX = \ + "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)" + +CREATE_ALL = "; ".join([ + CREATE_CAMERAS_TABLE, + CREATE_IMAGES_TABLE, + CREATE_KEYPOINTS_TABLE, + CREATE_DESCRIPTORS_TABLE, + CREATE_MATCHES_TABLE, + CREATE_TWO_VIEW_GEOMETRIES_TABLE, + CREATE_NAME_INDEX +]) + + +def image_ids_to_pair_id(image_id1, image_id2): + if image_id1 > image_id2: + image_id1, image_id2 = image_id2, image_id1 + return image_id1 * MAX_IMAGE_ID + image_id2 + + +def pair_id_to_image_ids(pair_id): + image_id2 = pair_id % MAX_IMAGE_ID + image_id1 = (pair_id - image_id2) / MAX_IMAGE_ID + return image_id1, image_id2 + + +def array_to_blob(array): + if IS_PYTHON3: + return array.tostring() + else: + return np.getbuffer(array) + + +def blob_to_array(blob, dtype, shape=(-1,)): + if IS_PYTHON3: + return np.fromstring(blob, dtype=dtype).reshape(*shape) + else: + return np.frombuffer(blob, dtype=dtype).reshape(*shape) + + +class COLMAPDatabase(sqlite3.Connection): + + @staticmethod + def connect(database_path): + return sqlite3.connect(str(database_path), factory=COLMAPDatabase) + + + def __init__(self, *args, **kwargs): + super(COLMAPDatabase, self).__init__(*args, **kwargs) + + self.create_tables = lambda: self.executescript(CREATE_ALL) + self.create_cameras_table = \ + lambda: self.executescript(CREATE_CAMERAS_TABLE) + self.create_descriptors_table = \ + lambda: self.executescript(CREATE_DESCRIPTORS_TABLE) + self.create_images_table = \ + lambda: self.executescript(CREATE_IMAGES_TABLE) + self.create_two_view_geometries_table = \ + lambda: self.executescript(CREATE_TWO_VIEW_GEOMETRIES_TABLE) + self.create_keypoints_table = \ + lambda: self.executescript(CREATE_KEYPOINTS_TABLE) + self.create_matches_table = \ + lambda: self.executescript(CREATE_MATCHES_TABLE) + self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX) + + def add_camera(self, model, width, height, params, + prior_focal_length=False, camera_id=None): + params = np.asarray(params, np.float64) + cursor = self.execute( + "INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)", + (camera_id, model, width, height, array_to_blob(params), + prior_focal_length)) + return cursor.lastrowid + + def add_image(self, name, camera_id, + prior_q=np.zeros(4), prior_t=np.zeros(3), image_id=None): + cursor = self.execute( + "INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + (image_id, name, camera_id, prior_q[0], prior_q[1], prior_q[2], + prior_q[3], prior_t[0], prior_t[1], prior_t[2])) + return cursor.lastrowid + + def add_keypoints(self, image_id, keypoints): + assert(len(keypoints.shape) == 2) + assert(keypoints.shape[1] in [2, 4, 6]) + + keypoints = np.asarray(keypoints, np.float32) + self.execute( + "INSERT INTO keypoints VALUES (?, ?, ?, ?)", + (image_id,) + keypoints.shape + (array_to_blob(keypoints),)) + + def add_descriptors(self, image_id, descriptors): + descriptors = np.ascontiguousarray(descriptors, np.uint8) + self.execute( + "INSERT INTO descriptors VALUES (?, ?, ?, ?)", + (image_id,) + descriptors.shape + (array_to_blob(descriptors),)) + + def add_matches(self, image_id1, image_id2, matches): + assert(len(matches.shape) == 2) + assert(matches.shape[1] == 2) + + if image_id1 > image_id2: + matches = matches[:,::-1] + + pair_id = image_ids_to_pair_id(image_id1, image_id2) + matches = np.asarray(matches, np.uint32) + self.execute( + "INSERT INTO matches VALUES (?, ?, ?, ?)", + (pair_id,) + matches.shape + (array_to_blob(matches),)) + + def add_two_view_geometry(self, image_id1, image_id2, matches, + F=np.eye(3), E=np.eye(3), H=np.eye(3), config=2): + assert(len(matches.shape) == 2) + assert(matches.shape[1] == 2) + + if image_id1 > image_id2: + matches = matches[:,::-1] + + pair_id = image_ids_to_pair_id(image_id1, image_id2) + matches = np.asarray(matches, np.uint32) + F = np.asarray(F, dtype=np.float64) + E = np.asarray(E, dtype=np.float64) + H = np.asarray(H, dtype=np.float64) + self.execute( + "INSERT INTO two_view_geometries VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + (pair_id,) + matches.shape + (array_to_blob(matches), config, + array_to_blob(F), array_to_blob(E), array_to_blob(H))) + + +def example_usage(): + import os + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--database_path", default="database.db") + args = parser.parse_args() + + if os.path.exists(args.database_path): + print("ERROR: database path already exists -- will not modify it.") + return + + # Open the database. + + db = COLMAPDatabase.connect(args.database_path) + + # For convenience, try creating all the tables upfront. + + db.create_tables() + + # Create dummy cameras. + + model1, width1, height1, params1 = \ + 0, 1024, 768, np.array((1024., 512., 384.)) + model2, width2, height2, params2 = \ + 2, 1024, 768, np.array((1024., 512., 384., 0.1)) + + camera_id1 = db.add_camera(model1, width1, height1, params1) + camera_id2 = db.add_camera(model2, width2, height2, params2) + + # Create dummy images. + + image_id1 = db.add_image("image1.png", camera_id1) + image_id2 = db.add_image("image2.png", camera_id1) + image_id3 = db.add_image("image3.png", camera_id2) + image_id4 = db.add_image("image4.png", camera_id2) + + # Create dummy keypoints. + # + # Note that COLMAP supports: + # - 2D keypoints: (x, y) + # - 4D keypoints: (x, y, theta, scale) + # - 6D affine keypoints: (x, y, a_11, a_12, a_21, a_22) + + num_keypoints = 1000 + keypoints1 = np.random.rand(num_keypoints, 2) * (width1, height1) + keypoints2 = np.random.rand(num_keypoints, 2) * (width1, height1) + keypoints3 = np.random.rand(num_keypoints, 2) * (width2, height2) + keypoints4 = np.random.rand(num_keypoints, 2) * (width2, height2) + + db.add_keypoints(image_id1, keypoints1) + db.add_keypoints(image_id2, keypoints2) + db.add_keypoints(image_id3, keypoints3) + db.add_keypoints(image_id4, keypoints4) + + # Create dummy matches. + + M = 50 + matches12 = np.random.randint(num_keypoints, size=(M, 2)) + matches23 = np.random.randint(num_keypoints, size=(M, 2)) + matches34 = np.random.randint(num_keypoints, size=(M, 2)) + + db.add_matches(image_id1, image_id2, matches12) + db.add_matches(image_id2, image_id3, matches23) + db.add_matches(image_id3, image_id4, matches34) + + # Commit the data to the file. + + db.commit() + + # Read and check cameras. + + rows = db.execute("SELECT * FROM cameras") + + camera_id, model, width, height, params, prior = next(rows) + params = blob_to_array(params, np.float64) + assert camera_id == camera_id1 + assert model == model1 and width == width1 and height == height1 + assert np.allclose(params, params1) + + camera_id, model, width, height, params, prior = next(rows) + params = blob_to_array(params, np.float64) + assert camera_id == camera_id2 + assert model == model2 and width == width2 and height == height2 + assert np.allclose(params, params2) + + # Read and check keypoints. + + keypoints = dict( + (image_id, blob_to_array(data, np.float32, (-1, 2))) + for image_id, data in db.execute( + "SELECT image_id, data FROM keypoints")) + + assert np.allclose(keypoints[image_id1], keypoints1) + assert np.allclose(keypoints[image_id2], keypoints2) + assert np.allclose(keypoints[image_id3], keypoints3) + assert np.allclose(keypoints[image_id4], keypoints4) + + # Read and check matches. + + pair_ids = [image_ids_to_pair_id(*pair) for pair in + ((image_id1, image_id2), + (image_id2, image_id3), + (image_id3, image_id4))] + + matches = dict( + (pair_id_to_image_ids(pair_id), + blob_to_array(data, np.uint32, (-1, 2))) + for pair_id, data in db.execute("SELECT pair_id, data FROM matches") + ) + + assert np.all(matches[(image_id1, image_id2)] == matches12) + assert np.all(matches[(image_id2, image_id3)] == matches23) + assert np.all(matches[(image_id3, image_id4)] == matches34) + + # Clean up. + + db.close() + + if os.path.exists(args.database_path): + os.remove(args.database_path) + + +if __name__ == "__main__": + example_usage() \ No newline at end of file diff --git a/third_party/pram/colmap_utils/geometry.py b/third_party/pram/colmap_utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..0d48f0a9545f04300f0f914515e650bb60957296 --- /dev/null +++ b/third_party/pram/colmap_utils/geometry.py @@ -0,0 +1,17 @@ +# -*- coding: UTF-8 -*- +import numpy as np +import pycolmap + + +def to_homogeneous(p): + return np.pad(p, ((0, 0),) * (p.ndim - 1) + ((0, 1),), constant_values=1) + + +def compute_epipolar_errors(j_from_i: pycolmap.Rigid3d, p2d_i, p2d_j): + j_E_i = j_from_i.essential_matrix() + l2d_j = to_homogeneous(p2d_i) @ j_E_i.T + l2d_i = to_homogeneous(p2d_j) @ j_E_i + dist = np.abs(np.sum(to_homogeneous(p2d_i) * l2d_i, axis=1)) + errors_i = dist / np.linalg.norm(l2d_i[:, :2], axis=1) + errors_j = dist / np.linalg.norm(l2d_j[:, :2], axis=1) + return errors_i, errors_j diff --git a/third_party/pram/colmap_utils/io.py b/third_party/pram/colmap_utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..1ad46c685ca2a2fbb166d22884948f3fd6547368 --- /dev/null +++ b/third_party/pram/colmap_utils/io.py @@ -0,0 +1,78 @@ +# -*- coding: UTF-8 -*- +from pathlib import Path +from typing import Tuple + +import cv2 +import h5py +import numpy as np + +from .parsers import names_to_pair, names_to_pair_old + + +def read_image(path, grayscale=False): + if grayscale: + mode = cv2.IMREAD_GRAYSCALE + else: + mode = cv2.IMREAD_COLOR + image = cv2.imread(str(path), mode) + if image is None: + raise ValueError(f"Cannot read image {path}.") + if not grayscale and len(image.shape) == 3: + image = image[:, :, ::-1] # BGR to RGB + return image + + +def list_h5_names(path): + names = [] + with h5py.File(str(path), "r", libver="latest") as fd: + def visit_fn(_, obj): + if isinstance(obj, h5py.Dataset): + names.append(obj.parent.name.strip("/")) + + fd.visititems(visit_fn) + return list(set(names)) + + +def get_keypoints( + path: Path, name: str, return_uncertainty: bool = False +) -> np.ndarray: + with h5py.File(str(path), "r", libver="latest") as hfile: + dset = hfile[name]["keypoints"] + p = dset.__array__() + uncertainty = dset.attrs.get("uncertainty") + if return_uncertainty: + return p, uncertainty + return p + + +def find_pair(hfile: h5py.File, name0: str, name1: str): + pair = names_to_pair(name0, name1) + if pair in hfile: + return pair, False + pair = names_to_pair(name1, name0) + if pair in hfile: + return pair, True + # older, less efficient format + pair = names_to_pair_old(name0, name1) + if pair in hfile: + return pair, False + pair = names_to_pair_old(name1, name0) + if pair in hfile: + return pair, True + raise ValueError( + f"Could not find pair {(name0, name1)}... " + "Maybe you matched with a different list of pairs? " + ) + + +def get_matches(path: Path, name0: str, name1: str) -> Tuple[np.ndarray]: + with h5py.File(str(path), "r", libver="latest") as hfile: + pair, reverse = find_pair(hfile, name0, name1) + matches = hfile[pair]["matches0"].__array__() + scores = hfile[pair]["matching_scores0"].__array__() + idx = np.where(matches != -1)[0] + matches = np.stack([idx, matches[idx]], -1) + if reverse: + matches = np.flip(matches, -1) + scores = scores[idx] + return matches, scores diff --git a/third_party/pram/colmap_utils/parsers.py b/third_party/pram/colmap_utils/parsers.py new file mode 100644 index 0000000000000000000000000000000000000000..3e9087d78cc8cf7f1e81ab8359862227c3882786 --- /dev/null +++ b/third_party/pram/colmap_utils/parsers.py @@ -0,0 +1,73 @@ +# -*- coding: UTF-8 -*- + +from pathlib import Path +import logging +import numpy as np +from collections import defaultdict + + +def parse_image_lists_with_intrinsics(paths): + results = [] + files = list(Path(paths.parent).glob(paths.name)) + assert len(files) > 0 + + for lfile in files: + with open(lfile, 'r') as f: + raw_data = f.readlines() + + logging.info(f'Importing {len(raw_data)} queries in {lfile.name}') + for data in raw_data: + data = data.strip('\n').split(' ') + name, camera_model, width, height = data[:4] + params = np.array(data[4:], float) + info = (camera_model, int(width), int(height), params) + results.append((name, info)) + + assert len(results) > 0 + return results + + +def parse_img_lists_for_extended_cmu_seaons(paths): + Ks = { + "c0": "OPENCV 1024 768 868.993378 866.063001 525.942323 420.042529 -0.399431 0.188924 0.000153 0.000571", + "c1": "OPENCV 1024 768 868.993378 866.063001 525.942323 420.042529 -0.399431 0.188924 0.000153 0.000571" + } + + results = [] + files = list(Path(paths.parent).glob(paths.name)) + assert len(files) > 0 + + for lfile in files: + with open(lfile, 'r') as f: + raw_data = f.readlines() + + logging.info(f'Importing {len(raw_data)} queries in {lfile.name}') + for name in raw_data: + name = name.strip('\n') + camera = name.split('_')[2] + K = Ks[camera].split(' ') + camera_model, width, height = K[:3] + params = np.array(K[3:], float) + # print("camera: ", camera_model, width, height, params) + info = (camera_model, int(width), int(height), params) + results.append((name, info)) + + assert len(results) > 0 + return results + + +def parse_retrieval(path): + retrieval = defaultdict(list) + with open(path, 'r') as f: + for p in f.read().rstrip('\n').split('\n'): + q, r = p.split(' ') + retrieval[q].append(r) + return dict(retrieval) + + +def names_to_pair_old(name0, name1): + return '_'.join((name0.replace('/', '-'), name1.replace('/', '-'))) + + +def names_to_pair(name0, name1, separator="/"): + return separator.join((name0.replace("/", "-"), name1.replace("/", "-"))) diff --git a/third_party/pram/colmap_utils/read_write_model.py b/third_party/pram/colmap_utils/read_write_model.py new file mode 100644 index 0000000000000000000000000000000000000000..eddbeb7edd364c27c54029fa81077ea4f75d2700 --- /dev/null +++ b/third_party/pram/colmap_utils/read_write_model.py @@ -0,0 +1,627 @@ +# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of +# its contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) + +import os +import sys +import collections +import numpy as np +import struct +import argparse + +CameraModel = collections.namedtuple( + "CameraModel", ["model_id", "model_name", "num_params"]) +Camera = collections.namedtuple( + "Camera", ["id", "model", "width", "height", "params"]) +BaseImage = collections.namedtuple( + "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) +Point3D = collections.namedtuple( + "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) + + +class Image(BaseImage): + def qvec2rotmat(self): + return qvec2rotmat(self.qvec) + + +CAMERA_MODELS = { + CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), + CameraModel(model_id=1, model_name="PINHOLE", num_params=4), + CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), + CameraModel(model_id=3, model_name="RADIAL", num_params=5), + CameraModel(model_id=4, model_name="OPENCV", num_params=8), + CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), + CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), + CameraModel(model_id=7, model_name="FOV", num_params=5), + CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), + CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), + CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) +} +CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) + for camera_model in CAMERA_MODELS]) +CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) + for camera_model in CAMERA_MODELS]) + + +def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): + """Read and unpack the next bytes from a binary file. + :param fid: + :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + :param endian_character: Any of {@, =, <, >, !} + :return: Tuple of read and unpacked values. + """ + data = fid.read(num_bytes) + return struct.unpack(endian_character + format_char_sequence, data) + + +def write_next_bytes(fid, data, format_char_sequence, endian_character="<"): + """pack and write to a binary file. + :param fid: + :param data: data to send, if multiple elements are sent at the same time, + they should be encapsuled either in a list or a tuple + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + should be the same length as the data list or tuple + :param endian_character: Any of {@, =, <, >, !} + """ + if isinstance(data, (list, tuple)): + bytes = struct.pack(endian_character + format_char_sequence, *data) + else: + bytes = struct.pack(endian_character + format_char_sequence, data) + fid.write(bytes) + + +def read_cameras_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasText(const std::string& path) + void Reconstruction::ReadCamerasText(const std::string& path) + """ + cameras = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + camera_id = int(elems[0]) + model = elems[1] + width = int(elems[2]) + height = int(elems[3]) + params = np.array(tuple(map(float, elems[4:]))) + cameras[camera_id] = Camera(id=camera_id, model=model, + width=width, height=height, + params=params) + return cameras + + +def read_cameras_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + cameras = {} + with open(path_to_model_file, "rb") as fid: + num_cameras = read_next_bytes(fid, 8, "Q")[0] + for camera_line_index in range(num_cameras): + camera_properties = read_next_bytes( + fid, num_bytes=24, format_char_sequence="iiQQ") + camera_id = camera_properties[0] + model_id = camera_properties[1] + model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name + width = camera_properties[2] + height = camera_properties[3] + num_params = CAMERA_MODEL_IDS[model_id].num_params + params = read_next_bytes(fid, num_bytes=8 * num_params, + format_char_sequence="d" * num_params) + cameras[camera_id] = Camera(id=camera_id, + model=model_name, + width=width, + height=height, + params=np.array(params)) + assert len(cameras) == num_cameras + return cameras + + +def write_cameras_text(cameras, path): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasText(const std::string& path) + void Reconstruction::ReadCamerasText(const std::string& path) + """ + HEADER = '# Camera list with one line of data per camera:\n' + '# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n' + '# Number of cameras: {}\n'.format(len(cameras)) + with open(path, "w") as fid: + fid.write(HEADER) + for _, cam in cameras.items(): + to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params] + line = " ".join([str(elem) for elem in to_write]) + fid.write(line + "\n") + + +def write_cameras_binary(cameras, path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(cameras), "Q") + for _, cam in cameras.items(): + model_id = CAMERA_MODEL_NAMES[cam.model].model_id + camera_properties = [cam.id, + model_id, + cam.width, + cam.height] + write_next_bytes(fid, camera_properties, "iiQQ") + for p in cam.params: + write_next_bytes(fid, float(p), "d") + return cameras + + +def read_images_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesText(const std::string& path) + void Reconstruction::WriteImagesText(const std::string& path) + """ + images = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + image_id = int(elems[0]) + qvec = np.array(tuple(map(float, elems[1:5]))) + tvec = np.array(tuple(map(float, elems[5:8]))) + camera_id = int(elems[8]) + image_name = elems[9] + elems = fid.readline().split() + xys = np.column_stack([tuple(map(float, elems[0::3])), + tuple(map(float, elems[1::3]))]) + point3D_ids = np.array(tuple(map(int, elems[2::3]))) + images[image_id] = Image( + id=image_id, qvec=qvec, tvec=tvec, + camera_id=camera_id, name=image_name, + xys=xys, point3D_ids=point3D_ids) + return images + + +def read_images_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + images = {} + with open(path_to_model_file, "rb") as fid: + num_reg_images = read_next_bytes(fid, 8, "Q")[0] + for image_index in range(num_reg_images): + binary_image_properties = read_next_bytes( + fid, num_bytes=64, format_char_sequence="idddddddi") + image_id = binary_image_properties[0] + qvec = np.array(binary_image_properties[1:5]) + tvec = np.array(binary_image_properties[5:8]) + camera_id = binary_image_properties[8] + image_name = "" + current_char = read_next_bytes(fid, 1, "c")[0] + while current_char != b"\x00": # look for the ASCII 0 entry + image_name += current_char.decode("utf-8") + current_char = read_next_bytes(fid, 1, "c")[0] + num_points2D = read_next_bytes(fid, num_bytes=8, + format_char_sequence="Q")[0] + x_y_id_s = read_next_bytes(fid, num_bytes=24 * num_points2D, + format_char_sequence="ddq" * num_points2D) + xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), + tuple(map(float, x_y_id_s[1::3]))]) + point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) + images[image_id] = Image( + id=image_id, qvec=qvec, tvec=tvec, + camera_id=camera_id, name=image_name, + xys=xys, point3D_ids=point3D_ids) + return images + + +def write_images_text(images, path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesText(const std::string& path) + void Reconstruction::WriteImagesText(const std::string& path) + """ + if len(images) == 0: + mean_observations = 0 + else: + mean_observations = sum((len(img.point3D_ids) for _, img in images.items())) / len(images) + HEADER = '# Image list with two lines of data per image:\n' + '# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n' + '# POINTS2D[] as (X, Y, POINT3D_ID)\n' + '# Number of images: {}, mean observations per image: {}\n'.format(len(images), mean_observations) + + with open(path, "w") as fid: + fid.write(HEADER) + for _, img in images.items(): + image_header = [img.id, *img.qvec, *img.tvec, img.camera_id, img.name] + first_line = " ".join(map(str, image_header)) + fid.write(first_line + "\n") + + points_strings = [] + for xy, point3D_id in zip(img.xys, img.point3D_ids): + points_strings.append(" ".join(map(str, [*xy, point3D_id]))) + fid.write(" ".join(points_strings) + "\n") + + +def write_images_binary(images, path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(images), "Q") + for _, img in images.items(): + write_next_bytes(fid, img.id, "i") + write_next_bytes(fid, img.qvec.tolist(), "dddd") + write_next_bytes(fid, img.tvec.tolist(), "ddd") + write_next_bytes(fid, img.camera_id, "i") + for char in img.name: + write_next_bytes(fid, char.encode("utf-8"), "c") + write_next_bytes(fid, b"\x00", "c") + write_next_bytes(fid, len(img.point3D_ids), "Q") + for xy, p3d_id in zip(img.xys, img.point3D_ids): + write_next_bytes(fid, [*xy, p3d_id], "ddq") + + +def read_points3D_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + points3D = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + point3D_id = int(elems[0]) + xyz = np.array(tuple(map(float, elems[1:4]))) + rgb = np.array(tuple(map(int, elems[4:7]))) + error = float(elems[7]) + image_ids = np.array(tuple(map(int, elems[8::2]))) + point2D_idxs = np.array(tuple(map(int, elems[9::2]))) + points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb, + error=error, image_ids=image_ids, + point2D_idxs=point2D_idxs) + return points3D + + +def read_points3d_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + points3D = {} + with open(path_to_model_file, "rb") as fid: + num_points = read_next_bytes(fid, 8, "Q")[0] + for point_line_index in range(num_points): + binary_point_line_properties = read_next_bytes( + fid, num_bytes=43, format_char_sequence="QdddBBBd") + point3D_id = binary_point_line_properties[0] + xyz = np.array(binary_point_line_properties[1:4]) + rgb = np.array(binary_point_line_properties[4:7]) + error = np.array(binary_point_line_properties[7]) + track_length = read_next_bytes( + fid, num_bytes=8, format_char_sequence="Q")[0] + track_elems = read_next_bytes( + fid, num_bytes=8 * track_length, + format_char_sequence="ii" * track_length) + image_ids = np.array(tuple(map(int, track_elems[0::2]))) + point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) + points3D[point3D_id] = Point3D( + id=point3D_id, xyz=xyz, rgb=rgb, + error=error, image_ids=image_ids, + point2D_idxs=point2D_idxs) + return points3D + + +def write_points3D_text(points3D, path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + if len(points3D) == 0: + mean_track_length = 0 + else: + mean_track_length = sum((len(pt.image_ids) for _, pt in points3D.items())) / len(points3D) + HEADER = '# 3D point list with one line of data per point:\n' + '# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n' + '# Number of points: {}, mean track length: {}\n'.format(len(points3D), mean_track_length) + + with open(path, "w") as fid: + fid.write(HEADER) + for _, pt in points3D.items(): + point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error] + fid.write(" ".join(map(str, point_header)) + " ") + track_strings = [] + for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs): + track_strings.append(" ".join(map(str, [image_id, point2D]))) + fid.write(" ".join(track_strings) + "\n") + + +def write_points3d_binary(points3D, path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(points3D), "Q") + for _, pt in points3D.items(): + write_next_bytes(fid, pt.id, "Q") + write_next_bytes(fid, pt.xyz.tolist(), "ddd") + write_next_bytes(fid, pt.rgb.tolist(), "BBB") + write_next_bytes(fid, pt.error, "d") + track_length = pt.image_ids.shape[0] + write_next_bytes(fid, track_length, "Q") + for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs): + write_next_bytes(fid, [image_id, point2D_id], "ii") + + +def read_model(path, ext): + if ext == ".txt": + cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) + images = read_images_text(os.path.join(path, "images" + ext)) + points3D = read_points3D_text(os.path.join(path, "points3D") + ext) + else: + cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) + images = read_images_binary(os.path.join(path, "images" + ext)) + points3D = read_points3d_binary(os.path.join(path, "points3D") + ext) + return cameras, images, points3D + + +def write_model(cameras, images, points3D, path, ext): + if ext == ".txt": + write_cameras_text(cameras, os.path.join(path, "cameras" + ext)) + write_images_text(images, os.path.join(path, "images" + ext)) + write_points3D_text(points3D, os.path.join(path, "points3D") + ext) + else: + write_cameras_binary(cameras, os.path.join(path, "cameras" + ext)) + write_images_binary(images, os.path.join(path, "images" + ext)) + write_points3d_binary(points3D, os.path.join(path, "points3D") + ext) + return cameras, images, points3D + + +def read_compressed_images_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + images = {} + with open(path_to_model_file, "rb") as fid: + num_reg_images = read_next_bytes(fid, 8, "Q")[0] + for image_index in range(num_reg_images): + binary_image_properties = read_next_bytes( + fid, num_bytes=64, format_char_sequence="idddddddi") + image_id = binary_image_properties[0] + qvec = np.array(binary_image_properties[1:5]) + tvec = np.array(binary_image_properties[5:8]) + camera_id = binary_image_properties[8] + image_name = "" + current_char = read_next_bytes(fid, 1, "c")[0] + while current_char != b"\x00": # look for the ASCII 0 entry + image_name += current_char.decode("utf-8") + current_char = read_next_bytes(fid, 1, "c")[0] + num_points2D = read_next_bytes(fid, num_bytes=8, + format_char_sequence="Q")[0] + # x_y_id_s = read_next_bytes(fid, num_bytes=24 * num_points2D, + # format_char_sequence="ddq" * num_points2D) + # xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), + # tuple(map(float, x_y_id_s[1::3]))]) + x_y_id_s = read_next_bytes(fid, num_bytes=8 * num_points2D, + format_char_sequence="q" * num_points2D) + point3D_ids = np.array(x_y_id_s) + images[image_id] = Image( + id=image_id, qvec=qvec, tvec=tvec, + camera_id=camera_id, name=image_name, + xys=np.array([]), point3D_ids=point3D_ids) + return images + + +def write_compressed_images_binary(images, path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(images), "Q") + for _, img in images.items(): + write_next_bytes(fid, img.id, "i") + write_next_bytes(fid, img.qvec.tolist(), "dddd") + write_next_bytes(fid, img.tvec.tolist(), "ddd") + write_next_bytes(fid, img.camera_id, "i") + for char in img.name: + write_next_bytes(fid, char.encode("utf-8"), "c") + write_next_bytes(fid, b"\x00", "c") + write_next_bytes(fid, len(img.point3D_ids), "Q") + for p3d_id in img.point3D_ids: + write_next_bytes(fid, p3d_id, "q") + # for xy, p3d_id in zip(img.xys, img.point3D_ids): + # write_next_bytes(fid, [*xy, p3d_id], "ddq") + + +def read_compressed_points3d_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + points3D = {} + with open(path_to_model_file, "rb") as fid: + num_points = read_next_bytes(fid, 8, "Q")[0] + for point_line_index in range(num_points): + binary_point_line_properties = read_next_bytes( + fid, num_bytes=43, format_char_sequence="QdddBBBd") + point3D_id = binary_point_line_properties[0] + xyz = np.array(binary_point_line_properties[1:4]) + rgb = np.array(binary_point_line_properties[4:7]) + error = np.array(binary_point_line_properties[7]) + track_length = read_next_bytes( + fid, num_bytes=8, format_char_sequence="Q")[0] + track_elems = read_next_bytes( + fid, num_bytes=4 * track_length, + format_char_sequence="i" * track_length) + image_ids = np.array(track_elems) + # point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) + points3D[point3D_id] = Point3D( + id=point3D_id, xyz=xyz, rgb=rgb, + error=error, image_ids=image_ids, + point2D_idxs=np.array([])) + return points3D + + +def write_compressed_points3d_binary(points3D, path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(points3D), "Q") + for _, pt in points3D.items(): + write_next_bytes(fid, pt.id, "Q") + write_next_bytes(fid, pt.xyz.tolist(), "ddd") + write_next_bytes(fid, pt.rgb.tolist(), "BBB") + write_next_bytes(fid, pt.error, "d") + track_length = pt.image_ids.shape[0] + write_next_bytes(fid, track_length, "Q") + # for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs): + # write_next_bytes(fid, [image_id, point2D_id], "ii") + for image_id in pt.image_ids: + write_next_bytes(fid, image_id, "i") + + +def read_compressed_model(path, ext): + if ext == ".txt": + cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) + images = read_images_text(os.path.join(path, "images" + ext)) + points3D = read_points3D_text(os.path.join(path, "points3D") + ext) + else: + cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) + images = read_compressed_images_binary(os.path.join(path, "images" + ext)) + points3D = read_compressed_points3d_binary(os.path.join(path, "points3D") + ext) + return cameras, images, points3D + + +def qvec2rotmat(qvec): + return np.array([ + [1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], + [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], + [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2]]) + + +def rotmat2qvec(R): + Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat + K = np.array([ + [Rxx - Ryy - Rzz, 0, 0, 0], + [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], + [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], + [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 + eigvals, eigvecs = np.linalg.eigh(K) + qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] + if qvec[0] < 0: + qvec *= -1 + return qvec + + +def intrinsics_from_camera(camera_model, params): + if camera_model in ("SIMPLE_PINHOLE", "SIMPLE_RADIAL", "RADIAL"): + fx = fy = params[0] + cx = params[1] + cy = params[2] + elif camera_model in ("PINHOLE", "OPENCV", "OPENCV_FISHEYE", "FULL_OPENCV"): + fx = params[0] + fy = params[1] + cx = params[2] + cy = params[3] + else: + raise Exception("Camera model not supported") + + # intrinsics + K = np.identity(3) + K[0, 0] = fx + K[1, 1] = fy + K[0, 2] = cx + K[1, 2] = cy + return K + + +def main(): + parser = argparse.ArgumentParser(description='Read and write COLMAP binary and text models') + parser.add_argument('input_model', help='path to input model folder') + parser.add_argument('input_format', choices=['.bin', '.txt'], + help='input model format') + parser.add_argument('--output_model', metavar='PATH', + help='path to output model folder') + parser.add_argument('--output_format', choices=['.bin', '.txt'], + help='outut model format', default='.txt') + args = parser.parse_args() + + cameras, images, points3D = read_model(path=args.input_model, ext=args.input_format) + + print("num_cameras:", len(cameras)) + print("num_images:", len(images)) + print("num_points3D:", len(points3D)) + + if args.output_model is not None: + write_model(cameras, images, points3D, path=args.output_model, ext=args.output_format) + + +if __name__ == "__main__": + main() diff --git a/third_party/pram/colmap_utils/utils.py b/third_party/pram/colmap_utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8d98fed2dfc5789b650144caa3a4bac8cfe6a2fb --- /dev/null +++ b/third_party/pram/colmap_utils/utils.py @@ -0,0 +1 @@ +# -*- coding: UTF-8 -*- diff --git a/third_party/pram/configs/config_train_12scenes_sfd2.yaml b/third_party/pram/configs/config_train_12scenes_sfd2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1e6e7fb7c851edb8bd6e26e8d4806cadeb5977d5 --- /dev/null +++ b/third_party/pram/configs/config_train_12scenes_sfd2.yaml @@ -0,0 +1,102 @@ +dataset: [ '12Scenes' ] + +network_1: "segnet" +network: "segnetvit" + +local_rank: 0 +gpu: [ 0 ] + +feature: "sfd2" +save_path: '/scratches/flyer_2/fx221/exp/pram' +landmark_path: "/scratches/flyer_3/fx221/exp/pram/landmarks/sfd2-gml" +dataset_path: "/scratches/flyer_3/fx221/dataset" +config_path: 'configs/datasets' + +image_dim: 3 +feat_dim: 128 +min_inliers: 32 +max_inliers: 512 +random_inliers: true +max_keypoints: 512 +ignore_index: -1 +output_dim: 1024 +output_dim_: 2048 +jitter_params: + brightness: 0.5 + contrast: 0.5 + saturation: 0.25 + hue: 0.15 + blur: 0 + +scale_params: [ 0.5, 1.0 ] +pre_load: false +train: true +inlier_th: 0.5 +lr: 0.0001 +min_lr: 0.00001 +optimizer: "adamw" +seg_loss: "cew" +seg_loss_nx: "cei" +cls_loss: "ce" +cls_loss_: "bce" +ac_fn: "relu" +norm_fn: "bn" +workers: 8 +layers: 15 +log_intervals: 50 +eval_n_epoch: 10 +do_eval: false + +use_mid_feature: true +norm_desc: false +with_score: false +with_aug: true +with_dist: true + +batch_size: 32 +its_per_epoch: 1000 +decay_rate: 0.999992 +decay_iter: 60000 +epochs: 500 + +cluster_method: 'birch' + +weight_path: null +weight_path_1: '20230719_220620_segnet_L15_T_resnet4x_B32_K1024_relu_bn_od1024_nc193_adamw_cew_md_A_birch/segnet.499.pth' +weight_path_2: '20240202_145337_segnetvit_L15_T_resnet4x_B32_K512_relu_bn_od1024_nc193_adam_cew_md_A_birch/segnetvit.499.pth' + +resume_path: null + +n_class: 193 + +eval_max_keypoints: 1024 + +localization: + loc_scene_name: [ 'apt1/kitchen' ] + save_path: '/scratches/flyer_2/fx221/exp/localizer/loc_results' + seg_k: 20 + threshold: 8 + min_kpts: 128 + min_matches: 4 + min_inliers: 64 + matching_method_: "mnn" + matching_method_1: "spg" + matching_method_2: "gm" + matching_method: "gml" + matching_method_5: "adagml" + save: false + show: true + show_time: 1 + max_vrf: 1 + with_original: true + with_extra: false + with_compress: true + semantic_matching: true + do_refinement: true + refinement_method_: 'matching' + refinement_method: 'projection' + pre_filtering_th: 0.95 + covisibility_frame: 20 + refinement_radius: 20 + refinement_nn_ratio: 0.9 + refinement_max_matches: 0 diff --git a/third_party/pram/configs/config_train_7scenes_sfd2.yaml b/third_party/pram/configs/config_train_7scenes_sfd2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..19b0635c9ad4ebcf0a085a759640e4a149a75009 --- /dev/null +++ b/third_party/pram/configs/config_train_7scenes_sfd2.yaml @@ -0,0 +1,104 @@ +dataset: [ '7Scenes' ] + +network: "segnetvit" + +local_rank: 0 +gpu: [ 0 ] +# when using ddp, set gpu: [0,1,2,3] +with_dist: true + +feature: "sfd2" +save_path_: '/scratches/flyer_2/fx221/exp/pram' +save_path: '/scratches/flyer_2/fx221/publications/test_pram/exp' +landmark_path_: "/scratches/flyer_3/fx221/exp/pram/landmarks/sfd2-gml" +landmark_path: "/scratches/flyer_2/fx221/publications/test_pram/landmakrs/sfd2-gml" +dataset_path: "/scratches/flyer_3/fx221/dataset" +config_path: 'configs/datasets' + +image_dim: 3 +feat_dim: 128 + +min_inliers: 32 +max_inliers: 256 +random_inliers: 1 +max_keypoints: 512 +ignore_index: -1 +output_dim: 1024 +output_dim_: 2048 +jitter_params: + brightness: 0.5 + contrast: 0.5 + saturation: 0.25 + hue: 0.15 + blur: 0 + +scale_params: [ 0.5, 1.0 ] +pre_load: false +train: true +inlier_th: 0.5 +lr: 0.0001 +min_lr: 0.00001 +cls_loss: "ce" +ac_fn: "relu" +norm_fn: "bn" +workers: 8 +layers: 15 +log_intervals: 50 +eval_n_epoch: 10 +do_eval: false + +use_mid_feature: true +norm_desc: false +with_cls: false +with_score: false +with_aug: true + +batch_size: 32 +its_per_epoch: 1000 +decay_rate: 0.999992 +decay_iter: 80000 +epochs: 200 + +cluster_method: 'birch' + +weight_path: null +weight_path_1: '20230724_203230_segnet_L15_S_resnet4x_B32_K1024_relu_bn_od1024_nc113_adam_cew_md_A_birch/segnet.180.pth' +weight_path_2: '20240202_152519_segnetvit_L15_S_resnet4x_B32_K512_relu_bn_od1024_nc113_adamw_cew_md_A_birch/segnetvit.199.pth' + +# used for resuming training +resume_path: null + +# used for localization +n_class: 113 + +eval_max_keypoints: 1024 + +localization: + loc_scene_name: [ 'chess' ] + save_path: '/scratches/flyer_2/fx221/exp/localizer/loc_results' + + seg_k: 20 + threshold: 8 + min_kpts: 128 + min_matches: 16 + min_inliers: 32 + matching_method_: "mnn" + matching_method_1: "spg" + matching_method_2: "gm" + matching_method: "gml" + matching_method_4: "adagml" + save: false + show: true + show_time: 1 + with_original: true + max_vrf: 1 + with_compress: true + semantic_matching: true + do_refinement: true + pre_filtering_th: 0.95 + refinement_method_: 'matching' + refinement_method: 'projection' + covisibility_frame: 20 + refinement_radius: 20 + refinement_nn_ratio: 0.9 + refinement_max_matches: 0 diff --git a/third_party/pram/configs/config_train_aachen_sfd2.yaml b/third_party/pram/configs/config_train_aachen_sfd2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9e2111377ed9d6cff38efd69bc397487ecfb33fb --- /dev/null +++ b/third_party/pram/configs/config_train_aachen_sfd2.yaml @@ -0,0 +1,104 @@ +dataset: [ 'Aachen' ] + +network_: "segnet" +network: "segnetvit" +local_rank: 0 +gpu: [ 0 ] + +feature: "sfd2" +save_path: '/scratches/flyer_2/fx221/exp/pram' +landmark_path: "/scratches/flyer_3/fx221/exp/pram/landmarks/sfd2-gml" +dataset_path: "/scratches/flyer_3/fx221/dataset" + +config_path: 'configs/datasets' + +image_dim: 3 +feat_dim: 128 + +min_inliers: 32 +max_inliers: 512 +random_inliers: true +max_keypoints: 1024 +ignore_index: -1 +output_dim: 1024 +output_dim_: 2048 +jitter_params: + brightness: 0.5 + contrast: 0.5 + saturation: 0.25 + hue: 0.15 + blur: 0 + +scale_params: [ 0.5, 1.0 ] +pre_load: false +do_eval: true +train: true +inlier_th: 0.5 +lr: 0.0001 +min_lr: 0.00001 +optimizer: "adam" +seg_loss: "cew" +seg_loss_nx: "cei" +cls_loss: "ce" +cls_loss_: "bce" +ac_fn: "relu" +norm_fn: "bn" +workers: 8 +layers: 15 +log_intervals: 50 +eval_n_epoch: 10 + +use_mid_feature: true +norm_desc: false +with_sc: false +with_cls: true +with_score: false +with_aug: true +with_dist: true + +batch_size: 32 +its_per_epoch: 1000 +decay_rate: 0.999992 +decay_iter: 80000 +epochs: 800 + +cluster_method: 'birch' + +weight_path: null +weight_path_1: '20230719_221442_segnet_L15_A_resnet4x_B32_K1024_relu_bn_od1024_nc513_adamw_cew_md_A_birch/segnet.899.pth' +weight_path_2: '20240211_142623_segnetvit_L15_A_resnet4x_B32_K1024_relu_bn_od1024_nc513_adam_cew_md_A_birch/segnetvit.799.pth' +resume_path: null + +n_class: 513 + +eval_max_keypoints: 4096 + +localization: + loc_scene_name: [ ] + save_path: '/scratches/flyer_2/fx221/exp/localizer/loc_results' + seg_k: 10 + threshold: 12 + min_kpts: 256 + min_matches: 8 + min_inliers: 128 + matching_method_: "mnn" + matching_method_1: "spg" + matching_method_2: "gm" + matching_method: "gml" + matching_method_4: "adagml" + save: false + show: true + show_time: 1 + with_original: true + with_extra: false + max_vrf: 1 + with_compress: true + semantic_matching: true + refinement_method_: 'matching' + refinement_method: 'projection' + pre_filtering_th: 0.95 + do_refinement: true + covisibility_frame: 50 + refinement_radius: 30 + refinement_nn_ratio: 0.9 + refinement_max_matches: 0 diff --git a/third_party/pram/configs/config_train_cambridge_sfd2.yaml b/third_party/pram/configs/config_train_cambridge_sfd2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8cc843ee963dc5c0041954790d7e622e24aefe16 --- /dev/null +++ b/third_party/pram/configs/config_train_cambridge_sfd2.yaml @@ -0,0 +1,103 @@ +dataset: [ 'CambridgeLandmarks' ] + +network_: "segnet" +network: "segnetvit" + +local_rank: 0 +gpu: [ 0 ] + +feature: "sfd2" +save_path: '/scratches/flyer_2/fx221/exp/pram' +landmark_path: "/scratches/flyer_3/fx221/exp/pram/landmarks/sfd2-gml" +dataset_path: "/scratches/flyer_3/fx221/dataset" +config_path: 'configs/datasets' + +image_dim: 3 +feat_dim: 128 + +min_inliers: 32 +max_inliers: 512 +random_inliers: 1 +max_keypoints: 1024 +ignore_index: -1 +output_dim: 1024 +output_dim_: 2048 +jitter_params: + brightness: 0.5 + contrast: 0.5 + saturation: 0.25 + hue: 0.15 + blur: 0 + +scale_params: [ 0.5, 1.0 ] +pre_load: false +do_eval: false +train: true +inlier_th: 0.5 +lr: 0.0001 +min_lr: 0.00001 +epochs: 300 +seg_loss: "cew" +ac_fn: "relu" +norm_fn: "bn" +workers: 8 +layers: 15 +log_intervals: 50 +eval_n_epoch: 10 + +use_mid_feature: true +norm_desc: false +with_score: false +with_aug: true +with_dist: true + +batch_size: 32 +its_per_epoch: 1000 +decay_rate: 0.999992 +decay_iter: 60000 + +cluster_method: 'birch' + +weight_path: null +weight_path_1: '20230725_144044_segnet_L15_C_resnet4x_B32_K1024_relu_bn_od1024_nc161_adam_cew_md_A_birch/segnet.260.pth' +weight_path_2: '20240204_130323_segnetvit_L15_C_resnet4x_B32_K1024_relu_bn_od1024_nc161_adamw_cew_md_A_birch/segnetvit.399.pth' + +resume_path: null + +n_class: 161 + +eval_max_keypoints: 2048 + +localization: + loc_scene_name_1: [ 'GreatCourt' ] + loc_scene_name_2: [ 'KingsCollege' ] + loc_scene_name: [ 'StMarysChurch' ] + loc_scene_name_4: [ 'OldHospital' ] + save_path: '/scratches/flyer_2/fx221/exp/localizer/loc_results' + seg_k: 30 + threshold: 12 + min_kpts: 256 + min_matches: 16 + min_inliers_gm: 128 + min_inliers: 128 + matching_method_: "mnn" + matching_method_1: "spg" + matching_method_2: "gm" + matching_method: "gml" + matching_method_4: "adagml" + show: true + show_time: 1 + save: false + with_original: true + max_vrf: 1 + with_extra: false + with_compress: true + semantic_matching: true + do_refinement: true + pre_filtering_th: 0.95 + refinement_method_: 'matching' + refinement_method: 'projection' + covisibility_frame: 20 + refinement_radius: 20 + refinement_nn_ratio: 0.9 + refinement_max_matches: 0 diff --git a/third_party/pram/configs/config_train_multiset_sfd2.yaml b/third_party/pram/configs/config_train_multiset_sfd2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..90618e0812c2321ba05fbe3ab9a12d52ec447e99 --- /dev/null +++ b/third_party/pram/configs/config_train_multiset_sfd2.yaml @@ -0,0 +1,100 @@ +dataset: [ 'S', 'T', 'C', 'A' ] + +network: "segnet" +network_: "gsegnet3" + +local_rank: 0 +gpu: [ 4 ] + +feature: "resnet4x" +save_path: '/scratches/flyer_2/fx221/exp/localizer' +landmark_path: "/scratches/flyer_3/fx221/exp/localizer/resnet4x-20230511-210205-pho-0005-gm" +dataset_path: "/scratches/flyer_3/fx221/dataset" +config_path: 'configs/datasets' + +image_dim: 3 +min_inliers: 32 +max_inliers: 512 +random_inliers: 1 +max_keypoints: 1024 +ignore_index: -1 +output_dim: 1024 +output_dim_: 2048 +jitter_params: + brightness: 0.5 + contrast: 0.5 + saturation: 0.25 + hue: 0.15 + blur: 0 + +scale_params: [ 0.5, 1.0 ] +pre_load: false +do_eval: true +train: true +inlier_th: 0.5 +lr: 0.0001 +min_lr: 0.00001 +optimizer: "adam" +seg_loss: "cew" +seg_loss_nx: "cei" +cls_loss: "ce" +cls_loss_: "bce" +sc_loss: 'l1g' +ac_fn: "relu" +norm_fn: "bn" +workers: 8 +layers: 15 +log_intervals: 50 +eval_n_epoch: 10 + +use_mid_feature: true +norm_desc: false +with_sc: false +with_cls: true +with_score: false +with_aug: true +with_dist: true + +batch_size: 32 +its_per_epoch: 1000 +decay_rate: 0.999992 +decay_iter: 150000 +epochs: 1500 + +cluster_method_: 'kmeans' +cluster_method: 'birch' + +weight_path_: null +weight_path: '20230805_132653_segnet_L15_STCA_resnet4x_B32_K1024_relu_bn_od1024_nc977_adam_cew_md_A_birch/segnet.485.pth' +resume_path: null + +eval: false +#loc: false +loc: true +#n_class: 977 +online: false + +eval_max_keypoints: 4096 + +localization: + loc_scene_name: [ ] + save_path: '/scratches/flyer_2/fx221/exp/localizer/loc_results' + dataset: [ 'T' ] + seg_k: 50 + threshold: 8 # 8 for indoor, 12 for outdoor + min_kpts: 256 + min_matches: 4 + min_inliers: 64 + matching_method_: "mnn" + matching_method_1: "spg" + matching_method: "gm" + save: false + show: true + show_time: 1 + do_refinement: true + with_original: true + with_extra: false + max_vrf: 1 + with_compress: false + covisibility_frame: 20 + observation_threshold: 3 diff --git a/third_party/pram/configs/datasets/12Scenes.yaml b/third_party/pram/configs/datasets/12Scenes.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e950aca2ff25526af622fec779e9bb6a07eaea6b --- /dev/null +++ b/third_party/pram/configs/datasets/12Scenes.yaml @@ -0,0 +1,166 @@ +dataset: '12Scenes' +scenes: [ 'apt1/kitchen', + 'apt1/living', + 'apt2/bed', + 'apt2/kitchen', + 'apt2/living', + 'apt2/luke', + 'office1/gates362', + 'office1/gates381', + 'office1/lounge', + 'office1/manolis', + 'office2/5a', + 'office2/5b' +] + +apt1/kitchen: + n_cluster: 16 + cluster_mode: 'xy' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 5 + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + image_path_prefix: '' + + +apt1/living: + n_cluster: 16 + cluster_mode: 'xy' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 5 + image_path_prefix: '' + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + +apt2/bed: + n_cluster: 16 + cluster_mode: 'xy' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 5 + image_path_prefix: '' + + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + + +apt2/kitchen: + n_cluster: 16 + cluster_mode: 'xy' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 5 + image_path_prefix: '' + + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + + +apt2/living: + n_cluster: 16 + cluster_mode: 'xy' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 5 + image_path_prefix: '' + + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + + +apt2/luke: + n_cluster: 16 + cluster_mode: 'xy' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 5 + image_path_prefix: '' + + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + + +office1/gates362: + n_cluster: 16 + cluster_mode: 'xy' + cluster_method: 'birch' + + training_sample_ratio: 3 + eval_sample_ratio: 5 + image_path_prefix: '' + + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + + +office1/gates381: + n_cluster: 16 + cluster_mode: 'xy' + cluster_method: 'birch' + + training_sample_ratio: 3 + eval_sample_ratio: 5 + image_path_prefix: '' + + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + + +office1/lounge: + n_cluster: 16 + cluster_mode: 'xy' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 5 + image_path_prefix: '' + + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + + +office1/manolis: + n_cluster: 16 + cluster_mode: 'xy' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 5 + image_path_prefix: '' + + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + + +office2/5a: + n_cluster: 16 + cluster_mode: 'xy' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 5 + image_path_prefix: '' + + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + + +office2/5b: + n_cluster: 16 + cluster_mode: 'xy' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 5 + image_path_prefix: '' + + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' diff --git a/third_party/pram/configs/datasets/7Scenes.yaml b/third_party/pram/configs/datasets/7Scenes.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fd68181fbc0ed96ccb3e464d94a5346183c1dfe3 --- /dev/null +++ b/third_party/pram/configs/datasets/7Scenes.yaml @@ -0,0 +1,96 @@ +dataset: '7Scenes' +scenes: [ 'chess', 'heads', 'office', 'fire', 'stairs', 'redkitchen', 'pumpkin' ] + + +chess: + n_cluster: 16 + cluster_mode: 'xz' + cluster_method_: 'kmeans' + cluster_method: 'birch' + + training_sample_ratio: 2 + eval_sample_ratio: 10 + gt_pose_path: 'queries_poses.txt' + query_path: 'queries_with_intrinsics.txt' + image_path_prefix: '' + + + +heads: + n_cluster: 16 + cluster_mode: 'xz' + cluster_method_: 'kmeans' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 2 + gt_pose_path: 'queries_poses.txt' + query_path: 'queries_with_intrinsics.txt' + image_path_prefix: '' + + +office: + n_cluster: 16 + cluster_mode: 'xz' + cluster_method_: 'kmeans' + cluster_method: 'birch' + + training_sample_ratio: 3 + eval_sample_ratio: 10 + gt_pose_path: 'queries_poses.txt' + query_path: 'queries_with_intrinsics.txt' + image_path_prefix: '' + +fire: + n_cluster: 16 + cluster_mode: 'xz' + cluster_method_: 'kmeans' + cluster_method: 'birch' + + training_sample_ratio: 2 + eval_sample_ratio: 5 + gt_pose_path: 'queries_poses.txt' + query_path: 'queries_with_intrinsics.txt' + image_path_prefix: '' + + +stairs: + n_cluster: 16 + cluster_mode: 'xz' + cluster_method_: 'kmeans' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 10 + gt_pose_path: 'queries_poses.txt' + query_path: 'queries_with_intrinsics.txt' + image_path_prefix: '' + + +redkitchen: + n_cluster: 16 + cluster_mode: 'xz' + cluster_method_: 'kmeans' + cluster_method: 'birch' + + training_sample_ratio: 3 + eval_sample_ratio: 10 + gt_pose_path: 'queries_poses.txt' + query_path: 'queries_with_intrinsics.txt' + image_path_prefix: '' + + + + +pumpkin: + n_cluster: 16 + cluster_mode: 'xz' + cluster_method_: 'kmeans' + cluster_method: 'birch' + + training_sample_ratio: 2 + eval_sample_ratio: 10 + gt_pose_path: 'queries_poses.txt' + query_path: 'queries_with_intrinsics.txt' + image_path_prefix: '' + diff --git a/third_party/pram/configs/datasets/Aachen.yaml b/third_party/pram/configs/datasets/Aachen.yaml new file mode 100644 index 0000000000000000000000000000000000000000..49477afbe569cb0fc4317b6c1a98c30f261ee7e0 --- /dev/null +++ b/third_party/pram/configs/datasets/Aachen.yaml @@ -0,0 +1,15 @@ +dataset: 'Aachen' + +scenes: [ 'Aachenv11' ] + +Aachenv11: + n_cluster: 512 + cluster_mode: 'xz' + cluster_method_: 'kmeans' + cluster_method: 'birch' + training_sample_ratio: 1 + eval_sample_ratio: 1 + image_path_prefix: 'images/images_upright' + query_path_: 'queries_with_intrinsics.txt' + query_path: 'queries_with_intrinsics_demo.txt' + gt_pose_path: 'queries_pose_spp_spg.txt' diff --git a/third_party/pram/configs/datasets/CambridgeLandmarks.yaml b/third_party/pram/configs/datasets/CambridgeLandmarks.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c3a757898db1e772b593059d2c21ef1eaaa825ea --- /dev/null +++ b/third_party/pram/configs/datasets/CambridgeLandmarks.yaml @@ -0,0 +1,67 @@ +dataset: 'CambridgeLandmarks' +scenes: [ 'GreatCourt', 'KingsCollege', 'OldHospital', 'ShopFacade', 'StMarysChurch' ] + +GreatCourt: + n_cluster: 32 + cluster_mode: 'xy' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 1 + image_path_prefix: '' + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + + +KingsCollege: + n_cluster: 32 + cluster_mode: 'xy' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 1 + image_path_prefix: '' + + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + + +OldHospital: + n_cluster: 32 + cluster_mode: 'xz' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 1 + image_path_prefix: '' + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + + +ShopFacade: + n_cluster: 32 + cluster_mode: 'xy' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 1 + image_path_prefix: '' + + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + + +StMarysChurch: + n_cluster: 32 + cluster_mode: 'xz' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 1 + image_path_prefix: '' + + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + + + diff --git a/third_party/pram/dataset/aachen.py b/third_party/pram/dataset/aachen.py new file mode 100644 index 0000000000000000000000000000000000000000..d57efd8e4460f943d66b2d8b92e57d7cd7f7f75a --- /dev/null +++ b/third_party/pram/dataset/aachen.py @@ -0,0 +1,119 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> aachen +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 14:33 +==================================================''' +import os.path as osp +import numpy as np +import cv2 +from colmap_utils.read_write_model import read_model +import torchvision.transforms as tvt +from dataset.basicdataset import BasicDataset + + +class Aachen(BasicDataset): + def __init__(self, landmark_path, scene, dataset_path, n_class, seg_mode, seg_method, dataset='Aachen', + nfeatures=1024, + query_p3d_fn=None, + train=True, + with_aug=False, + min_inliers=0, + max_inliers=4096, + random_inliers=False, + jitter_params=None, + scale_params=None, + image_dim=3, + query_info_path=None, + sample_ratio=1, ): + self.landmark_path = osp.join(landmark_path, scene) + self.dataset_path = osp.join(dataset_path, scene) + self.n_class = n_class + self.dataset = dataset + '/' + scene + self.nfeatures = nfeatures + self.with_aug = with_aug + self.jitter_params = jitter_params + self.scale_params = scale_params + self.image_dim = image_dim + self.train = train + self.min_inliers = min_inliers + self.max_inliers = max_inliers if max_inliers < nfeatures else nfeatures + self.random_inliers = random_inliers + self.image_prefix = 'images/images_upright' + + train_transforms = [] + if self.with_aug: + train_transforms.append(tvt.ColorJitter( + brightness=jitter_params['brightness'], + contrast=jitter_params['contrast'], + saturation=jitter_params['saturation'], + hue=jitter_params['hue'])) + if jitter_params['blur'] > 0: + train_transforms.append(tvt.GaussianBlur(kernel_size=int(jitter_params['blur']))) + self.train_transforms = tvt.Compose(train_transforms) + + if train: + self.cameras, self.images, point3Ds = read_model(path=osp.join(self.landmark_path, '3D-models'), ext='.bin') + self.name_to_id = {image.name: i for i, image in self.images.items() if len(self.images[i].point3D_ids) > 0} + + # only for testing of query images + if not self.train: + data = np.load(query_p3d_fn, allow_pickle=True)[()] + self.img_p3d = data + else: + self.img_p3d = {} + + self.img_fns = [] + if train: + with open(osp.join(self.dataset_path, 'aachen_db_imglist.txt'), 'r') as f: + lines = f.readlines() + for l in lines: + l = l.strip() + if l not in self.name_to_id.keys(): + continue + self.img_fns.append(l) + else: + with open(osp.join(self.dataset_path, 'queries', 'day_time_queries_with_intrinsics.txt'), 'r') as f: + lines = f.readlines() + for l in lines: + l = l.strip().split()[0] + if l not in self.img_p3d.keys(): + continue + self.img_fns.append(l) + with open(osp.join(self.dataset_path, 'queries', 'night_time_queries_with_intrinsics.txt'), 'r') as f: + lines = f.readlines() + for l in lines: + l = l.strip().split()[0] + if l not in self.img_p3d.keys(): + continue + self.img_fns.append(l) + + print( + 'Load {} images from {} for {}...'.format(len(self.img_fns), self.dataset, 'training' if train else 'eval')) + + data = np.load(osp.join(self.landmark_path, + 'point3D_cluster_n{:d}_{:s}_{:s}.npy'.format(n_class - 1, seg_mode, seg_method)), + allow_pickle=True)[()] + p3d_id = data['id'] + seg_id = data['label'] + self.p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])} + xyzs = data['xyz'] + self.p3d_xyzs = {p3d_id[i]: xyzs[i] for i in range(p3d_id.shape[0])} + + with open(osp.join(self.landmark_path, 'sc_mean_scale.txt'), 'r') as f: + lines = f.readlines() + for l in lines: + l = l.strip().split() + self.mean_xyz = np.array([float(v) for v in l[:3]]) + self.scale_xyz = np.array([float(v) for v in l[3:]]) + + if not train: + self.query_info = self.read_query_info(path=query_info_path) + + self.nfeatures = nfeatures + self.feature_dir = osp.join(self.landmark_path, 'feats') + self.feats = {} + + def read_image(self, image_name): + return cv2.imread(osp.join(self.dataset_path, 'images/images_upright/', image_name)) diff --git a/third_party/pram/dataset/basicdataset.py b/third_party/pram/dataset/basicdataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c77c32ca010e99d14ddd8643c2ff07789bd75851 --- /dev/null +++ b/third_party/pram/dataset/basicdataset.py @@ -0,0 +1,477 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> basicdataset +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 14:27 +==================================================''' +import torchvision.transforms.functional as tvf +import torchvision.transforms as tvt +import os.path as osp +import numpy as np +import cv2 +from colmap_utils.read_write_model import qvec2rotmat, read_model +from dataset.utils import normalize_size + + +class BasicDataset: + def __init__(self, + img_list_fn, + feature_dir, + sfm_path, + seg_fn, + dataset_path, + n_class, + dataset, + nfeatures=1024, + query_p3d_fn=None, + train=True, + with_aug=False, + min_inliers=0, + max_inliers=4096, + random_inliers=False, + jitter_params=None, + scale_params=None, + image_dim=1, + pre_load=False, + query_info_path=None, + sc_mean_scale_fn=None, + ): + self.n_class = n_class + self.train = train + self.min_inliers = min_inliers + self.max_inliers = max_inliers if max_inliers < nfeatures else nfeatures + self.random_inliers = random_inliers + self.dataset_path = dataset_path + self.with_aug = with_aug + self.dataset = dataset + self.jitter_params = jitter_params + self.scale_params = scale_params + self.image_dim = image_dim + self.image_prefix = '' + + train_transforms = [] + if self.with_aug: + train_transforms.append(tvt.ColorJitter( + brightness=jitter_params['brightness'], + contrast=jitter_params['contrast'], + saturation=jitter_params['saturation'], + hue=jitter_params['hue'])) + if jitter_params['blur'] > 0: + train_transforms.append(tvt.GaussianBlur(kernel_size=int(jitter_params['blur']))) + self.train_transforms = tvt.Compose(train_transforms) + + # only for testing of query images + if not self.train: + data = np.load(query_p3d_fn, allow_pickle=True)[()] + self.img_p3d = data + else: + self.img_p3d = {} + + self.img_fns = [] + with open(img_list_fn, 'r') as f: + lines = f.readlines() + for l in lines: + l = l.strip() + self.img_fns.append(l) + print('Load {} images from {} for {}...'.format(len(self.img_fns), dataset, 'training' if train else 'eval')) + self.feats = {} + if train: + self.cameras, self.images, point3Ds = read_model(path=sfm_path, ext='.bin') + self.name_to_id = {image.name: i for i, image in self.images.items()} + + data = np.load(seg_fn, allow_pickle=True)[()] + p3d_id = data['id'] + seg_id = data['label'] + self.p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])} + self.p3d_xyzs = {} + + for pid in self.p3d_seg.keys(): + p3d = point3Ds[pid] + self.p3d_xyzs[pid] = p3d.xyz + + with open(sc_mean_scale_fn, 'r') as f: + lines = f.readlines() + for l in lines: + l = l.strip().split() + self.mean_xyz = np.array([float(v) for v in l[:3]]) + self.scale_xyz = np.array([float(v) for v in l[3:]]) + + if not train: + self.query_info = self.read_query_info(path=query_info_path) + + self.nfeatures = nfeatures + self.feature_dir = feature_dir + print('Pre loaded {} feats, mean xyz {}, scale xyz {}'.format(len(self.feats.keys()), self.mean_xyz, + self.scale_xyz)) + + def normalize_p3ds(self, p3ds): + mean_p3ds = np.ceil(np.mean(p3ds, axis=0)) + p3ds_ = p3ds - mean_p3ds + dx = np.max(abs(p3ds_[:, 0])) + dy = np.max(abs(p3ds_[:, 1])) + dz = np.max(abs(p3ds_[:, 2])) + scale_p3ds = np.ceil(np.array([dx, dy, dz], dtype=float).reshape(3, )) + scale_p3ds[scale_p3ds < 1] = 1 + scale_p3ds[scale_p3ds == 0] = 1 + return mean_p3ds, scale_p3ds + + def read_query_info(self, path): + query_info = {} + with open(path, 'r') as f: + lines = f.readlines() + for l in lines: + l = l.strip().split() + image_name = l[0] + cam_model = l[1] + h, w = int(l[2]), int(l[3]) + params = np.array([float(v) for v in l[4:]]) + query_info[image_name] = { + 'width': w, + 'height': h, + 'model': cam_model, + 'params': params, + } + return query_info + + def extract_intrinsic_extrinsic_params(self, image_id): + cam = self.cameras[self.images[image_id].camera_id] + params = cam.params + model = cam.model + if model in ("SIMPLE_PINHOLE", "SIMPLE_RADIAL", "RADIAL"): + fx = fy = params[0] + cx = params[1] + cy = params[2] + elif model in ("PINHOLE", "OPENCV", "OPENCV_FISHEYE", "FULL_OPENCV"): + fx = params[0] + fy = params[1] + cx = params[2] + cy = params[3] + else: + raise Exception("Camera model not supported") + K = np.eye(3, dtype=float) + K[0, 0] = fx + K[1, 1] = fy + K[0, 2] = cx + K[1, 2] = cy + + qvec = self.images[image_id].qvec + tvec = self.images[image_id].tvec + R = qvec2rotmat(qvec=qvec) + P = np.eye(4, dtype=float) + P[:3, :3] = R + P[:3, 3] = tvec.reshape(3, ) + + return {'K': K, 'P': P} + + def get_item_train(self, idx): + img_name = self.img_fns[idx] + if img_name in self.feats.keys(): + feat_data = self.feats[img_name] + else: + feat_data = np.load(osp.join(self.feature_dir, img_name.replace('/', '+') + '.npy'), allow_pickle=True)[()] + # descs = feat_data['descriptors'] # [N, D] + scores = feat_data['scores'] # [N, 1] + kpts = feat_data['keypoints'] # [N, 2] + image_size = feat_data['image_size'] + + nfeat = kpts.shape[0] + + # print(img_name, self.name_to_id[img_name]) + p3d_ids = self.images[self.name_to_id[img_name]].point3D_ids + p3d_xyzs = np.zeros(shape=(nfeat, 3), dtype=float) + + seg_ids = np.zeros(shape=(nfeat,), dtype=int) # + self.n_class - 1 + for i in range(nfeat): + p3d = p3d_ids[i] + if p3d in self.p3d_seg.keys(): + seg_ids[i] = self.p3d_seg[p3d] + 1 # 0 for invalid + if seg_ids[i] == -1: + seg_ids[i] = 0 + + if p3d in self.p3d_xyzs.keys(): + p3d_xyzs[i] = self.p3d_xyzs[p3d] + + seg_ids = np.array(seg_ids).reshape(-1, ) + + n_inliers = np.sum(seg_ids > 0) + n_outliers = np.sum(seg_ids == 0) + inlier_ids = np.where(seg_ids > 0)[0] + outlier_ids = np.where(seg_ids == 0)[0] + + if n_inliers <= self.min_inliers: + sel_inliers = n_inliers + sel_outliers = self.nfeatures - sel_inliers + + out_ids = np.arange(n_outliers) + np.random.shuffle(out_ids) + sel_ids = np.hstack([inlier_ids, outlier_ids[out_ids[:self.nfeatures - n_inliers]]]) + else: + sel_inliers = np.random.randint(self.min_inliers, self.max_inliers) + if sel_inliers > n_inliers: + sel_inliers = n_inliers + + if sel_inliers + n_outliers < self.nfeatures: + sel_inliers = self.nfeatures - n_outliers + + sel_outliers = self.nfeatures - sel_inliers + + in_ids = np.arange(n_inliers) + np.random.shuffle(in_ids) + sel_inlier_ids = inlier_ids[in_ids[:sel_inliers]] + + out_ids = np.arange(n_outliers) + np.random.shuffle(out_ids) + sel_outlier_ids = outlier_ids[out_ids[:sel_outliers]] + + sel_ids = np.hstack([sel_inlier_ids, sel_outlier_ids]) + + # sel_descs = descs[sel_ids] + sel_scores = scores[sel_ids] + sel_kpts = kpts[sel_ids] + sel_seg_ids = seg_ids[sel_ids] + sel_xyzs = p3d_xyzs[sel_ids] + + shuffle_ids = np.arange(sel_ids.shape[0]) + np.random.shuffle(shuffle_ids) + # sel_descs = sel_descs[shuffle_ids] + sel_scores = sel_scores[shuffle_ids] + sel_kpts = sel_kpts[shuffle_ids] + sel_seg_ids = sel_seg_ids[shuffle_ids] + sel_xyzs = sel_xyzs[shuffle_ids] + + if sel_kpts.shape[0] < self.nfeatures: + # print(sel_descs.shape, sel_kpts.shape, sel_scores.shape, sel_seg_ids.shape, sel_xyzs.shape) + valid_sel_ids = np.array([v for v in range(sel_kpts.shape[0]) if sel_seg_ids[v] > 0], dtype=int) + # ref_sel_id = np.random.choice(valid_sel_ids, size=1)[0] + if valid_sel_ids.shape[0] == 0: + valid_sel_ids = np.array([v for v in range(sel_kpts.shape[0])], dtype=int) + random_n = self.nfeatures - sel_kpts.shape[0] + random_scores = np.random.random((random_n,)) + random_kpts, random_seg_ids, random_xyzs = self.random_points_from_reference( + n=random_n, + ref_kpts=sel_kpts[valid_sel_ids], + ref_segs=sel_seg_ids[valid_sel_ids], + ref_xyzs=sel_xyzs[valid_sel_ids], + radius=5, + ) + # sel_descs = np.vstack([sel_descs, random_descs]) + sel_scores = np.hstack([sel_scores, random_scores]) + sel_kpts = np.vstack([sel_kpts, random_kpts]) + sel_seg_ids = np.hstack([sel_seg_ids, random_seg_ids]) + sel_xyzs = np.vstack([sel_xyzs, random_xyzs]) + + gt_n_seg = np.zeros(shape=(self.n_class,), dtype=int) + gt_cls = np.zeros(shape=(self.n_class,), dtype=int) + gt_cls_dist = np.zeros(shape=(self.n_class,), dtype=float) + uids = np.unique(sel_seg_ids).tolist() + for uid in uids: + if uid == 0: + continue + gt_cls[uid] = 1 + gt_n_seg[uid] = np.sum(sel_seg_ids == uid) + gt_cls_dist[uid] = np.sum(seg_ids == uid) / np.sum(seg_ids > 0) # [valid_id / total_valid_id] + + param_out = self.extract_intrinsic_extrinsic_params(image_id=self.name_to_id[img_name]) + + img = self.read_image(image_name=img_name) + image_size = img.shape[:2] + if self.image_dim == 1: + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + if self.with_aug: + nh = img.shape[0] + nw = img.shape[1] + if self.scale_params is not None: + do_scale = np.random.random() + if do_scale <= 0.25: + p = np.random.randint(0, 11) + s = self.scale_params[0] + (self.scale_params[1] - self.scale_params[0]) / 10 * p + nh = int(img.shape[0] * s) + nw = int(img.shape[1] * s) + sh = nh / img.shape[0] + sw = nw / img.shape[1] + sel_kpts[:, 0] = sel_kpts[:, 0] * sw + sel_kpts[:, 1] = sel_kpts[:, 1] * sh + img = cv2.resize(img, dsize=(nw, nh)) + + brightness = np.random.uniform(-self.jitter_params['brightness'], self.jitter_params['brightness']) * 255 + contrast = 1 + np.random.uniform(-self.jitter_params['contrast'], self.jitter_params['contrast']) + img = cv2.addWeighted(img, contrast, img, 0, brightness) + img = np.clip(img, a_min=0, a_max=255) + if self.image_dim == 1: + img = img[..., None] + img = img.astype(float) / 255. + image_size = np.array([nh, nw], dtype=int) + else: + if self.image_dim == 1: + img = img[..., None].astype(float) / 255. + + output = { + # 'descriptors': sel_descs, # may not be used + 'scores': sel_scores, + 'keypoints': sel_kpts, + 'norm_keypoints': normalize_size(x=sel_kpts, size=image_size), + 'image': [img], + 'gt_seg': sel_seg_ids, + 'gt_cls': gt_cls, + 'gt_cls_dist': gt_cls_dist, + 'gt_n_seg': gt_n_seg, + 'file_name': img_name, + 'prefix_name': self.image_prefix, + # 'mean_xyz': self.mean_xyz, + # 'scale_xyz': self.scale_xyz, + # 'gt_sc': sel_xyzs, + # 'gt_norm_sc': (sel_xyzs - self.mean_xyz) / self.scale_xyz, + 'K': param_out['K'], + 'gt_P': param_out['P'] + } + return output + + def get_item_test(self, idx): + + # evaluation of recognition only + img_name = self.img_fns[idx] + feat_data = np.load(osp.join(self.feature_dir, img_name.replace('/', '+') + '.npy'), allow_pickle=True)[()] + descs = feat_data['descriptors'] # [N, D] + scores = feat_data['scores'] # [N, 1] + kpts = feat_data['keypoints'] # [N, 2] + image_size = feat_data['image_size'] + + nfeat = descs.shape[0] + + if img_name in self.img_p3d.keys(): + p3d_ids = self.img_p3d[img_name] + p3d_xyzs = np.zeros(shape=(nfeat, 3), dtype=float) + seg_ids = np.zeros(shape=(nfeat,), dtype=int) # attention! by default invalid!!! + for i in range(nfeat): + p3d = p3d_ids[i] + if p3d in self.p3d_seg.keys(): + seg_ids[i] = self.p3d_seg[p3d] + 1 + if seg_ids[i] == -1: + seg_ids[i] = 0 # 0 for in valid + + if p3d in self.p3d_xyzs.keys(): + p3d_xyzs[i] = self.p3d_xyzs[p3d] + + seg_ids = np.array(seg_ids).reshape(-1, ) + + if self.nfeatures > 0: + sorted_ids = np.argsort(scores)[::-1][:self.nfeatures] # large to small + descs = descs[sorted_ids] + scores = scores[sorted_ids] + kpts = kpts[sorted_ids] + p3d_xyzs = p3d_xyzs[sorted_ids] + + seg_ids = seg_ids[sorted_ids] + + gt_n_seg = np.zeros(shape=(self.n_class,), dtype=int) + gt_cls = np.zeros(shape=(self.n_class,), dtype=int) + gt_cls_dist = np.zeros(shape=(self.n_class,), dtype=float) + uids = np.unique(seg_ids).tolist() + for uid in uids: + if uid == 0: + continue + gt_cls[uid] = 1 + gt_n_seg[uid] = np.sum(seg_ids == uid) + gt_cls_dist[uid] = np.sum(seg_ids == uid) / np.sum( + seg_ids < self.n_class - 1) # [valid_id / total_valid_id] + + gt_cls[0] = 0 + + img = self.read_image(image_name=img_name) + if self.image_dim == 1: + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + img = img[..., None].astype(float) / 255. + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(float) / 255. + return { + 'descriptors': descs, + 'scores': scores, + 'keypoints': kpts, + 'image_size': image_size, + 'norm_keypoints': normalize_size(x=kpts, size=image_size), + 'gt_seg': seg_ids, + 'gt_cls': gt_cls, + 'gt_cls_dist': gt_cls_dist, + 'gt_n_seg': gt_n_seg, + 'file_name': img_name, + 'prefix_name': self.image_prefix, + 'image': [img], + + 'mean_xyz': self.mean_xyz, + 'scale_xyz': self.scale_xyz, + 'gt_sc': p3d_xyzs, + 'gt_norm_sc': (p3d_xyzs - self.mean_xyz) / self.scale_xyz + } + + def __getitem__(self, idx): + if self.train: + return self.get_item_train(idx=idx) + else: + return self.get_item_test(idx=idx) + + def __len__(self): + return len(self.img_fns) + + def read_image(self, image_name): + return cv2.imread(osp.join(self.dataset_path, image_name)) + + def jitter_augmentation(self, img, params): + brightness, contrast, saturation, hue = params + p = np.random.randint(0, 20) / 20 + b = brightness[0] + (brightness[1] - brightness[0]) / 20 * p + img = tvf.adjust_brightness(img=img, brightness_factor=b) + + p = np.random.randint(0, 20) / 20 + c = contrast[0] + (contrast[1] - contrast[0]) / 20 * p + img = tvf.adjust_contrast(img=img, contrast_factor=c) + + p = np.random.randint(0, 20) / 20 + s = saturation[0] + (saturation[1] - saturation[0]) / 20 * p + img = tvf.adjust_saturation(img=img, saturation_factor=s) + + p = np.random.randint(0, 20) / 20 + h = hue[0] + (hue[1] - hue[0]) / 20 * p + img = tvf.adjust_hue(img=img, hue_factor=h) + + return img + + def random_points(self, n, d, h, w): + desc = np.random.random((n, d)) + desc = desc / np.linalg.norm(desc, ord=2, axis=1)[..., None] + xs = np.random.randint(0, w - 1, size=(n, 1)) + ys = np.random.randint(0, h - 1, size=(n, 1)) + kpts = np.hstack([xs, ys]) + return desc, kpts + + def random_points_from_reference(self, n, ref_kpts, ref_segs, ref_xyzs, radius=5): + n_ref = ref_kpts.shape[0] + if n_ref < n: + ref_ids = np.random.choice([i for i in range(n_ref)], size=n).tolist() + else: + ref_ids = [i for i in range(n)] + + new_xs = [] + new_ys = [] + # new_descs = [] + new_segs = [] + new_xyzs = [] + for i in ref_ids: + nx = np.random.randint(-radius, radius) + ref_kpts[i, 0] + ny = np.random.randint(-radius, radius) + ref_kpts[i, 1] + + new_xs.append(nx) + new_ys.append(ny) + # new_descs.append(ref_descs[i]) + new_segs.append(ref_segs[i]) + new_xyzs.append(ref_xyzs[i]) + + new_xs = np.array(new_xs).reshape(n, 1) + new_ys = np.array(new_ys).reshape(n, 1) + new_segs = np.array(new_segs).reshape(n, ) + new_kpts = np.hstack([new_xs, new_ys]) + # new_descs = np.array(new_descs).reshape(n, -1) + new_xyzs = np.array(new_xyzs) + return new_kpts, new_segs, new_xyzs diff --git a/third_party/pram/dataset/cambridge_landmarks.py b/third_party/pram/dataset/cambridge_landmarks.py new file mode 100644 index 0000000000000000000000000000000000000000..03f30f367f4ded9ce1d7c2efbaa407ed26725a69 --- /dev/null +++ b/third_party/pram/dataset/cambridge_landmarks.py @@ -0,0 +1,101 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> cambridge_landmarks +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 14:41 +==================================================''' +import os.path as osp +import numpy as np +from colmap_utils.read_write_model import read_model +import torchvision.transforms as tvt +from dataset.basicdataset import BasicDataset + + +class CambridgeLandmarks(BasicDataset): + def __init__(self, landmark_path, scene, dataset_path, n_class, seg_mode, seg_method, dataset='CambridgeLandmarks', + nfeatures=1024, + query_p3d_fn=None, + train=True, + with_aug=False, + min_inliers=0, + max_inliers=4096, + random_inliers=False, + jitter_params=None, + scale_params=None, + image_dim=3, + query_info_path=None, + sample_ratio=1, + ): + self.landmark_path = osp.join(landmark_path, scene) + self.dataset_path = osp.join(dataset_path, scene) + self.n_class = n_class + self.dataset = dataset + '/' + scene + self.nfeatures = nfeatures + self.with_aug = with_aug + self.jitter_params = jitter_params + self.scale_params = scale_params + self.image_dim = image_dim + self.train = train + self.min_inliers = min_inliers + self.max_inliers = max_inliers if max_inliers < nfeatures else nfeatures + self.random_inliers = random_inliers + self.image_prefix = '' + train_transforms = [] + if self.with_aug: + train_transforms.append(tvt.ColorJitter( + brightness=jitter_params['brightness'], + contrast=jitter_params['contrast'], + saturation=jitter_params['saturation'], + hue=jitter_params['hue'])) + if jitter_params['blur'] > 0: + train_transforms.append(tvt.GaussianBlur(kernel_size=int(jitter_params['blur']))) + self.train_transforms = tvt.Compose(train_transforms) + + if train: + self.cameras, self.images, point3Ds = read_model(path=osp.join(self.landmark_path, '3D-models'), ext='.bin') + self.name_to_id = {image.name: i for i, image in self.images.items() if len(self.images[i].point3D_ids) > 0} + + # only for testing of query images + if not self.train: + data = np.load(query_p3d_fn, allow_pickle=True)[()] + self.img_p3d = data + else: + self.img_p3d = {} + + self.img_fns = [] + with open(osp.join(self.dataset_path, 'dataset_train.txt' if train else 'dataset_test.txt'), 'r') as f: + lines = f.readlines()[3:] # ignore the first 3 lines + for l in lines: + l = l.strip().split()[0] + if train and l not in self.name_to_id.keys(): + continue + if not train and l not in self.img_p3d.keys(): + continue + self.img_fns.append(l) + + print('Load {} images from {} for {}...'.format(len(self.img_fns), + self.dataset, 'training' if train else 'eval')) + + data = np.load(osp.join(self.landmark_path, + 'point3D_cluster_n{:d}_{:s}_{:s}.npy'.format(n_class - 1, seg_mode, seg_method)), + allow_pickle=True)[()] + p3d_id = data['id'] + seg_id = data['label'] + self.p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])} + xyzs = data['xyz'] + self.p3d_xyzs = {p3d_id[i]: xyzs[i] for i in range(p3d_id.shape[0])} + + # with open(osp.join(self.landmark_path, 'sc_mean_scale.txt'), 'r') as f: + # lines = f.readlines() + # for l in lines: + # l = l.strip().split() + # self.mean_xyz = np.array([float(v) for v in l[:3]]) + # self.scale_xyz = np.array([float(v) for v in l[3:]]) + + if not train: + self.query_info = self.read_query_info(path=query_info_path) + + self.nfeatures = nfeatures + self.feature_dir = osp.join(self.landmark_path, 'feats') + self.feats = {} diff --git a/third_party/pram/dataset/customdataset.py b/third_party/pram/dataset/customdataset.py new file mode 100644 index 0000000000000000000000000000000000000000..41ec99ec1540868f3dfbafe00b5585398062e3f8 --- /dev/null +++ b/third_party/pram/dataset/customdataset.py @@ -0,0 +1,93 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> customdataset.py +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 14:38 +==================================================''' +import os.path as osp +import numpy as np +from colmap_utils.read_write_model import read_model +import torchvision.transforms as tvt +from dataset.basicdataset import BasicDataset + + +class CustomDataset(BasicDataset): + def __init__(self, landmark_path, scene, dataset_path, n_class, seg_mode, seg_method, dataset, + nfeatures=1024, + query_p3d_fn=None, + train=True, + with_aug=False, + min_inliers=0, + max_inliers=4096, + random_inliers=False, + jitter_params=None, + scale_params=None, + image_dim=3, + query_info_path=None, + sample_ratio=1, + ): + self.landmark_path = osp.join(landmark_path, scene) + self.dataset_path = osp.join(dataset_path, scene) + self.n_class = n_class + self.dataset = dataset + '/' + scene + self.nfeatures = nfeatures + self.with_aug = with_aug + self.jitter_params = jitter_params + self.scale_params = scale_params + self.image_dim = image_dim + self.train = train + self.min_inliers = min_inliers + self.max_inliers = max_inliers if max_inliers < nfeatures else nfeatures + self.random_inliers = random_inliers + self.image_prefix = '' + + train_transforms = [] + if self.with_aug: + train_transforms.append(tvt.ColorJitter( + brightness=jitter_params['brightness'], + contrast=jitter_params['contrast'], + saturation=jitter_params['saturation'], + hue=jitter_params['hue'])) + if jitter_params['blur'] > 0: + train_transforms.append(tvt.GaussianBlur(kernel_size=int(jitter_params['blur']))) + self.train_transforms = tvt.Compose(train_transforms) + + if train: + self.cameras, self.images, point3Ds = read_model(path=osp.join(self.landmark_path, '3D-models'), ext='.bin') + self.name_to_id = {image.name: i for i, image in self.images.items() if len(self.images[i].point3D_ids) > 0} + + # only for testing of query images + if not self.train: + data = np.load(query_p3d_fn, allow_pickle=True)[()] + self.img_p3d = data + else: + self.img_p3d = {} + + if train: + self.img_fns = [self.images[v].name for v in self.images.keys() if + self.images[v].name in self.name_to_id.keys()] + else: + self.img_fns = [] + with open(osp.join(self.dataset_path, 'queries_with_intrinsics.txt'), 'r') as f: + lines = f.readlines() + for l in lines: + self.img_fns.append(l.strip().split()[0]) + print('Load {} images from {} for {}...'.format(len(self.img_fns), + self.dataset, 'training' if train else 'eval')) + + data = np.load(osp.join(self.landmark_path, + 'point3D_cluster_n{:d}_{:s}_{:s}.npy'.format(n_class - 1, seg_mode, seg_method)), + allow_pickle=True)[()] + p3d_id = data['id'] + seg_id = data['label'] + self.p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])} + xyzs = data['xyz'] + self.p3d_xyzs = {p3d_id[i]: xyzs[i] for i in range(p3d_id.shape[0])} + + if not train: + self.query_info = self.read_query_info(path=query_info_path) + + self.nfeatures = nfeatures + self.feature_dir = osp.join(self.landmark_path, 'feats') + self.feats = {} diff --git a/third_party/pram/dataset/get_dataset.py b/third_party/pram/dataset/get_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..4fe28eaa6238b480aae4c64cd08ffe6cd2379c90 --- /dev/null +++ b/third_party/pram/dataset/get_dataset.py @@ -0,0 +1,89 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> get_dataset +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 14:40 +==================================================''' +import os.path as osp +import yaml +from dataset.aachen import Aachen +from dataset.twelve_scenes import TwelveScenes +from dataset.seven_scenes import SevenScenes +from dataset.cambridge_landmarks import CambridgeLandmarks +from dataset.customdataset import CustomDataset +from dataset.recdataset import RecDataset + + +def get_dataset(dataset): + if dataset in ['7Scenes', 'S']: + return SevenScenes + elif dataset in ['12Scenes', 'T']: + return TwelveScenes + elif dataset in ['Aachen', 'A']: + return Aachen + elif dataset in ['CambridgeLandmarks', 'C']: + return CambridgeLandmarks + else: + return CustomDataset + + +def compose_datasets(datasets, config, train=True, sample_ratio=None): + sub_sets = [] + for name in datasets: + if name == 'S': + ds_name = '7Scenes' + elif name == 'T': + ds_name = '12Scenes' + elif name == 'A': + ds_name = 'Aachen' + elif name == 'R': + ds_name = 'RobotCar-Seasons' + elif name == 'C': + ds_name = 'CambridgeLandmarks' + else: + ds_name = name + # raise '{} dataset does not exist'.format(name) + landmark_path = osp.join(config['landmark_path'], ds_name) + dataset_path = osp.join(config['dataset_path'], ds_name) + scene_config_path = 'configs/datasets/{:s}.yaml'.format(ds_name) + + with open(scene_config_path, 'r') as f: + scene_config = yaml.load(f, Loader=yaml.Loader) + DSet = get_dataset(dataset=ds_name) + + for scene in scene_config['scenes']: + if sample_ratio is None: + scene_sample_ratio = scene_config[scene]['training_sample_ratio'] if train else scene_config[scene][ + 'eval_sample_ratio'] + else: + scene_sample_ratio = sample_ratio + scene_set = DSet(landmark_path=landmark_path, + dataset_path=dataset_path, + scene=scene, + seg_mode=scene_config[scene]['cluster_mode'], + seg_method=scene_config[scene]['cluster_method'], + n_class=scene_config[scene]['n_cluster'] + 1, # including invalid - 0 + dataset=ds_name, + train=train, + nfeatures=config['max_keypoints'] if train else config['eval_max_keypoints'], + min_inliers=config['min_inliers'], + max_inliers=config['max_inliers'], + random_inliers=config['random_inliers'], + with_aug=config['with_aug'], + jitter_params=config['jitter_params'], + scale_params=config['scale_params'], + image_dim=config['image_dim'], + query_p3d_fn=osp.join(config['landmark_path'], ds_name, scene, + 'point3D_query_n{:d}_{:s}_{:s}.npy'.format( + scene_config[scene]['n_cluster'], + scene_config[scene]['cluster_mode'], + scene_config[scene]['cluster_method'])), + query_info_path=osp.join(config['dataset_path'], ds_name, scene, + 'queries_with_intrinsics.txt'), + sample_ratio=scene_sample_ratio, + ) + + sub_sets.append(scene_set) + + return RecDataset(sub_sets=sub_sets) diff --git a/third_party/pram/dataset/recdataset.py b/third_party/pram/dataset/recdataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9eebd473018ad269eaa6cd8f1ffaab3f5f316ec6 --- /dev/null +++ b/third_party/pram/dataset/recdataset.py @@ -0,0 +1,95 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> recdataset +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 14:42 +==================================================''' +import numpy as np +from torch.utils.data import Dataset + + +class RecDataset(Dataset): + def __init__(self, sub_sets=[]): + assert len(sub_sets) >= 1 + + self.sub_sets = sub_sets + self.names = [] + + self.sub_set_index = [] + self.seg_offsets = [] + self.sub_set_item_index = [] + self.dataset_names = [] + self.scene_names = [] + start_index_valid_seg = 1 # start from 1, 0 is for invalid + + total_subset = 0 + for scene_set in sub_sets: # [0, n_class] + name = scene_set.dataset + self.names.append(name) + n_samples = len(scene_set) + + n_class = scene_set.n_class + self.seg_offsets = self.seg_offsets + [start_index_valid_seg for v in range(len(scene_set))] + start_index_valid_seg = start_index_valid_seg + n_class - 1 + + self.sub_set_index = self.sub_set_index + [total_subset for k in range(n_samples)] + self.sub_set_item_index = self.sub_set_item_index + [k for k in range(n_samples)] + + # self.dataset_names = self.dataset_names + [name for k in range(n_samples)] + self.scene_names = self.scene_names + [name for k in range(n_samples)] + total_subset += 1 + + self.n_class = start_index_valid_seg + + print('Load {} images {} segs from {} subsets from {}'.format(len(self.sub_set_item_index), self.n_class, + len(sub_sets), self.names)) + + def __len__(self): + return len(self.sub_set_item_index) + + def __getitem__(self, idx): + subset_idx = self.sub_set_index[idx] + item_idx = self.sub_set_item_index[idx] + scene_name = self.scene_names[idx] + + out = self.sub_sets[subset_idx][item_idx] + + org_gt_seg = out['gt_seg'] + org_gt_cls = out['gt_cls'] + org_gt_cls_dist = out['gt_cls_dist'] + org_gt_n_seg = out['gt_n_seg'] + offset = self.seg_offsets[idx] + org_n_class = self.sub_sets[subset_idx].n_class + + gt_seg = np.zeros(shape=(org_gt_seg.shape[0],), dtype=int) # [0, ..., n_features] + gt_n_seg = np.zeros(shape=(self.n_class,), dtype=int) + gt_cls = np.zeros(shape=(self.n_class,), dtype=int) + gt_cls_dist = np.zeros(shape=(self.n_class,), dtype=float) + + # copy invalid segments + gt_n_seg[0] = org_gt_n_seg[0] + gt_cls[0] = org_gt_cls[0] + gt_cls_dist[0] = org_gt_cls_dist[0] + # print('org: ', org_n_class, org_gt_seg.shape, org_gt_n_seg.shape, org_gt_seg) + + # copy valid segments + gt_seg[org_gt_seg > 0] = org_gt_seg[org_gt_seg > 0] + offset - 1 # [0, ..., 1023] + gt_n_seg[offset:offset + org_n_class - 1] = org_gt_n_seg[1:] # [0...,n_seg] + gt_cls[offset:offset + org_n_class - 1] = org_gt_cls[1:] # [0, ..., n_seg] + gt_cls_dist[offset:offset + org_n_class - 1] = org_gt_cls_dist[1:] # [0, ..., n_seg] + + out['gt_seg'] = gt_seg + out['gt_cls'] = gt_cls + out['gt_cls_dist'] = gt_cls_dist + out['gt_n_seg'] = gt_n_seg + + # print('gt: ', org_n_class, gt_seg.shape, gt_n_seg.shape, gt_seg) + out['scene_name'] = scene_name + + # out['org_gt_seg'] = org_gt_seg + # out['org_gt_n_seg'] = org_gt_n_seg + # out['org_gt_cls'] = org_gt_cls + # out['org_gt_cls_dist'] = org_gt_cls_dist + + return out diff --git a/third_party/pram/dataset/seven_scenes.py b/third_party/pram/dataset/seven_scenes.py new file mode 100644 index 0000000000000000000000000000000000000000..fbc29b29d3b935e45129a35b502117067816433a --- /dev/null +++ b/third_party/pram/dataset/seven_scenes.py @@ -0,0 +1,115 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> seven_scenes +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 14:36 +==================================================''' +import os +import os.path as osp +import numpy as np +from colmap_utils.read_write_model import read_model +import torchvision.transforms as tvt +from dataset.basicdataset import BasicDataset + + +class SevenScenes(BasicDataset): + def __init__(self, landmark_path, scene, dataset_path, n_class, seg_mode, seg_method, dataset='7Scenes', + nfeatures=1024, + query_p3d_fn=None, + train=True, + with_aug=False, + min_inliers=0, + max_inliers=4096, + random_inliers=False, + jitter_params=None, + scale_params=None, + image_dim=3, + query_info_path=None, + sample_ratio=1, + ): + self.landmark_path = osp.join(landmark_path, scene) + self.dataset_path = osp.join(dataset_path, scene) + self.n_class = n_class + self.dataset = dataset + '/' + scene + self.nfeatures = nfeatures + self.with_aug = with_aug + self.jitter_params = jitter_params + self.scale_params = scale_params + self.image_dim = image_dim + self.train = train + self.min_inliers = min_inliers + self.max_inliers = max_inliers if max_inliers < nfeatures else nfeatures + self.random_inliers = random_inliers + self.image_prefix = '' + + train_transforms = [] + if self.with_aug: + train_transforms.append(tvt.ColorJitter( + brightness=jitter_params['brightness'], + contrast=jitter_params['contrast'], + saturation=jitter_params['saturation'], + hue=jitter_params['hue'])) + if jitter_params['blur'] > 0: + train_transforms.append(tvt.GaussianBlur(kernel_size=int(jitter_params['blur']))) + self.train_transforms = tvt.Compose(train_transforms) + + if train: + self.cameras, self.images, point3Ds = read_model(path=osp.join(self.landmark_path, '3D-models'), ext='.bin') + self.name_to_id = {image.name: i for i, image in self.images.items() if len(self.images[i].point3D_ids) > 0} + + # only for testing of query images + if not self.train: + data = np.load(query_p3d_fn, allow_pickle=True)[()] + self.img_p3d = data + else: + self.img_p3d = {} + + if self.train: + split_fn = osp.join(self.dataset_path, 'TrainSplit.txt') + else: + split_fn = osp.join(self.dataset_path, 'TestSplit.txt') + + self.img_fns = [] + with open(split_fn, 'r') as f: + lines = f.readlines() + for l in lines: + seq = int(l.strip()[8:]) + fns = os.listdir(osp.join(self.dataset_path, osp.join('seq-{:02d}'.format(seq)))) + fns = sorted(fns) + nf = 0 + for fn in fns: + if fn.find('png') >= 0: + if train and 'seq-{:02d}'.format(seq) + '/' + fn not in self.name_to_id.keys(): + continue + if not train and 'seq-{:02d}'.format(seq) + '/' + fn not in self.img_p3d.keys(): + continue + if nf % sample_ratio == 0: + self.img_fns.append('seq-{:02d}'.format(seq) + '/' + fn) + nf += 1 + + print('Load {} images from {} for {}...'.format(len(self.img_fns), + self.dataset, 'training' if train else 'eval')) + + data = np.load(osp.join(self.landmark_path, + 'point3D_cluster_n{:d}_{:s}_{:s}.npy'.format(n_class - 1, seg_mode, seg_method)), + allow_pickle=True)[()] + p3d_id = data['id'] + seg_id = data['label'] + self.p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])} + xyzs = data['xyz'] + self.p3d_xyzs = {p3d_id[i]: xyzs[i] for i in range(p3d_id.shape[0])} + + # with open(osp.join(self.landmark_path, 'sc_mean_scale.txt'), 'r') as f: + # lines = f.readlines() + # for l in lines: + # l = l.strip().split() + # self.mean_xyz = np.array([float(v) for v in l[:3]]) + # self.scale_xyz = np.array([float(v) for v in l[3:]]) + + if not train: + self.query_info = self.read_query_info(path=query_info_path) + + self.nfeatures = nfeatures + self.feature_dir = osp.join(self.landmark_path, 'feats') + self.feats = {} diff --git a/third_party/pram/dataset/twelve_scenes.py b/third_party/pram/dataset/twelve_scenes.py new file mode 100644 index 0000000000000000000000000000000000000000..34fcc7f46b6d4315d9ebca69043a262310adc453 --- /dev/null +++ b/third_party/pram/dataset/twelve_scenes.py @@ -0,0 +1,121 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> twelve_scenes +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 14:37 +==================================================''' +import os +import os.path as osp +import numpy as np +from colmap_utils.read_write_model import read_model +import torchvision.transforms as tvt +from dataset.basicdataset import BasicDataset + + +class TwelveScenes(BasicDataset): + def __init__(self, landmark_path, scene, dataset_path, n_class, seg_mode, seg_method, dataset='12Scenes', + nfeatures=1024, + query_p3d_fn=None, + train=True, + with_aug=False, + min_inliers=0, + max_inliers=4096, + random_inliers=False, + jitter_params=None, + scale_params=None, + image_dim=3, + query_info_path=None, + sample_ratio=1, + ): + self.landmark_path = osp.join(landmark_path, scene) + self.dataset_path = osp.join(dataset_path, scene) + self.n_class = n_class + self.dataset = dataset + '/' + scene + self.nfeatures = nfeatures + self.with_aug = with_aug + self.jitter_params = jitter_params + self.scale_params = scale_params + self.image_dim = image_dim + self.train = train + self.min_inliers = min_inliers + self.max_inliers = max_inliers if max_inliers < nfeatures else nfeatures + self.random_inliers = random_inliers + self.image_prefix = '' + + train_transforms = [] + if self.with_aug: + train_transforms.append(tvt.ColorJitter( + brightness=jitter_params['brightness'], + contrast=jitter_params['contrast'], + saturation=jitter_params['saturation'], + hue=jitter_params['hue'])) + if jitter_params['blur'] > 0: + train_transforms.append(tvt.GaussianBlur(kernel_size=int(jitter_params['blur']))) + self.train_transforms = tvt.Compose(train_transforms) + + if train: + self.cameras, self.images, point3Ds = read_model(path=osp.join(self.landmark_path, '3D-models'), ext='.bin') + self.name_to_id = {image.name: i for i, image in self.images.items() if len(self.images[i].point3D_ids) > 0} + + # only for testing of query images + if not self.train: + data = np.load(query_p3d_fn, allow_pickle=True)[()] + self.img_p3d = data + else: + self.img_p3d = {} + + with open(osp.join(self.dataset_path, 'split.txt'), 'r') as f: + l = f.readline() + l = l.strip().split(' ') # sequence0 [frames=357] [start=0 ; end=356], first sequence for testing + start_img_id = l[-3].split('=')[-1] + end_img_id = l[-1].split('=')[-1][:-1] + test_start_img_id = int(start_img_id) + test_end_img_id = int(end_img_id) + + self.img_fns = [] + fns = os.listdir(osp.join(self.dataset_path, 'data')) + fns = sorted(fns) + nf = 0 + for fn in fns: + if fn.find('jpg') >= 0: # frame-001098.color.jpg + frame_id = int(fn.split('.')[0].split('-')[-1]) + if not train and frame_id > test_end_img_id: + continue + if train and frame_id <= test_end_img_id: + continue + + if train and 'data' + '/' + fn not in self.name_to_id.keys(): + continue + + if not train and 'data' + '/' + fn not in self.img_p3d.keys(): + continue + if nf % sample_ratio == 0: + self.img_fns.append('data' + '/' + fn) + nf += 1 + + print('Load {} images from {} for {}...'.format(len(self.img_fns), + self.dataset, 'training' if train else 'eval')) + + data = np.load(osp.join(self.landmark_path, + 'point3D_cluster_n{:d}_{:s}_{:s}.npy'.format(n_class - 1, seg_mode, seg_method)), + allow_pickle=True)[()] + p3d_id = data['id'] + seg_id = data['label'] + self.p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])} + xyzs = data['xyz'] + self.p3d_xyzs = {p3d_id[i]: xyzs[i] for i in range(p3d_id.shape[0])} + + # with open(osp.join(self.landmark_path, 'sc_mean_scale.txt'), 'r') as f: + # lines = f.readlines() + # for l in lines: + # l = l.strip().split() + # self.mean_xyz = np.array([float(v) for v in l[:3]]) + # self.scale_xyz = np.array([float(v) for v in l[3:]]) + + if not train: + self.query_info = self.read_query_info(path=query_info_path) + + self.nfeatures = nfeatures + self.feature_dir = osp.join(self.landmark_path, 'feats') + self.feats = {} diff --git a/third_party/pram/dataset/utils.py b/third_party/pram/dataset/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..eb8132662c540ae28de32494a5abff6e679064f5 --- /dev/null +++ b/third_party/pram/dataset/utils.py @@ -0,0 +1,31 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> utils +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 14:31 +==================================================''' +import torch + + +def normalize_size(x, size, scale=0.7): + size = size.reshape([1, 2]) + norm_fac = size.max() + 0.5 + return (x - size / 2) / (norm_fac * scale) + + +def collect_batch(batch): + out = {} + # if len(batch) == 0: + # return batch + # else: + for k in batch[0].keys(): + tmp = [] + for v in batch: + tmp.append(v[k]) + if isinstance(batch[0][k], str) or isinstance(batch[0][k], list): + out[k] = tmp + else: + out[k] = torch.cat([torch.from_numpy(i)[None] for i in tmp], dim=0) + + return out diff --git a/third_party/pram/environment.yml b/third_party/pram/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..bf1c2111660046500e25c9ff28e66d470c7f68a9 --- /dev/null +++ b/third_party/pram/environment.yml @@ -0,0 +1,173 @@ +name: pram +channels: + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=conda_forge + - _openmp_mutex=4.5=2_gnu + - binutils_impl_linux-64=2.38=h2a08ee3_1 + - bzip2=1.0.8=h5eee18b_5 + - ca-certificates=2024.3.11=h06a4308_0 + - gcc=12.1.0=h9ea6d83_10 + - gcc_impl_linux-64=12.1.0=hea43390_17 + - kernel-headers_linux-64=2.6.32=he073ed8_17 + - ld_impl_linux-64=2.38=h1181459_1 + - libffi=3.4.4=h6a678d5_0 + - libgcc-devel_linux-64=12.1.0=h1ec3361_17 + - libgcc-ng=13.2.0=h807b86a_5 + - libgomp=13.2.0=h807b86a_5 + - libsanitizer=12.1.0=ha89aaad_17 + - libstdcxx-ng=13.2.0=h7e041cc_5 + - libuuid=1.41.5=h5eee18b_0 + - ncurses=6.4=h6a678d5_0 + - openssl=3.2.1=hd590300_1 + - pip=23.3.1=py310h06a4308_0 + - python=3.10.14=h955ad1f_0 + - readline=8.2=h5eee18b_0 + - setuptools=68.2.2=py310h06a4308_0 + - sqlite=3.41.2=h5eee18b_0 + - sysroot_linux-64=2.12=he073ed8_17 + - tk=8.6.12=h1ccaba5_0 + - wheel=0.41.2=py310h06a4308_0 + - xz=5.4.6=h5eee18b_0 + - zlib=1.2.13=h5eee18b_0 + - pip: + - addict==2.4.0 + - aiofiles==23.2.1 + - aiohttp==3.9.3 + - aioopenssl==0.6.0 + - aiosasl==0.5.0 + - aiosignal==1.3.1 + - aioxmpp==0.13.3 + - asttokens==2.4.1 + - async-timeout==4.0.3 + - attrs==23.2.0 + - babel==2.14.0 + - benbotasync==3.0.2 + - blinker==1.7.0 + - certifi==2024.2.2 + - cffi==1.16.0 + - charset-normalizer==3.3.2 + - click==8.1.7 + - colorama==0.4.6 + - comm==0.2.2 + - configargparse==1.7 + - contourpy==1.2.1 + - crayons==0.4.0 + - cryptography==42.0.5 + - cycler==0.12.1 + - dash==2.16.1 + - dash-core-components==2.0.0 + - dash-html-components==2.0.0 + - dash-table==5.0.0 + - decorator==5.1.1 + - dnspython==2.6.1 + - einops==0.7.0 + - exceptiongroup==1.2.0 + - executing==2.0.1 + - fastjsonschema==2.19.1 + - filelock==3.13.3 + - flask==3.0.2 + - fonttools==4.50.0 + - fortniteapiasync==0.1.7 + - fortnitepy==3.6.9 + - frozenlist==1.4.1 + - fsspec==2024.3.1 + - h5py==3.10.0 + - html5tagger==1.3.0 + - httptools==0.6.1 + - idna==3.6 + - importlib-metadata==7.1.0 + - ipython==8.23.0 + - ipywidgets==8.1.2 + - itsdangerous==2.1.2 + - jedi==0.19.1 + - jinja2==3.1.3 + - joblib==1.3.2 + - jsonschema==4.21.1 + - jsonschema-specifications==2023.12.1 + - jupyter-core==5.7.2 + - jupyterlab-widgets==3.0.10 + - kiwisolver==1.4.5 + - lxml==4.9.4 + - markupsafe==2.1.5 + - matplotlib==3.8.4 + - matplotlib-inline==0.1.6 + - mpmath==1.3.0 + - multidict==6.0.5 + - nbformat==5.10.4 + - nest-asyncio==1.6.0 + - networkx==3.2.1 + - numpy==1.26.4 + - nvidia-cublas-cu12==12.1.3.1 + - nvidia-cuda-cupti-cu12==12.1.105 + - nvidia-cuda-nvrtc-cu12==12.1.105 + - nvidia-cuda-runtime-cu12==12.1.105 + - nvidia-cudnn-cu12==8.9.2.26 + - nvidia-cufft-cu12==11.0.2.54 + - nvidia-curand-cu12==10.3.2.106 + - nvidia-cusolver-cu12==11.4.5.107 + - nvidia-cusparse-cu12==12.1.0.106 + - nvidia-nccl-cu12==2.19.3 + - nvidia-nvjitlink-cu12==12.4.127 + - nvidia-nvtx-cu12==12.1.105 + - open3d==0.18.0 + - opencv-contrib-python==4.5.5.64 + - packaging==24.0 + - pandas==2.2.1 + - parso==0.8.3 + - pexpect==4.9.0 + - pillow==10.3.0 + - platformdirs==4.2.0 + - plotly==5.20.0 + - prompt-toolkit==3.0.43 + - ptyprocess==0.7.0 + - pure-eval==0.2.2 + - pyasn1==0.6.0 + - pyasn1-modules==0.4.0 + - pybind11==2.12.0 + - pycolmap==0.6.1 + - pycparser==2.22 + - pygments==2.17.2 + - pyopengl==3.1.7 + - pyopengl-accelerate==3.1.7 + - pyopenssl==24.1.0 + - pyparsing==3.1.2 + - pyquaternion==0.9.9 + - python-dateutil==2.9.0.post0 + - pytz==2024.1 + - pyyaml==6.0.1 + - referencing==0.34.0 + - requests==2.31.0 + - retrying==1.3.4 + - rpds-py==0.18.0 + - sanic==23.12.1 + - sanic-routing==23.12.0 + - scikit-learn==1.4.1.post1 + - scipy==1.13.0 + - six==1.16.0 + - sortedcollections==2.1.0 + - sortedcontainers==2.4.0 + - stack-data==0.6.3 + - sympy==1.12 + - tenacity==8.2.3 + - threadpoolctl==3.4.0 + - torch==2.2.2 + - torchvision==0.17.2 + - tqdm==4.66.2 + - tracerite==1.1.1 + - traitlets==5.14.2 + - triton==2.2.0 + - typing-extensions==4.10.0 + - tzdata==2024.1 + - tzlocal==5.2 + - ujson==5.9.0 + - urllib3==2.2.1 + - uvloop==0.15.2 + - wcwidth==0.2.13 + - websockets==12.0 + - werkzeug==3.0.2 + - widgetsnbextension==4.0.10 + - yaml2==0.0.1 + - yarl==1.9.4 + - zipp==3.18.1 diff --git a/third_party/pram/inference.py b/third_party/pram/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..29ccd76911f0b2ff8dc82fc28c712cf1d19d40be --- /dev/null +++ b/third_party/pram/inference.py @@ -0,0 +1,62 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> inference +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 03/04/2024 16:06 +==================================================''' +import argparse +import torch +import torchvision.transforms.transforms as tvt +import yaml +from nets.load_segnet import load_segnet +from nets.sfd2 import load_sfd2 +from dataset.get_dataset import compose_datasets + +parser = argparse.ArgumentParser(description='PRAM', formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--config', type=str, required=True, help='config of specifications') +parser.add_argument('--landmark_path', type=str, required=True, help='path of landmarks') +parser.add_argument('--feat_weight_path', type=str, default='weights/sfd2_20230511_210205_resnet4x.79.pth') +parser.add_argument('--rec_weight_path', type=str, required=True, help='recognition weight') +parser.add_argument('--online', action='store_true', help='online visualization with pangolin') + +if __name__ == '__main__': + args = parser.parse_args() + with open(args.config, 'rt') as f: + config = yaml.load(f, Loader=yaml.Loader) + config['landmark_path'] = args.landmark_path + + feat_model = load_sfd2(weight_path=args.feat_weight_path).cuda().eval() + print('Load SFD2 weight from {:s}'.format(args.feat_weight_path)) + + # rec_model = get_model(config=config) + rec_model = load_segnet(network=config['network'], + n_class=config['n_class'], + desc_dim=256 if config['use_mid_feature'] else 128, + n_layers=config['layers'], + output_dim=config['output_dim']) + state_dict = torch.load(args.rec_weight_path, map_location='cpu')['model'] + rec_model.load_state_dict(state_dict, strict=True) + print('Load recognition weight from {:s}'.format(args.rec_weight_path)) + + img_transforms = [] + img_transforms.append(tvt.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) + img_transforms = tvt.Compose(img_transforms) + + dataset = config['dataset'] + if not args.online: + from localization.loc_by_rec_eval import loc_by_rec_eval + + test_set = compose_datasets(datasets=dataset, config=config, train=False, sample_ratio=1) + config['n_class'] = test_set.n_class + + loc_by_rec_eval(rec_model=rec_model.cuda().eval(), + loader=test_set, + local_feat=feat_model.cuda().eval(), + config=config, img_transforms=img_transforms) + else: + from localization.loc_by_rec_online import loc_by_rec_online + + loc_by_rec_online(rec_model=rec_model.cuda().eval(), + local_feat=feat_model.cuda().eval(), + config=config, img_transforms=img_transforms) diff --git a/third_party/pram/localization/base_model.py b/third_party/pram/localization/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..432f49c325d39aa44efb0c3106abf7e376c8244e --- /dev/null +++ b/third_party/pram/localization/base_model.py @@ -0,0 +1,45 @@ +from abc import ABCMeta, abstractmethod +from torch import nn +from copy import copy +import inspect + + +class BaseModel(nn.Module, metaclass=ABCMeta): + default_conf = {} + required_data_keys = [] + + def __init__(self, conf): + """Perform some logic and call the _init method of the child model.""" + super().__init__() + self.conf = conf = {**self.default_conf, **conf} + self.required_data_keys = copy(self.required_data_keys) + self._init(conf) + + def forward(self, data): + """Check the data and call the _forward method of the child model.""" + for key in self.required_data_keys: + assert key in data, 'Missing key {} in data'.format(key) + return self._forward(data) + + @abstractmethod + def _init(self, conf): + """To be implemented by the child class.""" + raise NotImplementedError + + @abstractmethod + def _forward(self, data): + """To be implemented by the child class.""" + raise NotImplementedError + + +def dynamic_load(root, model): + module_path = f'{root.__name__}.{model}' + module = __import__(module_path, fromlist=['']) + classes = inspect.getmembers(module, inspect.isclass) + # Filter classes defined in the module + classes = [c for c in classes if c[1].__module__ == module_path] + # Filter classes inherited from BaseModel + classes = [c for c in classes if issubclass(c[1], BaseModel)] + assert len(classes) == 1, classes + return classes[0][1] + # return getattr(module, 'Model') diff --git a/third_party/pram/localization/camera.py b/third_party/pram/localization/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..a1d77af63bcac68b87acd6f5ddc19d92c7d99d07 --- /dev/null +++ b/third_party/pram/localization/camera.py @@ -0,0 +1,11 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> camera +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 04/03/2024 11:27 +==================================================''' +import collections + +Camera = collections.namedtuple( + "Camera", ["id", "model", "width", "height", "params"]) diff --git a/third_party/pram/localization/extract_features.py b/third_party/pram/localization/extract_features.py new file mode 100644 index 0000000000000000000000000000000000000000..cd3f85c53dafd33fe737fdb9e79eeee1bd1c600b --- /dev/null +++ b/third_party/pram/localization/extract_features.py @@ -0,0 +1,256 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> extract_features.py +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 07/02/2024 14:49 +==================================================''' +import os +import os.path as osp +import h5py +import numpy as np +import progressbar +import yaml +import torch +import cv2 +import torch.utils.data as Data +from tqdm import tqdm +from types import SimpleNamespace +import logging +import pprint +from pathlib import Path +import argparse +from nets.sfd2 import ResNet4x, extract_sfd2_return +from nets.superpoint import SuperPoint, extract_sp_return + +confs = { + 'superpoint-n4096': { + 'output': 'feats-superpoint-n4096', + 'model': { + 'name': 'superpoint', + 'outdim': 256, + 'use_stability': False, + 'nms_radius': 3, + 'max_keypoints': 4096, + 'conf_th': 0.005, + 'multiscale': False, + 'scales': [1.0], + 'model_fn': osp.join(os.getcwd(), + "weights/superpoint_v1.pth"), + }, + 'preprocessing': { + 'grayscale': True, + 'resize_max': False, + }, + }, + + 'resnet4x-20230511-210205-pho-0005': { + 'output': 'feats-resnet4x-20230511-210205-pho-0005', + 'model': { + 'outdim': 128, + 'name': 'resnet4x', + 'use_stability': False, + 'max_keypoints': 4096, + 'conf_th': 0.005, + 'multiscale': False, + 'scales': [1.0], + 'model_fn': osp.join(os.getcwd(), + "weights/sfd2_20230511_210205_resnet4x.79.pth"), + }, + 'preprocessing': { + 'grayscale': False, + 'resize_max': False, + }, + 'mask': False, + }, + + 'sfd2': { + 'output': 'feats-sfd2', + 'model': { + 'outdim': 128, + 'name': 'resnet4x', + 'use_stability': False, + 'max_keypoints': 4096, + 'conf_th': 0.005, + 'multiscale': False, + 'scales': [1.0], + 'model_fn': osp.join(os.getcwd(), + "weights/sfd2_20230511_210205_resnet4x.79.pth"), + }, + 'preprocessing': { + 'grayscale': False, + 'resize_max': False, + }, + 'mask': False, + }, +} + + +class ImageDataset(Data.Dataset): + default_conf = { + 'globs': ['*.jpg', '*.png', '*.jpeg', '*.JPG', '*.PNG'], + 'grayscale': False, + 'resize_max': None, + 'resize_force': False, + } + + def __init__(self, root, conf, image_list=None, + mask_root=None): + self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf}) + self.root = root + + self.paths = [] + if image_list is None: + for g in conf.globs: + self.paths += list(Path(root).glob('**/' + g)) + if len(self.paths) == 0: + raise ValueError(f'Could not find any image in root: {root}.') + self.paths = [i.relative_to(root) for i in self.paths] + else: + with open(image_list, "r") as f: + lines = f.readlines() + for l in lines: + l = l.strip() + self.paths.append(Path(l)) + + logging.info(f'Found {len(self.paths)} images in root {root}.') + + if mask_root is not None: + self.mask_root = mask_root + else: + self.mask_root = None + + def __getitem__(self, idx): + path = self.paths[idx] + if self.conf.grayscale: + mode = cv2.IMREAD_GRAYSCALE + else: + mode = cv2.IMREAD_COLOR + image = cv2.imread(str(self.root / path), mode) + if not self.conf.grayscale: + image = image[:, :, ::-1] # BGR to RGB + if image is None: + raise ValueError(f'Cannot read image {str(path)}.') + image = image.astype(np.float32) + size = image.shape[:2][::-1] + w, h = size + + if self.conf.resize_max and (self.conf.resize_force + or max(w, h) > self.conf.resize_max): + scale = self.conf.resize_max / max(h, w) + h_new, w_new = int(round(h * scale)), int(round(w * scale)) + image = cv2.resize( + image, (w_new, h_new), interpolation=cv2.INTER_CUBIC) + + if self.conf.grayscale: + image = image[None] + else: + image = image.transpose((2, 0, 1)) # HxWxC to CxHxW + image = image / 255. + + data = { + 'name': str(path), + 'image': image, + 'original_size': np.array(size), + } + + if self.mask_root is not None: + mask_path = Path(str(path).replace("jpg", "png")) + if osp.exists(mask_path): + mask = cv2.imread(str(self.mask_root / mask_path)) + mask = cv2.resize(mask, dsize=(image.shape[2], image.shape[1]), interpolation=cv2.INTER_NEAREST) + else: + mask = np.zeros(shape=(image.shape[1], image.shape[2], 3), dtype=np.uint8) + + data['mask'] = mask + + return data + + def __len__(self): + return len(self.paths) + + +def get_model(model_name, weight_path, outdim=128, **kwargs): + if model_name == 'superpoint': + model = SuperPoint(config={ + 'descriptor_dim': 256, + 'nms_radius': 4, + 'keypoint_threshold': 0.005, + 'max_keypoints': -1, + 'remove_borders': 4, + 'weight_path': weight_path, + }).eval() + + extractor = extract_sp_return + + if model_name == 'resnet4x': + model = ResNet4x(outdim=outdim).eval() + model.load_state_dict(torch.load(weight_path)['state_dict'], strict=True) + extractor = extract_sfd2_return + + return model, extractor + + +@torch.no_grad() +def main(conf, image_dir, export_dir): + logging.info('Extracting local features with configuration:' + f'\n{pprint.pformat(conf)}') + model, extractor = get_model(model_name=conf['model']['name'], weight_path=conf["model"]["model_fn"], + use_stability=conf['model']['use_stability'], outdim=conf['model']['outdim']) + model = model.cuda() + loader = ImageDataset(image_dir, + conf['preprocessing'], + image_list=args.image_list, + mask_root=None) + loader = torch.utils.data.DataLoader(loader, num_workers=4) + + os.makedirs(export_dir, exist_ok=True) + feature_path = Path(export_dir, conf['output'] + '.h5') + feature_path.parent.mkdir(exist_ok=True, parents=True) + feature_file = h5py.File(str(feature_path), 'a') + + with tqdm(total=len(loader)) as t: + for idx, data in enumerate(loader): + t.update() + pred = extractor(model, img=data["image"], + topK=conf["model"]["max_keypoints"], + mask=None, + conf_th=conf["model"]["conf_th"], + scales=conf["model"]["scales"], + ) + + # pred = {k: v[0].cpu().numpy() for k, v in pred.items()} + pred['descriptors'] = pred['descriptors'].transpose() + + t.set_postfix(npoints=pred['keypoints'].shape[0]) + # print(pred['keypoints'].shape) + + pred['image_size'] = original_size = data['original_size'][0].numpy() + # pred['descriptors'] = pred['descriptors'].T + if 'keypoints' in pred.keys(): + size = np.array(data['image'].shape[-2:][::-1]) + scales = (original_size / size).astype(np.float32) + pred['keypoints'] = (pred['keypoints'] + .5) * scales[None] - .5 + + grp = feature_file.create_group(data['name'][0]) + for k, v in pred.items(): + # print(k, v.shape) + grp.create_dataset(k, data=v) + + del pred + + feature_file.close() + logging.info('Finished exporting features.') + + return feature_path + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--image_dir', type=Path, required=True) + parser.add_argument('--image_list', type=str, default=None) + parser.add_argument('--mask_dir', type=Path, default=None) + parser.add_argument('--export_dir', type=Path, required=True) + parser.add_argument('--conf', type=str, required=True, choices=list(confs.keys())) + args = parser.parse_args() + main(confs[args.conf], args.image_dir, args.export_dir) diff --git a/third_party/pram/localization/frame.py b/third_party/pram/localization/frame.py new file mode 100644 index 0000000000000000000000000000000000000000..467a0f31a9c62a19b4435c71add6d08e34b051f3 --- /dev/null +++ b/third_party/pram/localization/frame.py @@ -0,0 +1,195 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> frame +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 01/03/2024 10:08 +==================================================''' +from collections import defaultdict + +import numpy as np +import torch +import pycolmap + +from localization.camera import Camera +from localization.utils import compute_pose_error + + +class Frame: + def __init__(self, image: np.ndarray, camera: pycolmap.Camera, id: int, name: str = None, qvec=None, tvec=None, + scene_name=None, + reference_frame_id=None): + self.image = image + self.camera = camera + self.id = id + self.name = name + self.image_size = np.array([camera.height, camera.width]) + self.qvec = qvec + self.tvec = tvec + self.scene_name = scene_name + self.reference_frame_id = reference_frame_id + + self.keypoints = None # [N, 3] + self.descriptors = None # [N, D] + self.segmentations = None # [N C] + self.seg_scores = None # [N C] + self.seg_ids = None # [N, 1] + self.point3D_ids = None # [N, 1] + self.xyzs = None + + self.gt_qvec = None + self.gt_tvec = None + + self.matched_scene_name = None + self.matched_keypoints = None + self.matched_keypoint_ids = None + self.matched_xyzs = None + self.matched_point3D_ids = None + self.matched_inliers = None + self.matched_sids = None + self.matched_order = None + + self.refinement_reference_frame_ids = None + self.image_rec = None + self.image_matching = None + self.image_inlier = None + self.reference_frame_name = None + self.image_matching_tmp = None + self.image_inlier_tmp = None + self.reference_frame_name_tmp = None + + self.tracking_status = None + + self.time_feat = 0 + self.time_rec = 0 + self.time_loc = 0 + self.time_ref = 0 + + def update_point3ds_old(self): + pt = torch.from_numpy(self.keypoints[:, :2]).unsqueeze(-1) # [M 2 1] + mpt = torch.from_numpy(self.matched_keypoints[:, :2].transpose()).unsqueeze(0) # [1 2 N] + dist = torch.sqrt(torch.sum((pt - mpt) ** 2, dim=1)) + values, ids = torch.topk(dist, dim=1, k=1, largest=False) + values = values[:, 0].numpy() + ids = ids[:, 0].numpy() + mask = (values < 1) # 1 pixel error + self.point3D_ids = np.zeros(shape=(self.keypoints.shape[0],), dtype=int) - 1 + self.point3D_ids[mask] = self.matched_point3D_ids[ids[mask]] + + # self.xyzs = np.zeros(shape=(self.keypoints.shape[0], 3), dtype=float) + inlier_mask = self.matched_inliers + self.xyzs[mask] = self.matched_xyzs[ids[mask]] + self.seg_ids[mask] = self.matched_sids[ids[mask]] + + def update_point3ds(self): + # print('Frame: update_point3ds: ', self.matched_keypoint_ids.shape, self.matched_xyzs.shape, + # self.matched_sids.shape, self.matched_point3D_ids.shape) + self.xyzs[self.matched_keypoint_ids] = self.matched_xyzs + self.seg_ids[self.matched_keypoint_ids] = self.matched_sids + self.point3D_ids[self.matched_keypoint_ids] = self.matched_point3D_ids + + def add_keypoints(self, keypoints: np.ndarray, descriptors: np.ndarray): + self.keypoints = keypoints + self.descriptors = descriptors + self.initialize_localization_variables() + + def add_segmentations(self, segmentations: torch.Tensor, filtering_threshold: float): + ''' + :param segmentations: [number_points number_labels] + :return: + ''' + seg_scores = torch.softmax(segmentations, dim=-1) + if filtering_threshold > 0: + scores_background = seg_scores[:, 0] + non_bg_mask = (scores_background < filtering_threshold) + print('pre filtering before: ', self.keypoints.shape) + if torch.sum(non_bg_mask) >= 0.4 * seg_scores.shape[0]: + self.keypoints = self.keypoints[non_bg_mask.cpu().numpy()] + self.descriptors = self.descriptors[non_bg_mask.cpu().numpy()] + # print('pre filtering after: ', self.keypoints.shape) + + # update localization variables + self.initialize_localization_variables() + + segmentations = segmentations[non_bg_mask] + seg_scores = seg_scores[non_bg_mask] + print('pre filtering after: ', self.keypoints.shape) + + # extract initial segmentation info + self.segmentations = segmentations.cpu().numpy() + self.seg_scores = seg_scores.cpu().numpy() + self.seg_ids = segmentations.max(dim=-1)[1].cpu().numpy() - 1 # should start from 0 + + def filter_keypoints(self, seg_scores: np.ndarray, filtering_threshold: float): + scores_background = seg_scores[:, 0] + non_bg_mask = (scores_background < filtering_threshold) + print('pre filtering before: ', self.keypoints.shape) + if np.sum(non_bg_mask) >= 0.4 * seg_scores.shape[0]: + self.keypoints = self.keypoints[non_bg_mask] + self.descriptors = self.descriptors[non_bg_mask] + print('pre filtering after: ', self.keypoints.shape) + + # update localization variables + self.initialize_localization_variables() + return non_bg_mask + else: + print('pre filtering after: ', self.keypoints.shape) + return None + + def compute_pose_error(self, pred_qvec=None, pred_tvec=None): + if pred_qvec is not None and pred_tvec is not None: + if self.gt_qvec is not None and self.gt_tvec is not None: + return compute_pose_error(pred_qcw=pred_qvec, pred_tcw=pred_tvec, + gt_qcw=self.gt_qvec, gt_tcw=self.gt_tvec) + else: + return 100, 100 + + if self.qvec is None or self.tvec is None or self.gt_qvec is None or self.gt_tvec is None: + return 100, 100 + else: + err_q, err_t = compute_pose_error(pred_qcw=self.qvec, pred_tcw=self.tvec, + gt_qcw=self.gt_qvec, gt_tcw=self.gt_tvec) + return err_q, err_t + + def get_intrinsics(self) -> np.ndarray: + camera_model = self.camera.model.name + params = self.camera.params + if camera_model in ("SIMPLE_PINHOLE", "SIMPLE_RADIAL", "RADIAL"): + fx = fy = params[0] + cx = params[1] + cy = params[2] + elif camera_model in ("PINHOLE", "OPENCV", "OPENCV_FISHEYE", "FULL_OPENCV"): + fx = params[0] + fy = params[1] + cx = params[2] + cy = params[3] + else: + raise Exception("Camera model not supported") + + # intrinsics + K = np.identity(3) + K[0, 0] = fx + K[1, 1] = fy + K[0, 2] = cx + K[1, 2] = cy + return K + + def get_dominate_seg_id(self): + counts = np.bincount(self.seg_ids[self.seg_ids > 0]) + return np.argmax(counts) + + def clear_localization_track(self): + self.matched_scene_name = None + self.matched_keypoints = None + self.matched_xyzs = None + self.matched_point3D_ids = None + self.matched_inliers = None + self.matched_sids = None + + self.refinement_reference_frame_ids = None + + def initialize_localization_variables(self): + nkpt = self.keypoints.shape[0] + self.seg_ids = np.zeros(shape=(nkpt,), dtype=int) - 1 + self.point3D_ids = np.zeros(shape=(nkpt,), dtype=int) - 1 + self.xyzs = np.zeros(shape=(nkpt, 3), dtype=float) diff --git a/third_party/pram/localization/loc_by_rec_eval.py b/third_party/pram/localization/loc_by_rec_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..f69b4ac3fde0547947abe983b1f5a4a4af55f974 --- /dev/null +++ b/third_party/pram/localization/loc_by_rec_eval.py @@ -0,0 +1,299 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> loc_by_rec +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 08/02/2024 15:26 +==================================================''' +import torch +from torch.autograd import Variable +from localization.multimap3d import MultiMap3D +from localization.frame import Frame +import yaml, cv2, time +import numpy as np +import os.path as osp +import threading +import os +from tqdm import tqdm +from recognition.vis_seg import vis_seg_point, generate_color_dic +from tools.metrics import compute_iou, compute_precision +from localization.tracker import Tracker +from localization.utils import read_query_info +from localization.camera import Camera + + +def loc_by_rec_eval(rec_model, loader, config, local_feat, img_transforms=None): + n_epoch = int(config['weight_path'].split('.')[1]) + save_fn = osp.join(config['localization']['save_path'], + config['weight_path'].split('/')[0] + '_{:d}'.format(n_epoch) + '_{:d}'.format( + config['feat_dim'])) + tag = 'k{:d}_th{:d}_mm{:d}_mi{:d}'.format(config['localization']['seg_k'], config['localization']['threshold'], + config['localization']['min_matches'], + config['localization']['min_inliers']) + if config['localization']['do_refinement']: + tag += '_op{:d}'.format(config['localization']['covisibility_frame']) + if config['localization']['with_compress']: + tag += '_comp' + + save_fn = save_fn + '_' + tag + + save = config['localization']['save'] + save = config['localization']['save'] + if save: + save_dir = save_fn + os.makedirs(save_dir, exist_ok=True) + else: + save_dir = None + + seg_color = generate_color_dic(n_seg=2000) + dataset_path = config['dataset_path'] + show = config['localization']['show'] + if show: + cv2.namedWindow('img', cv2.WINDOW_NORMAL) + + locMap = MultiMap3D(config=config, save_dir=None) + # start tracker + mTracker = Tracker(locMap=locMap, matcher=locMap.matcher, config=config) + + dataset_name = config['dataset'][0] + all_scene_query_info = {} + with open(osp.join(config['config_path'], '{:s}.yaml'.format(dataset_name)), 'r') as f: + scene_config = yaml.load(f, Loader=yaml.Loader) + scenes = scene_config['scenes'] + for scene in scenes: + query_path = osp.join(config['dataset_path'], dataset_name, scene, scene_config[scene]['query_path']) + query_info = read_query_info(query_fn=query_path) + all_scene_query_info[dataset_name + '/' + scene] = query_info + # print(scene, query_info.keys()) + + tracking = False + + full_log = '' + failed_cases = [] + success_cases = [] + poses = {} + err_ths_cnt = [0, 0, 0, 0] + + seg_results = {} + time_results = { + 'feat': [], + 'rec': [], + 'loc': [], + 'ref': [], + 'total': [], + } + n_total = 0 + + loc_scene_names = config['localization']['loc_scene_name'] + # loader = loader[8990:] + for bid, pred in tqdm(enumerate(loader), total=len(loader)): + pred = loader[bid] + image_name = pred['file_name'] # [0] + scene_name = pred['scene_name'] # [0] # dataset_scene + if len(loc_scene_names) > 0: + skip = True + for loc_scene in loc_scene_names: + if scene_name.find(loc_scene) > 0: + skip = False + break + if skip: + continue + with torch.no_grad(): + for k in pred: + if k.find('name') >= 0: + continue + if k != 'image0' and k != 'image1' and k != 'depth0' and k != 'depth1': + if type(pred[k]) == np.ndarray: + pred[k] = Variable(torch.from_numpy(pred[k]).float().cuda())[None] + elif type(pred[k]) == torch.Tensor: + pred[k] = Variable(pred[k].float().cuda()) + elif type(pred[k]) == list: + continue + else: + pred[k] = Variable(torch.stack(pred[k]).float().cuda()) + print('scene: ', scene_name, image_name) + + n_total += 1 + with torch.no_grad(): + img = pred['image'] + while isinstance(img, list): + img = img[0] + + new_im = torch.from_numpy(img).permute(2, 0, 1).cuda().float() + if img_transforms is not None: + new_im = img_transforms(new_im)[None] + else: + new_im = new_im[None] + img = (img * 255).astype(np.uint8) + + fn = image_name + camera_model, width, height, params = all_scene_query_info[scene_name][fn] + camera = Camera(id=-1, model=camera_model, width=width, height=height, params=params) + curr_frame = Frame(image=img, camera=camera, id=0, name=fn, scene_name=scene_name) + gt_sub_map = locMap.sub_maps[curr_frame.scene_name] + if gt_sub_map.gt_poses is not None and curr_frame.name in gt_sub_map.gt_poses.keys(): + curr_frame.gt_qvec = gt_sub_map.gt_poses[curr_frame.name]['qvec'] + curr_frame.gt_tvec = gt_sub_map.gt_poses[curr_frame.name]['tvec'] + + t_start = time.time() + encoder_out = local_feat.extract_local_global(data={'image': new_im}, + config= + { + # 'min_keypoints': 128, + 'max_keypoints': config['eval_max_keypoints'], + } + ) + t_feat = time.time() - t_start + # global_descriptors_cuda = encoder_out['global_descriptors'] + # scores_cuda = encoder_out['scores'][0][None] + # kpts_cuda = encoder_out['keypoints'][0][None] + # descriptors_cuda = encoder_out['descriptors'][0][None].permute(0, 2, 1) + + sparse_scores = pred['scores'] + sparse_descs = pred['descriptors'] + sparse_kpts = pred['keypoints'] + gt_seg = pred['gt_seg'] + + curr_frame.add_keypoints(keypoints=np.hstack([sparse_kpts[0].cpu().numpy(), + sparse_scores[0].cpu().numpy().reshape(-1, 1)]), + descriptors=sparse_descs[0].cpu().numpy()) + curr_frame.time_feat = t_feat + + t_start = time.time() + _, seg_descriptors = local_feat.sample(score_map=encoder_out['score_map'], + semi_descs=encoder_out['mid_features'], + # kpts=kpts_cuda[0], + kpts=sparse_kpts[0], + norm_desc=config['norm_desc']) + rec_out = rec_model({'scores': sparse_scores, + 'seg_descriptors': seg_descriptors[None].permute(0, 2, 1), + 'keypoints': sparse_kpts, + 'image': new_im}) + t_rec = time.time() - t_start + curr_frame.time_rec = t_rec + + pred = { + # 'scores': scores_cuda, + # 'keypoints': kpts_cuda, + # 'descriptors': descriptors_cuda, + # 'global_descriptors': global_descriptors_cuda, + 'image_size': np.array([img.shape[1], img.shape[0]])[None], + } + + pred = {**pred, **rec_out} + pred_seg = torch.max(pred['prediction'], dim=2)[1] # [B, N, C] + + pred_seg = pred_seg[0].cpu().numpy() + kpts = sparse_kpts[0].cpu().numpy() + img_pred_seg = vis_seg_point(img=img, kpts=kpts, segs=pred_seg, seg_color=seg_color, radius=9) + show_text = 'kpts: {:d}'.format(kpts.shape[0]) + img_pred_seg = cv2.putText(img=img_pred_seg, text=show_text, + org=(50, 30), + fontFace=cv2.FONT_HERSHEY_SIMPLEX, + fontScale=1, color=(0, 0, 255), + thickness=2, lineType=cv2.LINE_AA) + curr_frame.image_rec = img_pred_seg + + if show: + cv2.imshow('img', img) + key = cv2.waitKey(1) + if key == ord('q'): + exit(0) + elif key == ord('s'): + show_time = -1 + elif key == ord('c'): + show_time = 1 + + segmentations = pred['prediction'][0] # .cpu().numpy() # [N, C] + curr_frame.add_segmentations(segmentations=segmentations, + filtering_threshold=config['localization']['pre_filtering_th']) + + # Step1: do tracker first + success = not mTracker.lost and tracking + if success: + success = mTracker.run(frame=curr_frame) + if not success: + success = locMap.run(q_frame=curr_frame) + if success: + curr_frame.update_point3ds() + if tracking: + mTracker.lost = False + mTracker.last_frame = curr_frame + # ''' + pred_seg = torch.max(pred['prediction'], dim=-1)[1] # [B, N, C] + pred_seg = pred_seg[0].cpu().numpy() + gt_seg = gt_seg[0].cpu().numpy() + iou = compute_iou(pred=pred_seg, target=gt_seg, n_class=pred_seg.shape[0], + ignored_ids=[0]) # 0 - background + prec = compute_precision(pred=pred_seg, target=gt_seg, ignored_ids=[0]) + + kpts = sparse_kpts[0].cpu().numpy() + if scene not in seg_results.keys(): + seg_results[scene] = { + 'day': { + 'prec': [], + 'iou': [], + 'kpts': [], + }, + 'night': { + 'prec': [], + 'iou': [], + 'kpts': [], + + } + } + if fn.find('night') >= 0: + seg_results[scene]['night']['prec'].append(prec) + seg_results[scene]['night']['iou'].append(iou) + seg_results[scene]['night']['kpts'].append(kpts.shape[0]) + else: + seg_results[scene]['day']['prec'].append(prec) + seg_results[scene]['day']['iou'].append(iou) + seg_results[scene]['day']['kpts'].append(kpts.shape[0]) + + print_text = 'name: {:s}, kpts: {:d}, iou: {:.3f}, prec: {:.3f}'.format(fn, kpts.shape[0], iou, + prec) + print(print_text) + # ''' + + t_feat = curr_frame.time_feat + t_rec = curr_frame.time_rec + t_loc = curr_frame.time_loc + t_ref = curr_frame.time_ref + t_total = t_feat + t_rec + t_loc + t_ref + time_results['feat'].append(t_feat) + time_results['rec'].append(t_rec) + time_results['loc'].append(t_loc) + time_results['ref'].append(t_ref) + time_results['total'].append(t_total) + + poses[scene + '/' + fn] = (curr_frame.qvec, curr_frame.tvec) + q_err, t_err = curr_frame.compute_pose_error() + if q_err <= 5 and t_err <= 0.05: + err_ths_cnt[0] = err_ths_cnt[0] + 1 + if q_err <= 2 and t_err <= 0.25: + err_ths_cnt[1] = err_ths_cnt[1] + 1 + if q_err <= 5 and t_err <= 0.5: + err_ths_cnt[2] = err_ths_cnt[2] + 1 + if q_err <= 10 and t_err <= 5: + err_ths_cnt[3] = err_ths_cnt[3] + 1 + + if success: + success_cases.append(scene + '/' + fn) + print_text = 'qname: {:s} localization success {:d}/{:d}, q_err: {:.2f}, t_err: {:.2f}, {:d}/{:d}/{:d}/{:d}/{:d}, time: {:.2f}/{:.2f}/{:.2f}/{:.2f}/{:.2f}'.format( + scene + '/' + fn, len(success_cases), n_total, q_err, t_err, err_ths_cnt[0], + err_ths_cnt[1], + err_ths_cnt[2], + err_ths_cnt[3], + n_total, + t_feat, t_rec, t_loc, t_ref, t_total + ) + else: + failed_cases.append(scene + '/' + fn) + print_text = 'qname: {:s} localization fail {:d}/{:d}, q_err: {:.2f}, t_err: {:.2f}, {:d}/{:d}/{:d}/{:d}/{:d}, time: {:.2f}/{:.2f}/{:.2f}/{:.2f}/{:.2f}'.format( + scene + '/' + fn, len(failed_cases), n_total, q_err, t_err, err_ths_cnt[0], + err_ths_cnt[1], + err_ths_cnt[2], + err_ths_cnt[3], + n_total, t_feat, t_rec, t_loc, t_ref, t_total) + print(print_text) diff --git a/third_party/pram/localization/loc_by_rec_online.py b/third_party/pram/localization/loc_by_rec_online.py new file mode 100644 index 0000000000000000000000000000000000000000..58afed6eb439b23b4a0bc7daf45d50098bcc4fc2 --- /dev/null +++ b/third_party/pram/localization/loc_by_rec_online.py @@ -0,0 +1,225 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> loc_by_rec +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 08/02/2024 15:26 +==================================================''' +import torch +import pycolmap +from localization.multimap3d import MultiMap3D +from localization.frame import Frame +import yaml, cv2, time +import numpy as np +import os.path as osp +import threading +from recognition.vis_seg import vis_seg_point, generate_color_dic +from tools.common import resize_img +from localization.viewer import Viewer +from localization.tracker import Tracker +from localization.utils import read_query_info +from tools.common import puttext_with_background + + +def loc_by_rec_online(rec_model, config, local_feat, img_transforms=None): + seg_color = generate_color_dic(n_seg=2000) + dataset_path = config['dataset_path'] + show = config['localization']['show'] + if show: + cv2.namedWindow('img', cv2.WINDOW_NORMAL) + + locMap = MultiMap3D(config=config, save_dir=None) + if config['dataset'][0] in ['Aachen']: + viewer_config = {'scene': 'outdoor', + 'image_size_indoor': 4, + 'image_line_width_indoor': 8, } + elif config['dataset'][0] in ['C']: + viewer_config = {'scene': 'outdoor'} + elif config['dataset'][0] in ['12Scenes', '7Scenes']: + viewer_config = {'scene': 'indoor', } + else: + viewer_config = {'scene': 'outdoor', + 'image_size_indoor': 0.4, + 'image_line_width_indoor': 2, } + # start viewer + mViewer = Viewer(locMap=locMap, seg_color=seg_color, config=viewer_config) + mViewer.refinement = locMap.do_refinement + # locMap.viewer = mViewer + viewer_thread = threading.Thread(target=mViewer.run) + viewer_thread.start() + + # start tracker + mTracker = Tracker(locMap=locMap, matcher=locMap.matcher, config=config) + + dataset_name = config['dataset'][0] + all_scene_query_info = {} + with open(osp.join(config['config_path'], '{:s}.yaml'.format(dataset_name)), 'r') as f: + scene_config = yaml.load(f, Loader=yaml.Loader) + + # multiple scenes in a single dataset + err_ths_cnt = [0, 0, 0, 0] + + show_time = -1 + scenes = scene_config['scenes'] + n_total = 0 + for scene in scenes: + if len(config['localization']['loc_scene_name']) > 0: + if scene not in config['localization']['loc_scene_name']: + continue + + query_path = osp.join(config['dataset_path'], dataset_name, scene, scene_config[scene]['query_path']) + query_info = read_query_info(query_fn=query_path) + all_scene_query_info[dataset_name + '/' + scene] = query_info + image_path = osp.join(dataset_path, dataset_name, scene) + for fn in sorted(query_info.keys()): + # for fn in sorted(query_info.keys())[880:][::5]: # darwinRGB-loc-outdoor-aligned + # for fn in sorted(query_info.keys())[3161:][::5]: # darwinRGB-loc-indoor-aligned + # for fn in sorted(query_info.keys())[2840:][::5]: # darwinRGB-loc-indoor-aligned + + # for fn in sorted(query_info.keys())[2100:][::5]: # darwinRGB-loc-outdoor + # for fn in sorted(query_info.keys())[4360:][::5]: # darwinRGB-loc-indoor + # for fn in sorted(query_info.keys())[1380:]: # Cam-Church + # for fn in sorted(query_info.keys())[::5]: #ACUED-test2 + # for fn in sorted(query_info.keys())[1260:]: # jesus aligned + # for fn in sorted(query_info.keys())[1260:]: # jesus aligned + # for fn in sorted(query_info.keys())[4850:]: + img = cv2.imread(osp.join(image_path, fn)) # BGR + + camera_model, width, height, params = all_scene_query_info[dataset_name + '/' + scene][fn] + # camera = Camera(id=-1, model=camera_model, width=width, height=height, params=params) + camera = pycolmap.Camera(model=camera_model, width=int(width), height=int(height), params=params) + curr_frame = Frame(image=img, camera=camera, id=0, name=fn, scene_name=dataset_name + '/' + scene) + gt_sub_map = locMap.sub_maps[curr_frame.scene_name] + if gt_sub_map.gt_poses is not None and curr_frame.name in gt_sub_map.gt_poses.keys(): + curr_frame.gt_qvec = gt_sub_map.gt_poses[curr_frame.name]['qvec'] + curr_frame.gt_tvec = gt_sub_map.gt_poses[curr_frame.name]['tvec'] + + with torch.no_grad(): + if config['image_dim'] == 1: + img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + img_cuda = torch.from_numpy(img_gray / 255)[None].cuda().float() + else: + img_cuda = torch.from_numpy(img / 255).permute(2, 0, 1).cuda().float() + if img_transforms is not None: + img_cuda = img_transforms(img_cuda)[None] + else: + img_cuda = img_cuda[None] + + t_start = time.time() + encoder_out = local_feat.extract_local_global(data={'image': img_cuda}, + config={'min_keypoints': 128, + 'max_keypoints': config['eval_max_keypoints'], + } + ) + t_feat = time.time() - t_start + # global_descriptors_cuda = encoder_out['global_descriptors'] + scores_cuda = encoder_out['scores'][0][None] + kpts_cuda = encoder_out['keypoints'][0][None] + descriptors_cuda = encoder_out['descriptors'][0][None].permute(0, 2, 1) + + curr_frame.add_keypoints(keypoints=np.hstack([kpts_cuda[0].cpu().numpy(), + scores_cuda[0].cpu().numpy().reshape(-1, 1)]), + descriptors=descriptors_cuda[0].cpu().numpy()) + curr_frame.time_feat = t_feat + + t_start = time.time() + _, seg_descriptors = local_feat.sample(score_map=encoder_out['score_map'], + semi_descs=encoder_out['mid_features'], + kpts=kpts_cuda[0], + norm_desc=config['norm_desc']) + rec_out = rec_model({'scores': scores_cuda, + 'seg_descriptors': seg_descriptors[None].permute(0, 2, 1), + 'keypoints': kpts_cuda, + 'image': img_cuda}) + t_rec = time.time() - t_start + curr_frame.time_rec = t_rec + + pred = { + 'scores': scores_cuda, + 'keypoints': kpts_cuda, + 'descriptors': descriptors_cuda, + # 'global_descriptors': global_descriptors_cuda, + 'image_size': np.array([img.shape[1], img.shape[0]])[None], + } + + pred = {**pred, **rec_out} + pred_seg = torch.max(pred['prediction'], dim=2)[1] # [B, N, C] + + pred_seg = pred_seg[0].cpu().numpy() + kpts = kpts_cuda[0].cpu().numpy() + segmentations = pred['prediction'][0] # .cpu().numpy() # [N, C] + curr_frame.add_segmentations(segmentations=segmentations, + filtering_threshold=config['localization']['pre_filtering_th']) + + img_pred_seg = vis_seg_point(img=img, kpts=curr_frame.keypoints, + segs=curr_frame.seg_ids + 1, seg_color=seg_color, radius=9) + show_text = 'kpts: {:d}'.format(kpts.shape[0]) + img_pred_seg = cv2.putText(img=img_pred_seg, + text=show_text, + org=(50, 30), + fontFace=cv2.FONT_HERSHEY_SIMPLEX, + fontScale=1, color=(0, 0, 255), + thickness=2, lineType=cv2.LINE_AA) + curr_frame.image_rec = img_pred_seg + + if show: + img_text = puttext_with_background(image=img, text='Press C - continue | S - pause | Q - exit', + org=(30, 50), + bg_color=(255, 255, 255), + text_color=(0, 0, 255), + fontScale=1, thickness=2) + cv2.imshow('img', img_text) + key = cv2.waitKey(show_time) + if key == ord('q'): + exit(0) + elif key == ord('s'): + show_time = -1 + elif key == ord('c'): + show_time = 1 + + # Step1: do tracker first + success = not mTracker.lost and mViewer.tracking + if success: + success = mTracker.run(frame=curr_frame) + if success: + mViewer.update(curr_frame=curr_frame) + + if not success: + # success = locMap.run(q_frame=curr_frame, q_segs=segmentations) + success = locMap.run(q_frame=curr_frame) + if success: + mViewer.update(curr_frame=curr_frame) + + if success: + curr_frame.update_point3ds() + if mViewer.tracking: + mTracker.lost = False + mTracker.last_frame = curr_frame + + time.sleep(50 / 1000) + locMap.do_refinement = mViewer.refinement + + n_total = n_total + 1 + q_err, t_err = curr_frame.compute_pose_error() + if q_err <= 5 and t_err <= 0.05: + err_ths_cnt[0] = err_ths_cnt[0] + 1 + if q_err <= 2 and t_err <= 0.25: + err_ths_cnt[1] = err_ths_cnt[1] + 1 + if q_err <= 5 and t_err <= 0.5: + err_ths_cnt[2] = err_ths_cnt[2] + 1 + if q_err <= 10 and t_err <= 5: + err_ths_cnt[3] = err_ths_cnt[3] + 1 + time_total = curr_frame.time_feat + curr_frame.time_rec + curr_frame.time_loc + curr_frame.time_ref + print_text = 'qname: {:s} localization {:b}, q_err: {:.2f}, t_err: {:.2f}, {:d}/{:d}/{:d}/{:d}/{:d}, time: {:.2f}/{:.2f}/{:.2f}/{:.2f}/{:.2f}'.format( + scene + '/' + fn, success, q_err, t_err, + err_ths_cnt[0], + err_ths_cnt[1], + err_ths_cnt[2], + err_ths_cnt[3], + n_total, + curr_frame.time_feat, curr_frame.time_rec, curr_frame.time_loc, curr_frame.time_ref, time_total + ) + print(print_text) + + mViewer.terminate() + viewer_thread.join() diff --git a/third_party/pram/localization/localizer.py b/third_party/pram/localization/localizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0777b9cc6d7f70aa8c3699f360684cd24054a488 --- /dev/null +++ b/third_party/pram/localization/localizer.py @@ -0,0 +1,217 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> hloc +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 07/02/2024 16:45 +==================================================''' + +import os +import os.path as osp +from tqdm import tqdm +import argparse +import time +import logging +import h5py +import numpy as np +from pathlib import Path +from colmap_utils.read_write_model import read_model +from colmap_utils.parsers import parse_image_lists_with_intrinsics +# localization +from localization.match_features_batch import confs +from localization.base_model import dynamic_load +from localization import matchers +from localization.utils import compute_pose_error, read_gt_pose, read_retrieval_results +from localization.pose_estimator import pose_estimator_hloc, pose_estimator_iterative + + +def run(args): + if args.gt_pose_fn is not None: + gt_poses = read_gt_pose(path=args.gt_pose_fn) + else: + gt_poses = {} + retrievals = read_retrieval_results(args.retrieval) + + save_root = args.save_root # path to save + os.makedirs(save_root, exist_ok=True) + matcher_name = args.matcher_method # matching method + print('matcher: ', confs[args.matcher_method]['model']['name']) + Model = dynamic_load(matchers, confs[args.matcher_method]['model']['name']) + matcher = Model(confs[args.matcher_method]['model']).eval().cuda() + + local_feat_name = args.features.as_posix().split("/")[-1].split(".")[0] # name of local features + save_fn = '{:s}_{:s}'.format(local_feat_name, matcher_name) + if args.use_hloc: + save_fn = 'hloc_' + save_fn + save_fn = osp.join(save_root, save_fn) + + queries = parse_image_lists_with_intrinsics(args.queries) + _, db_images, points3D = read_model(str(args.reference_sfm), '.bin') + db_name_to_id = {image.name: i for i, image in db_images.items()} + feature_file = h5py.File(args.features, 'r') + + tag = '' + if args.do_covisible_opt: + tag = tag + "_o" + str(int(args.obs_thresh)) + 'op' + str(int(args.covisibility_frame)) + tag = tag + "th" + str(int(args.opt_thresh)) + if args.iters > 0: + tag = tag + "i" + str(int(args.iters)) + + log_fn = save_fn + tag + vis_dir = save_fn + tag + results = save_fn + tag + + full_log_fn = log_fn + '_full.log' + loc_log_fn = log_fn + '_loc.npy' + results = Path(results + '.txt') + vis_dir = Path(vis_dir) + if vis_dir is not None: + Path(vis_dir).mkdir(exist_ok=True) + print("save_fn: ", log_fn) + + logging.info('Starting localization...') + poses = {} + failed_cases = [] + n_total = 0 + n_failed = 0 + full_log_info = '' + loc_results = {} + + error_ths = ((0.25, 2), (0.5, 5), (5, 10)) + success = [0, 0, 0] + total_loc_time = [] + + for qname, qinfo in tqdm(queries): + kpq = feature_file[qname]['keypoints'].__array__() + n_total += 1 + time_start = time.time() + + if qname in retrievals.keys(): + cans = retrievals[qname] + db_ids = [db_name_to_id[v] for v in cans] + else: + cans = [] + db_ids = [] + time_coarse = time.time() + + if args.use_hloc: + output = pose_estimator_hloc(qname=qname, qinfo=qinfo, db_ids=db_ids, db_images=db_images, + points3D=points3D, + feature_file=feature_file, + thresh=args.ransac_thresh, + image_dir=args.image_dir, + matcher=matcher, + log_info='', + query_img_prefix='', + db_img_prefix='') + else: # should be faster and more accurate than hloc + t_start = time.time() + output = pose_estimator_iterative(qname=qname, + qinfo=qinfo, + matcher=matcher, + db_ids=db_ids, + db_images=db_images, + points3D=points3D, + feature_file=feature_file, + thresh=args.ransac_thresh, + image_dir=args.image_dir, + do_covisibility_opt=args.do_covisible_opt, + covisibility_frame=args.covisibility_frame, + log_info='', + inlier_th=args.inlier_thresh, + obs_th=args.obs_thresh, + opt_th=args.opt_thresh, + gt_qvec=gt_poses[qname]['qvec'] if qname in gt_poses.keys() else None, + gt_tvec=gt_poses[qname]['tvec'] if qname in gt_poses.keys() else None, + query_img_prefix='', + db_img_prefix='database', + ) + time_full = time.time() + + qvec = output['qvec'] + tvec = output['tvec'] + loc_time = time_full - time_start + total_loc_time.append(loc_time) + + poses[qname] = (qvec, tvec) + print_text = "All {:d}/{:d} failed cases, time[cs/fn]: {:.2f}/{:.2f}".format( + n_failed, n_total, + time_coarse - time_start, + time_full - time_coarse, + ) + + if qname in gt_poses.keys(): + gt_qvec = gt_poses[qname]['qvec'] + gt_tvec = gt_poses[qname]['tvec'] + + q_error, t_error = compute_pose_error(pred_qcw=qvec, pred_tcw=tvec, gt_qcw=gt_qvec, gt_tcw=gt_tvec) + + for error_idx, th in enumerate(error_ths): + if t_error <= th[0] and q_error <= th[1]: + success[error_idx] += 1 + print_text += ( + ', q_error:{:.2f} t_error:{:.2f} {:d}/{:d}/{:d}/{:d}, time: {:.2f}, {:d}pts'.format(q_error, t_error, + success[0], + success[1], + success[2], n_total, + loc_time, + kpq.shape[0])) + if output['num_inliers'] == 0: + failed_cases.append(qname) + + loc_results[qname] = { + 'keypoints_query': output['keypoints_query'], + 'points3D_ids': output['points3D_ids'], + } + full_log_info = full_log_info + output['log_info'] + full_log_info += (print_text + "\n") + print(print_text) + + logs_path = f'{results}.failed' + with open(logs_path, 'w') as f: + for v in failed_cases: + print(v) + f.write(v + "\n") + + logging.info(f'Localized {len(poses)} / {len(queries)} images.') + logging.info(f'Writing poses to {results}...') + # logging.info(f'Mean loc time: {np.mean(total_loc_time)}...') + print('Mean loc time: {:.2f}...'.format(np.mean(total_loc_time))) + with open(results, 'w') as f: + for q in poses: + qvec, tvec = poses[q] + qvec = ' '.join(map(str, qvec)) + tvec = ' '.join(map(str, tvec)) + name = q + f.write(f'{name} {qvec} {tvec}\n') + + with open(full_log_fn, 'w') as f: + f.write(full_log_info) + + np.save(loc_log_fn, loc_results) + print('Save logs to ', loc_log_fn) + logging.info('Done!') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--image_dir', type=str, required=True) + parser.add_argument('--dataset', type=str, required=True) + parser.add_argument('--reference_sfm', type=Path, required=True) + parser.add_argument('--queries', type=Path, required=True) + parser.add_argument('--features', type=Path, required=True) + parser.add_argument('--ransac_thresh', type=float, default=12) + parser.add_argument('--covisibility_frame', type=int, default=50) + parser.add_argument('--do_covisible_opt', action='store_true') + parser.add_argument('--use_hloc', action='store_true') + parser.add_argument('--matcher_method', type=str, default="NNM") + parser.add_argument('--inlier_thresh', type=int, default=50) + parser.add_argument('--obs_thresh', type=float, default=3) + parser.add_argument('--opt_thresh', type=float, default=12) + parser.add_argument('--save_root', type=str, required=True) + parser.add_argument('--retrieval', type=Path, default=None) + parser.add_argument('--gt_pose_fn', type=str, default=None) + + args = parser.parse_args() + os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' + run(args=args) diff --git a/third_party/pram/localization/match_features.py b/third_party/pram/localization/match_features.py new file mode 100644 index 0000000000000000000000000000000000000000..ef1b4edccff67db24d97fadb47024eb09c026ce8 --- /dev/null +++ b/third_party/pram/localization/match_features.py @@ -0,0 +1,156 @@ +import argparse +import torch +from pathlib import Path +import h5py +import logging +from tqdm import tqdm +import pprint + +import localization.matchers as matchers +from localization.base_model import dynamic_load +from colmap_utils.parsers import names_to_pair + +confs = { + 'gm': { + 'output': 'gm', + 'model': { + 'name': 'gm', + 'weight_path': 'weights/imp_gm.900.pth', + 'sinkhorn_iterations': 20, + }, + }, + 'gml': { + 'output': 'gml', + 'model': { + 'name': 'gml', + 'weight_path': 'weights/imp_gml.920.pth', + 'sinkhorn_iterations': 20, + }, + }, + + 'adagml': { + 'output': 'adagml', + 'model': { + 'name': 'adagml', + 'weight_path': 'weights/imp_adagml.80.pth', + 'sinkhorn_iterations': 20, + }, + }, + + 'superglue': { + 'output': 'superglue', + 'model': { + 'name': 'superglue', + 'weights': 'outdoor', + 'sinkhorn_iterations': 20, + 'weight_path': 'weights/superglue_outdoor.pth', + }, + }, + 'NNM': { + 'output': 'NNM', + 'model': { + 'name': 'nearest_neighbor', + 'do_mutual_check': True, + 'distance_threshold': None, + }, + }, +} + + +@torch.no_grad() +def main(conf, pairs, features, export_dir, exhaustive=False): + logging.info('Matching local features with configuration:' + f'\n{pprint.pformat(conf)}') + + feature_path = Path(export_dir, features + '.h5') + assert feature_path.exists(), feature_path + feature_file = h5py.File(str(feature_path), 'r') + pairs_name = pairs.stem + if not exhaustive: + assert pairs.exists(), pairs + with open(pairs, 'r') as f: + pair_list = f.read().rstrip('\n').split('\n') + elif exhaustive: + logging.info(f'Writing exhaustive match pairs to {pairs}.') + assert not pairs.exists(), pairs + + # get the list of images from the feature file + images = [] + feature_file.visititems( + lambda name, obj: images.append(obj.parent.name.strip('/')) + if isinstance(obj, h5py.Dataset) else None) + images = list(set(images)) + + pair_list = [' '.join((images[i], images[j])) + for i in range(len(images)) for j in range(i)] + with open(str(pairs), 'w') as f: + f.write('\n'.join(pair_list)) + + device = 'cuda' if torch.cuda.is_available() else 'cpu' + Model = dynamic_load(matchers, conf['model']['name']) + model = Model(conf['model']).eval().to(device) + + match_name = f'{features}-{conf["output"]}-{pairs_name}' + match_path = Path(export_dir, match_name + '.h5') + + match_file = h5py.File(str(match_path), 'a') + + matched = set() + for pair in tqdm(pair_list, smoothing=.1): + name0, name1 = pair.split(' ') + pair = names_to_pair(name0, name1) + + # Avoid to recompute duplicates to save time + if len({(name0, name1), (name1, name0)} & matched) \ + or pair in match_file: + continue + + data = {} + feats0, feats1 = feature_file[name0], feature_file[name1] + for k in feats1.keys(): + # data[k + '0'] = feats0[k].__array__() + if k == 'descriptors': + data[k + '0'] = feats0[k][()].transpose() # [N D] + else: + data[k + '0'] = feats0[k][()] + for k in feats1.keys(): + # data[k + '1'] = feats1[k].__array__() + # data[k + '1'] = feats1[k][()].transpose() # [N D] + if k == 'descriptors': + data[k + '1'] = feats1[k][()].transpose() # [N D] + else: + data[k + '1'] = feats1[k][()] + data = {k: torch.from_numpy(v)[None].float().to(device) + for k, v in data.items()} + + # some matchers might expect an image but only use its size + data['image0'] = torch.empty((1, 1,) + tuple(feats0['image_size'])[::-1]) + data['image1'] = torch.empty((1, 1,) + tuple(feats1['image_size'])[::-1]) + + pred = model(data) + grp = match_file.create_group(pair) + matches = pred['matches0'][0].cpu().short().numpy() + grp.create_dataset('matches0', data=matches) + + if 'matching_scores0' in pred: + scores = pred['matching_scores0'][0].cpu().half().numpy() + grp.create_dataset('matching_scores0', data=scores) + + matched |= {(name0, name1), (name1, name0)} + + match_file.close() + logging.info('Finished exporting matches.') + + return match_path + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--export_dir', type=Path, required=True) + parser.add_argument('--features', type=str, required=True) + parser.add_argument('--pairs', type=Path, required=True) + parser.add_argument('--conf', type=str, required=True, choices=list(confs.keys())) + parser.add_argument('--exhaustive', action='store_true') + args = parser.parse_args() + main(confs[args.conf], args.pairs, args.features, args.export_dir, + exhaustive=args.exhaustive) diff --git a/third_party/pram/localization/match_features_batch.py b/third_party/pram/localization/match_features_batch.py new file mode 100644 index 0000000000000000000000000000000000000000..1c0dc9d4a1e4288892c365616e45304a19e93c3e --- /dev/null +++ b/third_party/pram/localization/match_features_batch.py @@ -0,0 +1,242 @@ +import argparse +import torch +from pathlib import Path +import h5py +import logging +from tqdm import tqdm +import pprint +from queue import Queue +from threading import Thread +from functools import partial +from typing import Dict, List, Optional, Tuple, Union + +import localization.matchers as matchers +from localization.base_model import dynamic_load +from colmap_utils.parsers import names_to_pair, names_to_pair_old, parse_retrieval + +confs = { + 'gm': { + 'output': 'gm', + 'model': { + 'name': 'gm', + 'weight_path': 'weights/imp_gm.900.pth', + 'sinkhorn_iterations': 20, + }, + }, + 'gml': { + 'output': 'gml', + 'model': { + 'name': 'gml', + 'weight_path': 'weights/imp_gml.920.pth', + 'sinkhorn_iterations': 20, + }, + }, + + 'adagml': { + 'output': 'adagml', + 'model': { + 'name': 'adagml', + 'weight_path': 'weights/imp_adagml.80.pth', + 'sinkhorn_iterations': 20, + }, + }, + + 'superglue': { + 'output': 'superglue', + 'model': { + 'name': 'superglue', + 'weights': 'outdoor', + 'sinkhorn_iterations': 20, + 'weight_path': 'weights/superglue_outdoor.pth', + }, + }, + 'NNM': { + 'output': 'NNM', + 'model': { + 'name': 'nearest_neighbor', + 'do_mutual_check': True, + 'distance_threshold': None, + }, + }, +} + + +class WorkQueue: + def __init__(self, work_fn, num_threads=1): + self.queue = Queue(num_threads) + self.threads = [ + Thread(target=self.thread_fn, args=(work_fn,)) for _ in range(num_threads) + ] + for thread in self.threads: + thread.start() + + def join(self): + for thread in self.threads: + self.queue.put(None) + for thread in self.threads: + thread.join() + + def thread_fn(self, work_fn): + item = self.queue.get() + while item is not None: + work_fn(item) + item = self.queue.get() + + def put(self, data): + self.queue.put(data) + + +class FeaturePairsDataset(torch.utils.data.Dataset): + def __init__(self, pairs, feature_path_q, feature_path_r): + self.pairs = pairs + self.feature_path_q = feature_path_q + self.feature_path_r = feature_path_r + + def __getitem__(self, idx): + name0, name1 = self.pairs[idx] + data = {} + with h5py.File(self.feature_path_q, "r") as fd: + grp = fd[name0] + for k, v in grp.items(): + data[k + "0"] = torch.from_numpy(v.__array__()).float() + if k == 'descriptors': + data[k + '0'] = data[k + '0'].t() + # some matchers might expect an image but only use its size + data["image0"] = torch.empty((1,) + tuple(grp["image_size"])[::-1]) + with h5py.File(self.feature_path_r, "r") as fd: + grp = fd[name1] + for k, v in grp.items(): + data[k + "1"] = torch.from_numpy(v.__array__()).float() + if k == 'descriptors': + data[k + '1'] = data[k + '1'].t() + data["image1"] = torch.empty((1,) + tuple(grp["image_size"])[::-1]) + return data + + def __len__(self): + return len(self.pairs) + + +def writer_fn(inp, match_path): + pair, pred = inp + with h5py.File(str(match_path), "a", libver="latest") as fd: + if pair in fd: + del fd[pair] + grp = fd.create_group(pair) + matches = pred["matches0"][0].cpu().short().numpy() + grp.create_dataset("matches0", data=matches) + if "matching_scores0" in pred: + scores = pred["matching_scores0"][0].cpu().half().numpy() + grp.create_dataset("matching_scores0", data=scores) + + +def main( + conf: Dict, + pairs: Path, + features: Union[Path, str], + export_dir: Optional[Path] = None, + matches: Optional[Path] = None, + features_ref: Optional[Path] = None, + overwrite: bool = False, +) -> Path: + if isinstance(features, Path) or Path(features).exists(): + features_q = features + if matches is None: + raise ValueError( + "Either provide both features and matches as Path" " or both as names." + ) + else: + if export_dir is None: + raise ValueError( + "Provide an export_dir if features is not" f" a file path: {features}." + ) + features_q = Path(export_dir, features + ".h5") + if matches is None: + matches = Path(export_dir, f'{features}-{conf["output"]}-{pairs.stem}.h5') + + if features_ref is None: + features_ref = features_q + match_from_paths(conf, pairs, matches, features_q, features_ref, overwrite) + + return matches + + +def find_unique_new_pairs(pairs_all: List[Tuple[str]], match_path: Path = None): + """Avoid to recompute duplicates to save time.""" + pairs = set() + for i, j in pairs_all: + if (j, i) not in pairs: + pairs.add((i, j)) + pairs = list(pairs) + if match_path is not None and match_path.exists(): + with h5py.File(str(match_path), "r", libver="latest") as fd: + pairs_filtered = [] + for i, j in pairs: + if ( + names_to_pair(i, j) in fd + or names_to_pair(j, i) in fd + or names_to_pair_old(i, j) in fd + or names_to_pair_old(j, i) in fd + ): + continue + pairs_filtered.append((i, j)) + return pairs_filtered + return pairs + + +@torch.no_grad() +def match_from_paths( + conf: Dict, + pairs_path: Path, + match_path: Path, + feature_path_q: Path, + feature_path_ref: Path, + overwrite: bool = False, +) -> Path: + logging.info( + "Matching local features with configuration:" f"\n{pprint.pformat(conf)}" + ) + + if not feature_path_q.exists(): + raise FileNotFoundError(f"Query feature file {feature_path_q}.") + if not feature_path_ref.exists(): + raise FileNotFoundError(f"Reference feature file {feature_path_ref}.") + match_path.parent.mkdir(exist_ok=True, parents=True) + + assert pairs_path.exists(), pairs_path + pairs = parse_retrieval(pairs_path) + pairs = [(q, r) for q, rs in pairs.items() for r in rs] + pairs = find_unique_new_pairs(pairs, None if overwrite else match_path) + if len(pairs) == 0: + logging.info("Skipping the matching.") + return + + device = "cuda" if torch.cuda.is_available() else "cpu" + Model = dynamic_load(matchers, conf["model"]["name"]) + model = Model(conf["model"]).eval().to(device) + + dataset = FeaturePairsDataset(pairs, feature_path_q, feature_path_ref) + loader = torch.utils.data.DataLoader( + dataset, num_workers=4, batch_size=1, shuffle=False, pin_memory=True + ) + writer_queue = WorkQueue(partial(writer_fn, match_path=match_path), 5) + + for idx, data in enumerate(tqdm(loader, smoothing=0.1)): + data = { + k: v if k.startswith("image") else v.to(device, non_blocking=True) + for k, v in data.items() + } + pred = model(data) + pair = names_to_pair(*pairs[idx]) + writer_queue.put((pair, pred)) + writer_queue.join() + logging.info("Finished exporting matches.") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--export_dir', type=Path, required=True) + parser.add_argument('--features', type=str, required=True) + parser.add_argument('--pairs', type=Path, required=True) + parser.add_argument('--conf', type=str, required=True, choices=list(confs.keys())) + args = parser.parse_args() + main(confs[args.conf], args.pairs, args.features, args.export_dir) diff --git a/third_party/pram/localization/matchers/__init__.py b/third_party/pram/localization/matchers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7edac76f912b1e5ebb0401b6cc7a5d3c64ce963a --- /dev/null +++ b/third_party/pram/localization/matchers/__init__.py @@ -0,0 +1,3 @@ +def get_matcher(matcher): + mod = __import__(f'{__name__}.{matcher}', fromlist=['']) + return getattr(mod, 'Model') diff --git a/third_party/pram/localization/matchers/adagml.py b/third_party/pram/localization/matchers/adagml.py new file mode 100644 index 0000000000000000000000000000000000000000..31a4bd2aa74bef934543b79567f148f5b8b7b092 --- /dev/null +++ b/third_party/pram/localization/matchers/adagml.py @@ -0,0 +1,41 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> adagml +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 11/02/2024 14:34 +==================================================''' +import torch +from localization.base_model import BaseModel +from nets.adagml import AdaGML as GMatcher + + +class AdaGML(BaseModel): + default_config = { + 'descriptor_dim': 128, + 'hidden_dim': 256, + 'weights': 'indoor', + 'keypoint_encoder': [32, 64, 128, 256], + 'GNN_layers': ['self', 'cross'] * 9, # [self, cross, self, cross, ...] 9 in total + 'sinkhorn_iterations': 20, + 'match_threshold': 0.2, + 'with_pose': False, + 'n_layers': 9, + 'n_min_tokens': 256, + 'with_sinkhorn': True, + 'weight_path': None, + } + + required_inputs = [ + 'image0', 'keypoints0', 'scores0', 'descriptors0', + 'image1', 'keypoints1', 'scores1', 'descriptors1', + ] + + def _init(self, conf): + self.net = GMatcher(config=conf).eval() + state_dict = torch.load(conf['weight_path'], map_location='cpu')['model'] + self.net.load_state_dict(state_dict, strict=True) + + def _forward(self, data): + with torch.no_grad(): + return self.net(data) diff --git a/third_party/pram/localization/matchers/gm.py b/third_party/pram/localization/matchers/gm.py new file mode 100644 index 0000000000000000000000000000000000000000..2484cdb521d28a8cc0b5be7148919cd46bc67b32 --- /dev/null +++ b/third_party/pram/localization/matchers/gm.py @@ -0,0 +1,44 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File r2d2 -> gm +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 25/05/2023 10:09 +==================================================''' +import torch +from localization.base_model import BaseModel +from nets.gm import GM as GMatcher + + +class GM(BaseModel): + default_config = { + 'descriptor_dim': 128, + 'hidden_dim': 256, + 'weights': 'indoor', + 'keypoint_encoder': [32, 64, 128, 256], + 'GNN_layers': ['self', 'cross'] * 9, # [self, cross, self, cross, ...] 9 in total + 'sinkhorn_iterations': 20, + 'match_threshold': 0.2, + 'with_pose': False, + 'n_layers': 9, + 'n_min_tokens': 256, + 'with_sinkhorn': True, + + 'ac_fn': 'relu', + 'norm_fn': 'bn', + 'weight_path': None, + } + + required_inputs = [ + 'image0', 'keypoints0', 'scores0', 'descriptors0', + 'image1', 'keypoints1', 'scores1', 'descriptors1', + ] + + def _init(self, conf): + self.net = GMatcher(config=conf).eval() + state_dict = torch.load(conf['weight_path'], map_location='cpu')['model'] + self.net.load_state_dict(state_dict, strict=True) + + def _forward(self, data): + with torch.no_grad(): + return self.net(data) diff --git a/third_party/pram/localization/matchers/gml.py b/third_party/pram/localization/matchers/gml.py new file mode 100644 index 0000000000000000000000000000000000000000..2f9acdeaf3c7bd9670c1f7c49e2bbf709f1e8b4a --- /dev/null +++ b/third_party/pram/localization/matchers/gml.py @@ -0,0 +1,45 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File localizer -> gml +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 15/01/2024 11:01 +==================================================''' +import torch +from localization.base_model import BaseModel +from nets.gml import GML as GMatcher + + +class GML(BaseModel): + default_config = { + 'descriptor_dim': 128, + 'hidden_dim': 256, + 'weights': 'indoor', + 'keypoint_encoder': [32, 64, 128, 256], + 'GNN_layers': ['self', 'cross'] * 9, # [self, cross, self, cross, ...] 9 in total + 'sinkhorn_iterations': 20, + 'match_threshold': 0.2, + 'with_pose': False, + 'n_layers': 9, + 'n_min_tokens': 256, + 'with_sinkhorn': True, + + 'ac_fn': 'relu', + 'norm_fn': 'bn', + 'weight_path': None, + } + + required_inputs = [ + 'image0', 'keypoints0', 'scores0', 'descriptors0', + 'image1', 'keypoints1', 'scores1', 'descriptors1', + ] + + def _init(self, conf): + self.net = GMatcher(config=conf).eval() + state_dict = torch.load(conf['weight_path'], map_location='cpu')['model'] + self.net.load_state_dict(state_dict, strict=True) + + def _forward(self, data): + with torch.no_grad(): + # print(data['keypoints0'].shape, data['descriptors0'].shape, data['image0'].shape) + return self.net(data) diff --git a/third_party/pram/localization/matchers/nearest_neighbor.py b/third_party/pram/localization/matchers/nearest_neighbor.py new file mode 100644 index 0000000000000000000000000000000000000000..42b8078747535a269dab6131b4f20c0857c36c03 --- /dev/null +++ b/third_party/pram/localization/matchers/nearest_neighbor.py @@ -0,0 +1,56 @@ +import torch +from localization.base_model import BaseModel + + +def find_nn(sim, ratio_thresh, distance_thresh): + sim_nn, ind_nn = sim.topk(2 if ratio_thresh else 1, dim=-1, largest=True) + dist_nn = 2 * (1 - sim_nn) + mask = torch.ones(ind_nn.shape[:-1], dtype=torch.bool, device=sim.device) + if ratio_thresh: + mask = mask & (dist_nn[..., 0] <= (ratio_thresh ** 2) * dist_nn[..., 1]) + if distance_thresh: + mask = mask & (dist_nn[..., 0] <= distance_thresh ** 2) + matches = torch.where(mask, ind_nn[..., 0], ind_nn.new_tensor(-1)) + scores = torch.where(mask, (sim_nn[..., 0] + 1) / 2, sim_nn.new_tensor(0)) + return matches, scores + + +def mutual_check(m0, m1): + inds0 = torch.arange(m0.shape[-1], device=m0.device) + loop = torch.gather(m1, -1, torch.where(m0 > -1, m0, m0.new_tensor(0))) + ok = (m0 > -1) & (inds0 == loop) + m0_new = torch.where(ok, m0, m0.new_tensor(-1)) + return m0_new + + +class NearestNeighbor(BaseModel): + default_conf = { + 'ratio_threshold': None, + 'distance_threshold': None, + 'do_mutual_check': True, + } + required_inputs = ['descriptors0', 'descriptors1'] + + def _init(self, conf): + pass + + def _forward(self, data): + sim = torch.einsum( + 'bdn,bdm->bnm', data['descriptors0'], data['descriptors1']) + matches0, scores0 = find_nn( + sim, self.conf['ratio_threshold'], self.conf['distance_threshold']) + # matches1, scores1 = find_nn( + # sim.transpose(1, 2), self.conf['ratio_threshold'], + # self.conf['distance_threshold']) + if self.conf['do_mutual_check']: + # print("with mutual check") + matches1, scores1 = find_nn( + sim.transpose(1, 2), self.conf['ratio_threshold'], + self.conf['distance_threshold']) + matches0 = mutual_check(matches0, matches1) + # else: + # print("no mutual check") + return { + 'matches0': matches0, + 'matching_scores0': scores0, + } diff --git a/third_party/pram/localization/multimap3d.py b/third_party/pram/localization/multimap3d.py new file mode 100644 index 0000000000000000000000000000000000000000..6100b4f4bfeb1d3f8bc94598723979e830bf4172 --- /dev/null +++ b/third_party/pram/localization/multimap3d.py @@ -0,0 +1,379 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> multimap3d +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 04/03/2024 13:47 +==================================================''' +import numpy as np +import os +import os.path as osp +import time +import cv2 +import torch +import yaml +from copy import deepcopy +from recognition.vis_seg import vis_seg_point, generate_color_dic, vis_inlier, plot_matches +from localization.base_model import dynamic_load +import localization.matchers as matchers +from localization.match_features_batch import confs as matcher_confs +from nets.gm import GM +from tools.common import resize_img +from localization.singlemap3d import SingleMap3D +from localization.frame import Frame + + +class MultiMap3D: + def __init__(self, config, viewer=None, save_dir=None): + self.config = config + self.save_dir = save_dir + + self.scenes = [] + self.sid_scene_name = [] + self.sub_maps = {} + self.scene_name_start_sid = {} + + self.loc_config = config['localization'] + self.save_dir = save_dir + if self.save_dir is not None: + os.makedirs(self.save_dir, exist_ok=True) + + self.matching_method = config['localization']['matching_method'] + device = 'cuda' if torch.cuda.is_available() else 'cpu' + Model = dynamic_load(matchers, self.matching_method) + self.matcher = Model(matcher_confs[self.matching_method]['model']).eval().to(device) + + self.initialize_map(config=config) + self.loc_config = config['localization'] + + self.viewer = viewer + + # options + self.do_refinement = self.loc_config['do_refinement'] + self.refinement_method = self.loc_config['refinement_method'] + self.semantic_matching = self.loc_config['semantic_matching'] + self.do_pre_filtering = self.loc_config['pre_filtering_th'] > 0 + self.pre_filtering_th = self.loc_config['pre_filtering_th'] + + def initialize_map(self, config): + n_class = 0 + datasets = config['dataset'] + + for name in datasets: + config_path = osp.join(config['config_path'], '{:s}.yaml'.format(name)) + dataset_name = name + + with open(config_path, 'r') as f: + scene_config = yaml.load(f, Loader=yaml.Loader) + + scenes = scene_config['scenes'] + for sid, scene in enumerate(scenes): + self.scenes.append(name + '/' + scene) + + new_config = deepcopy(config) + new_config['dataset_path'] = osp.join(config['dataset_path'], dataset_name, scene) + new_config['landmark_path'] = osp.join(config['landmark_path'], dataset_name, scene) + new_config['n_cluster'] = scene_config[scene]['n_cluster'] + new_config['cluster_mode'] = scene_config[scene]['cluster_mode'] + new_config['cluster_method'] = scene_config[scene]['cluster_method'] + new_config['gt_pose_path'] = scene_config[scene]['gt_pose_path'] + new_config['image_path_prefix'] = scene_config[scene]['image_path_prefix'] + sub_map = SingleMap3D(config=new_config, + matcher=self.matcher, + with_compress=config['localization']['with_compress'], + start_sid=n_class) + self.sub_maps[dataset_name + '/' + scene] = sub_map + + n_scene_class = scene_config[scene]['n_cluster'] + self.sid_scene_name = self.sid_scene_name + [dataset_name + '/' + scene for ni in range(n_scene_class)] + self.scene_name_start_sid[dataset_name + '/' + scene] = n_class + n_class = n_class + n_scene_class + + # break + print('Load {} sub_maps from {} datasets'.format(len(self.sub_maps), len(datasets))) + + def run(self, q_frame: Frame): + show = self.loc_config['show'] + seg_color = generate_color_dic(n_seg=2000) + if show: + cv2.namedWindow('loc', cv2.WINDOW_NORMAL) + + q_loc_segs = self.process_segmentations(segs=torch.from_numpy(q_frame.segmentations), + topk=self.loc_config['seg_k']) + q_pred_segs_top1 = q_frame.seg_ids # initial results + + q_scene_name = q_frame.scene_name + q_name = q_frame.name + q_full_name = osp.join(q_scene_name, q_name) + + q_loc_sids = {} + for v in q_loc_segs: + q_loc_sids[v[0]] = (v[1], v[2]) + query_sids = list(q_loc_sids.keys()) + + for i, sid in enumerate(query_sids): + t_start = time.time() + q_kpt_ids = q_loc_sids[sid][0] + print(q_scene_name, q_name, sid) + + sid = sid - 1 # start from 0, confused! + + pred_scene_name = self.sid_scene_name[sid] + start_seg_id = self.scene_name_start_sid[pred_scene_name] + pred_sid_in_sub_scene = sid - self.scene_name_start_sid[pred_scene_name] + pred_sub_map = self.sub_maps[pred_scene_name] + pred_image_path_prefix = pred_sub_map.image_path_prefix + + print('pred/gt scene: {:s}, {:s}, sid: {:d}'.format(pred_scene_name, q_scene_name, pred_sid_in_sub_scene)) + print('{:s}/{:s}, pred: {:s}, sid: {:d}, order: {:d}'.format(q_scene_name, q_name, pred_scene_name, sid, + i)) + + if (q_kpt_ids.shape[0] >= self.loc_config['min_kpts'] + and self.semantic_matching + and pred_sub_map.check_semantic_consistency(q_frame=q_frame, + sid=pred_sid_in_sub_scene, + overlap_ratio=0.5)): + semantic_matching = True + else: + q_kpt_ids = np.arange(q_frame.keypoints.shape[0]) + semantic_matching = False + print_text = f'Semantic matching - {semantic_matching}! Query kpts {q_kpt_ids.shape[0]} for {i}th seg {sid}' + print(print_text) + ret = pred_sub_map.localize_with_ref_frame(q_frame=q_frame, + q_kpt_ids=q_kpt_ids, + sid=pred_sid_in_sub_scene, + semantic_matching=semantic_matching) + + q_frame.time_loc = q_frame.time_loc + time.time() - t_start # accumulate tracking time + + if show: + reference_frame = pred_sub_map.reference_frames[ret['reference_frame_id']] + ref_img = cv2.imread(osp.join(self.config['dataset_path'], pred_scene_name, pred_image_path_prefix, + reference_frame.name)) + q_img_seg = vis_seg_point(img=q_frame.image, kpts=q_frame.keypoints[q_kpt_ids, :2], + segs=q_frame.seg_ids[q_kpt_ids] + 1, + seg_color=seg_color) + matched_points3D_ids = ret['matched_point3D_ids'] + ref_sids = np.array([pred_sub_map.point3Ds[v].seg_id for v in matched_points3D_ids]) + \ + self.scene_name_start_sid[pred_scene_name] + 1 # start from 1 as bg is 0 + ref_img_seg = vis_seg_point(img=ref_img, kpts=ret['matched_ref_keypoints'], segs=ref_sids, + seg_color=seg_color) + q_matched_kpts = ret['matched_keypoints'] + ref_matched_kpts = ret['matched_ref_keypoints'] + img_loc_matching = plot_matches(img1=q_img_seg, img2=ref_img_seg, + pts1=q_matched_kpts, pts2=ref_matched_kpts, + inliers=np.array([True for i in range(q_matched_kpts.shape[0])]), + radius=9, line_thickness=3 + ) + + q_frame.image_matching_tmp = img_loc_matching + q_frame.reference_frame_name_tmp = osp.join(self.config['dataset_path'], + pred_scene_name, + pred_image_path_prefix, + reference_frame.name) + # ret['image_matching'] = img_loc_matching + # ret['reference_frame_name'] = osp.join(self.config['dataset_path'], + # pred_scene_name, + # pred_image_path_prefix, + # reference_frame.name) + q_ref_img_matching = np.hstack([resize_img(q_img_seg, nh=512), + resize_img(ref_img_seg, nh=512), + resize_img(img_loc_matching, nh=512)]) + + ret['order'] = i + ret['matched_scene_name'] = pred_scene_name + if not ret['success']: + num_matches = ret['matched_keypoints'].shape[0] + num_inliers = ret['num_inliers'] + print_text = f'Localization failed with {num_matches}/{q_kpt_ids.shape[0]} matches and {num_inliers} inliers, order {i}' + print(print_text) + + if show: + show_text = 'FAIL! order: {:d}/{:d}-{:d}/{:d}'.format(i, len(q_loc_segs), + num_matches, + q_kpt_ids.shape[0]) + q_img_inlier = vis_inlier(img=q_img_seg, kpts=ret['matched_keypoints'], inliers=ret['inliers'], + radius=9 + 2, thickness=2) + q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 30), + fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), + thickness=2, lineType=cv2.LINE_AA) + q_frame.image_inlier_tmp = q_img_inlier + q_img_loc = np.hstack([resize_img(q_ref_img_matching, nh=512), resize_img(q_img_inlier, nh=512)]) + cv2.imshow('loc', q_img_loc) + key = cv2.waitKey(self.loc_config['show_time']) + if key == ord('q'): + cv2.destroyAllWindows() + exit(0) + continue + + if show: + q_err, t_err = q_frame.compute_pose_error() + num_matches = ret['matched_keypoints'].shape[0] + num_inliers = ret['num_inliers'] + show_text = 'order: {:d}/{:d}, k/m/i: {:d}/{:d}/{:d}'.format( + i, len(q_loc_segs), q_kpt_ids.shape[0], num_matches, num_inliers) + q_img_inlier = vis_inlier(img=q_img_seg, kpts=ret['matched_keypoints'], inliers=ret['inliers'], + radius=9 + 2, thickness=2) + q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 30), + fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), + thickness=2, lineType=cv2.LINE_AA) + show_text = 'r_err:{:.2f}, t_err:{:.2f}'.format(q_err, t_err) + q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 80), + fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), + thickness=2, lineType=cv2.LINE_AA) + q_frame.image_inlier_tmp = q_img_inlier + + q_img_loc = np.hstack([resize_img(q_ref_img_matching, nh=512), resize_img(q_img_inlier, nh=512)]) + + cv2.imshow('loc', q_img_loc) + key = cv2.waitKey(self.loc_config['show_time']) + if key == ord('q'): + cv2.destroyAllWindows() + exit(0) + + success = self.verify_and_update(q_frame=q_frame, ret=ret) + + if not success: + continue + else: + break + + if q_frame.tracking_status is None: + print('Failed to find a proper reference frame.') + return False + + # do refinement + if not self.do_refinement: + return True + else: + t_start = time.time() + pred_sub_map = self.sub_maps[q_frame.matched_scene_name] + if q_frame.tracking_status is True and np.sum(q_frame.matched_inliers) >= 64: + ret = pred_sub_map.refine_pose(q_frame=q_frame, refinement_method=self.loc_config['refinement_method']) + else: + ret = pred_sub_map.refine_pose(q_frame=q_frame, + refinement_method='matching') # do not trust the pose for projection + + q_frame.time_ref = time.time() - t_start + + inlier_mask = np.array(ret['inliers']) + + q_frame.qvec = ret['qvec'] + q_frame.tvec = ret['tvec'] + q_frame.matched_keypoints = ret['matched_keypoints'][inlier_mask] + q_frame.matched_keypoint_ids = ret['matched_keypoint_ids'][inlier_mask] + q_frame.matched_xyzs = ret['matched_xyzs'][inlier_mask] + q_frame.matched_point3D_ids = ret['matched_point3D_ids'][inlier_mask] + q_frame.matched_sids = ret['matched_sids'][inlier_mask] + q_frame.matched_inliers = np.array(ret['inliers'])[inlier_mask] + + q_frame.refinement_reference_frame_ids = ret['refinement_reference_frame_ids'] + q_frame.reference_frame_id = ret['reference_frame_id'] + + q_err, t_err = q_frame.compute_pose_error() + ref_full_name = q_frame.matched_scene_name + '/' + pred_sub_map.reference_frames[ + q_frame.reference_frame_id].name + print_text = 'Localization of {:s} success with inliers {:d}/{:d} with ref_name: {:s}, order: {:d}, q_err: {:.2f}, t_err: {:.2f}'.format( + q_full_name, ret['num_inliers'], len(ret['inliers']), ref_full_name, q_frame.matched_order, q_err, + t_err) + print(print_text) + + if show: + q_err, t_err = q_frame.compute_pose_error() + num_matches = ret['matched_keypoints'].shape[0] + num_inliers = ret['num_inliers'] + show_text = 'Ref:{:d}/{:d},r_err:{:.2f}/t_err:{:.2f}'.format(num_matches, num_inliers, q_err, + t_err) + q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 130), + fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), + thickness=2, lineType=cv2.LINE_AA) + q_frame.image_inlier = q_img_inlier + + return True + + def verify_and_update(self, q_frame: Frame, ret: dict): + num_matches = ret['matched_keypoints'].shape[0] + num_inliers = ret['num_inliers'] + if q_frame.matched_keypoints is None or np.sum(q_frame.matched_inliers) < num_inliers: + self.update_query_frame(q_frame=q_frame, ret=ret) + + q_err, t_err = q_frame.compute_pose_error(pred_qvec=ret['qvec'], pred_tvec=ret['tvec']) + + if num_inliers < self.loc_config['min_inliers']: + print_text = 'Failed due to insufficient {:d} inliers, order {:d}, q_err: {:.2f}, t_err: {:.2f}'.format( + ret['num_inliers'], ret['order'], q_err, t_err) + print(print_text) + q_frame.tracking_status = False + return False + else: + print_text = 'Succeed! Find {}/{} 2D-3D inliers, order {:d}, q_err: {:.2f}, t_err: {:.2f}'.format( + num_inliers, num_matches, ret['order'], q_err, t_err) + print(print_text) + q_frame.tracking_status = True + return True + + def update_query_frame(self, q_frame, ret): + q_frame.matched_scene_name = ret['matched_scene_name'] + q_frame.reference_frame_id = ret['reference_frame_id'] + q_frame.qvec = ret['qvec'] + q_frame.tvec = ret['tvec'] + + inlier_mask = np.array(ret['inliers']) + q_frame.matched_keypoints = ret['matched_keypoints'] + q_frame.matched_keypoint_ids = ret['matched_keypoint_ids'] + q_frame.matched_xyzs = ret['matched_xyzs'] + q_frame.matched_point3D_ids = ret['matched_point3D_ids'] + q_frame.matched_sids = ret['matched_sids'] + q_frame.matched_inliers = np.array(ret['inliers']) + q_frame.matched_order = ret['order'] + + if q_frame.image_inlier_tmp is not None: + q_frame.image_inlier = deepcopy(q_frame.image_inlier_tmp) + if q_frame.image_matching_tmp is not None: + q_frame.image_matching = deepcopy(q_frame.image_matching_tmp) + if q_frame.reference_frame_name_tmp is not None: + q_frame.reference_frame_name = q_frame.reference_frame_name_tmp + + # inlier_mask = np.array(ret['inliers']) + # q_frame.matched_keypoints = ret['matched_keypoints'][inlier_mask] + # q_frame.matched_keypoint_ids = ret['matched_keypoint_ids'][inlier_mask] + # q_frame.matched_xyzs = ret['matched_xyzs'][inlier_mask] + # q_frame.matched_point3D_ids = ret['matched_point3D_ids'][inlier_mask] + # q_frame.matched_sids = ret['matched_sids'][inlier_mask] + # q_frame.matched_inliers = np.array(ret['inliers'])[inlier_mask] + + # print('update_query_frame: ', q_frame.matched_keypoint_ids.shape, q_frame.matched_keypoints.shape, + # q_frame.matched_xyzs.shape, q_frame.matched_xyzs.shape, np.sum(q_frame.matched_inliers)) + + def process_segmentations(self, segs, topk=10): + pred_values, pred_ids = torch.topk(segs, k=segs.shape[-1], largest=True, dim=-1) # [N, C] + pred_values = pred_values.numpy() + pred_ids = pred_ids.numpy() + + out = [] + used_sids = [] + for k in range(segs.shape[-1]): + values_k = pred_values[:, k] + ids_k = pred_ids[:, k] + uids = np.unique(ids_k) + + out_k = [] + for sid in uids: + if sid == 0: + continue + if sid in used_sids: + continue + used_sids.append(sid) + ids = np.where(ids_k == sid)[0] + score = np.mean(values_k[ids]) + # score = np.median(values_k[ids]) + # score = 100 - k + # out_k.append((ids.shape[0], sid - 1, ids, score)) + out_k.append((ids.shape[0], sid, ids, score)) + + out_k = sorted(out_k, key=lambda item: item[0], reverse=True) + for v in out_k: + out.append((v[1], v[2], v[3])) # [sid, ids, score] + if len(out) >= topk: + return out + return out diff --git a/third_party/pram/localization/point3d.py b/third_party/pram/localization/point3d.py new file mode 100644 index 0000000000000000000000000000000000000000..6e1babf427759c5f588f44023e9e1bf2648a073b --- /dev/null +++ b/third_party/pram/localization/point3d.py @@ -0,0 +1,21 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> point3d +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 04/03/2024 10:13 +==================================================''' +import numpy as np + + +class Point3D: + def __init__(self, id: int, xyz: np.ndarray, error: float, refframe_id: int, seg_id: int = None, + descriptor: np.ndarray = None, rgb: np.ndarray = None, frame_ids: np.ndarray = None): + self.id = id + self.xyz = xyz + self.rgb = rgb + self.error = error + self.seg_id = seg_id + self.refframe_id = refframe_id + self.frame_ids = frame_ids + self.descriptor = descriptor diff --git a/third_party/pram/localization/pose_estimator.py b/third_party/pram/localization/pose_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..5d28d6001d38cfd5f6f6135c611293ab5e83cf0a --- /dev/null +++ b/third_party/pram/localization/pose_estimator.py @@ -0,0 +1,612 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> pose_estimation +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 08/02/2024 11:01 +==================================================''' +import torch +import numpy as np +import pycolmap +import cv2 +import os +import time +import os.path as osp +from collections import defaultdict + + +def get_covisibility_frames(frame_id, all_images, points3D, covisibility_frame=50): + observed = all_images[frame_id].point3D_ids + covis = defaultdict(int) + for pid in observed: + if pid == -1: + continue + for img_id in points3D[pid].image_ids: + if img_id != frame_id: + covis[img_id] += 1 + + print('Find {:d} connected frames'.format(len(covis.keys()))) + + covis_ids = np.array(list(covis.keys())) + covis_num = np.array([covis[i] for i in covis_ids]) + + if len(covis_ids) <= covisibility_frame: + sel_covis_ids = covis_ids[np.argsort(-covis_num)] + else: + ind_top = np.argpartition(covis_num, -covisibility_frame) + ind_top = ind_top[-covisibility_frame:] # unsorted top k + ind_top = ind_top[np.argsort(-covis_num[ind_top])] + sel_covis_ids = [covis_ids[i] for i in ind_top] + + print('Retain {:d} valid connected frames'.format(len(sel_covis_ids))) + return sel_covis_ids + + +def feature_matching(query_data, db_data, matcher): + db_3D_ids = db_data['db_3D_ids'] + if db_3D_ids is None: + with torch.no_grad(): + match_data = { + 'keypoints0': torch.from_numpy(query_data['keypoints'])[None].float().cuda(), + 'scores0': torch.from_numpy(query_data['scores'])[None].float().cuda(), + 'descriptors0': torch.from_numpy(query_data['descriptors'])[None].float().cuda(), + 'image0': torch.empty((1, 1,) + tuple(query_data['image_size'])[::-1]), + + 'keypoints1': torch.from_numpy(db_data['keypoints'])[None].float().cuda(), + 'scores1': torch.from_numpy(db_data['scores'])[None].float().cuda(), + 'descriptors1': torch.from_numpy(db_data['descriptors'])[None].float().cuda(), # [B, N, D] + 'image1': torch.empty((1, 1,) + tuple(db_data['image_size'])[::-1]), + } + matches = matcher(match_data)['matches0'][0].cpu().numpy() + del match_data + else: + masks = (db_3D_ids != -1) + valid_ids = [i for i in range(masks.shape[0]) if masks[i]] + if len(valid_ids) == 0: + return np.zeros(shape=(query_data['keypoints'].shape[0],), dtype=int) - 1 + with torch.no_grad(): + match_data = { + 'keypoints0': torch.from_numpy(query_data['keypoints'])[None].float().cuda(), + 'scores0': torch.from_numpy(query_data['scores'])[None].float().cuda(), + 'descriptors0': torch.from_numpy(query_data['descriptors'])[None].float().cuda(), + 'image0': torch.empty((1, 1,) + tuple(query_data['image_size'])[::-1]), + + 'keypoints1': torch.from_numpy(db_data['keypoints'])[masks][None].float().cuda(), + 'scores1': torch.from_numpy(db_data['scores'])[masks][None].float().cuda(), + 'descriptors1': torch.from_numpy(db_data['descriptors'][masks])[None].float().cuda(), + 'image1': torch.empty((1, 1,) + tuple(db_data['image_size'])[::-1]), + } + matches = matcher(match_data)['matches0'][0].cpu().numpy() + del match_data + + for i in range(matches.shape[0]): + if matches[i] >= 0: + matches[i] = valid_ids[matches[i]] + + return matches + + +def find_2D_3D_matches(query_data, db_id, points3D, feature_file, db_images, matcher, obs_th=0): + kpq = query_data['keypoints'] + db_name = db_images[db_id].name + kpdb = feature_file[db_name]['keypoints'][()] + desc_db = feature_file[db_name]["descriptors"][()] + desc_db = desc_db.transpose() + + # print('db_desc: ', desc_db.shape, query_data['descriptors'].shape) + + points3D_ids = db_images[db_id].point3D_ids + matches = feature_matching(query_data=query_data, + db_data={ + 'keypoints': kpdb, + 'scores': feature_file[db_name]['scores'][()], + 'descriptors': desc_db, + 'db_3D_ids': points3D_ids, + 'image_size': feature_file[db_name]['image_size'][()] + }, + matcher=matcher) + mkpdb = [] + mp3d_ids = [] + q_ids = [] + mkpq = [] + mp3d = [] + valid_matches = [] + for idx in range(matches.shape[0]): + if matches[idx] == -1: + continue + if points3D_ids[matches[idx]] == -1: + continue + id_3D = points3D_ids[matches[idx]] + + # reject 3d points without enough observations + if len(points3D[id_3D].image_ids) < obs_th: + continue + mp3d.append(points3D[id_3D].xyz) + mp3d_ids.append(id_3D) + + mkpq.append(kpq[idx]) + mkpdb.append(kpdb[matches[idx]]) + q_ids.append(idx) + valid_matches.append(matches[idx]) + + mp3d = np.array(mp3d, float).reshape(-1, 3) + mkpq = np.array(mkpq, float).reshape(-1, 2) + 0.5 + return mp3d, mkpq, mp3d_ids, q_ids + + +# hfnet, cvpr 2019 +def pose_estimator_hloc(qname, qinfo, db_ids, db_images, points3D, + feature_file, + thresh, + image_dir, + matcher, + log_info=None, + query_img_prefix='', + db_img_prefix=''): + kpq = feature_file[qname]['keypoints'][()] + score_q = feature_file[qname]['scores'][()] + desc_q = feature_file[qname]['descriptors'][()] + desc_q = desc_q.transpose() + imgsize_q = feature_file[qname]['image_size'][()] + query_data = { + 'keypoints': kpq, + 'scores': score_q, + 'descriptors': desc_q, + 'image_size': imgsize_q, + } + + camera_model, width, height, params = qinfo + cam = pycolmap.Camera(model=camera_model, width=width, height=height, params=params) + cfg = { + 'model': camera_model, + 'width': width, + 'height': height, + 'params': params, + } + all_mkpts = [] + all_mp3ds = [] + all_points3D_ids = [] + best_db_id = db_ids[0] + best_db_name = db_images[best_db_id].name + + t_start = time.time() + + for cluster_idx, db_id in enumerate(db_ids): + mp3d, mkpq, mp3d_ids, q_ids = find_2D_3D_matches( + query_data=query_data, + db_id=db_id, + points3D=points3D, + feature_file=feature_file, + db_images=db_images, + matcher=matcher, + obs_th=3) + if mp3d.shape[0] > 0: + all_mkpts.append(mkpq) + all_mp3ds.append(mp3d) + all_points3D_ids = all_points3D_ids + mp3d_ids + + if len(all_mkpts) == 0: + print_text = 'Localize {:s} failed, but use the pose of {:s} as approximation'.format(qname, best_db_name) + print(print_text) + if log_info is not None: + log_info = log_info + print_text + '\n' + + qvec = db_images[best_db_id].qvec + tvec = db_images[best_db_id].tvec + + return { + 'qvec': qvec, + 'tvec': tvec, + 'log_info': log_info, + 'qname': qname, + 'dbname': best_db_name, + 'num_inliers': 0, + 'order': -1, + 'keypoints_query': np.array([]), + 'points3D_ids': [], + 'time': time.time() - t_start, + } + + all_mkpts = np.vstack(all_mkpts) + all_mp3ds = np.vstack(all_mp3ds) + + ret = pycolmap.absolute_pose_estimation(all_mkpts, all_mp3ds, cam, + estimation_options={ + "ransac": {"max_error": thresh}}, + refinement_options={}, + ) + if ret is None: + ret = {'success': False, } + else: + ret['success'] = True + ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]] + ret['tvec'] = ret['cam_from_world'].translation + success = ret['success'] + + if success: + print_text = 'qname: {:s} localization success with {:d}/{:d} inliers'.format(qname, ret['num_inliers'], + all_mp3ds.shape[0]) + print(print_text) + if log_info is not None: + log_info = log_info + print_text + '\n' + + qvec = ret['qvec'] + tvec = ret['tvec'] + ret['cfg'] = cfg + num_inliers = ret['num_inliers'] + inliers = ret['inliers'] + return { + 'qvec': qvec, + 'tvec': tvec, + 'log_info': log_info, + 'qname': qname, + 'dbname': best_db_name, + 'num_inliers': num_inliers, + 'order': -1, + 'keypoints_query': np.array([all_mkpts[i] for i in range(len(inliers)) if inliers[i]]), + 'points3D_ids': [all_points3D_ids[i] for i in range(len(inliers)) if inliers[i]], + 'time': time.time() - t_start, + } + else: + print_text = 'Localize {:s} failed, but use the pose of {:s} as approximation'.format(qname, best_db_name) + print(print_text) + if log_info is not None: + log_info = log_info + print_text + '\n' + + qvec = db_images[best_db_id].qvec + tvec = db_images[best_db_id].tvec + + return { + 'qvec': qvec, + 'tvec': tvec, + 'log_info': log_info, + 'qname': qname, + 'dbname': best_db_name, + 'num_inliers': 0, + 'order': -1, + 'keypoints_query': np.array([]), + 'points3D_ids': [], + 'time': time.time() - t_start, + } + + +def pose_refinement(query_data, + query_cam, feature_file, db_frame_id, db_images, points3D, matcher, + covisibility_frame=50, + obs_th=3, + opt_th=12, + qvec=None, + tvec=None, + log_info='', + **kwargs, + ): + db_ids = get_covisibility_frames(frame_id=db_frame_id, all_images=db_images, points3D=points3D, + covisibility_frame=covisibility_frame) + + mp3d = [] + mkpq = [] + mkpdb = [] + all_3D_ids = [] + all_score_q = [] + kpq = query_data['keypoints'] + for i, db_id in enumerate(db_ids): + db_name = db_images[db_id].name + kpdb = feature_file[db_name]['keypoints'][()] + scores_db = feature_file[db_name]['scores'][()] + imgsize_db = feature_file[db_name]['image_size'][()] + desc_db = feature_file[db_name]["descriptors"][()] + desc_db = desc_db.transpose() + + points3D_ids = db_images[db_id].point3D_ids + if points3D_ids.size == 0: + print("No 3D points in this db image: ", db_name) + continue + + matches = feature_matching(query_data=query_data, + db_data={'keypoints': kpdb, + 'scores': scores_db, + 'descriptors': desc_db, + 'image_size': imgsize_db, + 'db_3D_ids': points3D_ids, + }, + matcher=matcher, + ) + valid = np.where(matches > -1)[0] + valid = valid[points3D_ids[matches[valid]] != -1] + inliers = [] + for idx in valid: + id_3D = points3D_ids[matches[idx]] + if len(points3D[id_3D].image_ids) < obs_th: + continue + + inliers.append(True) + + mp3d.append(points3D[id_3D].xyz) + mkpq.append(kpq[idx]) + mkpdb.append(kpdb[matches[idx]]) + all_3D_ids.append(id_3D) + + mp3d = np.array(mp3d, float).reshape(-1, 3) + mkpq = np.array(mkpq, float).reshape(-1, 2) + 0.5 + print_text = 'Get {:d} covisible frames with {:d} matches from cluster optimization'.format(len(db_ids), + mp3d.shape[0]) + print(print_text) + if log_info is not None: + log_info += (print_text + '\n') + + # cam = pycolmap.Camera(model=cfg['model'], params=cfg['params']) + ret = pycolmap.absolute_pose_estimation(mkpq, mp3d, + query_cam, + estimation_options={ + "ransac": {"max_error": opt_th}}, + refinement_options={}, + ) + if ret is None: + ret = {'success': False, } + else: + ret['success'] = True + ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]] + ret['tvec'] = ret['cam_from_world'].translation + + if not ret['success']: + ret['mkpq'] = mkpq + ret['3D_ids'] = all_3D_ids + ret['db_ids'] = db_ids + ret['score_q'] = all_score_q + ret['log_info'] = log_info + ret['qvec'] = qvec + ret['tvec'] = tvec + ret['inliers'] = [False for i in range(mkpq.shape[0])] + ret['num_inliers'] = 0 + ret['keypoints_query'] = np.array([]) + ret['points3D_ids'] = [] + return ret + + ret_inliers = ret['inliers'] + loc_keypoints_query = np.array([mkpq[i] for i in range(len(ret_inliers)) if ret_inliers[i]]) + loc_points3D_ids = [all_3D_ids[i] for i in range(len(ret_inliers)) if ret_inliers[i]] + + ret['mkpq'] = mkpq + ret['3D_ids'] = all_3D_ids + ret['db_ids'] = db_ids + ret['log_info'] = log_info + ret['keypoints_query'] = loc_keypoints_query + ret['points3D_ids'] = loc_points3D_ids + + return ret + + +# proposed in efficient large-scale localization by global instance recognition, cvpr 2022 +def pose_estimator_iterative(qname, qinfo, db_ids, db_images, points3D, feature_file, thresh, image_dir, + matcher, + inlier_th=50, + log_info=None, + do_covisibility_opt=False, + covisibility_frame=50, + vis_dir=None, + obs_th=0, + opt_th=12, + gt_qvec=None, + gt_tvec=None, + query_img_prefix='', + db_img_prefix='', + ): + print("qname: ", qname) + db_name_to_id = {image.name: i for i, image in db_images.items()} + # q_img = cv2.imread(osp.join(image_dir, query_img_prefix, qname)) + + kpq = feature_file[qname]['keypoints'][()] + score_q = feature_file[qname]['scores'][()] + imgsize_q = feature_file[qname]['image_size'][()] + desc_q = feature_file[qname]['descriptors'][()] + desc_q = desc_q.transpose() # [N D] + query_data = { + 'keypoints': kpq, + 'scores': score_q, + 'descriptors': desc_q, + 'image_size': imgsize_q, + } + camera_model, width, height, params = qinfo + + best_results = { + 'tvec': None, + 'qvec': None, + 'num_inliers': 0, + 'single_num_inliers': 0, + 'db_id': -1, + 'order': -1, + 'qname': qname, + 'optimize': False, + 'dbname': db_images[db_ids[0]].name, + "ret_source": "", + "inliers": [], + 'keypoints_query': np.array([]), + 'points3D_ids': [], + } + + cam = pycolmap.Camera(model=camera_model, width=width, height=height, params=params) + + for cluster_idx, db_id in enumerate(db_ids): + db_name = db_images[db_id].name + mp3d, mkpq, mp3d_ids, q_ids = find_2D_3D_matches( + query_data=query_data, + db_id=db_id, + points3D=points3D, + feature_file=feature_file, + db_images=db_images, + matcher=matcher, + obs_th=obs_th) + + if mp3d.shape[0] < 8: + print_text = "qname: {:s} dbname: {:s}({:d}/{:d}) failed because of insufficient 3d points {:d}".format( + qname, + db_name, + cluster_idx + 1, + len(db_ids), + mp3d.shape[0]) + print(print_text) + if log_info is not None: + log_info += (print_text + '\n') + continue + + ret = pycolmap.absolute_pose_estimation(mkpq, mp3d, cam, + estimation_options={ + "ransac": {"max_error": thresh}}, + refinement_options={}, + ) + + if ret is None: + ret = {'success': False, } + else: + ret['success'] = True + ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]] + ret['tvec'] = ret['cam_from_world'].translation + + if not ret["success"]: + print_text = "qname: {:s} dbname: {:s} ({:d}/{:d}) failed after matching".format(qname, db_name, + cluster_idx + 1, + len(db_ids)) + print(print_text) + if log_info is not None: + log_info += (print_text + '\n') + continue + + inliers = ret['inliers'] + num_inliers = ret['num_inliers'] + inlier_p3d_ids = [mp3d_ids[i] for i in range(len(inliers)) if inliers[i]] + inlier_mkpq = [mkpq[i] for i in range(len(inliers)) if inliers[i]] + loc_keypoints_query = np.array(inlier_mkpq) + loc_points3D_ids = inlier_p3d_ids + + if ret['num_inliers'] > best_results['num_inliers']: + best_results['qvec'] = ret['qvec'] + best_results['tvec'] = ret['tvec'] + best_results['inlier'] = ret['inliers'] + best_results['num_inliers'] = ret['num_inliers'] + best_results['dbname'] = db_name + best_results['order'] = cluster_idx + 1 + best_results['keypoints_query'] = loc_keypoints_query + best_results['points3D_ids'] = loc_points3D_ids + + if ret['num_inliers'] < inlier_th: + print_text = "qname: {:s} dbname: {:s} ({:d}/{:d}) failed insufficient {:d} inliers".format(qname, + db_name, + cluster_idx + 1, + len(db_ids), + num_inliers, + ) + print(print_text) + if log_info is not None: + log_info += (print_text + '\n') + continue + + print_text = "qname: {:s} dbname: {:s} ({:d}/{:d}) initialization succeed with {:d} inliers".format( + qname, + db_name, + cluster_idx + 1, + len(db_ids), + ret["num_inliers"] + ) + print(print_text) + if log_info is not None: + log_info += (print_text + '\n') + + if do_covisibility_opt: + ret = pose_refinement(qname=qname, + query_cam=cam, + feature_file=feature_file, + db_frame_id=db_id, + db_images=db_images, + points3D=points3D, + thresh=thresh, + covisibility_frame=covisibility_frame, + matcher=matcher, + obs_th=obs_th, + opt_th=opt_th, + qvec=ret['qvec'], + tvec=ret['tvec'], + log_info='', + image_dir=image_dir, + vis_dir=vis_dir, + gt_qvec=gt_qvec, + gt_tvec=gt_tvec, + ) + + loc_keypoints_query = ret['keypoints_query'] + loc_points3D_ids = ret['points3D_ids'] + + log_info = log_info + ret['log_info'] + print_text = 'Find {:d} inliers after optimization'.format(ret['num_inliers']) + print(print_text) + if log_info is not None: + log_info += (print_text + "\n") + + # localization succeed + qvec = ret['qvec'] + tvec = ret['tvec'] + num_inliers = ret['num_inliers'] + best_results['keypoints_query'] = loc_keypoints_query + best_results['points3D_ids'] = loc_points3D_ids + + best_results['qvec'] = qvec + best_results['tvec'] = tvec + best_results['num_inliers'] = num_inliers + best_results['log_info'] = log_info + + return best_results + + if best_results['num_inliers'] >= 10: # 20 for aachen + qvec = best_results['qvec'] + tvec = best_results['tvec'] + best_dbname = best_results['dbname'] + + best_results['keypoints_query'] = loc_keypoints_query + best_results['points3D_ids'] = loc_points3D_ids + + if do_covisibility_opt: + ret = pose_refinement(qname=qname, + query_cam=cam, + feature_file=feature_file, + db_frame_id=db_name_to_id[best_dbname], + db_images=db_images, + points3D=points3D, + thresh=thresh, + covisibility_frame=covisibility_frame, + matcher=matcher, + obs_th=obs_th, + opt_th=opt_th, + qvec=qvec, + tvec=tvec, + log_info='', + image_dir=image_dir, + vis_dir=vis_dir, + gt_qvec=gt_qvec, + gt_tvec=gt_tvec, + ) + + # localization succeed + qvec = ret['qvec'] + tvec = ret['tvec'] + num_inliers = ret['num_inliers'] + best_results['keypoints_query'] = loc_keypoints_query + best_results['points3D_ids'] = loc_points3D_ids + + best_results['qvec'] = qvec + best_results['tvec'] = tvec + best_results['num_inliers'] = num_inliers + best_results['log_info'] = log_info + + return best_results + + closest = db_images[db_ids[0][0]] + print_text = 'Localize {:s} failed, but use the pose of {:s} as approximation'.format(qname, closest.name) + print(print_text) + if log_info is not None: + log_info += (print_text + '\n') + + best_results['qvec'] = closest.qvec + best_results['tvec'] = closest.tvec + best_results['num_inliers'] = -1 + best_results['log_info'] = log_info + + return best_results diff --git a/third_party/pram/localization/refframe.py b/third_party/pram/localization/refframe.py new file mode 100644 index 0000000000000000000000000000000000000000..b7eeafd44557ffdfda5829dab00dd5df125148b4 --- /dev/null +++ b/third_party/pram/localization/refframe.py @@ -0,0 +1,147 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> refframe +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 04/03/2024 10:06 +==================================================''' +import numpy as np +from localization.camera import Camera +from colmap_utils.camera_intrinsics import intrinsics_from_camera +from colmap_utils.read_write_model import qvec2rotmat + + +class RefFrame: + def __init__(self, camera: Camera, id: int, qvec: np.ndarray, tvec: np.ndarray, + point3D_ids: np.ndarray = None, keypoints: np.ndarray = None, + name: str = None, scene_name: str = None): + self.camera = camera + self.id = id + self.qvec = qvec + self.tvec = tvec + self.name = name + self.scene_name = scene_name + self.width = camera.width + self.height = camera.height + self.image_size = np.array([self.height, self.width]) + + self.point3D_ids = point3D_ids + self.keypoints = keypoints + self.descriptors = None + self.keypoint_segs = None + self.xyzs = None + + def get_keypoints_by_sid(self, sid: int): + mask = (self.keypoint_segs == sid) + return { + 'point3D_ids': self.point3D_ids[mask], + 'keypoints': self.keypoints[mask][:, :2], + 'descriptors': self.descriptors[mask], + 'scores': self.keypoints[mask][:, 2], + 'xyzs': self.xyzs[mask], + 'camera': self.camera, + } + + valid_p3d_ids = [] + valid_kpts = [] + valid_descs = [] + valid_scores = [] + valid_xyzs = [] + for i, v in enumerate(self.point3D_ids): + if v in point3Ds.keys(): + p3d = point3Ds[v] + if p3d.seg_id == sid: + valid_kpts.append(self.keypoints[i]) + valid_p3d_ids.append(v) + valid_xyzs.append(p3d.xyz) + valid_descs.append(p3d.descriptor) + valid_scores.append(p3d.error) + return { + 'point3D_ids': np.array(valid_p3d_ids), + 'keypoints': np.array(valid_kpts), + 'descriptors': np.array(valid_descs), + 'scores': np.array(valid_scores), + 'xyzs': np.array(valid_xyzs), + } + + def get_keypoints(self): + return { + 'point3D_ids': self.point3D_ids, + 'keypoints': self.keypoints[:, :2], + 'descriptors': self.descriptors, + 'scores': self.keypoints[:, 2], + 'xyzs': self.xyzs, + 'camera': self.camera, + } + + valid_p3d_ids = [] + valid_kpts = [] + valid_descs = [] + valid_scores = [] + valid_xyzs = [] + for i, v in enumerate(self.point3D_ids): + if v in point3Ds.keys(): + p3d = point3Ds[v] + valid_kpts.append(self.keypoints[i]) + valid_p3d_ids.append(v) + valid_xyzs.append(p3d.xyz) + valid_descs.append(p3d.descriptor) + valid_scores.append(p3d.error) + return { + 'points3D_ids': np.array(valid_p3d_ids), + 'keypoints': np.array(valid_kpts), + 'descriptors': np.array(valid_descs), + 'scores': 1 / np.clip(np.array(valid_scores) * 5, a_min=1., a_max=20.), + 'xyzs': np.array(valid_xyzs), + 'camera': self.camera, + } + + def associate_keypoints_with_point3Ds(self, point3Ds: dict): + xyzs = [] + descs = [] + scores = [] + p3d_ids = [] + kpt_sids = [] + for i, v in enumerate(self.point3D_ids): + if v in point3Ds.keys(): + p3d = point3Ds[v] + p3d_ids.append(v) + xyzs.append(p3d.xyz) + descs.append(p3d.descriptor) + scores.append(p3d.error) + + kpt_sids.append(p3d.seg_id) + + xyzs = np.array(xyzs) + if xyzs.shape[0] == 0: + return False + + descs = np.array(descs) + scores = 1 / np.clip(np.array(scores) * 5, a_min=1., a_max=20.) + p3d_ids = np.array(p3d_ids) + uvs = self.project(xyzs=xyzs) + self.keypoints = np.hstack([uvs, scores.reshape(-1, 1)]) + self.descriptors = descs + self.point3D_ids = p3d_ids + self.xyzs = xyzs + self.keypoint_segs = np.array(kpt_sids) + + return True + + def project(self, xyzs): + ''' + :param xyzs: [N, 3] + :return: + ''' + K = intrinsics_from_camera(camera_model=self.camera.model, params=self.camera.params) # [3, 3] + Rcw = qvec2rotmat(self.qvec) + tcw = self.tvec.reshape(3, 1) + Tcw = np.eye(4, dtype=float) + Tcw[:3, :3] = Rcw + Tcw[:3, 3:] = tcw + xyzs_homo = np.hstack([xyzs, np.ones(shape=(xyzs.shape[0], 1))]) # [N 4] + + xyzs_cam = Tcw @ xyzs_homo.transpose() # [4, N] + uvs = K @ xyzs_cam[:3, :] # [3, N] + uvs[:2, :] = uvs[:2, :] / uvs[2, :] + return uvs[:2, :].transpose() diff --git a/third_party/pram/localization/singlemap3d.py b/third_party/pram/localization/singlemap3d.py new file mode 100644 index 0000000000000000000000000000000000000000..77fc0ef2c78321044bb8f8f2952ccb278ea28d8f --- /dev/null +++ b/third_party/pram/localization/singlemap3d.py @@ -0,0 +1,532 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> map3d +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 04/03/2024 10:25 +==================================================''' +import numpy as np +from collections import defaultdict +import os.path as osp +import pycolmap +import logging +import time + +import torch + +from localization.refframe import RefFrame +from localization.frame import Frame +from localization.point3d import Point3D +from colmap_utils.read_write_model import qvec2rotmat, read_model, read_compressed_model +from localization.utils import read_gt_pose + + +class SingleMap3D: + def __init__(self, config, matcher, with_compress=False, start_sid: int = 0): + self.config = config + self.matcher = matcher + self.image_path_prefix = self.config['image_path_prefix'] + self.start_sid = start_sid # for a dataset with multiple scenes + if not with_compress: + cameras, images, p3ds = read_model( + path=osp.join(config['landmark_path'], 'model'), ext='.bin') + p3d_descs = np.load(osp.join(config['landmark_path'], 'point3D_desc.npy'), + allow_pickle=True)[()] + else: + cameras, images, p3ds = read_compressed_model( + path=osp.join(config['landmark_path'], 'compress_model_{:s}'.format(config['cluster_method'])), + ext='.bin') + p3d_descs = np.load(osp.join(config['landmark_path'], 'compress_model_{:s}/point3D_desc.npy'.format( + config['cluster_method'])), allow_pickle=True)[()] + + print('Load {} cameras {} images {} 3D points'.format(len(cameras), len(images), len(p3d_descs))) + + seg_data = np.load( + osp.join(config['landmark_path'], 'point3D_cluster_n{:d}_{:s}_{:s}.npy'.format(config['n_cluster'], + config['cluster_mode'], + config['cluster_method'])), + allow_pickle=True)[()] + + p3d_id = seg_data['id'] + seg_id = seg_data['label'] + p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])} + seg_p3d = {} + for k in p3d_seg.keys(): + sid = p3d_seg[k] + if sid in seg_p3d.keys(): + seg_p3d[sid].append(k) + else: + seg_p3d[sid] = [k] + + print('Load {} segments and {} 3d points'.format(len(seg_p3d.keys()), len(p3d_seg.keys()))) + seg_vrf = np.load( + osp.join(config['landmark_path'], 'point3D_vrf_n{:d}_{:s}_{:s}.npy'.format(config['n_cluster'], + config['cluster_mode'], + config['cluster_method'])), + allow_pickle=True)[()] + + # construct 3D map + self.initialize_point3Ds(p3ds=p3ds, p3d_descs=p3d_descs, p3d_seg=p3d_seg) + self.initialize_ref_frames(cameras=cameras, images=images) + + all_vrf_frame_ids = [] + self.seg_ref_frame_ids = {} + for sid in seg_vrf.keys(): + self.seg_ref_frame_ids[sid] = [] + for vi in seg_vrf[sid].keys(): + vrf_frame_id = seg_vrf[sid][vi]['image_id'] + self.seg_ref_frame_ids[sid].append(vrf_frame_id) + if with_compress and vrf_frame_id in self.reference_frames.keys(): + self.reference_frames[vrf_frame_id].point3D_ids = seg_vrf[sid][vi]['original_points3d'] + + all_vrf_frame_ids.extend(self.seg_ref_frame_ids[sid]) + + if with_compress: + all_ref_ids = list(self.reference_frames.keys()) + for fid in all_ref_ids: + valid = self.reference_frames[fid].associate_keypoints_with_point3Ds(point3Ds=self.point3Ds) + if not valid: + del self.reference_frames[fid] + + all_vrf_frame_ids = np.unique(all_vrf_frame_ids) + all_vrf_frame_ids = [v for v in all_vrf_frame_ids if v in self.reference_frames.keys()] + self.build_covisibility_graph(frame_ids=all_vrf_frame_ids, n_frame=config['localization'][ + 'covisibility_frame']) # build covisible frames for vrf frames only + + logging.info( + f'Construct {len(self.reference_frames.keys())} ref frames and {len(self.point3Ds.keys())} 3d points') + + self.gt_poses = {} + if config['gt_pose_path'] is not None: + gt_pose_path = osp.join(config['dataset_path'], config['gt_pose_path']) + self.read_gt_pose(path=gt_pose_path) + + def read_gt_pose(self, path, prefix=''): + self.gt_poses = read_gt_pose(path=path) + print('Load {} gt poses'.format(len(self.gt_poses.keys()))) + + def initialize_point3Ds(self, p3ds, p3d_descs, p3d_seg): + self.point3Ds = {} + for id in p3ds.keys(): + if id not in p3d_seg.keys(): + continue + self.point3Ds[id] = Point3D(id=id, xyz=p3ds[id].xyz, error=p3ds[id].error, + refframe_id=-1, rgb=p3ds[id].rgb, + descriptor=p3d_descs[id], seg_id=p3d_seg[id], + frame_ids=p3ds[id].image_ids) + + def initialize_ref_frames(self, cameras, images): + self.reference_frames = {} + for id in images.keys(): + im = images[id] + cam = cameras[im.camera_id] + self.reference_frames[id] = RefFrame(camera=cam, id=id, qvec=im.qvec, tvec=im.tvec, + point3D_ids=im.point3D_ids, + keypoints=im.xys, name=im.name) + + def localize_with_ref_frame(self, q_frame: Frame, q_kpt_ids: np.ndarray, sid, semantic_matching=False): + ref_frame_id = self.seg_ref_frame_ids[sid][0] + ref_frame = self.reference_frames[ref_frame_id] + if semantic_matching and sid > 0: + ref_data = ref_frame.get_keypoints_by_sid(sid=sid) + else: + ref_data = ref_frame.get_keypoints() + + q_descs = q_frame.descriptors[q_kpt_ids] + q_kpts = q_frame.keypoints[q_kpt_ids, :2] + q_scores = q_frame.keypoints[q_kpt_ids, 2] + + xyzs = ref_data['xyzs'] + point3D_ids = ref_data['point3D_ids'] + ref_sids = np.array([self.point3Ds[v].seg_id for v in point3D_ids]) + with torch.no_grad(): + indices0 = self.matcher({ + 'descriptors0': torch.from_numpy(q_descs)[None].cuda().float(), + 'keypoints0': torch.from_numpy(q_kpts)[None].cuda().float(), + 'scores0': torch.from_numpy(q_scores)[None].cuda().float(), + 'image_shape0': (1, 3, q_frame.camera.width, q_frame.camera.height), + + 'descriptors1': torch.from_numpy(ref_data['descriptors'])[None].cuda().float(), + 'keypoints1': torch.from_numpy(ref_data['keypoints'])[None].cuda().float(), + 'scores1': torch.from_numpy(ref_data['scores'])[None].cuda().float(), + 'image_shape1': (1, 3, ref_frame.camera.width, ref_frame.camera.height), + } + )['matches0'][0].cpu().numpy() + + valid = indices0 >= 0 + mkpts = q_kpts[valid] + mkpt_ids = q_kpt_ids[valid] + mxyzs = xyzs[indices0[valid]] + mpoint3D_ids = point3D_ids[indices0[valid]] + matched_sids = ref_sids[indices0[valid]] + matched_ref_keypoints = ref_data['keypoints'][indices0[valid]] + + # print('mkpts: ', mkpts.shape, mxyzs.shape, np.sum(indices0 >= 0)) + # cfg = q_frame.camera._asdict() + # q_cam = pycolmap.Camera(model=q_frame.camera.model, ) + # config = {"estimation": {"ransac": {"max_error": ransac_thresh}}, **(config or {})} + ret = pycolmap.absolute_pose_estimation(mkpts + 0.5, + mxyzs, + q_frame.camera, + estimation_options={ + "ransac": {"max_error": self.config['localization']['threshold']}}, + refinement_options={}, + # max_error_px=self.config['localization']['threshold'] + ) + if ret is None: + ret = {'success': False, } + else: + ret['success'] = True + ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]] + ret['tvec'] = ret['cam_from_world'].translation + ret['matched_keypoints'] = mkpts + ret['matched_keypoint_ids'] = mkpt_ids + ret['matched_xyzs'] = mxyzs + ret['reference_frame_id'] = ref_frame_id + ret['matched_point3D_ids'] = mpoint3D_ids + ret['matched_sids'] = matched_sids + ret['matched_ref_keypoints'] = matched_ref_keypoints + + if not ret['success']: + ret['num_inliers'] = 0 + ret['inliers'] = np.zeros(shape=(mkpts.shape[0],), dtype=bool) + return ret + + def match(self, query_data, ref_data): + q_descs = query_data['descriptors'] + q_kpts = query_data['keypoints'] + q_scores = query_data['scores'] + xyzs = ref_data['xyzs'] + points3D_ids = ref_data['point3D_ids'] + with torch.no_grad(): + indices0 = self.matcher({ + 'descriptors0': torch.from_numpy(q_descs)[None].cuda().float(), + 'keypoints0': torch.from_numpy(q_kpts)[None].cuda().float(), + 'scores0': torch.from_numpy(q_scores)[None].cuda().float(), + 'image_shape0': (1, 3, query_data['camera'].width, query_data['camera'].height), + + 'descriptors1': torch.from_numpy(ref_data['descriptors'])[None].cuda().float(), + 'keypoints1': torch.from_numpy(ref_data['keypoints'])[None].cuda().float(), + 'scores1': torch.from_numpy(ref_data['scores'])[None].cuda().float(), + 'image_shape1': (1, 3, ref_data['camera'].width, ref_data['camera'].height), + } + )['matches0'][0].cpu().numpy() + + valid = indices0 >= 0 + mkpts = q_kpts[valid] + mkpt_ids = np.where(valid)[0] + mxyzs = xyzs[indices0[valid]] + mpoints3D_ids = points3D_ids[indices0[valid]] + + return { + 'matched_keypoints': mkpts, + 'matched_xyzs': mxyzs, + 'matched_point3D_ids': mpoints3D_ids, + 'matched_keypoint_ids': mkpt_ids, + } + + def build_covisibility_graph(self, frame_ids: list = None, n_frame: int = 20): + def find_covisible_frames(frame_id): + observed = self.reference_frames[frame_id].point3D_ids + covis = defaultdict(int) + for pid in observed: + if pid == -1: + continue + if pid not in self.point3Ds.keys(): + continue + for img_id in self.point3Ds[pid].frame_ids: + covis[img_id] += 1 + + covis_ids = np.array(list(covis.keys())) + covis_num = np.array([covis[i] for i in covis_ids]) + + if len(covis_ids) <= n_frame: + sel_covis_ids = covis_ids[np.argsort(-covis_num)] + else: + ind_top = np.argpartition(covis_num, -n_frame) + ind_top = ind_top[-n_frame:] # unsorted top k + ind_top = ind_top[np.argsort(-covis_num[ind_top])] + sel_covis_ids = [covis_ids[i] for i in ind_top] + + return sel_covis_ids + + if frame_ids is None: + frame_ids = list(self.referece_frames.keys()) + + self.covisible_graph = defaultdict() + for frame_id in frame_ids: + self.covisible_graph[frame_id] = find_covisible_frames(frame_id=frame_id) + + def refine_pose(self, q_frame: Frame, refinement_method='matching'): + if refinement_method == 'matching': + return self.refine_pose_by_matching(q_frame=q_frame) + elif refinement_method == 'projection': + return self.refine_pose_by_projection(q_frame=q_frame) + else: + raise NotImplementedError + + def refine_pose_by_matching(self, q_frame): + ref_frame_id = q_frame.reference_frame_id + db_ids = self.covisible_graph[ref_frame_id] + print('Find {} covisible frames'.format(len(db_ids))) + loc_success = q_frame.tracking_status + if loc_success and ref_frame_id in db_ids: + init_kpts = q_frame.matched_keypoints + init_kpt_ids = q_frame.matched_keypoint_ids + init_point3D_ids = q_frame.matched_point3D_ids + init_xyzs = np.array([self.point3Ds[v].xyz for v in init_point3D_ids]).reshape(-1, 3) + list(db_ids).remove(ref_frame_id) + else: + init_kpts = None + init_xyzs = None + init_point3D_ids = None + + matched_xyzs = [] + matched_kpts = [] + matched_point3D_ids = [] + matched_kpt_ids = [] + for idx, frame_id in enumerate(db_ids): + ref_data = self.reference_frames[frame_id].get_keypoints() + match_out = self.match(query_data={ + 'keypoints': q_frame.keypoints[:, :2], + 'scores': q_frame.keypoints[:, 2], + 'descriptors': q_frame.descriptors, + 'camera': q_frame.camera, }, + ref_data=ref_data) + if match_out['matched_keypoints'].shape[0] > 0: + matched_kpts.append(match_out['matched_keypoints']) + matched_xyzs.append(match_out['matched_xyzs']) + matched_point3D_ids.append(match_out['matched_point3D_ids']) + matched_kpt_ids.append(match_out['matched_keypoint_ids']) + if len(matched_kpts) > 1: + matched_kpts = np.vstack(matched_kpts) + matched_xyzs = np.vstack(matched_xyzs).reshape(-1, 3) + matched_point3D_ids = np.hstack(matched_point3D_ids) + matched_kpt_ids = np.hstack(matched_kpt_ids) + else: + matched_kpts = matched_kpts[0] + matched_xyzs = matched_xyzs[0] + matched_point3D_ids = matched_point3D_ids[0] + matched_kpt_ids = matched_kpt_ids[0] + if init_kpts is not None and init_kpts.shape[0] > 0: + matched_kpts = np.vstack([matched_kpts, init_kpts]) + matched_xyzs = np.vstack([matched_xyzs, init_xyzs]) + matched_point3D_ids = np.hstack([matched_point3D_ids, init_point3D_ids]) + matched_kpt_ids = np.hstack([matched_kpt_ids, init_kpt_ids]) + + matched_sids = np.array([self.point3Ds[v].seg_id for v in matched_point3D_ids]) + + print_text = 'Refinement by matching. Get {:d} covisible frames with {:d} matches for optimization'.format( + len(db_ids), matched_xyzs.shape[0]) + print(print_text) + + t_start = time.time() + ret = pycolmap.absolute_pose_estimation(matched_kpts + 0.5, + matched_xyzs, + q_frame.camera, + estimation_options={ + 'ransac': { + 'max_error': self.config['localization']['threshold'], + 'min_num_trials': 1000, + 'max_num_trials': 10000, + 'confidence': 0.995, + }}, + refinement_options={}, + # max_error_px=self.config['localization']['threshold'], + # min_num_trials=1000, max_num_trials=10000, confidence=0.995) + ) + print('Time of RANSAC: {:.2f}s'.format(time.time() - t_start)) + + if ret is None: + ret = {'success': False, } + else: + ret['success'] = True + ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]] + ret['tvec'] = ret['cam_from_world'].translation + + ret['matched_keypoints'] = matched_kpts + ret['matched_keypoint_ids'] = matched_kpt_ids + ret['matched_xyzs'] = matched_xyzs + ret['matched_point3D_ids'] = matched_point3D_ids + ret['matched_sids'] = matched_sids + + if ret['success']: + inlier_mask = np.array(ret['inliers']) + best_reference_frame_ids = self.find_reference_frames(matched_point3D_ids=matched_point3D_ids[inlier_mask], + candidate_frame_ids=self.covisible_graph.keys()) + else: + best_reference_frame_ids = self.find_reference_frames(matched_point3D_ids=matched_point3D_ids, + candidate_frame_ids=self.covisible_graph.keys()) + + ret['refinement_reference_frame_ids'] = best_reference_frame_ids[:self.config['localization'][ + 'covisibility_frame']] + ret['reference_frame_id'] = best_reference_frame_ids[0] + + return ret + + @torch.no_grad() + def refine_pose_by_projection(self, q_frame): + q_Rcw = qvec2rotmat(q_frame.qvec) + q_tcw = q_frame.tvec + q_Tcw = np.eye(4, dtype=float) # [4 4] + q_Tcw[:3, :3] = q_Rcw + q_Tcw[:3, 3] = q_tcw + cam = q_frame.camera + imw = cam.width + imh = cam.height + K = q_frame.get_intrinsics() # [3, 3] + reference_frame_id = q_frame.reference_frame_id + covis_frame_ids = self.covisible_graph[reference_frame_id] + if reference_frame_id not in covis_frame_ids: + covis_frame_ids.append(reference_frame_id) + all_point3D_ids = [] + + for frame_id in covis_frame_ids: + all_point3D_ids.extend(list(self.reference_frames[frame_id].point3D_ids)) + + all_point3D_ids = np.unique(all_point3D_ids) + all_xyzs = [] + all_descs = [] + all_sids = [] + for pid in all_point3D_ids: + all_xyzs.append(self.point3Ds[pid].xyz) + all_descs.append(self.point3Ds[pid].descriptor) + all_sids.append(self.point3Ds[pid].seg_id) + + all_xyzs = np.array(all_xyzs) # [N 3] + all_descs = np.array(all_descs) # [N 3] + all_point3D_ids = np.array(all_point3D_ids) + all_sids = np.array(all_sids) + + # move to gpu (distortion is not included) + # proj_uv = pycolmap.camera.img_from_cam( + # np.array([1, 1, 1]).reshape(1, 3), + # ) + all_xyzs_cuda = torch.from_numpy(all_xyzs).cuda() + ones = torch.ones(size=(all_xyzs_cuda.shape[0], 1), dtype=all_xyzs_cuda.dtype).cuda() + all_xyzs_cuda_homo = torch.cat([all_xyzs_cuda, ones], dim=1) # [N 4] + K_cuda = torch.from_numpy(K).cuda() + proj_uvs = K_cuda @ (torch.from_numpy(q_Tcw).cuda() @ all_xyzs_cuda_homo.t())[:3, :] # [3, N] + proj_uvs[0] /= proj_uvs[2] + proj_uvs[1] /= proj_uvs[2] + mask = (proj_uvs[2] > 0) * (proj_uvs[2] < 100) * (proj_uvs[0] >= 0) * (proj_uvs[0] < imw) * ( + proj_uvs[1] >= 0) * (proj_uvs[1] < imh) + + proj_uvs = proj_uvs[:, mask] + + print('Projection: out of range {:d}/{:d}'.format(all_xyzs_cuda.shape[0], proj_uvs.shape[1])) + + mxyzs = all_xyzs[mask.cpu().numpy()] + mpoint3D_ids = all_point3D_ids[mask.cpu().numpy()] + msids = all_sids[mask.cpu().numpy()] + + q_kpts_cuda = torch.from_numpy(q_frame.keypoints[:, :2]).cuda() + proj_error = q_kpts_cuda[..., None] - proj_uvs[:2][None] + proj_error = torch.sqrt(torch.sum(proj_error ** 2, dim=1)) # [M N] + out_of_range_mask = (proj_error >= 2 * self.config['localization']['threshold']) + + q_descs_cuda = torch.from_numpy(q_frame.descriptors).cuda().float() # [M D] + all_descs_cuda = torch.from_numpy(all_descs).cuda().float()[mask] # [N D] + desc_dist = torch.sqrt(2 - 2 * q_descs_cuda @ all_descs_cuda.t() + 1e-6) + desc_dist[out_of_range_mask] = desc_dist[out_of_range_mask] + 100 + dists, ids = torch.topk(desc_dist, k=2, largest=False, dim=1) + # apply nn ratio + ratios = dists[:, 0] / dists[:, 1] # smaller, better + ratio_mask = (ratios <= 0.995) * (dists[:, 0] < 100) + ratio_mask = ratio_mask.cpu().numpy() + ids = ids.cpu().numpy()[ratio_mask, 0] + + ratio_num = torch.sum(ratios <= 0.995) + proj_num = torch.sum(dists[:, 0] < 100) + + print('Projection: after ratio {:d}/{:d}, ratio {:d}, proj {:d}'.format(q_kpts_cuda.shape[0], + np.sum(ratio_mask), + ratio_num, proj_num)) + + mkpts = q_frame.keypoints[ratio_mask] + mkpt_ids = np.where(ratio_mask)[0] + mxyzs = mxyzs[ids] + mpoint3D_ids = mpoint3D_ids[ids] + msids = msids[ids] + print('projection: ', mkpts.shape, mkpt_ids.shape, mxyzs.shape, mpoint3D_ids.shape, msids.shape) + + t_start = time.time() + ret = pycolmap.absolute_pose_estimation(mkpts[:, :2] + 0.5, mxyzs, q_frame.camera, + estimation_options={ + "ransac": {"max_error": self.config['localization']['threshold']}}, + refinement_options={}, + # max_error_px=self.config['localization']['threshold'] + ) + if ret is None: + ret = {'success': False, } + else: + ret['success'] = True + ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]] + ret['tvec'] = ret['cam_from_world'].translation + # inlier_mask = np.ones(shape=(mkpts.shape[0],), dtype=bool).tolist() + # ret = pycolmap.pose_refinement(q_frame.tvec, q_frame.qvec, mkpts[:, :2] + 0.5, mxyzs, inlier_mask, cfg) + # ret['num_inliers'] = np.sum(inlier_mask).astype(int) + # ret['inliers'] = np.array(inlier_mask) + + print_text = 'Refinement by projection. Get {:d} inliers of {:d} matches for optimization'.format( + ret['num_inliers'], mxyzs.shape[0]) + print(print_text) + print('Time of RANSAC: {:.2f}s'.format(time.time() - t_start)) + + ret['matched_keypoints'] = mkpts + ret['matched_xyzs'] = mxyzs + ret['matched_point3D_ids'] = mpoint3D_ids + ret['matched_sids'] = msids + ret['matched_keypoint_ids'] = mkpt_ids + + if ret['success']: + inlier_mask = np.array(ret['inliers']) + best_reference_frame_ids = self.find_reference_frames(matched_point3D_ids=mpoint3D_ids[inlier_mask], + candidate_frame_ids=self.covisible_graph.keys()) + else: + best_reference_frame_ids = self.find_reference_frames(matched_point3D_ids=mpoint3D_ids, + candidate_frame_ids=self.covisible_graph.keys()) + + ret['refinement_reference_frame_ids'] = best_reference_frame_ids[:self.config['localization'][ + 'covisibility_frame']] + ret['reference_frame_id'] = best_reference_frame_ids[0] + + if not ret['success']: + ret['num_inliers'] = 0 + ret['inliers'] = np.zeros(shape=(mkpts.shape[0],), dtype=bool) + + return ret + + def find_reference_frames(self, matched_point3D_ids, candidate_frame_ids=None): + covis_frames = defaultdict(int) + for pid in matched_point3D_ids: + for im_id in self.point3Ds[pid].frame_ids: + if candidate_frame_ids is not None and im_id in candidate_frame_ids: + covis_frames[im_id] += 1 + + covis_ids = np.array(list(covis_frames.keys())) + covis_num = np.array([covis_frames[i] for i in covis_ids]) + sorted_idxes = np.argsort(covis_num)[::-1] # larger to small + sorted_frame_ids = covis_ids[sorted_idxes] + return sorted_frame_ids + + def check_semantic_consistency(self, q_frame: Frame, sid, overlap_ratio=0.5): + ref_frame_id = self.seg_ref_frame_ids[sid][0] + ref_frame = self.reference_frames[ref_frame_id] + + q_sids = q_frame.seg_ids + ref_sids = np.array([self.point3Ds[v].seg_id for v in ref_frame.point3D_ids]) + self.start_sid + overlap_sids = np.intersect1d(q_sids, ref_sids) + + overlap_num1 = 0 + overlap_num2 = 0 + for sid in overlap_sids: + overlap_num1 += np.sum(q_sids == sid) + overlap_num2 += np.sum(ref_sids == sid) + + ratio1 = overlap_num1 / q_sids.shape[0] + ratio2 = overlap_num2 / ref_sids.shape[0] + + # print('semantic_check: ', overlap_sids, overlap_num1, ratio1, overlap_num2, ratio2) + + return min(ratio1, ratio2) >= overlap_ratio diff --git a/third_party/pram/localization/tracker.py b/third_party/pram/localization/tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..a401fea82c2372cfdf301ab2d2fb34981facf4fe --- /dev/null +++ b/third_party/pram/localization/tracker.py @@ -0,0 +1,338 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> tracker +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/02/2024 16:58 +==================================================''' +import time +import cv2 +import numpy as np +import torch +import pycolmap +from localization.frame import Frame +from localization.base_model import dynamic_load +import localization.matchers as matchers +from localization.match_features_batch import confs as matcher_confs +from recognition.vis_seg import vis_seg_point, generate_color_dic, vis_inlier, plot_matches +from tools.common import resize_img + + +class Tracker: + def __init__(self, locMap, matcher, config): + self.locMap = locMap + self.matcher = matcher + self.config = config + self.loc_config = config['localization'] + + self.lost = True + + self.curr_frame = None + self.last_frame = None + + device = 'cuda' if torch.cuda.is_available() else 'cpu' + Model = dynamic_load(matchers, 'nearest_neighbor') + self.nn_matcher = Model(matcher_confs['NNM']['model']).eval().to(device) + + def run(self, frame: Frame): + print('Start tracking...') + show = self.config['localization']['show'] + self.curr_frame = frame + ref_img = self.last_frame.image + curr_img = self.curr_frame.image + q_kpts = frame.keypoints + + t_start = time.time() + ret = self.track_last_frame(curr_frame=self.curr_frame, last_frame=self.last_frame) + self.curr_frame.time_loc = self.curr_frame.time_loc + time.time() - t_start + + if show: + curr_matched_kpts = ret['matched_keypoints'] + ref_matched_kpts = ret['matched_ref_keypoints'] + img_loc_matching = plot_matches(img1=curr_img, img2=ref_img, + pts1=curr_matched_kpts, + pts2=ref_matched_kpts, + inliers=np.array([True for i in range(curr_matched_kpts.shape[0])]), + radius=9, line_thickness=3) + self.curr_frame.image_matching = img_loc_matching + + q_ref_img_matching = resize_img(img_loc_matching, nh=512) + + if not ret['success']: + show_text = 'Tracking FAILED!' + img_inlier = vis_inlier(img=curr_img, kpts=curr_matched_kpts, + inliers=[False for i in range(curr_matched_kpts.shape[0])], radius=9 + 2, + thickness=2) + q_img_inlier = cv2.putText(img=img_inlier, text=show_text, org=(30, 30), + fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), + thickness=2, lineType=cv2.LINE_AA) + + q_img_loc = np.hstack([resize_img(q_ref_img_matching, nh=512), resize_img(q_img_inlier, nh=512)]) + + cv2.imshow('loc', q_img_loc) + key = cv2.waitKey(self.loc_config['show_time']) + if key == ord('q'): + cv2.destroyAllWindows() + exit(0) + return False + + ret['matched_scene_name'] = self.last_frame.scene_name + success = self.verify_and_update(q_frame=self.curr_frame, ret=ret) + + if not success: + return False + + if ret['num_inliers'] < 256: + # refinement is necessary for tracking last frame + t_start = time.time() + ret = self.locMap.sub_maps[self.last_frame.matched_scene_name].refine_pose(self.curr_frame, + refinement_method= + self.loc_config[ + 'refinement_method']) + self.curr_frame.time_ref = self.curr_frame.time_ref + time.time() - t_start + ret['matched_scene_name'] = self.last_frame.scene_name + success = self.verify_and_update(q_frame=self.curr_frame, ret=ret) + + if show: + q_err, t_err = self.curr_frame.compute_pose_error() + num_matches = ret['matched_keypoints'].shape[0] + num_inliers = ret['num_inliers'] + show_text = 'Tracking, k/m/i: {:d}/{:d}/{:d}'.format(q_kpts.shape[0], num_matches, num_inliers) + q_img_inlier = vis_inlier(img=curr_img, kpts=ret['matched_keypoints'], inliers=ret['inliers'], + radius=9 + 2, thickness=2) + q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 30), + fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), + thickness=2, lineType=cv2.LINE_AA) + show_text = 'r_err:{:.2f}, t_err:{:.2f}'.format(q_err, t_err) + q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 80), + fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), + thickness=2, lineType=cv2.LINE_AA) + self.curr_frame.image_inlier = q_img_inlier + + q_img_loc = np.hstack([resize_img(q_ref_img_matching, nh=512), resize_img(q_img_inlier, nh=512)]) + + cv2.imshow('loc', q_img_loc) + key = cv2.waitKey(self.loc_config['show_time']) + if key == ord('q'): + cv2.destroyAllWindows() + exit(0) + + self.lost = success + return success + + def verify_and_update(self, q_frame: Frame, ret: dict): + num_matches = ret['matched_keypoints'].shape[0] + num_inliers = ret['num_inliers'] + + q_frame.qvec = ret['qvec'] + q_frame.tvec = ret['tvec'] + + q_err, t_err = q_frame.compute_pose_error() + + if num_inliers < self.loc_config['min_inliers']: + print_text = 'Failed due to insufficient {:d} inliers, q_err: {:.2f}, t_err: {:.2f}'.format( + ret['num_inliers'], q_err, t_err) + print(print_text) + q_frame.tracking_status = False + q_frame.clear_localization_track() + return False + else: + print_text = 'Succeed! Find {}/{} 2D-3D inliers,q_err: {:.2f}, t_err: {:.2f}'.format( + num_inliers, num_matches, q_err, t_err) + print(print_text) + q_frame.tracking_status = True + + self.update_current_frame(curr_frame=q_frame, ret=ret) + return True + + def update_current_frame(self, curr_frame: Frame, ret: dict): + curr_frame.qvec = ret['qvec'] + curr_frame.tvec = ret['tvec'] + + curr_frame.matched_scene_name = ret['matched_scene_name'] + curr_frame.reference_frame_id = ret['reference_frame_id'] + inliers = np.array(ret['inliers']) + + curr_frame.matched_keypoints = ret['matched_keypoints'][inliers] + curr_frame.matched_xyzs = ret['matched_xyzs'][inliers] + curr_frame.matched_point3D_ids = ret['matched_point3D_ids'][inliers] + curr_frame.matched_keypoint_ids = ret['matched_keypoint_ids'][inliers] + curr_frame.matched_sids = ret['matched_sids'][inliers] + + def track_last_frame(self, curr_frame: Frame, last_frame: Frame): + curr_kpts = curr_frame.keypoints[:, :2] + curr_scores = curr_frame.keypoints[:, 2] + curr_descs = curr_frame.descriptors + curr_kpt_ids = np.arange(curr_kpts.shape[0]) + + last_kpts = last_frame.keypoints[:, :2] + last_scores = last_frame.keypoints[:, 2] + last_descs = last_frame.descriptors + last_xyzs = last_frame.xyzs + last_point3D_ids = last_frame.point3D_ids + last_sids = last_frame.seg_ids + + # ''' + indices = self.matcher({ + 'descriptors0': torch.from_numpy(curr_descs)[None].cuda().float(), + 'keypoints0': torch.from_numpy(curr_kpts)[None].cuda().float(), + 'scores0': torch.from_numpy(curr_scores)[None].cuda().float(), + 'image_shape0': (1, 3, curr_frame.camera.width, curr_frame.camera.height), + + 'descriptors1': torch.from_numpy(last_descs)[None].cuda().float(), + 'keypoints1': torch.from_numpy(last_kpts)[None].cuda().float(), + 'scores1': torch.from_numpy(last_scores)[None].cuda().float(), + 'image_shape1': (1, 3, last_frame.camera.width, last_frame.camera.height), + })['matches0'][0].cpu().numpy() + ''' + + indices = self.nn_matcher({ + 'descriptors0': torch.from_numpy(curr_descs.transpose()).float().cuda()[None], + 'descriptors1': torch.from_numpy(last_descs.transpose()).float().cuda()[None], + })['matches0'][0].cpu().numpy() + ''' + + valid = (indices >= 0) + + matched_point3D_ids = last_point3D_ids[indices[valid]] + point3D_mask = (matched_point3D_ids >= 0) + matched_point3D_ids = matched_point3D_ids[point3D_mask] + matched_sids = last_sids[indices[valid]][point3D_mask] + + matched_kpts = curr_kpts[valid][point3D_mask] + matched_kpt_ids = curr_kpt_ids[valid][point3D_mask] + matched_xyzs = last_xyzs[indices[valid]][point3D_mask] + matched_last_kpts = last_kpts[indices[valid]][point3D_mask] + + print('Tracking: {:d} matches from {:d}-{:d} kpts'.format(matched_kpts.shape[0], curr_kpts.shape[0], + last_kpts.shape[0])) + + # print('tracking: ', matched_kpts.shape, matched_xyzs.shape) + ret = pycolmap.absolute_pose_estimation(matched_kpts + 0.5, matched_xyzs, + curr_frame.camera, + estimation_options={ + "ransac": {"max_error": self.config['localization']['threshold']}}, + refinement_options={}, + # max_error_px=self.config['localization']['threshold'] + ) + if ret is None: + ret = {'success': False, } + else: + ret['success'] = True + ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]] + ret['tvec'] = ret['cam_from_world'].translation + + ret['matched_keypoints'] = matched_kpts + ret['matched_keypoint_ids'] = matched_kpt_ids + ret['matched_ref_keypoints'] = matched_last_kpts + ret['matched_xyzs'] = matched_xyzs + ret['matched_point3D_ids'] = matched_point3D_ids + ret['matched_sids'] = matched_sids + ret['reference_frame_id'] = last_frame.reference_frame_id + ret['matched_scene_name'] = last_frame.matched_scene_name + return ret + + def track_last_frame_fast(self, curr_frame: Frame, last_frame: Frame): + curr_kpts = curr_frame.keypoints[:, :2] + curr_scores = curr_frame.keypoints[:, 2] + curr_descs = curr_frame.descriptors + curr_kpt_ids = np.arange(curr_kpts.shape[0]) + + last_point3D_ids = last_frame.point3D_ids + point3D_mask = (last_point3D_ids >= 0) + last_kpts = last_frame.keypoints[:, :2][point3D_mask] + last_scores = last_frame.keypoints[:, 2][point3D_mask] + last_descs = last_frame.descriptors[point3D_mask] + last_xyzs = last_frame.xyzs[point3D_mask] + last_sids = last_frame.seg_ids[point3D_mask] + + minx = np.min(last_kpts[:, 0]) + maxx = np.max(last_kpts[:, 0]) + miny = np.min(last_kpts[:, 1]) + maxy = np.max(last_kpts[:, 1]) + curr_mask = (curr_kpts[:, 0] >= minx) * (curr_kpts[:, 0] <= maxx) * (curr_kpts[:, 1] >= miny) * ( + curr_kpts[:, 1] <= maxy) + + curr_kpts = curr_kpts[curr_mask] + curr_scores = curr_scores[curr_mask] + curr_descs = curr_descs[curr_mask] + curr_kpt_ids = curr_kpt_ids[curr_mask] + # ''' + indices = self.matcher({ + 'descriptors0': torch.from_numpy(curr_descs)[None].cuda().float(), + 'keypoints0': torch.from_numpy(curr_kpts)[None].cuda().float(), + 'scores0': torch.from_numpy(curr_scores)[None].cuda().float(), + 'image_shape0': (1, 3, curr_frame.camera.width, curr_frame.camera.height), + + 'descriptors1': torch.from_numpy(last_descs)[None].cuda().float(), + 'keypoints1': torch.from_numpy(last_kpts)[None].cuda().float(), + 'scores1': torch.from_numpy(last_scores)[None].cuda().float(), + 'image_shape1': (1, 3, last_frame.camera.width, last_frame.camera.height), + })['matches0'][0].cpu().numpy() + ''' + + indices = self.nn_matcher({ + 'descriptors0': torch.from_numpy(curr_descs.transpose()).float().cuda()[None], + 'descriptors1': torch.from_numpy(last_descs.transpose()).float().cuda()[None], + })['matches0'][0].cpu().numpy() + ''' + + valid = (indices >= 0) + + matched_point3D_ids = last_point3D_ids[indices[valid]] + matched_sids = last_sids[indices[valid]] + + matched_kpts = curr_kpts[valid] + matched_kpt_ids = curr_kpt_ids[valid] + matched_xyzs = last_xyzs[indices[valid]] + matched_last_kpts = last_kpts[indices[valid]] + + print('Tracking: {:d} matches from {:d}-{:d} kpts'.format(matched_kpts.shape[0], curr_kpts.shape[0], + last_kpts.shape[0])) + + # print('tracking: ', matched_kpts.shape, matched_xyzs.shape) + ret = pycolmap.absolute_pose_estimation(matched_kpts + 0.5, matched_xyzs, + curr_frame.camera._asdict(), + max_error_px=self.config['localization']['threshold']) + + ret['matched_keypoints'] = matched_kpts + ret['matched_keypoint_ids'] = matched_kpt_ids + ret['matched_ref_keypoints'] = matched_last_kpts + ret['matched_xyzs'] = matched_xyzs + ret['matched_point3D_ids'] = matched_point3D_ids + ret['matched_sids'] = matched_sids + ret['reference_frame_id'] = last_frame.reference_frame_id + ret['matched_scene_name'] = last_frame.matched_scene_name + return ret + + @torch.no_grad() + def match_frame(self, frame: Frame, reference_frame: Frame): + print('match: ', frame.keypoints.shape, reference_frame.keypoints.shape) + matches = self.matcher({ + 'descriptors0': torch.from_numpy(frame.descriptors)[None].cuda().float(), + 'keypoints0': torch.from_numpy(frame.keypoints[:, :2])[None].cuda().float(), + 'scores0': torch.from_numpy(frame.keypoints[:, 2])[None].cuda().float(), + 'image_shape0': (1, 3, frame.image_size[0], frame.image_size[1]), + + # 'descriptors0': torch.from_numpy(reference_frame.descriptors)[None].cuda().float(), + # 'keypoints0': torch.from_numpy(reference_frame.keypoints[:, :2])[None].cuda().float(), + # 'scores0': torch.from_numpy(reference_frame.keypoints[:, 2])[None].cuda().float(), + # 'image_shape0': (1, 3, reference_frame.image_size[0], reference_frame.image_size[1]), + + 'descriptors1': torch.from_numpy(reference_frame.descriptors)[None].cuda().float(), + 'keypoints1': torch.from_numpy(reference_frame.keypoints[:, :2])[None].cuda().float(), + 'scores1': torch.from_numpy(reference_frame.keypoints[:, 2])[None].cuda().float(), + 'image_shape1': (1, 3, reference_frame.image_size[0], reference_frame.image_size[1]), + + })['matches0'][0].cpu().numpy() + + ids1 = np.arange(matches.shape[0]) + ids2 = matches + ids1 = ids1[matches >= 0] + ids2 = ids2[matches >= 0] + + mask_p3ds = reference_frame.points3d_mask[ids2] + ids1 = ids1[mask_p3ds] + ids2 = ids2[mask_p3ds] + + return ids1, ids2 diff --git a/third_party/pram/localization/triangulation.py b/third_party/pram/localization/triangulation.py new file mode 100644 index 0000000000000000000000000000000000000000..d5b885ec4be9c328353af9c0b0aaf136d694556a --- /dev/null +++ b/third_party/pram/localization/triangulation.py @@ -0,0 +1,317 @@ +# code is from hloc https://github.com/cvg/Hierarchical-Localization/blob/master/hloc/triangulation.py +import argparse +import contextlib +import io +import sys +from pathlib import Path +from typing import Any, Dict, List, Optional + +import numpy as np +import pycolmap +from tqdm import tqdm + +from colmap_utils.database import COLMAPDatabase +from colmap_utils.geometry import compute_epipolar_errors +from colmap_utils.io import get_keypoints, get_matches +from colmap_utils.parsers import parse_retrieval +import logging + + +class OutputCapture: + def __init__(self, verbose: bool): + self.verbose = verbose + + def __enter__(self): + if not self.verbose: + self.capture = contextlib.redirect_stdout(io.StringIO()) + self.out = self.capture.__enter__() + + def __exit__(self, exc_type, *args): + if not self.verbose: + self.capture.__exit__(exc_type, *args) + if exc_type is not None: + # logger.error("Failed with output:\n%s", self.out.getvalue()) + logging.error("Failed with output:\n%s", self.out.getvalue()) + sys.stdout.flush() + + +def create_db_from_model( + reconstruction: pycolmap.Reconstruction, database_path: Path +) -> Dict[str, int]: + if database_path.exists(): + # logger.warning("The database already exists, deleting it.") + logging.warning("The database already exists, deleting it.") + database_path.unlink() + + db = COLMAPDatabase.connect(database_path) + db.create_tables() + + for i, camera in reconstruction.cameras.items(): + db.add_camera( + camera.model.value, + camera.width, + camera.height, + camera.params, + camera_id=i, + prior_focal_length=True, + ) + + for i, image in reconstruction.images.items(): + db.add_image(image.name, image.camera_id, image_id=i) + + db.commit() + db.close() + return {image.name: i for i, image in reconstruction.images.items()} + + +def import_features( + image_ids: Dict[str, int], database_path: Path, features_path: Path +): + # logger.info("Importing features into the database...") + logging.info("Importing features into the database...") + db = COLMAPDatabase.connect(database_path) + + for image_name, image_id in tqdm(image_ids.items()): + keypoints = get_keypoints(features_path, image_name) + keypoints += 0.5 # COLMAP origin + db.add_keypoints(image_id, keypoints) + + db.commit() + db.close() + + +def import_matches( + image_ids: Dict[str, int], + database_path: Path, + pairs_path: Path, + matches_path: Path, + min_match_score: Optional[float] = None, + skip_geometric_verification: bool = False, +): + # logger.info("Importing matches into the database...") + logging.info("Importing matches into the database...") + + with open(str(pairs_path), "r") as f: + pairs = [p.split() for p in f.readlines()] + + db = COLMAPDatabase.connect(database_path) + + matched = set() + for name0, name1 in tqdm(pairs): + id0, id1 = image_ids[name0], image_ids[name1] + if len({(id0, id1), (id1, id0)} & matched) > 0: + continue + matches, scores = get_matches(matches_path, name0, name1) + if min_match_score: + matches = matches[scores > min_match_score] + db.add_matches(id0, id1, matches) + matched |= {(id0, id1), (id1, id0)} + + if skip_geometric_verification: + db.add_two_view_geometry(id0, id1, matches) + + db.commit() + db.close() + + +def estimation_and_geometric_verification( + database_path: Path, pairs_path: Path, verbose: bool = False +): + # logger.info("Performing geometric verification of the matches...") + logging.info("Performing geometric verification of the matches...") + with OutputCapture(verbose): + with pycolmap.ostream(): + pycolmap.verify_matches( + database_path, + pairs_path, + options=dict(ransac=dict(max_num_trials=20000, min_inlier_ratio=0.1)), + ) + + +def geometric_verification( + image_ids: Dict[str, int], + reference: pycolmap.Reconstruction, + database_path: Path, + features_path: Path, + pairs_path: Path, + matches_path: Path, + max_error: float = 4.0, +): + # logger.info("Performing geometric verification of the matches...") + logging.info("Performing geometric verification of the matches...") + + pairs = parse_retrieval(pairs_path) + db = COLMAPDatabase.connect(database_path) + + inlier_ratios = [] + matched = set() + for name0 in tqdm(pairs): + id0 = image_ids[name0] + image0 = reference.images[id0] + cam0 = reference.cameras[image0.camera_id] + kps0, noise0 = get_keypoints(features_path, name0, return_uncertainty=True) + noise0 = 1.0 if noise0 is None else noise0 + if len(kps0) > 0: + kps0 = np.stack(cam0.cam_from_img(kps0)) + else: + kps0 = np.zeros((0, 2)) + + for name1 in pairs[name0]: + id1 = image_ids[name1] + image1 = reference.images[id1] + cam1 = reference.cameras[image1.camera_id] + kps1, noise1 = get_keypoints(features_path, name1, return_uncertainty=True) + noise1 = 1.0 if noise1 is None else noise1 + if len(kps1) > 0: + kps1 = np.stack(cam1.cam_from_img(kps1)) + else: + kps1 = np.zeros((0, 2)) + + matches = get_matches(matches_path, name0, name1)[0] + + if len({(id0, id1), (id1, id0)} & matched) > 0: + continue + matched |= {(id0, id1), (id1, id0)} + + if matches.shape[0] == 0: + db.add_two_view_geometry(id0, id1, matches) + continue + + cam1_from_cam0 = image1.cam_from_world * image0.cam_from_world.inverse() + errors0, errors1 = compute_epipolar_errors( + cam1_from_cam0, kps0[matches[:, 0]], kps1[matches[:, 1]] + ) + valid_matches = np.logical_and( + errors0 <= cam0.cam_from_img_threshold(noise0 * max_error), + errors1 <= cam1.cam_from_img_threshold(noise1 * max_error), + ) + # TODO: We could also add E to the database, but we need + # to reverse the transformations if id0 > id1 in utils/database.py. + db.add_two_view_geometry(id0, id1, matches[valid_matches, :]) + inlier_ratios.append(np.mean(valid_matches)) + # logger.info( + logging.info( + "mean/med/min/max valid matches %.2f/%.2f/%.2f/%.2f%%.", + np.mean(inlier_ratios) * 100, + np.median(inlier_ratios) * 100, + np.min(inlier_ratios) * 100, + np.max(inlier_ratios) * 100, + ) + + db.commit() + db.close() + + +def run_triangulation( + model_path: Path, + database_path: Path, + image_dir: Path, + reference_model: pycolmap.Reconstruction, + verbose: bool = False, + options: Optional[Dict[str, Any]] = None, +) -> pycolmap.Reconstruction: + model_path.mkdir(parents=True, exist_ok=True) + # logger.info("Running 3D triangulation...") + logging.info("Running 3D triangulation...") + if options is None: + options = {} + with OutputCapture(verbose): + with pycolmap.ostream(): + reconstruction = pycolmap.triangulate_points( + reference_model, database_path, image_dir, model_path, options=options + ) + return reconstruction + + +def main( + sfm_dir: Path, + reference_sfm_model: Path, + image_dir: Path, + pairs: Path, + features: Path, + matches: Path, + skip_geometric_verification: bool = False, + estimate_two_view_geometries: bool = False, + min_match_score: Optional[float] = None, + verbose: bool = False, + mapper_options: Optional[Dict[str, Any]] = None, +) -> pycolmap.Reconstruction: + assert reference_sfm_model.exists(), reference_sfm_model + assert features.exists(), features + assert pairs.exists(), pairs + assert matches.exists(), matches + + sfm_dir.mkdir(parents=True, exist_ok=True) + database = sfm_dir / "database.db" + reference = pycolmap.Reconstruction(reference_sfm_model) + + image_ids = create_db_from_model(reference, database) + import_features(image_ids, database, features) + import_matches( + image_ids, + database, + pairs, + matches, + min_match_score, + skip_geometric_verification, + ) + if not skip_geometric_verification: + if estimate_two_view_geometries: + estimation_and_geometric_verification(database, pairs, verbose) + else: + geometric_verification( + image_ids, reference, database, features, pairs, matches + ) + reconstruction = run_triangulation( + sfm_dir, database, image_dir, reference, verbose, mapper_options + ) + # logger.info( + logging.info( + "Finished the triangulation with statistics:\n%s", reconstruction.summary() + ) + stats = reconstruction.summary() + with open(sfm_dir / 'statics.txt', 'w') as f: + f.write(stats + '\n') + + # logging.info(f'Statistics:\n{pprint.pformat(stats)}') + return reconstruction + + +def parse_option_args(args: List[str], default_options) -> Dict[str, Any]: + options = {} + for arg in args: + idx = arg.find("=") + if idx == -1: + raise ValueError("Options format: key1=value1 key2=value2 etc.") + key, value = arg[:idx], arg[idx + 1:] + if not hasattr(default_options, key): + raise ValueError( + f'Unknown option "{key}", allowed options and default values' + f" for {default_options.summary()}" + ) + value = eval(value) + target_type = type(getattr(default_options, key)) + if not isinstance(value, target_type): + raise ValueError( + f'Incorrect type for option "{key}":' f" {type(value)} vs {target_type}" + ) + options[key] = value + return options + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--sfm_dir", type=Path, required=True) + parser.add_argument("--reference_sfm_model", type=Path, required=True) + parser.add_argument("--image_dir", type=Path, required=True) + + parser.add_argument("--pairs", type=Path, required=True) + parser.add_argument("--features", type=Path, required=True) + parser.add_argument("--matches", type=Path, required=True) + + parser.add_argument("--skip_geometric_verification", action="store_true") + parser.add_argument("--min_match_score", type=float) + parser.add_argument("--verbose", action="store_true") + args = parser.parse_args().__dict__ + + main(**args) diff --git a/third_party/pram/localization/utils.py b/third_party/pram/localization/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6e5861afceba6bed7518921145505b01caf66954 --- /dev/null +++ b/third_party/pram/localization/utils.py @@ -0,0 +1,83 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> utils +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 07/02/2024 15:27 +==================================================''' +import numpy as np +from colmap_utils.read_write_model import qvec2rotmat + + +def read_query_info(query_fn: str, name_prefix='') -> dict: + results = {} + with open(query_fn, 'r') as f: + lines = f.readlines() + for l in lines: + l = l.strip().split() + name, camera_model, width, height = l[:4] + params = np.array(l[4:], float) + info = (camera_model, int(width), int(height), params) + results[name_prefix + name] = info + print('Load {} query images'.format(len(results.keys()))) + return results + + +def quaternion_angular_error(q1, q2): + """ + angular error between two quaternions + :param q1: (4, ) + :param q2: (4, ) + :return: + """ + d = abs(np.dot(q1, q2)) + d = min(1.0, max(-1.0, d)) + theta = 2 * np.arccos(d) * 180 / np.pi + return theta + + +def compute_pose_error(pred_qcw, pred_tcw, gt_qcw, gt_tcw): + pred_Rcw = qvec2rotmat(qvec=pred_qcw) + pred_tcw = np.array(pred_tcw, float).reshape(3, 1) + pred_twc = -pred_Rcw.transpose() @ pred_tcw + + gt_Rcw = qvec2rotmat(gt_qcw) + gt_tcw = np.array(gt_tcw, float).reshape(3, 1) + gt_twc = -gt_Rcw.transpose() @ gt_tcw + + t_error_xyz = pred_twc - gt_twc + t_error = np.sqrt(np.sum(t_error_xyz ** 2)) + + q_error = quaternion_angular_error(q1=pred_qcw, q2=gt_qcw) + + return q_error, t_error + + +def read_retrieval_results(path): + output = {} + with open(path, "r") as f: + lines = f.readlines() + for p in lines: + p = p.strip("\n").split(" ") + + if p[1] == "no_match": + continue + if p[0] in output.keys(): + output[p[0]].append(p[1]) + else: + output[p[0]] = [p[1]] + return output + + +def read_gt_pose(path): + gt_poses = {} + with open(path, 'r') as f: + lines = f.readlines() + for l in lines: + l = l.strip().split(' ') + gt_poses[l[0]] = { + 'qvec': np.array([float(v) for v in l[1:5]], float), + 'tvec': np.array([float(v) for v in l[5:]], float), + } + + return gt_poses diff --git a/third_party/pram/localization/viewer.py b/third_party/pram/localization/viewer.py new file mode 100644 index 0000000000000000000000000000000000000000..33899f60ab362e240b7b0e6736a157a7aa041d31 --- /dev/null +++ b/third_party/pram/localization/viewer.py @@ -0,0 +1,548 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> viewer +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 05/03/2024 16:50 +==================================================''' +import cv2 +import numpy as np +import pypangolin as pangolin +from OpenGL.GL import * +import time +import threading +from colmap_utils.read_write_model import qvec2rotmat +from tools.common import resize_image_with_padding +from localization.frame import Frame + + +class Viewer: + default_config = { + 'image_size_indoor': 0.1, + 'image_line_width_indoor': 1, + + 'image_size_outdoor': 1, + 'image_line_width_outdoor': 3, + + 'point_size_indoor': 1, + 'point_size_outdoor': 1, + + 'image_width': 640, + 'image_height': 480, + + 'viewpoint_x': 0, + 'viewpoint_y': -1, + 'viewpoint_z': -5, + 'viewpoint_F': 512, + + 'scene': 'indoor', + } + + def __init__(self, locMap, seg_color, config={}): + self.config = {**self.default_config, **config} + self.viewpoint_x = self.config['viewpoint_x'] + self.viewpoint_y = self.config['viewpoint_y'] + self.viewpoint_z = self.config['viewpoint_z'] + self.viewpoint_F = self.config['viewpoint_F'] + self.img_width = self.config['image_width'] + self.img_height = self.config['image_height'] + + if self.config['scene'] == 'indoor': + self.image_size = self.config['image_size_indoor'] + self.image_line_width = self.config['image_line_width_indoor'] + self.point_size = self.config['point_size_indoor'] + + else: + self.image_size = self.config['image_size_outdoor'] + self.image_line_width = self.config['image_line_width_outdoor'] + self.point_size = self.config['point_size_outdoor'] + self.viewpoint_z = -150 + + self.locMap = locMap + self.seg_colors = seg_color + + # current camera pose + self.frame = None + self.Tcw = np.eye(4, dtype=float) + self.Twc = np.linalg.inv(self.Tcw) + self.gt_Tcw = None + self.gt_Twc = None + + self.scene = None + self.current_vrf_id = None + self.reference_frame_ids = None + self.subMap = None + self.seg_point_clouds = None + self.point_clouds = None + + self.start_seg_id = 1 + self.stop = False + + self.refinement = False + self.tracking = False + + # time + self.time_feat = np.NAN + self.time_rec = np.NAN + self.time_loc = np.NAN + self.time_ref = np.NAN + + # image + self.image_rec = None + + def draw_3d_points_white(self): + if self.point_clouds is None: + return + + point_size = self.point_size * 0.5 + glColor4f(0.9, 0.95, 1.0, 0.6) + glPointSize(point_size) + pangolin.glDrawPoints(self.point_clouds) + + def draw_seg_3d_points(self): + if self.seg_point_clouds is None: + return + for sid in self.seg_point_clouds.keys(): + xyzs = self.seg_point_clouds[sid] + point_size = self.point_size * 0.5 + bgr = self.seg_colors[sid + self.start_seg_id + 1] + glColor3f(bgr[2] / 255, bgr[1] / 255, bgr[0] / 255) + glPointSize(point_size) + pangolin.glDrawPoints(xyzs) + + def draw_ref_3d_points(self, use_seg_color=False): + if self.reference_frame_ids is None: + return + + ref_point3D_ids = [] + for fid in self.reference_frame_ids: + pids = self.subMap.reference_frames[fid].point3D_ids + ref_point3D_ids.extend(list(pids)) + + ref_point3D_ids = np.unique(ref_point3D_ids).tolist() + + point_size = self.point_size * 5 + glPointSize(point_size) + glBegin(GL_POINTS) + + for pid in ref_point3D_ids: + if pid not in self.subMap.point3Ds.keys(): + continue + xyz = self.subMap.point3Ds[pid].xyz + rgb = self.subMap.point3Ds[pid].rgb + sid = self.subMap.point3Ds[pid].seg_id + if use_seg_color: + bgr = self.seg_colors[sid + self.start_seg_id + 1] + glColor3f(bgr[2] / 255, bgr[1] / 255, bgr[0] / 255) + else: + glColor3f(rgb[0] / 255, rgb[1] / 255, rgb[2] / 255) + + glVertex3f(xyz[0], xyz[1], xyz[2]) + + glEnd() + + def draw_vrf_frames(self): + if self.subMap is None: + return + w = self.image_size * 1.0 + image_line_width = self.image_line_width * 1.0 + h = w * 0.75 + z = w * 0.6 + for sid in self.subMap.seg_ref_frame_ids.keys(): + frame_id = self.subMap.seg_ref_frame_ids[sid][0] + qvec = self.subMap.reference_frames[frame_id].qvec + tcw = self.subMap.reference_frames[frame_id].tvec + + Rcw = qvec2rotmat(qvec) + + twc = -Rcw.T @ tcw + Rwc = Rcw.T + + Twc = np.column_stack((Rwc, twc)) + Twc = np.vstack((Twc, (0, 0, 0, 1))) + + glPushMatrix() + + glMultMatrixf(Twc.T) + + glLineWidth(image_line_width) + glColor3f(1, 0, 0) + glBegin(GL_LINES) + glVertex3f(0, 0, 0) + glVertex3f(w, h, z) + glVertex3f(0, 0, 0) + glVertex3f(w, -h, z) + glVertex3f(0, 0, 0) + glVertex3f(-w, -h, z) + glVertex3f(0, 0, 0) + glVertex3f(-w, h, z) + + glVertex3f(w, h, z) + glVertex3f(w, -h, z) + + glVertex3f(-w, h, z) + glVertex3f(-w, -h, z) + + glVertex3f(-w, h, z) + glVertex3f(w, h, z) + + glVertex3f(-w, -h, z) + glVertex3f(w, -h, z) + glEnd() + + glPopMatrix() + + def draw_current_vrf_frame(self): + if self.current_vrf_id is None: + return + qvec = self.subMap.reference_frames[self.current_vrf_id].qvec + tcw = self.subMap.reference_frames[self.current_vrf_id].tvec + Rcw = qvec2rotmat(qvec) + twc = -Rcw.T @ tcw + Rwc = Rcw.T + Twc = np.column_stack((Rwc, twc)) + Twc = np.vstack((Twc, (0, 0, 0, 1))) + + camera_line_width = self.image_line_width * 2 + w = self.image_size * 2 + h = w * 0.75 + z = w * 0.6 + + glPushMatrix() + + glMultMatrixf(Twc.T) # note the .T + + glLineWidth(camera_line_width) + glColor3f(1, 0, 0) + glBegin(GL_LINES) + glVertex3f(0, 0, 0) + glVertex3f(w, h, z) + glVertex3f(0, 0, 0) + glVertex3f(w, -h, z) + glVertex3f(0, 0, 0) + glVertex3f(-w, -h, z) + glVertex3f(0, 0, 0) + glVertex3f(-w, h, z) + + glVertex3f(w, h, z) + glVertex3f(w, -h, z) + + glVertex3f(-w, h, z) + glVertex3f(-w, -h, z) + + glVertex3f(-w, h, z) + glVertex3f(w, h, z) + + glVertex3f(-w, -h, z) + glVertex3f(w, -h, z) + glEnd() + + glPopMatrix() + + def draw_current_frame(self, Tcw, color=(0, 1.0, 0)): + Twc = np.linalg.inv(Tcw) + + camera_line_width = self.image_line_width * 2 + w = self.image_size * 2 + h = w * 0.75 + z = w * 0.6 + + glPushMatrix() + + glMultMatrixf(Twc.T) # not the .T + + glLineWidth(camera_line_width) + glColor3f(color[0], color[1], color[2]) + glBegin(GL_LINES) + glVertex3f(0, 0, 0) + glVertex3f(w, h, z) + glVertex3f(0, 0, 0) + glVertex3f(w, -h, z) + glVertex3f(0, 0, 0) + glVertex3f(-w, -h, z) + glVertex3f(0, 0, 0) + glVertex3f(-w, h, z) + + glVertex3f(w, h, z) + glVertex3f(w, -h, z) + + glVertex3f(-w, h, z) + glVertex3f(-w, -h, z) + + glVertex3f(-w, h, z) + glVertex3f(w, h, z) + + glVertex3f(-w, -h, z) + glVertex3f(w, -h, z) + glEnd() + + glPopMatrix() + + def draw_ref_frames(self): + if self.reference_frame_ids is None: + return + w = self.image_size * 1.5 + image_line_width = self.image_line_width * 1.5 + h = w * 0.75 + z = w * 0.6 + for fid in self.reference_frame_ids: + qvec = self.subMap.reference_frames[fid].qvec + tcw = self.subMap.reference_frames[fid].tvec + Rcw = qvec2rotmat(qvec) + + twc = -Rcw.T @ tcw + Rwc = Rcw.T + + Twc = np.column_stack((Rwc, twc)) + Twc = np.vstack((Twc, (0, 0, 0, 1))) + + glPushMatrix() + + glMultMatrixf(Twc.T) + + glLineWidth(image_line_width) + glColor3f(100 / 255, 140 / 255, 17 / 255) + glBegin(GL_LINES) + glVertex3f(0, 0, 0) + glVertex3f(w, h, z) + glVertex3f(0, 0, 0) + glVertex3f(w, -h, z) + glVertex3f(0, 0, 0) + glVertex3f(-w, -h, z) + glVertex3f(0, 0, 0) + glVertex3f(-w, h, z) + + glVertex3f(w, h, z) + glVertex3f(w, -h, z) + + glVertex3f(-w, h, z) + glVertex3f(-w, -h, z) + + glVertex3f(-w, h, z) + glVertex3f(w, h, z) + + glVertex3f(-w, -h, z) + glVertex3f(w, -h, z) + glEnd() + + glPopMatrix() + + def terminate(self): + lock = threading.Lock() + lock.acquire() + self.stop = True + lock.release() + + def update_point_clouds(self): + # for fast drawing + seg_point_clouds = {} + point_clouds = [] + for pid in self.subMap.point3Ds.keys(): + sid = self.subMap.point3Ds[pid].seg_id + xyz = self.subMap.point3Ds[pid].xyz + if sid in seg_point_clouds.keys(): + seg_point_clouds[sid].append(xyz.reshape(3, 1)) + else: + seg_point_clouds[sid] = [xyz.reshape(3, 1)] + + point_clouds.append(xyz.reshape(3, 1)) + + self.seg_point_clouds = seg_point_clouds + self.point_clouds = point_clouds + + def update(self, curr_frame: Frame): + lock = threading.Lock() + lock.acquire() + + # self.frame = curr_frame + self.current_vrf_id = curr_frame.reference_frame_id + self.reference_frame_ids = [self.current_vrf_id] + + # self.reference_frame_ids = curr_frame.refinement_reference_frame_ids + # if self.reference_frame_ids is None: + # self.reference_frame_ids = [self.current_vrf_id] + self.subMap = self.locMap.sub_maps[curr_frame.matched_scene_name] + self.start_seg_id = self.locMap.scene_name_start_sid[curr_frame.matched_scene_name] + + if self.scene is None or self.scene != curr_frame.matched_scene_name: + self.scene = curr_frame.matched_scene_name + self.update_point_clouds() + + if curr_frame.qvec is not None: + Rcw = qvec2rotmat(curr_frame.qvec) + Tcw = np.column_stack((Rcw, curr_frame.tvec)) + self.Tcw = np.vstack((Tcw, (0, 0, 0, 1))) + Rwc = Rcw.T + twc = -Rcw.T @ curr_frame.tvec + Twc = np.column_stack((Rwc, twc)) + self.Twc = np.vstack((Twc, (0, 0, 0, 1))) + + if curr_frame.gt_qvec is not None: + gt_Rcw = qvec2rotmat(curr_frame.gt_qvec) + gt_Tcw = np.column_stack((gt_Rcw, curr_frame.gt_tvec)) + self.gt_Tcw = np.vstack((gt_Tcw, (0, 0, 0, 1))) + gt_Rwc = gt_Rcw.T + gt_twc = -gt_Rcw.T @ curr_frame.gt_tvec + gt_Twc = np.column_stack((gt_Rwc, gt_twc)) + self.gt_Twc = np.vstack((gt_Twc, (0, 0, 0, 1))) + else: + self.gt_Tcw = None + self.gt_Twc = None + + # update time + self.time_feat = curr_frame.time_feat + self.time_rec = curr_frame.time_rec + self.time_loc = curr_frame.time_loc + self.time_ref = curr_frame.time_ref + + # update image + image_rec_inlier = np.hstack([curr_frame.image_rec, curr_frame.image_inlier]) + image_rec_inlier = resize_image_with_padding(image=image_rec_inlier, nw=self.img_width * 2, nh=self.img_height) + image_matching = resize_image_with_padding(image=curr_frame.image_matching, nw=self.img_width * 2, + nh=self.img_height) + image_rec_matching_inliers = resize_image_with_padding(image=np.vstack([image_rec_inlier, image_matching]), + nw=self.img_width * 2, nh=self.img_height * 2) + + self.image_rec = cv2.cvtColor(image_rec_matching_inliers, cv2.COLOR_BGR2RGB) + lock.release() + + def run(self): + pangolin.CreateWindowAndBind("Map reviewer", 640, 480) + glEnable(GL_DEPTH_TEST) + glEnable(GL_BLEND) + glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA) + + pangolin.CreatePanel("menu").SetBounds(pangolin.Attach(0), + pangolin.Attach(1), + pangolin.Attach(0), + # pangolin.Attach.Pix(-175), + pangolin.Attach.Pix(175), + # pangolin.Attach(1) + ) + + menu = pangolin.Var("menu") + menu.Tracking = (False, pangolin.VarMeta(toggle=True)) + menu.FollowCamera = (True, pangolin.VarMeta(toggle=True)) + menu.ShowPoints = (True, pangolin.VarMeta(toggle=True)) + menu.ShowSegs = (False, pangolin.VarMeta(toggle=True)) + menu.ShowRefSegs = (True, pangolin.VarMeta(toggle=True)) + menu.ShowRefPoints = (False, pangolin.VarMeta(toggle=True)) + menu.ShowVRFFrame = (True, pangolin.VarMeta(toggle=True)) + menu.ShowAllVRFs = (False, pangolin.VarMeta(toggle=True)) + menu.ShowRefFrames = (False, pangolin.VarMeta(toggle=True)) + + menu.Refinement = (self.refinement, pangolin.VarMeta(toggle=True)) + + menu.featTime = 'NaN' + menu.recTime = 'NaN' + menu.locTime = 'NaN' + menu.refTime = 'NaN' + menu.totalTime = 'NaN' + + pm = pangolin.ProjectionMatrix(640, 480, self.viewpoint_F, self.viewpoint_F, 320, 240, 0.1, + 10000) + + # /camera position,viewpoint position,axis direction + mv = pangolin.ModelViewLookAt(self.viewpoint_x, + self.viewpoint_y, + self.viewpoint_z, + 0, 0, 0, + # 0.0, -1.0, 0.0, + pangolin.AxisZ, + ) + + s_cam = pangolin.OpenGlRenderState(pm, mv) + # Attach bottom, Attach top, Attach left, Attach right, + scale = 0.42 + d_img_rec = pangolin.Display('image_rec').SetBounds(pangolin.Attach(1 - scale), + pangolin.Attach(1), + pangolin.Attach( + 1 - 0.3), + pangolin.Attach(1), + self.img_width / self.img_height + ) # .SetLock(0, 1) + + handler = pangolin.Handler3D(s_cam) + + d_cam = pangolin.Display('3D').SetBounds( + pangolin.Attach(0), # bottom + pangolin.Attach(1), # top + pangolin.Attach.Pix(175), # left + # pangolin.Attach.Pix(0), # left + pangolin.Attach(1), # right + -640 / 480, # aspect + ).SetHandler(handler) + + d_img_rec_texture = pangolin.GlTexture(self.img_width * 2, self.img_height * 2, GL_RGB, False, 0, GL_RGB, + GL_UNSIGNED_BYTE) + while not pangolin.ShouldQuit() and not self.stop: + glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT) + + # glClearColor(1.0, 1.0, 1.0, 1.0) + glClearColor(0.0, 0.0, 0.0, 1.0) + + d_cam.Activate(s_cam) + if menu.FollowCamera: + s_cam.Follow(pangolin.OpenGlMatrix(self.Twc.astype(np.float32)), follow=True) + + # pangolin.glDrawColouredCube() + if menu.ShowPoints: + self.draw_3d_points_white() + + if menu.ShowRefPoints: + self.draw_ref_3d_points(use_seg_color=False) + if menu.ShowRefSegs: + self.draw_ref_3d_points(use_seg_color=True) + + if menu.ShowSegs: + self.draw_seg_3d_points() + + if menu.ShowAllVRFs: + self.draw_vrf_frames() + + if menu.ShowRefFrames: + self.draw_ref_frames() + + if menu.ShowVRFFrame: + self.draw_current_vrf_frame() + + if menu.Refinement: + self.refinement = True + else: + self.refinement = False + + if menu.Tracking: + self.tracking = True + else: + self.tracking = False + + self.draw_current_frame(Tcw=self.Tcw) + + if self.gt_Tcw is not None: # draw gt pose with color (0, 0, 1.0) + self.draw_current_frame(Tcw=self.gt_Tcw, color=(0., 0., 1.0)) + + d_img_rec.Activate() + glColor4f(1, 1, 1, 1) + + if self.image_rec is not None: + d_img_rec_texture.Upload(self.image_rec, GL_RGB, GL_UNSIGNED_BYTE) + d_img_rec_texture.RenderToViewportFlipY() + + time_total = 0 + if self.time_feat != np.NAN: + menu.featTime = '{:.2f}s'.format(self.time_feat) + time_total = time_total + self.time_feat + if self.time_rec != np.NAN: + menu.recTime = '{:.2f}s'.format(self.time_rec) + time_total = time_total + self.time_rec + if self.time_loc != np.NAN: + menu.locTime = '{:.2f}s'.format(self.time_loc) + time_total = time_total + self.time_loc + if self.time_ref != np.NAN: + menu.refTime = '{:.2f}s'.format(self.time_ref) + time_total = time_total + self.time_ref + menu.totalTime = '{:.2f}s'.format(time_total) + + time.sleep(50 / 1000) + + pangolin.FinishFrame() diff --git a/third_party/pram/main.py b/third_party/pram/main.py new file mode 100644 index 0000000000000000000000000000000000000000..0f32b1e9087dcf7edd152911cf09bef93f0555d5 --- /dev/null +++ b/third_party/pram/main.py @@ -0,0 +1,228 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> train +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 14:26 +==================================================''' +import argparse +import os +import os.path as osp +import torch +import torchvision.transforms.transforms as tvt +import yaml +import torch.utils.data as Data +import torch.multiprocessing as mp +import torch.distributed as dist + +from nets.segnet import SegNet +from nets.segnetvit import SegNetViT +from dataset.utils import collect_batch +from dataset.get_dataset import compose_datasets +from tools.common import torch_set_gpu +from trainer import Trainer + +from nets.sfd2 import ResNet4x, DescriptorCompressor +from nets.superpoint import SuperPoint + +torch.set_grad_enabled(True) + +parser = argparse.ArgumentParser(description='PRAM', formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--config', type=str, required=True, help='config of specifications') +parser.add_argument('--landmark_path', type=str, default=None, help='path of landmarks') + + +def load_feat_network(config): + if config['feature'] == 'spp': + net = SuperPoint(config={ + 'weight_path': '/scratches/flyer_2/fx221/Research/Code/third_weights/superpoint_v1.pth', + }).eval() + elif config['feature'] == 'resnet4x': + net = ResNet4x(inputdim=3, outdim=128) + net.load_state_dict( + torch.load('weights/sfd2_20230511_210205_resnet4x.79.pth', map_location='cpu')['state_dict'], + strict=True) + net.eval() + else: + print('Please input correct feature {:s}'.format(config['feature'])) + net = None + + if config['feat_dim'] != 128: + desc_compressor = DescriptorCompressor(inputdim=128, outdim=config['feat_dim']).eval() + if config['feat_dim'] == 64: + desc_compressor.load_state_dict( + torch.load('weights/20230511_210205_resnet4x_B6_R512_I3_O128_pho_resnet4x_e79_to_O64.pth', + map_location='cpu'), + strict=True) + elif config['feat_dim'] == 32: + desc_compressor.load_state_dict( + torch.load('weights/20230511_210205_resnet4x_B6_R512_I3_O128_pho_resnet4x_e79_to_O32.pth', + map_location='cpu'), + strict=True) + else: + desc_compressor = None + else: + desc_compressor = None + return net, desc_compressor + + +def get_model(config): + desc_dim = 256 if config['feature'] == 'spp' else 128 + if config['use_mid_feature']: + desc_dim = 256 + model_config = { + 'network': { + 'descriptor_dim': desc_dim, + 'n_layers': config['layers'], + 'ac_fn': config['ac_fn'], + 'norm_fn': config['norm_fn'], + 'n_class': config['n_class'], + 'output_dim': config['output_dim'], + 'with_cls': config['with_cls'], + 'with_sc': config['with_sc'], + 'with_score': config['with_score'], + } + } + + if config['network'] == 'segnet': + model = SegNet(model_config.get('network', {})) + config['with_cls'] = False + elif config['network'] == 'segnetvit': + model = SegNetViT(model_config.get('network', {})) + config['with_cls'] = False + else: + raise 'ERROR! {:s} model does not exist'.format(config['network']) + + if config['local_rank'] == 0: + if config['weight_path'] is not None: + state_dict = torch.load(osp.join(config['save_path'], config['weight_path']), map_location='cpu')['model'] + model.load_state_dict(state_dict, strict=True) + print('Load weight from {:s}'.format(osp.join(config['save_path'], config['weight_path']))) + + if config['resume_path'] is not None and not config['eval']: # only for training + model.load_state_dict( + torch.load(osp.join(config['save_path'], config['resume_path']), map_location='cpu')['model'], + strict=True) + print('Load resume weight from {:s}'.format(osp.join(config['save_path'], config['resume_path']))) + + return model + + +def setup(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + # initialize the process group + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + +def train_DDP(rank, world_size, model, config, train_set, test_set, feat_model, img_transforms): + print('In train_DDP..., rank: ', rank) + torch.cuda.set_device(rank) + + device = torch.device(f'cuda:{rank}') + if feat_model is not None: + feat_model.to(device) + model.to(device) + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + setup(rank=rank, world_size=world_size) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank]) + train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, + shuffle=True, + rank=rank, + num_replicas=world_size, + drop_last=True, # important? + ) + train_loader = torch.utils.data.DataLoader(train_set, + batch_size=config['batch_size'] // world_size, + num_workers=config['workers'] // world_size, + # num_workers=1, + pin_memory=True, + # persistent_workers=True, + shuffle=False, # must be False + drop_last=True, + collate_fn=collect_batch, + prefetch_factor=4, + sampler=train_sampler) + config['local_rank'] = rank + + if rank == 0: + test_set = test_set + else: + test_set = None + + trainer = Trainer(model=model, train_loader=train_loader, feat_model=feat_model, eval_loader=test_set, + config=config, img_transforms=img_transforms) + trainer.train() + + +if __name__ == '__main__': + args = parser.parse_args() + with open(args.config, 'rt') as f: + config = yaml.load(f, Loader=yaml.Loader) + torch_set_gpu(gpus=config['gpu']) + if config['local_rank'] == 0: + print(config) + + if config['feature'] == 'spp': + img_transforms = None + else: + img_transforms = [] + img_transforms.append(tvt.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) + img_transforms = tvt.Compose(img_transforms) + feat_model, desc_compressor = load_feat_network(config=config) + + dataset = config['dataset'] + if config['eval'] or config['loc']: + if not config['online']: + from localization.loc_by_rec_eval import loc_by_rec_eval + + test_set = compose_datasets(datasets=dataset, config=config, train=False, sample_ratio=1) + config['n_class'] = test_set.n_class + + model = get_model(config=config) + loc_by_rec_eval(rec_model=model.cuda().eval(), + loader=test_set, + local_feat=feat_model.cuda().eval(), + config=config, img_transforms=img_transforms) + else: + from localization.loc_by_rec_online import loc_by_rec_online + + model = get_model(config=config) + loc_by_rec_online(rec_model=model.cuda().eval(), + local_feat=feat_model.cuda().eval(), + config=config, img_transforms=img_transforms) + exit(0) + + train_set = compose_datasets(datasets=dataset, config=config, train=True, sample_ratio=None) + if config['do_eval']: + test_set = compose_datasets(datasets=dataset, config=config, train=False, sample_ratio=None) + else: + test_set = None + config['n_class'] = train_set.n_class + model = get_model(config=config) + + if not config['with_dist'] or len(config['gpu']) == 1: + config['with_dist'] = False + model = model.cuda() + train_loader = Data.DataLoader(dataset=train_set, + shuffle=True, + batch_size=config['batch_size'], + drop_last=True, + collate_fn=collect_batch, + num_workers=config['workers']) + if test_set is not None: + test_loader = Data.DataLoader(dataset=test_set, + shuffle=False, + batch_size=1, + drop_last=False, + collate_fn=collect_batch, + num_workers=4) + else: + test_loader = None + trainer = Trainer(model=model, train_loader=train_loader, feat_model=feat_model, eval_loader=test_loader, + config=config, img_transforms=img_transforms) + trainer.train() + else: + mp.spawn(train_DDP, nprocs=len(config['gpu']), + args=(len(config['gpu']), model, config, train_set, test_set, feat_model, img_transforms), + join=True) diff --git a/third_party/pram/nets/adagml.py b/third_party/pram/nets/adagml.py new file mode 100644 index 0000000000000000000000000000000000000000..c6980334a8980a105dc91d4586b3a342fb4e648e --- /dev/null +++ b/third_party/pram/nets/adagml.py @@ -0,0 +1,536 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> adagml +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 11/02/2024 14:29 +==================================================''' +import torch +from torch import nn +import torch.nn.functional as F +from typing import Callable +import time +import numpy as np + +torch.backends.cudnn.deterministic = True + +eps = 1e-8 + + +def arange_like(x, dim: int): + return x.new_ones(x.shape[dim]).cumsum(0) - 1 # traceable in 1.1 + + +def dual_softmax(M, dustbin): + M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1) + M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2) + score = torch.log_softmax(M, dim=-1) + torch.log_softmax(M, dim=1) + return torch.exp(score) + + +def sinkhorn(M, r, c, iteration): + p = torch.softmax(M, dim=-1) + u = torch.ones_like(r) + v = torch.ones_like(c) + for _ in range(iteration): + u = r / ((p * v.unsqueeze(-2)).sum(-1) + eps) + v = c / ((p * u.unsqueeze(-1)).sum(-2) + eps) + p = p * u.unsqueeze(-1) * v.unsqueeze(-2) + return p + + +def sink_algorithm(M, dustbin, iteration): + M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1) + M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2) + r = torch.ones([M.shape[0], M.shape[1] - 1], device='cuda') + r = torch.cat([r, torch.ones([M.shape[0], 1], device='cuda') * M.shape[1]], dim=-1) + c = torch.ones([M.shape[0], M.shape[2] - 1], device='cuda') + c = torch.cat([c, torch.ones([M.shape[0], 1], device='cuda') * M.shape[2]], dim=-1) + p = sinkhorn(M, r, c, iteration) + return p + + +def normalize_keypoints(kpts, image_shape): + """ Normalize keypoints locations based on image image_shape""" + _, _, height, width = image_shape + one = kpts.new_tensor(1) + size = torch.stack([one * width, one * height])[None] + center = size / 2 + scaling = size.max(1, keepdim=True).values * 0.7 + return (kpts - center[:, None, :]) / scaling[:, None, :] + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + x = x.unflatten(-1, (-1, 2)) + x1, x2 = x.unbind(dim=-1) + return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2) + + +def apply_cached_rotary_emb( + freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + return (t * freqs[0]) + (rotate_half(t) * freqs[1]) + + +class LearnableFourierPositionalEncoding(nn.Module): + def __init__(self, M: int, dim: int, F_dim: int = None, + gamma: float = 1.0) -> None: + super().__init__() + F_dim = F_dim if F_dim is not None else dim + self.gamma = gamma + self.Wr = nn.Linear(M, F_dim // 2, bias=False) + nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ encode position vector """ + projected = self.Wr(x) + cosines, sines = torch.cos(projected), torch.sin(projected) + emb = torch.stack([cosines, sines], 0).unsqueeze(-3) + return emb.repeat_interleave(2, dim=-1) + + +class KeypointEncoder(nn.Module): + """ Joint encoding of visual appearance and location using MLPs""" + + def __init__(self): + super().__init__() + self.encoder = nn.Sequential( + nn.Linear(3, 32), + nn.LayerNorm(32, elementwise_affine=True), + nn.GELU(), + nn.Linear(32, 64), + nn.LayerNorm(64, elementwise_affine=True), + nn.GELU(), + nn.Linear(64, 128), + nn.LayerNorm(128, elementwise_affine=True), + nn.GELU(), + nn.Linear(128, 256), + ) + + def forward(self, kpts, scores): + inputs = [kpts, scores.unsqueeze(2)] # [B, N, 2] + [B, N, 1] + return self.encoder(torch.cat(inputs, dim=-1)) + + +class PoolingLayer(nn.Module): + def __init__(self, hidden_dim: int, score_dim: int = 2): + super().__init__() + + self.score_enc = nn.Sequential( + nn.Linear(score_dim, hidden_dim), + nn.LayerNorm(hidden_dim, elementwise_affine=True), + nn.GELU(), + nn.Linear(hidden_dim, hidden_dim), + ) + self.proj = nn.Linear(hidden_dim, hidden_dim) + self.predict = nn.Sequential( + nn.Linear(hidden_dim * 2, hidden_dim), + nn.LayerNorm(hidden_dim, elementwise_affine=True), + nn.GELU(), + nn.Linear(hidden_dim, 1), + ) + + def forward(self, x, score): + score_ = self.score_enc(score) + x_ = self.proj(x) + confidence = self.predict(torch.cat([x_, score_], -1)) + confidence = torch.sigmoid(confidence) + + return confidence + + +class Attention(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, q, k, v): + s = q.shape[-1] ** -0.5 + attn = F.softmax(torch.einsum('...id,...jd->...ij', q, k) * s, -1) + return torch.einsum('...ij,...jd->...id', attn, v), torch.mean(torch.mean(attn, dim=1), dim=1) + + +class SelfMultiHeadAttention(nn.Module): + def __init__(self, feat_dim: int, hidden_dim: int, num_heads: int): + super().__init__() + self.feat_dim = feat_dim + self.num_heads = num_heads + + assert feat_dim % num_heads == 0 + self.head_dim = feat_dim // num_heads + self.qkv = nn.Linear(feat_dim, hidden_dim * 3) + self.attn = Attention() + self.proj = nn.Linear(hidden_dim, hidden_dim) + self.mlp = nn.Sequential( + nn.Linear(feat_dim + hidden_dim, feat_dim * 2), + nn.LayerNorm(feat_dim * 2, elementwise_affine=True), + nn.GELU(), + nn.Linear(feat_dim * 2, feat_dim) + ) + + def forward_(self, x, encoding=None): + qkv = self.qkv(x) + qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2) + q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2] + if encoding is not None: + q = apply_cached_rotary_emb(encoding, q) + k = apply_cached_rotary_emb(encoding, k) + attn, attn_score = self.attn(q, k, v) + message = self.proj(attn.transpose(1, 2).flatten(start_dim=-2)) + return x + self.mlp(torch.cat([x, message], -1)), attn_score + + def forward(self, x0, x1, encoding0=None, encoding1=None): + x0_, att_score00 = self.forward_(x=x0, encoding=encoding0) + x1_, att_score11 = self.forward_(x=x1, encoding=encoding1) + return x0_, x1_, att_score00, att_score11 + + +class CrossMultiHeadAttention(nn.Module): + def __init__(self, feat_dim: int, hidden_dim: int, num_heads: int): + super().__init__() + self.feat_dim = feat_dim + self.num_heads = num_heads + assert hidden_dim % num_heads == 0 + dim_head = hidden_dim // num_heads + self.scale = dim_head ** -0.5 + self.to_qk = nn.Linear(feat_dim, hidden_dim) + self.to_v = nn.Linear(feat_dim, hidden_dim) + self.proj = nn.Linear(hidden_dim, hidden_dim) + self.mlp = nn.Sequential( + nn.Linear(feat_dim + hidden_dim, feat_dim * 2), + nn.LayerNorm(feat_dim * 2, elementwise_affine=True), + nn.GELU(), + nn.Linear(feat_dim * 2, feat_dim), + ) + + def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor): + return func(x0), func(x1) + + def forward(self, x0, x1): + qk0 = self.to_qk(x0) + qk1 = self.to_qk(x1) + v0 = self.to_v(x0) + v1 = self.to_v(x1) + + qk0, qk1, v0, v1 = map( + lambda t: t.unflatten(-1, (self.num_heads, -1)).transpose(1, 2), + (qk0, qk1, v0, v1)) + + qk0, qk1 = qk0 * self.scale ** 0.5, qk1 * self.scale ** 0.5 + sim = torch.einsum('b h i d, b h j d -> b h i j', qk0, qk1) + attn01 = F.softmax(sim, dim=-1) + attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1) + m0 = torch.einsum('bhij, bhjd -> bhid', attn01, v1) + m1 = torch.einsum('bhji, bhjd -> bhid', attn10.transpose(-2, -1), v0) + + m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), + m0, m1) + m0, m1 = self.map_(self.proj, m0, m1) + x0 = x0 + self.mlp(torch.cat([x0, m0], -1)) + x1 = x1 + self.mlp(torch.cat([x1, m1], -1)) + return x0, x1, torch.mean(torch.mean(attn10, dim=1), dim=1), torch.mean(torch.mean(attn01, dim=1), dim=1) + + +class AdaGML(nn.Module): + default_config = { + 'descriptor_dim': 128, + 'hidden_dim': 256, + 'weights': 'indoor', + 'keypoint_encoder': [32, 64, 128, 256], + 'GNN_layers': ['self', 'cross'] * 9, # [self, cross, self, cross, ...] 9 in total + 'sinkhorn_iterations': 20, + 'match_threshold': 0.2, + 'with_pose': True, + 'n_layers': 9, + 'n_min_tokens': 256, + 'with_sinkhorn': True, + 'min_confidence': 0.9, + + 'classification_background_weight': 0.05, + 'pretrained': True, + } + + def __init__(self, config): + super().__init__() + self.config = {**self.default_config, **config} + self.n_layers = self.config['n_layers'] + self.first_layer_pooling = 0 + self.n_min_tokens = self.config['n_min_tokens'] + self.min_confidence = self.config['min_confidence'] + self.classification_background_weight = self.config['classification_background_weight'] + + self.with_sinkhorn = self.config['with_sinkhorn'] + self.match_threshold = self.config['match_threshold'] + self.sinkhorn_iterations = self.config['sinkhorn_iterations'] + + self.input_proj = nn.Linear(self.config['descriptor_dim'], self.config['hidden_dim']) + + self.self_attn = nn.ModuleList( + [SelfMultiHeadAttention(feat_dim=self.config['hidden_dim'], + hidden_dim=self.config['hidden_dim'], + num_heads=4) for _ in range(self.n_layers)] + ) + self.cross_attn = nn.ModuleList( + [CrossMultiHeadAttention(feat_dim=self.config['hidden_dim'], + hidden_dim=self.config['hidden_dim'], + num_heads=4) for _ in range(self.n_layers)] + ) + + head_dim = self.config['hidden_dim'] // 4 + self.poseenc = LearnableFourierPositionalEncoding(2, head_dim, head_dim) + self.out_proj = nn.ModuleList( + [nn.Linear(self.config['hidden_dim'], self.config['hidden_dim']) for _ in range(self.n_layers)] + ) + + bin_score = torch.nn.Parameter(torch.tensor(1.)) + self.register_parameter('bin_score', bin_score) + + self.pooling = nn.ModuleList( + [PoolingLayer(score_dim=2, hidden_dim=self.config['hidden_dim']) for _ in range(self.n_layers)] + ) + # self.pretrained = config['pretrained'] + # if self.pretrained: + # bin_score.requires_grad = False + # for m in [self.input_proj, self.out_proj, self.poseenc, self.self_attn, self.cross_attn]: + # for p in m.parameters(): + # p.requires_grad = False + + def forward(self, data, mode=0): + if not self.training: + if mode == 0: + return self.produce_matches(data=data) + else: + return self.run(data=data) + return self.forward_train(data=data) + + def forward_train(self, data: dict, p=0.2, **kwargs): + pass + + def produce_matches(self, data: dict, p: float = 0.2, **kwargs): + desc0, desc1 = data['descriptors0'], data['descriptors1'] + kpts0, kpts1 = data['keypoints0'], data['keypoints1'] + scores0, scores1 = data['scores0'], data['scores1'] + + # Keypoint normalization. + if 'norm_keypoints0' in data.keys() and 'norm_keypoints1' in data.keys(): + norm_kpts0 = data['norm_keypoints0'] + norm_kpts1 = data['norm_keypoints1'] + elif 'image0' in data.keys() and 'image1' in data.keys(): + norm_kpts0 = normalize_keypoints(kpts0, data['image0'].shape) + norm_kpts1 = normalize_keypoints(kpts1, data['image1'].shape) + elif 'image_shape0' in data.keys() and 'image_shape1' in data.keys(): + norm_kpts0 = normalize_keypoints(kpts0, data['image_shape0']) + norm_kpts1 = normalize_keypoints(kpts1, data['image_shape1']) + else: + raise ValueError('Require image shape for keypoint coordinate normalization') + + desc0 = desc0.detach() # [B, N, D] + desc1 = desc1.detach() + + desc0 = self.input_proj(desc0) + desc1 = self.input_proj(desc1) + enc0 = self.poseenc(norm_kpts0) + enc1 = self.poseenc(norm_kpts1) + + nI = self.config['n_layers'] + nB = desc0.shape[0] + m = desc0.shape[1] + n = desc1.shape[1] + dev = desc0.device + + ind0 = torch.arange(0, m, device=dev)[None] + ind1 = torch.arange(0, n, device=dev)[None] + + do_pooling = True + + for ni in range(nI): + desc0, desc1, att_score00, att_score11 = self.self_attn[ni](desc0, desc1, enc0, enc1) + desc0, desc1, att_score01, att_score10 = self.cross_attn[ni](desc0, desc1) + + att_score0 = torch.cat([att_score00.unsqueeze(-1), att_score01.unsqueeze(-1)], dim=-1) + att_score1 = torch.cat([att_score11.unsqueeze(-1), att_score10.unsqueeze(-1)], dim=-1) + + conf0 = self.pooling[ni](desc0, att_score0).squeeze(-1) + conf1 = self.pooling[ni](desc1, att_score1).squeeze(-1) + + if do_pooling and ni >= 1: + if desc0.shape[1] >= self.n_min_tokens: + mask0 = conf0 > self.confidence_threshold(layer_index=ni) + ind0 = ind0[mask0][None] + desc0 = desc0[mask0][None] + enc0 = enc0[:, :, mask0][:, None] + + if desc1.shape[1] >= self.n_min_tokens: + mask1 = conf1 > self.confidence_threshold(layer_index=ni) + ind1 = ind1[mask1][None] + desc1 = desc1[mask1][None] + enc1 = enc1[:, :, mask1][:, None] + + # print('pooling: ', ni, desc0.shape, desc1.shape) + # print('ni: {:d}: pooling: {:.4f}'.format(ni, time.time() - t_start)) + # t_start = time.time() + if self.check_if_stop(confidences0=conf0, confidences1=conf1, layer_index=ni, num_points=m + n): + # print('ni:{:d}: checking: {:.4f}'.format(ni, time.time() - t_start)) + break + + if ni == nI: ni = nI - 1 + d = desc0.shape[-1] + mdesc0 = self.out_proj[ni](desc0) / d ** .25 + mdesc1 = self.out_proj[ni](desc1) / d ** .25 + + dist = torch.einsum('bmd,bnd->bmn', mdesc0, mdesc1) + score = self.compute_score(dist=dist, dustbin=self.bin_score, iteration=self.sinkhorn_iterations) + indices0, indices1, mscores0, mscores1 = self.compute_matches(scores=score, p=p) + valid = indices0 > -1 + m_indices0 = torch.where(valid)[1] + m_indices1 = indices0[valid] + + mind0 = ind0[0, m_indices0] + mind1 = ind1[0, m_indices1] + + indices0_full = torch.full((nB, m), -1, device=dev, dtype=indices0.dtype) + indices0_full[:, mind0] = mind1 + + mscores0_full = torch.zeros((nB, m), device=dev) + mscores0_full[:, ind0] = mscores0 + + indices0 = indices0_full + mscores0 = mscores0_full + + output = { + 'matches0': indices0, # use -1 for invalid match + # 'matches1': indices1, # use -1 for invalid match + 'matching_scores0': mscores0, + } + + return output + + def run(self, data, p=0.2): + desc0 = data['desc1'] + # print('desc0: ', torch.sum(desc0 ** 2, dim=-1)) + # desc0 = torch.nn.functional.normalize(desc0, dim=-1) + desc0 = desc0.detach() + + desc1 = data['desc2'] + # desc1 = torch.nn.functional.normalize(desc1, dim=-1) + desc1 = desc1.detach() + + kpts0 = data['x1'][:, :, :2] + kpts1 = data['x2'][:, :, :2] + # kpts0 = normalize_keypoints(kpts=kpts0, image_shape=data['image_shape1']) + # kpts1 = normalize_keypoints(kpts=kpts1, image_shape=data['image_shape2']) + scores0 = data['x1'][:, :, -1] + scores1 = data['x2'][:, :, -1] + + desc0 = self.input_proj(desc0) + desc1 = self.input_proj(desc1) + enc0 = self.poseenc(kpts0) + enc1 = self.poseenc(kpts1) + + nB = desc0.shape[0] + nI = self.n_layers + m, n = desc0.shape[1], desc1.shape[1] + dev = desc0.device + ind0 = torch.arange(0, m, device=dev)[None] + ind1 = torch.arange(0, n, device=dev)[None] + do_pooling = True + + for ni in range(nI): + desc0, desc1, att_score00, att_score11 = self.self_attn[ni](desc0, desc1, enc0, enc1) + desc0, desc1, att_score01, att_score10 = self.cross_attn[ni](desc0, desc1) + + att_score0 = torch.cat([att_score00.unsqueeze(-1), att_score01.unsqueeze(-1)], dim=-1) + att_score1 = torch.cat([att_score11.unsqueeze(-1), att_score10.unsqueeze(-1)], dim=-1) + + conf0 = self.pooling[ni](desc0, att_score0).squeeze(-1) + conf1 = self.pooling[ni](desc1, att_score1).squeeze(-1) + + if do_pooling and ni >= 1: + if desc0.shape[1] >= self.n_min_tokens: + mask0 = conf0 > self.confidence_threshold(layer_index=ni) + ind0 = ind0[mask0][None] + desc0 = desc0[mask0][None] + enc0 = enc0[:, :, mask0][:, None] + + if desc1.shape[1] >= self.n_min_tokens: + mask1 = conf1 > self.confidence_threshold(layer_index=ni) + ind1 = ind1[mask1][None] + desc1 = desc1[mask1][None] + enc1 = enc1[:, :, mask1][:, None] + if desc0.shape[1] <= 5 or desc1.shape[1] <= 5: + return { + 'index0': torch.zeros(size=(1,), device=desc0.device).long(), + 'index1': torch.zeros(size=(1,), device=desc1.device).long(), + } + + if self.check_if_stop(confidences0=conf0, confidences1=conf1, layer_index=ni, + num_points=m + n): + break + + if ni == nI: ni = -1 + d = desc0.shape[-1] + mdesc0 = self.out_proj[ni](desc0) / d ** .25 + mdesc1 = self.out_proj[ni](desc1) / d ** .25 + + dist = torch.einsum('bmd,bnd->bmn', mdesc0, mdesc1) + score = self.compute_score(dist=dist, dustbin=self.bin_score, iteration=self.sinkhorn_iterations) + indices0, indices1, mscores0, mscores1 = self.compute_matches(scores=score, p=p) + valid = indices0 > -1 + m_indices0 = torch.where(valid)[1] + m_indices1 = indices0[valid] + + mind0 = ind0[0, m_indices0] + mind1 = ind1[0, m_indices1] + + output = { + # 'p': score, + 'index0': mind0, + 'index1': mind1, + } + + return output + + def compute_score(self, dist, dustbin, iteration): + if self.with_sinkhorn: + score = sink_algorithm(M=dist, dustbin=dustbin, + iteration=iteration) # [nI * nB, N, M] + else: + score = dual_softmax(M=dist, dustbin=dustbin) + return score + + def compute_matches(self, scores, p=0.2): + max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1) + indices0, indices1 = max0.indices, max1.indices + mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0) + mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1) + zero = scores.new_tensor(0) + # mscores0 = torch.where(mutual0, max0.values.exp(), zero) + mscores0 = torch.where(mutual0, max0.values, zero) + mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero) + # valid0 = mutual0 & (mscores0 > self.config['match_threshold']) + valid0 = mutual0 & (mscores0 > p) + valid1 = mutual1 & valid0.gather(1, indices1) + indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1)) + indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1)) + + return indices0, indices1, mscores0, mscores1 + + def confidence_threshold(self, layer_index: int): + """scaled confidence threshold""" + # threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.n_layers) + threshold = 0.5 + 0.1 * np.exp(-4.0 * layer_index / self.n_layers) + return np.clip(threshold, 0, 1) + + def check_if_stop(self, + confidences0: torch.Tensor, + confidences1: torch.Tensor, + layer_index: int, num_points: int) -> torch.Tensor: + """ evaluate stopping condition""" + confidences = torch.cat([confidences0, confidences1], -1) + threshold = self.confidence_threshold(layer_index) + pos = 1.0 - (confidences < threshold).float().sum() / num_points + # print('check_stop: ', pos) + return pos > 0.95 + + def stop_iteration(self, m_last, n_last, m_current, n_current, confidence=0.975): + prob = (m_current + n_current) / (m_last + n_last) + # print('prob: ', prob) + return prob > confidence diff --git a/third_party/pram/nets/gm.py b/third_party/pram/nets/gm.py new file mode 100644 index 0000000000000000000000000000000000000000..232a364ce60acb49cb6af26b72a881cbec18c1a9 --- /dev/null +++ b/third_party/pram/nets/gm.py @@ -0,0 +1,264 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> gm +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 07/02/2024 10:47 +==================================================''' +import torch +import torch.nn as nn +import torch.nn.functional as F +from nets.layers import KeypointEncoder, AttentionalPropagation +from nets.utils import normalize_keypoints, arange_like + +eps = 1e-8 + + +def dual_softmax(M, dustbin): + M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1) + M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2) + score = torch.log_softmax(M, dim=-1) + torch.log_softmax(M, dim=1) + return torch.exp(score) + + +def sinkhorn(M, r, c, iteration): + p = torch.softmax(M, dim=-1) + u = torch.ones_like(r) + v = torch.ones_like(c) + for _ in range(iteration): + u = r / ((p * v.unsqueeze(-2)).sum(-1) + eps) + v = c / ((p * u.unsqueeze(-1)).sum(-2) + eps) + p = p * u.unsqueeze(-1) * v.unsqueeze(-2) + return p + + +def sink_algorithm(M, dustbin, iteration): + M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1) + M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2) + r = torch.ones([M.shape[0], M.shape[1] - 1], device='cuda') + r = torch.cat([r, torch.ones([M.shape[0], 1], device='cuda') * M.shape[1]], dim=-1) + c = torch.ones([M.shape[0], M.shape[2] - 1], device='cuda') + c = torch.cat([c, torch.ones([M.shape[0], 1], device='cuda') * M.shape[2]], dim=-1) + p = sinkhorn(M, r, c, iteration) + return p + + +class AttentionalGNN(nn.Module): + def __init__(self, feature_dim: int, layer_names: list, hidden_dim: int = 256, ac_fn: str = 'relu', + norm_fn: str = 'bn'): + super().__init__() + self.layers = nn.ModuleList([ + AttentionalPropagation(feature_dim=feature_dim, num_heads=4, hidden_dim=hidden_dim, ac_fn=ac_fn, + norm_fn=norm_fn) + for _ in range(len(layer_names))]) + self.names = layer_names + + def forward(self, desc0, desc1): + # desc0s = [] + # desc1s = [] + + for i, (layer, name) in enumerate(zip(self.layers, self.names)): + if name == 'cross': + src0, src1 = desc1, desc0 + else: + src0, src1 = desc0, desc1 + delta0 = layer(desc0, src0) + # prob0 = layer.attn.prob + delta1 = layer(desc1, src1) + # prob1 = layer.attn.prob + desc0, desc1 = (desc0 + delta0), (desc1 + delta1) + + # if name == 'cross': + # desc0s.append(desc0) + # desc1s.append(desc1) + return [desc0], [desc1] + + def predict(self, desc0, desc1, n_it=-1): + for i, (layer, name) in enumerate(zip(self.layers, self.names)): + if name == 'cross': + src0, src1 = desc1, desc0 + else: + src0, src1 = desc0, desc1 + delta0 = layer(desc0, src0) + # prob0 = layer.attn.prob + delta1 = layer(desc1, src1) + # prob1 = layer.attn.prob + desc0, desc1 = (desc0 + delta0), (desc1 + delta1) + + if name == 'cross' and i == n_it: + break + return [desc0], [desc1] + + +class GM(nn.Module): + default_config = { + 'descriptor_dim': 128, + 'hidden_dim': 256, + 'keypoint_encoder': [32, 64, 128, 256], + 'GNN_layers': ['self', 'cross'] * 9, # [self, cross, self, cross, ...] 9 in total + 'sinkhorn_iterations': 20, + 'match_threshold': 0.2, + 'with_pose': False, + 'n_layers': 9, + 'n_min_tokens': 256, + 'with_sinkhorn': True, + + 'ac_fn': 'relu', + 'norm_fn': 'bn', + 'weight_path': None, + } + + required_inputs = [ + 'image0', 'keypoints0', 'scores0', 'descriptors0', + 'image1', 'keypoints1', 'scores1', 'descriptors1', + ] + + def __init__(self, config): + super().__init__() + self.config = {**self.default_config, **config} + print('gm: ', self.config) + + self.n_layers = self.config['n_layers'] + + self.with_sinkhorn = self.config['with_sinkhorn'] + self.match_threshold = self.config['match_threshold'] + + self.sinkhorn_iterations = self.config['sinkhorn_iterations'] + self.kenc = KeypointEncoder( + self.config['descriptor_dim'] if self.config['descriptor_dim'] > 0 else 128, + self.config['keypoint_encoder'], + ac_fn=self.config['ac_fn'], + norm_fn=self.config['norm_fn']) + self.gnn = AttentionalGNN( + feature_dim=self.config['descriptor_dim'] if self.config['descriptor_dim'] > 0 else 128, + hidden_dim=self.config['hidden_dim'], + layer_names=self.config['GNN_layers'], + ac_fn=self.config['ac_fn'], + norm_fn=self.config['norm_fn'], + ) + + self.final_proj = nn.ModuleList([nn.Conv1d( + self.config['descriptor_dim'] if self.config['descriptor_dim'] > 0 else 128, + self.config['descriptor_dim'] if self.config['descriptor_dim'] > 0 else 128, + kernel_size=1, bias=True) for _ in range(self.n_layers)]) + + bin_score = torch.nn.Parameter(torch.tensor(1.)) + self.register_parameter('bin_score', bin_score) + + self.match_net = None # GraphLoss(config=self.config) + + self.self_prob0 = None + self.self_prob1 = None + self.cross_prob0 = None + self.cross_prob1 = None + + self.desc_compressor = None + + def forward_train(self, data): + pass + + def produce_matches(self, data, p=0.2, n_it=-1, **kwargs): + kpts0, kpts1 = data['keypoints0'], data['keypoints1'] + scores0, scores1 = data['scores0'], data['scores1'] + if kpts0.shape[1] == 0 or kpts1.shape[1] == 0: # no keypoints + shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1] + return { + 'matches0': kpts0.new_full(shape0, -1, dtype=torch.int)[0], + 'matches1': kpts1.new_full(shape1, -1, dtype=torch.int)[0], + 'matching_scores0': kpts0.new_zeros(shape0)[0], + 'matching_scores1': kpts1.new_zeros(shape1)[0], + 'skip_train': True + } + + if 'norm_keypoints0' in data.keys() and 'norm_keypoints1' in data.keys(): + norm_kpts0 = data['norm_keypoints0'] + norm_kpts1 = data['norm_keypoints1'] + elif 'image0' in data.keys() and 'image1' in data.keys(): + norm_kpts0 = normalize_keypoints(kpts0, data['image0'].shape) + norm_kpts1 = normalize_keypoints(kpts1, data['image1'].shape) + elif 'image_shape0' in data.keys() and 'image_shape1' in data.keys(): + norm_kpts0 = normalize_keypoints(kpts0, data['image_shape0']) + norm_kpts1 = normalize_keypoints(kpts1, data['image_shape1']) + else: + raise ValueError('Require image shape for keypoint coordinate normalization') + + # Keypoint MLP encoder. + enc0, enc1 = self.encode_keypoint(norm_kpts0=norm_kpts0, norm_kpts1=norm_kpts1, scores0=scores0, + scores1=scores1) + + if self.config['descriptor_dim'] > 0: + desc0, desc1 = data['descriptors0'], data['descriptors1'] + desc0 = desc0.transpose(0, 2, 1) # [B, N, D ] -> [B, D, N] + desc1 = desc1.transpose(0, 2, 1) # [B, N, D ] -> [B, D, N] + with torch.no_grad(): + if desc0.shape[1] != self.config['descriptor_dim']: + desc0 = self.desc_compressor(desc0) + if desc1.shape[1] != self.config['descriptor_dim']: + desc1 = self.desc_compressor(desc1) + desc0 = desc0 + enc0 + desc1 = desc1 + enc1 + else: + desc0 = enc0 + desc1 = enc1 + + desc0s, desc1s = self.gnn.predict(desc0, desc1, n_it=n_it) + + mdescs0 = self.final_proj[n_it](desc0s[-1]) + mdescs1 = self.final_proj[n_it](desc1s[-1]) + dist = torch.einsum('bdn,bdm->bnm', mdescs0, mdescs1) + if self.config['descriptor_dim'] > 0: + dist = dist / self.config['descriptor_dim'] ** .5 + else: + dist = dist / 128 ** .5 + score = self.compute_score(dist=dist, dustbin=self.bin_score, iteration=self.sinkhorn_iterations) + + indices0, indices1, mscores0, mscores1 = self.compute_matches(scores=score, p=p) + + output = { + 'matches0': indices0, # use -1 for invalid match + 'matches1': indices1, # use -1 for invalid match + 'matching_scores0': mscores0, + 'matching_scores1': mscores1, + } + + return output + + def forward(self, data, mode=0): + if not self.training: + return self.produce_matches(data=data, n_it=-1) + return self.forward_train(data=data) + + def encode_keypoint(self, norm_kpts0, norm_kpts1, scores0, scores1): + return self.kenc(norm_kpts0, scores0), self.kenc(norm_kpts1, scores1) + + def compute_distance(self, desc0, desc1, layer_id=-1): + mdesc0 = self.final_proj[layer_id](desc0) + mdesc1 = self.final_proj[layer_id](desc1) + dist = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1) + dist = dist / self.config['descriptor_dim'] ** .5 + return dist + + def compute_score(self, dist, dustbin, iteration): + if self.with_sinkhorn: + score = sink_algorithm(M=dist, dustbin=dustbin, + iteration=iteration) # [nI * nB, N, M] + else: + score = dual_softmax(M=dist, dustbin=dustbin) + return score + + def compute_matches(self, scores, p=0.2): + max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1) + indices0, indices1 = max0.indices, max1.indices + mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0) + mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1) + zero = scores.new_tensor(0) + # mscores0 = torch.where(mutual0, max0.values.exp(), zero) + mscores0 = torch.where(mutual0, max0.values, zero) + mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero) + # valid0 = mutual0 & (mscores0 > self.config['match_threshold']) + valid0 = mutual0 & (mscores0 > p) + valid1 = mutual1 & valid0.gather(1, indices1) + indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1)) + indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1)) + + return indices0, indices1, mscores0, mscores1 diff --git a/third_party/pram/nets/gml.py b/third_party/pram/nets/gml.py new file mode 100644 index 0000000000000000000000000000000000000000..996de5f01211e0a315f7f9b4ce35d561dfc74b2f --- /dev/null +++ b/third_party/pram/nets/gml.py @@ -0,0 +1,319 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> gml +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 07/02/2024 10:56 +==================================================''' +import torch +from torch import nn +import torch.nn.functional as F +from typing import Callable +from .utils import arange_like, normalize_keypoints +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +torch.backends.cudnn.deterministic = True + +eps = 1e-8 + + +def dual_softmax(M, dustbin): + M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1) + M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2) + score = torch.log_softmax(M, dim=-1) + torch.log_softmax(M, dim=1) + return torch.exp(score) + + +def sinkhorn(M, r, c, iteration): + p = torch.softmax(M, dim=-1) + u = torch.ones_like(r) + v = torch.ones_like(c) + for _ in range(iteration): + u = r / ((p * v.unsqueeze(-2)).sum(-1) + eps) + v = c / ((p * u.unsqueeze(-1)).sum(-2) + eps) + p = p * u.unsqueeze(-1) * v.unsqueeze(-2) + return p + + +def sink_algorithm(M, dustbin, iteration): + M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1) + M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2) + r = torch.ones([M.shape[0], M.shape[1] - 1], device=device) + r = torch.cat([r, torch.ones([M.shape[0], 1], device=device) * M.shape[1]], dim=-1) + c = torch.ones([M.shape[0], M.shape[2] - 1], device=device) + c = torch.cat([c, torch.ones([M.shape[0], 1], device=device) * M.shape[2]], dim=-1) + p = sinkhorn(M, r, c, iteration) + return p + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + x = x.unflatten(-1, (-1, 2)) + x1, x2 = x.unbind(dim=-1) + return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2) + + +def apply_cached_rotary_emb( + freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + return (t * freqs[0]) + (rotate_half(t) * freqs[1]) + + +class LearnableFourierPositionalEncoding(nn.Module): + def __init__(self, M: int, dim: int, F_dim: int = None, + gamma: float = 1.0) -> None: + super().__init__() + F_dim = F_dim if F_dim is not None else dim + self.gamma = gamma + self.Wr = nn.Linear(M, F_dim // 2, bias=False) + nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ encode position vector """ + projected = self.Wr(x) + cosines, sines = torch.cos(projected), torch.sin(projected) + emb = torch.stack([cosines, sines], 0).unsqueeze(-3) + return emb.repeat_interleave(2, dim=-1) + + +class KeypointEncoder(nn.Module): + """ Joint encoding of visual appearance and location using MLPs""" + + def __init__(self): + super().__init__() + self.encoder = nn.Sequential( + nn.Linear(3, 32), + nn.LayerNorm(32, elementwise_affine=True), + nn.GELU(), + nn.Linear(32, 64), + nn.LayerNorm(64, elementwise_affine=True), + nn.GELU(), + nn.Linear(64, 128), + nn.LayerNorm(128, elementwise_affine=True), + nn.GELU(), + nn.Linear(128, 256), + ) + + def forward(self, kpts, scores): + inputs = [kpts, scores.unsqueeze(2)] # [B, N, 2] + [B, N, 1] + return self.encoder(torch.cat(inputs, dim=-1)) + + +class Attention(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, q, k, v): + s = q.shape[-1] ** -0.5 + attn = F.softmax(torch.einsum('...id,...jd->...ij', q, k) * s, -1) + return torch.einsum('...ij,...jd->...id', attn, v) + + +class SelfMultiHeadAttention(nn.Module): + def __init__(self, feat_dim: int, hidden_dim: int, num_heads: int): + super().__init__() + self.feat_dim = feat_dim + self.num_heads = num_heads + + assert feat_dim % num_heads == 0 + self.head_dim = feat_dim // num_heads + self.qkv = nn.Linear(feat_dim, hidden_dim * 3) + self.attn = Attention() + self.proj = nn.Linear(hidden_dim, hidden_dim) + self.mlp = nn.Sequential( + nn.Linear(feat_dim + hidden_dim, feat_dim * 2), + nn.LayerNorm(feat_dim * 2, elementwise_affine=True), + nn.GELU(), + nn.Linear(feat_dim * 2, feat_dim) + ) + + def forward_(self, x, encoding=None): + qkv = self.qkv(x) + qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2) + q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2] + if encoding is not None: + q = apply_cached_rotary_emb(encoding, q) + k = apply_cached_rotary_emb(encoding, k) + attn = self.attn(q, k, v) + message = self.proj(attn.transpose(1, 2).flatten(start_dim=-2)) + return x + self.mlp(torch.cat([x, message], -1)) + + def forward(self, x0, x1, encoding0=None, encoding1=None): + return self.forward_(x0, encoding0), self.forward_(x1, encoding1) + + +class CrossMultiHeadAttention(nn.Module): + def __init__(self, feat_dim: int, hidden_dim: int, num_heads: int): + super().__init__() + self.feat_dim = feat_dim + self.num_heads = num_heads + assert hidden_dim % num_heads == 0 + dim_head = hidden_dim // num_heads + self.scale = dim_head ** -0.5 + self.to_qk = nn.Linear(feat_dim, hidden_dim) + self.to_v = nn.Linear(feat_dim, hidden_dim) + self.proj = nn.Linear(hidden_dim, hidden_dim) + self.mlp = nn.Sequential( + nn.Linear(feat_dim + hidden_dim, feat_dim * 2), + nn.LayerNorm(feat_dim * 2, elementwise_affine=True), + nn.GELU(), + nn.Linear(feat_dim * 2, feat_dim), + ) + + def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor): + return func(x0), func(x1) + + def forward(self, x0, x1): + qk0 = self.to_qk(x0) + qk1 = self.to_qk(x1) + v0 = self.to_v(x0) + v1 = self.to_v(x1) + + qk0, qk1, v0, v1 = map( + lambda t: t.unflatten(-1, (self.num_heads, -1)).transpose(1, 2), + (qk0, qk1, v0, v1)) + + qk0, qk1 = qk0 * self.scale ** 0.5, qk1 * self.scale ** 0.5 + sim = torch.einsum('b h i d, b h j d -> b h i j', qk0, qk1) + attn01 = F.softmax(sim, dim=-1) + attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1) + m0 = torch.einsum('bhij, bhjd -> bhid', attn01, v1) + m1 = torch.einsum('bhji, bhjd -> bhid', attn10.transpose(-2, -1), v0) + + m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), + m0, m1) + m0, m1 = self.map_(self.proj, m0, m1) + x0 = x0 + self.mlp(torch.cat([x0, m0], -1)) + x1 = x1 + self.mlp(torch.cat([x1, m1], -1)) + return x0, x1 + + +class GML(nn.Module): + ''' + the architecture of lightglue, but trained with imp + ''' + default_config = { + 'descriptor_dim': 128, + 'hidden_dim': 256, + 'weights': 'indoor', + 'keypoint_encoder': [32, 64, 128, 256], + 'GNN_layers': ['self', 'cross'] * 9, # [self, cross, self, cross, ...] 9 in total + 'sinkhorn_iterations': 20, + 'match_threshold': 0.2, + 'with_pose': False, + 'n_layers': 9, + 'n_min_tokens': 256, + 'with_sinkhorn': True, + + 'ac_fn': 'relu', + 'norm_fn': 'bn', + + } + + def __init__(self, config): + super().__init__() + self.config = {**self.default_config, **config} + self.n_layers = self.config['n_layers'] + + self.with_sinkhorn = self.config['with_sinkhorn'] + self.match_threshold = self.config['match_threshold'] + self.sinkhorn_iterations = self.config['sinkhorn_iterations'] + + self.input_proj = nn.Linear(self.config['descriptor_dim'], self.config['hidden_dim']) + + self.self_attn = nn.ModuleList( + [SelfMultiHeadAttention(feat_dim=self.config['hidden_dim'], + hidden_dim=self.config['hidden_dim'], + num_heads=4) for _ in range(self.n_layers)] + ) + self.cross_attn = nn.ModuleList( + [CrossMultiHeadAttention(feat_dim=self.config['hidden_dim'], + hidden_dim=self.config['hidden_dim'], + num_heads=4) for _ in range(self.n_layers)] + ) + + head_dim = self.config['hidden_dim'] // 4 + self.poseenc = LearnableFourierPositionalEncoding(2, head_dim, head_dim) + self.out_proj = nn.ModuleList( + [nn.Linear(self.config['hidden_dim'], self.config['hidden_dim']) for _ in range(self.n_layers)] + ) + + bin_score = torch.nn.Parameter(torch.tensor(1.)) + self.register_parameter('bin_score', bin_score) + + def forward(self, data, mode=0): + if not self.training: + return self.produce_matches(data=data) + return self.forward_train(data=data) + + def forward_train(self, data: dict, p=0.2, **kwargs): + pass + + def produce_matches(self, data: dict, p=0.2, **kwargs): + desc0, desc1 = data['descriptors0'], data['descriptors1'] + kpts0, kpts1 = data['keypoints0'], data['keypoints1'] + # Keypoint normalization. + if 'norm_keypoints0' in data.keys() and 'norm_keypoints1' in data.keys(): + norm_kpts0 = data['norm_keypoints0'] + norm_kpts1 = data['norm_keypoints1'] + elif 'image0' in data.keys() and 'image1' in data.keys(): + norm_kpts0 = normalize_keypoints(kpts0, data['image0'].shape).float() + norm_kpts1 = normalize_keypoints(kpts1, data['image1'].shape).float() + elif 'image_shape0' in data.keys() and 'image_shape1' in data.keys(): + norm_kpts0 = normalize_keypoints(kpts0, data['image_shape0']).float() + norm_kpts1 = normalize_keypoints(kpts1, data['image_shape1']).float() + else: + raise ValueError('Require image shape for keypoint coordinate normalization') + + desc0 = self.input_proj(desc0) + desc1 = self.input_proj(desc1) + enc0 = self.poseenc(norm_kpts0) + enc1 = self.poseenc(norm_kpts1) + + nI = self.n_layers + # nI = 5 + + for i in range(nI): + desc0, desc1 = self.self_attn[i](desc0, desc1, enc0, enc1) + desc0, desc1 = self.cross_attn[i](desc0, desc1) + + d = desc0.shape[-1] + mdesc0 = self.out_proj[nI - 1](desc0) / d ** .25 + mdesc1 = self.out_proj[nI - 1](desc1) / d ** .25 + + dist = torch.einsum('bmd,bnd->bmn', mdesc0, mdesc1) + + score = self.compute_score(dist=dist, dustbin=self.bin_score, iteration=self.sinkhorn_iterations) + indices0, indices1, mscores0, mscores1 = self.compute_matches(scores=score, p=p) + + output = { + 'matches0': indices0, # use -1 for invalid match + 'matches1': indices1, # use -1 for invalid match + 'matching_scores0': mscores0, + 'matching_scores1': mscores1, + } + + return output + + def compute_score(self, dist, dustbin, iteration): + if self.with_sinkhorn: + score = sink_algorithm(M=dist, dustbin=dustbin, + iteration=iteration) # [nI * nB, N, M] + else: + score = dual_softmax(M=dist, dustbin=dustbin) + return score + + def compute_matches(self, scores, p=0.2): + max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1) + indices0, indices1 = max0.indices, max1.indices + mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0) + mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1) + zero = scores.new_tensor(0) + # mscores0 = torch.where(mutual0, max0.values.exp(), zero) + mscores0 = torch.where(mutual0, max0.values, zero) + mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero) + # valid0 = mutual0 & (mscores0 > self.config['match_threshold']) + valid0 = mutual0 & (mscores0 > p) + valid1 = mutual1 & valid0.gather(1, indices1) + indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1)) + indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1)) + + return indices0, indices1, mscores0, mscores1 diff --git a/third_party/pram/nets/layers.py b/third_party/pram/nets/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..417488e6a163327895eb435567c4255c7827bca2 --- /dev/null +++ b/third_party/pram/nets/layers.py @@ -0,0 +1,109 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> layers +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 14:46 +==================================================''' +import torch +import torch.nn as nn +import torch.nn.functional as F +from copy import deepcopy +from einops import rearrange + + +def MLP(channels: list, do_bn=True, ac_fn='relu', norm_fn='bn'): + """ Multi-layer perceptron """ + n = len(channels) + layers = [] + for i in range(1, n): + layers.append( + nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True)) + if i < (n - 1): + if norm_fn == 'in': + layers.append(nn.InstanceNorm1d(channels[i], eps=1e-3)) + elif norm_fn == 'bn': + layers.append(nn.BatchNorm1d(channels[i], eps=1e-3)) + if ac_fn == 'relu': + layers.append(nn.ReLU()) + elif ac_fn == 'gelu': + layers.append(nn.GELU()) + elif ac_fn == 'lrelu': + layers.append(nn.LeakyReLU(negative_slope=0.1)) + # if norm_fn == 'ln': + # layers.append(nn.LayerNorm(channels[i])) + return nn.Sequential(*layers) + + +class MultiHeadedAttention(nn.Module): + def __init__(self, num_heads: int, d_model: int): + super().__init__() + assert d_model % num_heads == 0 + self.dim = d_model // num_heads + self.num_heads = num_heads + self.merge = nn.Conv1d(d_model, d_model, kernel_size=1) + self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)]) + + def forward(self, query, key, value, M=None): + ''' + :param query: [B, D, N] + :param key: [B, D, M] + :param value: [B, D, M] + :param M: [B, N, M] + :return: + ''' + + batch_dim = query.size(0) + query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1) + for l, x in zip(self.proj, (query, key, value))] # [B, D, NH, N] + dim = query.shape[1] + scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim ** .5 + + if M is not None: + # print('M: ', scores.shape, M.shape, torch.sum(M, dim=2)) + # scores = scores * M[:, None, :, :].expand_as(scores) + # with torch.no_grad(): + mask = (1 - M[:, None, :, :]).repeat(1, scores.shape[1], 1, 1).bool() # [B, H, N, M] + scores = scores.masked_fill(mask, -torch.finfo(scores.dtype).max) + prob = F.softmax(scores, dim=-1) # * (~mask).float() # * mask.float() + else: + prob = F.softmax(scores, dim=-1) + + x = torch.einsum('bhnm,bdhm->bdhn', prob, value) + self.prob = prob + + out = self.merge(x.contiguous().view(batch_dim, self.dim * self.num_heads, -1)) + + return out + + +class AttentionalPropagation(nn.Module): + def __init__(self, feature_dim: int, num_heads: int, ac_fn='relu', norm_fn='bn'): + super().__init__() + self.attn = MultiHeadedAttention(num_heads, feature_dim) + self.mlp = MLP([feature_dim * 2, feature_dim * 2, feature_dim], ac_fn=ac_fn, norm_fn=norm_fn) + nn.init.constant_(self.mlp[-1].bias, 0.0) + + def forward(self, x, source, M=None): + message = self.attn(x, source, source, M=M) + self.prob = self.attn.prob + + out = self.mlp(torch.cat([x, message], dim=1)) + return out + + +class KeypointEncoder(nn.Module): + """ Joint encoding of visual appearance and location using MLPs""" + + def __init__(self, input_dim, feature_dim, layers, ac_fn='relu', norm_fn='bn'): + super().__init__() + self.input_dim = input_dim + self.encoder = MLP([input_dim] + layers + [feature_dim], ac_fn=ac_fn, norm_fn=norm_fn) + nn.init.constant_(self.encoder[-1].bias, 0.0) + + def forward(self, kpts, scores=None): + if self.input_dim == 2: + return self.encoder(kpts.transpose(1, 2)) + else: + inputs = [kpts.transpose(1, 2), scores.unsqueeze(1)] # [B, 2, N] + [B, 1, N] + return self.encoder(torch.cat(inputs, dim=1)) diff --git a/third_party/pram/nets/load_segnet.py b/third_party/pram/nets/load_segnet.py new file mode 100644 index 0000000000000000000000000000000000000000..51b8c5bc3fc1c25a8e52dd21cc6f3f4e79b418aa --- /dev/null +++ b/third_party/pram/nets/load_segnet.py @@ -0,0 +1,31 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> load_segnet +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 09/04/2024 15:39 +==================================================''' +from nets.segnet import SegNet +from nets.segnetvit import SegNetViT + + +def load_segnet(network, n_class, desc_dim, n_layers, output_dim): + model_config = { + 'network': { + 'descriptor_dim': desc_dim, + 'n_layers': n_layers, + 'n_class': n_class, + 'output_dim': output_dim, + 'with_score': False, + } + } + + if network == 'segnet': + model = SegNet(model_config.get('network', {})) + # config['with_cls'] = False + elif network == 'segnetvit': + model = SegNetViT(model_config.get('network', {})) + else: + raise 'ERROR! {:s} model does not exist'.format(config['network']) + + return model diff --git a/third_party/pram/nets/retnet.py b/third_party/pram/nets/retnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c4f3346fcd82193683ec72d0e55a2429d18a974b --- /dev/null +++ b/third_party/pram/nets/retnet.py @@ -0,0 +1,174 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> retnet +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 22/02/2024 15:23 +==================================================''' +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File glretrieve -> retnet +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 15/02/2024 10:55 +==================================================''' +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class ResBlock(nn.Module): + def __init__(self, inplanes, outplanes, stride=1, groups=32, dilation=1, norm_layer=None, ac_fn=None): + super(ResBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self.conv1 = conv1x1(inplanes, outplanes) + self.bn1 = norm_layer(outplanes) + self.conv2 = conv3x3(outplanes, outplanes, stride, groups, dilation) + self.bn2 = norm_layer(outplanes) + self.conv3 = conv1x1(outplanes, outplanes) + self.bn3 = norm_layer(outplanes) + if ac_fn is None: + self.ac_fn = nn.ReLU(inplace=True) + else: + self.ac_fn = ac_fn + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.ac_fn(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.ac_fn(out) + + out = self.conv3(out) + out = self.bn3(out) + + out += identity + out = self.ac_fn(out) + + return out + + +class GeneralizedMeanPooling(nn.Module): + r"""Applies a 2D power-average adaptive pooling over an input signal composed of several input planes. + The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)` + - At p = infinity, one gets Max Pooling + - At p = 1, one gets Average Pooling + The output is of size H x W, for any input size. + The number of output features is equal to the number of input planes. + Args: + output_size: the target output size of the image of the form H x W. + Can be a tuple (H, W) or a single H for a square image H x H + H and W can be either a ``int``, or ``None`` which means the size will + be the same as that of the input. + """ + + def __init__(self, norm, output_size=1, eps=1e-6): + super(GeneralizedMeanPooling, self).__init__() + assert norm > 0 + self.p = float(norm) + self.output_size = output_size + self.eps = eps + + def forward(self, x): + x = x.clamp(min=self.eps).pow(self.p) + return torch.nn.functional.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p) + + def __repr__(self): + return self.__class__.__name__ + '(' \ + + str(self.p) + ', ' \ + + 'output_size=' + str(self.output_size) + ')' + + +class GeneralizedMeanPoolingP(GeneralizedMeanPooling): + """ Same, but norm is trainable + """ + + def __init__(self, norm=3, output_size=1, eps=1e-6): + super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps) + self.p = nn.Parameter(torch.ones(1) * norm) + + +class Flatten(nn.Module): + def forward(self, input): + return input.view(input.size(0), -1) + + +class L2Norm(nn.Module): + def __init__(self, dim=1): + super().__init__() + self.dim = dim + + def forward(self, input): + return F.normalize(input, p=2, dim=self.dim) + + +class RetNet(nn.Module): + def __init__(self, indim=256, outdim=1024): + super().__init__() + + ac_fn = nn.GELU() + + self.convs = nn.Sequential( + # no batch normalization + + nn.Conv2d(in_channels=indim, out_channels=512, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(512), + # nn.ReLU(), + + ResBlock(512, 512, groups=32, stride=1, ac_fn=ac_fn), + ResBlock(512, 512, groups=32, stride=1, ac_fn=ac_fn), + + nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(1024), + # nn.ReLU(), + ResBlock(inplanes=1024, outplanes=1024, groups=32, stride=1, ac_fn=ac_fn), + ResBlock(inplanes=1024, outplanes=1024, groups=32, stride=1, ac_fn=ac_fn), + ) + + self.pool = GeneralizedMeanPoolingP() + self.fc = nn.Linear(1024, out_features=outdim) + + def initialize(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + out = self.convs(x) + out = self.pool(out).reshape(x.shape[0], -1) + out = self.fc(out) + out = F.normalize(out, p=2, dim=1) + return out + + +if __name__ == '__main__': + mode = RetNet(indim=256, outdim=1024) + state_dict = mode.state_dict() + keys = state_dict.keys() + print(keys) + shapes = [state_dict[v].shape for v in keys] + print(shapes) diff --git a/third_party/pram/nets/segnet.py b/third_party/pram/nets/segnet.py new file mode 100644 index 0000000000000000000000000000000000000000..632a38cb83ca77a23b5c1e1276996bd5574c3a0b --- /dev/null +++ b/third_party/pram/nets/segnet.py @@ -0,0 +1,120 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> segnet +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 14:46 +==================================================''' +import torch +import torch.nn as nn +import torch.nn.functional as F +from nets.layers import MLP, KeypointEncoder +from nets.layers import AttentionalPropagation +from nets.utils import normalize_keypoints + + +class SegGNN(nn.Module): + def __init__(self, feature_dim: int, n_layers: int, ac_fn: str = 'relu', norm_fn: str = 'bn', **kwargs): + super().__init__() + self.layers = nn.ModuleList([ + AttentionalPropagation(feature_dim, 4, ac_fn=ac_fn, norm_fn=norm_fn) + for _ in range(n_layers) + ]) + + def forward(self, desc): + for i, layer in enumerate(self.layers): + delta = layer(desc, desc) + desc = desc + delta + + return desc + + +class SegNet(nn.Module): + default_config = { + 'descriptor_dim': 256, + 'output_dim': 1024, + 'n_class': 512, + 'keypoint_encoder': [32, 64, 128, 256], + 'n_layers': 9, + 'ac_fn': 'relu', + 'norm_fn': 'in', + 'with_score': False, + # 'with_global': False, + 'with_cls': False, + 'with_sc': False, + } + + def __init__(self, config={}): + super().__init__() + self.config = {**self.default_config, **config} + self.with_cls = self.config['with_cls'] + self.with_sc = self.config['with_sc'] + + self.n_layers = self.config['n_layers'] + self.gnn = SegGNN( + feature_dim=self.config['descriptor_dim'], + n_layers=self.config['n_layers'], + ac_fn=self.config['ac_fn'], + norm_fn=self.config['norm_fn'], + ) + + self.with_score = self.config['with_score'] + self.kenc = KeypointEncoder( + input_dim=3 if self.with_score else 2, + feature_dim=self.config['descriptor_dim'], + layers=self.config['keypoint_encoder'], + ac_fn=self.config['ac_fn'], + norm_fn=self.config['norm_fn'] + ) + + self.seg = MLP(channels=[self.config['descriptor_dim'], + self.config['output_dim'], + self.config['n_class']], + ac_fn=self.config['ac_fn'], + norm_fn=self.config['norm_fn'] + ) + + if self.with_sc: + self.sc = MLP(channels=[self.config['descriptor_dim'], + self.config['output_dim'], + 3], + ac_fn=self.config['ac_fn'], + norm_fn=self.config['norm_fn'] + ) + + def preprocess(self, data): + desc0 = data['seg_descriptors'] + desc0 = desc0.transpose(1, 2) # [B, N, D] - > [B, D, N] + + if 'norm_keypoints' in data.keys(): + norm_kpts0 = data['norm_keypoints'] + elif 'image' in data.keys(): + kpts0 = data['keypoints'] + norm_kpts0 = normalize_keypoints(kpts0, data['image'].shape) + else: + raise ValueError('Require image shape for keypoint coordinate normalization') + + # Keypoint MLP encoder. + if self.with_score: + scores0 = data['scores'] + else: + scores0 = None + enc0 = self.kenc(norm_kpts0, scores0) + + return desc0, enc0 + + def forward(self, data): + desc, enc = self.preprocess(data=data) + desc = desc + enc + + desc = self.gnn(desc) + cls_output = self.seg(desc) # [B, C, N] + output = { + 'prediction': cls_output.transpose(-1, -2).contiguous(), + } + + if self.with_sc: + sc_output = self.sc(desc) + output['sc'] = sc_output + + return output diff --git a/third_party/pram/nets/segnetvit.py b/third_party/pram/nets/segnetvit.py new file mode 100644 index 0000000000000000000000000000000000000000..7919b545c26d3098df84d2e8e909d7ed69809dcd --- /dev/null +++ b/third_party/pram/nets/segnetvit.py @@ -0,0 +1,203 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> segnetvit +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 14:52 +==================================================''' + +import torch +from torch import nn +import torch.nn.functional as F +from nets.utils import normalize_keypoints + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + x = x.unflatten(-1, (-1, 2)) + x1, x2 = x.unbind(dim=-1) + return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2) + + +def apply_cached_rotary_emb( + freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + return (t * freqs[0]) + (rotate_half(t) * freqs[1]) + + +class LearnableFourierPositionalEncoding(nn.Module): + def __init__(self, M: int, dim: int, F_dim: int = None, + gamma: float = 1.0) -> None: + super().__init__() + F_dim = F_dim if F_dim is not None else dim + self.gamma = gamma + self.Wr = nn.Linear(M, F_dim // 2, bias=False) + nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ encode position vector """ + projected = self.Wr(x) + cosines, sines = torch.cos(projected), torch.sin(projected) + emb = torch.stack([cosines, sines], 0).unsqueeze(-3) + return emb.repeat_interleave(2, dim=-1) + + +class KeypointEncoder(nn.Module): + """ Joint encoding of visual appearance and location using MLPs""" + + def __init__(self): + super().__init__() + self.encoder = nn.Sequential( + nn.Linear(2, 32), + nn.LayerNorm(32, elementwise_affine=True), + nn.GELU(), + nn.Linear(32, 64), + nn.LayerNorm(64, elementwise_affine=True), + nn.GELU(), + nn.Linear(64, 128), + nn.LayerNorm(128, elementwise_affine=True), + nn.GELU(), + nn.Linear(128, 256), + ) + + def forward(self, kpts, scores=None): + if scores is not None: + inputs = [kpts, scores.unsqueeze(2)] # [B, N, 2] + [B, N, 1] + return self.encoder(torch.cat(inputs, dim=-1)) + else: + return self.encoder(kpts) + + +class Attention(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, q, k, v): + s = q.shape[-1] ** -0.5 + attn = F.softmax(torch.einsum('...id,...jd->...ij', q, k) * s, -1) + return torch.einsum('...ij,...jd->...id', attn, v) + + +class SelfMultiHeadAttention(nn.Module): + def __init__(self, feat_dim: int, hidden_dim: int, num_heads: int): + super().__init__() + self.feat_dim = feat_dim + self.num_heads = num_heads + + assert feat_dim % num_heads == 0 + self.head_dim = feat_dim // num_heads + self.qkv = nn.Linear(feat_dim, hidden_dim * 3) + self.attn = Attention() + self.proj = nn.Linear(hidden_dim, hidden_dim) + self.mlp = nn.Sequential( + nn.Linear(feat_dim + hidden_dim, feat_dim * 2), + nn.LayerNorm(feat_dim * 2, elementwise_affine=True), + nn.GELU(), + nn.Linear(feat_dim * 2, feat_dim) + ) + + def forward(self, x, encoding=None): + qkv = self.qkv(x) + qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2) + q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2] + if encoding is not None: + q = apply_cached_rotary_emb(encoding, q) + k = apply_cached_rotary_emb(encoding, k) + attn = self.attn(q, k, v) + message = self.proj(attn.transpose(1, 2).flatten(start_dim=-2)) + return x + self.mlp(torch.cat([x, message], -1)) + + +class SegGNNViT(nn.Module): + def __init__(self, feature_dim: int, n_layers: int, hidden_dim: int = 256, num_heads: int = 4, **kwargs): + super(SegGNNViT, self).__init__() + self.layers = nn.ModuleList([ + SelfMultiHeadAttention(feat_dim=feature_dim, hidden_dim=hidden_dim, num_heads=num_heads) + for _ in range(n_layers) + ]) + + def forward(self, desc, encoding=None): + for i, layer in enumerate(self.layers): + desc = layer(desc, encoding) + # desc = desc + delta // should be removed as this is already done in self-attention + return desc + + +class SegNetViT(nn.Module): + default_config = { + 'descriptor_dim': 256, + 'output_dim': 1024, + 'n_class': 512, + 'keypoint_encoder': [32, 64, 128, 256], + 'n_layers': 15, + 'num_heads': 4, + 'hidden_dim': 256, + 'with_score': False, + 'with_global': False, + 'with_cls': False, + 'with_sc': False, + } + + def __init__(self, config={}): + super(SegNetViT, self).__init__() + self.config = {**self.default_config, **config} + self.with_cls = self.config['with_cls'] + self.with_sc = self.config['with_sc'] + + self.n_layers = self.config['n_layers'] + self.gnn = SegGNNViT( + feature_dim=self.config['hidden_dim'], + n_layers=self.config['n_layers'], + hidden_dim=self.config['hidden_dim'], + num_heads=self.config['num_heads'], + ) + + self.with_score = self.config['with_score'] + self.kenc = LearnableFourierPositionalEncoding(2, self.config['hidden_dim'] // self.config['num_heads'], + self.config['hidden_dim'] // self.config['num_heads']) + + self.input_proj = nn.Linear(in_features=self.config['descriptor_dim'], + out_features=self.config['hidden_dim']) + self.seg = nn.Sequential( + nn.Linear(in_features=self.config['hidden_dim'], out_features=self.config['output_dim']), + nn.LayerNorm(self.config['output_dim'], elementwise_affine=True), + nn.GELU(), + nn.Linear(self.config['output_dim'], self.config['n_class']) + ) + + if self.with_sc: + self.sc = nn.Sequential( + nn.Linear(in_features=config['hidden_dim'], out_features=self.config['output_dim']), + nn.LayerNorm(self.config['output_dim'], elementwise_affine=True), + nn.GELU(), + nn.Linear(self.config['output_dim'], 3) + ) + + def preprocess(self, data): + desc0 = data['seg_descriptors'] + if 'norm_keypoints' in data.keys(): + norm_kpts0 = data['norm_keypoints'] + elif 'image' in data.keys(): + kpts0 = data['keypoints'] + norm_kpts0 = normalize_keypoints(kpts0, data['image'].shape) + else: + raise ValueError('Require image shape for keypoint coordinate normalization') + + enc0 = self.kenc(norm_kpts0) + + return desc0, enc0 + + def forward(self, data): + desc, enc = self.preprocess(data=data) + desc = self.input_proj(desc) + + desc = self.gnn(desc, enc) + seg_output = self.seg(desc) # [B, N, C] + + output = { + 'prediction': seg_output, + } + + if self.with_sc: + sc_output = self.sc(desc) + output['sc'] = sc_output + + return output diff --git a/third_party/pram/nets/sfd2.py b/third_party/pram/nets/sfd2.py new file mode 100644 index 0000000000000000000000000000000000000000..d9c5a099b001ed9cf9e8a82b1b77dc9f7d9e31c8 --- /dev/null +++ b/third_party/pram/nets/sfd2.py @@ -0,0 +1,596 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> sfd2 +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 07/02/2024 14:53 +==================================================''' +import torch +import torch.nn as nn +import numpy as np +import torch.nn.functional as F +import torchvision.transforms as tvf + +RGB_mean = [0.485, 0.456, 0.406] +RGB_std = [0.229, 0.224, 0.225] + +norm_RGB = tvf.Compose([tvf.Normalize(mean=RGB_mean, std=RGB_std)]) + + +def simple_nms(scores, nms_radius: int): + """ Fast Non-maximum suppression to remove nearby points """ + assert (nms_radius >= 0) + + def max_pool(x): + return torch.nn.functional.max_pool2d( + x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius) + + zeros = torch.zeros_like(scores) + max_mask = scores == max_pool(scores) + for _ in range(2): + supp_mask = max_pool(max_mask.float()) > 0 + supp_scores = torch.where(supp_mask, zeros, scores) + new_max_mask = supp_scores == max_pool(supp_scores) + max_mask = max_mask | (new_max_mask & (~supp_mask)) + return torch.where(max_mask, scores, zeros) + + +def remove_borders(keypoints, scores, border: int, height: int, width: int): + """ Removes keypoints too close to the border """ + mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border)) + mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border)) + mask = mask_h & mask_w + return keypoints[mask], scores[mask] + + +def top_k_keypoints(keypoints, scores, k: int): + if k >= len(keypoints): + return keypoints, scores + scores, indices = torch.topk(scores, k, dim=0) + return keypoints[indices], scores + + +def sample_descriptors(keypoints, descriptors, s: int = 8): + """ Interpolate descriptors at keypoint locations """ + b, c, h, w = descriptors.shape + keypoints = keypoints - s / 2 + 0.5 + keypoints /= torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)], + ).to(keypoints)[None] + keypoints = keypoints * 2 - 1 # normalize to (-1, 1) + descriptors = torch.nn.functional.grid_sample( + descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', align_corners=True) + descriptors = torch.nn.functional.normalize( + descriptors.reshape(b, c, -1), p=2, dim=1) + return descriptors + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +def conv(in_channels, out_channels, kernel_size=3, stride=1, padding=1, use_bn=False, groups=1, dilation=1): + if not use_bn: + return nn.Sequential( + nn.Conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, dilation=dilation), + nn.ReLU(inplace=True), + ) + else: + return nn.Sequential( + nn.Conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, dilation=dilation), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + + +class ResBlock(nn.Module): + def __init__(self, inplanes, outplanes, stride=1, groups=32, dilation=1, norm_layer=None): + super(ResBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self.conv1 = conv1x1(inplanes, outplanes) + self.bn1 = norm_layer(outplanes) + self.conv2 = conv3x3(outplanes, outplanes, stride, groups, dilation) + self.bn2 = norm_layer(outplanes) + self.conv3 = conv1x1(outplanes, outplanes) + self.bn3 = norm_layer(outplanes) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + out += identity + out = self.relu(out) + + return out + + +class ResNet4x(nn.Module): + default_config = { + 'conf_th': 0.005, + 'remove_borders': 4, + 'min_keypoints': 128, + 'max_keypoints': 4096, + } + + def __init__(self, inputdim=3, outdim=128, desc_compressor=None): + super().__init__() + self.outdim = outdim + self.desc_compressor = desc_compressor + + d1, d2, d3, d4, d5, d6 = 64, 128, 256, 256, 256, 256 + self.conv1a = conv(in_channels=inputdim, out_channels=d1, kernel_size=3, use_bn=True) + self.conv1b = conv(in_channels=d1, out_channels=d1, kernel_size=3, stride=2, use_bn=True) + + self.conv2a = conv(in_channels=d1, out_channels=d2, kernel_size=3, use_bn=True) + self.conv2b = conv(in_channels=d2, out_channels=d2, kernel_size=3, stride=2, use_bn=True) + + self.conv3a = conv(in_channels=d2, out_channels=d3, kernel_size=3, use_bn=True) + self.conv3b = conv(in_channels=d3, out_channels=d3, kernel_size=3, use_bn=True) + + self.conv4 = nn.Sequential( + ResBlock(inplanes=256, outplanes=256, groups=32), + ResBlock(inplanes=256, outplanes=256, groups=32), + ResBlock(inplanes=256, outplanes=256, groups=32), + ) + + self.convPa = nn.Sequential( + torch.nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + torch.nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), + ) + self.convDa = nn.Sequential( + nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + ) + + self.convPb = torch.nn.Conv2d(256, 65, kernel_size=1, stride=1, padding=0) + self.convDb = torch.nn.Conv2d(256, outdim, kernel_size=1, stride=1, padding=0) + + def det(self, x): + out1a = self.conv1a(x) + out1b = self.conv1b(out1a) + + out2a = self.conv2a(out1b) + out2b = self.conv2b(out2a) + + out3a = self.conv3a(out2b) + out3b = self.conv3b(out3a) + + out4 = self.conv4(out3b) + + cPa = self.convPa(out4) + logits = self.convPb(cPa) + full_semi = torch.softmax(logits, dim=1) + semi = full_semi[:, :-1, :, :] + Hc, Wc = semi.size(2), semi.size(3) + score = semi.permute([0, 2, 3, 1]) + score = score.view(score.size(0), Hc, Wc, 8, 8) + score = score.permute([0, 1, 3, 2, 4]) + score = score.contiguous().view(score.size(0), Hc * 8, Wc * 8) + + # Descriptor Head + cDa = self.convDa(out4) + desc = self.convDb(cDa) + desc = F.normalize(desc, dim=1) + + return score, desc + + def forward(self, batch): + out1a = self.conv1a(batch['image']) + out1b = self.conv1b(out1a) + + out2a = self.conv2a(out1b) + out2b = self.conv2b(out2a) + + out3a = self.conv3a(out2b) + out3b = self.conv3b(out3a) + + out4 = self.conv4(out3b) + + cPa = self.convPa(out4) + logits = self.convPb(cPa) + full_semi = torch.softmax(logits, dim=1) + semi = full_semi[:, :-1, :, :] + Hc, Wc = semi.size(2), semi.size(3) + score = semi.permute([0, 2, 3, 1]) + score = score.view(score.size(0), Hc, Wc, 8, 8) + score = score.permute([0, 1, 3, 2, 4]) + score = score.contiguous().view(score.size(0), Hc * 8, Wc * 8) + + # Descriptor Head + cDa = self.convDa(out4) + desc = self.convDb(cDa) + desc = F.normalize(desc, dim=1) + + return { + 'dense_features': desc, + 'scores': score, + 'logits': logits, + 'semi_map': semi, + } + + def extract_patches(self, batch): + out1a = self.conv1a(batch['image']) + out1b = self.conv1b(out1a) + + out2a = self.conv2a(out1b) + out2b = self.conv2b(out2a) + + out3a = self.conv3a(out2b) + out3b = self.conv3b(out3a) + + out4 = self.conv4(out3b) + + cPa = self.convPa(out4) + logits = self.convPb(cPa) + full_semi = torch.softmax(logits, dim=1) + semi = full_semi[:, :-1, :, :] + Hc, Wc = semi.size(2), semi.size(3) + score = semi.permute([0, 2, 3, 1]) + score = score.view(score.size(0), Hc, Wc, 8, 8) + score = score.permute([0, 1, 3, 2, 4]) + score = score.contiguous().view(score.size(0), Hc * 8, Wc * 8) + + # Descriptor Head + cDa = self.convDa(out4) + desc = self.convDb(cDa) + desc = F.normalize(desc, dim=1) + + return { + 'dense_features': desc, + 'scores': score, + 'logits': logits, + 'semi_map': semi, + } + + def extract_local_global(self, data, + config={ + 'conf_th': 0.005, + 'remove_borders': 4, + 'min_keypoints': 128, + 'max_keypoints': 4096, + } + ): + + config = {**self.default_config, **config} + + b, ic, ih, iw = data['image'].shape + out1a = self.conv1a(data['image']) + out1b = self.conv1b(out1a) # 64 + + out2a = self.conv2a(out1b) + out2b = self.conv2b(out2a) # 128 + + out3a = self.conv3a(out2b) + out3b = self.conv3b(out3a) # 256 + + out4 = self.conv4(out3b) # 256 + + cPa = self.convPa(out4) + logits = self.convPb(cPa) + full_semi = torch.softmax(logits, dim=1) + semi = full_semi[:, :-1, :, :] + Hc, Wc = semi.size(2), semi.size(3) + score = semi.permute([0, 2, 3, 1]) + score = score.view(score.size(0), Hc, Wc, 8, 8) + score = score.permute([0, 1, 3, 2, 4]) + score = score.contiguous().view(score.size(0), Hc * 8, Wc * 8) + if Hc * 8 != ih or Wc * 8 != iw: + score = F.interpolate(score.unsqueeze(1), size=[ih, iw], align_corners=True, mode='bilinear') + score = score.squeeze(1) + # extract keypoints + nms_scores = simple_nms(scores=score, nms_radius=4) + keypoints = [ + torch.nonzero(s >= config['conf_th']) + for s in nms_scores] + scores = [s[tuple(k.t())] for s, k in zip(nms_scores, keypoints)] + + if len(scores[0]) <= config['min_keypoints']: + keypoints = [ + torch.nonzero(s >= config['conf_th'] * 0.5) + for s in nms_scores] + scores = [s[tuple(k.t())] for s, k in zip(nms_scores, keypoints)] + + # Discard keypoints near the image borders + keypoints, scores = list(zip(*[ + remove_borders(k, s, config['remove_borders'], ih, iw) + for k, s in zip(keypoints, scores)])) + + # Keep the k keypoints with highest score + if config['max_keypoints'] >= 0: + keypoints, scores = list(zip(*[ + top_k_keypoints(k, s, config['max_keypoints']) + for k, s in zip(keypoints, scores)])) + + # Convert (h, w) to (x, y) + keypoints = [torch.flip(k, [1]).float() for k in keypoints] + # Descriptor Head + cDa = self.convDa(out4) + desc_map = self.convDb(cDa) + desc_map = F.normalize(desc_map, dim=1) + + descriptors = [sample_descriptors(k[None], d[None], 4)[0] + for k, d in zip(keypoints, desc_map)] + + return { + 'score_map': score, + 'desc_map': desc_map, + 'mid_features': out4, + 'global_descriptors': [out1b, out2b, out3b, out4], + 'keypoints': keypoints, + 'scores': scores, + 'descriptors': descriptors, + } + + def sample(self, score_map, semi_descs, kpts, s=4, norm_desc=True): + # print('sample: ', score_map.shape, semi_descs.shape, kpts.shape) + b, c, h, w = semi_descs.shape + norm_kpts = kpts - s / 2 + 0.5 + norm_kpts = norm_kpts / torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)], + ).to(norm_kpts)[None] + norm_kpts = norm_kpts * 2 - 1 + # args = {'align_corners': True} if int(torch.__version__[2]) > 2 else {} + descriptors = torch.nn.functional.grid_sample( + semi_descs, norm_kpts.view(b, 1, -1, 2), mode='bilinear', align_corners=True) + + if norm_desc: + descriptors = torch.nn.functional.normalize( + descriptors.reshape(b, c, -1), p=2, dim=1) + else: + descriptors = descriptors.reshape(b, c, -1) + + # print('max: ', torch.min(kpts[:, 1].long()), torch.max(kpts[:, 1].long()), torch.min(kpts[:, 0].long()), + # torch.max(kpts[:, 0].long())) + scores = score_map[0, kpts[:, 1].long(), kpts[:, 0].long()] + + return scores, descriptors.squeeze(0) + + +class DescriptorCompressor(nn.Module): + def __init__(self, inputdim: int, outdim: int): + super().__init__() + self.inputdim = inputdim + self.outdim = outdim + self.conv = nn.Conv1d(in_channels=inputdim, out_channels=outdim, kernel_size=1, padding=0, bias=True) + + def forward(self, x): + # b, c, n = x.shape + out = self.conv(x) + out = F.normalize(out, p=2, dim=1) + return out + + +def extract_sfd2_return(model, img, conf_th=0.001, + mask=None, + topK=-1, + min_keypoints=0, + **kwargs): + old_bm = torch.backends.cudnn.benchmark + torch.backends.cudnn.benchmark = False # speedup + + img = norm_RGB(img.squeeze()) + img = img[None] + img = img.cuda() + + B, one, H, W = img.shape + + all_pts = [] + all_descs = [] + + if 'scales' in kwargs.keys(): + scales = kwargs.get('scales') + else: + scales = [1.0] + + for s in scales: + if s == 1.0: + new_img = img + else: + nh = int(H * s) + nw = int(W * s) + new_img = F.interpolate(img, size=(nh, nw), mode='bilinear', align_corners=True) + nh, nw = new_img.shape[2:] + + with torch.no_grad(): + heatmap, coarse_desc = model.det(new_img) + + # print("nh, nw, heatmap, desc: ", nh, nw, heatmap.shape, coarse_desc.shape) + if len(heatmap.size()) == 3: + heatmap = heatmap.unsqueeze(1) + if len(heatmap.size()) == 2: + heatmap = heatmap.unsqueeze(0) + heatmap = heatmap.unsqueeze(1) + # print(heatmap.shape) + if heatmap.size(2) != nh or heatmap.size(3) != nw: + heatmap = F.interpolate(heatmap, size=[nh, nw], mode='bilinear', align_corners=True) + + conf_thresh = conf_th + nms_dist = 3 + border_remove = 4 + scores = simple_nms(heatmap, nms_radius=nms_dist) + keypoints = [ + torch.nonzero(s > conf_thresh) + for s in scores] + scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)] + # print('scores in return: ', len(scores[0])) + + # print(keypoints[0].shape) + keypoints = [torch.flip(k, [1]).float() for k in keypoints] + scores = scores[0].data.cpu().numpy().squeeze() + keypoints = keypoints[0].data.cpu().numpy().squeeze() + pts = keypoints.transpose() + pts[2, :] = scores + + inds = np.argsort(pts[2, :]) + pts = pts[:, inds[::-1]] # Sort by confidence. + # Remove points along border. + bord = border_remove + toremoveW = np.logical_or(pts[0, :] < bord, pts[0, :] >= (W - bord)) + toremoveH = np.logical_or(pts[1, :] < bord, pts[1, :] >= (H - bord)) + toremove = np.logical_or(toremoveW, toremoveH) + pts = pts[:, ~toremove] + + # valid_idex = heatmap > conf_thresh + # valid_score = heatmap[valid_idex] + # """ + # --- Process descriptor. + # coarse_desc = coarse_desc.data.cpu().numpy().squeeze() + D = coarse_desc.size(1) + if pts.shape[1] == 0: + desc = np.zeros((D, 0)) + else: + if coarse_desc.size(2) == nh and coarse_desc.size(3) == nw: + desc = coarse_desc[:, :, pts[1, :], pts[0, :]] + desc = desc.data.cpu().numpy().reshape(D, -1) + else: + # Interpolate into descriptor map using 2D point locations. + samp_pts = torch.from_numpy(pts[:2, :].copy()) + samp_pts[0, :] = (samp_pts[0, :] / (float(nw) / 2.)) - 1. + samp_pts[1, :] = (samp_pts[1, :] / (float(nh) / 2.)) - 1. + samp_pts = samp_pts.transpose(0, 1).contiguous() + samp_pts = samp_pts.view(1, 1, -1, 2) + samp_pts = samp_pts.float() + samp_pts = samp_pts.cuda() + desc = torch.nn.functional.grid_sample(coarse_desc, samp_pts, mode='bilinear', align_corners=True) + desc = desc.data.cpu().numpy().reshape(D, -1) + desc /= np.linalg.norm(desc, axis=0)[np.newaxis, :] + + if pts.shape[1] == 0: + continue + + # print(pts.shape, heatmap.shape, new_img.shape, img.shape, nw, nh, W, H) + pts[0, :] = pts[0, :] * W / nw + pts[1, :] = pts[1, :] * H / nh + all_pts.append(np.transpose(pts, [1, 0])) + all_descs.append(np.transpose(desc, [1, 0])) + + all_pts = np.vstack(all_pts) + all_descs = np.vstack(all_descs) + + torch.backends.cudnn.benchmark = old_bm + + if all_pts.shape[0] == 0: + return None, None, None + + keypoints = all_pts[:, 0:2] + scores = all_pts[:, 2] + descriptors = all_descs + + if mask is not None: + # cv2.imshow("mask", mask) + # cv2.waitKey(0) + labels = [] + others = [] + keypoints_with_labels = [] + scores_with_labels = [] + descriptors_with_labels = [] + keypoints_without_labels = [] + scores_without_labels = [] + descriptors_without_labels = [] + + id_img = np.int32(mask[:, :, 2]) * 256 * 256 + np.int32(mask[:, :, 1]) * 256 + np.int32(mask[:, :, 0]) + # print(img.shape, id_img.shape) + + for i in range(keypoints.shape[0]): + x = keypoints[i, 0] + y = keypoints[i, 1] + # print("x-y", x, y, int(x), int(y)) + gid = id_img[int(y), int(x)] + if gid == 0: + keypoints_without_labels.append(keypoints[i]) + scores_without_labels.append(scores[i]) + descriptors_without_labels.append(descriptors[i]) + others.append(0) + else: + keypoints_with_labels.append(keypoints[i]) + scores_with_labels.append(scores[i]) + descriptors_with_labels.append(descriptors[i]) + labels.append(gid) + + if topK > 0: + if topK <= len(keypoints_with_labels): + idxes = np.array(scores_with_labels, float).argsort()[::-1][:topK] + keypoints = np.array(keypoints_with_labels, float)[idxes] + scores = np.array(scores_with_labels, float)[idxes] + labels = np.array(labels, np.int32)[idxes] + descriptors = np.array(descriptors_with_labels, float)[idxes] + elif topK >= len(keypoints_with_labels) + len(keypoints_without_labels): + # keypoints = np.vstack([keypoints_with_labels, keypoints_without_labels]) + # scores = np.vstack([scorescc_with_labels, scores_without_labels]) + # descriptors = np.vstack([descriptors_with_labels, descriptors_without_labels]) + # labels = np.vstack([labels, others]) + keypoints = keypoints_with_labels + scores = scores_with_labels + descriptors = descriptors_with_labels + for i in range(len(others)): + keypoints.append(keypoints_without_labels[i]) + scores.append(scores_without_labels[i]) + descriptors.append(descriptors_without_labels[i]) + labels.append(others[i]) + else: + n = topK - len(keypoints_with_labels) + idxes = np.array(scores_without_labels, float).argsort()[::-1][:n] + keypoints = keypoints_with_labels + scores = scores_with_labels + descriptors = descriptors_with_labels + for i in idxes: + keypoints.append(keypoints_without_labels[i]) + scores.append(scores_without_labels[i]) + descriptors.append(descriptors_without_labels[i]) + labels.append(others[i]) + keypoints = np.array(keypoints, float) + descriptors = np.array(descriptors, float) + # print(keypoints.shape, descriptors.shape) + return {"keypoints": np.array(keypoints, float), + "descriptors": np.array(descriptors, float), + "scores": np.array(scores, np.float), + "labels": np.array(labels, np.int32), + } + else: + # print(topK) + if topK > 0: + idxes = np.array(scores, dtype=float).argsort()[::-1][:topK] + keypoints = np.array(keypoints[idxes], dtype=float) + scores = np.array(scores[idxes], dtype=float) + descriptors = np.array(descriptors[idxes], dtype=float) + + keypoints = np.array(keypoints, dtype=float) + scores = np.array(scores, dtype=float) + descriptors = np.array(descriptors, dtype=float) + + # print(keypoints.shape, descriptors.shape) + + return {"keypoints": np.array(keypoints, dtype=float), + "descriptors": descriptors, + "scores": scores, + } + + +def load_sfd2(weight_path): + net = ResNet4x(inputdim=3, outdim=128) + net.load_state_dict(torch.load(weight_path, map_location='cpu')['state_dict'], strict=True) + # print('Load sfd2 from {:s}'.format(weight_path)) + return net diff --git a/third_party/pram/nets/superpoint.py b/third_party/pram/nets/superpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..6751016bd71cbbbb072243b3c1aebc100f632693 --- /dev/null +++ b/third_party/pram/nets/superpoint.py @@ -0,0 +1,607 @@ +# %BANNER_BEGIN% +# --------------------------------------------------------------------- +# %COPYRIGHT_BEGIN% +# +# Magic Leap, Inc. ("COMPANY") CONFIDENTIAL +# +# Unpublished Copyright (c) 2020 +# Magic Leap, Inc., All Rights Reserved. +# +# NOTICE: All information contained herein is, and remains the property +# of COMPANY. The intellectual and technical concepts contained herein +# are proprietary to COMPANY and may be covered by U.S. and Foreign +# Patents, patents in process, and are protected by trade secret or +# copyright law. Dissemination of this information or reproduction of +# this material is strictly forbidden unless prior written permission is +# obtained from COMPANY. Access to the source code contained herein is +# hereby forbidden to anyone except current COMPANY employees, managers +# or contractors who have executed Confidentiality and Non-disclosure +# agreements explicitly covering such access. +# +# The copyright notice above does not evidence any actual or intended +# publication or disclosure of this source code, which includes +# information that is confidential and/or proprietary, and is a trade +# secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION, +# PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS +# SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS +# STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND +# INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE +# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS +# TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE, +# USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART. +# +# %COPYRIGHT_END% +# ---------------------------------------------------------------------- +# %AUTHORS_BEGIN% +# +# Originating Authors: Paul-Edouard Sarlin +# +# %AUTHORS_END% +# --------------------------------------------------------------------*/ +# %BANNER_END% + +from pathlib import Path +import torch +from torch import nn +import numpy as np +import cv2 +import torch.nn.functional as F + + +def simple_nms(scores, nms_radius: int): + """ Fast Non-maximum suppression to remove nearby points """ + assert (nms_radius >= 0) + + def max_pool(x): + return torch.nn.functional.max_pool2d( + x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius) + + zeros = torch.zeros_like(scores) + max_mask = scores == max_pool(scores) + for _ in range(2): + supp_mask = max_pool(max_mask.float()) > 0 + supp_scores = torch.where(supp_mask, zeros, scores) + new_max_mask = supp_scores == max_pool(supp_scores) + max_mask = max_mask | (new_max_mask & (~supp_mask)) + return torch.where(max_mask, scores, zeros) + + +def remove_borders(keypoints, scores, border: int, height: int, width: int): + """ Removes keypoints too close to the border """ + mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border)) + mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border)) + mask = mask_h & mask_w + return keypoints[mask], scores[mask] + + +def top_k_keypoints(keypoints, scores, k: int): + if k >= len(keypoints): + return keypoints, scores + scores, indices = torch.topk(scores, k, dim=0) + return keypoints[indices], scores + + +def sample_descriptors(keypoints, descriptors, s: int = 8): + """ Interpolate descriptors at keypoint locations """ + b, c, h, w = descriptors.shape + keypoints = keypoints - s / 2 + 0.5 + keypoints /= torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)], + ).to(keypoints)[None] + keypoints = keypoints * 2 - 1 # normalize to (-1, 1) + args = {'align_corners': True} if int(torch.__version__[2]) > 2 else {} + descriptors = torch.nn.functional.grid_sample( + descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args) + descriptors = torch.nn.functional.normalize( + descriptors.reshape(b, c, -1), p=2, dim=1) + return descriptors + + +class SuperPoint(nn.Module): + """SuperPoint Convolutional Detector and Descriptor + + SuperPoint: Self-Supervised Interest Point Detection and + Description. Daniel DeTone, Tomasz Malisiewicz, and Andrew + Rabinovich. In CVPRW, 2019. https://arxiv.org/abs/1712.07629 + + """ + default_config = { + 'descriptor_dim': 256, + 'nms_radius': 3, + 'keypoint_threshold': 0.001, + 'max_keypoints': -1, + 'min_keypoints': 32, + 'remove_borders': 4, + } + + def __init__(self, config): + super().__init__() + self.config = {**self.default_config, **config} + + self.relu = nn.ReLU(inplace=True) + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256 + + self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1) + self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1) # 64 + self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1) + self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1) # 64 + self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1) + self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1) # 128 + self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1) + self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1) # 128 + + self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) # 256 + self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0) + + self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) # 256 + self.convDb = nn.Conv2d( + c5, self.config['descriptor_dim'], + kernel_size=1, stride=1, padding=0) + + # path = Path(__file__).parent / 'weights/superpoint_v1.pth' + path = config['weight_path'] + self.load_state_dict(torch.load(str(path), map_location='cpu'), strict=True) + + mk = self.config['max_keypoints'] + if mk == 0 or mk < -1: + raise ValueError('\"max_keypoints\" must be positive or \"-1\"') + + print('Loaded SuperPoint model') + + def extract_global(self, data): + # Shared Encoder + x0 = self.relu(self.conv1a(data['image'])) + x0 = self.relu(self.conv1b(x0)) + x0 = self.pool(x0) + x1 = self.relu(self.conv2a(x0)) + x1 = self.relu(self.conv2b(x1)) + x1 = self.pool(x1) + x2 = self.relu(self.conv3a(x1)) + x2 = self.relu(self.conv3b(x2)) + x2 = self.pool(x2) + x3 = self.relu(self.conv4a(x2)) + x3 = self.relu(self.conv4b(x3)) + + x4 = self.relu(self.convDa(x3)) + + # print('ex_g: ', x0.shape, x1.shape, x2.shape, x3.shape, x4.shape) + + return [x0, x1, x2, x3, x4] + + def extract_local_global(self, data): + # Shared Encoder + b, ic, ih, iw = data['image'].shape + x0 = self.relu(self.conv1a(data['image'])) + x0 = self.relu(self.conv1b(x0)) + x0 = self.pool(x0) + x1 = self.relu(self.conv2a(x0)) + x1 = self.relu(self.conv2b(x1)) + x1 = self.pool(x1) + x2 = self.relu(self.conv3a(x1)) + x2 = self.relu(self.conv3b(x2)) + x2 = self.pool(x2) + x3 = self.relu(self.conv4a(x2)) + x3 = self.relu(self.conv4b(x3)) + + # Compute the dense keypoint scores + cPa = self.relu(self.convPa(x3)) + score = self.convPb(cPa) + score = torch.nn.functional.softmax(score, 1)[:, :-1] + # print(scores.shape) + b, _, h, w = score.shape + score = score.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) + score = score.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8) + score = torch.nn.functional.interpolate(score.unsqueeze(1), size=(ih, iw), align_corners=True, + mode='bilinear') + score = score.squeeze(1) + + # extract kpts + nms_scores = simple_nms(scores=score, nms_radius=self.config['nms_radius']) + keypoints = [ + torch.nonzero(s >= self.config['keypoint_threshold']) + for s in nms_scores] + scores = [s[tuple(k.t())] for s, k in zip(nms_scores, keypoints)] + + if len(scores[0]) <= self.config['min_keypoints']: + keypoints = [ + torch.nonzero(s >= self.config['keypoint_threshold'] * 0.5) + for s in nms_scores] + scores = [s[tuple(k.t())] for s, k in zip(nms_scores, keypoints)] + + # Discard keypoints near the image borders + keypoints, scores = list(zip(*[ + remove_borders(k, s, self.config['remove_borders'], ih, iw) + for k, s in zip(keypoints, scores)])) + + # Keep the k keypoints with the highest score + if self.config['max_keypoints'] >= 0: + keypoints, scores = list(zip(*[ + top_k_keypoints(k, s, self.config['max_keypoints']) + for k, s in zip(keypoints, scores)])) + + # Convert (h, w) to (x, y) + keypoints = [torch.flip(k, [1]).float() for k in keypoints] + + # Compute the dense descriptors + cDa = self.relu(self.convDa(x3)) + desc_map = self.convDb(cDa) + desc_map = torch.nn.functional.normalize(desc_map, p=2, dim=1) + descriptors = [sample_descriptors(k[None], d[None], 8)[0] + for k, d in zip(keypoints, desc_map)] + + return { + 'score_map': score, + 'desc_map': desc_map, + 'mid_features': cDa, # 256 + 'global_descriptors': [x0, x1, x2, x3, cDa], + 'keypoints': keypoints, + 'scores': scores, + 'descriptors': descriptors, + } + + def sample(self, score_map, semi_descs, kpts, s=8, norm_desc=True): + # print('sample: ', score_map.shape, semi_descs.shape, kpts.shape) + b, c, h, w = semi_descs.shape + norm_kpts = kpts - s / 2 + 0.5 + norm_kpts = norm_kpts / torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)], + ).to(norm_kpts)[None] + norm_kpts = norm_kpts * 2 - 1 + # args = {'align_corners': True} if int(torch.__version__[2]) > 2 else {} + descriptors = torch.nn.functional.grid_sample( + semi_descs, norm_kpts.view(b, 1, -1, 2), mode='bilinear', align_corners=True) + if norm_desc: + descriptors = torch.nn.functional.normalize( + descriptors.reshape(b, c, -1), p=2, dim=1) + else: + descriptors = descriptors.reshape(b, c, -1) + + # print('max: ', torch.min(kpts[:, 1].long()), torch.max(kpts[:, 1].long()), torch.min(kpts[:, 0].long()), + # torch.max(kpts[:, 0].long())) + scores = score_map[0, kpts[:, 1].long(), kpts[:, 0].long()] + + return scores, descriptors.squeeze(0) + + def extract(self, data): + """ Compute keypoints, scores, descriptors for image """ + # Shared Encoder + x = self.relu(self.conv1a(data['image'])) + x = self.relu(self.conv1b(x)) + x = self.pool(x) + x = self.relu(self.conv2a(x)) + x = self.relu(self.conv2b(x)) + x = self.pool(x) + x = self.relu(self.conv3a(x)) + x = self.relu(self.conv3b(x)) + x = self.pool(x) + x = self.relu(self.conv4a(x)) + x = self.relu(self.conv4b(x)) + + # Compute the dense keypoint scores + cPa = self.relu(self.convPa(x)) + scores = self.convPb(cPa) + scores = torch.nn.functional.softmax(scores, 1)[:, :-1] + b, _, h, w = scores.shape + scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) + scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8) + + # Compute the dense descriptors + cDa = self.relu(self.convDa(x)) + descriptors = self.convDb(cDa) + descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1) + + return scores, descriptors + + def det(self, image): + """ Compute keypoints, scores, descriptors for image """ + # Shared Encoder + x = self.relu(self.conv1a(image)) + x = self.relu(self.conv1b(x)) + x = self.pool(x) + x = self.relu(self.conv2a(x)) + x = self.relu(self.conv2b(x)) + x = self.pool(x) + x = self.relu(self.conv3a(x)) + x = self.relu(self.conv3b(x)) + x = self.pool(x) + x = self.relu(self.conv4a(x)) + x = self.relu(self.conv4b(x)) + + # Compute the dense keypoint scores + cPa = self.relu(self.convPa(x)) + scores = self.convPb(cPa) + scores = torch.nn.functional.softmax(scores, 1)[:, :-1] + # print(scores.shape) + b, _, h, w = scores.shape + scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) + scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8) + + # Compute the dense descriptors + cDa = self.relu(self.convDa(x)) + descriptors = self.convDb(cDa) + descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1) + + return scores, descriptors + + def forward(self, data): + """ Compute keypoints, scores, descriptors for image """ + # Shared Encoder + x = self.relu(self.conv1a(data['image'])) + x = self.relu(self.conv1b(x)) + x = self.pool(x) + x = self.relu(self.conv2a(x)) + x = self.relu(self.conv2b(x)) + x = self.pool(x) + x = self.relu(self.conv3a(x)) + x = self.relu(self.conv3b(x)) + x = self.pool(x) + x = self.relu(self.conv4a(x)) + x = self.relu(self.conv4b(x)) + + # Compute the dense keypoint scores + cPa = self.relu(self.convPa(x)) + scores = self.convPb(cPa) + scores = torch.nn.functional.softmax(scores, 1)[:, :-1] + # print(scores.shape) + b, _, h, w = scores.shape + scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) + scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8) + scores = simple_nms(scores, self.config['nms_radius']) + + # Extract keypoints + keypoints = [ + torch.nonzero(s > self.config['keypoint_threshold']) + for s in scores] + scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)] + + # Discard keypoints near the image borders + keypoints, scores = list(zip(*[ + remove_borders(k, s, self.config['remove_borders'], h * 8, w * 8) + for k, s in zip(keypoints, scores)])) + + # Keep the k keypoints with highest score + if self.config['max_keypoints'] >= 0: + keypoints, scores = list(zip(*[ + top_k_keypoints(k, s, self.config['max_keypoints']) + for k, s in zip(keypoints, scores)])) + + # Convert (h, w) to (x, y) + keypoints = [torch.flip(k, [1]).float() for k in keypoints] + + # Compute the dense descriptors + cDa = self.relu(self.convDa(x)) + descriptors = self.convDb(cDa) + descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1) + + # Extract descriptors + # print(keypoints[0].shape) + descriptors = [sample_descriptors(k[None], d[None], 8)[0] + for k, d in zip(keypoints, descriptors)] + + return { + 'keypoints': keypoints, + 'scores': scores, + 'descriptors': descriptors, + 'global_descriptor': x, + } + + +def extract_descriptor(sample_pts, coarse_desc, H, W): + ''' + :param samplt_pts: + :param coarse_desc: + :return: + ''' + with torch.no_grad(): + norm_sample_pts = torch.zeros_like(sample_pts) + norm_sample_pts[0, :] = (sample_pts[0, :] / (float(W) / 2.)) - 1. # x + norm_sample_pts[1, :] = (sample_pts[1, :] / (float(H) / 2.)) - 1. # y + norm_sample_pts = norm_sample_pts.transpose(0, 1).contiguous() + norm_sample_pts = norm_sample_pts.view(1, 1, -1, 2).float() + sample_desc = torch.nn.functional.grid_sample(coarse_desc[None], norm_sample_pts, mode='bilinear', + align_corners=False) + sample_desc = torch.nn.functional.normalize(sample_desc, dim=1).squeeze(2).squeeze(0) + return sample_desc + + +def extract_sp_return(model, img, conf_th=0.005, + mask=None, + topK=-1, + **kwargs): + old_bm = torch.backends.cudnn.benchmark + torch.backends.cudnn.benchmark = False # speedup + + # print(img.shape) + img = img.cuda() + # if len(img.shape) == 3: # gray image + # img = img[None] + + B, one, H, W = img.shape + + all_pts = [] + all_descs = [] + + if 'scales' in kwargs.keys(): + scales = kwargs.get('scales') + else: + scales = [1.0] + + for s in scales: + if s == 1.0: + new_img = img + else: + nh = int(H * s) + nw = int(W * s) + new_img = F.interpolate(img, size=(nh, nw), mode='bilinear', align_corners=True) + nh, nw = new_img.shape[2:] + + with torch.no_grad(): + heatmap, coarse_desc = model.det(new_img) + + # print("nh, nw, heatmap, desc: ", nh, nw, heatmap.shape, coarse_desc.shape) + if len(heatmap.size()) == 3: + heatmap = heatmap.unsqueeze(1) + if len(heatmap.size()) == 2: + heatmap = heatmap.unsqueeze(0) + heatmap = heatmap.unsqueeze(1) + # print(heatmap.shape) + if heatmap.size(2) != nh or heatmap.size(3) != nw: + heatmap = F.interpolate(heatmap, size=[nh, nw], mode='bilinear', align_corners=True) + + conf_thresh = conf_th + nms_dist = 4 + border_remove = 4 + scores = simple_nms(heatmap, nms_radius=nms_dist) + keypoints = [ + torch.nonzero(s > conf_thresh) + for s in scores] + scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)] + # print(keypoints[0].shape) + keypoints = [torch.flip(k, [1]).float() for k in keypoints] + scores = scores[0].data.cpu().numpy().squeeze() + keypoints = keypoints[0].data.cpu().numpy().squeeze() + pts = keypoints.transpose() + pts[2, :] = scores + + inds = np.argsort(pts[2, :]) + pts = pts[:, inds[::-1]] # Sort by confidence. + # Remove points along border. + bord = border_remove + toremoveW = np.logical_or(pts[0, :] < bord, pts[0, :] >= (W - bord)) + toremoveH = np.logical_or(pts[1, :] < bord, pts[1, :] >= (H - bord)) + toremove = np.logical_or(toremoveW, toremoveH) + pts = pts[:, ~toremove] + + # valid_idex = heatmap > conf_thresh + # valid_score = heatmap[valid_idex] + # """ + # --- Process descriptor. + # coarse_desc = coarse_desc.data.cpu().numpy().squeeze() + D = coarse_desc.size(1) + if pts.shape[1] == 0: + desc = np.zeros((D, 0)) + else: + if coarse_desc.size(2) == nh and coarse_desc.size(3) == nw: + desc = coarse_desc[:, :, pts[1, :], pts[0, :]] + desc = desc.data.cpu().numpy().reshape(D, -1) + else: + # Interpolate into descriptor map using 2D point locations. + samp_pts = torch.from_numpy(pts[:2, :].copy()) + samp_pts[0, :] = (samp_pts[0, :] / (float(nw) / 2.)) - 1. + samp_pts[1, :] = (samp_pts[1, :] / (float(nh) / 2.)) - 1. + samp_pts = samp_pts.transpose(0, 1).contiguous() + samp_pts = samp_pts.view(1, 1, -1, 2) + samp_pts = samp_pts.float() + samp_pts = samp_pts.cuda() + desc = torch.nn.functional.grid_sample(coarse_desc, samp_pts, mode='bilinear', align_corners=True) + desc = desc.data.cpu().numpy().reshape(D, -1) + desc /= np.linalg.norm(desc, axis=0)[np.newaxis, :] + + if pts.shape[1] == 0: + continue + + # print(pts.shape, heatmap.shape, new_img.shape, img.shape, nw, nh, W, H) + pts[0, :] = pts[0, :] * W / nw + pts[1, :] = pts[1, :] * H / nh + all_pts.append(np.transpose(pts, [1, 0])) + all_descs.append(np.transpose(desc, [1, 0])) + + all_pts = np.vstack(all_pts) + all_descs = np.vstack(all_descs) + + torch.backends.cudnn.benchmark = old_bm + + if all_pts.shape[0] == 0: + return None, None, None + + keypoints = all_pts[:, 0:2] + scores = all_pts[:, 2] + descriptors = all_descs + + if mask is not None: + # cv2.imshow("mask", mask) + # cv2.waitKey(0) + labels = [] + others = [] + keypoints_with_labels = [] + scores_with_labels = [] + descriptors_with_labels = [] + keypoints_without_labels = [] + scores_without_labels = [] + descriptors_without_labels = [] + + id_img = np.int32(mask[:, :, 2]) * 256 * 256 + np.int32(mask[:, :, 1]) * 256 + np.int32(mask[:, :, 0]) + # print(img.shape, id_img.shape) + + for i in range(keypoints.shape[0]): + x = keypoints[i, 0] + y = keypoints[i, 1] + # print("x-y", x, y, int(x), int(y)) + gid = id_img[int(y), int(x)] + if gid == 0: + keypoints_without_labels.append(keypoints[i]) + scores_without_labels.append(scores[i]) + descriptors_without_labels.append(descriptors[i]) + others.append(0) + else: + keypoints_with_labels.append(keypoints[i]) + scores_with_labels.append(scores[i]) + descriptors_with_labels.append(descriptors[i]) + labels.append(gid) + + if topK > 0: + if topK <= len(keypoints_with_labels): + idxes = np.array(scores_with_labels, float).argsort()[::-1][:topK] + keypoints = np.array(keypoints_with_labels, float)[idxes] + scores = np.array(scores_with_labels, float)[idxes] + labels = np.array(labels, np.int32)[idxes] + descriptors = np.array(descriptors_with_labels, float)[idxes] + elif topK >= len(keypoints_with_labels) + len(keypoints_without_labels): + # keypoints = np.vstack([keypoints_with_labels, keypoints_without_labels]) + # scores = np.vstack([scorescc_with_labels, scores_without_labels]) + # descriptors = np.vstack([descriptors_with_labels, descriptors_without_labels]) + # labels = np.vstack([labels, others]) + keypoints = keypoints_with_labels + scores = scores_with_labels + descriptors = descriptors_with_labels + for i in range(len(others)): + keypoints.append(keypoints_without_labels[i]) + scores.append(scores_without_labels[i]) + descriptors.append(descriptors_without_labels[i]) + labels.append(others[i]) + else: + n = topK - len(keypoints_with_labels) + idxes = np.array(scores_without_labels, float).argsort()[::-1][:n] + keypoints = keypoints_with_labels + scores = scores_with_labels + descriptors = descriptors_with_labels + for i in idxes: + keypoints.append(keypoints_without_labels[i]) + scores.append(scores_without_labels[i]) + descriptors.append(descriptors_without_labels[i]) + labels.append(others[i]) + keypoints = np.array(keypoints, float) + descriptors = np.array(descriptors, float) + # print(keypoints.shape, descriptors.shape) + return {"keypoints": np.array(keypoints, float), + "descriptors": np.array(descriptors, float), + "scores": np.array(scores, float), + "labels": np.array(labels, np.int32), + } + else: + # print(topK) + if topK > 0: + idxes = np.array(scores, dtype=float).argsort()[::-1][:topK] + keypoints = np.array(keypoints[idxes], dtype=float) + scores = np.array(scores[idxes], dtype=float) + descriptors = np.array(descriptors[idxes], dtype=float) + + keypoints = np.array(keypoints, dtype=float) + scores = np.array(scores, dtype=float) + descriptors = np.array(descriptors, dtype=float) + + # print(keypoints.shape, descriptors.shape) + + return {"keypoints": np.array(keypoints, dtype=float), + "descriptors": descriptors, + "scores": scores, + } diff --git a/third_party/pram/nets/utils.py b/third_party/pram/nets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..066a00510c19e0c87cf5d07a36cea2a90dd0e3eb --- /dev/null +++ b/third_party/pram/nets/utils.py @@ -0,0 +1,24 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> utils +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 07/02/2024 10:48 +==================================================''' +import torch + +eps = 1e-8 + + +def arange_like(x, dim: int): + return x.new_ones(x.shape[dim]).cumsum(0) - 1 + + +def normalize_keypoints(kpts, image_shape): + """ Normalize keypoints locations based on image image_shape""" + _, _, height, width = image_shape + one = kpts.new_tensor(1) + size = torch.stack([one * width, one * height])[None] + center = size / 2 + scaling = size.max(1, keepdim=True).values * 0.7 + return (kpts - center[:, None, :]) / scaling[:, None, :] diff --git a/third_party/pram/recognition/recmap.py b/third_party/pram/recognition/recmap.py new file mode 100644 index 0000000000000000000000000000000000000000..c159de286e96fdb594428e88e370e1a7edbecb79 --- /dev/null +++ b/third_party/pram/recognition/recmap.py @@ -0,0 +1,1118 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> recmap +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 07/02/2024 11:02 +==================================================''' +import argparse +import torch +import os +import os.path as osp +import numpy as np +import cv2 +import yaml +import multiprocessing as mp +from copy import deepcopy +import logging +import h5py +from tqdm import tqdm +import open3d as o3d +from sklearn.cluster import KMeans, Birch +from collections import defaultdict +from colmap_utils.read_write_model import read_model, qvec2rotmat, write_cameras_binary, write_images_binary +from colmap_utils.read_write_model import write_points3d_binary, Image, Point3D, Camera +from colmap_utils.read_write_model import write_compressed_points3d_binary, write_compressed_images_binary +from recognition.vis_seg import generate_color_dic, vis_seg_point, plot_kpts + + +class RecMap: + def __init__(self): + self.cameras = None + self.images = None + self.points3D = None + self.pcd = o3d.geometry.PointCloud() + self.seg_color_dict = generate_color_dic(n_seg=1000) + + def load_sfm_model(self, path: str, ext='.bin'): + self.cameras, self.images, self.points3D = read_model(path, ext) + self.name_to_id = {image.name: i for i, image in self.images.items()} + print('Load {:d} cameras, {:d} images, {:d} points'.format(len(self.cameras), len(self.images), + len(self.points3D))) + + def remove_statics_outlier(self, nb_neighbors: int = 20, std_ratio: float = 2.0): + xyzs = [] + p3d_ids = [] + for p3d_id in self.points3D.keys(): + xyzs.append(self.points3D[p3d_id].xyz) + p3d_ids.append(p3d_id) + + xyzs = np.array(xyzs) + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(xyzs) + new_pcd, inlier_ids = pcd.remove_statistical_outlier(nb_neighbors=nb_neighbors, std_ratio=std_ratio) + + new_point3Ds = {} + for i in inlier_ids: + new_point3Ds[p3d_ids[i]] = self.points3D[p3d_ids[i]] + self.points3D = new_point3Ds + n_outlier = xyzs.shape[0] - len(inlier_ids) + ratio = n_outlier / xyzs.shape[0] + print('Remove {:d} - {:d} = {:d}/{:.2f}% points'.format(xyzs.shape[0], len(inlier_ids), n_outlier, ratio * 100)) + + def load_segmentation(self, path: str): + data = np.load(path, allow_pickle=True)[()] + p3d_id = data['id'] + seg_id = data['label'] + self.p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])} + self.seg_p3d = {} + for pid in self.p3d_seg.keys(): + sid = self.p3d_seg[pid] + if sid not in self.seg_p3d.keys(): + self.seg_p3d[sid] = [pid] + else: + self.seg_p3d[sid].append(pid) + + if 'xyz' not in data.keys(): + all_xyz = [] + for pid in p3d_id: + xyz = self.points3D[pid].xyz + all_xyz.append(xyz) + data['xyz'] = np.array(all_xyz) + np.save(path, data) + print('Add xyz to ', path) + + def cluster(self, k=512, mode='xyz', min_obs=3, save_fn=None, method='kmeans', **kwargs): + if save_fn is not None: + if osp.isfile(save_fn): + print('{:s} exists.'.format(save_fn)) + return + all_xyz = [] + point3D_ids = [] + for p3d in self.points3D.values(): + track_len = len(p3d.point2D_idxs) + if track_len < min_obs: + continue + all_xyz.append(p3d.xyz) + point3D_ids.append(p3d.id) + + xyz = np.array(all_xyz) + point3D_ids = np.array(point3D_ids) + + if mode.find('x') < 0: + xyz[:, 0] = 0 + if mode.find('y') < 0: + xyz[:, 1] = 0 + if mode.find('z') < 0: + xyz[:, 2] = 0 + + if method == 'kmeans': + model = KMeans(n_clusters=k, random_state=0, verbose=True).fit(xyz) + elif method == 'birch': + model = Birch(threshold=kwargs.get('threshold'), n_clusters=k).fit(xyz) # 0.01 for indoor + else: + print('Method {:s} for clustering does not exist'.format(method)) + exit(0) + labels = np.array(model.labels_).reshape(-1) + if save_fn is not None: + np.save(save_fn, { + 'id': np.array(point3D_ids), # should be assigned to self.points3D_ids + 'label': np.array(labels), + 'xyz': np.array(all_xyz), + }) + + def assign_point3D_descriptor(self, feature_fn: str, save_fn=None, n_process=1): + ''' + assign each 3d point a descriptor for localization + :param feature_fn: file name of features [h5py] + :param save_fn: + :param n_process: + :return: + ''' + + def run(start_id, end_id, points3D_desc): + for pi in tqdm(range(start_id, end_id), total=end_id - start_id): + p3d_id = all_p3d_ids[pi] + img_list = self.points3D[p3d_id].image_ids + kpt_ids = self.points3D[p3d_id].point2D_idxs + all_descs = [] + for img_id, p2d_id in zip(img_list, kpt_ids): + if img_id not in self.images.keys(): + continue + img_fn = self.images[img_id].name + desc = feat_file[img_fn]['descriptors'][()].transpose()[p2d_id] + all_descs.append(desc) + + if len(all_descs) == 1: + points3D_desc[p3d_id] = all_descs[0] + else: + all_descs = np.array(all_descs) # [n, d] + dist = all_descs @ all_descs.transpose() # [n, n] + dist = 2 - 2 * dist + md_dist = np.median(dist, axis=-1) # [n] + min_id = np.argmin(md_dist) + points3D_desc[p3d_id] = all_descs[min_id] + + if osp.isfile(save_fn): + print('{:s} exists.'.format(save_fn)) + return + p3D_desc = {} + feat_file = h5py.File(feature_fn, 'r') + all_p3d_ids = sorted(self.points3D.keys()) + + if n_process > 1: + if len(all_p3d_ids) <= n_process: + run(start_id=0, end_id=len(all_p3d_ids), points3D_desc=p3D_desc) + else: + manager = mp.Manager() + output = manager.dict() # necessary otherwise empty + n_sample_per_process = len(all_p3d_ids) // n_process + jobs = [] + for i in range(n_process): + start_id = i * n_sample_per_process + if i == n_process - 1: + end_id = len(all_p3d_ids) + else: + end_id = (i + 1) * n_sample_per_process + p = mp.Process( + target=run, + args=(start_id, end_id, output), + ) + jobs.append(p) + p.start() + + for p in jobs: + p.join() + + p3D_desc = {} + for k in output.keys(): + p3D_desc[k] = output[k] + else: + run(start_id=0, end_id=len(all_p3d_ids), points3D_desc=p3D_desc) + + if save_fn is not None: + np.save(save_fn, p3D_desc) + + def reproject(self, img_id, xyzs): + qvec = self.images[img_id].qvec + Rcw = qvec2rotmat(qvec=qvec) + tvec = self.images[img_id].tvec + tcw = tvec.reshape(3, ) + Tcw = np.eye(4, dtype=float) + Tcw[:3, :3] = Rcw + Tcw[:3, 3] = tcw + # intrinsics + cam = self.cameras[self.images[img_id].camera_id] + K = self.get_intrinsics_from_camera(camera=cam) + + xyzs_homo = np.hstack([xyzs, np.ones(shape=(xyzs.shape[0], 1), dtype=float)]) + kpts = K @ ((Tcw @ xyzs_homo.transpose())[:3, :]) # [3, N] + kpts = kpts.transpose() # [N, 3] + kpts[:, 0] = kpts[:, 0] / kpts[:, 2] + kpts[:, 1] = kpts[:, 1] / kpts[:, 2] + + return kpts + + def find_covisible_frame_ids(self, image_id, images, points3D): + covis = defaultdict(int) + p3d_ids = images[image_id].point3D_ids + + for pid in p3d_ids: + if pid == -1: + continue + if pid not in points3D.keys(): + continue + for im in points3D[pid].image_ids: + covis[im] += 1 + + covis_ids = np.array(list(covis.keys())) + covis_num = np.array([covis[i] for i in covis_ids]) + ind_top = np.argsort(covis_num)[::-1] + sorted_covis_ids = [covis_ids[i] for i in ind_top] + return sorted_covis_ids + + def create_virtual_frame_3(self, save_fn=None, save_vrf_dir=None, show_time=-1, ignored_cameras=[], + min_cover_ratio=0.9, + depth_scale=1.2, + radius=15, + min_obs=120, + topk_imgs=500, + n_vrf=10, + covisible_frame=20, + **kwargs): + def reproject(img_id, xyzs): + qvec = self.images[img_id].qvec + Rcw = qvec2rotmat(qvec=qvec) + tvec = self.images[img_id].tvec + tcw = tvec.reshape(3, ) + Tcw = np.eye(4, dtype=float) + Tcw[:3, :3] = Rcw + Tcw[:3, 3] = tcw + # intrinsics + cam = self.cameras[self.images[img_id].camera_id] + K = self.get_intrinsics_from_camera(camera=cam) + + xyzs_homo = np.hstack([xyzs, np.ones(shape=(xyzs.shape[0], 1), dtype=float)]) + kpts = K @ ((Tcw @ xyzs_homo.transpose())[:3, :]) # [3, N] + kpts = kpts.transpose() # [N, 3] + kpts[:, 0] = kpts[:, 0] / kpts[:, 2] + kpts[:, 1] = kpts[:, 1] / kpts[:, 2] + + return kpts + + def find_best_vrf_by_covisibility(p3d_id_list): + all_img_ids = [] + all_xyzs = [] + + img_ids_full = [] + img_id_obs = {} + for pid in p3d_id_list: + if pid not in self.points3D.keys(): + continue + all_xyzs.append(self.points3D[pid].xyz) + + img_ids = self.points3D[pid].image_ids + for iid in img_ids: + if iid in all_img_ids: + continue + # valid_p3ds = [v for v in self.images[iid].point3D_ids if v > 0 and v in p3d_id_list] + if len(ignored_cameras) > 0: + ignore = False + img_name = self.images[iid].name + for c in ignored_cameras: + if img_name.find(c) >= 0: + ignore = True + break + if ignore: + continue + # valid_p3ds = np.intersect1d(np.array(self.images[iid].point3D_ids), np.array(p3d_id_list)).tolist() + valid_p3ds = [v for v in self.images[iid].point3D_ids if v > 0] + img_ids_full.append(iid) + if len(valid_p3ds) < min_obs: + continue + + all_img_ids.append(iid) + img_id_obs[iid] = len(valid_p3ds) + all_xyzs = np.array(all_xyzs) + + print('Find {} 3D points and {} images'.format(len(p3d_id_list), len(img_id_obs.keys()))) + top_img_ids_by_obs = sorted(img_id_obs.items(), key=lambda item: item[1], reverse=True) # [(key, value), ] + all_img_ids = [] + for item in top_img_ids_by_obs: + all_img_ids.append(item[0]) + if len(all_img_ids) >= topk_imgs: + break + + # all_img_ids = all_img_ids[:200] + if len(all_img_ids) == 0: + print('no valid img ids with obs over {:d}'.format(min_obs)) + all_img_ids = img_ids_full + + img_observations = {} + p3d_id_array = np.array(p3d_id_list) + for idx, img_id in enumerate(all_img_ids): + valid_p3ds = [v for v in self.images[img_id].point3D_ids if v > 0] + mask = np.array([False for i in range(len(p3d_id_list))]) + for pid in valid_p3ds: + found_idx = np.where(p3d_id_array == pid)[0] + if found_idx.shape[0] == 0: + continue + mask[found_idx[0]] = True + + img_observations[img_id] = mask + + unobserved_p3d_ids = np.array([True for i in range(len(p3d_id_list))]) + + candidate_img_ids = [] + total_cover_ratio = 0 + while total_cover_ratio < min_cover_ratio: + best_img_id = -1 + best_img_obs = -1 + for idx, im_id in enumerate(all_img_ids): + if im_id in candidate_img_ids: + continue + obs_i = np.sum(img_observations[im_id] * unobserved_p3d_ids) + if obs_i > best_img_obs: + best_img_id = im_id + best_img_obs = obs_i + + if best_img_id >= 0: + # keep the valid img_id + candidate_img_ids.append(best_img_id) + # update the unobserved mask + unobserved_p3d_ids[img_observations[best_img_id]] = False + total_cover_ratio = 1 - np.sum(unobserved_p3d_ids) / len(p3d_id_list) + print(len(candidate_img_ids), best_img_obs, best_img_obs / len(p3d_id_list), total_cover_ratio) + + if best_img_obs / len(p3d_id_list) < 0.01: + break + + if len(candidate_img_ids) >= n_vrf: + break + else: + break + + return candidate_img_ids + # return [(v, img_observations[v]) for v in candidate_img_ids] + + if save_vrf_dir is not None: + os.makedirs(save_vrf_dir, exist_ok=True) + + seg_ref = {} + for sid in self.seg_p3d.keys(): + if sid == -1: # ignore invalid segment + continue + all_p3d_ids = self.seg_p3d[sid] + candidate_img_ids = find_best_vrf_by_covisibility(p3d_id_list=all_p3d_ids) + + seg_ref[sid] = {} + for can_idx, img_id in enumerate(candidate_img_ids): + cam = self.cameras[self.images[img_id].camera_id] + width = cam.width + height = cam.height + qvec = self.images[img_id].qvec + tvec = self.images[img_id].tvec + + img_name = self.images[img_id].name + orig_p3d_ids = [p for p in self.images[img_id].point3D_ids if p in self.points3D.keys() and p >= 0] + orig_xyzs = [] + new_xyzs = [] + for pid in all_p3d_ids: + if pid in orig_p3d_ids: + orig_xyzs.append(self.points3D[pid].xyz) + else: + if pid in self.points3D.keys(): + new_xyzs.append(self.points3D[pid].xyz) + + if len(orig_xyzs) == 0: + continue + + orig_xyzs = np.array(orig_xyzs) + new_xyzs = np.array(new_xyzs) + + print('img: ', osp.join(kwargs.get('image_root'), img_name)) + img = cv2.imread(osp.join(kwargs.get('image_root'), img_name)) + orig_kpts = reproject(img_id=img_id, xyzs=orig_xyzs) + max_depth = depth_scale * np.max(orig_kpts[:, 2]) + orig_kpts = orig_kpts[:, :2] + mask_ori = (orig_kpts[:, 0] >= 0) & (orig_kpts[:, 0] < width) & (orig_kpts[:, 1] >= 0) & ( + orig_kpts[:, 1] < height) + orig_kpts = orig_kpts[mask_ori] + + if orig_kpts.shape[0] == 0: + continue + + img_kpt = plot_kpts(img=img, kpts=orig_kpts, radius=[3 for i in range(orig_kpts.shape[0])], + colors=[(0, 0, 255) for i in range(orig_kpts.shape[0])], thickness=-1) + if new_xyzs.shape[0] == 0: + img_all = img_kpt + else: + new_kpts = reproject(img_id=img_id, xyzs=new_xyzs) + mask_depth = (new_kpts[:, 2] > 0) & (new_kpts[:, 2] <= max_depth) + mask_in_img = (new_kpts[:, 0] >= 0) & (new_kpts[:, 0] < width) & (new_kpts[:, 1] >= 0) & ( + new_kpts[:, 1] < height) + dist_all_orig = torch.from_numpy(new_kpts[:, :2])[..., None] - \ + torch.from_numpy(orig_kpts[:, :2].transpose())[None] + dist_all_orig = torch.sqrt(torch.sum(dist_all_orig ** 2, dim=1)) # [N, M] + min_dist = torch.min(dist_all_orig, dim=1)[0].numpy() + mask_close_to_img = (min_dist <= radius) + + mask_new = (mask_depth & mask_in_img & mask_close_to_img) + + cover_ratio = np.sum(mask_ori) + np.sum(mask_new) + cover_ratio = cover_ratio / len(all_p3d_ids) + + print('idx: {:d}, img: ori {:d}/{:d}/{:.2f}, new {:d}/{:d}'.format(can_idx, + orig_kpts.shape[0], + np.sum(mask_ori), + cover_ratio * 100, + new_kpts.shape[0], + np.sum(mask_new))) + + new_kpts = new_kpts[mask_new] + + # img_all = img_kpt + img_all = plot_kpts(img=img_kpt, kpts=new_kpts, radius=[3 for i in range(new_kpts.shape[0])], + colors=[(0, 255, 0) for i in range(new_kpts.shape[0])], thickness=-1) + + cv2.namedWindow('img', cv2.WINDOW_NORMAL) + cv2.imshow('img', img_all) + + if save_vrf_dir is not None: + cv2.imwrite(osp.join(save_vrf_dir, + 'seg-{:05d}_can-{:05d}_'.format(sid, can_idx) + img_name.replace('/', '+')), + img_all) + + key = cv2.waitKey(show_time) + if key == ord('q'): + cv2.destroyAllWindows() + exit(0) + + covisile_frame_ids = self.find_covisible_frame_ids(image_id=img_id, images=self.images, + points3D=self.points3D) + seg_ref[sid][can_idx] = { + 'image_name': img_name, + 'image_id': img_id, + 'qvec': deepcopy(qvec), + 'tvec': deepcopy(tvec), + 'camera': { + 'model': cam.model, + 'params': cam.params, + 'width': cam.width, + 'height': cam.height, + }, + 'original_points3d': np.array( + [v for v in self.images[img_id].point3D_ids if v >= 0 and v in self.points3D.keys()]), + 'covisible_frame_ids': np.array(covisile_frame_ids[:covisible_frame]), + } + # save vrf info + if save_fn is not None: + print('Save {} segments with virtual reference image information to {}'.format(len(seg_ref.keys()), + save_fn)) + np.save(save_fn, seg_ref) + + def visualize_3Dpoints(self): + xyz = [] + rgb = [] + for point3D in self.points3D.values(): + xyz.append(point3D.xyz) + rgb.append(point3D.rgb / 255) + + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(xyz) + pcd.colors = o3d.utility.Vector3dVector(rgb) + o3d.visualization.draw_geometries([pcd]) + + def visualize_segmentation(self, p3d_segs, points3D): + p3d_ids = p3d_segs.keys() + xyzs = [] + rgbs = [] + for pid in p3d_ids: + xyzs.append(points3D[pid].xyz) + seg_color = self.seg_color_dict[p3d_segs[pid]] + rgbs.append(np.array([seg_color[2], seg_color[1], seg_color[0]]) / 255) + xyzs = np.array(xyzs) + rgbs = np.array(rgbs) + + self.pcd.points = o3d.utility.Vector3dVector(xyzs) + self.pcd.colors = o3d.utility.Vector3dVector(rgbs) + + o3d.visualization.draw_geometries([self.pcd]) + + def visualize_segmentation_on_image(self, p3d_segs, image_path, feat_path): + vis_color = generate_color_dic(n_seg=1024) + feat_file = h5py.File(feat_path, 'r') + + cv2.namedWindow('img', cv2.WINDOW_NORMAL) + for mi in sorted(self.images.keys()): + im = self.images[mi] + im_name = im.name + p3d_ids = im.point3D_ids + p2ds = feat_file[im_name]['keypoints'][()] + image = cv2.imread(osp.join(image_path, im_name)) + print('img_name: ', im_name) + + sems = [] + for pid in p3d_ids: + if pid in p3d_segs.keys(): + sems.append(p3d_segs[pid] + 1) + else: + sems.append(0) + sems = np.array(sems) + + sems = np.array(sems) + mask = sems > 0 + img_seg = vis_seg_point(img=image, kpts=p2ds[mask], segs=sems[mask], seg_color=vis_color) + + cv2.imshow('img', img_seg) + key = cv2.waitKey(0) + if key == ord('q'): + exit(0) + elif key == ord('r'): + # cv2.destroyAllWindows() + return + + def extract_query_p3ds(self, log_fn, feat_fn, save_fn=None): + if save_fn is not None: + if osp.isfile(save_fn): + print('{:s} exists'.format(save_fn)) + return + + loc_log = np.load(log_fn, allow_pickle=True)[()] + fns = loc_log.keys() + feat_file = h5py.File(feat_fn, 'r') + + out = {} + for fn in tqdm(fns, total=len(fns)): + matched_kpts = loc_log[fn]['keypoints_query'] + matched_p3ds = loc_log[fn]['points3D_ids'] + + query_kpts = feat_file[fn]['keypoints'][()].astype(float) + query_p3d_ids = np.zeros(shape=(query_kpts.shape[0],), dtype=int) - 1 + print('matched kpts: {}, query kpts: {}'.format(matched_kpts.shape[0], query_kpts.shape[0])) + + if matched_kpts.shape[0] > 0: + # [M, 2, 1] - [1, 2, N] = [M, 2, N] + dist = torch.from_numpy(matched_kpts).unsqueeze(-1) - torch.from_numpy( + query_kpts.transpose()).unsqueeze(0) + dist = torch.sum(dist ** 2, dim=1) # [M, N] + values, idxes = torch.topk(dist, dim=1, largest=False, k=1) # find the matches kpts with dist of 0 + values = values.numpy() + idxes = idxes.numpy() + for i in range(values.shape[0]): + if values[i, 0] < 1: + query_p3d_ids[idxes[i, 0]] = matched_p3ds[i] + + out[fn] = query_p3d_ids + np.save(save_fn, out) + feat_file.close() + + def compute_mean_scale_p3ds(self, min_obs=5, save_fn=None): + if save_fn is not None: + if osp.isfile(save_fn): + with open(save_fn, 'r') as f: + lines = f.readlines() + l = lines[0].strip().split() + self.mean_xyz = np.array([float(v) for v in l[:3]]) + self.scale_xyz = np.array([float(v) for v in l[3:]]) + print('{} exists'.format(save_fn)) + return + + all_xyzs = [] + for pid in self.points3D.keys(): + p3d = self.points3D[pid] + obs = len(p3d.point2D_idxs) + if obs < min_obs: + continue + all_xyzs.append(p3d.xyz) + + all_xyzs = np.array(all_xyzs) + mean_xyz = np.ceil(np.mean(all_xyzs, axis=0)) + all_xyz_ = all_xyzs - mean_xyz + + dx = np.max(abs(all_xyz_[:, 0])) + dy = np.max(abs(all_xyz_[:, 1])) + dz = np.max(abs(all_xyz_[:, 2])) + scale_xyz = np.ceil(np.array([dx, dy, dz], dtype=float).reshape(3, )) + scale_xyz[scale_xyz < 1] = 1 + scale_xyz[scale_xyz == 0] = 1 + + # self.mean_xyz = mean_xyz + # self.scale_xyz = scale_xyz + # + # if save_fn is not None: + # with open(save_fn, 'w') as f: + # text = '{:.4f} {:.4f} {:.4f} {:.4f} {:.4f} {:.4f}'.format(mean_xyz[0], mean_xyz[1], mean_xyz[2], + # scale_xyz[0], scale_xyz[1], scale_xyz[2]) + # f.write(text + '\n') + + def compute_statics_inlier(self, xyz, nb_neighbors=20, std_ratio=2.0): + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(xyz) + + new_pcd, inlier_ids = pcd.remove_statistical_outlier(nb_neighbors=nb_neighbors, std_ratio=std_ratio) + return inlier_ids + + def export_features_to_directory(self, feat_fn, save_dir, with_descriptors=True): + def print_grp_name(grp_name, object): + try: + n_subgroups = len(object.keys()) + except: + n_subgroups = 0 + dataset_list.append(object.name) + + dataset_list = [] + feat_file = h5py.File(feat_fn, 'r') + feat_file.visititems(print_grp_name) + all_keys = [] + os.makedirs(save_dir, exist_ok=True) + for fn in dataset_list: + subs = fn[1:].split('/')[:-1] # remove the first '/' + subs = '/'.join(map(str, subs)) + if subs in all_keys: + continue + all_keys.append(subs) + + for fn in tqdm(all_keys, total=len(all_keys)): + feat = feat_file[fn] + data = { + # 'descriptors': feat['descriptors'][()].transpose(), + 'scores': feat['scores'][()], + 'keypoints': feat['keypoints'][()], + 'image_size': feat['image_size'][()] + } + np.save(osp.join(save_dir, fn.replace('/', '+')), data) + feat_file.close() + + def get_intrinsics_from_camera(self, camera): + if camera.model in ("SIMPLE_PINHOLE", "SIMPLE_RADIAL", "RADIAL"): + fx = fy = camera.params[0] + cx = camera.params[1] + cy = camera.params[2] + elif camera.model in ("PINHOLE", "OPENCV", "OPENCV_FISHEYE", "FULL_OPENCV"): + fx = camera.params[0] + fy = camera.params[1] + cx = camera.params[2] + cy = camera.params[3] + else: + raise Exception("Camera model not supported") + + # intrinsics + K = np.identity(3) + K[0, 0] = fx + K[1, 1] = fy + K[0, 2] = cx + K[1, 2] = cy + return K + + def compress_map_by_projection_v2(self, vrf_path, point3d_desc_path, vrf_frames=1, covisible_frames=20, radius=20, + nkpts=-1, save_dir=None): + def sparsify_by_grid(h, w, uvs, scores): + nh = np.ceil(h / radius).astype(int) + nw = np.ceil(w / radius).astype(int) + grid = {} + for ip in range(uvs.shape[0]): + p = uvs[ip] + iw = np.rint(p[0] // radius).astype(int) + ih = np.rint(p[1] // radius).astype(int) + idx = ih * nw + iw + if idx in grid.keys(): + if scores[ip] <= grid[idx]['score']: + continue + else: + grid[idx]['score'] = scores[ip] + grid[idx]['ip'] = ip + else: + grid[idx] = { + 'score': scores[ip], + 'ip': ip + } + + retained_ips = [grid[v]['ip'] for v in grid.keys()] + retained_ips = np.array(retained_ips) + return retained_ips + + def choose_valid_p3ds(current_frame_id, covisible_frame_ids, reserved_images): + curr_p3d_ids = [] + curr_xyzs = [] + for pid in self.images[current_frame_id].point3D_ids: + if pid == -1: + continue + if pid not in self.points3D.keys(): + continue + curr_p3d_ids.append(pid) + curr_xyzs.append(self.points3D[pid].xyz) + curr_xyzs = np.array(curr_xyzs) # [N, 3] + curr_xyzs_homo = np.hstack([curr_xyzs, np.ones((curr_xyzs.shape[0], 1), dtype=curr_xyzs.dtype)]) # [N, 4] + + curr_mask = np.array([True for mi in range(curr_xyzs.shape[0])]) # keep all at first + for iim in covisible_frame_ids: + cam_id = self.images[iim].camera_id + width = self.cameras[cam_id].width + height = self.cameras[cam_id].height + qvec = self.images[iim].qvec + tcw = self.images[iim].tvec + Rcw = qvec2rotmat(qvec=qvec) + Tcw = np.eye(4, dtype=float) + Tcw[:3, :3] = Rcw + Tcw[:3, 3] = tcw.reshape(3, ) + + uvs = reserved_images[iim]['xys'] + K = self.get_intrinsics_from_camera(camera=self.cameras[cam_id]) + proj_xys = K @ (Tcw @ curr_xyzs_homo.transpose())[:3, :] # [3, ] + proj_xys = proj_xys.transpose() + depth = proj_xys[:, 2] + proj_xys[:, 0] = proj_xys[:, 0] / depth + proj_xys[:, 1] = proj_xys[:, 1] / depth + + mask_in_image = (proj_xys[:, 0] >= 0) * (proj_xys[:, 0] < width) * (proj_xys[:, 1] >= 0) * ( + proj_xys[:, 1] < height) + mask_depth = proj_xys[:, 2] > 0 + + dist_proj_uv = torch.from_numpy(proj_xys[:, :2])[..., None] - \ + torch.from_numpy(uvs[:, :2].transpose())[None] + dist_proj_uv = torch.sqrt(torch.sum(dist_proj_uv ** 2, dim=1)) # [N, M] + min_dist = torch.min(dist_proj_uv, dim=1)[0].numpy() + mask_close_to_img = (min_dist <= radius) + + mask = mask_in_image * mask_depth * mask_close_to_img # p3ds to be discarded + + curr_mask = curr_mask * (1 - mask) + + chosen_p3d_ids = [] + for mi in range(curr_mask.shape[0]): + if curr_mask[mi]: + chosen_p3d_ids.append(curr_p3d_ids[mi]) + + return chosen_p3d_ids + + vrf_data = np.load(vrf_path, allow_pickle=True)[()] + p3d_ids_in_vrf = [] + image_ids_in_vrf = [] + for sid in vrf_data.keys(): + svrf = vrf_data[sid] + svrf_keys = [vi for vi in range(vrf_frames)] + for vi in svrf_keys: + if vi not in svrf.keys(): + continue + image_id = svrf[vi]['image_id'] + if image_id in image_ids_in_vrf: + continue + image_ids_in_vrf.append(image_id) + for pid in svrf[vi]['original_points3d']: + if pid in p3d_ids_in_vrf: + continue + p3d_ids_in_vrf.append(pid) + + print('Find {:d} images and {:d} 3D points in vrf'.format(len(image_ids_in_vrf), len(p3d_ids_in_vrf))) + + # first_vrf_images_covis = {} + retained_image_ids = {} + for frame_id in image_ids_in_vrf: + observed = self.images[frame_id].point3D_ids + xys = self.images[frame_id].xys + covis = defaultdict(int) + valid_xys = [] + valid_p3d_ids = [] + for xy, pid in zip(xys, observed): + if pid == -1: + continue + if pid not in self.points3D.keys(): + continue + valid_xys.append(xy) + valid_p3d_ids.append(pid) + for img_id in self.points3D[pid].image_ids: + covis[img_id] += 1 + + retained_image_ids[frame_id] = { + 'xys': np.array(valid_xys), + 'p3d_ids': valid_p3d_ids, + } + + print('Find {:d} valid connected frames'.format(len(covis.keys()))) + + covis_ids = np.array(list(covis.keys())) + covis_num = np.array([covis[i] for i in covis_ids]) + + if len(covis_ids) <= covisible_frames: + sel_covis_ids = covis_ids[np.argsort(-covis_num)] + else: + ind_top = np.argpartition(covis_num, -covisible_frames) + ind_top = ind_top[-covisible_frames:] # unsorted top k + ind_top = ind_top[np.argsort(-covis_num[ind_top])] + sel_covis_ids = [covis_ids[i] for i in ind_top] + + covis_frame_ids = [frame_id] + for iim in sel_covis_ids: + if iim == frame_id: + continue + if iim in retained_image_ids.keys(): + covis_frame_ids.append(iim) + continue + + chosen_p3d_ids = choose_valid_p3ds(current_frame_id=iim, covisible_frame_ids=covis_frame_ids, + reserved_images=retained_image_ids) + if len(chosen_p3d_ids) == 0: + continue + + xys = [] + for xy, pid in zip(self.images[iim].xys, self.images[iim].point3D_ids): + if pid in chosen_p3d_ids: + xys.append(xy) + xys = np.array(xys) + + covis_frame_ids.append(iim) + retained_image_ids[iim] = { + 'xys': xys, + 'p3d_ids': chosen_p3d_ids, + } + + new_images = {} + new_point3Ds = {} + new_cameras = {} + for iim in retained_image_ids.keys(): + p3d_ids = retained_image_ids[iim]['p3d_ids'] + ''' this step reduces the performance + for v in self.images[iim].point3D_ids: + if v == -1 or v not in self.points3D: + continue + if v in p3d_ids: + continue + p3d_ids.append(v) + ''' + + xyzs = np.array([self.points3D[pid].xyz for pid in p3d_ids]) + obs = np.array([len(self.points3D[pid].point2D_idxs) for pid in p3d_ids]) + xys = self.images[iim].xys + cam_id = self.images[iim].camera_id + name = self.images[iim].name + qvec = self.images[iim].qvec + tvec = self.images[iim].tvec + + if nkpts > 0 and len(p3d_ids) > nkpts: + proj_uvs = self.reproject(img_id=iim, xyzs=xyzs) + width = self.cameras[cam_id].width + height = self.cameras[cam_id].height + sparsified_idxs = sparsify_by_grid(h=height, w=width, uvs=proj_uvs[:, :2], scores=obs) + + print('org / new kpts: ', len(p3d_ids), sparsified_idxs.shape) + + p3d_ids = [p3d_ids[k] for k in sparsified_idxs] + + new_images[iim] = Image(id=iim, qvec=qvec, tvec=tvec, + camera_id=cam_id, + name=name, + xys=np.array([]), + point3D_ids=np.array(p3d_ids)) + + if cam_id not in new_cameras.keys(): + new_cameras[cam_id] = self.cameras[cam_id] + + for pid in p3d_ids: + if pid in new_point3Ds.keys(): + new_point3Ds[pid]['image_ids'].append(iim) + else: + xyz = self.points3D[pid].xyz + rgb = self.points3D[pid].rgb + error = self.points3D[pid].error + + new_point3Ds[pid] = { + 'image_ids': [iim], + 'rgb': rgb, + 'xyz': xyz, + 'error': error + } + + new_point3Ds_to_save = {} + for pid in new_point3Ds.keys(): + image_ids = new_point3Ds[pid]['image_ids'] + if len(image_ids) == 0: + continue + xyz = new_point3Ds[pid]['xyz'] + rgb = new_point3Ds[pid]['rgb'] + error = new_point3Ds[pid]['error'] + + new_point3Ds_to_save[pid] = Point3D(id=pid, xyz=xyz, rgb=rgb, error=error, image_ids=np.array(image_ids), + point2D_idxs=np.array([])) + + print('Retain {:d}/{:d} images and {:d}/{:d} 3D points'.format(len(new_images), len(self.images), + len(new_point3Ds), len(self.points3D))) + + if save_dir is not None: + os.makedirs(save_dir, exist_ok=True) + # write_images_binary(images=new_image_ids, + # path_to_model_file=osp.join(save_dir, 'images.bin')) + # write_points3d_binary(points3D=new_point3Ds, + # path_to_model_file=osp.join(save_dir, 'points3D.bin')) + write_compressed_images_binary(images=new_images, + path_to_model_file=osp.join(save_dir, 'images.bin')) + write_cameras_binary(cameras=new_cameras, + path_to_model_file=osp.join(save_dir, 'cameras.bin')) + write_compressed_points3d_binary(points3D=new_point3Ds_to_save, + path_to_model_file=osp.join(save_dir, 'points3D.bin')) + + # Save 3d descriptors + p3d_desc = np.load(point3d_desc_path, allow_pickle=True)[()] + comp_p3d_desc = {} + for k in new_point3Ds_to_save.keys(): + if k not in p3d_desc.keys(): + print(k) + continue + comp_p3d_desc[k] = deepcopy(p3d_desc[k]) + np.save(osp.join(save_dir, point3d_desc_path.split('/')[-1]), comp_p3d_desc) + print('Save data to {:s}'.format(save_dir)) + + +def process_dataset(dataset, dataset_dir, sfm_dir, save_dir, feature='sfd2', matcher='gml'): + # dataset_dir = '/scratches/flyer_3/fx221/dataset' + # sfm_dir = '/scratches/flyer_2/fx221/localization/outputs' # your sfm results (cameras, images, points3D) and features + # save_dir = '/scratches/flyer_3/fx221/exp/localizer' + # local_feat = 'sfd2' + # matcher = 'gml' + # hloc_results_dir = '/scratches/flyer_2/fx221/exp/sgd2' + + # config_path = 'configs/datasets/CUED.yaml' + # config_path = 'configs/datasets/7Scenes.yaml' + # config_path = 'configs/datasets/12Scenes.yaml' + # config_path = 'configs/datasets/CambridgeLandmarks.yaml' + # config_path = 'configs/datasets/Aachen.yaml' + + # config_path = 'configs/datasets/Aria.yaml' + # config_path = 'configs/datasets/DarwinRGB.yaml' + # config_path = 'configs/datasets/ACUED.yaml' + # config_path = 'configs/datasets/JesusCollege.yaml' + # config_path = 'configs/datasets/CUED2Kings.yaml' + + config_path = 'configs/datasets/{:s}.yaml'.format(dataset) + with open(config_path, 'rt') as f: + configs = yaml.load(f, Loader=yaml.Loader) + print(configs) + + dataset = configs['dataset'] + all_scenes = configs['scenes'] + for scene in all_scenes: + n_cluster = configs[scene]['n_cluster'] + cluster_mode = configs[scene]['cluster_mode'] + cluster_method = configs[scene]['cluster_method'] + # if scene not in ['heads']: + # continue + + print('scene: ', scene, cluster_mode, cluster_method) + # hloc_path = osp.join(hloc_root, dataset, scene) + sfm_path = osp.join(sfm_dir, scene) + save_path = osp.join(save_dir, feature + '-' + matcher, dataset, scene) + + n_vrf = 1 + n_cov = 30 + radius = 20 + n_kpts = 0 + + if dataset in ['Aachen']: + image_path = osp.join(dataset_dir, scene, 'images/images_upright') + min_obs = 250 + filtering_outliers = True + threshold = 0.2 + radius = 32 + + elif dataset in ['CambridgeLandmarks', ]: + image_path = osp.join(dataset_dir, scene) + min_obs = 250 + filtering_outliers = True + threshold = 0.2 + radius = 64 + elif dataset in ['Aria']: + image_path = osp.join(dataset_dir, scene) + min_obs = 150 + filtering_outliers = False + threshold = 0.01 + radius = 15 + elif dataset in ['DarwinRGB']: + image_path = osp.join(dataset_dir, scene) + min_obs = 150 + filtering_outliers = True + threshold = 0.2 + radius = 16 + elif dataset in ['ACUED']: + image_path = osp.join(dataset_dir, scene) + min_obs = 250 + filtering_outliers = True + threshold = 0.2 + radius = 32 + elif dataset in ['7Scenes', '12Scenes']: + image_path = osp.join(dataset_dir, scene) + min_obs = 150 + filtering_outliers = False + threshold = 0.01 + radius = 15 + else: + image_path = osp.join(dataset_dir, scene) + min_obs = 250 + filtering_outliers = True + threshold = 0.2 + radius = 32 + + # comp_map_sub_path = 'comp_model_n{:d}_{:s}_{:s}_vrf{:d}_cov{:d}_r{:d}_np{:d}_projection_v2'.format(n_cluster, + # cluster_mode, + # cluster_method, + # n_vrf, + # n_cov, + # radius, + # n_kpts) + comp_map_sub_path = 'compress_model_{:s}'.format(cluster_method) + seg_fn = osp.join(save_path, + 'point3D_cluster_n{:d}_{:s}_{:s}.npy'.format(n_cluster, cluster_mode, cluster_method)) + vrf_fn = osp.join(save_path, + 'point3D_vrf_n{:d}_{:s}_{:s}.npy'.format(n_cluster, cluster_mode, cluster_method)) + vrf_img_dir = osp.join(save_path, + 'point3D_vrf_n{:d}_{:s}_{:s}'.format(n_cluster, cluster_mode, cluster_method)) + # p3d_query_fn = osp.join(save_path, + # 'point3D_query_n{:d}_{:s}_{:s}.npy'.format(n_cluster, cluster_mode, cluster_method)) + comp_map_path = osp.join(save_path, comp_map_sub_path) + + os.makedirs(save_path, exist_ok=True) + + rmap = RecMap() + rmap.load_sfm_model(path=osp.join(sfm_path, 'sfm_{:s}-{:s}'.format(feature, matcher))) + if filtering_outliers: + rmap.remove_statics_outlier(nb_neighbors=20, std_ratio=2.0) + + # extract keypoints to train the recognition model (descriptors are recomputed from augmented db images) + # we do this for ddp training (reading h5py file is not supported) + rmap.export_features_to_directory(feat_fn=osp.join(sfm_path, 'feats-{:s}.h5'.format(feature)), + save_dir=osp.join(save_path, 'feats')) # only once for training + + rmap.cluster(k=n_cluster, mode=cluster_mode, save_fn=seg_fn, method=cluster_method, threshold=threshold) + # rmap.visualize_3Dpoints() + rmap.load_segmentation(path=seg_fn) + # rmap.visualize_segmentation(p3d_segs=rmap.p3d_seg, points3D=rmap.points3D) + + # Assign each 3D point a desciptor and discard all 2D images and descriptors - for localization + rmap.assign_point3D_descriptor( + feature_fn=osp.join(sfm_path, 'feats-{:s}.h5'.format(feature)), + save_fn=osp.join(save_path, 'point3D_desc.npy'.format(n_cluster, cluster_mode)), + n_process=32) # only once + + # exit(0) + # rmap.visualize_segmentation_on_image(p3d_segs=rmap.p3d_seg, image_path=image_path, feat_path=feat_path) + + # for query images only - for evaluation + # rmap.extract_query_p3ds( + # log_fn=osp.join(hloc_path, 'hloc_feats-{:s}_{:s}_loc.npy'.format(local_feat, matcher)), + # feat_fn=osp.join(sfm_path, 'feats-{:s}.h5'.format(local_feat)), + # save_fn=p3d_query_fn) + # continue + + # up-to-date + rmap.create_virtual_frame_3( + save_fn=vrf_fn, + save_vrf_dir=vrf_img_dir, + image_root=image_path, + show_time=5, + min_cover_ratio=0.9, + radius=radius, + depth_scale=2.5, # 1.2 by default + min_obs=min_obs, + n_vrf=10, + covisible_frame=n_cov, + ignored_cameras=[]) + + # up-to-date + rmap.compress_map_by_projection_v2( + vrf_frames=n_vrf, + vrf_path=vrf_fn, + point3d_desc_path=osp.join(save_path, 'point3D_desc.npy'), + save_dir=comp_map_path, + covisible_frames=n_cov, + radius=radius, + nkpts=n_kpts, + ) + + # exit(0) + # soft_link_compress_path = osp.join(save_path, 'compress_model_{:s}'.format(cluster_method)) + os.chdir(save_path) + # if osp.isdir(soft_link_compress_path): + # os.unlink(soft_link_compress_path) + # os.symlink(comp_map_sub_path, 'compress_model_{:s}'.format(cluster_method)) + # create a soft link of the full model for training + if not osp.isdir('model'): + os.symlink(osp.join(sfm_path, 'sfm_{:s}-{:s}'.format(feature, matcher)), '3D-models') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--dataset', type=str, required=True, help='dataset name') + parser.add_argument('--dataset_dir', type=str, required=True, help='dataset dir') + parser.add_argument('--sfm_dir', type=str, required=True, help='sfm dir') + parser.add_argument('--save_dir', type=str, required=True, help='dir to save the landmarks data') + parser.add_argument('--feature', type=str, default='sfd2', help='feature name e.g., SP, SFD2') + parser.add_argument('--matcher', type=str, default='gml', help='matcher name e.g., SG, LSG, gml') + + args = parser.parse_args() + + process_dataset( + dataset=args.dataset, + dataset_dir=args.dataset_dir, + sfm_dir=args.sfm_dir, + save_dir=args.save_dir, + feature=args.feature, + matcher=args.matcher) diff --git a/third_party/pram/recognition/vis_seg.py b/third_party/pram/recognition/vis_seg.py new file mode 100644 index 0000000000000000000000000000000000000000..2ef9b2365787e5921a66c74ff6c0b5ec3e49a31a --- /dev/null +++ b/third_party/pram/recognition/vis_seg.py @@ -0,0 +1,225 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> vis_seg +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 07/02/2024 11:06 +==================================================''' +import cv2 +import numpy as np +from copy import deepcopy + + +def myHash(text: str): + hash = 0 + for ch in text: + hash = (hash * 7879 ^ ord(ch) * 5737) & 0xFFFFFFFF + return hash + + +def generate_color_dic(n_seg=1000): + out = {} + for i in range(n_seg + 1): + sid = i + if sid == 0: + color = (0, 0, 255) # [b, g, r] + else: + # rgb_new = hash(str(sid * 319993)) + rgb_new = myHash(str(sid * 319993)) + r = (rgb_new & 0xFF0000) >> 16 + g = (rgb_new & 0x00FF00) >> 8 + b = rgb_new & 0x0000FF + color = (b, g, r) + out[i] = color + return out + + +def vis_seg_point(img, kpts, segs=None, seg_color=None, radius=7, thickness=-1): + outimg = deepcopy(img) + for i in range(kpts.shape[0]): + # print(kpts[i]) + if segs is not None and seg_color is not None: + color = seg_color[segs[i]] + else: + color = (0, 255, 0) + outimg = cv2.circle(outimg, + center=(int(kpts[i, 0]), int(kpts[i, 1])), + color=color, + radius=radius, + thickness=thickness, ) + + return outimg + + +def vis_corr_incorr_point(img, kpts, pred_segs, gt_segs, radius=7, thickness=-1): + outimg = deepcopy(img) + for i in range(kpts.shape[0]): + # print(kpts[i]) + p_seg = pred_segs[i] + g_seg = gt_segs[i] + if p_seg == g_seg: + if g_seg != 0: + color = (0, 255, 0) + else: + color = (255, 0, 0) + else: + color = (0, 0, 255) + outimg = cv2.circle(outimg, + center=(int(kpts[i, 0]), int(kpts[i, 1])), + color=color, + radius=radius, + thickness=thickness, ) + return outimg + + +def vis_inlier(img, kpts, inliers, radius=7, thickness=1, with_outlier=True): + outimg = deepcopy(img) + for i in range(kpts.shape[0]): + if not with_outlier: + if not inliers[i]: + continue + if inliers[i]: + color = (0, 255, 0) + else: + color = (0, 0, 255) + outimg = cv2.rectangle(outimg, + pt1=(int(kpts[i, 0] - radius), int(kpts[i, 1] - radius)), + pt2=(int(kpts[i, 0] + radius), int(kpts[i, 1] + radius)), + color=color, + thickness=thickness, ) + + return outimg + + +def vis_global_seg(cls, seg_color, radius=7, thickness=-1): + all_patches = [] + for i in range(cls.shape[0]): + if cls[i] == 0: + continue + color = seg_color[i] + patch = np.zeros(shape=(radius, radius, 3), dtype=np.uint8) + patch[..., 0] = color[0] + patch[..., 1] = color[1] + patch[..., 2] = color[2] + + all_patches.append(patch) + if len(all_patches) == 0: + color = seg_color[0] + patch = np.zeros(shape=(radius, radius, 3), dtype=np.uint8) + patch[..., 0] = color[0] + patch[..., 1] = color[1] + patch[..., 2] = color[2] + all_patches.append(patch) + return np.vstack(all_patches) + + +def plot_matches(img1, img2, pts1, pts2, inliers, radius=3, line_thickness=2, horizon=True, plot_outlier=False, + confs=None): + rows1 = img1.shape[0] + cols1 = img1.shape[1] + rows2 = img2.shape[0] + cols2 = img2.shape[1] + # r = 3 + if horizon: + img_out = np.zeros((max([rows1, rows2]), cols1 + cols2, 3), dtype='uint8') + # Place the first image to the left + img_out[:rows1, :cols1] = img1 + # Place the next image to the right of it + img_out[:rows2, cols1:] = img2 # np.dstack([img2, img2, img2]) + for idx in range(inliers.shape[0]): + # if idx % 10 > 0: + # continue + if inliers[idx]: + color = (0, 255, 0) + else: + if not plot_outlier: + continue + color = (0, 0, 255) + pt1 = pts1[idx] + pt2 = pts2[idx] + + if confs is not None: + nr = int(radius * confs[idx]) + else: + nr = radius + img_out = cv2.circle(img_out, (int(pt1[0]), int(pt1[1])), nr, color, 2) + + img_out = cv2.circle(img_out, (int(pt2[0]) + cols1, int(pt2[1])), nr, color, 2) + + img_out = cv2.line(img_out, (int(pt1[0]), int(pt1[1])), (int(pt2[0]) + cols1, int(pt2[1])), color, + line_thickness) + else: + img_out = np.zeros((rows1 + rows2, max([cols1, cols2]), 3), dtype='uint8') + # Place the first image to the left + img_out[:rows1, :cols1] = img1 + # Place the next image to the right of it + img_out[rows1:, :cols2] = img2 # np.dstack([img2, img2, img2]) + + for idx in range(inliers.shape[0]): + # print("idx: ", inliers[idx]) + # if idx % 10 > 0: + # continue + if inliers[idx]: + color = (0, 255, 0) + else: + if not plot_outlier: + continue + color = (0, 0, 255) + + if confs is not None: + nr = int(radius * confs[idx]) + else: + nr = radius + + pt1 = pts1[idx] + pt2 = pts2[idx] + img_out = cv2.circle(img_out, (int(pt1[0]), int(pt1[1])), nr, color, 2) + + img_out = cv2.circle(img_out, (int(pt2[0]), int(pt2[1]) + rows1), nr, color, 2) + + img_out = cv2.line(img_out, (int(pt1[0]), int(pt1[1])), (int(pt2[0]), int(pt2[1]) + rows1), color, + line_thickness) + + return img_out + + +def plot_kpts(img, kpts, radius=None, colors=None, r=3, color=(0, 0, 255), nh=-1, nw=-1, shape='o', show_text=None, + thickness=5): + img_out = deepcopy(img) + for i in range(kpts.shape[0]): + pt = kpts[i] + if radius is not None: + if shape == 'o': + img_out = cv2.circle(img_out, center=(int(pt[0]), int(pt[1])), radius=radius[i], + color=color if colors is None else colors[i], + thickness=thickness) + elif shape == '+': + img_out = cv2.line(img_out, pt1=(int(pt[0] - radius[i]), int(pt[1])), + pt2=(int(pt[0] + radius[i]), int(pt[1])), + color=color if colors is None else colors[i], + thickness=5) + img_out = cv2.line(img_out, pt1=(int(pt[0]), int(pt[1] - radius[i])), + pt2=(int(pt[0]), int(pt[1] + radius[i])), color=color, + thickness=thickness) + else: + if shape == 'o': + img_out = cv2.circle(img_out, center=(int(pt[0]), int(pt[1])), radius=r, + color=color if colors is None else colors[i], + thickness=thickness) + elif shape == '+': + img_out = cv2.line(img_out, pt1=(int(pt[0] - r), int(pt[1])), + pt2=(int(pt[0] + r), int(pt[1])), color=color if colors is None else colors[i], + thickness=thickness) + img_out = cv2.line(img_out, pt1=(int(pt[0]), int(pt[1] - r)), + pt2=(int(pt[0]), int(pt[1] + r)), color=color if colors is None else colors[i], + thickness=thickness) + + if show_text is not None: + img_out = cv2.putText(img_out, show_text, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, + (0, 0, 255), 3) + if nh == -1 and nw == -1: + return img_out + if nh > 0: + return cv2.resize(img_out, dsize=(int(img.shape[1] / img.shape[0] * nh), nh)) + if nw > 0: + return cv2.resize(img_out, dsize=(nw, int(img.shape[0] / img.shape[1] * nw))) diff --git a/third_party/pram/sfm_scripts/reconstruct_12scenes.sh b/third_party/pram/sfm_scripts/reconstruct_12scenes.sh new file mode 100644 index 0000000000000000000000000000000000000000..4f79e356a73f897f9e5a3db5cdf4cbf4b689275c --- /dev/null +++ b/third_party/pram/sfm_scripts/reconstruct_12scenes.sh @@ -0,0 +1,91 @@ +#!/bin/bash +# you need to use your own path + +dataset_dir=/scratches/flyer_3/fx221/dataset/12Scenes +ref_sfm_dir=/scratches/flyer_2/fx221/publications/pram_data/3D-models/12Scenes +output_dir=/scratches/flyer_2/fx221/localization/outputs/12Scenes + +feat=sfd2 +matcher=gm + +#feat=superpoint-n4096 +#matcher=superglue + +extract_feat_db=1 +match_db=1 +triangulation=1 +localize=1 + +ransac_thresh=8 +opt_thresh=8 +covisibility_frame=20 +inlier_thresh=30 +obs_thresh=3 + + +#for scene in apt1 apt2 office1 office2 +for scene in apt2 office1 office2 +do + echo $scene + + if [ "$scene" = "apt1" ]; then + all_subscenes='kitchen living' + elif [ "$scene" = "apt2" ]; then + all_subscenes='bed kitchen living luke' + elif [ "$scene" = "office1" ]; then + all_subscenes='gates362 gates381 lounge manolis' + elif [ "$scene" = "office2" ]; then + all_subscenes='5a 5b' + fi + + for subscene in $all_subscenes + do + echo $subscene + + image_dir=$dataset_dir/$scene/$subscene + ref_sfm=$ref_sfm_dir/$scene/$subscene/3D-models + db_pair=$ref_sfm_dir/$scene/$subscene/pairs-db-covis20.txt + outputs=$output_dir/$scene/$subscene + query_pair=$ref_sfm_dir/$scene/$subscene/pairs-query-netvlad20.txt + gt_pose_fn=$ref_sfm_dir/$scene/$subscene/queries_poses.txt + query_fn=$ref_sfm_dir/$scene/$subscene/queries_with_intrinsics.txt + + if [ "$extract_feat_db" -gt "0" ]; then + python3 -m loc.extract_features --image_dir $image_dir --export_dir $outputs/ --conf $feat + fi + + if [ "$match_db" -gt "0" ]; then + python3 -m loc.match_features --pairs $db_pair --export_dir $outputs/ --conf $matcher --features feats-$feat + fi + + if [ "$triangulation" -gt "0" ]; then + python3 -m loc.triangulation \ + --sfm_dir $outputs/sfm_$feat-$matcher \ + --reference_sfm_model $ref_sfm \ + --image_dir $image_dir \ + --pairs $db_pair \ + --features $outputs/feats-$feat.h5 \ + --matches $outputs/feats-$feat-$matcher-pairs-db-covis20.h5 + fi + + if [ "$localize" -gt "0" ]; then + python3 -m loc.localizer \ + --dataset 12Scenes \ + --image_dir $image_dir \ + --save_root $outputs \ + --gt_pose_fn $gt_pose_fn \ + --retrieval $query_pair \ + --reference_sfm $outputs/sfm_$feat-$matcher \ + --queries $query_fn \ + --features $outputs/feats-$feat.h5 \ + --matcher_method $matcher \ + --ransac_thresh $ransac_thresh \ + --covisibility_frame $covisibility_frame \ + --obs_thresh $obs_thresh \ + --opt_thresh $opt_thresh \ + --inlier_thresh $inlier_thresh \ + --use_hloc + fi + done + +done diff --git a/third_party/pram/sfm_scripts/reconstruct_7scenes.sh b/third_party/pram/sfm_scripts/reconstruct_7scenes.sh new file mode 100644 index 0000000000000000000000000000000000000000..91fb16dabc2a294476c0865fc4a5e12e2b4cf0b7 --- /dev/null +++ b/third_party/pram/sfm_scripts/reconstruct_7scenes.sh @@ -0,0 +1,75 @@ +#!/bin/bash + +# you need to use your own path +dataset_dir=/scratches/flyer_3/fx221/dataset/7Scenes +ref_sfm_dir=/scratches/flyer_2/fx221/publications/pram_data/3D-models/7Scenes +output_dir=/scratches/flyer_2/fx221/publications/test_pram/7Scenes + +# keypoints and matcher used for sfm +feat=sfd2 +matcher=gml + + +extract_feat_db=1 +match_db=1 +triangulation=1 +localize=0 + + +ransac_thresh=12 +opt_thresh=12 +covisibility_frame=20 +inlier_thresh=30 +obs_thresh=3 + + +for scene in heads fire office stairs pumpkin redkitchen chess +#for scene in fire office pumpkin redkitchen chess +#for scene in chess +do + echo $scene + image_dir=$dataset_dir/$scene + ref_sfm=$ref_sfm_dir/$scene/3D-models + db_pair=$ref_sfm_dir/$scene/pairs-db-covis20.txt + outputs=$output_dir/$scene + query_pair=$ref_sfm_dir/$scene/pairs-query-netvlad20.txt + gt_pose_fn=$ref_sfm_dir/$scene/queries_poses.txt + query_fn=$ref_sfm_dir/$scene/queries_with_intrinsics.txt + + if [ "$extract_feat_db" -gt "0" ]; then + python3 -m localization.extract_features --image_dir $image_dir --export_dir $outputs/ --conf $feat + fi + + if [ "$match_db" -gt "0" ]; then + python3 -m localization.match_features --pairs $db_pair --export_dir $outputs/ --conf $matcher --features feats-$feat + fi + + if [ "$triangulation" -gt "0" ]; then + python3 -m localization.triangulation \ + --sfm_dir $outputs/sfm_$feat-$matcher \ + --reference_sfm_model $ref_sfm \ + --image_dir $image_dir \ + --pairs $db_pair \ + --features $outputs/feats-$feat.h5 \ + --matches $outputs/feats-$feat-$matcher-pairs-db-covis20.h5 + fi + + if [ "$localize" -gt "0" ]; then + python3 -m localization.localizer \ + --dataset 7Scenes \ + --image_dir $image_dir \ + --save_root $outputs \ + --gt_pose_fn $gt_pose_fn \ + --retrieval $query_pair \ + --reference_sfm $outputs/sfm_$feat-$matcher \ + --queries $query_fn \ + --features $outputs/feats-$feat.h5 \ + --matcher_method $matcher \ + --ransac_thresh $ransac_thresh \ + --covisibility_frame $covisibility_frame \ + --obs_thresh $obs_thresh \ + --opt_thresh $opt_thresh \ + --inlier_thresh $inlier_thresh \ + --use_hloc + fi +done \ No newline at end of file diff --git a/third_party/pram/sfm_scripts/reconstruct_aachen.sh b/third_party/pram/sfm_scripts/reconstruct_aachen.sh new file mode 100644 index 0000000000000000000000000000000000000000..510485e521511f1948060c5d0de5f56984586c8d --- /dev/null +++ b/third_party/pram/sfm_scripts/reconstruct_aachen.sh @@ -0,0 +1,69 @@ +#!/bin/bash +# you need to use your own path +dataset_dir=/scratches/flyer_3/fx221/dataset/Aachen/Aachenv11 +ref_sfm_dir=/scratches/flyer_2/fx221/publications/pram_data/3D-models/Aachen/Aachenv11 +output_dir=/scratches/flyer_2/fx221/localization/outputs/Aachen/Aachenv11 + +# fixed +output=$output_dir +ref_sfm=$ref_sfm_dir/3D-models +db_pair=$ref_sfm_dir/pairs-db-covis20.txt +query_pair=$ref_sfm_dir/pairs-query-netvlad50.txt +gt_pose_fn=$ref_sfm_dir/queries_pose_spp_spg.txt +query_fn=$ref_sfm_dir/queries_with_intrinsics.txt + + + +feat=sfd2 +matcher=gm + +#feat=superpoint-n4096 +#matcher=superglue + +extract_feat_db=1 +match_db=1 +triangulation=1 +localize=1 + +if [ "$extract_feat_db" -gt "0" ]; then + python3 -m loc.extract_features --image_dir $dataset/images/images_upright --export_dir $outputs/ --conf $feat +fi + +if [ "$match_db" -gt "0" ]; then + python3 -m loc.match_features --pairs $ref_sfm_dir/pairs-db-covis20.txt --export_dir $outputs/ --conf $matcher --features feats-$feat +fi + +if [ "$triangulation" -gt "0" ]; then + python3 -m loc.triangulation \ + --sfm_dir $outputs/sfm_$feat-$matcher \ + --reference_sfm_model $ref_sfm \ + --image_dir $dataset/images/images_upright \ + --pairs $db_pair \ + --features $outputs/feats-$feat.h5 \ + --matches $outputs/feats-$feat-$matcher-pairs-db-covis20.h5 +fi + +ransac_thresh=15 +opt_thresh=15 +covisibility_frame=30 +inlier_thresh=80 +obs_thresh=3 + +if [ "$localize" -gt "0" ]; then + python3 -m loc.localizer \ + --dataset aachen_v1.1 \ + --image_dir $image_dir \ + --save_root $outputs \ + --gt_pose_fn $gt_pose_fn \ + --retrieval $query_pair \ + --reference_sfm $outputs/sfm_$feat-$matcher \ + --queries $query_fn \ + --features $outputs/feats-$feat.h5 \ + --matcher_method $matcher \ + --ransac_thresh $ransac_thresh \ + --covisibility_frame $covisibility_frame \ + --obs_thresh $obs_thresh \ + --opt_thresh $opt_thresh \ + --inlier_thresh $inlier_thresh \ + --use_hloc +fi \ No newline at end of file diff --git a/third_party/pram/sfm_scripts/reconstruct_cambridge.sh b/third_party/pram/sfm_scripts/reconstruct_cambridge.sh new file mode 100644 index 0000000000000000000000000000000000000000..f1ee967cf94e16e4a2f1848436d236df9a273858 --- /dev/null +++ b/third_party/pram/sfm_scripts/reconstruct_cambridge.sh @@ -0,0 +1,76 @@ +#!/bin/bash + +# you need to use your own path +dataset_dir=/scratches/flyer_3/fx221/dataset/CambridgeLandmarks +ref_sfm_dir=/scratches/flyer_2/fx221/publications/pram_data/3D-models/CambridgeLandmarks +output_dir=/scratches/flyer_2/fx221/localization/outputs/CambridgeLandmarks + + +feat=sfd2 +matcher=gm + +extract_feat_db=0 +match_db=0 +triangulation=0 +localize=1 + +ransac_thresh=12 +opt_thresh=12 +covisibility_frame=20 +inlier_thresh=30 +radius=30 +obs_thresh=3 + + +#for scene in GreatCourt ShopFacade KingsCollege OldHospital StMarysChurch +for scene in StMarysChurch +#for scene in GreatCourt ShopFacade +do + echo $scene + + image_dir=$dataset_dir/$scene + ref_sfm=$ref_sfm_dir/$scene/3D-models + db_pair=$ref_sfm_dir/$scene/pairs-db-covis20.txt + outputs=$output_dir/$scene + query_pair=$ref_sfm_dir/$scene/pairs-query-netvlad20.txt + gt_pose_fn=$ref_sfm_dir/$scene/queries_poses.txt + query_fn=$ref_sfm_dir/$scene/queries_with_intrinsics.txt + + if [ "$extract_feat_db" -gt "0" ]; then + python3 -m loc.extract_features --image_dir $image_dir --export_dir $outputs/ --conf $feat + fi + + if [ "$match_db" -gt "0" ]; then + python3 -m loc.match_features --pairs $db_pair --export_dir $outputs/ --conf $matcher --features feats-$feat + fi + + if [ "$triangulation" -gt "0" ]; then + python3 -m loc.triangulation \ + --sfm_dir $outputs/sfm_$feat-$matcher \ + --reference_sfm_model $ref_sfm \ + --image_dir $image_dir\ + --pairs $db_pair \ + --features $outputs/feats-$feat.h5 \ + --matches $outputs/feats-$feat-$matcher-pairs-db-covis20.h5 + fi + + if [ "$localize" -gt "0" ]; then + python3 -m loc.localizer \ + --dataset cambridge \ + --image_dir $image_dir \ + --save_root $outputs\ + --gt_pose_fn $gt_pose_fn \ + --retrieval $query_pair \ + --reference_sfm $outputs/sfm_$feat-$matcher \ + --queries $query_fn \ + --features $outputs/feats-$feat.h5 \ + --matcher_method adagm2 \ + --ransac_thresh $ransac_thresh \ + --covisibility_frame $covisibility_frame \ + --obs_thresh $obs_thresh \ + --opt_thresh $opt_thresh \ + --inlier_thresh $inlier_thresh \ + --use_hloc + fi + +done \ No newline at end of file diff --git a/third_party/pram/tools/common.py b/third_party/pram/tools/common.py new file mode 100644 index 0000000000000000000000000000000000000000..8990012575324ed593ebc07bec88d47602005d5f --- /dev/null +++ b/third_party/pram/tools/common.py @@ -0,0 +1,125 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> common +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 15:05 +==================================================''' +import os +import torch +import json +import yaml +import cv2 +import numpy as np +from typing import Tuple +from copy import deepcopy + + +def load_args(args, save_path): + with open(save_path, "r") as f: + args.__dict__ = json.load(f) + + +def save_args_yaml(args, save_path): + with open(save_path, 'w') as f: + yaml.dump(args, f) + + +def merge_tags(tags: list, connection='_'): + out = '' + for i, t in enumerate(tags): + if i == 0: + out = out + t + else: + out = out + connection + t + return out + + +def torch_set_gpu(gpus): + if type(gpus) is int: + gpus = [gpus] + + cuda = all(gpu >= 0 for gpu in gpus) + + if cuda: + os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu) for gpu in gpus]) + # print(os.environ['CUDA_VISIBLE_DEVICES']) + assert cuda and torch.cuda.is_available(), "%s has GPUs %s unavailable" % ( + os.environ['HOSTNAME'], os.environ['CUDA_VISIBLE_DEVICES']) + torch.backends.cudnn.benchmark = True # speed-up cudnn + torch.backends.cudnn.fastest = True # even more speed-up? + print('Launching on GPUs ' + os.environ['CUDA_VISIBLE_DEVICES']) + + else: + print('Launching on CPU') + + return cuda + + +def resize_img(img, nh=-1, nw=-1, rmax=-1, mode=cv2.INTER_NEAREST): + assert nh > 0 or nw > 0 or rmax > 0 + if nh > 0: + return cv2.resize(img, dsize=(int(img.shape[1] / img.shape[0] * nh), nh), interpolation=mode) + if nw > 0: + return cv2.resize(img, dsize=(nw, int(img.shape[0] / img.shape[1] * nw)), interpolation=mode) + if rmax > 0: + oh, ow = img.shape[0], img.shape[1] + if oh > ow: + return cv2.resize(img, dsize=(int(img.shape[1] / img.shape[0] * rmax), rmax), interpolation=mode) + else: + return cv2.resize(img, dsize=(rmax, int(img.shape[0] / img.shape[1] * rmax)), interpolation=mode) + + return cv2.resize(img, dsize=(nw, nh), interpolation=mode) + + +def resize_image_with_padding(image: np.array, nw: int, nh: int, padding_color: Tuple[int] = (0, 0, 0)) -> np.array: + """Maintains aspect ratio and resizes with padding. + Params: + image: Image to be resized. + new_shape: Expected (width, height) of new image. + padding_color: Tuple in BGR of padding color + Returns: + image: Resized image with padding + """ + original_shape = (image.shape[1], image.shape[0]) # (w, h) + ratio_w = nw / original_shape[0] + ratio_h = nh / original_shape[1] + + if ratio_w == ratio_h: + image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_NEAREST) + + ratio = ratio_w if ratio_w < ratio_h else ratio_h + + new_size = tuple([int(x * ratio) for x in original_shape]) + image = cv2.resize(image, new_size, interpolation=cv2.INTER_NEAREST) + delta_w = nw - new_size[0] if nw > new_size[0] else new_size[0] - nw + delta_h = nh - new_size[1] if nh > new_size[1] else new_size[1] - nh + + left, right = delta_w // 2, delta_w - (delta_w // 2) + top, bottom = delta_h // 2, delta_h - (delta_h // 2) + + # print('top, bottom, left, right: ', top, bottom, left, right) + image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=padding_color) + return image + + +def puttext_with_background(image, text, org=(0, 0), fontFace=cv2.FONT_HERSHEY_SIMPLEX, + fontScale=1, text_color=(0, 0, 255), + thickness=2, lineType=cv2.LINE_AA, bg_color=None): + out_img = deepcopy(image) + if bg_color is not None: + (text_width, text_height), baseline = cv2.getTextSize(text, + fontFace, + fontScale=fontScale, + thickness=thickness) + box_coords = ( + (org[0], org[1] + baseline), + (org[0] + text_width + 2, org[1] - text_height - 2)) + + cv2.rectangle(out_img, box_coords[0], box_coords[1], bg_color, cv2.FILLED) + out_img = cv2.putText(img=out_img, text=text, + org=org, + fontFace=fontFace, + fontScale=fontScale, color=text_color, + thickness=thickness, lineType=lineType) + return out_img diff --git a/third_party/pram/tools/geometry.py b/third_party/pram/tools/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..d781a4172dd7f6ad8a4a26e252f614483ebd01e3 --- /dev/null +++ b/third_party/pram/tools/geometry.py @@ -0,0 +1,74 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> geometry +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 07/02/2024 11:08 +==================================================''' +import numpy as np + + +def nms_fast(in_corners, H, W, dist_thresh): + """ + Run a faster approximate Non-Max-Suppression on numpy corners shaped: + 3xN [x_i,y_i,conf_i]^T + + Algo summary: Create a grid sized HxW. Assign each corner location a 1, rest + are zeros. Iterate through all the 1's and convert them either to -1 or 0. + Suppress points by setting nearby values to 0. + + Grid Value Legend: + -1 : Kept. + 0 : Empty or suppressed. + 1 : To be processed (converted to either kept or supressed). + + NOTE: The NMS first rounds points to integers, so NMS distance might not + be exactly dist_thresh. It also assumes points are within image boundaries. + + Inputs + in_corners - 3xN numpy array with corners [x_i, y_i, confidence_i]^T. + H - Image height. + W - Image width. + dist_thresh - Distance to suppress, measured as an infinty norm distance. + Returns + nmsed_corners - 3xN numpy matrix with surviving corners. + nmsed_inds - N length numpy vector with surviving corner indices. + """ + grid = np.zeros((H, W)).astype(int) # Track NMS data. + inds = np.zeros((H, W)).astype(int) # Store indices of points. + # Sort by confidence and round to nearest int. + inds1 = np.argsort(-in_corners[2, :]) + corners = in_corners[:, inds1] + rcorners = corners[:2, :].round().astype(int) # Rounded corners. + # Check for edge case of 0 or 1 corners. + if rcorners.shape[1] == 0: + return np.zeros((3, 0)).astype(int), np.zeros(0).astype(int) + if rcorners.shape[1] == 1: + out = np.vstack((rcorners, in_corners[2])).reshape(3, 1) + return out, np.zeros((1)).astype(int) + # Initialize the grid. + for i, rc in enumerate(rcorners.T): + grid[rcorners[1, i], rcorners[0, i]] = 1 + inds[rcorners[1, i], rcorners[0, i]] = i + # Pad the border of the grid, so that we can NMS points near the border. + pad = dist_thresh + grid = np.pad(grid, ((pad, pad), (pad, pad)), mode='constant') + # Iterate through points, highest to lowest conf, suppress neighborhood. + count = 0 + for i, rc in enumerate(rcorners.T): + # Account for top and left padding. + pt = (rc[0] + pad, rc[1] + pad) + if grid[pt[1], pt[0]] == 1: # If not yet suppressed. + grid[pt[1] - pad:pt[1] + pad + 1, pt[0] - pad:pt[0] + pad + 1] = 0 + grid[pt[1], pt[0]] = -1 + count += 1 + # Get all surviving -1's and return sorted array of remaining corners. + keepy, keepx = np.where(grid == -1) + keepy, keepx = keepy - pad, keepx - pad + inds_keep = inds[keepy, keepx] + out = corners[:, inds_keep] + values = out[-1, :] + inds2 = np.argsort(-values) + out = out[:, inds2] + out_inds = inds1[inds_keep[inds2]] + return out_inds diff --git a/third_party/pram/tools/image_to_video.py b/third_party/pram/tools/image_to_video.py new file mode 100644 index 0000000000000000000000000000000000000000..b8f281fd2cf0ef5eb2752117610c042b8764f5f1 --- /dev/null +++ b/third_party/pram/tools/image_to_video.py @@ -0,0 +1,66 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File localizer -> image_to_video +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 07/09/2023 20:15 +==================================================''' +import cv2 +import os +import os.path as osp + +import numpy as np +from tqdm import tqdm +import argparse + +from tools.common import resize_img + +parser = argparse.ArgumentParser(description='Image2Video', formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--image_dir', type=str, required=True) +parser.add_argument('--video_path', type=str, required=True) +parser.add_argument('--height', type=int, default=-1) +parser.add_argument('--fps', type=int, default=30) + + +def imgs2video(img_dir, video_path, fps=30, height=1024): + img_fns = os.listdir(img_dir) + # print(img_fns) + img_fns = [v for v in img_fns if v.split('.')[-1] in ['jpg', 'png']] + img_fns = sorted(img_fns) + # print(img_fns) + # 输出视频路径 + # fps = 1 + + img = cv2.imread(osp.join(img_dir, img_fns[0])) + if height == -1: + height = img.shape[1] + new_img = resize_img(img=img, nh=height) + img_size = (new_img.shape[1], height) + + # fourcc = cv2.cv.CV_FOURCC('M','J','P','G')#opencv2.4 + # fourcc = cv2.VideoWriter_fourcc('I','4','2','0') + + fourcc = cv2.VideoWriter_fourcc(*'MP4V') # 设置输出视频为mp4格式 + # fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V') # 设置输出视频为mp4格式 + videoWriter = cv2.VideoWriter(video_path, fourcc, fps, img_size) + + for i in tqdm(range(3700, len(img_fns)), total=len(img_fns)): + # fn = img_fns[i].split('-') + im_name = os.path.join(img_dir, img_fns[i]) + print(im_name) + frame = cv2.imread(im_name, 1) + frame = np.flip(frame, 0) + + frame = cv2.resize(frame, dsize=img_size) + # print(frame.shape) + # exit(0) + cv2.imshow("frame", frame) + cv2.waitKey(1) + videoWriter.write(frame) + + videoWriter.release() + + +if __name__ == '__main__': + args = parser.parse_args() + imgs2video(img_dir=args.image_dir, video_path=args.video_path, fps=args.fps, height=args.height) diff --git a/third_party/pram/tools/metrics.py b/third_party/pram/tools/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..22e14374931fa9ba4151632b65b41c65d6ba55f7 --- /dev/null +++ b/third_party/pram/tools/metrics.py @@ -0,0 +1,216 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> metrics +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 16:32 +==================================================''' +import torch +import numpy as np +import torch.nn.functional as F + + +class SeqIOU: + def __init__(self, n_class, ignored_sids=[]): + self.n_class = n_class + self.ignored_sids = ignored_sids + self.class_iou = np.zeros(n_class) + self.precisions = [] + + def add(self, pred, target): + for i in range(self.n_class): + inter = np.sum((pred == target) * (target == i)) + union = np.sum(target == i) + np.sum(pred == i) - inter + if union > 0: + self.class_iou[i] = inter / union + + acc = (pred == target) + if len(self.ignored_sids) == 0: + acc_ratio = np.sum(acc) / pred.shape[0] + else: + pred_mask = (pred >= 0) + target_mask = (target >= 0) + for i in self.ignored_sids: + pred_mask = pred_mask & (pred == i) + target_mask = target_mask & (target == i) + + acc = acc & (1 - pred_mask) + tgt = (1 - target_mask) + if np.sum(tgt) == 0: + acc_ratio = 0 + else: + acc_ratio = np.sum(acc) / np.sum(tgt) + + self.precisions.append(acc_ratio) + + def get_mean_iou(self): + return np.mean(self.class_iou) + + def get_mean_precision(self): + return np.mean(self.precisions) + + def clear(self): + self.precisions = [] + self.class_iou = np.zeros(self.n_class) + + +def compute_iou(pred: np.ndarray, target: np.ndarray, n_class: int, ignored_ids=[]) -> float: + class_iou = np.zeros(n_class) + for i in range(n_class): + if i in ignored_ids: + continue + inter = np.sum((pred == target) * (target == i)) + union = np.sum(target == i) + np.sum(pred == i) - inter + if union > 0: + class_iou[i] = inter / union + + return np.mean(class_iou) + # return class_iou + + +def compute_precision(pred: np.ndarray, target: np.ndarray, ignored_ids: list = []) -> float: + acc = (pred == target) + if len(ignored_ids) == 0: + return np.sum(acc) / pred.shape[0] + else: + pred_mask = (pred >= 0) + target_mask = (target >= 0) + for i in ignored_ids: + pred_mask = pred_mask & (pred == i) + target_mask = target_mask & (target == i) + + acc = acc & (1 - pred_mask) + tgt = (1 - target_mask) + if np.sum(tgt) == 0: + return 0 + return np.sum(acc) / np.sum(tgt) + + +def compute_cls_corr(pred: torch.Tensor, target: torch.Tensor, k: int = 20) -> torch.Tensor: + bs = pred.shape[0] + _, target_ids = torch.topk(target, k=k, dim=1) + target_ids = target_ids.cpu().numpy() + _, top_ids = torch.topk(pred, k=k, dim=1) # [B, k, 1] + top_ids = top_ids.cpu().numpy() + acc = 0 + for i in range(bs): + # print('top_ids: ', i, top_ids[i], target_ids[i]) + overlap = [v for v in top_ids[i] if v in target_ids[i] and v >= 0] + acc = acc + len(overlap) / k + acc = acc / bs + return torch.from_numpy(np.array([acc])).to(pred.device) + + +def compute_corr_incorr(pred: torch.Tensor, target: torch.Tensor, ignored_ids: list = []) -> tuple: + ''' + :param pred: [B, N, C] + :param target: [B, N] + :param ignored_ids: [] + :return: + ''' + pred_ids = torch.max(pred, dim=-1)[1] + if len(ignored_ids) == 0: + acc = (pred_ids == target) + inacc = torch.logical_not(acc) + acc_ratio = torch.sum(acc) / torch.numel(target) + inacc_ratio = torch.sum(inacc) / torch.numel(target) + else: + acc = (pred_ids == target) + inacc = torch.logical_not(acc) + + mask = torch.zeros_like(acc) + for i in ignored_ids: + mask = torch.logical_and(mask, (target == i)) + + acc = torch.logical_and(acc, torch.logical_not(mask)) + acc_ratio = torch.sum(acc) / torch.numel(target) + inacc_ratio = torch.sum(inacc) / torch.numel(target) + + return acc_ratio, inacc_ratio + + +def compute_seg_loss_weight(pred: torch.Tensor, + target: torch.Tensor, + background_id: int = 0, + weight_background: float = 0.1) -> torch.Tensor: + ''' + :param pred: [B, C, N] + :param target: [B, N] + :param background_id: + :param weight_background: + :return: + ''' + pred = pred.transpose(-2, -1).contiguous() # [B, N, C] -> [B, C, N] + weight = torch.ones(size=(pred.shape[1],), device=pred.device).float() + pred = torch.log_softmax(pred, dim=1) + weight[background_id] = weight_background + seg_loss = F.cross_entropy(pred, target.long(), weight=weight) + return seg_loss + + +def compute_cls_loss_ce(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + cls_loss = torch.zeros(size=[], device=pred.device) + if len(pred.shape) == 2: + n_valid = torch.sum(target > 0) + cls_loss = cls_loss + torch.nn.functional.cross_entropy(pred, target, reduction='sum') + cls_loss = cls_loss / n_valid + else: + for i in range(pred.shape[-1]): + cls_loss = cls_loss + torch.nn.functional.cross_entropy(pred[..., i], target[..., i], reduction='sum') + n_valid = torch.sum(target > 0) + cls_loss = cls_loss / n_valid + + return cls_loss + + +def compute_cls_loss_kl(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + cls_loss = torch.zeros(size=[], device=pred.device) + if len(pred.shape) == 2: + cls_loss = cls_loss + torch.nn.functional.kl_div(torch.log_softmax(pred, dim=-1), + torch.softmax(target, dim=-1), + reduction='sum') + else: + for i in range(pred.shape[-1]): + cls_loss = cls_loss + torch.nn.functional.kl_div(torch.log_softmax(pred[..., i], dim=-1), + torch.softmax(target[..., i], dim=-1), + reduction='sum') + + cls_loss = cls_loss / pred.shape[-1] + + return cls_loss + + +def compute_sc_loss_l1(pred: torch.Tensor, target: torch.Tensor, mean_xyz=None, scale_xyz=None, mask=None): + ''' + :param pred: [B, N, C] + :param target: [B, N, C] + :param mean_xyz: + :param scale_xyz: + :param mask: + :return: + ''' + loss = (pred - target) + loss = torch.abs(loss).mean(dim=1) + if mask is not None: + return torch.mean(loss[mask]) + else: + return torch.mean(loss) + + +def compute_sc_loss_geo(pred: torch.Tensor, P, K, p2ds, mean_xyz, scale_xyz, max_value=20, mask=None): + b, c, n = pred.shape + p3ds = (pred * scale_xyz[..., None].repeat(1, 1, n) + mean_xyz[..., None].repeat(1, 1, n)) + p3ds_homo = torch.cat( + [pred, torch.ones(size=(p3ds.shape[0], 1, p3ds.shape[2]), dtype=p3ds.dtype, device=p3ds.device)], + dim=1) # [B, 4, N] + p3ds = torch.matmul(K, torch.matmul(P, p3ds_homo)[:, :3, :]) # [B, 3, N] + # print('p3ds: ', p3ds.shape, P.shape, K.shape, p2ds.shape) + + p2ds_ = p3ds[:, :2, :] / p3ds[:, 2:, :] + + loss = ((p2ds_ - p2ds.permute(0, 2, 1)) ** 2).sum(1) + loss = torch.clamp_max(loss, max=max_value) + if mask is not None: + return torch.mean(loss[mask]) + else: + return torch.mean(loss) diff --git a/third_party/pram/tools/video_to_image.py b/third_party/pram/tools/video_to_image.py new file mode 100644 index 0000000000000000000000000000000000000000..7283f3ba24d432410ea326a7d9aedbe011b60ed2 --- /dev/null +++ b/third_party/pram/tools/video_to_image.py @@ -0,0 +1,38 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File localizer -> video_to_image +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 13/01/2024 15:29 +==================================================''' +import argparse +import os +import os.path as osp +import cv2 + +parser = argparse.ArgumentParser(description='Image2Video', formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--image_path', type=str, required=True) +parser.add_argument('--video_path', type=str, required=True) +parser.add_argument('--height', type=int, default=-1) +parser.add_argument('--sample_ratio', type=int, default=-1) + + +def main(args): + video = cv2.VideoCapture(args.video_path) + nframe = 0 + while True: + ret, frame = video.read() + if ret: + if args.sample_ratio > 0: + if nframe % args.sample_ratio != 0: + nframe += 1 + continue + cv2.imwrite(osp.join(args.image_path, '{:06d}.png'.format(nframe)), frame) + nframe += 1 + else: + break + + +if __name__ == '__main__': + args = parser.parse_args() + main(args=args) diff --git a/third_party/pram/tools/visualize_landmarks.py b/third_party/pram/tools/visualize_landmarks.py new file mode 100644 index 0000000000000000000000000000000000000000..7f8bcba35c14b929de1159c3a9491a98e1f0aebb --- /dev/null +++ b/third_party/pram/tools/visualize_landmarks.py @@ -0,0 +1,171 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> visualize_landmarks +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 22/03/2024 10:39 +==================================================''' +import os +import os.path as osp +import numpy as np +from tqdm import tqdm +from colmap_utils.read_write_model import read_model, write_model, Point3D, Image, read_compressed_model +from recognition.vis_seg import generate_color_dic + + +def reconstruct_map(valid_image_ids, valid_p3d_ids, cameras, images, point3Ds, p3d_seg: dict): + new_point3Ds = {} + new_images = {} + + valid_p3d_ids_ = [] + for pid in tqdm(valid_p3d_ids, total=len(valid_p3d_ids)): + + if pid == -1: + continue + if pid not in point3Ds.keys(): + continue + + if pid not in p3d_seg.keys(): + continue + + sid = map_seg[pid] + if sid == -1: + continue + valid_p3d_ids_.append(pid) + + valid_p3d_ids = valid_p3d_ids_ + print('valid_p3ds: ', len(valid_p3d_ids)) + + # for im_id in tqdm(images.keys(), total=len(images.keys())): + for im_id in tqdm(valid_image_ids, total=len(valid_image_ids)): + im = images[im_id] + # print('im: ', im) + # exit(0) + pids = im.point3D_ids + valid_pids = [] + # for v in pids: + # if v not in valid_p3d_ids: + # valid_pids.append(-1) + # else: + # valid_pids.append(v) + + new_im = Image(id=im_id, qvec=im.qvec, tvec=im.tvec, camera_id=im.camera_id, name=im.name, xys=im.xys, + point3D_ids=pids) + new_images[im_id] = new_im + + for pid in tqdm(valid_p3d_ids, total=len(valid_p3d_ids)): + sid = map_seg[pid] + + xyz = points3D[pid].xyz + if show_2D: + xyz[1] = 0 + rgb = points3D[pid].rgb + else: + bgr = seg_color[sid + sid_start] + rgb = np.array([bgr[2], bgr[1], bgr[0]]) + + error = points3D[pid].error + + p3d = Point3D(id=pid, xyz=xyz, rgb=rgb, error=error, + image_ids=points3D[pid].image_ids, + point2D_idxs=points3D[pid].point2D_idxs) + new_point3Ds[pid] = p3d + + return cameras, new_images, new_point3Ds + + +if __name__ == '__main__': + save_root = '/scratches/flyer_3/fx221/exp/localizer/vis_clustering/' + seg_color = generate_color_dic(n_seg=2000) + data_root = '/scratches/flyer_3/fx221/exp/localizer/resnet4x-20230511-210205-pho-0005-gm' + show_2D = False + + compress_map = False + # compress_map = True + + # scene = 'Aachen/Aachenv11' + # seg_data = np.load(osp.join(data_root, scene, 'point3D_cluster_n512_xz_birch.npy'), allow_pickle=True)[()] + # sid_start = 1 + # vrf_file_name = 'point3D_vrf_n512_xz_birch.npy' + + # + # scene = 'CambridgeLandmarks/GreatCourt' + # seg_data = np.load(osp.join(data_root, scene, 'point3D_cluster_n32_xy_birch.npy'), allow_pickle=True)[()] + # sid_start = 1 + + # scene = 'CambridgeLandmarks/KingsCollege' + # seg_data = np.load(osp.join(data_root, scene, 'point3D_cluster_n32_xy_birch.npy'), allow_pickle=True)[()] + # sid_start = 33 + # vrf_file_name = 'point3D_vrf_n32_xy_birch.npy' + + # scene = 'CambridgeLandmarks/StMarysChurch' + # seg_data = np.load(osp.join(data_root, scene, 'point3D_cluster_n32_xz_birch.npy'), allow_pickle=True)[()] + # sid_start = 32 * 4 + 1 + # vrf_file_name = 'point3D_vrf_n32_xz_birch.npy' + + # scene = '7Scenes/office' + # seg_data = np.load(osp.join(data_root, scene, 'point3D_cluster_n16_xz_birch.npy'), allow_pickle=True)[()] + # sid_start = 33 + + # scene = '7Scenes/chess' + # seg_data = np.load(osp.join(data_root, scene, 'point3D_cluster_n16_xz_birch.npy'), allow_pickle=True)[()] + # sid_start = 1 + # vrf_file_name = 'point3D_vrf_n16_xz_birch.npy' + + # scene = '7Scenes/redkitchen' + # seg_data = np.load(osp.join(data_root, scene, 'point3D_cluster_n16_xz_birch.npy'), allow_pickle=True)[()] + # sid_start = 16 * 5 + 1 + # vrf_file_name = 'point3D_vrf_n16_xz_birch.npy' + + # scene = '12Scenes/apt1/kitchen' + # seg_data = np.load(osp.join(data_root, scene, 'point3D_cluster_n16_xy_birch.npy'), allow_pickle=True)[()] + # sid_start = 1 + # vrf_file_name = 'point3D_vrf_n16_xy_birch.npy' + + # data_root = '/scratches/flyer_3/fx221/exp/localizer/resnet4x-20230511-210205-pho-0005-gml2' + # scene = 'JesusCollege/jesuscollege' + # seg_data = np.load(osp.join(data_root, scene, 'point3D_cluster_n256_xy_birch.npy'), allow_pickle=True)[()] + # sid_start = 1 + # vrf_file_name = 'point3D_vrf_n256_xy_birch.npy' + + scene = 'DarwinRGB/darwin' + seg_data = np.load(osp.join(data_root, scene, 'point3D_cluster_n128_xy_birch.npy'), allow_pickle=True)[()] + sid_start = 1 + vrf_file_name = 'point3D_vrf_n128_xy_birch.npy' + + cameras, images, points3D = read_model(osp.join(data_root, scene, 'model'), ext='.bin') + print('Load {:d} 3D points from map'.format(len(points3D.keys()))) + + if compress_map: + vrf_data = np.load(osp.join(data_root, scene, vrf_file_name), allow_pickle=True)[()] + valid_image_ids = [vrf_data[v][0]['image_id'] for v in vrf_data.keys()] + else: + valid_image_ids = list(images.keys()) + + if compress_map: + _, _, compress_points3D = read_compressed_model(osp.join(data_root, scene, 'compress_model_birch'), + ext='.bin') + print('Load {:d} 3D points from compressed map'.format(len(compress_points3D.keys()))) + valid_p3d_ids = list(compress_points3D.keys()) + else: + valid_p3d_ids = list(points3D.keys()) + + save_path = osp.join(save_root, scene) + + if compress_map: + save_path = save_path + '_comp' + if show_2D: + save_path = save_path + '_2D' + + os.makedirs(save_path, exist_ok=True) + p3d_id = seg_data['id'] + seg_id = seg_data['label'] + map_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])} + + new_cameras, new_images, new_point3Ds = reconstruct_map(valid_image_ids=valid_image_ids, + valid_p3d_ids=valid_p3d_ids, cameras=cameras, images=images, + point3Ds=points3D, p3d_seg=map_seg) + + # write_model(cameras=cameras, images=images, points3D=new_point3Ds, + # path=save_path, ext='.bin') + write_model(cameras=new_cameras, images=new_images, points3D=new_point3Ds, path=save_path, ext='.bin') diff --git a/third_party/pram/train.py b/third_party/pram/train.py new file mode 100644 index 0000000000000000000000000000000000000000..9a2657f455d29c7c7c5417d8efa7aacaef4207ed --- /dev/null +++ b/third_party/pram/train.py @@ -0,0 +1,170 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> train +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 03/04/2024 16:33 +==================================================''' +import argparse +import os +import os.path as osp +import torch +import torchvision.transforms.transforms as tvt +import yaml +import torch.utils.data as Data +import torch.multiprocessing as mp +import torch.distributed as dist + +from nets.sfd2 import load_sfd2 +from nets.segnet import SegNet +from nets.segnetvit import SegNetViT +from nets.load_segnet import load_segnet +from dataset.utils import collect_batch +from dataset.get_dataset import compose_datasets +from tools.common import torch_set_gpu +from trainer import Trainer + + +def get_model(config): + desc_dim = 256 if config['feature'] == 'spp' else 128 + if config['use_mid_feature']: + desc_dim = 256 + model_config = { + 'network': { + 'descriptor_dim': desc_dim, + 'n_layers': config['layers'], + 'ac_fn': config['ac_fn'], + 'norm_fn': config['norm_fn'], + 'n_class': config['n_class'], + 'output_dim': config['output_dim'], + # 'with_cls': config['with_cls'], + # 'with_sc': config['with_sc'], + 'with_score': config['with_score'], + } + } + + if config['network'] == 'segnet': + model = SegNet(model_config.get('network', {})) + config['with_cls'] = False + elif config['network'] == 'segnetvit': + model = SegNetViT(model_config.get('network', {})) + config['with_cls'] = False + else: + raise 'ERROR! {:s} model does not exist'.format(config['network']) + + return model + + +parser = argparse.ArgumentParser(description='PRAM', formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--config', type=str, required=True, help='config of specifications') +# parser.add_argument('--landmark_path', type=str, required=True, help='path of landmarks') +parser.add_argument('--feat_weight_path', type=str, default='weights/sfd2_20230511_210205_resnet4x.79.pth') + + +def setup(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + # initialize the process group + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + +def train_DDP(rank, world_size, model, config, train_set, test_set, feat_model, img_transforms): + print('In train_DDP..., rank: ', rank) + torch.cuda.set_device(rank) + + device = torch.device(f'cuda:{rank}') + if feat_model is not None: + feat_model.to(device) + model.to(device) + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + setup(rank=rank, world_size=world_size) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank]) + train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, + shuffle=True, + rank=rank, + num_replicas=world_size, + drop_last=True, # important? + ) + train_loader = torch.utils.data.DataLoader(train_set, + batch_size=config['batch_size'] // world_size, + num_workers=config['workers'] // world_size, + # num_workers=1, + pin_memory=True, + # persistent_workers=True, + shuffle=False, # must be False + drop_last=True, + collate_fn=collect_batch, + prefetch_factor=4, + sampler=train_sampler) + config['local_rank'] = rank + + if rank == 0: + test_set = test_set + else: + test_set = None + + trainer = Trainer(model=model, train_loader=train_loader, feat_model=feat_model, eval_loader=test_set, + config=config, img_transforms=img_transforms) + trainer.train() + + +if __name__ == '__main__': + args = parser.parse_args() + with open(args.config, 'rt') as f: + config = yaml.load(f, Loader=yaml.Loader) + torch_set_gpu(gpus=config['gpu']) + if config['local_rank'] == 0: + print(config) + + img_transforms = [] + img_transforms.append(tvt.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) + img_transforms = tvt.Compose(img_transforms) + + feat_model = load_sfd2(weight_path=args.feat_weight_path).cuda().eval() + print('Load SFD2 weight from {:s}'.format(args.feat_weight_path)) + + dataset = config['dataset'] + train_set = compose_datasets(datasets=dataset, config=config, train=True, sample_ratio=None) + if config['do_eval']: + test_set = compose_datasets(datasets=dataset, config=config, train=False, sample_ratio=None) + else: + test_set = None + config['n_class'] = train_set.n_class + # model = get_model(config=config) + model = load_segnet(network=config['network'], + n_class=config['n_class'], + desc_dim=256 if config['use_mid_feature'] else 128, + n_layers=config['layers'], + output_dim=config['output_dim']) + if config['local_rank'] == 0: + if config['resume_path'] is not None: # only for training + model.load_state_dict( + torch.load(osp.join(config['save_path'], config['resume_path']), map_location='cpu')['model'], + strict=True) + print('Load resume weight from {:s}'.format(osp.join(config['save_path'], config['resume_path']))) + + if not config['with_dist'] or len(config['gpu']) == 1: + config['with_dist'] = False + model = model.cuda() + train_loader = Data.DataLoader(dataset=train_set, + shuffle=True, + batch_size=config['batch_size'], + drop_last=True, + collate_fn=collect_batch, + num_workers=config['workers']) + if test_set is not None: + test_loader = Data.DataLoader(dataset=test_set, + shuffle=False, + batch_size=1, + drop_last=False, + collate_fn=collect_batch, + num_workers=4) + else: + test_loader = None + trainer = Trainer(model=model, train_loader=train_loader, feat_model=feat_model, eval_loader=test_loader, + config=config, img_transforms=img_transforms) + trainer.train() + else: + mp.spawn(train_DDP, nprocs=len(config['gpu']), + args=(len(config['gpu']), model, config, train_set, test_set, feat_model, img_transforms), + join=True) diff --git a/third_party/pram/trainer.py b/third_party/pram/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..002e349323ec587843ea4119a0bc32b343bd34dd --- /dev/null +++ b/third_party/pram/trainer.py @@ -0,0 +1,404 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> trainer +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 15:04 +==================================================''' +import datetime +import os +import os.path as osp +import numpy as np +from pathlib import Path +from tensorboardX import SummaryWriter +from tqdm import tqdm +import torch.optim as optim +import torch.nn.functional as F + +import shutil +import torch +from torch.autograd import Variable +from tools.common import save_args_yaml, merge_tags +from tools.metrics import compute_iou, compute_precision, SeqIOU, compute_corr_incorr, compute_seg_loss_weight +from tools.metrics import compute_cls_loss_ce, compute_cls_corr + + +class Trainer: + def __init__(self, model, train_loader, feat_model=None, eval_loader=None, config=None, img_transforms=None): + self.model = model + self.train_loader = train_loader + self.eval_loader = eval_loader + self.config = config + self.with_aug = self.config['with_aug'] + self.with_cls = False # self.config['with_cls'] + self.with_sc = False # self.config['with_sc'] + self.img_transforms = img_transforms + self.feat_model = feat_model.cuda().eval() if feat_model is not None else None + + self.init_lr = self.config['lr'] + self.min_lr = self.config['min_lr'] + + params = [p for p in self.model.parameters() if p.requires_grad] + self.optimizer = optim.AdamW(params=params, lr=self.init_lr) + self.num_epochs = self.config['epochs'] + + if config['resume_path'] is not None: + log_dir = config['resume_path'].split('/')[-2] + resume_log = torch.load(osp.join(osp.join(config['save_path'], config['resume_path'])), map_location='cpu') + self.epoch = resume_log['epoch'] + 1 + if 'iteration' in resume_log.keys(): + self.iteration = resume_log['iteration'] + else: + self.iteration = len(self.train_loader) * self.epoch + self.min_loss = resume_log['min_loss'] + else: + self.iteration = 0 + self.epoch = 0 + self.min_loss = 1e10 + + now = datetime.datetime.now() + all_tags = [now.strftime("%Y%m%d_%H%M%S")] + dataset_name = merge_tags(self.config['dataset'], '') + all_tags = all_tags + [self.config['network'], 'L' + str(self.config['layers']), + dataset_name, + str(self.config['feature']), 'B' + str(self.config['batch_size']), + 'K' + str(self.config['max_keypoints']), 'od' + str(self.config['output_dim']), + 'nc' + str(self.config['n_class'])] + if self.config['use_mid_feature']: + all_tags.append('md') + # if self.with_cls: + # all_tags.append(self.config['cls_loss']) + # if self.with_sc: + # all_tags.append(self.config['sc_loss']) + if self.with_aug: + all_tags.append('A') + + all_tags.append(self.config['cluster_method']) + log_dir = merge_tags(tags=all_tags, connection='_') + + if config['local_rank'] == 0: + self.save_dir = osp.join(self.config['save_path'], log_dir) + os.makedirs(self.save_dir, exist_ok=True) + + print("save_dir: ", self.save_dir) + + self.log_file = open(osp.join(self.save_dir, "log.txt"), "a+") + save_args_yaml(args=config, save_path=Path(self.save_dir, "args.yaml")) + self.writer = SummaryWriter(self.save_dir) + + self.tag = log_dir + + self.do_eval = self.config['do_eval'] + if self.do_eval: + self.eval_fun = None + self.seq_metric = SeqIOU(n_class=self.config['n_class'], ignored_sids=[0]) + + def preprocess_input(self, pred): + for k in pred.keys(): + if k.find('name') >= 0: + continue + if k != 'image' and k != 'depth': + if type(pred[k]) == torch.Tensor: + pred[k] = Variable(pred[k].float().cuda()) + else: + pred[k] = Variable(torch.stack(pred[k]).float().cuda()) + + if self.with_aug: + new_scores = [] + new_descs = [] + global_descs = [] + with torch.no_grad(): + for i, im in enumerate(pred['image']): + img = torch.from_numpy(im[0]).cuda().float().permute(2, 0, 1) + # img = self.img_transforms(img)[None] + if self.img_transforms is not None: + img = self.img_transforms(img)[None] + else: + img = img[None] + out = self.feat_model.extract_local_global(data={'image': img}) + global_descs.append(out['global_descriptors']) + + seg_scores, seg_descs = self.feat_model.sample(score_map=out['score_map'], + semi_descs=out['mid_features'] if self.config[ + 'use_mid_feature'] else out['desc_map'], + kpts=pred['keypoints'][i], + norm_desc=self.config['norm_desc']) # [D, N] + new_scores.append(seg_scores[None]) + new_descs.append(seg_descs[None]) + pred['global_descriptors'] = global_descs + pred['scores'] = torch.cat(new_scores, dim=0) + pred['seg_descriptors'] = torch.cat(new_descs, dim=0).permute(0, 2, 1) # -> [B, N, D] + + def process_epoch(self): + self.model.train() + + epoch_cls_losses = [] + epoch_seg_losses = [] + epoch_losses = [] + epoch_acc_corr = [] + epoch_acc_incorr = [] + epoch_cls_acc = [] + + epoch_sc_losses = [] + + for bidx, pred in tqdm(enumerate(self.train_loader), total=len(self.train_loader)): + self.preprocess_input(pred) + if 0 <= self.config['its_per_epoch'] <= bidx: + break + + data = self.model(pred) + for k, v in pred.items(): + pred[k] = v + pred = {**pred, **data} + + seg_loss = compute_seg_loss_weight(pred=pred['prediction'], + target=pred['gt_seg'], + background_id=0, + weight_background=0.1) + acc_corr, acc_incorr = compute_corr_incorr(pred=pred['prediction'], + target=pred['gt_seg'], + ignored_ids=[0]) + + if self.with_cls: + pred_cls_dist = pred['classification'] + gt_cls_dist = pred['gt_cls_dist'] + if len(pred_cls_dist.shape) > 2: + gt_cls_dist_full = gt_cls_dist.unsqueeze(-1).repeat(1, 1, pred_cls_dist.shape[-1]) + else: + gt_cls_dist_full = gt_cls_dist.unsqueeze(-1) + cls_loss = compute_cls_loss_ce(pred=pred_cls_dist, target=gt_cls_dist_full) + loss = seg_loss + cls_loss + + # gt_n_seg = pred['gt_n_seg'] + cls_acc = compute_cls_corr(pred=pred_cls_dist.squeeze(-1), target=gt_cls_dist) + else: + loss = seg_loss + cls_loss = torch.zeros_like(seg_loss) + cls_acc = torch.zeros_like(seg_loss) + + if self.with_sc: + pass + else: + sc_loss = torch.zeros_like(seg_loss) + + epoch_losses.append(loss.item()) + epoch_seg_losses.append(seg_loss.item()) + epoch_cls_losses.append(cls_loss.item()) + epoch_sc_losses.append(sc_loss.item()) + + epoch_acc_corr.append(acc_corr.item()) + epoch_acc_incorr.append(acc_incorr.item()) + epoch_cls_acc.append(cls_acc.item()) + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + self.iteration += 1 + + lr = min(self.config['lr'] * self.config['decay_rate'] ** (self.iteration - self.config['decay_iter']), + self.config['lr']) + if lr < self.min_lr: + lr = self.min_lr + + for param_group in self.optimizer.param_groups: + param_group['lr'] = lr + + if self.config['local_rank'] == 0 and bidx % self.config['log_intervals'] == 0: + print_text = 'Epoch [{:d}/{:d}], Step [{:d}/{:d}/{:d}], Loss [s{:.2f}/c{:.2f}/sc{:.2f}/t{:.2f}], Acc [c{:.2f}/{:.2f}/{:.2f}]'.format( + self.epoch, + self.num_epochs, bidx, + len(self.train_loader), + self.iteration, + seg_loss.item(), + cls_loss.item(), + sc_loss.item(), + loss.item(), + + np.mean(epoch_acc_corr), + np.mean(epoch_acc_incorr), + np.mean(epoch_cls_acc) + ) + + print(print_text) + self.log_file.write(print_text + '\n') + + info = { + 'lr': lr, + 'loss': loss.item(), + 'cls_loss': cls_loss.item(), + 'sc_loss': sc_loss.item(), + 'acc_corr': acc_corr.item(), + 'acc_incorr': acc_incorr.item(), + 'acc_cls': cls_acc.item(), + } + + for k, v in info.items(): + self.writer.add_scalar(tag=k, scalar_value=v, global_step=self.iteration) + + if self.config['local_rank'] == 0: + print_text = 'Epoch [{:d}/{:d}], AVG Loss [s{:.2f}/c{:.2f}/sc{:.2f}/t{:.2f}], Acc [c{:.2f}/{:.2f}/{:.2f}]\n'.format( + self.epoch, + self.num_epochs, + np.mean(epoch_seg_losses), + np.mean(epoch_cls_losses), + np.mean(epoch_sc_losses), + np.mean(epoch_losses), + np.mean(epoch_acc_corr), + np.mean(epoch_acc_incorr), + np.mean(epoch_cls_acc), + ) + print(print_text) + self.log_file.write(print_text + '\n') + self.log_file.flush() + return np.mean(epoch_losses) + + def eval_seg(self, loader): + print('Start to do evaluation...') + + self.model.eval() + self.seq_metric.clear() + mean_iou_day = [] + mean_iou_night = [] + mean_prec_day = [] + mean_prec_night = [] + mean_cls_day = [] + mean_cls_night = [] + + for bid, pred in tqdm(enumerate(loader), total=len(loader)): + for k in pred.keys(): + if k.find('name') >= 0: + continue + if k != 'image' and k != 'depth': + if type(pred[k]) == torch.Tensor: + pred[k] = Variable(pred[k].float().cuda()) + elif type(pred[k]) == np.ndarray: + pred[k] = Variable(torch.from_numpy(pred[k]).float()[None].cuda()) + else: + pred[k] = Variable(torch.stack(pred[k]).float().cuda()) + + if self.with_aug: + with torch.no_grad(): + if isinstance(pred['image'][0], list): + img = pred['image'][0][0] + else: + img = pred['image'][0] + + img = torch.from_numpy(img).cuda().float().permute(2, 0, 1) + if self.img_transforms is not None: + img = self.img_transforms(img)[None] + else: + img = img[None] + + encoder_out = self.feat_model.extract_local_global(data={'image': img}) + global_descriptors = [encoder_out['global_descriptors']] + pred['global_descriptors'] = global_descriptors + if self.config['use_mid_feature']: + scores, descs = self.feat_model.sample(score_map=encoder_out['score_map'], + semi_descs=encoder_out['mid_features'], + kpts=pred['keypoints'][0], + norm_desc=self.config['norm_desc']) + # print('eval: ', scores.shape, descs.shape) + pred['scores'] = scores[None] + pred['seg_descriptors'] = descs[None].permute(0, 2, 1) # -> [B, N, D] + else: + pred['seg_descriptors'] = pred['descriptors'] + + image_name = pred['file_name'][0] + with torch.no_grad(): + out = self.model(pred) + pred = {**pred, **out} + + pred_seg = torch.max(pred['prediction'], dim=-1)[1] # [B, N, C] + pred_seg = pred_seg[0].cpu().numpy() + gt_seg = pred['gt_seg'][0].cpu().numpy() + iou = compute_iou(pred=pred_seg, target=gt_seg, n_class=self.config['n_class'], ignored_ids=[0]) + prec = compute_precision(pred=pred_seg, target=gt_seg, ignored_ids=[0]) + + if self.with_cls: + pred_cls_dist = pred['classification'] + gt_cls_dist = pred['gt_cls_dist'] + cls_acc = compute_cls_corr(pred=pred_cls_dist.squeeze(-1), target=gt_cls_dist).item() + else: + cls_acc = 0. + + if image_name.find('night') >= 0: + mean_iou_night.append(iou) + mean_prec_night.append(prec) + mean_cls_night.append(cls_acc) + else: + mean_iou_day.append(iou) + mean_prec_day.append(prec) + mean_cls_day.append(cls_acc) + + print_txt = 'Eval Epoch {:d}, iou day/night {:.3f}/{:.3f}, prec day/night {:.3f}/{:.3f}, cls day/night {:.3f}/{:.3f}'.format( + self.epoch, np.mean(mean_iou_day), np.mean(mean_iou_night), + np.mean(mean_prec_day), np.mean(mean_prec_night), + np.mean(mean_cls_day), np.mean(mean_cls_night)) + self.log_file.write(print_txt + '\n') + print(print_txt) + + info = { + 'mean_iou_day': np.mean(mean_iou_day), + 'mean_iou_night': np.mean(mean_iou_night), + 'mean_prec_day': np.mean(mean_prec_day), + 'mean_prec_night': np.mean(mean_prec_night), + } + + for k, v in info.items(): + self.writer.add_scalar(tag=k, scalar_value=v, global_step=self.epoch) + + return np.mean(mean_prec_night) + + def train(self): + if self.config['local_rank'] == 0: + print('Start to train the model from epoch: {:d}'.format(self.epoch)) + hist_values = [] + min_value = self.min_loss + + epoch = self.epoch + while epoch < self.num_epochs: + if self.config['with_dist']: + self.train_loader.sampler.set_epoch(epoch=epoch) + self.epoch = epoch + + train_loss = self.process_epoch() + + # return with loss INF/NAN + if train_loss is None: + continue + + if self.config['local_rank'] == 0: + if self.do_eval and self.epoch % self.config['eval_n_epoch'] == 0: # and self.epoch >= 50: + eval_ratio = self.eval_seg(loader=self.eval_loader) + + hist_values.append(eval_ratio) # higher better + else: + hist_values.append(-train_loss) # lower better + + checkpoint_path = os.path.join(self.save_dir, + '%s.%02d.pth' % (self.config['network'], self.epoch)) + checkpoint = { + 'epoch': self.epoch, + 'iteration': self.iteration, + 'model': self.model.state_dict(), + 'min_loss': min_value, + } + # for multi-gpu training + if len(self.config['gpu']) > 1: + checkpoint['model'] = self.model.module.state_dict() + + torch.save(checkpoint, checkpoint_path) + + if hist_values[-1] < min_value: + min_value = hist_values[-1] + best_checkpoint_path = os.path.join( + self.save_dir, + '%s.best.pth' % (self.tag) + ) + shutil.copy(checkpoint_path, best_checkpoint_path) + # important!!! + epoch += 1 + + if self.config['local_rank'] == 0: + self.log_file.close() diff --git a/ui/config.yaml b/ui/config.yaml index d94cc3f67789b454c248b10468b9b2354ba358a9..28d0a5106718e25e6e3fd31cfe95d270bb0d3b17 100644 --- a/ui/config.yaml +++ b/ui/config.yaml @@ -389,9 +389,8 @@ matcher_zoo: sfd2+imp: matcher: imp feature: sfd2 - enable: false + enable: true dense: false - skip_ci: true info: name: SFD2+IMP #dispaly name source: "CVPR 2023" @@ -403,9 +402,8 @@ matcher_zoo: sfd2+mnn: matcher: NN-mutual feature: sfd2 - enable: false + enable: true dense: false - skip_ci: true info: name: SFD2+MNN #dispaly name source: "CVPR 2023"