diff --git a/third_party/DKM/.gitignore b/third_party/DKM/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..07442492a552f5e0f3feadd0b992d92792ddb3bb --- /dev/null +++ b/third_party/DKM/.gitignore @@ -0,0 +1,3 @@ +*.egg-info* +*.vscode* +*__pycache__* \ No newline at end of file diff --git a/third_party/DKM/LICENSE b/third_party/DKM/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..1625fcb9c1046af4180f55b58acff245814a2c2e --- /dev/null +++ b/third_party/DKM/LICENSE @@ -0,0 +1,25 @@ +NOTE! Models trained on our synthetic dataset uses datasets which are licensed under non-commercial licenses. +Hence we cannot provide them under the MIT license. However, MegaDepth is under MIT license, hence we provide those models under MIT license, see below. + + +License for Models Trained on MegaDepth ONLY below: + +Copyright (c) 2022 Johan Edstedt + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/third_party/DKM/README.md b/third_party/DKM/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c68fbfaf848e5c2a64adc308351aad38f94b0830 --- /dev/null +++ b/third_party/DKM/README.md @@ -0,0 +1,117 @@ +# DKM: Dense Kernelized Feature Matching for Geometry Estimation +### [Project Page](https://parskatt.github.io/DKM) | [Paper](https://arxiv.org/abs/2202.00667) +
+ +> DKM: Dense Kernelized Feature Matching for Geometry Estimation +> [Johan Edstedt](https://scholar.google.com/citations?user=Ul-vMR0AAAAJ), [Ioannis Athanasiadis](https://scholar.google.com/citations?user=RCAtJgUAAAAJ), [Mårten Wadenbäck](https://scholar.google.com/citations?user=6WRQpCQAAAAJ), [Michael Felsberg](https://scholar.google.com/citations?&user=lkWfR08AAAAJ) +> CVPR 2023 + +## How to Use? +
+Our model produces a dense (for all pixels) warp and certainty. + +Warp: [B,H,W,4] for all images in batch of size B, for each pixel HxW, we ouput the input and matching coordinate in the normalized grids [-1,1]x[-1,1]. + +Certainty: [B,H,W] a number in each pixel indicating the matchability of the pixel. + +See [demo](dkm/demo/) for two demos of DKM. + +See [api.md](docs/api.md) for API. +
+ +## Qualitative Results +
+ +https://user-images.githubusercontent.com/22053118/223748279-0f0c21b4-376a-440a-81f5-7f9a5d87483f.mp4 + + +https://user-images.githubusercontent.com/22053118/223748512-1bca4a17-cffa-491d-a448-96aac1353ce9.mp4 + + + +https://user-images.githubusercontent.com/22053118/223748518-4d475d9f-a933-4581-97ed-6e9413c4caca.mp4 + + + +https://user-images.githubusercontent.com/22053118/223748522-39c20631-aa16-4954-9c27-95763b38f2ce.mp4 + + +
+ + + +## Benchmark Results + +
+ +### Megadepth1500 + +| | @5 | @10 | @20 | +|-------|-------|------|------| +| DKMv1 | 54.5 | 70.7 | 82.3 | +| DKMv2 | *56.8* | *72.3* | *83.2* | +| DKMv3 (paper) | **60.5** | **74.9** | **85.1** | +| DKMv3 (this repo) | **60.0** | **74.6** | **84.9** | + +### Megadepth 8 Scenes +| | @5 | @10 | @20 | +|-------|-------|------|------| +| DKMv3 (paper) | **60.5** | **74.5** | **84.2** | +| DKMv3 (this repo) | **60.4** | **74.6** | **84.3** | + + +### ScanNet1500 +| | @5 | @10 | @20 | +|-------|-------|------|------| +| DKMv1 | 24.8 | 44.4 | 61.9 | +| DKMv2 | *28.2* | *49.2* | *66.6* | +| DKMv3 (paper) | **29.4** | **50.7** | **68.3** | +| DKMv3 (this repo) | **29.8** | **50.8** | **68.3** | + +
+ +## Navigating the Code +* Code for models can be found in [dkm/models](dkm/models/) +* Code for benchmarks can be found in [dkm/benchmarks](dkm/benchmarks/) +* Code for reproducing experiments from our paper can be found in [experiments/](experiments/) + +## Install +Run ``pip install -e .`` + +## Demo + +A demonstration of our method can be run by: +``` bash +python demo_match.py +``` +This runs our model trained on mega on two images taken from Sacre Coeur. + +## Benchmarks +See [Benchmarks](docs/benchmarks.md) for details. +## Training +See [Training](docs/training.md) for details. +## Reproducing Results +Given that the required benchmark or training dataset has been downloaded and unpacked, results can be reproduced by running the experiments in the experiments folder. + +## Using DKM matches for estimation +We recommend using the excellent Graph-Cut RANSAC algorithm: https://github.com/danini/graph-cut-ransac + +| | @5 | @10 | @20 | +|-------|-------|------|------| +| DKMv3 (RANSAC) | *60.5* | *74.9* | *85.1* | +| DKMv3 (GC-RANSAC) | **65.5** | **78.0** | **86.7** | + + +## Acknowledgements +We have used code and been inspired by https://github.com/PruneTruong/DenseMatching, https://github.com/zju3dv/LoFTR, and https://github.com/GrumpyZhou/patch2pix. We additionally thank the authors of ECO-TR for providing their benchmark. + +## BibTeX +If you find our models useful, please consider citing our paper! +``` +@inproceedings{edstedt2023dkm, +title={{DKM}: Dense Kernelized Feature Matching for Geometry Estimation}, +author={Edstedt, Johan and Athanasiadis, Ioannis and Wadenbäck, Mårten and Felsberg, Michael}, +booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, +year={2023} +} +``` diff --git a/third_party/DKM/assets/ams_hom_A.jpg b/third_party/DKM/assets/ams_hom_A.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b2c1c35a5316d823dd88dcac6247b7fc952feb21 --- /dev/null +++ b/third_party/DKM/assets/ams_hom_A.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:271a19f0b29fc88d8f88d1136f001078ca6bf5105ff95355f89a18e787c50e3a +size 1194880 diff --git a/third_party/DKM/assets/ams_hom_B.jpg b/third_party/DKM/assets/ams_hom_B.jpg new file mode 100644 index 0000000000000000000000000000000000000000..008ed1fbaf39135a1ded50a502a86701d9d900ca --- /dev/null +++ b/third_party/DKM/assets/ams_hom_B.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d84ced12e607f5ac5f7628151694fbaa2300caa091ac168e0aedad2ebaf491d6 +size 1208132 diff --git a/third_party/DKM/assets/dkmv3_warp.jpg b/third_party/DKM/assets/dkmv3_warp.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f8251c324f0bb713215b53b73d68330b15f87550 --- /dev/null +++ b/third_party/DKM/assets/dkmv3_warp.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04c46e39d5ea68e9e116d4ae71c038a459beaf3eed89e8b7b87ccafd01d3bf85 +size 571179 diff --git a/third_party/DKM/assets/mega_8_scenes_0008_0.1_0.3.npz b/third_party/DKM/assets/mega_8_scenes_0008_0.1_0.3.npz new file mode 100644 index 0000000000000000000000000000000000000000..1fd5b337e1d27ca569230c80902e762a503b75d0 --- /dev/null +++ b/third_party/DKM/assets/mega_8_scenes_0008_0.1_0.3.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c902547181fc9b370fdd16272140be6803fe983aea978c68683db803ac70dd57 +size 906160 diff --git a/third_party/DKM/assets/mega_8_scenes_0008_0.3_0.5.npz b/third_party/DKM/assets/mega_8_scenes_0008_0.3_0.5.npz new file mode 100644 index 0000000000000000000000000000000000000000..464f56196138cd064182a33a4bf0171bb0df62e1 --- /dev/null +++ b/third_party/DKM/assets/mega_8_scenes_0008_0.3_0.5.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:65ce02bd248988b42363ccd257abaa9b99a00d569d2779597b36ba6c4da35021 +size 906160 diff --git a/third_party/DKM/assets/mega_8_scenes_0019_0.1_0.3.npz b/third_party/DKM/assets/mega_8_scenes_0019_0.1_0.3.npz new file mode 100644 index 0000000000000000000000000000000000000000..bd4ccc895c4aa9900a610468fb6976e07e2dc362 --- /dev/null +++ b/third_party/DKM/assets/mega_8_scenes_0019_0.1_0.3.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c6104feb8807a4ebdd1266160e67b3c507c550012f54c23292d0ebf99b88753f +size 368192 diff --git a/third_party/DKM/assets/mega_8_scenes_0019_0.3_0.5.npz b/third_party/DKM/assets/mega_8_scenes_0019_0.3_0.5.npz new file mode 100644 index 0000000000000000000000000000000000000000..189a6ab6c4be9c58970b8ec93bcc1d91f1e33b17 --- /dev/null +++ b/third_party/DKM/assets/mega_8_scenes_0019_0.3_0.5.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9600ba5c24d414f63728bf5ee7550a3b035d7c615461e357590890ae0e0f042e +size 368192 diff --git a/third_party/DKM/assets/mega_8_scenes_0021_0.1_0.3.npz b/third_party/DKM/assets/mega_8_scenes_0021_0.1_0.3.npz new file mode 100644 index 0000000000000000000000000000000000000000..408847e2613016fd0e1a34cef376fb94925011b5 --- /dev/null +++ b/third_party/DKM/assets/mega_8_scenes_0021_0.1_0.3.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de89e9ccf10515cc4196ba1e7172ec98b2fb92ff9f85d90db5df1af5b6503313 +size 167528 diff --git a/third_party/DKM/assets/mega_8_scenes_0021_0.3_0.5.npz b/third_party/DKM/assets/mega_8_scenes_0021_0.3_0.5.npz new file mode 100644 index 0000000000000000000000000000000000000000..f85a48ca4d1843959c61eefd2d200d0ffc31c87d --- /dev/null +++ b/third_party/DKM/assets/mega_8_scenes_0021_0.3_0.5.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:94c97b57beb10411b3b98a1e88c9e1e2f9db51994dce04580d2b7cfc8919dab3 +size 167528 diff --git a/third_party/DKM/assets/mega_8_scenes_0024_0.1_0.3.npz b/third_party/DKM/assets/mega_8_scenes_0024_0.1_0.3.npz new file mode 100644 index 0000000000000000000000000000000000000000..29974c7f4dda288c93dbb3551342486581079d62 --- /dev/null +++ b/third_party/DKM/assets/mega_8_scenes_0024_0.1_0.3.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9f14a66dbbd7fa8f31756dd496bfabe4c3ea115c6914acad9365dd02e46ae674 +size 63909 diff --git a/third_party/DKM/assets/mega_8_scenes_0024_0.3_0.5.npz b/third_party/DKM/assets/mega_8_scenes_0024_0.3_0.5.npz new file mode 100644 index 0000000000000000000000000000000000000000..dcdb08c233530b2e9a88aedb1ad1bc496a3075cb --- /dev/null +++ b/third_party/DKM/assets/mega_8_scenes_0024_0.3_0.5.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dfaee333beccd1da0d920777cdc8f17d584b21ba20f675b39222c5a205acf72a +size 63909 diff --git a/third_party/DKM/assets/mega_8_scenes_0025_0.1_0.3.npz b/third_party/DKM/assets/mega_8_scenes_0025_0.1_0.3.npz new file mode 100644 index 0000000000000000000000000000000000000000..9e8ae03ed59aeba49b3223c72169d77e935a7e33 --- /dev/null +++ b/third_party/DKM/assets/mega_8_scenes_0025_0.1_0.3.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b446ca3cc2073c8a3963cf68cc450ef2ebf73d2b956b1f5ae6b37621bc67cce4 +size 200371 diff --git a/third_party/DKM/assets/mega_8_scenes_0025_0.3_0.5.npz b/third_party/DKM/assets/mega_8_scenes_0025_0.3_0.5.npz new file mode 100644 index 0000000000000000000000000000000000000000..08c57fa1e419e7e1802b1ddfb00bd30a1c27d785 --- /dev/null +++ b/third_party/DKM/assets/mega_8_scenes_0025_0.3_0.5.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df1969fd94032562b5e8d916467101a878168d586b795770e64c108bab250c9e +size 200371 diff --git a/third_party/DKM/assets/mega_8_scenes_0032_0.1_0.3.npz b/third_party/DKM/assets/mega_8_scenes_0032_0.1_0.3.npz new file mode 100644 index 0000000000000000000000000000000000000000..16b72b06ae0afd00af2ec626d9d492107ee64f1b --- /dev/null +++ b/third_party/DKM/assets/mega_8_scenes_0032_0.1_0.3.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:37cbafa0b0f981f5d69aba202ddd37c5892bda0fa13d053a3ad27d6ddad51c16 +size 642823 diff --git a/third_party/DKM/assets/mega_8_scenes_0032_0.3_0.5.npz b/third_party/DKM/assets/mega_8_scenes_0032_0.3_0.5.npz new file mode 100644 index 0000000000000000000000000000000000000000..bbef8bff2ed410e893d34423941b32be4977a794 --- /dev/null +++ b/third_party/DKM/assets/mega_8_scenes_0032_0.3_0.5.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:962b3fadc7c94ea8e4a1bb5e168e72b0b6cc474ae56b0aee70ba4e517553fbcf +size 642823 diff --git a/third_party/DKM/assets/mega_8_scenes_0063_0.1_0.3.npz b/third_party/DKM/assets/mega_8_scenes_0063_0.1_0.3.npz new file mode 100644 index 0000000000000000000000000000000000000000..fb1e5decaa554507d0f1eea82b1bba8bfa933b05 --- /dev/null +++ b/third_party/DKM/assets/mega_8_scenes_0063_0.1_0.3.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:50ed6b02dff2fa719e4e9ca216b4704b82cbbefd127355d3ba7120828407e723 +size 228647 diff --git a/third_party/DKM/assets/mega_8_scenes_0063_0.3_0.5.npz b/third_party/DKM/assets/mega_8_scenes_0063_0.3_0.5.npz new file mode 100644 index 0000000000000000000000000000000000000000..ef3248afb930afa46efc27054819fa313f6bf286 --- /dev/null +++ b/third_party/DKM/assets/mega_8_scenes_0063_0.3_0.5.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:edc97129d000f0478020495f646f2fa7667408247ccd11054e02efbbb38d1444 +size 228647 diff --git a/third_party/DKM/assets/mega_8_scenes_1589_0.1_0.3.npz b/third_party/DKM/assets/mega_8_scenes_1589_0.1_0.3.npz new file mode 100644 index 0000000000000000000000000000000000000000..b092500a792cd274bc1e1b91f488d237b01ce3b5 --- /dev/null +++ b/third_party/DKM/assets/mega_8_scenes_1589_0.1_0.3.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04b0b6c6adff812e12b66476f7ca2a6ed2564cdd8208ec0c775f7b922f160103 +size 177063 diff --git a/third_party/DKM/assets/mega_8_scenes_1589_0.3_0.5.npz b/third_party/DKM/assets/mega_8_scenes_1589_0.3_0.5.npz new file mode 100644 index 0000000000000000000000000000000000000000..24c4448c682c2f3c103dd5d4813784dadebd10fa --- /dev/null +++ b/third_party/DKM/assets/mega_8_scenes_1589_0.3_0.5.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ae931f8cac1b2168f699c70efe42c215eaff27d3f0617d59afb3db183c9b1848 +size 177063 diff --git a/third_party/DKM/assets/mount_rushmore.mp4 b/third_party/DKM/assets/mount_rushmore.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..3beebd798c74749f2034e73f068d919d307d46b9 Binary files /dev/null and b/third_party/DKM/assets/mount_rushmore.mp4 differ diff --git a/third_party/DKM/assets/sacre_coeur_A.jpg b/third_party/DKM/assets/sacre_coeur_A.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6e441dad34cf13d8a29d7c6a1519f4263c40058c --- /dev/null +++ b/third_party/DKM/assets/sacre_coeur_A.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:90d9c5f5a4d76425624989215120fba6f2899190a1d5654b88fa380c64cf6b2c +size 117985 diff --git a/third_party/DKM/assets/sacre_coeur_B.jpg b/third_party/DKM/assets/sacre_coeur_B.jpg new file mode 100644 index 0000000000000000000000000000000000000000..27a239a8fa7581d909104872754ecda79422e7b6 --- /dev/null +++ b/third_party/DKM/assets/sacre_coeur_B.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2f1eb9bdd4d80e480f672d6a729689ac77f9fd5c8deb90f59b377590f3ca4799 +size 152515 diff --git a/third_party/DKM/data/.gitignore b/third_party/DKM/data/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..c96a04f008ee21e260b28f7701595ed59e2839e3 --- /dev/null +++ b/third_party/DKM/data/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore \ No newline at end of file diff --git a/third_party/DKM/demo/.gitignore b/third_party/DKM/demo/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..76ce7fcf6f600a9db91639ce9e445bec98ff1671 --- /dev/null +++ b/third_party/DKM/demo/.gitignore @@ -0,0 +1 @@ +*.jpg diff --git a/third_party/DKM/demo/demo_fundamental.py b/third_party/DKM/demo/demo_fundamental.py new file mode 100644 index 0000000000000000000000000000000000000000..e19766d5d3ce1abf0d18483cbbce71b2696983be --- /dev/null +++ b/third_party/DKM/demo/demo_fundamental.py @@ -0,0 +1,37 @@ +from PIL import Image +import torch +import torch.nn.functional as F +import numpy as np +from dkm.utils.utils import tensor_to_pil +import cv2 +from dkm import DKMv3_outdoor + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + +if __name__ == "__main__": + from argparse import ArgumentParser + parser = ArgumentParser() + parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str) + parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str) + + args, _ = parser.parse_known_args() + im1_path = args.im_A_path + im2_path = args.im_B_path + + # Create model + dkm_model = DKMv3_outdoor(device=device) + + + W_A, H_A = Image.open(im1_path).size + W_B, H_B = Image.open(im2_path).size + + # Match + warp, certainty = dkm_model.match(im1_path, im2_path, device=device) + # Sample matches for estimation + matches, certainty = dkm_model.sample(warp, certainty) + kpts1, kpts2 = dkm_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B) + F, mask = cv2.findFundamentalMat( + kpts1.cpu().numpy(), kpts2.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000 + ) + # TODO: some better visualization \ No newline at end of file diff --git a/third_party/DKM/demo/demo_match.py b/third_party/DKM/demo/demo_match.py new file mode 100644 index 0000000000000000000000000000000000000000..fb901894d8654a884819162d3b9bb8094529e034 --- /dev/null +++ b/third_party/DKM/demo/demo_match.py @@ -0,0 +1,48 @@ +from PIL import Image +import torch +import torch.nn.functional as F +import numpy as np +from dkm.utils.utils import tensor_to_pil + +from dkm import DKMv3_outdoor + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + +if __name__ == "__main__": + from argparse import ArgumentParser + parser = ArgumentParser() + parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str) + parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str) + parser.add_argument("--save_path", default="demo/dkmv3_warp_sacre_coeur.jpg", type=str) + + args, _ = parser.parse_known_args() + im1_path = args.im_A_path + im2_path = args.im_B_path + save_path = args.save_path + + # Create model + dkm_model = DKMv3_outdoor(device=device) + + H, W = 864, 1152 + + im1 = Image.open(im1_path).resize((W, H)) + im2 = Image.open(im2_path).resize((W, H)) + + # Match + warp, certainty = dkm_model.match(im1_path, im2_path, device=device) + # Sampling not needed, but can be done with model.sample(warp, certainty) + dkm_model.sample(warp, certainty) + x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1) + x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1) + + im2_transfer_rgb = F.grid_sample( + x2[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False + )[0] + im1_transfer_rgb = F.grid_sample( + x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False + )[0] + warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2) + white_im = torch.ones((H,2*W),device=device) + vis_im = certainty * warp_im + (1 - certainty) * white_im + tensor_to_pil(vis_im, unnormalize=False).save(save_path) diff --git a/third_party/DKM/dkm/__init__.py b/third_party/DKM/dkm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b47632780acc7762bcccc348e2025fe99f3726 --- /dev/null +++ b/third_party/DKM/dkm/__init__.py @@ -0,0 +1,4 @@ +from .models import ( + DKMv3_outdoor, + DKMv3_indoor, + ) diff --git a/third_party/DKM/dkm/benchmarks/__init__.py b/third_party/DKM/dkm/benchmarks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..57643fd314a2301138aecdc804a5877d0ce9274e --- /dev/null +++ b/third_party/DKM/dkm/benchmarks/__init__.py @@ -0,0 +1,4 @@ +from .hpatches_sequences_homog_benchmark import HpatchesHomogBenchmark +from .scannet_benchmark import ScanNetBenchmark +from .megadepth1500_benchmark import Megadepth1500Benchmark +from .megadepth_dense_benchmark import MegadepthDenseBenchmark diff --git a/third_party/DKM/dkm/benchmarks/deprecated/hpatches_sequences_dense_benchmark.py b/third_party/DKM/dkm/benchmarks/deprecated/hpatches_sequences_dense_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..079622fdaf77c75aeadd675629f2512c45d04c2d --- /dev/null +++ b/third_party/DKM/dkm/benchmarks/deprecated/hpatches_sequences_dense_benchmark.py @@ -0,0 +1,100 @@ +from PIL import Image +import numpy as np + +import os + +import torch +from tqdm import tqdm + +from dkm.utils import * + + +class HpatchesDenseBenchmark: + """WARNING: HPATCHES grid goes from [0,n-1] instead of [0.5,n-0.5]""" + + def __init__(self, dataset_path) -> None: + seqs_dir = "hpatches-sequences-release" + self.seqs_path = os.path.join(dataset_path, seqs_dir) + self.seq_names = sorted(os.listdir(self.seqs_path)) + + def convert_coordinates(self, query_coords, query_to_support, wq, hq, wsup, hsup): + # Get matches in output format on the grid [0, n] where the center of the top-left coordinate is [0.5, 0.5] + offset = ( + 0.5 # Hpatches assumes that the center of the top-left pixel is at [0,0] + ) + query_coords = ( + torch.stack( + ( + wq * (query_coords[..., 0] + 1) / 2, + hq * (query_coords[..., 1] + 1) / 2, + ), + axis=-1, + ) + - offset + ) + query_to_support = ( + torch.stack( + ( + wsup * (query_to_support[..., 0] + 1) / 2, + hsup * (query_to_support[..., 1] + 1) / 2, + ), + axis=-1, + ) + - offset + ) + return query_coords, query_to_support + + def inside_image(self, x, w, h): + return torch.logical_and( + x[:, 0] < (w - 1), + torch.logical_and(x[:, 1] < (h - 1), (x > 0).prod(dim=-1)), + ) + + def benchmark(self, model): + use_cuda = torch.cuda.is_available() + device = torch.device("cuda:0" if use_cuda else "cpu") + aepes = [] + pcks = [] + for seq_idx, seq_name in tqdm( + enumerate(self.seq_names), total=len(self.seq_names) + ): + if seq_name[0] == "i": + continue + im1_path = os.path.join(self.seqs_path, seq_name, "1.ppm") + im1 = Image.open(im1_path) + w1, h1 = im1.size + for im_idx in range(2, 7): + im2_path = os.path.join(self.seqs_path, seq_name, f"{im_idx}.ppm") + im2 = Image.open(im2_path) + w2, h2 = im2.size + matches, certainty = model.match(im2, im1, do_pred_in_og_res=True) + matches, certainty = matches.reshape(-1, 4), certainty.reshape(-1) + inv_homography = torch.from_numpy( + np.loadtxt( + os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx)) + ) + ).to(device) + homography = torch.linalg.inv(inv_homography) + pos_a, pos_b = self.convert_coordinates( + matches[:, :2], matches[:, 2:], w2, h2, w1, h1 + ) + pos_a, pos_b = pos_a.double(), pos_b.double() + pos_a_h = torch.cat( + [pos_a, torch.ones([pos_a.shape[0], 1], device=device)], dim=1 + ) + pos_b_proj_h = (homography @ pos_a_h.t()).t() + pos_b_proj = pos_b_proj_h[:, :2] / pos_b_proj_h[:, 2:] + mask = self.inside_image(pos_b_proj, w1, h1) + residual = pos_b - pos_b_proj + dist = (residual**2).sum(dim=1).sqrt()[mask] + aepes.append(torch.mean(dist).item()) + pck1 = (dist < 1.0).float().mean().item() + pck3 = (dist < 3.0).float().mean().item() + pck5 = (dist < 5.0).float().mean().item() + pcks.append([pck1, pck3, pck5]) + m_pcks = np.mean(np.array(pcks), axis=0) + return { + "hp_pck1": m_pcks[0], + "hp_pck3": m_pcks[1], + "hp_pck5": m_pcks[2], + } diff --git a/third_party/DKM/dkm/benchmarks/deprecated/yfcc100m_benchmark.py b/third_party/DKM/dkm/benchmarks/deprecated/yfcc100m_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..781a291f0c358cbd435790b0a639f2a2510145b2 --- /dev/null +++ b/third_party/DKM/dkm/benchmarks/deprecated/yfcc100m_benchmark.py @@ -0,0 +1,119 @@ +import pickle +import h5py +import numpy as np +import torch +from dkm.utils import * +from PIL import Image +from tqdm import tqdm + + +class Yfcc100mBenchmark: + def __init__(self, data_root="data/yfcc100m_test") -> None: + self.scenes = [ + "buckingham_palace", + "notre_dame_front_facade", + "reichstag", + "sacre_coeur", + ] + self.data_root = data_root + + def benchmark(self, model, r=2): + model.train(False) + with torch.no_grad(): + data_root = self.data_root + meta_info = open( + f"{data_root}/yfcc_test_pairs_with_gt.txt", "r" + ).readlines() + tot_e_t, tot_e_R, tot_e_pose = [], [], [] + for scene_ind in range(len(self.scenes)): + scene = self.scenes[scene_ind] + pairs = np.array( + pickle.load( + open(f"{data_root}/pairs/{scene}-te-1000-pairs.pkl", "rb") + ) + ) + scene_dir = f"{data_root}/yfcc100m/{scene}/test/" + calibs = open(scene_dir + "calibration.txt", "r").read().split("\n") + images = open(scene_dir + "images.txt", "r").read().split("\n") + pair_inds = np.random.choice( + range(len(pairs)), size=len(pairs), replace=False + ) + for pairind in tqdm(pair_inds): + idx1, idx2 = pairs[pairind] + params = meta_info[1000 * scene_ind + pairind].split() + rot1, rot2 = int(params[2]), int(params[3]) + calib1 = h5py.File(scene_dir + calibs[idx1], "r") + K1, R1, t1, _, _ = get_pose(calib1) + calib2 = h5py.File(scene_dir + calibs[idx2], "r") + K2, R2, t2, _, _ = get_pose(calib2) + + R, t = compute_relative_pose(R1, t1, R2, t2) + im1 = images[idx1] + im2 = images[idx2] + im1 = Image.open(scene_dir + im1).rotate(rot1 * 90, expand=True) + w1, h1 = im1.size + im2 = Image.open(scene_dir + im2).rotate(rot2 * 90, expand=True) + w2, h2 = im2.size + K1 = rotate_intrinsic(K1, rot1) + K2 = rotate_intrinsic(K2, rot2) + + dense_matches, dense_certainty = model.match(im1, im2) + dense_certainty = dense_certainty ** (1 / r) + sparse_matches, sparse_confidence = model.sample( + dense_matches, dense_certainty, 10000 + ) + scale1 = 480 / min(w1, h1) + scale2 = 480 / min(w2, h2) + w1, h1 = scale1 * w1, scale1 * h1 + w2, h2 = scale2 * w2, scale2 * h2 + K1 = K1 * scale1 + K2 = K2 * scale2 + + kpts1 = sparse_matches[:, :2] + kpts1 = np.stack( + (w1 * kpts1[:, 0] / 2, h1 * kpts1[:, 1] / 2), axis=-1 + ) + kpts2 = sparse_matches[:, 2:] + kpts2 = np.stack( + (w2 * kpts2[:, 0] / 2, h2 * kpts2[:, 1] / 2), axis=-1 + ) + try: + threshold = 1.0 + norm_threshold = threshold / ( + np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])) + ) + R_est, t_est, mask = estimate_pose( + kpts1, + kpts2, + K1[:2, :2], + K2[:2, :2], + norm_threshold, + conf=0.9999999, + ) + T1_to_2 = np.concatenate((R_est, t_est), axis=-1) # + e_t, e_R = compute_pose_error(T1_to_2, R, t) + e_pose = max(e_t, e_R) + except: + e_t, e_R = 90, 90 + e_pose = max(e_t, e_R) + tot_e_t.append(e_t) + tot_e_R.append(e_R) + tot_e_pose.append(e_pose) + tot_e_pose = np.array(tot_e_pose) + thresholds = [5, 10, 20] + auc = pose_auc(tot_e_pose, thresholds) + acc_5 = (tot_e_pose < 5).mean() + acc_10 = (tot_e_pose < 10).mean() + acc_15 = (tot_e_pose < 15).mean() + acc_20 = (tot_e_pose < 20).mean() + map_5 = acc_5 + map_10 = np.mean([acc_5, acc_10]) + map_20 = np.mean([acc_5, acc_10, acc_15, acc_20]) + return { + "auc_5": auc[0], + "auc_10": auc[1], + "auc_20": auc[2], + "map_5": map_5, + "map_10": map_10, + "map_20": map_20, + } diff --git a/third_party/DKM/dkm/benchmarks/hpatches_sequences_homog_benchmark.py b/third_party/DKM/dkm/benchmarks/hpatches_sequences_homog_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..9c3febe5ca9e3a683bc7122cec635c4f54b66f7c --- /dev/null +++ b/third_party/DKM/dkm/benchmarks/hpatches_sequences_homog_benchmark.py @@ -0,0 +1,114 @@ +from PIL import Image +import numpy as np + +import os + +from tqdm import tqdm +from dkm.utils import pose_auc +import cv2 + + +class HpatchesHomogBenchmark: + """Hpatches grid goes from [0,n-1] instead of [0.5,n-0.5]""" + + def __init__(self, dataset_path) -> None: + seqs_dir = "hpatches-sequences-release" + self.seqs_path = os.path.join(dataset_path, seqs_dir) + self.seq_names = sorted(os.listdir(self.seqs_path)) + # Ignore seqs is same as LoFTR. + self.ignore_seqs = set( + [ + "i_contruction", + "i_crownnight", + "i_dc", + "i_pencils", + "i_whitebuilding", + "v_artisans", + "v_astronautis", + "v_talent", + ] + ) + + def convert_coordinates(self, query_coords, query_to_support, wq, hq, wsup, hsup): + offset = 0.5 # Hpatches assumes that the center of the top-left pixel is at [0,0] (I think) + query_coords = ( + np.stack( + ( + wq * (query_coords[..., 0] + 1) / 2, + hq * (query_coords[..., 1] + 1) / 2, + ), + axis=-1, + ) + - offset + ) + query_to_support = ( + np.stack( + ( + wsup * (query_to_support[..., 0] + 1) / 2, + hsup * (query_to_support[..., 1] + 1) / 2, + ), + axis=-1, + ) + - offset + ) + return query_coords, query_to_support + + def benchmark(self, model, model_name = None): + n_matches = [] + homog_dists = [] + for seq_idx, seq_name in tqdm( + enumerate(self.seq_names), total=len(self.seq_names) + ): + if seq_name in self.ignore_seqs: + continue + im1_path = os.path.join(self.seqs_path, seq_name, "1.ppm") + im1 = Image.open(im1_path) + w1, h1 = im1.size + for im_idx in range(2, 7): + im2_path = os.path.join(self.seqs_path, seq_name, f"{im_idx}.ppm") + im2 = Image.open(im2_path) + w2, h2 = im2.size + H = np.loadtxt( + os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx)) + ) + dense_matches, dense_certainty = model.match( + im1_path, im2_path + ) + good_matches, _ = model.sample(dense_matches, dense_certainty, 5000) + pos_a, pos_b = self.convert_coordinates( + good_matches[:, :2], good_matches[:, 2:], w1, h1, w2, h2 + ) + try: + H_pred, inliers = cv2.findHomography( + pos_a, + pos_b, + method = cv2.RANSAC, + confidence = 0.99999, + ransacReprojThreshold = 3 * min(w2, h2) / 480, + ) + except: + H_pred = None + if H_pred is None: + H_pred = np.zeros((3, 3)) + H_pred[2, 2] = 1.0 + corners = np.array( + [[0, 0, 1], [0, h1 - 1, 1], [w1 - 1, 0, 1], [w1 - 1, h1 - 1, 1]] + ) + real_warped_corners = np.dot(corners, np.transpose(H)) + real_warped_corners = ( + real_warped_corners[:, :2] / real_warped_corners[:, 2:] + ) + warped_corners = np.dot(corners, np.transpose(H_pred)) + warped_corners = warped_corners[:, :2] / warped_corners[:, 2:] + mean_dist = np.mean( + np.linalg.norm(real_warped_corners - warped_corners, axis=1) + ) / (min(w2, h2) / 480.0) + homog_dists.append(mean_dist) + n_matches = np.array(n_matches) + thresholds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + auc = pose_auc(np.array(homog_dists), thresholds) + return { + "hpatches_homog_auc_3": auc[2], + "hpatches_homog_auc_5": auc[4], + "hpatches_homog_auc_10": auc[9], + } diff --git a/third_party/DKM/dkm/benchmarks/megadepth1500_benchmark.py b/third_party/DKM/dkm/benchmarks/megadepth1500_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..6b1193745ff18d239165aeb3376642fb17033874 --- /dev/null +++ b/third_party/DKM/dkm/benchmarks/megadepth1500_benchmark.py @@ -0,0 +1,124 @@ +import numpy as np +import torch +from dkm.utils import * +from PIL import Image +from tqdm import tqdm +import torch.nn.functional as F + +class Megadepth1500Benchmark: + def __init__(self, data_root="data/megadepth", scene_names = None) -> None: + if scene_names is None: + self.scene_names = [ + "0015_0.1_0.3.npz", + "0015_0.3_0.5.npz", + "0022_0.1_0.3.npz", + "0022_0.3_0.5.npz", + "0022_0.5_0.7.npz", + ] + else: + self.scene_names = scene_names + self.scenes = [ + np.load(f"{data_root}/{scene}", allow_pickle=True) + for scene in self.scene_names + ] + self.data_root = data_root + + def benchmark(self, model): + with torch.no_grad(): + data_root = self.data_root + tot_e_t, tot_e_R, tot_e_pose = [], [], [] + for scene_ind in range(len(self.scenes)): + scene = self.scenes[scene_ind] + pairs = scene["pair_infos"] + intrinsics = scene["intrinsics"] + poses = scene["poses"] + im_paths = scene["image_paths"] + pair_inds = range(len(pairs)) + for pairind in tqdm(pair_inds): + idx1, idx2 = pairs[pairind][0] + K1 = intrinsics[idx1].copy() + T1 = poses[idx1].copy() + R1, t1 = T1[:3, :3], T1[:3, 3] + K2 = intrinsics[idx2].copy() + T2 = poses[idx2].copy() + R2, t2 = T2[:3, :3], T2[:3, 3] + R, t = compute_relative_pose(R1, t1, R2, t2) + im1_path = f"{data_root}/{im_paths[idx1]}" + im2_path = f"{data_root}/{im_paths[idx2]}" + im1 = Image.open(im1_path) + w1, h1 = im1.size + im2 = Image.open(im2_path) + w2, h2 = im2.size + scale1 = 1200 / max(w1, h1) + scale2 = 1200 / max(w2, h2) + w1, h1 = scale1 * w1, scale1 * h1 + w2, h2 = scale2 * w2, scale2 * h2 + K1[:2] = K1[:2] * scale1 + K2[:2] = K2[:2] * scale2 + dense_matches, dense_certainty = model.match(im1_path, im2_path) + sparse_matches,_ = model.sample( + dense_matches, dense_certainty, 5000 + ) + kpts1 = sparse_matches[:, :2] + kpts1 = ( + torch.stack( + ( + w1 * (kpts1[:, 0] + 1) / 2, + h1 * (kpts1[:, 1] + 1) / 2, + ), + axis=-1, + ) + ) + kpts2 = sparse_matches[:, 2:] + kpts2 = ( + torch.stack( + ( + w2 * (kpts2[:, 0] + 1) / 2, + h2 * (kpts2[:, 1] + 1) / 2, + ), + axis=-1, + ) + ) + for _ in range(5): + shuffling = np.random.permutation(np.arange(len(kpts1))) + kpts1 = kpts1[shuffling] + kpts2 = kpts2[shuffling] + try: + norm_threshold = 0.5 / ( + np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2]))) + R_est, t_est, mask = estimate_pose( + kpts1.cpu().numpy(), + kpts2.cpu().numpy(), + K1, + K2, + norm_threshold, + conf=0.99999, + ) + T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) # + e_t, e_R = compute_pose_error(T1_to_2_est, R, t) + e_pose = max(e_t, e_R) + except Exception as e: + print(repr(e)) + e_t, e_R = 90, 90 + e_pose = max(e_t, e_R) + tot_e_t.append(e_t) + tot_e_R.append(e_R) + tot_e_pose.append(e_pose) + tot_e_pose = np.array(tot_e_pose) + thresholds = [5, 10, 20] + auc = pose_auc(tot_e_pose, thresholds) + acc_5 = (tot_e_pose < 5).mean() + acc_10 = (tot_e_pose < 10).mean() + acc_15 = (tot_e_pose < 15).mean() + acc_20 = (tot_e_pose < 20).mean() + map_5 = acc_5 + map_10 = np.mean([acc_5, acc_10]) + map_20 = np.mean([acc_5, acc_10, acc_15, acc_20]) + return { + "auc_5": auc[0], + "auc_10": auc[1], + "auc_20": auc[2], + "map_5": map_5, + "map_10": map_10, + "map_20": map_20, + } diff --git a/third_party/DKM/dkm/benchmarks/megadepth_dense_benchmark.py b/third_party/DKM/dkm/benchmarks/megadepth_dense_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..0b370644497efd62563105e68e692e10ff339669 --- /dev/null +++ b/third_party/DKM/dkm/benchmarks/megadepth_dense_benchmark.py @@ -0,0 +1,86 @@ +import torch +import numpy as np +import tqdm +from dkm.datasets import MegadepthBuilder +from dkm.utils import warp_kpts +from torch.utils.data import ConcatDataset + + +class MegadepthDenseBenchmark: + def __init__(self, data_root="data/megadepth", h = 384, w = 512, num_samples = 2000, device=None) -> None: + mega = MegadepthBuilder(data_root=data_root) + self.dataset = ConcatDataset( + mega.build_scenes(split="test_loftr", ht=h, wt=w) + ) # fixed resolution of 384,512 + self.num_samples = num_samples + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.device = device + + def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches): + b, h1, w1, d = dense_matches.shape + with torch.no_grad(): + x1 = dense_matches[..., :2].reshape(b, h1 * w1, 2) + # x1 = torch.stack((2*x1[...,0]/w1-1,2*x1[...,1]/h1-1),dim=-1) + mask, x2 = warp_kpts( + x1.double(), + depth1.double(), + depth2.double(), + T_1to2.double(), + K1.double(), + K2.double(), + ) + x2 = torch.stack( + (w1 * (x2[..., 0] + 1) / 2, h1 * (x2[..., 1] + 1) / 2), dim=-1 + ) + prob = mask.float().reshape(b, h1, w1) + x2_hat = dense_matches[..., 2:] + x2_hat = torch.stack( + (w1 * (x2_hat[..., 0] + 1) / 2, h1 * (x2_hat[..., 1] + 1) / 2), dim=-1 + ) + gd = (x2_hat - x2.reshape(b, h1, w1, 2)).norm(dim=-1) + gd = gd[prob == 1] + pck_1 = (gd < 1.0).float().mean() + pck_3 = (gd < 3.0).float().mean() + pck_5 = (gd < 5.0).float().mean() + gd = gd.mean() + return gd, pck_1, pck_3, pck_5 + + def benchmark(self, model, batch_size=8): + model.train(False) + with torch.no_grad(): + gd_tot = 0.0 + pck_1_tot = 0.0 + pck_3_tot = 0.0 + pck_5_tot = 0.0 + sampler = torch.utils.data.WeightedRandomSampler( + torch.ones(len(self.dataset)), replacement=False, num_samples=self.num_samples + ) + dataloader = torch.utils.data.DataLoader( + self.dataset, batch_size=8, num_workers=batch_size, sampler=sampler + ) + for data in tqdm.tqdm(dataloader): + im1, im2, depth1, depth2, T_1to2, K1, K2 = ( + data["query"], + data["support"], + data["query_depth"].to(self.device), + data["support_depth"].to(self.device), + data["T_1to2"].to(self.device), + data["K1"].to(self.device), + data["K2"].to(self.device), + ) + matches, certainty = model.match(im1, im2, batched=True) + gd, pck_1, pck_3, pck_5 = self.geometric_dist( + depth1, depth2, T_1to2, K1, K2, matches + ) + gd_tot, pck_1_tot, pck_3_tot, pck_5_tot = ( + gd_tot + gd, + pck_1_tot + pck_1, + pck_3_tot + pck_3, + pck_5_tot + pck_5, + ) + return { + "mega_pck_1": pck_1_tot.item() / len(dataloader), + "mega_pck_3": pck_3_tot.item() / len(dataloader), + "mega_pck_5": pck_5_tot.item() / len(dataloader), + } diff --git a/third_party/DKM/dkm/benchmarks/scannet_benchmark.py b/third_party/DKM/dkm/benchmarks/scannet_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..ca938cb462c351845ce035f8be0714cf81214452 --- /dev/null +++ b/third_party/DKM/dkm/benchmarks/scannet_benchmark.py @@ -0,0 +1,143 @@ +import os.path as osp +import numpy as np +import torch +from dkm.utils import * +from PIL import Image +from tqdm import tqdm + + +class ScanNetBenchmark: + def __init__(self, data_root="data/scannet") -> None: + self.data_root = data_root + + def benchmark(self, model, model_name = None): + model.train(False) + with torch.no_grad(): + data_root = self.data_root + tmp = np.load(osp.join(data_root, "test.npz")) + pairs, rel_pose = tmp["name"], tmp["rel_pose"] + tot_e_t, tot_e_R, tot_e_pose = [], [], [] + pair_inds = np.random.choice( + range(len(pairs)), size=len(pairs), replace=False + ) + for pairind in tqdm(pair_inds, smoothing=0.9): + scene = pairs[pairind] + scene_name = f"scene0{scene[0]}_00" + im1_path = osp.join( + self.data_root, + "scans_test", + scene_name, + "color", + f"{scene[2]}.jpg", + ) + im1 = Image.open(im1_path) + im2_path = osp.join( + self.data_root, + "scans_test", + scene_name, + "color", + f"{scene[3]}.jpg", + ) + im2 = Image.open(im2_path) + T_gt = rel_pose[pairind].reshape(3, 4) + R, t = T_gt[:3, :3], T_gt[:3, 3] + K = np.stack( + [ + np.array([float(i) for i in r.split()]) + for r in open( + osp.join( + self.data_root, + "scans_test", + scene_name, + "intrinsic", + "intrinsic_color.txt", + ), + "r", + ) + .read() + .split("\n") + if r + ] + ) + w1, h1 = im1.size + w2, h2 = im2.size + K1 = K.copy() + K2 = K.copy() + dense_matches, dense_certainty = model.match(im1_path, im2_path) + sparse_matches, sparse_certainty = model.sample( + dense_matches, dense_certainty, 5000 + ) + scale1 = 480 / min(w1, h1) + scale2 = 480 / min(w2, h2) + w1, h1 = scale1 * w1, scale1 * h1 + w2, h2 = scale2 * w2, scale2 * h2 + K1 = K1 * scale1 + K2 = K2 * scale2 + + offset = 0.5 + kpts1 = sparse_matches[:, :2] + kpts1 = ( + np.stack( + ( + w1 * (kpts1[:, 0] + 1) / 2 - offset, + h1 * (kpts1[:, 1] + 1) / 2 - offset, + ), + axis=-1, + ) + ) + kpts2 = sparse_matches[:, 2:] + kpts2 = ( + np.stack( + ( + w2 * (kpts2[:, 0] + 1) / 2 - offset, + h2 * (kpts2[:, 1] + 1) / 2 - offset, + ), + axis=-1, + ) + ) + for _ in range(5): + shuffling = np.random.permutation(np.arange(len(kpts1))) + kpts1 = kpts1[shuffling] + kpts2 = kpts2[shuffling] + try: + norm_threshold = 0.5 / ( + np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2]))) + R_est, t_est, mask = estimate_pose( + kpts1, + kpts2, + K1, + K2, + norm_threshold, + conf=0.99999, + ) + T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) # + e_t, e_R = compute_pose_error(T1_to_2_est, R, t) + e_pose = max(e_t, e_R) + except Exception as e: + print(repr(e)) + e_t, e_R = 90, 90 + e_pose = max(e_t, e_R) + tot_e_t.append(e_t) + tot_e_R.append(e_R) + tot_e_pose.append(e_pose) + tot_e_t.append(e_t) + tot_e_R.append(e_R) + tot_e_pose.append(e_pose) + tot_e_pose = np.array(tot_e_pose) + thresholds = [5, 10, 20] + auc = pose_auc(tot_e_pose, thresholds) + acc_5 = (tot_e_pose < 5).mean() + acc_10 = (tot_e_pose < 10).mean() + acc_15 = (tot_e_pose < 15).mean() + acc_20 = (tot_e_pose < 20).mean() + map_5 = acc_5 + map_10 = np.mean([acc_5, acc_10]) + map_20 = np.mean([acc_5, acc_10, acc_15, acc_20]) + return { + "auc_5": auc[0], + "auc_10": auc[1], + "auc_20": auc[2], + "map_5": map_5, + "map_10": map_10, + "map_20": map_20, + } diff --git a/third_party/DKM/dkm/checkpointing/__init__.py b/third_party/DKM/dkm/checkpointing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..22f5afe727aa6f6e8fffa9ecf5be69cbff686577 --- /dev/null +++ b/third_party/DKM/dkm/checkpointing/__init__.py @@ -0,0 +1 @@ +from .checkpoint import CheckPoint diff --git a/third_party/DKM/dkm/checkpointing/checkpoint.py b/third_party/DKM/dkm/checkpointing/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..715eeb587ebb87ed0d1bcf9940e048adbe35cde2 --- /dev/null +++ b/third_party/DKM/dkm/checkpointing/checkpoint.py @@ -0,0 +1,31 @@ +import os +import torch +from torch.nn.parallel.data_parallel import DataParallel +from torch.nn.parallel.distributed import DistributedDataParallel +from loguru import logger + + +class CheckPoint: + def __init__(self, dir=None, name="tmp"): + self.name = name + self.dir = dir + os.makedirs(self.dir, exist_ok=True) + + def __call__( + self, + model, + optimizer, + lr_scheduler, + n, + ): + assert model is not None + if isinstance(model, (DataParallel, DistributedDataParallel)): + model = model.module + states = { + "model": model.state_dict(), + "n": n, + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + } + torch.save(states, self.dir + self.name + f"_latest.pth") + logger.info(f"Saved states {list(states.keys())}, at step {n}") diff --git a/third_party/DKM/dkm/datasets/__init__.py b/third_party/DKM/dkm/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6b81083212edaf345c30f0cb1116c5f9de284ce6 --- /dev/null +++ b/third_party/DKM/dkm/datasets/__init__.py @@ -0,0 +1 @@ +from .megadepth import MegadepthBuilder diff --git a/third_party/DKM/dkm/datasets/megadepth.py b/third_party/DKM/dkm/datasets/megadepth.py new file mode 100644 index 0000000000000000000000000000000000000000..c580607e910ce1926b7711b5473aa82b20865369 --- /dev/null +++ b/third_party/DKM/dkm/datasets/megadepth.py @@ -0,0 +1,177 @@ +import os +import random +from PIL import Image +import h5py +import numpy as np +import torch +from torch.utils.data import Dataset, DataLoader, ConcatDataset + +from dkm.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops +import torchvision.transforms.functional as tvf +from dkm.utils.transforms import GeometricSequential +import kornia.augmentation as K + + +class MegadepthScene: + def __init__( + self, + data_root, + scene_info, + ht=384, + wt=512, + min_overlap=0.0, + shake_t=0, + rot_prob=0.0, + normalize=True, + ) -> None: + self.data_root = data_root + self.image_paths = scene_info["image_paths"] + self.depth_paths = scene_info["depth_paths"] + self.intrinsics = scene_info["intrinsics"] + self.poses = scene_info["poses"] + self.pairs = scene_info["pairs"] + self.overlaps = scene_info["overlaps"] + threshold = self.overlaps > min_overlap + self.pairs = self.pairs[threshold] + self.overlaps = self.overlaps[threshold] + if len(self.pairs) > 100000: + pairinds = np.random.choice( + np.arange(0, len(self.pairs)), 100000, replace=False + ) + self.pairs = self.pairs[pairinds] + self.overlaps = self.overlaps[pairinds] + # counts, bins = np.histogram(self.overlaps,20) + # print(counts) + self.im_transform_ops = get_tuple_transform_ops( + resize=(ht, wt), normalize=normalize + ) + self.depth_transform_ops = get_depth_tuple_transform_ops( + resize=(ht, wt), normalize=False + ) + self.wt, self.ht = wt, ht + self.shake_t = shake_t + self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob)) + + def load_im(self, im_ref, crop=None): + im = Image.open(im_ref) + return im + + def load_depth(self, depth_ref, crop=None): + depth = np.array(h5py.File(depth_ref, "r")["depth"]) + return torch.from_numpy(depth) + + def __len__(self): + return len(self.pairs) + + def scale_intrinsic(self, K, wi, hi): + sx, sy = self.wt / wi, self.ht / hi + sK = torch.tensor([[sx, 0, 0], [0, sy, 0], [0, 0, 1]]) + return sK @ K + + def rand_shake(self, *things): + t = np.random.choice(range(-self.shake_t, self.shake_t + 1), size=2) + return [ + tvf.affine(thing, angle=0.0, translate=list(t), scale=1.0, shear=[0.0, 0.0]) + for thing in things + ], t + + def __getitem__(self, pair_idx): + # read intrinsics of original size + idx1, idx2 = self.pairs[pair_idx] + K1 = torch.tensor(self.intrinsics[idx1].copy(), dtype=torch.float).reshape(3, 3) + K2 = torch.tensor(self.intrinsics[idx2].copy(), dtype=torch.float).reshape(3, 3) + + # read and compute relative poses + T1 = self.poses[idx1] + T2 = self.poses[idx2] + T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[ + :4, :4 + ] # (4, 4) + + # Load positive pair data + im1, im2 = self.image_paths[idx1], self.image_paths[idx2] + depth1, depth2 = self.depth_paths[idx1], self.depth_paths[idx2] + im_src_ref = os.path.join(self.data_root, im1) + im_pos_ref = os.path.join(self.data_root, im2) + depth_src_ref = os.path.join(self.data_root, depth1) + depth_pos_ref = os.path.join(self.data_root, depth2) + # return torch.randn((1000,1000)) + im_src = self.load_im(im_src_ref) + im_pos = self.load_im(im_pos_ref) + depth_src = self.load_depth(depth_src_ref) + depth_pos = self.load_depth(depth_pos_ref) + + # Recompute camera intrinsic matrix due to the resize + K1 = self.scale_intrinsic(K1, im_src.width, im_src.height) + K2 = self.scale_intrinsic(K2, im_pos.width, im_pos.height) + # Process images + im_src, im_pos = self.im_transform_ops((im_src, im_pos)) + depth_src, depth_pos = self.depth_transform_ops( + (depth_src[None, None], depth_pos[None, None]) + ) + [im_src, im_pos, depth_src, depth_pos], t = self.rand_shake( + im_src, im_pos, depth_src, depth_pos + ) + im_src, Hq = self.H_generator(im_src[None]) + depth_src = self.H_generator.apply_transform(depth_src, Hq) + K1[:2, 2] += t + K2[:2, 2] += t + K1 = Hq[0] @ K1 + data_dict = { + "query": im_src[0], + "query_identifier": self.image_paths[idx1].split("/")[-1].split(".jpg")[0], + "support": im_pos, + "support_identifier": self.image_paths[idx2] + .split("/")[-1] + .split(".jpg")[0], + "query_depth": depth_src[0, 0], + "support_depth": depth_pos[0, 0], + "K1": K1, + "K2": K2, + "T_1to2": T_1to2, + } + return data_dict + + +class MegadepthBuilder: + def __init__(self, data_root="data/megadepth") -> None: + self.data_root = data_root + self.scene_info_root = os.path.join(data_root, "prep_scene_info") + self.all_scenes = os.listdir(self.scene_info_root) + self.test_scenes = ["0017.npy", "0004.npy", "0048.npy", "0013.npy"] + self.test_scenes_loftr = ["0015.npy", "0022.npy"] + + def build_scenes(self, split="train", min_overlap=0.0, **kwargs): + if split == "train": + scene_names = set(self.all_scenes) - set(self.test_scenes) + elif split == "train_loftr": + scene_names = set(self.all_scenes) - set(self.test_scenes_loftr) + elif split == "test": + scene_names = self.test_scenes + elif split == "test_loftr": + scene_names = self.test_scenes_loftr + else: + raise ValueError(f"Split {split} not available") + scenes = [] + for scene_name in scene_names: + scene_info = np.load( + os.path.join(self.scene_info_root, scene_name), allow_pickle=True + ).item() + scenes.append( + MegadepthScene( + self.data_root, scene_info, min_overlap=min_overlap, **kwargs + ) + ) + return scenes + + def weight_scenes(self, concat_dataset, alpha=0.5): + ns = [] + for d in concat_dataset.datasets: + ns.append(len(d)) + ws = torch.cat([torch.ones(n) / n**alpha for n in ns]) + return ws + + +if __name__ == "__main__": + mega_test = ConcatDataset(MegadepthBuilder().build_scenes(split="train")) + mega_test[0] diff --git a/third_party/DKM/dkm/datasets/scannet.py b/third_party/DKM/dkm/datasets/scannet.py new file mode 100644 index 0000000000000000000000000000000000000000..6ac39b41480f7585c4755cc30e0677ef74ed5e0c --- /dev/null +++ b/third_party/DKM/dkm/datasets/scannet.py @@ -0,0 +1,151 @@ +import os +import random +from PIL import Image +import cv2 +import h5py +import numpy as np +import torch +from torch.utils.data import ( + Dataset, + DataLoader, + ConcatDataset) + +import torchvision.transforms.functional as tvf +import kornia.augmentation as K +import os.path as osp +import matplotlib.pyplot as plt +from dkm.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops +from dkm.utils.transforms import GeometricSequential + +from tqdm import tqdm + +class ScanNetScene: + def __init__(self, data_root, scene_info, ht = 384, wt = 512, min_overlap=0., shake_t = 0, rot_prob=0.) -> None: + self.scene_root = osp.join(data_root,"scans","scans_train") + self.data_names = scene_info['name'] + self.overlaps = scene_info['score'] + # Only sample 10s + valid = (self.data_names[:,-2:] % 10).sum(axis=-1) == 0 + self.overlaps = self.overlaps[valid] + self.data_names = self.data_names[valid] + if len(self.data_names) > 10000: + pairinds = np.random.choice(np.arange(0,len(self.data_names)),10000,replace=False) + self.data_names = self.data_names[pairinds] + self.overlaps = self.overlaps[pairinds] + self.im_transform_ops = get_tuple_transform_ops(resize=(ht, wt), normalize=True) + self.depth_transform_ops = get_depth_tuple_transform_ops(resize=(ht, wt), normalize=False) + self.wt, self.ht = wt, ht + self.shake_t = shake_t + self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob)) + + def load_im(self, im_ref, crop=None): + im = Image.open(im_ref) + return im + + def load_depth(self, depth_ref, crop=None): + depth = cv2.imread(str(depth_ref), cv2.IMREAD_UNCHANGED) + depth = depth / 1000 + depth = torch.from_numpy(depth).float() # (h, w) + return depth + + def __len__(self): + return len(self.data_names) + + def scale_intrinsic(self, K, wi, hi): + sx, sy = self.wt / wi, self.ht / hi + sK = torch.tensor([[sx, 0, 0], + [0, sy, 0], + [0, 0, 1]]) + return sK@K + + def read_scannet_pose(self,path): + """ Read ScanNet's Camera2World pose and transform it to World2Camera. + + Returns: + pose_w2c (np.ndarray): (4, 4) + """ + cam2world = np.loadtxt(path, delimiter=' ') + world2cam = np.linalg.inv(cam2world) + return world2cam + + + def read_scannet_intrinsic(self,path): + """ Read ScanNet's intrinsic matrix and return the 3x3 matrix. + """ + intrinsic = np.loadtxt(path, delimiter=' ') + return intrinsic[:-1, :-1] + + def __getitem__(self, pair_idx): + # read intrinsics of original size + data_name = self.data_names[pair_idx] + scene_name, scene_sub_name, stem_name_1, stem_name_2 = data_name + scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}' + + # read the intrinsic of depthmap + K1 = K2 = self.read_scannet_intrinsic(osp.join(self.scene_root, + scene_name, + 'intrinsic', 'intrinsic_color.txt'))#the depth K is not the same, but doesnt really matter + # read and compute relative poses + T1 = self.read_scannet_pose(osp.join(self.scene_root, + scene_name, + 'pose', f'{stem_name_1}.txt')) + T2 = self.read_scannet_pose(osp.join(self.scene_root, + scene_name, + 'pose', f'{stem_name_2}.txt')) + T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[:4, :4] # (4, 4) + + # Load positive pair data + im_src_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_1}.jpg') + im_pos_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_2}.jpg') + depth_src_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_1}.png') + depth_pos_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_2}.png') + + im_src = self.load_im(im_src_ref) + im_pos = self.load_im(im_pos_ref) + depth_src = self.load_depth(depth_src_ref) + depth_pos = self.load_depth(depth_pos_ref) + + # Recompute camera intrinsic matrix due to the resize + K1 = self.scale_intrinsic(K1, im_src.width, im_src.height) + K2 = self.scale_intrinsic(K2, im_pos.width, im_pos.height) + # Process images + im_src, im_pos = self.im_transform_ops((im_src, im_pos)) + depth_src, depth_pos = self.depth_transform_ops((depth_src[None,None], depth_pos[None,None])) + + data_dict = {'query': im_src, + 'support': im_pos, + 'query_depth': depth_src[0,0], + 'support_depth': depth_pos[0,0], + 'K1': K1, + 'K2': K2, + 'T_1to2':T_1to2, + } + return data_dict + + +class ScanNetBuilder: + def __init__(self, data_root = 'data/scannet') -> None: + self.data_root = data_root + self.scene_info_root = os.path.join(data_root,'scannet_indices') + self.all_scenes = os.listdir(self.scene_info_root) + + def build_scenes(self, split = 'train', min_overlap=0., **kwargs): + # Note: split doesn't matter here as we always use same scannet_train scenes + scene_names = self.all_scenes + scenes = [] + for scene_name in tqdm(scene_names): + scene_info = np.load(os.path.join(self.scene_info_root,scene_name), allow_pickle=True) + scenes.append(ScanNetScene(self.data_root, scene_info, min_overlap=min_overlap, **kwargs)) + return scenes + + def weight_scenes(self, concat_dataset, alpha=.5): + ns = [] + for d in concat_dataset.datasets: + ns.append(len(d)) + ws = torch.cat([torch.ones(n)/n**alpha for n in ns]) + return ws + + +if __name__ == "__main__": + mega_test = ConcatDataset(ScanNetBuilder("data/scannet").build_scenes(split='train')) + mega_test[0] \ No newline at end of file diff --git a/third_party/DKM/dkm/losses/__init__.py b/third_party/DKM/dkm/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..71914f50d891079d204a07c57367159888f892de --- /dev/null +++ b/third_party/DKM/dkm/losses/__init__.py @@ -0,0 +1 @@ +from .depth_match_regression_loss import DepthRegressionLoss diff --git a/third_party/DKM/dkm/losses/depth_match_regression_loss.py b/third_party/DKM/dkm/losses/depth_match_regression_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..80da70347b4b4addc721e2a14ed489f8683fd48a --- /dev/null +++ b/third_party/DKM/dkm/losses/depth_match_regression_loss.py @@ -0,0 +1,128 @@ +from einops.einops import rearrange +import torch +import torch.nn as nn +import torch.nn.functional as F +from dkm.utils.utils import warp_kpts + + +class DepthRegressionLoss(nn.Module): + def __init__( + self, + robust=True, + center_coords=False, + scale_normalize=False, + ce_weight=0.01, + local_loss=True, + local_dist=4.0, + local_largest_scale=8, + ): + super().__init__() + self.robust = robust # measured in pixels + self.center_coords = center_coords + self.scale_normalize = scale_normalize + self.ce_weight = ce_weight + self.local_loss = local_loss + self.local_dist = local_dist + self.local_largest_scale = local_largest_scale + + def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches, scale): + """[summary] + + Args: + H ([type]): [description] + scale ([type]): [description] + + Returns: + [type]: [description] + """ + b, h1, w1, d = dense_matches.shape + with torch.no_grad(): + x1_n = torch.meshgrid( + *[ + torch.linspace( + -1 + 1 / n, 1 - 1 / n, n, device=dense_matches.device + ) + for n in (b, h1, w1) + ] + ) + x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(b, h1 * w1, 2) + mask, x2 = warp_kpts( + x1_n.double(), + depth1.double(), + depth2.double(), + T_1to2.double(), + K1.double(), + K2.double(), + ) + prob = mask.float().reshape(b, h1, w1) + gd = (dense_matches - x2.reshape(b, h1, w1, 2)).norm(dim=-1) # *scale? + return gd, prob + + def dense_depth_loss(self, dense_certainty, prob, gd, scale, eps=1e-8): + """[summary] + + Args: + dense_certainty ([type]): [description] + prob ([type]): [description] + eps ([type], optional): [description]. Defaults to 1e-8. + + Returns: + [type]: [description] + """ + smooth_prob = prob + ce_loss = F.binary_cross_entropy_with_logits(dense_certainty[:, 0], smooth_prob) + depth_loss = gd[prob > 0] + if not torch.any(prob > 0).item(): + depth_loss = (gd * 0.0).mean() # Prevent issues where prob is 0 everywhere + return { + f"ce_loss_{scale}": ce_loss.mean(), + f"depth_loss_{scale}": depth_loss.mean(), + } + + def forward(self, dense_corresps, batch): + """[summary] + + Args: + out ([type]): [description] + batch ([type]): [description] + + Returns: + [type]: [description] + """ + scales = list(dense_corresps.keys()) + tot_loss = 0.0 + prev_gd = 0.0 + for scale in scales: + dense_scale_corresps = dense_corresps[scale] + dense_scale_certainty, dense_scale_coords = ( + dense_scale_corresps["dense_certainty"], + dense_scale_corresps["dense_flow"], + ) + dense_scale_coords = rearrange(dense_scale_coords, "b d h w -> b h w d") + b, h, w, d = dense_scale_coords.shape + gd, prob = self.geometric_dist( + batch["query_depth"], + batch["support_depth"], + batch["T_1to2"], + batch["K1"], + batch["K2"], + dense_scale_coords, + scale, + ) + if ( + scale <= self.local_largest_scale and self.local_loss + ): # Thought here is that fine matching loss should not be punished by coarse mistakes, but should identify wrong matching + prob = prob * ( + F.interpolate(prev_gd[:, None], size=(h, w), mode="nearest")[:, 0] + < (2 / 512) * (self.local_dist * scale) + ) + depth_losses = self.dense_depth_loss(dense_scale_certainty, prob, gd, scale) + scale_loss = ( + self.ce_weight * depth_losses[f"ce_loss_{scale}"] + + depth_losses[f"depth_loss_{scale}"] + ) # scale ce loss for coarser scales + if self.scale_normalize: + scale_loss = scale_loss * 1 / scale + tot_loss = tot_loss + scale_loss + prev_gd = gd.detach() + return tot_loss diff --git a/third_party/DKM/dkm/models/__init__.py b/third_party/DKM/dkm/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4fc321ec70fd116beca23e94248cb6bbe771523 --- /dev/null +++ b/third_party/DKM/dkm/models/__init__.py @@ -0,0 +1,4 @@ +from .model_zoo import ( + DKMv3_outdoor, + DKMv3_indoor, +) diff --git a/third_party/DKM/dkm/models/deprecated/build_model.py b/third_party/DKM/dkm/models/deprecated/build_model.py new file mode 100644 index 0000000000000000000000000000000000000000..dd28335f3e348ab6c90b26ba91b95e864b0bbbb9 --- /dev/null +++ b/third_party/DKM/dkm/models/deprecated/build_model.py @@ -0,0 +1,787 @@ +import torch +import torch.nn as nn +from dkm import * +from .local_corr import LocalCorr +from .corr_channels import NormedCorr +from torchvision.models import resnet as tv_resnet + +dkm_pretrained_urls = { + "DKM": { + "mega_synthetic": "https://github.com/Parskatt/storage/releases/download/dkm_mega_synthetic/dkm_mega_synthetic.pth", + "mega": "https://github.com/Parskatt/storage/releases/download/dkm_mega/dkm_mega.pth", + }, + "DKMv2":{ + "outdoor": "https://github.com/Parskatt/storage/releases/download/dkmv2/dkm_v2_outdoor.pth", + "indoor": "https://github.com/Parskatt/storage/releases/download/dkmv2/dkm_v2_indoor.pth", + } +} + + +def DKM(pretrained=True, version="mega_synthetic", device=None): + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + gp_dim = 256 + dfn_dim = 384 + feat_dim = 256 + coordinate_decoder = DFN( + internal_dim=dfn_dim, + feat_input_modules=nn.ModuleDict( + { + "32": nn.Conv2d(512, feat_dim, 1, 1), + "16": nn.Conv2d(512, feat_dim, 1, 1), + } + ), + pred_input_modules=nn.ModuleDict( + { + "32": nn.Identity(), + "16": nn.Identity(), + } + ), + rrb_d_dict=nn.ModuleDict( + { + "32": RRB(gp_dim + feat_dim, dfn_dim), + "16": RRB(gp_dim + feat_dim, dfn_dim), + } + ), + cab_dict=nn.ModuleDict( + { + "32": CAB(2 * dfn_dim, dfn_dim), + "16": CAB(2 * dfn_dim, dfn_dim), + } + ), + rrb_u_dict=nn.ModuleDict( + { + "32": RRB(dfn_dim, dfn_dim), + "16": RRB(dfn_dim, dfn_dim), + } + ), + terminal_module=nn.ModuleDict( + { + "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0), + "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0), + } + ), + ) + dw = True + hidden_blocks = 8 + kernel_size = 5 + conv_refiner = nn.ModuleDict( + { + "16": ConvRefiner( + 2 * 512, + 1024, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "8": ConvRefiner( + 2 * 512, + 1024, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "4": ConvRefiner( + 2 * 256, + 512, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "2": ConvRefiner( + 2 * 64, + 128, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "1": ConvRefiner( + 2 * 3, + 24, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + } + ) + kernel_temperature = 0.2 + learn_temperature = False + no_cov = True + kernel = CosKernel + only_attention = False + basis = "fourier" + gp32 = GP( + kernel, + T=kernel_temperature, + learn_temperature=learn_temperature, + only_attention=only_attention, + gp_dim=gp_dim, + basis=basis, + no_cov=no_cov, + ) + gp16 = GP( + kernel, + T=kernel_temperature, + learn_temperature=learn_temperature, + only_attention=only_attention, + gp_dim=gp_dim, + basis=basis, + no_cov=no_cov, + ) + gps = nn.ModuleDict({"32": gp32, "16": gp16}) + proj = nn.ModuleDict( + {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)} + ) + decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True) + h, w = 384, 512 + encoder = Encoder( + tv_resnet.resnet50(pretrained=not pretrained), + ) # only load pretrained weights if not loading a pretrained matcher ;) + matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device) + if pretrained: + weights = torch.hub.load_state_dict_from_url( + dkm_pretrained_urls["DKM"][version] + ) + matcher.load_state_dict(weights) + return matcher + +def DKMv2(pretrained=True, version="outdoor", resolution = "low", **kwargs): + gp_dim = 256 + dfn_dim = 384 + feat_dim = 256 + coordinate_decoder = DFN( + internal_dim=dfn_dim, + feat_input_modules=nn.ModuleDict( + { + "32": nn.Conv2d(512, feat_dim, 1, 1), + "16": nn.Conv2d(512, feat_dim, 1, 1), + } + ), + pred_input_modules=nn.ModuleDict( + { + "32": nn.Identity(), + "16": nn.Identity(), + } + ), + rrb_d_dict=nn.ModuleDict( + { + "32": RRB(gp_dim + feat_dim, dfn_dim), + "16": RRB(gp_dim + feat_dim, dfn_dim), + } + ), + cab_dict=nn.ModuleDict( + { + "32": CAB(2 * dfn_dim, dfn_dim), + "16": CAB(2 * dfn_dim, dfn_dim), + } + ), + rrb_u_dict=nn.ModuleDict( + { + "32": RRB(dfn_dim, dfn_dim), + "16": RRB(dfn_dim, dfn_dim), + } + ), + terminal_module=nn.ModuleDict( + { + "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0), + "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0), + } + ), + ) + dw = True + hidden_blocks = 8 + kernel_size = 5 + displacement_emb = "linear" + conv_refiner = nn.ModuleDict( + { + "16": ConvRefiner( + 2 * 512+128, + 1024+128, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=128, + ), + "8": ConvRefiner( + 2 * 512+64, + 1024+64, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=64, + ), + "4": ConvRefiner( + 2 * 256+32, + 512+32, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=32, + ), + "2": ConvRefiner( + 2 * 64+16, + 128+16, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=16, + ), + "1": ConvRefiner( + 2 * 3+6, + 24, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=6, + ), + } + ) + kernel_temperature = 0.2 + learn_temperature = False + no_cov = True + kernel = CosKernel + only_attention = False + basis = "fourier" + gp32 = GP( + kernel, + T=kernel_temperature, + learn_temperature=learn_temperature, + only_attention=only_attention, + gp_dim=gp_dim, + basis=basis, + no_cov=no_cov, + ) + gp16 = GP( + kernel, + T=kernel_temperature, + learn_temperature=learn_temperature, + only_attention=only_attention, + gp_dim=gp_dim, + basis=basis, + no_cov=no_cov, + ) + gps = nn.ModuleDict({"32": gp32, "16": gp16}) + proj = nn.ModuleDict( + {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)} + ) + decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True) + if resolution == "low": + h, w = 384, 512 + elif resolution == "high": + h, w = 480, 640 + encoder = Encoder( + tv_resnet.resnet50(pretrained=not pretrained), + ) # only load pretrained weights if not loading a pretrained matcher ;) + matcher = RegressionMatcher(encoder, decoder, h=h, w=w,**kwargs).to(device) + if pretrained: + try: + weights = torch.hub.load_state_dict_from_url( + dkm_pretrained_urls["DKMv2"][version] + ) + except: + weights = torch.load( + dkm_pretrained_urls["DKMv2"][version] + ) + matcher.load_state_dict(weights) + return matcher + + +def local_corr(pretrained=True, version="mega_synthetic"): + gp_dim = 256 + dfn_dim = 384 + feat_dim = 256 + coordinate_decoder = DFN( + internal_dim=dfn_dim, + feat_input_modules=nn.ModuleDict( + { + "32": nn.Conv2d(512, feat_dim, 1, 1), + "16": nn.Conv2d(512, feat_dim, 1, 1), + } + ), + pred_input_modules=nn.ModuleDict( + { + "32": nn.Identity(), + "16": nn.Identity(), + } + ), + rrb_d_dict=nn.ModuleDict( + { + "32": RRB(gp_dim + feat_dim, dfn_dim), + "16": RRB(gp_dim + feat_dim, dfn_dim), + } + ), + cab_dict=nn.ModuleDict( + { + "32": CAB(2 * dfn_dim, dfn_dim), + "16": CAB(2 * dfn_dim, dfn_dim), + } + ), + rrb_u_dict=nn.ModuleDict( + { + "32": RRB(dfn_dim, dfn_dim), + "16": RRB(dfn_dim, dfn_dim), + } + ), + terminal_module=nn.ModuleDict( + { + "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0), + "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0), + } + ), + ) + dw = True + hidden_blocks = 8 + kernel_size = 5 + conv_refiner = nn.ModuleDict( + { + "16": LocalCorr( + 81, + 81 * 12, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "8": LocalCorr( + 81, + 81 * 12, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "4": LocalCorr( + 81, + 81 * 6, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "2": LocalCorr( + 81, + 81, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "1": ConvRefiner( + 2 * 3, + 24, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + } + ) + kernel_temperature = 0.2 + learn_temperature = False + no_cov = True + kernel = CosKernel + only_attention = False + basis = "fourier" + gp32 = GP( + kernel, + T=kernel_temperature, + learn_temperature=learn_temperature, + only_attention=only_attention, + gp_dim=gp_dim, + basis=basis, + no_cov=no_cov, + ) + gp16 = GP( + kernel, + T=kernel_temperature, + learn_temperature=learn_temperature, + only_attention=only_attention, + gp_dim=gp_dim, + basis=basis, + no_cov=no_cov, + ) + gps = nn.ModuleDict({"32": gp32, "16": gp16}) + proj = nn.ModuleDict( + {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)} + ) + decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True) + h, w = 384, 512 + encoder = Encoder( + tv_resnet.resnet50(pretrained=not pretrained) + ) # only load pretrained weights if not loading a pretrained matcher ;) + matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device) + if pretrained: + weights = torch.hub.load_state_dict_from_url( + dkm_pretrained_urls["local_corr"][version] + ) + matcher.load_state_dict(weights) + return matcher + + +def corr_channels(pretrained=True, version="mega_synthetic"): + h, w = 384, 512 + gp_dim = (h // 32) * (w // 32), (h // 16) * (w // 16) + dfn_dim = 384 + feat_dim = 256 + coordinate_decoder = DFN( + internal_dim=dfn_dim, + feat_input_modules=nn.ModuleDict( + { + "32": nn.Conv2d(512, feat_dim, 1, 1), + "16": nn.Conv2d(512, feat_dim, 1, 1), + } + ), + pred_input_modules=nn.ModuleDict( + { + "32": nn.Identity(), + "16": nn.Identity(), + } + ), + rrb_d_dict=nn.ModuleDict( + { + "32": RRB(gp_dim[0] + feat_dim, dfn_dim), + "16": RRB(gp_dim[1] + feat_dim, dfn_dim), + } + ), + cab_dict=nn.ModuleDict( + { + "32": CAB(2 * dfn_dim, dfn_dim), + "16": CAB(2 * dfn_dim, dfn_dim), + } + ), + rrb_u_dict=nn.ModuleDict( + { + "32": RRB(dfn_dim, dfn_dim), + "16": RRB(dfn_dim, dfn_dim), + } + ), + terminal_module=nn.ModuleDict( + { + "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0), + "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0), + } + ), + ) + dw = True + hidden_blocks = 8 + kernel_size = 5 + conv_refiner = nn.ModuleDict( + { + "16": ConvRefiner( + 2 * 512, + 1024, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "8": ConvRefiner( + 2 * 512, + 1024, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "4": ConvRefiner( + 2 * 256, + 512, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "2": ConvRefiner( + 2 * 64, + 128, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "1": ConvRefiner( + 2 * 3, + 24, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + } + ) + gp32 = NormedCorr() + gp16 = NormedCorr() + gps = nn.ModuleDict({"32": gp32, "16": gp16}) + proj = nn.ModuleDict( + {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)} + ) + decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True) + h, w = 384, 512 + encoder = Encoder( + tv_resnet.resnet50(pretrained=not pretrained) + ) # only load pretrained weights if not loading a pretrained matcher ;) + matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device) + if pretrained: + weights = torch.hub.load_state_dict_from_url( + dkm_pretrained_urls["corr_channels"][version] + ) + matcher.load_state_dict(weights) + return matcher + + +def baseline(pretrained=True, version="mega_synthetic"): + h, w = 384, 512 + gp_dim = (h // 32) * (w // 32), (h // 16) * (w // 16) + dfn_dim = 384 + feat_dim = 256 + coordinate_decoder = DFN( + internal_dim=dfn_dim, + feat_input_modules=nn.ModuleDict( + { + "32": nn.Conv2d(512, feat_dim, 1, 1), + "16": nn.Conv2d(512, feat_dim, 1, 1), + } + ), + pred_input_modules=nn.ModuleDict( + { + "32": nn.Identity(), + "16": nn.Identity(), + } + ), + rrb_d_dict=nn.ModuleDict( + { + "32": RRB(gp_dim[0] + feat_dim, dfn_dim), + "16": RRB(gp_dim[1] + feat_dim, dfn_dim), + } + ), + cab_dict=nn.ModuleDict( + { + "32": CAB(2 * dfn_dim, dfn_dim), + "16": CAB(2 * dfn_dim, dfn_dim), + } + ), + rrb_u_dict=nn.ModuleDict( + { + "32": RRB(dfn_dim, dfn_dim), + "16": RRB(dfn_dim, dfn_dim), + } + ), + terminal_module=nn.ModuleDict( + { + "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0), + "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0), + } + ), + ) + dw = True + hidden_blocks = 8 + kernel_size = 5 + conv_refiner = nn.ModuleDict( + { + "16": LocalCorr( + 81, + 81 * 12, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "8": LocalCorr( + 81, + 81 * 12, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "4": LocalCorr( + 81, + 81 * 6, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "2": LocalCorr( + 81, + 81, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "1": ConvRefiner( + 2 * 3, + 24, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + } + ) + gp32 = NormedCorr() + gp16 = NormedCorr() + gps = nn.ModuleDict({"32": gp32, "16": gp16}) + proj = nn.ModuleDict( + {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)} + ) + decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True) + h, w = 384, 512 + encoder = Encoder( + tv_resnet.resnet50(pretrained=not pretrained) + ) # only load pretrained weights if not loading a pretrained matcher ;) + matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device) + if pretrained: + weights = torch.hub.load_state_dict_from_url( + dkm_pretrained_urls["baseline"][version] + ) + matcher.load_state_dict(weights) + return matcher + + +def linear(pretrained=True, version="mega_synthetic"): + gp_dim = 256 + dfn_dim = 384 + feat_dim = 256 + coordinate_decoder = DFN( + internal_dim=dfn_dim, + feat_input_modules=nn.ModuleDict( + { + "32": nn.Conv2d(512, feat_dim, 1, 1), + "16": nn.Conv2d(512, feat_dim, 1, 1), + } + ), + pred_input_modules=nn.ModuleDict( + { + "32": nn.Identity(), + "16": nn.Identity(), + } + ), + rrb_d_dict=nn.ModuleDict( + { + "32": RRB(gp_dim + feat_dim, dfn_dim), + "16": RRB(gp_dim + feat_dim, dfn_dim), + } + ), + cab_dict=nn.ModuleDict( + { + "32": CAB(2 * dfn_dim, dfn_dim), + "16": CAB(2 * dfn_dim, dfn_dim), + } + ), + rrb_u_dict=nn.ModuleDict( + { + "32": RRB(dfn_dim, dfn_dim), + "16": RRB(dfn_dim, dfn_dim), + } + ), + terminal_module=nn.ModuleDict( + { + "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0), + "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0), + } + ), + ) + dw = True + hidden_blocks = 8 + kernel_size = 5 + conv_refiner = nn.ModuleDict( + { + "16": ConvRefiner( + 2 * 512, + 1024, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "8": ConvRefiner( + 2 * 512, + 1024, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "4": ConvRefiner( + 2 * 256, + 512, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "2": ConvRefiner( + 2 * 64, + 128, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "1": ConvRefiner( + 2 * 3, + 24, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + } + ) + kernel_temperature = 0.2 + learn_temperature = False + no_cov = True + kernel = CosKernel + only_attention = False + basis = "linear" + gp32 = GP( + kernel, + T=kernel_temperature, + learn_temperature=learn_temperature, + only_attention=only_attention, + gp_dim=gp_dim, + basis=basis, + no_cov=no_cov, + ) + gp16 = GP( + kernel, + T=kernel_temperature, + learn_temperature=learn_temperature, + only_attention=only_attention, + gp_dim=gp_dim, + basis=basis, + no_cov=no_cov, + ) + gps = nn.ModuleDict({"32": gp32, "16": gp16}) + proj = nn.ModuleDict( + {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)} + ) + decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True) + h, w = 384, 512 + encoder = Encoder( + tv_resnet.resnet50(pretrained=not pretrained) + ) # only load pretrained weights if not loading a pretrained matcher ;) + matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device) + if pretrained: + weights = torch.hub.load_state_dict_from_url( + dkm_pretrained_urls["linear"][version] + ) + matcher.load_state_dict(weights) + return matcher diff --git a/third_party/DKM/dkm/models/deprecated/corr_channels.py b/third_party/DKM/dkm/models/deprecated/corr_channels.py new file mode 100644 index 0000000000000000000000000000000000000000..8713b0d8c7a0ce91da4d2105ba29097a4969a037 --- /dev/null +++ b/third_party/DKM/dkm/models/deprecated/corr_channels.py @@ -0,0 +1,34 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +class NormedCorrelationKernel(nn.Module): # similar to softmax kernel + def __init__(self): + super().__init__() + + def __call__(self, x, y, eps=1e-6): + c = torch.einsum("bnd,bmd->bnm", x, y) / ( + x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps + ) + return c + + +class NormedCorr(nn.Module): + def __init__( + self, + ): + super().__init__() + self.corr = NormedCorrelationKernel() + + def reshape(self, x): + return rearrange(x, "b d h w -> b (h w) d") + + def forward(self, x, y, **kwargs): + b, c, h, w = y.shape + assert x.shape == y.shape + x, y = self.reshape(x), self.reshape(y) + corr_xy = self.corr(x, y) + corr_xy_flat = rearrange(corr_xy, "b (h w) c -> b c h w", h=h, w=w) + return corr_xy_flat diff --git a/third_party/DKM/dkm/models/deprecated/local_corr.py b/third_party/DKM/dkm/models/deprecated/local_corr.py new file mode 100644 index 0000000000000000000000000000000000000000..681fe4c0079561fa7a4c44e82a8879a4a27273a1 --- /dev/null +++ b/third_party/DKM/dkm/models/deprecated/local_corr.py @@ -0,0 +1,630 @@ +import torch +import torch.nn.functional as F + +try: + import cupy +except: + print("Cupy not found, local correlation will not work") +import re +from ..dkm import ConvRefiner + + +class Stream: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + if device == 'cuda': + stream = torch.cuda.current_stream(device=device).cuda_stream + else: + stream = None + + +kernel_Correlation_rearrange = """ + extern "C" __global__ void kernel_Correlation_rearrange( + const int n, + const float* input, + float* output + ) { + int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; + if (intIndex >= n) { + return; + } + int intSample = blockIdx.z; + int intChannel = blockIdx.y; + float dblValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex]; + __syncthreads(); + int intPaddedY = (intIndex / SIZE_3(input)) + 4; + int intPaddedX = (intIndex % SIZE_3(input)) + 4; + int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX; + output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = dblValue; + } +""" + +kernel_Correlation_updateOutput = """ + extern "C" __global__ void kernel_Correlation_updateOutput( + const int n, + const float* rbot0, + const float* rbot1, + float* top + ) { + extern __shared__ char patch_data_char[]; + float *patch_data = (float *)patch_data_char; + // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1 + int x1 = blockIdx.x + 4; + int y1 = blockIdx.y + 4; + int item = blockIdx.z; + int ch_off = threadIdx.x; + // Load 3D patch into shared shared memory + for (int j = 0; j < 1; j++) { // HEIGHT + for (int i = 0; i < 1; i++) { // WIDTH + int ji_off = (j + i) * SIZE_3(rbot0); + for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS + int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch; + int idxPatchData = ji_off + ch; + patch_data[idxPatchData] = rbot0[idx1]; + } + } + } + __syncthreads(); + __shared__ float sum[32]; + // Compute correlation + for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) { + sum[ch_off] = 0; + int s2o = top_channel % 9 - 4; + int s2p = top_channel / 9 - 4; + for (int j = 0; j < 1; j++) { // HEIGHT + for (int i = 0; i < 1; i++) { // WIDTH + int ji_off = (j + i) * SIZE_3(rbot0); + for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS + int x2 = x1 + s2o; + int y2 = y1 + s2p; + int idxPatchData = ji_off + ch; + int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch; + sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2]; + } + } + } + __syncthreads(); + if (ch_off == 0) { + float total_sum = 0; + for (int idx = 0; idx < 32; idx++) { + total_sum += sum[idx]; + } + const int sumelems = SIZE_3(rbot0); + const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x; + top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems; + } + } + } +""" + +kernel_Correlation_updateGradFirst = """ + #define ROUND_OFF 50000 + extern "C" __global__ void kernel_Correlation_updateGradFirst( + const int n, + const int intSample, + const float* rbot0, + const float* rbot1, + const float* gradOutput, + float* gradFirst, + float* gradSecond + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + int n = intIndex % SIZE_1(gradFirst); // channels + int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 4; // w-pos + int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 4; // h-pos + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = round_off; + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) + int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) + // Same here: + int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4) + int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4) + float sum = 0; + if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { + xmin = max(0,xmin); + xmax = min(SIZE_3(gradOutput)-1,xmax); + ymin = max(0,ymin); + ymax = min(SIZE_2(gradOutput)-1,ymax); + for (int p = -4; p <= 4; p++) { + for (int o = -4; o <= 4; o++) { + // Get rbot1 data: + int s2o = o; + int s2p = p; + int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n; + float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n] + // Index offset for gradOutput in following loops: + int op = (p+4) * 9 + (o+4); // index[o,p] + int idxopoffset = (intSample * SIZE_1(gradOutput) + op); + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] + sum += gradOutput[idxgradOutput] * bot1tmp; + } + } + } + } + } + const int sumelems = SIZE_1(gradFirst); + const int bot0index = ((n * SIZE_2(gradFirst)) + (m-4)) * SIZE_3(gradFirst) + (l-4); + gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems; + } } +""" + +kernel_Correlation_updateGradSecond = """ + #define ROUND_OFF 50000 + extern "C" __global__ void kernel_Correlation_updateGradSecond( + const int n, + const int intSample, + const float* rbot0, + const float* rbot1, + const float* gradOutput, + float* gradFirst, + float* gradSecond + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + int n = intIndex % SIZE_1(gradSecond); // channels + int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 4; // w-pos + int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 4; // h-pos + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = round_off; + float sum = 0; + for (int p = -4; p <= 4; p++) { + for (int o = -4; o <= 4; o++) { + int s2o = o; + int s2p = p; + //Get X,Y ranges and clamp + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) + int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) + // Same here: + int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o) + int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p) + if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { + xmin = max(0,xmin); + xmax = min(SIZE_3(gradOutput)-1,xmax); + ymin = max(0,ymin); + ymax = min(SIZE_2(gradOutput)-1,ymax); + // Get rbot0 data: + int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n; + float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n] + // Index offset for gradOutput in following loops: + int op = (p+4) * 9 + (o+4); // index[o,p] + int idxopoffset = (intSample * SIZE_1(gradOutput) + op); + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] + sum += gradOutput[idxgradOutput] * bot0tmp; + } + } + } + } + } + const int sumelems = SIZE_1(gradSecond); + const int bot1index = ((n * SIZE_2(gradSecond)) + (m-4)) * SIZE_3(gradSecond) + (l-4); + gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems; + } } +""" + + +def cupy_kernel(strFunction, objectVariables): + strKernel = globals()[strFunction] + + while True: + objectMatch = re.search(r"(SIZE_)([0-4])(\()([^\)]*)(\))", strKernel) + + if objectMatch is None: + break + + intArg = int(objectMatch.group(2)) + + strTensor = objectMatch.group(4) + intSizes = objectVariables[strTensor].size() + + strKernel = strKernel.replace(objectMatch.group(), str(intSizes[intArg])) + + while True: + objectMatch = re.search(r"(VALUE_)([0-4])(\()([^\)]+)(\))", strKernel) + + if objectMatch is None: + break + + intArgs = int(objectMatch.group(2)) + strArgs = objectMatch.group(4).split(",") + + strTensor = strArgs[0] + intStrides = objectVariables[strTensor].stride() + strIndex = [ + "((" + + strArgs[intArg + 1].replace("{", "(").replace("}", ")").strip() + + ")*" + + str(intStrides[intArg]) + + ")" + for intArg in range(intArgs) + ] + + strKernel = strKernel.replace( + objectMatch.group(0), strTensor + "[" + str.join("+", strIndex) + "]" + ) + + return strKernel + + +try: + + @cupy.memoize(for_each_device=True) + def cupy_launch(strFunction, strKernel): + return cupy.RawModule(code=strKernel).get_function(strFunction) + +except: + pass + + +class _FunctionCorrelation(torch.autograd.Function): + @staticmethod + def forward(self, first, second): + rbot0 = first.new_zeros( + [first.size(0), first.size(2) + 8, first.size(3) + 8, first.size(1)] + ) + rbot1 = first.new_zeros( + [first.size(0), first.size(2) + 8, first.size(3) + 8, first.size(1)] + ) + + self.save_for_backward(first, second, rbot0, rbot1) + + first = first.contiguous() + second = second.contiguous() + + output = first.new_zeros([first.size(0), 81, first.size(2), first.size(3)]) + + if first.is_cuda == True: + n = first.size(2) * first.size(3) + cupy_launch( + "kernel_Correlation_rearrange", + cupy_kernel( + "kernel_Correlation_rearrange", {"input": first, "output": rbot0} + ), + )( + grid=tuple([int((n + 16 - 1) / 16), first.size(1), first.size(0)]), + block=tuple([16, 1, 1]), + args=[n, first.data_ptr(), rbot0.data_ptr()], + stream=Stream, + ) + + n = second.size(2) * second.size(3) + cupy_launch( + "kernel_Correlation_rearrange", + cupy_kernel( + "kernel_Correlation_rearrange", {"input": second, "output": rbot1} + ), + )( + grid=tuple([int((n + 16 - 1) / 16), second.size(1), second.size(0)]), + block=tuple([16, 1, 1]), + args=[n, second.data_ptr(), rbot1.data_ptr()], + stream=Stream, + ) + + n = output.size(1) * output.size(2) * output.size(3) + cupy_launch( + "kernel_Correlation_updateOutput", + cupy_kernel( + "kernel_Correlation_updateOutput", + {"rbot0": rbot0, "rbot1": rbot1, "top": output}, + ), + )( + grid=tuple([output.size(3), output.size(2), output.size(0)]), + block=tuple([32, 1, 1]), + shared_mem=first.size(1) * 4, + args=[n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr()], + stream=Stream, + ) + + elif first.is_cuda == False: + raise NotImplementedError() + + return output + + @staticmethod + def backward(self, gradOutput): + first, second, rbot0, rbot1 = self.saved_tensors + + gradOutput = gradOutput.contiguous() + + assert gradOutput.is_contiguous() == True + + gradFirst = ( + first.new_zeros( + [first.size(0), first.size(1), first.size(2), first.size(3)] + ) + if self.needs_input_grad[0] == True + else None + ) + gradSecond = ( + first.new_zeros( + [first.size(0), first.size(1), first.size(2), first.size(3)] + ) + if self.needs_input_grad[1] == True + else None + ) + + if first.is_cuda == True: + if gradFirst is not None: + for intSample in range(first.size(0)): + n = first.size(1) * first.size(2) * first.size(3) + cupy_launch( + "kernel_Correlation_updateGradFirst", + cupy_kernel( + "kernel_Correlation_updateGradFirst", + { + "rbot0": rbot0, + "rbot1": rbot1, + "gradOutput": gradOutput, + "gradFirst": gradFirst, + "gradSecond": None, + }, + ), + )( + grid=tuple([int((n + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + n, + intSample, + rbot0.data_ptr(), + rbot1.data_ptr(), + gradOutput.data_ptr(), + gradFirst.data_ptr(), + None, + ], + stream=Stream, + ) + + if gradSecond is not None: + for intSample in range(first.size(0)): + n = first.size(1) * first.size(2) * first.size(3) + cupy_launch( + "kernel_Correlation_updateGradSecond", + cupy_kernel( + "kernel_Correlation_updateGradSecond", + { + "rbot0": rbot0, + "rbot1": rbot1, + "gradOutput": gradOutput, + "gradFirst": None, + "gradSecond": gradSecond, + }, + ), + )( + grid=tuple([int((n + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + n, + intSample, + rbot0.data_ptr(), + rbot1.data_ptr(), + gradOutput.data_ptr(), + None, + gradSecond.data_ptr(), + ], + stream=Stream, + ) + + elif first.is_cuda == False: + raise NotImplementedError() + + return gradFirst, gradSecond + + +class _FunctionCorrelationTranspose(torch.autograd.Function): + @staticmethod + def forward(self, input, second): + rbot0 = second.new_zeros( + [second.size(0), second.size(2) + 8, second.size(3) + 8, second.size(1)] + ) + rbot1 = second.new_zeros( + [second.size(0), second.size(2) + 8, second.size(3) + 8, second.size(1)] + ) + + self.save_for_backward(input, second, rbot0, rbot1) + + input = input.contiguous() + second = second.contiguous() + + output = second.new_zeros( + [second.size(0), second.size(1), second.size(2), second.size(3)] + ) + + if second.is_cuda == True: + n = second.size(2) * second.size(3) + cupy_launch( + "kernel_Correlation_rearrange", + cupy_kernel( + "kernel_Correlation_rearrange", {"input": second, "output": rbot1} + ), + )( + grid=tuple([int((n + 16 - 1) / 16), second.size(1), second.size(0)]), + block=tuple([16, 1, 1]), + args=[n, second.data_ptr(), rbot1.data_ptr()], + stream=Stream, + ) + + for intSample in range(second.size(0)): + n = second.size(1) * second.size(2) * second.size(3) + cupy_launch( + "kernel_Correlation_updateGradFirst", + cupy_kernel( + "kernel_Correlation_updateGradFirst", + { + "rbot0": rbot0, + "rbot1": rbot1, + "gradOutput": input, + "gradFirst": output, + "gradSecond": None, + }, + ), + )( + grid=tuple([int((n + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + n, + intSample, + rbot0.data_ptr(), + rbot1.data_ptr(), + input.data_ptr(), + output.data_ptr(), + None, + ], + stream=Stream, + ) + + elif second.is_cuda == False: + raise NotImplementedError() + + return output + + @staticmethod + def backward(self, gradOutput): + input, second, rbot0, rbot1 = self.saved_tensors + + gradOutput = gradOutput.contiguous() + + gradInput = ( + input.new_zeros( + [input.size(0), input.size(1), input.size(2), input.size(3)] + ) + if self.needs_input_grad[0] == True + else None + ) + gradSecond = ( + second.new_zeros( + [second.size(0), second.size(1), second.size(2), second.size(3)] + ) + if self.needs_input_grad[1] == True + else None + ) + + if second.is_cuda == True: + if gradInput is not None or gradSecond is not None: + n = second.size(2) * second.size(3) + cupy_launch( + "kernel_Correlation_rearrange", + cupy_kernel( + "kernel_Correlation_rearrange", + {"input": gradOutput, "output": rbot0}, + ), + )( + grid=tuple( + [int((n + 16 - 1) / 16), gradOutput.size(1), gradOutput.size(0)] + ), + block=tuple([16, 1, 1]), + args=[n, gradOutput.data_ptr(), rbot0.data_ptr()], + stream=Stream, + ) + + if gradInput is not None: + n = gradInput.size(1) * gradInput.size(2) * gradInput.size(3) + cupy_launch( + "kernel_Correlation_updateOutput", + cupy_kernel( + "kernel_Correlation_updateOutput", + {"rbot0": rbot0, "rbot1": rbot1, "top": gradInput}, + ), + )( + grid=tuple( + [gradInput.size(3), gradInput.size(2), gradInput.size(0)] + ), + block=tuple([32, 1, 1]), + shared_mem=gradOutput.size(1) * 4, + args=[n, rbot0.data_ptr(), rbot1.data_ptr(), gradInput.data_ptr()], + stream=Stream, + ) + + if gradSecond is not None: + for intSample in range(second.size(0)): + n = second.size(1) * second.size(2) * second.size(3) + cupy_launch( + "kernel_Correlation_updateGradSecond", + cupy_kernel( + "kernel_Correlation_updateGradSecond", + { + "rbot0": rbot0, + "rbot1": rbot1, + "gradOutput": input, + "gradFirst": None, + "gradSecond": gradSecond, + }, + ), + )( + grid=tuple([int((n + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + n, + intSample, + rbot0.data_ptr(), + rbot1.data_ptr(), + input.data_ptr(), + None, + gradSecond.data_ptr(), + ], + stream=Stream, + ) + + elif second.is_cuda == False: + raise NotImplementedError() + + return gradInput, gradSecond + + +def FunctionCorrelation(reference_features, query_features): + return _FunctionCorrelation.apply(reference_features, query_features) + + +class ModuleCorrelation(torch.nn.Module): + def __init__(self): + super(ModuleCorrelation, self).__init__() + + def forward(self, tensorFirst, tensorSecond): + return _FunctionCorrelation.apply(tensorFirst, tensorSecond) + + +def FunctionCorrelationTranspose(reference_features, query_features): + return _FunctionCorrelationTranspose.apply(reference_features, query_features) + + +class ModuleCorrelationTranspose(torch.nn.Module): + def __init__(self): + super(ModuleCorrelationTranspose, self).__init__() + + def forward(self, tensorFirst, tensorSecond): + return _FunctionCorrelationTranspose.apply(tensorFirst, tensorSecond) + + +class LocalCorr(ConvRefiner): + def forward(self, x, y, flow): + """Computes the relative refining displacement in pixels for a given image x,y and a coarse flow-field between them + + Args: + x ([type]): [description] + y ([type]): [description] + flow ([type]): [description] + + Returns: + [type]: [description] + """ + with torch.no_grad(): + x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False) + corr = FunctionCorrelation(x, x_hat) + d = self.block1(corr) + d = self.hidden_blocks(d) + d = self.out_conv(d) + certainty, displacement = d[:, :-2], d[:, -2:] + return certainty, displacement + + +if __name__ == "__main__": + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + x = torch.randn(2, 128, 32, 32).to(device) + y = torch.randn(2, 128, 32, 32).to(device) + local_corr = LocalCorr(in_dim=81, hidden_dim=81 * 4) + z = local_corr(x, y) + print("hej") diff --git a/third_party/DKM/dkm/models/dkm.py b/third_party/DKM/dkm/models/dkm.py new file mode 100644 index 0000000000000000000000000000000000000000..27c3f6d59ad3a8e976e3d719868908ddf443883e --- /dev/null +++ b/third_party/DKM/dkm/models/dkm.py @@ -0,0 +1,759 @@ +import math +import os +import numpy as np +from PIL import Image +import torch +import torch.nn as nn +import torch.nn.functional as F +from ..utils import get_tuple_transform_ops +from einops import rearrange +from ..utils.local_correlation import local_correlation + + +class ConvRefiner(nn.Module): + def __init__( + self, + in_dim=6, + hidden_dim=16, + out_dim=2, + dw=False, + kernel_size=5, + hidden_blocks=3, + displacement_emb = None, + displacement_emb_dim = None, + local_corr_radius = None, + corr_in_other = None, + no_support_fm = False, + ): + super().__init__() + self.block1 = self.create_block( + in_dim, hidden_dim, dw=dw, kernel_size=kernel_size + ) + self.hidden_blocks = nn.Sequential( + *[ + self.create_block( + hidden_dim, + hidden_dim, + dw=dw, + kernel_size=kernel_size, + ) + for hb in range(hidden_blocks) + ] + ) + self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0) + if displacement_emb: + self.has_displacement_emb = True + self.disp_emb = nn.Conv2d(2,displacement_emb_dim,1,1,0) + else: + self.has_displacement_emb = False + self.local_corr_radius = local_corr_radius + self.corr_in_other = corr_in_other + self.no_support_fm = no_support_fm + def create_block( + self, + in_dim, + out_dim, + dw=False, + kernel_size=5, + ): + num_groups = 1 if not dw else in_dim + if dw: + assert ( + out_dim % in_dim == 0 + ), "outdim must be divisible by indim for depthwise" + conv1 = nn.Conv2d( + in_dim, + out_dim, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + groups=num_groups, + ) + norm = nn.BatchNorm2d(out_dim) + relu = nn.ReLU(inplace=True) + conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0) + return nn.Sequential(conv1, norm, relu, conv2) + + def forward(self, x, y, flow): + """Computes the relative refining displacement in pixels for a given image x,y and a coarse flow-field between them + + Args: + x ([type]): [description] + y ([type]): [description] + flow ([type]): [description] + + Returns: + [type]: [description] + """ + device = x.device + b,c,hs,ws = x.shape + with torch.no_grad(): + x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False) + if self.has_displacement_emb: + query_coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device), + torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device), + ) + ) + query_coords = torch.stack((query_coords[1], query_coords[0])) + query_coords = query_coords[None].expand(b, 2, hs, ws) + in_displacement = flow-query_coords + emb_in_displacement = self.disp_emb(in_displacement) + if self.local_corr_radius: + #TODO: should corr have gradient? + if self.corr_in_other: + # Corr in other means take a kxk grid around the predicted coordinate in other image + local_corr = local_correlation(x,y,local_radius=self.local_corr_radius,flow = flow) + else: + # Otherwise we use the warp to sample in the first image + # This is actually different operations, especially for large viewpoint changes + local_corr = local_correlation(x, x_hat, local_radius=self.local_corr_radius,) + if self.no_support_fm: + x_hat = torch.zeros_like(x) + d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1) + else: + d = torch.cat((x, x_hat, emb_in_displacement), dim=1) + else: + if self.no_support_fm: + x_hat = torch.zeros_like(x) + d = torch.cat((x, x_hat), dim=1) + d = self.block1(d) + d = self.hidden_blocks(d) + d = self.out_conv(d) + certainty, displacement = d[:, :-2], d[:, -2:] + return certainty, displacement + + +class CosKernel(nn.Module): # similar to softmax kernel + def __init__(self, T, learn_temperature=False): + super().__init__() + self.learn_temperature = learn_temperature + if self.learn_temperature: + self.T = nn.Parameter(torch.tensor(T)) + else: + self.T = T + + def __call__(self, x, y, eps=1e-6): + c = torch.einsum("bnd,bmd->bnm", x, y) / ( + x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps + ) + if self.learn_temperature: + T = self.T.abs() + 0.01 + else: + T = torch.tensor(self.T, device=c.device) + K = ((c - 1.0) / T).exp() + return K + + +class CAB(nn.Module): + def __init__(self, in_channels, out_channels): + super(CAB, self).__init__() + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + self.relu = nn.ReLU() + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + self.sigmod = nn.Sigmoid() + + def forward(self, x): + x1, x2 = x # high, low (old, new) + x = torch.cat([x1, x2], dim=1) + x = self.global_pooling(x) + x = self.conv1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.sigmod(x) + x2 = x * x2 + res = x2 + x1 + return res + + +class RRB(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3): + super(RRB, self).__init__() + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + self.conv2 = nn.Conv2d( + out_channels, + out_channels, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + ) + self.relu = nn.ReLU() + self.bn = nn.BatchNorm2d(out_channels) + self.conv3 = nn.Conv2d( + out_channels, + out_channels, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + ) + + def forward(self, x): + x = self.conv1(x) + res = self.conv2(x) + res = self.bn(res) + res = self.relu(res) + res = self.conv3(res) + return self.relu(x + res) + + +class DFN(nn.Module): + def __init__( + self, + internal_dim, + feat_input_modules, + pred_input_modules, + rrb_d_dict, + cab_dict, + rrb_u_dict, + use_global_context=False, + global_dim=None, + terminal_module=None, + upsample_mode="bilinear", + align_corners=False, + ): + super().__init__() + if use_global_context: + assert ( + global_dim is not None + ), "Global dim must be provided when using global context" + self.align_corners = align_corners + self.internal_dim = internal_dim + self.feat_input_modules = feat_input_modules + self.pred_input_modules = pred_input_modules + self.rrb_d = rrb_d_dict + self.cab = cab_dict + self.rrb_u = rrb_u_dict + self.use_global_context = use_global_context + if use_global_context: + self.global_to_internal = nn.Conv2d(global_dim, self.internal_dim, 1, 1, 0) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.terminal_module = ( + terminal_module if terminal_module is not None else nn.Identity() + ) + self.upsample_mode = upsample_mode + self._scales = [int(key) for key in self.terminal_module.keys()] + + def scales(self): + return self._scales.copy() + + def forward(self, embeddings, feats, context, key): + feats = self.feat_input_modules[str(key)](feats) + embeddings = torch.cat([feats, embeddings], dim=1) + embeddings = self.rrb_d[str(key)](embeddings) + context = self.cab[str(key)]([context, embeddings]) + context = self.rrb_u[str(key)](context) + preds = self.terminal_module[str(key)](context) + pred_coord = preds[:, -2:] + pred_certainty = preds[:, :-2] + return pred_coord, pred_certainty, context + + +class GP(nn.Module): + def __init__( + self, + kernel, + T=1, + learn_temperature=False, + only_attention=False, + gp_dim=64, + basis="fourier", + covar_size=5, + only_nearest_neighbour=False, + sigma_noise=0.1, + no_cov=False, + predict_features = False, + ): + super().__init__() + self.K = kernel(T=T, learn_temperature=learn_temperature) + self.sigma_noise = sigma_noise + self.covar_size = covar_size + self.pos_conv = torch.nn.Conv2d(2, gp_dim, 1, 1) + self.only_attention = only_attention + self.only_nearest_neighbour = only_nearest_neighbour + self.basis = basis + self.no_cov = no_cov + self.dim = gp_dim + self.predict_features = predict_features + + def get_local_cov(self, cov): + K = self.covar_size + b, h, w, h, w = cov.shape + hw = h * w + cov = F.pad(cov, 4 * (K // 2,)) # pad v_q + delta = torch.stack( + torch.meshgrid( + torch.arange(-(K // 2), K // 2 + 1), torch.arange(-(K // 2), K // 2 + 1) + ), + dim=-1, + ) + positions = torch.stack( + torch.meshgrid( + torch.arange(K // 2, h + K // 2), torch.arange(K // 2, w + K // 2) + ), + dim=-1, + ) + neighbours = positions[:, :, None, None, :] + delta[None, :, :] + points = torch.arange(hw)[:, None].expand(hw, K**2) + local_cov = cov.reshape(b, hw, h + K - 1, w + K - 1)[ + :, + points.flatten(), + neighbours[..., 0].flatten(), + neighbours[..., 1].flatten(), + ].reshape(b, h, w, K**2) + return local_cov + + def reshape(self, x): + return rearrange(x, "b d h w -> b (h w) d") + + def project_to_basis(self, x): + if self.basis == "fourier": + return torch.cos(8 * math.pi * self.pos_conv(x)) + elif self.basis == "linear": + return self.pos_conv(x) + else: + raise ValueError( + "No other bases other than fourier and linear currently supported in public release" + ) + + def get_pos_enc(self, y): + b, c, h, w = y.shape + coarse_coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=y.device), + torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=y.device), + ) + ) + + coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[ + None + ].expand(b, h, w, 2) + coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w") + coarse_embedded_coords = self.project_to_basis(coarse_coords) + return coarse_embedded_coords + + def forward(self, x, y, **kwargs): + b, c, h1, w1 = x.shape + b, c, h2, w2 = y.shape + f = self.get_pos_enc(y) + if self.predict_features: + f = f + y[:,:self.dim] # Stupid way to predict features + b, d, h2, w2 = f.shape + #assert x.shape == y.shape + x, y, f = self.reshape(x), self.reshape(y), self.reshape(f) + K_xx = self.K(x, x) + K_yy = self.K(y, y) + K_xy = self.K(x, y) + K_yx = K_xy.permute(0, 2, 1) + sigma_noise = self.sigma_noise * torch.eye(h2 * w2, device=x.device)[None, :, :] + # Due to https://github.com/pytorch/pytorch/issues/16963 annoying warnings, remove batch if N large + if len(K_yy[0]) > 2000: + K_yy_inv = torch.cat([torch.linalg.inv(K_yy[k:k+1] + sigma_noise[k:k+1]) for k in range(b)]) + else: + K_yy_inv = torch.linalg.inv(K_yy + sigma_noise) + + mu_x = K_xy.matmul(K_yy_inv.matmul(f)) + mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1) + if not self.no_cov: + cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx)) + cov_x = rearrange(cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1) + local_cov_x = self.get_local_cov(cov_x) + local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w") + gp_feats = torch.cat((mu_x, local_cov_x), dim=1) + else: + gp_feats = mu_x + return gp_feats + + +class Encoder(nn.Module): + def __init__(self, resnet): + super().__init__() + self.resnet = resnet + def forward(self, x): + x0 = x + b, c, h, w = x.shape + x = self.resnet.conv1(x) + x = self.resnet.bn1(x) + x1 = self.resnet.relu(x) + + x = self.resnet.maxpool(x1) + x2 = self.resnet.layer1(x) + + x3 = self.resnet.layer2(x2) + + x4 = self.resnet.layer3(x3) + + x5 = self.resnet.layer4(x4) + feats = {32: x5, 16: x4, 8: x3, 4: x2, 2: x1, 1: x0} + return feats + + def train(self, mode=True): + super().train(mode) + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + pass + + +class Decoder(nn.Module): + def __init__( + self, embedding_decoder, gps, proj, conv_refiner, transformers = None, detach=False, scales="all", pos_embeddings = None, + ): + super().__init__() + self.embedding_decoder = embedding_decoder + self.gps = gps + self.proj = proj + self.conv_refiner = conv_refiner + self.detach = detach + if scales == "all": + self.scales = ["32", "16", "8", "4", "2", "1"] + else: + self.scales = scales + + def upsample_preds(self, flow, certainty, query, support): + b, hs, ws, d = flow.shape + b, c, h, w = query.shape + flow = flow.permute(0, 3, 1, 2) + certainty = F.interpolate( + certainty, size=(h, w), align_corners=False, mode="bilinear" + ) + flow = F.interpolate( + flow, size=(h, w), align_corners=False, mode="bilinear" + ) + delta_certainty, delta_flow = self.conv_refiner["1"](query, support, flow) + flow = torch.stack( + ( + flow[:, 0] + delta_flow[:, 0] / (4 * w), + flow[:, 1] + delta_flow[:, 1] / (4 * h), + ), + dim=1, + ) + flow = flow.permute(0, 2, 3, 1) + certainty = certainty + delta_certainty + return flow, certainty + + def get_placeholder_flow(self, b, h, w, device): + coarse_coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device), + torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device), + ) + ) + coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[ + None + ].expand(b, h, w, 2) + coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w") + return coarse_coords + + + def forward(self, f1, f2, upsample = False, dense_flow = None, dense_certainty = None): + coarse_scales = self.embedding_decoder.scales() + all_scales = self.scales if not upsample else ["8", "4", "2", "1"] + sizes = {scale: f1[scale].shape[-2:] for scale in f1} + h, w = sizes[1] + b = f1[1].shape[0] + device = f1[1].device + coarsest_scale = int(all_scales[0]) + old_stuff = torch.zeros( + b, self.embedding_decoder.internal_dim, *sizes[coarsest_scale], device=f1[coarsest_scale].device + ) + dense_corresps = {} + if not upsample: + dense_flow = self.get_placeholder_flow(b, *sizes[coarsest_scale], device) + dense_certainty = 0.0 + else: + dense_flow = F.interpolate( + dense_flow, + size=sizes[coarsest_scale], + align_corners=False, + mode="bilinear", + ) + dense_certainty = F.interpolate( + dense_certainty, + size=sizes[coarsest_scale], + align_corners=False, + mode="bilinear", + ) + for new_scale in all_scales: + ins = int(new_scale) + f1_s, f2_s = f1[ins], f2[ins] + if new_scale in self.proj: + f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s) + b, c, hs, ws = f1_s.shape + if ins in coarse_scales: + old_stuff = F.interpolate( + old_stuff, size=sizes[ins], mode="bilinear", align_corners=False + ) + new_stuff = self.gps[new_scale](f1_s, f2_s, dense_flow=dense_flow) + dense_flow, dense_certainty, old_stuff = self.embedding_decoder( + new_stuff, f1_s, old_stuff, new_scale + ) + + if new_scale in self.conv_refiner: + delta_certainty, displacement = self.conv_refiner[new_scale]( + f1_s, f2_s, dense_flow + ) + dense_flow = torch.stack( + ( + dense_flow[:, 0] + ins * displacement[:, 0] / (4 * w), + dense_flow[:, 1] + ins * displacement[:, 1] / (4 * h), + ), + dim=1, + ) + dense_certainty = ( + dense_certainty + delta_certainty + ) # predict both certainty and displacement + + dense_corresps[ins] = { + "dense_flow": dense_flow, + "dense_certainty": dense_certainty, + } + + if new_scale != "1": + dense_flow = F.interpolate( + dense_flow, + size=sizes[ins // 2], + align_corners=False, + mode="bilinear", + ) + + dense_certainty = F.interpolate( + dense_certainty, + size=sizes[ins // 2], + align_corners=False, + mode="bilinear", + ) + if self.detach: + dense_flow = dense_flow.detach() + dense_certainty = dense_certainty.detach() + return dense_corresps + + +class RegressionMatcher(nn.Module): + def __init__( + self, + encoder, + decoder, + h=384, + w=512, + use_contrastive_loss = False, + alpha = 1, + beta = 0, + sample_mode = "threshold", + upsample_preds = False, + symmetric = False, + name = None, + use_soft_mutual_nearest_neighbours = False, + ): + super().__init__() + self.encoder = encoder + self.decoder = decoder + self.w_resized = w + self.h_resized = h + self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True) + self.use_contrastive_loss = use_contrastive_loss + self.alpha = alpha + self.beta = beta + self.sample_mode = sample_mode + self.upsample_preds = upsample_preds + self.symmetric = symmetric + self.name = name + self.sample_thresh = 0.05 + self.upsample_res = (864,1152) + if use_soft_mutual_nearest_neighbours: + assert symmetric, "MNS requires symmetric inference" + self.use_soft_mutual_nearest_neighbours = use_soft_mutual_nearest_neighbours + + def extract_backbone_features(self, batch, batched = True, upsample = True): + #TODO: only extract stride [1,2,4,8] for upsample = True + x_q = batch["query"] + x_s = batch["support"] + if batched: + X = torch.cat((x_q, x_s)) + feature_pyramid = self.encoder(X) + else: + feature_pyramid = self.encoder(x_q), self.encoder(x_s) + return feature_pyramid + + def sample( + self, + dense_matches, + dense_certainty, + num=10000, + ): + if "threshold" in self.sample_mode: + upper_thresh = self.sample_thresh + dense_certainty = dense_certainty.clone() + dense_certainty[dense_certainty > upper_thresh] = 1 + elif "pow" in self.sample_mode: + dense_certainty = dense_certainty**(1/3) + elif "naive" in self.sample_mode: + dense_certainty = torch.ones_like(dense_certainty) + matches, certainty = ( + dense_matches.reshape(-1, 4), + dense_certainty.reshape(-1), + ) + expansion_factor = 4 if "balanced" in self.sample_mode else 1 + good_samples = torch.multinomial(certainty, + num_samples = min(expansion_factor*num, len(certainty)), + replacement=False) + good_matches, good_certainty = matches[good_samples], certainty[good_samples] + if "balanced" not in self.sample_mode: + return good_matches, good_certainty + + from ..utils.kde import kde + density = kde(good_matches, std=0.1) + p = 1 / (density+1) + p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones + balanced_samples = torch.multinomial(p, + num_samples = min(num,len(good_certainty)), + replacement=False) + return good_matches[balanced_samples], good_certainty[balanced_samples] + + def forward(self, batch, batched = True): + feature_pyramid = self.extract_backbone_features(batch, batched=batched) + if batched: + f_q_pyramid = { + scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items() + } + f_s_pyramid = { + scale: f_scale.chunk(2)[1] for scale, f_scale in feature_pyramid.items() + } + else: + f_q_pyramid, f_s_pyramid = feature_pyramid + dense_corresps = self.decoder(f_q_pyramid, f_s_pyramid) + if self.training and self.use_contrastive_loss: + return dense_corresps, (f_q_pyramid, f_s_pyramid) + else: + return dense_corresps + + def forward_symmetric(self, batch, upsample = False, batched = True): + feature_pyramid = self.extract_backbone_features(batch, upsample = upsample, batched = batched) + f_q_pyramid = feature_pyramid + f_s_pyramid = { + scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0])) + for scale, f_scale in feature_pyramid.items() + } + dense_corresps = self.decoder(f_q_pyramid, f_s_pyramid, upsample = upsample, **(batch["corresps"] if "corresps" in batch else {})) + return dense_corresps + + def to_pixel_coordinates(self, matches, H_A, W_A, H_B, W_B): + kpts_A, kpts_B = matches[...,:2], matches[...,2:] + kpts_A = torch.stack((W_A/2 * (kpts_A[...,0]+1), H_A/2 * (kpts_A[...,1]+1)),axis=-1) + kpts_B = torch.stack((W_B/2 * (kpts_B[...,0]+1), H_B/2 * (kpts_B[...,1]+1)),axis=-1) + return kpts_A, kpts_B + + def match( + self, + im1_path, + im2_path, + *args, + batched=False, + device = None + ): + assert not (batched and self.upsample_preds), "Cannot upsample preds if in batchmode (as we don't have access to high res images). You can turn off upsample_preds by model.upsample_preds = False " + if isinstance(im1_path, (str, os.PathLike)): + im1, im2 = Image.open(im1_path), Image.open(im2_path) + else: # assume it is a PIL Image + im1, im2 = im1_path, im2_path + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + symmetric = self.symmetric + self.train(False) + with torch.no_grad(): + if not batched: + b = 1 + w, h = im1.size + w2, h2 = im2.size + # Get images in good format + ws = self.w_resized + hs = self.h_resized + + test_transform = get_tuple_transform_ops( + resize=(hs, ws), normalize=True + ) + query, support = test_transform((im1, im2)) + batch = {"query": query[None].to(device), "support": support[None].to(device)} + else: + b, c, h, w = im1.shape + b, c, h2, w2 = im2.shape + assert w == w2 and h == h2, "For batched images we assume same size" + batch = {"query": im1.to(device), "support": im2.to(device)} + hs, ws = self.h_resized, self.w_resized + finest_scale = 1 + # Run matcher + if symmetric: + dense_corresps = self.forward_symmetric(batch, batched = True) + else: + dense_corresps = self.forward(batch, batched = True) + + if self.upsample_preds: + hs, ws = self.upsample_res + low_res_certainty = F.interpolate( + dense_corresps[16]["dense_certainty"], size=(hs, ws), align_corners=False, mode="bilinear" + ) + cert_clamp = 0 + factor = 0.5 + low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp) + + if self.upsample_preds: + test_transform = get_tuple_transform_ops( + resize=(hs, ws), normalize=True + ) + query, support = test_transform((im1, im2)) + query, support = query[None].to(device), support[None].to(device) + batch = {"query": query, "support": support, "corresps": dense_corresps[finest_scale]} + if symmetric: + dense_corresps = self.forward_symmetric(batch, upsample = True, batched=True) + else: + dense_corresps = self.forward(batch, batched = True, upsample=True) + query_to_support = dense_corresps[finest_scale]["dense_flow"] + dense_certainty = dense_corresps[finest_scale]["dense_certainty"] + + # Get certainty interpolation + dense_certainty = dense_certainty - low_res_certainty + query_to_support = query_to_support.permute( + 0, 2, 3, 1 + ) + # Create im1 meshgrid + query_coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device), + torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device), + ) + ) + query_coords = torch.stack((query_coords[1], query_coords[0])) + query_coords = query_coords[None].expand(b, 2, hs, ws) + dense_certainty = dense_certainty.sigmoid() # logits -> probs + query_coords = query_coords.permute(0, 2, 3, 1) + if (query_to_support.abs() > 1).any() and True: + wrong = (query_to_support.abs() > 1).sum(dim=-1) > 0 + dense_certainty[wrong[:,None]] = 0 + + query_to_support = torch.clamp(query_to_support, -1, 1) + if symmetric: + support_coords = query_coords + qts, stq = query_to_support.chunk(2) + q_warp = torch.cat((query_coords, qts), dim=-1) + s_warp = torch.cat((stq, support_coords), dim=-1) + warp = torch.cat((q_warp, s_warp),dim=2) + dense_certainty = torch.cat(dense_certainty.chunk(2), dim=3)[:,0] + else: + warp = torch.cat((query_coords, query_to_support), dim=-1) + if batched: + return ( + warp, + dense_certainty + ) + else: + return ( + warp[0], + dense_certainty[0], + ) diff --git a/third_party/DKM/dkm/models/encoders.py b/third_party/DKM/dkm/models/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..29077e1797196611e9b59a753130a5b153e0aa05 --- /dev/null +++ b/third_party/DKM/dkm/models/encoders.py @@ -0,0 +1,147 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as tvm + +class ResNet18(nn.Module): + def __init__(self, pretrained=False) -> None: + super().__init__() + self.net = tvm.resnet18(pretrained=pretrained) + def forward(self, x): + self = self.net + x1 = x + x = self.conv1(x1) + x = self.bn1(x) + x2 = self.relu(x) + x = self.maxpool(x2) + x4 = self.layer1(x) + x8 = self.layer2(x4) + x16 = self.layer3(x8) + x32 = self.layer4(x16) + return {32:x32,16:x16,8:x8,4:x4,2:x2,1:x1} + + def train(self, mode=True): + super().train(mode) + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + pass + +class ResNet50(nn.Module): + def __init__(self, pretrained=False, high_res = False, weights = None, dilation = None, freeze_bn = True, anti_aliased = False) -> None: + super().__init__() + if dilation is None: + dilation = [False,False,False] + if anti_aliased: + pass + else: + if weights is not None: + self.net = tvm.resnet50(weights = weights,replace_stride_with_dilation=dilation) + else: + self.net = tvm.resnet50(pretrained=pretrained,replace_stride_with_dilation=dilation) + + self.high_res = high_res + self.freeze_bn = freeze_bn + def forward(self, x): + net = self.net + feats = {1:x} + x = net.conv1(x) + x = net.bn1(x) + x = net.relu(x) + feats[2] = x + x = net.maxpool(x) + x = net.layer1(x) + feats[4] = x + x = net.layer2(x) + feats[8] = x + x = net.layer3(x) + feats[16] = x + x = net.layer4(x) + feats[32] = x + return feats + + def train(self, mode=True): + super().train(mode) + if self.freeze_bn: + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + pass + + + + +class ResNet101(nn.Module): + def __init__(self, pretrained=False, high_res = False, weights = None) -> None: + super().__init__() + if weights is not None: + self.net = tvm.resnet101(weights = weights) + else: + self.net = tvm.resnet101(pretrained=pretrained) + self.high_res = high_res + self.scale_factor = 1 if not high_res else 1.5 + def forward(self, x): + net = self.net + feats = {1:x} + sf = self.scale_factor + if self.high_res: + x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic") + x = net.conv1(x) + x = net.bn1(x) + x = net.relu(x) + feats[2] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + x = net.maxpool(x) + x = net.layer1(x) + feats[4] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + x = net.layer2(x) + feats[8] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + x = net.layer3(x) + feats[16] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + x = net.layer4(x) + feats[32] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + return feats + + def train(self, mode=True): + super().train(mode) + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + pass + + +class WideResNet50(nn.Module): + def __init__(self, pretrained=False, high_res = False, weights = None) -> None: + super().__init__() + if weights is not None: + self.net = tvm.wide_resnet50_2(weights = weights) + else: + self.net = tvm.wide_resnet50_2(pretrained=pretrained) + self.high_res = high_res + self.scale_factor = 1 if not high_res else 1.5 + def forward(self, x): + net = self.net + feats = {1:x} + sf = self.scale_factor + if self.high_res: + x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic") + x = net.conv1(x) + x = net.bn1(x) + x = net.relu(x) + feats[2] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + x = net.maxpool(x) + x = net.layer1(x) + feats[4] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + x = net.layer2(x) + feats[8] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + x = net.layer3(x) + feats[16] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + x = net.layer4(x) + feats[32] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + return feats + + def train(self, mode=True): + super().train(mode) + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + pass \ No newline at end of file diff --git a/third_party/DKM/dkm/models/model_zoo/DKMv3.py b/third_party/DKM/dkm/models/model_zoo/DKMv3.py new file mode 100644 index 0000000000000000000000000000000000000000..6f4c9ede3863d778f679a033d8d2287b8776e894 --- /dev/null +++ b/third_party/DKM/dkm/models/model_zoo/DKMv3.py @@ -0,0 +1,150 @@ +import torch + +from torch import nn +from ..dkm import * +from ..encoders import * + + +def DKMv3(weights, h, w, symmetric = True, sample_mode= "threshold_balanced", device = None, **kwargs): + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + gp_dim = 256 + dfn_dim = 384 + feat_dim = 256 + coordinate_decoder = DFN( + internal_dim=dfn_dim, + feat_input_modules=nn.ModuleDict( + { + "32": nn.Conv2d(512, feat_dim, 1, 1), + "16": nn.Conv2d(512, feat_dim, 1, 1), + } + ), + pred_input_modules=nn.ModuleDict( + { + "32": nn.Identity(), + "16": nn.Identity(), + } + ), + rrb_d_dict=nn.ModuleDict( + { + "32": RRB(gp_dim + feat_dim, dfn_dim), + "16": RRB(gp_dim + feat_dim, dfn_dim), + } + ), + cab_dict=nn.ModuleDict( + { + "32": CAB(2 * dfn_dim, dfn_dim), + "16": CAB(2 * dfn_dim, dfn_dim), + } + ), + rrb_u_dict=nn.ModuleDict( + { + "32": RRB(dfn_dim, dfn_dim), + "16": RRB(dfn_dim, dfn_dim), + } + ), + terminal_module=nn.ModuleDict( + { + "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0), + "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0), + } + ), + ) + dw = True + hidden_blocks = 8 + kernel_size = 5 + displacement_emb = "linear" + conv_refiner = nn.ModuleDict( + { + "16": ConvRefiner( + 2 * 512+128+(2*7+1)**2, + 2 * 512+128+(2*7+1)**2, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=128, + local_corr_radius = 7, + corr_in_other = True, + ), + "8": ConvRefiner( + 2 * 512+64+(2*3+1)**2, + 2 * 512+64+(2*3+1)**2, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=64, + local_corr_radius = 3, + corr_in_other = True, + ), + "4": ConvRefiner( + 2 * 256+32+(2*2+1)**2, + 2 * 256+32+(2*2+1)**2, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=32, + local_corr_radius = 2, + corr_in_other = True, + ), + "2": ConvRefiner( + 2 * 64+16, + 128+16, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=16, + ), + "1": ConvRefiner( + 2 * 3+6, + 24, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=6, + ), + } + ) + kernel_temperature = 0.2 + learn_temperature = False + no_cov = True + kernel = CosKernel + only_attention = False + basis = "fourier" + gp32 = GP( + kernel, + T=kernel_temperature, + learn_temperature=learn_temperature, + only_attention=only_attention, + gp_dim=gp_dim, + basis=basis, + no_cov=no_cov, + ) + gp16 = GP( + kernel, + T=kernel_temperature, + learn_temperature=learn_temperature, + only_attention=only_attention, + gp_dim=gp_dim, + basis=basis, + no_cov=no_cov, + ) + gps = nn.ModuleDict({"32": gp32, "16": gp16}) + proj = nn.ModuleDict( + {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)} + ) + decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True) + + encoder = ResNet50(pretrained = False, high_res = False, freeze_bn=False) + matcher = RegressionMatcher(encoder, decoder, h=h, w=w, name = "DKMv3", sample_mode=sample_mode, symmetric = symmetric, **kwargs).to(device) + res = matcher.load_state_dict(weights) + return matcher diff --git a/third_party/DKM/dkm/models/model_zoo/__init__.py b/third_party/DKM/dkm/models/model_zoo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c85da2920c1acfac140ada2d87623203607d42ca --- /dev/null +++ b/third_party/DKM/dkm/models/model_zoo/__init__.py @@ -0,0 +1,39 @@ +weight_urls = { + "DKMv3": { + "outdoor": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_outdoor.pth", + "indoor": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_indoor.pth", + }, +} +import torch +from .DKMv3 import DKMv3 + + +def DKMv3_outdoor(path_to_weights = None, device=None): + """ + Loads DKMv3 outdoor weights, uses internal resolution of (540, 720) by default + resolution can be changed by setting model.h_resized, model.w_resized later. + Additionally upsamples preds to fixed resolution of (864, 1152), + can be turned off by model.upsample_preds = False + """ + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + if path_to_weights is not None: + weights = torch.load(path_to_weights, map_location='cpu') + else: + weights = torch.hub.load_state_dict_from_url(weight_urls["DKMv3"]["outdoor"], + map_location='cpu') + return DKMv3(weights, 540, 720, upsample_preds = True, device=device) + +def DKMv3_indoor(path_to_weights = None, device=None): + """ + Loads DKMv3 indoor weights, uses internal resolution of (480, 640) by default + Resolution can be changed by setting model.h_resized, model.w_resized later. + """ + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + if path_to_weights is not None: + weights = torch.load(path_to_weights, map_location=device) + else: + weights = torch.hub.load_state_dict_from_url(weight_urls["DKMv3"]["indoor"], + map_location=device) + return DKMv3(weights, 480, 640, upsample_preds = False, device=device) diff --git a/third_party/DKM/dkm/train/__init__.py b/third_party/DKM/dkm/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..90269dc0f345a575e0ba21f5afa34202c7e6b433 --- /dev/null +++ b/third_party/DKM/dkm/train/__init__.py @@ -0,0 +1 @@ +from .train import train_k_epochs diff --git a/third_party/DKM/dkm/train/train.py b/third_party/DKM/dkm/train/train.py new file mode 100644 index 0000000000000000000000000000000000000000..b580221f56a2667784836f0237955cc75131b88c --- /dev/null +++ b/third_party/DKM/dkm/train/train.py @@ -0,0 +1,67 @@ +from tqdm import tqdm +from dkm.utils.utils import to_cuda + + +def train_step(train_batch, model, objective, optimizer, **kwargs): + optimizer.zero_grad() + out = model(train_batch) + l = objective(out, train_batch) + l.backward() + optimizer.step() + return {"train_out": out, "train_loss": l.item()} + + +def train_k_steps( + n_0, k, dataloader, model, objective, optimizer, lr_scheduler, progress_bar=True +): + for n in tqdm(range(n_0, n_0 + k), disable=not progress_bar): + batch = next(dataloader) + model.train(True) + batch = to_cuda(batch) + train_step( + train_batch=batch, + model=model, + objective=objective, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + n=n, + ) + lr_scheduler.step() + + +def train_epoch( + dataloader=None, + model=None, + objective=None, + optimizer=None, + lr_scheduler=None, + epoch=None, +): + model.train(True) + print(f"At epoch {epoch}") + for batch in tqdm(dataloader, mininterval=5.0): + batch = to_cuda(batch) + train_step( + train_batch=batch, model=model, objective=objective, optimizer=optimizer + ) + lr_scheduler.step() + return { + "model": model, + "optimizer": optimizer, + "lr_scheduler": lr_scheduler, + "epoch": epoch, + } + + +def train_k_epochs( + start_epoch, end_epoch, dataloader, model, objective, optimizer, lr_scheduler +): + for epoch in range(start_epoch, end_epoch + 1): + train_epoch( + dataloader=dataloader, + model=model, + objective=objective, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + epoch=epoch, + ) diff --git a/third_party/DKM/dkm/utils/__init__.py b/third_party/DKM/dkm/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05367ac9521664992f587738caa231f32ae2e81c --- /dev/null +++ b/third_party/DKM/dkm/utils/__init__.py @@ -0,0 +1,13 @@ +from .utils import ( + pose_auc, + get_pose, + compute_relative_pose, + compute_pose_error, + estimate_pose, + rotate_intrinsic, + get_tuple_transform_ops, + get_depth_tuple_transform_ops, + warp_kpts, + numpy_to_pil, + tensor_to_pil, +) diff --git a/third_party/DKM/dkm/utils/kde.py b/third_party/DKM/dkm/utils/kde.py new file mode 100644 index 0000000000000000000000000000000000000000..fa392455e70fda4c9c77c28bda76bcb7ef9045b0 --- /dev/null +++ b/third_party/DKM/dkm/utils/kde.py @@ -0,0 +1,26 @@ +import torch +import torch.nn.functional as F +import numpy as np + +def fast_kde(x, std = 0.1, kernel_size = 9, dilation = 3, padding = 9//2, stride = 1): + raise NotImplementedError("WIP, use at your own risk.") + # Note: when doing symmetric matching this might not be very exact, since we only check neighbours on the grid + x = x.permute(0,3,1,2) + B,C,H,W = x.shape + K = kernel_size ** 2 + unfolded_x = F.unfold(x,kernel_size=kernel_size, dilation = dilation, padding = padding, stride = stride).reshape(B, C, K, H, W) + scores = (-(unfolded_x - x[:,:,None]).sum(dim=1)**2/(2*std**2)).exp() + density = scores.sum(dim=1) + return density + + +def kde(x, std = 0.1, device=None): + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + if isinstance(x, np.ndarray): + x = torch.from_numpy(x) + # use a gaussian kernel to estimate density + x = x.to(device) + scores = (-torch.cdist(x,x)**2/(2*std**2)).exp() + density = scores.sum(dim=-1) + return density diff --git a/third_party/DKM/dkm/utils/local_correlation.py b/third_party/DKM/dkm/utils/local_correlation.py new file mode 100644 index 0000000000000000000000000000000000000000..c0c1c06291d0b760376a2b2162bcf49d6eb1303c --- /dev/null +++ b/third_party/DKM/dkm/utils/local_correlation.py @@ -0,0 +1,40 @@ +import torch +import torch.nn.functional as F + + +def local_correlation( + feature0, + feature1, + local_radius, + padding_mode="zeros", + flow = None +): + device = feature0.device + b, c, h, w = feature0.size() + if flow is None: + # If flow is None, assume feature0 and feature1 are aligned + coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device), + torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device), + )) + coords = torch.stack((coords[1], coords[0]), dim=-1)[ + None + ].expand(b, h, w, 2) + else: + coords = flow.permute(0,2,3,1) # If using flow, sample around flow target. + r = local_radius + local_window = torch.meshgrid( + ( + torch.linspace(-2*local_radius/h, 2*local_radius/h, 2*r+1, device=device), + torch.linspace(-2*local_radius/w, 2*local_radius/w, 2*r+1, device=device), + )) + local_window = torch.stack((local_window[1], local_window[0]), dim=-1)[ + None + ].expand(b, 2*r+1, 2*r+1, 2).reshape(b, (2*r+1)**2, 2) + coords = (coords[:,:,:,None]+local_window[:,None,None]).reshape(b,h,w*(2*r+1)**2,2) + window_feature = F.grid_sample( + feature1, coords, padding_mode=padding_mode, align_corners=False + )[...,None].reshape(b,c,h,w,(2*r+1)**2) + corr = torch.einsum("bchw, bchwk -> bkhw", feature0, window_feature)/(c**.5) + return corr diff --git a/third_party/DKM/dkm/utils/transforms.py b/third_party/DKM/dkm/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..754d853fda4cbcf89d2111bed4f44b0ca84f0518 --- /dev/null +++ b/third_party/DKM/dkm/utils/transforms.py @@ -0,0 +1,104 @@ +from typing import Dict +import numpy as np +import torch +import kornia.augmentation as K +from kornia.geometry.transform import warp_perspective + +# Adapted from Kornia +class GeometricSequential: + def __init__(self, *transforms, align_corners=True) -> None: + self.transforms = transforms + self.align_corners = align_corners + + def __call__(self, x, mode="bilinear"): + b, c, h, w = x.shape + M = torch.eye(3, device=x.device)[None].expand(b, 3, 3) + for t in self.transforms: + if np.random.rand() < t.p: + M = M.matmul( + t.compute_transformation(x, t.generate_parameters((b, c, h, w))) + ) + return ( + warp_perspective( + x, M, dsize=(h, w), mode=mode, align_corners=self.align_corners + ), + M, + ) + + def apply_transform(self, x, M, mode="bilinear"): + b, c, h, w = x.shape + return warp_perspective( + x, M, dsize=(h, w), align_corners=self.align_corners, mode=mode + ) + + +class RandomPerspective(K.RandomPerspective): + def generate_parameters(self, batch_shape: torch.Size) -> Dict[str, torch.Tensor]: + distortion_scale = torch.as_tensor( + self.distortion_scale, device=self._device, dtype=self._dtype + ) + return self.random_perspective_generator( + batch_shape[0], + batch_shape[-2], + batch_shape[-1], + distortion_scale, + self.same_on_batch, + self.device, + self.dtype, + ) + + def random_perspective_generator( + self, + batch_size: int, + height: int, + width: int, + distortion_scale: torch.Tensor, + same_on_batch: bool = False, + device: torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float32, + ) -> Dict[str, torch.Tensor]: + r"""Get parameters for ``perspective`` for a random perspective transform. + + Args: + batch_size (int): the tensor batch size. + height (int) : height of the image. + width (int): width of the image. + distortion_scale (torch.Tensor): it controls the degree of distortion and ranges from 0 to 1. + same_on_batch (bool): apply the same transformation across the batch. Default: False. + device (torch.device): the device on which the random numbers will be generated. Default: cpu. + dtype (torch.dtype): the data type of the generated random numbers. Default: float32. + + Returns: + params Dict[str, torch.Tensor]: parameters to be passed for transformation. + - start_points (torch.Tensor): element-wise perspective source areas with a shape of (B, 4, 2). + - end_points (torch.Tensor): element-wise perspective target areas with a shape of (B, 4, 2). + + Note: + The generated random numbers are not reproducible across different devices and dtypes. + """ + if not (distortion_scale.dim() == 0 and 0 <= distortion_scale <= 1): + raise AssertionError( + f"'distortion_scale' must be a scalar within [0, 1]. Got {distortion_scale}." + ) + if not ( + type(height) is int and height > 0 and type(width) is int and width > 0 + ): + raise AssertionError( + f"'height' and 'width' must be integers. Got {height}, {width}." + ) + + start_points: torch.Tensor = torch.tensor( + [[[0.0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]], + device=distortion_scale.device, + dtype=distortion_scale.dtype, + ).expand(batch_size, -1, -1) + + # generate random offset not larger than half of the image + fx = distortion_scale * width / 2 + fy = distortion_scale * height / 2 + + factor = torch.stack([fx, fy], dim=0).view(-1, 1, 2) + offset = (torch.rand_like(start_points) - 0.5) * 2 + end_points = start_points + factor * offset + + return dict(start_points=start_points, end_points=end_points) diff --git a/third_party/DKM/dkm/utils/utils.py b/third_party/DKM/dkm/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..46bbe60260930aed184c6fa5907c837c0177b304 --- /dev/null +++ b/third_party/DKM/dkm/utils/utils.py @@ -0,0 +1,341 @@ +import numpy as np +import cv2 +import torch +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode +import torch.nn.functional as F +from PIL import Image + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +# Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py +# --- GEOMETRY --- +def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999): + if len(kpts0) < 5: + return None + K0inv = np.linalg.inv(K0[:2,:2]) + K1inv = np.linalg.inv(K1[:2,:2]) + + kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T + kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T + + E, mask = cv2.findEssentialMat( + kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf, method=cv2.RANSAC + ) + + ret = None + if E is not None: + best_num_inliers = 0 + + for _E in np.split(E, len(E) / 3): + n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask) + if n > best_num_inliers: + best_num_inliers = n + ret = (R, t, mask.ravel() > 0) + return ret + + +def rotate_intrinsic(K, n): + base_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]]) + rot = np.linalg.matrix_power(base_rot, n) + return rot @ K + + +def rotate_pose_inplane(i_T_w, rot): + rotation_matrices = [ + np.array( + [ + [np.cos(r), -np.sin(r), 0.0, 0.0], + [np.sin(r), np.cos(r), 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ], + dtype=np.float32, + ) + for r in [np.deg2rad(d) for d in (0, 270, 180, 90)] + ] + return np.dot(rotation_matrices[rot], i_T_w) + + +def scale_intrinsics(K, scales): + scales = np.diag([1.0 / scales[0], 1.0 / scales[1], 1.0]) + return np.dot(scales, K) + + +def to_homogeneous(points): + return np.concatenate([points, np.ones_like(points[:, :1])], axis=-1) + + +def angle_error_mat(R1, R2): + cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2 + cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds + return np.rad2deg(np.abs(np.arccos(cos))) + + +def angle_error_vec(v1, v2): + n = np.linalg.norm(v1) * np.linalg.norm(v2) + return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0))) + + +def compute_pose_error(T_0to1, R, t): + R_gt = T_0to1[:3, :3] + t_gt = T_0to1[:3, 3] + error_t = angle_error_vec(t.squeeze(), t_gt) + error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation + error_R = angle_error_mat(R, R_gt) + return error_t, error_R + + +def pose_auc(errors, thresholds): + sort_idx = np.argsort(errors) + errors = np.array(errors.copy())[sort_idx] + recall = (np.arange(len(errors)) + 1) / len(errors) + errors = np.r_[0.0, errors] + recall = np.r_[0.0, recall] + aucs = [] + for t in thresholds: + last_index = np.searchsorted(errors, t) + r = np.r_[recall[:last_index], recall[last_index - 1]] + e = np.r_[errors[:last_index], t] + aucs.append(np.trapz(r, x=e) / t) + return aucs + + +# From Patch2Pix https://github.com/GrumpyZhou/patch2pix +def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False): + ops = [] + if resize: + ops.append(TupleResize(resize, mode=InterpolationMode.BILINEAR)) + return TupleCompose(ops) + + +def get_tuple_transform_ops(resize=None, normalize=True, unscale=False): + ops = [] + if resize: + ops.append(TupleResize(resize)) + if normalize: + ops.append(TupleToTensorScaled()) + ops.append( + TupleNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ) # Imagenet mean/std + else: + if unscale: + ops.append(TupleToTensorUnscaled()) + else: + ops.append(TupleToTensorScaled()) + return TupleCompose(ops) + + +class ToTensorScaled(object): + """Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]""" + + def __call__(self, im): + if not isinstance(im, torch.Tensor): + im = np.array(im, dtype=np.float32).transpose((2, 0, 1)) + im /= 255.0 + return torch.from_numpy(im) + else: + return im + + def __repr__(self): + return "ToTensorScaled(./255)" + + +class TupleToTensorScaled(object): + def __init__(self): + self.to_tensor = ToTensorScaled() + + def __call__(self, im_tuple): + return [self.to_tensor(im) for im in im_tuple] + + def __repr__(self): + return "TupleToTensorScaled(./255)" + + +class ToTensorUnscaled(object): + """Convert a RGB PIL Image to a CHW ordered Tensor""" + + def __call__(self, im): + return torch.from_numpy(np.array(im, dtype=np.float32).transpose((2, 0, 1))) + + def __repr__(self): + return "ToTensorUnscaled()" + + +class TupleToTensorUnscaled(object): + """Convert a RGB PIL Image to a CHW ordered Tensor""" + + def __init__(self): + self.to_tensor = ToTensorUnscaled() + + def __call__(self, im_tuple): + return [self.to_tensor(im) for im in im_tuple] + + def __repr__(self): + return "TupleToTensorUnscaled()" + + +class TupleResize(object): + def __init__(self, size, mode=InterpolationMode.BICUBIC): + self.size = size + self.resize = transforms.Resize(size, mode) + + def __call__(self, im_tuple): + return [self.resize(im) for im in im_tuple] + + def __repr__(self): + return "TupleResize(size={})".format(self.size) + + +class TupleNormalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + self.normalize = transforms.Normalize(mean=mean, std=std) + + def __call__(self, im_tuple): + return [self.normalize(im) for im in im_tuple] + + def __repr__(self): + return "TupleNormalize(mean={}, std={})".format(self.mean, self.std) + + +class TupleCompose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, im_tuple): + for t in self.transforms: + im_tuple = t(im_tuple) + return im_tuple + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" + return format_string + + +@torch.no_grad() +def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1): + """Warp kpts0 from I0 to I1 with depth, K and Rt + Also check covisibility and depth consistency. + Depth is consistent if relative error < 0.2 (hard-coded). + # https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here + Args: + kpts0 (torch.Tensor): [N, L, 2] - , should be normalized in (-1,1) + depth0 (torch.Tensor): [N, H, W], + depth1 (torch.Tensor): [N, H, W], + T_0to1 (torch.Tensor): [N, 3, 4], + K0 (torch.Tensor): [N, 3, 3], + K1 (torch.Tensor): [N, 3, 3], + Returns: + calculable_mask (torch.Tensor): [N, L] + warped_keypoints0 (torch.Tensor): [N, L, 2] + """ + ( + n, + h, + w, + ) = depth0.shape + kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode="bilinear")[ + :, 0, :, 0 + ] + kpts0 = torch.stack( + (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1 + ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5] + # Sample depth, get calculable_mask on depth != 0 + nonzero_mask = kpts0_depth != 0 + + # Unproject + kpts0_h = ( + torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) + * kpts0_depth[..., None] + ) # (N, L, 3) + kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L) + kpts0_cam = kpts0_n + + # Rigid Transform + w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L) + w_kpts0_depth_computed = w_kpts0_cam[:, 2, :] + + # Project + w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3) + w_kpts0 = w_kpts0_h[:, :, :2] / ( + w_kpts0_h[:, :, [2]] + 1e-4 + ) # (N, L, 2), +1e-4 to avoid zero depth + + # Covisible Check + h, w = depth1.shape[1:3] + covisible_mask = ( + (w_kpts0[:, :, 0] > 0) + * (w_kpts0[:, :, 0] < w - 1) + * (w_kpts0[:, :, 1] > 0) + * (w_kpts0[:, :, 1] < h - 1) + ) + w_kpts0 = torch.stack( + (2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1 + ) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h] + # w_kpts0[~covisible_mask, :] = -5 # xd + + w_kpts0_depth = F.grid_sample( + depth1[:, None], w_kpts0[:, :, None], mode="bilinear" + )[:, 0, :, 0] + consistent_mask = ( + (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth + ).abs() < 0.05 + valid_mask = nonzero_mask * covisible_mask * consistent_mask + + return valid_mask, w_kpts0 + + +imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) +imagenet_std = torch.tensor([0.229, 0.224, 0.225]).to(device) + + +def numpy_to_pil(x: np.ndarray): + """ + Args: + x: Assumed to be of shape (h,w,c) + """ + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + if x.max() <= 1.01: + x *= 255 + x = x.astype(np.uint8) + return Image.fromarray(x) + + +def tensor_to_pil(x, unnormalize=False): + if unnormalize: + x = x * imagenet_std[:, None, None] + imagenet_mean[:, None, None] + x = x.detach().permute(1, 2, 0).cpu().numpy() + x = np.clip(x, 0.0, 1.0) + return numpy_to_pil(x) + + +def to_cuda(batch): + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + batch[key] = value.to(device) + return batch + + +def to_cpu(batch): + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + batch[key] = value.cpu() + return batch + + +def get_pose(calib): + w, h = np.array(calib["imsize"])[0] + return np.array(calib["K"]), np.array(calib["R"]), np.array(calib["T"]).T, h, w + + +def compute_relative_pose(R1, t1, R2, t2): + rots = R2 @ (R1.T) + trans = -rots @ t1 + t2 + return rots, trans diff --git a/third_party/DKM/docs/api.md b/third_party/DKM/docs/api.md new file mode 100644 index 0000000000000000000000000000000000000000..d19e961a81f59ea6f33de1cc53bce16b4db9678c --- /dev/null +++ b/third_party/DKM/docs/api.md @@ -0,0 +1,24 @@ +## Creating a model +```python +from dkm import DKMv3_outdoor, DKMv3_indoor +DKMv3_outdoor() # creates an outdoor trained model +DKMv3_indoor() # creates an indoor trained model +``` +## Model settings +Note: Non-exhaustive list +```python +model.upsample_preds = True/False # Whether to upsample the predictions to higher resolution +model.upsample_res = (H_big, W_big) # Which resolution to use for upsampling +model.symmetric = True/False # Whether to compute a bidirectional warp +model.w_resized = W # width of image used +model.h_resized = H # height of image used +model.sample_mode = "threshold_balanced" # method for sampling matches. threshold_balanced is what was used in the paper +model.sample_threshold = 0.05 # the threshold for sampling, 0.05 works well for megadepth, for IMC2022 we found 0.2 to work better. +``` +## Running model +```python +warp, certainty = model.match(im_A, im_B) # produces a warp of shape [B,H,W,4] and certainty of shape [B,H,W] +matches, certainty = model.sample(warp, certainty) # samples from the warp using the certainty +kpts_A, kpts_B = model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B) # convenience function to convert normalized matches to pixel coordinates +``` + diff --git a/third_party/DKM/docs/benchmarks.md b/third_party/DKM/docs/benchmarks.md new file mode 100644 index 0000000000000000000000000000000000000000..30dd57af86ad4f85c621e430eef9e9c55ba9d2c5 --- /dev/null +++ b/third_party/DKM/docs/benchmarks.md @@ -0,0 +1,27 @@ +Benchmarking datasets for geometry estimation can be somewhat cumbersome to download. We provide instructions for the benchmarks we use below, and are happy to answer any questions. + +### HPatches +First, make sure that the "data/hpatches" path exists, e.g. by + +`` ln -s place/where/your/datasets/are/stored/hpatches data/hpatches `` + +Then run (if you don't already have hpatches downloaded) + +`` bash scripts/download_hpatches.sh`` + +### Megadepth-1500 (LoFTR Split) +1. We use the split made by LoFTR, which can be downloaded here https://drive.google.com/drive/folders/1nTkK1485FuwqA0DbZrK2Cl0WnXadUZdc. (You can also use the preprocessed megadepth dataset if you have it available) +2. The images should be located in data/megadepth/Undistorted_SfM/0015 and 0022. +3. The pair infos are provided here https://github.com/zju3dv/LoFTR/tree/master/assets/megadepth_test_1500_scene_info +3. Put those files in data/megadepth/xxx + +### Megadepth-8-Scenes (DKM Split) +1. The pair infos are provided in [assets](../assets/) +2. Put those files in data/megadepth/xxx + + +### Scannet-1500 (SuperGlue Split) +We use the same split of scannet as superglue. +1. LoFTR provides the split here: https://drive.google.com/drive/folders/1nTkK1485FuwqA0DbZrK2Cl0WnXadUZdc +2. Note that ScanNet requires you to sign a License agreement, which can be found http://kaldir.vc.in.tum.de/scannet/ScanNet_TOS.pdf +3. This benchmark should be put in the data/scannet folder diff --git a/third_party/DKM/docs/training.md b/third_party/DKM/docs/training.md new file mode 100644 index 0000000000000000000000000000000000000000..37d17171ac45c95ff587b9bbd525c4175558ff8a --- /dev/null +++ b/third_party/DKM/docs/training.md @@ -0,0 +1,21 @@ +Here we provide instructions for how to train our models, including download of datasets. + +### MegaDepth +First the MegaDepth dataset needs to be downloaded and preprocessed. This can be done by the following steps: +1. Download MegaDepth from here: https://www.cs.cornell.edu/projects/megadepth/ +2. Extract and preprocess: See https://github.com/mihaidusmanu/d2-net +3. Download our prepared scene info from here: https://github.com/Parskatt/storage/releases/download/prep_scene_info/prep_scene_info.tar +4. File structure should be data/megadepth/phoenix, data/megadepth/Undistorted_SfM, data/megadepth/prep_scene_info. +Then run +``` bash +python experiments/dkmv3/train_DKMv3_outdoor.py --gpus 4 +``` + +## Megadepth + Scannet +First follow the steps outlined above. +Then, see https://github.com/zju3dv/LoFTR/blob/master/docs/TRAINING.md + +Then run +``` bash +python experiments/dkmv3/train_DKMv3_indoor.py --gpus 4 +``` diff --git a/third_party/DKM/requirements.txt b/third_party/DKM/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..018696905e480072ebe7dd9a9010db8b9e77f1d8 --- /dev/null +++ b/third_party/DKM/requirements.txt @@ -0,0 +1,11 @@ +torch +einops +torchvision +opencv-python +kornia +albumentations +loguru +tqdm +matplotlib +h5py +wandb \ No newline at end of file diff --git a/third_party/DKM/scripts/download_hpatches.sh b/third_party/DKM/scripts/download_hpatches.sh new file mode 100644 index 0000000000000000000000000000000000000000..5cdc42f6c9304062773ea30852179f51580ea9e0 --- /dev/null +++ b/third_party/DKM/scripts/download_hpatches.sh @@ -0,0 +1,4 @@ +cd data/hpatches +wget http://icvl.ee.ic.ac.uk/vbalnt/hpatches/hpatches-sequences-release.tar.gz +tar -xvf hpatches-sequences-release.tar.gz -C . +rm hpatches-sequences-release.tar.gz \ No newline at end of file diff --git a/third_party/DKM/setup.py b/third_party/DKM/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..73ae664126066249200d72c1ec3166f4f2d76b10 --- /dev/null +++ b/third_party/DKM/setup.py @@ -0,0 +1,9 @@ +from setuptools import setup, find_packages + +setup( + name="dkm", + packages=find_packages(include=("dkm*",)), + version="0.3.0", + author="Johan Edstedt", + install_requires=open("requirements.txt", "r").read().split("\n"), +) diff --git a/third_party/Roma/.gitignore b/third_party/Roma/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..ff4633046059da69ab4b6e222909614ccda82ac4 --- /dev/null +++ b/third_party/Roma/.gitignore @@ -0,0 +1,5 @@ +*.egg-info* +*.vscode* +*__pycache__* +vis* +workspace* \ No newline at end of file diff --git a/third_party/Roma/LICENSE b/third_party/Roma/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..a115f899f8d09ef3b1def4a16c7bae1a0bd50fbe --- /dev/null +++ b/third_party/Roma/LICENSE @@ -0,0 +1,400 @@ + +Attribution-NonCommercial 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More_considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial 4.0 International Public +License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial 4.0 International Public License ("Public +License"). To the extent this Public License may be interpreted as a +contract, You are granted the Licensed Rights in consideration of Your +acceptance of these terms and conditions, and the Licensor grants You +such rights in consideration of benefits the Licensor receives from +making the Licensed Material available under these terms and +conditions. + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + d. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + e. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + f. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + g. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + h. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + i. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + j. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + k. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + l. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's + License You apply must not prevent recipients of the Adapted + Material from complying with this Public License. + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material; and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. diff --git a/third_party/Roma/README.md b/third_party/Roma/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5e984366c8f7af37615d7666f34cd82a90073fee --- /dev/null +++ b/third_party/Roma/README.md @@ -0,0 +1,63 @@ +# RoMa: Revisiting Robust Losses for Dense Feature Matching +### [Project Page (TODO)](https://parskatt.github.io/RoMa) | [Paper](https://arxiv.org/abs/2305.15404) +
+ +> RoMa: Revisiting Robust Lossses for Dense Feature Matching +> [Johan Edstedt](https://scholar.google.com/citations?user=Ul-vMR0AAAAJ), [Qiyu Sun](https://scholar.google.com/citations?user=HS2WuHkAAAAJ), [Georg Bökman](https://scholar.google.com/citations?user=FUE3Wd0AAAAJ), [Mårten Wadenbäck](https://scholar.google.com/citations?user=6WRQpCQAAAAJ), [Michael Felsberg](https://scholar.google.com/citations?&user=lkWfR08AAAAJ) +> Arxiv 2023 + +**NOTE!!! Very early code, there might be bugs** + +The codebase is in the [roma folder](roma). + +## Setup/Install +In your python environment (tested on Linux python 3.10), run: +```bash +pip install -e . +``` +## Demo / How to Use +We provide two demos in the [demos folder](demo). +Here's the gist of it: +```python +from roma import roma_outdoor +roma_model = roma_outdoor(device=device) +# Match +warp, certainty = roma_model.match(imA_path, imB_path, device=device) +# Sample matches for estimation +matches, certainty = roma_model.sample(warp, certainty) +# Convert to pixel coordinates (RoMa produces matches in [-1,1]x[-1,1]) +kptsA, kptsB = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B) +# Find a fundamental matrix (or anything else of interest) +F, mask = cv2.findFundamentalMat( + kptsA.cpu().numpy(), kptsB.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000 +) +``` +## Reproducing Results +The experiments in the paper are provided in the [experiments folder](experiments). + +### Training +1. First follow the instructions provided here: https://github.com/Parskatt/DKM for downloading and preprocessing datasets. +2. Run the relevant experiment, e.g., +```bash +torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d experiments/roma_outdoor.py +``` +### Testing +```bash +python experiments/roma_outdoor.py --only_test --benchmark mega-1500 +``` +## License +Due to our dependency on [DINOv2](https://github.com/facebookresearch/dinov2/blob/main/LICENSE), the license is sadly non-commercial only for the moment. + +## Acknowledgement +Our codebase builds on the code in [DKM](https://github.com/Parskatt/DKM). + +## BibTeX +If you find our models useful, please consider citing our paper! +``` +@article{edstedt2023roma, +title={{RoMa}: Revisiting Robust Lossses for Dense Feature Matching}, +author={Edstedt, Johan and Sun, Qiyu and Bökman, Georg and Wadenbäck, Mårten and Felsberg, Michael}, +journal={arXiv preprint arXiv:2305.15404}, +year={2023} +} +``` diff --git a/third_party/Roma/assets/sacre_coeur_A.jpg b/third_party/Roma/assets/sacre_coeur_A.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6e441dad34cf13d8a29d7c6a1519f4263c40058c --- /dev/null +++ b/third_party/Roma/assets/sacre_coeur_A.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:90d9c5f5a4d76425624989215120fba6f2899190a1d5654b88fa380c64cf6b2c +size 117985 diff --git a/third_party/Roma/assets/sacre_coeur_B.jpg b/third_party/Roma/assets/sacre_coeur_B.jpg new file mode 100644 index 0000000000000000000000000000000000000000..27a239a8fa7581d909104872754ecda79422e7b6 --- /dev/null +++ b/third_party/Roma/assets/sacre_coeur_B.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2f1eb9bdd4d80e480f672d6a729689ac77f9fd5c8deb90f59b377590f3ca4799 +size 152515 diff --git a/third_party/Roma/data/.gitignore b/third_party/Roma/data/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..c96a04f008ee21e260b28f7701595ed59e2839e3 --- /dev/null +++ b/third_party/Roma/data/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore \ No newline at end of file diff --git a/third_party/Roma/demo/demo_fundamental.py b/third_party/Roma/demo/demo_fundamental.py new file mode 100644 index 0000000000000000000000000000000000000000..31618d4b06cd56fdd4be9065fb00b826a19e10f9 --- /dev/null +++ b/third_party/Roma/demo/demo_fundamental.py @@ -0,0 +1,33 @@ +from PIL import Image +import torch +import cv2 +from roma import roma_outdoor + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + +if __name__ == "__main__": + from argparse import ArgumentParser + parser = ArgumentParser() + parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str) + parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str) + + args, _ = parser.parse_known_args() + im1_path = args.im_A_path + im2_path = args.im_B_path + + # Create model + roma_model = roma_outdoor(device=device) + + + W_A, H_A = Image.open(im1_path).size + W_B, H_B = Image.open(im2_path).size + + # Match + warp, certainty = roma_model.match(im1_path, im2_path, device=device) + # Sample matches for estimation + matches, certainty = roma_model.sample(warp, certainty) + kpts1, kpts2 = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B) + F, mask = cv2.findFundamentalMat( + kpts1.cpu().numpy(), kpts2.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000 + ) \ No newline at end of file diff --git a/third_party/Roma/demo/demo_match.py b/third_party/Roma/demo/demo_match.py new file mode 100644 index 0000000000000000000000000000000000000000..46413bb2b336e2ef2c0bc48315821e4de0fcb982 --- /dev/null +++ b/third_party/Roma/demo/demo_match.py @@ -0,0 +1,47 @@ +from PIL import Image +import torch +import torch.nn.functional as F +import numpy as np +from roma.utils.utils import tensor_to_pil + +from roma import roma_indoor + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + +if __name__ == "__main__": + from argparse import ArgumentParser + parser = ArgumentParser() + parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str) + parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str) + parser.add_argument("--save_path", default="demo/dkmv3_warp_sacre_coeur.jpg", type=str) + + args, _ = parser.parse_known_args() + im1_path = args.im_A_path + im2_path = args.im_B_path + save_path = args.save_path + + # Create model + roma_model = roma_indoor(device=device) + + H, W = roma_model.get_output_resolution() + + im1 = Image.open(im1_path).resize((W, H)) + im2 = Image.open(im2_path).resize((W, H)) + + # Match + warp, certainty = roma_model.match(im1_path, im2_path, device=device) + # Sampling not needed, but can be done with model.sample(warp, certainty) + x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1) + x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1) + + im2_transfer_rgb = F.grid_sample( + x2[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False + )[0] + im1_transfer_rgb = F.grid_sample( + x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False + )[0] + warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2) + white_im = torch.ones((H,2*W),device=device) + vis_im = certainty * warp_im + (1 - certainty) * white_im + tensor_to_pil(vis_im, unnormalize=False).save(save_path) \ No newline at end of file diff --git a/third_party/Roma/requirements.txt b/third_party/Roma/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..12addf0d0eb74e6cac0da6bca704eac0b28990d7 --- /dev/null +++ b/third_party/Roma/requirements.txt @@ -0,0 +1,13 @@ +torch +einops +torchvision +opencv-python +kornia +albumentations +loguru +tqdm +matplotlib +h5py +wandb +timm +xformers # Optional, used for memefficient attention \ No newline at end of file diff --git a/third_party/Roma/roma/__init__.py b/third_party/Roma/roma/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a7c96481e0a808b68c7b3054a3e34fa0b5c45ab9 --- /dev/null +++ b/third_party/Roma/roma/__init__.py @@ -0,0 +1,8 @@ +import os +from .models import roma_outdoor, roma_indoor + +DEBUG_MODE = False +RANK = int(os.environ.get('RANK', default = 0)) +GLOBAL_STEP = 0 +STEP_SIZE = 1 +LOCAL_RANK = -1 \ No newline at end of file diff --git a/third_party/Roma/roma/benchmarks/__init__.py b/third_party/Roma/roma/benchmarks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..de7b841a5a6ab2ba91297a181a79dfaa91c9e104 --- /dev/null +++ b/third_party/Roma/roma/benchmarks/__init__.py @@ -0,0 +1,4 @@ +from .hpatches_sequences_homog_benchmark import HpatchesHomogBenchmark +from .scannet_benchmark import ScanNetBenchmark +from .megadepth_pose_estimation_benchmark import MegaDepthPoseEstimationBenchmark +from .megadepth_dense_benchmark import MegadepthDenseBenchmark diff --git a/third_party/Roma/roma/benchmarks/hpatches_sequences_homog_benchmark.py b/third_party/Roma/roma/benchmarks/hpatches_sequences_homog_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..2154a471c73d9e883c3ba8ed1b90d708f4950a63 --- /dev/null +++ b/third_party/Roma/roma/benchmarks/hpatches_sequences_homog_benchmark.py @@ -0,0 +1,113 @@ +from PIL import Image +import numpy as np + +import os + +from tqdm import tqdm +from roma.utils import pose_auc +import cv2 + + +class HpatchesHomogBenchmark: + """Hpatches grid goes from [0,n-1] instead of [0.5,n-0.5]""" + + def __init__(self, dataset_path) -> None: + seqs_dir = "hpatches-sequences-release" + self.seqs_path = os.path.join(dataset_path, seqs_dir) + self.seq_names = sorted(os.listdir(self.seqs_path)) + # Ignore seqs is same as LoFTR. + self.ignore_seqs = set( + [ + "i_contruction", + "i_crownnight", + "i_dc", + "i_pencils", + "i_whitebuilding", + "v_artisans", + "v_astronautis", + "v_talent", + ] + ) + + def convert_coordinates(self, im_A_coords, im_A_to_im_B, wq, hq, wsup, hsup): + offset = 0.5 # Hpatches assumes that the center of the top-left pixel is at [0,0] (I think) + im_A_coords = ( + np.stack( + ( + wq * (im_A_coords[..., 0] + 1) / 2, + hq * (im_A_coords[..., 1] + 1) / 2, + ), + axis=-1, + ) + - offset + ) + im_A_to_im_B = ( + np.stack( + ( + wsup * (im_A_to_im_B[..., 0] + 1) / 2, + hsup * (im_A_to_im_B[..., 1] + 1) / 2, + ), + axis=-1, + ) + - offset + ) + return im_A_coords, im_A_to_im_B + + def benchmark(self, model, model_name = None): + n_matches = [] + homog_dists = [] + for seq_idx, seq_name in tqdm( + enumerate(self.seq_names), total=len(self.seq_names) + ): + im_A_path = os.path.join(self.seqs_path, seq_name, "1.ppm") + im_A = Image.open(im_A_path) + w1, h1 = im_A.size + for im_idx in range(2, 7): + im_B_path = os.path.join(self.seqs_path, seq_name, f"{im_idx}.ppm") + im_B = Image.open(im_B_path) + w2, h2 = im_B.size + H = np.loadtxt( + os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx)) + ) + dense_matches, dense_certainty = model.match( + im_A_path, im_B_path + ) + good_matches, _ = model.sample(dense_matches, dense_certainty, 5000) + pos_a, pos_b = self.convert_coordinates( + good_matches[:, :2], good_matches[:, 2:], w1, h1, w2, h2 + ) + try: + H_pred, inliers = cv2.findHomography( + pos_a, + pos_b, + method = cv2.RANSAC, + confidence = 0.99999, + ransacReprojThreshold = 3 * min(w2, h2) / 480, + ) + except: + H_pred = None + if H_pred is None: + H_pred = np.zeros((3, 3)) + H_pred[2, 2] = 1.0 + corners = np.array( + [[0, 0, 1], [0, h1 - 1, 1], [w1 - 1, 0, 1], [w1 - 1, h1 - 1, 1]] + ) + real_warped_corners = np.dot(corners, np.transpose(H)) + real_warped_corners = ( + real_warped_corners[:, :2] / real_warped_corners[:, 2:] + ) + warped_corners = np.dot(corners, np.transpose(H_pred)) + warped_corners = warped_corners[:, :2] / warped_corners[:, 2:] + mean_dist = np.mean( + np.linalg.norm(real_warped_corners - warped_corners, axis=1) + ) / (min(w2, h2) / 480.0) + homog_dists.append(mean_dist) + + n_matches = np.array(n_matches) + thresholds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + auc = pose_auc(np.array(homog_dists), thresholds) + return { + "hpatches_homog_auc_3": auc[2], + "hpatches_homog_auc_5": auc[4], + "hpatches_homog_auc_10": auc[9], + } diff --git a/third_party/Roma/roma/benchmarks/megadepth_dense_benchmark.py b/third_party/Roma/roma/benchmarks/megadepth_dense_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..0600d354b1d0dfa7f8e2b0f8882a4cc08fafeed9 --- /dev/null +++ b/third_party/Roma/roma/benchmarks/megadepth_dense_benchmark.py @@ -0,0 +1,106 @@ +import torch +import numpy as np +import tqdm +from roma.datasets import MegadepthBuilder +from roma.utils import warp_kpts +from torch.utils.data import ConcatDataset +import roma + +class MegadepthDenseBenchmark: + def __init__(self, data_root="data/megadepth", h = 384, w = 512, num_samples = 2000) -> None: + mega = MegadepthBuilder(data_root=data_root) + self.dataset = ConcatDataset( + mega.build_scenes(split="test_loftr", ht=h, wt=w) + ) # fixed resolution of 384,512 + self.num_samples = num_samples + + def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches): + b, h1, w1, d = dense_matches.shape + with torch.no_grad(): + x1 = dense_matches[..., :2].reshape(b, h1 * w1, 2) + mask, x2 = warp_kpts( + x1.double(), + depth1.double(), + depth2.double(), + T_1to2.double(), + K1.double(), + K2.double(), + ) + x2 = torch.stack( + (w1 * (x2[..., 0] + 1) / 2, h1 * (x2[..., 1] + 1) / 2), dim=-1 + ) + prob = mask.float().reshape(b, h1, w1) + x2_hat = dense_matches[..., 2:] + x2_hat = torch.stack( + (w1 * (x2_hat[..., 0] + 1) / 2, h1 * (x2_hat[..., 1] + 1) / 2), dim=-1 + ) + gd = (x2_hat - x2.reshape(b, h1, w1, 2)).norm(dim=-1) + gd = gd[prob == 1] + pck_1 = (gd < 1.0).float().mean() + pck_3 = (gd < 3.0).float().mean() + pck_5 = (gd < 5.0).float().mean() + return gd, pck_1, pck_3, pck_5, prob + + def benchmark(self, model, batch_size=8): + model.train(False) + with torch.no_grad(): + gd_tot = 0.0 + pck_1_tot = 0.0 + pck_3_tot = 0.0 + pck_5_tot = 0.0 + sampler = torch.utils.data.WeightedRandomSampler( + torch.ones(len(self.dataset)), replacement=False, num_samples=self.num_samples + ) + B = batch_size + dataloader = torch.utils.data.DataLoader( + self.dataset, batch_size=B, num_workers=batch_size, sampler=sampler + ) + for idx, data in tqdm.tqdm(enumerate(dataloader), disable = roma.RANK > 0): + im_A, im_B, depth1, depth2, T_1to2, K1, K2 = ( + data["im_A"], + data["im_B"], + data["im_A_depth"].cuda(), + data["im_B_depth"].cuda(), + data["T_1to2"].cuda(), + data["K1"].cuda(), + data["K2"].cuda(), + ) + matches, certainty = model.match(im_A, im_B, batched=True) + gd, pck_1, pck_3, pck_5, prob = self.geometric_dist( + depth1, depth2, T_1to2, K1, K2, matches + ) + if roma.DEBUG_MODE: + from roma.utils.utils import tensor_to_pil + import torch.nn.functional as F + path = "vis" + H, W = model.get_output_resolution() + white_im = torch.ones((B,1,H,W),device="cuda") + im_B_transfer_rgb = F.grid_sample( + im_B.cuda(), matches[:,:,:W, 2:], mode="bilinear", align_corners=False + ) + warp_im = im_B_transfer_rgb + c_b = certainty[:,None]#(certainty*0.9 + 0.1*torch.ones_like(certainty))[:,None] + vis_im = c_b * warp_im + (1 - c_b) * white_im + for b in range(B): + import os + os.makedirs(f"{path}/{model.name}/{idx}_{b}_{H}_{W}",exist_ok=True) + tensor_to_pil(vis_im[b], unnormalize=True).save( + f"{path}/{model.name}/{idx}_{b}_{H}_{W}/warp.jpg") + tensor_to_pil(im_A[b].cuda(), unnormalize=True).save( + f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_A.jpg") + tensor_to_pil(im_B[b].cuda(), unnormalize=True).save( + f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_B.jpg") + + + gd_tot, pck_1_tot, pck_3_tot, pck_5_tot = ( + gd_tot + gd.mean(), + pck_1_tot + pck_1, + pck_3_tot + pck_3, + pck_5_tot + pck_5, + ) + return { + "epe": gd_tot.item() / len(dataloader), + "mega_pck_1": pck_1_tot.item() / len(dataloader), + "mega_pck_3": pck_3_tot.item() / len(dataloader), + "mega_pck_5": pck_5_tot.item() / len(dataloader), + } diff --git a/third_party/Roma/roma/benchmarks/megadepth_pose_estimation_benchmark.py b/third_party/Roma/roma/benchmarks/megadepth_pose_estimation_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..8007fe8ecad09c33401450ad6b7af1f3dad043d2 --- /dev/null +++ b/third_party/Roma/roma/benchmarks/megadepth_pose_estimation_benchmark.py @@ -0,0 +1,140 @@ +import numpy as np +import torch +from roma.utils import * +from PIL import Image +from tqdm import tqdm +import torch.nn.functional as F +import roma +import kornia.geometry.epipolar as kepi + +class MegaDepthPoseEstimationBenchmark: + def __init__(self, data_root="data/megadepth", scene_names = None) -> None: + if scene_names is None: + self.scene_names = [ + "0015_0.1_0.3.npz", + "0015_0.3_0.5.npz", + "0022_0.1_0.3.npz", + "0022_0.3_0.5.npz", + "0022_0.5_0.7.npz", + ] + else: + self.scene_names = scene_names + self.scenes = [ + np.load(f"{data_root}/{scene}", allow_pickle=True) + for scene in self.scene_names + ] + self.data_root = data_root + + def benchmark(self, model, model_name = None, resolution = None, scale_intrinsics = True, calibrated = True): + H,W = model.get_output_resolution() + with torch.no_grad(): + data_root = self.data_root + tot_e_t, tot_e_R, tot_e_pose = [], [], [] + thresholds = [5, 10, 20] + for scene_ind in range(len(self.scenes)): + import os + scene_name = os.path.splitext(self.scene_names[scene_ind])[0] + scene = self.scenes[scene_ind] + pairs = scene["pair_infos"] + intrinsics = scene["intrinsics"] + poses = scene["poses"] + im_paths = scene["image_paths"] + pair_inds = range(len(pairs)) + for pairind in tqdm(pair_inds): + idx1, idx2 = pairs[pairind][0] + K1 = intrinsics[idx1].copy() + T1 = poses[idx1].copy() + R1, t1 = T1[:3, :3], T1[:3, 3] + K2 = intrinsics[idx2].copy() + T2 = poses[idx2].copy() + R2, t2 = T2[:3, :3], T2[:3, 3] + R, t = compute_relative_pose(R1, t1, R2, t2) + T1_to_2 = np.concatenate((R,t[:,None]), axis=-1) + im_A_path = f"{data_root}/{im_paths[idx1]}" + im_B_path = f"{data_root}/{im_paths[idx2]}" + dense_matches, dense_certainty = model.match( + im_A_path, im_B_path, K1.copy(), K2.copy(), T1_to_2.copy() + ) + sparse_matches,_ = model.sample( + dense_matches, dense_certainty, 5000 + ) + + im_A = Image.open(im_A_path) + w1, h1 = im_A.size + im_B = Image.open(im_B_path) + w2, h2 = im_B.size + + if scale_intrinsics: + scale1 = 1200 / max(w1, h1) + scale2 = 1200 / max(w2, h2) + w1, h1 = scale1 * w1, scale1 * h1 + w2, h2 = scale2 * w2, scale2 * h2 + K1, K2 = K1.copy(), K2.copy() + K1[:2] = K1[:2] * scale1 + K2[:2] = K2[:2] * scale2 + + kpts1 = sparse_matches[:, :2] + kpts1 = ( + np.stack( + ( + w1 * (kpts1[:, 0] + 1) / 2, + h1 * (kpts1[:, 1] + 1) / 2, + ), + axis=-1, + ) + ) + kpts2 = sparse_matches[:, 2:] + kpts2 = ( + np.stack( + ( + w2 * (kpts2[:, 0] + 1) / 2, + h2 * (kpts2[:, 1] + 1) / 2, + ), + axis=-1, + ) + ) + + for _ in range(5): + shuffling = np.random.permutation(np.arange(len(kpts1))) + kpts1 = kpts1[shuffling] + kpts2 = kpts2[shuffling] + try: + threshold = 0.5 + if calibrated: + norm_threshold = threshold / (np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2]))) + R_est, t_est, mask = estimate_pose( + kpts1, + kpts2, + K1, + K2, + norm_threshold, + conf=0.99999, + ) + T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) # + e_t, e_R = compute_pose_error(T1_to_2_est, R, t) + e_pose = max(e_t, e_R) + except Exception as e: + print(repr(e)) + e_t, e_R = 90, 90 + e_pose = max(e_t, e_R) + tot_e_t.append(e_t) + tot_e_R.append(e_R) + tot_e_pose.append(e_pose) + tot_e_pose = np.array(tot_e_pose) + auc = pose_auc(tot_e_pose, thresholds) + acc_5 = (tot_e_pose < 5).mean() + acc_10 = (tot_e_pose < 10).mean() + acc_15 = (tot_e_pose < 15).mean() + acc_20 = (tot_e_pose < 20).mean() + map_5 = acc_5 + map_10 = np.mean([acc_5, acc_10]) + map_20 = np.mean([acc_5, acc_10, acc_15, acc_20]) + print(f"{model_name} auc: {auc}") + return { + "auc_5": auc[0], + "auc_10": auc[1], + "auc_20": auc[2], + "map_5": map_5, + "map_10": map_10, + "map_20": map_20, + } diff --git a/third_party/Roma/roma/benchmarks/scannet_benchmark.py b/third_party/Roma/roma/benchmarks/scannet_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..853af0d0ebef4dfefe2632eb49e4156ea791ee76 --- /dev/null +++ b/third_party/Roma/roma/benchmarks/scannet_benchmark.py @@ -0,0 +1,143 @@ +import os.path as osp +import numpy as np +import torch +from roma.utils import * +from PIL import Image +from tqdm import tqdm + + +class ScanNetBenchmark: + def __init__(self, data_root="data/scannet") -> None: + self.data_root = data_root + + def benchmark(self, model, model_name = None): + model.train(False) + with torch.no_grad(): + data_root = self.data_root + tmp = np.load(osp.join(data_root, "test.npz")) + pairs, rel_pose = tmp["name"], tmp["rel_pose"] + tot_e_t, tot_e_R, tot_e_pose = [], [], [] + pair_inds = np.random.choice( + range(len(pairs)), size=len(pairs), replace=False + ) + for pairind in tqdm(pair_inds, smoothing=0.9): + scene = pairs[pairind] + scene_name = f"scene0{scene[0]}_00" + im_A_path = osp.join( + self.data_root, + "scans_test", + scene_name, + "color", + f"{scene[2]}.jpg", + ) + im_A = Image.open(im_A_path) + im_B_path = osp.join( + self.data_root, + "scans_test", + scene_name, + "color", + f"{scene[3]}.jpg", + ) + im_B = Image.open(im_B_path) + T_gt = rel_pose[pairind].reshape(3, 4) + R, t = T_gt[:3, :3], T_gt[:3, 3] + K = np.stack( + [ + np.array([float(i) for i in r.split()]) + for r in open( + osp.join( + self.data_root, + "scans_test", + scene_name, + "intrinsic", + "intrinsic_color.txt", + ), + "r", + ) + .read() + .split("\n") + if r + ] + ) + w1, h1 = im_A.size + w2, h2 = im_B.size + K1 = K.copy() + K2 = K.copy() + dense_matches, dense_certainty = model.match(im_A_path, im_B_path) + sparse_matches, sparse_certainty = model.sample( + dense_matches, dense_certainty, 5000 + ) + scale1 = 480 / min(w1, h1) + scale2 = 480 / min(w2, h2) + w1, h1 = scale1 * w1, scale1 * h1 + w2, h2 = scale2 * w2, scale2 * h2 + K1 = K1 * scale1 + K2 = K2 * scale2 + + offset = 0.5 + kpts1 = sparse_matches[:, :2] + kpts1 = ( + np.stack( + ( + w1 * (kpts1[:, 0] + 1) / 2 - offset, + h1 * (kpts1[:, 1] + 1) / 2 - offset, + ), + axis=-1, + ) + ) + kpts2 = sparse_matches[:, 2:] + kpts2 = ( + np.stack( + ( + w2 * (kpts2[:, 0] + 1) / 2 - offset, + h2 * (kpts2[:, 1] + 1) / 2 - offset, + ), + axis=-1, + ) + ) + for _ in range(5): + shuffling = np.random.permutation(np.arange(len(kpts1))) + kpts1 = kpts1[shuffling] + kpts2 = kpts2[shuffling] + try: + norm_threshold = 0.5 / ( + np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2]))) + R_est, t_est, mask = estimate_pose( + kpts1, + kpts2, + K1, + K2, + norm_threshold, + conf=0.99999, + ) + T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) # + e_t, e_R = compute_pose_error(T1_to_2_est, R, t) + e_pose = max(e_t, e_R) + except Exception as e: + print(repr(e)) + e_t, e_R = 90, 90 + e_pose = max(e_t, e_R) + tot_e_t.append(e_t) + tot_e_R.append(e_R) + tot_e_pose.append(e_pose) + tot_e_t.append(e_t) + tot_e_R.append(e_R) + tot_e_pose.append(e_pose) + tot_e_pose = np.array(tot_e_pose) + thresholds = [5, 10, 20] + auc = pose_auc(tot_e_pose, thresholds) + acc_5 = (tot_e_pose < 5).mean() + acc_10 = (tot_e_pose < 10).mean() + acc_15 = (tot_e_pose < 15).mean() + acc_20 = (tot_e_pose < 20).mean() + map_5 = acc_5 + map_10 = np.mean([acc_5, acc_10]) + map_20 = np.mean([acc_5, acc_10, acc_15, acc_20]) + return { + "auc_5": auc[0], + "auc_10": auc[1], + "auc_20": auc[2], + "map_5": map_5, + "map_10": map_10, + "map_20": map_20, + } diff --git a/third_party/Roma/roma/checkpointing/__init__.py b/third_party/Roma/roma/checkpointing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..22f5afe727aa6f6e8fffa9ecf5be69cbff686577 --- /dev/null +++ b/third_party/Roma/roma/checkpointing/__init__.py @@ -0,0 +1 @@ +from .checkpoint import CheckPoint diff --git a/third_party/Roma/roma/checkpointing/checkpoint.py b/third_party/Roma/roma/checkpointing/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..8995efeb54f4d558127ea63423fa958c64e9088f --- /dev/null +++ b/third_party/Roma/roma/checkpointing/checkpoint.py @@ -0,0 +1,60 @@ +import os +import torch +from torch.nn.parallel.data_parallel import DataParallel +from torch.nn.parallel.distributed import DistributedDataParallel +from loguru import logger +import gc + +import roma + +class CheckPoint: + def __init__(self, dir=None, name="tmp"): + self.name = name + self.dir = dir + os.makedirs(self.dir, exist_ok=True) + + def save( + self, + model, + optimizer, + lr_scheduler, + n, + ): + if roma.RANK == 0: + assert model is not None + if isinstance(model, (DataParallel, DistributedDataParallel)): + model = model.module + states = { + "model": model.state_dict(), + "n": n, + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + } + torch.save(states, self.dir + self.name + f"_latest.pth") + logger.info(f"Saved states {list(states.keys())}, at step {n}") + + def load( + self, + model, + optimizer, + lr_scheduler, + n, + ): + if os.path.exists(self.dir + self.name + f"_latest.pth") and roma.RANK == 0: + states = torch.load(self.dir + self.name + f"_latest.pth") + if "model" in states: + model.load_state_dict(states["model"]) + if "n" in states: + n = states["n"] if states["n"] else n + if "optimizer" in states: + try: + optimizer.load_state_dict(states["optimizer"]) + except Exception as e: + print(f"Failed to load states for optimizer, with error {e}") + if "lr_scheduler" in states: + lr_scheduler.load_state_dict(states["lr_scheduler"]) + print(f"Loaded states {list(states.keys())}, at step {n}") + del states + gc.collect() + torch.cuda.empty_cache() + return model, optimizer, lr_scheduler, n \ No newline at end of file diff --git a/third_party/Roma/roma/datasets/__init__.py b/third_party/Roma/roma/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b60c709926a4a7bd019b73eac10879063a996c90 --- /dev/null +++ b/third_party/Roma/roma/datasets/__init__.py @@ -0,0 +1,2 @@ +from .megadepth import MegadepthBuilder +from .scannet import ScanNetBuilder \ No newline at end of file diff --git a/third_party/Roma/roma/datasets/megadepth.py b/third_party/Roma/roma/datasets/megadepth.py new file mode 100644 index 0000000000000000000000000000000000000000..5deee5ac30c439a9f300c0ad2271f141931020c0 --- /dev/null +++ b/third_party/Roma/roma/datasets/megadepth.py @@ -0,0 +1,230 @@ +import os +from PIL import Image +import h5py +import numpy as np +import torch +import torchvision.transforms.functional as tvf +import kornia.augmentation as K +from roma.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops +import roma +from roma.utils import * +import math + +class MegadepthScene: + def __init__( + self, + data_root, + scene_info, + ht=384, + wt=512, + min_overlap=0.0, + max_overlap=1.0, + shake_t=0, + rot_prob=0.0, + normalize=True, + max_num_pairs = 100_000, + scene_name = None, + use_horizontal_flip_aug = False, + use_single_horizontal_flip_aug = False, + colorjiggle_params = None, + random_eraser = None, + use_randaug = False, + randaug_params = None, + randomize_size = False, + ) -> None: + self.data_root = data_root + self.scene_name = os.path.splitext(scene_name)[0]+f"_{min_overlap}_{max_overlap}" + self.image_paths = scene_info["image_paths"] + self.depth_paths = scene_info["depth_paths"] + self.intrinsics = scene_info["intrinsics"] + self.poses = scene_info["poses"] + self.pairs = scene_info["pairs"] + self.overlaps = scene_info["overlaps"] + threshold = (self.overlaps > min_overlap) & (self.overlaps < max_overlap) + self.pairs = self.pairs[threshold] + self.overlaps = self.overlaps[threshold] + if len(self.pairs) > max_num_pairs: + pairinds = np.random.choice( + np.arange(0, len(self.pairs)), max_num_pairs, replace=False + ) + self.pairs = self.pairs[pairinds] + self.overlaps = self.overlaps[pairinds] + if randomize_size: + area = ht * wt + s = int(16 * (math.sqrt(area)//16)) + sizes = ((ht,wt), (s,s), (wt,ht)) + choice = roma.RANK % 3 + ht, wt = sizes[choice] + # counts, bins = np.histogram(self.overlaps,20) + # print(counts) + self.im_transform_ops = get_tuple_transform_ops( + resize=(ht, wt), normalize=normalize, colorjiggle_params = colorjiggle_params, + ) + self.depth_transform_ops = get_depth_tuple_transform_ops( + resize=(ht, wt) + ) + self.wt, self.ht = wt, ht + self.shake_t = shake_t + self.random_eraser = random_eraser + if use_horizontal_flip_aug and use_single_horizontal_flip_aug: + raise ValueError("Can't both flip both images and only flip one") + self.use_horizontal_flip_aug = use_horizontal_flip_aug + self.use_single_horizontal_flip_aug = use_single_horizontal_flip_aug + self.use_randaug = use_randaug + + def load_im(self, im_path): + im = Image.open(im_path) + return im + + def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B): + im_A = im_A.flip(-1) + im_B = im_B.flip(-1) + depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1) + flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device) + K_A = flip_mat@K_A + K_B = flip_mat@K_B + + return im_A, im_B, depth_A, depth_B, K_A, K_B + + def load_depth(self, depth_ref, crop=None): + depth = np.array(h5py.File(depth_ref, "r")["depth"]) + return torch.from_numpy(depth) + + def __len__(self): + return len(self.pairs) + + def scale_intrinsic(self, K, wi, hi): + sx, sy = self.wt / wi, self.ht / hi + sK = torch.tensor([[sx, 0, 0], [0, sy, 0], [0, 0, 1]]) + return sK @ K + + def rand_shake(self, *things): + t = np.random.choice(range(-self.shake_t, self.shake_t + 1), size=2) + return [ + tvf.affine(thing, angle=0.0, translate=list(t), scale=1.0, shear=[0.0, 0.0]) + for thing in things + ], t + + def __getitem__(self, pair_idx): + # read intrinsics of original size + idx1, idx2 = self.pairs[pair_idx] + K1 = torch.tensor(self.intrinsics[idx1].copy(), dtype=torch.float).reshape(3, 3) + K2 = torch.tensor(self.intrinsics[idx2].copy(), dtype=torch.float).reshape(3, 3) + + # read and compute relative poses + T1 = self.poses[idx1] + T2 = self.poses[idx2] + T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[ + :4, :4 + ] # (4, 4) + + # Load positive pair data + im_A, im_B = self.image_paths[idx1], self.image_paths[idx2] + depth1, depth2 = self.depth_paths[idx1], self.depth_paths[idx2] + im_A_ref = os.path.join(self.data_root, im_A) + im_B_ref = os.path.join(self.data_root, im_B) + depth_A_ref = os.path.join(self.data_root, depth1) + depth_B_ref = os.path.join(self.data_root, depth2) + im_A = self.load_im(im_A_ref) + im_B = self.load_im(im_B_ref) + K1 = self.scale_intrinsic(K1, im_A.width, im_A.height) + K2 = self.scale_intrinsic(K2, im_B.width, im_B.height) + + if self.use_randaug: + im_A, im_B = self.rand_augment(im_A, im_B) + + depth_A = self.load_depth(depth_A_ref) + depth_B = self.load_depth(depth_B_ref) + # Process images + im_A, im_B = self.im_transform_ops((im_A, im_B)) + depth_A, depth_B = self.depth_transform_ops( + (depth_A[None, None], depth_B[None, None]) + ) + + [im_A, im_B, depth_A, depth_B], t = self.rand_shake(im_A, im_B, depth_A, depth_B) + K1[:2, 2] += t + K2[:2, 2] += t + + im_A, im_B = im_A[None], im_B[None] + if self.random_eraser is not None: + im_A, depth_A = self.random_eraser(im_A, depth_A) + im_B, depth_B = self.random_eraser(im_B, depth_B) + + if self.use_horizontal_flip_aug: + if np.random.rand() > 0.5: + im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2) + if self.use_single_horizontal_flip_aug: + if np.random.rand() > 0.5: + im_B, depth_B, K2 = self.single_horizontal_flip(im_B, depth_B, K2) + + if roma.DEBUG_MODE: + tensor_to_pil(im_A[0], unnormalize=True).save( + f"vis/im_A.jpg") + tensor_to_pil(im_B[0], unnormalize=True).save( + f"vis/im_B.jpg") + + data_dict = { + "im_A": im_A[0], + "im_A_identifier": self.image_paths[idx1].split("/")[-1].split(".jpg")[0], + "im_B": im_B[0], + "im_B_identifier": self.image_paths[idx2].split("/")[-1].split(".jpg")[0], + "im_A_depth": depth_A[0, 0], + "im_B_depth": depth_B[0, 0], + "K1": K1, + "K2": K2, + "T_1to2": T_1to2, + "im_A_path": im_A_ref, + "im_B_path": im_B_ref, + + } + return data_dict + + +class MegadepthBuilder: + def __init__(self, data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True) -> None: + self.data_root = data_root + self.scene_info_root = os.path.join(data_root, "prep_scene_info") + self.all_scenes = os.listdir(self.scene_info_root) + self.test_scenes = ["0017.npy", "0004.npy", "0048.npy", "0013.npy"] + # LoFTR did the D2-net preprocessing differently than we did and got more ignore scenes, can optionially ignore those + self.loftr_ignore_scenes = set(['0121.npy', '0133.npy', '0168.npy', '0178.npy', '0229.npy', '0349.npy', '0412.npy', '0430.npy', '0443.npy', '1001.npy', '5014.npy', '5015.npy', '5016.npy']) + self.imc21_scenes = set(['0008.npy', '0019.npy', '0021.npy', '0024.npy', '0025.npy', '0032.npy', '0063.npy', '1589.npy']) + self.test_scenes_loftr = ["0015.npy", "0022.npy"] + self.loftr_ignore = loftr_ignore + self.imc21_ignore = imc21_ignore + + def build_scenes(self, split="train", min_overlap=0.0, scene_names = None, **kwargs): + if split == "train": + scene_names = set(self.all_scenes) - set(self.test_scenes) + elif split == "train_loftr": + scene_names = set(self.all_scenes) - set(self.test_scenes_loftr) + elif split == "test": + scene_names = self.test_scenes + elif split == "test_loftr": + scene_names = self.test_scenes_loftr + elif split == "custom": + scene_names = scene_names + else: + raise ValueError(f"Split {split} not available") + scenes = [] + for scene_name in scene_names: + if self.loftr_ignore and scene_name in self.loftr_ignore_scenes: + continue + if self.imc21_ignore and scene_name in self.imc21_scenes: + continue + scene_info = np.load( + os.path.join(self.scene_info_root, scene_name), allow_pickle=True + ).item() + scenes.append( + MegadepthScene( + self.data_root, scene_info, min_overlap=min_overlap,scene_name = scene_name, **kwargs + ) + ) + return scenes + + def weight_scenes(self, concat_dataset, alpha=0.5): + ns = [] + for d in concat_dataset.datasets: + ns.append(len(d)) + ws = torch.cat([torch.ones(n) / n**alpha for n in ns]) + return ws diff --git a/third_party/Roma/roma/datasets/scannet.py b/third_party/Roma/roma/datasets/scannet.py new file mode 100644 index 0000000000000000000000000000000000000000..704ea57259afdfbbca627ad143bee97a0a79d41c --- /dev/null +++ b/third_party/Roma/roma/datasets/scannet.py @@ -0,0 +1,160 @@ +import os +import random +from PIL import Image +import cv2 +import h5py +import numpy as np +import torch +from torch.utils.data import ( + Dataset, + DataLoader, + ConcatDataset) + +import torchvision.transforms.functional as tvf +import kornia.augmentation as K +import os.path as osp +import matplotlib.pyplot as plt +import roma +from roma.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops +from roma.utils.transforms import GeometricSequential +from tqdm import tqdm + +class ScanNetScene: + def __init__(self, data_root, scene_info, ht = 384, wt = 512, min_overlap=0., shake_t = 0, rot_prob=0.,use_horizontal_flip_aug = False, +) -> None: + self.scene_root = osp.join(data_root,"scans","scans_train") + self.data_names = scene_info['name'] + self.overlaps = scene_info['score'] + # Only sample 10s + valid = (self.data_names[:,-2:] % 10).sum(axis=-1) == 0 + self.overlaps = self.overlaps[valid] + self.data_names = self.data_names[valid] + if len(self.data_names) > 10000: + pairinds = np.random.choice(np.arange(0,len(self.data_names)),10000,replace=False) + self.data_names = self.data_names[pairinds] + self.overlaps = self.overlaps[pairinds] + self.im_transform_ops = get_tuple_transform_ops(resize=(ht, wt), normalize=True) + self.depth_transform_ops = get_depth_tuple_transform_ops(resize=(ht, wt), normalize=False) + self.wt, self.ht = wt, ht + self.shake_t = shake_t + self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob)) + self.use_horizontal_flip_aug = use_horizontal_flip_aug + + def load_im(self, im_B, crop=None): + im = Image.open(im_B) + return im + + def load_depth(self, depth_ref, crop=None): + depth = cv2.imread(str(depth_ref), cv2.IMREAD_UNCHANGED) + depth = depth / 1000 + depth = torch.from_numpy(depth).float() # (h, w) + return depth + + def __len__(self): + return len(self.data_names) + + def scale_intrinsic(self, K, wi, hi): + sx, sy = self.wt / wi, self.ht / hi + sK = torch.tensor([[sx, 0, 0], + [0, sy, 0], + [0, 0, 1]]) + return sK@K + + def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B): + im_A = im_A.flip(-1) + im_B = im_B.flip(-1) + depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1) + flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device) + K_A = flip_mat@K_A + K_B = flip_mat@K_B + + return im_A, im_B, depth_A, depth_B, K_A, K_B + def read_scannet_pose(self,path): + """ Read ScanNet's Camera2World pose and transform it to World2Camera. + + Returns: + pose_w2c (np.ndarray): (4, 4) + """ + cam2world = np.loadtxt(path, delimiter=' ') + world2cam = np.linalg.inv(cam2world) + return world2cam + + + def read_scannet_intrinsic(self,path): + """ Read ScanNet's intrinsic matrix and return the 3x3 matrix. + """ + intrinsic = np.loadtxt(path, delimiter=' ') + return torch.tensor(intrinsic[:-1, :-1], dtype = torch.float) + + def __getitem__(self, pair_idx): + # read intrinsics of original size + data_name = self.data_names[pair_idx] + scene_name, scene_sub_name, stem_name_1, stem_name_2 = data_name + scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}' + + # read the intrinsic of depthmap + K1 = K2 = self.read_scannet_intrinsic(osp.join(self.scene_root, + scene_name, + 'intrinsic', 'intrinsic_color.txt'))#the depth K is not the same, but doesnt really matter + # read and compute relative poses + T1 = self.read_scannet_pose(osp.join(self.scene_root, + scene_name, + 'pose', f'{stem_name_1}.txt')) + T2 = self.read_scannet_pose(osp.join(self.scene_root, + scene_name, + 'pose', f'{stem_name_2}.txt')) + T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[:4, :4] # (4, 4) + + # Load positive pair data + im_A_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_1}.jpg') + im_B_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_2}.jpg') + depth_A_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_1}.png') + depth_B_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_2}.png') + + im_A = self.load_im(im_A_ref) + im_B = self.load_im(im_B_ref) + depth_A = self.load_depth(depth_A_ref) + depth_B = self.load_depth(depth_B_ref) + + # Recompute camera intrinsic matrix due to the resize + K1 = self.scale_intrinsic(K1, im_A.width, im_A.height) + K2 = self.scale_intrinsic(K2, im_B.width, im_B.height) + # Process images + im_A, im_B = self.im_transform_ops((im_A, im_B)) + depth_A, depth_B = self.depth_transform_ops((depth_A[None,None], depth_B[None,None])) + if self.use_horizontal_flip_aug: + if np.random.rand() > 0.5: + im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2) + + data_dict = {'im_A': im_A, + 'im_B': im_B, + 'im_A_depth': depth_A[0,0], + 'im_B_depth': depth_B[0,0], + 'K1': K1, + 'K2': K2, + 'T_1to2':T_1to2, + } + return data_dict + + +class ScanNetBuilder: + def __init__(self, data_root = 'data/scannet') -> None: + self.data_root = data_root + self.scene_info_root = os.path.join(data_root,'scannet_indices') + self.all_scenes = os.listdir(self.scene_info_root) + + def build_scenes(self, split = 'train', min_overlap=0., **kwargs): + # Note: split doesn't matter here as we always use same scannet_train scenes + scene_names = self.all_scenes + scenes = [] + for scene_name in tqdm(scene_names, disable = roma.RANK > 0): + scene_info = np.load(os.path.join(self.scene_info_root,scene_name), allow_pickle=True) + scenes.append(ScanNetScene(self.data_root, scene_info, min_overlap=min_overlap, **kwargs)) + return scenes + + def weight_scenes(self, concat_dataset, alpha=.5): + ns = [] + for d in concat_dataset.datasets: + ns.append(len(d)) + ws = torch.cat([torch.ones(n)/n**alpha for n in ns]) + return ws diff --git a/third_party/Roma/roma/losses/__init__.py b/third_party/Roma/roma/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2e08abacfc0f83d7de0f2ddc0583766a80bf53cf --- /dev/null +++ b/third_party/Roma/roma/losses/__init__.py @@ -0,0 +1 @@ +from .robust_loss import RobustLosses \ No newline at end of file diff --git a/third_party/Roma/roma/losses/robust_loss.py b/third_party/Roma/roma/losses/robust_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..b932b2706f619c083485e1be0d86eec44ead83ef --- /dev/null +++ b/third_party/Roma/roma/losses/robust_loss.py @@ -0,0 +1,157 @@ +from einops.einops import rearrange +import torch +import torch.nn as nn +import torch.nn.functional as F +from roma.utils.utils import get_gt_warp +import wandb +import roma +import math + +class RobustLosses(nn.Module): + def __init__( + self, + robust=False, + center_coords=False, + scale_normalize=False, + ce_weight=0.01, + local_loss=True, + local_dist=4.0, + local_largest_scale=8, + smooth_mask = False, + depth_interpolation_mode = "bilinear", + mask_depth_loss = False, + relative_depth_error_threshold = 0.05, + alpha = 1., + c = 1e-3, + ): + super().__init__() + self.robust = robust # measured in pixels + self.center_coords = center_coords + self.scale_normalize = scale_normalize + self.ce_weight = ce_weight + self.local_loss = local_loss + self.local_dist = local_dist + self.local_largest_scale = local_largest_scale + self.smooth_mask = smooth_mask + self.depth_interpolation_mode = depth_interpolation_mode + self.mask_depth_loss = mask_depth_loss + self.relative_depth_error_threshold = relative_depth_error_threshold + self.avg_overlap = dict() + self.alpha = alpha + self.c = c + + def gm_cls_loss(self, x2, prob, scale_gm_cls, gm_certainty, scale): + with torch.no_grad(): + B, C, H, W = scale_gm_cls.shape + device = x2.device + cls_res = round(math.sqrt(C)) + G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)]) + G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2) + GT = (G[None,:,None,None,:]-x2[:,None]).norm(dim=-1).min(dim=1).indices + cls_loss = F.cross_entropy(scale_gm_cls, GT, reduction = 'none')[prob > 0.99] + if not torch.any(cls_loss): + cls_loss = (certainty_loss * 0.0) # Prevent issues where prob is 0 everywhere + + certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:,0], prob) + losses = { + f"gm_certainty_loss_{scale}": certainty_loss.mean(), + f"gm_cls_loss_{scale}": cls_loss.mean(), + } + wandb.log(losses, step = roma.GLOBAL_STEP) + return losses + + def delta_cls_loss(self, x2, prob, flow_pre_delta, delta_cls, certainty, scale, offset_scale): + with torch.no_grad(): + B, C, H, W = delta_cls.shape + device = x2.device + cls_res = round(math.sqrt(C)) + G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)]) + G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2) * offset_scale + GT = (G[None,:,None,None,:] + flow_pre_delta[:,None] - x2[:,None]).norm(dim=-1).min(dim=1).indices + cls_loss = F.cross_entropy(delta_cls, GT, reduction = 'none')[prob > 0.99] + if not torch.any(cls_loss): + cls_loss = (certainty_loss * 0.0) # Prevent issues where prob is 0 everywhere + certainty_loss = F.binary_cross_entropy_with_logits(certainty[:,0], prob) + losses = { + f"delta_certainty_loss_{scale}": certainty_loss.mean(), + f"delta_cls_loss_{scale}": cls_loss.mean(), + } + wandb.log(losses, step = roma.GLOBAL_STEP) + return losses + + def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode = "delta"): + epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1) + if scale == 1: + pck_05 = (epe[prob > 0.99] < 0.5 * (2/512)).float().mean() + wandb.log({"train_pck_05": pck_05}, step = roma.GLOBAL_STEP) + + ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], prob) + a = self.alpha + cs = self.c * scale + x = epe[prob > 0.99] + reg_loss = cs**a * ((x/(cs))**2 + 1**2)**(a/2) + if not torch.any(reg_loss): + reg_loss = (ce_loss * 0.0) # Prevent issues where prob is 0 everywhere + losses = { + f"{mode}_certainty_loss_{scale}": ce_loss.mean(), + f"{mode}_regression_loss_{scale}": reg_loss.mean(), + } + wandb.log(losses, step = roma.GLOBAL_STEP) + return losses + + def forward(self, corresps, batch): + scales = list(corresps.keys()) + tot_loss = 0.0 + # scale_weights due to differences in scale for regression gradients and classification gradients + scale_weights = {1:1, 2:1, 4:1, 8:1, 16:1} + for scale in scales: + scale_corresps = corresps[scale] + scale_certainty, flow_pre_delta, delta_cls, offset_scale, scale_gm_cls, scale_gm_certainty, flow, scale_gm_flow = ( + scale_corresps["certainty"], + scale_corresps["flow_pre_delta"], + scale_corresps.get("delta_cls"), + scale_corresps.get("offset_scale"), + scale_corresps.get("gm_cls"), + scale_corresps.get("gm_certainty"), + scale_corresps["flow"], + scale_corresps.get("gm_flow"), + + ) + flow_pre_delta = rearrange(flow_pre_delta, "b d h w -> b h w d") + b, h, w, d = flow_pre_delta.shape + gt_warp, gt_prob = get_gt_warp( + batch["im_A_depth"], + batch["im_B_depth"], + batch["T_1to2"], + batch["K1"], + batch["K2"], + H=h, + W=w, + ) + x2 = gt_warp.float() + prob = gt_prob + + if self.local_largest_scale >= scale: + prob = prob * ( + F.interpolate(prev_epe[:, None], size=(h, w), mode="nearest-exact")[:, 0] + < (2 / 512) * (self.local_dist[scale] * scale)) + + if scale_gm_cls is not None: + gm_cls_losses = self.gm_cls_loss(x2, prob, scale_gm_cls, scale_gm_certainty, scale) + gm_loss = self.ce_weight * gm_cls_losses[f"gm_certainty_loss_{scale}"] + gm_cls_losses[f"gm_cls_loss_{scale}"] + tot_loss = tot_loss + scale_weights[scale] * gm_loss + elif scale_gm_flow is not None: + gm_flow_losses = self.regression_loss(x2, prob, scale_gm_flow, scale_gm_certainty, scale, mode = "gm") + gm_loss = self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"] + gm_flow_losses[f"gm_regression_loss_{scale}"] + tot_loss = tot_loss + scale_weights[scale] * gm_loss + + if delta_cls is not None: + delta_cls_losses = self.delta_cls_loss(x2, prob, flow_pre_delta, delta_cls, scale_certainty, scale, offset_scale) + delta_cls_loss = self.ce_weight * delta_cls_losses[f"delta_certainty_loss_{scale}"] + delta_cls_losses[f"delta_cls_loss_{scale}"] + tot_loss = tot_loss + scale_weights[scale] * delta_cls_loss + else: + delta_regression_losses = self.regression_loss(x2, prob, flow, scale_certainty, scale) + reg_loss = self.ce_weight * delta_regression_losses[f"delta_certainty_loss_{scale}"] + delta_regression_losses[f"delta_regression_loss_{scale}"] + tot_loss = tot_loss + scale_weights[scale] * reg_loss + prev_epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1).detach() + return tot_loss diff --git a/third_party/Roma/roma/models/__init__.py b/third_party/Roma/roma/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f20461e2f3a1722e558cefab94c5164be8842c3 --- /dev/null +++ b/third_party/Roma/roma/models/__init__.py @@ -0,0 +1 @@ +from .model_zoo import roma_outdoor, roma_indoor \ No newline at end of file diff --git a/third_party/Roma/roma/models/encoders.py b/third_party/Roma/roma/models/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..69b488743b91905aca6adc3e4d3439421d492051 --- /dev/null +++ b/third_party/Roma/roma/models/encoders.py @@ -0,0 +1,118 @@ +from typing import Optional, Union +import torch +from torch import device +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as tvm +import gc + + +class ResNet50(nn.Module): + def __init__(self, pretrained=False, high_res = False, weights = None, dilation = None, freeze_bn = True, anti_aliased = False, early_exit = False, amp = False) -> None: + super().__init__() + if dilation is None: + dilation = [False,False,False] + if anti_aliased: + pass + else: + if weights is not None: + self.net = tvm.resnet50(weights = weights,replace_stride_with_dilation=dilation) + else: + self.net = tvm.resnet50(pretrained=pretrained,replace_stride_with_dilation=dilation) + + self.high_res = high_res + self.freeze_bn = freeze_bn + self.early_exit = early_exit + self.amp = amp + self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + + def forward(self, x, **kwargs): + with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype): + net = self.net + feats = {1:x} + x = net.conv1(x) + x = net.bn1(x) + x = net.relu(x) + feats[2] = x + x = net.maxpool(x) + x = net.layer1(x) + feats[4] = x + x = net.layer2(x) + feats[8] = x + if self.early_exit: + return feats + x = net.layer3(x) + feats[16] = x + x = net.layer4(x) + feats[32] = x + return feats + + def train(self, mode=True): + super().train(mode) + if self.freeze_bn: + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + pass + +class VGG19(nn.Module): + def __init__(self, pretrained=False, amp = False) -> None: + super().__init__() + self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40]) + self.amp = amp + self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + + def forward(self, x, **kwargs): + with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype): + feats = {} + scale = 1 + for layer in self.layers: + if isinstance(layer, nn.MaxPool2d): + feats[scale] = x + scale = scale*2 + x = layer(x) + return feats + +class CNNandDinov2(nn.Module): + def __init__(self, cnn_kwargs = None, amp = False, use_vgg = False, dinov2_weights = None): + super().__init__() + if dinov2_weights is None: + dinov2_weights = torch.hub.load_state_dict_from_url("https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", map_location="cpu") + from .transformer import vit_large + vit_kwargs = dict(img_size= 518, + patch_size= 14, + init_values = 1.0, + ffn_layer = "mlp", + block_chunks = 0, + ) + + dinov2_vitl14 = vit_large(**vit_kwargs).eval() + dinov2_vitl14.load_state_dict(dinov2_weights) + cnn_kwargs = cnn_kwargs if cnn_kwargs is not None else {} + if not use_vgg: + self.cnn = ResNet50(**cnn_kwargs) + else: + self.cnn = VGG19(**cnn_kwargs) + self.amp = amp + self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + if self.amp: + dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype) + self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP + + + def train(self, mode: bool = True): + return self.cnn.train(mode) + + def forward(self, x, upsample = False): + B,C,H,W = x.shape + feature_pyramid = self.cnn(x) + + if not upsample: + with torch.no_grad(): + if self.dinov2_vitl14[0].device != x.device: + self.dinov2_vitl14[0] = self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype) + dinov2_features_16 = self.dinov2_vitl14[0].forward_features(x.to(self.amp_dtype)) + features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,H//14, W//14) + del dinov2_features_16 + feature_pyramid[16] = features_16 + return feature_pyramid \ No newline at end of file diff --git a/third_party/Roma/roma/models/matcher.py b/third_party/Roma/roma/models/matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..c06e1ba3aebe8dec7ee9f1800a6f4ba55ac8f0d9 --- /dev/null +++ b/third_party/Roma/roma/models/matcher.py @@ -0,0 +1,649 @@ +import os +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +import warnings +from warnings import warn + +import roma +from roma.utils import get_tuple_transform_ops +from roma.utils.local_correlation import local_correlation +from roma.utils.utils import cls_to_flow_refine +from roma.utils.kde import kde + +class ConvRefiner(nn.Module): + def __init__( + self, + in_dim=6, + hidden_dim=16, + out_dim=2, + dw=False, + kernel_size=5, + hidden_blocks=3, + displacement_emb = None, + displacement_emb_dim = None, + local_corr_radius = None, + corr_in_other = None, + no_im_B_fm = False, + amp = False, + concat_logits = False, + use_bias_block_1 = True, + use_cosine_corr = False, + disable_local_corr_grad = False, + is_classifier = False, + sample_mode = "bilinear", + norm_type = nn.BatchNorm2d, + bn_momentum = 0.1, + ): + super().__init__() + self.bn_momentum = bn_momentum + self.block1 = self.create_block( + in_dim, hidden_dim, dw=dw, kernel_size=kernel_size, bias = use_bias_block_1, + ) + self.hidden_blocks = nn.Sequential( + *[ + self.create_block( + hidden_dim, + hidden_dim, + dw=dw, + kernel_size=kernel_size, + norm_type=norm_type, + ) + for hb in range(hidden_blocks) + ] + ) + self.hidden_blocks = self.hidden_blocks + self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0) + if displacement_emb: + self.has_displacement_emb = True + self.disp_emb = nn.Conv2d(2,displacement_emb_dim,1,1,0) + else: + self.has_displacement_emb = False + self.local_corr_radius = local_corr_radius + self.corr_in_other = corr_in_other + self.no_im_B_fm = no_im_B_fm + self.amp = amp + self.concat_logits = concat_logits + self.use_cosine_corr = use_cosine_corr + self.disable_local_corr_grad = disable_local_corr_grad + self.is_classifier = is_classifier + self.sample_mode = sample_mode + self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + + def create_block( + self, + in_dim, + out_dim, + dw=False, + kernel_size=5, + bias = True, + norm_type = nn.BatchNorm2d, + ): + num_groups = 1 if not dw else in_dim + if dw: + assert ( + out_dim % in_dim == 0 + ), "outdim must be divisible by indim for depthwise" + conv1 = nn.Conv2d( + in_dim, + out_dim, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + groups=num_groups, + bias=bias, + ) + norm = norm_type(out_dim, momentum = self.bn_momentum) if norm_type is nn.BatchNorm2d else norm_type(num_channels = out_dim) + relu = nn.ReLU(inplace=True) + conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0) + return nn.Sequential(conv1, norm, relu, conv2) + + def forward(self, x, y, flow, scale_factor = 1, logits = None): + b,c,hs,ws = x.shape + with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype): + with torch.no_grad(): + x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False, mode = self.sample_mode) + if self.has_displacement_emb: + im_A_coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device="cuda"), + torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device="cuda"), + ) + ) + im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0])) + im_A_coords = im_A_coords[None].expand(b, 2, hs, ws) + in_displacement = flow-im_A_coords + emb_in_displacement = self.disp_emb(40/32 * scale_factor * in_displacement) + if self.local_corr_radius: + if self.corr_in_other: + # Corr in other means take a kxk grid around the predicted coordinate in other image + local_corr = local_correlation(x,y,local_radius=self.local_corr_radius,flow = flow, + sample_mode = self.sample_mode) + else: + raise NotImplementedError("Local corr in own frame should not be used.") + if self.no_im_B_fm: + x_hat = torch.zeros_like(x) + d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1) + else: + d = torch.cat((x, x_hat, emb_in_displacement), dim=1) + else: + if self.no_im_B_fm: + x_hat = torch.zeros_like(x) + d = torch.cat((x, x_hat), dim=1) + if self.concat_logits: + d = torch.cat((d, logits), dim=1) + d = self.block1(d) + d = self.hidden_blocks(d) + d = self.out_conv(d.float()) + displacement, certainty = d[:, :-1], d[:, -1:] + return displacement, certainty + +class CosKernel(nn.Module): # similar to softmax kernel + def __init__(self, T, learn_temperature=False): + super().__init__() + self.learn_temperature = learn_temperature + if self.learn_temperature: + self.T = nn.Parameter(torch.tensor(T)) + else: + self.T = T + + def __call__(self, x, y, eps=1e-6): + c = torch.einsum("bnd,bmd->bnm", x, y) / ( + x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps + ) + if self.learn_temperature: + T = self.T.abs() + 0.01 + else: + T = torch.tensor(self.T, device=c.device) + K = ((c - 1.0) / T).exp() + return K + +class GP(nn.Module): + def __init__( + self, + kernel, + T=1, + learn_temperature=False, + only_attention=False, + gp_dim=64, + basis="fourier", + covar_size=5, + only_nearest_neighbour=False, + sigma_noise=0.1, + no_cov=False, + predict_features = False, + ): + super().__init__() + self.K = kernel(T=T, learn_temperature=learn_temperature) + self.sigma_noise = sigma_noise + self.covar_size = covar_size + self.pos_conv = torch.nn.Conv2d(2, gp_dim, 1, 1) + self.only_attention = only_attention + self.only_nearest_neighbour = only_nearest_neighbour + self.basis = basis + self.no_cov = no_cov + self.dim = gp_dim + self.predict_features = predict_features + + def get_local_cov(self, cov): + K = self.covar_size + b, h, w, h, w = cov.shape + hw = h * w + cov = F.pad(cov, 4 * (K // 2,)) # pad v_q + delta = torch.stack( + torch.meshgrid( + torch.arange(-(K // 2), K // 2 + 1), torch.arange(-(K // 2), K // 2 + 1) + ), + dim=-1, + ) + positions = torch.stack( + torch.meshgrid( + torch.arange(K // 2, h + K // 2), torch.arange(K // 2, w + K // 2) + ), + dim=-1, + ) + neighbours = positions[:, :, None, None, :] + delta[None, :, :] + points = torch.arange(hw)[:, None].expand(hw, K**2) + local_cov = cov.reshape(b, hw, h + K - 1, w + K - 1)[ + :, + points.flatten(), + neighbours[..., 0].flatten(), + neighbours[..., 1].flatten(), + ].reshape(b, h, w, K**2) + return local_cov + + def reshape(self, x): + return rearrange(x, "b d h w -> b (h w) d") + + def project_to_basis(self, x): + if self.basis == "fourier": + return torch.cos(8 * math.pi * self.pos_conv(x)) + elif self.basis == "linear": + return self.pos_conv(x) + else: + raise ValueError( + "No other bases other than fourier and linear currently im_Bed in public release" + ) + + def get_pos_enc(self, y): + b, c, h, w = y.shape + coarse_coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=y.device), + torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=y.device), + ) + ) + + coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[ + None + ].expand(b, h, w, 2) + coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w") + coarse_embedded_coords = self.project_to_basis(coarse_coords) + return coarse_embedded_coords + + def forward(self, x, y, **kwargs): + b, c, h1, w1 = x.shape + b, c, h2, w2 = y.shape + f = self.get_pos_enc(y) + b, d, h2, w2 = f.shape + x, y, f = self.reshape(x.float()), self.reshape(y.float()), self.reshape(f) + K_xx = self.K(x, x) + K_yy = self.K(y, y) + K_xy = self.K(x, y) + K_yx = K_xy.permute(0, 2, 1) + sigma_noise = self.sigma_noise * torch.eye(h2 * w2, device=x.device)[None, :, :] + with warnings.catch_warnings(): + K_yy_inv = torch.linalg.inv(K_yy + sigma_noise) + + mu_x = K_xy.matmul(K_yy_inv.matmul(f)) + mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1) + if not self.no_cov: + cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx)) + cov_x = rearrange(cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1) + local_cov_x = self.get_local_cov(cov_x) + local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w") + gp_feats = torch.cat((mu_x, local_cov_x), dim=1) + else: + gp_feats = mu_x + return gp_feats + +class Decoder(nn.Module): + def __init__( + self, embedding_decoder, gps, proj, conv_refiner, detach=False, scales="all", pos_embeddings = None, + num_refinement_steps_per_scale = 1, warp_noise_std = 0.0, displacement_dropout_p = 0.0, gm_warp_dropout_p = 0.0, + flow_upsample_mode = "bilinear" + ): + super().__init__() + self.embedding_decoder = embedding_decoder + self.num_refinement_steps_per_scale = num_refinement_steps_per_scale + self.gps = gps + self.proj = proj + self.conv_refiner = conv_refiner + self.detach = detach + if pos_embeddings is None: + self.pos_embeddings = {} + else: + self.pos_embeddings = pos_embeddings + if scales == "all": + self.scales = ["32", "16", "8", "4", "2", "1"] + else: + self.scales = scales + self.warp_noise_std = warp_noise_std + self.refine_init = 4 + self.displacement_dropout_p = displacement_dropout_p + self.gm_warp_dropout_p = gm_warp_dropout_p + self.flow_upsample_mode = flow_upsample_mode + self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + + def get_placeholder_flow(self, b, h, w, device): + coarse_coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device), + torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device), + ) + ) + coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[ + None + ].expand(b, h, w, 2) + coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w") + return coarse_coords + + def get_positional_embedding(self, b, h ,w, device): + coarse_coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device), + torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device), + ) + ) + + coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[ + None + ].expand(b, h, w, 2) + coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w") + coarse_embedded_coords = self.pos_embedding(coarse_coords) + return coarse_embedded_coords + + def forward(self, f1, f2, gt_warp = None, gt_prob = None, upsample = False, flow = None, certainty = None, scale_factor = 1): + coarse_scales = self.embedding_decoder.scales() + all_scales = self.scales if not upsample else ["8", "4", "2", "1"] + sizes = {scale: f1[scale].shape[-2:] for scale in f1} + h, w = sizes[1] + b = f1[1].shape[0] + device = f1[1].device + coarsest_scale = int(all_scales[0]) + old_stuff = torch.zeros( + b, self.embedding_decoder.hidden_dim, *sizes[coarsest_scale], device=f1[coarsest_scale].device + ) + corresps = {} + if not upsample: + flow = self.get_placeholder_flow(b, *sizes[coarsest_scale], device) + certainty = 0.0 + else: + flow = F.interpolate( + flow, + size=sizes[coarsest_scale], + align_corners=False, + mode="bilinear", + ) + certainty = F.interpolate( + certainty, + size=sizes[coarsest_scale], + align_corners=False, + mode="bilinear", + ) + displacement = 0.0 + for new_scale in all_scales: + ins = int(new_scale) + corresps[ins] = {} + f1_s, f2_s = f1[ins], f2[ins] + if new_scale in self.proj: + with torch.autocast("cuda", self.amp_dtype): + f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s) + + if ins in coarse_scales: + old_stuff = F.interpolate( + old_stuff, size=sizes[ins], mode="bilinear", align_corners=False + ) + gp_posterior = self.gps[new_scale](f1_s, f2_s) + gm_warp_or_cls, certainty, old_stuff = self.embedding_decoder( + gp_posterior, f1_s, old_stuff, new_scale + ) + + if self.embedding_decoder.is_classifier: + flow = cls_to_flow_refine( + gm_warp_or_cls, + ).permute(0,3,1,2) + corresps[ins].update({"gm_cls": gm_warp_or_cls,"gm_certainty": certainty,}) if self.training else None + else: + corresps[ins].update({"gm_flow": gm_warp_or_cls,"gm_certainty": certainty,}) if self.training else None + flow = gm_warp_or_cls.detach() + + if new_scale in self.conv_refiner: + corresps[ins].update({"flow_pre_delta": flow}) if self.training else None + delta_flow, delta_certainty = self.conv_refiner[new_scale]( + f1_s, f2_s, flow, scale_factor = scale_factor, logits = certainty, + ) + corresps[ins].update({"delta_flow": delta_flow,}) if self.training else None + displacement = ins*torch.stack((delta_flow[:, 0].float() / (self.refine_init * w), + delta_flow[:, 1].float() / (self.refine_init * h),),dim=1,) + flow = flow + displacement + certainty = ( + certainty + delta_certainty + ) # predict both certainty and displacement + corresps[ins].update({ + "certainty": certainty, + "flow": flow, + }) + if new_scale != "1": + flow = F.interpolate( + flow, + size=sizes[ins // 2], + mode=self.flow_upsample_mode, + ) + certainty = F.interpolate( + certainty, + size=sizes[ins // 2], + mode=self.flow_upsample_mode, + ) + if self.detach: + flow = flow.detach() + certainty = certainty.detach() + #torch.cuda.empty_cache() + return corresps + + +class RegressionMatcher(nn.Module): + def __init__( + self, + encoder, + decoder, + h=448, + w=448, + sample_mode = "threshold", + upsample_preds = False, + symmetric = False, + name = None, + attenuate_cert = None, + ): + super().__init__() + self.attenuate_cert = attenuate_cert + self.encoder = encoder + self.decoder = decoder + self.name = name + self.w_resized = w + self.h_resized = h + self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True) + self.sample_mode = sample_mode + self.upsample_preds = upsample_preds + self.upsample_res = (14*16*6, 14*16*6) + self.symmetric = symmetric + self.sample_thresh = 0.05 + + def get_output_resolution(self): + if not self.upsample_preds: + return self.h_resized, self.w_resized + else: + return self.upsample_res + + def extract_backbone_features(self, batch, batched = True, upsample = False): + x_q = batch["im_A"] + x_s = batch["im_B"] + if batched: + X = torch.cat((x_q, x_s), dim = 0) + feature_pyramid = self.encoder(X, upsample = upsample) + else: + feature_pyramid = self.encoder(x_q, upsample = upsample), self.encoder(x_s, upsample = upsample) + return feature_pyramid + + def sample( + self, + matches, + certainty, + num=10000, + ): + if "threshold" in self.sample_mode: + upper_thresh = self.sample_thresh + certainty = certainty.clone() + certainty[certainty > upper_thresh] = 1 + matches, certainty = ( + matches.reshape(-1, 4), + certainty.reshape(-1), + ) + expansion_factor = 4 if "balanced" in self.sample_mode else 1 + good_samples = torch.multinomial(certainty, + num_samples = min(expansion_factor*num, len(certainty)), + replacement=False) + good_matches, good_certainty = matches[good_samples], certainty[good_samples] + if "balanced" not in self.sample_mode: + return good_matches, good_certainty + density = kde(good_matches, std=0.1) + p = 1 / (density+1) + p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones + balanced_samples = torch.multinomial(p, + num_samples = min(num,len(good_certainty)), + replacement=False) + return good_matches[balanced_samples], good_certainty[balanced_samples] + + def forward(self, batch, batched = True, upsample = False, scale_factor = 1): + feature_pyramid = self.extract_backbone_features(batch, batched=batched, upsample = upsample) + if batched: + f_q_pyramid = { + scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items() + } + f_s_pyramid = { + scale: f_scale.chunk(2)[1] for scale, f_scale in feature_pyramid.items() + } + else: + f_q_pyramid, f_s_pyramid = feature_pyramid + corresps = self.decoder(f_q_pyramid, + f_s_pyramid, + upsample = upsample, + **(batch["corresps"] if "corresps" in batch else {}), + scale_factor=scale_factor) + + return corresps + + def forward_symmetric(self, batch, batched = True, upsample = False, scale_factor = 1): + feature_pyramid = self.extract_backbone_features(batch, batched = batched, upsample = upsample) + f_q_pyramid = feature_pyramid + f_s_pyramid = { + scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]), dim = 0) + for scale, f_scale in feature_pyramid.items() + } + corresps = self.decoder(f_q_pyramid, + f_s_pyramid, + upsample = upsample, + **(batch["corresps"] if "corresps" in batch else {}), + scale_factor=scale_factor) + return corresps + + def to_pixel_coordinates(self, matches, H_A, W_A, H_B, W_B): + kpts_A, kpts_B = matches[...,:2], matches[...,2:] + kpts_A = torch.stack((W_A/2 * (kpts_A[...,0]+1), H_A/2 * (kpts_A[...,1]+1)),axis=-1) + kpts_B = torch.stack((W_B/2 * (kpts_B[...,0]+1), H_B/2 * (kpts_B[...,1]+1)),axis=-1) + return kpts_A, kpts_B + + def match( + self, + im_A_path, + im_B_path, + *args, + batched=False, + device = None, + ): + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + from PIL import Image + if isinstance(im_A_path, (str, os.PathLike)): + im_A, im_B = Image.open(im_A_path), Image.open(im_B_path) + else: + # Assume its not a path + im_A, im_B = im_A_path, im_B_path + symmetric = self.symmetric + self.train(False) + with torch.no_grad(): + if not batched: + b = 1 + w, h = im_A.size + w2, h2 = im_B.size + # Get images in good format + ws = self.w_resized + hs = self.h_resized + + test_transform = get_tuple_transform_ops( + resize=(hs, ws), normalize=True, clahe = False + ) + im_A, im_B = test_transform((im_A, im_B)) + batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)} + else: + b, c, h, w = im_A.shape + b, c, h2, w2 = im_B.shape + assert w == w2 and h == h2, "For batched images we assume same size" + batch = {"im_A": im_A.to(device), "im_B": im_B.to(device)} + if h != self.h_resized or self.w_resized != w: + warn("Model resolution and batch resolution differ, may produce unexpected results") + hs, ws = h, w + finest_scale = 1 + # Run matcher + if symmetric: + corresps = self.forward_symmetric(batch) + else: + corresps = self.forward(batch, batched = True) + + if self.upsample_preds: + hs, ws = self.upsample_res + + if self.attenuate_cert: + low_res_certainty = F.interpolate( + corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear" + ) + cert_clamp = 0 + factor = 0.5 + low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp) + + if self.upsample_preds: + finest_corresps = corresps[finest_scale] + torch.cuda.empty_cache() + test_transform = get_tuple_transform_ops( + resize=(hs, ws), normalize=True + ) + im_A, im_B = Image.open(im_A_path), Image.open(im_B_path) + im_A, im_B = test_transform((im_A, im_B)) + im_A, im_B = im_A[None].to(device), im_B[None].to(device) + scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized)) + batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps} + if symmetric: + corresps = self.forward_symmetric(batch, upsample = True, batched=True, scale_factor = scale_factor) + else: + corresps = self.forward(batch, batched = True, upsample=True, scale_factor = scale_factor) + + im_A_to_im_B = corresps[finest_scale]["flow"] + certainty = corresps[finest_scale]["certainty"] - (low_res_certainty if self.attenuate_cert else 0) + if finest_scale != 1: + im_A_to_im_B = F.interpolate( + im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear" + ) + certainty = F.interpolate( + certainty, size=(hs, ws), align_corners=False, mode="bilinear" + ) + im_A_to_im_B = im_A_to_im_B.permute( + 0, 2, 3, 1 + ) + # Create im_A meshgrid + im_A_coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device="cuda"), + torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device="cuda"), + ) + ) + im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0])) + im_A_coords = im_A_coords[None].expand(b, 2, hs, ws) + certainty = certainty.sigmoid() # logits -> probs + im_A_coords = im_A_coords.permute(0, 2, 3, 1) + if (im_A_to_im_B.abs() > 1).any() and True: + wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0 + certainty[wrong[:,None]] = 0 + im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1) + if symmetric: + A_to_B, B_to_A = im_A_to_im_B.chunk(2) + q_warp = torch.cat((im_A_coords, A_to_B), dim=-1) + im_B_coords = im_A_coords + s_warp = torch.cat((B_to_A, im_B_coords), dim=-1) + warp = torch.cat((q_warp, s_warp),dim=2) + certainty = torch.cat(certainty.chunk(2), dim=3) + else: + warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1) + if batched: + return ( + warp, + certainty[:, 0] + ) + else: + return ( + warp[0], + certainty[0, 0], + ) + diff --git a/third_party/Roma/roma/models/model_zoo/__init__.py b/third_party/Roma/roma/models/model_zoo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..91edd4e69f2b39f18d62545a95f2774324ff404b --- /dev/null +++ b/third_party/Roma/roma/models/model_zoo/__init__.py @@ -0,0 +1,30 @@ +import torch +from .roma_models import roma_model + +weight_urls = { + "roma": { + "outdoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_outdoor.pth", + "indoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_indoor.pth", + }, + "dinov2": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", #hopefully this doesnt change :D +} + +def roma_outdoor(device, weights=None, dinov2_weights=None): + if weights is None: + weights = torch.hub.load_state_dict_from_url(weight_urls["roma"]["outdoor"], + map_location=device) + if dinov2_weights is None: + dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"], + map_location=device) + return roma_model(resolution=(14*8*6,14*8*6), upsample_preds=True, + weights=weights,dinov2_weights = dinov2_weights,device=device) + +def roma_indoor(device, weights=None, dinov2_weights=None): + if weights is None: + weights = torch.hub.load_state_dict_from_url(weight_urls["roma"]["indoor"], + map_location=device) + if dinov2_weights is None: + dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"], + map_location=device) + return roma_model(resolution=(14*8*5,14*8*5), upsample_preds=False, + weights=weights,dinov2_weights = dinov2_weights,device=device) diff --git a/third_party/Roma/roma/models/model_zoo/roma_models.py b/third_party/Roma/roma/models/model_zoo/roma_models.py new file mode 100644 index 0000000000000000000000000000000000000000..dfb0ff7264880d25f0feb0802e582bf29c84b051 --- /dev/null +++ b/third_party/Roma/roma/models/model_zoo/roma_models.py @@ -0,0 +1,157 @@ +import warnings +import torch.nn as nn +from roma.models.matcher import * +from roma.models.transformer import Block, TransformerDecoder, MemEffAttention +from roma.models.encoders import * + +def roma_model(resolution, upsample_preds, device = None, weights=None, dinov2_weights=None, **kwargs): + # roma weights and dinov2 weights are loaded seperately, as dinov2 weights are not parameters + torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul + torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn + warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') + gp_dim = 512 + feat_dim = 512 + decoder_dim = gp_dim + feat_dim + cls_to_coord_res = 64 + coordinate_decoder = TransformerDecoder( + nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]), + decoder_dim, + cls_to_coord_res**2 + 1, + is_classifier=True, + amp = True, + pos_enc = False,) + dw = True + hidden_blocks = 8 + kernel_size = 5 + displacement_emb = "linear" + disable_local_corr_grad = True + + conv_refiner = nn.ModuleDict( + { + "16": ConvRefiner( + 2 * 512+128+(2*7+1)**2, + 2 * 512+128+(2*7+1)**2, + 2 + 1, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=128, + local_corr_radius = 7, + corr_in_other = True, + amp = True, + disable_local_corr_grad = disable_local_corr_grad, + bn_momentum = 0.01, + ), + "8": ConvRefiner( + 2 * 512+64+(2*3+1)**2, + 2 * 512+64+(2*3+1)**2, + 2 + 1, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=64, + local_corr_radius = 3, + corr_in_other = True, + amp = True, + disable_local_corr_grad = disable_local_corr_grad, + bn_momentum = 0.01, + ), + "4": ConvRefiner( + 2 * 256+32+(2*2+1)**2, + 2 * 256+32+(2*2+1)**2, + 2 + 1, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=32, + local_corr_radius = 2, + corr_in_other = True, + amp = True, + disable_local_corr_grad = disable_local_corr_grad, + bn_momentum = 0.01, + ), + "2": ConvRefiner( + 2 * 64+16, + 128+16, + 2 + 1, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=16, + amp = True, + disable_local_corr_grad = disable_local_corr_grad, + bn_momentum = 0.01, + ), + "1": ConvRefiner( + 2 * 9 + 6, + 24, + 2 + 1, + kernel_size=kernel_size, + dw=dw, + hidden_blocks = hidden_blocks, + displacement_emb = displacement_emb, + displacement_emb_dim = 6, + amp = True, + disable_local_corr_grad = disable_local_corr_grad, + bn_momentum = 0.01, + ), + } + ) + kernel_temperature = 0.2 + learn_temperature = False + no_cov = True + kernel = CosKernel + only_attention = False + basis = "fourier" + gp16 = GP( + kernel, + T=kernel_temperature, + learn_temperature=learn_temperature, + only_attention=only_attention, + gp_dim=gp_dim, + basis=basis, + no_cov=no_cov, + ) + gps = nn.ModuleDict({"16": gp16}) + proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512)) + proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512)) + proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256)) + proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64)) + proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9)) + proj = nn.ModuleDict({ + "16": proj16, + "8": proj8, + "4": proj4, + "2": proj2, + "1": proj1, + }) + displacement_dropout_p = 0.0 + gm_warp_dropout_p = 0.0 + decoder = Decoder(coordinate_decoder, + gps, + proj, + conv_refiner, + detach=True, + scales=["16", "8", "4", "2", "1"], + displacement_dropout_p = displacement_dropout_p, + gm_warp_dropout_p = gm_warp_dropout_p) + + encoder = CNNandDinov2( + cnn_kwargs = dict( + pretrained=False, + amp = True), + amp = True, + use_vgg = True, + dinov2_weights = dinov2_weights + ) + h,w = resolution + symmetric = True + attenuate_cert = True + matcher = RegressionMatcher(encoder, decoder, h=h, w=w, upsample_preds=upsample_preds, + symmetric = symmetric, attenuate_cert=attenuate_cert, **kwargs).to(device) + matcher.load_state_dict(weights) + return matcher diff --git a/third_party/Roma/roma/models/transformer/__init__.py b/third_party/Roma/roma/models/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4770ebb19f111df14f1539fa3696553d96d4e48b --- /dev/null +++ b/third_party/Roma/roma/models/transformer/__init__.py @@ -0,0 +1,47 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from roma.utils.utils import get_grid +from .layers.block import Block +from .layers.attention import MemEffAttention +from .dinov2 import vit_large + +class TransformerDecoder(nn.Module): + def __init__(self, blocks, hidden_dim, out_dim, is_classifier = False, *args, + amp = False, pos_enc = True, learned_embeddings = False, embedding_dim = None, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.blocks = blocks + self.to_out = nn.Linear(hidden_dim, out_dim) + self.hidden_dim = hidden_dim + self.out_dim = out_dim + self._scales = [16] + self.is_classifier = is_classifier + self.amp = amp + self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + self.pos_enc = pos_enc + self.learned_embeddings = learned_embeddings + if self.learned_embeddings: + self.learned_pos_embeddings = nn.Parameter(nn.init.kaiming_normal_(torch.empty((1, hidden_dim, embedding_dim, embedding_dim)))) + + def scales(self): + return self._scales.copy() + + def forward(self, gp_posterior, features, old_stuff, new_scale): + with torch.autocast("cuda", dtype=self.amp_dtype, enabled=self.amp): + B,C,H,W = gp_posterior.shape + x = torch.cat((gp_posterior, features), dim = 1) + B,C,H,W = x.shape + grid = get_grid(B, H, W, x.device).reshape(B,H*W,2) + if self.learned_embeddings: + pos_enc = F.interpolate(self.learned_pos_embeddings, size = (H,W), mode = 'bilinear', align_corners = False).permute(0,2,3,1).reshape(1,H*W,C) + else: + pos_enc = 0 + tokens = x.reshape(B,C,H*W).permute(0,2,1) + pos_enc + z = self.blocks(tokens) + out = self.to_out(z) + out = out.permute(0,2,1).reshape(B, self.out_dim, H, W) + warp, certainty = out[:, :-1], out[:, -1:] + return warp, certainty, None + + diff --git a/third_party/Roma/roma/models/transformer/dinov2.py b/third_party/Roma/roma/models/transformer/dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..b556c63096d17239c8603d5fe626c331963099fd --- /dev/null +++ b/third_party/Roma/roma/models/transformer/dinov2.py @@ -0,0 +1,359 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +from .layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + for param in self.parameters(): + param.requires_grad = False + + @property + def device(self): + return self.cls_token.device + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode="bicubic", + ) + + assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_patchtokens": x_norm[:, 1:], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_patchtokens": x_norm[:, 1:], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1:] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +def vit_base(patch_size=16, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +def vit_large(patch_size=16, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + **kwargs, + ) + return model \ No newline at end of file diff --git a/third_party/Roma/roma/models/transformer/layers/__init__.py b/third_party/Roma/roma/models/transformer/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..31f196aacac5be8a7c537a3dfa8f97084671b466 --- /dev/null +++ b/third_party/Roma/roma/models/transformer/layers/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .dino_head import DINOHead +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/third_party/Roma/roma/models/transformer/layers/attention.py b/third_party/Roma/roma/models/transformer/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..1f9b0c94b40967dfdff4f261c127cbd21328c905 --- /dev/null +++ b/third_party/Roma/roma/models/transformer/layers/attention.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging + +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import memory_efficient_attention, unbind, fmha + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + assert attn_bias is None, "xFormers is required for nested tensors usage" + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/third_party/Roma/roma/models/transformer/layers/block.py b/third_party/Roma/roma/models/transformer/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..25488f57cc0ad3c692f86b62555f6668e2a66db1 --- /dev/null +++ b/third_party/Roma/roma/models/transformer/layers/block.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +from typing import Callable, List, Any, Tuple, Dict + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import fmha + from xformers.ops import scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/third_party/Roma/roma/models/transformer/layers/dino_head.py b/third_party/Roma/roma/models/transformer/layers/dino_head.py new file mode 100644 index 0000000000000000000000000000000000000000..7212db92a4fd8d4c7230e284e551a0234e9d8623 --- /dev/null +++ b/third_party/Roma/roma/models/transformer/layers/dino_head.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.nn.utils import weight_norm + + +class DINOHead(nn.Module): + def __init__( + self, + in_dim, + out_dim, + use_bn=False, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + mlp_bias=True, + ): + super().__init__() + nlayers = max(nlayers, 1) + self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) + self.apply(self._init_weights) + self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + eps = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + x = self.last_layer(x) + return x + + +def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + else: + layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + return nn.Sequential(*layers) diff --git a/third_party/Roma/roma/models/transformer/layers/drop_path.py b/third_party/Roma/roma/models/transformer/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..af05625984dd14682cc96a63bf0c97bab1f123b1 --- /dev/null +++ b/third_party/Roma/roma/models/transformer/layers/drop_path.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/third_party/Roma/roma/models/transformer/layers/layer_scale.py b/third_party/Roma/roma/models/transformer/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5daa52bd81d3581adeb2198ea5b7dba2a3aea1 --- /dev/null +++ b/third_party/Roma/roma/models/transformer/layers/layer_scale.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/third_party/Roma/roma/models/transformer/layers/mlp.py b/third_party/Roma/roma/models/transformer/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4b315f972f9a9f54aef1e4ef4e81b52976f018 --- /dev/null +++ b/third_party/Roma/roma/models/transformer/layers/mlp.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/third_party/Roma/roma/models/transformer/layers/patch_embed.py b/third_party/Roma/roma/models/transformer/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..574abe41175568d700a389b8b96d1ba554914779 --- /dev/null +++ b/third_party/Roma/roma/models/transformer/layers/patch_embed.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/third_party/Roma/roma/models/transformer/layers/swiglu_ffn.py b/third_party/Roma/roma/models/transformer/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..b3324b266fb0a50ccf8c3a0ede2ae10ac4dfa03e --- /dev/null +++ b/third_party/Roma/roma/models/transformer/layers/swiglu_ffn.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +try: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/third_party/Roma/roma/train/__init__.py b/third_party/Roma/roma/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..90269dc0f345a575e0ba21f5afa34202c7e6b433 --- /dev/null +++ b/third_party/Roma/roma/train/__init__.py @@ -0,0 +1 @@ +from .train import train_k_epochs diff --git a/third_party/Roma/roma/train/train.py b/third_party/Roma/roma/train/train.py new file mode 100644 index 0000000000000000000000000000000000000000..5556f7ebf9b6378e1395c125dde093f5e55e7141 --- /dev/null +++ b/third_party/Roma/roma/train/train.py @@ -0,0 +1,102 @@ +from tqdm import tqdm +from roma.utils.utils import to_cuda +import roma +import torch +import wandb + +def log_param_statistics(named_parameters, norm_type = 2): + named_parameters = list(named_parameters) + grads = [p.grad for n, p in named_parameters if p.grad is not None] + weight_norms = [p.norm(p=norm_type) for n, p in named_parameters if p.grad is not None] + names = [n for n,p in named_parameters if p.grad is not None] + param_norm = torch.stack(weight_norms).norm(p=norm_type) + device = grads[0].device + grad_norms = torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]) + nans_or_infs = torch.isinf(grad_norms) | torch.isnan(grad_norms) + nan_inf_names = [name for name, naninf in zip(names, nans_or_infs) if naninf] + total_grad_norm = torch.norm(grad_norms, norm_type) + if torch.any(nans_or_infs): + print(f"These params have nan or inf grads: {nan_inf_names}") + wandb.log({"grad_norm": total_grad_norm.item()}, step = roma.GLOBAL_STEP) + wandb.log({"param_norm": param_norm.item()}, step = roma.GLOBAL_STEP) + +def train_step(train_batch, model, objective, optimizer, grad_scaler, grad_clip_norm = 1.,**kwargs): + optimizer.zero_grad() + out = model(train_batch) + l = objective(out, train_batch) + grad_scaler.scale(l).backward() + grad_scaler.unscale_(optimizer) + log_param_statistics(model.named_parameters()) + torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm) # what should max norm be? + grad_scaler.step(optimizer) + grad_scaler.update() + wandb.log({"grad_scale": grad_scaler._scale.item()}, step = roma.GLOBAL_STEP) + if grad_scaler._scale < 1.: + grad_scaler._scale = torch.tensor(1.).to(grad_scaler._scale) + roma.GLOBAL_STEP = roma.GLOBAL_STEP + roma.STEP_SIZE # increment global step + return {"train_out": out, "train_loss": l.item()} + + +def train_k_steps( + n_0, k, dataloader, model, objective, optimizer, lr_scheduler, grad_scaler, progress_bar=True, grad_clip_norm = 1., warmup = None, ema_model = None, +): + for n in tqdm(range(n_0, n_0 + k), disable=(not progress_bar) or roma.RANK > 0): + batch = next(dataloader) + model.train(True) + batch = to_cuda(batch) + train_step( + train_batch=batch, + model=model, + objective=objective, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + grad_scaler=grad_scaler, + n=n, + grad_clip_norm = grad_clip_norm, + ) + if ema_model is not None: + ema_model.update() + if warmup is not None: + with warmup.dampening(): + lr_scheduler.step() + else: + lr_scheduler.step() + [wandb.log({f"lr_group_{grp}": lr}) for grp, lr in enumerate(lr_scheduler.get_last_lr())] + + +def train_epoch( + dataloader=None, + model=None, + objective=None, + optimizer=None, + lr_scheduler=None, + epoch=None, +): + model.train(True) + print(f"At epoch {epoch}") + for batch in tqdm(dataloader, mininterval=5.0): + batch = to_cuda(batch) + train_step( + train_batch=batch, model=model, objective=objective, optimizer=optimizer + ) + lr_scheduler.step() + return { + "model": model, + "optimizer": optimizer, + "lr_scheduler": lr_scheduler, + "epoch": epoch, + } + + +def train_k_epochs( + start_epoch, end_epoch, dataloader, model, objective, optimizer, lr_scheduler +): + for epoch in range(start_epoch, end_epoch + 1): + train_epoch( + dataloader=dataloader, + model=model, + objective=objective, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + epoch=epoch, + ) diff --git a/third_party/Roma/roma/utils/__init__.py b/third_party/Roma/roma/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2709f5e586150289085a4e2cbd458bc443fab7f3 --- /dev/null +++ b/third_party/Roma/roma/utils/__init__.py @@ -0,0 +1,16 @@ +from .utils import ( + pose_auc, + get_pose, + compute_relative_pose, + compute_pose_error, + estimate_pose, + estimate_pose_uncalibrated, + rotate_intrinsic, + get_tuple_transform_ops, + get_depth_tuple_transform_ops, + warp_kpts, + numpy_to_pil, + tensor_to_pil, + recover_pose, + signed_left_to_right_epipolar_distance, +) diff --git a/third_party/Roma/roma/utils/kde.py b/third_party/Roma/roma/utils/kde.py new file mode 100644 index 0000000000000000000000000000000000000000..90a058fb68253cfe23c2a7f21b213bea8e06cfe3 --- /dev/null +++ b/third_party/Roma/roma/utils/kde.py @@ -0,0 +1,8 @@ +import torch + +def kde(x, std = 0.1): + # use a gaussian kernel to estimate density + x = x.half() # Do it in half precision + scores = (-torch.cdist(x,x)**2/(2*std**2)).exp() + density = scores.sum(dim=-1) + return density \ No newline at end of file diff --git a/third_party/Roma/roma/utils/local_correlation.py b/third_party/Roma/roma/utils/local_correlation.py new file mode 100644 index 0000000000000000000000000000000000000000..586eef5f154a95968b253ad9701933b55b3a4dd6 --- /dev/null +++ b/third_party/Roma/roma/utils/local_correlation.py @@ -0,0 +1,47 @@ +import torch +import torch.nn.functional as F + +def local_correlation( + feature0, + feature1, + local_radius, + padding_mode="zeros", + flow = None, + sample_mode = "bilinear", +): + r = local_radius + K = (2*r+1)**2 + B, c, h, w = feature0.size() + feature0 = feature0.half() + feature1 = feature1.half() + corr = torch.empty((B,K,h,w), device = feature0.device, dtype=feature0.dtype) + if flow is None: + # If flow is None, assume feature0 and feature1 are aligned + coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device="cuda"), + torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device="cuda"), + )) + coords = torch.stack((coords[1], coords[0]), dim=-1)[ + None + ].expand(B, h, w, 2) + else: + coords = flow.permute(0,2,3,1) # If using flow, sample around flow target. + local_window = torch.meshgrid( + ( + torch.linspace(-2*local_radius/h, 2*local_radius/h, 2*r+1, device="cuda"), + torch.linspace(-2*local_radius/w, 2*local_radius/w, 2*r+1, device="cuda"), + )) + local_window = torch.stack((local_window[1], local_window[0]), dim=-1)[ + None + ].expand(1, 2*r+1, 2*r+1, 2).reshape(1, (2*r+1)**2, 2) + for _ in range(B): + with torch.no_grad(): + local_window_coords = (coords[_,:,:,None]+local_window[:,None,None]).reshape(1,h,w*(2*r+1)**2,2).float() + window_feature = F.grid_sample( + feature1[_:_+1].float(), local_window_coords, padding_mode=padding_mode, align_corners=False, mode = sample_mode, # + ) + window_feature = window_feature.reshape(c,h,w,(2*r+1)**2) + corr[_] = (feature0[_,...,None]/(c**.5)*window_feature).sum(dim=0).permute(2,0,1) + torch.cuda.empty_cache() + return corr \ No newline at end of file diff --git a/third_party/Roma/roma/utils/transforms.py b/third_party/Roma/roma/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..ea6476bd816a31df36f7d1b5417853637b65474b --- /dev/null +++ b/third_party/Roma/roma/utils/transforms.py @@ -0,0 +1,118 @@ +from typing import Dict +import numpy as np +import torch +import kornia.augmentation as K +from kornia.geometry.transform import warp_perspective + +# Adapted from Kornia +class GeometricSequential: + def __init__(self, *transforms, align_corners=True) -> None: + self.transforms = transforms + self.align_corners = align_corners + + def __call__(self, x, mode="bilinear"): + b, c, h, w = x.shape + M = torch.eye(3, device=x.device)[None].expand(b, 3, 3) + for t in self.transforms: + if np.random.rand() < t.p: + M = M.matmul( + t.compute_transformation(x, t.generate_parameters((b, c, h, w)), None) + ) + return ( + warp_perspective( + x, M, dsize=(h, w), mode=mode, align_corners=self.align_corners + ), + M, + ) + + def apply_transform(self, x, M, mode="bilinear"): + b, c, h, w = x.shape + return warp_perspective( + x, M, dsize=(h, w), align_corners=self.align_corners, mode=mode + ) + + +class RandomPerspective(K.RandomPerspective): + def generate_parameters(self, batch_shape: torch.Size) -> Dict[str, torch.Tensor]: + distortion_scale = torch.as_tensor( + self.distortion_scale, device=self._device, dtype=self._dtype + ) + return self.random_perspective_generator( + batch_shape[0], + batch_shape[-2], + batch_shape[-1], + distortion_scale, + self.same_on_batch, + self.device, + self.dtype, + ) + + def random_perspective_generator( + self, + batch_size: int, + height: int, + width: int, + distortion_scale: torch.Tensor, + same_on_batch: bool = False, + device: torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float32, + ) -> Dict[str, torch.Tensor]: + r"""Get parameters for ``perspective`` for a random perspective transform. + + Args: + batch_size (int): the tensor batch size. + height (int) : height of the image. + width (int): width of the image. + distortion_scale (torch.Tensor): it controls the degree of distortion and ranges from 0 to 1. + same_on_batch (bool): apply the same transformation across the batch. Default: False. + device (torch.device): the device on which the random numbers will be generated. Default: cpu. + dtype (torch.dtype): the data type of the generated random numbers. Default: float32. + + Returns: + params Dict[str, torch.Tensor]: parameters to be passed for transformation. + - start_points (torch.Tensor): element-wise perspective source areas with a shape of (B, 4, 2). + - end_points (torch.Tensor): element-wise perspective target areas with a shape of (B, 4, 2). + + Note: + The generated random numbers are not reproducible across different devices and dtypes. + """ + if not (distortion_scale.dim() == 0 and 0 <= distortion_scale <= 1): + raise AssertionError( + f"'distortion_scale' must be a scalar within [0, 1]. Got {distortion_scale}." + ) + if not ( + type(height) is int and height > 0 and type(width) is int and width > 0 + ): + raise AssertionError( + f"'height' and 'width' must be integers. Got {height}, {width}." + ) + + start_points: torch.Tensor = torch.tensor( + [[[0.0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]], + device=distortion_scale.device, + dtype=distortion_scale.dtype, + ).expand(batch_size, -1, -1) + + # generate random offset not larger than half of the image + fx = distortion_scale * width / 2 + fy = distortion_scale * height / 2 + + factor = torch.stack([fx, fy], dim=0).view(-1, 1, 2) + offset = (torch.rand_like(start_points) - 0.5) * 2 + end_points = start_points + factor * offset + + return dict(start_points=start_points, end_points=end_points) + + + +class RandomErasing: + def __init__(self, p = 0., scale = 0.) -> None: + self.p = p + self.scale = scale + self.random_eraser = K.RandomErasing(scale = (0.02, scale), p = p) + def __call__(self, image, depth): + if self.p > 0: + image = self.random_eraser(image) + depth = self.random_eraser(depth, params=self.random_eraser._params) + return image, depth + \ No newline at end of file diff --git a/third_party/Roma/roma/utils/utils.py b/third_party/Roma/roma/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d673f679823c833688e2548dd40bf50943796a71 --- /dev/null +++ b/third_party/Roma/roma/utils/utils.py @@ -0,0 +1,622 @@ +import warnings +import numpy as np +import cv2 +import math +import torch +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode +import torch.nn.functional as F +from PIL import Image +import kornia + +def recover_pose(E, kpts0, kpts1, K0, K1, mask): + best_num_inliers = 0 + K0inv = np.linalg.inv(K0[:2,:2]) + K1inv = np.linalg.inv(K1[:2,:2]) + + kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T + kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T + + for _E in np.split(E, len(E) / 3): + n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask) + if n > best_num_inliers: + best_num_inliers = n + ret = (R, t, mask.ravel() > 0) + return ret + + + +# Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py +# --- GEOMETRY --- +def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999): + if len(kpts0) < 5: + return None + K0inv = np.linalg.inv(K0[:2,:2]) + K1inv = np.linalg.inv(K1[:2,:2]) + + kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T + kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T + E, mask = cv2.findEssentialMat( + kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf + ) + + ret = None + if E is not None: + best_num_inliers = 0 + + for _E in np.split(E, len(E) / 3): + n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask) + if n > best_num_inliers: + best_num_inliers = n + ret = (R, t, mask.ravel() > 0) + return ret + +def estimate_pose_uncalibrated(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999): + if len(kpts0) < 5: + return None + method = cv2.USAC_ACCURATE + F, mask = cv2.findFundamentalMat( + kpts0, kpts1, ransacReprojThreshold=norm_thresh, confidence=conf, method=method, maxIters=10000 + ) + E = K1.T@F@K0 + ret = None + if E is not None: + best_num_inliers = 0 + K0inv = np.linalg.inv(K0[:2,:2]) + K1inv = np.linalg.inv(K1[:2,:2]) + + kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T + kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T + + for _E in np.split(E, len(E) / 3): + n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask) + if n > best_num_inliers: + best_num_inliers = n + ret = (R, t, mask.ravel() > 0) + return ret + +def unnormalize_coords(x_n,h,w): + x = torch.stack( + (w * (x_n[..., 0] + 1) / 2, h * (x_n[..., 1] + 1) / 2), dim=-1 + ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5] + return x + + +def rotate_intrinsic(K, n): + base_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]]) + rot = np.linalg.matrix_power(base_rot, n) + return rot @ K + + +def rotate_pose_inplane(i_T_w, rot): + rotation_matrices = [ + np.array( + [ + [np.cos(r), -np.sin(r), 0.0, 0.0], + [np.sin(r), np.cos(r), 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ], + dtype=np.float32, + ) + for r in [np.deg2rad(d) for d in (0, 270, 180, 90)] + ] + return np.dot(rotation_matrices[rot], i_T_w) + + +def scale_intrinsics(K, scales): + scales = np.diag([1.0 / scales[0], 1.0 / scales[1], 1.0]) + return np.dot(scales, K) + + +def to_homogeneous(points): + return np.concatenate([points, np.ones_like(points[:, :1])], axis=-1) + + +def angle_error_mat(R1, R2): + cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2 + cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds + return np.rad2deg(np.abs(np.arccos(cos))) + + +def angle_error_vec(v1, v2): + n = np.linalg.norm(v1) * np.linalg.norm(v2) + return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0))) + + +def compute_pose_error(T_0to1, R, t): + R_gt = T_0to1[:3, :3] + t_gt = T_0to1[:3, 3] + error_t = angle_error_vec(t.squeeze(), t_gt) + error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation + error_R = angle_error_mat(R, R_gt) + return error_t, error_R + + +def pose_auc(errors, thresholds): + sort_idx = np.argsort(errors) + errors = np.array(errors.copy())[sort_idx] + recall = (np.arange(len(errors)) + 1) / len(errors) + errors = np.r_[0.0, errors] + recall = np.r_[0.0, recall] + aucs = [] + for t in thresholds: + last_index = np.searchsorted(errors, t) + r = np.r_[recall[:last_index], recall[last_index - 1]] + e = np.r_[errors[:last_index], t] + aucs.append(np.trapz(r, x=e) / t) + return aucs + + +# From Patch2Pix https://github.com/GrumpyZhou/patch2pix +def get_depth_tuple_transform_ops_nearest_exact(resize=None): + ops = [] + if resize: + ops.append(TupleResizeNearestExact(resize)) + return TupleCompose(ops) + +def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False): + ops = [] + if resize: + ops.append(TupleResize(resize, mode=InterpolationMode.BILINEAR)) + return TupleCompose(ops) + + +def get_tuple_transform_ops(resize=None, normalize=True, unscale=False, clahe = False, colorjiggle_params = None): + ops = [] + if resize: + ops.append(TupleResize(resize)) + ops.append(TupleToTensorScaled()) + if normalize: + ops.append( + TupleNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ) # Imagenet mean/std + return TupleCompose(ops) + +class ToTensorScaled(object): + """Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]""" + + def __call__(self, im): + if not isinstance(im, torch.Tensor): + im = np.array(im, dtype=np.float32).transpose((2, 0, 1)) + im /= 255.0 + return torch.from_numpy(im) + else: + return im + + def __repr__(self): + return "ToTensorScaled(./255)" + + +class TupleToTensorScaled(object): + def __init__(self): + self.to_tensor = ToTensorScaled() + + def __call__(self, im_tuple): + return [self.to_tensor(im) for im in im_tuple] + + def __repr__(self): + return "TupleToTensorScaled(./255)" + + +class ToTensorUnscaled(object): + """Convert a RGB PIL Image to a CHW ordered Tensor""" + + def __call__(self, im): + return torch.from_numpy(np.array(im, dtype=np.float32).transpose((2, 0, 1))) + + def __repr__(self): + return "ToTensorUnscaled()" + + +class TupleToTensorUnscaled(object): + """Convert a RGB PIL Image to a CHW ordered Tensor""" + + def __init__(self): + self.to_tensor = ToTensorUnscaled() + + def __call__(self, im_tuple): + return [self.to_tensor(im) for im in im_tuple] + + def __repr__(self): + return "TupleToTensorUnscaled()" + +class TupleResizeNearestExact: + def __init__(self, size): + self.size = size + def __call__(self, im_tuple): + return [F.interpolate(im, size = self.size, mode = 'nearest-exact') for im in im_tuple] + + def __repr__(self): + return "TupleResizeNearestExact(size={})".format(self.size) + + +class TupleResize(object): + def __init__(self, size, mode=InterpolationMode.BICUBIC): + self.size = size + self.resize = transforms.Resize(size, mode) + def __call__(self, im_tuple): + return [self.resize(im) for im in im_tuple] + + def __repr__(self): + return "TupleResize(size={})".format(self.size) + +class Normalize: + def __call__(self,im): + mean = im.mean(dim=(1,2), keepdims=True) + std = im.std(dim=(1,2), keepdims=True) + return (im-mean)/std + + +class TupleNormalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + self.normalize = transforms.Normalize(mean=mean, std=std) + + def __call__(self, im_tuple): + c,h,w = im_tuple[0].shape + if c > 3: + warnings.warn(f"Number of channels c={c} > 3, assuming first 3 are rgb") + return [self.normalize(im[:3]) for im in im_tuple] + + def __repr__(self): + return "TupleNormalize(mean={}, std={})".format(self.mean, self.std) + + +class TupleCompose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, im_tuple): + for t in self.transforms: + im_tuple = t(im_tuple) + return im_tuple + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" + return format_string + +@torch.no_grad() +def cls_to_flow(cls, deterministic_sampling = True): + B,C,H,W = cls.shape + device = cls.device + res = round(math.sqrt(C)) + G = torch.meshgrid(*[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)]) + G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2) + if deterministic_sampling: + sampled_cls = cls.max(dim=1).indices + else: + sampled_cls = torch.multinomial(cls.permute(0,2,3,1).reshape(B*H*W,C).softmax(dim=-1), 1).reshape(B,H,W) + flow = G[sampled_cls] + return flow + +@torch.no_grad() +def cls_to_flow_refine(cls): + B,C,H,W = cls.shape + device = cls.device + res = round(math.sqrt(C)) + G = torch.meshgrid(*[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)]) + G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2) + cls = cls.softmax(dim=1) + mode = cls.max(dim=1).indices + + index = torch.stack((mode-1, mode, mode+1, mode - res, mode + res), dim = 1).clamp(0,C - 1).long() + neighbours = torch.gather(cls, dim = 1, index = index)[...,None] + flow = neighbours[:,0] * G[index[:,0]] + neighbours[:,1] * G[index[:,1]] + neighbours[:,2] * G[index[:,2]] + neighbours[:,3] * G[index[:,3]] + neighbours[:,4] * G[index[:,4]] + tot_prob = neighbours.sum(dim=1) + flow = flow / tot_prob + return flow + + +def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bilinear', relative_depth_error_threshold = 0.05, H = None, W = None): + + if H is None: + B,H,W = depth1.shape + else: + B = depth1.shape[0] + with torch.no_grad(): + x1_n = torch.meshgrid( + *[ + torch.linspace( + -1 + 1 / n, 1 - 1 / n, n, device=depth1.device + ) + for n in (B, H, W) + ] + ) + x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2) + mask, x2 = warp_kpts( + x1_n.double(), + depth1.double(), + depth2.double(), + T_1to2.double(), + K1.double(), + K2.double(), + depth_interpolation_mode = depth_interpolation_mode, + relative_depth_error_threshold = relative_depth_error_threshold, + ) + prob = mask.float().reshape(B, H, W) + x2 = x2.reshape(B, H, W, 2) + return x2, prob + +@torch.no_grad() +def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return_relative_depth_error = False, depth_interpolation_mode = "bilinear", relative_depth_error_threshold = 0.05): + """Warp kpts0 from I0 to I1 with depth, K and Rt + Also check covisibility and depth consistency. + Depth is consistent if relative error < 0.2 (hard-coded). + # https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here + Args: + kpts0 (torch.Tensor): [N, L, 2] - , should be normalized in (-1,1) + depth0 (torch.Tensor): [N, H, W], + depth1 (torch.Tensor): [N, H, W], + T_0to1 (torch.Tensor): [N, 3, 4], + K0 (torch.Tensor): [N, 3, 3], + K1 (torch.Tensor): [N, 3, 3], + Returns: + calculable_mask (torch.Tensor): [N, L] + warped_keypoints0 (torch.Tensor): [N, L, 2] + """ + ( + n, + h, + w, + ) = depth0.shape + if depth_interpolation_mode == "combined": + # Inspired by approach in inloc, try to fill holes from bilinear interpolation by nearest neighbour interpolation + if smooth_mask: + raise NotImplementedError("Combined bilinear and NN warp not implemented") + valid_bilinear, warp_bilinear = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, + smooth_mask = smooth_mask, + return_relative_depth_error = return_relative_depth_error, + depth_interpolation_mode = "bilinear", + relative_depth_error_threshold = relative_depth_error_threshold) + valid_nearest, warp_nearest = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, + smooth_mask = smooth_mask, + return_relative_depth_error = return_relative_depth_error, + depth_interpolation_mode = "nearest-exact", + relative_depth_error_threshold = relative_depth_error_threshold) + nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest) + warp = warp_bilinear.clone() + warp[nearest_valid_bilinear_invalid] = warp_nearest[nearest_valid_bilinear_invalid] + valid = valid_bilinear | valid_nearest + return valid, warp + + + kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode = depth_interpolation_mode, align_corners=False)[ + :, 0, :, 0 + ] + kpts0 = torch.stack( + (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1 + ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5] + # Sample depth, get calculable_mask on depth != 0 + nonzero_mask = kpts0_depth != 0 + + # Unproject + kpts0_h = ( + torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) + * kpts0_depth[..., None] + ) # (N, L, 3) + kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L) + kpts0_cam = kpts0_n + + # Rigid Transform + w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L) + w_kpts0_depth_computed = w_kpts0_cam[:, 2, :] + + # Project + w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3) + w_kpts0 = w_kpts0_h[:, :, :2] / ( + w_kpts0_h[:, :, [2]] + 1e-4 + ) # (N, L, 2), +1e-4 to avoid zero depth + + # Covisible Check + h, w = depth1.shape[1:3] + covisible_mask = ( + (w_kpts0[:, :, 0] > 0) + * (w_kpts0[:, :, 0] < w - 1) + * (w_kpts0[:, :, 1] > 0) + * (w_kpts0[:, :, 1] < h - 1) + ) + w_kpts0 = torch.stack( + (2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1 + ) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h] + # w_kpts0[~covisible_mask, :] = -5 # xd + + w_kpts0_depth = F.grid_sample( + depth1[:, None], w_kpts0[:, :, None], mode=depth_interpolation_mode, align_corners=False + )[:, 0, :, 0] + + relative_depth_error = ( + (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth + ).abs() + if not smooth_mask: + consistent_mask = relative_depth_error < relative_depth_error_threshold + else: + consistent_mask = (-relative_depth_error/smooth_mask).exp() + valid_mask = nonzero_mask * covisible_mask * consistent_mask + if return_relative_depth_error: + return relative_depth_error, w_kpts0 + else: + return valid_mask, w_kpts0 + +imagenet_mean = torch.tensor([0.485, 0.456, 0.406]) +imagenet_std = torch.tensor([0.229, 0.224, 0.225]) + + +def numpy_to_pil(x: np.ndarray): + """ + Args: + x: Assumed to be of shape (h,w,c) + """ + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + if x.max() <= 1.01: + x *= 255 + x = x.astype(np.uint8) + return Image.fromarray(x) + + +def tensor_to_pil(x, unnormalize=False): + if unnormalize: + x = x * (imagenet_std[:, None, None].to(x.device)) + (imagenet_mean[:, None, None].to(x.device)) + x = x.detach().permute(1, 2, 0).cpu().numpy() + x = np.clip(x, 0.0, 1.0) + return numpy_to_pil(x) + + +def to_cuda(batch): + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + batch[key] = value.cuda() + return batch + + +def to_cpu(batch): + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + batch[key] = value.cpu() + return batch + + +def get_pose(calib): + w, h = np.array(calib["imsize"])[0] + return np.array(calib["K"]), np.array(calib["R"]), np.array(calib["T"]).T, h, w + + +def compute_relative_pose(R1, t1, R2, t2): + rots = R2 @ (R1.T) + trans = -rots @ t1 + t2 + return rots, trans + +@torch.no_grad() +def reset_opt(opt): + for group in opt.param_groups: + for p in group['params']: + if p.requires_grad: + state = opt.state[p] + # State initialization + + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + # Exponential moving average of gradient difference + state['exp_avg_diff'] = torch.zeros_like(p) + + +def flow_to_pixel_coords(flow, h1, w1): + flow = ( + torch.stack( + ( + w1 * (flow[..., 0] + 1) / 2, + h1 * (flow[..., 1] + 1) / 2, + ), + axis=-1, + ) + ) + return flow + +def flow_to_normalized_coords(flow, h1, w1): + flow = ( + torch.stack( + ( + 2 * (flow[..., 0]) / w1 - 1, + 2 * (flow[..., 1]) / h1 - 1, + ), + axis=-1, + ) + ) + return flow + + +def warp_to_pixel_coords(warp, h1, w1, h2, w2): + warp1 = warp[..., :2] + warp1 = ( + torch.stack( + ( + w1 * (warp1[..., 0] + 1) / 2, + h1 * (warp1[..., 1] + 1) / 2, + ), + axis=-1, + ) + ) + warp2 = warp[..., 2:] + warp2 = ( + torch.stack( + ( + w2 * (warp2[..., 0] + 1) / 2, + h2 * (warp2[..., 1] + 1) / 2, + ), + axis=-1, + ) + ) + return torch.cat((warp1,warp2), dim=-1) + + + +def signed_point_line_distance(point, line, eps: float = 1e-9): + r"""Return the distance from points to lines. + + Args: + point: (possibly homogeneous) points :math:`(*, N, 2 or 3)`. + line: lines coefficients :math:`(a, b, c)` with shape :math:`(*, N, 3)`, where :math:`ax + by + c = 0`. + eps: Small constant for safe sqrt. + + Returns: + the computed distance with shape :math:`(*, N)`. + """ + + if not point.shape[-1] in (2, 3): + raise ValueError(f"pts must be a (*, 2 or 3) tensor. Got {point.shape}") + + if not line.shape[-1] == 3: + raise ValueError(f"lines must be a (*, 3) tensor. Got {line.shape}") + + numerator = (line[..., 0] * point[..., 0] + line[..., 1] * point[..., 1] + line[..., 2]) + denominator = line[..., :2].norm(dim=-1) + + return numerator / (denominator + eps) + + +def signed_left_to_right_epipolar_distance(pts1, pts2, Fm): + r"""Return one-sided epipolar distance for correspondences given the fundamental matrix. + + This method measures the distance from points in the right images to the epilines + of the corresponding points in the left images as they reflect in the right images. + + Args: + pts1: correspondences from the left images with shape + :math:`(*, N, 2 or 3)`. If they are not homogeneous, converted automatically. + pts2: correspondences from the right images with shape + :math:`(*, N, 2 or 3)`. If they are not homogeneous, converted automatically. + Fm: Fundamental matrices with shape :math:`(*, 3, 3)`. Called Fm to + avoid ambiguity with torch.nn.functional. + + Returns: + the computed Symmetrical distance with shape :math:`(*, N)`. + """ + import kornia + if (len(Fm.shape) < 3) or not Fm.shape[-2:] == (3, 3): + raise ValueError(f"Fm must be a (*, 3, 3) tensor. Got {Fm.shape}") + + if pts1.shape[-1] == 2: + pts1 = kornia.geometry.convert_points_to_homogeneous(pts1) + + F_t = Fm.transpose(dim0=-2, dim1=-1) + line1_in_2 = pts1 @ F_t + + return signed_point_line_distance(pts2, line1_in_2) + +def get_grid(b, h, w, device): + grid = torch.meshgrid( + *[ + torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=device) + for n in (b, h, w) + ] + ) + grid = torch.stack((grid[2], grid[1]), dim=-1).reshape(b, h, w, 2) + return grid diff --git a/third_party/Roma/setup.py b/third_party/Roma/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..ae777c0e5a41f0e4b03a838d19bc9a2bb04d4617 --- /dev/null +++ b/third_party/Roma/setup.py @@ -0,0 +1,9 @@ +from setuptools import setup + +setup( + name="roma", + packages=["roma"], + version="0.0.1", + author="Johan Edstedt", + install_requires=open("requirements.txt", "r").read().split("\n"), +)