diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..c72653fb43fdb1e87c3c25e94257e81e24d87907 --- /dev/null +++ b/.gitignore @@ -0,0 +1,28 @@ +build/ +# lib +bin/ +cmake_modules/ +cmake-build-debug/ +.idea/ +.vscode/ +*.pyc +flagged +.ipynb_checkpoints +__pycache__ +Untitled* +experiments +third_party/REKD +hloc/matchers/dedode.py +gradio_cached_examples +*.mp4 +hloc/matchers/quadtree.py +third_party/QuadTreeAttention +desktop.ini +*.egg-info +output.pkl +log.txt +experiments* +gen_example.py +datasets/lines/terrace0.JPG +datasets/lines/terrace1.JPG +datasets/South-Building* diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..7455862d5fe993d55e63f79fb63f1d274f25774e --- /dev/null +++ b/Dockerfile @@ -0,0 +1,27 @@ +# Use an official conda-based Python image as a parent image +FROM pytorch/pytorch:2.4.0-cuda12.1-cudnn9-runtime +LABEL maintainer vincentqyw +ARG PYTHON_VERSION=3.10.10 + +# Set the working directory to /code +WORKDIR /code + +# Install Git and Git LFS +RUN apt-get update && apt-get install -y git-lfs +RUN git lfs install + +# Clone the Git repository +RUN git clone https://huggingface.co/spaces/Realcat/image-matching-webui /code + +RUN conda create -n imw python=${PYTHON_VERSION} +RUN echo "source activate imw" > ~/.bashrc +ENV PATH /opt/conda/envs/imw/bin:$PATH + +# Make RUN commands use the new environment +SHELL ["conda", "run", "-n", "imw", "/bin/bash", "-c"] +RUN pip install --upgrade pip +RUN pip install -r requirements.txt +RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y + +# Export port +EXPOSE 7860 diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index 7e4e6ce7a10f7dbba861de1e668a74c9ed4afd60..1c73b52835d9a630e3174dfe7bed069470bdf798 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,155 @@ --- title: Imatchui -emoji: 🐨 +emoji: 🤗 colorFrom: red -colorTo: indigo +colorTo: yellow sdk: gradio sdk_version: 5.4.0 app_file: app.py -pinned: false +pinned: true +license: apache-2.0 --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +[![Contributors][contributors-shield]][contributors-url] +[![Forks][forks-shield]][forks-url] +[![Stargazers][stars-shield]][stars-url] +[![Issues][issues-shield]][issues-url] + +

+


Image Matching WebUI
Identify matching points between two images

+

+ +## Description + +This simple tool efficiently matches image pairs using multiple famous image matching algorithms. The tool features a Graphical User Interface (GUI) designed using [gradio](https://gradio.app/). You can effortlessly select two images and a matching algorithm and obtain a precise matching result. +**Note**: the images source can be either local images or webcam images. + +Try it on + + Open In Studio + + +Here is a demo of the tool: + +![demo](assets/demo.gif) + +The tool currently supports various popular image matching algorithms, namely: +- [x] [EfficientLoFTR](https://github.com/zju3dv/EfficientLoFTR), CVPR 2024 +- [x] [MASt3R](https://github.com/naver/mast3r), CVPR 2024 +- [x] [DUSt3R](https://github.com/naver/dust3r), CVPR 2024 +- [x] [OmniGlue](https://github.com/Vincentqyw/omniglue-onnx), CVPR 2024 +- [x] [XFeat](https://github.com/verlab/accelerated_features), CVPR 2024 +- [x] [RoMa](https://github.com/Vincentqyw/RoMa), CVPR 2024 +- [x] [DeDoDe](https://github.com/Parskatt/DeDoDe), 3DV 2024 +- [ ] [Mickey](https://github.com/nianticlabs/mickey), CVPR 2024 +- [x] [GIM](https://github.com/xuelunshen/gim), ICLR 2024 +- [ ] [DUSt3R](https://github.com/naver/dust3r), arXiv 2023 +- [x] [LightGlue](https://github.com/cvg/LightGlue), ICCV 2023 +- [x] [DarkFeat](https://github.com/THU-LYJ-Lab/DarkFeat), AAAI 2023 +- [x] [SFD2](https://github.com/feixue94/sfd2), CVPR 2023 +- [x] [IMP](https://github.com/feixue94/imp-release), CVPR 2023 +- [ ] [ASTR](https://github.com/ASTR2023/ASTR), CVPR 2023 +- [ ] [SEM](https://github.com/SEM2023/SEM), CVPR 2023 +- [ ] [DeepLSD](https://github.com/cvg/DeepLSD), CVPR 2023 +- [x] [GlueStick](https://github.com/cvg/GlueStick), ICCV 2023 +- [ ] [ConvMatch](https://github.com/SuhZhang/ConvMatch), AAAI 2023 +- [x] [LoFTR](https://github.com/zju3dv/LoFTR), CVPR 2021 +- [x] [SOLD2](https://github.com/cvg/SOLD2), CVPR 2021 +- [ ] [LineTR](https://github.com/yosungho/LineTR), RA-L 2021 +- [x] [DKM](https://github.com/Parskatt/DKM), CVPR 2023 +- [ ] [NCMNet](https://github.com/xinliu29/NCMNet), CVPR 2023 +- [x] [TopicFM](https://github.com/Vincentqyw/TopicFM), AAAI 2023 +- [x] [AspanFormer](https://github.com/Vincentqyw/ml-aspanformer), ECCV 2022 +- [x] [LANet](https://github.com/wangch-g/lanet), ACCV 2022 +- [ ] [LISRD](https://github.com/rpautrat/LISRD), ECCV 2022 +- [ ] [REKD](https://github.com/bluedream1121/REKD), CVPR 2022 +- [x] [CoTR](https://github.com/ubc-vision/COTR), ICCV 2021 +- [x] [ALIKE](https://github.com/Shiaoming/ALIKE), TMM 2022 +- [x] [RoRD](https://github.com/UditSinghParihar/RoRD), IROS 2021 +- [x] [SGMNet](https://github.com/vdvchen/SGMNet), ICCV 2021 +- [x] [SuperPoint](https://github.com/magicleap/SuperPointPretrainedNetwork), CVPRW 2018 +- [x] [SuperGlue](https://github.com/magicleap/SuperGluePretrainedNetwork), CVPR 2020 +- [x] [D2Net](https://github.com/Vincentqyw/d2-net), CVPR 2019 +- [x] [R2D2](https://github.com/naver/r2d2), NeurIPS 2019 +- [x] [DISK](https://github.com/cvlab-epfl/disk), NeurIPS 2020 +- [ ] [Key.Net](https://github.com/axelBarroso/Key.Net), ICCV 2019 +- [ ] [OANet](https://github.com/zjhthu/OANet), ICCV 2019 +- [x] [SOSNet](https://github.com/scape-research/SOSNet), CVPR 2019 +- [x] [HardNet](https://github.com/DagnyT/hardnet), NeurIPS 2017 +- [x] [SIFT](https://docs.opencv.org/4.x/da/df5/tutorial_py_sift_intro.html), IJCV 2004 + +## How to use + +### HuggingFace / Lightning AI + +Just try it on + + Open In Studio + + +or deploy it locally following the instructions below. + +### Requirements +``` bash +git clone --recursive https://github.com/Vincentqyw/image-matching-webui.git +cd image-matching-webui +conda env create -f environment.yaml +conda activate imw +``` + +or using [docker](https://hub.docker.com/r/vincentqin/image-matching-webui): + +``` bash +docker pull vincentqin/image-matching-webui:latest +docker run -it -p 7860:7860 vincentqin/image-matching-webui:latest python app.py --server_name "0.0.0.0" --server_port=7860 +``` + +### Run demo +``` bash +python3 ./app.py +``` +then open http://localhost:7860 in your browser. + +![](assets/gui.jpg) + +### Add your own feature / matcher + +I provide an example to add local feature in [hloc/extractors/example.py](hloc/extractors/example.py). Then add feature settings in `confs` in file [hloc/extract_features.py](hloc/extract_features.py). Last step is adding some settings to `model_zoo` in file [ui/config.yaml](ui/config.yaml). + +## Contributions welcome! + +External contributions are very much welcome. Please follow the [PEP8 style guidelines](https://www.python.org/dev/peps/pep-0008/) using a linter like flake8 (reformat using command `python -m black .`). This is a non-exhaustive list of features that might be valuable additions: + +- [x] add webcam support +- [x] add [line feature matching](https://github.com/Vincentqyw/LineSegmentsDetection) algorithms +- [x] example to add a new feature extractor / matcher +- [x] ransac to filter outliers +- [ ] add [rotation images](https://github.com/pidahbus/deep-image-orientation-angle-detection) options before matching +- [ ] support export matches to colmap ([#issue 6](https://github.com/Vincentqyw/image-matching-webui/issues/6)) +- [ ] add config file to set default parameters +- [ ] dynamically load models and reduce GPU overload + +Adding local features / matchers as submodules is very easy. For example, to add the [GlueStick](https://github.com/cvg/GlueStick): + +``` bash +git submodule add https://github.com/cvg/GlueStick.git third_party/GlueStick +``` + +If remote submodule repositories are updated, don't forget to pull submodules with `git submodule update --remote`, if you only want to update one submodule, use `git submodule update --remote third_party/GlueStick`. + +## Resources +- [Image Matching: Local Features & Beyond](https://image-matching-workshop.github.io) +- [Long-term Visual Localization](https://www.visuallocalization.net) + +## Acknowledgement + +This code is built based on [Hierarchical-Localization](https://github.com/cvg/Hierarchical-Localization). We express our gratitude to the authors for their valuable source code. + +[contributors-shield]: https://img.shields.io/github/contributors/Vincentqyw/image-matching-webui.svg?style=for-the-badge +[contributors-url]: https://github.com/Vincentqyw/image-matching-webui/graphs/contributors +[forks-shield]: https://img.shields.io/github/forks/Vincentqyw/image-matching-webui.svg?style=for-the-badge +[forks-url]: https://github.com/Vincentqyw/image-matching-webui/network/members +[stars-shield]: https://img.shields.io/github/stars/Vincentqyw/image-matching-webui.svg?style=for-the-badge +[stars-url]: https://github.com/Vincentqyw/image-matching-webui/stargazers +[issues-shield]: https://img.shields.io/github/issues/Vincentqyw/image-matching-webui.svg?style=for-the-badge +[issues-url]: https://github.com/Vincentqyw/image-matching-webui/issues \ No newline at end of file diff --git a/api/__init__.py b/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/client.py b/api/client.py new file mode 100644 index 0000000000000000000000000000000000000000..4fd751c6bc359e8edf162aa67f30f8240a90de3a --- /dev/null +++ b/api/client.py @@ -0,0 +1,225 @@ +import argparse +import base64 +import os +import pickle +import time +from typing import Dict, List + +import cv2 +import numpy as np +import requests + +ENDPOINT = "http://127.0.0.1:8001" +if "REMOTE_URL_RAILWAY" in os.environ: + ENDPOINT = os.environ["REMOTE_URL_RAILWAY"] + +print(f"API ENDPOINT: {ENDPOINT}") + +API_VERSION = f"{ENDPOINT}/version" +API_URL_MATCH = f"{ENDPOINT}/v1/match" +API_URL_EXTRACT = f"{ENDPOINT}/v1/extract" + + +def read_image(path: str) -> str: + """ + Read an image from a file, encode it as a JPEG and then as a base64 string. + + Args: + path (str): The path to the image to read. + + Returns: + str: The base64 encoded image. + """ + # Read the image from the file + img = cv2.imread(path, cv2.IMREAD_GRAYSCALE) + + # Encode the image as a png, NO COMPRESSION!!! + retval, buffer = cv2.imencode(".png", img) + + # Encode the JPEG as a base64 string + b64img = base64.b64encode(buffer).decode("utf-8") + + return b64img + + +def do_api_requests(url=API_URL_EXTRACT, **kwargs): + """ + Helper function to send an API request to the image matching service. + + Args: + url (str): The URL of the API endpoint to use. Defaults to the + feature extraction endpoint. + **kwargs: Additional keyword arguments to pass to the API. + + Returns: + List[Dict[str, np.ndarray]]: A list of dictionaries containing the + extracted features. The keys are "keypoints", "descriptors", and + "scores", and the values are ndarrays of shape (N, 2), (N, ?), + and (N,), respectively. + """ + # Set up the request body + reqbody = { + # List of image data base64 encoded + "data": [], + # List of maximum number of keypoints to extract from each image + "max_keypoints": [100, 100], + # List of timestamps for each image (not used?) + "timestamps": ["0", "1"], + # Whether to convert the images to grayscale + "grayscale": 0, + # List of image height and width + "image_hw": [[640, 480], [320, 240]], + # Type of feature to extract + "feature_type": 0, + # List of rotation angles for each image + "rotates": [0.0, 0.0], + # List of scale factors for each image + "scales": [1.0, 1.0], + # List of reference points for each image (not used) + "reference_points": [[640, 480], [320, 240]], + # Whether to binarize the descriptors + "binarize": True, + } + # Update the request body with the additional keyword arguments + reqbody.update(kwargs) + try: + # Send the request + r = requests.post(url, json=reqbody) + if r.status_code == 200: + # Return the response + return r.json() + else: + # Print an error message if the response code is not 200 + print(f"Error: Response code {r.status_code} - {r.text}") + except Exception as e: + # Print an error message if an exception occurs + print(f"An error occurred: {e}") + + +def send_request_match(path0: str, path1: str) -> Dict[str, np.ndarray]: + """ + Send a request to the API to generate a match between two images. + + Args: + path0 (str): The path to the first image. + path1 (str): The path to the second image. + + Returns: + Dict[str, np.ndarray]: A dictionary containing the generated matches. + The keys are "keypoints0", "keypoints1", "matches0", and "matches1", + and the values are ndarrays of shape (N, 2), (N, 2), (N, 2), and + (N, 2), respectively. + """ + files = {"image0": open(path0, "rb"), "image1": open(path1, "rb")} + try: + # TODO: replace files with post json + response = requests.post(API_URL_MATCH, files=files) + pred = {} + if response.status_code == 200: + pred = response.json() + for key in list(pred.keys()): + pred[key] = np.array(pred[key]) + else: + print( + f"Error: Response code {response.status_code} - {response.text}" + ) + finally: + files["image0"].close() + files["image1"].close() + return pred + + +def send_request_extract( + input_images: str, viz: bool = False +) -> List[Dict[str, np.ndarray]]: + """ + Send a request to the API to extract features from an image. + + Args: + input_images (str): The path to the image. + + Returns: + List[Dict[str, np.ndarray]]: A list of dictionaries containing the + extracted features. The keys are "keypoints", "descriptors", and + "scores", and the values are ndarrays of shape (N, 2), (N, 128), + and (N,), respectively. + """ + image_data = read_image(input_images) + inputs = { + "data": [image_data], + } + response = do_api_requests( + url=API_URL_EXTRACT, + **inputs, + ) + print("Keypoints detected: {}".format(len(response[0]["keypoints"]))) + + # draw matching, debug only + if viz: + from hloc.utils.viz import plot_keypoints + from ui.viz import fig2im, plot_images + + kpts = np.array(response[0]["keypoints_orig"]) + if "image_orig" in response[0].keys(): + img_orig = np.array(["image_orig"]) + + output_keypoints = plot_images([img_orig], titles="titles", dpi=300) + plot_keypoints([kpts]) + output_keypoints = fig2im(output_keypoints) + cv2.imwrite( + "demo_match.jpg", + output_keypoints[:, :, ::-1].copy(), # RGB -> BGR + ) + return response + + +def get_api_version(): + try: + response = requests.get(API_VERSION).json() + print("API VERSION: {}".format(response["version"])) + except Exception as e: + print(f"An error occurred: {e}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Send text to stable audio server and receive generated audio." + ) + parser.add_argument( + "--image0", + required=False, + help="Path for the file's melody", + default="datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot45.jpg", + ) + parser.add_argument( + "--image1", + required=False, + help="Path for the file's melody", + default="datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot90.jpg", + ) + args = parser.parse_args() + + # get api version + get_api_version() + + # request match + # for i in range(10): + # t1 = time.time() + # preds = send_request_match(args.image0, args.image1) + # t2 = time.time() + # print( + # "Time cost1: {} seconds, matched: {}".format( + # (t2 - t1), len(preds["mmkeypoints0_orig"]) + # ) + # ) + + # request extract + for i in range(10): + t1 = time.time() + preds = send_request_extract(args.image0) + t2 = time.time() + print(f"Time cost2: {(t2 - t1)} seconds") + + # dump preds + with open("preds.pkl", "wb") as f: + pickle.dump(preds, f) diff --git a/api/server.py b/api/server.py new file mode 100644 index 0000000000000000000000000000000000000000..1a1edc5e75a7b1353364d3fba56d4aa94fabe0b9 --- /dev/null +++ b/api/server.py @@ -0,0 +1,499 @@ +# server.py +import base64 +import io +import sys +import warnings +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import torch +import uvicorn +from fastapi import FastAPI, File, UploadFile +from fastapi.exceptions import HTTPException +from fastapi.responses import JSONResponse +from PIL import Image + +sys.path.append(str(Path(__file__).parents[1])) + +from api.types import ImagesInput +from hloc import DEVICE, extract_features, logger, match_dense, match_features +from hloc.utils.viz import add_text, plot_keypoints +from ui import get_version +from ui.utils import filter_matches, get_feature_model, get_model +from ui.viz import display_matches, fig2im, plot_images + +warnings.simplefilter("ignore") + + +def decode_base64_to_image(encoding): + if encoding.startswith("data:image/"): + encoding = encoding.split(";")[1].split(",")[1] + try: + image = Image.open(io.BytesIO(base64.b64decode(encoding))) + return image + except Exception as e: + logger.warning(f"API cannot decode image: {e}") + raise HTTPException( + status_code=500, detail="Invalid encoded image" + ) from e + + +def to_base64_nparray(encoding: str) -> np.ndarray: + return np.array(decode_base64_to_image(encoding)).astype("uint8") + + +class ImageMatchingAPI(torch.nn.Module): + default_conf = { + "ransac": { + "enable": True, + "estimator": "poselib", + "geometry": "homography", + "method": "RANSAC", + "reproj_threshold": 3, + "confidence": 0.9999, + "max_iter": 10000, + }, + } + + def __init__( + self, + conf: dict = {}, + device: str = "cpu", + detect_threshold: float = 0.015, + max_keypoints: int = 1024, + match_threshold: float = 0.2, + ) -> None: + """ + Initializes an instance of the ImageMatchingAPI class. + + Args: + conf (dict): A dictionary containing the configuration parameters. + device (str, optional): The device to use for computation. Defaults to "cpu". + detect_threshold (float, optional): The threshold for detecting keypoints. Defaults to 0.015. + max_keypoints (int, optional): The maximum number of keypoints to extract. Defaults to 1024. + match_threshold (float, optional): The threshold for matching keypoints. Defaults to 0.2. + + Returns: + None + """ + super().__init__() + self.device = device + self.conf = {**self.default_conf, **conf} + self._updata_config(detect_threshold, max_keypoints, match_threshold) + self._init_models() + if device == "cuda": + memory_allocated = torch.cuda.memory_allocated(device) + memory_reserved = torch.cuda.memory_reserved(device) + logger.info( + f"GPU memory allocated: {memory_allocated / 1024**2:.3f} MB" + ) + logger.info( + f"GPU memory reserved: {memory_reserved / 1024**2:.3f} MB" + ) + self.pred = None + + def parse_match_config(self, conf): + if conf["dense"]: + return { + **conf, + "matcher": match_dense.confs.get( + conf["matcher"]["model"]["name"] + ), + "dense": True, + } + else: + return { + **conf, + "feature": extract_features.confs.get( + conf["feature"]["model"]["name"] + ), + "matcher": match_features.confs.get( + conf["matcher"]["model"]["name"] + ), + "dense": False, + } + + def _updata_config( + self, + detect_threshold: float = 0.015, + max_keypoints: int = 1024, + match_threshold: float = 0.2, + ): + self.dense = self.conf["dense"] + if self.conf["dense"]: + try: + self.conf["matcher"]["model"][ + "match_threshold" + ] = match_threshold + except TypeError as e: + logger.error(e) + else: + self.conf["feature"]["model"]["max_keypoints"] = max_keypoints + self.conf["feature"]["model"][ + "keypoint_threshold" + ] = detect_threshold + self.extract_conf = self.conf["feature"] + + self.match_conf = self.conf["matcher"] + + def _init_models(self): + # initialize matcher + self.matcher = get_model(self.match_conf) + # initialize extractor + if self.dense: + self.extractor = None + else: + self.extractor = get_feature_model(self.conf["feature"]) + + def _forward(self, img0, img1): + if self.dense: + pred = match_dense.match_images( + self.matcher, + img0, + img1, + self.match_conf["preprocessing"], + device=self.device, + ) + last_fixed = "{}".format( # noqa: F841 + self.match_conf["model"]["name"] + ) + else: + pred0 = extract_features.extract( + self.extractor, img0, self.extract_conf["preprocessing"] + ) + pred1 = extract_features.extract( + self.extractor, img1, self.extract_conf["preprocessing"] + ) + pred = match_features.match_images(self.matcher, pred0, pred1) + return pred + + @torch.inference_mode() + def extract(self, img0: np.ndarray, **kwargs) -> Dict[str, np.ndarray]: + """Extract features from a single image. + + Args: + img0 (np.ndarray): image + + Returns: + Dict[str, np.ndarray]: feature dict + """ + + # setting prams + self.extractor.conf["max_keypoints"] = kwargs.get("max_keypoints", 512) + self.extractor.conf["keypoint_threshold"] = kwargs.get( + "keypoint_threshold", 0.0 + ) + + pred = extract_features.extract( + self.extractor, img0, self.extract_conf["preprocessing"] + ) + pred = { + k: v.cpu().detach()[0].numpy() if isinstance(v, torch.Tensor) else v + for k, v in pred.items() + } + # back to origin scale + s0 = pred["original_size"] / pred["size"] + pred["keypoints_orig"] = ( + match_features.scale_keypoints(pred["keypoints"] + 0.5, s0) - 0.5 + ) + # TODO: rotate back + + binarize = kwargs.get("binarize", False) + if binarize: + assert "descriptors" in pred + pred["descriptors"] = (pred["descriptors"] > 0).astype(np.uint8) + pred["descriptors"] = pred["descriptors"].T # N x DIM + return pred + + @torch.inference_mode() + def forward( + self, + img0: np.ndarray, + img1: np.ndarray, + ) -> Dict[str, np.ndarray]: + """ + Forward pass of the image matching API. + + Args: + img0: A 3D NumPy array of shape (H, W, C) representing the first image. + Values are in the range [0, 1] and are in RGB mode. + img1: A 3D NumPy array of shape (H, W, C) representing the second image. + Values are in the range [0, 1] and are in RGB mode. + + Returns: + A dictionary containing the following keys: + - image0_orig: The original image 0. + - image1_orig: The original image 1. + - keypoints0_orig: The keypoints detected in image 0. + - keypoints1_orig: The keypoints detected in image 1. + - mkeypoints0_orig: The raw matches between image 0 and image 1. + - mkeypoints1_orig: The raw matches between image 1 and image 0. + - mmkeypoints0_orig: The RANSAC inliers in image 0. + - mmkeypoints1_orig: The RANSAC inliers in image 1. + - mconf: The confidence scores for the raw matches. + - mmconf: The confidence scores for the RANSAC inliers. + """ + # Take as input a pair of images (not a batch) + assert isinstance(img0, np.ndarray) + assert isinstance(img1, np.ndarray) + self.pred = self._forward(img0, img1) + if self.conf["ransac"]["enable"]: + self.pred = self._geometry_check(self.pred) + return self.pred + + def _geometry_check( + self, + pred: Dict[str, Any], + ) -> Dict[str, Any]: + """ + Filter matches using RANSAC. If keypoints are available, filter by keypoints. + If lines are available, filter by lines. If both keypoints and lines are + available, filter by keypoints. + + Args: + pred (Dict[str, Any]): dict of matches, including original keypoints. + See :func:`filter_matches` for the expected keys. + + Returns: + Dict[str, Any]: filtered matches + """ + pred = filter_matches( + pred, + ransac_method=self.conf["ransac"]["method"], + ransac_reproj_threshold=self.conf["ransac"]["reproj_threshold"], + ransac_confidence=self.conf["ransac"]["confidence"], + ransac_max_iter=self.conf["ransac"]["max_iter"], + ) + return pred + + def visualize( + self, + log_path: Optional[Path] = None, + ) -> None: + """ + Visualize the matches. + + Args: + log_path (Path, optional): The directory to save the images. Defaults to None. + + Returns: + None + """ + if self.conf["dense"]: + postfix = str(self.conf["matcher"]["model"]["name"]) + else: + postfix = "{}_{}".format( + str(self.conf["feature"]["model"]["name"]), + str(self.conf["matcher"]["model"]["name"]), + ) + titles = [ + "Image 0 - Keypoints", + "Image 1 - Keypoints", + ] + pred: Dict[str, Any] = self.pred + image0: np.ndarray = pred["image0_orig"] + image1: np.ndarray = pred["image1_orig"] + output_keypoints: np.ndarray = plot_images( + [image0, image1], titles=titles, dpi=300 + ) + if ( + "keypoints0_orig" in pred.keys() + and "keypoints1_orig" in pred.keys() + ): + plot_keypoints([pred["keypoints0_orig"], pred["keypoints1_orig"]]) + text: str = ( + f"# keypoints0: {len(pred['keypoints0_orig'])} \n" + + f"# keypoints1: {len(pred['keypoints1_orig'])}" + ) + add_text(0, text, fs=15) + output_keypoints = fig2im(output_keypoints) + # plot images with raw matches + titles = [ + "Image 0 - Raw matched keypoints", + "Image 1 - Raw matched keypoints", + ] + output_matches_raw, num_matches_raw = display_matches( + pred, titles=titles, tag="KPTS_RAW" + ) + # plot images with ransac matches + titles = [ + "Image 0 - Ransac matched keypoints", + "Image 1 - Ransac matched keypoints", + ] + output_matches_ransac, num_matches_ransac = display_matches( + pred, titles=titles, tag="KPTS_RANSAC" + ) + if log_path is not None: + img_keypoints_path: Path = log_path / f"img_keypoints_{postfix}.png" + img_matches_raw_path: Path = ( + log_path / f"img_matches_raw_{postfix}.png" + ) + img_matches_ransac_path: Path = ( + log_path / f"img_matches_ransac_{postfix}.png" + ) + cv2.imwrite( + str(img_keypoints_path), + output_keypoints[:, :, ::-1].copy(), # RGB -> BGR + ) + cv2.imwrite( + str(img_matches_raw_path), + output_matches_raw[:, :, ::-1].copy(), # RGB -> BGR + ) + cv2.imwrite( + str(img_matches_ransac_path), + output_matches_ransac[:, :, ::-1].copy(), # RGB -> BGR + ) + plt.close("all") + + +class ImageMatchingService: + def __init__(self, conf: dict, device: str): + self.conf = conf + self.api = ImageMatchingAPI(conf=conf, device=device) + self.app = FastAPI() + self.register_routes() + + def register_routes(self): + + @self.app.get("/version") + async def version(): + return {"version": get_version()} + + @self.app.post("/v1/match") + async def match( + image0: UploadFile = File(...), image1: UploadFile = File(...) + ): + """ + Handle the image matching request and return the processed result. + + Args: + image0 (UploadFile): The first image file for matching. + image1 (UploadFile): The second image file for matching. + + Returns: + JSONResponse: A JSON response containing the filtered match results + or an error message in case of failure. + """ + try: + # Load the images from the uploaded files + image0_array = self.load_image(image0) + image1_array = self.load_image(image1) + + # Perform image matching using the API + output = self.api(image0_array, image1_array) + + # Keys to skip in the output + skip_keys = ["image0_orig", "image1_orig"] + + # Postprocess the output to filter unwanted data + pred = self.postprocess(output, skip_keys) + + # Return the filtered prediction as a JSON response + return JSONResponse(content=pred) + except Exception as e: + # Return an error message with status code 500 in case of exception + return JSONResponse(content={"error": str(e)}, status_code=500) + + @self.app.post("/v1/extract") + async def extract(input_info: ImagesInput): + """ + Extract keypoints and descriptors from images. + + Args: + input_info: An object containing the image data and options. + + Returns: + A list of dictionaries containing the keypoints and descriptors. + """ + try: + preds = [] + for i, input_image in enumerate(input_info.data): + # Load the image from the input data + image_array = to_base64_nparray(input_image) + # Extract keypoints and descriptors + output = self.api.extract( + image_array, + max_keypoints=input_info.max_keypoints[i], + binarize=input_info.binarize, + ) + # Do not return the original image and image_orig + # skip_keys = ["image", "image_orig"] + skip_keys = [] + + # Postprocess the output + pred = self.postprocess(output, skip_keys) + preds.append(pred) + # Return the list of extracted features + return JSONResponse(content=preds) + except Exception as e: + # Return an error message if an exception occurs + return JSONResponse(content={"error": str(e)}, status_code=500) + + def load_image(self, file_path: Union[str, UploadFile]) -> np.ndarray: + """ + Reads an image from a file path or an UploadFile object. + + Args: + file_path: A file path or an UploadFile object. + + Returns: + A numpy array representing the image. + """ + if isinstance(file_path, str): + file_path = Path(file_path).resolve(strict=False) + else: + file_path = file_path.file + with Image.open(file_path) as img: + image_array = np.array(img) + return image_array + + def postprocess( + self, output: dict, skip_keys: list, binarize: bool = True + ) -> dict: + pred = {} + for key, value in output.items(): + if key in skip_keys: + continue + if isinstance(value, np.ndarray): + pred[key] = value.tolist() + return pred + + def run(self, host: str = "0.0.0.0", port: int = 8001): + uvicorn.run(self.app, host=host, port=port) + + +if __name__ == "__main__": + conf = { + "feature": { + "output": "feats-superpoint-n4096-rmax1600", + "model": { + "name": "superpoint", + "nms_radius": 3, + "max_keypoints": 4096, + "keypoint_threshold": 0.005, + }, + "preprocessing": { + "grayscale": True, + "force_resize": True, + "resize_max": 1600, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "matcher": { + "output": "matches-NN-mutual", + "model": { + "name": "nearest_neighbor", + "do_mutual_check": True, + "match_threshold": 0.2, + }, + }, + "dense": False, + } + + service = ImageMatchingService(conf=conf, device=DEVICE) + service.run() diff --git a/api/test/CMakeLists.txt b/api/test/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..200c17d8e34add0e787d6ca32bdbed9e3c4213a3 --- /dev/null +++ b/api/test/CMakeLists.txt @@ -0,0 +1,16 @@ +cmake_minimum_required(VERSION 3.10) +project(imatchui) + +set(OpenCV_DIR /usr/include/opencv4) +find_package(OpenCV REQUIRED) + +find_package(Boost REQUIRED COMPONENTS system) +if(Boost_FOUND) + include_directories(${Boost_INCLUDE_DIRS}) +endif() + +add_executable(client client.cpp) + +target_include_directories(client PRIVATE ${Boost_LIBRARIES} ${OpenCV_INCLUDE_DIRS}) + +target_link_libraries(client PRIVATE curl jsoncpp b64 ${OpenCV_LIBS}) diff --git a/api/test/build_and_run.sh b/api/test/build_and_run.sh new file mode 100644 index 0000000000000000000000000000000000000000..40921bb9b925c67722247df7ab901668d713e888 --- /dev/null +++ b/api/test/build_and_run.sh @@ -0,0 +1,16 @@ +# g++ main.cpp -I/usr/include/opencv4 -lcurl -ljsoncpp -lb64 -lopencv_core -lopencv_imgcodecs -o main +# sudo apt-get update +# sudo apt-get install libboost-all-dev -y +# sudo apt-get install libcurl4-openssl-dev libjsoncpp-dev libb64-dev libopencv-dev -y + +cd build +cmake .. +make -j12 + +echo " ======== RUN DEMO ========" + +./client + +echo " ======== END DEMO ========" + +cd .. diff --git a/api/test/client.cpp b/api/test/client.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7d80c8474a21a83374ddcbec721919b60901c7d2 --- /dev/null +++ b/api/test/client.cpp @@ -0,0 +1,84 @@ +#include +#include +#include "helper.h" + +int main() { + std::string img_path = "../../../datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot45.jpg"; + cv::Mat original_img = cv::imread(img_path, cv::IMREAD_GRAYSCALE); + + if (original_img.empty()) { + throw std::runtime_error("Failed to decode image"); + } + + // Convert the image to Base64 + std::string base64_img = image_to_base64(original_img); + + // Convert the Base64 back to an image + cv::Mat decoded_img = base64_to_image(base64_img); + cv::imwrite("decoded_image.jpg", decoded_img); + cv::imwrite("original_img.jpg", original_img); + + // The images should be identical + if (cv::countNonZero(original_img != decoded_img) != 0) { + std::cerr << "The images are not identical" << std::endl; + return -1; + } else { + std::cout << "The images are identical!" << std::endl; + } + + // construct params + APIParams params{ + .data = {base64_img}, + .max_keypoints = {100, 100}, + .timestamps = {"0", "1"}, + .grayscale = {0}, + .image_hw = {{480, 640}, {240, 320}}, + .feature_type = 0, + .rotates = {0.0f, 0.0f}, + .scales = {1.0f, 1.0f}, + .reference_points = { + {1.23e+2f, 1.2e+1f}, + {5.0e-1f, 3.0e-1f}, + {2.3e+2f, 2.2e+1f}, + {6.0e-1f, 4.0e-1f} + }, + .binarize = {1} + }; + + KeyPointResults kpts_results; + + // Convert the parameters to JSON + Json::Value jsonData = paramsToJson(params); + std::string url = "http://127.0.0.1:8001/v1/extract"; + Json::StreamWriterBuilder writer; + std::string output = Json::writeString(writer, jsonData); + + CURL* curl; + CURLcode res; + std::string readBuffer; + + curl_global_init(CURL_GLOBAL_DEFAULT); + curl = curl_easy_init(); + if (curl) { + struct curl_slist* hs = NULL; + hs = curl_slist_append(hs, "Content-Type: application/json"); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, hs); + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, output.c_str()); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer); + res = curl_easy_perform(curl); + + if (res != CURLE_OK) + fprintf(stderr, "curl_easy_perform() failed: %s\n", + curl_easy_strerror(res)); + else { + // std::cout << "Response from server: " << readBuffer << std::endl; + kpts_results = decode_response(readBuffer); + } + curl_easy_cleanup(curl); + } + curl_global_cleanup(); + + return 0; +} diff --git a/api/test/helper.h b/api/test/helper.h new file mode 100644 index 0000000000000000000000000000000000000000..029291e8e97b6cb8bb40014912014f3f229447b1 --- /dev/null +++ b/api/test/helper.h @@ -0,0 +1,410 @@ + +#include +#include +#include +#include +#include +#include + +// base64 to image +#include +#include +#include + +/// Parameters used in the API +struct APIParams { + /// A list of images, base64 encoded + std::vector data; + + /// The maximum number of keypoints to detect for each image + std::vector max_keypoints; + + /// The timestamps of the images + std::vector timestamps; + + /// Whether to convert the images to grayscale + bool grayscale; + + /// The height and width of each image + std::vector> image_hw; + + /// The type of feature detector to use + int feature_type; + + /// The rotations of the images + std::vector rotates; + + /// The scales of the images + std::vector scales; + + /// The reference points of the images + std::vector> reference_points; + + /// Whether to binarize the descriptors + bool binarize; +}; + +/** + * @brief Contains the results of a keypoint detector. + * + * @details Stores the keypoints and descriptors for each image. + */ +class KeyPointResults { +public: + KeyPointResults() {} + + /** + * @brief Constructor. + * + * @param kp The keypoints for each image. + */ + KeyPointResults(const std::vector>& kp, + const std::vector& desc) + : keypoints(kp), descriptors(desc) {} + + /** + * @brief Append keypoints to the result. + * + * @param kpts The keypoints to append. + */ + inline void append_keypoints(std::vector& kpts) { + keypoints.emplace_back(kpts); + } + + /** + * @brief Append descriptors to the result. + * + * @param desc The descriptors to append. + */ + inline void append_descriptors(cv::Mat& desc) { + descriptors.emplace_back(desc); + } + + /** + * @brief Get the keypoints. + * + * @return The keypoints. + */ + inline std::vector> get_keypoints() { + return keypoints; + } + + /** + * @brief Get the descriptors. + * + * @return The descriptors. + */ + inline std::vector get_descriptors() { + return descriptors; + } + +private: + std::vector> keypoints; + std::vector descriptors; + std::vector> scores; +}; + + +/** + * @brief Decodes a base64 encoded string. + * + * @param base64 The base64 encoded string to decode. + * @return The decoded string. + */ +std::string base64_decode(const std::string& base64) { + using namespace boost::archive::iterators; + using It = transform_width, 8, 6>; + + // Find the position of the last non-whitespace character + auto end = base64.find_last_not_of(" \t\n\r"); + if (end != std::string::npos) { + // Move one past the last non-whitespace character + end += 1; + } + + // Decode the base64 string and return the result + return std::string(It(base64.begin()), It(base64.begin() + end)); +} + + + +/** + * @brief Decodes a base64 string into an OpenCV image + * + * @param base64 The base64 encoded string + * @return The decoded OpenCV image + */ +cv::Mat base64_to_image(const std::string& base64) { + // Decode the base64 string + std::string decodedStr = base64_decode(base64); + + // Decode the image + std::vector data(decodedStr.begin(), decodedStr.end()); + cv::Mat img = cv::imdecode(data, cv::IMREAD_GRAYSCALE); + + // Check for errors + if (img.empty()) { + throw std::runtime_error("Failed to decode image"); + } + + return img; +} + + +/** + * @brief Encodes an OpenCV image into a base64 string + * + * This function takes an OpenCV image and encodes it into a base64 string. + * The image is first encoded as a PNG image, and then the resulting + * bytes are encoded as a base64 string. + * + * @param img The OpenCV image + * @return The base64 encoded string + * + * @throws std::runtime_error if the image is empty or encoding fails + */ +std::string image_to_base64(cv::Mat &img) { + if (img.empty()) { + throw std::runtime_error("Failed to read image"); + } + + // Encode the image as a PNG + std::vector buf; + if (!cv::imencode(".png", img, buf)) { + throw std::runtime_error("Failed to encode image"); + } + + // Encode the bytes as a base64 string + using namespace boost::archive::iterators; + using It = base64_from_binary::const_iterator, 6, 8>>; + std::string base64(It(buf.begin()), It(buf.end())); + + // Pad the string with '=' characters to a multiple of 4 bytes + base64.append((3 - buf.size() % 3) % 3, '='); + + return base64; +} + + +/** + * @brief Callback function for libcurl to write data to a string + * + * This function is used as a callback for libcurl to write data to a string. + * It takes the contents, size, and nmemb as parameters, and writes the data to + * the string. + * + * @param contents The data to write + * @param size The size of the data + * @param nmemb The number of members in the data + * @param s The string to write the data to + * @return The number of bytes written + */ +size_t WriteCallback(void* contents, size_t size, size_t nmemb, std::string* s) { + size_t newLength = size * nmemb; + try { + // Resize the string to fit the new data + s->resize(s->size() + newLength); + } catch (std::bad_alloc& e) { + // If there's an error allocating memory, return 0 + return 0; + } + + // Copy the data to the string + std::copy(static_cast(contents), + static_cast(contents) + newLength, + s->begin() + s->size() - newLength); + return newLength; +} + +// Helper functions + +/** + * @brief Helper function to convert a type to a Json::Value + * + * This function takes a value of type T and converts it to a Json::Value. + * It is used to simplify the process of converting a type to a Json::Value. + * + * @param val The value to convert + * @return The converted Json::Value + */ +template +Json::Value toJson(const T& val) { + return Json::Value(val); +} + +/** + * @brief Converts a vector to a Json::Value + * + * This function takes a vector of type T and converts it to a Json::Value. + * Each element in the vector is appended to the Json::Value array. + * + * @param vec The vector to convert to Json::Value + * @return The Json::Value representing the vector + */ +template +Json::Value vectorToJson(const std::vector& vec) { + Json::Value json(Json::arrayValue); + for (const auto& item : vec) { + json.append(item); + } + return json; +} + +/** + * @brief Converts a nested vector to a Json::Value + * + * This function takes a nested vector of type T and converts it to a Json::Value. + * Each sub-vector is converted to a Json::Value array and appended to the main Json::Value array. + * + * @param vec The nested vector to convert to Json::Value + * @return The Json::Value representing the nested vector + */ +template +Json::Value nestedVectorToJson(const std::vector>& vec) { + Json::Value json(Json::arrayValue); + for (const auto& subVec : vec) { + json.append(vectorToJson(subVec)); + } + return json; +} + + + +/** + * @brief Converts the APIParams struct to a Json::Value + * + * This function takes an APIParams struct and converts it to a Json::Value. + * The Json::Value is a JSON object with the following fields: + * - data: a JSON array of base64 encoded images + * - max_keypoints: a JSON array of integers, max number of keypoints for each image + * - timestamps: a JSON array of timestamps, one for each image + * - grayscale: a JSON boolean, whether to convert images to grayscale + * - image_hw: a nested JSON array, each sub-array contains the height and width of an image + * - feature_type: a JSON integer, the type of feature detector to use + * - rotates: a JSON array of doubles, the rotation of each image + * - scales: a JSON array of doubles, the scale of each image + * - reference_points: a nested JSON array, each sub-array contains the reference points of an image + * - binarize: a JSON boolean, whether to binarize the descriptors + * + * @param params The APIParams struct to convert + * @return The Json::Value representing the APIParams struct + */ +Json::Value paramsToJson(const APIParams& params) { + Json::Value json; + json["data"] = vectorToJson(params.data); + json["max_keypoints"] = vectorToJson(params.max_keypoints); + json["timestamps"] = vectorToJson(params.timestamps); + json["grayscale"] = toJson(params.grayscale); + json["image_hw"] = nestedVectorToJson(params.image_hw); + json["feature_type"] = toJson(params.feature_type); + json["rotates"] = vectorToJson(params.rotates); + json["scales"] = vectorToJson(params.scales); + json["reference_points"] = nestedVectorToJson(params.reference_points); + json["binarize"] = toJson(params.binarize); + return json; +} + +template +cv::Mat jsonToMat(Json::Value json) { + int rows = json.size(); + int cols = json[0].size(); + + // Create a single array to hold all the data. + std::vector data; + data.reserve(rows * cols); + + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + data.push_back(static_cast(json[i][j].asInt())); + } + } + + // Create a cv::Mat object that points to the data. + cv::Mat mat(rows, cols, CV_8UC1, data.data()); // Change the type if necessary. + // cv::Mat mat(cols, rows,CV_8UC1, data.data()); // Change the type if necessary. + + return mat; +} + + + +/** + * @brief Decodes the response of the server and prints the keypoints + * + * This function takes the response of the server, a JSON string, and decodes + * it. It then prints the keypoints and draws them on the original image. + * + * @param response The response of the server + * @return The keypoints and descriptors + */ +KeyPointResults decode_response(const std::string& response, bool viz=true) { + Json::CharReaderBuilder builder; + Json::CharReader* reader = builder.newCharReader(); + + Json::Value jsonData; + std::string errors; + + // Parse the JSON response + bool parsingSuccessful = reader->parse(response.c_str(), + response.c_str() + response.size(), &jsonData, &errors); + delete reader; + + if (!parsingSuccessful) { + // Handle error + std::cout << "Failed to parse the JSON, errors:" << std::endl; + std::cout << errors << std::endl; + return KeyPointResults(); + } + + KeyPointResults kpts_results; + + // Iterate over the images + for (const auto& jsonItem : jsonData) { + auto jkeypoints = jsonItem["keypoints"]; + auto jkeypoints_orig = jsonItem["keypoints_orig"]; + auto jdescriptors = jsonItem["descriptors"]; + auto jscores = jsonItem["scores"]; + auto jimageSize = jsonItem["image_size"]; + auto joriginalSize = jsonItem["original_size"]; + auto jsize = jsonItem["size"]; + + std::vector vkeypoints; + std::vector vscores; + + // Iterate over the keypoints + int counter = 0; + for (const auto& keypoint : jkeypoints_orig) { + if (counter < 10) { + // Print the first 10 keypoints + std::cout << keypoint[0].asFloat() << ", " + << keypoint[1].asFloat() << std::endl; + } + counter++; + // Convert the Json::Value to a cv::KeyPoint + vkeypoints.emplace_back(cv::KeyPoint(keypoint[0].asFloat(), + keypoint[1].asFloat(), 0.0)); + } + + if (viz && jsonItem.isMember("image_orig")) { + + auto jimg_orig = jsonItem["image_orig"]; + cv::Mat img = jsonToMat(jimg_orig); + cv::imwrite("viz_image_orig.jpg", img); + + // Draw keypoints on the image + cv::Mat imgWithKeypoints; + cv::drawKeypoints(img, vkeypoints, + imgWithKeypoints, cv::Scalar(0, 0, 255)); + + // Write the image with keypoints + std::string filename = "viz_image_orig_keypoints.jpg"; + cv::imwrite(filename, imgWithKeypoints); + } + + // Iterate over the descriptors + cv::Mat descriptors = jsonToMat(jdescriptors); + kpts_results.append_keypoints(vkeypoints); + kpts_results.append_descriptors(descriptors); + } + return kpts_results; +} diff --git a/api/types.py b/api/types.py new file mode 100644 index 0000000000000000000000000000000000000000..db17dce8a6824f8887720fdbc6b0b2513bdb17eb --- /dev/null +++ b/api/types.py @@ -0,0 +1,16 @@ +from typing import List + +from pydantic import BaseModel + + +class ImagesInput(BaseModel): + data: List[str] = [] + max_keypoints: List[int] = [] + timestamps: List[str] = [] + grayscale: bool = False + image_hw: List[List[int]] = [[], []] + feature_type: int = 0 + rotates: List[float] = [] + scales: List[float] = [] + reference_points: List[List[float]] = [] + binarize: bool = False diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..b168e266b562be651ab217b46a30145cac712914 --- /dev/null +++ b/app.py @@ -0,0 +1,28 @@ +import argparse +from pathlib import Path +from ui.app_class import ImageMatchingApp + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--server_name", + type=str, + default="0.0.0.0", + help="server name", + ) + parser.add_argument( + "--server_port", + type=int, + default=7860, + help="server port", + ) + parser.add_argument( + "--config", + type=str, + default=Path(__file__).parent / "ui/config.yaml", + help="config file", + ) + args = parser.parse_args() + ImageMatchingApp( + args.server_name, args.server_port, config=args.config + ).run() diff --git a/build_docker.sh b/build_docker.sh new file mode 100644 index 0000000000000000000000000000000000000000..836deb8ae6d9b9c65cf7e2588b9acd474a129d6f --- /dev/null +++ b/build_docker.sh @@ -0,0 +1,3 @@ +docker build -t image-matching-webui:latest . --no-cache +docker tag image-matching-webui:latest vincentqin/image-matching-webui:latest +docker push vincentqin/image-matching-webui:latest \ No newline at end of file diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..7455862d5fe993d55e63f79fb63f1d274f25774e --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,27 @@ +# Use an official conda-based Python image as a parent image +FROM pytorch/pytorch:2.4.0-cuda12.1-cudnn9-runtime +LABEL maintainer vincentqyw +ARG PYTHON_VERSION=3.10.10 + +# Set the working directory to /code +WORKDIR /code + +# Install Git and Git LFS +RUN apt-get update && apt-get install -y git-lfs +RUN git lfs install + +# Clone the Git repository +RUN git clone https://huggingface.co/spaces/Realcat/image-matching-webui /code + +RUN conda create -n imw python=${PYTHON_VERSION} +RUN echo "source activate imw" > ~/.bashrc +ENV PATH /opt/conda/envs/imw/bin:$PATH + +# Make RUN commands use the new environment +SHELL ["conda", "run", "-n", "imw", "/bin/bash", "-c"] +RUN pip install --upgrade pip +RUN pip install -r requirements.txt +RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y + +# Export port +EXPOSE 7860 diff --git a/docker/build_docker.bat b/docker/build_docker.bat new file mode 100644 index 0000000000000000000000000000000000000000..9f3fc687e1185de2866a1dbe221599549abdbce8 --- /dev/null +++ b/docker/build_docker.bat @@ -0,0 +1,3 @@ +docker build -t image-matching-webui:latest . --no-cache +# docker tag image-matching-webui:latest vincentqin/image-matching-webui:latest +# docker push vincentqin/image-matching-webui:latest diff --git a/docker/run_docker.bat b/docker/run_docker.bat new file mode 100644 index 0000000000000000000000000000000000000000..da7686293c14465f0899c4b022f89fcc03db93b3 --- /dev/null +++ b/docker/run_docker.bat @@ -0,0 +1 @@ +docker run -it -p 7860:7860 vincentqin/image-matching-webui:latest python app.py --server_name "0.0.0.0" --server_port=7860 diff --git a/docker/run_docker.sh b/docker/run_docker.sh new file mode 100644 index 0000000000000000000000000000000000000000..da7686293c14465f0899c4b022f89fcc03db93b3 --- /dev/null +++ b/docker/run_docker.sh @@ -0,0 +1 @@ +docker run -it -p 7860:7860 vincentqin/image-matching-webui:latest python app.py --server_name "0.0.0.0" --server_port=7860 diff --git a/format.sh b/format.sh new file mode 100644 index 0000000000000000000000000000000000000000..ada71402e3a1b431e0c82e3f542700e2224e3a58 --- /dev/null +++ b/format.sh @@ -0,0 +1,3 @@ +python -m flake8 ui/*.py api/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py +python -m isort ui/*.py api/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py +python -m black ui/*.py api/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py \ No newline at end of file diff --git a/hloc/__init__.py b/hloc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7db5d9f07ad8bd04f704eeaf8cb599f99194623 --- /dev/null +++ b/hloc/__init__.py @@ -0,0 +1,63 @@ +import logging +import sys + +import torch +from packaging import version + +__version__ = "1.5" + +LOG_PATH = "log.txt" + + +def read_logs(): + sys.stdout.flush() + with open(LOG_PATH, "r") as f: + return f.read() + + +def flush_logs(): + sys.stdout.flush() + logs = open(LOG_PATH, "w") + logs.close() + + +formatter = logging.Formatter( + fmt="[%(asctime)s %(name)s %(levelname)s] %(message)s", + datefmt="%Y/%m/%d %H:%M:%S", +) + +logs_file = open(LOG_PATH, "w") +logs_file.close() + +file_handler = logging.FileHandler(filename=LOG_PATH) +file_handler.setFormatter(formatter) +file_handler.setLevel(logging.INFO) +stdout_handler = logging.StreamHandler() +stdout_handler.setFormatter(formatter) +stdout_handler.setLevel(logging.INFO) +logger = logging.getLogger("hloc") +logger.setLevel(logging.INFO) +logger.addHandler(file_handler) +logger.addHandler(stdout_handler) +logger.propagate = False + +try: + import pycolmap +except ImportError: + logger.warning("pycolmap is not installed, some features may not work.") +else: + min_version = version.parse("0.6.0") + found_version = pycolmap.__version__ + if found_version != "dev": + version = version.parse(found_version) + if version < min_version: + s = f"pycolmap>={min_version}" + logger.warning( + "hloc requires %s but found pycolmap==%s, " + 'please upgrade with `pip install --upgrade "%s"`', + s, + found_version, + s, + ) + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/hloc/colmap_from_nvm.py b/hloc/colmap_from_nvm.py new file mode 100644 index 0000000000000000000000000000000000000000..1f3ad896b88f2cb484918d1b395bbee91b7c6c29 --- /dev/null +++ b/hloc/colmap_from_nvm.py @@ -0,0 +1,220 @@ +import argparse +import sqlite3 +from collections import defaultdict +from pathlib import Path + +import numpy as np +from tqdm import tqdm + +from . import logger +from .utils.read_write_model import ( + CAMERA_MODEL_NAMES, + Camera, + Image, + Point3D, + write_model, +) + + +def recover_database_images_and_ids(database_path): + images = {} + cameras = {} + db = sqlite3.connect(str(database_path)) + ret = db.execute("SELECT name, image_id, camera_id FROM images;") + for name, image_id, camera_id in ret: + images[name] = image_id + cameras[name] = camera_id + db.close() + logger.info( + f"Found {len(images)} images and {len(cameras)} cameras in database." + ) + return images, cameras + + +def quaternion_to_rotation_matrix(qvec): + qvec = qvec / np.linalg.norm(qvec) + w, x, y, z = qvec + R = np.array( + [ + [ + 1 - 2 * y * y - 2 * z * z, + 2 * x * y - 2 * z * w, + 2 * x * z + 2 * y * w, + ], + [ + 2 * x * y + 2 * z * w, + 1 - 2 * x * x - 2 * z * z, + 2 * y * z - 2 * x * w, + ], + [ + 2 * x * z - 2 * y * w, + 2 * y * z + 2 * x * w, + 1 - 2 * x * x - 2 * y * y, + ], + ] + ) + return R + + +def camera_center_to_translation(c, qvec): + R = quaternion_to_rotation_matrix(qvec) + return (-1) * np.matmul(R, c) + + +def read_nvm_model( + nvm_path, intrinsics_path, image_ids, camera_ids, skip_points=False +): + with open(intrinsics_path, "r") as f: + raw_intrinsics = f.readlines() + + logger.info(f"Reading {len(raw_intrinsics)} cameras...") + cameras = {} + for intrinsics in raw_intrinsics: + intrinsics = intrinsics.strip("\n").split(" ") + name, camera_model, width, height = intrinsics[:4] + params = [float(p) for p in intrinsics[4:]] + camera_model = CAMERA_MODEL_NAMES[camera_model] + assert len(params) == camera_model.num_params + camera_id = camera_ids[name] + camera = Camera( + id=camera_id, + model=camera_model.model_name, + width=int(width), + height=int(height), + params=params, + ) + cameras[camera_id] = camera + + nvm_f = open(nvm_path, "r") + line = nvm_f.readline() + while line == "\n" or line.startswith("NVM_V3"): + line = nvm_f.readline() + num_images = int(line) + assert num_images == len(cameras) + + logger.info(f"Reading {num_images} images...") + image_idx_to_db_image_id = [] + image_data = [] + i = 0 + while i < num_images: + line = nvm_f.readline() + if line == "\n": + continue + data = line.strip("\n").split(" ") + image_data.append(data) + image_idx_to_db_image_id.append(image_ids[data[0]]) + i += 1 + + line = nvm_f.readline() + while line == "\n": + line = nvm_f.readline() + num_points = int(line) + + if skip_points: + logger.info(f"Skipping {num_points} points.") + num_points = 0 + else: + logger.info(f"Reading {num_points} points...") + points3D = {} + image_idx_to_keypoints = defaultdict(list) + i = 0 + pbar = tqdm(total=num_points, unit="pts") + while i < num_points: + line = nvm_f.readline() + if line == "\n": + continue + + data = line.strip("\n").split(" ") + x, y, z, r, g, b, num_observations = data[:7] + obs_image_ids, point2D_idxs = [], [] + for j in range(int(num_observations)): + s = 7 + 4 * j + img_index, kp_index, kx, ky = data[s : s + 4] + image_idx_to_keypoints[int(img_index)].append( + (int(kp_index), float(kx), float(ky), i) + ) + db_image_id = image_idx_to_db_image_id[int(img_index)] + obs_image_ids.append(db_image_id) + point2D_idxs.append(kp_index) + + point = Point3D( + id=i, + xyz=np.array([x, y, z], float), + rgb=np.array([r, g, b], int), + error=1.0, # fake + image_ids=np.array(obs_image_ids, int), + point2D_idxs=np.array(point2D_idxs, int), + ) + points3D[i] = point + + i += 1 + pbar.update(1) + pbar.close() + + logger.info("Parsing image data...") + images = {} + for i, data in enumerate(image_data): + # Skip the focal length. Skip the distortion and terminal 0. + name, _, qw, qx, qy, qz, cx, cy, cz, _, _ = data + qvec = np.array([qw, qx, qy, qz], float) + c = np.array([cx, cy, cz], float) + t = camera_center_to_translation(c, qvec) + + if i in image_idx_to_keypoints: + # NVM only stores triangulated 2D keypoints: add dummy ones + keypoints = image_idx_to_keypoints[i] + point2D_idxs = np.array([d[0] for d in keypoints]) + tri_xys = np.array([[x, y] for _, x, y, _ in keypoints]) + tri_ids = np.array([i for _, _, _, i in keypoints]) + + num_2Dpoints = max(point2D_idxs) + 1 + xys = np.zeros((num_2Dpoints, 2), float) + point3D_ids = np.full(num_2Dpoints, -1, int) + xys[point2D_idxs] = tri_xys + point3D_ids[point2D_idxs] = tri_ids + else: + xys = np.zeros((0, 2), float) + point3D_ids = np.full(0, -1, int) + + image_id = image_ids[name] + image = Image( + id=image_id, + qvec=qvec, + tvec=t, + camera_id=camera_ids[name], + name=name, + xys=xys, + point3D_ids=point3D_ids, + ) + images[image_id] = image + + return cameras, images, points3D + + +def main(nvm, intrinsics, database, output, skip_points=False): + assert nvm.exists(), nvm + assert intrinsics.exists(), intrinsics + assert database.exists(), database + + image_ids, camera_ids = recover_database_images_and_ids(database) + + logger.info("Reading the NVM model...") + model = read_nvm_model( + nvm, intrinsics, image_ids, camera_ids, skip_points=skip_points + ) + + logger.info("Writing the COLMAP model...") + output.mkdir(exist_ok=True, parents=True) + write_model(*model, path=str(output), ext=".bin") + logger.info("Done.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--nvm", required=True, type=Path) + parser.add_argument("--intrinsics", required=True, type=Path) + parser.add_argument("--database", required=True, type=Path) + parser.add_argument("--output", required=True, type=Path) + parser.add_argument("--skip_points", action="store_true") + args = parser.parse_args() + main(**args.__dict__) diff --git a/hloc/extract_features.py b/hloc/extract_features.py new file mode 100644 index 0000000000000000000000000000000000000000..d268990cb1e69d8ae560dcdaaa66af823d753247 --- /dev/null +++ b/hloc/extract_features.py @@ -0,0 +1,618 @@ +import argparse +import collections.abc as collections +import pprint +from pathlib import Path +from types import SimpleNamespace +from typing import Dict, List, Optional, Union + +import cv2 +import h5py +import numpy as np +import PIL.Image +import torch +import torchvision.transforms.functional as F +from tqdm import tqdm + +from . import extractors, logger +from .utils.base_model import dynamic_load +from .utils.io import list_h5_names, read_image +from .utils.parsers import parse_image_lists + +""" +A set of standard configurations that can be directly selected from the command +line using their name. Each is a dictionary with the following entries: + - output: the name of the feature file that will be generated. + - model: the model configuration, as passed to a feature extractor. + - preprocessing: how to preprocess the images read from disk. +""" +confs = { + "superpoint_aachen": { + "output": "feats-superpoint-n4096-r1024", + "model": { + "name": "superpoint", + "nms_radius": 3, + "max_keypoints": 4096, + "keypoint_threshold": 0.005, + }, + "preprocessing": { + "grayscale": True, + "force_resize": True, + "resize_max": 1600, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + # Resize images to 1600px even if they are originally smaller. + # Improves the keypoint localization if the images are of good quality. + "superpoint_max": { + "output": "feats-superpoint-n4096-rmax1600", + "model": { + "name": "superpoint", + "nms_radius": 3, + "max_keypoints": 4096, + "keypoint_threshold": 0.005, + }, + "preprocessing": { + "grayscale": True, + "force_resize": True, + "resize_max": 1600, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "superpoint_inloc": { + "output": "feats-superpoint-n4096-r1600", + "model": { + "name": "superpoint", + "nms_radius": 4, + "max_keypoints": 4096, + "keypoint_threshold": 0.005, + }, + "preprocessing": { + "grayscale": True, + "resize_max": 1600, + "force_resize": True, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "r2d2": { + "output": "feats-r2d2-n5000-r1024", + "model": { + "name": "r2d2", + "max_keypoints": 5000, + "reliability_threshold": 0.7, + "repetability_threshold": 0.7, + }, + "preprocessing": { + "grayscale": False, + "force_resize": True, + "resize_max": 1024, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "d2net-ss": { + "output": "feats-d2net-ss-n5000-r1600", + "model": { + "name": "d2net", + "multiscale": False, + "max_keypoints": 5000, + }, + "preprocessing": { + "grayscale": False, + "resize_max": 1600, + "force_resize": True, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "d2net-ms": { + "output": "feats-d2net-ms-n5000-r1600", + "model": { + "name": "d2net", + "multiscale": True, + "max_keypoints": 5000, + }, + "preprocessing": { + "grayscale": False, + "resize_max": 1600, + "force_resize": True, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "rord": { + "output": "feats-rord-ss-n5000-r1600", + "model": { + "name": "rord", + "multiscale": False, + "max_keypoints": 5000, + }, + "preprocessing": { + "grayscale": False, + "resize_max": 1600, + "force_resize": True, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "rootsift": { + "output": "feats-rootsift-n5000-r1600", + "model": { + "name": "dog", + "descriptor": "rootsift", + "max_keypoints": 5000, + }, + "preprocessing": { + "grayscale": True, + "force_resize": True, + "resize_max": 1600, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "sift": { + "output": "feats-sift-n5000-r1600", + "model": { + "name": "sift", + "rootsift": True, + "max_keypoints": 5000, + }, + "preprocessing": { + "grayscale": True, + "force_resize": True, + "resize_max": 1600, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "sosnet": { + "output": "feats-sosnet-n5000-r1600", + "model": { + "name": "dog", + "descriptor": "sosnet", + "max_keypoints": 5000, + }, + "preprocessing": { + "grayscale": True, + "resize_max": 1600, + "force_resize": True, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "hardnet": { + "output": "feats-hardnet-n5000-r1600", + "model": { + "name": "dog", + "descriptor": "hardnet", + "max_keypoints": 5000, + }, + "preprocessing": { + "grayscale": True, + "resize_max": 1600, + "force_resize": True, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "disk": { + "output": "feats-disk-n5000-r1600", + "model": { + "name": "disk", + "max_keypoints": 5000, + }, + "preprocessing": { + "grayscale": False, + "resize_max": 1600, + "force_resize": True, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "xfeat": { + "output": "feats-xfeat-n5000-r1600", + "model": { + "name": "xfeat", + "max_keypoints": 5000, + }, + "preprocessing": { + "grayscale": False, + "resize_max": 1600, + "force_resize": True, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "alike": { + "output": "feats-alike-n5000-r1600", + "model": { + "name": "alike", + "max_keypoints": 5000, + "use_relu": True, + "multiscale": False, + "detection_threshold": 0.5, + "top_k": -1, + "sub_pixel": False, + }, + "preprocessing": { + "grayscale": False, + "resize_max": 1600, + "force_resize": True, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "lanet": { + "output": "feats-lanet-n5000-r1600", + "model": { + "name": "lanet", + "keypoint_threshold": 0.1, + "max_keypoints": 5000, + }, + "preprocessing": { + "grayscale": False, + "resize_max": 1600, + "force_resize": True, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "darkfeat": { + "output": "feats-darkfeat-n5000-r1600", + "model": { + "name": "darkfeat", + "max_keypoints": 5000, + "reliability_threshold": 0.7, + "repetability_threshold": 0.7, + }, + "preprocessing": { + "grayscale": False, + "force_resize": True, + "resize_max": 1600, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "dedode": { + "output": "feats-dedode-n5000-r1600", + "model": { + "name": "dedode", + "max_keypoints": 5000, + }, + "preprocessing": { + "grayscale": False, + "force_resize": True, + "resize_max": 1600, + "width": 768, + "height": 768, + "dfactor": 8, + }, + }, + "example": { + "output": "feats-example-n2000-r1024", + "model": { + "name": "example", + "keypoint_threshold": 0.1, + "max_keypoints": 2000, + "model_name": "model.pth", + }, + "preprocessing": { + "grayscale": False, + "force_resize": True, + "resize_max": 1024, + "width": 768, + "height": 768, + "dfactor": 8, + }, + }, + "sfd2": { + "output": "feats-sfd2-n4096-r1600", + "model": { + "name": "sfd2", + "max_keypoints": 4096, + }, + "preprocessing": { + "grayscale": False, + "force_resize": True, + "resize_max": 1600, + "width": 640, + "height": 480, + "conf_th": 0.001, + "multiscale": False, + "scales": [1.0], + }, + }, + # Global descriptors + "dir": { + "output": "global-feats-dir", + "model": {"name": "dir"}, + "preprocessing": {"resize_max": 1024}, + }, + "netvlad": { + "output": "global-feats-netvlad", + "model": {"name": "netvlad"}, + "preprocessing": {"resize_max": 1024}, + }, + "openibl": { + "output": "global-feats-openibl", + "model": {"name": "openibl"}, + "preprocessing": {"resize_max": 1024}, + }, + "cosplace": { + "output": "global-feats-cosplace", + "model": {"name": "cosplace"}, + "preprocessing": {"resize_max": 1024}, + }, + "eigenplaces": { + "output": "global-feats-eigenplaces", + "model": {"name": "eigenplaces"}, + "preprocessing": {"resize_max": 1024}, + }, +} + + +def resize_image(image, size, interp): + if interp.startswith("cv2_"): + interp = getattr(cv2, "INTER_" + interp[len("cv2_") :].upper()) + h, w = image.shape[:2] + if interp == cv2.INTER_AREA and (w < size[0] or h < size[1]): + interp = cv2.INTER_LINEAR + resized = cv2.resize(image, size, interpolation=interp) + elif interp.startswith("pil_"): + interp = getattr(PIL.Image, interp[len("pil_") :].upper()) + resized = PIL.Image.fromarray(image.astype(np.uint8)) + resized = resized.resize(size, resample=interp) + resized = np.asarray(resized, dtype=image.dtype) + else: + raise ValueError(f"Unknown interpolation {interp}.") + return resized + + +class ImageDataset(torch.utils.data.Dataset): + default_conf = { + "globs": ["*.jpg", "*.png", "*.jpeg", "*.JPG", "*.PNG"], + "grayscale": False, + "resize_max": None, + "force_resize": False, + "interpolation": "cv2_area", # pil_linear is more accurate but slower + } + + def __init__(self, root, conf, paths=None): + self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf}) + self.root = root + + if paths is None: + paths = [] + for g in conf.globs: + paths += list(Path(root).glob("**/" + g)) + if len(paths) == 0: + raise ValueError(f"Could not find any image in root: {root}.") + paths = sorted(list(set(paths))) + self.names = [i.relative_to(root).as_posix() for i in paths] + logger.info(f"Found {len(self.names)} images in root {root}.") + else: + if isinstance(paths, (Path, str)): + self.names = parse_image_lists(paths) + elif isinstance(paths, collections.Iterable): + self.names = [ + p.as_posix() if isinstance(p, Path) else p for p in paths + ] + else: + raise ValueError(f"Unknown format for path argument {paths}.") + + for name in self.names: + if not (root / name).exists(): + raise ValueError( + f"Image {name} does not exists in root: {root}." + ) + + def __getitem__(self, idx): + name = self.names[idx] + image = read_image(self.root / name, self.conf.grayscale) + image = image.astype(np.float32) + size = image.shape[:2][::-1] + + if self.conf.resize_max and ( + self.conf.force_resize or max(size) > self.conf.resize_max + ): + scale = self.conf.resize_max / max(size) + size_new = tuple(int(round(x * scale)) for x in size) + image = resize_image(image, size_new, self.conf.interpolation) + + if self.conf.grayscale: + image = image[None] + else: + image = image.transpose((2, 0, 1)) # HxWxC to CxHxW + image = image / 255.0 + + data = { + "image": image, + "original_size": np.array(size), + } + return data + + def __len__(self): + return len(self.names) + + +def extract(model, image_0, conf): + default_conf = { + "grayscale": True, + "resize_max": 1024, + "dfactor": 8, + "cache_images": False, + "force_resize": False, + "width": 320, + "height": 240, + "interpolation": "cv2_area", + } + conf = SimpleNamespace(**{**default_conf, **conf}) + device = "cuda" if torch.cuda.is_available() else "cpu" + + def preprocess(image: np.ndarray, conf: SimpleNamespace): + image = image.astype(np.float32, copy=False) + size = image.shape[:2][::-1] + scale = np.array([1.0, 1.0]) + if conf.resize_max: + scale = conf.resize_max / max(size) + if scale < 1.0: + size_new = tuple(int(round(x * scale)) for x in size) + image = resize_image(image, size_new, "cv2_area") + scale = np.array(size) / np.array(size_new) + if conf.force_resize: + image = resize_image(image, (conf.width, conf.height), "cv2_area") + size_new = (conf.width, conf.height) + scale = np.array(size) / np.array(size_new) + if conf.grayscale: + assert image.ndim == 2, image.shape + image = image[None] + else: + image = image.transpose((2, 0, 1)) # HxWxC to CxHxW + image = torch.from_numpy(image / 255.0).float() + + # assure that the size is divisible by dfactor + size_new = tuple( + map( + lambda x: int(x // conf.dfactor * conf.dfactor), + image.shape[-2:], + ) + ) + image = F.resize(image, size=size_new, antialias=True) + input_ = image.to(device, non_blocking=True)[None] + data = { + "image": input_, + "image_orig": image_0, + "original_size": np.array(size), + "size": np.array(image.shape[1:][::-1]), + } + return data + + # convert to grayscale if needed + if len(image_0.shape) == 3 and conf.grayscale: + image0 = cv2.cvtColor(image_0, cv2.COLOR_RGB2GRAY) + else: + image0 = image_0 + # comment following lines, image is always RGB mode + # if not conf.grayscale and len(image_0.shape) == 3: + # image0 = image_0[:, :, ::-1] # BGR to RGB + data = preprocess(image0, conf) + pred = model({"image": data["image"]}) + pred["image_size"] = data["original_size"] + pred = {**pred, **data} + return pred + + +@torch.no_grad() +def main( + conf: Dict, + image_dir: Path, + export_dir: Optional[Path] = None, + as_half: bool = True, + image_list: Optional[Union[Path, List[str]]] = None, + feature_path: Optional[Path] = None, + overwrite: bool = False, +) -> Path: + logger.info( + "Extracting local features with configuration:" + f"\n{pprint.pformat(conf)}" + ) + + dataset = ImageDataset(image_dir, conf["preprocessing"], image_list) + if feature_path is None: + feature_path = Path(export_dir, conf["output"] + ".h5") + feature_path.parent.mkdir(exist_ok=True, parents=True) + skip_names = set( + list_h5_names(feature_path) + if feature_path.exists() and not overwrite + else () + ) + dataset.names = [n for n in dataset.names if n not in skip_names] + if len(dataset.names) == 0: + logger.info("Skipping the extraction.") + return feature_path + + device = "cuda" if torch.cuda.is_available() else "cpu" + Model = dynamic_load(extractors, conf["model"]["name"]) + model = Model(conf["model"]).eval().to(device) + + loader = torch.utils.data.DataLoader( + dataset, num_workers=1, shuffle=False, pin_memory=True + ) + for idx, data in enumerate(tqdm(loader)): + name = dataset.names[idx] + pred = model({"image": data["image"].to(device, non_blocking=True)}) + pred = {k: v[0].cpu().numpy() for k, v in pred.items()} + + pred["image_size"] = original_size = data["original_size"][0].numpy() + if "keypoints" in pred: + size = np.array(data["image"].shape[-2:][::-1]) + scales = (original_size / size).astype(np.float32) + pred["keypoints"] = (pred["keypoints"] + 0.5) * scales[None] - 0.5 + if "scales" in pred: + pred["scales"] *= scales.mean() + # add keypoint uncertainties scaled to the original resolution + uncertainty = getattr(model, "detection_noise", 1) * scales.mean() + + if as_half: + for k in pred: + dt = pred[k].dtype + if (dt == np.float32) and (dt != np.float16): + pred[k] = pred[k].astype(np.float16) + + with h5py.File(str(feature_path), "a", libver="latest") as fd: + try: + if name in fd: + del fd[name] + grp = fd.create_group(name) + for k, v in pred.items(): + grp.create_dataset(k, data=v) + if "keypoints" in pred: + grp["keypoints"].attrs["uncertainty"] = uncertainty + except OSError as error: + if "No space left on device" in error.args[0]: + logger.error( + "Out of disk space: storing features on disk can take " + "significant space, did you enable the as_half flag?" + ) + del grp, fd[name] + raise error + + del pred + + logger.info("Finished exporting features.") + return feature_path + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--image_dir", type=Path, required=True) + parser.add_argument("--export_dir", type=Path, required=True) + parser.add_argument( + "--conf", + type=str, + default="superpoint_aachen", + choices=list(confs.keys()), + ) + parser.add_argument("--as_half", action="store_true") + parser.add_argument("--image_list", type=Path) + parser.add_argument("--feature_path", type=Path) + args = parser.parse_args() + main(confs[args.conf], args.image_dir, args.export_dir, args.as_half) diff --git a/hloc/extractors/__init__.py b/hloc/extractors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hloc/extractors/alike.py b/hloc/extractors/alike.py new file mode 100644 index 0000000000000000000000000000000000000000..2f6ae550c443551f44baba44cc4612e9a1f048cc --- /dev/null +++ b/hloc/extractors/alike.py @@ -0,0 +1,55 @@ +import sys +from pathlib import Path + +import torch + +from hloc import logger + +from ..utils.base_model import BaseModel + +alike_path = Path(__file__).parent / "../../third_party/ALIKE" +sys.path.append(str(alike_path)) +from alike import ALike as Alike_ +from alike import configs + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class Alike(BaseModel): + default_conf = { + "model_name": "alike-t", # 'alike-t', 'alike-s', 'alike-n', 'alike-l' + "use_relu": True, + "multiscale": False, + "max_keypoints": 1000, + "detection_threshold": 0.5, + "top_k": -1, + "sub_pixel": False, + } + + required_inputs = ["image"] + + def _init(self, conf): + self.net = Alike_( + **configs[conf["model_name"]], + device=device, + top_k=conf["top_k"], + scores_th=conf["detection_threshold"], + n_limit=conf["max_keypoints"], + ) + logger.info("Load Alike model done.") + + def _forward(self, data): + image = data["image"] + image = image.permute(0, 2, 3, 1).squeeze() + image = image.cpu().numpy() * 255.0 + pred = self.net(image, sub_pixel=self.conf["sub_pixel"]) + + keypoints = pred["keypoints"] + descriptors = pred["descriptors"] + scores = pred["scores"] + + return { + "keypoints": torch.from_numpy(keypoints)[None], + "scores": torch.from_numpy(scores)[None], + "descriptors": torch.from_numpy(descriptors.T)[None], + } diff --git a/hloc/extractors/cosplace.py b/hloc/extractors/cosplace.py new file mode 100644 index 0000000000000000000000000000000000000000..8d13a84d57d80bee090709623cce74453784844b --- /dev/null +++ b/hloc/extractors/cosplace.py @@ -0,0 +1,44 @@ +""" +Code for loading models trained with CosPlace as a global features extractor +for geolocalization through image retrieval. +Multiple models are available with different backbones. Below is a summary of +models available (backbone : list of available output descriptors +dimensionality). For example you can use a model based on a ResNet50 with +descriptors dimensionality 1024. + ResNet18: [32, 64, 128, 256, 512] + ResNet50: [32, 64, 128, 256, 512, 1024, 2048] + ResNet101: [32, 64, 128, 256, 512, 1024, 2048] + ResNet152: [32, 64, 128, 256, 512, 1024, 2048] + VGG16: [ 64, 128, 256, 512] + +CosPlace paper: https://arxiv.org/abs/2204.02287 +""" + +import torch +import torchvision.transforms as tvf + +from ..utils.base_model import BaseModel + + +class CosPlace(BaseModel): + default_conf = {"backbone": "ResNet50", "fc_output_dim": 2048} + required_inputs = ["image"] + + def _init(self, conf): + self.net = torch.hub.load( + "gmberton/CosPlace", + "get_trained_model", + backbone=conf["backbone"], + fc_output_dim=conf["fc_output_dim"], + ).eval() + + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + self.norm_rgb = tvf.Normalize(mean=mean, std=std) + + def _forward(self, data): + image = self.norm_rgb(data["image"]) + desc = self.net(image) + return { + "global_descriptor": desc, + } diff --git a/hloc/extractors/d2net.py b/hloc/extractors/d2net.py new file mode 100644 index 0000000000000000000000000000000000000000..3f92437714dcf63b1f81fa28ee86e5d3d1a9cddb --- /dev/null +++ b/hloc/extractors/d2net.py @@ -0,0 +1,68 @@ +import subprocess +import sys +from pathlib import Path + +import torch + +from hloc import logger + +from ..utils.base_model import BaseModel + +d2net_path = Path(__file__).parent / "../../third_party/d2net" +sys.path.append(str(d2net_path)) +from lib.model_test import D2Net as _D2Net +from lib.pyramid import process_multiscale + + +class D2Net(BaseModel): + default_conf = { + "model_name": "d2_tf.pth", + "checkpoint_dir": d2net_path / "models", + "use_relu": True, + "multiscale": False, + "max_keypoints": 1024, + } + required_inputs = ["image"] + + def _init(self, conf): + model_file = conf["checkpoint_dir"] / conf["model_name"] + if not model_file.exists(): + model_file.parent.mkdir(exist_ok=True) + cmd = [ + "wget", + "--quiet", + "https://dusmanu.com/files/d2-net/" + conf["model_name"], + "-O", + str(model_file), + ] + subprocess.run(cmd, check=True) + + self.net = _D2Net( + model_file=model_file, use_relu=conf["use_relu"], use_cuda=False + ) + logger.info("Load D2Net model done.") + + def _forward(self, data): + image = data["image"] + image = image.flip(1) # RGB -> BGR + norm = image.new_tensor([103.939, 116.779, 123.68]) + image = image * 255 - norm.view(1, 3, 1, 1) # caffe normalization + + if self.conf["multiscale"]: + keypoints, scores, descriptors = process_multiscale(image, self.net) + else: + keypoints, scores, descriptors = process_multiscale( + image, self.net, scales=[1] + ) + keypoints = keypoints[:, [1, 0]] # (x, y) and remove the scale + + idxs = scores.argsort()[-self.conf["max_keypoints"] or None :] + keypoints = keypoints[idxs, :2] + descriptors = descriptors[idxs] + scores = scores[idxs] + + return { + "keypoints": torch.from_numpy(keypoints)[None], + "scores": torch.from_numpy(scores)[None], + "descriptors": torch.from_numpy(descriptors.T)[None], + } diff --git a/hloc/extractors/darkfeat.py b/hloc/extractors/darkfeat.py new file mode 100644 index 0000000000000000000000000000000000000000..38a1e5bc99f8db490d4cdb3fc47be331151d819b --- /dev/null +++ b/hloc/extractors/darkfeat.py @@ -0,0 +1,64 @@ +import subprocess +import sys +from pathlib import Path + +from hloc import logger + +from ..utils.base_model import BaseModel + +darkfeat_path = Path(__file__).parent / "../../third_party/DarkFeat" +sys.path.append(str(darkfeat_path)) +from darkfeat import DarkFeat as DarkFeat_ + + +class DarkFeat(BaseModel): + default_conf = { + "model_name": "DarkFeat.pth", + "max_keypoints": 1000, + "detection_threshold": 0.5, + "sub_pixel": False, + } + weight_urls = { + "DarkFeat.pth": "https://drive.google.com/uc?id=1Thl6m8NcmQ7zSAF-1_xaFs3F4H8UU6HX&confirm=t", + } + proxy = "http://localhost:1080" + required_inputs = ["image"] + + def _init(self, conf): + model_path = darkfeat_path / "checkpoints" / conf["model_name"] + link = self.weight_urls[conf["model_name"]] + if not model_path.exists(): + model_path.parent.mkdir(exist_ok=True) + cmd_wo_proxy = ["gdown", link, "-O", str(model_path)] + cmd = ["gdown", link, "-O", str(model_path), "--proxy", self.proxy] + logger.info( + f"Downloading the DarkFeat model with `{cmd_wo_proxy}`." + ) + try: + subprocess.run(cmd_wo_proxy, check=True) + except subprocess.CalledProcessError as e: + logger.info(f"Downloading the model failed `{e}`.") + logger.info(f"Downloading the DarkFeat model with `{cmd}`.") + try: + subprocess.run(cmd, check=True) + except subprocess.CalledProcessError as e: + logger.error("Failed to download the DarkFeat model.") + raise e + + self.net = DarkFeat_(model_path) + logger.info("Load DarkFeat model done.") + + def _forward(self, data): + pred = self.net({"image": data["image"]}) + keypoints = pred["keypoints"] + descriptors = pred["descriptors"] + scores = pred["scores"] + idxs = scores.argsort()[-self.conf["max_keypoints"] or None :] + keypoints = keypoints[idxs, :2] + descriptors = descriptors[:, idxs] + scores = scores[idxs] + return { + "keypoints": keypoints[None], # 1 x N x 2 + "scores": scores[None], # 1 x N + "descriptors": descriptors[None], # 1 x 128 x N + } diff --git a/hloc/extractors/dedode.py b/hloc/extractors/dedode.py new file mode 100644 index 0000000000000000000000000000000000000000..a1d7130a1d6c0db65fbb2e1e40fddb90bc2e3096 --- /dev/null +++ b/hloc/extractors/dedode.py @@ -0,0 +1,111 @@ +import subprocess +import sys +from pathlib import Path + +import torch +import torchvision.transforms as transforms + +from hloc import logger + +from ..utils.base_model import BaseModel + +dedode_path = Path(__file__).parent / "../../third_party/DeDoDe" +sys.path.append(str(dedode_path)) + +from DeDoDe import dedode_descriptor_B, dedode_detector_L +from DeDoDe.utils import to_pixel_coords + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class DeDoDe(BaseModel): + default_conf = { + "name": "dedode", + "model_detector_name": "dedode_detector_L.pth", + "model_descriptor_name": "dedode_descriptor_B.pth", + "max_keypoints": 2000, + "match_threshold": 0.2, + "dense": False, # Now fixed to be false + } + required_inputs = [ + "image", + ] + weight_urls = { + "dedode_detector_L.pth": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_detector_L.pth", + "dedode_descriptor_B.pth": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_descriptor_B.pth", + } + + # Initialize the line matcher + def _init(self, conf): + model_detector_path = ( + dedode_path / "pretrained" / conf["model_detector_name"] + ) + model_descriptor_path = ( + dedode_path / "pretrained" / conf["model_descriptor_name"] + ) + + self.normalizer = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + # Download the model. + if not model_detector_path.exists(): + model_detector_path.parent.mkdir(exist_ok=True) + link = self.weight_urls[conf["model_detector_name"]] + cmd = ["wget", "--quiet", link, "-O", str(model_detector_path)] + logger.info(f"Downloading the DeDoDe detector model with `{cmd}`.") + subprocess.run(cmd, check=True) + + if not model_descriptor_path.exists(): + model_descriptor_path.parent.mkdir(exist_ok=True) + link = self.weight_urls[conf["model_descriptor_name"]] + cmd = ["wget", "--quiet", link, "-O", str(model_descriptor_path)] + logger.info( + f"Downloading the DeDoDe descriptor model with `{cmd}`." + ) + subprocess.run(cmd, check=True) + + # load the model + weights_detector = torch.load(model_detector_path, map_location="cpu") + weights_descriptor = torch.load( + model_descriptor_path, map_location="cpu" + ) + self.detector = dedode_detector_L( + weights=weights_detector, device=device + ) + self.descriptor = dedode_descriptor_B( + weights=weights_descriptor, device=device + ) + logger.info("Load DeDoDe model done.") + + def _forward(self, data): + """ + data: dict, keys: {'image0','image1'} + image shape: N x C x H x W + color mode: RGB + """ + img0 = self.normalizer(data["image"].squeeze()).float()[None] + H_A, W_A = img0.shape[2:] + + # step 1: detect keypoints + detections_A = None + batch_A = {"image": img0} + if self.conf["dense"]: + detections_A = self.detector.detect_dense(batch_A) + else: + detections_A = self.detector.detect( + batch_A, num_keypoints=self.conf["max_keypoints"] + ) + keypoints_A, P_A = detections_A["keypoints"], detections_A["confidence"] + + # step 2: describe keypoints + # dim: 1 x N x 256 + description_A = self.descriptor.describe_keypoints( + batch_A, keypoints_A + )["descriptions"] + keypoints_A = to_pixel_coords(keypoints_A, H_A, W_A) + + return { + "keypoints": keypoints_A, # 1 x N x 2 + "descriptors": description_A.permute(0, 2, 1), # 1 x 256 x N + "scores": P_A, # 1 x N + } diff --git a/hloc/extractors/dir.py b/hloc/extractors/dir.py new file mode 100644 index 0000000000000000000000000000000000000000..2d47256b1a4f2d74a99fc0320293ba1b3bf88bb4 --- /dev/null +++ b/hloc/extractors/dir.py @@ -0,0 +1,81 @@ +import os +import sys +from pathlib import Path +from zipfile import ZipFile + +import gdown +import sklearn +import torch + +from ..utils.base_model import BaseModel + +sys.path.append( + str(Path(__file__).parent / "../../third_party/deep-image-retrieval") +) +os.environ["DB_ROOT"] = "" # required by dirtorch + +from dirtorch.extract_features import load_model # noqa: E402 +from dirtorch.utils import common # noqa: E402 + +# The DIR model checkpoints (pickle files) include sklearn.decomposition.pca, +# which has been deprecated in sklearn v0.24 +# and must be explicitly imported with `from sklearn.decomposition import PCA`. +# This is a hacky workaround to maintain forward compatibility. +sys.modules["sklearn.decomposition.pca"] = sklearn.decomposition._pca + + +class DIR(BaseModel): + default_conf = { + "model_name": "Resnet-101-AP-GeM", + "whiten_name": "Landmarks_clean", + "whiten_params": { + "whitenp": 0.25, + "whitenv": None, + "whitenm": 1.0, + }, + "pooling": "gem", + "gemp": 3, + } + required_inputs = ["image"] + + dir_models = { + "Resnet-101-AP-GeM": "https://docs.google.com/uc?export=download&id=1UWJGDuHtzaQdFhSMojoYVQjmCXhIwVvy", + } + + def _init(self, conf): + checkpoint = Path( + torch.hub.get_dir(), "dirtorch", conf["model_name"] + ".pt" + ) + if not checkpoint.exists(): + checkpoint.parent.mkdir(exist_ok=True, parents=True) + link = self.dir_models[conf["model_name"]] + gdown.download(str(link), str(checkpoint) + ".zip", quiet=False) + zf = ZipFile(str(checkpoint) + ".zip", "r") + zf.extractall(checkpoint.parent) + zf.close() + os.remove(str(checkpoint) + ".zip") + + self.net = load_model(checkpoint, False) # first load on CPU + if conf["whiten_name"]: + assert conf["whiten_name"] in self.net.pca + + def _forward(self, data): + image = data["image"] + assert image.shape[1] == 3 + mean = self.net.preprocess["mean"] + std = self.net.preprocess["std"] + image = image - image.new_tensor(mean)[:, None, None] + image = image / image.new_tensor(std)[:, None, None] + + desc = self.net(image) + desc = desc.unsqueeze(0) # batch dimension + if self.conf["whiten_name"]: + pca = self.net.pca[self.conf["whiten_name"]] + desc = common.whiten_features( + desc.cpu().numpy(), pca, **self.conf["whiten_params"] + ) + desc = torch.from_numpy(desc) + + return { + "global_descriptor": desc, + } diff --git a/hloc/extractors/disk.py b/hloc/extractors/disk.py new file mode 100644 index 0000000000000000000000000000000000000000..762061016eaa262f4f7468ad9b8ba3889410b142 --- /dev/null +++ b/hloc/extractors/disk.py @@ -0,0 +1,35 @@ +import kornia + +from hloc import logger + +from ..utils.base_model import BaseModel + + +class DISK(BaseModel): + default_conf = { + "weights": "depth", + "max_keypoints": None, + "nms_window_size": 5, + "detection_threshold": 0.0, + "pad_if_not_divisible": True, + } + required_inputs = ["image"] + + def _init(self, conf): + self.model = kornia.feature.DISK.from_pretrained(conf["weights"]) + logger.info("Load DISK model done.") + + def _forward(self, data): + image = data["image"] + features = self.model( + image, + n=self.conf["max_keypoints"], + window_size=self.conf["nms_window_size"], + score_threshold=self.conf["detection_threshold"], + pad_if_not_divisible=self.conf["pad_if_not_divisible"], + ) + return { + "keypoints": [f.keypoints for f in features][0][None], + "scores": [f.detection_scores for f in features][0][None], + "descriptors": [f.descriptors.t() for f in features][0][None], + } diff --git a/hloc/extractors/dog.py b/hloc/extractors/dog.py new file mode 100644 index 0000000000000000000000000000000000000000..b280bbc42376f3af827002bb85ff4996ccdf50b4 --- /dev/null +++ b/hloc/extractors/dog.py @@ -0,0 +1,135 @@ +import kornia +import numpy as np +import pycolmap +import torch +from kornia.feature.laf import ( + extract_patches_from_pyramid, + laf_from_center_scale_ori, +) + +from ..utils.base_model import BaseModel + +EPS = 1e-6 + + +def sift_to_rootsift(x): + x = x / (np.linalg.norm(x, ord=1, axis=-1, keepdims=True) + EPS) + x = np.sqrt(x.clip(min=EPS)) + x = x / (np.linalg.norm(x, axis=-1, keepdims=True) + EPS) + return x + + +class DoG(BaseModel): + default_conf = { + "options": { + "first_octave": 0, + "peak_threshold": 0.01, + }, + "descriptor": "rootsift", + "max_keypoints": -1, + "patch_size": 32, + "mr_size": 12, + } + required_inputs = ["image"] + detection_noise = 1.0 + max_batch_size = 1024 + + def _init(self, conf): + if conf["descriptor"] == "sosnet": + self.describe = kornia.feature.SOSNet(pretrained=True) + elif conf["descriptor"] == "hardnet": + self.describe = kornia.feature.HardNet(pretrained=True) + elif conf["descriptor"] not in ["sift", "rootsift"]: + raise ValueError(f'Unknown descriptor: {conf["descriptor"]}') + + self.sift = None # lazily instantiated on the first image + self.dummy_param = torch.nn.Parameter(torch.empty(0)) + self.device = torch.device("cpu") + + def to(self, *args, **kwargs): + device = kwargs.get("device") + if device is None: + match = [a for a in args if isinstance(a, (torch.device, str))] + if len(match) > 0: + device = match[0] + if device is not None: + self.device = torch.device(device) + return super().to(*args, **kwargs) + + def _forward(self, data): + image = data["image"] + image_np = image.cpu().numpy()[0, 0] + assert image.shape[1] == 1 + assert image_np.min() >= -EPS and image_np.max() <= 1 + EPS + + if self.sift is None: + device = self.dummy_param.device + use_gpu = pycolmap.has_cuda and device.type == "cuda" + options = {**self.conf["options"]} + if self.conf["descriptor"] == "rootsift": + options["normalization"] = pycolmap.Normalization.L1_ROOT + else: + options["normalization"] = pycolmap.Normalization.L2 + self.sift = pycolmap.Sift( + options=pycolmap.SiftExtractionOptions(options), + device=getattr(pycolmap.Device, "cuda" if use_gpu else "cpu"), + ) + keypoints, descriptors = self.sift.extract(image_np) + scales = keypoints[:, 2] + oris = np.rad2deg(keypoints[:, 3]) + + if self.conf["descriptor"] in ["sift", "rootsift"]: + # We still renormalize because COLMAP does not normalize well, + # maybe due to numerical errors + if self.conf["descriptor"] == "rootsift": + descriptors = sift_to_rootsift(descriptors) + descriptors = torch.from_numpy(descriptors) + elif self.conf["descriptor"] in ("sosnet", "hardnet"): + center = keypoints[:, :2] + 0.5 + laf_scale = scales * self.conf["mr_size"] / 2 + laf_ori = -oris + lafs = laf_from_center_scale_ori( + torch.from_numpy(center)[None], + torch.from_numpy(laf_scale)[None, :, None, None], + torch.from_numpy(laf_ori)[None, :, None], + ).to(image.device) + patches = extract_patches_from_pyramid( + image, lafs, PS=self.conf["patch_size"] + )[0] + descriptors = patches.new_zeros((len(patches), 128)) + if len(patches) > 0: + for start_idx in range(0, len(patches), self.max_batch_size): + end_idx = min(len(patches), start_idx + self.max_batch_size) + descriptors[start_idx:end_idx] = self.describe( + patches[start_idx:end_idx] + ) + else: + raise ValueError(f'Unknown descriptor: {self.conf["descriptor"]}') + + keypoints = torch.from_numpy(keypoints[:, :2]) # keep only x, y + scales = torch.from_numpy(scales) + oris = torch.from_numpy(oris) + scores = keypoints.new_zeros(len(keypoints)) # no scores for SIFT yet + + if self.conf["max_keypoints"] != -1: + # TODO: check that the scores from PyCOLMAP are 100% correct, + # follow https://github.com/mihaidusmanu/pycolmap/issues/8 + max_number = ( + scores.shape[0] + if scores.shape[0] < self.conf["max_keypoints"] + else self.conf["max_keypoints"] + ) + values, indices = torch.topk(scores, max_number) + keypoints = keypoints[indices] + scales = scales[indices] + oris = oris[indices] + scores = scores[indices] + descriptors = descriptors[indices] + + return { + "keypoints": keypoints[None], + "scales": scales[None], + "oris": oris[None], + "scores": scores[None], + "descriptors": descriptors.T[None], + } diff --git a/hloc/extractors/eigenplaces.py b/hloc/extractors/eigenplaces.py new file mode 100644 index 0000000000000000000000000000000000000000..fd9953b27c00682c830842736fd0bdab93857f14 --- /dev/null +++ b/hloc/extractors/eigenplaces.py @@ -0,0 +1,57 @@ +""" +Code for loading models trained with EigenPlaces (or CosPlace) as a global +features extractor for geolocalization through image retrieval. +Multiple models are available with different backbones. Below is a summary of +models available (backbone : list of available output descriptors +dimensionality). For example you can use a model based on a ResNet50 with +descriptors dimensionality 1024. + +EigenPlaces trained models: + ResNet18: [ 256, 512] + ResNet50: [128, 256, 512, 2048] + ResNet101: [128, 256, 512, 2048] + VGG16: [ 512] + +CosPlace trained models: + ResNet18: [32, 64, 128, 256, 512] + ResNet50: [32, 64, 128, 256, 512, 1024, 2048] + ResNet101: [32, 64, 128, 256, 512, 1024, 2048] + ResNet152: [32, 64, 128, 256, 512, 1024, 2048] + VGG16: [ 64, 128, 256, 512] + +EigenPlaces paper (ICCV 2023): https://arxiv.org/abs/2308.10832 +CosPlace paper (CVPR 2022): https://arxiv.org/abs/2204.02287 +""" + +import torch +import torchvision.transforms as tvf + +from ..utils.base_model import BaseModel + + +class EigenPlaces(BaseModel): + default_conf = { + "variant": "EigenPlaces", + "backbone": "ResNet101", + "fc_output_dim": 2048, + } + required_inputs = ["image"] + + def _init(self, conf): + self.net = torch.hub.load( + "gmberton/" + conf["variant"], + "get_trained_model", + backbone=conf["backbone"], + fc_output_dim=conf["fc_output_dim"], + ).eval() + + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + self.norm_rgb = tvf.Normalize(mean=mean, std=std) + + def _forward(self, data): + image = self.norm_rgb(data["image"]) + desc = self.net(image) + return { + "global_descriptor": desc, + } diff --git a/hloc/extractors/example.py b/hloc/extractors/example.py new file mode 100644 index 0000000000000000000000000000000000000000..3d952c4014e006d74409a8f32ee7159d58305de5 --- /dev/null +++ b/hloc/extractors/example.py @@ -0,0 +1,56 @@ +import sys +from pathlib import Path + +import torch + +from .. import logger +from ..utils.base_model import BaseModel + +example_path = Path(__file__).parent / "../../third_party/example" +sys.path.append(str(example_path)) + +# import some modules here + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class Example(BaseModel): + # change to your default configs + default_conf = { + "name": "example", + "keypoint_threshold": 0.1, + "max_keypoints": 2000, + "model_name": "model.pth", + } + required_inputs = ["image"] + + def _init(self, conf): + # set checkpoints paths if needed + model_path = example_path / "checkpoints" / f'{conf["model_name"]}' + if not model_path.exists(): + logger.info(f"No model found at {model_path}") + + # init model + self.net = callable + # self.net = ExampleNet(is_test=True) + state_dict = torch.load(model_path, map_location="cpu") + self.net.load_state_dict(state_dict["model_state"]) + logger.info("Load example model done.") + + def _forward(self, data): + # data: dict, keys: 'image' + # image color mode: RGB + # image value range in [0, 1] + image = data["image"] + + # B: batch size, N: number of keypoints + # keypoints shape: B x N x 2, type: torch tensor + # scores shape: B x N, type: torch tensor + # descriptors shape: B x 128 x N, type: torch tensor + keypoints, scores, descriptors = self.net(image) + + return { + "keypoints": keypoints, + "scores": scores, + "descriptors": descriptors, + } diff --git a/hloc/extractors/fire.py b/hloc/extractors/fire.py new file mode 100644 index 0000000000000000000000000000000000000000..980f18e63d1a395835891c8e6595cfc66c21db2d --- /dev/null +++ b/hloc/extractors/fire.py @@ -0,0 +1,72 @@ +import logging +import subprocess +import sys +from pathlib import Path + +import torch +import torchvision.transforms as tvf + +from ..utils.base_model import BaseModel + +logger = logging.getLogger(__name__) +fire_path = Path(__file__).parent / "../../third_party/fire" +sys.path.append(str(fire_path)) + + +import fire_network + + +class FIRe(BaseModel): + default_conf = { + "global": True, + "asmk": False, + "model_name": "fire_SfM_120k.pth", + "scales": [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25], # default params + "features_num": 1000, # TODO:not supported now + "asmk_name": "asmk_codebook.bin", # TODO:not supported now + "config_name": "eval_fire.yml", + } + required_inputs = ["image"] + + # Models exported using + fire_models = { + "fire_SfM_120k.pth": "http://download.europe.naverlabs.com/ComputerVision/FIRe/official/fire.pth", + "fire_imagenet.pth": "http://download.europe.naverlabs.com/ComputerVision/FIRe/pretraining/fire_imagenet.pth", + } + + def _init(self, conf): + assert conf["model_name"] in self.fire_models.keys() + # Config paths + model_path = fire_path / "model" / conf["model_name"] + + # Download the model. + if not model_path.exists(): + model_path.parent.mkdir(exist_ok=True) + link = self.fire_models[conf["model_name"]] + cmd = ["wget", "--quiet", link, "-O", str(model_path)] + logger.info(f"Downloading the FIRe model with `{cmd}`.") + subprocess.run(cmd, check=True) + + logger.info("Loading fire model...") + + # Load net + state = torch.load(model_path) + state["net_params"]["pretrained"] = None + net = fire_network.init_network(**state["net_params"]) + net.load_state_dict(state["state_dict"]) + self.net = net + + self.norm_rgb = tvf.Normalize( + **dict(zip(["mean", "std"], net.runtime["mean_std"])) + ) + + # params + self.scales = conf["scales"] + + def _forward(self, data): + image = self.norm_rgb(data["image"]) + + # Feature extraction. + desc = self.net.forward_global(image, scales=self.scales) + + return {"global_descriptor": desc} diff --git a/hloc/extractors/fire_local.py b/hloc/extractors/fire_local.py new file mode 100644 index 0000000000000000000000000000000000000000..a8e9ba9f4c3d86280e8232f61263b729ccb933be --- /dev/null +++ b/hloc/extractors/fire_local.py @@ -0,0 +1,84 @@ +import subprocess +import sys +from pathlib import Path + +import torch +import torchvision.transforms as tvf + +from .. import logger +from ..utils.base_model import BaseModel + +fire_path = Path(__file__).parent / "../../third_party/fire" + +sys.path.append(str(fire_path)) + + +import fire_network + +EPS = 1e-6 + + +class FIRe(BaseModel): + default_conf = { + "global": True, + "asmk": False, + "model_name": "fire_SfM_120k.pth", + "scales": [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25], # default params + "features_num": 1000, + "asmk_name": "asmk_codebook.bin", + "config_name": "eval_fire.yml", + } + required_inputs = ["image"] + + # Models exported using + fire_models = { + "fire_SfM_120k.pth": "http://download.europe.naverlabs.com/ComputerVision/FIRe/official/fire.pth", + "fire_imagenet.pth": "http://download.europe.naverlabs.com/ComputerVision/FIRe/pretraining/fire_imagenet.pth", + } + + def _init(self, conf): + assert conf["model_name"] in self.fire_models.keys() + + # Config paths + model_path = fire_path / "model" / conf["model_name"] + config_path = fire_path / conf["config_name"] # noqa: F841 + asmk_bin_path = fire_path / "model" / conf["asmk_name"] # noqa: F841 + + # Download the model. + if not model_path.exists(): + model_path.parent.mkdir(exist_ok=True) + link = self.fire_models[conf["model_name"]] + cmd = ["wget", "--quiet", link, "-O", str(model_path)] + logger.info(f"Downloading the FIRe model with `{cmd}`.") + subprocess.run(cmd, check=True) + + logger.info("Loading fire model...") + + # Load net + state = torch.load(model_path) + state["net_params"]["pretrained"] = None + net = fire_network.init_network(**state["net_params"]) + net.load_state_dict(state["state_dict"]) + self.net = net + + self.norm_rgb = tvf.Normalize( + **dict(zip(["mean", "std"], net.runtime["mean_std"])) + ) + + # params + self.scales = conf["scales"] + self.features_num = conf["features_num"] + + def _forward(self, data): + image = self.norm_rgb(data["image"]) + + local_desc = self.net.forward_local( + image, features_num=self.features_num, scales=self.scales + ) + + logger.info(f"output[0].shape = {local_desc[0].shape}\n") + + return { + # 'global_descriptor': desc + "local_descriptor": local_desc + } diff --git a/hloc/extractors/lanet.py b/hloc/extractors/lanet.py new file mode 100644 index 0000000000000000000000000000000000000000..59ec07bceb60540d8c616e3e76e96aae3bc24595 --- /dev/null +++ b/hloc/extractors/lanet.py @@ -0,0 +1,66 @@ +import sys +from pathlib import Path + +import torch + +from hloc import logger + +from ..utils.base_model import BaseModel + +lib_path = Path(__file__).parent / "../../third_party" +sys.path.append(str(lib_path)) +from lanet.network_v0.model import PointModel + +lanet_path = Path(__file__).parent / "../../third_party/lanet" + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class LANet(BaseModel): + default_conf = { + "model_name": "v0", + "keypoint_threshold": 0.1, + "max_keypoints": 1024, + } + required_inputs = ["image"] + + def _init(self, conf): + model_path = ( + lanet_path / "checkpoints" / f'PointModel_{conf["model_name"]}.pth' + ) + if not model_path.exists(): + logger.warning(f"No model found at {model_path}, start downloading") + self.net = PointModel(is_test=True) + state_dict = torch.load(model_path, map_location="cpu") + self.net.load_state_dict(state_dict["model_state"]) + logger.info("Load LANet model done.") + + def _forward(self, data): + image = data["image"] + keypoints, scores, descriptors = self.net(image) + _, _, Hc, Wc = descriptors.shape + + # Scores & Descriptors + kpts_score = torch.cat([keypoints, scores], dim=1).view(3, -1).t() + descriptors = descriptors.view(256, Hc, Wc).view(256, -1).t() + + # Filter based on confidence threshold + descriptors = descriptors[ + kpts_score[:, 0] > self.conf["keypoint_threshold"], : + ] + kpts_score = kpts_score[ + kpts_score[:, 0] > self.conf["keypoint_threshold"], : + ] + keypoints = kpts_score[:, 1:] + scores = kpts_score[:, 0] + + idxs = scores.argsort()[-self.conf["max_keypoints"] or None :] + keypoints = keypoints[idxs, :2] + descriptors = descriptors[idxs] + scores = scores[idxs] + + return { + "keypoints": keypoints[None], + "scores": scores[None], + "descriptors": descriptors.T[None], + } diff --git a/hloc/extractors/netvlad.py b/hloc/extractors/netvlad.py new file mode 100644 index 0000000000000000000000000000000000000000..c7938820d0ea0c84b738ef5564aa1dbad5532236 --- /dev/null +++ b/hloc/extractors/netvlad.py @@ -0,0 +1,152 @@ +import subprocess +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models +from scipy.io import loadmat + +from .. import logger +from ..utils.base_model import BaseModel + +EPS = 1e-6 + + +class NetVLADLayer(nn.Module): + def __init__(self, input_dim=512, K=64, score_bias=False, intranorm=True): + super().__init__() + self.score_proj = nn.Conv1d( + input_dim, K, kernel_size=1, bias=score_bias + ) + centers = nn.parameter.Parameter(torch.empty([input_dim, K])) + nn.init.xavier_uniform_(centers) + self.register_parameter("centers", centers) + self.intranorm = intranorm + self.output_dim = input_dim * K + + def forward(self, x): + b = x.size(0) + scores = self.score_proj(x) + scores = F.softmax(scores, dim=1) + diff = x.unsqueeze(2) - self.centers.unsqueeze(0).unsqueeze(-1) + desc = (scores.unsqueeze(1) * diff).sum(dim=-1) + if self.intranorm: + # From the official MATLAB implementation. + desc = F.normalize(desc, dim=1) + desc = desc.view(b, -1) + desc = F.normalize(desc, dim=1) + return desc + + +class NetVLAD(BaseModel): + default_conf = {"model_name": "VGG16-NetVLAD-Pitts30K", "whiten": True} + required_inputs = ["image"] + + # Models exported using + # https://github.com/uzh-rpg/netvlad_tf_open/blob/master/matlab/net_class2struct.m. + dir_models = { + "VGG16-NetVLAD-Pitts30K": "https://cvg-data.inf.ethz.ch/hloc/netvlad/Pitts30K_struct.mat", + "VGG16-NetVLAD-TokyoTM": "https://cvg-data.inf.ethz.ch/hloc/netvlad/TokyoTM_struct.mat", + } + + def _init(self, conf): + assert conf["model_name"] in self.dir_models.keys() + + # Download the checkpoint. + checkpoint = Path( + torch.hub.get_dir(), "netvlad", conf["model_name"] + ".mat" + ) + if not checkpoint.exists(): + checkpoint.parent.mkdir(exist_ok=True, parents=True) + link = self.dir_models[conf["model_name"]] + cmd = ["wget", "--quiet", link, "-O", str(checkpoint)] + logger.info(f"Downloading the NetVLAD model with `{cmd}`.") + subprocess.run(cmd, check=True) + + # Create the network. + # Remove classification head. + backbone = list(models.vgg16().children())[0] + # Remove last ReLU + MaxPool2d. + self.backbone = nn.Sequential(*list(backbone.children())[:-2]) + + self.netvlad = NetVLADLayer() + + if conf["whiten"]: + self.whiten = nn.Linear(self.netvlad.output_dim, 4096) + + # Parse MATLAB weights using https://github.com/uzh-rpg/netvlad_tf_open + mat = loadmat(checkpoint, struct_as_record=False, squeeze_me=True) + + # CNN weights. + for layer, mat_layer in zip( + self.backbone.children(), mat["net"].layers + ): + if isinstance(layer, nn.Conv2d): + w = mat_layer.weights[0] # Shape: S x S x IN x OUT + b = mat_layer.weights[1] # Shape: OUT + # Prepare for PyTorch - enforce float32 and right shape. + # w should have shape: OUT x IN x S x S + # b should have shape: OUT + w = torch.tensor(w).float().permute([3, 2, 0, 1]) + b = torch.tensor(b).float() + # Update layer weights. + layer.weight = nn.Parameter(w) + layer.bias = nn.Parameter(b) + + # NetVLAD weights. + score_w = mat["net"].layers[30].weights[0] # D x K + # centers are stored as opposite in official MATLAB code + center_w = -mat["net"].layers[30].weights[1] # D x K + # Prepare for PyTorch - make sure it is float32 and has right shape. + # score_w should have shape K x D x 1 + # center_w should have shape D x K + score_w = torch.tensor(score_w).float().permute([1, 0]).unsqueeze(-1) + center_w = torch.tensor(center_w).float() + # Update layer weights. + self.netvlad.score_proj.weight = nn.Parameter(score_w) + self.netvlad.centers = nn.Parameter(center_w) + + # Whitening weights. + if conf["whiten"]: + w = mat["net"].layers[33].weights[0] # Shape: 1 x 1 x IN x OUT + b = mat["net"].layers[33].weights[1] # Shape: OUT + # Prepare for PyTorch - make sure it is float32 and has right shape + w = torch.tensor(w).float().squeeze().permute([1, 0]) # OUT x IN + b = torch.tensor(b.squeeze()).float() # Shape: OUT + # Update layer weights. + self.whiten.weight = nn.Parameter(w) + self.whiten.bias = nn.Parameter(b) + + # Preprocessing parameters. + self.preprocess = { + "mean": mat["net"].meta.normalization.averageImage[0, 0], + "std": np.array([1, 1, 1], dtype=np.float32), + } + + def _forward(self, data): + image = data["image"] + assert image.shape[1] == 3 + assert image.min() >= -EPS and image.max() <= 1 + EPS + image = torch.clamp(image * 255, 0.0, 255.0) # Input should be 0-255. + mean = self.preprocess["mean"] + std = self.preprocess["std"] + image = image - image.new_tensor(mean).view(1, -1, 1, 1) + image = image / image.new_tensor(std).view(1, -1, 1, 1) + + # Feature extraction. + descriptors = self.backbone(image) + b, c, _, _ = descriptors.size() + descriptors = descriptors.view(b, c, -1) + + # NetVLAD layer. + descriptors = F.normalize(descriptors, dim=1) # Pre-normalization. + desc = self.netvlad(descriptors) + + # Whiten if needed. + if hasattr(self, "whiten"): + desc = self.whiten(desc) + desc = F.normalize(desc, dim=1) # Final L2 normalization. + + return {"global_descriptor": desc} diff --git a/hloc/extractors/openibl.py b/hloc/extractors/openibl.py new file mode 100644 index 0000000000000000000000000000000000000000..9e332a4e0016fceb184dd850bd3b6f86231dad54 --- /dev/null +++ b/hloc/extractors/openibl.py @@ -0,0 +1,26 @@ +import torch +import torchvision.transforms as tvf + +from ..utils.base_model import BaseModel + + +class OpenIBL(BaseModel): + default_conf = { + "model_name": "vgg16_netvlad", + } + required_inputs = ["image"] + + def _init(self, conf): + self.net = torch.hub.load( + "yxgeee/OpenIBL", conf["model_name"], pretrained=True + ).eval() + mean = [0.48501960784313836, 0.4579568627450961, 0.4076039215686255] + std = [0.00392156862745098, 0.00392156862745098, 0.00392156862745098] + self.norm_rgb = tvf.Normalize(mean=mean, std=std) + + def _forward(self, data): + image = self.norm_rgb(data["image"]) + desc = self.net(image) + return { + "global_descriptor": desc, + } diff --git a/hloc/extractors/r2d2.py b/hloc/extractors/r2d2.py new file mode 100644 index 0000000000000000000000000000000000000000..359d89c96a5590764bae0604989c2d738c814bd9 --- /dev/null +++ b/hloc/extractors/r2d2.py @@ -0,0 +1,65 @@ +import sys +from pathlib import Path + +import torchvision.transforms as tvf + +from hloc import logger + +from ..utils.base_model import BaseModel + +r2d2_path = Path(__file__).parent / "../../third_party/r2d2" +sys.path.append(str(r2d2_path)) +from extract import NonMaxSuppression, extract_multiscale, load_network + + +class R2D2(BaseModel): + default_conf = { + "model_name": "r2d2_WASF_N16.pt", + "max_keypoints": 5000, + "scale_factor": 2**0.25, + "min_size": 256, + "max_size": 1024, + "min_scale": 0, + "max_scale": 1, + "reliability_threshold": 0.7, + "repetability_threshold": 0.7, + } + required_inputs = ["image"] + + def _init(self, conf): + model_fn = r2d2_path / "models" / conf["model_name"] + self.norm_rgb = tvf.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + self.net = load_network(model_fn) + self.detector = NonMaxSuppression( + rel_thr=conf["reliability_threshold"], + rep_thr=conf["repetability_threshold"], + ) + logger.info("Load R2D2 model done.") + + def _forward(self, data): + img = data["image"] + img = self.norm_rgb(img) + + xys, desc, scores = extract_multiscale( + self.net, + img, + self.detector, + scale_f=self.conf["scale_factor"], + min_size=self.conf["min_size"], + max_size=self.conf["max_size"], + min_scale=self.conf["min_scale"], + max_scale=self.conf["max_scale"], + ) + idxs = scores.argsort()[-self.conf["max_keypoints"] or None :] + xy = xys[idxs, :2] + desc = desc[idxs].t() + scores = scores[idxs] + + pred = { + "keypoints": xy[None], + "descriptors": desc[None], + "scores": scores[None], + } + return pred diff --git a/hloc/extractors/rekd.py b/hloc/extractors/rekd.py new file mode 100644 index 0000000000000000000000000000000000000000..c4fbb5fd583d0371c1dba900c5e3719391bed3e0 --- /dev/null +++ b/hloc/extractors/rekd.py @@ -0,0 +1,72 @@ +import sys +from pathlib import Path + +import torch + +from hloc import logger + +from ..utils.base_model import BaseModel + +rekd_path = Path(__file__).parent / "../../third_party" +sys.path.append(str(rekd_path)) +from REKD.training.model.REKD import REKD as REKD_ + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class REKD(BaseModel): + default_conf = { + "model_name": "v0", + "keypoint_threshold": 0.1, + } + required_inputs = ["image"] + + def _init(self, conf): + model_path = ( + rekd_path / "checkpoints" / f'PointModel_{conf["model_name"]}.pth' + ) + if not model_path.exists(): + print(f"No model found at {model_path}") + self.net = REKD_(is_test=True) + state_dict = torch.load(model_path, map_location="cpu") + self.net.load_state_dict(state_dict["model_state"]) + logger.info("Load REKD model done.") + + def _forward(self, data): + image = data["image"] + keypoints, scores, descriptors = self.net(image) + _, _, Hc, Wc = descriptors.shape + + # Scores & Descriptors + kpts_score = ( + torch.cat([keypoints, scores], dim=1) + .view(3, -1) + .t() + .cpu() + .detach() + .numpy() + ) + descriptors = ( + descriptors.view(256, Hc, Wc) + .view(256, -1) + .t() + .cpu() + .detach() + .numpy() + ) + + # Filter based on confidence threshold + descriptors = descriptors[ + kpts_score[:, 0] > self.conf["keypoint_threshold"], : + ] + kpts_score = kpts_score[ + kpts_score[:, 0] > self.conf["keypoint_threshold"], : + ] + keypoints = kpts_score[:, 1:] + scores = kpts_score[:, 0] + + return { + "keypoints": torch.from_numpy(keypoints)[None], + "scores": torch.from_numpy(scores)[None], + "descriptors": torch.from_numpy(descriptors.T)[None], + } diff --git a/hloc/extractors/rord.py b/hloc/extractors/rord.py new file mode 100644 index 0000000000000000000000000000000000000000..cf1d5249ef88c7f90d0e98fa911bbba9d4067a99 --- /dev/null +++ b/hloc/extractors/rord.py @@ -0,0 +1,76 @@ +import subprocess +import sys +from pathlib import Path + +import torch + +from hloc import logger + +from ..utils.base_model import BaseModel + +rord_path = Path(__file__).parent / "../../third_party" +sys.path.append(str(rord_path)) +from RoRD.lib.model_test import D2Net as _RoRD +from RoRD.lib.pyramid import process_multiscale + + +class RoRD(BaseModel): + default_conf = { + "model_name": "rord.pth", + "checkpoint_dir": rord_path / "RoRD" / "models", + "use_relu": True, + "multiscale": False, + "max_keypoints": 1024, + } + required_inputs = ["image"] + weight_urls = { + "rord.pth": "https://drive.google.com/uc?id=12414ZGKwgPAjNTGtNrlB4VV9l7W76B2o&confirm=t", + } + proxy = "http://localhost:1080" + + def _init(self, conf): + model_path = conf["checkpoint_dir"] / conf["model_name"] + link = self.weight_urls[conf["model_name"]] + if not model_path.exists(): + model_path.parent.mkdir(exist_ok=True) + cmd_wo_proxy = ["gdown", link, "-O", str(model_path)] + cmd = ["gdown", link, "-O", str(model_path), "--proxy", self.proxy] + logger.info(f"Downloading the RoRD model with `{cmd_wo_proxy}`.") + try: + subprocess.run(cmd_wo_proxy, check=True) + except subprocess.CalledProcessError as e: + logger.info(f"Downloading failed {e}.") + logger.info(f"Downloading the RoRD model with {cmd}.") + try: + subprocess.run(cmd, check=True) + except subprocess.CalledProcessError as e: + logger.error(f"Failed to download the RoRD model: {e}") + self.net = _RoRD( + model_file=model_path, use_relu=conf["use_relu"], use_cuda=False + ) + logger.info("Load RoRD model done.") + + def _forward(self, data): + image = data["image"] + image = image.flip(1) # RGB -> BGR + norm = image.new_tensor([103.939, 116.779, 123.68]) + image = image * 255 - norm.view(1, 3, 1, 1) # caffe normalization + + if self.conf["multiscale"]: + keypoints, scores, descriptors = process_multiscale(image, self.net) + else: + keypoints, scores, descriptors = process_multiscale( + image, self.net, scales=[1] + ) + keypoints = keypoints[:, [1, 0]] # (x, y) and remove the scale + + idxs = scores.argsort()[-self.conf["max_keypoints"] or None :] + keypoints = keypoints[idxs, :2] + descriptors = descriptors[idxs] + scores = scores[idxs] + + return { + "keypoints": torch.from_numpy(keypoints)[None], + "scores": torch.from_numpy(scores)[None], + "descriptors": torch.from_numpy(descriptors.T)[None], + } diff --git a/hloc/extractors/sfd2.py b/hloc/extractors/sfd2.py new file mode 100644 index 0000000000000000000000000000000000000000..1bd6188faa8ac8bfa647e6d5bcb3a9dfc07a2f30 --- /dev/null +++ b/hloc/extractors/sfd2.py @@ -0,0 +1,41 @@ +import sys +from pathlib import Path + +import torchvision.transforms as tvf + +from .. import logger +from ..utils.base_model import BaseModel + +tp_path = Path(__file__).parent / "../../third_party" +sys.path.append(str(tp_path)) +from pram.nets.sfd2 import load_sfd2 + + +class SFD2(BaseModel): + default_conf = { + "max_keypoints": 4096, + "model_name": "sfd2_20230511_210205_resnet4x.79.pth", + "conf_th": 0.001, + } + required_inputs = ["image"] + + def _init(self, conf): + self.conf = {**self.default_conf, **conf} + self.norm_rgb = tvf.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + model_path = tp_path / "pram" / "weights" / self.conf["model_name"] + self.net = load_sfd2(weight_path=model_path).eval() + + logger.info("Load SFD2 model done.") + + def _forward(self, data): + pred = self.net.extract_local_global( + data={"image": self.norm_rgb(data["image"])}, config=self.conf + ) + out = { + "keypoints": pred["keypoints"][0][None], + "scores": pred["scores"][0][None], + "descriptors": pred["descriptors"][0][None], + } + return out diff --git a/hloc/extractors/sift.py b/hloc/extractors/sift.py new file mode 100644 index 0000000000000000000000000000000000000000..09576f98355595ea1c8e0105bac98887a320b675 --- /dev/null +++ b/hloc/extractors/sift.py @@ -0,0 +1,225 @@ +import warnings + +import cv2 +import numpy as np +import torch +from kornia.color import rgb_to_grayscale +from omegaconf import OmegaConf +from packaging import version + +try: + import pycolmap +except ImportError: + pycolmap = None +from hloc import logger + +from ..utils.base_model import BaseModel + + +def filter_dog_point( + points, scales, angles, image_shape, nms_radius, scores=None +): + h, w = image_shape + ij = np.round(points - 0.5).astype(int).T[::-1] + + # Remove duplicate points (identical coordinates). + # Pick highest scale or score + s = scales if scores is None else scores + buffer = np.zeros((h, w)) + np.maximum.at(buffer, tuple(ij), s) + keep = np.where(buffer[tuple(ij)] == s)[0] + + # Pick lowest angle (arbitrary). + ij = ij[:, keep] + buffer[:] = np.inf + o_abs = np.abs(angles[keep]) + np.minimum.at(buffer, tuple(ij), o_abs) + mask = buffer[tuple(ij)] == o_abs + ij = ij[:, mask] + keep = keep[mask] + + if nms_radius > 0: + # Apply NMS on the remaining points + buffer[:] = 0 + buffer[tuple(ij)] = s[keep] # scores or scale + + local_max = torch.nn.functional.max_pool2d( + torch.from_numpy(buffer).unsqueeze(0), + kernel_size=nms_radius * 2 + 1, + stride=1, + padding=nms_radius, + ).squeeze(0) + is_local_max = buffer == local_max.numpy() + keep = keep[is_local_max[tuple(ij)]] + return keep + + +def sift_to_rootsift(x: torch.Tensor, eps=1e-6) -> torch.Tensor: + x = torch.nn.functional.normalize(x, p=1, dim=-1, eps=eps) + x.clip_(min=eps).sqrt_() + return torch.nn.functional.normalize(x, p=2, dim=-1, eps=eps) + + +def run_opencv_sift(features: cv2.Feature2D, image: np.ndarray) -> np.ndarray: + """ + Detect keypoints using OpenCV Detector. + Optionally, perform description. + Args: + features: OpenCV based keypoints detector and descriptor + image: Grayscale image of uint8 data type + Returns: + keypoints: 1D array of detected cv2.KeyPoint + scores: 1D array of responses + descriptors: 1D array of descriptors + """ + detections, descriptors = features.detectAndCompute(image, None) + points = np.array([k.pt for k in detections], dtype=np.float32) + scores = np.array([k.response for k in detections], dtype=np.float32) + scales = np.array([k.size for k in detections], dtype=np.float32) + angles = np.deg2rad( + np.array([k.angle for k in detections], dtype=np.float32) + ) + return points, scores, scales, angles, descriptors + + +class SIFT(BaseModel): + default_conf = { + "rootsift": True, + "nms_radius": 0, # None to disable filtering entirely. + "max_keypoints": 4096, + "backend": "opencv", # in {opencv, pycolmap, pycolmap_cpu, pycolmap_cuda} + "detection_threshold": 0.0066667, # from COLMAP + "edge_threshold": 10, + "first_octave": -1, # only used by pycolmap, the default of COLMAP + "num_octaves": 4, + } + + required_data_keys = ["image"] + + def _init(self, conf): + self.conf = OmegaConf.create(self.conf) + backend = self.conf.backend + if backend.startswith("pycolmap"): + if pycolmap is None: + raise ImportError( + "Cannot find module pycolmap: install it with pip" + "or use backend=opencv." + ) + options = { + "peak_threshold": self.conf.detection_threshold, + "edge_threshold": self.conf.edge_threshold, + "first_octave": self.conf.first_octave, + "num_octaves": self.conf.num_octaves, + "normalization": pycolmap.Normalization.L2, # L1_ROOT is buggy. + } + device = ( + "auto" + if backend == "pycolmap" + else backend.replace("pycolmap_", "") + ) + if ( + backend == "pycolmap_cpu" or not pycolmap.has_cuda + ) and pycolmap.__version__ < "0.5.0": + warnings.warn( + "The pycolmap CPU SIFT is buggy in version < 0.5.0, " + "consider upgrading pycolmap or use the CUDA version.", + stacklevel=1, + ) + else: + options["max_num_features"] = self.conf.max_keypoints + self.sift = pycolmap.Sift(options=options, device=device) + elif backend == "opencv": + self.sift = cv2.SIFT_create( + contrastThreshold=self.conf.detection_threshold, + nfeatures=self.conf.max_keypoints, + edgeThreshold=self.conf.edge_threshold, + nOctaveLayers=self.conf.num_octaves, + ) + else: + backends = {"opencv", "pycolmap", "pycolmap_cpu", "pycolmap_cuda"} + raise ValueError( + f"Unknown backend: {backend} not in " + f"{{{','.join(backends)}}}." + ) + logger.info("Load SIFT model done.") + + def extract_single_image(self, image: torch.Tensor): + image_np = image.cpu().numpy().squeeze(0) + + if self.conf.backend.startswith("pycolmap"): + if version.parse(pycolmap.__version__) >= version.parse("0.5.0"): + detections, descriptors = self.sift.extract(image_np) + scores = None # Scores are not exposed by COLMAP anymore. + else: + detections, scores, descriptors = self.sift.extract(image_np) + keypoints = detections[:, :2] # Keep only (x, y). + scales, angles = detections[:, -2:].T + if scores is not None and ( + self.conf.backend == "pycolmap_cpu" or not pycolmap.has_cuda + ): + # Set the scores as a combination of abs. response and scale. + scores = np.abs(scores) * scales + elif self.conf.backend == "opencv": + # TODO: Check if opencv keypoints are already in corner convention + keypoints, scores, scales, angles, descriptors = run_opencv_sift( + self.sift, (image_np * 255.0).astype(np.uint8) + ) + pred = { + "keypoints": keypoints, + "scales": scales, + "oris": angles, + "descriptors": descriptors, + } + if scores is not None: + pred["scores"] = scores + + # sometimes pycolmap returns points outside the image. We remove them + if self.conf.backend.startswith("pycolmap"): + is_inside = ( + pred["keypoints"] + 0.5 < np.array([image_np.shape[-2:][::-1]]) + ).all(-1) + pred = {k: v[is_inside] for k, v in pred.items()} + + if self.conf.nms_radius is not None: + keep = filter_dog_point( + pred["keypoints"], + pred["scales"], + pred["oris"], + image_np.shape, + self.conf.nms_radius, + scores=pred.get("scores"), + ) + pred = {k: v[keep] for k, v in pred.items()} + + pred = {k: torch.from_numpy(v) for k, v in pred.items()} + if scores is not None: + # Keep the k keypoints with highest score + num_points = self.conf.max_keypoints + if num_points is not None and len(pred["keypoints"]) > num_points: + indices = torch.topk(pred["scores"], num_points).indices + pred = {k: v[indices] for k, v in pred.items()} + return pred + + def _forward(self, data: dict) -> dict: + image = data["image"] + if image.shape[1] == 3: + image = rgb_to_grayscale(image) + device = image.device + image = image.cpu() + pred = [] + for k in range(len(image)): + img = image[k] + if "image_size" in data.keys(): + # avoid extracting points in padded areas + w, h = data["image_size"][k] + img = img[:, :h, :w] + p = self.extract_single_image(img) + pred.append(p) + pred = { + k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0] + } + if self.conf.rootsift: + pred["descriptors"] = sift_to_rootsift(pred["descriptors"]) + pred["descriptors"] = pred["descriptors"].permute(0, 2, 1) + pred["keypoint_scores"] = pred["scores"].clone() + return pred diff --git a/hloc/extractors/superpoint.py b/hloc/extractors/superpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..ee618392ae9d976b40d1c43a6628892a09d993fd --- /dev/null +++ b/hloc/extractors/superpoint.py @@ -0,0 +1,51 @@ +import sys +from pathlib import Path + +import torch + +from hloc import logger + +from ..utils.base_model import BaseModel + +sys.path.append(str(Path(__file__).parent / "../../third_party")) +from SuperGluePretrainedNetwork.models import superpoint # noqa E402 + + +# The original keypoint sampling is incorrect. We patch it here but +# we don't fix it upstream to not impact exisiting evaluations. +def sample_descriptors_fix_sampling(keypoints, descriptors, s: int = 8): + """Interpolate descriptors at keypoint locations""" + b, c, h, w = descriptors.shape + keypoints = (keypoints + 0.5) / (keypoints.new_tensor([w, h]) * s) + keypoints = keypoints * 2 - 1 # normalize to (-1, 1) + descriptors = torch.nn.functional.grid_sample( + descriptors, + keypoints.view(b, 1, -1, 2), + mode="bilinear", + align_corners=False, + ) + descriptors = torch.nn.functional.normalize( + descriptors.reshape(b, c, -1), p=2, dim=1 + ) + return descriptors + + +class SuperPoint(BaseModel): + default_conf = { + "nms_radius": 4, + "keypoint_threshold": 0.005, + "max_keypoints": -1, + "remove_borders": 4, + "fix_sampling": False, + } + required_inputs = ["image"] + detection_noise = 2.0 + + def _init(self, conf): + if conf["fix_sampling"]: + superpoint.sample_descriptors = sample_descriptors_fix_sampling + self.net = superpoint.SuperPoint(conf) + logger.info("Load SuperPoint model done.") + + def _forward(self, data): + return self.net(data, self.conf) diff --git a/hloc/extractors/xfeat.py b/hloc/extractors/xfeat.py new file mode 100644 index 0000000000000000000000000000000000000000..5dc230f247a79021db8b194ac5ce1d0ff7f37e89 --- /dev/null +++ b/hloc/extractors/xfeat.py @@ -0,0 +1,33 @@ +import torch + +from hloc import logger + +from ..utils.base_model import BaseModel + + +class XFeat(BaseModel): + default_conf = { + "keypoint_threshold": 0.005, + "max_keypoints": -1, + } + required_inputs = ["image"] + + def _init(self, conf): + self.net = torch.hub.load( + "verlab/accelerated_features", + "XFeat", + pretrained=True, + top_k=self.conf["max_keypoints"], + ) + logger.info("Load XFeat(sparse) model done.") + + def _forward(self, data): + pred = self.net.detectAndCompute( + data["image"], top_k=self.conf["max_keypoints"] + )[0] + pred = { + "keypoints": pred["keypoints"][None], + "scores": pred["scores"][None], + "descriptors": pred["descriptors"].T[None], + } + return pred diff --git a/hloc/localize_inloc.py b/hloc/localize_inloc.py new file mode 100644 index 0000000000000000000000000000000000000000..1e003b1678bb84a544ec51ecf3ddef83e09e406d --- /dev/null +++ b/hloc/localize_inloc.py @@ -0,0 +1,183 @@ +import argparse +import pickle +from pathlib import Path + +import cv2 +import h5py +import numpy as np +import pycolmap +import torch +from scipy.io import loadmat +from tqdm import tqdm + +from . import logger +from .utils.parsers import names_to_pair, parse_retrieval + + +def interpolate_scan(scan, kp): + h, w, c = scan.shape + kp = kp / np.array([[w - 1, h - 1]]) * 2 - 1 + assert np.all(kp > -1) and np.all(kp < 1) + scan = torch.from_numpy(scan).permute(2, 0, 1)[None] + kp = torch.from_numpy(kp)[None, None] + grid_sample = torch.nn.functional.grid_sample + + # To maximize the number of points that have depth: + # do bilinear interpolation first and then nearest for the remaining points + interp_lin = grid_sample(scan, kp, align_corners=True, mode="bilinear")[ + 0, :, 0 + ] + interp_nn = torch.nn.functional.grid_sample( + scan, kp, align_corners=True, mode="nearest" + )[0, :, 0] + interp = torch.where(torch.isnan(interp_lin), interp_nn, interp_lin) + valid = ~torch.any(torch.isnan(interp), 0) + + kp3d = interp.T.numpy() + valid = valid.numpy() + return kp3d, valid + + +def get_scan_pose(dataset_dir, rpath): + split_image_rpath = rpath.split("/") + floor_name = split_image_rpath[-3] + scan_id = split_image_rpath[-2] + image_name = split_image_rpath[-1] + building_name = image_name[:3] + + path = Path( + dataset_dir, + "database/alignments", + floor_name, + f"transformations/{building_name}_trans_{scan_id}.txt", + ) + with open(path) as f: + raw_lines = f.readlines() + + P_after_GICP = np.array( + [ + np.fromstring(raw_lines[7], sep=" "), + np.fromstring(raw_lines[8], sep=" "), + np.fromstring(raw_lines[9], sep=" "), + np.fromstring(raw_lines[10], sep=" "), + ] + ) + + return P_after_GICP + + +def pose_from_cluster( + dataset_dir, q, retrieved, feature_file, match_file, skip=None +): + height, width = cv2.imread(str(dataset_dir / q)).shape[:2] + cx = 0.5 * width + cy = 0.5 * height + focal_length = 4032.0 * 28.0 / 36.0 + + all_mkpq = [] + all_mkpr = [] + all_mkp3d = [] + all_indices = [] + kpq = feature_file[q]["keypoints"].__array__() + num_matches = 0 + + for i, r in enumerate(retrieved): + kpr = feature_file[r]["keypoints"].__array__() + pair = names_to_pair(q, r) + m = match_file[pair]["matches0"].__array__() + v = m > -1 + + if skip and (np.count_nonzero(v) < skip): + continue + + mkpq, mkpr = kpq[v], kpr[m[v]] + num_matches += len(mkpq) + + scan_r = loadmat(Path(dataset_dir, r + ".mat"))["XYZcut"] + mkp3d, valid = interpolate_scan(scan_r, mkpr) + Tr = get_scan_pose(dataset_dir, r) + mkp3d = (Tr[:3, :3] @ mkp3d.T + Tr[:3, -1:]).T + + all_mkpq.append(mkpq[valid]) + all_mkpr.append(mkpr[valid]) + all_mkp3d.append(mkp3d[valid]) + all_indices.append(np.full(np.count_nonzero(valid), i)) + + all_mkpq = np.concatenate(all_mkpq, 0) + all_mkpr = np.concatenate(all_mkpr, 0) + all_mkp3d = np.concatenate(all_mkp3d, 0) + all_indices = np.concatenate(all_indices, 0) + + cfg = { + "model": "SIMPLE_PINHOLE", + "width": width, + "height": height, + "params": [focal_length, cx, cy], + } + ret = pycolmap.absolute_pose_estimation(all_mkpq, all_mkp3d, cfg, 48.00) + ret["cfg"] = cfg + return ret, all_mkpq, all_mkpr, all_mkp3d, all_indices, num_matches + + +def main(dataset_dir, retrieval, features, matches, results, skip_matches=None): + assert retrieval.exists(), retrieval + assert features.exists(), features + assert matches.exists(), matches + + retrieval_dict = parse_retrieval(retrieval) + queries = list(retrieval_dict.keys()) + + feature_file = h5py.File(features, "r", libver="latest") + match_file = h5py.File(matches, "r", libver="latest") + + poses = {} + logs = { + "features": features, + "matches": matches, + "retrieval": retrieval, + "loc": {}, + } + logger.info("Starting localization...") + for q in tqdm(queries): + db = retrieval_dict[q] + ret, mkpq, mkpr, mkp3d, indices, num_matches = pose_from_cluster( + dataset_dir, q, db, feature_file, match_file, skip_matches + ) + + poses[q] = (ret["qvec"], ret["tvec"]) + logs["loc"][q] = { + "db": db, + "PnP_ret": ret, + "keypoints_query": mkpq, + "keypoints_db": mkpr, + "3d_points": mkp3d, + "indices_db": indices, + "num_matches": num_matches, + } + + logger.info(f"Writing poses to {results}...") + with open(results, "w") as f: + for q in queries: + qvec, tvec = poses[q] + qvec = " ".join(map(str, qvec)) + tvec = " ".join(map(str, tvec)) + name = q.split("/")[-1] + f.write(f"{name} {qvec} {tvec}\n") + + logs_path = f"{results}_logs.pkl" + logger.info(f"Writing logs to {logs_path}...") + with open(logs_path, "wb") as f: + pickle.dump(logs, f) + logger.info("Done!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--dataset_dir", type=Path, required=True) + parser.add_argument("--retrieval", type=Path, required=True) + parser.add_argument("--features", type=Path, required=True) + parser.add_argument("--matches", type=Path, required=True) + parser.add_argument("--results", type=Path, required=True) + parser.add_argument("--skip_matches", type=int) + args = parser.parse_args() + main(**args.__dict__) diff --git a/hloc/localize_sfm.py b/hloc/localize_sfm.py new file mode 100644 index 0000000000000000000000000000000000000000..a1cb672254936ba6b6c9576fa6078f00458c714c --- /dev/null +++ b/hloc/localize_sfm.py @@ -0,0 +1,247 @@ +import argparse +import pickle +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Union + +import numpy as np +import pycolmap +from tqdm import tqdm + +from . import logger +from .utils.io import get_keypoints, get_matches +from .utils.parsers import parse_image_lists, parse_retrieval + + +def do_covisibility_clustering( + frame_ids: List[int], reconstruction: pycolmap.Reconstruction +): + clusters = [] + visited = set() + for frame_id in frame_ids: + # Check if already labeled + if frame_id in visited: + continue + + # New component + clusters.append([]) + queue = {frame_id} + while len(queue): + exploration_frame = queue.pop() + + # Already part of the component + if exploration_frame in visited: + continue + visited.add(exploration_frame) + clusters[-1].append(exploration_frame) + + observed = reconstruction.images[exploration_frame].points2D + connected_frames = { + obs.image_id + for p2D in observed + if p2D.has_point3D() + for obs in reconstruction.points3D[ + p2D.point3D_id + ].track.elements + } + connected_frames &= set(frame_ids) + connected_frames -= visited + queue |= connected_frames + + clusters = sorted(clusters, key=len, reverse=True) + return clusters + + +class QueryLocalizer: + def __init__(self, reconstruction, config=None): + self.reconstruction = reconstruction + self.config = config or {} + + def localize(self, points2D_all, points2D_idxs, points3D_id, query_camera): + points2D = points2D_all[points2D_idxs] + points3D = [self.reconstruction.points3D[j].xyz for j in points3D_id] + ret = pycolmap.absolute_pose_estimation( + points2D, + points3D, + query_camera, + estimation_options=self.config.get("estimation", {}), + refinement_options=self.config.get("refinement", {}), + ) + return ret + + +def pose_from_cluster( + localizer: QueryLocalizer, + qname: str, + query_camera: pycolmap.Camera, + db_ids: List[int], + features_path: Path, + matches_path: Path, + **kwargs, +): + kpq = get_keypoints(features_path, qname) + kpq += 0.5 # COLMAP coordinates + + kp_idx_to_3D = defaultdict(list) + kp_idx_to_3D_to_db = defaultdict(lambda: defaultdict(list)) + num_matches = 0 + for i, db_id in enumerate(db_ids): + image = localizer.reconstruction.images[db_id] + if image.num_points3D == 0: + logger.debug(f"No 3D points found for {image.name}.") + continue + points3D_ids = np.array( + [p.point3D_id if p.has_point3D() else -1 for p in image.points2D] + ) + + matches, _ = get_matches(matches_path, qname, image.name) + matches = matches[points3D_ids[matches[:, 1]] != -1] + num_matches += len(matches) + for idx, m in matches: + id_3D = points3D_ids[m] + kp_idx_to_3D_to_db[idx][id_3D].append(i) + # avoid duplicate observations + if id_3D not in kp_idx_to_3D[idx]: + kp_idx_to_3D[idx].append(id_3D) + + idxs = list(kp_idx_to_3D.keys()) + mkp_idxs = [i for i in idxs for _ in kp_idx_to_3D[i]] + mp3d_ids = [j for i in idxs for j in kp_idx_to_3D[i]] + ret = localizer.localize(kpq, mkp_idxs, mp3d_ids, query_camera, **kwargs) + if ret is not None: + ret["camera"] = query_camera + + # mostly for logging and post-processing + mkp_to_3D_to_db = [ + (j, kp_idx_to_3D_to_db[i][j]) for i in idxs for j in kp_idx_to_3D[i] + ] + log = { + "db": db_ids, + "PnP_ret": ret, + "keypoints_query": kpq[mkp_idxs], + "points3D_ids": mp3d_ids, + "points3D_xyz": None, # we don't log xyz anymore because of file size + "num_matches": num_matches, + "keypoint_index_to_db": (mkp_idxs, mkp_to_3D_to_db), + } + return ret, log + + +def main( + reference_sfm: Union[Path, pycolmap.Reconstruction], + queries: Path, + retrieval: Path, + features: Path, + matches: Path, + results: Path, + ransac_thresh: int = 12, + covisibility_clustering: bool = False, + prepend_camera_name: bool = False, + config: Dict = None, +): + assert retrieval.exists(), retrieval + assert features.exists(), features + assert matches.exists(), matches + + queries = parse_image_lists(queries, with_intrinsics=True) + retrieval_dict = parse_retrieval(retrieval) + + logger.info("Reading the 3D model...") + if not isinstance(reference_sfm, pycolmap.Reconstruction): + reference_sfm = pycolmap.Reconstruction(reference_sfm) + db_name_to_id = {img.name: i for i, img in reference_sfm.images.items()} + + config = { + "estimation": {"ransac": {"max_error": ransac_thresh}}, + **(config or {}), + } + localizer = QueryLocalizer(reference_sfm, config) + + cam_from_world = {} + logs = { + "features": features, + "matches": matches, + "retrieval": retrieval, + "loc": {}, + } + logger.info("Starting localization...") + for qname, qcam in tqdm(queries): + if qname not in retrieval_dict: + logger.warning( + f"No images retrieved for query image {qname}. Skipping..." + ) + continue + db_names = retrieval_dict[qname] + db_ids = [] + for n in db_names: + if n not in db_name_to_id: + logger.warning(f"Image {n} was retrieved but not in database") + continue + db_ids.append(db_name_to_id[n]) + + if covisibility_clustering: + clusters = do_covisibility_clustering(db_ids, reference_sfm) + best_inliers = 0 + best_cluster = None + logs_clusters = [] + for i, cluster_ids in enumerate(clusters): + ret, log = pose_from_cluster( + localizer, qname, qcam, cluster_ids, features, matches + ) + if ret is not None and ret["num_inliers"] > best_inliers: + best_cluster = i + best_inliers = ret["num_inliers"] + logs_clusters.append(log) + if best_cluster is not None: + ret = logs_clusters[best_cluster]["PnP_ret"] + cam_from_world[qname] = ret["cam_from_world"] + logs["loc"][qname] = { + "db": db_ids, + "best_cluster": best_cluster, + "log_clusters": logs_clusters, + "covisibility_clustering": covisibility_clustering, + } + else: + ret, log = pose_from_cluster( + localizer, qname, qcam, db_ids, features, matches + ) + if ret is not None: + cam_from_world[qname] = ret["cam_from_world"] + else: + closest = reference_sfm.images[db_ids[0]] + cam_from_world[qname] = closest.cam_from_world + log["covisibility_clustering"] = covisibility_clustering + logs["loc"][qname] = log + + logger.info(f"Localized {len(cam_from_world)} / {len(queries)} images.") + logger.info(f"Writing poses to {results}...") + with open(results, "w") as f: + for query, t in cam_from_world.items(): + qvec = " ".join(map(str, t.rotation.quat[[3, 0, 1, 2]])) + tvec = " ".join(map(str, t.translation)) + name = query.split("/")[-1] + if prepend_camera_name: + name = query.split("/")[-2] + "/" + name + f.write(f"{name} {qvec} {tvec}\n") + + logs_path = f"{results}_logs.pkl" + logger.info(f"Writing logs to {logs_path}...") + # TODO: Resolve pickling issue with pycolmap objects. + with open(logs_path, "wb") as f: + pickle.dump(logs, f) + logger.info("Done!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--reference_sfm", type=Path, required=True) + parser.add_argument("--queries", type=Path, required=True) + parser.add_argument("--features", type=Path, required=True) + parser.add_argument("--matches", type=Path, required=True) + parser.add_argument("--retrieval", type=Path, required=True) + parser.add_argument("--results", type=Path, required=True) + parser.add_argument("--ransac_thresh", type=float, default=12.0) + parser.add_argument("--covisibility_clustering", action="store_true") + parser.add_argument("--prepend_camera_name", action="store_true") + args = parser.parse_args() + main(**args.__dict__) diff --git a/hloc/match_dense.py b/hloc/match_dense.py new file mode 100644 index 0000000000000000000000000000000000000000..30d35422cf085744f3436da3c321edbb1ae69b4f --- /dev/null +++ b/hloc/match_dense.py @@ -0,0 +1,1121 @@ +import argparse +import pprint +from collections import Counter, defaultdict +from itertools import chain +from pathlib import Path +from types import SimpleNamespace +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union + +import cv2 +import h5py +import numpy as np +import torch +import torchvision.transforms.functional as F +from scipy.spatial import KDTree +from tqdm import tqdm + +from . import logger, matchers +from .extract_features import read_image, resize_image +from .match_features import find_unique_new_pairs +from .utils.base_model import dynamic_load +from .utils.io import list_h5_names +from .utils.parsers import names_to_pair, parse_retrieval + +device = "cuda" if torch.cuda.is_available() else "cpu" + +confs = { + # Best quality but loads of points. Only use for small scenes + "loftr": { + "output": "matches-loftr", + "model": { + "name": "loftr", + "weights": "outdoor", + "max_keypoints": 2000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": True, + "resize_max": 1024, + "dfactor": 8, + "width": 640, + "height": 480, + "force_resize": True, + }, + "max_error": 1, # max error for assigned keypoints (in px) + "cell_size": 1, # size of quantization patch (max 1 kp/patch) + }, + "eloftr": { + "output": "matches-eloftr", + "model": { + "name": "eloftr", + "weights": "weights/eloftr_outdoor.ckpt", + "max_keypoints": 2000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": True, + "resize_max": 1024, + "dfactor": 32, + "width": 640, + "height": 480, + "force_resize": True, + }, + "max_error": 1, # max error for assigned keypoints (in px) + "cell_size": 1, # size of quantization patch (max 1 kp/patch) + }, + # "loftr_quadtree": { + # "output": "matches-loftr-quadtree", + # "model": { + # "name": "quadtree", + # "weights": "outdoor", + # "max_keypoints": 2000, + # "match_threshold": 0.2, + # }, + # "preprocessing": { + # "grayscale": True, + # "resize_max": 1024, + # "dfactor": 8, + # "width": 640, + # "height": 480, + # "force_resize": True, + # }, + # "max_error": 1, # max error for assigned keypoints (in px) + # "cell_size": 1, # size of quantization patch (max 1 kp/patch) + # }, + "cotr": { + "output": "matches-cotr", + "model": { + "name": "cotr", + "weights": "out/default", + "max_keypoints": 2000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": False, + "resize_max": 1024, + "dfactor": 8, + "width": 640, + "height": 480, + "force_resize": True, + }, + "max_error": 1, # max error for assigned keypoints (in px) + "cell_size": 1, # size of quantization patch (max 1 kp/patch) + }, + # Semi-scalable loftr which limits detected keypoints + "loftr_aachen": { + "output": "matches-loftr_aachen", + "model": { + "name": "loftr", + "weights": "outdoor", + "max_keypoints": 2000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": True, + "resize_max": 1024, + "dfactor": 8, + "width": 640, + "height": 480, + "force_resize": True, + }, + "max_error": 2, # max error for assigned keypoints (in px) + "cell_size": 8, # size of quantization patch (max 1 kp/patch) + }, + # Use for matching superpoint feats with loftr + "loftr_superpoint": { + "output": "matches-loftr_aachen", + "model": { + "name": "loftr", + "weights": "outdoor", + "max_keypoints": 2000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": True, + "resize_max": 1024, + "dfactor": 8, + "width": 640, + "height": 480, + "force_resize": True, + }, + "max_error": 4, # max error for assigned keypoints (in px) + "cell_size": 4, # size of quantization patch (max 1 kp/patch) + }, + # Use topicfm for matching feats + "topicfm": { + "output": "matches-topicfm", + "model": { + "name": "topicfm", + "weights": "outdoor", + "max_keypoints": 2000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": True, + "force_resize": True, + "resize_max": 1024, + "dfactor": 8, + "width": 640, + "height": 480, + }, + }, + # Use aspanformer for matching feats + "aspanformer": { + "output": "matches-aspanformer", + "model": { + "name": "aspanformer", + "weights": "outdoor", + "max_keypoints": 2000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": True, + "force_resize": True, + "resize_max": 1024, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "duster": { + "output": "matches-duster", + "model": { + "name": "duster", + "weights": "vit_large", + "max_keypoints": 2000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": False, + "resize_max": 512, + "dfactor": 16, + }, + }, + "mast3r": { + "output": "matches-mast3r", + "model": { + "name": "mast3r", + "weights": "vit_large", + "max_keypoints": 2000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": False, + "resize_max": 512, + "dfactor": 16, + }, + }, + "xfeat_lightglue": { + "output": "matches-xfeat_lightglue", + "model": { + "name": "xfeat_lightglue", + "max_keypoints": 8000, + }, + "preprocessing": { + "grayscale": False, + "force_resize": False, + "resize_max": 1024, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "xfeat_dense": { + "output": "matches-xfeat_dense", + "model": { + "name": "xfeat_dense", + "max_keypoints": 8000, + }, + "preprocessing": { + "grayscale": False, + "force_resize": False, + "resize_max": 1024, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "dkm": { + "output": "matches-dkm", + "model": { + "name": "dkm", + "weights": "outdoor", + "max_keypoints": 2000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": False, + "force_resize": True, + "resize_max": 1024, + "width": 80, + "height": 60, + "dfactor": 8, + }, + }, + "roma": { + "output": "matches-roma", + "model": { + "name": "roma", + "weights": "outdoor", + "max_keypoints": 2000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": False, + "force_resize": True, + "resize_max": 1024, + "width": 320, + "height": 240, + "dfactor": 8, + }, + }, + "gim(dkm)": { + "output": "matches-gim", + "model": { + "name": "gim", + "weights": "gim_dkm_100h.ckpt", + "max_keypoints": 2000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": False, + "force_resize": True, + "resize_max": 1024, + "width": 320, + "height": 240, + "dfactor": 8, + }, + }, + "omniglue": { + "output": "matches-omniglue", + "model": { + "name": "omniglue", + "match_threshold": 0.2, + "max_keypoints": 2000, + "features": "null", + }, + "preprocessing": { + "grayscale": False, + "resize_max": 1024, + "dfactor": 8, + "force_resize": False, + "resize_max": 1024, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "sold2": { + "output": "matches-sold2", + "model": { + "name": "sold2", + "max_keypoints": 2000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": True, + "force_resize": True, + "resize_max": 1024, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "gluestick": { + "output": "matches-gluestick", + "model": { + "name": "gluestick", + "use_lines": True, + "max_keypoints": 1000, + "max_lines": 300, + "force_num_keypoints": False, + }, + "preprocessing": { + "grayscale": True, + "force_resize": True, + "resize_max": 1024, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, +} + + +def to_cpts(kpts, ps): + if ps > 0.0: + kpts = np.round(np.round((kpts + 0.5) / ps) * ps - 0.5, 2) + return [tuple(cpt) for cpt in kpts] + + +def assign_keypoints( + kpts: np.ndarray, + other_cpts: Union[List[Tuple], np.ndarray], + max_error: float, + update: bool = False, + ref_bins: Optional[List[Counter]] = None, + scores: Optional[np.ndarray] = None, + cell_size: Optional[int] = None, +): + if not update: + # Without update this is just a NN search + if len(other_cpts) == 0 or len(kpts) == 0: + return np.full(len(kpts), -1) + dist, kpt_ids = KDTree(np.array(other_cpts)).query(kpts) + valid = dist <= max_error + kpt_ids[~valid] = -1 + return kpt_ids + else: + ps = cell_size if cell_size is not None else max_error + ps = max(ps, max_error) + # With update we quantize and bin (optionally) + assert isinstance(other_cpts, list) + kpt_ids = [] + cpts = to_cpts(kpts, ps) + bpts = to_cpts(kpts, int(max_error)) + cp_to_id = {val: i for i, val in enumerate(other_cpts)} + for i, (cpt, bpt) in enumerate(zip(cpts, bpts)): + try: + kid = cp_to_id[cpt] + except KeyError: + kid = len(cp_to_id) + cp_to_id[cpt] = kid + other_cpts.append(cpt) + if ref_bins is not None: + ref_bins.append(Counter()) + if ref_bins is not None: + score = scores[i] if scores is not None else 1 + ref_bins[cp_to_id[cpt]][bpt] += score + kpt_ids.append(kid) + return np.array(kpt_ids) + + +def get_grouped_ids(array): + # Group array indices based on its values + # all duplicates are grouped as a set + idx_sort = np.argsort(array) + sorted_array = array[idx_sort] + _, ids, _ = np.unique(sorted_array, return_counts=True, return_index=True) + res = np.split(idx_sort, ids[1:]) + return res + + +def get_unique_matches(match_ids, scores): + if len(match_ids.shape) == 1: + return [0] + + isets1 = get_grouped_ids(match_ids[:, 0]) + isets2 = get_grouped_ids(match_ids[:, 1]) + uid1s = [ids[scores[ids].argmax()] for ids in isets1 if len(ids) > 0] + uid2s = [ids[scores[ids].argmax()] for ids in isets2 if len(ids) > 0] + uids = list(set(uid1s).intersection(uid2s)) + return match_ids[uids], scores[uids] + + +def matches_to_matches0(matches, scores): + if len(matches) == 0: + return np.zeros(0, dtype=np.int32), np.zeros(0, dtype=np.float16) + n_kps0 = np.max(matches[:, 0]) + 1 + matches0 = -np.ones((n_kps0,)) + scores0 = np.zeros((n_kps0,)) + matches0[matches[:, 0]] = matches[:, 1] + scores0[matches[:, 0]] = scores + return matches0.astype(np.int32), scores0.astype(np.float16) + + +def kpids_to_matches0(kpt_ids0, kpt_ids1, scores): + valid = (kpt_ids0 != -1) & (kpt_ids1 != -1) + matches = np.dstack([kpt_ids0[valid], kpt_ids1[valid]]) + matches = matches.reshape(-1, 2) + scores = scores[valid] + + # Remove n-to-1 matches + matches, scores = get_unique_matches(matches, scores) + return matches_to_matches0(matches, scores) + + +def scale_keypoints(kpts, scale): + if np.any(scale != 1.0): + kpts *= kpts.new_tensor(scale) + return kpts + + +class ImagePairDataset(torch.utils.data.Dataset): + default_conf = { + "grayscale": True, + "resize_max": 1024, + "dfactor": 8, + "cache_images": False, + } + + def __init__(self, image_dir, conf, pairs): + self.image_dir = image_dir + self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf}) + self.pairs = pairs + if self.conf.cache_images: + image_names = set(sum(pairs, ())) # unique image names in pairs + logger.info( + f"Loading and caching {len(image_names)} unique images." + ) + self.images = {} + self.scales = {} + for name in tqdm(image_names): + image = read_image(self.image_dir / name, self.conf.grayscale) + self.images[name], self.scales[name] = self.preprocess(image) + + def preprocess(self, image: np.ndarray): + image = image.astype(np.float32, copy=False) + size = image.shape[:2][::-1] + scale = np.array([1.0, 1.0]) + + if self.conf.resize_max: + scale = self.conf.resize_max / max(size) + if scale < 1.0: + size_new = tuple(int(round(x * scale)) for x in size) + image = resize_image(image, size_new, "cv2_area") + scale = np.array(size) / np.array(size_new) + + if self.conf.grayscale: + assert image.ndim == 2, image.shape + image = image[None] + else: + image = image.transpose((2, 0, 1)) # HxWxC to CxHxW + image = torch.from_numpy(image / 255.0).float() + + # assure that the size is divisible by dfactor + size_new = tuple( + map( + lambda x: int(x // self.conf.dfactor * self.conf.dfactor), + image.shape[-2:], + ) + ) + image = F.resize(image, size=size_new) + scale = np.array(size) / np.array(size_new)[::-1] + return image, scale + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, idx): + name0, name1 = self.pairs[idx] + if self.conf.cache_images: + image0, scale0 = self.images[name0], self.scales[name0] + image1, scale1 = self.images[name1], self.scales[name1] + else: + image0 = read_image(self.image_dir / name0, self.conf.grayscale) + image1 = read_image(self.image_dir / name1, self.conf.grayscale) + image0, scale0 = self.preprocess(image0) + image1, scale1 = self.preprocess(image1) + return image0, image1, scale0, scale1, name0, name1 + + +@torch.no_grad() +def match_dense( + conf: Dict, + pairs: List[Tuple[str, str]], + image_dir: Path, + match_path: Path, # out + existing_refs: Optional[List] = [], +): + device = "cuda" if torch.cuda.is_available() else "cpu" + Model = dynamic_load(matchers, conf["model"]["name"]) + model = Model(conf["model"]).eval().to(device) + + dataset = ImagePairDataset(image_dir, conf["preprocessing"], pairs) + loader = torch.utils.data.DataLoader( + dataset, num_workers=16, batch_size=1, shuffle=False + ) + + logger.info("Performing dense matching...") + with h5py.File(str(match_path), "a") as fd: + for data in tqdm(loader, smoothing=0.1): + # load image-pair data + image0, image1, scale0, scale1, (name0,), (name1,) = data + scale0, scale1 = scale0[0].numpy(), scale1[0].numpy() + image0, image1 = image0.to(device), image1.to(device) + + # match semi-dense + # for consistency with pairs_from_*: refine kpts of image0 + if name0 in existing_refs: + # special case: flip to enable refinement in query image + pred = model({"image0": image1, "image1": image0}) + pred = { + **pred, + "keypoints0": pred["keypoints1"], + "keypoints1": pred["keypoints0"], + } + else: + # usual case + pred = model({"image0": image0, "image1": image1}) + + # Rescale keypoints and move to cpu + kpts0, kpts1 = pred["keypoints0"], pred["keypoints1"] + kpts0 = scale_keypoints(kpts0 + 0.5, scale0) - 0.5 + kpts1 = scale_keypoints(kpts1 + 0.5, scale1) - 0.5 + kpts0 = kpts0.cpu().numpy() + kpts1 = kpts1.cpu().numpy() + scores = pred["scores"].cpu().numpy() + + # Write matches and matching scores in hloc format + pair = names_to_pair(name0, name1) + if pair in fd: + del fd[pair] + grp = fd.create_group(pair) + + # Write dense matching output + grp.create_dataset("keypoints0", data=kpts0) + grp.create_dataset("keypoints1", data=kpts1) + grp.create_dataset("scores", data=scores) + del model, loader + + +# default: quantize all! +def load_keypoints( + conf: Dict, feature_paths_refs: List[Path], quantize: Optional[set] = None +): + name2ref = { + n: i for i, p in enumerate(feature_paths_refs) for n in list_h5_names(p) + } + + existing_refs = set(name2ref.keys()) + if quantize is None: + quantize = existing_refs # quantize all + if len(existing_refs) > 0: + logger.info(f"Loading keypoints from {len(existing_refs)} images.") + + # Load query keypoints + cpdict = defaultdict(list) + bindict = defaultdict(list) + for name in existing_refs: + with h5py.File(str(feature_paths_refs[name2ref[name]]), "r") as fd: + kps = fd[name]["keypoints"].__array__() + if name not in quantize: + cpdict[name] = kps + else: + if "scores" in fd[name].keys(): + kp_scores = fd[name]["scores"].__array__() + else: + # we set the score to 1.0 if not provided + # increase for more weight on reference keypoints for + # stronger anchoring + kp_scores = [1.0 for _ in range(kps.shape[0])] + # bin existing keypoints of reference images for association + assign_keypoints( + kps, + cpdict[name], + conf["max_error"], + True, + bindict[name], + kp_scores, + conf["cell_size"], + ) + return cpdict, bindict + + +def aggregate_matches( + conf: Dict, + pairs: List[Tuple[str, str]], + match_path: Path, + feature_path: Path, + required_queries: Optional[Set[str]] = None, + max_kps: Optional[int] = None, + cpdict: Dict[str, Iterable] = defaultdict(list), + bindict: Dict[str, List[Counter]] = defaultdict(list), +): + if required_queries is None: + required_queries = set(sum(pairs, ())) + # default: do not overwrite existing features in feature_path! + required_queries -= set(list_h5_names(feature_path)) + + # if an entry in cpdict is provided as np.ndarray we assume it is fixed + required_queries -= set( + [k for k, v in cpdict.items() if isinstance(v, np.ndarray)] + ) + + # sort pairs for reduced RAM + pairs_per_q = Counter(list(chain(*pairs))) + pairs_score = [min(pairs_per_q[i], pairs_per_q[j]) for i, j in pairs] + pairs = [p for _, p in sorted(zip(pairs_score, pairs))] + + if len(required_queries) > 0: + logger.info( + f"Aggregating keypoints for {len(required_queries)} images." + ) + n_kps = 0 + with h5py.File(str(match_path), "a") as fd: + for name0, name1 in tqdm(pairs, smoothing=0.1): + pair = names_to_pair(name0, name1) + grp = fd[pair] + kpts0 = grp["keypoints0"].__array__() + kpts1 = grp["keypoints1"].__array__() + scores = grp["scores"].__array__() + + # Aggregate local features + update0 = name0 in required_queries + update1 = name1 in required_queries + + # in localization we do not want to bin the query kp + # assumes that the query is name0! + if update0 and not update1 and max_kps is None: + max_error0 = cell_size0 = 0.0 + else: + max_error0 = conf["max_error"] + cell_size0 = conf["cell_size"] + + # Get match ids and extend query keypoints (cpdict) + mkp_ids0 = assign_keypoints( + kpts0, + cpdict[name0], + max_error0, + update0, + bindict[name0], + scores, + cell_size0, + ) + mkp_ids1 = assign_keypoints( + kpts1, + cpdict[name1], + conf["max_error"], + update1, + bindict[name1], + scores, + conf["cell_size"], + ) + + # Build matches from assignments + matches0, scores0 = kpids_to_matches0(mkp_ids0, mkp_ids1, scores) + + assert kpts0.shape[0] == scores.shape[0] + grp.create_dataset("matches0", data=matches0) + grp.create_dataset("matching_scores0", data=scores0) + + # Convert bins to kps if finished, and store them + for name in (name0, name1): + pairs_per_q[name] -= 1 + if pairs_per_q[name] > 0 or name not in required_queries: + continue + kp_score = [c.most_common(1)[0][1] for c in bindict[name]] + cpdict[name] = [c.most_common(1)[0][0] for c in bindict[name]] + cpdict[name] = np.array(cpdict[name], dtype=np.float32) + + # Select top-k query kps by score (reassign matches later) + if max_kps: + top_k = min(max_kps, cpdict[name].shape[0]) + top_k = np.argsort(kp_score)[::-1][:top_k] + cpdict[name] = cpdict[name][top_k] + kp_score = np.array(kp_score)[top_k] + + # Write query keypoints + with h5py.File(feature_path, "a") as kfd: + if name in kfd: + del kfd[name] + kgrp = kfd.create_group(name) + kgrp.create_dataset("keypoints", data=cpdict[name]) + kgrp.create_dataset("score", data=kp_score) + n_kps += cpdict[name].shape[0] + del bindict[name] + + if len(required_queries) > 0: + avg_kp_per_image = round(n_kps / len(required_queries), 1) + logger.info( + f"Finished assignment, found {avg_kp_per_image} " + f"keypoints/image (avg.), total {n_kps}." + ) + return cpdict + + +def assign_matches( + pairs: List[Tuple[str, str]], + match_path: Path, + keypoints: Union[List[Path], Dict[str, np.array]], + max_error: float, +): + if isinstance(keypoints, list): + keypoints = load_keypoints({}, keypoints, kpts_as_bin=set([])) + assert len(set(sum(pairs, ())) - set(keypoints.keys())) == 0 + with h5py.File(str(match_path), "a") as fd: + for name0, name1 in tqdm(pairs): + pair = names_to_pair(name0, name1) + grp = fd[pair] + kpts0 = grp["keypoints0"].__array__() + kpts1 = grp["keypoints1"].__array__() + scores = grp["scores"].__array__() + + # NN search across cell boundaries + mkp_ids0 = assign_keypoints(kpts0, keypoints[name0], max_error) + mkp_ids1 = assign_keypoints(kpts1, keypoints[name1], max_error) + + matches0, scores0 = kpids_to_matches0(mkp_ids0, mkp_ids1, scores) + + # overwrite matches0 and matching_scores0 + del grp["matches0"], grp["matching_scores0"] + grp.create_dataset("matches0", data=matches0) + grp.create_dataset("matching_scores0", data=scores0) + + +@torch.no_grad() +def match_and_assign( + conf: Dict, + pairs_path: Path, + image_dir: Path, + match_path: Path, # out + feature_path_q: Path, # out + feature_paths_refs: Optional[List[Path]] = [], + max_kps: Optional[int] = 8192, + overwrite: bool = False, +) -> Path: + for path in feature_paths_refs: + if not path.exists(): + raise FileNotFoundError(f"Reference feature file {path}.") + pairs = parse_retrieval(pairs_path) + pairs = [(q, r) for q, rs in pairs.items() for r in rs] + pairs = find_unique_new_pairs(pairs, None if overwrite else match_path) + required_queries = set(sum(pairs, ())) + + name2ref = { + n: i for i, p in enumerate(feature_paths_refs) for n in list_h5_names(p) + } + existing_refs = required_queries.intersection(set(name2ref.keys())) + + # images which require feature extraction + required_queries = required_queries - existing_refs + + if feature_path_q.exists(): + existing_queries = set(list_h5_names(feature_path_q)) + feature_paths_refs.append(feature_path_q) + existing_refs = set.union(existing_refs, existing_queries) + if not overwrite: + required_queries = required_queries - existing_queries + + if len(pairs) == 0 and len(required_queries) == 0: + logger.info("All pairs exist. Skipping dense matching.") + return + + # extract semi-dense matches + match_dense(conf, pairs, image_dir, match_path, existing_refs=existing_refs) + + logger.info("Assigning matches...") + + # Pre-load existing keypoints + cpdict, bindict = load_keypoints( + conf, feature_paths_refs, quantize=required_queries + ) + + # Reassign matches by aggregation + cpdict = aggregate_matches( + conf, + pairs, + match_path, + feature_path=feature_path_q, + required_queries=required_queries, + max_kps=max_kps, + cpdict=cpdict, + bindict=bindict, + ) + + # Invalidate matches that are far from selected bin by reassignment + if max_kps is not None: + logger.info(f'Reassign matches with max_error={conf["max_error"]}.') + assign_matches(pairs, match_path, cpdict, max_error=conf["max_error"]) + + +def scale_lines(lines, scale): + if np.any(scale != 1.0): + lines *= lines.new_tensor(scale) + return lines + + +def match(model, path_0, path_1, conf): + default_conf = { + "grayscale": True, + "resize_max": 1024, + "dfactor": 8, + "cache_images": False, + "force_resize": False, + "width": 320, + "height": 240, + } + + def preprocess(image: np.ndarray): + image = image.astype(np.float32, copy=False) + size = image.shape[:2][::-1] + scale = np.array([1.0, 1.0]) + if conf.resize_max: + scale = conf.resize_max / max(size) + if scale < 1.0: + size_new = tuple(int(round(x * scale)) for x in size) + image = resize_image(image, size_new, "cv2_area") + scale = np.array(size) / np.array(size_new) + if conf.force_resize: + size = image.shape[:2][::-1] + image = resize_image(image, (conf.width, conf.height), "cv2_area") + size_new = (conf.width, conf.height) + scale = np.array(size) / np.array(size_new) + if conf.grayscale: + assert image.ndim == 2, image.shape + image = image[None] + else: + image = image.transpose((2, 0, 1)) # HxWxC to CxHxW + image = torch.from_numpy(image / 255.0).float() + # assure that the size is divisible by dfactor + size_new = tuple( + map( + lambda x: int(x // conf.dfactor * conf.dfactor), + image.shape[-2:], + ) + ) + image = F.resize(image, size=size_new, antialias=True) + scale = np.array(size) / np.array(size_new)[::-1] + return image, scale + + conf = SimpleNamespace(**{**default_conf, **conf}) + image0 = read_image(path_0, conf.grayscale) + image1 = read_image(path_1, conf.grayscale) + image0, scale0 = preprocess(image0) + image1, scale1 = preprocess(image1) + image0 = image0.to(device)[None] + image1 = image1.to(device)[None] + pred = model({"image0": image0, "image1": image1}) + + # Rescale keypoints and move to cpu + kpts0, kpts1 = pred["keypoints0"], pred["keypoints1"] + kpts0 = scale_keypoints(kpts0 + 0.5, scale0) - 0.5 + kpts1 = scale_keypoints(kpts1 + 0.5, scale1) - 0.5 + + ret = { + "image0": image0.squeeze().cpu().numpy(), + "image1": image1.squeeze().cpu().numpy(), + "keypoints0": kpts0.cpu().numpy(), + "keypoints1": kpts1.cpu().numpy(), + } + if "mconf" in pred.keys(): + ret["mconf"] = pred["mconf"].cpu().numpy() + return ret + + +@torch.no_grad() +def match_images(model, image_0, image_1, conf, device="cpu"): + default_conf = { + "grayscale": True, + "resize_max": 1024, + "dfactor": 8, + "cache_images": False, + "force_resize": False, + "width": 320, + "height": 240, + } + + def preprocess(image: np.ndarray): + image = image.astype(np.float32, copy=False) + size = image.shape[:2][::-1] + scale = np.array([1.0, 1.0]) + if conf.resize_max: + scale = conf.resize_max / max(size) + if scale < 1.0: + size_new = tuple(int(round(x * scale)) for x in size) + image = resize_image(image, size_new, "cv2_area") + scale = np.array(size) / np.array(size_new) + if conf.force_resize: + size = image.shape[:2][::-1] + image = resize_image(image, (conf.width, conf.height), "cv2_area") + size_new = (conf.width, conf.height) + scale = np.array(size) / np.array(size_new) + if conf.grayscale: + assert image.ndim == 2, image.shape + image = image[None] + else: + image = image.transpose((2, 0, 1)) # HxWxC to CxHxW + image = torch.from_numpy(image / 255.0).float() + + # assure that the size is divisible by dfactor + size_new = tuple( + map( + lambda x: int(x // conf.dfactor * conf.dfactor), + image.shape[-2:], + ) + ) + image = F.resize(image, size=size_new) + scale = np.array(size) / np.array(size_new)[::-1] + return image, scale + + conf = SimpleNamespace(**{**default_conf, **conf}) + + if len(image_0.shape) == 3 and conf.grayscale: + image0 = cv2.cvtColor(image_0, cv2.COLOR_RGB2GRAY) + else: + image0 = image_0 + if len(image_0.shape) == 3 and conf.grayscale: + image1 = cv2.cvtColor(image_1, cv2.COLOR_RGB2GRAY) + else: + image1 = image_1 + + # comment following lines, image is always RGB mode + # if not conf.grayscale and len(image0.shape) == 3: + # image0 = image0[:, :, ::-1] # BGR to RGB + # if not conf.grayscale and len(image1.shape) == 3: + # image1 = image1[:, :, ::-1] # BGR to RGB + + image0, scale0 = preprocess(image0) + image1, scale1 = preprocess(image1) + image0 = image0.to(device)[None] + image1 = image1.to(device)[None] + pred = model({"image0": image0, "image1": image1}) + + s0 = np.array(image_0.shape[:2][::-1]) / np.array(image0.shape[-2:][::-1]) + s1 = np.array(image_1.shape[:2][::-1]) / np.array(image1.shape[-2:][::-1]) + + # Rescale keypoints and move to cpu + if "keypoints0" in pred.keys() and "keypoints1" in pred.keys(): + kpts0, kpts1 = pred["keypoints0"], pred["keypoints1"] + kpts0_origin = scale_keypoints(kpts0 + 0.5, s0) - 0.5 + kpts1_origin = scale_keypoints(kpts1 + 0.5, s1) - 0.5 + + ret = { + "image0": image0.squeeze().cpu().numpy(), + "image1": image1.squeeze().cpu().numpy(), + "image0_orig": image_0, + "image1_orig": image_1, + "keypoints0": kpts0.cpu().numpy(), + "keypoints1": kpts1.cpu().numpy(), + "keypoints0_orig": kpts0_origin.cpu().numpy(), + "keypoints1_orig": kpts1_origin.cpu().numpy(), + "mkeypoints0": kpts0.cpu().numpy(), + "mkeypoints1": kpts1.cpu().numpy(), + "mkeypoints0_orig": kpts0_origin.cpu().numpy(), + "mkeypoints1_orig": kpts1_origin.cpu().numpy(), + "original_size0": np.array(image_0.shape[:2][::-1]), + "original_size1": np.array(image_1.shape[:2][::-1]), + "new_size0": np.array(image0.shape[-2:][::-1]), + "new_size1": np.array(image1.shape[-2:][::-1]), + "scale0": s0, + "scale1": s1, + } + if "mconf" in pred.keys(): + ret["mconf"] = pred["mconf"].cpu().numpy() + elif "scores" in pred.keys(): # adapting loftr + ret["mconf"] = pred["scores"].cpu().numpy() + else: + ret["mconf"] = np.ones_like(kpts0.cpu().numpy()[:, 0]) + if "lines0" in pred.keys() and "lines1" in pred.keys(): + if "keypoints0" in pred.keys() and "keypoints1" in pred.keys(): + kpts0, kpts1 = pred["keypoints0"], pred["keypoints1"] + kpts0_origin = scale_keypoints(kpts0 + 0.5, s0) - 0.5 + kpts1_origin = scale_keypoints(kpts1 + 0.5, s1) - 0.5 + kpts0_origin = kpts0_origin.cpu().numpy() + kpts1_origin = kpts1_origin.cpu().numpy() + else: + kpts0_origin, kpts1_origin = ( + None, + None, + ) # np.zeros([0]), np.zeros([0]) + lines0, lines1 = pred["lines0"], pred["lines1"] + lines0_raw, lines1_raw = pred["raw_lines0"], pred["raw_lines1"] + + lines0_raw = torch.from_numpy(lines0_raw.copy()) + lines1_raw = torch.from_numpy(lines1_raw.copy()) + lines0_raw = scale_lines(lines0_raw + 0.5, s0) - 0.5 + lines1_raw = scale_lines(lines1_raw + 0.5, s1) - 0.5 + + lines0 = torch.from_numpy(lines0.copy()) + lines1 = torch.from_numpy(lines1.copy()) + lines0 = scale_lines(lines0 + 0.5, s0) - 0.5 + lines1 = scale_lines(lines1 + 0.5, s1) - 0.5 + + ret = { + "image0_orig": image_0, + "image1_orig": image_1, + "line0": lines0_raw.cpu().numpy(), + "line1": lines1_raw.cpu().numpy(), + "line0_orig": lines0.cpu().numpy(), + "line1_orig": lines1.cpu().numpy(), + "line_keypoints0_orig": kpts0_origin, + "line_keypoints1_orig": kpts1_origin, + } + del pred + torch.cuda.empty_cache() + return ret + + +@torch.no_grad() +def main( + conf: Dict, + pairs: Path, + image_dir: Path, + export_dir: Optional[Path] = None, + matches: Optional[Path] = None, # out + features: Optional[Path] = None, # out + features_ref: Optional[Path] = None, + max_kps: Optional[int] = 8192, + overwrite: bool = False, +) -> Path: + logger.info( + "Extracting semi-dense features with configuration:" + f"\n{pprint.pformat(conf)}" + ) + + if features is None: + features = "feats_" + + if isinstance(features, Path): + features_q = features + if matches is None: + raise ValueError( + "Either provide both features and matches as Path" + " or both as names." + ) + else: + if export_dir is None: + raise ValueError( + "Provide an export_dir if features and matches" + f" are not file paths: {features}, {matches}." + ) + features_q = Path(export_dir, f'{features}{conf["output"]}.h5') + if matches is None: + matches = Path(export_dir, f'{conf["output"]}_{pairs.stem}.h5') + + if features_ref is None: + features_ref = [] + elif isinstance(features_ref, list): + features_ref = list(features_ref) + elif isinstance(features_ref, Path): + features_ref = [features_ref] + else: + raise TypeError(str(features_ref)) + + match_and_assign( + conf, + pairs, + image_dir, + matches, + features_q, + features_ref, + max_kps, + overwrite, + ) + + return features_q, matches + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pairs", type=Path, required=True) + parser.add_argument("--image_dir", type=Path, required=True) + parser.add_argument("--export_dir", type=Path, required=True) + parser.add_argument( + "--matches", type=Path, default=confs["loftr"]["output"] + ) + parser.add_argument( + "--features", type=str, default="feats_" + confs["loftr"]["output"] + ) + parser.add_argument( + "--conf", type=str, default="loftr", choices=list(confs.keys()) + ) + args = parser.parse_args() + main( + confs[args.conf], + args.pairs, + args.image_dir, + args.export_dir, + args.matches, + args.features, + ) diff --git a/hloc/match_features.py b/hloc/match_features.py new file mode 100644 index 0000000000000000000000000000000000000000..2d4b0bd0cc3078789ac980177479c224129ad4cc --- /dev/null +++ b/hloc/match_features.py @@ -0,0 +1,441 @@ +import argparse +import pprint +from functools import partial +from pathlib import Path +from queue import Queue +from threading import Thread +from typing import Dict, List, Optional, Tuple, Union + +import h5py +import numpy as np +import torch +from tqdm import tqdm + +from . import logger, matchers +from .utils.base_model import dynamic_load +from .utils.parsers import names_to_pair, names_to_pair_old, parse_retrieval + +""" +A set of standard configurations that can be directly selected from the command +line using their name. Each is a dictionary with the following entries: + - output: the name of the match file that will be generated. + - model: the model configuration, as passed to a feature matcher. +""" +confs = { + "superglue": { + "output": "matches-superglue", + "model": { + "name": "superglue", + "weights": "outdoor", + "sinkhorn_iterations": 50, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": True, + "resize_max": 1024, + "dfactor": 8, + "force_resize": False, + }, + }, + "superglue-fast": { + "output": "matches-superglue-it5", + "model": { + "name": "superglue", + "weights": "outdoor", + "sinkhorn_iterations": 5, + "match_threshold": 0.2, + }, + }, + "superpoint-lightglue": { + "output": "matches-lightglue", + "model": { + "name": "lightglue", + "match_threshold": 0.2, + "width_confidence": 0.99, # for point pruning + "depth_confidence": 0.95, # for early stopping, + "features": "superpoint", + "model_name": "superpoint_lightglue.pth", + }, + "preprocessing": { + "grayscale": True, + "resize_max": 1024, + "dfactor": 8, + "force_resize": False, + }, + }, + "disk-lightglue": { + "output": "matches-disk-lightglue", + "model": { + "name": "lightglue", + "match_threshold": 0.2, + "width_confidence": 0.99, # for point pruning + "depth_confidence": 0.95, # for early stopping, + "features": "disk", + "model_name": "disk_lightglue.pth", + }, + "preprocessing": { + "grayscale": True, + "resize_max": 1024, + "dfactor": 8, + "force_resize": False, + }, + }, + "sift-lightglue": { + "output": "matches-sift-lightglue", + "model": { + "name": "lightglue", + "match_threshold": 0.2, + "width_confidence": 0.99, # for point pruning + "depth_confidence": 0.95, # for early stopping, + "features": "sift", + "add_scale_ori": True, + "model_name": "sift_lightglue.pth", + }, + "preprocessing": { + "grayscale": True, + "resize_max": 1024, + "dfactor": 8, + "force_resize": False, + }, + }, + "sgmnet": { + "output": "matches-sgmnet", + "model": { + "name": "sgmnet", + "seed_top_k": [256, 256], + "seed_radius_coe": 0.01, + "net_channels": 128, + "layer_num": 9, + "head": 4, + "seedlayer": [0, 6], + "use_mc_seeding": True, + "use_score_encoding": False, + "conf_bar": [1.11, 0.1], + "sink_iter": [10, 100], + "detach_iter": 1000000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": True, + "resize_max": 1024, + "dfactor": 8, + "force_resize": False, + }, + }, + "NN-superpoint": { + "output": "matches-NN-mutual-dist.7", + "model": { + "name": "nearest_neighbor", + "do_mutual_check": True, + "distance_threshold": 0.7, + "match_threshold": 0.2, + }, + }, + "NN-ratio": { + "output": "matches-NN-mutual-ratio.8", + "model": { + "name": "nearest_neighbor", + "do_mutual_check": True, + "ratio_threshold": 0.8, + "match_threshold": 0.2, + }, + }, + "NN-mutual": { + "output": "matches-NN-mutual", + "model": { + "name": "nearest_neighbor", + "do_mutual_check": True, + "match_threshold": 0.2, + }, + }, + "Dual-Softmax": { + "output": "matches-Dual-Softmax", + "model": { + "name": "dual_softmax", + "match_threshold": 0.01, + "inv_temperature": 20, + }, + }, + "adalam": { + "output": "matches-adalam", + "model": { + "name": "adalam", + "match_threshold": 0.2, + }, + }, + "imp": { + "output": "matches-imp", + "model": { + "name": "imp", + "match_threshold": 0.2, + }, + }, +} + + +class WorkQueue: + def __init__(self, work_fn, num_threads=1): + self.queue = Queue(num_threads) + self.threads = [ + Thread(target=self.thread_fn, args=(work_fn,)) + for _ in range(num_threads) + ] + for thread in self.threads: + thread.start() + + def join(self): + for thread in self.threads: + self.queue.put(None) + for thread in self.threads: + thread.join() + + def thread_fn(self, work_fn): + item = self.queue.get() + while item is not None: + work_fn(item) + item = self.queue.get() + + def put(self, data): + self.queue.put(data) + + +class FeaturePairsDataset(torch.utils.data.Dataset): + def __init__(self, pairs, feature_path_q, feature_path_r): + self.pairs = pairs + self.feature_path_q = feature_path_q + self.feature_path_r = feature_path_r + + def __getitem__(self, idx): + name0, name1 = self.pairs[idx] + data = {} + with h5py.File(self.feature_path_q, "r") as fd: + grp = fd[name0] + for k, v in grp.items(): + data[k + "0"] = torch.from_numpy(v.__array__()).float() + # some matchers might expect an image but only use its size + data["image0"] = torch.empty((1,) + tuple(grp["image_size"])[::-1]) + with h5py.File(self.feature_path_r, "r") as fd: + grp = fd[name1] + for k, v in grp.items(): + data[k + "1"] = torch.from_numpy(v.__array__()).float() + data["image1"] = torch.empty((1,) + tuple(grp["image_size"])[::-1]) + return data + + def __len__(self): + return len(self.pairs) + + +def writer_fn(inp, match_path): + pair, pred = inp + with h5py.File(str(match_path), "a", libver="latest") as fd: + if pair in fd: + del fd[pair] + grp = fd.create_group(pair) + matches = pred["matches0"][0].cpu().short().numpy() + grp.create_dataset("matches0", data=matches) + if "matching_scores0" in pred: + scores = pred["matching_scores0"][0].cpu().half().numpy() + grp.create_dataset("matching_scores0", data=scores) + + +def main( + conf: Dict, + pairs: Path, + features: Union[Path, str], + export_dir: Optional[Path] = None, + matches: Optional[Path] = None, + features_ref: Optional[Path] = None, + overwrite: bool = False, +) -> Path: + if isinstance(features, Path) or Path(features).exists(): + features_q = features + if matches is None: + raise ValueError( + "Either provide both features and matches as Path" + " or both as names." + ) + else: + if export_dir is None: + raise ValueError( + "Provide an export_dir if features is not" + f" a file path: {features}." + ) + features_q = Path(export_dir, features + ".h5") + if matches is None: + matches = Path( + export_dir, f'{features}_{conf["output"]}_{pairs.stem}.h5' + ) + + if features_ref is None: + features_ref = features_q + match_from_paths(conf, pairs, matches, features_q, features_ref, overwrite) + + return matches + + +def find_unique_new_pairs(pairs_all: List[Tuple[str]], match_path: Path = None): + """Avoid to recompute duplicates to save time.""" + pairs = set() + for i, j in pairs_all: + if (j, i) not in pairs: + pairs.add((i, j)) + pairs = list(pairs) + if match_path is not None and match_path.exists(): + with h5py.File(str(match_path), "r", libver="latest") as fd: + pairs_filtered = [] + for i, j in pairs: + if ( + names_to_pair(i, j) in fd + or names_to_pair(j, i) in fd + or names_to_pair_old(i, j) in fd + or names_to_pair_old(j, i) in fd + ): + continue + pairs_filtered.append((i, j)) + return pairs_filtered + return pairs + + +@torch.no_grad() +def match_from_paths( + conf: Dict, + pairs_path: Path, + match_path: Path, + feature_path_q: Path, + feature_path_ref: Path, + overwrite: bool = False, +) -> Path: + logger.info( + "Matching local features with configuration:" + f"\n{pprint.pformat(conf)}" + ) + + if not feature_path_q.exists(): + raise FileNotFoundError(f"Query feature file {feature_path_q}.") + if not feature_path_ref.exists(): + raise FileNotFoundError(f"Reference feature file {feature_path_ref}.") + match_path.parent.mkdir(exist_ok=True, parents=True) + + assert pairs_path.exists(), pairs_path + pairs = parse_retrieval(pairs_path) + pairs = [(q, r) for q, rs in pairs.items() for r in rs] + pairs = find_unique_new_pairs(pairs, None if overwrite else match_path) + if len(pairs) == 0: + logger.info("Skipping the matching.") + return + + device = "cuda" if torch.cuda.is_available() else "cpu" + Model = dynamic_load(matchers, conf["model"]["name"]) + model = Model(conf["model"]).eval().to(device) + + dataset = FeaturePairsDataset(pairs, feature_path_q, feature_path_ref) + loader = torch.utils.data.DataLoader( + dataset, num_workers=5, batch_size=1, shuffle=False, pin_memory=True + ) + writer_queue = WorkQueue(partial(writer_fn, match_path=match_path), 5) + + for idx, data in enumerate(tqdm(loader, smoothing=0.1)): + data = { + k: v if k.startswith("image") else v.to(device, non_blocking=True) + for k, v in data.items() + } + pred = model(data) + pair = names_to_pair(*pairs[idx]) + writer_queue.put((pair, pred)) + writer_queue.join() + logger.info("Finished exporting matches.") + + +def scale_keypoints(kpts, scale): + if np.any(scale != 1.0): + kpts *= kpts.new_tensor(scale) + return kpts + + +@torch.no_grad() +def match_images(model, feat0, feat1): + # forward pass to match keypoints + desc0 = feat0["descriptors"][0] + desc1 = feat1["descriptors"][0] + if len(desc0.shape) == 2: + desc0 = desc0.unsqueeze(0) + if len(desc1.shape) == 2: + desc1 = desc1.unsqueeze(0) + if isinstance(feat0["keypoints"], list): + feat0["keypoints"] = feat0["keypoints"][0][None] + if isinstance(feat1["keypoints"], list): + feat1["keypoints"] = feat1["keypoints"][0][None] + input_dict = { + "image0": feat0["image"], + "keypoints0": feat0["keypoints"], + "scores0": feat0["scores"][0].unsqueeze(0), + "descriptors0": desc0, + "image1": feat1["image"], + "keypoints1": feat1["keypoints"], + "scores1": feat1["scores"][0].unsqueeze(0), + "descriptors1": desc1, + } + if "scales" in feat0: + input_dict = {**input_dict, "scales0": feat0["scales"]} + if "scales" in feat1: + input_dict = {**input_dict, "scales1": feat1["scales"]} + if "oris" in feat0: + input_dict = {**input_dict, "oris0": feat0["oris"]} + if "oris" in feat1: + input_dict = {**input_dict, "oris1": feat1["oris"]} + pred = model(input_dict) + pred = { + k: v.cpu().detach()[0] if isinstance(v, torch.Tensor) else v + for k, v in pred.items() + } + kpts0, kpts1 = ( + feat0["keypoints"][0].cpu().numpy(), + feat1["keypoints"][0].cpu().numpy(), + ) + matches, confid = pred["matches0"], pred["matching_scores0"] + # Keep the matching keypoints. + valid = matches > -1 + mkpts0 = kpts0[valid] + mkpts1 = kpts1[matches[valid]] + mconfid = confid[valid] + # rescale the keypoints to their original size + s0 = feat0["original_size"] / feat0["size"] + s1 = feat1["original_size"] / feat1["size"] + kpts0_origin = scale_keypoints(torch.from_numpy(kpts0 + 0.5), s0) - 0.5 + kpts1_origin = scale_keypoints(torch.from_numpy(kpts1 + 0.5), s1) - 0.5 + + mkpts0_origin = scale_keypoints(torch.from_numpy(mkpts0 + 0.5), s0) - 0.5 + mkpts1_origin = scale_keypoints(torch.from_numpy(mkpts1 + 0.5), s1) - 0.5 + + ret = { + "image0_orig": feat0["image_orig"], + "image1_orig": feat1["image_orig"], + "keypoints0": kpts0, + "keypoints1": kpts1, + "keypoints0_orig": kpts0_origin.numpy(), + "keypoints1_orig": kpts1_origin.numpy(), + "mkeypoints0": mkpts0, + "mkeypoints1": mkpts1, + "mkeypoints0_orig": mkpts0_origin.numpy(), + "mkeypoints1_orig": mkpts1_origin.numpy(), + "mconf": mconfid.numpy(), + } + del feat0, feat1, desc0, desc1, kpts0, kpts1, kpts0_origin, kpts1_origin + torch.cuda.empty_cache() + + return ret + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pairs", type=Path, required=True) + parser.add_argument("--export_dir", type=Path) + parser.add_argument( + "--features", type=str, default="feats-superpoint-n4096-r1024" + ) + parser.add_argument("--matches", type=Path) + parser.add_argument( + "--conf", type=str, default="superglue", choices=list(confs.keys()) + ) + args = parser.parse_args() + main(confs[args.conf], args.pairs, args.features, args.export_dir) diff --git a/hloc/matchers/__init__.py b/hloc/matchers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9fd381eb391604db3d3c3c03d278c0e08de2531 --- /dev/null +++ b/hloc/matchers/__init__.py @@ -0,0 +1,3 @@ +def get_matcher(matcher): + mod = __import__(f"{__name__}.{matcher}", fromlist=[""]) + return getattr(mod, "Model") diff --git a/hloc/matchers/adalam.py b/hloc/matchers/adalam.py new file mode 100644 index 0000000000000000000000000000000000000000..7820428a5a087d0b5d6855de15c0230327ce7dc1 --- /dev/null +++ b/hloc/matchers/adalam.py @@ -0,0 +1,68 @@ +import torch +from kornia.feature.adalam import AdalamFilter +from kornia.utils.helpers import get_cuda_device_if_available + +from ..utils.base_model import BaseModel + + +class AdaLAM(BaseModel): + # See https://kornia.readthedocs.io/en/latest/_modules/kornia/feature/adalam/adalam.html. + default_conf = { + "area_ratio": 100, + "search_expansion": 4, + "ransac_iters": 128, + "min_inliers": 6, + "min_confidence": 200, + "orientation_difference_threshold": 30, + "scale_rate_threshold": 1.5, + "detected_scale_rate_threshold": 5, + "refit": True, + "force_seed_mnn": True, + "device": get_cuda_device_if_available(), + } + required_inputs = [ + "image0", + "image1", + "descriptors0", + "descriptors1", + "keypoints0", + "keypoints1", + "scales0", + "scales1", + "oris0", + "oris1", + ] + + def _init(self, conf): + self.adalam = AdalamFilter(conf) + + def _forward(self, data): + assert data["keypoints0"].size(0) == 1 + if data["keypoints0"].size(1) < 2 or data["keypoints1"].size(1) < 2: + matches = torch.zeros( + (0, 2), dtype=torch.int64, device=data["keypoints0"].device + ) + else: + matches = self.adalam.match_and_filter( + data["keypoints0"][0], + data["keypoints1"][0], + data["descriptors0"][0].T, + data["descriptors1"][0].T, + data["image0"].shape[2:], + data["image1"].shape[2:], + data["oris0"][0], + data["oris1"][0], + data["scales0"][0], + data["scales1"][0], + ) + matches_new = torch.full( + (data["keypoints0"].size(1),), + -1, + dtype=torch.int64, + device=data["keypoints0"].device, + ) + matches_new[matches[:, 0]] = matches[:, 1] + return { + "matches0": matches_new.unsqueeze(0), + "matching_scores0": torch.zeros(matches_new.size(0)).unsqueeze(0), + } diff --git a/hloc/matchers/aspanformer.py b/hloc/matchers/aspanformer.py new file mode 100644 index 0000000000000000000000000000000000000000..ce750747d6e73d911904d1566a757478c376aba9 --- /dev/null +++ b/hloc/matchers/aspanformer.py @@ -0,0 +1,112 @@ +import subprocess +import sys +from pathlib import Path + +import torch + +from hloc import logger +from hloc.utils.base_model import BaseModel + +sys.path.append(str(Path(__file__).parent / "../../third_party")) +from ASpanFormer.src.ASpanFormer.aspanformer import ASpanFormer as _ASpanFormer +from ASpanFormer.src.config.default import get_cfg_defaults +from ASpanFormer.src.utils.misc import lower_config + +aspanformer_path = Path(__file__).parent / "../../third_party/ASpanFormer" + + +class ASpanFormer(BaseModel): + default_conf = { + "weights": "outdoor", + "match_threshold": 0.2, + "sinkhorn_iterations": 20, + "max_keypoints": 2048, + "config_path": aspanformer_path / "configs/aspan/outdoor/aspan_test.py", + "model_name": "weights_aspanformer.tar", + } + required_inputs = ["image0", "image1"] + proxy = "http://localhost:1080" + aspanformer_models = { + "weights_aspanformer.tar": "https://drive.google.com/uc?id=1eavM9dTkw9nbc-JqlVVfGPU5UvTTfc6k&confirm=t" + } + + def _init(self, conf): + model_path = ( + aspanformer_path / "weights" / Path(conf["weights"] + ".ckpt") + ) + # Download the model. + if not model_path.exists(): + # model_path.parent.mkdir(exist_ok=True) + tar_path = aspanformer_path / conf["model_name"] + if not tar_path.exists(): + link = self.aspanformer_models[conf["model_name"]] + cmd = [ + "gdown", + link, + "-O", + str(tar_path), + "--proxy", + self.proxy, + ] + cmd_wo_proxy = ["gdown", link, "-O", str(tar_path)] + logger.info( + f"Downloading the Aspanformer model with `{cmd_wo_proxy}`." + ) + try: + subprocess.run(cmd_wo_proxy, check=True) + except subprocess.CalledProcessError as e: + logger.info(f"Downloading failed {e}.") + logger.info( + f"Downloading the Aspanformer model with `{cmd}`." + ) + try: + subprocess.run(cmd, check=True) + except subprocess.CalledProcessError as e: + logger.error( + f"Failed to download the Aspanformer model: {e}" + ) + + cmd = ["tar", "-xvf", str(tar_path), "-C", str(aspanformer_path)] + logger.info(f"Unzip model file `{cmd}`.") + subprocess.run(cmd, check=True) + + config = get_cfg_defaults() + config.merge_from_file(conf["config_path"]) + _config = lower_config(config) + + # update: match threshold + _config["aspan"]["match_coarse"]["thr"] = conf["match_threshold"] + _config["aspan"]["match_coarse"]["skh_iters"] = conf[ + "sinkhorn_iterations" + ] + + self.net = _ASpanFormer(config=_config["aspan"]) + weight_path = model_path + state_dict = torch.load(str(weight_path), map_location="cpu")[ + "state_dict" + ] + self.net.load_state_dict(state_dict, strict=False) + logger.info("Loaded Aspanformer model") + + def _forward(self, data): + data_ = { + "image0": data["image0"], + "image1": data["image1"], + } + self.net(data_, online_resize=True) + pred = { + "keypoints0": data_["mkpts0_f"], + "keypoints1": data_["mkpts1_f"], + "mconf": data_["mconf"], + } + scores = data_["mconf"] + top_k = self.conf["max_keypoints"] + if top_k is not None and len(scores) > top_k: + keep = torch.argsort(scores, descending=True)[:top_k] + scores = scores[keep] + pred["keypoints0"], pred["keypoints1"], pred["mconf"] = ( + pred["keypoints0"][keep], + pred["keypoints1"][keep], + scores, + ) + return pred diff --git a/hloc/matchers/cotr.py b/hloc/matchers/cotr.py new file mode 100644 index 0000000000000000000000000000000000000000..44d74f642339133eea2da5beae15d1899e5920bc --- /dev/null +++ b/hloc/matchers/cotr.py @@ -0,0 +1,77 @@ +import argparse +import sys +from pathlib import Path + +import numpy as np +import torch +from torchvision.transforms import ToPILImage + +from ..utils.base_model import BaseModel + +sys.path.append(str(Path(__file__).parent / "../../third_party/COTR")) +from COTR.inference.sparse_engine import SparseEngine +from COTR.models import build_model +from COTR.options.options import * # noqa: F403 +from COTR.options.options_utils import * # noqa: F403 +from COTR.utils import utils as utils_cotr + +utils_cotr.fix_randomness(0) +torch.set_grad_enabled(False) + +cotr_path = Path(__file__).parent / "../../third_party/COTR" + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class COTR(BaseModel): + default_conf = { + "weights": "out/default", + "match_threshold": 0.2, + "max_keypoints": -1, + } + required_inputs = ["image0", "image1"] + + def _init(self, conf): + parser = argparse.ArgumentParser() + set_COTR_arguments(parser) # noqa: F405 + opt = parser.parse_args() + opt.command = " ".join(sys.argv) + opt.load_weights_path = str( + cotr_path / conf["weights"] / "checkpoint.pth.tar" + ) + + layer_2_channels = { + "layer1": 256, + "layer2": 512, + "layer3": 1024, + "layer4": 2048, + } + opt.dim_feedforward = layer_2_channels[opt.layer] + + model = build_model(opt) + model = model.to(device) + weights = torch.load(opt.load_weights_path, map_location="cpu")[ + "model_state_dict" + ] + utils_cotr.safe_load_weights(model, weights) + self.net = model.eval() + self.to_pil_func = ToPILImage(mode="RGB") + + def _forward(self, data): + img_a = np.array(self.to_pil_func(data["image0"][0].cpu())) + img_b = np.array(self.to_pil_func(data["image1"][0].cpu())) + corrs = SparseEngine( + self.net, 32, mode="tile" + ).cotr_corr_multiscale_with_cycle_consistency( + img_a, + img_b, + np.linspace(0.5, 0.0625, 4), + 1, + max_corrs=self.conf["max_keypoints"], + queries_a=None, + ) + pred = { + "keypoints0": torch.from_numpy(corrs[:, :2]), + "keypoints1": torch.from_numpy(corrs[:, 2:]), + } + return pred diff --git a/hloc/matchers/dkm.py b/hloc/matchers/dkm.py new file mode 100644 index 0000000000000000000000000000000000000000..f4d702e01500421f526d70a13e40d91d1f5f7096 --- /dev/null +++ b/hloc/matchers/dkm.py @@ -0,0 +1,70 @@ +import subprocess +import sys +from pathlib import Path + +import torch +from PIL import Image + +from .. import logger +from ..utils.base_model import BaseModel + +sys.path.append(str(Path(__file__).parent / "../../third_party")) +from DKM.dkm import DKMv3_outdoor + +dkm_path = Path(__file__).parent / "../../third_party/DKM" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class DKMv3(BaseModel): + default_conf = { + "model_name": "DKMv3_outdoor.pth", + "match_threshold": 0.2, + "checkpoint_dir": dkm_path / "pretrained", + "max_keypoints": -1, + } + required_inputs = [ + "image0", + "image1", + ] + # Models exported using + dkm_models = { + "DKMv3_outdoor.pth": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_outdoor.pth", + "DKMv3_indoor.pth": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_indoor.pth", + } + + def _init(self, conf): + model_path = dkm_path / "pretrained" / conf["model_name"] + + # Download the model. + if not model_path.exists(): + model_path.parent.mkdir(exist_ok=True) + link = self.dkm_models[conf["model_name"]] + cmd = ["wget", "--quiet", link, "-O", str(model_path)] + logger.info(f"Downloading the DKMv3 model with `{cmd}`.") + subprocess.run(cmd, check=True) + self.net = DKMv3_outdoor(path_to_weights=str(model_path), device=device) + logger.info("Loading DKMv3 model done") + + def _forward(self, data): + img0 = data["image0"].cpu().numpy().squeeze() * 255 + img1 = data["image1"].cpu().numpy().squeeze() * 255 + img0 = img0.transpose(1, 2, 0) + img1 = img1.transpose(1, 2, 0) + img0 = Image.fromarray(img0.astype("uint8")) + img1 = Image.fromarray(img1.astype("uint8")) + W_A, H_A = img0.size + W_B, H_B = img1.size + + warp, certainty = self.net.match(img0, img1, device=device) + matches, certainty = self.net.sample( + warp, certainty, num=self.conf["max_keypoints"] + ) + kpts1, kpts2 = self.net.to_pixel_coordinates( + matches, H_A, W_A, H_B, W_B + ) + pred = { + "keypoints0": kpts1, + "keypoints1": kpts2, + "mconf": certainty, + } + return pred diff --git a/hloc/matchers/dual_softmax.py b/hloc/matchers/dual_softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..1c073ae66fdd064a27140e0cb566aa1d78ad2e6e --- /dev/null +++ b/hloc/matchers/dual_softmax.py @@ -0,0 +1,76 @@ +import numpy as np +import torch + +from ..utils.base_model import BaseModel + + +# borrow from dedode +def dual_softmax_matcher( + desc_A: tuple["B", "C", "N"], # noqa: F821 + desc_B: tuple["B", "C", "M"], # noqa: F821 + threshold=0.1, + inv_temperature=20, + normalize=True, +): + B, C, N = desc_A.shape + if len(desc_A.shape) < 3: + desc_A, desc_B = desc_A[None], desc_B[None] + if normalize: + desc_A = desc_A / desc_A.norm(dim=1, keepdim=True) + desc_B = desc_B / desc_B.norm(dim=1, keepdim=True) + sim = ( + torch.einsum("b c n, b c m -> b n m", desc_A, desc_B) * inv_temperature + ) + P = sim.softmax(dim=-2) * sim.softmax(dim=-1) + mask = torch.nonzero( + (P == P.max(dim=-1, keepdim=True).values) + * (P == P.max(dim=-2, keepdim=True).values) + * (P > threshold) + ) + mask = mask.cpu().numpy() + matches0 = np.ones((B, P.shape[-2]), dtype=int) * (-1) + scores0 = np.zeros((B, P.shape[-2]), dtype=float) + matches0[:, mask[:, 1]] = mask[:, 2] + tmp_P = P.cpu().numpy() + scores0[:, mask[:, 1]] = tmp_P[mask[:, 0], mask[:, 1], mask[:, 2]] + matches0 = torch.from_numpy(matches0).to(P.device) + scores0 = torch.from_numpy(scores0).to(P.device) + return matches0, scores0 + + +class DualSoftMax(BaseModel): + default_conf = { + "match_threshold": 0.2, + "inv_temperature": 20, + } + # shape: B x DIM x M + required_inputs = ["descriptors0", "descriptors1"] + + def _init(self, conf): + pass + + def _forward(self, data): + if ( + data["descriptors0"].size(-1) == 0 + or data["descriptors1"].size(-1) == 0 + ): + matches0 = torch.full( + data["descriptors0"].shape[:2], + -1, + device=data["descriptors0"].device, + ) + return { + "matches0": matches0, + "matching_scores0": torch.zeros_like(matches0), + } + + matches0, scores0 = dual_softmax_matcher( + data["descriptors0"], + data["descriptors1"], + threshold=self.conf["match_threshold"], + inv_temperature=self.conf["inv_temperature"], + ) + return { + "matches0": matches0, # 1 x M + "matching_scores0": scores0, + } diff --git a/hloc/matchers/duster.py b/hloc/matchers/duster.py new file mode 100644 index 0000000000000000000000000000000000000000..2243d8aad04ee52df61044a5a945ec55414860f9 --- /dev/null +++ b/hloc/matchers/duster.py @@ -0,0 +1,129 @@ +import os +import sys +import urllib.request +from pathlib import Path + +import numpy as np +import torch +import torchvision.transforms as tfm + +from .. import logger +from ..utils.base_model import BaseModel + +duster_path = Path(__file__).parent / "../../third_party/dust3r" +sys.path.append(str(duster_path)) + +from dust3r.cloud_opt import GlobalAlignerMode, global_aligner +from dust3r.image_pairs import make_pairs +from dust3r.inference import inference +from dust3r.model import AsymmetricCroCo3DStereo +from dust3r.utils.geometry import find_reciprocal_matches, xy_grid + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class Duster(BaseModel): + default_conf = { + "name": "Duster3r", + "model_path": duster_path / "model_weights/duster_vit_large.pth", + "max_keypoints": 3000, + "vit_patch_size": 16, + } + + def _init(self, conf): + self.normalize = tfm.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + self.model_path = self.conf["model_path"] + self.download_weights() + # self.net = load_model(self.model_path, device) + self.net = AsymmetricCroCo3DStereo.from_pretrained( + self.model_path + # "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt" + ).to(device) + logger.info("Loaded Dust3r model") + + def download_weights(self): + url = "https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth" + + self.model_path.parent.mkdir(parents=True, exist_ok=True) + if not os.path.isfile(self.model_path): + logger.info("Downloading Duster(ViT large)... (takes a while)") + urllib.request.urlretrieve(url, self.model_path) + + def preprocess(self, img): + # the super-class already makes sure that img0,img1 have + # same resolution and that h == w + _, h, _ = img.shape + imsize = h + if not ((h % self.vit_patch_size) == 0): + imsize = int( + self.vit_patch_size * round(h / self.vit_patch_size, 0) + ) + img = tfm.functional.resize(img, imsize, antialias=True) + + _, new_h, new_w = img.shape + if not ((new_w % self.vit_patch_size) == 0): + safe_w = int( + self.vit_patch_size * round(new_w / self.vit_patch_size, 0) + ) + img = tfm.functional.resize(img, (new_h, safe_w), antialias=True) + + img = self.normalize(img).unsqueeze(0) + + return img + + def _forward(self, data): + img0, img1 = data["image0"], data["image1"] + mean = torch.tensor([0.5, 0.5, 0.5]).to(device) + std = torch.tensor([0.5, 0.5, 0.5]).to(device) + + img0 = (img0 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1) + img1 = (img1 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1) + + images = [ + {"img": img0, "idx": 0, "instance": 0}, + {"img": img1, "idx": 1, "instance": 1}, + ] + pairs = make_pairs( + images, scene_graph="complete", prefilter=None, symmetrize=True + ) + output = inference(pairs, self.net, device, batch_size=1) + scene = global_aligner( + output, device=device, mode=GlobalAlignerMode.PairViewer + ) + # retrieve useful values from scene: + imgs = scene.imgs + confidence_masks = scene.get_masks() + pts3d = scene.get_pts3d() + pts2d_list, pts3d_list = [], [] + for i in range(2): + conf_i = confidence_masks[i].cpu().numpy() + pts2d_list.append( + xy_grid(*imgs[i].shape[:2][::-1])[conf_i] + ) # imgs[i].shape[:2] = (H, W) + pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i]) + + if len(pts3d_list[1]) == 0: + pred = { + "keypoints0": torch.zeros([0, 2]), + "keypoints1": torch.zeros([0, 2]), + } + logger.warning(f"Matched {0} points") + else: + reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches( + *pts3d_list + ) + logger.info(f"Found {num_matches} matches") + mkpts1 = pts2d_list[1][reciprocal_in_P2] + mkpts0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2] + top_k = self.conf["max_keypoints"] + if top_k is not None and len(mkpts0) > top_k: + keep = np.round(np.linspace(0, len(mkpts0) - 1, top_k)).astype( + int + ) + mkpts0 = mkpts0[keep] + mkpts1 = mkpts1[keep] + pred = { + "keypoints0": torch.from_numpy(mkpts0), + "keypoints1": torch.from_numpy(mkpts1), + } + return pred diff --git a/hloc/matchers/eloftr.py b/hloc/matchers/eloftr.py new file mode 100644 index 0000000000000000000000000000000000000000..d22906de8bf7cc912745c21b950458829dee5d19 --- /dev/null +++ b/hloc/matchers/eloftr.py @@ -0,0 +1,92 @@ +import sys +import warnings +from copy import deepcopy +from pathlib import Path + +import torch + +tp_path = Path(__file__).parent / "../../third_party" +sys.path.append(str(tp_path)) + +from EfficientLoFTR.src.loftr import LoFTR as ELoFTR_ +from EfficientLoFTR.src.loftr import ( + full_default_cfg, + opt_default_cfg, + reparameter, +) + +from hloc import logger + +from ..utils.base_model import BaseModel + + +class ELoFTR(BaseModel): + default_conf = { + "weights": "weights/eloftr_outdoor.ckpt", + "match_threshold": 0.2, + # "sinkhorn_iterations": 20, + "max_keypoints": -1, + # You can choose model type in ['full', 'opt'] + "model_type": "full", # 'full' for best quality, 'opt' for best efficiency + # You can choose numerical precision in ['fp32', 'mp', 'fp16']. 'fp16' for best efficiency + "precision": "fp32", + } + required_inputs = ["image0", "image1"] + + def _init(self, conf): + + if self.conf["model_type"] == "full": + _default_cfg = deepcopy(full_default_cfg) + elif self.conf["model_type"] == "opt": + _default_cfg = deepcopy(opt_default_cfg) + + if self.conf["precision"] == "mp": + _default_cfg["mp"] = True + elif self.conf["precision"] == "fp16": + _default_cfg["half"] = True + model_path = tp_path / "EfficientLoFTR" / self.conf["weights"] + cfg = _default_cfg + cfg["match_coarse"]["thr"] = conf["match_threshold"] + # cfg["match_coarse"]["skh_iters"] = conf["sinkhorn_iterations"] + state_dict = torch.load(model_path, map_location="cpu")["state_dict"] + matcher = ELoFTR_(config=cfg) + matcher.load_state_dict(state_dict) + self.net = reparameter(matcher) + + if self.conf["precision"] == "fp16": + self.net = self.net.half() + logger.info(f"Loaded Efficient LoFTR with weights {conf['weights']}") + + def _forward(self, data): + # For consistency with hloc pairs, we refine kpts in image0! + rename = { + "keypoints0": "keypoints1", + "keypoints1": "keypoints0", + "image0": "image1", + "image1": "image0", + "mask0": "mask1", + "mask1": "mask0", + } + data_ = {rename[k]: v for k, v in data.items()} + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + pred = self.net(data_) + pred = { + "keypoints0": data_["mkpts0_f"], + "keypoints1": data_["mkpts1_f"], + } + scores = data_["mconf"] + + top_k = self.conf["max_keypoints"] + if top_k is not None and len(scores) > top_k: + keep = torch.argsort(scores, descending=True)[:top_k] + pred["keypoints0"], pred["keypoints1"] = ( + pred["keypoints0"][keep], + pred["keypoints1"][keep], + ) + scores = scores[keep] + + # Switch back indices + pred = {(rename[k] if k in rename else k): v for k, v in pred.items()} + pred["scores"] = scores + return pred diff --git a/hloc/matchers/gim.py b/hloc/matchers/gim.py new file mode 100644 index 0000000000000000000000000000000000000000..d0ccffa145314570a9419af3fd021740916d177f --- /dev/null +++ b/hloc/matchers/gim.py @@ -0,0 +1,143 @@ +import subprocess +import sys +from pathlib import Path + +import gdown +import torch + +from .. import logger +from ..utils.base_model import BaseModel + +gim_path = Path(__file__).parent / "../../third_party/gim" +sys.path.append(str(gim_path)) + +from dkm.models.model_zoo.DKMv3 import DKMv3 + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class GIM(BaseModel): + default_conf = { + "model_name": "gim_dkm_100h.ckpt", + "match_threshold": 0.2, + "checkpoint_dir": gim_path / "weights", + } + required_inputs = [ + "image0", + "image1", + ] + model_dict = { + "gim_lightglue_100h.ckpt": "https://github.com/xuelunshen/gim/blob/main/weights/gim_lightglue_100h.ckpt", + "gim_dkm_100h.ckpt": "https://drive.google.com/file/d/1gk97V4IROnR1Nprq10W9NCFUv2mxXR_-/view", + } + + def _init(self, conf): + conf["model_name"] = str(conf["weights"]) + if conf["model_name"] not in self.model_dict: + raise ValueError(f"Unknown GIM model {conf['model_name']}.") + model_path = conf["checkpoint_dir"] / conf["model_name"] + + # Download the model. + if not model_path.exists(): + model_path.parent.mkdir(exist_ok=True) + model_link = self.model_dict[conf["model_name"]] + if "drive.google.com" in model_link: + gdown.download(model_link, output=str(model_path), fuzzy=True) + else: + cmd = ["wget", "--quiet", model_link, "-O", str(model_path)] + subprocess.run(cmd, check=True) + logger.info("Downloaded GIM model succeeed!") + + self.aspect_ratio = 896 / 672 + model = DKMv3(None, 672, 896, upsample_preds=True) + state_dict = torch.load(str(model_path), map_location="cpu") + if "state_dict" in state_dict.keys(): + state_dict = state_dict["state_dict"] + for k in list(state_dict.keys()): + if k.startswith("model."): + state_dict[k.replace("model.", "", 1)] = state_dict.pop(k) + if "encoder.net.fc" in k: + state_dict.pop(k) + model.load_state_dict(state_dict) + + self.net = model + logger.info("Loaded GIM model") + + def pad_image(self, image, aspect_ratio): + new_width = max(image.shape[3], int(image.shape[2] * aspect_ratio)) + new_height = max(image.shape[2], int(image.shape[3] / aspect_ratio)) + pad_width = new_width - image.shape[3] + pad_height = new_height - image.shape[2] + return torch.nn.functional.pad( + image, + ( + pad_width // 2, + pad_width - pad_width // 2, + pad_height // 2, + pad_height - pad_height // 2, + ), + ) + + def rescale_kpts(self, sparse_matches, shape0, shape1): + kpts0 = torch.stack( + ( + shape0[1] * (sparse_matches[:, 0] + 1) / 2, + shape0[0] * (sparse_matches[:, 1] + 1) / 2, + ), + dim=-1, + ) + kpts1 = torch.stack( + ( + shape1[1] * (sparse_matches[:, 2] + 1) / 2, + shape1[0] * (sparse_matches[:, 3] + 1) / 2, + ), + dim=-1, + ) + return kpts0, kpts1 + + def compute_mask(self, kpts0, kpts1, orig_shape0, orig_shape1): + mask = ( + (kpts0[:, 0] > 0) + & (kpts0[:, 1] > 0) + & (kpts1[:, 0] > 0) + & (kpts1[:, 1] > 0) + ) + mask &= ( + (kpts0[:, 0] <= (orig_shape0[1] - 1)) + & (kpts1[:, 0] <= (orig_shape1[1] - 1)) + & (kpts0[:, 1] <= (orig_shape0[0] - 1)) + & (kpts1[:, 1] <= (orig_shape1[0] - 1)) + ) + return mask + + def _forward(self, data): + image0, image1 = self.pad_image( + data["image0"], self.aspect_ratio + ), self.pad_image(data["image1"], self.aspect_ratio) + dense_matches, dense_certainty = self.net.match(image0, image1) + sparse_matches, mconf = self.net.sample( + dense_matches, dense_certainty, self.conf["max_keypoints"] + ) + kpts0, kpts1 = self.rescale_kpts( + sparse_matches, image0.shape[-2:], image1.shape[-2:] + ) + mask = self.compute_mask( + kpts0, kpts1, data["image0"].shape[-2:], data["image1"].shape[-2:] + ) + b_ids, i_ids = torch.where(mconf[None]) + pred = { + "keypoints0": kpts0[i_ids], + "keypoints1": kpts1[i_ids], + "confidence": mconf[i_ids], + "batch_indexes": b_ids, + } + scores, b_ids = pred["confidence"], pred["batch_indexes"] + kpts0, kpts1 = pred["keypoints0"], pred["keypoints1"] + pred["confidence"], pred["batch_indexes"] = scores[mask], b_ids[mask] + pred["keypoints0"], pred["keypoints1"] = kpts0[mask], kpts1[mask] + + out = { + "keypoints0": pred["keypoints0"], + "keypoints1": pred["keypoints1"], + } + return out diff --git a/hloc/matchers/gluestick.py b/hloc/matchers/gluestick.py new file mode 100644 index 0000000000000000000000000000000000000000..b14614e23f58fd9d1bcb9a39d73a18d5d12ee6df --- /dev/null +++ b/hloc/matchers/gluestick.py @@ -0,0 +1,111 @@ +import subprocess +import sys +from pathlib import Path + +import torch + +from .. import logger +from ..utils.base_model import BaseModel + +gluestick_path = Path(__file__).parent / "../../third_party/GlueStick" +sys.path.append(str(gluestick_path)) + +from gluestick import batch_to_np +from gluestick.models.two_view_pipeline import TwoViewPipeline + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class GlueStick(BaseModel): + default_conf = { + "name": "two_view_pipeline", + "model_name": "checkpoint_GlueStick_MD.tar", + "use_lines": True, + "max_keypoints": 1000, + "max_lines": 300, + "force_num_keypoints": False, + } + required_inputs = [ + "image0", + "image1", + ] + + gluestick_models = { + "checkpoint_GlueStick_MD.tar": "https://github.com/cvg/GlueStick/releases/download/v0.1_arxiv/checkpoint_GlueStick_MD.tar", + } + + # Initialize the line matcher + def _init(self, conf): + model_path = ( + gluestick_path / "resources" / "weights" / conf["model_name"] + ) + + # Download the model. + if not model_path.exists(): + model_path.parent.mkdir(exist_ok=True) + link = self.gluestick_models[conf["model_name"]] + cmd = ["wget", "--quiet", link, "-O", str(model_path)] + logger.info(f"Downloading the Gluestick model with `{cmd}`.") + subprocess.run(cmd, check=True) + logger.info("Loading GlueStick model...") + + gluestick_conf = { + "name": "two_view_pipeline", + "use_lines": True, + "extractor": { + "name": "wireframe", + "sp_params": { + "force_num_keypoints": False, + "max_num_keypoints": 1000, + }, + "wireframe_params": { + "merge_points": True, + "merge_line_endpoints": True, + }, + "max_n_lines": 300, + }, + "matcher": { + "name": "gluestick", + "weights": str(model_path), + "trainable": False, + }, + "ground_truth": { + "from_pose_depth": False, + }, + } + gluestick_conf["extractor"]["sp_params"]["max_num_keypoints"] = conf[ + "max_keypoints" + ] + gluestick_conf["extractor"]["sp_params"]["force_num_keypoints"] = conf[ + "force_num_keypoints" + ] + gluestick_conf["extractor"]["max_n_lines"] = conf["max_lines"] + self.net = TwoViewPipeline(gluestick_conf) + + def _forward(self, data): + pred = self.net(data) + + pred = batch_to_np(pred) + kp0, kp1 = pred["keypoints0"], pred["keypoints1"] + m0 = pred["matches0"] + + line_seg0, line_seg1 = pred["lines0"], pred["lines1"] + line_matches = pred["line_matches0"] + + valid_matches = m0 != -1 + match_indices = m0[valid_matches] + matched_kps0 = kp0[valid_matches] + matched_kps1 = kp1[match_indices] + + valid_matches = line_matches != -1 + match_indices = line_matches[valid_matches] + matched_lines0 = line_seg0[valid_matches] + matched_lines1 = line_seg1[match_indices] + + pred["raw_lines0"], pred["raw_lines1"] = line_seg0, line_seg1 + pred["lines0"], pred["lines1"] = matched_lines0, matched_lines1 + pred["keypoints0"], pred["keypoints1"] = torch.from_numpy( + matched_kps0 + ), torch.from_numpy(matched_kps1) + pred = {**pred, **data} + return pred diff --git a/hloc/matchers/imp.py b/hloc/matchers/imp.py new file mode 100644 index 0000000000000000000000000000000000000000..05c3cb96b05410985ca97f89d8fe55a4d71be501 --- /dev/null +++ b/hloc/matchers/imp.py @@ -0,0 +1,46 @@ +import sys +from pathlib import Path + +import torch + +from .. import DEVICE, logger +from ..utils.base_model import BaseModel + +tp_path = Path(__file__).parent / "../../third_party" +sys.path.append(str(tp_path)) +from pram.nets.gml import GML + + +class IMP(BaseModel): + default_conf = { + "match_threshold": 0.2, + "features": "sfd2", + "model_name": "imp_gml.920.pth", + "sinkhorn_iterations": 20, + } + required_inputs = [ + "image0", + "keypoints0", + "scores0", + "descriptors0", + "image1", + "keypoints1", + "scores1", + "descriptors1", + ] + + def _init(self, conf): + self.conf = {**self.default_conf, **conf} + weight_path = tp_path / "pram" / "weights" / self.conf["model_name"] + # self.net = nets.gml(self.conf).eval().to(DEVICE) + self.net = GML(self.conf).eval().to(DEVICE) + self.net.load_state_dict( + torch.load(weight_path, map_location="cpu")["model"], strict=True + ) + logger.info("Load IMP model done.") + + def _forward(self, data): + data["descriptors0"] = data["descriptors0"].transpose(2, 1).float() + data["descriptors1"] = data["descriptors1"].transpose(2, 1).float() + + return self.net.produce_matches(data, p=0.2) diff --git a/hloc/matchers/lightglue.py b/hloc/matchers/lightglue.py new file mode 100644 index 0000000000000000000000000000000000000000..4a36be64b4e4dbe95d45bb1f52c869fe067de58f --- /dev/null +++ b/hloc/matchers/lightglue.py @@ -0,0 +1,63 @@ +import sys +from pathlib import Path + +from .. import logger +from ..utils.base_model import BaseModel + +lightglue_path = Path(__file__).parent / "../../third_party/LightGlue" +sys.path.append(str(lightglue_path)) +from lightglue import LightGlue as LG + + +class LightGlue(BaseModel): + default_conf = { + "match_threshold": 0.2, + "filter_threshold": 0.2, + "width_confidence": 0.99, # for point pruning + "depth_confidence": 0.95, # for early stopping, + "features": "superpoint", + "model_name": "superpoint_lightglue.pth", + "flash": True, # enable FlashAttention if available. + "mp": False, # enable mixed precision + "add_scale_ori": False, + } + required_inputs = [ + "image0", + "keypoints0", + "scores0", + "descriptors0", + "image1", + "keypoints1", + "scores1", + "descriptors1", + ] + + def _init(self, conf): + weight_path = lightglue_path / "weights" / conf["model_name"] + conf["weights"] = str(weight_path) + conf["filter_threshold"] = conf["match_threshold"] + self.net = LG(**conf) + logger.info("Load lightglue model done.") + + def _forward(self, data): + input = {} + input["image0"] = { + "image": data["image0"], + "keypoints": data["keypoints0"], + "descriptors": data["descriptors0"].permute(0, 2, 1), + } + if "scales0" in data: + input["image0"] = {**input["image0"], "scales": data["scales0"]} + if "oris0" in data: + input["image0"] = {**input["image0"], "oris": data["oris0"]} + + input["image1"] = { + "image": data["image1"], + "keypoints": data["keypoints1"], + "descriptors": data["descriptors1"].permute(0, 2, 1), + } + if "scales1" in data: + input["image1"] = {**input["image1"], "scales": data["scales1"]} + if "oris1" in data: + input["image1"] = {**input["image1"], "oris": data["oris1"]} + return self.net(input) diff --git a/hloc/matchers/loftr.py b/hloc/matchers/loftr.py new file mode 100644 index 0000000000000000000000000000000000000000..a1405b7073a80ab946ec8d724642a8f8ab9de9ba --- /dev/null +++ b/hloc/matchers/loftr.py @@ -0,0 +1,58 @@ +import warnings + +import torch +from kornia.feature import LoFTR as LoFTR_ +from kornia.feature.loftr.loftr import default_cfg + +from hloc import logger + +from ..utils.base_model import BaseModel + + +class LoFTR(BaseModel): + default_conf = { + "weights": "outdoor", + "match_threshold": 0.2, + "sinkhorn_iterations": 20, + "max_keypoints": -1, + } + required_inputs = ["image0", "image1"] + + def _init(self, conf): + cfg = default_cfg + cfg["match_coarse"]["thr"] = conf["match_threshold"] + cfg["match_coarse"]["skh_iters"] = conf["sinkhorn_iterations"] + self.net = LoFTR_(pretrained=conf["weights"], config=cfg) + logger.info(f"Loaded LoFTR with weights {conf['weights']}") + + def _forward(self, data): + # For consistency with hloc pairs, we refine kpts in image0! + rename = { + "keypoints0": "keypoints1", + "keypoints1": "keypoints0", + "image0": "image1", + "image1": "image0", + "mask0": "mask1", + "mask1": "mask0", + } + data_ = {rename[k]: v for k, v in data.items()} + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + pred = self.net(data_) + + scores = pred["confidence"] + + top_k = self.conf["max_keypoints"] + if top_k is not None and len(scores) > top_k: + keep = torch.argsort(scores, descending=True)[:top_k] + pred["keypoints0"], pred["keypoints1"] = ( + pred["keypoints0"][keep], + pred["keypoints1"][keep], + ) + scores = scores[keep] + + # Switch back indices + pred = {(rename[k] if k in rename else k): v for k, v in pred.items()} + pred["scores"] = scores + del pred["confidence"] + return pred diff --git a/hloc/matchers/mast3r.py b/hloc/matchers/mast3r.py new file mode 100644 index 0000000000000000000000000000000000000000..46489cc278f90df6e1039ecca1361c4461afb574 --- /dev/null +++ b/hloc/matchers/mast3r.py @@ -0,0 +1,111 @@ +import os +import sys +import urllib.request +from pathlib import Path + +import numpy as np +import torch +import torchvision.transforms as tfm + +from .. import logger + +mast3r_path = Path(__file__).parent / "../../third_party/mast3r" +sys.path.append(str(mast3r_path)) + +dust3r_path = Path(__file__).parent / "../../third_party/dust3r" +sys.path.append(str(dust3r_path)) + +from dust3r.image_pairs import make_pairs +from dust3r.inference import inference +from mast3r.fast_nn import fast_reciprocal_NNs +from mast3r.model import AsymmetricMASt3R + +from hloc.matchers.duster import Duster + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class Mast3r(Duster): + default_conf = { + "name": "Mast3r", + "model_path": mast3r_path + / "model_weights/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth", + "max_keypoints": 2000, + "vit_patch_size": 16, + } + + def _init(self, conf): + self.normalize = tfm.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + self.model_path = self.conf["model_path"] + self.download_weights() + self.net = AsymmetricMASt3R.from_pretrained(self.model_path).to(device) + logger.info("Loaded Mast3r model") + + def download_weights(self): + url = "https://download.europe.naverlabs.com/ComputerVision/MASt3R/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth" + + self.model_path.parent.mkdir(parents=True, exist_ok=True) + if not os.path.isfile(self.model_path): + logger.info("Downloading Mast3r(ViT large)... (takes a while)") + urllib.request.urlretrieve(url, self.model_path) + logger.info("Downloading Mast3r(ViT large)... done!") + + def _forward(self, data): + img0, img1 = data["image0"], data["image1"] + mean = torch.tensor([0.5, 0.5, 0.5]).to(device) + std = torch.tensor([0.5, 0.5, 0.5]).to(device) + + img0 = (img0 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1) + img1 = (img1 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1) + + images = [ + {"img": img0, "idx": 0, "instance": 0}, + {"img": img1, "idx": 1, "instance": 1}, + ] + pairs = make_pairs( + images, scene_graph="complete", prefilter=None, symmetrize=True + ) + output = inference(pairs, self.net, device, batch_size=1) + + # at this stage, you have the raw dust3r predictions + _, pred1 = output["view1"], output["pred1"] + _, pred2 = output["view2"], output["pred2"] + + desc1, desc2 = ( + pred1["desc"][1].squeeze(0).detach(), + pred2["desc"][1].squeeze(0).detach(), + ) + + # find 2D-2D matches between the two images + matches_im0, matches_im1 = fast_reciprocal_NNs( + desc1, + desc2, + subsample_or_initxy1=2, + device=device, + dist="dot", + block_size=2**13, + ) + + mkpts0 = matches_im0.copy() + mkpts1 = matches_im1.copy() + + if len(mkpts0) == 0: + pred = { + "keypoints0": torch.zeros([0, 2]), + "keypoints1": torch.zeros([0, 2]), + } + logger.warning(f"Matched {0} points") + else: + + top_k = self.conf["max_keypoints"] + if top_k is not None and len(mkpts0) > top_k: + keep = np.round(np.linspace(0, len(mkpts0) - 1, top_k)).astype( + int + ) + mkpts0 = mkpts0[keep] + mkpts1 = mkpts1[keep] + pred = { + "keypoints0": torch.from_numpy(mkpts0), + "keypoints1": torch.from_numpy(mkpts1), + } + return pred diff --git a/hloc/matchers/mickey.py b/hloc/matchers/mickey.py new file mode 100644 index 0000000000000000000000000000000000000000..3d60ff5f229ba31a0922406fe54c0588fd6a4273 --- /dev/null +++ b/hloc/matchers/mickey.py @@ -0,0 +1,67 @@ +import subprocess +import sys +from pathlib import Path + +import torch + +from .. import logger +from ..utils.base_model import BaseModel + +mickey_path = Path(__file__).parent / "../../third_party" +sys.path.append(str(mickey_path)) + +from mickey.config.default import cfg +from mickey.lib.models.builder import build_model + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class Mickey(BaseModel): + default_conf = { + "config_path": "config.yaml", + "model_name": "mickey.ckpt", + "max_keypoints": 3000, + } + required_inputs = [ + "image0", + "image1", + ] + weight_urls = "https://storage.googleapis.com/niantic-lon-static/research/mickey/assets/mickey_weights.zip" + + # Initialize the line matcher + def _init(self, conf): + model_path = mickey_path / "mickey/mickey_weights" / conf["model_name"] + zip_path = mickey_path / "mickey/mickey_weights.zip" + config_path = model_path.parent / self.conf["config_path"] + # Download the model. + if not model_path.exists(): + model_path.parent.mkdir(exist_ok=True, parents=True) + link = self.weight_urls + if not zip_path.exists(): + cmd = ["wget", "--quiet", link, "-O", str(zip_path)] + logger.info(f"Downloading the Mickey model with {cmd}.") + subprocess.run(cmd, check=True) + cmd = ["unzip", "-d", str(model_path.parent.parent), str(zip_path)] + logger.info(f"Running {cmd}.") + subprocess.run(cmd, check=True) + + logger.info("Loading mickey model...") + cfg.merge_from_file(config_path) + self.net = build_model(cfg, checkpoint=model_path) + logger.info("Load Mickey model done.") + + def _forward(self, data): + # data['K_color0'] = torch.from_numpy(K['im0.jpg']).unsqueeze(0).to(device) + # data['K_color1'] = torch.from_numpy(K['im1.jpg']).unsqueeze(0).to(device) + pred = self.net(data) + pred = { + **pred, + **data, + } + inliers = data["inliers_list"] + pred = { + "keypoints0": inliers[:, :2], + "keypoints1": inliers[:, 2:4], + } + + return pred diff --git a/hloc/matchers/nearest_neighbor.py b/hloc/matchers/nearest_neighbor.py new file mode 100644 index 0000000000000000000000000000000000000000..1d42d6b6cf48399f23d22a6f6949ef3d16e9c4e7 --- /dev/null +++ b/hloc/matchers/nearest_neighbor.py @@ -0,0 +1,74 @@ +import torch + +from ..utils.base_model import BaseModel + + +def find_nn(sim, ratio_thresh, distance_thresh): + sim_nn, ind_nn = sim.topk(2 if ratio_thresh else 1, dim=-1, largest=True) + dist_nn = 2 * (1 - sim_nn) + mask = torch.ones(ind_nn.shape[:-1], dtype=torch.bool, device=sim.device) + if ratio_thresh: + mask = mask & (dist_nn[..., 0] <= (ratio_thresh**2) * dist_nn[..., 1]) + if distance_thresh: + mask = mask & (dist_nn[..., 0] <= distance_thresh**2) + matches = torch.where(mask, ind_nn[..., 0], ind_nn.new_tensor(-1)) + scores = torch.where(mask, (sim_nn[..., 0] + 1) / 2, sim_nn.new_tensor(0)) + return matches, scores + + +def mutual_check(m0, m1): + inds0 = torch.arange(m0.shape[-1], device=m0.device) + loop = torch.gather(m1, -1, torch.where(m0 > -1, m0, m0.new_tensor(0))) + ok = (m0 > -1) & (inds0 == loop) + m0_new = torch.where(ok, m0, m0.new_tensor(-1)) + return m0_new + + +class NearestNeighbor(BaseModel): + default_conf = { + "ratio_threshold": None, + "distance_threshold": None, + "do_mutual_check": True, + } + required_inputs = ["descriptors0", "descriptors1"] + + def _init(self, conf): + pass + + def _forward(self, data): + if ( + data["descriptors0"].size(-1) == 0 + or data["descriptors1"].size(-1) == 0 + ): + matches0 = torch.full( + data["descriptors0"].shape[:2], + -1, + device=data["descriptors0"].device, + ) + return { + "matches0": matches0, + "matching_scores0": torch.zeros_like(matches0), + } + ratio_threshold = self.conf["ratio_threshold"] + if ( + data["descriptors0"].size(-1) == 1 + or data["descriptors1"].size(-1) == 1 + ): + ratio_threshold = None + sim = torch.einsum( + "bdn,bdm->bnm", data["descriptors0"], data["descriptors1"] + ) + matches0, scores0 = find_nn( + sim, ratio_threshold, self.conf["distance_threshold"] + ) + if self.conf["do_mutual_check"]: + matches1, scores1 = find_nn( + sim.transpose(1, 2), + ratio_threshold, + self.conf["distance_threshold"], + ) + matches0 = mutual_check(matches0, matches1) + return { + "matches0": matches0, + "matching_scores0": scores0, + } diff --git a/hloc/matchers/omniglue.py b/hloc/matchers/omniglue.py new file mode 100644 index 0000000000000000000000000000000000000000..c02d7f35f10706565d109987b12daa166703113e --- /dev/null +++ b/hloc/matchers/omniglue.py @@ -0,0 +1,81 @@ +import subprocess +import sys +from pathlib import Path + +import numpy as np +import torch + +from .. import logger +from ..utils.base_model import BaseModel + +thirdparty_path = Path(__file__).parent / "../../third_party" +sys.path.append(str(thirdparty_path)) +from omniglue.src import omniglue + +omniglue_path = thirdparty_path / "omniglue" + + +class OmniGlue(BaseModel): + default_conf = { + "match_threshold": 0.02, + "max_keypoints": 2048, + } + required_inputs = ["image0", "image1"] + dino_v2_link_dict = { + "dinov2_vitb14_pretrain.pth": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth" + } + + def _init(self, conf): + logger.info("Loading OmniGlue model") + og_model_path = omniglue_path / "models" / "omniglue.onnx" + sp_model_path = omniglue_path / "models" / "sp_v6.onnx" + dino_model_path = ( + omniglue_path / "models" / "dinov2_vitb14_pretrain.pth" # ~330MB + ) + if not dino_model_path.exists(): + link = self.dino_v2_link_dict.get(dino_model_path.name, None) + if link is not None: + cmd = ["wget", "--quiet", link, "-O", str(dino_model_path)] + logger.info(f"Downloading the dinov2 model with `{cmd}`.") + subprocess.run(cmd, check=True) + else: + logger.error(f"Invalid dinov2 model: {dino_model_path.name}") + self.net = omniglue.OmniGlue( + og_export=str(og_model_path), + sp_export=str(sp_model_path), + dino_export=str(dino_model_path), + max_keypoints=self.conf["max_keypoints"], + ) + logger.info("Loaded OmniGlue model done!") + + def _forward(self, data): + image0_rgb_np = data["image0"][0].permute(1, 2, 0).cpu().numpy() * 255 + image1_rgb_np = data["image1"][0].permute(1, 2, 0).cpu().numpy() * 255 + image0_rgb_np = image0_rgb_np.astype(np.uint8) # RGB, 0-255 + image1_rgb_np = image1_rgb_np.astype(np.uint8) # RGB, 0-255 + match_kp0, match_kp1, match_confidences = self.net.FindMatches( + image0_rgb_np, image1_rgb_np, self.conf["max_keypoints"] + ) + # filter matches + match_threshold = self.conf["match_threshold"] + keep_idx = [] + for i in range(match_kp0.shape[0]): + if match_confidences[i] > match_threshold: + keep_idx.append(i) + scores = torch.from_numpy(match_confidences[keep_idx]).reshape(-1, 1) + pred = { + "keypoints0": torch.from_numpy(match_kp0[keep_idx]), + "keypoints1": torch.from_numpy(match_kp1[keep_idx]), + "mconf": scores, + } + + top_k = self.conf["max_keypoints"] + if top_k is not None and len(scores) > top_k: + keep = torch.argsort(scores, descending=True)[:top_k] + scores = scores[keep] + pred["keypoints0"], pred["keypoints1"], pred["mconf"] = ( + pred["keypoints0"][keep], + pred["keypoints1"][keep], + scores, + ) + return pred diff --git a/hloc/matchers/roma.py b/hloc/matchers/roma.py new file mode 100644 index 0000000000000000000000000000000000000000..01949160f98478c6c1620f60ddfa83cf555490f8 --- /dev/null +++ b/hloc/matchers/roma.py @@ -0,0 +1,98 @@ +import subprocess +import sys +from pathlib import Path + +import torch +from PIL import Image + +from .. import logger +from ..utils.base_model import BaseModel + +roma_path = Path(__file__).parent / "../../third_party/RoMa" +sys.path.append(str(roma_path)) +from romatch.models.model_zoo import roma_model + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class Roma(BaseModel): + default_conf = { + "name": "two_view_pipeline", + "model_name": "roma_outdoor.pth", + "model_utils_name": "dinov2_vitl14_pretrain.pth", + "max_keypoints": 3000, + } + required_inputs = [ + "image0", + "image1", + ] + weight_urls = { + "roma": { + "roma_outdoor.pth": "https://github.com/Parskatt/storage/releases/download/roma/roma_outdoor.pth", + "roma_indoor.pth": "https://github.com/Parskatt/storage/releases/download/roma/roma_indoor.pth", + }, + "dinov2_vitl14_pretrain.pth": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", + } + + # Initialize the line matcher + def _init(self, conf): + model_path = roma_path / "pretrained" / conf["model_name"] + dinov2_weights = roma_path / "pretrained" / conf["model_utils_name"] + + # Download the model. + if not model_path.exists(): + model_path.parent.mkdir(exist_ok=True) + link = self.weight_urls["roma"][conf["model_name"]] + cmd = ["wget", "--quiet", link, "-O", str(model_path)] + logger.info(f"Downloading the Roma model with `{cmd}`.") + subprocess.run(cmd, check=True) + + if not dinov2_weights.exists(): + dinov2_weights.parent.mkdir(exist_ok=True) + link = self.weight_urls[conf["model_utils_name"]] + cmd = ["wget", "--quiet", link, "-O", str(dinov2_weights)] + logger.info(f"Downloading the dinov2 model with `{cmd}`.") + subprocess.run(cmd, check=True) + + logger.info("Loading Roma model") + # load the model + weights = torch.load(model_path, map_location="cpu") + dinov2_weights = torch.load(dinov2_weights, map_location="cpu") + + self.net = roma_model( + resolution=(14 * 8 * 6, 14 * 8 * 6), + upsample_preds=False, + weights=weights, + dinov2_weights=dinov2_weights, + device=device, + # temp fix issue: https://github.com/Parskatt/RoMa/issues/26 + amp_dtype=torch.float32, + ) + logger.info("Load Roma model done.") + + def _forward(self, data): + img0 = data["image0"].cpu().numpy().squeeze() * 255 + img1 = data["image1"].cpu().numpy().squeeze() * 255 + img0 = img0.transpose(1, 2, 0) + img1 = img1.transpose(1, 2, 0) + img0 = Image.fromarray(img0.astype("uint8")) + img1 = Image.fromarray(img1.astype("uint8")) + W_A, H_A = img0.size + W_B, H_B = img1.size + + # Match + warp, certainty = self.net.match(img0, img1, device=device) + # Sample matches for estimation + matches, certainty = self.net.sample( + warp, certainty, num=self.conf["max_keypoints"] + ) + kpts1, kpts2 = self.net.to_pixel_coordinates( + matches, H_A, W_A, H_B, W_B + ) + pred = { + "keypoints0": kpts1, + "keypoints1": kpts2, + "mconf": certainty, + } + + return pred diff --git a/hloc/matchers/sgmnet.py b/hloc/matchers/sgmnet.py new file mode 100644 index 0000000000000000000000000000000000000000..aa2f72a007aff422ff6fd71ee12aec1520f01f33 --- /dev/null +++ b/hloc/matchers/sgmnet.py @@ -0,0 +1,144 @@ +import subprocess +import sys +from collections import OrderedDict, namedtuple +from pathlib import Path + +import torch + +from .. import logger +from ..utils.base_model import BaseModel + +sgmnet_path = Path(__file__).parent / "../../third_party/SGMNet" +sys.path.append(str(sgmnet_path)) + +from sgmnet import matcher as SGM_Model + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class SGMNet(BaseModel): + default_conf = { + "name": "SGM", + "model_name": "model_best.pth", + "seed_top_k": [256, 256], + "seed_radius_coe": 0.01, + "net_channels": 128, + "layer_num": 9, + "head": 4, + "seedlayer": [0, 6], + "use_mc_seeding": True, + "use_score_encoding": False, + "conf_bar": [1.11, 0.1], + "sink_iter": [10, 100], + "detach_iter": 1000000, + "match_threshold": 0.2, + } + required_inputs = [ + "image0", + "image1", + ] + weight_urls = { + "model_best.pth": "https://drive.google.com/uc?id=1Ca0WmKSSt2G6P7m8YAOlSAHEFar_TAWb&confirm=t", + } + proxy = "http://localhost:1080" + + # Initialize the line matcher + def _init(self, conf): + sgmnet_weights = sgmnet_path / "weights/sgm/root" / conf["model_name"] + + link = self.weight_urls[conf["model_name"]] + tar_path = sgmnet_path / "weights.tar.gz" + # Download the model. + if not sgmnet_weights.exists(): + if not tar_path.exists(): + cmd = [ + "gdown", + link, + "-O", + str(tar_path), + "--proxy", + self.proxy, + ] + cmd_wo_proxy = ["gdown", link, "-O", str(tar_path)] + logger.info( + f"Downloading the SGMNet model with `{cmd_wo_proxy}`." + ) + try: + subprocess.run(cmd_wo_proxy, check=True) + except subprocess.CalledProcessError as e: + logger.info(f"Downloading failed {e}.") + logger.info(f"Downloading the SGMNet model with `{cmd}`.") + try: + subprocess.run(cmd, check=True) + except subprocess.CalledProcessError as e: + logger.error("Failed to download the SGMNet model.") + raise e + cmd = ["tar", "-xvf", str(tar_path), "-C", str(sgmnet_path)] + logger.info(f"Unzip model file `{cmd}`.") + subprocess.run(cmd, check=True) + + # config + config = namedtuple("config", conf.keys())(*conf.values()) + self.net = SGM_Model(config) + checkpoint = torch.load(sgmnet_weights, map_location="cpu") + # for ddp model + if ( + list(checkpoint["state_dict"].items())[0][0].split(".")[0] + == "module" + ): + new_stat_dict = OrderedDict() + for key, value in checkpoint["state_dict"].items(): + new_stat_dict[key[7:]] = value + checkpoint["state_dict"] = new_stat_dict + self.net.load_state_dict(checkpoint["state_dict"]) + logger.info("Load SGMNet model done.") + + def _forward(self, data): + x1 = data["keypoints0"].squeeze() # N x 2 + x2 = data["keypoints1"].squeeze() + score1 = data["scores0"].reshape(-1, 1) # N x 1 + score2 = data["scores1"].reshape(-1, 1) + desc1 = data["descriptors0"].permute(0, 2, 1) # 1 x N x 128 + desc2 = data["descriptors1"].permute(0, 2, 1) + size1 = ( + torch.tensor(data["image0"].shape[2:]).flip(0).to(x1.device) + ) # W x H -> x & y + size2 = ( + torch.tensor(data["image1"].shape[2:]).flip(0).to(x2.device) + ) # W x H + norm_x1 = self.normalize_size(x1, size1) + norm_x2 = self.normalize_size(x2, size2) + + x1 = torch.cat((norm_x1, score1), dim=-1) # N x 3 + x2 = torch.cat((norm_x2, score2), dim=-1) + input = {"x1": x1[None], "x2": x2[None], "desc1": desc1, "desc2": desc2} + input = { + k: v.to(device).float() if isinstance(v, torch.Tensor) else v + for k, v in input.items() + } + pred = self.net(input, test_mode=True) + + p = pred["p"] # shape: N * M + indices0 = self.match_p(p[0, :-1, :-1]) + pred = { + "matches0": indices0.unsqueeze(0), + "matching_scores0": torch.zeros(indices0.size(0)).unsqueeze(0), + } + return pred + + def match_p(self, p): + score, index = torch.topk(p, k=1, dim=-1) + _, index2 = torch.topk(p, k=1, dim=-2) + mask_th, index, index2 = ( + score[:, 0] > self.conf["match_threshold"], + index[:, 0], + index2.squeeze(0), + ) + mask_mc = index2[index] == torch.arange(len(p)).to(device) + mask = mask_th & mask_mc + indices0 = torch.where(mask, index, index.new_tensor(-1)) + return indices0 + + def normalize_size(self, x, size, scale=1): + norm_fac = size.max() + return (x - size / 2 + 0.5) / (norm_fac * scale) diff --git a/hloc/matchers/sold2.py b/hloc/matchers/sold2.py new file mode 100644 index 0000000000000000000000000000000000000000..e7ac07f6a4e1c3f4af0ab79fd908cd6a350503d8 --- /dev/null +++ b/hloc/matchers/sold2.py @@ -0,0 +1,153 @@ +import subprocess +import sys +from pathlib import Path + +import torch + +from .. import logger +from ..utils.base_model import BaseModel + +sold2_path = Path(__file__).parent / "../../third_party/SOLD2" +sys.path.append(str(sold2_path)) + +from sold2.model.line_matcher import LineMatcher + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class SOLD2(BaseModel): + default_conf = { + "weights": "sold2_wireframe.tar", + "match_threshold": 0.2, + "checkpoint_dir": sold2_path / "pretrained", + "detect_thresh": 0.25, + "multiscale": False, + "valid_thresh": 1e-3, + "num_blocks": 20, + "overlap_ratio": 0.5, + } + required_inputs = [ + "image0", + "image1", + ] + + weight_urls = { + "sold2_wireframe.tar": "https://www.polybox.ethz.ch/index.php/s/blOrW89gqSLoHOk/download", + } + + # Initialize the line matcher + def _init(self, conf): + checkpoint_path = conf["checkpoint_dir"] / conf["weights"] + + # Download the model. + if not checkpoint_path.exists(): + checkpoint_path.parent.mkdir(exist_ok=True) + link = self.weight_urls[conf["weights"]] + cmd = ["wget", "--quiet", link, "-O", str(checkpoint_path)] + logger.info(f"Downloading the SOLD2 model with `{cmd}`.") + subprocess.run(cmd, check=True) + + mode = "dynamic" # 'dynamic' or 'static' + match_config = { + "model_cfg": { + "model_name": "lcnn_simple", + "model_architecture": "simple", + # Backbone related config + "backbone": "lcnn", + "backbone_cfg": { + "input_channel": 1, # Use RGB images or grayscale images. + "depth": 4, + "num_stacks": 2, + "num_blocks": 1, + "num_classes": 5, + }, + # Junction decoder related config + "junction_decoder": "superpoint_decoder", + "junc_decoder_cfg": {}, + # Heatmap decoder related config + "heatmap_decoder": "pixel_shuffle", + "heatmap_decoder_cfg": {}, + # Descriptor decoder related config + "descriptor_decoder": "superpoint_descriptor", + "descriptor_decoder_cfg": {}, + # Shared configurations + "grid_size": 8, + "keep_border_valid": True, + # Threshold of junction detection + "detection_thresh": 0.0153846, # 1/65 + "max_num_junctions": 300, + # Threshold of heatmap detection + "prob_thresh": 0.5, + # Weighting related parameters + "weighting_policy": mode, + # [Heatmap loss] + "w_heatmap": 0.0, + "w_heatmap_class": 1, + "heatmap_loss_func": "cross_entropy", + "heatmap_loss_cfg": {"policy": mode}, + # [Heatmap consistency loss] + # [Junction loss] + "w_junc": 0.0, + "junction_loss_func": "superpoint", + "junction_loss_cfg": {"policy": mode}, + # [Descriptor loss] + "w_desc": 0.0, + "descriptor_loss_func": "regular_sampling", + "descriptor_loss_cfg": { + "dist_threshold": 8, + "grid_size": 4, + "margin": 1, + "policy": mode, + }, + }, + "line_detector_cfg": { + "detect_thresh": 0.25, # depending on your images, you might need to tune this parameter + "num_samples": 64, + "sampling_method": "local_max", + "inlier_thresh": 0.9, + "use_candidate_suppression": True, + "nms_dist_tolerance": 3.0, + "use_heatmap_refinement": True, + "heatmap_refine_cfg": { + "mode": "local", + "ratio": 0.2, + "valid_thresh": 1e-3, + "num_blocks": 20, + "overlap_ratio": 0.5, + }, + }, + "multiscale": False, + "line_matcher_cfg": { + "cross_check": True, + "num_samples": 5, + "min_dist_pts": 8, + "top_k_candidates": 10, + "grid_size": 4, + }, + } + self.net = LineMatcher( + match_config["model_cfg"], + checkpoint_path, + device, + match_config["line_detector_cfg"], + match_config["line_matcher_cfg"], + match_config["multiscale"], + ) + + def _forward(self, data): + img0 = data["image0"] + img1 = data["image1"] + pred = self.net([img0, img1]) + line_seg1 = pred["line_segments"][0] + line_seg2 = pred["line_segments"][1] + matches = pred["matches"] + + valid_matches = matches != -1 + match_indices = matches[valid_matches] + matched_lines1 = line_seg1[valid_matches][:, :, ::-1] + matched_lines2 = line_seg2[match_indices][:, :, ::-1] + + pred["raw_lines0"], pred["raw_lines1"] = line_seg1, line_seg2 + pred["lines0"], pred["lines1"] = matched_lines1, matched_lines2 + pred = {**pred, **data} + return pred diff --git a/hloc/matchers/superglue.py b/hloc/matchers/superglue.py new file mode 100644 index 0000000000000000000000000000000000000000..6fae344e8dfe0fd46f090c6915036e5f9c09a635 --- /dev/null +++ b/hloc/matchers/superglue.py @@ -0,0 +1,33 @@ +import sys +from pathlib import Path + +from ..utils.base_model import BaseModel + +sys.path.append(str(Path(__file__).parent / "../../third_party")) +from SuperGluePretrainedNetwork.models.superglue import ( # noqa: E402 + SuperGlue as SG, +) + + +class SuperGlue(BaseModel): + default_conf = { + "weights": "outdoor", + "sinkhorn_iterations": 100, + "match_threshold": 0.2, + } + required_inputs = [ + "image0", + "keypoints0", + "scores0", + "descriptors0", + "image1", + "keypoints1", + "scores1", + "descriptors1", + ] + + def _init(self, conf): + self.net = SG(conf) + + def _forward(self, data): + return self.net(data) diff --git a/hloc/matchers/topicfm.py b/hloc/matchers/topicfm.py new file mode 100644 index 0000000000000000000000000000000000000000..2d4701cc0dbe4952712f4718e26256022dd0b522 --- /dev/null +++ b/hloc/matchers/topicfm.py @@ -0,0 +1,54 @@ +import sys +from pathlib import Path + +import torch + +from ..utils.base_model import BaseModel + +sys.path.append(str(Path(__file__).parent / "../../third_party")) +from TopicFM.src import get_model_cfg +from TopicFM.src.models.topic_fm import TopicFM as _TopicFM + +topicfm_path = Path(__file__).parent / "../../third_party/TopicFM" + + +class TopicFM(BaseModel): + default_conf = { + "weights": "outdoor", + "match_threshold": 0.2, + "n_sampling_topics": 4, + "max_keypoints": -1, + } + required_inputs = ["image0", "image1"] + + def _init(self, conf): + _conf = dict(get_model_cfg()) + _conf["match_coarse"]["thr"] = conf["match_threshold"] + _conf["coarse"]["n_samples"] = conf["n_sampling_topics"] + weight_path = topicfm_path / "pretrained/model_best.ckpt" + self.net = _TopicFM(config=_conf) + ckpt_dict = torch.load(weight_path, map_location="cpu") + self.net.load_state_dict(ckpt_dict["state_dict"]) + + def _forward(self, data): + data_ = { + "image0": data["image0"], + "image1": data["image1"], + } + self.net(data_) + pred = { + "keypoints0": data_["mkpts0_f"], + "keypoints1": data_["mkpts1_f"], + "mconf": data_["mconf"], + } + scores = data_["mconf"] + top_k = self.conf["max_keypoints"] + if top_k is not None and len(scores) > top_k: + keep = torch.argsort(scores, descending=True)[:top_k] + scores = scores[keep] + pred["keypoints0"], pred["keypoints1"], pred["mconf"] = ( + pred["keypoints0"][keep], + pred["keypoints1"][keep], + scores, + ) + return pred diff --git a/hloc/matchers/xfeat_dense.py b/hloc/matchers/xfeat_dense.py new file mode 100644 index 0000000000000000000000000000000000000000..00d660fed15530b78b4445299059cc152eeeea33 --- /dev/null +++ b/hloc/matchers/xfeat_dense.py @@ -0,0 +1,58 @@ +import torch + +from hloc import logger + +from ..utils.base_model import BaseModel + + +class XFeatDense(BaseModel): + default_conf = { + "keypoint_threshold": 0.005, + "max_keypoints": 8000, + } + required_inputs = [ + "image0", + "image1", + ] + + def _init(self, conf): + self.net = torch.hub.load( + "verlab/accelerated_features", + "XFeat", + pretrained=True, + top_k=self.conf["max_keypoints"], + ) + logger.info("Load XFeat(dense) model done.") + + def _forward(self, data): + # Compute coarse feats + out0 = self.net.detectAndComputeDense( + data["image0"], top_k=self.conf["max_keypoints"] + ) + out1 = self.net.detectAndComputeDense( + data["image1"], top_k=self.conf["max_keypoints"] + ) + + # Match batches of pairs + idxs_list = self.net.batch_match( + out0["descriptors"], out1["descriptors"] + ) + B = len(data["image0"]) + + # Refine coarse matches + # this part is harder to batch, currently iterate + matches = [] + for b in range(B): + matches.append( + self.net.refine_matches( + out0, out1, matches=idxs_list, batch_idx=b + ) + ) + # we use results from one batch + matches = matches[0] + pred = { + "keypoints0": matches[:, :2], + "keypoints1": matches[:, 2:], + "mconf": torch.ones_like(matches[:, 0]), + } + return pred diff --git a/hloc/matchers/xfeat_lightglue.py b/hloc/matchers/xfeat_lightglue.py new file mode 100644 index 0000000000000000000000000000000000000000..f87536178d333248589806a32496743283aec292 --- /dev/null +++ b/hloc/matchers/xfeat_lightglue.py @@ -0,0 +1,48 @@ +import torch + +from hloc import logger + +from ..utils.base_model import BaseModel + + +class XFeatLightGlue(BaseModel): + default_conf = { + "keypoint_threshold": 0.005, + "max_keypoints": 8000, + } + required_inputs = [ + "image0", + "image1", + ] + + def _init(self, conf): + self.net = torch.hub.load( + "verlab/accelerated_features", + "XFeat", + pretrained=True, + top_k=self.conf["max_keypoints"], + ) + logger.info("Load XFeat(dense) model done.") + + def _forward(self, data): + # we use results from one batch + im0 = data["image0"] + im1 = data["image1"] + # Compute coarse feats + out0 = self.net.detectAndCompute(im0, top_k=self.conf["max_keypoints"])[ + 0 + ] + out1 = self.net.detectAndCompute(im1, top_k=self.conf["max_keypoints"])[ + 0 + ] + out0.update({"image_size": (im0.shape[-1], im0.shape[-2])}) # W H + out1.update({"image_size": (im1.shape[-1], im1.shape[-2])}) # W H + mkpts_0, mkpts_1 = self.net.match_lighterglue(out0, out1) + mkpts_0 = torch.from_numpy(mkpts_0) # n x 2 + mkpts_1 = torch.from_numpy(mkpts_1) # n x 2 + pred = { + "keypoints0": mkpts_0, + "keypoints1": mkpts_1, + "mconf": torch.ones_like(mkpts_0[:, 0]), + } + return pred diff --git a/hloc/pairs_from_covisibility.py b/hloc/pairs_from_covisibility.py new file mode 100644 index 0000000000000000000000000000000000000000..49f3e57f2bd1aec20e12ecca6df8f94a68b7fd4e --- /dev/null +++ b/hloc/pairs_from_covisibility.py @@ -0,0 +1,60 @@ +import argparse +from collections import defaultdict +from pathlib import Path + +import numpy as np +from tqdm import tqdm + +from . import logger +from .utils.read_write_model import read_model + + +def main(model, output, num_matched): + logger.info("Reading the COLMAP model...") + cameras, images, points3D = read_model(model) + + logger.info("Extracting image pairs from covisibility info...") + pairs = [] + for image_id, image in tqdm(images.items()): + matched = image.point3D_ids != -1 + points3D_covis = image.point3D_ids[matched] + + covis = defaultdict(int) + for point_id in points3D_covis: + for image_covis_id in points3D[point_id].image_ids: + if image_covis_id != image_id: + covis[image_covis_id] += 1 + + if len(covis) == 0: + logger.info(f"Image {image_id} does not have any covisibility.") + continue + + covis_ids = np.array(list(covis.keys())) + covis_num = np.array([covis[i] for i in covis_ids]) + + if len(covis_ids) <= num_matched: + top_covis_ids = covis_ids[np.argsort(-covis_num)] + else: + # get covisible image ids with top k number of common matches + ind_top = np.argpartition(covis_num, -num_matched) + ind_top = ind_top[-num_matched:] # unsorted top k + ind_top = ind_top[np.argsort(-covis_num[ind_top])] + top_covis_ids = [covis_ids[i] for i in ind_top] + assert covis_num[ind_top[0]] == np.max(covis_num) + + for i in top_covis_ids: + pair = (image.name, images[i].name) + pairs.append(pair) + + logger.info(f"Found {len(pairs)} pairs.") + with open(output, "w") as f: + f.write("\n".join(" ".join([i, j]) for i, j in pairs)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", required=True, type=Path) + parser.add_argument("--output", required=True, type=Path) + parser.add_argument("--num_matched", required=True, type=int) + args = parser.parse_args() + main(**args.__dict__) diff --git a/hloc/pairs_from_exhaustive.py b/hloc/pairs_from_exhaustive.py new file mode 100644 index 0000000000000000000000000000000000000000..438b8141e344e0f6b7644514919bfc69075cbc3d --- /dev/null +++ b/hloc/pairs_from_exhaustive.py @@ -0,0 +1,66 @@ +import argparse +import collections.abc as collections +from pathlib import Path +from typing import List, Optional, Union + +from . import logger +from .utils.io import list_h5_names +from .utils.parsers import parse_image_lists + + +def main( + output: Path, + image_list: Optional[Union[Path, List[str]]] = None, + features: Optional[Path] = None, + ref_list: Optional[Union[Path, List[str]]] = None, + ref_features: Optional[Path] = None, +): + if image_list is not None: + if isinstance(image_list, (str, Path)): + names_q = parse_image_lists(image_list) + elif isinstance(image_list, collections.Iterable): + names_q = list(image_list) + else: + raise ValueError(f"Unknown type for image list: {image_list}") + elif features is not None: + names_q = list_h5_names(features) + else: + raise ValueError("Provide either a list of images or a feature file.") + + self_matching = False + if ref_list is not None: + if isinstance(ref_list, (str, Path)): + names_ref = parse_image_lists(ref_list) + elif isinstance(image_list, collections.Iterable): + names_ref = list(ref_list) + else: + raise ValueError( + f"Unknown type for reference image list: {ref_list}" + ) + elif ref_features is not None: + names_ref = list_h5_names(ref_features) + else: + self_matching = True + names_ref = names_q + + pairs = [] + for i, n1 in enumerate(names_q): + for j, n2 in enumerate(names_ref): + if self_matching and j <= i: + continue + pairs.append((n1, n2)) + + logger.info(f"Found {len(pairs)} pairs.") + with open(output, "w") as f: + f.write("\n".join(" ".join([i, j]) for i, j in pairs)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--output", required=True, type=Path) + parser.add_argument("--image_list", type=Path) + parser.add_argument("--features", type=Path) + parser.add_argument("--ref_list", type=Path) + parser.add_argument("--ref_features", type=Path) + args = parser.parse_args() + main(**args.__dict__) diff --git a/hloc/pairs_from_poses.py b/hloc/pairs_from_poses.py new file mode 100644 index 0000000000000000000000000000000000000000..b6b4f88f92834412f1753e7e3414e0f75e762367 --- /dev/null +++ b/hloc/pairs_from_poses.py @@ -0,0 +1,70 @@ +import argparse +from pathlib import Path + +import numpy as np +import scipy.spatial + +from . import logger +from .pairs_from_retrieval import pairs_from_score_matrix +from .utils.read_write_model import read_images_binary + +DEFAULT_ROT_THRESH = 30 # in degrees + + +def get_pairwise_distances(images): + ids = np.array(list(images.keys())) + Rs = [] + ts = [] + for id_ in ids: + image = images[id_] + R = image.qvec2rotmat() + t = image.tvec + Rs.append(R) + ts.append(t) + Rs = np.stack(Rs, 0) + ts = np.stack(ts, 0) + + # Invert the poses from world-to-camera to camera-to-world. + Rs = Rs.transpose(0, 2, 1) + ts = -(Rs @ ts[:, :, None])[:, :, 0] + + dist = scipy.spatial.distance.squareform(scipy.spatial.distance.pdist(ts)) + + # Instead of computing the angle between two camera orientations, + # we compute the angle between the principal axes, as two images rotated + # around their principal axis still observe the same scene. + axes = Rs[:, :, -1] + dots = np.einsum("mi,ni->mn", axes, axes, optimize=True) + dR = np.rad2deg(np.arccos(np.clip(dots, -1.0, 1.0))) + + return ids, dist, dR + + +def main(model, output, num_matched, rotation_threshold=DEFAULT_ROT_THRESH): + logger.info("Reading the COLMAP model...") + images = read_images_binary(model / "images.bin") + + logger.info(f"Obtaining pairwise distances between {len(images)} images...") + ids, dist, dR = get_pairwise_distances(images) + scores = -dist + + invalid = dR >= rotation_threshold + np.fill_diagonal(invalid, True) + pairs = pairs_from_score_matrix(scores, invalid, num_matched) + pairs = [(images[ids[i]].name, images[ids[j]].name) for i, j in pairs] + + logger.info(f"Found {len(pairs)} pairs.") + with open(output, "w") as f: + f.write("\n".join(" ".join(p) for p in pairs)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", required=True, type=Path) + parser.add_argument("--output", required=True, type=Path) + parser.add_argument("--num_matched", required=True, type=int) + parser.add_argument( + "--rotation_threshold", default=DEFAULT_ROT_THRESH, type=float + ) + args = parser.parse_args() + main(**args.__dict__) diff --git a/hloc/pairs_from_retrieval.py b/hloc/pairs_from_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..6948fe64bdc467946f07a3376aa5d6cc38474859 --- /dev/null +++ b/hloc/pairs_from_retrieval.py @@ -0,0 +1,137 @@ +import argparse +import collections.abc as collections +from pathlib import Path +from typing import Optional + +import h5py +import numpy as np +import torch + +from . import logger +from .utils.io import list_h5_names +from .utils.parsers import parse_image_lists +from .utils.read_write_model import read_images_binary + + +def parse_names(prefix, names, names_all): + if prefix is not None: + if not isinstance(prefix, str): + prefix = tuple(prefix) + names = [n for n in names_all if n.startswith(prefix)] + if len(names) == 0: + raise ValueError( + f"Could not find any image with the prefix `{prefix}`." + ) + elif names is not None: + if isinstance(names, (str, Path)): + names = parse_image_lists(names) + elif isinstance(names, collections.Iterable): + names = list(names) + else: + raise ValueError( + f"Unknown type of image list: {names}." + "Provide either a list or a path to a list file." + ) + else: + names = names_all + return names + + +def get_descriptors(names, path, name2idx=None, key="global_descriptor"): + if name2idx is None: + with h5py.File(str(path), "r", libver="latest") as fd: + desc = [fd[n][key].__array__() for n in names] + else: + desc = [] + for n in names: + with h5py.File(str(path[name2idx[n]]), "r", libver="latest") as fd: + desc.append(fd[n][key].__array__()) + return torch.from_numpy(np.stack(desc, 0)).float() + + +def pairs_from_score_matrix( + scores: torch.Tensor, + invalid: np.array, + num_select: int, + min_score: Optional[float] = None, +): + assert scores.shape == invalid.shape + if isinstance(scores, np.ndarray): + scores = torch.from_numpy(scores) + invalid = torch.from_numpy(invalid).to(scores.device) + if min_score is not None: + invalid |= scores < min_score + scores.masked_fill_(invalid, float("-inf")) + + topk = torch.topk(scores, num_select, dim=1) + indices = topk.indices.cpu().numpy() + valid = topk.values.isfinite().cpu().numpy() + + pairs = [] + for i, j in zip(*np.where(valid)): + pairs.append((i, indices[i, j])) + return pairs + + +def main( + descriptors, + output, + num_matched, + query_prefix=None, + query_list=None, + db_prefix=None, + db_list=None, + db_model=None, + db_descriptors=None, +): + logger.info("Extracting image pairs from a retrieval database.") + + # We handle multiple reference feature files. + # We only assume that names are unique among them and map names to files. + if db_descriptors is None: + db_descriptors = descriptors + if isinstance(db_descriptors, (Path, str)): + db_descriptors = [db_descriptors] + name2db = { + n: i for i, p in enumerate(db_descriptors) for n in list_h5_names(p) + } + db_names_h5 = list(name2db.keys()) + query_names_h5 = list_h5_names(descriptors) + + if db_model: + images = read_images_binary(db_model / "images.bin") + db_names = [i.name for i in images.values()] + else: + db_names = parse_names(db_prefix, db_list, db_names_h5) + if len(db_names) == 0: + raise ValueError("Could not find any database image.") + query_names = parse_names(query_prefix, query_list, query_names_h5) + + device = "cuda" if torch.cuda.is_available() else "cpu" + db_desc = get_descriptors(db_names, db_descriptors, name2db) + query_desc = get_descriptors(query_names, descriptors) + sim = torch.einsum("id,jd->ij", query_desc.to(device), db_desc.to(device)) + + # Avoid self-matching + self = np.array(query_names)[:, None] == np.array(db_names)[None] + pairs = pairs_from_score_matrix(sim, self, num_matched, min_score=0) + pairs = [(query_names[i], db_names[j]) for i, j in pairs] + + logger.info(f"Found {len(pairs)} pairs.") + with open(output, "w") as f: + f.write("\n".join(" ".join([i, j]) for i, j in pairs)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--descriptors", type=Path, required=True) + parser.add_argument("--output", type=Path, required=True) + parser.add_argument("--num_matched", type=int, required=True) + parser.add_argument("--query_prefix", type=str, nargs="+") + parser.add_argument("--query_list", type=Path) + parser.add_argument("--db_prefix", type=str, nargs="+") + parser.add_argument("--db_list", type=Path) + parser.add_argument("--db_model", type=Path) + parser.add_argument("--db_descriptors", type=Path) + args = parser.parse_args() + main(**args.__dict__) diff --git a/hloc/pipelines/4Seasons/README.md b/hloc/pipelines/4Seasons/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ad23ac8348ae9f0963611bc9a342240d5ae97255 --- /dev/null +++ b/hloc/pipelines/4Seasons/README.md @@ -0,0 +1,43 @@ +# 4Seasons dataset + +This pipeline localizes sequences from the [4Seasons dataset](https://arxiv.org/abs/2009.06364) and can reproduce our winning submission to the challenge of the [ECCV 2020 Workshop on Map-based Localization for Autonomous Driving](https://sites.google.com/view/mlad-eccv2020/home). + +## Installation + +Download the sequences from the [challenge webpage](https://sites.google.com/view/mlad-eccv2020/challenge) and run: +```bash +unzip recording_2020-04-07_10-20-32.zip -d datasets/4Seasons/reference +unzip recording_2020-03-24_17-36-22.zip -d datasets/4Seasons/training +unzip recording_2020-03-03_12-03-23.zip -d datasets/4Seasons/validation +unzip recording_2020-03-24_17-45-31.zip -d datasets/4Seasons/test0 +unzip recording_2020-04-23_19-37-00.zip -d datasets/4Seasons/test1 +``` +Note that the provided scripts might modify the dataset files by deleting unused images to speed up the feature extraction + +## Pipeline + +The process is presented in our workshop talk, whose recording can be found [here](https://youtu.be/M-X6HX1JxYk?t=5245). + +We first triangulate a 3D model from the given poses of the reference sequence: +```bash +python3 -m hloc.pipelines.4Seasons.prepare_reference +``` + +We then relocalize a given sequence: +```bash +python3 -m hloc.pipelines.4Seasons.localize --sequence [training|validation|test0|test1] +``` + +The final submission files can be found in `outputs/4Seasons/submission_hloc+superglue/`. The script will also evaluate these results if the training or validation sequences are selected. + +## Results + +We evaluate the localization recall at distance thresholds 0.1m, 0.2m, and 0.5m. + +| Methods | test0 | test1 | +| -------------------- | ---------------------- | ---------------------- | +| **hloc + SuperGlue** | **91.8 / 97.7 / 99.2** | **67.3 / 93.5 / 98.7** | +| Baseline SuperGlue | 21.2 / 33.9 / 60.0 | 12.4 / 26.5 / 54.4 | +| Baseline R2D2 | 21.5 / 33.1 / 53.0 | 12.3 / 23.7 / 42.0 | +| Baseline D2Net | 12.5 / 29.3 / 56.7 | 7.5 / 21.4 / 47.7 | +| Baseline SuperPoint | 15.5 / 27.5 / 47.5 | 9.0 / 19.4 / 36.4 | diff --git a/hloc/pipelines/4Seasons/__init__.py b/hloc/pipelines/4Seasons/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hloc/pipelines/4Seasons/localize.py b/hloc/pipelines/4Seasons/localize.py new file mode 100644 index 0000000000000000000000000000000000000000..50ed957bcf159915d0be98fa9b54c8bce0059b56 --- /dev/null +++ b/hloc/pipelines/4Seasons/localize.py @@ -0,0 +1,89 @@ +import argparse +from pathlib import Path + +from ... import extract_features, localize_sfm, logger, match_features +from .utils import ( + delete_unused_images, + evaluate_submission, + generate_localization_pairs, + generate_query_lists, + get_timestamps, + prepare_submission, +) + +relocalization_files = { + "training": "RelocalizationFilesTrain//relocalizationFile_recording_2020-03-24_17-36-22.txt", # noqa: E501 + "validation": "RelocalizationFilesVal/relocalizationFile_recording_2020-03-03_12-03-23.txt", # noqa: E501 + "test0": "RelocalizationFilesTest/relocalizationFile_recording_2020-03-24_17-45-31_*.txt", # noqa: E501 + "test1": "RelocalizationFilesTest/relocalizationFile_recording_2020-04-23_19-37-00_*.txt", # noqa: E501 +} + +parser = argparse.ArgumentParser() +parser.add_argument( + "--sequence", + type=str, + required=True, + choices=["training", "validation", "test0", "test1"], + help="Sequence to be relocalized.", +) +parser.add_argument( + "--dataset", + type=Path, + default="datasets/4Seasons", + help="Path to the dataset, default: %(default)s", +) +parser.add_argument( + "--outputs", + type=Path, + default="outputs/4Seasons", + help="Path to the output directory, default: %(default)s", +) +args = parser.parse_args() +sequence = args.sequence + +data_dir = args.dataset +ref_dir = data_dir / "reference" +assert ref_dir.exists(), f"{ref_dir} does not exist" +seq_dir = data_dir / sequence +assert seq_dir.exists(), f"{seq_dir} does not exist" +seq_images = seq_dir / "undistorted_images" +reloc = ref_dir / relocalization_files[sequence] + +output_dir = args.outputs +output_dir.mkdir(exist_ok=True, parents=True) +query_list = output_dir / f"{sequence}_queries_with_intrinsics.txt" +ref_pairs = output_dir / "pairs-db-dist20.txt" +ref_sfm = output_dir / "sfm_superpoint+superglue" +results_path = output_dir / f"localization_{sequence}_hloc+superglue.txt" +submission_dir = output_dir / "submission_hloc+superglue" + +num_loc_pairs = 10 +loc_pairs = output_dir / f"pairs-query-{sequence}-dist{num_loc_pairs}.txt" + +fconf = extract_features.confs["superpoint_max"] +mconf = match_features.confs["superglue"] + +# Not all query images that are used for the evaluation +# To save time in feature extraction, we delete unsused images. +timestamps = get_timestamps(reloc, 1) +delete_unused_images(seq_images, timestamps) + +# Generate a list of query images with their intrinsics. +generate_query_lists(timestamps, seq_dir, query_list) + +# Generate the localization pairs from the given reference frames. +generate_localization_pairs(sequence, reloc, num_loc_pairs, ref_pairs, loc_pairs) + +# Extract, match, amd localize. +ffile = extract_features.main(fconf, seq_images, output_dir) +mfile = match_features.main(mconf, loc_pairs, fconf["output"], output_dir) +localize_sfm.main(ref_sfm, query_list, loc_pairs, ffile, mfile, results_path) + +# Convert the absolute poses to relative poses with the reference frames. +submission_dir.mkdir(exist_ok=True) +prepare_submission(results_path, reloc, ref_dir / "poses.txt", submission_dir) + +# If not a test sequence: evaluation the localization accuracy +if "test" not in sequence: + logger.info("Evaluating the relocalization submission...") + evaluate_submission(submission_dir, reloc) diff --git a/hloc/pipelines/4Seasons/prepare_reference.py b/hloc/pipelines/4Seasons/prepare_reference.py new file mode 100644 index 0000000000000000000000000000000000000000..f47aee778ba24ef89a1cc4418f5db9cfab209b9d --- /dev/null +++ b/hloc/pipelines/4Seasons/prepare_reference.py @@ -0,0 +1,51 @@ +import argparse +from pathlib import Path + +from ... import extract_features, match_features, pairs_from_poses, triangulation +from .utils import build_empty_colmap_model, delete_unused_images, get_timestamps + +parser = argparse.ArgumentParser() +parser.add_argument( + "--dataset", + type=Path, + default="datasets/4Seasons", + help="Path to the dataset, default: %(default)s", +) +parser.add_argument( + "--outputs", + type=Path, + default="outputs/4Seasons", + help="Path to the output directory, default: %(default)s", +) +args = parser.parse_args() + +ref_dir = args.dataset / "reference" +assert ref_dir.exists(), f"{ref_dir} does not exist" +ref_images = ref_dir / "undistorted_images" + +output_dir = args.outputs +output_dir.mkdir(exist_ok=True, parents=True) +ref_sfm_empty = output_dir / "sfm_reference_empty" +ref_sfm = output_dir / "sfm_superpoint+superglue" + +num_ref_pairs = 20 +ref_pairs = output_dir / f"pairs-db-dist{num_ref_pairs}.txt" + +fconf = extract_features.confs["superpoint_max"] +mconf = match_features.confs["superglue"] + +# Only reference images that have a pose are used in the pipeline. +# To save time in feature extraction, we delete unsused images. +delete_unused_images(ref_images, get_timestamps(ref_dir / "poses.txt", 0)) + +# Build an empty COLMAP model containing only camera and images +# from the provided poses and intrinsics. +build_empty_colmap_model(ref_dir, ref_sfm_empty) + +# Match reference images that are spatially close. +pairs_from_poses.main(ref_sfm_empty, ref_pairs, num_ref_pairs) + +# Extract, match, and triangulate the reference SfM model. +ffile = extract_features.main(fconf, ref_images, output_dir) +mfile = match_features.main(mconf, ref_pairs, fconf["output"], output_dir) +triangulation.main(ref_sfm, ref_sfm_empty, ref_images, ref_pairs, ffile, mfile) diff --git a/hloc/pipelines/4Seasons/utils.py b/hloc/pipelines/4Seasons/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e5aace9dd9a31b9c39a58691c0c23795313bc462 --- /dev/null +++ b/hloc/pipelines/4Seasons/utils.py @@ -0,0 +1,231 @@ +import glob +import logging +import os +from pathlib import Path + +import numpy as np + +from ...utils.parsers import parse_retrieval +from ...utils.read_write_model import ( + Camera, + Image, + qvec2rotmat, + rotmat2qvec, + write_model, +) + +logger = logging.getLogger(__name__) + + +def get_timestamps(files, idx): + """Extract timestamps from a pose or relocalization file.""" + lines = [] + for p in files.parent.glob(files.name): + with open(p) as f: + lines += f.readlines() + timestamps = set() + for line in lines: + line = line.rstrip("\n") + if line[0] == "#" or line == "": + continue + ts = line.replace(",", " ").split()[idx] + timestamps.add(ts) + return timestamps + + +def delete_unused_images(root, timestamps): + """Delete all images in root if they are not contained in timestamps.""" + images = glob.glob((root / "**/*.png").as_posix(), recursive=True) + deleted = 0 + for image in images: + ts = Path(image).stem + if ts not in timestamps: + os.remove(image) + deleted += 1 + logger.info(f"Deleted {deleted} images in {root}.") + + +def camera_from_calibration_file(id_, path): + """Create a COLMAP camera from an MLAD calibration file.""" + with open(path, "r") as f: + data = f.readlines() + model, fx, fy, cx, cy = data[0].split()[:5] + width, height = data[1].split() + assert model == "Pinhole" + model_name = "PINHOLE" + params = [float(i) for i in [fx, fy, cx, cy]] + camera = Camera( + id=id_, model=model_name, width=int(width), height=int(height), params=params + ) + return camera + + +def parse_poses(path, colmap=False): + """Parse a list of poses in COLMAP or MLAD quaternion convention.""" + poses = [] + with open(path) as f: + for line in f.readlines(): + line = line.rstrip("\n") + if line[0] == "#" or line == "": + continue + data = line.replace(",", " ").split() + ts, p = data[0], np.array(data[1:], float) + if colmap: + q, t = np.split(p, [4]) + else: + t, q = np.split(p, [3]) + q = q[[3, 0, 1, 2]] # xyzw to wxyz + R = qvec2rotmat(q) + poses.append((ts, R, t)) + return poses + + +def parse_relocalization(path, has_poses=False): + """Parse a relocalization file, possibly with poses.""" + reloc = [] + with open(path) as f: + for line in f.readlines(): + line = line.rstrip("\n") + if line[0] == "#" or line == "": + continue + data = line.replace(",", " ").split() + out = data[:2] # ref_ts, q_ts + if has_poses: + assert len(data) == 9 + t, q = np.split(np.array(data[2:], float), [3]) + q = q[[3, 0, 1, 2]] # xyzw to wxyz + R = qvec2rotmat(q) + out += [R, t] + reloc.append(out) + return reloc + + +def build_empty_colmap_model(root, sfm_dir): + """Build a COLMAP model with images and cameras only.""" + calibration = "Calibration/undistorted_calib_{}.txt" + cam0 = camera_from_calibration_file(0, root / calibration.format(0)) + cam1 = camera_from_calibration_file(1, root / calibration.format(1)) + cameras = {0: cam0, 1: cam1} + + T_0to1 = np.loadtxt(root / "Calibration/undistorted_calib_stereo.txt") + poses = parse_poses(root / "poses.txt") + images = {} + id_ = 0 + for ts, R_cam0_to_w, t_cam0_to_w in poses: + R_w_to_cam0 = R_cam0_to_w.T + t_w_to_cam0 = -(R_w_to_cam0 @ t_cam0_to_w) + + R_w_to_cam1 = T_0to1[:3, :3] @ R_w_to_cam0 + t_w_to_cam1 = T_0to1[:3, :3] @ t_w_to_cam0 + T_0to1[:3, 3] + + for idx, (R_w_to_cam, t_w_to_cam) in enumerate( + zip([R_w_to_cam0, R_w_to_cam1], [t_w_to_cam0, t_w_to_cam1]) + ): + image = Image( + id=id_, + qvec=rotmat2qvec(R_w_to_cam), + tvec=t_w_to_cam, + camera_id=idx, + name=f"cam{idx}/{ts}.png", + xys=np.zeros((0, 2), float), + point3D_ids=np.full(0, -1, int), + ) + images[id_] = image + id_ += 1 + + sfm_dir.mkdir(exist_ok=True, parents=True) + write_model(cameras, images, {}, path=str(sfm_dir), ext=".bin") + + +def generate_query_lists(timestamps, seq_dir, out_path): + """Create a list of query images with intrinsics from timestamps.""" + cam0 = camera_from_calibration_file( + 0, seq_dir / "Calibration/undistorted_calib_0.txt" + ) + intrinsics = [cam0.model, cam0.width, cam0.height] + cam0.params + intrinsics = [str(p) for p in intrinsics] + data = map(lambda ts: " ".join([f"cam0/{ts}.png"] + intrinsics), timestamps) + with open(out_path, "w") as f: + f.write("\n".join(data)) + + +def generate_localization_pairs(sequence, reloc, num, ref_pairs, out_path): + """Create the matching pairs for the localization. + We simply lookup the corresponding reference frame + and extract its `num` closest frames from the existing pair list. + """ + if "test" in sequence: + # hard pairs will be overwritten by easy ones if available + relocs = [str(reloc).replace("*", d) for d in ["hard", "moderate", "easy"]] + else: + relocs = [reloc] + query_to_ref_ts = {} + for reloc in relocs: + with open(reloc, "r") as f: + for line in f.readlines(): + line = line.rstrip("\n") + if line[0] == "#" or line == "": + continue + ref_ts, q_ts = line.split()[:2] + query_to_ref_ts[q_ts] = ref_ts + + ts_to_name = "cam0/{}.png".format + ref_pairs = parse_retrieval(ref_pairs) + loc_pairs = [] + for q_ts, ref_ts in query_to_ref_ts.items(): + ref_name = ts_to_name(ref_ts) + selected = [ref_name] + ref_pairs[ref_name][: num - 1] + loc_pairs.extend([" ".join((ts_to_name(q_ts), s)) for s in selected]) + with open(out_path, "w") as f: + f.write("\n".join(loc_pairs)) + + +def prepare_submission(results, relocs, poses_path, out_dir): + """Obtain relative poses from estimated absolute and reference poses.""" + gt_poses = parse_poses(poses_path) + all_T_ref0_to_w = {ts: (R, t) for ts, R, t in gt_poses} + + pred_poses = parse_poses(results, colmap=True) + all_T_w_to_q0 = {Path(name).stem: (R, t) for name, R, t in pred_poses} + + for reloc in relocs.parent.glob(relocs.name): + relative_poses = [] + reloc_ts = parse_relocalization(reloc) + for ref_ts, q_ts in reloc_ts: + R_w_to_q0, t_w_to_q0 = all_T_w_to_q0[q_ts] + R_ref0_to_w, t_ref0_to_w = all_T_ref0_to_w[ref_ts] + + R_ref0_to_q0 = R_w_to_q0 @ R_ref0_to_w + t_ref0_to_q0 = R_w_to_q0 @ t_ref0_to_w + t_w_to_q0 + + tvec = t_ref0_to_q0.tolist() + qvec = rotmat2qvec(R_ref0_to_q0)[[1, 2, 3, 0]] # wxyz to xyzw + + out = [ref_ts, q_ts] + list(map(str, tvec)) + list(map(str, qvec)) + relative_poses.append(" ".join(out)) + + out_path = out_dir / reloc.name + with open(out_path, "w") as f: + f.write("\n".join(relative_poses)) + logger.info(f"Submission file written to {out_path}.") + + +def evaluate_submission(submission_dir, relocs, ths=[0.1, 0.2, 0.5]): + """Compute the relocalization recall from predicted and ground truth poses.""" + for reloc in relocs.parent.glob(relocs.name): + poses_gt = parse_relocalization(reloc, has_poses=True) + poses_pred = parse_relocalization(submission_dir / reloc.name, has_poses=True) + poses_pred = {(ref_ts, q_ts): (R, t) for ref_ts, q_ts, R, t in poses_pred} + + error = [] + for ref_ts, q_ts, R_gt, t_gt in poses_gt: + R, t = poses_pred[(ref_ts, q_ts)] + e = np.linalg.norm(t - t_gt) + error.append(e) + + error = np.array(error) + recall = [np.mean(error <= th) for th in ths] + s = f"Relocalization evaluation {submission_dir.name}/{reloc.name}\n" + s += " / ".join([f"{th:>7}m" for th in ths]) + "\n" + s += " / ".join([f"{100*r:>7.3f}%" for r in recall]) + logger.info(s) diff --git a/hloc/pipelines/7Scenes/README.md b/hloc/pipelines/7Scenes/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2124779c43ec8d1ffc552e07790d39c3578526a9 --- /dev/null +++ b/hloc/pipelines/7Scenes/README.md @@ -0,0 +1,65 @@ +# 7Scenes dataset + +## Installation + +Download the images from the [7Scenes project page](https://www.microsoft.com/en-us/research/project/rgb-d-dataset-7-scenes/): +```bash +export dataset=datasets/7scenes +for scene in chess fire heads office pumpkin redkitchen stairs; \ +do wget http://download.microsoft.com/download/2/8/5/28564B23-0828-408F-8631-23B1EFF1DAC8/$scene.zip -P $dataset \ +&& unzip $dataset/$scene.zip -d $dataset && unzip $dataset/$scene/'*.zip' -d $dataset/$scene; done +``` + +Download the SIFT SfM models and DenseVLAD image pairs, courtesy of Torsten Sattler: +```bash +function download { +wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate "https://docs.google.com/uc?export=download&id=$1" -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=$1" -O $2 && rm -rf /tmp/cookies.txt +unzip $2 -d $dataset && rm $2; +} +download 1cu6KUR7WHO7G4EO49Qi3HEKU6n_yYDjb $dataset/7scenes_sfm_triangulated.zip +download 1IbS2vLmxr1N0f3CEnd_wsYlgclwTyvB1 $dataset/7scenes_densevlad_retrieval_top_10.zip +``` + +Download the rendered depth maps, courtesy of Eric Brachmann for [DSAC\*](https://github.com/vislearn/dsacstar): +```bash +wget https://heidata.uni-heidelberg.de/api/access/datafile/4037 -O $dataset/7scenes_rendered_depth.tar.gz +mkdir $dataset/depth/ +tar xzf $dataset/7scenes_rendered_depth.tar.gz -C $dataset/depth/ && rm $dataset/7scenes_rendered_depth.tar.gz +``` + +## Pipeline + +```bash +python3 -m hloc.pipelines.7Scenes.pipeline [--use_dense_depth] +``` +By default, hloc triangulates a sparse point cloud that can be noisy in indoor environements due to image noise and lack of texture. With the flag `--use_dense_depth`, the pipeline improves the accuracy of the sparse point cloud using dense depth maps provided by the dataset. The original depth maps captured by the RGBD sensor are miscalibrated, so we use depth maps rendered from the mesh obtained by fusing the RGBD data. + +## Results +We report the median error in translation/rotation in cm/deg over all scenes: +| Method \ Scene | Chess | Fire | Heads | Office | Pumpkin | Kitchen | Stairs | +| ------------------------------- | -------------- | -------------- | -------------- | -------------- | -------------- | -------------- | ---------- | +| Active Search | 3/0.87 | **2**/1.01 | **1**/0.82 | 4/1.15 | 7/1.69 | 5/1.72 | 4/**1.01** | +| DSAC* | **2**/1.10 | **2**/1.24 | **1**/1.82 | **3**/1.15 | **4**/1.34 | 4/1.68 | **3**/1.16 | +| **SuperPoint+SuperGlue** (sfm) | **2**/0.84 | **2**/0.93 | **1**/**0.74** | **3**/0.92 | 5/1.27 | 4/1.40 | 5/1.47 | +| **SuperPoint+SuperGlue** (RGBD) | **2**/**0.80** | **2**/**0.77** | **1**/0.79 | **3**/**0.80** | **4**/**1.07** | **3**/**1.13** | 4/1.15 | + +## Citation +Please cite the following paper if you use the 7Scenes dataset: +``` +@inproceedings{shotton2013scene, + title={Scene coordinate regression forests for camera relocalization in {RGB-D} images}, + author={Shotton, Jamie and Glocker, Ben and Zach, Christopher and Izadi, Shahram and Criminisi, Antonio and Fitzgibbon, Andrew}, + booktitle={CVPR}, + year={2013} +} +``` + +Also cite DSAC* if you use dense depth maps with the flag `--use_dense_depth`: +``` +@article{brachmann2020dsacstar, + title={Visual Camera Re-Localization from {RGB} and {RGB-D} Images Using {DSAC}}, + author={Brachmann, Eric and Rother, Carsten}, + journal={TPAMI}, + year={2021} +} +``` diff --git a/hloc/pipelines/7Scenes/__init__.py b/hloc/pipelines/7Scenes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hloc/pipelines/7Scenes/create_gt_sfm.py b/hloc/pipelines/7Scenes/create_gt_sfm.py new file mode 100644 index 0000000000000000000000000000000000000000..95dfa461e17de99e0bdde0c52c5f02568c4fbab3 --- /dev/null +++ b/hloc/pipelines/7Scenes/create_gt_sfm.py @@ -0,0 +1,134 @@ +from pathlib import Path + +import numpy as np +import PIL.Image +import pycolmap +import torch +from tqdm import tqdm + +from ...utils.read_write_model import read_model, write_model + + +def scene_coordinates(p2D, R_w2c, t_w2c, depth, camera): + assert len(depth) == len(p2D) + p2D_norm = np.stack(pycolmap.Camera(camera._asdict()).image_to_world(p2D)) + p2D_h = np.concatenate([p2D_norm, np.ones_like(p2D_norm[:, :1])], 1) + p3D_c = p2D_h * depth[:, None] + p3D_w = (p3D_c - t_w2c) @ R_w2c + return p3D_w + + +def interpolate_depth(depth, kp): + h, w = depth.shape + kp = kp / np.array([[w - 1, h - 1]]) * 2 - 1 + assert np.all(kp > -1) and np.all(kp < 1) + depth = torch.from_numpy(depth)[None, None] + kp = torch.from_numpy(kp)[None, None] + grid_sample = torch.nn.functional.grid_sample + + # To maximize the number of points that have depth: + # do bilinear interpolation first and then nearest for the remaining points + interp_lin = grid_sample(depth, kp, align_corners=True, mode="bilinear")[0, :, 0] + interp_nn = torch.nn.functional.grid_sample( + depth, kp, align_corners=True, mode="nearest" + )[0, :, 0] + interp = torch.where(torch.isnan(interp_lin), interp_nn, interp_lin) + valid = ~torch.any(torch.isnan(interp), 0) + + interp_depth = interp.T.numpy().flatten() + valid = valid.numpy() + return interp_depth, valid + + +def image_path_to_rendered_depth_path(image_name): + parts = image_name.split("/") + name = "_".join(["".join(parts[0].split("-")), parts[1]]) + name = name.replace("color", "pose") + name = name.replace("png", "depth.tiff") + return name + + +def project_to_image(p3D, R, t, camera, eps: float = 1e-4, pad: int = 1): + p3D = (p3D @ R.T) + t + visible = p3D[:, -1] >= eps # keep points in front of the camera + p2D_norm = p3D[:, :-1] / p3D[:, -1:].clip(min=eps) + p2D = np.stack(pycolmap.Camera(camera._asdict()).world_to_image(p2D_norm)) + size = np.array([camera.width - pad - 1, camera.height - pad - 1]) + valid = np.all((p2D >= pad) & (p2D <= size), -1) + valid &= visible + return p2D[valid], valid + + +def correct_sfm_with_gt_depth(sfm_path, depth_folder_path, output_path): + cameras, images, points3D = read_model(sfm_path) + for imgid, img in tqdm(images.items()): + image_name = img.name + depth_name = image_path_to_rendered_depth_path(image_name) + + depth = PIL.Image.open(Path(depth_folder_path) / depth_name) + depth = np.array(depth).astype("float64") + depth = depth / 1000.0 # mm to meter + depth[(depth == 0.0) | (depth > 1000.0)] = np.nan + + R_w2c, t_w2c = img.qvec2rotmat(), img.tvec + camera = cameras[img.camera_id] + p3D_ids = img.point3D_ids + p3Ds = np.stack([points3D[i].xyz for i in p3D_ids[p3D_ids != -1]], 0) + + p2Ds, valids_projected = project_to_image(p3Ds, R_w2c, t_w2c, camera) + invalid_p3D_ids = p3D_ids[p3D_ids != -1][~valids_projected] + interp_depth, valids_backprojected = interpolate_depth(depth, p2Ds) + scs = scene_coordinates( + p2Ds[valids_backprojected], + R_w2c, + t_w2c, + interp_depth[valids_backprojected], + camera, + ) + invalid_p3D_ids = np.append( + invalid_p3D_ids, + p3D_ids[p3D_ids != -1][valids_projected][~valids_backprojected], + ) + for p3did in invalid_p3D_ids: + if p3did == -1: + continue + else: + obs_imgids = points3D[p3did].image_ids + invalid_imgids = list(np.where(obs_imgids == img.id)[0]) + points3D[p3did] = points3D[p3did]._replace( + image_ids=np.delete(obs_imgids, invalid_imgids), + point2D_idxs=np.delete( + points3D[p3did].point2D_idxs, invalid_imgids + ), + ) + + new_p3D_ids = p3D_ids.copy() + sub_p3D_ids = new_p3D_ids[new_p3D_ids != -1] + valids = np.ones(np.count_nonzero(new_p3D_ids != -1), dtype=bool) + valids[~valids_projected] = False + valids[valids_projected] = valids_backprojected + sub_p3D_ids[~valids] = -1 + new_p3D_ids[new_p3D_ids != -1] = sub_p3D_ids + img = img._replace(point3D_ids=new_p3D_ids) + + assert len(img.point3D_ids[img.point3D_ids != -1]) == len( + scs + ), f"{len(scs)}, {len(img.point3D_ids[img.point3D_ids != -1])}" + for i, p3did in enumerate(img.point3D_ids[img.point3D_ids != -1]): + points3D[p3did] = points3D[p3did]._replace(xyz=scs[i]) + images[imgid] = img + + output_path.mkdir(parents=True, exist_ok=True) + write_model(cameras, images, points3D, output_path) + + +if __name__ == "__main__": + dataset = Path("datasets/7scenes") + outputs = Path("outputs/7Scenes") + + SCENES = ["chess", "fire", "heads", "office", "pumpkin", "redkitchen", "stairs"] + for scene in SCENES: + sfm_path = outputs / scene / "sfm_superpoint+superglue" + depth_path = dataset / f"depth/7scenes_{scene}/train/depth" + output_path = outputs / scene / "sfm_superpoint+superglue+depth" + correct_sfm_with_gt_depth(sfm_path, depth_path, output_path) diff --git a/hloc/pipelines/7Scenes/pipeline.py b/hloc/pipelines/7Scenes/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..6fc28c6d29828ddbe4b9efaf1678be624a998819 --- /dev/null +++ b/hloc/pipelines/7Scenes/pipeline.py @@ -0,0 +1,139 @@ +import argparse +from pathlib import Path + +from ... import ( + extract_features, + localize_sfm, + logger, + match_features, + pairs_from_covisibility, + triangulation, +) +from ..Cambridge.utils import create_query_list_with_intrinsics, evaluate +from .create_gt_sfm import correct_sfm_with_gt_depth +from .utils import create_reference_sfm + +SCENES = ["chess", "fire", "heads", "office", "pumpkin", "redkitchen", "stairs"] + + +def run_scene( + images, + gt_dir, + retrieval, + outputs, + results, + num_covis, + use_dense_depth, + depth_dir=None, +): + outputs.mkdir(exist_ok=True, parents=True) + ref_sfm_sift = outputs / "sfm_sift" + ref_sfm = outputs / "sfm_superpoint+superglue" + query_list = outputs / "query_list_with_intrinsics.txt" + + feature_conf = { + "output": "feats-superpoint-n4096-r1024", + "model": { + "name": "superpoint", + "nms_radius": 3, + "max_keypoints": 4096, + }, + "preprocessing": { + "globs": ["*.color.png"], + "grayscale": True, + "resize_max": 1024, + }, + } + matcher_conf = match_features.confs["superglue"] + matcher_conf["model"]["sinkhorn_iterations"] = 5 + + test_list = gt_dir / "list_test.txt" + create_reference_sfm(gt_dir, ref_sfm_sift, test_list) + create_query_list_with_intrinsics(gt_dir, query_list, test_list) + + features = extract_features.main(feature_conf, images, outputs, as_half=True) + + sfm_pairs = outputs / f"pairs-db-covis{num_covis}.txt" + pairs_from_covisibility.main(ref_sfm_sift, sfm_pairs, num_matched=num_covis) + sfm_matches = match_features.main( + matcher_conf, sfm_pairs, feature_conf["output"], outputs + ) + if not (use_dense_depth and ref_sfm.exists()): + triangulation.main( + ref_sfm, ref_sfm_sift, images, sfm_pairs, features, sfm_matches + ) + if use_dense_depth: + assert depth_dir is not None + ref_sfm_fix = outputs / "sfm_superpoint+superglue+depth" + correct_sfm_with_gt_depth(ref_sfm, depth_dir, ref_sfm_fix) + ref_sfm = ref_sfm_fix + + loc_matches = match_features.main( + matcher_conf, retrieval, feature_conf["output"], outputs + ) + + localize_sfm.main( + ref_sfm, + query_list, + retrieval, + features, + loc_matches, + results, + covisibility_clustering=False, + prepend_camera_name=True, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--scenes", default=SCENES, choices=SCENES, nargs="+") + parser.add_argument("--overwrite", action="store_true") + parser.add_argument( + "--dataset", + type=Path, + default="datasets/7scenes", + help="Path to the dataset, default: %(default)s", + ) + parser.add_argument( + "--outputs", + type=Path, + default="outputs/7scenes", + help="Path to the output directory, default: %(default)s", + ) + parser.add_argument("--use_dense_depth", action="store_true") + parser.add_argument( + "--num_covis", + type=int, + default=30, + help="Number of image pairs for SfM, default: %(default)s", + ) + args = parser.parse_args() + + gt_dirs = args.dataset / "7scenes_sfm_triangulated/{scene}/triangulated" + retrieval_dirs = args.dataset / "7scenes_densevlad_retrieval_top_10" + + all_results = {} + for scene in args.scenes: + logger.info(f'Working on scene "{scene}".') + results = ( + args.outputs + / scene + / "results_{}.txt".format("dense" if args.use_dense_depth else "sparse") + ) + if args.overwrite or not results.exists(): + run_scene( + args.dataset / scene, + Path(str(gt_dirs).format(scene=scene)), + retrieval_dirs / f"{scene}_top10.txt", + args.outputs / scene, + results, + args.num_covis, + args.use_dense_depth, + depth_dir=args.dataset / f"depth/7scenes_{scene}/train/depth", + ) + all_results[scene] = results + + for scene in args.scenes: + logger.info(f'Evaluate scene "{scene}".') + gt_dir = Path(str(gt_dirs).format(scene=scene)) + evaluate(gt_dir, all_results[scene], gt_dir / "list_test.txt") diff --git a/hloc/pipelines/7Scenes/utils.py b/hloc/pipelines/7Scenes/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1cb021286e550de6fe89e03370c11e3f7d567c5f --- /dev/null +++ b/hloc/pipelines/7Scenes/utils.py @@ -0,0 +1,34 @@ +import logging + +import numpy as np + +from hloc.utils.read_write_model import read_model, write_model + +logger = logging.getLogger(__name__) + + +def create_reference_sfm(full_model, ref_model, blacklist=None, ext=".bin"): + """Create a new COLMAP model with only training images.""" + logger.info("Creating the reference model.") + ref_model.mkdir(exist_ok=True) + cameras, images, points3D = read_model(full_model, ext) + + if blacklist is not None: + with open(blacklist, "r") as f: + blacklist = f.read().rstrip().split("\n") + + images_ref = dict() + for id_, image in images.items(): + if blacklist and image.name in blacklist: + continue + images_ref[id_] = image + + points3D_ref = dict() + for id_, point3D in points3D.items(): + ref_ids = [i for i in point3D.image_ids if i in images_ref] + if len(ref_ids) == 0: + continue + points3D_ref[id_] = point3D._replace(image_ids=np.array(ref_ids)) + + write_model(cameras, images_ref, points3D_ref, ref_model, ".bin") + logger.info(f"Kept {len(images_ref)} images out of {len(images)}.") diff --git a/hloc/pipelines/Aachen/README.md b/hloc/pipelines/Aachen/README.md new file mode 100644 index 0000000000000000000000000000000000000000..57b66d6ad1e5cdb3e74c6c1866d394d487c608d1 --- /dev/null +++ b/hloc/pipelines/Aachen/README.md @@ -0,0 +1,16 @@ +# Aachen-Day-Night dataset + +## Installation + +Download the dataset from [visuallocalization.net](https://www.visuallocalization.net): +```bash +export dataset=datasets/aachen +wget -r -np -nH -R "index.html*,aachen_v1_1.zip" --cut-dirs=4 https://data.ciirc.cvut.cz/public/projects/2020VisualLocalization/Aachen-Day-Night/ -P $dataset +unzip $dataset/images/database_and_query_images.zip -d $dataset +``` + +## Pipeline + +```bash +python3 -m hloc.pipelines.Aachen.pipeline +``` diff --git a/hloc/pipelines/Aachen/__init__.py b/hloc/pipelines/Aachen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hloc/pipelines/Aachen/pipeline.py b/hloc/pipelines/Aachen/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..e31ce7255ce5e178f505a6fb415b5bbf13b76879 --- /dev/null +++ b/hloc/pipelines/Aachen/pipeline.py @@ -0,0 +1,109 @@ +import argparse +from pathlib import Path +from pprint import pformat + +from ... import ( + colmap_from_nvm, + extract_features, + localize_sfm, + logger, + match_features, + pairs_from_covisibility, + pairs_from_retrieval, + triangulation, +) + + +def run(args): + # Setup the paths + dataset = args.dataset + images = dataset / "images_upright/" + + outputs = args.outputs # where everything will be saved + sift_sfm = outputs / "sfm_sift" # from which we extract the reference poses + reference_sfm = outputs / "sfm_superpoint+superglue" # the SfM model we will build + sfm_pairs = ( + outputs / f"pairs-db-covis{args.num_covis}.txt" + ) # top-k most covisible in SIFT model + loc_pairs = ( + outputs / f"pairs-query-netvlad{args.num_loc}.txt" + ) # top-k retrieved by NetVLAD + results = outputs / f"Aachen_hloc_superpoint+superglue_netvlad{args.num_loc}.txt" + + # list the standard configurations available + logger.info("Configs for feature extractors:\n%s", pformat(extract_features.confs)) + logger.info("Configs for feature matchers:\n%s", pformat(match_features.confs)) + + # pick one of the configurations for extraction and matching + retrieval_conf = extract_features.confs["netvlad"] + feature_conf = extract_features.confs["superpoint_aachen"] + matcher_conf = match_features.confs["superglue"] + + features = extract_features.main(feature_conf, images, outputs) + + colmap_from_nvm.main( + dataset / "3D-models/aachen_cvpr2018_db.nvm", + dataset / "3D-models/database_intrinsics.txt", + dataset / "aachen.db", + sift_sfm, + ) + pairs_from_covisibility.main(sift_sfm, sfm_pairs, num_matched=args.num_covis) + sfm_matches = match_features.main( + matcher_conf, sfm_pairs, feature_conf["output"], outputs + ) + + triangulation.main( + reference_sfm, sift_sfm, images, sfm_pairs, features, sfm_matches + ) + + global_descriptors = extract_features.main(retrieval_conf, images, outputs) + pairs_from_retrieval.main( + global_descriptors, + loc_pairs, + args.num_loc, + query_prefix="query", + db_model=reference_sfm, + ) + loc_matches = match_features.main( + matcher_conf, loc_pairs, feature_conf["output"], outputs + ) + + localize_sfm.main( + reference_sfm, + dataset / "queries/*_time_queries_with_intrinsics.txt", + loc_pairs, + features, + loc_matches, + results, + covisibility_clustering=False, + ) # not required with SuperPoint+SuperGlue + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--dataset", + type=Path, + default="datasets/aachen", + help="Path to the dataset, default: %(default)s", + ) + parser.add_argument( + "--outputs", + type=Path, + default="outputs/aachen", + help="Path to the output directory, default: %(default)s", + ) + parser.add_argument( + "--num_covis", + type=int, + default=20, + help="Number of image pairs for SfM, default: %(default)s", + ) + parser.add_argument( + "--num_loc", + type=int, + default=50, + help="Number of image pairs for loc, default: %(default)s", + ) + args = parser.parse_args() + run(args) diff --git a/hloc/pipelines/Aachen_v1_1/README.md b/hloc/pipelines/Aachen_v1_1/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c17e751777b56e36b8633c1eec37ff656f2d3979 --- /dev/null +++ b/hloc/pipelines/Aachen_v1_1/README.md @@ -0,0 +1,17 @@ +# Aachen-Day-Night dataset v1.1 + +## Installation + +Download the dataset from [visuallocalization.net](https://www.visuallocalization.net): +```bash +export dataset=datasets/aachen_v1.1 +wget -r -np -nH -R "index.html*" --cut-dirs=4 https://data.ciirc.cvut.cz/public/projects/2020VisualLocalization/Aachen-Day-Night/ -P $dataset +unzip $dataset/images/database_and_query_images.zip -d $dataset +unzip $dataset/aachen_v1_1.zip -d $dataset +``` + +## Pipeline + +```bash +python3 -m hloc.pipelines.Aachen_v1_1.pipeline +``` diff --git a/hloc/pipelines/Aachen_v1_1/__init__.py b/hloc/pipelines/Aachen_v1_1/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hloc/pipelines/Aachen_v1_1/pipeline.py b/hloc/pipelines/Aachen_v1_1/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..0753604624d31984952942ab5b297a247e4d5123 --- /dev/null +++ b/hloc/pipelines/Aachen_v1_1/pipeline.py @@ -0,0 +1,104 @@ +import argparse +from pathlib import Path +from pprint import pformat + +from ... import ( + extract_features, + localize_sfm, + logger, + match_features, + pairs_from_covisibility, + pairs_from_retrieval, + triangulation, +) + + +def run(args): + # Setup the paths + dataset = args.dataset + images = dataset / "images_upright/" + sift_sfm = dataset / "3D-models/aachen_v_1_1" + + outputs = args.outputs # where everything will be saved + reference_sfm = outputs / "sfm_superpoint+superglue" # the SfM model we will build + sfm_pairs = ( + outputs / f"pairs-db-covis{args.num_covis}.txt" + ) # top-k most covisible in SIFT model + loc_pairs = ( + outputs / f"pairs-query-netvlad{args.num_loc}.txt" + ) # top-k retrieved by NetVLAD + results = ( + outputs / f"Aachen-v1.1_hloc_superpoint+superglue_netvlad{args.num_loc}.txt" + ) + + # list the standard configurations available + logger.info("Configs for feature extractors:\n%s", pformat(extract_features.confs)) + logger.info("Configs for feature matchers:\n%s", pformat(match_features.confs)) + + # pick one of the configurations for extraction and matching + retrieval_conf = extract_features.confs["netvlad"] + feature_conf = extract_features.confs["superpoint_max"] + matcher_conf = match_features.confs["superglue"] + + features = extract_features.main(feature_conf, images, outputs) + + pairs_from_covisibility.main(sift_sfm, sfm_pairs, num_matched=args.num_covis) + sfm_matches = match_features.main( + matcher_conf, sfm_pairs, feature_conf["output"], outputs + ) + + triangulation.main( + reference_sfm, sift_sfm, images, sfm_pairs, features, sfm_matches + ) + + global_descriptors = extract_features.main(retrieval_conf, images, outputs) + pairs_from_retrieval.main( + global_descriptors, + loc_pairs, + args.num_loc, + query_prefix="query", + db_model=reference_sfm, + ) + loc_matches = match_features.main( + matcher_conf, loc_pairs, feature_conf["output"], outputs + ) + + localize_sfm.main( + reference_sfm, + dataset / "queries/*_time_queries_with_intrinsics.txt", + loc_pairs, + features, + loc_matches, + results, + covisibility_clustering=False, + ) # not required with SuperPoint+SuperGlue + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--dataset", + type=Path, + default="datasets/aachen_v1.1", + help="Path to the dataset, default: %(default)s", + ) + parser.add_argument( + "--outputs", + type=Path, + default="outputs/aachen_v1.1", + help="Path to the output directory, default: %(default)s", + ) + parser.add_argument( + "--num_covis", + type=int, + default=20, + help="Number of image pairs for SfM, default: %(default)s", + ) + parser.add_argument( + "--num_loc", + type=int, + default=50, + help="Number of image pairs for loc, default: %(default)s", + ) + args = parser.parse_args() + run(args) diff --git a/hloc/pipelines/Aachen_v1_1/pipeline_loftr.py b/hloc/pipelines/Aachen_v1_1/pipeline_loftr.py new file mode 100644 index 0000000000000000000000000000000000000000..9c0a769897604fa9838106c7a38ba585ceeefe5c --- /dev/null +++ b/hloc/pipelines/Aachen_v1_1/pipeline_loftr.py @@ -0,0 +1,104 @@ +import argparse +from pathlib import Path +from pprint import pformat + +from ... import ( + extract_features, + localize_sfm, + logger, + match_dense, + pairs_from_covisibility, + pairs_from_retrieval, + triangulation, +) + + +def run(args): + # Setup the paths + dataset = args.dataset + images = dataset / "images_upright/" + sift_sfm = dataset / "3D-models/aachen_v_1_1" + + outputs = args.outputs # where everything will be saved + outputs.mkdir() + reference_sfm = outputs / "sfm_loftr" # the SfM model we will build + sfm_pairs = ( + outputs / f"pairs-db-covis{args.num_covis}.txt" + ) # top-k most covisible in SIFT model + loc_pairs = ( + outputs / f"pairs-query-netvlad{args.num_loc}.txt" + ) # top-k retrieved by NetVLAD + results = outputs / f"Aachen-v1.1_hloc_loftr_netvlad{args.num_loc}.txt" + + # list the standard configurations available + logger.info("Configs for dense feature matchers:\n%s", pformat(match_dense.confs)) + + # pick one of the configurations for extraction and matching + retrieval_conf = extract_features.confs["netvlad"] + matcher_conf = match_dense.confs["loftr_aachen"] + + pairs_from_covisibility.main(sift_sfm, sfm_pairs, num_matched=args.num_covis) + features, sfm_matches = match_dense.main( + matcher_conf, sfm_pairs, images, outputs, max_kps=8192, overwrite=False + ) + + triangulation.main( + reference_sfm, sift_sfm, images, sfm_pairs, features, sfm_matches + ) + + global_descriptors = extract_features.main(retrieval_conf, images, outputs) + pairs_from_retrieval.main( + global_descriptors, + loc_pairs, + args.num_loc, + query_prefix="query", + db_model=reference_sfm, + ) + features, loc_matches = match_dense.main( + matcher_conf, + loc_pairs, + images, + outputs, + features=features, + max_kps=None, + matches=sfm_matches, + ) + + localize_sfm.main( + reference_sfm, + dataset / "queries/*_time_queries_with_intrinsics.txt", + loc_pairs, + features, + loc_matches, + results, + covisibility_clustering=False, + ) # not required with loftr + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--dataset", + type=Path, + default="datasets/aachen_v1.1", + help="Path to the dataset, default: %(default)s", + ) + parser.add_argument( + "--outputs", + type=Path, + default="outputs/aachen_v1.1", + help="Path to the output directory, default: %(default)s", + ) + parser.add_argument( + "--num_covis", + type=int, + default=20, + help="Number of image pairs for SfM, default: %(default)s", + ) + parser.add_argument( + "--num_loc", + type=int, + default=50, + help="Number of image pairs for loc, default: %(default)s", + ) + args = parser.parse_args() diff --git a/hloc/pipelines/CMU/README.md b/hloc/pipelines/CMU/README.md new file mode 100644 index 0000000000000000000000000000000000000000..566ba352c53ada2a13dce21c8ec1041b56969d03 --- /dev/null +++ b/hloc/pipelines/CMU/README.md @@ -0,0 +1,16 @@ +# Extended CMU Seasons dataset + +## Installation + +Download the dataset from [visuallocalization.net](https://www.visuallocalization.net): +```bash +export dataset=datasets/cmu_extended +wget -r -np -nH -R "index.html*" --cut-dirs=4 https://data.ciirc.cvut.cz/public/projects/2020VisualLocalization/Extended-CMU-Seasons/ -P $dataset +for slice in $dataset/*.tar; do tar -xf $slice -C $dataset && rm $slice; done +``` + +## Pipeline + +```bash +python3 -m hloc.pipelines.CMU.pipeline +``` diff --git a/hloc/pipelines/CMU/__init__.py b/hloc/pipelines/CMU/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hloc/pipelines/CMU/pipeline.py b/hloc/pipelines/CMU/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..4706a05c6aef134dc244419501b22a2ba95ede04 --- /dev/null +++ b/hloc/pipelines/CMU/pipeline.py @@ -0,0 +1,133 @@ +import argparse +from pathlib import Path + +from ... import ( + extract_features, + localize_sfm, + logger, + match_features, + pairs_from_covisibility, + pairs_from_retrieval, + triangulation, +) + +TEST_SLICES = [2, 3, 4, 5, 6, 13, 14, 15, 16, 17, 18, 19, 20, 21] + + +def generate_query_list(dataset, path, slice_): + cameras = {} + with open(dataset / "intrinsics.txt", "r") as f: + for line in f.readlines(): + if line[0] == "#" or line == "\n": + continue + data = line.split() + cameras[data[0]] = data[1:] + assert len(cameras) == 2 + + queries = dataset / f"{slice_}/test-images-{slice_}.txt" + with open(queries, "r") as f: + queries = [q.rstrip("\n") for q in f.readlines()] + + out = [[q] + cameras[q.split("_")[2]] for q in queries] + with open(path, "w") as f: + f.write("\n".join(map(" ".join, out))) + + +def run_slice(slice_, root, outputs, num_covis, num_loc): + dataset = root / slice_ + ref_images = dataset / "database" + query_images = dataset / "query" + sift_sfm = dataset / "sparse" + + outputs = outputs / slice_ + outputs.mkdir(exist_ok=True, parents=True) + query_list = dataset / "queries_with_intrinsics.txt" + sfm_pairs = outputs / f"pairs-db-covis{num_covis}.txt" + loc_pairs = outputs / f"pairs-query-netvlad{num_loc}.txt" + ref_sfm = outputs / "sfm_superpoint+superglue" + results = outputs / f"CMU_hloc_superpoint+superglue_netvlad{num_loc}.txt" + + # pick one of the configurations for extraction and matching + retrieval_conf = extract_features.confs["netvlad"] + feature_conf = extract_features.confs["superpoint_aachen"] + matcher_conf = match_features.confs["superglue"] + + pairs_from_covisibility.main(sift_sfm, sfm_pairs, num_matched=num_covis) + features = extract_features.main(feature_conf, ref_images, outputs, as_half=True) + sfm_matches = match_features.main( + matcher_conf, sfm_pairs, feature_conf["output"], outputs + ) + triangulation.main(ref_sfm, sift_sfm, ref_images, sfm_pairs, features, sfm_matches) + + generate_query_list(root, query_list, slice_) + global_descriptors = extract_features.main(retrieval_conf, ref_images, outputs) + global_descriptors = extract_features.main(retrieval_conf, query_images, outputs) + pairs_from_retrieval.main( + global_descriptors, loc_pairs, num_loc, query_list=query_list, db_model=ref_sfm + ) + + features = extract_features.main(feature_conf, query_images, outputs, as_half=True) + loc_matches = match_features.main( + matcher_conf, loc_pairs, feature_conf["output"], outputs + ) + + localize_sfm.main( + ref_sfm, + dataset / "queries/*_time_queries_with_intrinsics.txt", + loc_pairs, + features, + loc_matches, + results, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--slices", + type=str, + default="*", + help="a single number, an interval (e.g. 2-6), " + "or a Python-style list or int (e.g. [2, 3, 4]", + ) + parser.add_argument( + "--dataset", + type=Path, + default="datasets/cmu_extended", + help="Path to the dataset, default: %(default)s", + ) + parser.add_argument( + "--outputs", + type=Path, + default="outputs/aachen_extended", + help="Path to the output directory, default: %(default)s", + ) + parser.add_argument( + "--num_covis", + type=int, + default=20, + help="Number of image pairs for SfM, default: %(default)s", + ) + parser.add_argument( + "--num_loc", + type=int, + default=10, + help="Number of image pairs for loc, default: %(default)s", + ) + args = parser.parse_args() + + if args.slice == "*": + slices = TEST_SLICES + if "-" in args.slices: + min_, max_ = args.slices.split("-") + slices = list(range(int(min_), int(max_) + 1)) + else: + slices = eval(args.slices) + if isinstance(slices, int): + slices = [slices] + + for slice_ in slices: + logger.info("Working on slice %s.", slice_) + run_slice( + f"slice{slice_}", args.dataset, args.outputs, args.num_covis, args.num_loc + ) diff --git a/hloc/pipelines/Cambridge/README.md b/hloc/pipelines/Cambridge/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d5ae07b71c48a98fa9235f0dfb0234c3c18c74c6 --- /dev/null +++ b/hloc/pipelines/Cambridge/README.md @@ -0,0 +1,47 @@ +# Cambridge Landmarks dataset + +## Installation + +Download the dataset from the [PoseNet project page](http://mi.eng.cam.ac.uk/projects/relocalisation/): +```bash +export dataset=datasets/cambridge +export scenes=( "KingsCollege" "OldHospital" "StMarysChurch" "ShopFacade" "GreatCourt" ) +export IDs=( "251342" "251340" "251294" "251336" "251291" ) +for i in "${!scenes[@]}"; do +wget https://www.repository.cam.ac.uk/bitstream/handle/1810/${IDs[i]}/${scenes[i]}.zip -P $dataset \ +&& unzip $dataset/${scenes[i]}.zip -d $dataset && rm $dataset/${scenes[i]}.zip; done +``` + +Download the SIFT SfM models, courtesy of Torsten Sattler: +```bash +export fileid=1esqzZ1zEQlzZVic-H32V6kkZvc4NeS15 +export filename=$dataset/CambridgeLandmarks_Colmap_Retriangulated_1024px.zip +wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate "https://docs.google.com/uc?export=download&id=$fileid" -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=$fileid" -O $filename && rm -rf /tmp/cookies.txt +unzip $filename -d $dataset +``` + +## Pipeline + +```bash +python3 -m hloc.pipelines.Cambridge.pipeline +``` + +## Results +We report the median error in translation/rotation in cm/deg over all scenes: +| Method \ Scene | Court | King's | Hospital | Shop | St. Mary's | +| ------------------------ | --------------- | --------------- | --------------- | -------------- | -------------- | +| Active Search | 24/0.13 | 13/0.22 | 20/0.36 | **4**/0.21 | 8/0.25 | +| DSAC* | 49/0.3 | 15/0.3 | 21/0.4 | 5/0.3 | 13/0.4 | +| **SuperPoint+SuperGlue** | **17**/**0.11** | **12**/**0.21** | **14**/**0.30** | **4**/**0.19** | **7**/**0.22** | + +## Citation + +Please cite the following paper if you use the Cambridge Landmarks dataset: +``` +@inproceedings{kendall2015posenet, + title={{PoseNet}: A convolutional network for real-time {6-DoF} camera relocalization}, + author={Kendall, Alex and Grimes, Matthew and Cipolla, Roberto}, + booktitle={ICCV}, + year={2015} +} +``` diff --git a/hloc/pipelines/Cambridge/__init__.py b/hloc/pipelines/Cambridge/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hloc/pipelines/Cambridge/pipeline.py b/hloc/pipelines/Cambridge/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..3a676e5af411858f2459cb7f58f777b30be67d29 --- /dev/null +++ b/hloc/pipelines/Cambridge/pipeline.py @@ -0,0 +1,140 @@ +import argparse +from pathlib import Path + +from ... import ( + extract_features, + localize_sfm, + logger, + match_features, + pairs_from_covisibility, + pairs_from_retrieval, + triangulation, +) +from .utils import create_query_list_with_intrinsics, evaluate, scale_sfm_images + +SCENES = ["KingsCollege", "OldHospital", "ShopFacade", "StMarysChurch", "GreatCourt"] + + +def run_scene(images, gt_dir, outputs, results, num_covis, num_loc): + ref_sfm_sift = gt_dir / "model_train" + test_list = gt_dir / "list_query.txt" + + outputs.mkdir(exist_ok=True, parents=True) + ref_sfm = outputs / "sfm_superpoint+superglue" + ref_sfm_scaled = outputs / "sfm_sift_scaled" + query_list = outputs / "query_list_with_intrinsics.txt" + sfm_pairs = outputs / f"pairs-db-covis{num_covis}.txt" + loc_pairs = outputs / f"pairs-query-netvlad{num_loc}.txt" + + feature_conf = { + "output": "feats-superpoint-n4096-r1024", + "model": { + "name": "superpoint", + "nms_radius": 3, + "max_keypoints": 4096, + }, + "preprocessing": { + "grayscale": True, + "resize_max": 1024, + }, + } + matcher_conf = match_features.confs["superglue"] + retrieval_conf = extract_features.confs["netvlad"] + + create_query_list_with_intrinsics( + gt_dir / "empty_all", query_list, test_list, ext=".txt", image_dir=images + ) + with open(test_list, "r") as f: + query_seqs = {q.split("/")[0] for q in f.read().rstrip().split("\n")} + + global_descriptors = extract_features.main(retrieval_conf, images, outputs) + pairs_from_retrieval.main( + global_descriptors, + loc_pairs, + num_loc, + db_model=ref_sfm_sift, + query_prefix=query_seqs, + ) + + features = extract_features.main(feature_conf, images, outputs, as_half=True) + pairs_from_covisibility.main(ref_sfm_sift, sfm_pairs, num_matched=num_covis) + sfm_matches = match_features.main( + matcher_conf, sfm_pairs, feature_conf["output"], outputs + ) + + scale_sfm_images(ref_sfm_sift, ref_sfm_scaled, images) + triangulation.main( + ref_sfm, ref_sfm_scaled, images, sfm_pairs, features, sfm_matches + ) + + loc_matches = match_features.main( + matcher_conf, loc_pairs, feature_conf["output"], outputs + ) + + localize_sfm.main( + ref_sfm, + query_list, + loc_pairs, + features, + loc_matches, + results, + covisibility_clustering=False, + prepend_camera_name=True, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--scenes", default=SCENES, choices=SCENES, nargs="+") + parser.add_argument("--overwrite", action="store_true") + parser.add_argument( + "--dataset", + type=Path, + default="datasets/cambridge", + help="Path to the dataset, default: %(default)s", + ) + parser.add_argument( + "--outputs", + type=Path, + default="outputs/cambridge", + help="Path to the output directory, default: %(default)s", + ) + parser.add_argument( + "--num_covis", + type=int, + default=20, + help="Number of image pairs for SfM, default: %(default)s", + ) + parser.add_argument( + "--num_loc", + type=int, + default=10, + help="Number of image pairs for loc, default: %(default)s", + ) + args = parser.parse_args() + + gt_dirs = args.dataset / "CambridgeLandmarks_Colmap_Retriangulated_1024px" + + all_results = {} + for scene in args.scenes: + logger.info(f'Working on scene "{scene}".') + results = args.outputs / scene / "results.txt" + if args.overwrite or not results.exists(): + run_scene( + args.dataset / scene, + gt_dirs / scene, + args.outputs / scene, + results, + args.num_covis, + args.num_loc, + ) + all_results[scene] = results + + for scene in args.scenes: + logger.info(f'Evaluate scene "{scene}".') + evaluate( + gt_dirs / scene / "empty_all", + all_results[scene], + gt_dirs / scene / "list_query.txt", + ext=".txt", + ) diff --git a/hloc/pipelines/Cambridge/utils.py b/hloc/pipelines/Cambridge/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..36460f067369065668837fa317b1c0f7047e9203 --- /dev/null +++ b/hloc/pipelines/Cambridge/utils.py @@ -0,0 +1,145 @@ +import logging + +import cv2 +import numpy as np + +from hloc.utils.read_write_model import ( + qvec2rotmat, + read_cameras_binary, + read_cameras_text, + read_images_binary, + read_images_text, + read_model, + write_model, +) + +logger = logging.getLogger(__name__) + + +def scale_sfm_images(full_model, scaled_model, image_dir): + """Duplicate the provided model and scale the camera intrinsics so that + they match the original image resolution - makes everything easier. + """ + logger.info("Scaling the COLMAP model to the original image size.") + scaled_model.mkdir(exist_ok=True) + cameras, images, points3D = read_model(full_model) + + scaled_cameras = {} + for id_, image in images.items(): + name = image.name + img = cv2.imread(str(image_dir / name)) + assert img is not None, image_dir / name + h, w = img.shape[:2] + + cam_id = image.camera_id + if cam_id in scaled_cameras: + assert scaled_cameras[cam_id].width == w + assert scaled_cameras[cam_id].height == h + continue + + camera = cameras[cam_id] + assert camera.model == "SIMPLE_RADIAL" + sx = w / camera.width + sy = h / camera.height + assert sx == sy, (sx, sy) + scaled_cameras[cam_id] = camera._replace( + width=w, height=h, params=camera.params * np.array([sx, sx, sy, 1.0]) + ) + + write_model(scaled_cameras, images, points3D, scaled_model) + + +def create_query_list_with_intrinsics( + model, out, list_file=None, ext=".bin", image_dir=None +): + """Create a list of query images with intrinsics from the colmap model.""" + if ext == ".bin": + images = read_images_binary(model / "images.bin") + cameras = read_cameras_binary(model / "cameras.bin") + else: + images = read_images_text(model / "images.txt") + cameras = read_cameras_text(model / "cameras.txt") + + name2id = {image.name: i for i, image in images.items()} + if list_file is None: + names = list(name2id) + else: + with open(list_file, "r") as f: + names = f.read().rstrip().split("\n") + data = [] + for name in names: + image = images[name2id[name]] + camera = cameras[image.camera_id] + w, h, params = camera.width, camera.height, camera.params + + if image_dir is not None: + # Check the original image size and rescale the camera intrinsics + img = cv2.imread(str(image_dir / name)) + assert img is not None, image_dir / name + h_orig, w_orig = img.shape[:2] + assert camera.model == "SIMPLE_RADIAL" + sx = w_orig / w + sy = h_orig / h + assert sx == sy, (sx, sy) + w, h = w_orig, h_orig + params = params * np.array([sx, sx, sy, 1.0]) + + p = [name, camera.model, w, h] + params.tolist() + data.append(" ".join(map(str, p))) + with open(out, "w") as f: + f.write("\n".join(data)) + + +def evaluate(model, results, list_file=None, ext=".bin", only_localized=False): + predictions = {} + with open(results, "r") as f: + for data in f.read().rstrip().split("\n"): + data = data.split() + name = data[0] + q, t = np.split(np.array(data[1:], float), [4]) + predictions[name] = (qvec2rotmat(q), t) + if ext == ".bin": + images = read_images_binary(model / "images.bin") + else: + images = read_images_text(model / "images.txt") + name2id = {image.name: i for i, image in images.items()} + + if list_file is None: + test_names = list(name2id) + else: + with open(list_file, "r") as f: + test_names = f.read().rstrip().split("\n") + + errors_t = [] + errors_R = [] + for name in test_names: + if name not in predictions: + if only_localized: + continue + e_t = np.inf + e_R = 180.0 + else: + image = images[name2id[name]] + R_gt, t_gt = image.qvec2rotmat(), image.tvec + R, t = predictions[name] + e_t = np.linalg.norm(-R_gt.T @ t_gt + R.T @ t, axis=0) + cos = np.clip((np.trace(np.dot(R_gt.T, R)) - 1) / 2, -1.0, 1.0) + e_R = np.rad2deg(np.abs(np.arccos(cos))) + errors_t.append(e_t) + errors_R.append(e_R) + + errors_t = np.array(errors_t) + errors_R = np.array(errors_R) + + med_t = np.median(errors_t) + med_R = np.median(errors_R) + out = f"Results for file {results.name}:" + out += f"\nMedian errors: {med_t:.3f}m, {med_R:.3f}deg" + + out += "\nPercentage of test images localized within:" + threshs_t = [0.01, 0.02, 0.03, 0.05, 0.25, 0.5, 5.0] + threshs_R = [1.0, 2.0, 3.0, 5.0, 2.0, 5.0, 10.0] + for th_t, th_R in zip(threshs_t, threshs_R): + ratio = np.mean((errors_t < th_t) & (errors_R < th_R)) + out += f"\n\t{th_t*100:.0f}cm, {th_R:.0f}deg : {ratio*100:.2f}%" + logger.info(out) diff --git a/hloc/pipelines/RobotCar/README.md b/hloc/pipelines/RobotCar/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9881d153d4930cf32b5481ecd4fa2c900fa58c8c --- /dev/null +++ b/hloc/pipelines/RobotCar/README.md @@ -0,0 +1,16 @@ +# RobotCar Seasons dataset + +## Installation + +Download the dataset from [visuallocalization.net](https://www.visuallocalization.net): +```bash +export dataset=datasets/robotcar +wget -r -np -nH -R "index.html*" --cut-dirs=4 https://data.ciirc.cvut.cz/public/projects/2020VisualLocalization/RobotCar-Seasons/ -P $dataset +for condition in $dataset/images/*.zip; do unzip condition -d $dataset/images/; done +``` + +## Pipeline + +```bash +python3 -m hloc.pipelines.RobotCar.pipeline +``` diff --git a/hloc/pipelines/RobotCar/__init__.py b/hloc/pipelines/RobotCar/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hloc/pipelines/RobotCar/colmap_from_nvm.py b/hloc/pipelines/RobotCar/colmap_from_nvm.py new file mode 100644 index 0000000000000000000000000000000000000000..e90ed72b5391990d26961b5acfaaada6517ac191 --- /dev/null +++ b/hloc/pipelines/RobotCar/colmap_from_nvm.py @@ -0,0 +1,176 @@ +import argparse +import logging +import sqlite3 +from collections import defaultdict +from pathlib import Path + +import numpy as np +from tqdm import tqdm + +from ...colmap_from_nvm import ( + camera_center_to_translation, + recover_database_images_and_ids, +) +from ...utils.read_write_model import ( + CAMERA_MODEL_IDS, + Camera, + Image, + Point3D, + write_model, +) + +logger = logging.getLogger(__name__) + + +def read_nvm_model(nvm_path, database_path, image_ids, camera_ids, skip_points=False): + # Extract the intrinsics from the db file instead of the NVM model + db = sqlite3.connect(str(database_path)) + ret = db.execute("SELECT camera_id, model, width, height, params FROM cameras;") + cameras = {} + for camera_id, camera_model, width, height, params in ret: + params = np.fromstring(params, dtype=np.double).reshape(-1) + camera_model = CAMERA_MODEL_IDS[camera_model] + assert len(params) == camera_model.num_params, ( + len(params), + camera_model.num_params, + ) + camera = Camera( + id=camera_id, + model=camera_model.model_name, + width=int(width), + height=int(height), + params=params, + ) + cameras[camera_id] = camera + + nvm_f = open(nvm_path, "r") + line = nvm_f.readline() + while line == "\n" or line.startswith("NVM_V3"): + line = nvm_f.readline() + num_images = int(line) + # assert num_images == len(cameras), (num_images, len(cameras)) + + logger.info(f"Reading {num_images} images...") + image_idx_to_db_image_id = [] + image_data = [] + i = 0 + while i < num_images: + line = nvm_f.readline() + if line == "\n": + continue + data = line.strip("\n").lstrip("./").split(" ") + image_data.append(data) + image_idx_to_db_image_id.append(image_ids[data[0]]) + i += 1 + + line = nvm_f.readline() + while line == "\n": + line = nvm_f.readline() + num_points = int(line) + + if skip_points: + logger.info(f"Skipping {num_points} points.") + num_points = 0 + else: + logger.info(f"Reading {num_points} points...") + points3D = {} + image_idx_to_keypoints = defaultdict(list) + i = 0 + pbar = tqdm(total=num_points, unit="pts") + while i < num_points: + line = nvm_f.readline() + if line == "\n": + continue + + data = line.strip("\n").split(" ") + x, y, z, r, g, b, num_observations = data[:7] + obs_image_ids, point2D_idxs = [], [] + for j in range(int(num_observations)): + s = 7 + 4 * j + img_index, kp_index, kx, ky = data[s : s + 4] + image_idx_to_keypoints[int(img_index)].append( + (int(kp_index), float(kx), float(ky), i) + ) + db_image_id = image_idx_to_db_image_id[int(img_index)] + obs_image_ids.append(db_image_id) + point2D_idxs.append(kp_index) + + point = Point3D( + id=i, + xyz=np.array([x, y, z], float), + rgb=np.array([r, g, b], int), + error=1.0, # fake + image_ids=np.array(obs_image_ids, int), + point2D_idxs=np.array(point2D_idxs, int), + ) + points3D[i] = point + + i += 1 + pbar.update(1) + pbar.close() + + logger.info("Parsing image data...") + images = {} + for i, data in enumerate(image_data): + # Skip the focal length. Skip the distortion and terminal 0. + name, _, qw, qx, qy, qz, cx, cy, cz, _, _ = data + qvec = np.array([qw, qx, qy, qz], float) + c = np.array([cx, cy, cz], float) + t = camera_center_to_translation(c, qvec) + + if i in image_idx_to_keypoints: + # NVM only stores triangulated 2D keypoints: add dummy ones + keypoints = image_idx_to_keypoints[i] + point2D_idxs = np.array([d[0] for d in keypoints]) + tri_xys = np.array([[x, y] for _, x, y, _ in keypoints]) + tri_ids = np.array([i for _, _, _, i in keypoints]) + + num_2Dpoints = max(point2D_idxs) + 1 + xys = np.zeros((num_2Dpoints, 2), float) + point3D_ids = np.full(num_2Dpoints, -1, int) + xys[point2D_idxs] = tri_xys + point3D_ids[point2D_idxs] = tri_ids + else: + xys = np.zeros((0, 2), float) + point3D_ids = np.full(0, -1, int) + + image_id = image_ids[name] + image = Image( + id=image_id, + qvec=qvec, + tvec=t, + camera_id=camera_ids[name], + name=name.replace("png", "jpg"), # some hack required for RobotCar + xys=xys, + point3D_ids=point3D_ids, + ) + images[image_id] = image + + return cameras, images, points3D + + +def main(nvm, database, output, skip_points=False): + assert nvm.exists(), nvm + assert database.exists(), database + + image_ids, camera_ids = recover_database_images_and_ids(database) + + logger.info("Reading the NVM model...") + model = read_nvm_model( + nvm, database, image_ids, camera_ids, skip_points=skip_points + ) + + logger.info("Writing the COLMAP model...") + output.mkdir(exist_ok=True, parents=True) + write_model(*model, path=str(output), ext=".bin") + logger.info("Done.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--nvm", required=True, type=Path) + parser.add_argument("--database", required=True, type=Path) + parser.add_argument("--output", required=True, type=Path) + parser.add_argument("--skip_points", action="store_true") + args = parser.parse_args() + main(**args.__dict__) diff --git a/hloc/pipelines/RobotCar/pipeline.py b/hloc/pipelines/RobotCar/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..7a7ee314480d09b2200b9ff1992a3217e24bea2f --- /dev/null +++ b/hloc/pipelines/RobotCar/pipeline.py @@ -0,0 +1,143 @@ +import argparse +import glob +from pathlib import Path + +from ... import ( + extract_features, + localize_sfm, + match_features, + pairs_from_covisibility, + pairs_from_retrieval, + triangulation, +) +from . import colmap_from_nvm + +CONDITIONS = [ + "dawn", + "dusk", + "night", + "night-rain", + "overcast-summer", + "overcast-winter", + "rain", + "snow", + "sun", +] + + +def generate_query_list(dataset, image_dir, path): + h, w = 1024, 1024 + intrinsics_filename = "intrinsics/{}_intrinsics.txt" + cameras = {} + for side in ["left", "right", "rear"]: + with open(dataset / intrinsics_filename.format(side), "r") as f: + fx = f.readline().split()[1] + fy = f.readline().split()[1] + cx = f.readline().split()[1] + cy = f.readline().split()[1] + assert fx == fy + params = ["SIMPLE_RADIAL", w, h, fx, cx, cy, 0.0] + cameras[side] = [str(p) for p in params] + + queries = glob.glob((image_dir / "**/*.jpg").as_posix(), recursive=True) + queries = [ + Path(q).relative_to(image_dir.parents[0]).as_posix() for q in sorted(queries) + ] + + out = [[q] + cameras[Path(q).parent.name] for q in queries] + with open(path, "w") as f: + f.write("\n".join(map(" ".join, out))) + + +def run(args): + # Setup the paths + dataset = args.dataset + images = dataset / "images/" + + outputs = args.outputs # where everything will be saved + outputs.mkdir(exist_ok=True, parents=True) + query_list = outputs / "{condition}_queries_with_intrinsics.txt" + sift_sfm = outputs / "sfm_sift" + reference_sfm = outputs / "sfm_superpoint+superglue" + sfm_pairs = outputs / f"pairs-db-covis{args.num_covis}.txt" + loc_pairs = outputs / f"pairs-query-netvlad{args.num_loc}.txt" + results = outputs / f"RobotCar_hloc_superpoint+superglue_netvlad{args.num_loc}.txt" + + # pick one of the configurations for extraction and matching + retrieval_conf = extract_features.confs["netvlad"] + feature_conf = extract_features.confs["superpoint_aachen"] + matcher_conf = match_features.confs["superglue"] + + for condition in CONDITIONS: + generate_query_list( + dataset, images / condition, str(query_list).format(condition=condition) + ) + + features = extract_features.main(feature_conf, images, outputs, as_half=True) + + colmap_from_nvm.main( + dataset / "3D-models/all-merged/all.nvm", + dataset / "3D-models/overcast-reference.db", + sift_sfm, + ) + pairs_from_covisibility.main(sift_sfm, sfm_pairs, num_matched=args.num_covis) + sfm_matches = match_features.main( + matcher_conf, sfm_pairs, feature_conf["output"], outputs + ) + + triangulation.main( + reference_sfm, sift_sfm, images, sfm_pairs, features, sfm_matches + ) + + global_descriptors = extract_features.main(retrieval_conf, images, outputs) + # TODO: do per location and per camera + pairs_from_retrieval.main( + global_descriptors, + loc_pairs, + args.num_loc, + query_prefix=CONDITIONS, + db_model=reference_sfm, + ) + loc_matches = match_features.main( + matcher_conf, loc_pairs, feature_conf["output"], outputs + ) + + localize_sfm.main( + reference_sfm, + Path(str(query_list).format(condition="*")), + loc_pairs, + features, + loc_matches, + results, + covisibility_clustering=False, + prepend_camera_name=True, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--dataset", + type=Path, + default="datasets/robotcar", + help="Path to the dataset, default: %(default)s", + ) + parser.add_argument( + "--outputs", + type=Path, + default="outputs/robotcar", + help="Path to the output directory, default: %(default)s", + ) + parser.add_argument( + "--num_covis", + type=int, + default=20, + help="Number of image pairs for SfM, default: %(default)s", + ) + parser.add_argument( + "--num_loc", + type=int, + default=20, + help="Number of image pairs for loc, default: %(default)s", + ) + args = parser.parse_args() diff --git a/hloc/pipelines/__init__.py b/hloc/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hloc/reconstruction.py b/hloc/reconstruction.py new file mode 100644 index 0000000000000000000000000000000000000000..ff4a90a72a73f6a34d99ffedae1e5da1e8683454 --- /dev/null +++ b/hloc/reconstruction.py @@ -0,0 +1,199 @@ +import argparse +import multiprocessing +import shutil +from pathlib import Path +from typing import Any, Dict, List, Optional + +import pycolmap + +from . import logger +from .triangulation import ( + OutputCapture, + estimation_and_geometric_verification, + import_features, + import_matches, + parse_option_args, +) +from .utils.database import COLMAPDatabase + + +def create_empty_db(database_path: Path): + if database_path.exists(): + logger.warning("The database already exists, deleting it.") + database_path.unlink() + logger.info("Creating an empty database...") + db = COLMAPDatabase.connect(database_path) + db.create_tables() + db.commit() + db.close() + + +def import_images( + image_dir: Path, + database_path: Path, + camera_mode: pycolmap.CameraMode, + image_list: Optional[List[str]] = None, + options: Optional[Dict[str, Any]] = None, +): + logger.info("Importing images into the database...") + if options is None: + options = {} + images = list(image_dir.iterdir()) + if len(images) == 0: + raise IOError(f"No images found in {image_dir}.") + with pycolmap.ostream(): + pycolmap.import_images( + database_path, + image_dir, + camera_mode, + image_list=image_list or [], + options=options, + ) + + +def get_image_ids(database_path: Path) -> Dict[str, int]: + db = COLMAPDatabase.connect(database_path) + images = {} + for name, image_id in db.execute("SELECT name, image_id FROM images;"): + images[name] = image_id + db.close() + return images + + +def run_reconstruction( + sfm_dir: Path, + database_path: Path, + image_dir: Path, + verbose: bool = False, + options: Optional[Dict[str, Any]] = None, +) -> pycolmap.Reconstruction: + models_path = sfm_dir / "models" + models_path.mkdir(exist_ok=True, parents=True) + logger.info("Running 3D reconstruction...") + if options is None: + options = {} + options = {"num_threads": min(multiprocessing.cpu_count(), 16), **options} + with OutputCapture(verbose): + with pycolmap.ostream(): + reconstructions = pycolmap.incremental_mapping( + database_path, image_dir, models_path, options=options + ) + + if len(reconstructions) == 0: + logger.error("Could not reconstruct any model!") + return None + logger.info(f"Reconstructed {len(reconstructions)} model(s).") + + largest_index = None + largest_num_images = 0 + for index, rec in reconstructions.items(): + num_images = rec.num_reg_images() + if num_images > largest_num_images: + largest_index = index + largest_num_images = num_images + assert largest_index is not None + logger.info( + f"Largest model is #{largest_index} " + f"with {largest_num_images} images." + ) + + for filename in ["images.bin", "cameras.bin", "points3D.bin"]: + if (sfm_dir / filename).exists(): + (sfm_dir / filename).unlink() + shutil.move( + str(models_path / str(largest_index) / filename), str(sfm_dir) + ) + return reconstructions[largest_index] + + +def main( + sfm_dir: Path, + image_dir: Path, + pairs: Path, + features: Path, + matches: Path, + camera_mode: pycolmap.CameraMode = pycolmap.CameraMode.AUTO, + verbose: bool = False, + skip_geometric_verification: bool = False, + min_match_score: Optional[float] = None, + image_list: Optional[List[str]] = None, + image_options: Optional[Dict[str, Any]] = None, + mapper_options: Optional[Dict[str, Any]] = None, +) -> pycolmap.Reconstruction: + assert features.exists(), features + assert pairs.exists(), pairs + assert matches.exists(), matches + + sfm_dir.mkdir(parents=True, exist_ok=True) + database = sfm_dir / "database.db" + + create_empty_db(database) + import_images(image_dir, database, camera_mode, image_list, image_options) + image_ids = get_image_ids(database) + import_features(image_ids, database, features) + import_matches( + image_ids, + database, + pairs, + matches, + min_match_score, + skip_geometric_verification, + ) + if not skip_geometric_verification: + estimation_and_geometric_verification(database, pairs, verbose) + reconstruction = run_reconstruction( + sfm_dir, database, image_dir, verbose, mapper_options + ) + if reconstruction is not None: + logger.info( + f"Reconstruction statistics:\n{reconstruction.summary()}" + + f"\n\tnum_input_images = {len(image_ids)}" + ) + return reconstruction + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--sfm_dir", type=Path, required=True) + parser.add_argument("--image_dir", type=Path, required=True) + + parser.add_argument("--pairs", type=Path, required=True) + parser.add_argument("--features", type=Path, required=True) + parser.add_argument("--matches", type=Path, required=True) + + parser.add_argument( + "--camera_mode", + type=str, + default="AUTO", + choices=list(pycolmap.CameraMode.__members__.keys()), + ) + parser.add_argument("--skip_geometric_verification", action="store_true") + parser.add_argument("--min_match_score", type=float) + parser.add_argument("--verbose", action="store_true") + + parser.add_argument( + "--image_options", + nargs="+", + default=[], + help="List of key=value from {}".format( + pycolmap.ImageReaderOptions().todict() + ), + ) + parser.add_argument( + "--mapper_options", + nargs="+", + default=[], + help="List of key=value from {}".format( + pycolmap.IncrementalMapperOptions().todict() + ), + ) + args = parser.parse_args().__dict__ + + image_options = parse_option_args( + args.pop("image_options"), pycolmap.ImageReaderOptions() + ) + mapper_options = parse_option_args( + args.pop("mapper_options"), pycolmap.IncrementalMapperOptions() + ) + + main(**args, image_options=image_options, mapper_options=mapper_options) diff --git a/hloc/triangulation.py b/hloc/triangulation.py new file mode 100644 index 0000000000000000000000000000000000000000..385fed97e1e2093d9e05331c1525f11f95c885cd --- /dev/null +++ b/hloc/triangulation.py @@ -0,0 +1,320 @@ +import argparse +import contextlib +import io +import sys +from pathlib import Path +from typing import Any, Dict, List, Optional + +import numpy as np +import pycolmap +from tqdm import tqdm + +from . import logger +from .utils.database import COLMAPDatabase +from .utils.geometry import compute_epipolar_errors +from .utils.io import get_keypoints, get_matches +from .utils.parsers import parse_retrieval + + +class OutputCapture: + def __init__(self, verbose: bool): + self.verbose = verbose + + def __enter__(self): + if not self.verbose: + self.capture = contextlib.redirect_stdout(io.StringIO()) + self.out = self.capture.__enter__() + + def __exit__(self, exc_type, *args): + if not self.verbose: + self.capture.__exit__(exc_type, *args) + if exc_type is not None: + logger.error("Failed with output:\n%s", self.out.getvalue()) + sys.stdout.flush() + + +def create_db_from_model( + reconstruction: pycolmap.Reconstruction, database_path: Path +) -> Dict[str, int]: + if database_path.exists(): + logger.warning("The database already exists, deleting it.") + database_path.unlink() + + db = COLMAPDatabase.connect(database_path) + db.create_tables() + + for i, camera in reconstruction.cameras.items(): + db.add_camera( + camera.model.value, + camera.width, + camera.height, + camera.params, + camera_id=i, + prior_focal_length=True, + ) + + for i, image in reconstruction.images.items(): + db.add_image(image.name, image.camera_id, image_id=i) + + db.commit() + db.close() + return {image.name: i for i, image in reconstruction.images.items()} + + +def import_features( + image_ids: Dict[str, int], database_path: Path, features_path: Path +): + logger.info("Importing features into the database...") + db = COLMAPDatabase.connect(database_path) + + for image_name, image_id in tqdm(image_ids.items()): + keypoints = get_keypoints(features_path, image_name) + keypoints += 0.5 # COLMAP origin + db.add_keypoints(image_id, keypoints) + + db.commit() + db.close() + + +def import_matches( + image_ids: Dict[str, int], + database_path: Path, + pairs_path: Path, + matches_path: Path, + min_match_score: Optional[float] = None, + skip_geometric_verification: bool = False, +): + logger.info("Importing matches into the database...") + + with open(str(pairs_path), "r") as f: + pairs = [p.split() for p in f.readlines()] + + db = COLMAPDatabase.connect(database_path) + + matched = set() + for name0, name1 in tqdm(pairs): + id0, id1 = image_ids[name0], image_ids[name1] + if len({(id0, id1), (id1, id0)} & matched) > 0: + continue + matches, scores = get_matches(matches_path, name0, name1) + if min_match_score: + matches = matches[scores > min_match_score] + db.add_matches(id0, id1, matches) + matched |= {(id0, id1), (id1, id0)} + + if skip_geometric_verification: + db.add_two_view_geometry(id0, id1, matches) + + db.commit() + db.close() + + +def estimation_and_geometric_verification( + database_path: Path, pairs_path: Path, verbose: bool = False +): + logger.info("Performing geometric verification of the matches...") + with OutputCapture(verbose): + with pycolmap.ostream(): + pycolmap.verify_matches( + database_path, + pairs_path, + options=dict( + ransac=dict(max_num_trials=20000, min_inlier_ratio=0.1) + ), + ) + + +def geometric_verification( + image_ids: Dict[str, int], + reference: pycolmap.Reconstruction, + database_path: Path, + features_path: Path, + pairs_path: Path, + matches_path: Path, + max_error: float = 4.0, +): + logger.info("Performing geometric verification of the matches...") + + pairs = parse_retrieval(pairs_path) + db = COLMAPDatabase.connect(database_path) + + inlier_ratios = [] + matched = set() + for name0 in tqdm(pairs): + id0 = image_ids[name0] + image0 = reference.images[id0] + cam0 = reference.cameras[image0.camera_id] + kps0, noise0 = get_keypoints( + features_path, name0, return_uncertainty=True + ) + noise0 = 1.0 if noise0 is None else noise0 + if len(kps0) > 0: + kps0 = np.stack(cam0.cam_from_img(kps0)) + else: + kps0 = np.zeros((0, 2)) + + for name1 in pairs[name0]: + id1 = image_ids[name1] + image1 = reference.images[id1] + cam1 = reference.cameras[image1.camera_id] + kps1, noise1 = get_keypoints( + features_path, name1, return_uncertainty=True + ) + noise1 = 1.0 if noise1 is None else noise1 + if len(kps1) > 0: + kps1 = np.stack(cam1.cam_from_img(kps1)) + else: + kps1 = np.zeros((0, 2)) + + matches = get_matches(matches_path, name0, name1)[0] + + if len({(id0, id1), (id1, id0)} & matched) > 0: + continue + matched |= {(id0, id1), (id1, id0)} + + if matches.shape[0] == 0: + db.add_two_view_geometry(id0, id1, matches) + continue + + cam1_from_cam0 = ( + image1.cam_from_world * image0.cam_from_world.inverse() + ) + errors0, errors1 = compute_epipolar_errors( + cam1_from_cam0, kps0[matches[:, 0]], kps1[matches[:, 1]] + ) + valid_matches = np.logical_and( + errors0 <= cam0.cam_from_img_threshold(noise0 * max_error), + errors1 <= cam1.cam_from_img_threshold(noise1 * max_error), + ) + # TODO: We could also add E to the database, but we need + # to reverse the transformations if id0 > id1 in utils/database.py. + db.add_two_view_geometry(id0, id1, matches[valid_matches, :]) + inlier_ratios.append(np.mean(valid_matches)) + logger.info( + "mean/med/min/max valid matches %.2f/%.2f/%.2f/%.2f%%.", + np.mean(inlier_ratios) * 100, + np.median(inlier_ratios) * 100, + np.min(inlier_ratios) * 100, + np.max(inlier_ratios) * 100, + ) + + db.commit() + db.close() + + +def run_triangulation( + model_path: Path, + database_path: Path, + image_dir: Path, + reference_model: pycolmap.Reconstruction, + verbose: bool = False, + options: Optional[Dict[str, Any]] = None, +) -> pycolmap.Reconstruction: + model_path.mkdir(parents=True, exist_ok=True) + logger.info("Running 3D triangulation...") + if options is None: + options = {} + with OutputCapture(verbose): + with pycolmap.ostream(): + reconstruction = pycolmap.triangulate_points( + reference_model, + database_path, + image_dir, + model_path, + options=options, + ) + return reconstruction + + +def main( + sfm_dir: Path, + reference_model: Path, + image_dir: Path, + pairs: Path, + features: Path, + matches: Path, + skip_geometric_verification: bool = False, + estimate_two_view_geometries: bool = False, + min_match_score: Optional[float] = None, + verbose: bool = False, + mapper_options: Optional[Dict[str, Any]] = None, +) -> pycolmap.Reconstruction: + assert reference_model.exists(), reference_model + assert features.exists(), features + assert pairs.exists(), pairs + assert matches.exists(), matches + + sfm_dir.mkdir(parents=True, exist_ok=True) + database = sfm_dir / "database.db" + reference = pycolmap.Reconstruction(reference_model) + + image_ids = create_db_from_model(reference, database) + import_features(image_ids, database, features) + import_matches( + image_ids, + database, + pairs, + matches, + min_match_score, + skip_geometric_verification, + ) + if not skip_geometric_verification: + if estimate_two_view_geometries: + estimation_and_geometric_verification(database, pairs, verbose) + else: + geometric_verification( + image_ids, reference, database, features, pairs, matches + ) + reconstruction = run_triangulation( + sfm_dir, database, image_dir, reference, verbose, mapper_options + ) + logger.info( + "Finished the triangulation with statistics:\n%s", + reconstruction.summary(), + ) + return reconstruction + + +def parse_option_args(args: List[str], default_options) -> Dict[str, Any]: + options = {} + for arg in args: + idx = arg.find("=") + if idx == -1: + raise ValueError("Options format: key1=value1 key2=value2 etc.") + key, value = arg[:idx], arg[idx + 1 :] + if not hasattr(default_options, key): + raise ValueError( + f'Unknown option "{key}", allowed options and default values' + f" for {default_options.summary()}" + ) + value = eval(value) + target_type = type(getattr(default_options, key)) + if not isinstance(value, target_type): + raise ValueError( + f'Incorrect type for option "{key}":' + f" {type(value)} vs {target_type}" + ) + options[key] = value + return options + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--sfm_dir", type=Path, required=True) + parser.add_argument("--reference_sfm_model", type=Path, required=True) + parser.add_argument("--image_dir", type=Path, required=True) + + parser.add_argument("--pairs", type=Path, required=True) + parser.add_argument("--features", type=Path, required=True) + parser.add_argument("--matches", type=Path, required=True) + + parser.add_argument("--skip_geometric_verification", action="store_true") + parser.add_argument("--min_match_score", type=float) + parser.add_argument("--verbose", action="store_true") + args = parser.parse_args().__dict__ + + mapper_options = parse_option_args( + args.pop("mapper_options"), pycolmap.IncrementalMapperOptions() + ) + + main(**args, mapper_options=mapper_options) diff --git a/hloc/utils/__init__.py b/hloc/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7c1e6e13ec689af7d948e5155ca773ee038df7bb --- /dev/null +++ b/hloc/utils/__init__.py @@ -0,0 +1,13 @@ +import os +import logging +import sys +from .. import logger + + +def do_system(cmd, verbose=False): + if verbose: + logger.info(f"Run cmd: `{cmd}`.") + err = os.system(cmd) + if err: + logger.info(f"Run cmd err.") + sys.exit(err) diff --git a/hloc/utils/base_model.py b/hloc/utils/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..f560a2664eeb7ff53b49e169289ab284c16cb0ec --- /dev/null +++ b/hloc/utils/base_model.py @@ -0,0 +1,47 @@ +import sys +from abc import ABCMeta, abstractmethod +from torch import nn +from copy import copy +import inspect + + +class BaseModel(nn.Module, metaclass=ABCMeta): + default_conf = {} + required_inputs = [] + + def __init__(self, conf): + """Perform some logic and call the _init method of the child model.""" + super().__init__() + self.conf = conf = {**self.default_conf, **conf} + self.required_inputs = copy(self.required_inputs) + self._init(conf) + sys.stdout.flush() + + def forward(self, data): + """Check the data and call the _forward method of the child model.""" + for key in self.required_inputs: + assert key in data, "Missing key {} in data".format(key) + return self._forward(data) + + @abstractmethod + def _init(self, conf): + """To be implemented by the child class.""" + raise NotImplementedError + + @abstractmethod + def _forward(self, data): + """To be implemented by the child class.""" + raise NotImplementedError + + +def dynamic_load(root, model): + module_path = f"{root.__name__}.{model}" + module = __import__(module_path, fromlist=[""]) + classes = inspect.getmembers(module, inspect.isclass) + # Filter classes defined in the module + classes = [c for c in classes if c[1].__module__ == module_path] + # Filter classes inherited from BaseModel + classes = [c for c in classes if issubclass(c[1], BaseModel)] + assert len(classes) == 1, classes + return classes[0][1] + # return getattr(module, 'Model') diff --git a/hloc/utils/database.py b/hloc/utils/database.py new file mode 100644 index 0000000000000000000000000000000000000000..683c250594c9fe990567a6c0099d5a0631f23b0d --- /dev/null +++ b/hloc/utils/database.py @@ -0,0 +1,414 @@ +# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of +# its contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) + +# This script is based on an original implementation by True Price. + +import sqlite3 +import sys + +import numpy as np + +IS_PYTHON3 = sys.version_info[0] >= 3 + +MAX_IMAGE_ID = 2**31 - 1 + +CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras ( + camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + model INTEGER NOT NULL, + width INTEGER NOT NULL, + height INTEGER NOT NULL, + params BLOB, + prior_focal_length INTEGER NOT NULL)""" + +CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors ( + image_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB, + FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)""" + +CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images ( + image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + name TEXT NOT NULL UNIQUE, + camera_id INTEGER NOT NULL, + prior_qw REAL, + prior_qx REAL, + prior_qy REAL, + prior_qz REAL, + prior_tx REAL, + prior_ty REAL, + prior_tz REAL, + CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < {}), + FOREIGN KEY(camera_id) REFERENCES cameras(camera_id)) +""".format( + MAX_IMAGE_ID +) + +CREATE_TWO_VIEW_GEOMETRIES_TABLE = """ +CREATE TABLE IF NOT EXISTS two_view_geometries ( + pair_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB, + config INTEGER NOT NULL, + F BLOB, + E BLOB, + H BLOB, + qvec BLOB, + tvec BLOB) +""" + +CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints ( + image_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB, + FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE) +""" + +CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches ( + pair_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB)""" + +CREATE_NAME_INDEX = "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)" + +CREATE_ALL = "; ".join( + [ + CREATE_CAMERAS_TABLE, + CREATE_IMAGES_TABLE, + CREATE_KEYPOINTS_TABLE, + CREATE_DESCRIPTORS_TABLE, + CREATE_MATCHES_TABLE, + CREATE_TWO_VIEW_GEOMETRIES_TABLE, + CREATE_NAME_INDEX, + ] +) + + +def image_ids_to_pair_id(image_id1, image_id2): + if image_id1 > image_id2: + image_id1, image_id2 = image_id2, image_id1 + return image_id1 * MAX_IMAGE_ID + image_id2 + + +def pair_id_to_image_ids(pair_id): + image_id2 = pair_id % MAX_IMAGE_ID + image_id1 = (pair_id - image_id2) / MAX_IMAGE_ID + return image_id1, image_id2 + + +def array_to_blob(array): + if IS_PYTHON3: + return array.tobytes() + else: + return np.getbuffer(array) + + +def blob_to_array(blob, dtype, shape=(-1,)): + if IS_PYTHON3: + return np.fromstring(blob, dtype=dtype).reshape(*shape) + else: + return np.frombuffer(blob, dtype=dtype).reshape(*shape) + + +class COLMAPDatabase(sqlite3.Connection): + @staticmethod + def connect(database_path): + return sqlite3.connect(str(database_path), factory=COLMAPDatabase) + + def __init__(self, *args, **kwargs): + super(COLMAPDatabase, self).__init__(*args, **kwargs) + + self.create_tables = lambda: self.executescript(CREATE_ALL) + self.create_cameras_table = lambda: self.executescript(CREATE_CAMERAS_TABLE) + self.create_descriptors_table = lambda: self.executescript( + CREATE_DESCRIPTORS_TABLE + ) + self.create_images_table = lambda: self.executescript(CREATE_IMAGES_TABLE) + self.create_two_view_geometries_table = lambda: self.executescript( + CREATE_TWO_VIEW_GEOMETRIES_TABLE + ) + self.create_keypoints_table = lambda: self.executescript(CREATE_KEYPOINTS_TABLE) + self.create_matches_table = lambda: self.executescript(CREATE_MATCHES_TABLE) + self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX) + + def add_camera( + self, model, width, height, params, prior_focal_length=False, camera_id=None + ): + params = np.asarray(params, np.float64) + cursor = self.execute( + "INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)", + ( + camera_id, + model, + width, + height, + array_to_blob(params), + prior_focal_length, + ), + ) + return cursor.lastrowid + + def add_image( + self, + name, + camera_id, + prior_q=np.full(4, np.NaN), + prior_t=np.full(3, np.NaN), + image_id=None, + ): + cursor = self.execute( + "INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + image_id, + name, + camera_id, + prior_q[0], + prior_q[1], + prior_q[2], + prior_q[3], + prior_t[0], + prior_t[1], + prior_t[2], + ), + ) + return cursor.lastrowid + + def add_keypoints(self, image_id, keypoints): + assert len(keypoints.shape) == 2 + assert keypoints.shape[1] in [2, 4, 6] + + keypoints = np.asarray(keypoints, np.float32) + self.execute( + "INSERT INTO keypoints VALUES (?, ?, ?, ?)", + (image_id,) + keypoints.shape + (array_to_blob(keypoints),), + ) + + def add_descriptors(self, image_id, descriptors): + descriptors = np.ascontiguousarray(descriptors, np.uint8) + self.execute( + "INSERT INTO descriptors VALUES (?, ?, ?, ?)", + (image_id,) + descriptors.shape + (array_to_blob(descriptors),), + ) + + def add_matches(self, image_id1, image_id2, matches): + assert len(matches.shape) == 2 + assert matches.shape[1] == 2 + + if image_id1 > image_id2: + matches = matches[:, ::-1] + + pair_id = image_ids_to_pair_id(image_id1, image_id2) + matches = np.asarray(matches, np.uint32) + self.execute( + "INSERT INTO matches VALUES (?, ?, ?, ?)", + (pair_id,) + matches.shape + (array_to_blob(matches),), + ) + + def add_two_view_geometry( + self, + image_id1, + image_id2, + matches, + F=np.eye(3), + E=np.eye(3), + H=np.eye(3), + qvec=np.array([1.0, 0.0, 0.0, 0.0]), + tvec=np.zeros(3), + config=2, + ): + assert len(matches.shape) == 2 + assert matches.shape[1] == 2 + + if image_id1 > image_id2: + matches = matches[:, ::-1] + + pair_id = image_ids_to_pair_id(image_id1, image_id2) + matches = np.asarray(matches, np.uint32) + F = np.asarray(F, dtype=np.float64) + E = np.asarray(E, dtype=np.float64) + H = np.asarray(H, dtype=np.float64) + qvec = np.asarray(qvec, dtype=np.float64) + tvec = np.asarray(tvec, dtype=np.float64) + self.execute( + "INSERT INTO two_view_geometries VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + (pair_id,) + + matches.shape + + ( + array_to_blob(matches), + config, + array_to_blob(F), + array_to_blob(E), + array_to_blob(H), + array_to_blob(qvec), + array_to_blob(tvec), + ), + ) + + +def example_usage(): + import os + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--database_path", default="database.db") + args = parser.parse_args() + + if os.path.exists(args.database_path): + print("ERROR: database path already exists -- will not modify it.") + return + + # Open the database. + + db = COLMAPDatabase.connect(args.database_path) + + # For convenience, try creating all the tables upfront. + + db.create_tables() + + # Create dummy cameras. + + model1, width1, height1, params1 = ( + 0, + 1024, + 768, + np.array((1024.0, 512.0, 384.0)), + ) + model2, width2, height2, params2 = ( + 2, + 1024, + 768, + np.array((1024.0, 512.0, 384.0, 0.1)), + ) + + camera_id1 = db.add_camera(model1, width1, height1, params1) + camera_id2 = db.add_camera(model2, width2, height2, params2) + + # Create dummy images. + + image_id1 = db.add_image("image1.png", camera_id1) + image_id2 = db.add_image("image2.png", camera_id1) + image_id3 = db.add_image("image3.png", camera_id2) + image_id4 = db.add_image("image4.png", camera_id2) + + # Create dummy keypoints. + # + # Note that COLMAP supports: + # - 2D keypoints: (x, y) + # - 4D keypoints: (x, y, theta, scale) + # - 6D affine keypoints: (x, y, a_11, a_12, a_21, a_22) + + num_keypoints = 1000 + keypoints1 = np.random.rand(num_keypoints, 2) * (width1, height1) + keypoints2 = np.random.rand(num_keypoints, 2) * (width1, height1) + keypoints3 = np.random.rand(num_keypoints, 2) * (width2, height2) + keypoints4 = np.random.rand(num_keypoints, 2) * (width2, height2) + + db.add_keypoints(image_id1, keypoints1) + db.add_keypoints(image_id2, keypoints2) + db.add_keypoints(image_id3, keypoints3) + db.add_keypoints(image_id4, keypoints4) + + # Create dummy matches. + + M = 50 + matches12 = np.random.randint(num_keypoints, size=(M, 2)) + matches23 = np.random.randint(num_keypoints, size=(M, 2)) + matches34 = np.random.randint(num_keypoints, size=(M, 2)) + + db.add_matches(image_id1, image_id2, matches12) + db.add_matches(image_id2, image_id3, matches23) + db.add_matches(image_id3, image_id4, matches34) + + # Commit the data to the file. + + db.commit() + + # Read and check cameras. + + rows = db.execute("SELECT * FROM cameras") + + camera_id, model, width, height, params, prior = next(rows) + params = blob_to_array(params, np.float64) + assert camera_id == camera_id1 + assert model == model1 and width == width1 and height == height1 + assert np.allclose(params, params1) + + camera_id, model, width, height, params, prior = next(rows) + params = blob_to_array(params, np.float64) + assert camera_id == camera_id2 + assert model == model2 and width == width2 and height == height2 + assert np.allclose(params, params2) + + # Read and check keypoints. + + keypoints = dict( + (image_id, blob_to_array(data, np.float32, (-1, 2))) + for image_id, data in db.execute("SELECT image_id, data FROM keypoints") + ) + + assert np.allclose(keypoints[image_id1], keypoints1) + assert np.allclose(keypoints[image_id2], keypoints2) + assert np.allclose(keypoints[image_id3], keypoints3) + assert np.allclose(keypoints[image_id4], keypoints4) + + # Read and check matches. + + pair_ids = [ + image_ids_to_pair_id(*pair) + for pair in ( + (image_id1, image_id2), + (image_id2, image_id3), + (image_id3, image_id4), + ) + ] + + matches = dict( + (pair_id_to_image_ids(pair_id), blob_to_array(data, np.uint32, (-1, 2))) + for pair_id, data in db.execute("SELECT pair_id, data FROM matches") + ) + + assert np.all(matches[(image_id1, image_id2)] == matches12) + assert np.all(matches[(image_id2, image_id3)] == matches23) + assert np.all(matches[(image_id3, image_id4)] == matches34) + + # Clean up. + + db.close() + + if os.path.exists(args.database_path): + os.remove(args.database_path) + + +if __name__ == "__main__": + example_usage() diff --git a/hloc/utils/geometry.py b/hloc/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..5995cccda40354f2346ad84fa966614b4ccbfed0 --- /dev/null +++ b/hloc/utils/geometry.py @@ -0,0 +1,16 @@ +import numpy as np +import pycolmap + + +def to_homogeneous(p): + return np.pad(p, ((0, 0),) * (p.ndim - 1) + ((0, 1),), constant_values=1) + + +def compute_epipolar_errors(j_from_i: pycolmap.Rigid3d, p2d_i, p2d_j): + j_E_i = j_from_i.essential_matrix() + l2d_j = to_homogeneous(p2d_i) @ j_E_i.T + l2d_i = to_homogeneous(p2d_j) @ j_E_i + dist = np.abs(np.sum(to_homogeneous(p2d_i) * l2d_i, axis=1)) + errors_i = dist / np.linalg.norm(l2d_i[:, :2], axis=1) + errors_j = dist / np.linalg.norm(l2d_j[:, :2], axis=1) + return errors_i, errors_j diff --git a/hloc/utils/io.py b/hloc/utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..1cd55d4c30b41c3754634a164312dc5e8c294274 --- /dev/null +++ b/hloc/utils/io.py @@ -0,0 +1,77 @@ +from typing import Tuple +from pathlib import Path +import numpy as np +import cv2 +import h5py + +from .parsers import names_to_pair, names_to_pair_old + + +def read_image(path, grayscale=False): + if grayscale: + mode = cv2.IMREAD_GRAYSCALE + else: + mode = cv2.IMREAD_COLOR + image = cv2.imread(str(path), mode) + if image is None: + raise ValueError(f"Cannot read image {path}.") + if not grayscale and len(image.shape) == 3: + image = image[:, :, ::-1] # BGR to RGB + return image + + +def list_h5_names(path): + names = [] + with h5py.File(str(path), "r", libver="latest") as fd: + + def visit_fn(_, obj): + if isinstance(obj, h5py.Dataset): + names.append(obj.parent.name.strip("/")) + + fd.visititems(visit_fn) + return list(set(names)) + + +def get_keypoints( + path: Path, name: str, return_uncertainty: bool = False +) -> np.ndarray: + with h5py.File(str(path), "r", libver="latest") as hfile: + dset = hfile[name]["keypoints"] + p = dset.__array__() + uncertainty = dset.attrs.get("uncertainty") + if return_uncertainty: + return p, uncertainty + return p + + +def find_pair(hfile: h5py.File, name0: str, name1: str): + pair = names_to_pair(name0, name1) + if pair in hfile: + return pair, False + pair = names_to_pair(name1, name0) + if pair in hfile: + return pair, True + # older, less efficient format + pair = names_to_pair_old(name0, name1) + if pair in hfile: + return pair, False + pair = names_to_pair_old(name1, name0) + if pair in hfile: + return pair, True + raise ValueError( + f"Could not find pair {(name0, name1)}... " + "Maybe you matched with a different list of pairs? " + ) + + +def get_matches(path: Path, name0: str, name1: str) -> Tuple[np.ndarray]: + with h5py.File(str(path), "r", libver="latest") as hfile: + pair, reverse = find_pair(hfile, name0, name1) + matches = hfile[pair]["matches0"].__array__() + scores = hfile[pair]["matching_scores0"].__array__() + idx = np.where(matches != -1)[0] + matches = np.stack([idx, matches[idx]], -1) + if reverse: + matches = np.flip(matches, -1) + scores = scores[idx] + return matches, scores diff --git a/hloc/utils/parsers.py b/hloc/utils/parsers.py new file mode 100644 index 0000000000000000000000000000000000000000..9407dcf916d67170e3f8d19581041a774a61e84b --- /dev/null +++ b/hloc/utils/parsers.py @@ -0,0 +1,59 @@ +import logging +from collections import defaultdict +from pathlib import Path + +import numpy as np +import pycolmap + +logger = logging.getLogger(__name__) + + +def parse_image_list(path, with_intrinsics=False): + images = [] + with open(path, "r") as f: + for line in f: + line = line.strip("\n") + if len(line) == 0 or line[0] == "#": + continue + name, *data = line.split() + if with_intrinsics: + model, width, height, *params = data + params = np.array(params, float) + cam = pycolmap.Camera( + model=model, width=int(width), height=int(height), params=params + ) + images.append((name, cam)) + else: + images.append(name) + + assert len(images) > 0 + logger.info(f"Imported {len(images)} images from {path.name}") + return images + + +def parse_image_lists(paths, with_intrinsics=False): + images = [] + files = list(Path(paths.parent).glob(paths.name)) + assert len(files) > 0 + for lfile in files: + images += parse_image_list(lfile, with_intrinsics=with_intrinsics) + return images + + +def parse_retrieval(path): + retrieval = defaultdict(list) + with open(path, "r") as f: + for p in f.read().rstrip("\n").split("\n"): + if len(p) == 0: + continue + q, r = p.split() + retrieval[q].append(r) + return dict(retrieval) + + +def names_to_pair(name0, name1, separator="/"): + return separator.join((name0.replace("/", "-"), name1.replace("/", "-"))) + + +def names_to_pair_old(name0, name1): + return names_to_pair(name0, name1, separator="_") diff --git a/hloc/utils/read_write_model.py b/hloc/utils/read_write_model.py new file mode 100644 index 0000000000000000000000000000000000000000..197921ded6d9cad3f365fd68a225822dc5411aee --- /dev/null +++ b/hloc/utils/read_write_model.py @@ -0,0 +1,588 @@ +# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of +# its contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) + +import argparse +import collections +import logging +import os +import struct + +import numpy as np + +logger = logging.getLogger(__name__) + + +CameraModel = collections.namedtuple( + "CameraModel", ["model_id", "model_name", "num_params"] +) +Camera = collections.namedtuple("Camera", ["id", "model", "width", "height", "params"]) +BaseImage = collections.namedtuple( + "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"] +) +Point3D = collections.namedtuple( + "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"] +) + + +class Image(BaseImage): + def qvec2rotmat(self): + return qvec2rotmat(self.qvec) + + +CAMERA_MODELS = { + CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), + CameraModel(model_id=1, model_name="PINHOLE", num_params=4), + CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), + CameraModel(model_id=3, model_name="RADIAL", num_params=5), + CameraModel(model_id=4, model_name="OPENCV", num_params=8), + CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), + CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), + CameraModel(model_id=7, model_name="FOV", num_params=5), + CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), + CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), + CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12), +} +CAMERA_MODEL_IDS = dict( + [(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS] +) +CAMERA_MODEL_NAMES = dict( + [(camera_model.model_name, camera_model) for camera_model in CAMERA_MODELS] +) + + +def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): + """Read and unpack the next bytes from a binary file. + :param fid: + :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + :param endian_character: Any of {@, =, <, >, !} + :return: Tuple of read and unpacked values. + """ + data = fid.read(num_bytes) + return struct.unpack(endian_character + format_char_sequence, data) + + +def write_next_bytes(fid, data, format_char_sequence, endian_character="<"): + """pack and write to a binary file. + :param fid: + :param data: data to send, if multiple elements are sent at the same time, + they should be encapsuled either in a list or a tuple + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + should be the same length as the data list or tuple + :param endian_character: Any of {@, =, <, >, !} + """ + if isinstance(data, (list, tuple)): + bytes = struct.pack(endian_character + format_char_sequence, *data) + else: + bytes = struct.pack(endian_character + format_char_sequence, data) + fid.write(bytes) + + +def read_cameras_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasText(const std::string& path) + void Reconstruction::ReadCamerasText(const std::string& path) + """ + cameras = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + camera_id = int(elems[0]) + model = elems[1] + width = int(elems[2]) + height = int(elems[3]) + params = np.array(tuple(map(float, elems[4:]))) + cameras[camera_id] = Camera( + id=camera_id, model=model, width=width, height=height, params=params + ) + return cameras + + +def read_cameras_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + cameras = {} + with open(path_to_model_file, "rb") as fid: + num_cameras = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_cameras): + camera_properties = read_next_bytes( + fid, num_bytes=24, format_char_sequence="iiQQ" + ) + camera_id = camera_properties[0] + model_id = camera_properties[1] + model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name + width = camera_properties[2] + height = camera_properties[3] + num_params = CAMERA_MODEL_IDS[model_id].num_params + params = read_next_bytes( + fid, num_bytes=8 * num_params, format_char_sequence="d" * num_params + ) + cameras[camera_id] = Camera( + id=camera_id, + model=model_name, + width=width, + height=height, + params=np.array(params), + ) + assert len(cameras) == num_cameras + return cameras + + +def write_cameras_text(cameras, path): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasText(const std::string& path) + void Reconstruction::ReadCamerasText(const std::string& path) + """ + HEADER = ( + "# Camera list with one line of data per camera:\n" + + "# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n" + + "# Number of cameras: {}\n".format(len(cameras)) + ) + with open(path, "w") as fid: + fid.write(HEADER) + for _, cam in cameras.items(): + to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params] + line = " ".join([str(elem) for elem in to_write]) + fid.write(line + "\n") + + +def write_cameras_binary(cameras, path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(cameras), "Q") + for _, cam in cameras.items(): + model_id = CAMERA_MODEL_NAMES[cam.model].model_id + camera_properties = [cam.id, model_id, cam.width, cam.height] + write_next_bytes(fid, camera_properties, "iiQQ") + for p in cam.params: + write_next_bytes(fid, float(p), "d") + return cameras + + +def read_images_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesText(const std::string& path) + void Reconstruction::WriteImagesText(const std::string& path) + """ + images = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + image_id = int(elems[0]) + qvec = np.array(tuple(map(float, elems[1:5]))) + tvec = np.array(tuple(map(float, elems[5:8]))) + camera_id = int(elems[8]) + image_name = elems[9] + elems = fid.readline().split() + xys = np.column_stack( + [tuple(map(float, elems[0::3])), tuple(map(float, elems[1::3]))] + ) + point3D_ids = np.array(tuple(map(int, elems[2::3]))) + images[image_id] = Image( + id=image_id, + qvec=qvec, + tvec=tvec, + camera_id=camera_id, + name=image_name, + xys=xys, + point3D_ids=point3D_ids, + ) + return images + + +def read_images_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + images = {} + with open(path_to_model_file, "rb") as fid: + num_reg_images = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_reg_images): + binary_image_properties = read_next_bytes( + fid, num_bytes=64, format_char_sequence="idddddddi" + ) + image_id = binary_image_properties[0] + qvec = np.array(binary_image_properties[1:5]) + tvec = np.array(binary_image_properties[5:8]) + camera_id = binary_image_properties[8] + image_name = "" + current_char = read_next_bytes(fid, 1, "c")[0] + while current_char != b"\x00": # look for the ASCII 0 entry + image_name += current_char.decode("utf-8") + current_char = read_next_bytes(fid, 1, "c")[0] + num_points2D = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[ + 0 + ] + x_y_id_s = read_next_bytes( + fid, + num_bytes=24 * num_points2D, + format_char_sequence="ddq" * num_points2D, + ) + xys = np.column_stack( + [tuple(map(float, x_y_id_s[0::3])), tuple(map(float, x_y_id_s[1::3]))] + ) + point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) + images[image_id] = Image( + id=image_id, + qvec=qvec, + tvec=tvec, + camera_id=camera_id, + name=image_name, + xys=xys, + point3D_ids=point3D_ids, + ) + return images + + +def write_images_text(images, path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesText(const std::string& path) + void Reconstruction::WriteImagesText(const std::string& path) + """ + if len(images) == 0: + mean_observations = 0 + else: + mean_observations = sum( + (len(img.point3D_ids) for _, img in images.items()) + ) / len(images) + HEADER = ( + "# Image list with two lines of data per image:\n" + + "# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n" + + "# POINTS2D[] as (X, Y, POINT3D_ID)\n" + + "# Number of images: {}, mean observations per image: {}\n".format( + len(images), mean_observations + ) + ) + + with open(path, "w") as fid: + fid.write(HEADER) + for _, img in images.items(): + image_header = [img.id, *img.qvec, *img.tvec, img.camera_id, img.name] + first_line = " ".join(map(str, image_header)) + fid.write(first_line + "\n") + + points_strings = [] + for xy, point3D_id in zip(img.xys, img.point3D_ids): + points_strings.append(" ".join(map(str, [*xy, point3D_id]))) + fid.write(" ".join(points_strings) + "\n") + + +def write_images_binary(images, path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(images), "Q") + for _, img in images.items(): + write_next_bytes(fid, img.id, "i") + write_next_bytes(fid, img.qvec.tolist(), "dddd") + write_next_bytes(fid, img.tvec.tolist(), "ddd") + write_next_bytes(fid, img.camera_id, "i") + for char in img.name: + write_next_bytes(fid, char.encode("utf-8"), "c") + write_next_bytes(fid, b"\x00", "c") + write_next_bytes(fid, len(img.point3D_ids), "Q") + for xy, p3d_id in zip(img.xys, img.point3D_ids): + write_next_bytes(fid, [*xy, p3d_id], "ddq") + + +def read_points3D_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + points3D = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + point3D_id = int(elems[0]) + xyz = np.array(tuple(map(float, elems[1:4]))) + rgb = np.array(tuple(map(int, elems[4:7]))) + error = float(elems[7]) + image_ids = np.array(tuple(map(int, elems[8::2]))) + point2D_idxs = np.array(tuple(map(int, elems[9::2]))) + points3D[point3D_id] = Point3D( + id=point3D_id, + xyz=xyz, + rgb=rgb, + error=error, + image_ids=image_ids, + point2D_idxs=point2D_idxs, + ) + return points3D + + +def read_points3D_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + points3D = {} + with open(path_to_model_file, "rb") as fid: + num_points = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_points): + binary_point_line_properties = read_next_bytes( + fid, num_bytes=43, format_char_sequence="QdddBBBd" + ) + point3D_id = binary_point_line_properties[0] + xyz = np.array(binary_point_line_properties[1:4]) + rgb = np.array(binary_point_line_properties[4:7]) + error = np.array(binary_point_line_properties[7]) + track_length = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[ + 0 + ] + track_elems = read_next_bytes( + fid, + num_bytes=8 * track_length, + format_char_sequence="ii" * track_length, + ) + image_ids = np.array(tuple(map(int, track_elems[0::2]))) + point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) + points3D[point3D_id] = Point3D( + id=point3D_id, + xyz=xyz, + rgb=rgb, + error=error, + image_ids=image_ids, + point2D_idxs=point2D_idxs, + ) + return points3D + + +def write_points3D_text(points3D, path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + if len(points3D) == 0: + mean_track_length = 0 + else: + mean_track_length = sum( + (len(pt.image_ids) for _, pt in points3D.items()) + ) / len(points3D) + HEADER = ( + "# 3D point list with one line of data per point:\n" + + "# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n" # noqa: E501 + + "# Number of points: {}, mean track length: {}\n".format( + len(points3D), mean_track_length + ) + ) + + with open(path, "w") as fid: + fid.write(HEADER) + for _, pt in points3D.items(): + point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error] + fid.write(" ".join(map(str, point_header)) + " ") + track_strings = [] + for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs): + track_strings.append(" ".join(map(str, [image_id, point2D]))) + fid.write(" ".join(track_strings) + "\n") + + +def write_points3D_binary(points3D, path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(points3D), "Q") + for _, pt in points3D.items(): + write_next_bytes(fid, pt.id, "Q") + write_next_bytes(fid, pt.xyz.tolist(), "ddd") + write_next_bytes(fid, pt.rgb.tolist(), "BBB") + write_next_bytes(fid, pt.error, "d") + track_length = pt.image_ids.shape[0] + write_next_bytes(fid, track_length, "Q") + for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs): + write_next_bytes(fid, [image_id, point2D_id], "ii") + + +def detect_model_format(path, ext): + if ( + os.path.isfile(os.path.join(path, "cameras" + ext)) + and os.path.isfile(os.path.join(path, "images" + ext)) + and os.path.isfile(os.path.join(path, "points3D" + ext)) + ): + return True + + return False + + +def read_model(path, ext=""): + # try to detect the extension automatically + if ext == "": + if detect_model_format(path, ".bin"): + ext = ".bin" + elif detect_model_format(path, ".txt"): + ext = ".txt" + else: + try: + cameras, images, points3D = read_model(os.path.join(path, "model/")) + logger.warning("This SfM file structure was deprecated in hloc v1.1") + return cameras, images, points3D + except FileNotFoundError: + raise FileNotFoundError( + f"Could not find binary or text COLMAP model at {path}" + ) + + if ext == ".txt": + cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) + images = read_images_text(os.path.join(path, "images" + ext)) + points3D = read_points3D_text(os.path.join(path, "points3D") + ext) + else: + cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) + images = read_images_binary(os.path.join(path, "images" + ext)) + points3D = read_points3D_binary(os.path.join(path, "points3D") + ext) + return cameras, images, points3D + + +def write_model(cameras, images, points3D, path, ext=".bin"): + if ext == ".txt": + write_cameras_text(cameras, os.path.join(path, "cameras" + ext)) + write_images_text(images, os.path.join(path, "images" + ext)) + write_points3D_text(points3D, os.path.join(path, "points3D") + ext) + else: + write_cameras_binary(cameras, os.path.join(path, "cameras" + ext)) + write_images_binary(images, os.path.join(path, "images" + ext)) + write_points3D_binary(points3D, os.path.join(path, "points3D") + ext) + return cameras, images, points3D + + +def qvec2rotmat(qvec): + return np.array( + [ + [ + 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2], + ], + [ + 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1], + ], + [ + 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2, + ], + ] + ) + + +def rotmat2qvec(R): + Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat + K = ( + np.array( + [ + [Rxx - Ryy - Rzz, 0, 0, 0], + [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], + [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], + [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz], + ] + ) + / 3.0 + ) + eigvals, eigvecs = np.linalg.eigh(K) + qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] + if qvec[0] < 0: + qvec *= -1 + return qvec + + +def main(): + parser = argparse.ArgumentParser( + description="Read and write COLMAP binary and text models" + ) + parser.add_argument("--input_model", help="path to input model folder") + parser.add_argument( + "--input_format", + choices=[".bin", ".txt"], + help="input model format", + default="", + ) + parser.add_argument("--output_model", help="path to output model folder") + parser.add_argument( + "--output_format", + choices=[".bin", ".txt"], + help="outut model format", + default=".txt", + ) + args = parser.parse_args() + + cameras, images, points3D = read_model(path=args.input_model, ext=args.input_format) + + print("num_cameras:", len(cameras)) + print("num_images:", len(images)) + print("num_points3D:", len(points3D)) + + if args.output_model is not None: + write_model( + cameras, images, points3D, path=args.output_model, ext=args.output_format + ) + + +if __name__ == "__main__": + main() diff --git a/hloc/utils/viz.py b/hloc/utils/viz.py new file mode 100644 index 0000000000000000000000000000000000000000..360466b8a4487bad8b7c0b687081c15f99c6d63f --- /dev/null +++ b/hloc/utils/viz.py @@ -0,0 +1,144 @@ +""" +2D visualization primitives based on Matplotlib. + +1) Plot images with `plot_images`. +2) Call `plot_keypoints` or `plot_matches` any number of times. +3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`. +""" + +import matplotlib +import matplotlib.pyplot as plt +import matplotlib.patheffects as path_effects +import numpy as np + + +def cm_RdGn(x): + """Custom colormap: red (0) -> yellow (0.5) -> green (1).""" + x = np.clip(x, 0, 1)[..., None] * 2 + c = x * np.array([[0, 1.0, 0]]) + (2 - x) * np.array([[1.0, 0, 0]]) + return np.clip(c, 0, 1) + + +def plot_images( + imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True, figsize=4.5 +): + """Plot a set of images horizontally. + Args: + imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W). + titles: a list of strings, as titles for each image. + cmaps: colormaps for monochrome images. + adaptive: whether the figure size should fit the image aspect ratios. + """ + n = len(imgs) + if not isinstance(cmaps, (list, tuple)): + cmaps = [cmaps] * n + + if adaptive: + ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H + else: + ratios = [4 / 3] * n + figsize = [sum(ratios) * figsize, figsize] + fig, axs = plt.subplots( + 1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios} + ) + if n == 1: + axs = [axs] + for i, (img, ax) in enumerate(zip(imgs, axs)): + ax.imshow(img, cmap=plt.get_cmap(cmaps[i])) + ax.set_axis_off() + if titles: + ax.set_title(titles[i]) + fig.tight_layout(pad=pad) + return fig + +def plot_keypoints(kpts, colors="lime", ps=4): + """Plot keypoints for existing images. + Args: + kpts: list of ndarrays of size (N, 2). + colors: string, or list of list of tuples (one for each keypoints). + ps: size of the keypoints as float. + """ + if not isinstance(colors, list): + colors = [colors] * len(kpts) + axes = plt.gcf().axes + try: + for a, k, c in zip(axes, kpts, colors): + a.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0) + except IndexError as e: + pass + + +def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.0): + """Plot matches for a pair of existing images. + Args: + kpts0, kpts1: corresponding keypoints of size (N, 2). + color: color of each match, string or RGB tuple. Random if not given. + lw: width of the lines. + ps: size of the end points (no endpoint if ps=0) + indices: indices of the images to draw the matches on. + a: alpha opacity of the match lines. + """ + fig = plt.gcf() + ax = fig.axes + assert len(ax) > max(indices) + ax0, ax1 = ax[indices[0]], ax[indices[1]] + fig.canvas.draw() + + assert len(kpts0) == len(kpts1) + if color is None: + color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist() + elif len(color) > 0 and not isinstance(color[0], (tuple, list)): + color = [color] * len(kpts0) + + if lw > 0: + # transform the points into the figure coordinate system + for i in range(len(kpts0)): + fig.add_artist( + matplotlib.patches.ConnectionPatch( + xyA=(kpts0[i, 0], kpts0[i, 1]), + coordsA=ax0.transData, + xyB=(kpts1[i, 0], kpts1[i, 1]), + coordsB=ax1.transData, + zorder=1, + color=color[i], + linewidth=lw, + alpha=a, + ) + ) + + # freeze the axes to prevent the transform to change + ax0.autoscale(enable=False) + ax1.autoscale(enable=False) + + if ps > 0: + ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps) + ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps) + + +def add_text( + idx, + text, + pos=(0.01, 0.99), + fs=15, + color="w", + lcolor="k", + lwidth=2, + ha="left", + va="top", +): + ax = plt.gcf().axes[idx] + t = ax.text( + *pos, text, fontsize=fs, ha=ha, va=va, color=color, transform=ax.transAxes + ) + if lcolor is not None: + t.set_path_effects( + [ + path_effects.Stroke(linewidth=lwidth, foreground=lcolor), + path_effects.Normal(), + ] + ) + + +def save_plot(path, **kw): + """Save the current figure without any white margin.""" + plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw) diff --git a/hloc/utils/viz_3d.py b/hloc/utils/viz_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..e608f7828306c43ef2a6a7898752d70469f9bac9 --- /dev/null +++ b/hloc/utils/viz_3d.py @@ -0,0 +1,203 @@ +""" +3D visualization based on plotly. +Works for a small number of points and cameras, might be slow otherwise. + +1) Initialize a figure with `init_figure` +2) Add 3D points, camera frustums, or both as a pycolmap.Reconstruction + +Written by Paul-Edouard Sarlin and Philipp Lindenberger. +""" + +from typing import Optional + +import numpy as np +import plotly.graph_objects as go +import pycolmap + + +def to_homogeneous(points): + pad = np.ones((points.shape[:-1] + (1,)), dtype=points.dtype) + return np.concatenate([points, pad], axis=-1) + + +def init_figure(height: int = 800) -> go.Figure: + """Initialize a 3D figure.""" + fig = go.Figure() + axes = dict( + visible=False, + showbackground=False, + showgrid=False, + showline=False, + showticklabels=True, + autorange=True, + ) + fig.update_layout( + template="plotly_dark", + height=height, + scene_camera=dict( + eye=dict(x=0.0, y=-0.1, z=-2), + up=dict(x=0, y=-1.0, z=0), + projection=dict(type="orthographic"), + ), + scene=dict( + xaxis=axes, + yaxis=axes, + zaxis=axes, + aspectmode="data", + dragmode="orbit", + ), + margin=dict(l=0, r=0, b=0, t=0, pad=0), + legend=dict(orientation="h", yanchor="top", y=0.99, xanchor="left", x=0.1), + ) + return fig + + +def plot_points( + fig: go.Figure, + pts: np.ndarray, + color: str = "rgba(255, 0, 0, 1)", + ps: int = 2, + colorscale: Optional[str] = None, + name: Optional[str] = None, +): + """Plot a set of 3D points.""" + x, y, z = pts.T + tr = go.Scatter3d( + x=x, + y=y, + z=z, + mode="markers", + name=name, + legendgroup=name, + marker=dict(size=ps, color=color, line_width=0.0, colorscale=colorscale), + ) + fig.add_trace(tr) + + +def plot_camera( + fig: go.Figure, + R: np.ndarray, + t: np.ndarray, + K: np.ndarray, + color: str = "rgb(0, 0, 255)", + name: Optional[str] = None, + legendgroup: Optional[str] = None, + fill: bool = False, + size: float = 1.0, + text: Optional[str] = None, +): + """Plot a camera frustum from pose and intrinsic matrix.""" + W, H = K[0, 2] * 2, K[1, 2] * 2 + corners = np.array([[0, 0], [W, 0], [W, H], [0, H], [0, 0]]) + if size is not None: + image_extent = max(size * W / 1024.0, size * H / 1024.0) + world_extent = max(W, H) / (K[0, 0] + K[1, 1]) / 0.5 + scale = 0.5 * image_extent / world_extent + else: + scale = 1.0 + corners = to_homogeneous(corners) @ np.linalg.inv(K).T + corners = (corners / 2 * scale) @ R.T + t + legendgroup = legendgroup if legendgroup is not None else name + + x, y, z = np.concatenate(([t], corners)).T + i = [0, 0, 0, 0] + j = [1, 2, 3, 4] + k = [2, 3, 4, 1] + + if fill: + pyramid = go.Mesh3d( + x=x, + y=y, + z=z, + color=color, + i=i, + j=j, + k=k, + legendgroup=legendgroup, + name=name, + showlegend=False, + hovertemplate=text.replace("\n", "
"), + ) + fig.add_trace(pyramid) + + triangles = np.vstack((i, j, k)).T + vertices = np.concatenate(([t], corners)) + tri_points = np.array([vertices[i] for i in triangles.reshape(-1)]) + x, y, z = tri_points.T + + pyramid = go.Scatter3d( + x=x, + y=y, + z=z, + mode="lines", + legendgroup=legendgroup, + name=name, + line=dict(color=color, width=1), + showlegend=False, + hovertemplate=text.replace("\n", "
"), + ) + fig.add_trace(pyramid) + + +def plot_camera_colmap( + fig: go.Figure, + image: pycolmap.Image, + camera: pycolmap.Camera, + name: Optional[str] = None, + **kwargs +): + """Plot a camera frustum from PyCOLMAP objects""" + world_t_camera = image.cam_from_world.inverse() + plot_camera( + fig, + world_t_camera.rotation.matrix(), + world_t_camera.translation, + camera.calibration_matrix(), + name=name or str(image.image_id), + text=str(image), + **kwargs + ) + + +def plot_cameras(fig: go.Figure, reconstruction: pycolmap.Reconstruction, **kwargs): + """Plot a camera as a cone with camera frustum.""" + for image_id, image in reconstruction.images.items(): + plot_camera_colmap( + fig, image, reconstruction.cameras[image.camera_id], **kwargs + ) + + +def plot_reconstruction( + fig: go.Figure, + rec: pycolmap.Reconstruction, + max_reproj_error: float = 6.0, + color: str = "rgb(0, 0, 255)", + name: Optional[str] = None, + min_track_length: int = 2, + points: bool = True, + cameras: bool = True, + points_rgb: bool = True, + cs: float = 1.0, +): + # Filter outliers + bbs = rec.compute_bounding_box(0.001, 0.999) + # Filter points, use original reproj error here + p3Ds = [ + p3D + for _, p3D in rec.points3D.items() + if ( + (p3D.xyz >= bbs[0]).all() + and (p3D.xyz <= bbs[1]).all() + and p3D.error <= max_reproj_error + and p3D.track.length() >= min_track_length + ) + ] + xyzs = [p3D.xyz for p3D in p3Ds] + if points_rgb: + pcolor = [p3D.color for p3D in p3Ds] + else: + pcolor = color + if points: + plot_points(fig, np.array(xyzs), color=pcolor, ps=1, name=name) + if cameras: + plot_cameras(fig, rec, color=color, legendgroup=name, size=cs) diff --git a/hloc/visualization.py b/hloc/visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..77369efb3ca9485bf3d60c4837c934d86191d15d --- /dev/null +++ b/hloc/visualization.py @@ -0,0 +1,182 @@ +import pickle +import random + +import numpy as np +import pycolmap +from matplotlib import cm + +from .utils.io import read_image +from .utils.viz import ( + add_text, + cm_RdGn, + plot_images, + plot_keypoints, + plot_matches, +) + + +def visualize_sfm_2d( + reconstruction, + image_dir, + color_by="visibility", + selected=[], + n=1, + seed=0, + dpi=75, +): + assert image_dir.exists() + if not isinstance(reconstruction, pycolmap.Reconstruction): + reconstruction = pycolmap.Reconstruction(reconstruction) + + if not selected: + image_ids = reconstruction.reg_image_ids() + selected = random.Random(seed).sample(image_ids, min(n, len(image_ids))) + + for i in selected: + image = reconstruction.images[i] + keypoints = np.array([p.xy for p in image.points2D]) + visible = np.array([p.has_point3D() for p in image.points2D]) + + if color_by == "visibility": + color = [(0, 0, 1) if v else (1, 0, 0) for v in visible] + text = f"visible: {np.count_nonzero(visible)}/{len(visible)}" + elif color_by == "track_length": + tl = np.array( + [ + ( + reconstruction.points3D[p.point3D_id].track.length() + if p.has_point3D() + else 1 + ) + for p in image.points2D + ] + ) + max_, med_ = np.max(tl), np.median(tl[tl > 1]) + tl = np.log(tl) + color = cm.jet(tl / tl.max()).tolist() + text = f"max/median track length: {max_}/{med_}" + elif color_by == "depth": + p3ids = [p.point3D_id for p in image.points2D if p.has_point3D()] + z = np.array( + [ + (image.cam_from_world * reconstruction.points3D[j].xyz)[-1] + for j in p3ids + ] + ) + z -= z.min() + color = cm.jet(z / np.percentile(z, 99.9)) + text = f"visible: {np.count_nonzero(visible)}/{len(visible)}" + keypoints = keypoints[visible] + else: + raise NotImplementedError(f"Coloring not implemented: {color_by}.") + + name = image.name + fig = plot_images([read_image(image_dir / name)], dpi=dpi) + plot_keypoints([keypoints], colors=[color], ps=4) + add_text(0, text) + add_text(0, name, pos=(0.01, 0.01), fs=5, lcolor=None, va="bottom") + return fig + + +def visualize_loc( + results, + image_dir, + reconstruction=None, + db_image_dir=None, + selected=[], + n=1, + seed=0, + prefix=None, + **kwargs, +): + assert image_dir.exists() + + with open(str(results) + "_logs.pkl", "rb") as f: + logs = pickle.load(f) + + if not selected: + queries = list(logs["loc"].keys()) + if prefix: + queries = [q for q in queries if q.startswith(prefix)] + selected = random.Random(seed).sample(queries, min(n, len(queries))) + + if reconstruction is not None: + if not isinstance(reconstruction, pycolmap.Reconstruction): + reconstruction = pycolmap.Reconstruction(reconstruction) + + for qname in selected: + loc = logs["loc"][qname] + visualize_loc_from_log( + image_dir, qname, loc, reconstruction, db_image_dir, **kwargs + ) + + +def visualize_loc_from_log( + image_dir, + query_name, + loc, + reconstruction=None, + db_image_dir=None, + top_k_db=2, + dpi=75, +): + q_image = read_image(image_dir / query_name) + if loc.get("covisibility_clustering", False): + # select the first, largest cluster if the localization failed + loc = loc["log_clusters"][loc["best_cluster"] or 0] + + inliers = np.array(loc["PnP_ret"]["inliers"]) + mkp_q = loc["keypoints_query"] + n = len(loc["db"]) + if reconstruction is not None: + # for each pair of query keypoint and its matched 3D point, + # we need to find its corresponding keypoint in each database image + # that observes it. We also count the number of inliers in each. + kp_idxs, kp_to_3D_to_db = loc["keypoint_index_to_db"] + counts = np.zeros(n) + dbs_kp_q_db = [[] for _ in range(n)] + inliers_dbs = [[] for _ in range(n)] + for i, (inl, (p3D_id, db_idxs)) in enumerate( + zip(inliers, kp_to_3D_to_db) + ): + track = reconstruction.points3D[p3D_id].track + track = {el.image_id: el.point2D_idx for el in track.elements} + for db_idx in db_idxs: + counts[db_idx] += inl + kp_db = track[loc["db"][db_idx]] + dbs_kp_q_db[db_idx].append((i, kp_db)) + inliers_dbs[db_idx].append(inl) + else: + # for inloc the database keypoints are already in the logs + assert "keypoints_db" in loc + assert "indices_db" in loc + counts = np.array( + [np.sum(loc["indices_db"][inliers] == i) for i in range(n)] + ) + + # display the database images with the most inlier matches + db_sort = np.argsort(-counts) + for db_idx in db_sort[:top_k_db]: + if reconstruction is not None: + db = reconstruction.images[loc["db"][db_idx]] + db_name = db.name + db_kp_q_db = np.array(dbs_kp_q_db[db_idx]) + kp_q = mkp_q[db_kp_q_db[:, 0]] + kp_db = np.array([db.points2D[i].xy for i in db_kp_q_db[:, 1]]) + inliers_db = inliers_dbs[db_idx] + else: + db_name = loc["db"][db_idx] + kp_q = mkp_q[loc["indices_db"] == db_idx] + kp_db = loc["keypoints_db"][loc["indices_db"] == db_idx] + inliers_db = inliers[loc["indices_db"] == db_idx] + + db_image = read_image((db_image_dir or image_dir) / db_name) + color = cm_RdGn(inliers_db).tolist() + text = f"inliers: {sum(inliers_db)}/{len(inliers_db)}" + + plot_images([q_image, db_image], dpi=dpi) + plot_matches(kp_q, kp_db, color, a=0.1) + add_text(0, text) + opts = dict(pos=(0.01, 0.01), fs=5, lcolor=None, va="bottom") + add_text(0, query_name, **opts) + add_text(1, db_name, **opts) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..2d4c90b4c4dd31de3ec69e2a0c7194a41de5485b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,36 @@ +[project] +name = "ImageMatchingWebui" +description = "Image Matching Webui: A tool for matching images using sota algorithms with a Gradio UI" +version = "1.0" +authors = [ + {name = "vincentqyw"}, +] +readme = "README.md" +requires-python = ">=3.8" +license = {file = "LICENSE"} +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", +] +urls = {Repository = "https://github.com/Vincentqyw/image-matching-webui"} +dynamic = ["dependencies"] + +[project.optional-dependencies] +dev = ["black", "flake8", "isort"] + +[tool.setuptools.packages.find] +include = ["hloc*", "ui",] + +[tool.setuptools.package-data] +ui = ["*.yaml"] + +[tool.setuptools.dynamic] +dependencies = {file = ["requirements.txt"]} + +[tool.black] +line-length = 80 + +[tool.isort] +profile = "black" +line_length = 80 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..99d3d921dada752f12babe57f0e6830a3176e5b7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,38 @@ +e2cnn +einops +easydict +gdown +# gradio==4.44.1 +h5py +huggingface_hub +imageio +Jinja2 +kornia +loguru +matplotlib +numpy==1.23.5 +onnxruntime +omegaconf +opencv-python +opencv-contrib-python +pandas +psutil +plotly +protobuf +poselib +pycolmap==0.6.1 +pytlsd +PyYAML +pytorch-lightning==1.4.9 +scikit-image +scikit-learn +scipy +seaborn +shapely +tensorboardX==2.6.1 +torchmetrics==0.6.0 +torchvision==0.19.0 +roma #dust3r +tqdm +yacs +fastapi diff --git a/test_app_cli.py b/test_app_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..f75ff6b1de6dc8227d3f18ee7bdcd84979a41434 --- /dev/null +++ b/test_app_cli.py @@ -0,0 +1,112 @@ +import sys +from pathlib import Path + +import cv2 + +from hloc import logger +from ui.utils import DEVICE, ROOT, get_matcher_zoo, load_config + +sys.path.append(str(Path(__file__).parents[1])) +from api.server import ImageMatchingAPI + + +def test_all(config: dict = None): + img_path1 = ROOT / "datasets/sacre_coeur/mapping/02928139_3448003521.jpg" + img_path2 = ROOT / "datasets/sacre_coeur/mapping/17295357_9106075285.jpg" + image0 = cv2.imread(str(img_path1))[:, :, ::-1] # RGB + image1 = cv2.imread(str(img_path2))[:, :, ::-1] # RGB + + matcher_zoo_restored = get_matcher_zoo(config["matcher_zoo"]) + for k, v in matcher_zoo_restored.items(): + if image0 is None or image1 is None: + logger.error("Error: No images found! Please upload two images.") + enable = config["matcher_zoo"][k].get("enable", True) + skip_ci = config["matcher_zoo"][k].get("skip_ci", False) + if enable and not skip_ci: + logger.info(f"Testing {k} ...") + api = ImageMatchingAPI(conf=v, device=DEVICE) + api(image0, image1) + log_path = ROOT / "experiments" / "all" + log_path.mkdir(exist_ok=True, parents=True) + api.visualize(log_path=log_path) + else: + logger.info(f"Skipping {k} ...") + return 0 + + +def test_one(): + img_path1 = ROOT / "datasets/sacre_coeur/mapping/02928139_3448003521.jpg" + img_path2 = ROOT / "datasets/sacre_coeur/mapping/17295357_9106075285.jpg" + image0 = cv2.imread(str(img_path1))[:, :, ::-1] # RGB + image1 = cv2.imread(str(img_path2))[:, :, ::-1] # RGB + # sparse + conf = { + "feature": { + "output": "feats-superpoint-n4096-rmax1600", + "model": { + "name": "superpoint", + "nms_radius": 3, + "max_keypoints": 4096, + "keypoint_threshold": 0.005, + }, + "preprocessing": { + "grayscale": True, + "force_resize": True, + "resize_max": 1600, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "matcher": { + "output": "matches-NN-mutual", + "model": { + "name": "nearest_neighbor", + "do_mutual_check": True, + "match_threshold": 0.2, + }, + }, + "dense": False, + } + api = ImageMatchingAPI(conf=conf, device=DEVICE) + api(image0, image1) + log_path = ROOT / "experiments" / "one" + log_path.mkdir(exist_ok=True, parents=True) + api.visualize(log_path=log_path) + + # dense + conf = { + "matcher": { + "output": "matches-loftr", + "model": { + "name": "loftr", + "weights": "outdoor", + "max_keypoints": 2000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": True, + "resize_max": 1024, + "dfactor": 8, + "width": 640, + "height": 480, + "force_resize": True, + }, + "max_error": 1, + "cell_size": 1, + }, + "dense": True, + } + + api = ImageMatchingAPI(conf=conf, device=DEVICE) + api(image0, image1) + log_path = ROOT / "experiments" / "one" + log_path.mkdir(exist_ok=True, parents=True) + api.visualize(log_path=log_path) + return 0 + + +if __name__ == "__main__": + config = load_config(ROOT / "ui/config.yaml") + test_one() + test_all(config) diff --git a/ui/__init__.py b/ui/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ac6ccf52978e85f5abaca55d6559c74a6b2bd169 --- /dev/null +++ b/ui/__init__.py @@ -0,0 +1,5 @@ +__version__ = "1.0.1" + + +def get_version(): + return __version__ diff --git a/ui/app_class.py b/ui/app_class.py new file mode 100644 index 0000000000000000000000000000000000000000..628a9a71d4f13193c3573398bdba9b380e216a04 --- /dev/null +++ b/ui/app_class.py @@ -0,0 +1,849 @@ +import sys +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + +import gradio as gr +import numpy as np +from easydict import EasyDict as edict +from omegaconf import OmegaConf + +sys.path.append(str(Path(__file__).parents[1])) + +from ui.sfm import SfmEngine +from ui.utils import ( + GRADIO_VERSION, + gen_examples, + generate_warp_images, + get_matcher_zoo, + load_config, + ransac_zoo, + run_matching, + run_ransac, + send_to_match, +) + +DESCRIPTION = """ +# Image Matching WebUI +This Space demonstrates [Image Matching WebUI](https://github.com/Vincentqyw/image-matching-webui) by vincent qin. Feel free to play with it, or duplicate to run image matching without a queue! +
+🔎 For more details about supported local features and matchers, please refer to https://github.com/Vincentqyw/image-matching-webui + +🚀 All algorithms run on CPU for inference, causing slow speeds and high latency. For faster inference, please download the [source code](https://github.com/Vincentqyw/image-matching-webui) for local deployment. + +🐛 Your feedback is valuable to me. Please do not hesitate to report any bugs [here](https://github.com/Vincentqyw/image-matching-webui/issues). +""" + +CSS = """ +#warning {background-color: #FFCCCB} +.logs_class textarea {font-size: 12px !important} +""" + + +class ImageMatchingApp: + def __init__(self, server_name="0.0.0.0", server_port=7860, **kwargs): + self.server_name = server_name + self.server_port = server_port + self.config_path = kwargs.get( + "config", Path(__file__).parent / "config.yaml" + ) + self.cfg = load_config(self.config_path) + self.matcher_zoo = get_matcher_zoo(self.cfg["matcher_zoo"]) + self.app = None + self.init_interface() + # print all the keys + + def init_matcher_dropdown(self): + algos = [] + for k, v in self.cfg["matcher_zoo"].items(): + if v.get("enable", True): + algos.append(k) + return algos + + def init_interface(self): + with gr.Blocks(css=CSS) as self.app: + with gr.Tab("Image Matching"): + with gr.Row(): + with gr.Column(scale=1): + gr.Image( + str( + Path(__file__).parent.parent + / "assets/logo.webp" + ), + elem_id="logo-img", + show_label=False, + show_share_button=False, + show_download_button=False, + ) + with gr.Column(scale=3): + gr.Markdown(DESCRIPTION) + with gr.Row(equal_height=False): + with gr.Column(): + with gr.Row(): + matcher_list = gr.Dropdown( + choices=self.init_matcher_dropdown(), + value="disk+lightglue", + label="Matching Model", + interactive=True, + ) + match_image_src = gr.Radio( + ( + ["upload", "webcam", "clipboard"] + if GRADIO_VERSION > "3" + else ["upload", "webcam", "canvas"] + ), + label="Image Source", + value="upload", + ) + with gr.Row(): + input_image0 = gr.Image( + label="Image 0", + type="numpy", + image_mode="RGB", + height=300 if GRADIO_VERSION > "3" else None, + interactive=True, + ) + input_image1 = gr.Image( + label="Image 1", + type="numpy", + image_mode="RGB", + height=300 if GRADIO_VERSION > "3" else None, + interactive=True, + ) + + with gr.Row(): + button_reset = gr.Button(value="Reset") + button_run = gr.Button( + value="Run Match", variant="primary" + ) + + with gr.Accordion("Advanced Setting", open=False): + with gr.Accordion("Image Setting", open=True): + with gr.Row(): + image_force_resize_cb = gr.Checkbox( + label="Force Resize", + value=False, + interactive=True, + ) + image_setting_height = gr.Slider( + minimum=48, + maximum=2048, + step=16, + label="Image Height", + value=480, + visible=False, + ) + image_setting_width = gr.Slider( + minimum=64, + maximum=2048, + step=16, + label="Image Width", + value=640, + visible=False, + ) + with gr.Accordion("Matching Setting", open=True): + with gr.Row(): + match_setting_threshold = gr.Slider( + minimum=0.0, + maximum=1, + step=0.001, + label="Match threshold", + value=0.1, + ) + match_setting_max_keypoints = gr.Slider( + minimum=10, + maximum=10000, + step=10, + label="Max features", + value=1000, + ) + # TODO: add line settings + with gr.Row(): + detect_keypoints_threshold = gr.Slider( + minimum=0, + maximum=1, + step=0.001, + label="Keypoint threshold", + value=0.015, + ) + detect_line_threshold = ( # noqa: F841 + gr.Slider( + minimum=0.1, + maximum=1, + step=0.01, + label="Line threshold", + value=0.2, + ) + ) + # matcher_lists = gr.Radio( + # ["NN-mutual", "Dual-Softmax"], + # label="Matcher mode", + # value="NN-mutual", + # ) + with gr.Accordion("RANSAC Setting", open=True): + with gr.Row(equal_height=False): + ransac_method = gr.Dropdown( + choices=ransac_zoo.keys(), + value=self.cfg["defaults"][ + "ransac_method" + ], + label="RANSAC Method", + interactive=True, + ) + ransac_reproj_threshold = gr.Slider( + minimum=0.0, + maximum=12, + step=0.01, + label="Ransac Reproj threshold", + value=8.0, + ) + ransac_confidence = gr.Slider( + minimum=0.0, + maximum=1, + step=0.00001, + label="Ransac Confidence", + value=self.cfg["defaults"][ + "ransac_confidence" + ], + ) + ransac_max_iter = gr.Slider( + minimum=0.0, + maximum=100000, + step=100, + label="Ransac Iterations", + value=self.cfg["defaults"][ + "ransac_max_iter" + ], + ) + button_ransac = gr.Button( + value="Rerun RANSAC", variant="primary" + ) + with gr.Accordion("Geometry Setting", open=False): + with gr.Row(equal_height=False): + choice_geometry_type = gr.Radio( + ["Fundamental", "Homography"], + label="Reconstruct Geometry", + value=self.cfg["defaults"][ + "setting_geometry" + ], + ) + # image resize + image_force_resize_cb.select( + fn=self._on_select_force_resize, + inputs=image_force_resize_cb, + outputs=[image_setting_width, image_setting_height], + ) + # collect inputs + state_cache = gr.State({}) + inputs = [ + input_image0, + input_image1, + match_setting_threshold, + match_setting_max_keypoints, + detect_keypoints_threshold, + matcher_list, + ransac_method, + ransac_reproj_threshold, + ransac_confidence, + ransac_max_iter, + choice_geometry_type, + gr.State(self.matcher_zoo), + image_force_resize_cb, + image_setting_width, + image_setting_height, + ] + + # Add some examples + with gr.Row(): + # Example inputs + with gr.Accordion( + "Open for More: Examples", open=True + ): + gr.Examples( + examples=gen_examples(), + inputs=inputs, + outputs=[], + fn=run_matching, + cache_examples=False, + label=( + "Examples (click one of the images below to Run" + " Match). Thx: WxBS" + ), + ) + with gr.Accordion("Supported Algorithms", open=False): + # add a table of supported algorithms + self.display_supported_algorithms() + + with gr.Column(): + + with gr.Accordion( + "Open for More: Keypoints", open=True + ): + output_keypoints = gr.Image( + label="Keypoints", type="numpy" + ) + with gr.Accordion( + ( + "Open for More: Raw Matches" + " (Green for good matches, Red for bad)" + ), + open=False, + ): + output_matches_raw = gr.Image( + label="Raw Matches", + type="numpy", + ) + with gr.Accordion( + ( + "Open for More: Ransac Matches" + " (Green for good matches, Red for bad)" + ), + open=True, + ): + output_matches_ransac = gr.Image( + label="Ransac Matches", type="numpy" + ) + with gr.Accordion( + "Open for More: Matches Statistics", open=False + ): + output_pred = gr.File( + label="Outputs", elem_id="download" + ) + matches_result_info = gr.JSON( + label="Matches Statistics" + ) + matcher_info = gr.JSON(label="Match info") + + with gr.Accordion( + "Open for More: Warped Image", open=True + ): + output_wrapped = gr.Image( + label="Wrapped Pair", type="numpy" + ) + # send to input + button_rerun = gr.Button( + value="Send to Input Match Pair", + variant="primary", + ) + with gr.Accordion( + "Open for More: Geometry info", open=False + ): + geometry_result = gr.JSON( + label="Reconstructed Geometry" + ) + + # callbacks + match_image_src.change( + fn=self.ui_change_imagebox, + inputs=match_image_src, + outputs=input_image0, + ) + match_image_src.change( + fn=self.ui_change_imagebox, + inputs=match_image_src, + outputs=input_image1, + ) + # collect outputs + outputs = [ + output_keypoints, + output_matches_raw, + output_matches_ransac, + matches_result_info, + matcher_info, + geometry_result, + output_wrapped, + state_cache, + output_pred, + ] + # button callbacks + button_run.click( + fn=run_matching, inputs=inputs, outputs=outputs + ) + # Reset images + reset_outputs = [ + input_image0, + input_image1, + match_setting_threshold, + match_setting_max_keypoints, + detect_keypoints_threshold, + matcher_list, + input_image0, + input_image1, + match_image_src, + output_keypoints, + output_matches_raw, + output_matches_ransac, + matches_result_info, + matcher_info, + output_wrapped, + geometry_result, + ransac_method, + ransac_reproj_threshold, + ransac_confidence, + ransac_max_iter, + choice_geometry_type, + output_pred, + image_force_resize_cb, + ] + button_reset.click( + fn=self.ui_reset_state, + inputs=None, + outputs=reset_outputs, + ) + + # run ransac button action + button_ransac.click( + fn=run_ransac, + inputs=[ + state_cache, + choice_geometry_type, + ransac_method, + ransac_reproj_threshold, + ransac_confidence, + ransac_max_iter, + ], + outputs=[ + output_matches_ransac, + matches_result_info, + output_wrapped, + output_pred, + ], + ) + + # send warped image to match + button_rerun.click( + fn=send_to_match, + inputs=[state_cache], + outputs=[input_image0, input_image1], + ) + + # estimate geo + choice_geometry_type.change( + fn=generate_warp_images, + inputs=[ + input_image0, + input_image1, + geometry_result, + choice_geometry_type, + ], + outputs=[output_wrapped, geometry_result], + ) + with gr.Tab("Structure from Motion(under-dev)"): + sfm_ui = AppSfmUI( # noqa: F841 + { + **self.cfg, + "matcher_zoo": self.matcher_zoo, + "outputs": "experiments/sfm", + } + ) + sfm_ui.call_empty() + + def run(self): + self.app.queue().launch( + server_name=self.server_name, + server_port=self.server_port, + share=False, + ) + + def ui_change_imagebox(self, choice): + """ + Updates the image box with the given choice. + + Args: + choice (list): The list of image sources to be displayed in the image box. + + Returns: + dict: A dictionary containing the updated value, sources, and type for the image box. + """ + ret_dict = { + "value": None, # The updated value of the image box + "__type__": "update", # The type of update for the image box + } + if GRADIO_VERSION > "3": + return { + **ret_dict, + "sources": choice, # The list of image sources to be displayed + } + else: + return { + **ret_dict, + "source": choice, # The list of image sources to be displayed + } + + def _on_select_force_resize(self, visible: bool = False): + return gr.update(visible=visible), gr.update(visible=visible) + + def ui_reset_state( + self, + *args: Any, + ) -> Tuple[ + Optional[np.ndarray], + Optional[np.ndarray], + float, + int, + float, + str, + Dict[str, Any], + Dict[str, Any], + str, + Optional[np.ndarray], + Optional[np.ndarray], + Optional[np.ndarray], + Dict[str, Any], + Dict[str, Any], + Optional[np.ndarray], + Dict[str, Any], + str, + int, + float, + int, + bool, + ]: + """ + Reset the state of the UI. + + Returns: + tuple: A tuple containing the initial values for the UI state. + """ + key: str = list(self.matcher_zoo.keys())[ + 0 + ] # Get the first key from matcher_zoo + # flush_logs() + return ( + None, # image0: Optional[np.ndarray] + None, # image1: Optional[np.ndarray] + self.cfg["defaults"][ + "match_threshold" + ], # matching_threshold: float + self.cfg["defaults"]["max_keypoints"], # max_keypoints: int + self.cfg["defaults"][ + "keypoint_threshold" + ], # keypoint_threshold: float + key, # matcher: str + self.ui_change_imagebox("upload"), # input image0: Dict[str, Any] + self.ui_change_imagebox("upload"), # input image1: Dict[str, Any] + "upload", # match_image_src: str + None, # keypoints: Optional[np.ndarray] + None, # raw matches: Optional[np.ndarray] + None, # ransac matches: Optional[np.ndarray] + {}, # matches result info: Dict[str, Any] + {}, # matcher config: Dict[str, Any] + None, # warped image: Optional[np.ndarray] + {}, # geometry result: Dict[str, Any] + self.cfg["defaults"]["ransac_method"], # ransac_method: str + self.cfg["defaults"][ + "ransac_reproj_threshold" + ], # ransac_reproj_threshold: float + self.cfg["defaults"][ + "ransac_confidence" + ], # ransac_confidence: float + self.cfg["defaults"]["ransac_max_iter"], # ransac_max_iter: int + self.cfg["defaults"]["setting_geometry"], # geometry: str + None, # predictions + False, + ) + + def display_supported_algorithms(self, style="tab"): + def get_link(link, tag="Link"): + return "[{}]({})".format(tag, link) if link is not None else "None" + + data = [] + cfg = self.cfg["matcher_zoo"] + if style == "md": + markdown_table = "| Algo. | Conference | Code | Project | Paper |\n" + markdown_table += ( + "| ----- | ---------- | ---- | ------- | ----- |\n" + ) + + for k, v in cfg.items(): + if not v["info"]["display"]: + continue + github_link = get_link(v["info"]["github"]) + project_link = get_link(v["info"]["project"]) + paper_link = get_link( + v["info"]["paper"], + ( + Path(v["info"]["paper"]).name[-10:] + if v["info"]["paper"] is not None + else "Link" + ), + ) + + markdown_table += "{}|{}|{}|{}|{}\n".format( + v["info"]["name"], # display name + v["info"]["source"], + github_link, + project_link, + paper_link, + ) + return gr.Markdown(markdown_table) + elif style == "tab": + for k, v in cfg.items(): + if not v["info"].get("display", True): + continue + data.append( + [ + v["info"]["name"], + v["info"]["source"], + v["info"]["github"], + v["info"]["paper"], + v["info"]["project"], + ] + ) + tab = gr.Dataframe( + headers=["Algo.", "Conference", "Code", "Paper", "Project"], + datatype=["str", "str", "str", "str", "str"], + col_count=(5, "fixed"), + value=data, + # wrap=True, + # min_width = 1000, + # height=1000, + ) + return tab + + +class AppBaseUI: + def __init__(self, cfg: Dict[str, Any] = {}): + self.cfg = OmegaConf.create(cfg) + self.inputs = edict({}) + self.outputs = edict({}) + self.ui = edict({}) + + def _init_ui(self): + NotImplemented + + def call(self, **kwargs): + NotImplemented + + def info(self): + gr.Info("SFM is under construction.") + + +class AppSfmUI(AppBaseUI): + def __init__(self, cfg: Dict[str, Any] = None): + super().__init__(cfg) + assert "matcher_zoo" in self.cfg + self.matcher_zoo = self.cfg["matcher_zoo"] + self.sfm_engine = SfmEngine(cfg) + self._init_ui() + + def init_retrieval_dropdown(self): + algos = [] + for k, v in self.cfg["retrieval_zoo"].items(): + if v.get("enable", True): + algos.append(k) + return algos + + def _update_options(self, option): + if option == "sparse": + return gr.Textbox("sparse", visible=True) + elif option == "dense": + return gr.Textbox("dense", visible=True) + else: + return gr.Textbox("not set", visible=True) + + def _on_select_custom_params(self, value: bool = False): + return gr.update(visible=value) + + def _init_ui(self): + with gr.Row(): + # data settting and camera settings + with gr.Column(): + self.inputs.input_images = gr.File( + label="SfM", + interactive=True, + file_count="multiple", + min_width=300, + ) + # camera setting + with gr.Accordion("Camera Settings", open=True): + with gr.Column(): + with gr.Row(): + with gr.Column(): + self.inputs.camera_model = gr.Dropdown( + choices=[ + "PINHOLE", + "SIMPLE_RADIAL", + "OPENCV", + ], + value="PINHOLE", + label="Camera Model", + interactive=True, + ) + with gr.Column(): + gr.Checkbox( + label="Shared Params", + value=True, + interactive=True, + ) + camera_custom_params_cb = gr.Checkbox( + label="Custom Params", + value=False, + interactive=True, + ) + with gr.Row(): + self.inputs.camera_params = gr.Textbox( + label="Camera Params", + value="0,0,0,0", + interactive=False, + visible=False, + ) + camera_custom_params_cb.select( + fn=self._on_select_custom_params, + inputs=camera_custom_params_cb, + outputs=self.inputs.camera_params, + ) + + with gr.Accordion("Matching Settings", open=True): + # feature extraction and matching setting + with gr.Row(): + # matcher setting + self.inputs.matcher_key = gr.Dropdown( + choices=self.matcher_zoo.keys(), + value="disk+lightglue", + label="Matching Model", + interactive=True, + ) + with gr.Row(): + with gr.Accordion("Advanced Settings", open=False): + with gr.Column(): + with gr.Row(): + # matching setting + self.inputs.max_keypoints = gr.Slider( + label="Max Keypoints", + minimum=100, + maximum=10000, + value=1000, + interactive=True, + ) + self.inputs.keypoint_threshold = gr.Slider( + label="Keypoint Threshold", + minimum=0, + maximum=1, + value=0.01, + ) + with gr.Row(): + self.inputs.match_threshold = gr.Slider( + label="Match Threshold", + minimum=0.01, + maximum=12.0, + value=0.2, + ) + self.inputs.ransac_threshold = gr.Slider( + label="Ransac Threshold", + minimum=0.01, + maximum=12.0, + value=4.0, + step=0.01, + interactive=True, + ) + + with gr.Row(): + self.inputs.ransac_confidence = gr.Slider( + label="Ransac Confidence", + minimum=0.01, + maximum=1.0, + value=0.9999, + step=0.0001, + interactive=True, + ) + self.inputs.ransac_max_iter = gr.Slider( + label="Ransac Max Iter", + minimum=1, + maximum=100, + value=100, + step=1, + interactive=True, + ) + with gr.Accordion("Scene Graph Settings", open=True): + # mapping setting + self.inputs.scene_graph = gr.Dropdown( + choices=["all", "swin", "oneref"], + value="all", + label="Scene Graph", + interactive=True, + ) + + # global feature setting + self.inputs.global_feature = gr.Dropdown( + choices=self.init_retrieval_dropdown(), + value="netvlad", + label="Global features", + interactive=True, + ) + self.inputs.top_k = gr.Slider( + label="Number of Images per Image to Match", + minimum=1, + maximum=100, + value=10, + step=1, + ) + # button_match = gr.Button("Run Matching", variant="primary") + + # mapping setting + with gr.Column(): + with gr.Accordion("Mapping Settings", open=True): + with gr.Row(): + with gr.Accordion("Buddle Settings", open=True): + with gr.Row(): + self.inputs.mapper_refine_focal_length = ( + gr.Checkbox( + label="Refine Focal Length", + value=False, + interactive=True, + ) + ) + self.inputs.mapper_refine_principle_points = ( + gr.Checkbox( + label="Refine Principle Points", + value=False, + interactive=True, + ) + ) + self.inputs.mapper_refine_extra_params = ( + gr.Checkbox( + label="Refine Extra Params", + value=False, + interactive=True, + ) + ) + with gr.Accordion("Retriangluation Settings", open=True): + gr.Textbox( + label="Retriangluation Details", + ) + self.ui.button_sfm = gr.Button("Run SFM", variant="primary") + self.outputs.model_3d = gr.Model3D( + interactive=True, + ) + self.outputs.output_image = gr.Image( + label="SFM Visualize", + type="numpy", + image_mode="RGB", + interactive=False, + ) + + def call_empty(self): + self.ui.button_sfm.click(fn=self.info, inputs=[], outputs=[]) + + def call(self): + self.ui.button_sfm.click( + fn=self.sfm_engine.call, + inputs=[ + self.inputs.matcher_key, + self.inputs.input_images, # images + self.inputs.camera_model, + self.inputs.camera_params, + self.inputs.max_keypoints, + self.inputs.keypoint_threshold, + self.inputs.match_threshold, + self.inputs.ransac_threshold, + self.inputs.ransac_confidence, + self.inputs.ransac_max_iter, + self.inputs.scene_graph, + self.inputs.global_feature, + self.inputs.top_k, + self.inputs.mapper_refine_focal_length, + self.inputs.mapper_refine_principle_points, + self.inputs.mapper_refine_extra_params, + ], + outputs=[self.outputs.model_3d, self.outputs.output_image], + ) diff --git a/ui/config.yaml b/ui/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8f4d455fa7d81e4f8c343754f0bca593e08dbd23 --- /dev/null +++ b/ui/config.yaml @@ -0,0 +1,436 @@ +server: + name: "0.0.0.0" + port: 7861 + +defaults: + setting_threshold: 0.1 + max_keypoints: 2000 + keypoint_threshold: 0.05 + enable_ransac: true + ransac_method: CV2_USAC_MAGSAC + ransac_reproj_threshold: 8 + ransac_confidence: 0.999 + ransac_max_iter: 10000 + ransac_num_samples: 4 + match_threshold: 0.2 + setting_geometry: Homography + +matcher_zoo: + omniglue: + enable: true + matcher: omniglue + dense: true + info: + name: OmniGlue + source: "CVPR 2024" + github: https://github.com/Vincentqyw/omniglue-onnx + paper: https://arxiv.org/abs/2405.12979 + project: https://hwjiang1510.github.io/OmniGlue + display: true + Mast3R: + enable: false + matcher: mast3r + dense: true + info: + name: Mast3R #dispaly name + source: "CVPR 2024" + github: https://github.com/naver/mast3r + paper: https://arxiv.org/abs/2406.09756 + project: https://dust3r.europe.naverlabs.com + display: true + DUSt3R: + # TODO: duster is under development + enable: true + # skip_ci: true + matcher: duster + dense: true + info: + name: DUSt3R #dispaly name + source: "CVPR 2024" + github: https://github.com/naver/dust3r + paper: https://arxiv.org/abs/2312.14132 + project: https://dust3r.europe.naverlabs.com + display: true + GIM(dkm): + enable: true + # skip_ci: true + matcher: gim(dkm) + dense: true + info: + name: GIM(DKM) #dispaly name + source: "ICLR 2024" + github: https://github.com/xuelunshen/gim + paper: https://arxiv.org/abs/2402.11095 + project: https://xuelunshen.com/gim + display: true + RoMa: + matcher: roma + skip_ci: true + dense: true + info: + name: RoMa #dispaly name + source: "CVPR 2024" + github: https://github.com/Parskatt/RoMa + paper: https://arxiv.org/abs/2305.15404 + project: https://parskatt.github.io/RoMa + display: true + dkm: + matcher: dkm + skip_ci: true + dense: true + info: + name: DKM #dispaly name + source: "CVPR 2023" + github: https://github.com/Parskatt/DKM + paper: https://arxiv.org/abs/2202.00667 + project: https://parskatt.github.io/DKM + display: true + loftr: + matcher: loftr + dense: true + info: + name: LoFTR #dispaly name + source: "CVPR 2021" + github: https://github.com/zju3dv/LoFTR + paper: https://arxiv.org/pdf/2104.00680 + project: https://zju3dv.github.io/loftr + display: true + eloftr: + matcher: eloftr + dense: true + info: + name: Efficient LoFTR #dispaly name + source: "CVPR 2024" + github: https://github.com/zju3dv/efficientloftr + paper: https://zju3dv.github.io/efficientloftr/files/EfficientLoFTR.pdf + project: https://zju3dv.github.io/efficientloftr + display: true + cotr: + enable: false + skip_ci: true + matcher: cotr + dense: true + info: + name: CoTR #dispaly name + source: "ICCV 2021" + github: https://github.com/ubc-vision/COTR + paper: https://arxiv.org/abs/2103.14167 + project: null + display: true + topicfm: + matcher: topicfm + dense: true + info: + name: TopicFM #dispaly name + source: "AAAI 2023" + github: https://github.com/TruongKhang/TopicFM + paper: https://arxiv.org/abs/2307.00485 + project: null + display: true + aspanformer: + matcher: aspanformer + dense: true + info: + name: ASpanformer #dispaly name + source: "ECCV 2022" + github: https://github.com/Vincentqyw/ml-aspanformer + paper: https://arxiv.org/abs/2208.14201 + project: null + display: true + xfeat+lightglue: + enable: true + matcher: xfeat_lightglue + dense: true + info: + name: xfeat+lightglue + source: "CVPR 2024" + github: https://github.com/Vincentqyw/omniglue-onnx + paper: https://arxiv.org/abs/2405.12979 + project: https://hwjiang1510.github.io/OmniGlue + display: true + xfeat(sparse): + matcher: NN-mutual + feature: xfeat + dense: false + info: + name: XFeat #dispaly name + source: "CVPR 2024" + github: https://github.com/verlab/accelerated_features + paper: https://arxiv.org/abs/2404.19174 + project: null + display: true + xfeat(dense): + matcher: xfeat_dense + dense: true + info: + name: XFeat #dispaly name + source: "CVPR 2024" + github: https://github.com/verlab/accelerated_features + paper: https://arxiv.org/abs/2404.19174 + project: null + display: false + dedode: + matcher: Dual-Softmax + feature: dedode + dense: false + info: + name: DeDoDe #dispaly name + source: "3DV 2024" + github: https://github.com/Parskatt/DeDoDe + paper: https://arxiv.org/abs/2308.08479 + project: null + display: true + superpoint+superglue: + matcher: superglue + feature: superpoint_max + dense: false + info: + name: SuperGlue #dispaly name + source: "CVPR 2020" + github: https://github.com/magicleap/SuperGluePretrainedNetwork + paper: https://arxiv.org/abs/1911.11763 + project: null + display: true + superpoint+lightglue: + matcher: superpoint-lightglue + feature: superpoint_max + dense: false + info: + name: LightGlue #dispaly name + source: "ICCV 2023" + github: https://github.com/cvg/LightGlue + paper: https://arxiv.org/pdf/2306.13643 + project: null + display: true + disk: + matcher: NN-mutual + feature: disk + dense: false + info: + name: DISK + source: "NeurIPS 2020" + github: https://github.com/cvlab-epfl/disk + paper: https://arxiv.org/abs/2006.13566 + project: null + display: true + disk+dualsoftmax: + matcher: Dual-Softmax + feature: disk + dense: false + info: + name: DISK + source: "NeurIPS 2020" + github: https://github.com/cvlab-epfl/disk + paper: https://arxiv.org/abs/2006.13566 + project: null + display: false + superpoint+dualsoftmax: + matcher: Dual-Softmax + feature: superpoint_max + dense: false + info: + name: SuperPoint + source: "CVPRW 2018" + github: https://github.com/magicleap/SuperPointPretrainedNetwork + paper: https://arxiv.org/abs/1712.07629 + project: null + display: false + sift+lightglue: + matcher: sift-lightglue + feature: sift + dense: false + info: + name: LightGlue #dispaly name + source: "ICCV 2023" + github: https://github.com/cvg/LightGlue + paper: https://arxiv.org/pdf/2306.13643 + project: null + display: true + disk+lightglue: + matcher: disk-lightglue + feature: disk + dense: false + info: + name: LightGlue + source: "ICCV 2023" + github: https://github.com/cvg/LightGlue + paper: https://arxiv.org/pdf/2306.13643 + project: null + display: true + superpoint+mnn: + matcher: NN-mutual + feature: superpoint_max + dense: false + info: + name: SuperPoint #dispaly name + source: "CVPRW 2018" + github: https://github.com/magicleap/SuperPointPretrainedNetwork + paper: https://arxiv.org/abs/1712.07629 + project: null + display: true + sift+sgmnet: + matcher: sgmnet + feature: sift + dense: false + info: + name: SGMNet #dispaly name + source: "ICCV 2021" + github: https://github.com/vdvchen/SGMNet + paper: https://arxiv.org/abs/2108.08771 + project: null + display: true + sosnet: + matcher: NN-mutual + feature: sosnet + dense: false + info: + name: SOSNet #dispaly name + source: "CVPR 2019" + github: https://github.com/scape-research/SOSNet + paper: https://arxiv.org/abs/1904.05019 + project: https://research.scape.io/sosnet + display: true + hardnet: + matcher: NN-mutual + feature: hardnet + dense: false + info: + name: HardNet #dispaly name + source: "NeurIPS 2017" + github: https://github.com/DagnyT/hardnet + paper: https://arxiv.org/abs/1705.10872 + project: null + display: true + d2net: + matcher: NN-mutual + feature: d2net-ss + dense: false + info: + name: D2Net #dispaly name + source: "CVPR 2019" + github: https://github.com/Vincentqyw/d2-net + paper: https://arxiv.org/abs/1905.03561 + project: https://dusmanu.com/publications/d2-net.html + display: true + rord: + matcher: NN-mutual + feature: rord + dense: false + info: + name: RoRD #dispaly name + source: "IROS 2021" + github: https://github.com/UditSinghParihar/RoRD + paper: https://arxiv.org/abs/2103.08573 + project: https://uditsinghparihar.github.io/RoRD + display: true + alike: + matcher: NN-mutual + feature: alike + dense: false + info: + name: ALIKE #dispaly name + source: "TMM 2022" + github: https://github.com/Shiaoming/ALIKE + paper: https://arxiv.org/abs/2112.02906 + project: null + display: true + lanet: + matcher: NN-mutual + feature: lanet + dense: false + info: + name: LANet #dispaly name + source: "ACCV 2022" + github: https://github.com/wangch-g/lanet + paper: https://openaccess.thecvf.com/content/ACCV2022/papers/Wang_Rethinking_Low-level_Features_for_Interest_Point_Detection_and_Description_ACCV_2022_paper.pdf + project: null + display: true + r2d2: + matcher: NN-mutual + feature: r2d2 + dense: false + info: + name: R2D2 #dispaly name + source: "NeurIPS 2019" + github: https://github.com/naver/r2d2 + paper: https://arxiv.org/abs/1906.06195 + project: null + display: true + darkfeat: + matcher: NN-mutual + feature: darkfeat + dense: false + info: + name: DarkFeat #dispaly name + source: "AAAI 2023" + github: https://github.com/THU-LYJ-Lab/DarkFeat + paper: null + project: null + display: true + sift: + matcher: NN-mutual + feature: sift + dense: false + info: + name: SIFT #dispaly name + source: "IJCV 2004" + github: null + paper: https://www.cs.ubc.ca/~lowe/papers/ijcv04.pdf + project: null + display: true + gluestick: + enable: false + matcher: gluestick + dense: true + info: + name: GlueStick #dispaly name + source: "ICCV 2023" + github: https://github.com/cvg/GlueStick + paper: https://arxiv.org/abs/2304.02008 + project: https://iago-suarez.com/gluestick + display: true + sold2: + enable: false + matcher: sold2 + dense: true + info: + name: SOLD2 #dispaly name + source: "CVPR 2021" + github: https://github.com/cvg/SOLD2 + paper: https://arxiv.org/abs/2104.03362 + project: null + display: true + + sfd2+imp: + enable: true + matcher: imp + feature: sfd2 + dense: false + info: + name: SFD2+IMP #dispaly name + source: "CVPR 2023" + github: https://github.com/feixue94/imp-release + paper: https://arxiv.org/pdf/2304.14837 + project: https://feixue94.github.io/ + display: true + + sfd2+mnn: + enable: true + matcher: NN-mutual + feature: sfd2 + dense: false + info: + name: SFD2+MNN #dispaly name + source: "CVPR 2023" + github: https://github.com/feixue94/sfd2 + paper: https://arxiv.org/abs/2304.14845 + project: https://feixue94.github.io/ + display: true + +retrieval_zoo: + netvlad: + enable: true + openibl: + enable: true + cosplace: + enable: true diff --git a/ui/sfm.py b/ui/sfm.py new file mode 100644 index 0000000000000000000000000000000000000000..2fd90bd07891cb9e7492fe538b1b2a591a138ce2 --- /dev/null +++ b/ui/sfm.py @@ -0,0 +1,170 @@ +import shutil +import sys +import tempfile +from pathlib import Path +from typing import Any, Dict, List + +sys.path.append(str(Path(__file__).parents[1])) + +from hloc import ( + extract_features, + logger, + match_features, + pairs_from_retrieval, + reconstruction, + visualization, +) + +try: + import pycolmap +except ImportError: + logger.warning("pycolmap not installed, some features may not work") + +from ui.viz import fig2im + + +class SfmEngine: + def __init__(self, cfg: Dict[str, Any] = None): + self.cfg = cfg + if "outputs" in cfg and Path(cfg["outputs"]): + outputs = Path(cfg["outputs"]) + outputs.mkdir(parents=True, exist_ok=True) + else: + outputs = tempfile.mkdtemp() + self.outputs = Path(outputs) + + def call( + self, + key: str, + images: Path, + camera_model: str, + camera_params: List[float], + max_keypoints: int, + keypoint_threshold: float, + match_threshold: float, + ransac_threshold: int, + ransac_confidence: float, + ransac_max_iter: int, + scene_graph: bool, + global_feature: str, + top_k: int = 10, + mapper_refine_focal_length: bool = False, + mapper_refine_principle_points: bool = False, + mapper_refine_extra_params: bool = False, + ): + """ + Call a list of functions to perform feature extraction, matching, and reconstruction. + + Args: + key (str): The key to retrieve the matcher and feature models. + images (Path): The directory containing the images. + outputs (Path): The directory to store the outputs. + camera_model (str): The camera model. + camera_params (List[float]): The camera parameters. + max_keypoints (int): The maximum number of features. + match_threshold (float): The match threshold. + ransac_threshold (int): The RANSAC threshold. + ransac_confidence (float): The RANSAC confidence. + ransac_max_iter (int): The maximum number of RANSAC iterations. + scene_graph (bool): Whether to compute the scene graph. + global_feature (str): Whether to compute the global feature. + top_k (int): The number of image-pair to use. + mapper_refine_focal_length (bool): Whether to refine the focal length. + mapper_refine_principle_points (bool): Whether to refine the principle points. + mapper_refine_extra_params (bool): Whether to refine the extra parameters. + + Returns: + Path: The directory containing the SfM results. + """ + if len(images) == 0: + logger.error(f"{images} does not exist.") + + temp_images = Path(tempfile.mkdtemp()) + # copy images + logger.info(f"Copying images to {temp_images}.") + for image in images: + shutil.copy(image, temp_images) + + matcher_zoo = self.cfg["matcher_zoo"] + model = matcher_zoo[key] + match_conf = model["matcher"] + match_conf["model"]["max_keypoints"] = max_keypoints + match_conf["model"]["match_threshold"] = match_threshold + + feature_conf = model["feature"] + feature_conf["model"]["max_keypoints"] = max_keypoints + feature_conf["model"]["keypoint_threshold"] = keypoint_threshold + + # retrieval + retrieval_name = self.cfg.get("retrieval_name", "netvlad") + retrieval_conf = extract_features.confs[retrieval_name] + + mapper_options = { + "ba_refine_extra_params": mapper_refine_extra_params, + "ba_refine_focal_length": mapper_refine_focal_length, + "ba_refine_principal_point": mapper_refine_principle_points, + "ba_local_max_num_iterations": 40, + "ba_local_max_refinements": 3, + "ba_global_max_num_iterations": 100, + # below 3 options are for individual/video data, for internet photos, they should be left + # default + "min_focal_length_ratio": 0.1, + "max_focal_length_ratio": 10, + "max_extra_param": 1e15, + } + + sfm_dir = self.outputs / "sfm_{}".format(key) + sfm_pairs = self.outputs / "pairs-sfm.txt" + sfm_dir.mkdir(exist_ok=True, parents=True) + + # extract features + retrieval_path = extract_features.main( + retrieval_conf, temp_images, self.outputs + ) + pairs_from_retrieval.main(retrieval_path, sfm_pairs, num_matched=top_k) + + feature_path = extract_features.main( + feature_conf, temp_images, self.outputs + ) + # match features + match_path = match_features.main( + match_conf, sfm_pairs, feature_conf["output"], self.outputs + ) + # reconstruction + already_sfm = False + if sfm_dir.exists(): + try: + model = pycolmap.Reconstruction(str(sfm_dir)) + already_sfm = True + except ValueError: + logger.info(f"sfm_dir not exists model: {sfm_dir}") + if not already_sfm: + model = reconstruction.main( + sfm_dir, + temp_images, + sfm_pairs, + feature_path, + match_path, + mapper_options=mapper_options, + ) + + vertices = [] + for point3D_id, point3D in model.points3D.items(): + vertices.append([point3D.xyz, point3D.color]) + + model_3d = sfm_dir / "points3D.obj" + with open(model_3d, "w") as f: + for p, c in vertices: + # Write vertex position + f.write("v {} {} {}\n".format(p[0], p[1], p[2])) + # Write vertex normal (color) + f.write( + "vn {} {} {}\n".format( + c[0] / 255.0, c[1] / 255.0, c[2] / 255.0 + ) + ) + viz_2d = visualization.visualize_sfm_2d( + model, temp_images, color_by="visibility", n=2, dpi=300 + ) + + return model_3d, fig2im(viz_2d) / 255.0 diff --git a/ui/utils.py b/ui/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cbd935e39dd774b179b8811cc2902d536ac5785f --- /dev/null +++ b/ui/utils.py @@ -0,0 +1,1081 @@ +import os +import pickle +import random +import shutil +import sys +import time +import warnings +from itertools import combinations +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import cv2 +import gradio as gr +import matplotlib.pyplot as plt +import numpy as np +import poselib +import psutil +from PIL import Image + +sys.path.append(str(Path(__file__).parents[1])) + +from hloc import ( + DEVICE, + extract_features, + extractors, + logger, + match_dense, + match_features, + matchers, +) +from hloc.utils.base_model import dynamic_load +from ui.viz import display_keypoints, display_matches, fig2im, plot_images + +warnings.simplefilter("ignore") + +ROOT = Path(__file__).parent.parent +# some default values +DEFAULT_SETTING_THRESHOLD = 0.1 +DEFAULT_SETTING_MAX_FEATURES = 2000 +DEFAULT_DEFAULT_KEYPOINT_THRESHOLD = 0.01 +DEFAULT_ENABLE_RANSAC = True +DEFAULT_RANSAC_METHOD = "CV2_USAC_MAGSAC" +DEFAULT_RANSAC_REPROJ_THRESHOLD = 8 +DEFAULT_RANSAC_CONFIDENCE = 0.999 +DEFAULT_RANSAC_MAX_ITER = 10000 +DEFAULT_MIN_NUM_MATCHES = 4 +DEFAULT_MATCHING_THRESHOLD = 0.2 +DEFAULT_SETTING_GEOMETRY = "Homography" +GRADIO_VERSION = gr.__version__.split(".")[0] +MATCHER_ZOO = None + + +class ModelCache: + def __init__(self, max_memory_size: int = 8): + self.max_memory_size = max_memory_size + self.current_memory_size = 0 + self.model_dict = {} + self.model_timestamps = [] + + def cache_model(self, model_key, model_loader_func, model_conf): + if model_key in self.model_dict: + self.model_timestamps.remove(model_key) + self.model_timestamps.append(model_key) + logger.info(f"Load cached {model_key}") + return self.model_dict[model_key] + + model = self._load_model_from_disk(model_loader_func, model_conf) + while self._calculate_model_memory() > self.max_memory_size: + if len(self.model_timestamps) == 0: + logger.warn( + "RAM: {}GB, MAX RAM: {}GB".format( + self._calculate_model_memory(), self.max_memory_size + ) + ) + break + oldest_model_key = self.model_timestamps.pop(0) + self.current_memory_size = self._calculate_model_memory() + logger.info(f"Del cached {oldest_model_key}") + del self.model_dict[oldest_model_key] + + self.model_dict[model_key] = model + self.model_timestamps.append(model_key) + + self.print_memory_usage() + logger.info(f"Total cached {list(self.model_dict.keys())}") + + return model + + def _load_model_from_disk(self, model_loader_func, model_conf): + return model_loader_func(model_conf) + + def _calculate_model_memory(self, verbose=False): + host_colocation = int(os.environ.get("HOST_COLOCATION", "1")) + vm = psutil.virtual_memory() + du = shutil.disk_usage(".") + if verbose: + logger.info( + f"RAM: {vm.used / 1e9:.1f}/{vm.total / host_colocation / 1e9:.1f}GB" + ) + logger.info( + f"DISK: {du.used / 1e9:.1f}/{du.total / host_colocation / 1e9:.1f}GB" + ) + return vm.used / 1e9 + + def print_memory_usage(self): + self._calculate_model_memory(verbose=True) + + +model_cache = ModelCache() + + +def load_config(config_name: str) -> Dict[str, Any]: + """ + Load a YAML configuration file. + + Args: + config_name: The path to the YAML configuration file. + + Returns: + The configuration dictionary, with string keys and arbitrary values. + """ + import yaml + + with open(config_name, "r") as stream: + try: + config: Dict[str, Any] = yaml.safe_load(stream) + except yaml.YAMLError as exc: + logger.error(exc) + return config + + +def get_matcher_zoo( + matcher_zoo: Dict[str, Dict[str, Union[str, bool]]] +) -> Dict[str, Dict[str, Union[Callable, bool]]]: + """ + Restore matcher configurations from a dictionary. + + Args: + matcher_zoo: A dictionary with the matcher configurations, + where the configuration is a dictionary as loaded from a YAML file. + + Returns: + A dictionary with the matcher configurations, where the configuration is + a function or a function instead of a string. + """ + matcher_zoo_restored = {} + for k, v in matcher_zoo.items(): + matcher_zoo_restored[k] = parse_match_config(v) + return matcher_zoo_restored + + +def parse_match_config(conf): + if conf["dense"]: + return { + "matcher": match_dense.confs.get(conf["matcher"]), + "dense": True, + } + else: + return { + "feature": extract_features.confs.get(conf["feature"]), + "matcher": match_features.confs.get(conf["matcher"]), + "dense": False, + } + + +def get_model(match_conf: Dict[str, Any]): + """ + Load a matcher model from the provided configuration. + + Args: + match_conf: A dictionary containing the model configuration. + + Returns: + A matcher model instance. + """ + Model = dynamic_load(matchers, match_conf["model"]["name"]) + model = Model(match_conf["model"]).eval().to(DEVICE) + return model + + +def get_feature_model(conf: Dict[str, Dict[str, Any]]): + """ + Load a feature extraction model from the provided configuration. + + Args: + conf: A dictionary containing the model configuration. + + Returns: + A feature extraction model instance. + """ + Model = dynamic_load(extractors, conf["model"]["name"]) + model = Model(conf["model"]).eval().to(DEVICE) + return model + + +def gen_examples(): + random.seed(1) + example_matchers = [ + "disk+lightglue", + "xfeat(sparse)", + "dedode", + "loftr", + "disk", + "RoMa", + "d2net", + "aspanformer", + "topicfm", + "superpoint+superglue", + "superpoint+lightglue", + "superpoint+mnn", + "disk", + ] + + def distribute_elements(A, B): + new_B = np.array(B, copy=True).flatten() + np.random.shuffle(new_B) + new_B = np.resize(new_B, len(A)) + np.random.shuffle(new_B) + return new_B.tolist() + + # normal examples + def gen_images_pairs(count: int = 5): + path = str(ROOT / "datasets/sacre_coeur/mapping") + imgs_list = [ + os.path.join(path, file) + for file in os.listdir(path) + if file.lower().endswith((".jpg", ".jpeg", ".png")) + ] + pairs = list(combinations(imgs_list, 2)) + if len(pairs) < count: + count = len(pairs) + selected = random.sample(range(len(pairs)), count) + return [pairs[i] for i in selected] + + # rotated examples + def gen_rot_image_pairs(count: int = 5): + path = ROOT / "datasets/sacre_coeur/mapping" + path_rot = ROOT / "datasets/sacre_coeur/mapping_rot" + rot_list = [45, 180, 90, 225, 270] + pairs = [] + for file in os.listdir(path): + if file.lower().endswith((".jpg", ".jpeg", ".png")): + for rot in rot_list: + file_rot = "{}_rot{}.jpg".format(Path(file).stem, rot) + if (path_rot / file_rot).exists(): + pairs.append( + [ + path / file, + path_rot / file_rot, + ] + ) + if len(pairs) < count: + count = len(pairs) + selected = random.sample(range(len(pairs)), count) + return [pairs[i] for i in selected] + + def gen_scale_image_pairs(count: int = 5): + path = ROOT / "datasets/sacre_coeur/mapping" + path_scale = ROOT / "datasets/sacre_coeur/mapping_scale" + scale_list = [0.3, 0.5] + pairs = [] + for file in os.listdir(path): + if file.lower().endswith((".jpg", ".jpeg", ".png")): + for scale in scale_list: + file_scale = "{}_scale{}.jpg".format(Path(file).stem, scale) + if (path_scale / file_scale).exists(): + pairs.append( + [ + path / file, + path_scale / file_scale, + ] + ) + if len(pairs) < count: + count = len(pairs) + selected = random.sample(range(len(pairs)), count) + return [pairs[i] for i in selected] + + # extramely hard examples + def gen_image_pairs_wxbs(count: int = None): + prefix = "datasets/wxbs_benchmark/.WxBS/v1.1" + wxbs_path = ROOT / prefix + pairs = [] + for catg in os.listdir(wxbs_path): + catg_path = wxbs_path / catg + if not catg_path.is_dir(): + continue + for scene in os.listdir(catg_path): + scene_path = catg_path / scene + if not scene_path.is_dir(): + continue + img1_path = scene_path / "01.png" + img2_path = scene_path / "02.png" + if img1_path.exists() and img2_path.exists(): + pairs.append([str(img1_path), str(img2_path)]) + return pairs + + # image pair path + pairs = gen_images_pairs() + pairs += gen_rot_image_pairs() + pairs += gen_scale_image_pairs() + pairs += gen_image_pairs_wxbs() + + match_setting_threshold = DEFAULT_SETTING_THRESHOLD + match_setting_max_features = DEFAULT_SETTING_MAX_FEATURES + detect_keypoints_threshold = DEFAULT_DEFAULT_KEYPOINT_THRESHOLD + ransac_method = DEFAULT_RANSAC_METHOD + ransac_reproj_threshold = DEFAULT_RANSAC_REPROJ_THRESHOLD + ransac_confidence = DEFAULT_RANSAC_CONFIDENCE + ransac_max_iter = DEFAULT_RANSAC_MAX_ITER + input_lists = [] + dist_examples = distribute_elements(pairs, example_matchers) + for pair, mt in zip(pairs, dist_examples): + input_lists.append( + [ + pair[0], + pair[1], + match_setting_threshold, + match_setting_max_features, + detect_keypoints_threshold, + mt, + # enable_ransac, + ransac_method, + ransac_reproj_threshold, + ransac_confidence, + ransac_max_iter, + ] + ) + return input_lists + + +def set_null_pred(feature_type: str, pred: dict): + if feature_type == "KEYPOINT": + pred["mmkeypoints0_orig"] = np.array([]) + pred["mmkeypoints1_orig"] = np.array([]) + pred["mmconf"] = np.array([]) + elif feature_type == "LINE": + pred["mline_keypoints0_orig"] = np.array([]) + pred["mline_keypoints1_orig"] = np.array([]) + pred["H"] = None + pred["geom_info"] = {} + return pred + + +def _filter_matches_opencv( + kp0: np.ndarray, + kp1: np.ndarray, + method: int = cv2.RANSAC, + reproj_threshold: float = 3.0, + confidence: float = 0.99, + max_iter: int = 2000, + geometry_type: str = "Homography", +) -> Tuple[np.ndarray, np.ndarray]: + """ + Filters matches between two sets of keypoints using OpenCV's findHomography. + + Args: + kp0 (np.ndarray): Array of keypoints from the first image. + kp1 (np.ndarray): Array of keypoints from the second image. + method (int, optional): RANSAC method. Defaults to "cv2.RANSAC". + reproj_threshold (float, optional): RANSAC reprojection threshold. Defaults to 3.0. + confidence (float, optional): RANSAC confidence. Defaults to 0.99. + max_iter (int, optional): RANSAC maximum iterations. Defaults to 2000. + geometry_type (str, optional): Type of geometry. Defaults to "Homography". + + Returns: + Tuple[np.ndarray, np.ndarray]: Homography matrix and mask. + """ + if geometry_type == "Homography": + M, mask = cv2.findHomography( + kp0, + kp1, + method=method, + ransacReprojThreshold=reproj_threshold, + confidence=confidence, + maxIters=max_iter, + ) + elif geometry_type == "Fundamental": + M, mask = cv2.findFundamentalMat( + kp0, + kp1, + method=method, + ransacReprojThreshold=reproj_threshold, + confidence=confidence, + maxIters=max_iter, + ) + mask = np.array(mask.ravel().astype("bool"), dtype="bool") + return M, mask + + +def _filter_matches_poselib( + kp0: np.ndarray, + kp1: np.ndarray, + method: int = None, # not used + reproj_threshold: float = 3, + confidence: float = 0.99, + max_iter: int = 2000, + geometry_type: str = "Homography", +) -> dict: + """ + Filters matches between two sets of keypoints using the poselib library. + + Args: + kp0 (np.ndarray): Array of keypoints from the first image. + kp1 (np.ndarray): Array of keypoints from the second image. + method (str, optional): RANSAC method. Defaults to "RANSAC". + reproj_threshold (float, optional): RANSAC reprojection threshold. Defaults to 3. + confidence (float, optional): RANSAC confidence. Defaults to 0.99. + max_iter (int, optional): RANSAC maximum iterations. Defaults to 2000. + geometry_type (str, optional): Type of geometry. Defaults to "Homography". + + Returns: + dict: Information about the homography estimation. + """ + ransac_options = { + "max_iterations": max_iter, + # "min_iterations": min_iter, + "success_prob": confidence, + "max_reproj_error": reproj_threshold, + # "progressive_sampling": args.sampler.lower() == 'prosac' + } + + if geometry_type == "Homography": + M, info = poselib.estimate_homography(kp0, kp1, ransac_options) + elif geometry_type == "Fundamental": + M, info = poselib.estimate_fundamental(kp0, kp1, ransac_options) + else: + raise NotImplementedError + + return M, np.array(info["inliers"]) + + +def proc_ransac_matches( + mkpts0: np.ndarray, + mkpts1: np.ndarray, + ransac_method: str = DEFAULT_RANSAC_METHOD, + ransac_reproj_threshold: float = 3.0, + ransac_confidence: float = 0.99, + ransac_max_iter: int = 2000, + geometry_type: str = "Homography", +): + if ransac_method.startswith("CV2"): + logger.info( + f"ransac_method: {ransac_method}, geometry_type: {geometry_type}" + ) + return _filter_matches_opencv( + mkpts0, + mkpts1, + ransac_zoo[ransac_method], + ransac_reproj_threshold, + ransac_confidence, + ransac_max_iter, + geometry_type, + ) + elif ransac_method.startswith("POSELIB"): + logger.info( + f"ransac_method: {ransac_method}, geometry_type: {geometry_type}" + ) + return _filter_matches_poselib( + mkpts0, + mkpts1, + None, + ransac_reproj_threshold, + ransac_confidence, + ransac_max_iter, + geometry_type, + ) + else: + raise NotImplementedError + + +def filter_matches( + pred: Dict[str, Any], + ransac_method: str = DEFAULT_RANSAC_METHOD, + ransac_reproj_threshold: float = DEFAULT_RANSAC_REPROJ_THRESHOLD, + ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE, + ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER, + ransac_estimator: str = None, +): + """ + Filter matches using RANSAC. If keypoints are available, filter by keypoints. + If lines are available, filter by lines. If both keypoints and lines are + available, filter by keypoints. + + Args: + pred (Dict[str, Any]): dict of matches, including original keypoints. + ransac_method (str, optional): RANSAC method. Defaults to DEFAULT_RANSAC_METHOD. + ransac_reproj_threshold (float, optional): RANSAC reprojection threshold. Defaults to DEFAULT_RANSAC_REPROJ_THRESHOLD. + ransac_confidence (float, optional): RANSAC confidence. Defaults to DEFAULT_RANSAC_CONFIDENCE. + ransac_max_iter (int, optional): RANSAC maximum iterations. Defaults to DEFAULT_RANSAC_MAX_ITER. + + Returns: + Dict[str, Any]: filtered matches. + """ + mkpts0: Optional[np.ndarray] = None + mkpts1: Optional[np.ndarray] = None + feature_type: Optional[str] = None + if "mkeypoints0_orig" in pred.keys() and "mkeypoints1_orig" in pred.keys(): + mkpts0 = pred["mkeypoints0_orig"] + mkpts1 = pred["mkeypoints1_orig"] + feature_type = "KEYPOINT" + elif ( + "line_keypoints0_orig" in pred.keys() + and "line_keypoints1_orig" in pred.keys() + ): + mkpts0 = pred["line_keypoints0_orig"] + mkpts1 = pred["line_keypoints1_orig"] + feature_type = "LINE" + else: + return set_null_pred(feature_type, pred) + if mkpts0 is None or mkpts0 is None: + return set_null_pred(feature_type, pred) + if ransac_method not in ransac_zoo.keys(): + ransac_method = DEFAULT_RANSAC_METHOD + + if len(mkpts0) < DEFAULT_MIN_NUM_MATCHES: + return set_null_pred(feature_type, pred) + + geom_info = compute_geometry( + pred, + ransac_method=ransac_method, + ransac_reproj_threshold=ransac_reproj_threshold, + ransac_confidence=ransac_confidence, + ransac_max_iter=ransac_max_iter, + ) + + if "Homography" in geom_info.keys(): + mask = geom_info["mask_h"] + if feature_type == "KEYPOINT": + pred["mmkeypoints0_orig"] = mkpts0[mask] + pred["mmkeypoints1_orig"] = mkpts1[mask] + pred["mmconf"] = pred["mconf"][mask] + elif feature_type == "LINE": + pred["mline_keypoints0_orig"] = mkpts0[mask] + pred["mline_keypoints1_orig"] = mkpts1[mask] + pred["H"] = np.array(geom_info["Homography"]) + else: + set_null_pred(feature_type, pred) + # do not show mask + geom_info.pop("mask_h", None) + geom_info.pop("mask_f", None) + pred["geom_info"] = geom_info + return pred + + +def compute_geometry( + pred: Dict[str, Any], + ransac_method: str = DEFAULT_RANSAC_METHOD, + ransac_reproj_threshold: float = DEFAULT_RANSAC_REPROJ_THRESHOLD, + ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE, + ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER, +) -> Dict[str, List[float]]: + """ + Compute geometric information of matches, including Fundamental matrix, + Homography matrix, and rectification matrices (if available). + + Args: + pred (Dict[str, Any]): dict of matches, including original keypoints. + ransac_method (str, optional): RANSAC method. Defaults to DEFAULT_RANSAC_METHOD. + ransac_reproj_threshold (float, optional): RANSAC reprojection threshold. Defaults to DEFAULT_RANSAC_REPROJ_THRESHOLD. + ransac_confidence (float, optional): RANSAC confidence. Defaults to DEFAULT_RANSAC_CONFIDENCE. + ransac_max_iter (int, optional): RANSAC maximum iterations. Defaults to DEFAULT_RANSAC_MAX_ITER. + + Returns: + Dict[str, List[float]]: geometric information in form of a dict. + """ + mkpts0: Optional[np.ndarray] = None + mkpts1: Optional[np.ndarray] = None + + if "mkeypoints0_orig" in pred.keys() and "mkeypoints1_orig" in pred.keys(): + mkpts0 = pred["mkeypoints0_orig"] + mkpts1 = pred["mkeypoints1_orig"] + elif ( + "line_keypoints0_orig" in pred.keys() + and "line_keypoints1_orig" in pred.keys() + ): + mkpts0 = pred["line_keypoints0_orig"] + mkpts1 = pred["line_keypoints1_orig"] + + if mkpts0 is not None and mkpts1 is not None: + if len(mkpts0) < 2 * DEFAULT_MIN_NUM_MATCHES: + return {} + geo_info: Dict[str, List[float]] = {} + + F, mask_f = proc_ransac_matches( + mkpts0, + mkpts1, + ransac_method, + ransac_reproj_threshold, + ransac_confidence, + ransac_max_iter, + geometry_type="Fundamental", + ) + + if F is not None: + geo_info["Fundamental"] = F.tolist() + geo_info["mask_f"] = mask_f + H, mask_h = proc_ransac_matches( + mkpts1, + mkpts0, + ransac_method, + ransac_reproj_threshold, + ransac_confidence, + ransac_max_iter, + geometry_type="Homography", + ) + + h0, w0, _ = pred["image0_orig"].shape + if H is not None: + geo_info["Homography"] = H.tolist() + geo_info["mask_h"] = mask_h + try: + _, H1, H2 = cv2.stereoRectifyUncalibrated( + mkpts0.reshape(-1, 2), + mkpts1.reshape(-1, 2), + F, + imgSize=(w0, h0), + ) + geo_info["H1"] = H1.tolist() + geo_info["H2"] = H2.tolist() + except cv2.error as e: + logger.error( + f"StereoRectifyUncalibrated failed, skip! error: {e}" + ) + return geo_info + else: + return {} + + +def wrap_images( + img0: np.ndarray, + img1: np.ndarray, + geo_info: Optional[Dict[str, List[float]]], + geom_type: str, +) -> Tuple[Optional[str], Optional[Dict[str, List[float]]]]: + """ + Wraps the images based on the geometric transformation used to align them. + + Args: + img0: numpy array representing the first image. + img1: numpy array representing the second image. + geo_info: dictionary containing the geometric transformation information. + geom_type: type of geometric transformation used to align the images. + + Returns: + A tuple containing a base64 encoded image string and a dictionary with the transformation matrix. + """ + h0, w0, _ = img0.shape + h1, w1, _ = img1.shape + if geo_info is not None and len(geo_info) != 0: + rectified_image0 = img0 + rectified_image1 = None + if "Homography" not in geo_info: + logger.warning(f"{geom_type} not exist, maybe too less matches") + return None, None + + H = np.array(geo_info["Homography"]) + + title: List[str] = [] + if geom_type == "Homography": + rectified_image1 = cv2.warpPerspective(img1, H, (w0, h0)) + title = ["Image 0", "Image 1 - warped"] + elif geom_type == "Fundamental": + if geom_type not in geo_info: + logger.warning(f"{geom_type} not exist, maybe too less matches") + return None, None + else: + H1, H2 = np.array(geo_info["H1"]), np.array(geo_info["H2"]) + rectified_image0 = cv2.warpPerspective(img0, H1, (w0, h0)) + rectified_image1 = cv2.warpPerspective(img1, H2, (w1, h1)) + title = ["Image 0 - warped", "Image 1 - warped"] + else: + print("Error: Unknown geometry type") + fig = plot_images( + [rectified_image0.squeeze(), rectified_image1.squeeze()], + title, + dpi=300, + ) + return fig2im(fig), rectified_image1 + else: + return None, None + + +def generate_warp_images( + input_image0: np.ndarray, + input_image1: np.ndarray, + matches_info: Dict[str, Any], + choice: str, +) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]: + """ + Changes the estimate of the geometric transformation used to align the images. + + Args: + input_image0: First input image. + input_image1: Second input image. + matches_info: Dictionary containing information about the matches. + choice: Type of geometric transformation to use ('Homography' or 'Fundamental') or 'No' to disable. + + Returns: + A tuple containing the updated images and the warpped images. + """ + if ( + matches_info is None + or len(matches_info) < 1 + or "geom_info" not in matches_info.keys() + ): + return None, None + geom_info = matches_info["geom_info"] + warped_image = None + if choice != "No": + wrapped_image_pair, warped_image = wrap_images( + input_image0, input_image1, geom_info, choice + ) + return wrapped_image_pair, warped_image + else: + return None, None + + +def send_to_match(state_cache: Dict[str, Any]): + """ + Send the state cache to the match function. + + Args: + state_cache (Dict[str, Any]): Current state of the app. + + Returns: + None + """ + if state_cache: + return ( + state_cache["image0_orig"], + state_cache["wrapped_image"], + ) + else: + return None, None + + +def run_ransac( + state_cache: Dict[str, Any], + choice_geometry_type: str, + ransac_method: str = DEFAULT_RANSAC_METHOD, + ransac_reproj_threshold: int = DEFAULT_RANSAC_REPROJ_THRESHOLD, + ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE, + ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER, +) -> Tuple[Optional[np.ndarray], Optional[Dict[str, int]]]: + """ + Run RANSAC matches and return the output images and the number of matches. + + Args: + state_cache (Dict[str, Any]): Current state of the app, including the matches. + ransac_method (str, optional): RANSAC method. Defaults to DEFAULT_RANSAC_METHOD. + ransac_reproj_threshold (int, optional): RANSAC reprojection threshold. Defaults to DEFAULT_RANSAC_REPROJ_THRESHOLD. + ransac_confidence (float, optional): RANSAC confidence. Defaults to DEFAULT_RANSAC_CONFIDENCE. + ransac_max_iter (int, optional): RANSAC maximum iterations. Defaults to DEFAULT_RANSAC_MAX_ITER. + + Returns: + Tuple[Optional[np.ndarray], Optional[Dict[str, int]]]: Tuple containing the output images and the number of matches. + """ + if not state_cache: + logger.info("Run Match first before Rerun RANSAC") + gr.Warning("Run Match first before Rerun RANSAC") + return None, None + t1 = time.time() + logger.info( + f"Run RANSAC matches using: {ransac_method} with threshold: {ransac_reproj_threshold}" + ) + logger.info( + f"Run RANSAC matches using: {ransac_confidence} with iter: {ransac_max_iter}" + ) + # if enable_ransac: + filter_matches( + state_cache, + ransac_method=ransac_method, + ransac_reproj_threshold=ransac_reproj_threshold, + ransac_confidence=ransac_confidence, + ransac_max_iter=ransac_max_iter, + ) + logger.info(f"RANSAC matches done using: {time.time()-t1:.3f}s") + t1 = time.time() + + # plot images with ransac matches + titles = [ + "Image 0 - Ransac matched keypoints", + "Image 1 - Ransac matched keypoints", + ] + output_matches_ransac, num_matches_ransac = display_matches( + state_cache, titles=titles, tag="KPTS_RANSAC" + ) + logger.info(f"Display matches done using: {time.time()-t1:.3f}s") + t1 = time.time() + + # compute warp images + output_wrapped, warped_image = generate_warp_images( + state_cache["image0_orig"], + state_cache["image1_orig"], + state_cache, + choice_geometry_type, + ) + plt.close("all") + + num_matches_raw = state_cache["num_matches_raw"] + state_cache["wrapped_image"] = warped_image + + # tmp_state_cache = tempfile.NamedTemporaryFile(suffix='.pkl', delete=False) + tmp_state_cache = "output.pkl" + with open(tmp_state_cache, "wb") as f: + pickle.dump(state_cache, f) + + logger.info("Dump results done!") + + return ( + output_matches_ransac, + { + "num_matches_raw": num_matches_raw, + "num_matches_ransac": num_matches_ransac, + }, + output_wrapped, + tmp_state_cache, + ) + + +def run_matching( + image0: np.ndarray, + image1: np.ndarray, + match_threshold: float, + extract_max_keypoints: int, + keypoint_threshold: float, + key: str, + ransac_method: str = DEFAULT_RANSAC_METHOD, + ransac_reproj_threshold: int = DEFAULT_RANSAC_REPROJ_THRESHOLD, + ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE, + ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER, + choice_geometry_type: str = DEFAULT_SETTING_GEOMETRY, + matcher_zoo: Dict[str, Any] = None, + force_resize: bool = False, + image_width: int = 640, + image_height: int = 480, + use_cached_model: bool = False, +) -> Tuple[ + np.ndarray, + np.ndarray, + np.ndarray, + Dict[str, int], + Dict[str, Dict[str, Any]], + Dict[str, Dict[str, float]], + np.ndarray, +]: + """Match two images using the given parameters. + + Args: + image0 (np.ndarray): RGB image 0. + image1 (np.ndarray): RGB image 1. + match_threshold (float): match threshold. + extract_max_keypoints (int): number of keypoints to extract. + keypoint_threshold (float): keypoint threshold. + key (str): key of the model to use. + ransac_method (str, optional): RANSAC method to use. + ransac_reproj_threshold (int, optional): RANSAC reprojection threshold. + ransac_confidence (float, optional): RANSAC confidence level. + ransac_max_iter (int, optional): RANSAC maximum number of iterations. + choice_geometry_type (str, optional): setting of geometry estimation. + matcher_zoo (Dict[str, Any], optional): matcher zoo. Defaults to None. + force_resize (bool, optional): force resize. Defaults to False. + image_width (int, optional): image width. Defaults to 640. + image_height (int, optional): image height. Defaults to 480. + use_cached_model (bool, optional): use cached model. Defaults to False. + + Returns: + tuple: + - output_keypoints (np.ndarray): image with keypoints. + - output_matches_raw (np.ndarray): image with raw matches. + - output_matches_ransac (np.ndarray): image with RANSAC matches. + - num_matches (Dict[str, int]): number of raw and RANSAC matches. + - configs (Dict[str, Dict[str, Any]]): match and feature extraction configs. + - geom_info (Dict[str, Dict[str, float]]): geometry information. + - output_wrapped (np.ndarray): wrapped images. + """ + # image0 and image1 is RGB mode + if image0 is None or image1 is None: + logger.error( + "Error: No images found! Please upload two images or select an example." + ) + raise gr.Error( + "Error: No images found! Please upload two images or select an example." + ) + # init output + output_keypoints = None + output_matches_raw = None + output_matches_ransac = None + + # super slow! + if "roma" in key.lower() and DEVICE == "cpu": + gr.Info( + f"Success! Please be patient and allow for about 2-3 minutes." + f" Due to CPU inference, {key} is quiet slow." + ) + t0 = time.time() + model = matcher_zoo[key] + match_conf = model["matcher"] + # update match config + match_conf["model"]["match_threshold"] = match_threshold + match_conf["model"]["max_keypoints"] = extract_max_keypoints + cache_key = "{}_{}".format(key, match_conf["model"]["name"]) + if use_cached_model: + # because of the model cache, we need to update the config + matcher = model_cache.cache_model(cache_key, get_model, match_conf) + matcher.conf["max_keypoints"] = extract_max_keypoints + matcher.conf["match_threshold"] = match_threshold + logger.info(f"Loaded cached model {cache_key}") + else: + matcher = get_model(match_conf) + logger.info(f"Loading model using: {time.time()-t0:.3f}s") + t1 = time.time() + + if model["dense"]: + if not match_conf["preprocessing"].get("force_resize", False): + match_conf["preprocessing"]["force_resize"] = force_resize + else: + logger.info("preprocessing is already resized") + if force_resize: + match_conf["preprocessing"]["height"] = image_height + match_conf["preprocessing"]["width"] = image_width + logger.info(f"Force resize to {image_width}x{image_height}") + + pred = match_dense.match_images( + matcher, image0, image1, match_conf["preprocessing"], device=DEVICE + ) + del matcher + extract_conf = None + else: + extract_conf = model["feature"] + # update extract config + extract_conf["model"]["max_keypoints"] = extract_max_keypoints + extract_conf["model"]["keypoint_threshold"] = keypoint_threshold + cache_key = "{}_{}".format(key, extract_conf["model"]["name"]) + + if use_cached_model: + extractor = model_cache.cache_model( + cache_key, get_feature_model, extract_conf + ) + # because of the model cache, we need to update the config + extractor.conf["max_keypoints"] = extract_max_keypoints + extractor.conf["keypoint_threshold"] = keypoint_threshold + logger.info(f"Loaded cached model {cache_key}") + else: + extractor = get_feature_model(extract_conf) + + if not extract_conf["preprocessing"].get("force_resize", False): + extract_conf["preprocessing"]["force_resize"] = force_resize + else: + logger.info("preprocessing is already resized") + if force_resize: + extract_conf["preprocessing"]["height"] = image_height + extract_conf["preprocessing"]["width"] = image_width + logger.info(f"Force resize to {image_width}x{image_height}") + + pred0 = extract_features.extract( + extractor, image0, extract_conf["preprocessing"] + ) + pred1 = extract_features.extract( + extractor, image1, extract_conf["preprocessing"] + ) + pred = match_features.match_images(matcher, pred0, pred1) + del extractor + # gr.Info( + # f"Matching images done using: {time.time()-t1:.3f}s", + # ) + logger.info(f"Matching images done using: {time.time()-t1:.3f}s") + t1 = time.time() + + # plot images with keypoints + titles = [ + "Image 0 - Keypoints", + "Image 1 - Keypoints", + ] + output_keypoints = display_keypoints(pred, titles=titles) + + # plot images with raw matches + titles = [ + "Image 0 - Raw matched keypoints", + "Image 1 - Raw matched keypoints", + ] + output_matches_raw, num_matches_raw = display_matches(pred, titles=titles) + + # if enable_ransac: + filter_matches( + pred, + ransac_method=ransac_method, + ransac_reproj_threshold=ransac_reproj_threshold, + ransac_confidence=ransac_confidence, + ransac_max_iter=ransac_max_iter, + ) + + # gr.Info(f"RANSAC matches done using: {time.time()-t1:.3f}s") + logger.info(f"RANSAC matches done using: {time.time()-t1:.3f}s") + t1 = time.time() + + # plot images with ransac matches + titles = [ + "Image 0 - Ransac matched keypoints", + "Image 1 - Ransac matched keypoints", + ] + output_matches_ransac, num_matches_ransac = display_matches( + pred, titles=titles, tag="KPTS_RANSAC" + ) + # gr.Info(f"Display matches done using: {time.time()-t1:.3f}s") + logger.info(f"Display matches done using: {time.time()-t1:.3f}s") + + t1 = time.time() + # plot wrapped images + output_wrapped, warped_image = generate_warp_images( + pred["image0_orig"], + pred["image1_orig"], + pred, + choice_geometry_type, + ) + plt.close("all") + # gr.Info(f"In summary, total time: {time.time()-t0:.3f}s") + logger.info(f"TOTAL time: {time.time()-t0:.3f}s") + + state_cache = pred + state_cache["num_matches_raw"] = num_matches_raw + state_cache["num_matches_ransac"] = num_matches_ransac + state_cache["wrapped_image"] = warped_image + + # tmp_state_cache = tempfile.NamedTemporaryFile(suffix='.pkl', delete=False) + tmp_state_cache = "output.pkl" + with open(tmp_state_cache, "wb") as f: + pickle.dump(state_cache, f) + logger.info("Dump results done!") + return ( + output_keypoints, + output_matches_raw, + output_matches_ransac, + { + "num_raw_matches": num_matches_raw, + "num_ransac_matches": num_matches_ransac, + }, + { + "match_conf": match_conf, + "extractor_conf": extract_conf, + }, + { + "geom_info": pred.get("geom_info", {}), + }, + output_wrapped, + state_cache, + tmp_state_cache, + ) + + +# @ref: https://docs.opencv.org/4.x/d0/d74/md__build_4_x-contrib_docs-lin64_opencv_doc_tutorials_calib3d_usac.html +# AND: https://opencv.org/blog/2021/06/09/evaluating-opencvs-new-ransacs +ransac_zoo = { + "POSELIB": "LO-RANSAC", + "CV2_RANSAC": cv2.RANSAC, + "CV2_USAC_MAGSAC": cv2.USAC_MAGSAC, + "CV2_USAC_DEFAULT": cv2.USAC_DEFAULT, + "CV2_USAC_FM_8PTS": cv2.USAC_FM_8PTS, + "CV2_USAC_PROSAC": cv2.USAC_PROSAC, + "CV2_USAC_FAST": cv2.USAC_FAST, + "CV2_USAC_ACCURATE": cv2.USAC_ACCURATE, + "CV2_USAC_PARALLEL": cv2.USAC_PARALLEL, +} + + +def rotate_image(input_path, degrees, output_path): + img = Image.open(input_path) + img_rotated = img.rotate(-degrees) + img_rotated.save(output_path) + + +def scale_image(input_path, scale_factor, output_path): + img = Image.open(input_path) + width, height = img.size + new_width = int(width * scale_factor) + new_height = int(height * scale_factor) + new_img = Image.new("RGB", (width, height), (0, 0, 0)) + img_resized = img.resize((new_width, new_height)) + position = ((width - new_width) // 2, (height - new_height) // 2) + new_img.paste(img_resized, position) + new_img.save(output_path) diff --git a/ui/viz.py b/ui/viz.py new file mode 100644 index 0000000000000000000000000000000000000000..6533f0b03aec86775552da951a9e57e7eeb33164 --- /dev/null +++ b/ui/viz.py @@ -0,0 +1,498 @@ +import sys +import typing +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import cv2 +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import seaborn as sns + +sys.path.append(str(Path(__file__).parents[1])) + +from hloc.utils.viz import add_text, plot_keypoints + +np.random.seed(1995) +color_map = np.arange(100) +np.random.shuffle(color_map) + + +def plot_images( + imgs: List[np.ndarray], + titles: Optional[List[str]] = None, + cmaps: Union[str, List[str]] = "gray", + dpi: int = 100, + size: Optional[int] = 5, + pad: float = 0.5, +) -> plt.Figure: + """Plot a set of images horizontally. + Args: + imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W). + titles: a list of strings, as titles for each image. + cmaps: colormaps for monochrome images. If a single string is given, + it is used for all images. + dpi: DPI of the figure. + size: figure size in inches (width). If not provided, the figure + size is determined automatically. + pad: padding between subplots, in inches. + Returns: + The created figure. + """ + n = len(imgs) + if not isinstance(cmaps, list): + cmaps = [cmaps] * n + figsize = (size * n, size * 6 / 5) if size is not None else None + fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi) + + if n == 1: + ax = [ax] + for i in range(n): + ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i])) + ax[i].get_yaxis().set_ticks([]) + ax[i].get_xaxis().set_ticks([]) + ax[i].set_axis_off() + for spine in ax[i].spines.values(): # remove frame + spine.set_visible(False) + if titles: + ax[i].set_title(titles[i]) + fig.tight_layout(pad=pad) + return fig + + +def plot_color_line_matches( + lines: List[np.ndarray], + correct_matches: Optional[np.ndarray] = None, + lw: float = 2.0, + indices: Tuple[int, int] = (0, 1), +) -> matplotlib.figure.Figure: + """Plot line matches for existing images with multiple colors. + + Args: + lines: List of ndarrays of size (N, 2, 2) representing line segments. + correct_matches: Optional bool array of size (N,) indicating correct + matches. If not None, display wrong matches with a low alpha. + lw: Line width as float pixels. + indices: Indices of the images to draw the matches on. + + Returns: + The modified matplotlib figure. + """ + n_lines = lines[0].shape[0] + colors = sns.color_palette("husl", n_colors=n_lines) + np.random.shuffle(colors) + alphas = np.ones(n_lines) + if correct_matches is not None: + alphas[~np.array(correct_matches)] = 0.2 + + fig = plt.gcf() + ax = typing.cast(List[matplotlib.axes.Axes], fig.axes) + assert len(ax) > max(indices) + axes = [ax[i] for i in indices] + fig.canvas.draw() + + # Plot the lines + for a, l in zip(axes, lines): + # Transform the points into the figure coordinate system + transFigure = fig.transFigure.inverted() + endpoint0 = transFigure.transform(a.transData.transform(l[:, 0])) + endpoint1 = transFigure.transform(a.transData.transform(l[:, 1])) + fig.lines += [ + matplotlib.lines.Line2D( + (endpoint0[i, 0], endpoint1[i, 0]), + (endpoint0[i, 1], endpoint1[i, 1]), + zorder=1, + transform=fig.transFigure, + c=colors[i], + alpha=alphas[i], + linewidth=lw, + ) + for i in range(n_lines) + ] + + return fig + + +def make_matching_figure( + img0: np.ndarray, + img1: np.ndarray, + mkpts0: np.ndarray, + mkpts1: np.ndarray, + color: np.ndarray, + titles: Optional[List[str]] = None, + kpts0: Optional[np.ndarray] = None, + kpts1: Optional[np.ndarray] = None, + text: List[str] = [], + dpi: int = 75, + path: Optional[Path] = None, + pad: float = 0.0, +) -> Optional[plt.Figure]: + """Draw image pair with matches. + + Args: + img0: image0 as HxWx3 numpy array. + img1: image1 as HxWx3 numpy array. + mkpts0: matched points in image0 as Nx2 numpy array. + mkpts1: matched points in image1 as Nx2 numpy array. + color: colors for the matches as Nx4 numpy array. + titles: titles for the two subplots. + kpts0: keypoints in image0 as Kx2 numpy array. + kpts1: keypoints in image1 as Kx2 numpy array. + text: list of strings to display in the top-left corner of the image. + dpi: dots per inch of the saved figure. + path: if not None, save the figure to this path. + pad: padding around the image as a fraction of the image size. + + Returns: + The matplotlib Figure object if path is None. + """ + # draw image pair + fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi) + axes[0].imshow(img0) # , cmap='gray') + axes[1].imshow(img1) # , cmap='gray') + for i in range(2): # clear all frames + axes[i].get_yaxis().set_ticks([]) + axes[i].get_xaxis().set_ticks([]) + for spine in axes[i].spines.values(): + spine.set_visible(False) + if titles is not None: + axes[i].set_title(titles[i]) + + plt.tight_layout(pad=pad) + + if kpts0 is not None: + assert kpts1 is not None + axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c="w", s=5) + axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c="w", s=5) + + # draw matches + if ( + mkpts0.shape[0] != 0 + and mkpts1.shape[0] != 0 + and mkpts0.shape == mkpts1.shape + ): + fig.canvas.draw() + transFigure = fig.transFigure.inverted() + fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0)) + fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1)) + fig.lines = [ + matplotlib.lines.Line2D( + (fkpts0[i, 0], fkpts1[i, 0]), + (fkpts0[i, 1], fkpts1[i, 1]), + transform=fig.transFigure, + c=color[i], + linewidth=2, + ) + for i in range(len(mkpts0)) + ] + + # freeze the axes to prevent the transform to change + axes[0].autoscale(enable=False) + axes[1].autoscale(enable=False) + + axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color[..., :3], s=4) + axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color[..., :3], s=4) + + # put txts + txt_color = "k" if img0[:100, :200].mean() > 200 else "w" + fig.text( + 0.01, + 0.99, + "\n".join(text), + transform=fig.axes[0].transAxes, + fontsize=15, + va="top", + ha="left", + color=txt_color, + ) + + # save or return figure + if path: + plt.savefig(str(path), bbox_inches="tight", pad_inches=0) + plt.close() + else: + return fig + + +def error_colormap( + err: np.ndarray, thr: float, alpha: float = 1.0 +) -> np.ndarray: + """ + Create a colormap based on the error values. + + Args: + err: Error values as a numpy array of shape (N,). + thr: Threshold value for the error. + alpha: Alpha value for the colormap, between 0 and 1. + + Returns: + Colormap as a numpy array of shape (N, 4) with values in [0, 1]. + """ + assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}" + x = 1 - np.clip(err / (thr * 2), 0, 1) + return np.clip( + np.stack( + [2 - x * 2, x * 2, np.zeros_like(x), np.ones_like(x) * alpha], -1 + ), + 0, + 1, + ) + + +def fig2im(fig: matplotlib.figure.Figure) -> np.ndarray: + """ + Convert a matplotlib figure to a numpy array with RGB values. + + Args: + fig: A matplotlib figure. + + Returns: + A numpy array with shape (height, width, 3) and dtype uint8 containing + the RGB values of the figure. + """ + fig.canvas.draw() + (width, height) = fig.canvas.get_width_height() + buf_ndarray = np.frombuffer(fig.canvas.tostring_rgb(), dtype="u1") + return buf_ndarray.reshape(height, width, 3) + + +def draw_matches_core( + mkpts0: List[np.ndarray], + mkpts1: List[np.ndarray], + img0: np.ndarray, + img1: np.ndarray, + conf: np.ndarray, + titles: Optional[List[str]] = None, + texts: Optional[List[str]] = None, + dpi: int = 150, + path: Optional[str] = None, + pad: float = 0.5, +) -> np.ndarray: + """ + Draw matches between two images. + + Args: + mkpts0: List of matches from the first image, with shape (N, 2) + mkpts1: List of matches from the second image, with shape (N, 2) + img0: First image, with shape (H, W, 3) + img1: Second image, with shape (H, W, 3) + conf: Confidence values for the matches, with shape (N,) + titles: Optional list of title strings for the plot + dpi: DPI for the saved image + path: Optional path to save the image to. If None, the image is not saved. + pad: Padding between subplots + + Returns: + The figure as a numpy array with shape (height, width, 3) and dtype uint8 + containing the RGB values of the figure. + """ + thr = 0.5 + color = error_colormap(1 - conf, thr, alpha=0.1) + text = [ + # "image name", + f"#Matches: {len(mkpts0)}", + ] + if path: + fig2im( + make_matching_figure( + img0, + img1, + mkpts0, + mkpts1, + color, + titles=titles, + text=text, + path=path, + dpi=dpi, + pad=pad, + ) + ) + else: + return fig2im( + make_matching_figure( + img0, + img1, + mkpts0, + mkpts1, + color, + titles=titles, + text=text, + pad=pad, + dpi=dpi, + ) + ) + + +def draw_image_pairs( + img0: np.ndarray, + img1: np.ndarray, + text: List[str] = [], + dpi: int = 75, + path: Optional[str] = None, + pad: float = 0.5, +) -> np.ndarray: + """Draw image pair horizontally. + + Args: + img0: First image, with shape (H, W, 3) + img1: Second image, with shape (H, W, 3) + text: List of strings to print. Each string is a new line. + dpi: DPI of the figure. + path: Path to save the image to. If None, the image is not saved and + the function returns the figure as a numpy array with shape + (height, width, 3) and dtype uint8 containing the RGB values of the + figure. + pad: Padding between subplots + + Returns: + The figure as a numpy array with shape (height, width, 3) and dtype uint8 + containing the RGB values of the figure, or None if path is not None. + """ + # draw image pair + fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi) + axes[0].imshow(img0) # , cmap='gray') + axes[1].imshow(img1) # , cmap='gray') + for i in range(2): # clear all frames + axes[i].get_yaxis().set_ticks([]) + axes[i].get_xaxis().set_ticks([]) + for spine in axes[i].spines.values(): + spine.set_visible(False) + plt.tight_layout(pad=pad) + + # put txts + txt_color = "k" if img0[:100, :200].mean() > 200 else "w" + fig.text( + 0.01, + 0.99, + "\n".join(text), + transform=fig.axes[0].transAxes, + fontsize=15, + va="top", + ha="left", + color=txt_color, + ) + + # save or return figure + if path: + plt.savefig(str(path), bbox_inches="tight", pad_inches=0) + plt.close() + else: + return fig2im(fig) + + +def display_keypoints(pred: dict, titles: List[str] = []): + img0 = pred["image0_orig"] + img1 = pred["image1_orig"] + output_keypoints = plot_images([img0, img1], titles=titles, dpi=300) + if "keypoints0_orig" in pred.keys() and "keypoints1_orig" in pred.keys(): + plot_keypoints([pred["keypoints0_orig"], pred["keypoints1_orig"]]) + text = ( + f"# keypoints0: {len(pred['keypoints0_orig'])} \n" + + f"# keypoints1: {len(pred['keypoints1_orig'])}" + ) + add_text(0, text, fs=15) + output_keypoints = fig2im(output_keypoints) + return output_keypoints + + +def display_matches( + pred: Dict[str, np.ndarray], + titles: List[str] = [], + texts: List[str] = [], + dpi: int = 300, + tag: str = "KPTS_RAW", # KPTS_RAW, KPTS_RANSAC, LINES_RAW, LINES_RANSAC, +) -> Tuple[np.ndarray, int]: + """ + Displays the matches between two images. + + Args: + pred: Dictionary containing the original images and the matches. + titles: Optional titles for the plot. + dpi: Resolution of the plot. + + Returns: + The resulting concatenated plot and the number of inliers. + """ + img0 = pred["image0_orig"] + img1 = pred["image1_orig"] + num_inliers = 0 + KPTS0_KEY = None + KPTS1_KEY = None + confid = None + if tag == "KPTS_RAW": + KPTS0_KEY = "mkeypoints0_orig" + KPTS1_KEY = "mkeypoints1_orig" + if "mconf" in pred: + confid = pred["mconf"] + elif tag == "KPTS_RANSAC": + KPTS0_KEY = "mmkeypoints0_orig" + KPTS1_KEY = "mmkeypoints1_orig" + if "mmconf" in pred: + confid = pred["mmconf"] + else: + # TODO: LINES_RAW, LINES_RANSAC + raise ValueError(f"Unknown tag: {tag}") + # draw raw matches + if ( + KPTS0_KEY in pred + and KPTS1_KEY in pred + and pred[KPTS0_KEY] is not None + and pred[KPTS1_KEY] is not None + ): # draw ransac matches + mkpts0 = pred[KPTS0_KEY] + mkpts1 = pred[KPTS1_KEY] + num_inliers = len(mkpts0) + if confid is None: + confid = np.ones(len(mkpts0)) + fig_mkpts = draw_matches_core( + mkpts0, + mkpts1, + img0, + img1, + confid, + dpi=dpi, + titles=titles, + texts=texts, + ) + fig = fig_mkpts + # TODO: draw lines + if ( + "line0_orig" in pred + and "line1_orig" in pred + and pred["line0_orig"] is not None + and pred["line1_orig"] is not None + and (tag == "LINES_RAW" or tag == "LINES_RANSAC") + ): + # lines + mtlines0 = pred["line0_orig"] + mtlines1 = pred["line1_orig"] + num_inliers = len(mtlines0) + fig_lines = plot_images( + [img0.squeeze(), img1.squeeze()], + ["Image 0 - matched lines", "Image 1 - matched lines"], + dpi=300, + ) + fig_lines = plot_color_line_matches([mtlines0, mtlines1], lw=2) + fig_lines = fig2im(fig_lines) + + # keypoints + mkpts0 = pred.get("line_keypoints0_orig") + mkpts1 = pred.get("line_keypoints1_orig") + fig = None + breakpoint() + if mkpts0 is not None and mkpts1 is not None: + num_inliers = len(mkpts0) + if "mconf" in pred: + mconf = pred["mconf"] + else: + mconf = np.ones(len(mkpts0)) + fig_mkpts = draw_matches_core( + mkpts0, mkpts1, img0, img1, mconf, dpi=300 + ) + fig_lines = cv2.resize( + fig_lines, (fig_mkpts.shape[1], fig_mkpts.shape[0]) + ) + fig = np.concatenate([fig_mkpts, fig_lines], axis=0) + else: + fig = fig_lines + return fig, num_inliers