diff --git a/third_party/DKM/.gitignore b/third_party/DKM/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..07442492a552f5e0f3feadd0b992d92792ddb3bb
--- /dev/null
+++ b/third_party/DKM/.gitignore
@@ -0,0 +1,3 @@
+*.egg-info*
+*.vscode*
+*__pycache__*
\ No newline at end of file
diff --git a/third_party/DKM/LICENSE b/third_party/DKM/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..1625fcb9c1046af4180f55b58acff245814a2c2e
--- /dev/null
+++ b/third_party/DKM/LICENSE
@@ -0,0 +1,25 @@
+NOTE! Models trained on our synthetic dataset uses datasets which are licensed under non-commercial licenses.
+Hence we cannot provide them under the MIT license. However, MegaDepth is under MIT license, hence we provide those models under MIT license, see below.
+
+
+License for Models Trained on MegaDepth ONLY below:
+
+Copyright (c) 2022 Johan Edstedt
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/third_party/DKM/README.md b/third_party/DKM/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..c68fbfaf848e5c2a64adc308351aad38f94b0830
--- /dev/null
+++ b/third_party/DKM/README.md
@@ -0,0 +1,117 @@
+# DKM: Dense Kernelized Feature Matching for Geometry Estimation
+### [Project Page](https://parskatt.github.io/DKM) | [Paper](https://arxiv.org/abs/2202.00667)
+
+
+> DKM: Dense Kernelized Feature Matching for Geometry Estimation
+> [Johan Edstedt](https://scholar.google.com/citations?user=Ul-vMR0AAAAJ), [Ioannis Athanasiadis](https://scholar.google.com/citations?user=RCAtJgUAAAAJ), [Mårten Wadenbäck](https://scholar.google.com/citations?user=6WRQpCQAAAAJ), [Michael Felsberg](https://scholar.google.com/citations?&user=lkWfR08AAAAJ)
+> CVPR 2023
+
+## How to Use?
+
+Our model produces a dense (for all pixels) warp and certainty.
+
+Warp: [B,H,W,4] for all images in batch of size B, for each pixel HxW, we ouput the input and matching coordinate in the normalized grids [-1,1]x[-1,1].
+
+Certainty: [B,H,W] a number in each pixel indicating the matchability of the pixel.
+
+See [demo](dkm/demo/) for two demos of DKM.
+
+See [api.md](docs/api.md) for API.
+
+
+## Qualitative Results
+
+
+https://user-images.githubusercontent.com/22053118/223748279-0f0c21b4-376a-440a-81f5-7f9a5d87483f.mp4
+
+
+https://user-images.githubusercontent.com/22053118/223748512-1bca4a17-cffa-491d-a448-96aac1353ce9.mp4
+
+
+
+https://user-images.githubusercontent.com/22053118/223748518-4d475d9f-a933-4581-97ed-6e9413c4caca.mp4
+
+
+
+https://user-images.githubusercontent.com/22053118/223748522-39c20631-aa16-4954-9c27-95763b38f2ce.mp4
+
+
+
+
+
+
+## Benchmark Results
+
+
+
+### Megadepth1500
+
+| | @5 | @10 | @20 |
+|-------|-------|------|------|
+| DKMv1 | 54.5 | 70.7 | 82.3 |
+| DKMv2 | *56.8* | *72.3* | *83.2* |
+| DKMv3 (paper) | **60.5** | **74.9** | **85.1** |
+| DKMv3 (this repo) | **60.0** | **74.6** | **84.9** |
+
+### Megadepth 8 Scenes
+| | @5 | @10 | @20 |
+|-------|-------|------|------|
+| DKMv3 (paper) | **60.5** | **74.5** | **84.2** |
+| DKMv3 (this repo) | **60.4** | **74.6** | **84.3** |
+
+
+### ScanNet1500
+| | @5 | @10 | @20 |
+|-------|-------|------|------|
+| DKMv1 | 24.8 | 44.4 | 61.9 |
+| DKMv2 | *28.2* | *49.2* | *66.6* |
+| DKMv3 (paper) | **29.4** | **50.7** | **68.3** |
+| DKMv3 (this repo) | **29.8** | **50.8** | **68.3** |
+
+
+
+## Navigating the Code
+* Code for models can be found in [dkm/models](dkm/models/)
+* Code for benchmarks can be found in [dkm/benchmarks](dkm/benchmarks/)
+* Code for reproducing experiments from our paper can be found in [experiments/](experiments/)
+
+## Install
+Run ``pip install -e .``
+
+## Demo
+
+A demonstration of our method can be run by:
+``` bash
+python demo_match.py
+```
+This runs our model trained on mega on two images taken from Sacre Coeur.
+
+## Benchmarks
+See [Benchmarks](docs/benchmarks.md) for details.
+## Training
+See [Training](docs/training.md) for details.
+## Reproducing Results
+Given that the required benchmark or training dataset has been downloaded and unpacked, results can be reproduced by running the experiments in the experiments folder.
+
+## Using DKM matches for estimation
+We recommend using the excellent Graph-Cut RANSAC algorithm: https://github.com/danini/graph-cut-ransac
+
+| | @5 | @10 | @20 |
+|-------|-------|------|------|
+| DKMv3 (RANSAC) | *60.5* | *74.9* | *85.1* |
+| DKMv3 (GC-RANSAC) | **65.5** | **78.0** | **86.7** |
+
+
+## Acknowledgements
+We have used code and been inspired by https://github.com/PruneTruong/DenseMatching, https://github.com/zju3dv/LoFTR, and https://github.com/GrumpyZhou/patch2pix. We additionally thank the authors of ECO-TR for providing their benchmark.
+
+## BibTeX
+If you find our models useful, please consider citing our paper!
+```
+@inproceedings{edstedt2023dkm,
+title={{DKM}: Dense Kernelized Feature Matching for Geometry Estimation},
+author={Edstedt, Johan and Athanasiadis, Ioannis and Wadenbäck, Mårten and Felsberg, Michael},
+booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+year={2023}
+}
+```
diff --git a/third_party/DKM/assets/ams_hom_A.jpg b/third_party/DKM/assets/ams_hom_A.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b2c1c35a5316d823dd88dcac6247b7fc952feb21
--- /dev/null
+++ b/third_party/DKM/assets/ams_hom_A.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:271a19f0b29fc88d8f88d1136f001078ca6bf5105ff95355f89a18e787c50e3a
+size 1194880
diff --git a/third_party/DKM/assets/ams_hom_B.jpg b/third_party/DKM/assets/ams_hom_B.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..008ed1fbaf39135a1ded50a502a86701d9d900ca
--- /dev/null
+++ b/third_party/DKM/assets/ams_hom_B.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d84ced12e607f5ac5f7628151694fbaa2300caa091ac168e0aedad2ebaf491d6
+size 1208132
diff --git a/third_party/DKM/assets/dkmv3_warp.jpg b/third_party/DKM/assets/dkmv3_warp.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f8251c324f0bb713215b53b73d68330b15f87550
--- /dev/null
+++ b/third_party/DKM/assets/dkmv3_warp.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:04c46e39d5ea68e9e116d4ae71c038a459beaf3eed89e8b7b87ccafd01d3bf85
+size 571179
diff --git a/third_party/DKM/assets/mega_8_scenes_0008_0.1_0.3.npz b/third_party/DKM/assets/mega_8_scenes_0008_0.1_0.3.npz
new file mode 100644
index 0000000000000000000000000000000000000000..1fd5b337e1d27ca569230c80902e762a503b75d0
--- /dev/null
+++ b/third_party/DKM/assets/mega_8_scenes_0008_0.1_0.3.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c902547181fc9b370fdd16272140be6803fe983aea978c68683db803ac70dd57
+size 906160
diff --git a/third_party/DKM/assets/mega_8_scenes_0008_0.3_0.5.npz b/third_party/DKM/assets/mega_8_scenes_0008_0.3_0.5.npz
new file mode 100644
index 0000000000000000000000000000000000000000..464f56196138cd064182a33a4bf0171bb0df62e1
--- /dev/null
+++ b/third_party/DKM/assets/mega_8_scenes_0008_0.3_0.5.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:65ce02bd248988b42363ccd257abaa9b99a00d569d2779597b36ba6c4da35021
+size 906160
diff --git a/third_party/DKM/assets/mega_8_scenes_0019_0.1_0.3.npz b/third_party/DKM/assets/mega_8_scenes_0019_0.1_0.3.npz
new file mode 100644
index 0000000000000000000000000000000000000000..bd4ccc895c4aa9900a610468fb6976e07e2dc362
--- /dev/null
+++ b/third_party/DKM/assets/mega_8_scenes_0019_0.1_0.3.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c6104feb8807a4ebdd1266160e67b3c507c550012f54c23292d0ebf99b88753f
+size 368192
diff --git a/third_party/DKM/assets/mega_8_scenes_0019_0.3_0.5.npz b/third_party/DKM/assets/mega_8_scenes_0019_0.3_0.5.npz
new file mode 100644
index 0000000000000000000000000000000000000000..189a6ab6c4be9c58970b8ec93bcc1d91f1e33b17
--- /dev/null
+++ b/third_party/DKM/assets/mega_8_scenes_0019_0.3_0.5.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9600ba5c24d414f63728bf5ee7550a3b035d7c615461e357590890ae0e0f042e
+size 368192
diff --git a/third_party/DKM/assets/mega_8_scenes_0021_0.1_0.3.npz b/third_party/DKM/assets/mega_8_scenes_0021_0.1_0.3.npz
new file mode 100644
index 0000000000000000000000000000000000000000..408847e2613016fd0e1a34cef376fb94925011b5
--- /dev/null
+++ b/third_party/DKM/assets/mega_8_scenes_0021_0.1_0.3.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:de89e9ccf10515cc4196ba1e7172ec98b2fb92ff9f85d90db5df1af5b6503313
+size 167528
diff --git a/third_party/DKM/assets/mega_8_scenes_0021_0.3_0.5.npz b/third_party/DKM/assets/mega_8_scenes_0021_0.3_0.5.npz
new file mode 100644
index 0000000000000000000000000000000000000000..f85a48ca4d1843959c61eefd2d200d0ffc31c87d
--- /dev/null
+++ b/third_party/DKM/assets/mega_8_scenes_0021_0.3_0.5.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:94c97b57beb10411b3b98a1e88c9e1e2f9db51994dce04580d2b7cfc8919dab3
+size 167528
diff --git a/third_party/DKM/assets/mega_8_scenes_0024_0.1_0.3.npz b/third_party/DKM/assets/mega_8_scenes_0024_0.1_0.3.npz
new file mode 100644
index 0000000000000000000000000000000000000000..29974c7f4dda288c93dbb3551342486581079d62
--- /dev/null
+++ b/third_party/DKM/assets/mega_8_scenes_0024_0.1_0.3.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9f14a66dbbd7fa8f31756dd496bfabe4c3ea115c6914acad9365dd02e46ae674
+size 63909
diff --git a/third_party/DKM/assets/mega_8_scenes_0024_0.3_0.5.npz b/third_party/DKM/assets/mega_8_scenes_0024_0.3_0.5.npz
new file mode 100644
index 0000000000000000000000000000000000000000..dcdb08c233530b2e9a88aedb1ad1bc496a3075cb
--- /dev/null
+++ b/third_party/DKM/assets/mega_8_scenes_0024_0.3_0.5.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dfaee333beccd1da0d920777cdc8f17d584b21ba20f675b39222c5a205acf72a
+size 63909
diff --git a/third_party/DKM/assets/mega_8_scenes_0025_0.1_0.3.npz b/third_party/DKM/assets/mega_8_scenes_0025_0.1_0.3.npz
new file mode 100644
index 0000000000000000000000000000000000000000..9e8ae03ed59aeba49b3223c72169d77e935a7e33
--- /dev/null
+++ b/third_party/DKM/assets/mega_8_scenes_0025_0.1_0.3.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b446ca3cc2073c8a3963cf68cc450ef2ebf73d2b956b1f5ae6b37621bc67cce4
+size 200371
diff --git a/third_party/DKM/assets/mega_8_scenes_0025_0.3_0.5.npz b/third_party/DKM/assets/mega_8_scenes_0025_0.3_0.5.npz
new file mode 100644
index 0000000000000000000000000000000000000000..08c57fa1e419e7e1802b1ddfb00bd30a1c27d785
--- /dev/null
+++ b/third_party/DKM/assets/mega_8_scenes_0025_0.3_0.5.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:df1969fd94032562b5e8d916467101a878168d586b795770e64c108bab250c9e
+size 200371
diff --git a/third_party/DKM/assets/mega_8_scenes_0032_0.1_0.3.npz b/third_party/DKM/assets/mega_8_scenes_0032_0.1_0.3.npz
new file mode 100644
index 0000000000000000000000000000000000000000..16b72b06ae0afd00af2ec626d9d492107ee64f1b
--- /dev/null
+++ b/third_party/DKM/assets/mega_8_scenes_0032_0.1_0.3.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:37cbafa0b0f981f5d69aba202ddd37c5892bda0fa13d053a3ad27d6ddad51c16
+size 642823
diff --git a/third_party/DKM/assets/mega_8_scenes_0032_0.3_0.5.npz b/third_party/DKM/assets/mega_8_scenes_0032_0.3_0.5.npz
new file mode 100644
index 0000000000000000000000000000000000000000..bbef8bff2ed410e893d34423941b32be4977a794
--- /dev/null
+++ b/third_party/DKM/assets/mega_8_scenes_0032_0.3_0.5.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:962b3fadc7c94ea8e4a1bb5e168e72b0b6cc474ae56b0aee70ba4e517553fbcf
+size 642823
diff --git a/third_party/DKM/assets/mega_8_scenes_0063_0.1_0.3.npz b/third_party/DKM/assets/mega_8_scenes_0063_0.1_0.3.npz
new file mode 100644
index 0000000000000000000000000000000000000000..fb1e5decaa554507d0f1eea82b1bba8bfa933b05
--- /dev/null
+++ b/third_party/DKM/assets/mega_8_scenes_0063_0.1_0.3.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:50ed6b02dff2fa719e4e9ca216b4704b82cbbefd127355d3ba7120828407e723
+size 228647
diff --git a/third_party/DKM/assets/mega_8_scenes_0063_0.3_0.5.npz b/third_party/DKM/assets/mega_8_scenes_0063_0.3_0.5.npz
new file mode 100644
index 0000000000000000000000000000000000000000..ef3248afb930afa46efc27054819fa313f6bf286
--- /dev/null
+++ b/third_party/DKM/assets/mega_8_scenes_0063_0.3_0.5.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:edc97129d000f0478020495f646f2fa7667408247ccd11054e02efbbb38d1444
+size 228647
diff --git a/third_party/DKM/assets/mega_8_scenes_1589_0.1_0.3.npz b/third_party/DKM/assets/mega_8_scenes_1589_0.1_0.3.npz
new file mode 100644
index 0000000000000000000000000000000000000000..b092500a792cd274bc1e1b91f488d237b01ce3b5
--- /dev/null
+++ b/third_party/DKM/assets/mega_8_scenes_1589_0.1_0.3.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:04b0b6c6adff812e12b66476f7ca2a6ed2564cdd8208ec0c775f7b922f160103
+size 177063
diff --git a/third_party/DKM/assets/mega_8_scenes_1589_0.3_0.5.npz b/third_party/DKM/assets/mega_8_scenes_1589_0.3_0.5.npz
new file mode 100644
index 0000000000000000000000000000000000000000..24c4448c682c2f3c103dd5d4813784dadebd10fa
--- /dev/null
+++ b/third_party/DKM/assets/mega_8_scenes_1589_0.3_0.5.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ae931f8cac1b2168f699c70efe42c215eaff27d3f0617d59afb3db183c9b1848
+size 177063
diff --git a/third_party/DKM/assets/mount_rushmore.mp4 b/third_party/DKM/assets/mount_rushmore.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..3beebd798c74749f2034e73f068d919d307d46b9
Binary files /dev/null and b/third_party/DKM/assets/mount_rushmore.mp4 differ
diff --git a/third_party/DKM/assets/sacre_coeur_A.jpg b/third_party/DKM/assets/sacre_coeur_A.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6e441dad34cf13d8a29d7c6a1519f4263c40058c
--- /dev/null
+++ b/third_party/DKM/assets/sacre_coeur_A.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:90d9c5f5a4d76425624989215120fba6f2899190a1d5654b88fa380c64cf6b2c
+size 117985
diff --git a/third_party/DKM/assets/sacre_coeur_B.jpg b/third_party/DKM/assets/sacre_coeur_B.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..27a239a8fa7581d909104872754ecda79422e7b6
--- /dev/null
+++ b/third_party/DKM/assets/sacre_coeur_B.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2f1eb9bdd4d80e480f672d6a729689ac77f9fd5c8deb90f59b377590f3ca4799
+size 152515
diff --git a/third_party/DKM/data/.gitignore b/third_party/DKM/data/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..c96a04f008ee21e260b28f7701595ed59e2839e3
--- /dev/null
+++ b/third_party/DKM/data/.gitignore
@@ -0,0 +1,2 @@
+*
+!.gitignore
\ No newline at end of file
diff --git a/third_party/DKM/demo/.gitignore b/third_party/DKM/demo/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..76ce7fcf6f600a9db91639ce9e445bec98ff1671
--- /dev/null
+++ b/third_party/DKM/demo/.gitignore
@@ -0,0 +1 @@
+*.jpg
diff --git a/third_party/DKM/demo/demo_fundamental.py b/third_party/DKM/demo/demo_fundamental.py
new file mode 100644
index 0000000000000000000000000000000000000000..e19766d5d3ce1abf0d18483cbbce71b2696983be
--- /dev/null
+++ b/third_party/DKM/demo/demo_fundamental.py
@@ -0,0 +1,37 @@
+from PIL import Image
+import torch
+import torch.nn.functional as F
+import numpy as np
+from dkm.utils.utils import tensor_to_pil
+import cv2
+from dkm import DKMv3_outdoor
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
+if __name__ == "__main__":
+ from argparse import ArgumentParser
+ parser = ArgumentParser()
+ parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
+ parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
+
+ args, _ = parser.parse_known_args()
+ im1_path = args.im_A_path
+ im2_path = args.im_B_path
+
+ # Create model
+ dkm_model = DKMv3_outdoor(device=device)
+
+
+ W_A, H_A = Image.open(im1_path).size
+ W_B, H_B = Image.open(im2_path).size
+
+ # Match
+ warp, certainty = dkm_model.match(im1_path, im2_path, device=device)
+ # Sample matches for estimation
+ matches, certainty = dkm_model.sample(warp, certainty)
+ kpts1, kpts2 = dkm_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
+ F, mask = cv2.findFundamentalMat(
+ kpts1.cpu().numpy(), kpts2.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
+ )
+ # TODO: some better visualization
\ No newline at end of file
diff --git a/third_party/DKM/demo/demo_match.py b/third_party/DKM/demo/demo_match.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb901894d8654a884819162d3b9bb8094529e034
--- /dev/null
+++ b/third_party/DKM/demo/demo_match.py
@@ -0,0 +1,48 @@
+from PIL import Image
+import torch
+import torch.nn.functional as F
+import numpy as np
+from dkm.utils.utils import tensor_to_pil
+
+from dkm import DKMv3_outdoor
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
+if __name__ == "__main__":
+ from argparse import ArgumentParser
+ parser = ArgumentParser()
+ parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
+ parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
+ parser.add_argument("--save_path", default="demo/dkmv3_warp_sacre_coeur.jpg", type=str)
+
+ args, _ = parser.parse_known_args()
+ im1_path = args.im_A_path
+ im2_path = args.im_B_path
+ save_path = args.save_path
+
+ # Create model
+ dkm_model = DKMv3_outdoor(device=device)
+
+ H, W = 864, 1152
+
+ im1 = Image.open(im1_path).resize((W, H))
+ im2 = Image.open(im2_path).resize((W, H))
+
+ # Match
+ warp, certainty = dkm_model.match(im1_path, im2_path, device=device)
+ # Sampling not needed, but can be done with model.sample(warp, certainty)
+ dkm_model.sample(warp, certainty)
+ x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
+ x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
+
+ im2_transfer_rgb = F.grid_sample(
+ x2[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
+ )[0]
+ im1_transfer_rgb = F.grid_sample(
+ x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
+ )[0]
+ warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2)
+ white_im = torch.ones((H,2*W),device=device)
+ vis_im = certainty * warp_im + (1 - certainty) * white_im
+ tensor_to_pil(vis_im, unnormalize=False).save(save_path)
diff --git a/third_party/DKM/dkm/__init__.py b/third_party/DKM/dkm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b47632780acc7762bcccc348e2025fe99f3726
--- /dev/null
+++ b/third_party/DKM/dkm/__init__.py
@@ -0,0 +1,4 @@
+from .models import (
+ DKMv3_outdoor,
+ DKMv3_indoor,
+ )
diff --git a/third_party/DKM/dkm/benchmarks/__init__.py b/third_party/DKM/dkm/benchmarks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..57643fd314a2301138aecdc804a5877d0ce9274e
--- /dev/null
+++ b/third_party/DKM/dkm/benchmarks/__init__.py
@@ -0,0 +1,4 @@
+from .hpatches_sequences_homog_benchmark import HpatchesHomogBenchmark
+from .scannet_benchmark import ScanNetBenchmark
+from .megadepth1500_benchmark import Megadepth1500Benchmark
+from .megadepth_dense_benchmark import MegadepthDenseBenchmark
diff --git a/third_party/DKM/dkm/benchmarks/deprecated/hpatches_sequences_dense_benchmark.py b/third_party/DKM/dkm/benchmarks/deprecated/hpatches_sequences_dense_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..079622fdaf77c75aeadd675629f2512c45d04c2d
--- /dev/null
+++ b/third_party/DKM/dkm/benchmarks/deprecated/hpatches_sequences_dense_benchmark.py
@@ -0,0 +1,100 @@
+from PIL import Image
+import numpy as np
+
+import os
+
+import torch
+from tqdm import tqdm
+
+from dkm.utils import *
+
+
+class HpatchesDenseBenchmark:
+ """WARNING: HPATCHES grid goes from [0,n-1] instead of [0.5,n-0.5]"""
+
+ def __init__(self, dataset_path) -> None:
+ seqs_dir = "hpatches-sequences-release"
+ self.seqs_path = os.path.join(dataset_path, seqs_dir)
+ self.seq_names = sorted(os.listdir(self.seqs_path))
+
+ def convert_coordinates(self, query_coords, query_to_support, wq, hq, wsup, hsup):
+ # Get matches in output format on the grid [0, n] where the center of the top-left coordinate is [0.5, 0.5]
+ offset = (
+ 0.5 # Hpatches assumes that the center of the top-left pixel is at [0,0]
+ )
+ query_coords = (
+ torch.stack(
+ (
+ wq * (query_coords[..., 0] + 1) / 2,
+ hq * (query_coords[..., 1] + 1) / 2,
+ ),
+ axis=-1,
+ )
+ - offset
+ )
+ query_to_support = (
+ torch.stack(
+ (
+ wsup * (query_to_support[..., 0] + 1) / 2,
+ hsup * (query_to_support[..., 1] + 1) / 2,
+ ),
+ axis=-1,
+ )
+ - offset
+ )
+ return query_coords, query_to_support
+
+ def inside_image(self, x, w, h):
+ return torch.logical_and(
+ x[:, 0] < (w - 1),
+ torch.logical_and(x[:, 1] < (h - 1), (x > 0).prod(dim=-1)),
+ )
+
+ def benchmark(self, model):
+ use_cuda = torch.cuda.is_available()
+ device = torch.device("cuda:0" if use_cuda else "cpu")
+ aepes = []
+ pcks = []
+ for seq_idx, seq_name in tqdm(
+ enumerate(self.seq_names), total=len(self.seq_names)
+ ):
+ if seq_name[0] == "i":
+ continue
+ im1_path = os.path.join(self.seqs_path, seq_name, "1.ppm")
+ im1 = Image.open(im1_path)
+ w1, h1 = im1.size
+ for im_idx in range(2, 7):
+ im2_path = os.path.join(self.seqs_path, seq_name, f"{im_idx}.ppm")
+ im2 = Image.open(im2_path)
+ w2, h2 = im2.size
+ matches, certainty = model.match(im2, im1, do_pred_in_og_res=True)
+ matches, certainty = matches.reshape(-1, 4), certainty.reshape(-1)
+ inv_homography = torch.from_numpy(
+ np.loadtxt(
+ os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx))
+ )
+ ).to(device)
+ homography = torch.linalg.inv(inv_homography)
+ pos_a, pos_b = self.convert_coordinates(
+ matches[:, :2], matches[:, 2:], w2, h2, w1, h1
+ )
+ pos_a, pos_b = pos_a.double(), pos_b.double()
+ pos_a_h = torch.cat(
+ [pos_a, torch.ones([pos_a.shape[0], 1], device=device)], dim=1
+ )
+ pos_b_proj_h = (homography @ pos_a_h.t()).t()
+ pos_b_proj = pos_b_proj_h[:, :2] / pos_b_proj_h[:, 2:]
+ mask = self.inside_image(pos_b_proj, w1, h1)
+ residual = pos_b - pos_b_proj
+ dist = (residual**2).sum(dim=1).sqrt()[mask]
+ aepes.append(torch.mean(dist).item())
+ pck1 = (dist < 1.0).float().mean().item()
+ pck3 = (dist < 3.0).float().mean().item()
+ pck5 = (dist < 5.0).float().mean().item()
+ pcks.append([pck1, pck3, pck5])
+ m_pcks = np.mean(np.array(pcks), axis=0)
+ return {
+ "hp_pck1": m_pcks[0],
+ "hp_pck3": m_pcks[1],
+ "hp_pck5": m_pcks[2],
+ }
diff --git a/third_party/DKM/dkm/benchmarks/deprecated/yfcc100m_benchmark.py b/third_party/DKM/dkm/benchmarks/deprecated/yfcc100m_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..781a291f0c358cbd435790b0a639f2a2510145b2
--- /dev/null
+++ b/third_party/DKM/dkm/benchmarks/deprecated/yfcc100m_benchmark.py
@@ -0,0 +1,119 @@
+import pickle
+import h5py
+import numpy as np
+import torch
+from dkm.utils import *
+from PIL import Image
+from tqdm import tqdm
+
+
+class Yfcc100mBenchmark:
+ def __init__(self, data_root="data/yfcc100m_test") -> None:
+ self.scenes = [
+ "buckingham_palace",
+ "notre_dame_front_facade",
+ "reichstag",
+ "sacre_coeur",
+ ]
+ self.data_root = data_root
+
+ def benchmark(self, model, r=2):
+ model.train(False)
+ with torch.no_grad():
+ data_root = self.data_root
+ meta_info = open(
+ f"{data_root}/yfcc_test_pairs_with_gt.txt", "r"
+ ).readlines()
+ tot_e_t, tot_e_R, tot_e_pose = [], [], []
+ for scene_ind in range(len(self.scenes)):
+ scene = self.scenes[scene_ind]
+ pairs = np.array(
+ pickle.load(
+ open(f"{data_root}/pairs/{scene}-te-1000-pairs.pkl", "rb")
+ )
+ )
+ scene_dir = f"{data_root}/yfcc100m/{scene}/test/"
+ calibs = open(scene_dir + "calibration.txt", "r").read().split("\n")
+ images = open(scene_dir + "images.txt", "r").read().split("\n")
+ pair_inds = np.random.choice(
+ range(len(pairs)), size=len(pairs), replace=False
+ )
+ for pairind in tqdm(pair_inds):
+ idx1, idx2 = pairs[pairind]
+ params = meta_info[1000 * scene_ind + pairind].split()
+ rot1, rot2 = int(params[2]), int(params[3])
+ calib1 = h5py.File(scene_dir + calibs[idx1], "r")
+ K1, R1, t1, _, _ = get_pose(calib1)
+ calib2 = h5py.File(scene_dir + calibs[idx2], "r")
+ K2, R2, t2, _, _ = get_pose(calib2)
+
+ R, t = compute_relative_pose(R1, t1, R2, t2)
+ im1 = images[idx1]
+ im2 = images[idx2]
+ im1 = Image.open(scene_dir + im1).rotate(rot1 * 90, expand=True)
+ w1, h1 = im1.size
+ im2 = Image.open(scene_dir + im2).rotate(rot2 * 90, expand=True)
+ w2, h2 = im2.size
+ K1 = rotate_intrinsic(K1, rot1)
+ K2 = rotate_intrinsic(K2, rot2)
+
+ dense_matches, dense_certainty = model.match(im1, im2)
+ dense_certainty = dense_certainty ** (1 / r)
+ sparse_matches, sparse_confidence = model.sample(
+ dense_matches, dense_certainty, 10000
+ )
+ scale1 = 480 / min(w1, h1)
+ scale2 = 480 / min(w2, h2)
+ w1, h1 = scale1 * w1, scale1 * h1
+ w2, h2 = scale2 * w2, scale2 * h2
+ K1 = K1 * scale1
+ K2 = K2 * scale2
+
+ kpts1 = sparse_matches[:, :2]
+ kpts1 = np.stack(
+ (w1 * kpts1[:, 0] / 2, h1 * kpts1[:, 1] / 2), axis=-1
+ )
+ kpts2 = sparse_matches[:, 2:]
+ kpts2 = np.stack(
+ (w2 * kpts2[:, 0] / 2, h2 * kpts2[:, 1] / 2), axis=-1
+ )
+ try:
+ threshold = 1.0
+ norm_threshold = threshold / (
+ np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2]))
+ )
+ R_est, t_est, mask = estimate_pose(
+ kpts1,
+ kpts2,
+ K1[:2, :2],
+ K2[:2, :2],
+ norm_threshold,
+ conf=0.9999999,
+ )
+ T1_to_2 = np.concatenate((R_est, t_est), axis=-1) #
+ e_t, e_R = compute_pose_error(T1_to_2, R, t)
+ e_pose = max(e_t, e_R)
+ except:
+ e_t, e_R = 90, 90
+ e_pose = max(e_t, e_R)
+ tot_e_t.append(e_t)
+ tot_e_R.append(e_R)
+ tot_e_pose.append(e_pose)
+ tot_e_pose = np.array(tot_e_pose)
+ thresholds = [5, 10, 20]
+ auc = pose_auc(tot_e_pose, thresholds)
+ acc_5 = (tot_e_pose < 5).mean()
+ acc_10 = (tot_e_pose < 10).mean()
+ acc_15 = (tot_e_pose < 15).mean()
+ acc_20 = (tot_e_pose < 20).mean()
+ map_5 = acc_5
+ map_10 = np.mean([acc_5, acc_10])
+ map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
+ return {
+ "auc_5": auc[0],
+ "auc_10": auc[1],
+ "auc_20": auc[2],
+ "map_5": map_5,
+ "map_10": map_10,
+ "map_20": map_20,
+ }
diff --git a/third_party/DKM/dkm/benchmarks/hpatches_sequences_homog_benchmark.py b/third_party/DKM/dkm/benchmarks/hpatches_sequences_homog_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c3febe5ca9e3a683bc7122cec635c4f54b66f7c
--- /dev/null
+++ b/third_party/DKM/dkm/benchmarks/hpatches_sequences_homog_benchmark.py
@@ -0,0 +1,114 @@
+from PIL import Image
+import numpy as np
+
+import os
+
+from tqdm import tqdm
+from dkm.utils import pose_auc
+import cv2
+
+
+class HpatchesHomogBenchmark:
+ """Hpatches grid goes from [0,n-1] instead of [0.5,n-0.5]"""
+
+ def __init__(self, dataset_path) -> None:
+ seqs_dir = "hpatches-sequences-release"
+ self.seqs_path = os.path.join(dataset_path, seqs_dir)
+ self.seq_names = sorted(os.listdir(self.seqs_path))
+ # Ignore seqs is same as LoFTR.
+ self.ignore_seqs = set(
+ [
+ "i_contruction",
+ "i_crownnight",
+ "i_dc",
+ "i_pencils",
+ "i_whitebuilding",
+ "v_artisans",
+ "v_astronautis",
+ "v_talent",
+ ]
+ )
+
+ def convert_coordinates(self, query_coords, query_to_support, wq, hq, wsup, hsup):
+ offset = 0.5 # Hpatches assumes that the center of the top-left pixel is at [0,0] (I think)
+ query_coords = (
+ np.stack(
+ (
+ wq * (query_coords[..., 0] + 1) / 2,
+ hq * (query_coords[..., 1] + 1) / 2,
+ ),
+ axis=-1,
+ )
+ - offset
+ )
+ query_to_support = (
+ np.stack(
+ (
+ wsup * (query_to_support[..., 0] + 1) / 2,
+ hsup * (query_to_support[..., 1] + 1) / 2,
+ ),
+ axis=-1,
+ )
+ - offset
+ )
+ return query_coords, query_to_support
+
+ def benchmark(self, model, model_name = None):
+ n_matches = []
+ homog_dists = []
+ for seq_idx, seq_name in tqdm(
+ enumerate(self.seq_names), total=len(self.seq_names)
+ ):
+ if seq_name in self.ignore_seqs:
+ continue
+ im1_path = os.path.join(self.seqs_path, seq_name, "1.ppm")
+ im1 = Image.open(im1_path)
+ w1, h1 = im1.size
+ for im_idx in range(2, 7):
+ im2_path = os.path.join(self.seqs_path, seq_name, f"{im_idx}.ppm")
+ im2 = Image.open(im2_path)
+ w2, h2 = im2.size
+ H = np.loadtxt(
+ os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx))
+ )
+ dense_matches, dense_certainty = model.match(
+ im1_path, im2_path
+ )
+ good_matches, _ = model.sample(dense_matches, dense_certainty, 5000)
+ pos_a, pos_b = self.convert_coordinates(
+ good_matches[:, :2], good_matches[:, 2:], w1, h1, w2, h2
+ )
+ try:
+ H_pred, inliers = cv2.findHomography(
+ pos_a,
+ pos_b,
+ method = cv2.RANSAC,
+ confidence = 0.99999,
+ ransacReprojThreshold = 3 * min(w2, h2) / 480,
+ )
+ except:
+ H_pred = None
+ if H_pred is None:
+ H_pred = np.zeros((3, 3))
+ H_pred[2, 2] = 1.0
+ corners = np.array(
+ [[0, 0, 1], [0, h1 - 1, 1], [w1 - 1, 0, 1], [w1 - 1, h1 - 1, 1]]
+ )
+ real_warped_corners = np.dot(corners, np.transpose(H))
+ real_warped_corners = (
+ real_warped_corners[:, :2] / real_warped_corners[:, 2:]
+ )
+ warped_corners = np.dot(corners, np.transpose(H_pred))
+ warped_corners = warped_corners[:, :2] / warped_corners[:, 2:]
+ mean_dist = np.mean(
+ np.linalg.norm(real_warped_corners - warped_corners, axis=1)
+ ) / (min(w2, h2) / 480.0)
+ homog_dists.append(mean_dist)
+ n_matches = np.array(n_matches)
+ thresholds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+ auc = pose_auc(np.array(homog_dists), thresholds)
+ return {
+ "hpatches_homog_auc_3": auc[2],
+ "hpatches_homog_auc_5": auc[4],
+ "hpatches_homog_auc_10": auc[9],
+ }
diff --git a/third_party/DKM/dkm/benchmarks/megadepth1500_benchmark.py b/third_party/DKM/dkm/benchmarks/megadepth1500_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b1193745ff18d239165aeb3376642fb17033874
--- /dev/null
+++ b/third_party/DKM/dkm/benchmarks/megadepth1500_benchmark.py
@@ -0,0 +1,124 @@
+import numpy as np
+import torch
+from dkm.utils import *
+from PIL import Image
+from tqdm import tqdm
+import torch.nn.functional as F
+
+class Megadepth1500Benchmark:
+ def __init__(self, data_root="data/megadepth", scene_names = None) -> None:
+ if scene_names is None:
+ self.scene_names = [
+ "0015_0.1_0.3.npz",
+ "0015_0.3_0.5.npz",
+ "0022_0.1_0.3.npz",
+ "0022_0.3_0.5.npz",
+ "0022_0.5_0.7.npz",
+ ]
+ else:
+ self.scene_names = scene_names
+ self.scenes = [
+ np.load(f"{data_root}/{scene}", allow_pickle=True)
+ for scene in self.scene_names
+ ]
+ self.data_root = data_root
+
+ def benchmark(self, model):
+ with torch.no_grad():
+ data_root = self.data_root
+ tot_e_t, tot_e_R, tot_e_pose = [], [], []
+ for scene_ind in range(len(self.scenes)):
+ scene = self.scenes[scene_ind]
+ pairs = scene["pair_infos"]
+ intrinsics = scene["intrinsics"]
+ poses = scene["poses"]
+ im_paths = scene["image_paths"]
+ pair_inds = range(len(pairs))
+ for pairind in tqdm(pair_inds):
+ idx1, idx2 = pairs[pairind][0]
+ K1 = intrinsics[idx1].copy()
+ T1 = poses[idx1].copy()
+ R1, t1 = T1[:3, :3], T1[:3, 3]
+ K2 = intrinsics[idx2].copy()
+ T2 = poses[idx2].copy()
+ R2, t2 = T2[:3, :3], T2[:3, 3]
+ R, t = compute_relative_pose(R1, t1, R2, t2)
+ im1_path = f"{data_root}/{im_paths[idx1]}"
+ im2_path = f"{data_root}/{im_paths[idx2]}"
+ im1 = Image.open(im1_path)
+ w1, h1 = im1.size
+ im2 = Image.open(im2_path)
+ w2, h2 = im2.size
+ scale1 = 1200 / max(w1, h1)
+ scale2 = 1200 / max(w2, h2)
+ w1, h1 = scale1 * w1, scale1 * h1
+ w2, h2 = scale2 * w2, scale2 * h2
+ K1[:2] = K1[:2] * scale1
+ K2[:2] = K2[:2] * scale2
+ dense_matches, dense_certainty = model.match(im1_path, im2_path)
+ sparse_matches,_ = model.sample(
+ dense_matches, dense_certainty, 5000
+ )
+ kpts1 = sparse_matches[:, :2]
+ kpts1 = (
+ torch.stack(
+ (
+ w1 * (kpts1[:, 0] + 1) / 2,
+ h1 * (kpts1[:, 1] + 1) / 2,
+ ),
+ axis=-1,
+ )
+ )
+ kpts2 = sparse_matches[:, 2:]
+ kpts2 = (
+ torch.stack(
+ (
+ w2 * (kpts2[:, 0] + 1) / 2,
+ h2 * (kpts2[:, 1] + 1) / 2,
+ ),
+ axis=-1,
+ )
+ )
+ for _ in range(5):
+ shuffling = np.random.permutation(np.arange(len(kpts1)))
+ kpts1 = kpts1[shuffling]
+ kpts2 = kpts2[shuffling]
+ try:
+ norm_threshold = 0.5 / (
+ np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
+ R_est, t_est, mask = estimate_pose(
+ kpts1.cpu().numpy(),
+ kpts2.cpu().numpy(),
+ K1,
+ K2,
+ norm_threshold,
+ conf=0.99999,
+ )
+ T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
+ e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
+ e_pose = max(e_t, e_R)
+ except Exception as e:
+ print(repr(e))
+ e_t, e_R = 90, 90
+ e_pose = max(e_t, e_R)
+ tot_e_t.append(e_t)
+ tot_e_R.append(e_R)
+ tot_e_pose.append(e_pose)
+ tot_e_pose = np.array(tot_e_pose)
+ thresholds = [5, 10, 20]
+ auc = pose_auc(tot_e_pose, thresholds)
+ acc_5 = (tot_e_pose < 5).mean()
+ acc_10 = (tot_e_pose < 10).mean()
+ acc_15 = (tot_e_pose < 15).mean()
+ acc_20 = (tot_e_pose < 20).mean()
+ map_5 = acc_5
+ map_10 = np.mean([acc_5, acc_10])
+ map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
+ return {
+ "auc_5": auc[0],
+ "auc_10": auc[1],
+ "auc_20": auc[2],
+ "map_5": map_5,
+ "map_10": map_10,
+ "map_20": map_20,
+ }
diff --git a/third_party/DKM/dkm/benchmarks/megadepth_dense_benchmark.py b/third_party/DKM/dkm/benchmarks/megadepth_dense_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b370644497efd62563105e68e692e10ff339669
--- /dev/null
+++ b/third_party/DKM/dkm/benchmarks/megadepth_dense_benchmark.py
@@ -0,0 +1,86 @@
+import torch
+import numpy as np
+import tqdm
+from dkm.datasets import MegadepthBuilder
+from dkm.utils import warp_kpts
+from torch.utils.data import ConcatDataset
+
+
+class MegadepthDenseBenchmark:
+ def __init__(self, data_root="data/megadepth", h = 384, w = 512, num_samples = 2000, device=None) -> None:
+ mega = MegadepthBuilder(data_root=data_root)
+ self.dataset = ConcatDataset(
+ mega.build_scenes(split="test_loftr", ht=h, wt=w)
+ ) # fixed resolution of 384,512
+ self.num_samples = num_samples
+ if device is None:
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ self.device = device
+
+ def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches):
+ b, h1, w1, d = dense_matches.shape
+ with torch.no_grad():
+ x1 = dense_matches[..., :2].reshape(b, h1 * w1, 2)
+ # x1 = torch.stack((2*x1[...,0]/w1-1,2*x1[...,1]/h1-1),dim=-1)
+ mask, x2 = warp_kpts(
+ x1.double(),
+ depth1.double(),
+ depth2.double(),
+ T_1to2.double(),
+ K1.double(),
+ K2.double(),
+ )
+ x2 = torch.stack(
+ (w1 * (x2[..., 0] + 1) / 2, h1 * (x2[..., 1] + 1) / 2), dim=-1
+ )
+ prob = mask.float().reshape(b, h1, w1)
+ x2_hat = dense_matches[..., 2:]
+ x2_hat = torch.stack(
+ (w1 * (x2_hat[..., 0] + 1) / 2, h1 * (x2_hat[..., 1] + 1) / 2), dim=-1
+ )
+ gd = (x2_hat - x2.reshape(b, h1, w1, 2)).norm(dim=-1)
+ gd = gd[prob == 1]
+ pck_1 = (gd < 1.0).float().mean()
+ pck_3 = (gd < 3.0).float().mean()
+ pck_5 = (gd < 5.0).float().mean()
+ gd = gd.mean()
+ return gd, pck_1, pck_3, pck_5
+
+ def benchmark(self, model, batch_size=8):
+ model.train(False)
+ with torch.no_grad():
+ gd_tot = 0.0
+ pck_1_tot = 0.0
+ pck_3_tot = 0.0
+ pck_5_tot = 0.0
+ sampler = torch.utils.data.WeightedRandomSampler(
+ torch.ones(len(self.dataset)), replacement=False, num_samples=self.num_samples
+ )
+ dataloader = torch.utils.data.DataLoader(
+ self.dataset, batch_size=8, num_workers=batch_size, sampler=sampler
+ )
+ for data in tqdm.tqdm(dataloader):
+ im1, im2, depth1, depth2, T_1to2, K1, K2 = (
+ data["query"],
+ data["support"],
+ data["query_depth"].to(self.device),
+ data["support_depth"].to(self.device),
+ data["T_1to2"].to(self.device),
+ data["K1"].to(self.device),
+ data["K2"].to(self.device),
+ )
+ matches, certainty = model.match(im1, im2, batched=True)
+ gd, pck_1, pck_3, pck_5 = self.geometric_dist(
+ depth1, depth2, T_1to2, K1, K2, matches
+ )
+ gd_tot, pck_1_tot, pck_3_tot, pck_5_tot = (
+ gd_tot + gd,
+ pck_1_tot + pck_1,
+ pck_3_tot + pck_3,
+ pck_5_tot + pck_5,
+ )
+ return {
+ "mega_pck_1": pck_1_tot.item() / len(dataloader),
+ "mega_pck_3": pck_3_tot.item() / len(dataloader),
+ "mega_pck_5": pck_5_tot.item() / len(dataloader),
+ }
diff --git a/third_party/DKM/dkm/benchmarks/scannet_benchmark.py b/third_party/DKM/dkm/benchmarks/scannet_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca938cb462c351845ce035f8be0714cf81214452
--- /dev/null
+++ b/third_party/DKM/dkm/benchmarks/scannet_benchmark.py
@@ -0,0 +1,143 @@
+import os.path as osp
+import numpy as np
+import torch
+from dkm.utils import *
+from PIL import Image
+from tqdm import tqdm
+
+
+class ScanNetBenchmark:
+ def __init__(self, data_root="data/scannet") -> None:
+ self.data_root = data_root
+
+ def benchmark(self, model, model_name = None):
+ model.train(False)
+ with torch.no_grad():
+ data_root = self.data_root
+ tmp = np.load(osp.join(data_root, "test.npz"))
+ pairs, rel_pose = tmp["name"], tmp["rel_pose"]
+ tot_e_t, tot_e_R, tot_e_pose = [], [], []
+ pair_inds = np.random.choice(
+ range(len(pairs)), size=len(pairs), replace=False
+ )
+ for pairind in tqdm(pair_inds, smoothing=0.9):
+ scene = pairs[pairind]
+ scene_name = f"scene0{scene[0]}_00"
+ im1_path = osp.join(
+ self.data_root,
+ "scans_test",
+ scene_name,
+ "color",
+ f"{scene[2]}.jpg",
+ )
+ im1 = Image.open(im1_path)
+ im2_path = osp.join(
+ self.data_root,
+ "scans_test",
+ scene_name,
+ "color",
+ f"{scene[3]}.jpg",
+ )
+ im2 = Image.open(im2_path)
+ T_gt = rel_pose[pairind].reshape(3, 4)
+ R, t = T_gt[:3, :3], T_gt[:3, 3]
+ K = np.stack(
+ [
+ np.array([float(i) for i in r.split()])
+ for r in open(
+ osp.join(
+ self.data_root,
+ "scans_test",
+ scene_name,
+ "intrinsic",
+ "intrinsic_color.txt",
+ ),
+ "r",
+ )
+ .read()
+ .split("\n")
+ if r
+ ]
+ )
+ w1, h1 = im1.size
+ w2, h2 = im2.size
+ K1 = K.copy()
+ K2 = K.copy()
+ dense_matches, dense_certainty = model.match(im1_path, im2_path)
+ sparse_matches, sparse_certainty = model.sample(
+ dense_matches, dense_certainty, 5000
+ )
+ scale1 = 480 / min(w1, h1)
+ scale2 = 480 / min(w2, h2)
+ w1, h1 = scale1 * w1, scale1 * h1
+ w2, h2 = scale2 * w2, scale2 * h2
+ K1 = K1 * scale1
+ K2 = K2 * scale2
+
+ offset = 0.5
+ kpts1 = sparse_matches[:, :2]
+ kpts1 = (
+ np.stack(
+ (
+ w1 * (kpts1[:, 0] + 1) / 2 - offset,
+ h1 * (kpts1[:, 1] + 1) / 2 - offset,
+ ),
+ axis=-1,
+ )
+ )
+ kpts2 = sparse_matches[:, 2:]
+ kpts2 = (
+ np.stack(
+ (
+ w2 * (kpts2[:, 0] + 1) / 2 - offset,
+ h2 * (kpts2[:, 1] + 1) / 2 - offset,
+ ),
+ axis=-1,
+ )
+ )
+ for _ in range(5):
+ shuffling = np.random.permutation(np.arange(len(kpts1)))
+ kpts1 = kpts1[shuffling]
+ kpts2 = kpts2[shuffling]
+ try:
+ norm_threshold = 0.5 / (
+ np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
+ R_est, t_est, mask = estimate_pose(
+ kpts1,
+ kpts2,
+ K1,
+ K2,
+ norm_threshold,
+ conf=0.99999,
+ )
+ T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
+ e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
+ e_pose = max(e_t, e_R)
+ except Exception as e:
+ print(repr(e))
+ e_t, e_R = 90, 90
+ e_pose = max(e_t, e_R)
+ tot_e_t.append(e_t)
+ tot_e_R.append(e_R)
+ tot_e_pose.append(e_pose)
+ tot_e_t.append(e_t)
+ tot_e_R.append(e_R)
+ tot_e_pose.append(e_pose)
+ tot_e_pose = np.array(tot_e_pose)
+ thresholds = [5, 10, 20]
+ auc = pose_auc(tot_e_pose, thresholds)
+ acc_5 = (tot_e_pose < 5).mean()
+ acc_10 = (tot_e_pose < 10).mean()
+ acc_15 = (tot_e_pose < 15).mean()
+ acc_20 = (tot_e_pose < 20).mean()
+ map_5 = acc_5
+ map_10 = np.mean([acc_5, acc_10])
+ map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
+ return {
+ "auc_5": auc[0],
+ "auc_10": auc[1],
+ "auc_20": auc[2],
+ "map_5": map_5,
+ "map_10": map_10,
+ "map_20": map_20,
+ }
diff --git a/third_party/DKM/dkm/checkpointing/__init__.py b/third_party/DKM/dkm/checkpointing/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..22f5afe727aa6f6e8fffa9ecf5be69cbff686577
--- /dev/null
+++ b/third_party/DKM/dkm/checkpointing/__init__.py
@@ -0,0 +1 @@
+from .checkpoint import CheckPoint
diff --git a/third_party/DKM/dkm/checkpointing/checkpoint.py b/third_party/DKM/dkm/checkpointing/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..715eeb587ebb87ed0d1bcf9940e048adbe35cde2
--- /dev/null
+++ b/third_party/DKM/dkm/checkpointing/checkpoint.py
@@ -0,0 +1,31 @@
+import os
+import torch
+from torch.nn.parallel.data_parallel import DataParallel
+from torch.nn.parallel.distributed import DistributedDataParallel
+from loguru import logger
+
+
+class CheckPoint:
+ def __init__(self, dir=None, name="tmp"):
+ self.name = name
+ self.dir = dir
+ os.makedirs(self.dir, exist_ok=True)
+
+ def __call__(
+ self,
+ model,
+ optimizer,
+ lr_scheduler,
+ n,
+ ):
+ assert model is not None
+ if isinstance(model, (DataParallel, DistributedDataParallel)):
+ model = model.module
+ states = {
+ "model": model.state_dict(),
+ "n": n,
+ "optimizer": optimizer.state_dict(),
+ "lr_scheduler": lr_scheduler.state_dict(),
+ }
+ torch.save(states, self.dir + self.name + f"_latest.pth")
+ logger.info(f"Saved states {list(states.keys())}, at step {n}")
diff --git a/third_party/DKM/dkm/datasets/__init__.py b/third_party/DKM/dkm/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b81083212edaf345c30f0cb1116c5f9de284ce6
--- /dev/null
+++ b/third_party/DKM/dkm/datasets/__init__.py
@@ -0,0 +1 @@
+from .megadepth import MegadepthBuilder
diff --git a/third_party/DKM/dkm/datasets/megadepth.py b/third_party/DKM/dkm/datasets/megadepth.py
new file mode 100644
index 0000000000000000000000000000000000000000..c580607e910ce1926b7711b5473aa82b20865369
--- /dev/null
+++ b/third_party/DKM/dkm/datasets/megadepth.py
@@ -0,0 +1,177 @@
+import os
+import random
+from PIL import Image
+import h5py
+import numpy as np
+import torch
+from torch.utils.data import Dataset, DataLoader, ConcatDataset
+
+from dkm.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
+import torchvision.transforms.functional as tvf
+from dkm.utils.transforms import GeometricSequential
+import kornia.augmentation as K
+
+
+class MegadepthScene:
+ def __init__(
+ self,
+ data_root,
+ scene_info,
+ ht=384,
+ wt=512,
+ min_overlap=0.0,
+ shake_t=0,
+ rot_prob=0.0,
+ normalize=True,
+ ) -> None:
+ self.data_root = data_root
+ self.image_paths = scene_info["image_paths"]
+ self.depth_paths = scene_info["depth_paths"]
+ self.intrinsics = scene_info["intrinsics"]
+ self.poses = scene_info["poses"]
+ self.pairs = scene_info["pairs"]
+ self.overlaps = scene_info["overlaps"]
+ threshold = self.overlaps > min_overlap
+ self.pairs = self.pairs[threshold]
+ self.overlaps = self.overlaps[threshold]
+ if len(self.pairs) > 100000:
+ pairinds = np.random.choice(
+ np.arange(0, len(self.pairs)), 100000, replace=False
+ )
+ self.pairs = self.pairs[pairinds]
+ self.overlaps = self.overlaps[pairinds]
+ # counts, bins = np.histogram(self.overlaps,20)
+ # print(counts)
+ self.im_transform_ops = get_tuple_transform_ops(
+ resize=(ht, wt), normalize=normalize
+ )
+ self.depth_transform_ops = get_depth_tuple_transform_ops(
+ resize=(ht, wt), normalize=False
+ )
+ self.wt, self.ht = wt, ht
+ self.shake_t = shake_t
+ self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob))
+
+ def load_im(self, im_ref, crop=None):
+ im = Image.open(im_ref)
+ return im
+
+ def load_depth(self, depth_ref, crop=None):
+ depth = np.array(h5py.File(depth_ref, "r")["depth"])
+ return torch.from_numpy(depth)
+
+ def __len__(self):
+ return len(self.pairs)
+
+ def scale_intrinsic(self, K, wi, hi):
+ sx, sy = self.wt / wi, self.ht / hi
+ sK = torch.tensor([[sx, 0, 0], [0, sy, 0], [0, 0, 1]])
+ return sK @ K
+
+ def rand_shake(self, *things):
+ t = np.random.choice(range(-self.shake_t, self.shake_t + 1), size=2)
+ return [
+ tvf.affine(thing, angle=0.0, translate=list(t), scale=1.0, shear=[0.0, 0.0])
+ for thing in things
+ ], t
+
+ def __getitem__(self, pair_idx):
+ # read intrinsics of original size
+ idx1, idx2 = self.pairs[pair_idx]
+ K1 = torch.tensor(self.intrinsics[idx1].copy(), dtype=torch.float).reshape(3, 3)
+ K2 = torch.tensor(self.intrinsics[idx2].copy(), dtype=torch.float).reshape(3, 3)
+
+ # read and compute relative poses
+ T1 = self.poses[idx1]
+ T2 = self.poses[idx2]
+ T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[
+ :4, :4
+ ] # (4, 4)
+
+ # Load positive pair data
+ im1, im2 = self.image_paths[idx1], self.image_paths[idx2]
+ depth1, depth2 = self.depth_paths[idx1], self.depth_paths[idx2]
+ im_src_ref = os.path.join(self.data_root, im1)
+ im_pos_ref = os.path.join(self.data_root, im2)
+ depth_src_ref = os.path.join(self.data_root, depth1)
+ depth_pos_ref = os.path.join(self.data_root, depth2)
+ # return torch.randn((1000,1000))
+ im_src = self.load_im(im_src_ref)
+ im_pos = self.load_im(im_pos_ref)
+ depth_src = self.load_depth(depth_src_ref)
+ depth_pos = self.load_depth(depth_pos_ref)
+
+ # Recompute camera intrinsic matrix due to the resize
+ K1 = self.scale_intrinsic(K1, im_src.width, im_src.height)
+ K2 = self.scale_intrinsic(K2, im_pos.width, im_pos.height)
+ # Process images
+ im_src, im_pos = self.im_transform_ops((im_src, im_pos))
+ depth_src, depth_pos = self.depth_transform_ops(
+ (depth_src[None, None], depth_pos[None, None])
+ )
+ [im_src, im_pos, depth_src, depth_pos], t = self.rand_shake(
+ im_src, im_pos, depth_src, depth_pos
+ )
+ im_src, Hq = self.H_generator(im_src[None])
+ depth_src = self.H_generator.apply_transform(depth_src, Hq)
+ K1[:2, 2] += t
+ K2[:2, 2] += t
+ K1 = Hq[0] @ K1
+ data_dict = {
+ "query": im_src[0],
+ "query_identifier": self.image_paths[idx1].split("/")[-1].split(".jpg")[0],
+ "support": im_pos,
+ "support_identifier": self.image_paths[idx2]
+ .split("/")[-1]
+ .split(".jpg")[0],
+ "query_depth": depth_src[0, 0],
+ "support_depth": depth_pos[0, 0],
+ "K1": K1,
+ "K2": K2,
+ "T_1to2": T_1to2,
+ }
+ return data_dict
+
+
+class MegadepthBuilder:
+ def __init__(self, data_root="data/megadepth") -> None:
+ self.data_root = data_root
+ self.scene_info_root = os.path.join(data_root, "prep_scene_info")
+ self.all_scenes = os.listdir(self.scene_info_root)
+ self.test_scenes = ["0017.npy", "0004.npy", "0048.npy", "0013.npy"]
+ self.test_scenes_loftr = ["0015.npy", "0022.npy"]
+
+ def build_scenes(self, split="train", min_overlap=0.0, **kwargs):
+ if split == "train":
+ scene_names = set(self.all_scenes) - set(self.test_scenes)
+ elif split == "train_loftr":
+ scene_names = set(self.all_scenes) - set(self.test_scenes_loftr)
+ elif split == "test":
+ scene_names = self.test_scenes
+ elif split == "test_loftr":
+ scene_names = self.test_scenes_loftr
+ else:
+ raise ValueError(f"Split {split} not available")
+ scenes = []
+ for scene_name in scene_names:
+ scene_info = np.load(
+ os.path.join(self.scene_info_root, scene_name), allow_pickle=True
+ ).item()
+ scenes.append(
+ MegadepthScene(
+ self.data_root, scene_info, min_overlap=min_overlap, **kwargs
+ )
+ )
+ return scenes
+
+ def weight_scenes(self, concat_dataset, alpha=0.5):
+ ns = []
+ for d in concat_dataset.datasets:
+ ns.append(len(d))
+ ws = torch.cat([torch.ones(n) / n**alpha for n in ns])
+ return ws
+
+
+if __name__ == "__main__":
+ mega_test = ConcatDataset(MegadepthBuilder().build_scenes(split="train"))
+ mega_test[0]
diff --git a/third_party/DKM/dkm/datasets/scannet.py b/third_party/DKM/dkm/datasets/scannet.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ac39b41480f7585c4755cc30e0677ef74ed5e0c
--- /dev/null
+++ b/third_party/DKM/dkm/datasets/scannet.py
@@ -0,0 +1,151 @@
+import os
+import random
+from PIL import Image
+import cv2
+import h5py
+import numpy as np
+import torch
+from torch.utils.data import (
+ Dataset,
+ DataLoader,
+ ConcatDataset)
+
+import torchvision.transforms.functional as tvf
+import kornia.augmentation as K
+import os.path as osp
+import matplotlib.pyplot as plt
+from dkm.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
+from dkm.utils.transforms import GeometricSequential
+
+from tqdm import tqdm
+
+class ScanNetScene:
+ def __init__(self, data_root, scene_info, ht = 384, wt = 512, min_overlap=0., shake_t = 0, rot_prob=0.) -> None:
+ self.scene_root = osp.join(data_root,"scans","scans_train")
+ self.data_names = scene_info['name']
+ self.overlaps = scene_info['score']
+ # Only sample 10s
+ valid = (self.data_names[:,-2:] % 10).sum(axis=-1) == 0
+ self.overlaps = self.overlaps[valid]
+ self.data_names = self.data_names[valid]
+ if len(self.data_names) > 10000:
+ pairinds = np.random.choice(np.arange(0,len(self.data_names)),10000,replace=False)
+ self.data_names = self.data_names[pairinds]
+ self.overlaps = self.overlaps[pairinds]
+ self.im_transform_ops = get_tuple_transform_ops(resize=(ht, wt), normalize=True)
+ self.depth_transform_ops = get_depth_tuple_transform_ops(resize=(ht, wt), normalize=False)
+ self.wt, self.ht = wt, ht
+ self.shake_t = shake_t
+ self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob))
+
+ def load_im(self, im_ref, crop=None):
+ im = Image.open(im_ref)
+ return im
+
+ def load_depth(self, depth_ref, crop=None):
+ depth = cv2.imread(str(depth_ref), cv2.IMREAD_UNCHANGED)
+ depth = depth / 1000
+ depth = torch.from_numpy(depth).float() # (h, w)
+ return depth
+
+ def __len__(self):
+ return len(self.data_names)
+
+ def scale_intrinsic(self, K, wi, hi):
+ sx, sy = self.wt / wi, self.ht / hi
+ sK = torch.tensor([[sx, 0, 0],
+ [0, sy, 0],
+ [0, 0, 1]])
+ return sK@K
+
+ def read_scannet_pose(self,path):
+ """ Read ScanNet's Camera2World pose and transform it to World2Camera.
+
+ Returns:
+ pose_w2c (np.ndarray): (4, 4)
+ """
+ cam2world = np.loadtxt(path, delimiter=' ')
+ world2cam = np.linalg.inv(cam2world)
+ return world2cam
+
+
+ def read_scannet_intrinsic(self,path):
+ """ Read ScanNet's intrinsic matrix and return the 3x3 matrix.
+ """
+ intrinsic = np.loadtxt(path, delimiter=' ')
+ return intrinsic[:-1, :-1]
+
+ def __getitem__(self, pair_idx):
+ # read intrinsics of original size
+ data_name = self.data_names[pair_idx]
+ scene_name, scene_sub_name, stem_name_1, stem_name_2 = data_name
+ scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}'
+
+ # read the intrinsic of depthmap
+ K1 = K2 = self.read_scannet_intrinsic(osp.join(self.scene_root,
+ scene_name,
+ 'intrinsic', 'intrinsic_color.txt'))#the depth K is not the same, but doesnt really matter
+ # read and compute relative poses
+ T1 = self.read_scannet_pose(osp.join(self.scene_root,
+ scene_name,
+ 'pose', f'{stem_name_1}.txt'))
+ T2 = self.read_scannet_pose(osp.join(self.scene_root,
+ scene_name,
+ 'pose', f'{stem_name_2}.txt'))
+ T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[:4, :4] # (4, 4)
+
+ # Load positive pair data
+ im_src_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_1}.jpg')
+ im_pos_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_2}.jpg')
+ depth_src_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_1}.png')
+ depth_pos_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_2}.png')
+
+ im_src = self.load_im(im_src_ref)
+ im_pos = self.load_im(im_pos_ref)
+ depth_src = self.load_depth(depth_src_ref)
+ depth_pos = self.load_depth(depth_pos_ref)
+
+ # Recompute camera intrinsic matrix due to the resize
+ K1 = self.scale_intrinsic(K1, im_src.width, im_src.height)
+ K2 = self.scale_intrinsic(K2, im_pos.width, im_pos.height)
+ # Process images
+ im_src, im_pos = self.im_transform_ops((im_src, im_pos))
+ depth_src, depth_pos = self.depth_transform_ops((depth_src[None,None], depth_pos[None,None]))
+
+ data_dict = {'query': im_src,
+ 'support': im_pos,
+ 'query_depth': depth_src[0,0],
+ 'support_depth': depth_pos[0,0],
+ 'K1': K1,
+ 'K2': K2,
+ 'T_1to2':T_1to2,
+ }
+ return data_dict
+
+
+class ScanNetBuilder:
+ def __init__(self, data_root = 'data/scannet') -> None:
+ self.data_root = data_root
+ self.scene_info_root = os.path.join(data_root,'scannet_indices')
+ self.all_scenes = os.listdir(self.scene_info_root)
+
+ def build_scenes(self, split = 'train', min_overlap=0., **kwargs):
+ # Note: split doesn't matter here as we always use same scannet_train scenes
+ scene_names = self.all_scenes
+ scenes = []
+ for scene_name in tqdm(scene_names):
+ scene_info = np.load(os.path.join(self.scene_info_root,scene_name), allow_pickle=True)
+ scenes.append(ScanNetScene(self.data_root, scene_info, min_overlap=min_overlap, **kwargs))
+ return scenes
+
+ def weight_scenes(self, concat_dataset, alpha=.5):
+ ns = []
+ for d in concat_dataset.datasets:
+ ns.append(len(d))
+ ws = torch.cat([torch.ones(n)/n**alpha for n in ns])
+ return ws
+
+
+if __name__ == "__main__":
+ mega_test = ConcatDataset(ScanNetBuilder("data/scannet").build_scenes(split='train'))
+ mega_test[0]
\ No newline at end of file
diff --git a/third_party/DKM/dkm/losses/__init__.py b/third_party/DKM/dkm/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..71914f50d891079d204a07c57367159888f892de
--- /dev/null
+++ b/third_party/DKM/dkm/losses/__init__.py
@@ -0,0 +1 @@
+from .depth_match_regression_loss import DepthRegressionLoss
diff --git a/third_party/DKM/dkm/losses/depth_match_regression_loss.py b/third_party/DKM/dkm/losses/depth_match_regression_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..80da70347b4b4addc721e2a14ed489f8683fd48a
--- /dev/null
+++ b/third_party/DKM/dkm/losses/depth_match_regression_loss.py
@@ -0,0 +1,128 @@
+from einops.einops import rearrange
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from dkm.utils.utils import warp_kpts
+
+
+class DepthRegressionLoss(nn.Module):
+ def __init__(
+ self,
+ robust=True,
+ center_coords=False,
+ scale_normalize=False,
+ ce_weight=0.01,
+ local_loss=True,
+ local_dist=4.0,
+ local_largest_scale=8,
+ ):
+ super().__init__()
+ self.robust = robust # measured in pixels
+ self.center_coords = center_coords
+ self.scale_normalize = scale_normalize
+ self.ce_weight = ce_weight
+ self.local_loss = local_loss
+ self.local_dist = local_dist
+ self.local_largest_scale = local_largest_scale
+
+ def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches, scale):
+ """[summary]
+
+ Args:
+ H ([type]): [description]
+ scale ([type]): [description]
+
+ Returns:
+ [type]: [description]
+ """
+ b, h1, w1, d = dense_matches.shape
+ with torch.no_grad():
+ x1_n = torch.meshgrid(
+ *[
+ torch.linspace(
+ -1 + 1 / n, 1 - 1 / n, n, device=dense_matches.device
+ )
+ for n in (b, h1, w1)
+ ]
+ )
+ x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(b, h1 * w1, 2)
+ mask, x2 = warp_kpts(
+ x1_n.double(),
+ depth1.double(),
+ depth2.double(),
+ T_1to2.double(),
+ K1.double(),
+ K2.double(),
+ )
+ prob = mask.float().reshape(b, h1, w1)
+ gd = (dense_matches - x2.reshape(b, h1, w1, 2)).norm(dim=-1) # *scale?
+ return gd, prob
+
+ def dense_depth_loss(self, dense_certainty, prob, gd, scale, eps=1e-8):
+ """[summary]
+
+ Args:
+ dense_certainty ([type]): [description]
+ prob ([type]): [description]
+ eps ([type], optional): [description]. Defaults to 1e-8.
+
+ Returns:
+ [type]: [description]
+ """
+ smooth_prob = prob
+ ce_loss = F.binary_cross_entropy_with_logits(dense_certainty[:, 0], smooth_prob)
+ depth_loss = gd[prob > 0]
+ if not torch.any(prob > 0).item():
+ depth_loss = (gd * 0.0).mean() # Prevent issues where prob is 0 everywhere
+ return {
+ f"ce_loss_{scale}": ce_loss.mean(),
+ f"depth_loss_{scale}": depth_loss.mean(),
+ }
+
+ def forward(self, dense_corresps, batch):
+ """[summary]
+
+ Args:
+ out ([type]): [description]
+ batch ([type]): [description]
+
+ Returns:
+ [type]: [description]
+ """
+ scales = list(dense_corresps.keys())
+ tot_loss = 0.0
+ prev_gd = 0.0
+ for scale in scales:
+ dense_scale_corresps = dense_corresps[scale]
+ dense_scale_certainty, dense_scale_coords = (
+ dense_scale_corresps["dense_certainty"],
+ dense_scale_corresps["dense_flow"],
+ )
+ dense_scale_coords = rearrange(dense_scale_coords, "b d h w -> b h w d")
+ b, h, w, d = dense_scale_coords.shape
+ gd, prob = self.geometric_dist(
+ batch["query_depth"],
+ batch["support_depth"],
+ batch["T_1to2"],
+ batch["K1"],
+ batch["K2"],
+ dense_scale_coords,
+ scale,
+ )
+ if (
+ scale <= self.local_largest_scale and self.local_loss
+ ): # Thought here is that fine matching loss should not be punished by coarse mistakes, but should identify wrong matching
+ prob = prob * (
+ F.interpolate(prev_gd[:, None], size=(h, w), mode="nearest")[:, 0]
+ < (2 / 512) * (self.local_dist * scale)
+ )
+ depth_losses = self.dense_depth_loss(dense_scale_certainty, prob, gd, scale)
+ scale_loss = (
+ self.ce_weight * depth_losses[f"ce_loss_{scale}"]
+ + depth_losses[f"depth_loss_{scale}"]
+ ) # scale ce loss for coarser scales
+ if self.scale_normalize:
+ scale_loss = scale_loss * 1 / scale
+ tot_loss = tot_loss + scale_loss
+ prev_gd = gd.detach()
+ return tot_loss
diff --git a/third_party/DKM/dkm/models/__init__.py b/third_party/DKM/dkm/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4fc321ec70fd116beca23e94248cb6bbe771523
--- /dev/null
+++ b/third_party/DKM/dkm/models/__init__.py
@@ -0,0 +1,4 @@
+from .model_zoo import (
+ DKMv3_outdoor,
+ DKMv3_indoor,
+)
diff --git a/third_party/DKM/dkm/models/deprecated/build_model.py b/third_party/DKM/dkm/models/deprecated/build_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd28335f3e348ab6c90b26ba91b95e864b0bbbb9
--- /dev/null
+++ b/third_party/DKM/dkm/models/deprecated/build_model.py
@@ -0,0 +1,787 @@
+import torch
+import torch.nn as nn
+from dkm import *
+from .local_corr import LocalCorr
+from .corr_channels import NormedCorr
+from torchvision.models import resnet as tv_resnet
+
+dkm_pretrained_urls = {
+ "DKM": {
+ "mega_synthetic": "https://github.com/Parskatt/storage/releases/download/dkm_mega_synthetic/dkm_mega_synthetic.pth",
+ "mega": "https://github.com/Parskatt/storage/releases/download/dkm_mega/dkm_mega.pth",
+ },
+ "DKMv2":{
+ "outdoor": "https://github.com/Parskatt/storage/releases/download/dkmv2/dkm_v2_outdoor.pth",
+ "indoor": "https://github.com/Parskatt/storage/releases/download/dkmv2/dkm_v2_indoor.pth",
+ }
+}
+
+
+def DKM(pretrained=True, version="mega_synthetic", device=None):
+ if device is None:
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ gp_dim = 256
+ dfn_dim = 384
+ feat_dim = 256
+ coordinate_decoder = DFN(
+ internal_dim=dfn_dim,
+ feat_input_modules=nn.ModuleDict(
+ {
+ "32": nn.Conv2d(512, feat_dim, 1, 1),
+ "16": nn.Conv2d(512, feat_dim, 1, 1),
+ }
+ ),
+ pred_input_modules=nn.ModuleDict(
+ {
+ "32": nn.Identity(),
+ "16": nn.Identity(),
+ }
+ ),
+ rrb_d_dict=nn.ModuleDict(
+ {
+ "32": RRB(gp_dim + feat_dim, dfn_dim),
+ "16": RRB(gp_dim + feat_dim, dfn_dim),
+ }
+ ),
+ cab_dict=nn.ModuleDict(
+ {
+ "32": CAB(2 * dfn_dim, dfn_dim),
+ "16": CAB(2 * dfn_dim, dfn_dim),
+ }
+ ),
+ rrb_u_dict=nn.ModuleDict(
+ {
+ "32": RRB(dfn_dim, dfn_dim),
+ "16": RRB(dfn_dim, dfn_dim),
+ }
+ ),
+ terminal_module=nn.ModuleDict(
+ {
+ "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
+ "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
+ }
+ ),
+ )
+ dw = True
+ hidden_blocks = 8
+ kernel_size = 5
+ conv_refiner = nn.ModuleDict(
+ {
+ "16": ConvRefiner(
+ 2 * 512,
+ 1024,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "8": ConvRefiner(
+ 2 * 512,
+ 1024,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "4": ConvRefiner(
+ 2 * 256,
+ 512,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "2": ConvRefiner(
+ 2 * 64,
+ 128,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "1": ConvRefiner(
+ 2 * 3,
+ 24,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ }
+ )
+ kernel_temperature = 0.2
+ learn_temperature = False
+ no_cov = True
+ kernel = CosKernel
+ only_attention = False
+ basis = "fourier"
+ gp32 = GP(
+ kernel,
+ T=kernel_temperature,
+ learn_temperature=learn_temperature,
+ only_attention=only_attention,
+ gp_dim=gp_dim,
+ basis=basis,
+ no_cov=no_cov,
+ )
+ gp16 = GP(
+ kernel,
+ T=kernel_temperature,
+ learn_temperature=learn_temperature,
+ only_attention=only_attention,
+ gp_dim=gp_dim,
+ basis=basis,
+ no_cov=no_cov,
+ )
+ gps = nn.ModuleDict({"32": gp32, "16": gp16})
+ proj = nn.ModuleDict(
+ {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
+ )
+ decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
+ h, w = 384, 512
+ encoder = Encoder(
+ tv_resnet.resnet50(pretrained=not pretrained),
+ ) # only load pretrained weights if not loading a pretrained matcher ;)
+ matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device)
+ if pretrained:
+ weights = torch.hub.load_state_dict_from_url(
+ dkm_pretrained_urls["DKM"][version]
+ )
+ matcher.load_state_dict(weights)
+ return matcher
+
+def DKMv2(pretrained=True, version="outdoor", resolution = "low", **kwargs):
+ gp_dim = 256
+ dfn_dim = 384
+ feat_dim = 256
+ coordinate_decoder = DFN(
+ internal_dim=dfn_dim,
+ feat_input_modules=nn.ModuleDict(
+ {
+ "32": nn.Conv2d(512, feat_dim, 1, 1),
+ "16": nn.Conv2d(512, feat_dim, 1, 1),
+ }
+ ),
+ pred_input_modules=nn.ModuleDict(
+ {
+ "32": nn.Identity(),
+ "16": nn.Identity(),
+ }
+ ),
+ rrb_d_dict=nn.ModuleDict(
+ {
+ "32": RRB(gp_dim + feat_dim, dfn_dim),
+ "16": RRB(gp_dim + feat_dim, dfn_dim),
+ }
+ ),
+ cab_dict=nn.ModuleDict(
+ {
+ "32": CAB(2 * dfn_dim, dfn_dim),
+ "16": CAB(2 * dfn_dim, dfn_dim),
+ }
+ ),
+ rrb_u_dict=nn.ModuleDict(
+ {
+ "32": RRB(dfn_dim, dfn_dim),
+ "16": RRB(dfn_dim, dfn_dim),
+ }
+ ),
+ terminal_module=nn.ModuleDict(
+ {
+ "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
+ "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
+ }
+ ),
+ )
+ dw = True
+ hidden_blocks = 8
+ kernel_size = 5
+ displacement_emb = "linear"
+ conv_refiner = nn.ModuleDict(
+ {
+ "16": ConvRefiner(
+ 2 * 512+128,
+ 1024+128,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ displacement_emb=displacement_emb,
+ displacement_emb_dim=128,
+ ),
+ "8": ConvRefiner(
+ 2 * 512+64,
+ 1024+64,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ displacement_emb=displacement_emb,
+ displacement_emb_dim=64,
+ ),
+ "4": ConvRefiner(
+ 2 * 256+32,
+ 512+32,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ displacement_emb=displacement_emb,
+ displacement_emb_dim=32,
+ ),
+ "2": ConvRefiner(
+ 2 * 64+16,
+ 128+16,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ displacement_emb=displacement_emb,
+ displacement_emb_dim=16,
+ ),
+ "1": ConvRefiner(
+ 2 * 3+6,
+ 24,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ displacement_emb=displacement_emb,
+ displacement_emb_dim=6,
+ ),
+ }
+ )
+ kernel_temperature = 0.2
+ learn_temperature = False
+ no_cov = True
+ kernel = CosKernel
+ only_attention = False
+ basis = "fourier"
+ gp32 = GP(
+ kernel,
+ T=kernel_temperature,
+ learn_temperature=learn_temperature,
+ only_attention=only_attention,
+ gp_dim=gp_dim,
+ basis=basis,
+ no_cov=no_cov,
+ )
+ gp16 = GP(
+ kernel,
+ T=kernel_temperature,
+ learn_temperature=learn_temperature,
+ only_attention=only_attention,
+ gp_dim=gp_dim,
+ basis=basis,
+ no_cov=no_cov,
+ )
+ gps = nn.ModuleDict({"32": gp32, "16": gp16})
+ proj = nn.ModuleDict(
+ {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
+ )
+ decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
+ if resolution == "low":
+ h, w = 384, 512
+ elif resolution == "high":
+ h, w = 480, 640
+ encoder = Encoder(
+ tv_resnet.resnet50(pretrained=not pretrained),
+ ) # only load pretrained weights if not loading a pretrained matcher ;)
+ matcher = RegressionMatcher(encoder, decoder, h=h, w=w,**kwargs).to(device)
+ if pretrained:
+ try:
+ weights = torch.hub.load_state_dict_from_url(
+ dkm_pretrained_urls["DKMv2"][version]
+ )
+ except:
+ weights = torch.load(
+ dkm_pretrained_urls["DKMv2"][version]
+ )
+ matcher.load_state_dict(weights)
+ return matcher
+
+
+def local_corr(pretrained=True, version="mega_synthetic"):
+ gp_dim = 256
+ dfn_dim = 384
+ feat_dim = 256
+ coordinate_decoder = DFN(
+ internal_dim=dfn_dim,
+ feat_input_modules=nn.ModuleDict(
+ {
+ "32": nn.Conv2d(512, feat_dim, 1, 1),
+ "16": nn.Conv2d(512, feat_dim, 1, 1),
+ }
+ ),
+ pred_input_modules=nn.ModuleDict(
+ {
+ "32": nn.Identity(),
+ "16": nn.Identity(),
+ }
+ ),
+ rrb_d_dict=nn.ModuleDict(
+ {
+ "32": RRB(gp_dim + feat_dim, dfn_dim),
+ "16": RRB(gp_dim + feat_dim, dfn_dim),
+ }
+ ),
+ cab_dict=nn.ModuleDict(
+ {
+ "32": CAB(2 * dfn_dim, dfn_dim),
+ "16": CAB(2 * dfn_dim, dfn_dim),
+ }
+ ),
+ rrb_u_dict=nn.ModuleDict(
+ {
+ "32": RRB(dfn_dim, dfn_dim),
+ "16": RRB(dfn_dim, dfn_dim),
+ }
+ ),
+ terminal_module=nn.ModuleDict(
+ {
+ "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
+ "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
+ }
+ ),
+ )
+ dw = True
+ hidden_blocks = 8
+ kernel_size = 5
+ conv_refiner = nn.ModuleDict(
+ {
+ "16": LocalCorr(
+ 81,
+ 81 * 12,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "8": LocalCorr(
+ 81,
+ 81 * 12,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "4": LocalCorr(
+ 81,
+ 81 * 6,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "2": LocalCorr(
+ 81,
+ 81,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "1": ConvRefiner(
+ 2 * 3,
+ 24,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ }
+ )
+ kernel_temperature = 0.2
+ learn_temperature = False
+ no_cov = True
+ kernel = CosKernel
+ only_attention = False
+ basis = "fourier"
+ gp32 = GP(
+ kernel,
+ T=kernel_temperature,
+ learn_temperature=learn_temperature,
+ only_attention=only_attention,
+ gp_dim=gp_dim,
+ basis=basis,
+ no_cov=no_cov,
+ )
+ gp16 = GP(
+ kernel,
+ T=kernel_temperature,
+ learn_temperature=learn_temperature,
+ only_attention=only_attention,
+ gp_dim=gp_dim,
+ basis=basis,
+ no_cov=no_cov,
+ )
+ gps = nn.ModuleDict({"32": gp32, "16": gp16})
+ proj = nn.ModuleDict(
+ {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
+ )
+ decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
+ h, w = 384, 512
+ encoder = Encoder(
+ tv_resnet.resnet50(pretrained=not pretrained)
+ ) # only load pretrained weights if not loading a pretrained matcher ;)
+ matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device)
+ if pretrained:
+ weights = torch.hub.load_state_dict_from_url(
+ dkm_pretrained_urls["local_corr"][version]
+ )
+ matcher.load_state_dict(weights)
+ return matcher
+
+
+def corr_channels(pretrained=True, version="mega_synthetic"):
+ h, w = 384, 512
+ gp_dim = (h // 32) * (w // 32), (h // 16) * (w // 16)
+ dfn_dim = 384
+ feat_dim = 256
+ coordinate_decoder = DFN(
+ internal_dim=dfn_dim,
+ feat_input_modules=nn.ModuleDict(
+ {
+ "32": nn.Conv2d(512, feat_dim, 1, 1),
+ "16": nn.Conv2d(512, feat_dim, 1, 1),
+ }
+ ),
+ pred_input_modules=nn.ModuleDict(
+ {
+ "32": nn.Identity(),
+ "16": nn.Identity(),
+ }
+ ),
+ rrb_d_dict=nn.ModuleDict(
+ {
+ "32": RRB(gp_dim[0] + feat_dim, dfn_dim),
+ "16": RRB(gp_dim[1] + feat_dim, dfn_dim),
+ }
+ ),
+ cab_dict=nn.ModuleDict(
+ {
+ "32": CAB(2 * dfn_dim, dfn_dim),
+ "16": CAB(2 * dfn_dim, dfn_dim),
+ }
+ ),
+ rrb_u_dict=nn.ModuleDict(
+ {
+ "32": RRB(dfn_dim, dfn_dim),
+ "16": RRB(dfn_dim, dfn_dim),
+ }
+ ),
+ terminal_module=nn.ModuleDict(
+ {
+ "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
+ "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
+ }
+ ),
+ )
+ dw = True
+ hidden_blocks = 8
+ kernel_size = 5
+ conv_refiner = nn.ModuleDict(
+ {
+ "16": ConvRefiner(
+ 2 * 512,
+ 1024,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "8": ConvRefiner(
+ 2 * 512,
+ 1024,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "4": ConvRefiner(
+ 2 * 256,
+ 512,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "2": ConvRefiner(
+ 2 * 64,
+ 128,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "1": ConvRefiner(
+ 2 * 3,
+ 24,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ }
+ )
+ gp32 = NormedCorr()
+ gp16 = NormedCorr()
+ gps = nn.ModuleDict({"32": gp32, "16": gp16})
+ proj = nn.ModuleDict(
+ {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
+ )
+ decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
+ h, w = 384, 512
+ encoder = Encoder(
+ tv_resnet.resnet50(pretrained=not pretrained)
+ ) # only load pretrained weights if not loading a pretrained matcher ;)
+ matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device)
+ if pretrained:
+ weights = torch.hub.load_state_dict_from_url(
+ dkm_pretrained_urls["corr_channels"][version]
+ )
+ matcher.load_state_dict(weights)
+ return matcher
+
+
+def baseline(pretrained=True, version="mega_synthetic"):
+ h, w = 384, 512
+ gp_dim = (h // 32) * (w // 32), (h // 16) * (w // 16)
+ dfn_dim = 384
+ feat_dim = 256
+ coordinate_decoder = DFN(
+ internal_dim=dfn_dim,
+ feat_input_modules=nn.ModuleDict(
+ {
+ "32": nn.Conv2d(512, feat_dim, 1, 1),
+ "16": nn.Conv2d(512, feat_dim, 1, 1),
+ }
+ ),
+ pred_input_modules=nn.ModuleDict(
+ {
+ "32": nn.Identity(),
+ "16": nn.Identity(),
+ }
+ ),
+ rrb_d_dict=nn.ModuleDict(
+ {
+ "32": RRB(gp_dim[0] + feat_dim, dfn_dim),
+ "16": RRB(gp_dim[1] + feat_dim, dfn_dim),
+ }
+ ),
+ cab_dict=nn.ModuleDict(
+ {
+ "32": CAB(2 * dfn_dim, dfn_dim),
+ "16": CAB(2 * dfn_dim, dfn_dim),
+ }
+ ),
+ rrb_u_dict=nn.ModuleDict(
+ {
+ "32": RRB(dfn_dim, dfn_dim),
+ "16": RRB(dfn_dim, dfn_dim),
+ }
+ ),
+ terminal_module=nn.ModuleDict(
+ {
+ "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
+ "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
+ }
+ ),
+ )
+ dw = True
+ hidden_blocks = 8
+ kernel_size = 5
+ conv_refiner = nn.ModuleDict(
+ {
+ "16": LocalCorr(
+ 81,
+ 81 * 12,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "8": LocalCorr(
+ 81,
+ 81 * 12,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "4": LocalCorr(
+ 81,
+ 81 * 6,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "2": LocalCorr(
+ 81,
+ 81,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "1": ConvRefiner(
+ 2 * 3,
+ 24,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ }
+ )
+ gp32 = NormedCorr()
+ gp16 = NormedCorr()
+ gps = nn.ModuleDict({"32": gp32, "16": gp16})
+ proj = nn.ModuleDict(
+ {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
+ )
+ decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
+ h, w = 384, 512
+ encoder = Encoder(
+ tv_resnet.resnet50(pretrained=not pretrained)
+ ) # only load pretrained weights if not loading a pretrained matcher ;)
+ matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device)
+ if pretrained:
+ weights = torch.hub.load_state_dict_from_url(
+ dkm_pretrained_urls["baseline"][version]
+ )
+ matcher.load_state_dict(weights)
+ return matcher
+
+
+def linear(pretrained=True, version="mega_synthetic"):
+ gp_dim = 256
+ dfn_dim = 384
+ feat_dim = 256
+ coordinate_decoder = DFN(
+ internal_dim=dfn_dim,
+ feat_input_modules=nn.ModuleDict(
+ {
+ "32": nn.Conv2d(512, feat_dim, 1, 1),
+ "16": nn.Conv2d(512, feat_dim, 1, 1),
+ }
+ ),
+ pred_input_modules=nn.ModuleDict(
+ {
+ "32": nn.Identity(),
+ "16": nn.Identity(),
+ }
+ ),
+ rrb_d_dict=nn.ModuleDict(
+ {
+ "32": RRB(gp_dim + feat_dim, dfn_dim),
+ "16": RRB(gp_dim + feat_dim, dfn_dim),
+ }
+ ),
+ cab_dict=nn.ModuleDict(
+ {
+ "32": CAB(2 * dfn_dim, dfn_dim),
+ "16": CAB(2 * dfn_dim, dfn_dim),
+ }
+ ),
+ rrb_u_dict=nn.ModuleDict(
+ {
+ "32": RRB(dfn_dim, dfn_dim),
+ "16": RRB(dfn_dim, dfn_dim),
+ }
+ ),
+ terminal_module=nn.ModuleDict(
+ {
+ "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
+ "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
+ }
+ ),
+ )
+ dw = True
+ hidden_blocks = 8
+ kernel_size = 5
+ conv_refiner = nn.ModuleDict(
+ {
+ "16": ConvRefiner(
+ 2 * 512,
+ 1024,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "8": ConvRefiner(
+ 2 * 512,
+ 1024,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "4": ConvRefiner(
+ 2 * 256,
+ 512,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "2": ConvRefiner(
+ 2 * 64,
+ 128,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "1": ConvRefiner(
+ 2 * 3,
+ 24,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ }
+ )
+ kernel_temperature = 0.2
+ learn_temperature = False
+ no_cov = True
+ kernel = CosKernel
+ only_attention = False
+ basis = "linear"
+ gp32 = GP(
+ kernel,
+ T=kernel_temperature,
+ learn_temperature=learn_temperature,
+ only_attention=only_attention,
+ gp_dim=gp_dim,
+ basis=basis,
+ no_cov=no_cov,
+ )
+ gp16 = GP(
+ kernel,
+ T=kernel_temperature,
+ learn_temperature=learn_temperature,
+ only_attention=only_attention,
+ gp_dim=gp_dim,
+ basis=basis,
+ no_cov=no_cov,
+ )
+ gps = nn.ModuleDict({"32": gp32, "16": gp16})
+ proj = nn.ModuleDict(
+ {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
+ )
+ decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
+ h, w = 384, 512
+ encoder = Encoder(
+ tv_resnet.resnet50(pretrained=not pretrained)
+ ) # only load pretrained weights if not loading a pretrained matcher ;)
+ matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device)
+ if pretrained:
+ weights = torch.hub.load_state_dict_from_url(
+ dkm_pretrained_urls["linear"][version]
+ )
+ matcher.load_state_dict(weights)
+ return matcher
diff --git a/third_party/DKM/dkm/models/deprecated/corr_channels.py b/third_party/DKM/dkm/models/deprecated/corr_channels.py
new file mode 100644
index 0000000000000000000000000000000000000000..8713b0d8c7a0ce91da4d2105ba29097a4969a037
--- /dev/null
+++ b/third_party/DKM/dkm/models/deprecated/corr_channels.py
@@ -0,0 +1,34 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+
+
+class NormedCorrelationKernel(nn.Module): # similar to softmax kernel
+ def __init__(self):
+ super().__init__()
+
+ def __call__(self, x, y, eps=1e-6):
+ c = torch.einsum("bnd,bmd->bnm", x, y) / (
+ x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps
+ )
+ return c
+
+
+class NormedCorr(nn.Module):
+ def __init__(
+ self,
+ ):
+ super().__init__()
+ self.corr = NormedCorrelationKernel()
+
+ def reshape(self, x):
+ return rearrange(x, "b d h w -> b (h w) d")
+
+ def forward(self, x, y, **kwargs):
+ b, c, h, w = y.shape
+ assert x.shape == y.shape
+ x, y = self.reshape(x), self.reshape(y)
+ corr_xy = self.corr(x, y)
+ corr_xy_flat = rearrange(corr_xy, "b (h w) c -> b c h w", h=h, w=w)
+ return corr_xy_flat
diff --git a/third_party/DKM/dkm/models/deprecated/local_corr.py b/third_party/DKM/dkm/models/deprecated/local_corr.py
new file mode 100644
index 0000000000000000000000000000000000000000..681fe4c0079561fa7a4c44e82a8879a4a27273a1
--- /dev/null
+++ b/third_party/DKM/dkm/models/deprecated/local_corr.py
@@ -0,0 +1,630 @@
+import torch
+import torch.nn.functional as F
+
+try:
+ import cupy
+except:
+ print("Cupy not found, local correlation will not work")
+import re
+from ..dkm import ConvRefiner
+
+
+class Stream:
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ if device == 'cuda':
+ stream = torch.cuda.current_stream(device=device).cuda_stream
+ else:
+ stream = None
+
+
+kernel_Correlation_rearrange = """
+ extern "C" __global__ void kernel_Correlation_rearrange(
+ const int n,
+ const float* input,
+ float* output
+ ) {
+ int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x;
+ if (intIndex >= n) {
+ return;
+ }
+ int intSample = blockIdx.z;
+ int intChannel = blockIdx.y;
+ float dblValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex];
+ __syncthreads();
+ int intPaddedY = (intIndex / SIZE_3(input)) + 4;
+ int intPaddedX = (intIndex % SIZE_3(input)) + 4;
+ int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX;
+ output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = dblValue;
+ }
+"""
+
+kernel_Correlation_updateOutput = """
+ extern "C" __global__ void kernel_Correlation_updateOutput(
+ const int n,
+ const float* rbot0,
+ const float* rbot1,
+ float* top
+ ) {
+ extern __shared__ char patch_data_char[];
+ float *patch_data = (float *)patch_data_char;
+ // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1
+ int x1 = blockIdx.x + 4;
+ int y1 = blockIdx.y + 4;
+ int item = blockIdx.z;
+ int ch_off = threadIdx.x;
+ // Load 3D patch into shared shared memory
+ for (int j = 0; j < 1; j++) { // HEIGHT
+ for (int i = 0; i < 1; i++) { // WIDTH
+ int ji_off = (j + i) * SIZE_3(rbot0);
+ for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
+ int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch;
+ int idxPatchData = ji_off + ch;
+ patch_data[idxPatchData] = rbot0[idx1];
+ }
+ }
+ }
+ __syncthreads();
+ __shared__ float sum[32];
+ // Compute correlation
+ for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) {
+ sum[ch_off] = 0;
+ int s2o = top_channel % 9 - 4;
+ int s2p = top_channel / 9 - 4;
+ for (int j = 0; j < 1; j++) { // HEIGHT
+ for (int i = 0; i < 1; i++) { // WIDTH
+ int ji_off = (j + i) * SIZE_3(rbot0);
+ for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
+ int x2 = x1 + s2o;
+ int y2 = y1 + s2p;
+ int idxPatchData = ji_off + ch;
+ int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch;
+ sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2];
+ }
+ }
+ }
+ __syncthreads();
+ if (ch_off == 0) {
+ float total_sum = 0;
+ for (int idx = 0; idx < 32; idx++) {
+ total_sum += sum[idx];
+ }
+ const int sumelems = SIZE_3(rbot0);
+ const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x;
+ top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems;
+ }
+ }
+ }
+"""
+
+kernel_Correlation_updateGradFirst = """
+ #define ROUND_OFF 50000
+ extern "C" __global__ void kernel_Correlation_updateGradFirst(
+ const int n,
+ const int intSample,
+ const float* rbot0,
+ const float* rbot1,
+ const float* gradOutput,
+ float* gradFirst,
+ float* gradSecond
+ ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
+ int n = intIndex % SIZE_1(gradFirst); // channels
+ int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 4; // w-pos
+ int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 4; // h-pos
+ // round_off is a trick to enable integer division with ceil, even for negative numbers
+ // We use a large offset, for the inner part not to become negative.
+ const int round_off = ROUND_OFF;
+ const int round_off_s1 = round_off;
+ // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
+ int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4)
+ int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4)
+ // Same here:
+ int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4)
+ int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4)
+ float sum = 0;
+ if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
+ xmin = max(0,xmin);
+ xmax = min(SIZE_3(gradOutput)-1,xmax);
+ ymin = max(0,ymin);
+ ymax = min(SIZE_2(gradOutput)-1,ymax);
+ for (int p = -4; p <= 4; p++) {
+ for (int o = -4; o <= 4; o++) {
+ // Get rbot1 data:
+ int s2o = o;
+ int s2p = p;
+ int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n;
+ float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n]
+ // Index offset for gradOutput in following loops:
+ int op = (p+4) * 9 + (o+4); // index[o,p]
+ int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
+ for (int y = ymin; y <= ymax; y++) {
+ for (int x = xmin; x <= xmax; x++) {
+ int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
+ sum += gradOutput[idxgradOutput] * bot1tmp;
+ }
+ }
+ }
+ }
+ }
+ const int sumelems = SIZE_1(gradFirst);
+ const int bot0index = ((n * SIZE_2(gradFirst)) + (m-4)) * SIZE_3(gradFirst) + (l-4);
+ gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems;
+ } }
+"""
+
+kernel_Correlation_updateGradSecond = """
+ #define ROUND_OFF 50000
+ extern "C" __global__ void kernel_Correlation_updateGradSecond(
+ const int n,
+ const int intSample,
+ const float* rbot0,
+ const float* rbot1,
+ const float* gradOutput,
+ float* gradFirst,
+ float* gradSecond
+ ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
+ int n = intIndex % SIZE_1(gradSecond); // channels
+ int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 4; // w-pos
+ int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 4; // h-pos
+ // round_off is a trick to enable integer division with ceil, even for negative numbers
+ // We use a large offset, for the inner part not to become negative.
+ const int round_off = ROUND_OFF;
+ const int round_off_s1 = round_off;
+ float sum = 0;
+ for (int p = -4; p <= 4; p++) {
+ for (int o = -4; o <= 4; o++) {
+ int s2o = o;
+ int s2p = p;
+ //Get X,Y ranges and clamp
+ // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
+ int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o)
+ int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o)
+ // Same here:
+ int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o)
+ int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p)
+ if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
+ xmin = max(0,xmin);
+ xmax = min(SIZE_3(gradOutput)-1,xmax);
+ ymin = max(0,ymin);
+ ymax = min(SIZE_2(gradOutput)-1,ymax);
+ // Get rbot0 data:
+ int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n;
+ float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n]
+ // Index offset for gradOutput in following loops:
+ int op = (p+4) * 9 + (o+4); // index[o,p]
+ int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
+ for (int y = ymin; y <= ymax; y++) {
+ for (int x = xmin; x <= xmax; x++) {
+ int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
+ sum += gradOutput[idxgradOutput] * bot0tmp;
+ }
+ }
+ }
+ }
+ }
+ const int sumelems = SIZE_1(gradSecond);
+ const int bot1index = ((n * SIZE_2(gradSecond)) + (m-4)) * SIZE_3(gradSecond) + (l-4);
+ gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems;
+ } }
+"""
+
+
+def cupy_kernel(strFunction, objectVariables):
+ strKernel = globals()[strFunction]
+
+ while True:
+ objectMatch = re.search(r"(SIZE_)([0-4])(\()([^\)]*)(\))", strKernel)
+
+ if objectMatch is None:
+ break
+
+ intArg = int(objectMatch.group(2))
+
+ strTensor = objectMatch.group(4)
+ intSizes = objectVariables[strTensor].size()
+
+ strKernel = strKernel.replace(objectMatch.group(), str(intSizes[intArg]))
+
+ while True:
+ objectMatch = re.search(r"(VALUE_)([0-4])(\()([^\)]+)(\))", strKernel)
+
+ if objectMatch is None:
+ break
+
+ intArgs = int(objectMatch.group(2))
+ strArgs = objectMatch.group(4).split(",")
+
+ strTensor = strArgs[0]
+ intStrides = objectVariables[strTensor].stride()
+ strIndex = [
+ "(("
+ + strArgs[intArg + 1].replace("{", "(").replace("}", ")").strip()
+ + ")*"
+ + str(intStrides[intArg])
+ + ")"
+ for intArg in range(intArgs)
+ ]
+
+ strKernel = strKernel.replace(
+ objectMatch.group(0), strTensor + "[" + str.join("+", strIndex) + "]"
+ )
+
+ return strKernel
+
+
+try:
+
+ @cupy.memoize(for_each_device=True)
+ def cupy_launch(strFunction, strKernel):
+ return cupy.RawModule(code=strKernel).get_function(strFunction)
+
+except:
+ pass
+
+
+class _FunctionCorrelation(torch.autograd.Function):
+ @staticmethod
+ def forward(self, first, second):
+ rbot0 = first.new_zeros(
+ [first.size(0), first.size(2) + 8, first.size(3) + 8, first.size(1)]
+ )
+ rbot1 = first.new_zeros(
+ [first.size(0), first.size(2) + 8, first.size(3) + 8, first.size(1)]
+ )
+
+ self.save_for_backward(first, second, rbot0, rbot1)
+
+ first = first.contiguous()
+ second = second.contiguous()
+
+ output = first.new_zeros([first.size(0), 81, first.size(2), first.size(3)])
+
+ if first.is_cuda == True:
+ n = first.size(2) * first.size(3)
+ cupy_launch(
+ "kernel_Correlation_rearrange",
+ cupy_kernel(
+ "kernel_Correlation_rearrange", {"input": first, "output": rbot0}
+ ),
+ )(
+ grid=tuple([int((n + 16 - 1) / 16), first.size(1), first.size(0)]),
+ block=tuple([16, 1, 1]),
+ args=[n, first.data_ptr(), rbot0.data_ptr()],
+ stream=Stream,
+ )
+
+ n = second.size(2) * second.size(3)
+ cupy_launch(
+ "kernel_Correlation_rearrange",
+ cupy_kernel(
+ "kernel_Correlation_rearrange", {"input": second, "output": rbot1}
+ ),
+ )(
+ grid=tuple([int((n + 16 - 1) / 16), second.size(1), second.size(0)]),
+ block=tuple([16, 1, 1]),
+ args=[n, second.data_ptr(), rbot1.data_ptr()],
+ stream=Stream,
+ )
+
+ n = output.size(1) * output.size(2) * output.size(3)
+ cupy_launch(
+ "kernel_Correlation_updateOutput",
+ cupy_kernel(
+ "kernel_Correlation_updateOutput",
+ {"rbot0": rbot0, "rbot1": rbot1, "top": output},
+ ),
+ )(
+ grid=tuple([output.size(3), output.size(2), output.size(0)]),
+ block=tuple([32, 1, 1]),
+ shared_mem=first.size(1) * 4,
+ args=[n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr()],
+ stream=Stream,
+ )
+
+ elif first.is_cuda == False:
+ raise NotImplementedError()
+
+ return output
+
+ @staticmethod
+ def backward(self, gradOutput):
+ first, second, rbot0, rbot1 = self.saved_tensors
+
+ gradOutput = gradOutput.contiguous()
+
+ assert gradOutput.is_contiguous() == True
+
+ gradFirst = (
+ first.new_zeros(
+ [first.size(0), first.size(1), first.size(2), first.size(3)]
+ )
+ if self.needs_input_grad[0] == True
+ else None
+ )
+ gradSecond = (
+ first.new_zeros(
+ [first.size(0), first.size(1), first.size(2), first.size(3)]
+ )
+ if self.needs_input_grad[1] == True
+ else None
+ )
+
+ if first.is_cuda == True:
+ if gradFirst is not None:
+ for intSample in range(first.size(0)):
+ n = first.size(1) * first.size(2) * first.size(3)
+ cupy_launch(
+ "kernel_Correlation_updateGradFirst",
+ cupy_kernel(
+ "kernel_Correlation_updateGradFirst",
+ {
+ "rbot0": rbot0,
+ "rbot1": rbot1,
+ "gradOutput": gradOutput,
+ "gradFirst": gradFirst,
+ "gradSecond": None,
+ },
+ ),
+ )(
+ grid=tuple([int((n + 512 - 1) / 512), 1, 1]),
+ block=tuple([512, 1, 1]),
+ args=[
+ n,
+ intSample,
+ rbot0.data_ptr(),
+ rbot1.data_ptr(),
+ gradOutput.data_ptr(),
+ gradFirst.data_ptr(),
+ None,
+ ],
+ stream=Stream,
+ )
+
+ if gradSecond is not None:
+ for intSample in range(first.size(0)):
+ n = first.size(1) * first.size(2) * first.size(3)
+ cupy_launch(
+ "kernel_Correlation_updateGradSecond",
+ cupy_kernel(
+ "kernel_Correlation_updateGradSecond",
+ {
+ "rbot0": rbot0,
+ "rbot1": rbot1,
+ "gradOutput": gradOutput,
+ "gradFirst": None,
+ "gradSecond": gradSecond,
+ },
+ ),
+ )(
+ grid=tuple([int((n + 512 - 1) / 512), 1, 1]),
+ block=tuple([512, 1, 1]),
+ args=[
+ n,
+ intSample,
+ rbot0.data_ptr(),
+ rbot1.data_ptr(),
+ gradOutput.data_ptr(),
+ None,
+ gradSecond.data_ptr(),
+ ],
+ stream=Stream,
+ )
+
+ elif first.is_cuda == False:
+ raise NotImplementedError()
+
+ return gradFirst, gradSecond
+
+
+class _FunctionCorrelationTranspose(torch.autograd.Function):
+ @staticmethod
+ def forward(self, input, second):
+ rbot0 = second.new_zeros(
+ [second.size(0), second.size(2) + 8, second.size(3) + 8, second.size(1)]
+ )
+ rbot1 = second.new_zeros(
+ [second.size(0), second.size(2) + 8, second.size(3) + 8, second.size(1)]
+ )
+
+ self.save_for_backward(input, second, rbot0, rbot1)
+
+ input = input.contiguous()
+ second = second.contiguous()
+
+ output = second.new_zeros(
+ [second.size(0), second.size(1), second.size(2), second.size(3)]
+ )
+
+ if second.is_cuda == True:
+ n = second.size(2) * second.size(3)
+ cupy_launch(
+ "kernel_Correlation_rearrange",
+ cupy_kernel(
+ "kernel_Correlation_rearrange", {"input": second, "output": rbot1}
+ ),
+ )(
+ grid=tuple([int((n + 16 - 1) / 16), second.size(1), second.size(0)]),
+ block=tuple([16, 1, 1]),
+ args=[n, second.data_ptr(), rbot1.data_ptr()],
+ stream=Stream,
+ )
+
+ for intSample in range(second.size(0)):
+ n = second.size(1) * second.size(2) * second.size(3)
+ cupy_launch(
+ "kernel_Correlation_updateGradFirst",
+ cupy_kernel(
+ "kernel_Correlation_updateGradFirst",
+ {
+ "rbot0": rbot0,
+ "rbot1": rbot1,
+ "gradOutput": input,
+ "gradFirst": output,
+ "gradSecond": None,
+ },
+ ),
+ )(
+ grid=tuple([int((n + 512 - 1) / 512), 1, 1]),
+ block=tuple([512, 1, 1]),
+ args=[
+ n,
+ intSample,
+ rbot0.data_ptr(),
+ rbot1.data_ptr(),
+ input.data_ptr(),
+ output.data_ptr(),
+ None,
+ ],
+ stream=Stream,
+ )
+
+ elif second.is_cuda == False:
+ raise NotImplementedError()
+
+ return output
+
+ @staticmethod
+ def backward(self, gradOutput):
+ input, second, rbot0, rbot1 = self.saved_tensors
+
+ gradOutput = gradOutput.contiguous()
+
+ gradInput = (
+ input.new_zeros(
+ [input.size(0), input.size(1), input.size(2), input.size(3)]
+ )
+ if self.needs_input_grad[0] == True
+ else None
+ )
+ gradSecond = (
+ second.new_zeros(
+ [second.size(0), second.size(1), second.size(2), second.size(3)]
+ )
+ if self.needs_input_grad[1] == True
+ else None
+ )
+
+ if second.is_cuda == True:
+ if gradInput is not None or gradSecond is not None:
+ n = second.size(2) * second.size(3)
+ cupy_launch(
+ "kernel_Correlation_rearrange",
+ cupy_kernel(
+ "kernel_Correlation_rearrange",
+ {"input": gradOutput, "output": rbot0},
+ ),
+ )(
+ grid=tuple(
+ [int((n + 16 - 1) / 16), gradOutput.size(1), gradOutput.size(0)]
+ ),
+ block=tuple([16, 1, 1]),
+ args=[n, gradOutput.data_ptr(), rbot0.data_ptr()],
+ stream=Stream,
+ )
+
+ if gradInput is not None:
+ n = gradInput.size(1) * gradInput.size(2) * gradInput.size(3)
+ cupy_launch(
+ "kernel_Correlation_updateOutput",
+ cupy_kernel(
+ "kernel_Correlation_updateOutput",
+ {"rbot0": rbot0, "rbot1": rbot1, "top": gradInput},
+ ),
+ )(
+ grid=tuple(
+ [gradInput.size(3), gradInput.size(2), gradInput.size(0)]
+ ),
+ block=tuple([32, 1, 1]),
+ shared_mem=gradOutput.size(1) * 4,
+ args=[n, rbot0.data_ptr(), rbot1.data_ptr(), gradInput.data_ptr()],
+ stream=Stream,
+ )
+
+ if gradSecond is not None:
+ for intSample in range(second.size(0)):
+ n = second.size(1) * second.size(2) * second.size(3)
+ cupy_launch(
+ "kernel_Correlation_updateGradSecond",
+ cupy_kernel(
+ "kernel_Correlation_updateGradSecond",
+ {
+ "rbot0": rbot0,
+ "rbot1": rbot1,
+ "gradOutput": input,
+ "gradFirst": None,
+ "gradSecond": gradSecond,
+ },
+ ),
+ )(
+ grid=tuple([int((n + 512 - 1) / 512), 1, 1]),
+ block=tuple([512, 1, 1]),
+ args=[
+ n,
+ intSample,
+ rbot0.data_ptr(),
+ rbot1.data_ptr(),
+ input.data_ptr(),
+ None,
+ gradSecond.data_ptr(),
+ ],
+ stream=Stream,
+ )
+
+ elif second.is_cuda == False:
+ raise NotImplementedError()
+
+ return gradInput, gradSecond
+
+
+def FunctionCorrelation(reference_features, query_features):
+ return _FunctionCorrelation.apply(reference_features, query_features)
+
+
+class ModuleCorrelation(torch.nn.Module):
+ def __init__(self):
+ super(ModuleCorrelation, self).__init__()
+
+ def forward(self, tensorFirst, tensorSecond):
+ return _FunctionCorrelation.apply(tensorFirst, tensorSecond)
+
+
+def FunctionCorrelationTranspose(reference_features, query_features):
+ return _FunctionCorrelationTranspose.apply(reference_features, query_features)
+
+
+class ModuleCorrelationTranspose(torch.nn.Module):
+ def __init__(self):
+ super(ModuleCorrelationTranspose, self).__init__()
+
+ def forward(self, tensorFirst, tensorSecond):
+ return _FunctionCorrelationTranspose.apply(tensorFirst, tensorSecond)
+
+
+class LocalCorr(ConvRefiner):
+ def forward(self, x, y, flow):
+ """Computes the relative refining displacement in pixels for a given image x,y and a coarse flow-field between them
+
+ Args:
+ x ([type]): [description]
+ y ([type]): [description]
+ flow ([type]): [description]
+
+ Returns:
+ [type]: [description]
+ """
+ with torch.no_grad():
+ x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False)
+ corr = FunctionCorrelation(x, x_hat)
+ d = self.block1(corr)
+ d = self.hidden_blocks(d)
+ d = self.out_conv(d)
+ certainty, displacement = d[:, :-2], d[:, -2:]
+ return certainty, displacement
+
+
+if __name__ == "__main__":
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ x = torch.randn(2, 128, 32, 32).to(device)
+ y = torch.randn(2, 128, 32, 32).to(device)
+ local_corr = LocalCorr(in_dim=81, hidden_dim=81 * 4)
+ z = local_corr(x, y)
+ print("hej")
diff --git a/third_party/DKM/dkm/models/dkm.py b/third_party/DKM/dkm/models/dkm.py
new file mode 100644
index 0000000000000000000000000000000000000000..27c3f6d59ad3a8e976e3d719868908ddf443883e
--- /dev/null
+++ b/third_party/DKM/dkm/models/dkm.py
@@ -0,0 +1,759 @@
+import math
+import os
+import numpy as np
+from PIL import Image
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from ..utils import get_tuple_transform_ops
+from einops import rearrange
+from ..utils.local_correlation import local_correlation
+
+
+class ConvRefiner(nn.Module):
+ def __init__(
+ self,
+ in_dim=6,
+ hidden_dim=16,
+ out_dim=2,
+ dw=False,
+ kernel_size=5,
+ hidden_blocks=3,
+ displacement_emb = None,
+ displacement_emb_dim = None,
+ local_corr_radius = None,
+ corr_in_other = None,
+ no_support_fm = False,
+ ):
+ super().__init__()
+ self.block1 = self.create_block(
+ in_dim, hidden_dim, dw=dw, kernel_size=kernel_size
+ )
+ self.hidden_blocks = nn.Sequential(
+ *[
+ self.create_block(
+ hidden_dim,
+ hidden_dim,
+ dw=dw,
+ kernel_size=kernel_size,
+ )
+ for hb in range(hidden_blocks)
+ ]
+ )
+ self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
+ if displacement_emb:
+ self.has_displacement_emb = True
+ self.disp_emb = nn.Conv2d(2,displacement_emb_dim,1,1,0)
+ else:
+ self.has_displacement_emb = False
+ self.local_corr_radius = local_corr_radius
+ self.corr_in_other = corr_in_other
+ self.no_support_fm = no_support_fm
+ def create_block(
+ self,
+ in_dim,
+ out_dim,
+ dw=False,
+ kernel_size=5,
+ ):
+ num_groups = 1 if not dw else in_dim
+ if dw:
+ assert (
+ out_dim % in_dim == 0
+ ), "outdim must be divisible by indim for depthwise"
+ conv1 = nn.Conv2d(
+ in_dim,
+ out_dim,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ groups=num_groups,
+ )
+ norm = nn.BatchNorm2d(out_dim)
+ relu = nn.ReLU(inplace=True)
+ conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
+ return nn.Sequential(conv1, norm, relu, conv2)
+
+ def forward(self, x, y, flow):
+ """Computes the relative refining displacement in pixels for a given image x,y and a coarse flow-field between them
+
+ Args:
+ x ([type]): [description]
+ y ([type]): [description]
+ flow ([type]): [description]
+
+ Returns:
+ [type]: [description]
+ """
+ device = x.device
+ b,c,hs,ws = x.shape
+ with torch.no_grad():
+ x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False)
+ if self.has_displacement_emb:
+ query_coords = torch.meshgrid(
+ (
+ torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
+ torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
+ )
+ )
+ query_coords = torch.stack((query_coords[1], query_coords[0]))
+ query_coords = query_coords[None].expand(b, 2, hs, ws)
+ in_displacement = flow-query_coords
+ emb_in_displacement = self.disp_emb(in_displacement)
+ if self.local_corr_radius:
+ #TODO: should corr have gradient?
+ if self.corr_in_other:
+ # Corr in other means take a kxk grid around the predicted coordinate in other image
+ local_corr = local_correlation(x,y,local_radius=self.local_corr_radius,flow = flow)
+ else:
+ # Otherwise we use the warp to sample in the first image
+ # This is actually different operations, especially for large viewpoint changes
+ local_corr = local_correlation(x, x_hat, local_radius=self.local_corr_radius,)
+ if self.no_support_fm:
+ x_hat = torch.zeros_like(x)
+ d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1)
+ else:
+ d = torch.cat((x, x_hat, emb_in_displacement), dim=1)
+ else:
+ if self.no_support_fm:
+ x_hat = torch.zeros_like(x)
+ d = torch.cat((x, x_hat), dim=1)
+ d = self.block1(d)
+ d = self.hidden_blocks(d)
+ d = self.out_conv(d)
+ certainty, displacement = d[:, :-2], d[:, -2:]
+ return certainty, displacement
+
+
+class CosKernel(nn.Module): # similar to softmax kernel
+ def __init__(self, T, learn_temperature=False):
+ super().__init__()
+ self.learn_temperature = learn_temperature
+ if self.learn_temperature:
+ self.T = nn.Parameter(torch.tensor(T))
+ else:
+ self.T = T
+
+ def __call__(self, x, y, eps=1e-6):
+ c = torch.einsum("bnd,bmd->bnm", x, y) / (
+ x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps
+ )
+ if self.learn_temperature:
+ T = self.T.abs() + 0.01
+ else:
+ T = torch.tensor(self.T, device=c.device)
+ K = ((c - 1.0) / T).exp()
+ return K
+
+
+class CAB(nn.Module):
+ def __init__(self, in_channels, out_channels):
+ super(CAB, self).__init__()
+ self.global_pooling = nn.AdaptiveAvgPool2d(1)
+ self.conv1 = nn.Conv2d(
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.relu = nn.ReLU()
+ self.conv2 = nn.Conv2d(
+ out_channels, out_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.sigmod = nn.Sigmoid()
+
+ def forward(self, x):
+ x1, x2 = x # high, low (old, new)
+ x = torch.cat([x1, x2], dim=1)
+ x = self.global_pooling(x)
+ x = self.conv1(x)
+ x = self.relu(x)
+ x = self.conv2(x)
+ x = self.sigmod(x)
+ x2 = x * x2
+ res = x2 + x1
+ return res
+
+
+class RRB(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size=3):
+ super(RRB, self).__init__()
+ self.conv1 = nn.Conv2d(
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.conv2 = nn.Conv2d(
+ out_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ )
+ self.relu = nn.ReLU()
+ self.bn = nn.BatchNorm2d(out_channels)
+ self.conv3 = nn.Conv2d(
+ out_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ )
+
+ def forward(self, x):
+ x = self.conv1(x)
+ res = self.conv2(x)
+ res = self.bn(res)
+ res = self.relu(res)
+ res = self.conv3(res)
+ return self.relu(x + res)
+
+
+class DFN(nn.Module):
+ def __init__(
+ self,
+ internal_dim,
+ feat_input_modules,
+ pred_input_modules,
+ rrb_d_dict,
+ cab_dict,
+ rrb_u_dict,
+ use_global_context=False,
+ global_dim=None,
+ terminal_module=None,
+ upsample_mode="bilinear",
+ align_corners=False,
+ ):
+ super().__init__()
+ if use_global_context:
+ assert (
+ global_dim is not None
+ ), "Global dim must be provided when using global context"
+ self.align_corners = align_corners
+ self.internal_dim = internal_dim
+ self.feat_input_modules = feat_input_modules
+ self.pred_input_modules = pred_input_modules
+ self.rrb_d = rrb_d_dict
+ self.cab = cab_dict
+ self.rrb_u = rrb_u_dict
+ self.use_global_context = use_global_context
+ if use_global_context:
+ self.global_to_internal = nn.Conv2d(global_dim, self.internal_dim, 1, 1, 0)
+ self.global_pooling = nn.AdaptiveAvgPool2d(1)
+ self.terminal_module = (
+ terminal_module if terminal_module is not None else nn.Identity()
+ )
+ self.upsample_mode = upsample_mode
+ self._scales = [int(key) for key in self.terminal_module.keys()]
+
+ def scales(self):
+ return self._scales.copy()
+
+ def forward(self, embeddings, feats, context, key):
+ feats = self.feat_input_modules[str(key)](feats)
+ embeddings = torch.cat([feats, embeddings], dim=1)
+ embeddings = self.rrb_d[str(key)](embeddings)
+ context = self.cab[str(key)]([context, embeddings])
+ context = self.rrb_u[str(key)](context)
+ preds = self.terminal_module[str(key)](context)
+ pred_coord = preds[:, -2:]
+ pred_certainty = preds[:, :-2]
+ return pred_coord, pred_certainty, context
+
+
+class GP(nn.Module):
+ def __init__(
+ self,
+ kernel,
+ T=1,
+ learn_temperature=False,
+ only_attention=False,
+ gp_dim=64,
+ basis="fourier",
+ covar_size=5,
+ only_nearest_neighbour=False,
+ sigma_noise=0.1,
+ no_cov=False,
+ predict_features = False,
+ ):
+ super().__init__()
+ self.K = kernel(T=T, learn_temperature=learn_temperature)
+ self.sigma_noise = sigma_noise
+ self.covar_size = covar_size
+ self.pos_conv = torch.nn.Conv2d(2, gp_dim, 1, 1)
+ self.only_attention = only_attention
+ self.only_nearest_neighbour = only_nearest_neighbour
+ self.basis = basis
+ self.no_cov = no_cov
+ self.dim = gp_dim
+ self.predict_features = predict_features
+
+ def get_local_cov(self, cov):
+ K = self.covar_size
+ b, h, w, h, w = cov.shape
+ hw = h * w
+ cov = F.pad(cov, 4 * (K // 2,)) # pad v_q
+ delta = torch.stack(
+ torch.meshgrid(
+ torch.arange(-(K // 2), K // 2 + 1), torch.arange(-(K // 2), K // 2 + 1)
+ ),
+ dim=-1,
+ )
+ positions = torch.stack(
+ torch.meshgrid(
+ torch.arange(K // 2, h + K // 2), torch.arange(K // 2, w + K // 2)
+ ),
+ dim=-1,
+ )
+ neighbours = positions[:, :, None, None, :] + delta[None, :, :]
+ points = torch.arange(hw)[:, None].expand(hw, K**2)
+ local_cov = cov.reshape(b, hw, h + K - 1, w + K - 1)[
+ :,
+ points.flatten(),
+ neighbours[..., 0].flatten(),
+ neighbours[..., 1].flatten(),
+ ].reshape(b, h, w, K**2)
+ return local_cov
+
+ def reshape(self, x):
+ return rearrange(x, "b d h w -> b (h w) d")
+
+ def project_to_basis(self, x):
+ if self.basis == "fourier":
+ return torch.cos(8 * math.pi * self.pos_conv(x))
+ elif self.basis == "linear":
+ return self.pos_conv(x)
+ else:
+ raise ValueError(
+ "No other bases other than fourier and linear currently supported in public release"
+ )
+
+ def get_pos_enc(self, y):
+ b, c, h, w = y.shape
+ coarse_coords = torch.meshgrid(
+ (
+ torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=y.device),
+ torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=y.device),
+ )
+ )
+
+ coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
+ None
+ ].expand(b, h, w, 2)
+ coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
+ coarse_embedded_coords = self.project_to_basis(coarse_coords)
+ return coarse_embedded_coords
+
+ def forward(self, x, y, **kwargs):
+ b, c, h1, w1 = x.shape
+ b, c, h2, w2 = y.shape
+ f = self.get_pos_enc(y)
+ if self.predict_features:
+ f = f + y[:,:self.dim] # Stupid way to predict features
+ b, d, h2, w2 = f.shape
+ #assert x.shape == y.shape
+ x, y, f = self.reshape(x), self.reshape(y), self.reshape(f)
+ K_xx = self.K(x, x)
+ K_yy = self.K(y, y)
+ K_xy = self.K(x, y)
+ K_yx = K_xy.permute(0, 2, 1)
+ sigma_noise = self.sigma_noise * torch.eye(h2 * w2, device=x.device)[None, :, :]
+ # Due to https://github.com/pytorch/pytorch/issues/16963 annoying warnings, remove batch if N large
+ if len(K_yy[0]) > 2000:
+ K_yy_inv = torch.cat([torch.linalg.inv(K_yy[k:k+1] + sigma_noise[k:k+1]) for k in range(b)])
+ else:
+ K_yy_inv = torch.linalg.inv(K_yy + sigma_noise)
+
+ mu_x = K_xy.matmul(K_yy_inv.matmul(f))
+ mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1)
+ if not self.no_cov:
+ cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx))
+ cov_x = rearrange(cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1)
+ local_cov_x = self.get_local_cov(cov_x)
+ local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w")
+ gp_feats = torch.cat((mu_x, local_cov_x), dim=1)
+ else:
+ gp_feats = mu_x
+ return gp_feats
+
+
+class Encoder(nn.Module):
+ def __init__(self, resnet):
+ super().__init__()
+ self.resnet = resnet
+ def forward(self, x):
+ x0 = x
+ b, c, h, w = x.shape
+ x = self.resnet.conv1(x)
+ x = self.resnet.bn1(x)
+ x1 = self.resnet.relu(x)
+
+ x = self.resnet.maxpool(x1)
+ x2 = self.resnet.layer1(x)
+
+ x3 = self.resnet.layer2(x2)
+
+ x4 = self.resnet.layer3(x3)
+
+ x5 = self.resnet.layer4(x4)
+ feats = {32: x5, 16: x4, 8: x3, 4: x2, 2: x1, 1: x0}
+ return feats
+
+ def train(self, mode=True):
+ super().train(mode)
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
+ pass
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self, embedding_decoder, gps, proj, conv_refiner, transformers = None, detach=False, scales="all", pos_embeddings = None,
+ ):
+ super().__init__()
+ self.embedding_decoder = embedding_decoder
+ self.gps = gps
+ self.proj = proj
+ self.conv_refiner = conv_refiner
+ self.detach = detach
+ if scales == "all":
+ self.scales = ["32", "16", "8", "4", "2", "1"]
+ else:
+ self.scales = scales
+
+ def upsample_preds(self, flow, certainty, query, support):
+ b, hs, ws, d = flow.shape
+ b, c, h, w = query.shape
+ flow = flow.permute(0, 3, 1, 2)
+ certainty = F.interpolate(
+ certainty, size=(h, w), align_corners=False, mode="bilinear"
+ )
+ flow = F.interpolate(
+ flow, size=(h, w), align_corners=False, mode="bilinear"
+ )
+ delta_certainty, delta_flow = self.conv_refiner["1"](query, support, flow)
+ flow = torch.stack(
+ (
+ flow[:, 0] + delta_flow[:, 0] / (4 * w),
+ flow[:, 1] + delta_flow[:, 1] / (4 * h),
+ ),
+ dim=1,
+ )
+ flow = flow.permute(0, 2, 3, 1)
+ certainty = certainty + delta_certainty
+ return flow, certainty
+
+ def get_placeholder_flow(self, b, h, w, device):
+ coarse_coords = torch.meshgrid(
+ (
+ torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
+ torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
+ )
+ )
+ coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
+ None
+ ].expand(b, h, w, 2)
+ coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
+ return coarse_coords
+
+
+ def forward(self, f1, f2, upsample = False, dense_flow = None, dense_certainty = None):
+ coarse_scales = self.embedding_decoder.scales()
+ all_scales = self.scales if not upsample else ["8", "4", "2", "1"]
+ sizes = {scale: f1[scale].shape[-2:] for scale in f1}
+ h, w = sizes[1]
+ b = f1[1].shape[0]
+ device = f1[1].device
+ coarsest_scale = int(all_scales[0])
+ old_stuff = torch.zeros(
+ b, self.embedding_decoder.internal_dim, *sizes[coarsest_scale], device=f1[coarsest_scale].device
+ )
+ dense_corresps = {}
+ if not upsample:
+ dense_flow = self.get_placeholder_flow(b, *sizes[coarsest_scale], device)
+ dense_certainty = 0.0
+ else:
+ dense_flow = F.interpolate(
+ dense_flow,
+ size=sizes[coarsest_scale],
+ align_corners=False,
+ mode="bilinear",
+ )
+ dense_certainty = F.interpolate(
+ dense_certainty,
+ size=sizes[coarsest_scale],
+ align_corners=False,
+ mode="bilinear",
+ )
+ for new_scale in all_scales:
+ ins = int(new_scale)
+ f1_s, f2_s = f1[ins], f2[ins]
+ if new_scale in self.proj:
+ f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
+ b, c, hs, ws = f1_s.shape
+ if ins in coarse_scales:
+ old_stuff = F.interpolate(
+ old_stuff, size=sizes[ins], mode="bilinear", align_corners=False
+ )
+ new_stuff = self.gps[new_scale](f1_s, f2_s, dense_flow=dense_flow)
+ dense_flow, dense_certainty, old_stuff = self.embedding_decoder(
+ new_stuff, f1_s, old_stuff, new_scale
+ )
+
+ if new_scale in self.conv_refiner:
+ delta_certainty, displacement = self.conv_refiner[new_scale](
+ f1_s, f2_s, dense_flow
+ )
+ dense_flow = torch.stack(
+ (
+ dense_flow[:, 0] + ins * displacement[:, 0] / (4 * w),
+ dense_flow[:, 1] + ins * displacement[:, 1] / (4 * h),
+ ),
+ dim=1,
+ )
+ dense_certainty = (
+ dense_certainty + delta_certainty
+ ) # predict both certainty and displacement
+
+ dense_corresps[ins] = {
+ "dense_flow": dense_flow,
+ "dense_certainty": dense_certainty,
+ }
+
+ if new_scale != "1":
+ dense_flow = F.interpolate(
+ dense_flow,
+ size=sizes[ins // 2],
+ align_corners=False,
+ mode="bilinear",
+ )
+
+ dense_certainty = F.interpolate(
+ dense_certainty,
+ size=sizes[ins // 2],
+ align_corners=False,
+ mode="bilinear",
+ )
+ if self.detach:
+ dense_flow = dense_flow.detach()
+ dense_certainty = dense_certainty.detach()
+ return dense_corresps
+
+
+class RegressionMatcher(nn.Module):
+ def __init__(
+ self,
+ encoder,
+ decoder,
+ h=384,
+ w=512,
+ use_contrastive_loss = False,
+ alpha = 1,
+ beta = 0,
+ sample_mode = "threshold",
+ upsample_preds = False,
+ symmetric = False,
+ name = None,
+ use_soft_mutual_nearest_neighbours = False,
+ ):
+ super().__init__()
+ self.encoder = encoder
+ self.decoder = decoder
+ self.w_resized = w
+ self.h_resized = h
+ self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True)
+ self.use_contrastive_loss = use_contrastive_loss
+ self.alpha = alpha
+ self.beta = beta
+ self.sample_mode = sample_mode
+ self.upsample_preds = upsample_preds
+ self.symmetric = symmetric
+ self.name = name
+ self.sample_thresh = 0.05
+ self.upsample_res = (864,1152)
+ if use_soft_mutual_nearest_neighbours:
+ assert symmetric, "MNS requires symmetric inference"
+ self.use_soft_mutual_nearest_neighbours = use_soft_mutual_nearest_neighbours
+
+ def extract_backbone_features(self, batch, batched = True, upsample = True):
+ #TODO: only extract stride [1,2,4,8] for upsample = True
+ x_q = batch["query"]
+ x_s = batch["support"]
+ if batched:
+ X = torch.cat((x_q, x_s))
+ feature_pyramid = self.encoder(X)
+ else:
+ feature_pyramid = self.encoder(x_q), self.encoder(x_s)
+ return feature_pyramid
+
+ def sample(
+ self,
+ dense_matches,
+ dense_certainty,
+ num=10000,
+ ):
+ if "threshold" in self.sample_mode:
+ upper_thresh = self.sample_thresh
+ dense_certainty = dense_certainty.clone()
+ dense_certainty[dense_certainty > upper_thresh] = 1
+ elif "pow" in self.sample_mode:
+ dense_certainty = dense_certainty**(1/3)
+ elif "naive" in self.sample_mode:
+ dense_certainty = torch.ones_like(dense_certainty)
+ matches, certainty = (
+ dense_matches.reshape(-1, 4),
+ dense_certainty.reshape(-1),
+ )
+ expansion_factor = 4 if "balanced" in self.sample_mode else 1
+ good_samples = torch.multinomial(certainty,
+ num_samples = min(expansion_factor*num, len(certainty)),
+ replacement=False)
+ good_matches, good_certainty = matches[good_samples], certainty[good_samples]
+ if "balanced" not in self.sample_mode:
+ return good_matches, good_certainty
+
+ from ..utils.kde import kde
+ density = kde(good_matches, std=0.1)
+ p = 1 / (density+1)
+ p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
+ balanced_samples = torch.multinomial(p,
+ num_samples = min(num,len(good_certainty)),
+ replacement=False)
+ return good_matches[balanced_samples], good_certainty[balanced_samples]
+
+ def forward(self, batch, batched = True):
+ feature_pyramid = self.extract_backbone_features(batch, batched=batched)
+ if batched:
+ f_q_pyramid = {
+ scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items()
+ }
+ f_s_pyramid = {
+ scale: f_scale.chunk(2)[1] for scale, f_scale in feature_pyramid.items()
+ }
+ else:
+ f_q_pyramid, f_s_pyramid = feature_pyramid
+ dense_corresps = self.decoder(f_q_pyramid, f_s_pyramid)
+ if self.training and self.use_contrastive_loss:
+ return dense_corresps, (f_q_pyramid, f_s_pyramid)
+ else:
+ return dense_corresps
+
+ def forward_symmetric(self, batch, upsample = False, batched = True):
+ feature_pyramid = self.extract_backbone_features(batch, upsample = upsample, batched = batched)
+ f_q_pyramid = feature_pyramid
+ f_s_pyramid = {
+ scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]))
+ for scale, f_scale in feature_pyramid.items()
+ }
+ dense_corresps = self.decoder(f_q_pyramid, f_s_pyramid, upsample = upsample, **(batch["corresps"] if "corresps" in batch else {}))
+ return dense_corresps
+
+ def to_pixel_coordinates(self, matches, H_A, W_A, H_B, W_B):
+ kpts_A, kpts_B = matches[...,:2], matches[...,2:]
+ kpts_A = torch.stack((W_A/2 * (kpts_A[...,0]+1), H_A/2 * (kpts_A[...,1]+1)),axis=-1)
+ kpts_B = torch.stack((W_B/2 * (kpts_B[...,0]+1), H_B/2 * (kpts_B[...,1]+1)),axis=-1)
+ return kpts_A, kpts_B
+
+ def match(
+ self,
+ im1_path,
+ im2_path,
+ *args,
+ batched=False,
+ device = None
+ ):
+ assert not (batched and self.upsample_preds), "Cannot upsample preds if in batchmode (as we don't have access to high res images). You can turn off upsample_preds by model.upsample_preds = False "
+ if isinstance(im1_path, (str, os.PathLike)):
+ im1, im2 = Image.open(im1_path), Image.open(im2_path)
+ else: # assume it is a PIL Image
+ im1, im2 = im1_path, im2_path
+ if device is None:
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ symmetric = self.symmetric
+ self.train(False)
+ with torch.no_grad():
+ if not batched:
+ b = 1
+ w, h = im1.size
+ w2, h2 = im2.size
+ # Get images in good format
+ ws = self.w_resized
+ hs = self.h_resized
+
+ test_transform = get_tuple_transform_ops(
+ resize=(hs, ws), normalize=True
+ )
+ query, support = test_transform((im1, im2))
+ batch = {"query": query[None].to(device), "support": support[None].to(device)}
+ else:
+ b, c, h, w = im1.shape
+ b, c, h2, w2 = im2.shape
+ assert w == w2 and h == h2, "For batched images we assume same size"
+ batch = {"query": im1.to(device), "support": im2.to(device)}
+ hs, ws = self.h_resized, self.w_resized
+ finest_scale = 1
+ # Run matcher
+ if symmetric:
+ dense_corresps = self.forward_symmetric(batch, batched = True)
+ else:
+ dense_corresps = self.forward(batch, batched = True)
+
+ if self.upsample_preds:
+ hs, ws = self.upsample_res
+ low_res_certainty = F.interpolate(
+ dense_corresps[16]["dense_certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
+ )
+ cert_clamp = 0
+ factor = 0.5
+ low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp)
+
+ if self.upsample_preds:
+ test_transform = get_tuple_transform_ops(
+ resize=(hs, ws), normalize=True
+ )
+ query, support = test_transform((im1, im2))
+ query, support = query[None].to(device), support[None].to(device)
+ batch = {"query": query, "support": support, "corresps": dense_corresps[finest_scale]}
+ if symmetric:
+ dense_corresps = self.forward_symmetric(batch, upsample = True, batched=True)
+ else:
+ dense_corresps = self.forward(batch, batched = True, upsample=True)
+ query_to_support = dense_corresps[finest_scale]["dense_flow"]
+ dense_certainty = dense_corresps[finest_scale]["dense_certainty"]
+
+ # Get certainty interpolation
+ dense_certainty = dense_certainty - low_res_certainty
+ query_to_support = query_to_support.permute(
+ 0, 2, 3, 1
+ )
+ # Create im1 meshgrid
+ query_coords = torch.meshgrid(
+ (
+ torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
+ torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
+ )
+ )
+ query_coords = torch.stack((query_coords[1], query_coords[0]))
+ query_coords = query_coords[None].expand(b, 2, hs, ws)
+ dense_certainty = dense_certainty.sigmoid() # logits -> probs
+ query_coords = query_coords.permute(0, 2, 3, 1)
+ if (query_to_support.abs() > 1).any() and True:
+ wrong = (query_to_support.abs() > 1).sum(dim=-1) > 0
+ dense_certainty[wrong[:,None]] = 0
+
+ query_to_support = torch.clamp(query_to_support, -1, 1)
+ if symmetric:
+ support_coords = query_coords
+ qts, stq = query_to_support.chunk(2)
+ q_warp = torch.cat((query_coords, qts), dim=-1)
+ s_warp = torch.cat((stq, support_coords), dim=-1)
+ warp = torch.cat((q_warp, s_warp),dim=2)
+ dense_certainty = torch.cat(dense_certainty.chunk(2), dim=3)[:,0]
+ else:
+ warp = torch.cat((query_coords, query_to_support), dim=-1)
+ if batched:
+ return (
+ warp,
+ dense_certainty
+ )
+ else:
+ return (
+ warp[0],
+ dense_certainty[0],
+ )
diff --git a/third_party/DKM/dkm/models/encoders.py b/third_party/DKM/dkm/models/encoders.py
new file mode 100644
index 0000000000000000000000000000000000000000..29077e1797196611e9b59a753130a5b153e0aa05
--- /dev/null
+++ b/third_party/DKM/dkm/models/encoders.py
@@ -0,0 +1,147 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.models as tvm
+
+class ResNet18(nn.Module):
+ def __init__(self, pretrained=False) -> None:
+ super().__init__()
+ self.net = tvm.resnet18(pretrained=pretrained)
+ def forward(self, x):
+ self = self.net
+ x1 = x
+ x = self.conv1(x1)
+ x = self.bn1(x)
+ x2 = self.relu(x)
+ x = self.maxpool(x2)
+ x4 = self.layer1(x)
+ x8 = self.layer2(x4)
+ x16 = self.layer3(x8)
+ x32 = self.layer4(x16)
+ return {32:x32,16:x16,8:x8,4:x4,2:x2,1:x1}
+
+ def train(self, mode=True):
+ super().train(mode)
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
+ pass
+
+class ResNet50(nn.Module):
+ def __init__(self, pretrained=False, high_res = False, weights = None, dilation = None, freeze_bn = True, anti_aliased = False) -> None:
+ super().__init__()
+ if dilation is None:
+ dilation = [False,False,False]
+ if anti_aliased:
+ pass
+ else:
+ if weights is not None:
+ self.net = tvm.resnet50(weights = weights,replace_stride_with_dilation=dilation)
+ else:
+ self.net = tvm.resnet50(pretrained=pretrained,replace_stride_with_dilation=dilation)
+
+ self.high_res = high_res
+ self.freeze_bn = freeze_bn
+ def forward(self, x):
+ net = self.net
+ feats = {1:x}
+ x = net.conv1(x)
+ x = net.bn1(x)
+ x = net.relu(x)
+ feats[2] = x
+ x = net.maxpool(x)
+ x = net.layer1(x)
+ feats[4] = x
+ x = net.layer2(x)
+ feats[8] = x
+ x = net.layer3(x)
+ feats[16] = x
+ x = net.layer4(x)
+ feats[32] = x
+ return feats
+
+ def train(self, mode=True):
+ super().train(mode)
+ if self.freeze_bn:
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
+ pass
+
+
+
+
+class ResNet101(nn.Module):
+ def __init__(self, pretrained=False, high_res = False, weights = None) -> None:
+ super().__init__()
+ if weights is not None:
+ self.net = tvm.resnet101(weights = weights)
+ else:
+ self.net = tvm.resnet101(pretrained=pretrained)
+ self.high_res = high_res
+ self.scale_factor = 1 if not high_res else 1.5
+ def forward(self, x):
+ net = self.net
+ feats = {1:x}
+ sf = self.scale_factor
+ if self.high_res:
+ x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic")
+ x = net.conv1(x)
+ x = net.bn1(x)
+ x = net.relu(x)
+ feats[2] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
+ x = net.maxpool(x)
+ x = net.layer1(x)
+ feats[4] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
+ x = net.layer2(x)
+ feats[8] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
+ x = net.layer3(x)
+ feats[16] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
+ x = net.layer4(x)
+ feats[32] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
+ return feats
+
+ def train(self, mode=True):
+ super().train(mode)
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
+ pass
+
+
+class WideResNet50(nn.Module):
+ def __init__(self, pretrained=False, high_res = False, weights = None) -> None:
+ super().__init__()
+ if weights is not None:
+ self.net = tvm.wide_resnet50_2(weights = weights)
+ else:
+ self.net = tvm.wide_resnet50_2(pretrained=pretrained)
+ self.high_res = high_res
+ self.scale_factor = 1 if not high_res else 1.5
+ def forward(self, x):
+ net = self.net
+ feats = {1:x}
+ sf = self.scale_factor
+ if self.high_res:
+ x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic")
+ x = net.conv1(x)
+ x = net.bn1(x)
+ x = net.relu(x)
+ feats[2] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
+ x = net.maxpool(x)
+ x = net.layer1(x)
+ feats[4] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
+ x = net.layer2(x)
+ feats[8] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
+ x = net.layer3(x)
+ feats[16] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
+ x = net.layer4(x)
+ feats[32] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
+ return feats
+
+ def train(self, mode=True):
+ super().train(mode)
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
+ pass
\ No newline at end of file
diff --git a/third_party/DKM/dkm/models/model_zoo/DKMv3.py b/third_party/DKM/dkm/models/model_zoo/DKMv3.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f4c9ede3863d778f679a033d8d2287b8776e894
--- /dev/null
+++ b/third_party/DKM/dkm/models/model_zoo/DKMv3.py
@@ -0,0 +1,150 @@
+import torch
+
+from torch import nn
+from ..dkm import *
+from ..encoders import *
+
+
+def DKMv3(weights, h, w, symmetric = True, sample_mode= "threshold_balanced", device = None, **kwargs):
+ if device is None:
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ gp_dim = 256
+ dfn_dim = 384
+ feat_dim = 256
+ coordinate_decoder = DFN(
+ internal_dim=dfn_dim,
+ feat_input_modules=nn.ModuleDict(
+ {
+ "32": nn.Conv2d(512, feat_dim, 1, 1),
+ "16": nn.Conv2d(512, feat_dim, 1, 1),
+ }
+ ),
+ pred_input_modules=nn.ModuleDict(
+ {
+ "32": nn.Identity(),
+ "16": nn.Identity(),
+ }
+ ),
+ rrb_d_dict=nn.ModuleDict(
+ {
+ "32": RRB(gp_dim + feat_dim, dfn_dim),
+ "16": RRB(gp_dim + feat_dim, dfn_dim),
+ }
+ ),
+ cab_dict=nn.ModuleDict(
+ {
+ "32": CAB(2 * dfn_dim, dfn_dim),
+ "16": CAB(2 * dfn_dim, dfn_dim),
+ }
+ ),
+ rrb_u_dict=nn.ModuleDict(
+ {
+ "32": RRB(dfn_dim, dfn_dim),
+ "16": RRB(dfn_dim, dfn_dim),
+ }
+ ),
+ terminal_module=nn.ModuleDict(
+ {
+ "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
+ "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
+ }
+ ),
+ )
+ dw = True
+ hidden_blocks = 8
+ kernel_size = 5
+ displacement_emb = "linear"
+ conv_refiner = nn.ModuleDict(
+ {
+ "16": ConvRefiner(
+ 2 * 512+128+(2*7+1)**2,
+ 2 * 512+128+(2*7+1)**2,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ displacement_emb=displacement_emb,
+ displacement_emb_dim=128,
+ local_corr_radius = 7,
+ corr_in_other = True,
+ ),
+ "8": ConvRefiner(
+ 2 * 512+64+(2*3+1)**2,
+ 2 * 512+64+(2*3+1)**2,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ displacement_emb=displacement_emb,
+ displacement_emb_dim=64,
+ local_corr_radius = 3,
+ corr_in_other = True,
+ ),
+ "4": ConvRefiner(
+ 2 * 256+32+(2*2+1)**2,
+ 2 * 256+32+(2*2+1)**2,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ displacement_emb=displacement_emb,
+ displacement_emb_dim=32,
+ local_corr_radius = 2,
+ corr_in_other = True,
+ ),
+ "2": ConvRefiner(
+ 2 * 64+16,
+ 128+16,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ displacement_emb=displacement_emb,
+ displacement_emb_dim=16,
+ ),
+ "1": ConvRefiner(
+ 2 * 3+6,
+ 24,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ displacement_emb=displacement_emb,
+ displacement_emb_dim=6,
+ ),
+ }
+ )
+ kernel_temperature = 0.2
+ learn_temperature = False
+ no_cov = True
+ kernel = CosKernel
+ only_attention = False
+ basis = "fourier"
+ gp32 = GP(
+ kernel,
+ T=kernel_temperature,
+ learn_temperature=learn_temperature,
+ only_attention=only_attention,
+ gp_dim=gp_dim,
+ basis=basis,
+ no_cov=no_cov,
+ )
+ gp16 = GP(
+ kernel,
+ T=kernel_temperature,
+ learn_temperature=learn_temperature,
+ only_attention=only_attention,
+ gp_dim=gp_dim,
+ basis=basis,
+ no_cov=no_cov,
+ )
+ gps = nn.ModuleDict({"32": gp32, "16": gp16})
+ proj = nn.ModuleDict(
+ {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
+ )
+ decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
+
+ encoder = ResNet50(pretrained = False, high_res = False, freeze_bn=False)
+ matcher = RegressionMatcher(encoder, decoder, h=h, w=w, name = "DKMv3", sample_mode=sample_mode, symmetric = symmetric, **kwargs).to(device)
+ res = matcher.load_state_dict(weights)
+ return matcher
diff --git a/third_party/DKM/dkm/models/model_zoo/__init__.py b/third_party/DKM/dkm/models/model_zoo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c85da2920c1acfac140ada2d87623203607d42ca
--- /dev/null
+++ b/third_party/DKM/dkm/models/model_zoo/__init__.py
@@ -0,0 +1,39 @@
+weight_urls = {
+ "DKMv3": {
+ "outdoor": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_outdoor.pth",
+ "indoor": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_indoor.pth",
+ },
+}
+import torch
+from .DKMv3 import DKMv3
+
+
+def DKMv3_outdoor(path_to_weights = None, device=None):
+ """
+ Loads DKMv3 outdoor weights, uses internal resolution of (540, 720) by default
+ resolution can be changed by setting model.h_resized, model.w_resized later.
+ Additionally upsamples preds to fixed resolution of (864, 1152),
+ can be turned off by model.upsample_preds = False
+ """
+ if device is None:
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ if path_to_weights is not None:
+ weights = torch.load(path_to_weights, map_location='cpu')
+ else:
+ weights = torch.hub.load_state_dict_from_url(weight_urls["DKMv3"]["outdoor"],
+ map_location='cpu')
+ return DKMv3(weights, 540, 720, upsample_preds = True, device=device)
+
+def DKMv3_indoor(path_to_weights = None, device=None):
+ """
+ Loads DKMv3 indoor weights, uses internal resolution of (480, 640) by default
+ Resolution can be changed by setting model.h_resized, model.w_resized later.
+ """
+ if device is None:
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ if path_to_weights is not None:
+ weights = torch.load(path_to_weights, map_location=device)
+ else:
+ weights = torch.hub.load_state_dict_from_url(weight_urls["DKMv3"]["indoor"],
+ map_location=device)
+ return DKMv3(weights, 480, 640, upsample_preds = False, device=device)
diff --git a/third_party/DKM/dkm/train/__init__.py b/third_party/DKM/dkm/train/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..90269dc0f345a575e0ba21f5afa34202c7e6b433
--- /dev/null
+++ b/third_party/DKM/dkm/train/__init__.py
@@ -0,0 +1 @@
+from .train import train_k_epochs
diff --git a/third_party/DKM/dkm/train/train.py b/third_party/DKM/dkm/train/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..b580221f56a2667784836f0237955cc75131b88c
--- /dev/null
+++ b/third_party/DKM/dkm/train/train.py
@@ -0,0 +1,67 @@
+from tqdm import tqdm
+from dkm.utils.utils import to_cuda
+
+
+def train_step(train_batch, model, objective, optimizer, **kwargs):
+ optimizer.zero_grad()
+ out = model(train_batch)
+ l = objective(out, train_batch)
+ l.backward()
+ optimizer.step()
+ return {"train_out": out, "train_loss": l.item()}
+
+
+def train_k_steps(
+ n_0, k, dataloader, model, objective, optimizer, lr_scheduler, progress_bar=True
+):
+ for n in tqdm(range(n_0, n_0 + k), disable=not progress_bar):
+ batch = next(dataloader)
+ model.train(True)
+ batch = to_cuda(batch)
+ train_step(
+ train_batch=batch,
+ model=model,
+ objective=objective,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ n=n,
+ )
+ lr_scheduler.step()
+
+
+def train_epoch(
+ dataloader=None,
+ model=None,
+ objective=None,
+ optimizer=None,
+ lr_scheduler=None,
+ epoch=None,
+):
+ model.train(True)
+ print(f"At epoch {epoch}")
+ for batch in tqdm(dataloader, mininterval=5.0):
+ batch = to_cuda(batch)
+ train_step(
+ train_batch=batch, model=model, objective=objective, optimizer=optimizer
+ )
+ lr_scheduler.step()
+ return {
+ "model": model,
+ "optimizer": optimizer,
+ "lr_scheduler": lr_scheduler,
+ "epoch": epoch,
+ }
+
+
+def train_k_epochs(
+ start_epoch, end_epoch, dataloader, model, objective, optimizer, lr_scheduler
+):
+ for epoch in range(start_epoch, end_epoch + 1):
+ train_epoch(
+ dataloader=dataloader,
+ model=model,
+ objective=objective,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ epoch=epoch,
+ )
diff --git a/third_party/DKM/dkm/utils/__init__.py b/third_party/DKM/dkm/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..05367ac9521664992f587738caa231f32ae2e81c
--- /dev/null
+++ b/third_party/DKM/dkm/utils/__init__.py
@@ -0,0 +1,13 @@
+from .utils import (
+ pose_auc,
+ get_pose,
+ compute_relative_pose,
+ compute_pose_error,
+ estimate_pose,
+ rotate_intrinsic,
+ get_tuple_transform_ops,
+ get_depth_tuple_transform_ops,
+ warp_kpts,
+ numpy_to_pil,
+ tensor_to_pil,
+)
diff --git a/third_party/DKM/dkm/utils/kde.py b/third_party/DKM/dkm/utils/kde.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa392455e70fda4c9c77c28bda76bcb7ef9045b0
--- /dev/null
+++ b/third_party/DKM/dkm/utils/kde.py
@@ -0,0 +1,26 @@
+import torch
+import torch.nn.functional as F
+import numpy as np
+
+def fast_kde(x, std = 0.1, kernel_size = 9, dilation = 3, padding = 9//2, stride = 1):
+ raise NotImplementedError("WIP, use at your own risk.")
+ # Note: when doing symmetric matching this might not be very exact, since we only check neighbours on the grid
+ x = x.permute(0,3,1,2)
+ B,C,H,W = x.shape
+ K = kernel_size ** 2
+ unfolded_x = F.unfold(x,kernel_size=kernel_size, dilation = dilation, padding = padding, stride = stride).reshape(B, C, K, H, W)
+ scores = (-(unfolded_x - x[:,:,None]).sum(dim=1)**2/(2*std**2)).exp()
+ density = scores.sum(dim=1)
+ return density
+
+
+def kde(x, std = 0.1, device=None):
+ if device is None:
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ if isinstance(x, np.ndarray):
+ x = torch.from_numpy(x)
+ # use a gaussian kernel to estimate density
+ x = x.to(device)
+ scores = (-torch.cdist(x,x)**2/(2*std**2)).exp()
+ density = scores.sum(dim=-1)
+ return density
diff --git a/third_party/DKM/dkm/utils/local_correlation.py b/third_party/DKM/dkm/utils/local_correlation.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0c1c06291d0b760376a2b2162bcf49d6eb1303c
--- /dev/null
+++ b/third_party/DKM/dkm/utils/local_correlation.py
@@ -0,0 +1,40 @@
+import torch
+import torch.nn.functional as F
+
+
+def local_correlation(
+ feature0,
+ feature1,
+ local_radius,
+ padding_mode="zeros",
+ flow = None
+):
+ device = feature0.device
+ b, c, h, w = feature0.size()
+ if flow is None:
+ # If flow is None, assume feature0 and feature1 are aligned
+ coords = torch.meshgrid(
+ (
+ torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
+ torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
+ ))
+ coords = torch.stack((coords[1], coords[0]), dim=-1)[
+ None
+ ].expand(b, h, w, 2)
+ else:
+ coords = flow.permute(0,2,3,1) # If using flow, sample around flow target.
+ r = local_radius
+ local_window = torch.meshgrid(
+ (
+ torch.linspace(-2*local_radius/h, 2*local_radius/h, 2*r+1, device=device),
+ torch.linspace(-2*local_radius/w, 2*local_radius/w, 2*r+1, device=device),
+ ))
+ local_window = torch.stack((local_window[1], local_window[0]), dim=-1)[
+ None
+ ].expand(b, 2*r+1, 2*r+1, 2).reshape(b, (2*r+1)**2, 2)
+ coords = (coords[:,:,:,None]+local_window[:,None,None]).reshape(b,h,w*(2*r+1)**2,2)
+ window_feature = F.grid_sample(
+ feature1, coords, padding_mode=padding_mode, align_corners=False
+ )[...,None].reshape(b,c,h,w,(2*r+1)**2)
+ corr = torch.einsum("bchw, bchwk -> bkhw", feature0, window_feature)/(c**.5)
+ return corr
diff --git a/third_party/DKM/dkm/utils/transforms.py b/third_party/DKM/dkm/utils/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..754d853fda4cbcf89d2111bed4f44b0ca84f0518
--- /dev/null
+++ b/third_party/DKM/dkm/utils/transforms.py
@@ -0,0 +1,104 @@
+from typing import Dict
+import numpy as np
+import torch
+import kornia.augmentation as K
+from kornia.geometry.transform import warp_perspective
+
+# Adapted from Kornia
+class GeometricSequential:
+ def __init__(self, *transforms, align_corners=True) -> None:
+ self.transforms = transforms
+ self.align_corners = align_corners
+
+ def __call__(self, x, mode="bilinear"):
+ b, c, h, w = x.shape
+ M = torch.eye(3, device=x.device)[None].expand(b, 3, 3)
+ for t in self.transforms:
+ if np.random.rand() < t.p:
+ M = M.matmul(
+ t.compute_transformation(x, t.generate_parameters((b, c, h, w)))
+ )
+ return (
+ warp_perspective(
+ x, M, dsize=(h, w), mode=mode, align_corners=self.align_corners
+ ),
+ M,
+ )
+
+ def apply_transform(self, x, M, mode="bilinear"):
+ b, c, h, w = x.shape
+ return warp_perspective(
+ x, M, dsize=(h, w), align_corners=self.align_corners, mode=mode
+ )
+
+
+class RandomPerspective(K.RandomPerspective):
+ def generate_parameters(self, batch_shape: torch.Size) -> Dict[str, torch.Tensor]:
+ distortion_scale = torch.as_tensor(
+ self.distortion_scale, device=self._device, dtype=self._dtype
+ )
+ return self.random_perspective_generator(
+ batch_shape[0],
+ batch_shape[-2],
+ batch_shape[-1],
+ distortion_scale,
+ self.same_on_batch,
+ self.device,
+ self.dtype,
+ )
+
+ def random_perspective_generator(
+ self,
+ batch_size: int,
+ height: int,
+ width: int,
+ distortion_scale: torch.Tensor,
+ same_on_batch: bool = False,
+ device: torch.device = torch.device("cpu"),
+ dtype: torch.dtype = torch.float32,
+ ) -> Dict[str, torch.Tensor]:
+ r"""Get parameters for ``perspective`` for a random perspective transform.
+
+ Args:
+ batch_size (int): the tensor batch size.
+ height (int) : height of the image.
+ width (int): width of the image.
+ distortion_scale (torch.Tensor): it controls the degree of distortion and ranges from 0 to 1.
+ same_on_batch (bool): apply the same transformation across the batch. Default: False.
+ device (torch.device): the device on which the random numbers will be generated. Default: cpu.
+ dtype (torch.dtype): the data type of the generated random numbers. Default: float32.
+
+ Returns:
+ params Dict[str, torch.Tensor]: parameters to be passed for transformation.
+ - start_points (torch.Tensor): element-wise perspective source areas with a shape of (B, 4, 2).
+ - end_points (torch.Tensor): element-wise perspective target areas with a shape of (B, 4, 2).
+
+ Note:
+ The generated random numbers are not reproducible across different devices and dtypes.
+ """
+ if not (distortion_scale.dim() == 0 and 0 <= distortion_scale <= 1):
+ raise AssertionError(
+ f"'distortion_scale' must be a scalar within [0, 1]. Got {distortion_scale}."
+ )
+ if not (
+ type(height) is int and height > 0 and type(width) is int and width > 0
+ ):
+ raise AssertionError(
+ f"'height' and 'width' must be integers. Got {height}, {width}."
+ )
+
+ start_points: torch.Tensor = torch.tensor(
+ [[[0.0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]],
+ device=distortion_scale.device,
+ dtype=distortion_scale.dtype,
+ ).expand(batch_size, -1, -1)
+
+ # generate random offset not larger than half of the image
+ fx = distortion_scale * width / 2
+ fy = distortion_scale * height / 2
+
+ factor = torch.stack([fx, fy], dim=0).view(-1, 1, 2)
+ offset = (torch.rand_like(start_points) - 0.5) * 2
+ end_points = start_points + factor * offset
+
+ return dict(start_points=start_points, end_points=end_points)
diff --git a/third_party/DKM/dkm/utils/utils.py b/third_party/DKM/dkm/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..46bbe60260930aed184c6fa5907c837c0177b304
--- /dev/null
+++ b/third_party/DKM/dkm/utils/utils.py
@@ -0,0 +1,341 @@
+import numpy as np
+import cv2
+import torch
+from torchvision import transforms
+from torchvision.transforms.functional import InterpolationMode
+import torch.nn.functional as F
+from PIL import Image
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+# Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py
+# --- GEOMETRY ---
+def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
+ if len(kpts0) < 5:
+ return None
+ K0inv = np.linalg.inv(K0[:2,:2])
+ K1inv = np.linalg.inv(K1[:2,:2])
+
+ kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T
+ kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T
+
+ E, mask = cv2.findEssentialMat(
+ kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf, method=cv2.RANSAC
+ )
+
+ ret = None
+ if E is not None:
+ best_num_inliers = 0
+
+ for _E in np.split(E, len(E) / 3):
+ n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
+ if n > best_num_inliers:
+ best_num_inliers = n
+ ret = (R, t, mask.ravel() > 0)
+ return ret
+
+
+def rotate_intrinsic(K, n):
+ base_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])
+ rot = np.linalg.matrix_power(base_rot, n)
+ return rot @ K
+
+
+def rotate_pose_inplane(i_T_w, rot):
+ rotation_matrices = [
+ np.array(
+ [
+ [np.cos(r), -np.sin(r), 0.0, 0.0],
+ [np.sin(r), np.cos(r), 0.0, 0.0],
+ [0.0, 0.0, 1.0, 0.0],
+ [0.0, 0.0, 0.0, 1.0],
+ ],
+ dtype=np.float32,
+ )
+ for r in [np.deg2rad(d) for d in (0, 270, 180, 90)]
+ ]
+ return np.dot(rotation_matrices[rot], i_T_w)
+
+
+def scale_intrinsics(K, scales):
+ scales = np.diag([1.0 / scales[0], 1.0 / scales[1], 1.0])
+ return np.dot(scales, K)
+
+
+def to_homogeneous(points):
+ return np.concatenate([points, np.ones_like(points[:, :1])], axis=-1)
+
+
+def angle_error_mat(R1, R2):
+ cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2
+ cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds
+ return np.rad2deg(np.abs(np.arccos(cos)))
+
+
+def angle_error_vec(v1, v2):
+ n = np.linalg.norm(v1) * np.linalg.norm(v2)
+ return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0)))
+
+
+def compute_pose_error(T_0to1, R, t):
+ R_gt = T_0to1[:3, :3]
+ t_gt = T_0to1[:3, 3]
+ error_t = angle_error_vec(t.squeeze(), t_gt)
+ error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation
+ error_R = angle_error_mat(R, R_gt)
+ return error_t, error_R
+
+
+def pose_auc(errors, thresholds):
+ sort_idx = np.argsort(errors)
+ errors = np.array(errors.copy())[sort_idx]
+ recall = (np.arange(len(errors)) + 1) / len(errors)
+ errors = np.r_[0.0, errors]
+ recall = np.r_[0.0, recall]
+ aucs = []
+ for t in thresholds:
+ last_index = np.searchsorted(errors, t)
+ r = np.r_[recall[:last_index], recall[last_index - 1]]
+ e = np.r_[errors[:last_index], t]
+ aucs.append(np.trapz(r, x=e) / t)
+ return aucs
+
+
+# From Patch2Pix https://github.com/GrumpyZhou/patch2pix
+def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False):
+ ops = []
+ if resize:
+ ops.append(TupleResize(resize, mode=InterpolationMode.BILINEAR))
+ return TupleCompose(ops)
+
+
+def get_tuple_transform_ops(resize=None, normalize=True, unscale=False):
+ ops = []
+ if resize:
+ ops.append(TupleResize(resize))
+ if normalize:
+ ops.append(TupleToTensorScaled())
+ ops.append(
+ TupleNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ) # Imagenet mean/std
+ else:
+ if unscale:
+ ops.append(TupleToTensorUnscaled())
+ else:
+ ops.append(TupleToTensorScaled())
+ return TupleCompose(ops)
+
+
+class ToTensorScaled(object):
+ """Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]"""
+
+ def __call__(self, im):
+ if not isinstance(im, torch.Tensor):
+ im = np.array(im, dtype=np.float32).transpose((2, 0, 1))
+ im /= 255.0
+ return torch.from_numpy(im)
+ else:
+ return im
+
+ def __repr__(self):
+ return "ToTensorScaled(./255)"
+
+
+class TupleToTensorScaled(object):
+ def __init__(self):
+ self.to_tensor = ToTensorScaled()
+
+ def __call__(self, im_tuple):
+ return [self.to_tensor(im) for im in im_tuple]
+
+ def __repr__(self):
+ return "TupleToTensorScaled(./255)"
+
+
+class ToTensorUnscaled(object):
+ """Convert a RGB PIL Image to a CHW ordered Tensor"""
+
+ def __call__(self, im):
+ return torch.from_numpy(np.array(im, dtype=np.float32).transpose((2, 0, 1)))
+
+ def __repr__(self):
+ return "ToTensorUnscaled()"
+
+
+class TupleToTensorUnscaled(object):
+ """Convert a RGB PIL Image to a CHW ordered Tensor"""
+
+ def __init__(self):
+ self.to_tensor = ToTensorUnscaled()
+
+ def __call__(self, im_tuple):
+ return [self.to_tensor(im) for im in im_tuple]
+
+ def __repr__(self):
+ return "TupleToTensorUnscaled()"
+
+
+class TupleResize(object):
+ def __init__(self, size, mode=InterpolationMode.BICUBIC):
+ self.size = size
+ self.resize = transforms.Resize(size, mode)
+
+ def __call__(self, im_tuple):
+ return [self.resize(im) for im in im_tuple]
+
+ def __repr__(self):
+ return "TupleResize(size={})".format(self.size)
+
+
+class TupleNormalize(object):
+ def __init__(self, mean, std):
+ self.mean = mean
+ self.std = std
+ self.normalize = transforms.Normalize(mean=mean, std=std)
+
+ def __call__(self, im_tuple):
+ return [self.normalize(im) for im in im_tuple]
+
+ def __repr__(self):
+ return "TupleNormalize(mean={}, std={})".format(self.mean, self.std)
+
+
+class TupleCompose(object):
+ def __init__(self, transforms):
+ self.transforms = transforms
+
+ def __call__(self, im_tuple):
+ for t in self.transforms:
+ im_tuple = t(im_tuple)
+ return im_tuple
+
+ def __repr__(self):
+ format_string = self.__class__.__name__ + "("
+ for t in self.transforms:
+ format_string += "\n"
+ format_string += " {0}".format(t)
+ format_string += "\n)"
+ return format_string
+
+
+@torch.no_grad()
+def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
+ """Warp kpts0 from I0 to I1 with depth, K and Rt
+ Also check covisibility and depth consistency.
+ Depth is consistent if relative error < 0.2 (hard-coded).
+ # https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here
+ Args:
+ kpts0 (torch.Tensor): [N, L, 2] - , should be normalized in (-1,1)
+ depth0 (torch.Tensor): [N, H, W],
+ depth1 (torch.Tensor): [N, H, W],
+ T_0to1 (torch.Tensor): [N, 3, 4],
+ K0 (torch.Tensor): [N, 3, 3],
+ K1 (torch.Tensor): [N, 3, 3],
+ Returns:
+ calculable_mask (torch.Tensor): [N, L]
+ warped_keypoints0 (torch.Tensor): [N, L, 2]
+ """
+ (
+ n,
+ h,
+ w,
+ ) = depth0.shape
+ kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode="bilinear")[
+ :, 0, :, 0
+ ]
+ kpts0 = torch.stack(
+ (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1
+ ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
+ # Sample depth, get calculable_mask on depth != 0
+ nonzero_mask = kpts0_depth != 0
+
+ # Unproject
+ kpts0_h = (
+ torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1)
+ * kpts0_depth[..., None]
+ ) # (N, L, 3)
+ kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
+ kpts0_cam = kpts0_n
+
+ # Rigid Transform
+ w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
+ w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
+
+ # Project
+ w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
+ w_kpts0 = w_kpts0_h[:, :, :2] / (
+ w_kpts0_h[:, :, [2]] + 1e-4
+ ) # (N, L, 2), +1e-4 to avoid zero depth
+
+ # Covisible Check
+ h, w = depth1.shape[1:3]
+ covisible_mask = (
+ (w_kpts0[:, :, 0] > 0)
+ * (w_kpts0[:, :, 0] < w - 1)
+ * (w_kpts0[:, :, 1] > 0)
+ * (w_kpts0[:, :, 1] < h - 1)
+ )
+ w_kpts0 = torch.stack(
+ (2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1
+ ) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h]
+ # w_kpts0[~covisible_mask, :] = -5 # xd
+
+ w_kpts0_depth = F.grid_sample(
+ depth1[:, None], w_kpts0[:, :, None], mode="bilinear"
+ )[:, 0, :, 0]
+ consistent_mask = (
+ (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth
+ ).abs() < 0.05
+ valid_mask = nonzero_mask * covisible_mask * consistent_mask
+
+ return valid_mask, w_kpts0
+
+
+imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
+imagenet_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
+
+
+def numpy_to_pil(x: np.ndarray):
+ """
+ Args:
+ x: Assumed to be of shape (h,w,c)
+ """
+ if isinstance(x, torch.Tensor):
+ x = x.detach().cpu().numpy()
+ if x.max() <= 1.01:
+ x *= 255
+ x = x.astype(np.uint8)
+ return Image.fromarray(x)
+
+
+def tensor_to_pil(x, unnormalize=False):
+ if unnormalize:
+ x = x * imagenet_std[:, None, None] + imagenet_mean[:, None, None]
+ x = x.detach().permute(1, 2, 0).cpu().numpy()
+ x = np.clip(x, 0.0, 1.0)
+ return numpy_to_pil(x)
+
+
+def to_cuda(batch):
+ for key, value in batch.items():
+ if isinstance(value, torch.Tensor):
+ batch[key] = value.to(device)
+ return batch
+
+
+def to_cpu(batch):
+ for key, value in batch.items():
+ if isinstance(value, torch.Tensor):
+ batch[key] = value.cpu()
+ return batch
+
+
+def get_pose(calib):
+ w, h = np.array(calib["imsize"])[0]
+ return np.array(calib["K"]), np.array(calib["R"]), np.array(calib["T"]).T, h, w
+
+
+def compute_relative_pose(R1, t1, R2, t2):
+ rots = R2 @ (R1.T)
+ trans = -rots @ t1 + t2
+ return rots, trans
diff --git a/third_party/DKM/docs/api.md b/third_party/DKM/docs/api.md
new file mode 100644
index 0000000000000000000000000000000000000000..d19e961a81f59ea6f33de1cc53bce16b4db9678c
--- /dev/null
+++ b/third_party/DKM/docs/api.md
@@ -0,0 +1,24 @@
+## Creating a model
+```python
+from dkm import DKMv3_outdoor, DKMv3_indoor
+DKMv3_outdoor() # creates an outdoor trained model
+DKMv3_indoor() # creates an indoor trained model
+```
+## Model settings
+Note: Non-exhaustive list
+```python
+model.upsample_preds = True/False # Whether to upsample the predictions to higher resolution
+model.upsample_res = (H_big, W_big) # Which resolution to use for upsampling
+model.symmetric = True/False # Whether to compute a bidirectional warp
+model.w_resized = W # width of image used
+model.h_resized = H # height of image used
+model.sample_mode = "threshold_balanced" # method for sampling matches. threshold_balanced is what was used in the paper
+model.sample_threshold = 0.05 # the threshold for sampling, 0.05 works well for megadepth, for IMC2022 we found 0.2 to work better.
+```
+## Running model
+```python
+warp, certainty = model.match(im_A, im_B) # produces a warp of shape [B,H,W,4] and certainty of shape [B,H,W]
+matches, certainty = model.sample(warp, certainty) # samples from the warp using the certainty
+kpts_A, kpts_B = model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B) # convenience function to convert normalized matches to pixel coordinates
+```
+
diff --git a/third_party/DKM/docs/benchmarks.md b/third_party/DKM/docs/benchmarks.md
new file mode 100644
index 0000000000000000000000000000000000000000..30dd57af86ad4f85c621e430eef9e9c55ba9d2c5
--- /dev/null
+++ b/third_party/DKM/docs/benchmarks.md
@@ -0,0 +1,27 @@
+Benchmarking datasets for geometry estimation can be somewhat cumbersome to download. We provide instructions for the benchmarks we use below, and are happy to answer any questions.
+
+### HPatches
+First, make sure that the "data/hpatches" path exists, e.g. by
+
+`` ln -s place/where/your/datasets/are/stored/hpatches data/hpatches ``
+
+Then run (if you don't already have hpatches downloaded)
+
+`` bash scripts/download_hpatches.sh``
+
+### Megadepth-1500 (LoFTR Split)
+1. We use the split made by LoFTR, which can be downloaded here https://drive.google.com/drive/folders/1nTkK1485FuwqA0DbZrK2Cl0WnXadUZdc. (You can also use the preprocessed megadepth dataset if you have it available)
+2. The images should be located in data/megadepth/Undistorted_SfM/0015 and 0022.
+3. The pair infos are provided here https://github.com/zju3dv/LoFTR/tree/master/assets/megadepth_test_1500_scene_info
+3. Put those files in data/megadepth/xxx
+
+### Megadepth-8-Scenes (DKM Split)
+1. The pair infos are provided in [assets](../assets/)
+2. Put those files in data/megadepth/xxx
+
+
+### Scannet-1500 (SuperGlue Split)
+We use the same split of scannet as superglue.
+1. LoFTR provides the split here: https://drive.google.com/drive/folders/1nTkK1485FuwqA0DbZrK2Cl0WnXadUZdc
+2. Note that ScanNet requires you to sign a License agreement, which can be found http://kaldir.vc.in.tum.de/scannet/ScanNet_TOS.pdf
+3. This benchmark should be put in the data/scannet folder
diff --git a/third_party/DKM/docs/training.md b/third_party/DKM/docs/training.md
new file mode 100644
index 0000000000000000000000000000000000000000..37d17171ac45c95ff587b9bbd525c4175558ff8a
--- /dev/null
+++ b/third_party/DKM/docs/training.md
@@ -0,0 +1,21 @@
+Here we provide instructions for how to train our models, including download of datasets.
+
+### MegaDepth
+First the MegaDepth dataset needs to be downloaded and preprocessed. This can be done by the following steps:
+1. Download MegaDepth from here: https://www.cs.cornell.edu/projects/megadepth/
+2. Extract and preprocess: See https://github.com/mihaidusmanu/d2-net
+3. Download our prepared scene info from here: https://github.com/Parskatt/storage/releases/download/prep_scene_info/prep_scene_info.tar
+4. File structure should be data/megadepth/phoenix, data/megadepth/Undistorted_SfM, data/megadepth/prep_scene_info.
+Then run
+``` bash
+python experiments/dkmv3/train_DKMv3_outdoor.py --gpus 4
+```
+
+## Megadepth + Scannet
+First follow the steps outlined above.
+Then, see https://github.com/zju3dv/LoFTR/blob/master/docs/TRAINING.md
+
+Then run
+``` bash
+python experiments/dkmv3/train_DKMv3_indoor.py --gpus 4
+```
diff --git a/third_party/DKM/requirements.txt b/third_party/DKM/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..018696905e480072ebe7dd9a9010db8b9e77f1d8
--- /dev/null
+++ b/third_party/DKM/requirements.txt
@@ -0,0 +1,11 @@
+torch
+einops
+torchvision
+opencv-python
+kornia
+albumentations
+loguru
+tqdm
+matplotlib
+h5py
+wandb
\ No newline at end of file
diff --git a/third_party/DKM/scripts/download_hpatches.sh b/third_party/DKM/scripts/download_hpatches.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5cdc42f6c9304062773ea30852179f51580ea9e0
--- /dev/null
+++ b/third_party/DKM/scripts/download_hpatches.sh
@@ -0,0 +1,4 @@
+cd data/hpatches
+wget http://icvl.ee.ic.ac.uk/vbalnt/hpatches/hpatches-sequences-release.tar.gz
+tar -xvf hpatches-sequences-release.tar.gz -C .
+rm hpatches-sequences-release.tar.gz
\ No newline at end of file
diff --git a/third_party/DKM/setup.py b/third_party/DKM/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..73ae664126066249200d72c1ec3166f4f2d76b10
--- /dev/null
+++ b/third_party/DKM/setup.py
@@ -0,0 +1,9 @@
+from setuptools import setup, find_packages
+
+setup(
+ name="dkm",
+ packages=find_packages(include=("dkm*",)),
+ version="0.3.0",
+ author="Johan Edstedt",
+ install_requires=open("requirements.txt", "r").read().split("\n"),
+)
diff --git a/third_party/Roma/.gitignore b/third_party/Roma/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..ff4633046059da69ab4b6e222909614ccda82ac4
--- /dev/null
+++ b/third_party/Roma/.gitignore
@@ -0,0 +1,5 @@
+*.egg-info*
+*.vscode*
+*__pycache__*
+vis*
+workspace*
\ No newline at end of file
diff --git a/third_party/Roma/LICENSE b/third_party/Roma/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..a115f899f8d09ef3b1def4a16c7bae1a0bd50fbe
--- /dev/null
+++ b/third_party/Roma/LICENSE
@@ -0,0 +1,400 @@
+
+Attribution-NonCommercial 4.0 International
+
+=======================================================================
+
+Creative Commons Corporation ("Creative Commons") is not a law firm and
+does not provide legal services or legal advice. Distribution of
+Creative Commons public licenses does not create a lawyer-client or
+other relationship. Creative Commons makes its licenses and related
+information available on an "as-is" basis. Creative Commons gives no
+warranties regarding its licenses, any material licensed under their
+terms and conditions, or any related information. Creative Commons
+disclaims all liability for damages resulting from their use to the
+fullest extent possible.
+
+Using Creative Commons Public Licenses
+
+Creative Commons public licenses provide a standard set of terms and
+conditions that creators and other rights holders may use to share
+original works of authorship and other material subject to copyright
+and certain other rights specified in the public license below. The
+following considerations are for informational purposes only, are not
+exhaustive, and do not form part of our licenses.
+
+ Considerations for licensors: Our public licenses are
+ intended for use by those authorized to give the public
+ permission to use material in ways otherwise restricted by
+ copyright and certain other rights. Our licenses are
+ irrevocable. Licensors should read and understand the terms
+ and conditions of the license they choose before applying it.
+ Licensors should also secure all rights necessary before
+ applying our licenses so that the public can reuse the
+ material as expected. Licensors should clearly mark any
+ material not subject to the license. This includes other CC-
+ licensed material, or material used under an exception or
+ limitation to copyright. More considerations for licensors:
+ wiki.creativecommons.org/Considerations_for_licensors
+
+ Considerations for the public: By using one of our public
+ licenses, a licensor grants the public permission to use the
+ licensed material under specified terms and conditions. If
+ the licensor's permission is not necessary for any reason--for
+ example, because of any applicable exception or limitation to
+ copyright--then that use is not regulated by the license. Our
+ licenses grant only permissions under copyright and certain
+ other rights that a licensor has authority to grant. Use of
+ the licensed material may still be restricted for other
+ reasons, including because others have copyright or other
+ rights in the material. A licensor may make special requests,
+ such as asking that all changes be marked or described.
+ Although not required by our licenses, you are encouraged to
+ respect those requests where reasonable. More_considerations
+ for the public:
+ wiki.creativecommons.org/Considerations_for_licensees
+
+=======================================================================
+
+Creative Commons Attribution-NonCommercial 4.0 International Public
+License
+
+By exercising the Licensed Rights (defined below), You accept and agree
+to be bound by the terms and conditions of this Creative Commons
+Attribution-NonCommercial 4.0 International Public License ("Public
+License"). To the extent this Public License may be interpreted as a
+contract, You are granted the Licensed Rights in consideration of Your
+acceptance of these terms and conditions, and the Licensor grants You
+such rights in consideration of benefits the Licensor receives from
+making the Licensed Material available under these terms and
+conditions.
+
+Section 1 -- Definitions.
+
+ a. Adapted Material means material subject to Copyright and Similar
+ Rights that is derived from or based upon the Licensed Material
+ and in which the Licensed Material is translated, altered,
+ arranged, transformed, or otherwise modified in a manner requiring
+ permission under the Copyright and Similar Rights held by the
+ Licensor. For purposes of this Public License, where the Licensed
+ Material is a musical work, performance, or sound recording,
+ Adapted Material is always produced where the Licensed Material is
+ synched in timed relation with a moving image.
+
+ b. Adapter's License means the license You apply to Your Copyright
+ and Similar Rights in Your contributions to Adapted Material in
+ accordance with the terms and conditions of this Public License.
+
+ c. Copyright and Similar Rights means copyright and/or similar rights
+ closely related to copyright including, without limitation,
+ performance, broadcast, sound recording, and Sui Generis Database
+ Rights, without regard to how the rights are labeled or
+ categorized. For purposes of this Public License, the rights
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
+ Rights.
+ d. Effective Technological Measures means those measures that, in the
+ absence of proper authority, may not be circumvented under laws
+ fulfilling obligations under Article 11 of the WIPO Copyright
+ Treaty adopted on December 20, 1996, and/or similar international
+ agreements.
+
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
+ any other exception or limitation to Copyright and Similar Rights
+ that applies to Your use of the Licensed Material.
+
+ f. Licensed Material means the artistic or literary work, database,
+ or other material to which the Licensor applied this Public
+ License.
+
+ g. Licensed Rights means the rights granted to You subject to the
+ terms and conditions of this Public License, which are limited to
+ all Copyright and Similar Rights that apply to Your use of the
+ Licensed Material and that the Licensor has authority to license.
+
+ h. Licensor means the individual(s) or entity(ies) granting rights
+ under this Public License.
+
+ i. NonCommercial means not primarily intended for or directed towards
+ commercial advantage or monetary compensation. For purposes of
+ this Public License, the exchange of the Licensed Material for
+ other material subject to Copyright and Similar Rights by digital
+ file-sharing or similar means is NonCommercial provided there is
+ no payment of monetary compensation in connection with the
+ exchange.
+
+ j. Share means to provide material to the public by any means or
+ process that requires permission under the Licensed Rights, such
+ as reproduction, public display, public performance, distribution,
+ dissemination, communication, or importation, and to make material
+ available to the public including in ways that members of the
+ public may access the material from a place and at a time
+ individually chosen by them.
+
+ k. Sui Generis Database Rights means rights other than copyright
+ resulting from Directive 96/9/EC of the European Parliament and of
+ the Council of 11 March 1996 on the legal protection of databases,
+ as amended and/or succeeded, as well as other essentially
+ equivalent rights anywhere in the world.
+
+ l. You means the individual or entity exercising the Licensed Rights
+ under this Public License. Your has a corresponding meaning.
+
+Section 2 -- Scope.
+
+ a. License grant.
+
+ 1. Subject to the terms and conditions of this Public License,
+ the Licensor hereby grants You a worldwide, royalty-free,
+ non-sublicensable, non-exclusive, irrevocable license to
+ exercise the Licensed Rights in the Licensed Material to:
+
+ a. reproduce and Share the Licensed Material, in whole or
+ in part, for NonCommercial purposes only; and
+
+ b. produce, reproduce, and Share Adapted Material for
+ NonCommercial purposes only.
+
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
+ Exceptions and Limitations apply to Your use, this Public
+ License does not apply, and You do not need to comply with
+ its terms and conditions.
+
+ 3. Term. The term of this Public License is specified in Section
+ 6(a).
+
+ 4. Media and formats; technical modifications allowed. The
+ Licensor authorizes You to exercise the Licensed Rights in
+ all media and formats whether now known or hereafter created,
+ and to make technical modifications necessary to do so. The
+ Licensor waives and/or agrees not to assert any right or
+ authority to forbid You from making technical modifications
+ necessary to exercise the Licensed Rights, including
+ technical modifications necessary to circumvent Effective
+ Technological Measures. For purposes of this Public License,
+ simply making modifications authorized by this Section 2(a)
+ (4) never produces Adapted Material.
+
+ 5. Downstream recipients.
+
+ a. Offer from the Licensor -- Licensed Material. Every
+ recipient of the Licensed Material automatically
+ receives an offer from the Licensor to exercise the
+ Licensed Rights under the terms and conditions of this
+ Public License.
+
+ b. No downstream restrictions. You may not offer or impose
+ any additional or different terms or conditions on, or
+ apply any Effective Technological Measures to, the
+ Licensed Material if doing so restricts exercise of the
+ Licensed Rights by any recipient of the Licensed
+ Material.
+
+ 6. No endorsement. Nothing in this Public License constitutes or
+ may be construed as permission to assert or imply that You
+ are, or that Your use of the Licensed Material is, connected
+ with, or sponsored, endorsed, or granted official status by,
+ the Licensor or others designated to receive attribution as
+ provided in Section 3(a)(1)(A)(i).
+
+ b. Other rights.
+
+ 1. Moral rights, such as the right of integrity, are not
+ licensed under this Public License, nor are publicity,
+ privacy, and/or other similar personality rights; however, to
+ the extent possible, the Licensor waives and/or agrees not to
+ assert any such rights held by the Licensor to the limited
+ extent necessary to allow You to exercise the Licensed
+ Rights, but not otherwise.
+
+ 2. Patent and trademark rights are not licensed under this
+ Public License.
+
+ 3. To the extent possible, the Licensor waives any right to
+ collect royalties from You for the exercise of the Licensed
+ Rights, whether directly or through a collecting society
+ under any voluntary or waivable statutory or compulsory
+ licensing scheme. In all other cases the Licensor expressly
+ reserves any right to collect such royalties, including when
+ the Licensed Material is used other than for NonCommercial
+ purposes.
+
+Section 3 -- License Conditions.
+
+Your exercise of the Licensed Rights is expressly made subject to the
+following conditions.
+
+ a. Attribution.
+
+ 1. If You Share the Licensed Material (including in modified
+ form), You must:
+
+ a. retain the following if it is supplied by the Licensor
+ with the Licensed Material:
+
+ i. identification of the creator(s) of the Licensed
+ Material and any others designated to receive
+ attribution, in any reasonable manner requested by
+ the Licensor (including by pseudonym if
+ designated);
+
+ ii. a copyright notice;
+
+ iii. a notice that refers to this Public License;
+
+ iv. a notice that refers to the disclaimer of
+ warranties;
+
+ v. a URI or hyperlink to the Licensed Material to the
+ extent reasonably practicable;
+
+ b. indicate if You modified the Licensed Material and
+ retain an indication of any previous modifications; and
+
+ c. indicate the Licensed Material is licensed under this
+ Public License, and include the text of, or the URI or
+ hyperlink to, this Public License.
+
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
+ reasonable manner based on the medium, means, and context in
+ which You Share the Licensed Material. For example, it may be
+ reasonable to satisfy the conditions by providing a URI or
+ hyperlink to a resource that includes the required
+ information.
+
+ 3. If requested by the Licensor, You must remove any of the
+ information required by Section 3(a)(1)(A) to the extent
+ reasonably practicable.
+
+ 4. If You Share Adapted Material You produce, the Adapter's
+ License You apply must not prevent recipients of the Adapted
+ Material from complying with this Public License.
+
+Section 4 -- Sui Generis Database Rights.
+
+Where the Licensed Rights include Sui Generis Database Rights that
+apply to Your use of the Licensed Material:
+
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
+ to extract, reuse, reproduce, and Share all or a substantial
+ portion of the contents of the database for NonCommercial purposes
+ only;
+
+ b. if You include all or a substantial portion of the database
+ contents in a database in which You have Sui Generis Database
+ Rights, then the database in which You have Sui Generis Database
+ Rights (but not its individual contents) is Adapted Material; and
+
+ c. You must comply with the conditions in Section 3(a) if You Share
+ all or a substantial portion of the contents of the database.
+
+For the avoidance of doubt, this Section 4 supplements and does not
+replace Your obligations under this Public License where the Licensed
+Rights include other Copyright and Similar Rights.
+
+Section 5 -- Disclaimer of Warranties and Limitation of Liability.
+
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
+
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
+
+ c. The disclaimer of warranties and limitation of liability provided
+ above shall be interpreted in a manner that, to the extent
+ possible, most closely approximates an absolute disclaimer and
+ waiver of all liability.
+
+Section 6 -- Term and Termination.
+
+ a. This Public License applies for the term of the Copyright and
+ Similar Rights licensed here. However, if You fail to comply with
+ this Public License, then Your rights under this Public License
+ terminate automatically.
+
+ b. Where Your right to use the Licensed Material has terminated under
+ Section 6(a), it reinstates:
+
+ 1. automatically as of the date the violation is cured, provided
+ it is cured within 30 days of Your discovery of the
+ violation; or
+
+ 2. upon express reinstatement by the Licensor.
+
+ For the avoidance of doubt, this Section 6(b) does not affect any
+ right the Licensor may have to seek remedies for Your violations
+ of this Public License.
+
+ c. For the avoidance of doubt, the Licensor may also offer the
+ Licensed Material under separate terms or conditions or stop
+ distributing the Licensed Material at any time; however, doing so
+ will not terminate this Public License.
+
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
+ License.
+
+Section 7 -- Other Terms and Conditions.
+
+ a. The Licensor shall not be bound by any additional or different
+ terms or conditions communicated by You unless expressly agreed.
+
+ b. Any arrangements, understandings, or agreements regarding the
+ Licensed Material not stated herein are separate from and
+ independent of the terms and conditions of this Public License.
+
+Section 8 -- Interpretation.
+
+ a. For the avoidance of doubt, this Public License does not, and
+ shall not be interpreted to, reduce, limit, restrict, or impose
+ conditions on any use of the Licensed Material that could lawfully
+ be made without permission under this Public License.
+
+ b. To the extent possible, if any provision of this Public License is
+ deemed unenforceable, it shall be automatically reformed to the
+ minimum extent necessary to make it enforceable. If the provision
+ cannot be reformed, it shall be severed from this Public License
+ without affecting the enforceability of the remaining terms and
+ conditions.
+
+ c. No term or condition of this Public License will be waived and no
+ failure to comply consented to unless expressly agreed to by the
+ Licensor.
+
+ d. Nothing in this Public License constitutes or may be interpreted
+ as a limitation upon, or waiver of, any privileges and immunities
+ that apply to the Licensor or You, including from the legal
+ processes of any jurisdiction or authority.
+
+=======================================================================
+
+Creative Commons is not a party to its public
+licenses. Notwithstanding, Creative Commons may elect to apply one of
+its public licenses to material it publishes and in those instances
+will be considered the “Licensor.” The text of the Creative Commons
+public licenses is dedicated to the public domain under the CC0 Public
+Domain Dedication. Except for the limited purpose of indicating that
+material is shared under a Creative Commons public license or as
+otherwise permitted by the Creative Commons policies published at
+creativecommons.org/policies, Creative Commons does not authorize the
+use of the trademark "Creative Commons" or any other trademark or logo
+of Creative Commons without its prior written consent including,
+without limitation, in connection with any unauthorized modifications
+to any of its public licenses or any other arrangements,
+understandings, or agreements concerning use of licensed material. For
+the avoidance of doubt, this paragraph does not form part of the
+public licenses.
+
+Creative Commons may be contacted at creativecommons.org.
diff --git a/third_party/Roma/README.md b/third_party/Roma/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..5e984366c8f7af37615d7666f34cd82a90073fee
--- /dev/null
+++ b/third_party/Roma/README.md
@@ -0,0 +1,63 @@
+# RoMa: Revisiting Robust Losses for Dense Feature Matching
+### [Project Page (TODO)](https://parskatt.github.io/RoMa) | [Paper](https://arxiv.org/abs/2305.15404)
+
+
+> RoMa: Revisiting Robust Lossses for Dense Feature Matching
+> [Johan Edstedt](https://scholar.google.com/citations?user=Ul-vMR0AAAAJ), [Qiyu Sun](https://scholar.google.com/citations?user=HS2WuHkAAAAJ), [Georg Bökman](https://scholar.google.com/citations?user=FUE3Wd0AAAAJ), [Mårten Wadenbäck](https://scholar.google.com/citations?user=6WRQpCQAAAAJ), [Michael Felsberg](https://scholar.google.com/citations?&user=lkWfR08AAAAJ)
+> Arxiv 2023
+
+**NOTE!!! Very early code, there might be bugs**
+
+The codebase is in the [roma folder](roma).
+
+## Setup/Install
+In your python environment (tested on Linux python 3.10), run:
+```bash
+pip install -e .
+```
+## Demo / How to Use
+We provide two demos in the [demos folder](demo).
+Here's the gist of it:
+```python
+from roma import roma_outdoor
+roma_model = roma_outdoor(device=device)
+# Match
+warp, certainty = roma_model.match(imA_path, imB_path, device=device)
+# Sample matches for estimation
+matches, certainty = roma_model.sample(warp, certainty)
+# Convert to pixel coordinates (RoMa produces matches in [-1,1]x[-1,1])
+kptsA, kptsB = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
+# Find a fundamental matrix (or anything else of interest)
+F, mask = cv2.findFundamentalMat(
+ kptsA.cpu().numpy(), kptsB.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
+)
+```
+## Reproducing Results
+The experiments in the paper are provided in the [experiments folder](experiments).
+
+### Training
+1. First follow the instructions provided here: https://github.com/Parskatt/DKM for downloading and preprocessing datasets.
+2. Run the relevant experiment, e.g.,
+```bash
+torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d experiments/roma_outdoor.py
+```
+### Testing
+```bash
+python experiments/roma_outdoor.py --only_test --benchmark mega-1500
+```
+## License
+Due to our dependency on [DINOv2](https://github.com/facebookresearch/dinov2/blob/main/LICENSE), the license is sadly non-commercial only for the moment.
+
+## Acknowledgement
+Our codebase builds on the code in [DKM](https://github.com/Parskatt/DKM).
+
+## BibTeX
+If you find our models useful, please consider citing our paper!
+```
+@article{edstedt2023roma,
+title={{RoMa}: Revisiting Robust Lossses for Dense Feature Matching},
+author={Edstedt, Johan and Sun, Qiyu and Bökman, Georg and Wadenbäck, Mårten and Felsberg, Michael},
+journal={arXiv preprint arXiv:2305.15404},
+year={2023}
+}
+```
diff --git a/third_party/Roma/assets/sacre_coeur_A.jpg b/third_party/Roma/assets/sacre_coeur_A.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6e441dad34cf13d8a29d7c6a1519f4263c40058c
--- /dev/null
+++ b/third_party/Roma/assets/sacre_coeur_A.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:90d9c5f5a4d76425624989215120fba6f2899190a1d5654b88fa380c64cf6b2c
+size 117985
diff --git a/third_party/Roma/assets/sacre_coeur_B.jpg b/third_party/Roma/assets/sacre_coeur_B.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..27a239a8fa7581d909104872754ecda79422e7b6
--- /dev/null
+++ b/third_party/Roma/assets/sacre_coeur_B.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2f1eb9bdd4d80e480f672d6a729689ac77f9fd5c8deb90f59b377590f3ca4799
+size 152515
diff --git a/third_party/Roma/data/.gitignore b/third_party/Roma/data/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..c96a04f008ee21e260b28f7701595ed59e2839e3
--- /dev/null
+++ b/third_party/Roma/data/.gitignore
@@ -0,0 +1,2 @@
+*
+!.gitignore
\ No newline at end of file
diff --git a/third_party/Roma/demo/demo_fundamental.py b/third_party/Roma/demo/demo_fundamental.py
new file mode 100644
index 0000000000000000000000000000000000000000..31618d4b06cd56fdd4be9065fb00b826a19e10f9
--- /dev/null
+++ b/third_party/Roma/demo/demo_fundamental.py
@@ -0,0 +1,33 @@
+from PIL import Image
+import torch
+import cv2
+from roma import roma_outdoor
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
+if __name__ == "__main__":
+ from argparse import ArgumentParser
+ parser = ArgumentParser()
+ parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
+ parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
+
+ args, _ = parser.parse_known_args()
+ im1_path = args.im_A_path
+ im2_path = args.im_B_path
+
+ # Create model
+ roma_model = roma_outdoor(device=device)
+
+
+ W_A, H_A = Image.open(im1_path).size
+ W_B, H_B = Image.open(im2_path).size
+
+ # Match
+ warp, certainty = roma_model.match(im1_path, im2_path, device=device)
+ # Sample matches for estimation
+ matches, certainty = roma_model.sample(warp, certainty)
+ kpts1, kpts2 = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
+ F, mask = cv2.findFundamentalMat(
+ kpts1.cpu().numpy(), kpts2.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
+ )
\ No newline at end of file
diff --git a/third_party/Roma/demo/demo_match.py b/third_party/Roma/demo/demo_match.py
new file mode 100644
index 0000000000000000000000000000000000000000..46413bb2b336e2ef2c0bc48315821e4de0fcb982
--- /dev/null
+++ b/third_party/Roma/demo/demo_match.py
@@ -0,0 +1,47 @@
+from PIL import Image
+import torch
+import torch.nn.functional as F
+import numpy as np
+from roma.utils.utils import tensor_to_pil
+
+from roma import roma_indoor
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
+if __name__ == "__main__":
+ from argparse import ArgumentParser
+ parser = ArgumentParser()
+ parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
+ parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
+ parser.add_argument("--save_path", default="demo/dkmv3_warp_sacre_coeur.jpg", type=str)
+
+ args, _ = parser.parse_known_args()
+ im1_path = args.im_A_path
+ im2_path = args.im_B_path
+ save_path = args.save_path
+
+ # Create model
+ roma_model = roma_indoor(device=device)
+
+ H, W = roma_model.get_output_resolution()
+
+ im1 = Image.open(im1_path).resize((W, H))
+ im2 = Image.open(im2_path).resize((W, H))
+
+ # Match
+ warp, certainty = roma_model.match(im1_path, im2_path, device=device)
+ # Sampling not needed, but can be done with model.sample(warp, certainty)
+ x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
+ x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
+
+ im2_transfer_rgb = F.grid_sample(
+ x2[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
+ )[0]
+ im1_transfer_rgb = F.grid_sample(
+ x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
+ )[0]
+ warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2)
+ white_im = torch.ones((H,2*W),device=device)
+ vis_im = certainty * warp_im + (1 - certainty) * white_im
+ tensor_to_pil(vis_im, unnormalize=False).save(save_path)
\ No newline at end of file
diff --git a/third_party/Roma/requirements.txt b/third_party/Roma/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..12addf0d0eb74e6cac0da6bca704eac0b28990d7
--- /dev/null
+++ b/third_party/Roma/requirements.txt
@@ -0,0 +1,13 @@
+torch
+einops
+torchvision
+opencv-python
+kornia
+albumentations
+loguru
+tqdm
+matplotlib
+h5py
+wandb
+timm
+xformers # Optional, used for memefficient attention
\ No newline at end of file
diff --git a/third_party/Roma/roma/__init__.py b/third_party/Roma/roma/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7c96481e0a808b68c7b3054a3e34fa0b5c45ab9
--- /dev/null
+++ b/third_party/Roma/roma/__init__.py
@@ -0,0 +1,8 @@
+import os
+from .models import roma_outdoor, roma_indoor
+
+DEBUG_MODE = False
+RANK = int(os.environ.get('RANK', default = 0))
+GLOBAL_STEP = 0
+STEP_SIZE = 1
+LOCAL_RANK = -1
\ No newline at end of file
diff --git a/third_party/Roma/roma/benchmarks/__init__.py b/third_party/Roma/roma/benchmarks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..de7b841a5a6ab2ba91297a181a79dfaa91c9e104
--- /dev/null
+++ b/third_party/Roma/roma/benchmarks/__init__.py
@@ -0,0 +1,4 @@
+from .hpatches_sequences_homog_benchmark import HpatchesHomogBenchmark
+from .scannet_benchmark import ScanNetBenchmark
+from .megadepth_pose_estimation_benchmark import MegaDepthPoseEstimationBenchmark
+from .megadepth_dense_benchmark import MegadepthDenseBenchmark
diff --git a/third_party/Roma/roma/benchmarks/hpatches_sequences_homog_benchmark.py b/third_party/Roma/roma/benchmarks/hpatches_sequences_homog_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..2154a471c73d9e883c3ba8ed1b90d708f4950a63
--- /dev/null
+++ b/third_party/Roma/roma/benchmarks/hpatches_sequences_homog_benchmark.py
@@ -0,0 +1,113 @@
+from PIL import Image
+import numpy as np
+
+import os
+
+from tqdm import tqdm
+from roma.utils import pose_auc
+import cv2
+
+
+class HpatchesHomogBenchmark:
+ """Hpatches grid goes from [0,n-1] instead of [0.5,n-0.5]"""
+
+ def __init__(self, dataset_path) -> None:
+ seqs_dir = "hpatches-sequences-release"
+ self.seqs_path = os.path.join(dataset_path, seqs_dir)
+ self.seq_names = sorted(os.listdir(self.seqs_path))
+ # Ignore seqs is same as LoFTR.
+ self.ignore_seqs = set(
+ [
+ "i_contruction",
+ "i_crownnight",
+ "i_dc",
+ "i_pencils",
+ "i_whitebuilding",
+ "v_artisans",
+ "v_astronautis",
+ "v_talent",
+ ]
+ )
+
+ def convert_coordinates(self, im_A_coords, im_A_to_im_B, wq, hq, wsup, hsup):
+ offset = 0.5 # Hpatches assumes that the center of the top-left pixel is at [0,0] (I think)
+ im_A_coords = (
+ np.stack(
+ (
+ wq * (im_A_coords[..., 0] + 1) / 2,
+ hq * (im_A_coords[..., 1] + 1) / 2,
+ ),
+ axis=-1,
+ )
+ - offset
+ )
+ im_A_to_im_B = (
+ np.stack(
+ (
+ wsup * (im_A_to_im_B[..., 0] + 1) / 2,
+ hsup * (im_A_to_im_B[..., 1] + 1) / 2,
+ ),
+ axis=-1,
+ )
+ - offset
+ )
+ return im_A_coords, im_A_to_im_B
+
+ def benchmark(self, model, model_name = None):
+ n_matches = []
+ homog_dists = []
+ for seq_idx, seq_name in tqdm(
+ enumerate(self.seq_names), total=len(self.seq_names)
+ ):
+ im_A_path = os.path.join(self.seqs_path, seq_name, "1.ppm")
+ im_A = Image.open(im_A_path)
+ w1, h1 = im_A.size
+ for im_idx in range(2, 7):
+ im_B_path = os.path.join(self.seqs_path, seq_name, f"{im_idx}.ppm")
+ im_B = Image.open(im_B_path)
+ w2, h2 = im_B.size
+ H = np.loadtxt(
+ os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx))
+ )
+ dense_matches, dense_certainty = model.match(
+ im_A_path, im_B_path
+ )
+ good_matches, _ = model.sample(dense_matches, dense_certainty, 5000)
+ pos_a, pos_b = self.convert_coordinates(
+ good_matches[:, :2], good_matches[:, 2:], w1, h1, w2, h2
+ )
+ try:
+ H_pred, inliers = cv2.findHomography(
+ pos_a,
+ pos_b,
+ method = cv2.RANSAC,
+ confidence = 0.99999,
+ ransacReprojThreshold = 3 * min(w2, h2) / 480,
+ )
+ except:
+ H_pred = None
+ if H_pred is None:
+ H_pred = np.zeros((3, 3))
+ H_pred[2, 2] = 1.0
+ corners = np.array(
+ [[0, 0, 1], [0, h1 - 1, 1], [w1 - 1, 0, 1], [w1 - 1, h1 - 1, 1]]
+ )
+ real_warped_corners = np.dot(corners, np.transpose(H))
+ real_warped_corners = (
+ real_warped_corners[:, :2] / real_warped_corners[:, 2:]
+ )
+ warped_corners = np.dot(corners, np.transpose(H_pred))
+ warped_corners = warped_corners[:, :2] / warped_corners[:, 2:]
+ mean_dist = np.mean(
+ np.linalg.norm(real_warped_corners - warped_corners, axis=1)
+ ) / (min(w2, h2) / 480.0)
+ homog_dists.append(mean_dist)
+
+ n_matches = np.array(n_matches)
+ thresholds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+ auc = pose_auc(np.array(homog_dists), thresholds)
+ return {
+ "hpatches_homog_auc_3": auc[2],
+ "hpatches_homog_auc_5": auc[4],
+ "hpatches_homog_auc_10": auc[9],
+ }
diff --git a/third_party/Roma/roma/benchmarks/megadepth_dense_benchmark.py b/third_party/Roma/roma/benchmarks/megadepth_dense_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..0600d354b1d0dfa7f8e2b0f8882a4cc08fafeed9
--- /dev/null
+++ b/third_party/Roma/roma/benchmarks/megadepth_dense_benchmark.py
@@ -0,0 +1,106 @@
+import torch
+import numpy as np
+import tqdm
+from roma.datasets import MegadepthBuilder
+from roma.utils import warp_kpts
+from torch.utils.data import ConcatDataset
+import roma
+
+class MegadepthDenseBenchmark:
+ def __init__(self, data_root="data/megadepth", h = 384, w = 512, num_samples = 2000) -> None:
+ mega = MegadepthBuilder(data_root=data_root)
+ self.dataset = ConcatDataset(
+ mega.build_scenes(split="test_loftr", ht=h, wt=w)
+ ) # fixed resolution of 384,512
+ self.num_samples = num_samples
+
+ def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches):
+ b, h1, w1, d = dense_matches.shape
+ with torch.no_grad():
+ x1 = dense_matches[..., :2].reshape(b, h1 * w1, 2)
+ mask, x2 = warp_kpts(
+ x1.double(),
+ depth1.double(),
+ depth2.double(),
+ T_1to2.double(),
+ K1.double(),
+ K2.double(),
+ )
+ x2 = torch.stack(
+ (w1 * (x2[..., 0] + 1) / 2, h1 * (x2[..., 1] + 1) / 2), dim=-1
+ )
+ prob = mask.float().reshape(b, h1, w1)
+ x2_hat = dense_matches[..., 2:]
+ x2_hat = torch.stack(
+ (w1 * (x2_hat[..., 0] + 1) / 2, h1 * (x2_hat[..., 1] + 1) / 2), dim=-1
+ )
+ gd = (x2_hat - x2.reshape(b, h1, w1, 2)).norm(dim=-1)
+ gd = gd[prob == 1]
+ pck_1 = (gd < 1.0).float().mean()
+ pck_3 = (gd < 3.0).float().mean()
+ pck_5 = (gd < 5.0).float().mean()
+ return gd, pck_1, pck_3, pck_5, prob
+
+ def benchmark(self, model, batch_size=8):
+ model.train(False)
+ with torch.no_grad():
+ gd_tot = 0.0
+ pck_1_tot = 0.0
+ pck_3_tot = 0.0
+ pck_5_tot = 0.0
+ sampler = torch.utils.data.WeightedRandomSampler(
+ torch.ones(len(self.dataset)), replacement=False, num_samples=self.num_samples
+ )
+ B = batch_size
+ dataloader = torch.utils.data.DataLoader(
+ self.dataset, batch_size=B, num_workers=batch_size, sampler=sampler
+ )
+ for idx, data in tqdm.tqdm(enumerate(dataloader), disable = roma.RANK > 0):
+ im_A, im_B, depth1, depth2, T_1to2, K1, K2 = (
+ data["im_A"],
+ data["im_B"],
+ data["im_A_depth"].cuda(),
+ data["im_B_depth"].cuda(),
+ data["T_1to2"].cuda(),
+ data["K1"].cuda(),
+ data["K2"].cuda(),
+ )
+ matches, certainty = model.match(im_A, im_B, batched=True)
+ gd, pck_1, pck_3, pck_5, prob = self.geometric_dist(
+ depth1, depth2, T_1to2, K1, K2, matches
+ )
+ if roma.DEBUG_MODE:
+ from roma.utils.utils import tensor_to_pil
+ import torch.nn.functional as F
+ path = "vis"
+ H, W = model.get_output_resolution()
+ white_im = torch.ones((B,1,H,W),device="cuda")
+ im_B_transfer_rgb = F.grid_sample(
+ im_B.cuda(), matches[:,:,:W, 2:], mode="bilinear", align_corners=False
+ )
+ warp_im = im_B_transfer_rgb
+ c_b = certainty[:,None]#(certainty*0.9 + 0.1*torch.ones_like(certainty))[:,None]
+ vis_im = c_b * warp_im + (1 - c_b) * white_im
+ for b in range(B):
+ import os
+ os.makedirs(f"{path}/{model.name}/{idx}_{b}_{H}_{W}",exist_ok=True)
+ tensor_to_pil(vis_im[b], unnormalize=True).save(
+ f"{path}/{model.name}/{idx}_{b}_{H}_{W}/warp.jpg")
+ tensor_to_pil(im_A[b].cuda(), unnormalize=True).save(
+ f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_A.jpg")
+ tensor_to_pil(im_B[b].cuda(), unnormalize=True).save(
+ f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_B.jpg")
+
+
+ gd_tot, pck_1_tot, pck_3_tot, pck_5_tot = (
+ gd_tot + gd.mean(),
+ pck_1_tot + pck_1,
+ pck_3_tot + pck_3,
+ pck_5_tot + pck_5,
+ )
+ return {
+ "epe": gd_tot.item() / len(dataloader),
+ "mega_pck_1": pck_1_tot.item() / len(dataloader),
+ "mega_pck_3": pck_3_tot.item() / len(dataloader),
+ "mega_pck_5": pck_5_tot.item() / len(dataloader),
+ }
diff --git a/third_party/Roma/roma/benchmarks/megadepth_pose_estimation_benchmark.py b/third_party/Roma/roma/benchmarks/megadepth_pose_estimation_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..8007fe8ecad09c33401450ad6b7af1f3dad043d2
--- /dev/null
+++ b/third_party/Roma/roma/benchmarks/megadepth_pose_estimation_benchmark.py
@@ -0,0 +1,140 @@
+import numpy as np
+import torch
+from roma.utils import *
+from PIL import Image
+from tqdm import tqdm
+import torch.nn.functional as F
+import roma
+import kornia.geometry.epipolar as kepi
+
+class MegaDepthPoseEstimationBenchmark:
+ def __init__(self, data_root="data/megadepth", scene_names = None) -> None:
+ if scene_names is None:
+ self.scene_names = [
+ "0015_0.1_0.3.npz",
+ "0015_0.3_0.5.npz",
+ "0022_0.1_0.3.npz",
+ "0022_0.3_0.5.npz",
+ "0022_0.5_0.7.npz",
+ ]
+ else:
+ self.scene_names = scene_names
+ self.scenes = [
+ np.load(f"{data_root}/{scene}", allow_pickle=True)
+ for scene in self.scene_names
+ ]
+ self.data_root = data_root
+
+ def benchmark(self, model, model_name = None, resolution = None, scale_intrinsics = True, calibrated = True):
+ H,W = model.get_output_resolution()
+ with torch.no_grad():
+ data_root = self.data_root
+ tot_e_t, tot_e_R, tot_e_pose = [], [], []
+ thresholds = [5, 10, 20]
+ for scene_ind in range(len(self.scenes)):
+ import os
+ scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
+ scene = self.scenes[scene_ind]
+ pairs = scene["pair_infos"]
+ intrinsics = scene["intrinsics"]
+ poses = scene["poses"]
+ im_paths = scene["image_paths"]
+ pair_inds = range(len(pairs))
+ for pairind in tqdm(pair_inds):
+ idx1, idx2 = pairs[pairind][0]
+ K1 = intrinsics[idx1].copy()
+ T1 = poses[idx1].copy()
+ R1, t1 = T1[:3, :3], T1[:3, 3]
+ K2 = intrinsics[idx2].copy()
+ T2 = poses[idx2].copy()
+ R2, t2 = T2[:3, :3], T2[:3, 3]
+ R, t = compute_relative_pose(R1, t1, R2, t2)
+ T1_to_2 = np.concatenate((R,t[:,None]), axis=-1)
+ im_A_path = f"{data_root}/{im_paths[idx1]}"
+ im_B_path = f"{data_root}/{im_paths[idx2]}"
+ dense_matches, dense_certainty = model.match(
+ im_A_path, im_B_path, K1.copy(), K2.copy(), T1_to_2.copy()
+ )
+ sparse_matches,_ = model.sample(
+ dense_matches, dense_certainty, 5000
+ )
+
+ im_A = Image.open(im_A_path)
+ w1, h1 = im_A.size
+ im_B = Image.open(im_B_path)
+ w2, h2 = im_B.size
+
+ if scale_intrinsics:
+ scale1 = 1200 / max(w1, h1)
+ scale2 = 1200 / max(w2, h2)
+ w1, h1 = scale1 * w1, scale1 * h1
+ w2, h2 = scale2 * w2, scale2 * h2
+ K1, K2 = K1.copy(), K2.copy()
+ K1[:2] = K1[:2] * scale1
+ K2[:2] = K2[:2] * scale2
+
+ kpts1 = sparse_matches[:, :2]
+ kpts1 = (
+ np.stack(
+ (
+ w1 * (kpts1[:, 0] + 1) / 2,
+ h1 * (kpts1[:, 1] + 1) / 2,
+ ),
+ axis=-1,
+ )
+ )
+ kpts2 = sparse_matches[:, 2:]
+ kpts2 = (
+ np.stack(
+ (
+ w2 * (kpts2[:, 0] + 1) / 2,
+ h2 * (kpts2[:, 1] + 1) / 2,
+ ),
+ axis=-1,
+ )
+ )
+
+ for _ in range(5):
+ shuffling = np.random.permutation(np.arange(len(kpts1)))
+ kpts1 = kpts1[shuffling]
+ kpts2 = kpts2[shuffling]
+ try:
+ threshold = 0.5
+ if calibrated:
+ norm_threshold = threshold / (np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
+ R_est, t_est, mask = estimate_pose(
+ kpts1,
+ kpts2,
+ K1,
+ K2,
+ norm_threshold,
+ conf=0.99999,
+ )
+ T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
+ e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
+ e_pose = max(e_t, e_R)
+ except Exception as e:
+ print(repr(e))
+ e_t, e_R = 90, 90
+ e_pose = max(e_t, e_R)
+ tot_e_t.append(e_t)
+ tot_e_R.append(e_R)
+ tot_e_pose.append(e_pose)
+ tot_e_pose = np.array(tot_e_pose)
+ auc = pose_auc(tot_e_pose, thresholds)
+ acc_5 = (tot_e_pose < 5).mean()
+ acc_10 = (tot_e_pose < 10).mean()
+ acc_15 = (tot_e_pose < 15).mean()
+ acc_20 = (tot_e_pose < 20).mean()
+ map_5 = acc_5
+ map_10 = np.mean([acc_5, acc_10])
+ map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
+ print(f"{model_name} auc: {auc}")
+ return {
+ "auc_5": auc[0],
+ "auc_10": auc[1],
+ "auc_20": auc[2],
+ "map_5": map_5,
+ "map_10": map_10,
+ "map_20": map_20,
+ }
diff --git a/third_party/Roma/roma/benchmarks/scannet_benchmark.py b/third_party/Roma/roma/benchmarks/scannet_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..853af0d0ebef4dfefe2632eb49e4156ea791ee76
--- /dev/null
+++ b/third_party/Roma/roma/benchmarks/scannet_benchmark.py
@@ -0,0 +1,143 @@
+import os.path as osp
+import numpy as np
+import torch
+from roma.utils import *
+from PIL import Image
+from tqdm import tqdm
+
+
+class ScanNetBenchmark:
+ def __init__(self, data_root="data/scannet") -> None:
+ self.data_root = data_root
+
+ def benchmark(self, model, model_name = None):
+ model.train(False)
+ with torch.no_grad():
+ data_root = self.data_root
+ tmp = np.load(osp.join(data_root, "test.npz"))
+ pairs, rel_pose = tmp["name"], tmp["rel_pose"]
+ tot_e_t, tot_e_R, tot_e_pose = [], [], []
+ pair_inds = np.random.choice(
+ range(len(pairs)), size=len(pairs), replace=False
+ )
+ for pairind in tqdm(pair_inds, smoothing=0.9):
+ scene = pairs[pairind]
+ scene_name = f"scene0{scene[0]}_00"
+ im_A_path = osp.join(
+ self.data_root,
+ "scans_test",
+ scene_name,
+ "color",
+ f"{scene[2]}.jpg",
+ )
+ im_A = Image.open(im_A_path)
+ im_B_path = osp.join(
+ self.data_root,
+ "scans_test",
+ scene_name,
+ "color",
+ f"{scene[3]}.jpg",
+ )
+ im_B = Image.open(im_B_path)
+ T_gt = rel_pose[pairind].reshape(3, 4)
+ R, t = T_gt[:3, :3], T_gt[:3, 3]
+ K = np.stack(
+ [
+ np.array([float(i) for i in r.split()])
+ for r in open(
+ osp.join(
+ self.data_root,
+ "scans_test",
+ scene_name,
+ "intrinsic",
+ "intrinsic_color.txt",
+ ),
+ "r",
+ )
+ .read()
+ .split("\n")
+ if r
+ ]
+ )
+ w1, h1 = im_A.size
+ w2, h2 = im_B.size
+ K1 = K.copy()
+ K2 = K.copy()
+ dense_matches, dense_certainty = model.match(im_A_path, im_B_path)
+ sparse_matches, sparse_certainty = model.sample(
+ dense_matches, dense_certainty, 5000
+ )
+ scale1 = 480 / min(w1, h1)
+ scale2 = 480 / min(w2, h2)
+ w1, h1 = scale1 * w1, scale1 * h1
+ w2, h2 = scale2 * w2, scale2 * h2
+ K1 = K1 * scale1
+ K2 = K2 * scale2
+
+ offset = 0.5
+ kpts1 = sparse_matches[:, :2]
+ kpts1 = (
+ np.stack(
+ (
+ w1 * (kpts1[:, 0] + 1) / 2 - offset,
+ h1 * (kpts1[:, 1] + 1) / 2 - offset,
+ ),
+ axis=-1,
+ )
+ )
+ kpts2 = sparse_matches[:, 2:]
+ kpts2 = (
+ np.stack(
+ (
+ w2 * (kpts2[:, 0] + 1) / 2 - offset,
+ h2 * (kpts2[:, 1] + 1) / 2 - offset,
+ ),
+ axis=-1,
+ )
+ )
+ for _ in range(5):
+ shuffling = np.random.permutation(np.arange(len(kpts1)))
+ kpts1 = kpts1[shuffling]
+ kpts2 = kpts2[shuffling]
+ try:
+ norm_threshold = 0.5 / (
+ np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
+ R_est, t_est, mask = estimate_pose(
+ kpts1,
+ kpts2,
+ K1,
+ K2,
+ norm_threshold,
+ conf=0.99999,
+ )
+ T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
+ e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
+ e_pose = max(e_t, e_R)
+ except Exception as e:
+ print(repr(e))
+ e_t, e_R = 90, 90
+ e_pose = max(e_t, e_R)
+ tot_e_t.append(e_t)
+ tot_e_R.append(e_R)
+ tot_e_pose.append(e_pose)
+ tot_e_t.append(e_t)
+ tot_e_R.append(e_R)
+ tot_e_pose.append(e_pose)
+ tot_e_pose = np.array(tot_e_pose)
+ thresholds = [5, 10, 20]
+ auc = pose_auc(tot_e_pose, thresholds)
+ acc_5 = (tot_e_pose < 5).mean()
+ acc_10 = (tot_e_pose < 10).mean()
+ acc_15 = (tot_e_pose < 15).mean()
+ acc_20 = (tot_e_pose < 20).mean()
+ map_5 = acc_5
+ map_10 = np.mean([acc_5, acc_10])
+ map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
+ return {
+ "auc_5": auc[0],
+ "auc_10": auc[1],
+ "auc_20": auc[2],
+ "map_5": map_5,
+ "map_10": map_10,
+ "map_20": map_20,
+ }
diff --git a/third_party/Roma/roma/checkpointing/__init__.py b/third_party/Roma/roma/checkpointing/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..22f5afe727aa6f6e8fffa9ecf5be69cbff686577
--- /dev/null
+++ b/third_party/Roma/roma/checkpointing/__init__.py
@@ -0,0 +1 @@
+from .checkpoint import CheckPoint
diff --git a/third_party/Roma/roma/checkpointing/checkpoint.py b/third_party/Roma/roma/checkpointing/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..8995efeb54f4d558127ea63423fa958c64e9088f
--- /dev/null
+++ b/third_party/Roma/roma/checkpointing/checkpoint.py
@@ -0,0 +1,60 @@
+import os
+import torch
+from torch.nn.parallel.data_parallel import DataParallel
+from torch.nn.parallel.distributed import DistributedDataParallel
+from loguru import logger
+import gc
+
+import roma
+
+class CheckPoint:
+ def __init__(self, dir=None, name="tmp"):
+ self.name = name
+ self.dir = dir
+ os.makedirs(self.dir, exist_ok=True)
+
+ def save(
+ self,
+ model,
+ optimizer,
+ lr_scheduler,
+ n,
+ ):
+ if roma.RANK == 0:
+ assert model is not None
+ if isinstance(model, (DataParallel, DistributedDataParallel)):
+ model = model.module
+ states = {
+ "model": model.state_dict(),
+ "n": n,
+ "optimizer": optimizer.state_dict(),
+ "lr_scheduler": lr_scheduler.state_dict(),
+ }
+ torch.save(states, self.dir + self.name + f"_latest.pth")
+ logger.info(f"Saved states {list(states.keys())}, at step {n}")
+
+ def load(
+ self,
+ model,
+ optimizer,
+ lr_scheduler,
+ n,
+ ):
+ if os.path.exists(self.dir + self.name + f"_latest.pth") and roma.RANK == 0:
+ states = torch.load(self.dir + self.name + f"_latest.pth")
+ if "model" in states:
+ model.load_state_dict(states["model"])
+ if "n" in states:
+ n = states["n"] if states["n"] else n
+ if "optimizer" in states:
+ try:
+ optimizer.load_state_dict(states["optimizer"])
+ except Exception as e:
+ print(f"Failed to load states for optimizer, with error {e}")
+ if "lr_scheduler" in states:
+ lr_scheduler.load_state_dict(states["lr_scheduler"])
+ print(f"Loaded states {list(states.keys())}, at step {n}")
+ del states
+ gc.collect()
+ torch.cuda.empty_cache()
+ return model, optimizer, lr_scheduler, n
\ No newline at end of file
diff --git a/third_party/Roma/roma/datasets/__init__.py b/third_party/Roma/roma/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b60c709926a4a7bd019b73eac10879063a996c90
--- /dev/null
+++ b/third_party/Roma/roma/datasets/__init__.py
@@ -0,0 +1,2 @@
+from .megadepth import MegadepthBuilder
+from .scannet import ScanNetBuilder
\ No newline at end of file
diff --git a/third_party/Roma/roma/datasets/megadepth.py b/third_party/Roma/roma/datasets/megadepth.py
new file mode 100644
index 0000000000000000000000000000000000000000..5deee5ac30c439a9f300c0ad2271f141931020c0
--- /dev/null
+++ b/third_party/Roma/roma/datasets/megadepth.py
@@ -0,0 +1,230 @@
+import os
+from PIL import Image
+import h5py
+import numpy as np
+import torch
+import torchvision.transforms.functional as tvf
+import kornia.augmentation as K
+from roma.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
+import roma
+from roma.utils import *
+import math
+
+class MegadepthScene:
+ def __init__(
+ self,
+ data_root,
+ scene_info,
+ ht=384,
+ wt=512,
+ min_overlap=0.0,
+ max_overlap=1.0,
+ shake_t=0,
+ rot_prob=0.0,
+ normalize=True,
+ max_num_pairs = 100_000,
+ scene_name = None,
+ use_horizontal_flip_aug = False,
+ use_single_horizontal_flip_aug = False,
+ colorjiggle_params = None,
+ random_eraser = None,
+ use_randaug = False,
+ randaug_params = None,
+ randomize_size = False,
+ ) -> None:
+ self.data_root = data_root
+ self.scene_name = os.path.splitext(scene_name)[0]+f"_{min_overlap}_{max_overlap}"
+ self.image_paths = scene_info["image_paths"]
+ self.depth_paths = scene_info["depth_paths"]
+ self.intrinsics = scene_info["intrinsics"]
+ self.poses = scene_info["poses"]
+ self.pairs = scene_info["pairs"]
+ self.overlaps = scene_info["overlaps"]
+ threshold = (self.overlaps > min_overlap) & (self.overlaps < max_overlap)
+ self.pairs = self.pairs[threshold]
+ self.overlaps = self.overlaps[threshold]
+ if len(self.pairs) > max_num_pairs:
+ pairinds = np.random.choice(
+ np.arange(0, len(self.pairs)), max_num_pairs, replace=False
+ )
+ self.pairs = self.pairs[pairinds]
+ self.overlaps = self.overlaps[pairinds]
+ if randomize_size:
+ area = ht * wt
+ s = int(16 * (math.sqrt(area)//16))
+ sizes = ((ht,wt), (s,s), (wt,ht))
+ choice = roma.RANK % 3
+ ht, wt = sizes[choice]
+ # counts, bins = np.histogram(self.overlaps,20)
+ # print(counts)
+ self.im_transform_ops = get_tuple_transform_ops(
+ resize=(ht, wt), normalize=normalize, colorjiggle_params = colorjiggle_params,
+ )
+ self.depth_transform_ops = get_depth_tuple_transform_ops(
+ resize=(ht, wt)
+ )
+ self.wt, self.ht = wt, ht
+ self.shake_t = shake_t
+ self.random_eraser = random_eraser
+ if use_horizontal_flip_aug and use_single_horizontal_flip_aug:
+ raise ValueError("Can't both flip both images and only flip one")
+ self.use_horizontal_flip_aug = use_horizontal_flip_aug
+ self.use_single_horizontal_flip_aug = use_single_horizontal_flip_aug
+ self.use_randaug = use_randaug
+
+ def load_im(self, im_path):
+ im = Image.open(im_path)
+ return im
+
+ def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B):
+ im_A = im_A.flip(-1)
+ im_B = im_B.flip(-1)
+ depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1)
+ flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device)
+ K_A = flip_mat@K_A
+ K_B = flip_mat@K_B
+
+ return im_A, im_B, depth_A, depth_B, K_A, K_B
+
+ def load_depth(self, depth_ref, crop=None):
+ depth = np.array(h5py.File(depth_ref, "r")["depth"])
+ return torch.from_numpy(depth)
+
+ def __len__(self):
+ return len(self.pairs)
+
+ def scale_intrinsic(self, K, wi, hi):
+ sx, sy = self.wt / wi, self.ht / hi
+ sK = torch.tensor([[sx, 0, 0], [0, sy, 0], [0, 0, 1]])
+ return sK @ K
+
+ def rand_shake(self, *things):
+ t = np.random.choice(range(-self.shake_t, self.shake_t + 1), size=2)
+ return [
+ tvf.affine(thing, angle=0.0, translate=list(t), scale=1.0, shear=[0.0, 0.0])
+ for thing in things
+ ], t
+
+ def __getitem__(self, pair_idx):
+ # read intrinsics of original size
+ idx1, idx2 = self.pairs[pair_idx]
+ K1 = torch.tensor(self.intrinsics[idx1].copy(), dtype=torch.float).reshape(3, 3)
+ K2 = torch.tensor(self.intrinsics[idx2].copy(), dtype=torch.float).reshape(3, 3)
+
+ # read and compute relative poses
+ T1 = self.poses[idx1]
+ T2 = self.poses[idx2]
+ T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[
+ :4, :4
+ ] # (4, 4)
+
+ # Load positive pair data
+ im_A, im_B = self.image_paths[idx1], self.image_paths[idx2]
+ depth1, depth2 = self.depth_paths[idx1], self.depth_paths[idx2]
+ im_A_ref = os.path.join(self.data_root, im_A)
+ im_B_ref = os.path.join(self.data_root, im_B)
+ depth_A_ref = os.path.join(self.data_root, depth1)
+ depth_B_ref = os.path.join(self.data_root, depth2)
+ im_A = self.load_im(im_A_ref)
+ im_B = self.load_im(im_B_ref)
+ K1 = self.scale_intrinsic(K1, im_A.width, im_A.height)
+ K2 = self.scale_intrinsic(K2, im_B.width, im_B.height)
+
+ if self.use_randaug:
+ im_A, im_B = self.rand_augment(im_A, im_B)
+
+ depth_A = self.load_depth(depth_A_ref)
+ depth_B = self.load_depth(depth_B_ref)
+ # Process images
+ im_A, im_B = self.im_transform_ops((im_A, im_B))
+ depth_A, depth_B = self.depth_transform_ops(
+ (depth_A[None, None], depth_B[None, None])
+ )
+
+ [im_A, im_B, depth_A, depth_B], t = self.rand_shake(im_A, im_B, depth_A, depth_B)
+ K1[:2, 2] += t
+ K2[:2, 2] += t
+
+ im_A, im_B = im_A[None], im_B[None]
+ if self.random_eraser is not None:
+ im_A, depth_A = self.random_eraser(im_A, depth_A)
+ im_B, depth_B = self.random_eraser(im_B, depth_B)
+
+ if self.use_horizontal_flip_aug:
+ if np.random.rand() > 0.5:
+ im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2)
+ if self.use_single_horizontal_flip_aug:
+ if np.random.rand() > 0.5:
+ im_B, depth_B, K2 = self.single_horizontal_flip(im_B, depth_B, K2)
+
+ if roma.DEBUG_MODE:
+ tensor_to_pil(im_A[0], unnormalize=True).save(
+ f"vis/im_A.jpg")
+ tensor_to_pil(im_B[0], unnormalize=True).save(
+ f"vis/im_B.jpg")
+
+ data_dict = {
+ "im_A": im_A[0],
+ "im_A_identifier": self.image_paths[idx1].split("/")[-1].split(".jpg")[0],
+ "im_B": im_B[0],
+ "im_B_identifier": self.image_paths[idx2].split("/")[-1].split(".jpg")[0],
+ "im_A_depth": depth_A[0, 0],
+ "im_B_depth": depth_B[0, 0],
+ "K1": K1,
+ "K2": K2,
+ "T_1to2": T_1to2,
+ "im_A_path": im_A_ref,
+ "im_B_path": im_B_ref,
+
+ }
+ return data_dict
+
+
+class MegadepthBuilder:
+ def __init__(self, data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True) -> None:
+ self.data_root = data_root
+ self.scene_info_root = os.path.join(data_root, "prep_scene_info")
+ self.all_scenes = os.listdir(self.scene_info_root)
+ self.test_scenes = ["0017.npy", "0004.npy", "0048.npy", "0013.npy"]
+ # LoFTR did the D2-net preprocessing differently than we did and got more ignore scenes, can optionially ignore those
+ self.loftr_ignore_scenes = set(['0121.npy', '0133.npy', '0168.npy', '0178.npy', '0229.npy', '0349.npy', '0412.npy', '0430.npy', '0443.npy', '1001.npy', '5014.npy', '5015.npy', '5016.npy'])
+ self.imc21_scenes = set(['0008.npy', '0019.npy', '0021.npy', '0024.npy', '0025.npy', '0032.npy', '0063.npy', '1589.npy'])
+ self.test_scenes_loftr = ["0015.npy", "0022.npy"]
+ self.loftr_ignore = loftr_ignore
+ self.imc21_ignore = imc21_ignore
+
+ def build_scenes(self, split="train", min_overlap=0.0, scene_names = None, **kwargs):
+ if split == "train":
+ scene_names = set(self.all_scenes) - set(self.test_scenes)
+ elif split == "train_loftr":
+ scene_names = set(self.all_scenes) - set(self.test_scenes_loftr)
+ elif split == "test":
+ scene_names = self.test_scenes
+ elif split == "test_loftr":
+ scene_names = self.test_scenes_loftr
+ elif split == "custom":
+ scene_names = scene_names
+ else:
+ raise ValueError(f"Split {split} not available")
+ scenes = []
+ for scene_name in scene_names:
+ if self.loftr_ignore and scene_name in self.loftr_ignore_scenes:
+ continue
+ if self.imc21_ignore and scene_name in self.imc21_scenes:
+ continue
+ scene_info = np.load(
+ os.path.join(self.scene_info_root, scene_name), allow_pickle=True
+ ).item()
+ scenes.append(
+ MegadepthScene(
+ self.data_root, scene_info, min_overlap=min_overlap,scene_name = scene_name, **kwargs
+ )
+ )
+ return scenes
+
+ def weight_scenes(self, concat_dataset, alpha=0.5):
+ ns = []
+ for d in concat_dataset.datasets:
+ ns.append(len(d))
+ ws = torch.cat([torch.ones(n) / n**alpha for n in ns])
+ return ws
diff --git a/third_party/Roma/roma/datasets/scannet.py b/third_party/Roma/roma/datasets/scannet.py
new file mode 100644
index 0000000000000000000000000000000000000000..704ea57259afdfbbca627ad143bee97a0a79d41c
--- /dev/null
+++ b/third_party/Roma/roma/datasets/scannet.py
@@ -0,0 +1,160 @@
+import os
+import random
+from PIL import Image
+import cv2
+import h5py
+import numpy as np
+import torch
+from torch.utils.data import (
+ Dataset,
+ DataLoader,
+ ConcatDataset)
+
+import torchvision.transforms.functional as tvf
+import kornia.augmentation as K
+import os.path as osp
+import matplotlib.pyplot as plt
+import roma
+from roma.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
+from roma.utils.transforms import GeometricSequential
+from tqdm import tqdm
+
+class ScanNetScene:
+ def __init__(self, data_root, scene_info, ht = 384, wt = 512, min_overlap=0., shake_t = 0, rot_prob=0.,use_horizontal_flip_aug = False,
+) -> None:
+ self.scene_root = osp.join(data_root,"scans","scans_train")
+ self.data_names = scene_info['name']
+ self.overlaps = scene_info['score']
+ # Only sample 10s
+ valid = (self.data_names[:,-2:] % 10).sum(axis=-1) == 0
+ self.overlaps = self.overlaps[valid]
+ self.data_names = self.data_names[valid]
+ if len(self.data_names) > 10000:
+ pairinds = np.random.choice(np.arange(0,len(self.data_names)),10000,replace=False)
+ self.data_names = self.data_names[pairinds]
+ self.overlaps = self.overlaps[pairinds]
+ self.im_transform_ops = get_tuple_transform_ops(resize=(ht, wt), normalize=True)
+ self.depth_transform_ops = get_depth_tuple_transform_ops(resize=(ht, wt), normalize=False)
+ self.wt, self.ht = wt, ht
+ self.shake_t = shake_t
+ self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob))
+ self.use_horizontal_flip_aug = use_horizontal_flip_aug
+
+ def load_im(self, im_B, crop=None):
+ im = Image.open(im_B)
+ return im
+
+ def load_depth(self, depth_ref, crop=None):
+ depth = cv2.imread(str(depth_ref), cv2.IMREAD_UNCHANGED)
+ depth = depth / 1000
+ depth = torch.from_numpy(depth).float() # (h, w)
+ return depth
+
+ def __len__(self):
+ return len(self.data_names)
+
+ def scale_intrinsic(self, K, wi, hi):
+ sx, sy = self.wt / wi, self.ht / hi
+ sK = torch.tensor([[sx, 0, 0],
+ [0, sy, 0],
+ [0, 0, 1]])
+ return sK@K
+
+ def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B):
+ im_A = im_A.flip(-1)
+ im_B = im_B.flip(-1)
+ depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1)
+ flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device)
+ K_A = flip_mat@K_A
+ K_B = flip_mat@K_B
+
+ return im_A, im_B, depth_A, depth_B, K_A, K_B
+ def read_scannet_pose(self,path):
+ """ Read ScanNet's Camera2World pose and transform it to World2Camera.
+
+ Returns:
+ pose_w2c (np.ndarray): (4, 4)
+ """
+ cam2world = np.loadtxt(path, delimiter=' ')
+ world2cam = np.linalg.inv(cam2world)
+ return world2cam
+
+
+ def read_scannet_intrinsic(self,path):
+ """ Read ScanNet's intrinsic matrix and return the 3x3 matrix.
+ """
+ intrinsic = np.loadtxt(path, delimiter=' ')
+ return torch.tensor(intrinsic[:-1, :-1], dtype = torch.float)
+
+ def __getitem__(self, pair_idx):
+ # read intrinsics of original size
+ data_name = self.data_names[pair_idx]
+ scene_name, scene_sub_name, stem_name_1, stem_name_2 = data_name
+ scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}'
+
+ # read the intrinsic of depthmap
+ K1 = K2 = self.read_scannet_intrinsic(osp.join(self.scene_root,
+ scene_name,
+ 'intrinsic', 'intrinsic_color.txt'))#the depth K is not the same, but doesnt really matter
+ # read and compute relative poses
+ T1 = self.read_scannet_pose(osp.join(self.scene_root,
+ scene_name,
+ 'pose', f'{stem_name_1}.txt'))
+ T2 = self.read_scannet_pose(osp.join(self.scene_root,
+ scene_name,
+ 'pose', f'{stem_name_2}.txt'))
+ T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[:4, :4] # (4, 4)
+
+ # Load positive pair data
+ im_A_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_1}.jpg')
+ im_B_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_2}.jpg')
+ depth_A_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_1}.png')
+ depth_B_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_2}.png')
+
+ im_A = self.load_im(im_A_ref)
+ im_B = self.load_im(im_B_ref)
+ depth_A = self.load_depth(depth_A_ref)
+ depth_B = self.load_depth(depth_B_ref)
+
+ # Recompute camera intrinsic matrix due to the resize
+ K1 = self.scale_intrinsic(K1, im_A.width, im_A.height)
+ K2 = self.scale_intrinsic(K2, im_B.width, im_B.height)
+ # Process images
+ im_A, im_B = self.im_transform_ops((im_A, im_B))
+ depth_A, depth_B = self.depth_transform_ops((depth_A[None,None], depth_B[None,None]))
+ if self.use_horizontal_flip_aug:
+ if np.random.rand() > 0.5:
+ im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2)
+
+ data_dict = {'im_A': im_A,
+ 'im_B': im_B,
+ 'im_A_depth': depth_A[0,0],
+ 'im_B_depth': depth_B[0,0],
+ 'K1': K1,
+ 'K2': K2,
+ 'T_1to2':T_1to2,
+ }
+ return data_dict
+
+
+class ScanNetBuilder:
+ def __init__(self, data_root = 'data/scannet') -> None:
+ self.data_root = data_root
+ self.scene_info_root = os.path.join(data_root,'scannet_indices')
+ self.all_scenes = os.listdir(self.scene_info_root)
+
+ def build_scenes(self, split = 'train', min_overlap=0., **kwargs):
+ # Note: split doesn't matter here as we always use same scannet_train scenes
+ scene_names = self.all_scenes
+ scenes = []
+ for scene_name in tqdm(scene_names, disable = roma.RANK > 0):
+ scene_info = np.load(os.path.join(self.scene_info_root,scene_name), allow_pickle=True)
+ scenes.append(ScanNetScene(self.data_root, scene_info, min_overlap=min_overlap, **kwargs))
+ return scenes
+
+ def weight_scenes(self, concat_dataset, alpha=.5):
+ ns = []
+ for d in concat_dataset.datasets:
+ ns.append(len(d))
+ ws = torch.cat([torch.ones(n)/n**alpha for n in ns])
+ return ws
diff --git a/third_party/Roma/roma/losses/__init__.py b/third_party/Roma/roma/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e08abacfc0f83d7de0f2ddc0583766a80bf53cf
--- /dev/null
+++ b/third_party/Roma/roma/losses/__init__.py
@@ -0,0 +1 @@
+from .robust_loss import RobustLosses
\ No newline at end of file
diff --git a/third_party/Roma/roma/losses/robust_loss.py b/third_party/Roma/roma/losses/robust_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..b932b2706f619c083485e1be0d86eec44ead83ef
--- /dev/null
+++ b/third_party/Roma/roma/losses/robust_loss.py
@@ -0,0 +1,157 @@
+from einops.einops import rearrange
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from roma.utils.utils import get_gt_warp
+import wandb
+import roma
+import math
+
+class RobustLosses(nn.Module):
+ def __init__(
+ self,
+ robust=False,
+ center_coords=False,
+ scale_normalize=False,
+ ce_weight=0.01,
+ local_loss=True,
+ local_dist=4.0,
+ local_largest_scale=8,
+ smooth_mask = False,
+ depth_interpolation_mode = "bilinear",
+ mask_depth_loss = False,
+ relative_depth_error_threshold = 0.05,
+ alpha = 1.,
+ c = 1e-3,
+ ):
+ super().__init__()
+ self.robust = robust # measured in pixels
+ self.center_coords = center_coords
+ self.scale_normalize = scale_normalize
+ self.ce_weight = ce_weight
+ self.local_loss = local_loss
+ self.local_dist = local_dist
+ self.local_largest_scale = local_largest_scale
+ self.smooth_mask = smooth_mask
+ self.depth_interpolation_mode = depth_interpolation_mode
+ self.mask_depth_loss = mask_depth_loss
+ self.relative_depth_error_threshold = relative_depth_error_threshold
+ self.avg_overlap = dict()
+ self.alpha = alpha
+ self.c = c
+
+ def gm_cls_loss(self, x2, prob, scale_gm_cls, gm_certainty, scale):
+ with torch.no_grad():
+ B, C, H, W = scale_gm_cls.shape
+ device = x2.device
+ cls_res = round(math.sqrt(C))
+ G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)])
+ G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2)
+ GT = (G[None,:,None,None,:]-x2[:,None]).norm(dim=-1).min(dim=1).indices
+ cls_loss = F.cross_entropy(scale_gm_cls, GT, reduction = 'none')[prob > 0.99]
+ if not torch.any(cls_loss):
+ cls_loss = (certainty_loss * 0.0) # Prevent issues where prob is 0 everywhere
+
+ certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:,0], prob)
+ losses = {
+ f"gm_certainty_loss_{scale}": certainty_loss.mean(),
+ f"gm_cls_loss_{scale}": cls_loss.mean(),
+ }
+ wandb.log(losses, step = roma.GLOBAL_STEP)
+ return losses
+
+ def delta_cls_loss(self, x2, prob, flow_pre_delta, delta_cls, certainty, scale, offset_scale):
+ with torch.no_grad():
+ B, C, H, W = delta_cls.shape
+ device = x2.device
+ cls_res = round(math.sqrt(C))
+ G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)])
+ G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2) * offset_scale
+ GT = (G[None,:,None,None,:] + flow_pre_delta[:,None] - x2[:,None]).norm(dim=-1).min(dim=1).indices
+ cls_loss = F.cross_entropy(delta_cls, GT, reduction = 'none')[prob > 0.99]
+ if not torch.any(cls_loss):
+ cls_loss = (certainty_loss * 0.0) # Prevent issues where prob is 0 everywhere
+ certainty_loss = F.binary_cross_entropy_with_logits(certainty[:,0], prob)
+ losses = {
+ f"delta_certainty_loss_{scale}": certainty_loss.mean(),
+ f"delta_cls_loss_{scale}": cls_loss.mean(),
+ }
+ wandb.log(losses, step = roma.GLOBAL_STEP)
+ return losses
+
+ def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode = "delta"):
+ epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1)
+ if scale == 1:
+ pck_05 = (epe[prob > 0.99] < 0.5 * (2/512)).float().mean()
+ wandb.log({"train_pck_05": pck_05}, step = roma.GLOBAL_STEP)
+
+ ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], prob)
+ a = self.alpha
+ cs = self.c * scale
+ x = epe[prob > 0.99]
+ reg_loss = cs**a * ((x/(cs))**2 + 1**2)**(a/2)
+ if not torch.any(reg_loss):
+ reg_loss = (ce_loss * 0.0) # Prevent issues where prob is 0 everywhere
+ losses = {
+ f"{mode}_certainty_loss_{scale}": ce_loss.mean(),
+ f"{mode}_regression_loss_{scale}": reg_loss.mean(),
+ }
+ wandb.log(losses, step = roma.GLOBAL_STEP)
+ return losses
+
+ def forward(self, corresps, batch):
+ scales = list(corresps.keys())
+ tot_loss = 0.0
+ # scale_weights due to differences in scale for regression gradients and classification gradients
+ scale_weights = {1:1, 2:1, 4:1, 8:1, 16:1}
+ for scale in scales:
+ scale_corresps = corresps[scale]
+ scale_certainty, flow_pre_delta, delta_cls, offset_scale, scale_gm_cls, scale_gm_certainty, flow, scale_gm_flow = (
+ scale_corresps["certainty"],
+ scale_corresps["flow_pre_delta"],
+ scale_corresps.get("delta_cls"),
+ scale_corresps.get("offset_scale"),
+ scale_corresps.get("gm_cls"),
+ scale_corresps.get("gm_certainty"),
+ scale_corresps["flow"],
+ scale_corresps.get("gm_flow"),
+
+ )
+ flow_pre_delta = rearrange(flow_pre_delta, "b d h w -> b h w d")
+ b, h, w, d = flow_pre_delta.shape
+ gt_warp, gt_prob = get_gt_warp(
+ batch["im_A_depth"],
+ batch["im_B_depth"],
+ batch["T_1to2"],
+ batch["K1"],
+ batch["K2"],
+ H=h,
+ W=w,
+ )
+ x2 = gt_warp.float()
+ prob = gt_prob
+
+ if self.local_largest_scale >= scale:
+ prob = prob * (
+ F.interpolate(prev_epe[:, None], size=(h, w), mode="nearest-exact")[:, 0]
+ < (2 / 512) * (self.local_dist[scale] * scale))
+
+ if scale_gm_cls is not None:
+ gm_cls_losses = self.gm_cls_loss(x2, prob, scale_gm_cls, scale_gm_certainty, scale)
+ gm_loss = self.ce_weight * gm_cls_losses[f"gm_certainty_loss_{scale}"] + gm_cls_losses[f"gm_cls_loss_{scale}"]
+ tot_loss = tot_loss + scale_weights[scale] * gm_loss
+ elif scale_gm_flow is not None:
+ gm_flow_losses = self.regression_loss(x2, prob, scale_gm_flow, scale_gm_certainty, scale, mode = "gm")
+ gm_loss = self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"] + gm_flow_losses[f"gm_regression_loss_{scale}"]
+ tot_loss = tot_loss + scale_weights[scale] * gm_loss
+
+ if delta_cls is not None:
+ delta_cls_losses = self.delta_cls_loss(x2, prob, flow_pre_delta, delta_cls, scale_certainty, scale, offset_scale)
+ delta_cls_loss = self.ce_weight * delta_cls_losses[f"delta_certainty_loss_{scale}"] + delta_cls_losses[f"delta_cls_loss_{scale}"]
+ tot_loss = tot_loss + scale_weights[scale] * delta_cls_loss
+ else:
+ delta_regression_losses = self.regression_loss(x2, prob, flow, scale_certainty, scale)
+ reg_loss = self.ce_weight * delta_regression_losses[f"delta_certainty_loss_{scale}"] + delta_regression_losses[f"delta_regression_loss_{scale}"]
+ tot_loss = tot_loss + scale_weights[scale] * reg_loss
+ prev_epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1).detach()
+ return tot_loss
diff --git a/third_party/Roma/roma/models/__init__.py b/third_party/Roma/roma/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f20461e2f3a1722e558cefab94c5164be8842c3
--- /dev/null
+++ b/third_party/Roma/roma/models/__init__.py
@@ -0,0 +1 @@
+from .model_zoo import roma_outdoor, roma_indoor
\ No newline at end of file
diff --git a/third_party/Roma/roma/models/encoders.py b/third_party/Roma/roma/models/encoders.py
new file mode 100644
index 0000000000000000000000000000000000000000..69b488743b91905aca6adc3e4d3439421d492051
--- /dev/null
+++ b/third_party/Roma/roma/models/encoders.py
@@ -0,0 +1,118 @@
+from typing import Optional, Union
+import torch
+from torch import device
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.models as tvm
+import gc
+
+
+class ResNet50(nn.Module):
+ def __init__(self, pretrained=False, high_res = False, weights = None, dilation = None, freeze_bn = True, anti_aliased = False, early_exit = False, amp = False) -> None:
+ super().__init__()
+ if dilation is None:
+ dilation = [False,False,False]
+ if anti_aliased:
+ pass
+ else:
+ if weights is not None:
+ self.net = tvm.resnet50(weights = weights,replace_stride_with_dilation=dilation)
+ else:
+ self.net = tvm.resnet50(pretrained=pretrained,replace_stride_with_dilation=dilation)
+
+ self.high_res = high_res
+ self.freeze_bn = freeze_bn
+ self.early_exit = early_exit
+ self.amp = amp
+ self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+
+ def forward(self, x, **kwargs):
+ with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
+ net = self.net
+ feats = {1:x}
+ x = net.conv1(x)
+ x = net.bn1(x)
+ x = net.relu(x)
+ feats[2] = x
+ x = net.maxpool(x)
+ x = net.layer1(x)
+ feats[4] = x
+ x = net.layer2(x)
+ feats[8] = x
+ if self.early_exit:
+ return feats
+ x = net.layer3(x)
+ feats[16] = x
+ x = net.layer4(x)
+ feats[32] = x
+ return feats
+
+ def train(self, mode=True):
+ super().train(mode)
+ if self.freeze_bn:
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
+ pass
+
+class VGG19(nn.Module):
+ def __init__(self, pretrained=False, amp = False) -> None:
+ super().__init__()
+ self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
+ self.amp = amp
+ self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+
+ def forward(self, x, **kwargs):
+ with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
+ feats = {}
+ scale = 1
+ for layer in self.layers:
+ if isinstance(layer, nn.MaxPool2d):
+ feats[scale] = x
+ scale = scale*2
+ x = layer(x)
+ return feats
+
+class CNNandDinov2(nn.Module):
+ def __init__(self, cnn_kwargs = None, amp = False, use_vgg = False, dinov2_weights = None):
+ super().__init__()
+ if dinov2_weights is None:
+ dinov2_weights = torch.hub.load_state_dict_from_url("https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", map_location="cpu")
+ from .transformer import vit_large
+ vit_kwargs = dict(img_size= 518,
+ patch_size= 14,
+ init_values = 1.0,
+ ffn_layer = "mlp",
+ block_chunks = 0,
+ )
+
+ dinov2_vitl14 = vit_large(**vit_kwargs).eval()
+ dinov2_vitl14.load_state_dict(dinov2_weights)
+ cnn_kwargs = cnn_kwargs if cnn_kwargs is not None else {}
+ if not use_vgg:
+ self.cnn = ResNet50(**cnn_kwargs)
+ else:
+ self.cnn = VGG19(**cnn_kwargs)
+ self.amp = amp
+ self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+ if self.amp:
+ dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
+ self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
+
+
+ def train(self, mode: bool = True):
+ return self.cnn.train(mode)
+
+ def forward(self, x, upsample = False):
+ B,C,H,W = x.shape
+ feature_pyramid = self.cnn(x)
+
+ if not upsample:
+ with torch.no_grad():
+ if self.dinov2_vitl14[0].device != x.device:
+ self.dinov2_vitl14[0] = self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype)
+ dinov2_features_16 = self.dinov2_vitl14[0].forward_features(x.to(self.amp_dtype))
+ features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,H//14, W//14)
+ del dinov2_features_16
+ feature_pyramid[16] = features_16
+ return feature_pyramid
\ No newline at end of file
diff --git a/third_party/Roma/roma/models/matcher.py b/third_party/Roma/roma/models/matcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..c06e1ba3aebe8dec7ee9f1800a6f4ba55ac8f0d9
--- /dev/null
+++ b/third_party/Roma/roma/models/matcher.py
@@ -0,0 +1,649 @@
+import os
+import math
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+import warnings
+from warnings import warn
+
+import roma
+from roma.utils import get_tuple_transform_ops
+from roma.utils.local_correlation import local_correlation
+from roma.utils.utils import cls_to_flow_refine
+from roma.utils.kde import kde
+
+class ConvRefiner(nn.Module):
+ def __init__(
+ self,
+ in_dim=6,
+ hidden_dim=16,
+ out_dim=2,
+ dw=False,
+ kernel_size=5,
+ hidden_blocks=3,
+ displacement_emb = None,
+ displacement_emb_dim = None,
+ local_corr_radius = None,
+ corr_in_other = None,
+ no_im_B_fm = False,
+ amp = False,
+ concat_logits = False,
+ use_bias_block_1 = True,
+ use_cosine_corr = False,
+ disable_local_corr_grad = False,
+ is_classifier = False,
+ sample_mode = "bilinear",
+ norm_type = nn.BatchNorm2d,
+ bn_momentum = 0.1,
+ ):
+ super().__init__()
+ self.bn_momentum = bn_momentum
+ self.block1 = self.create_block(
+ in_dim, hidden_dim, dw=dw, kernel_size=kernel_size, bias = use_bias_block_1,
+ )
+ self.hidden_blocks = nn.Sequential(
+ *[
+ self.create_block(
+ hidden_dim,
+ hidden_dim,
+ dw=dw,
+ kernel_size=kernel_size,
+ norm_type=norm_type,
+ )
+ for hb in range(hidden_blocks)
+ ]
+ )
+ self.hidden_blocks = self.hidden_blocks
+ self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
+ if displacement_emb:
+ self.has_displacement_emb = True
+ self.disp_emb = nn.Conv2d(2,displacement_emb_dim,1,1,0)
+ else:
+ self.has_displacement_emb = False
+ self.local_corr_radius = local_corr_radius
+ self.corr_in_other = corr_in_other
+ self.no_im_B_fm = no_im_B_fm
+ self.amp = amp
+ self.concat_logits = concat_logits
+ self.use_cosine_corr = use_cosine_corr
+ self.disable_local_corr_grad = disable_local_corr_grad
+ self.is_classifier = is_classifier
+ self.sample_mode = sample_mode
+ self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+
+ def create_block(
+ self,
+ in_dim,
+ out_dim,
+ dw=False,
+ kernel_size=5,
+ bias = True,
+ norm_type = nn.BatchNorm2d,
+ ):
+ num_groups = 1 if not dw else in_dim
+ if dw:
+ assert (
+ out_dim % in_dim == 0
+ ), "outdim must be divisible by indim for depthwise"
+ conv1 = nn.Conv2d(
+ in_dim,
+ out_dim,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ groups=num_groups,
+ bias=bias,
+ )
+ norm = norm_type(out_dim, momentum = self.bn_momentum) if norm_type is nn.BatchNorm2d else norm_type(num_channels = out_dim)
+ relu = nn.ReLU(inplace=True)
+ conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
+ return nn.Sequential(conv1, norm, relu, conv2)
+
+ def forward(self, x, y, flow, scale_factor = 1, logits = None):
+ b,c,hs,ws = x.shape
+ with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
+ with torch.no_grad():
+ x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False, mode = self.sample_mode)
+ if self.has_displacement_emb:
+ im_A_coords = torch.meshgrid(
+ (
+ torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device="cuda"),
+ torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device="cuda"),
+ )
+ )
+ im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
+ im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
+ in_displacement = flow-im_A_coords
+ emb_in_displacement = self.disp_emb(40/32 * scale_factor * in_displacement)
+ if self.local_corr_radius:
+ if self.corr_in_other:
+ # Corr in other means take a kxk grid around the predicted coordinate in other image
+ local_corr = local_correlation(x,y,local_radius=self.local_corr_radius,flow = flow,
+ sample_mode = self.sample_mode)
+ else:
+ raise NotImplementedError("Local corr in own frame should not be used.")
+ if self.no_im_B_fm:
+ x_hat = torch.zeros_like(x)
+ d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1)
+ else:
+ d = torch.cat((x, x_hat, emb_in_displacement), dim=1)
+ else:
+ if self.no_im_B_fm:
+ x_hat = torch.zeros_like(x)
+ d = torch.cat((x, x_hat), dim=1)
+ if self.concat_logits:
+ d = torch.cat((d, logits), dim=1)
+ d = self.block1(d)
+ d = self.hidden_blocks(d)
+ d = self.out_conv(d.float())
+ displacement, certainty = d[:, :-1], d[:, -1:]
+ return displacement, certainty
+
+class CosKernel(nn.Module): # similar to softmax kernel
+ def __init__(self, T, learn_temperature=False):
+ super().__init__()
+ self.learn_temperature = learn_temperature
+ if self.learn_temperature:
+ self.T = nn.Parameter(torch.tensor(T))
+ else:
+ self.T = T
+
+ def __call__(self, x, y, eps=1e-6):
+ c = torch.einsum("bnd,bmd->bnm", x, y) / (
+ x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps
+ )
+ if self.learn_temperature:
+ T = self.T.abs() + 0.01
+ else:
+ T = torch.tensor(self.T, device=c.device)
+ K = ((c - 1.0) / T).exp()
+ return K
+
+class GP(nn.Module):
+ def __init__(
+ self,
+ kernel,
+ T=1,
+ learn_temperature=False,
+ only_attention=False,
+ gp_dim=64,
+ basis="fourier",
+ covar_size=5,
+ only_nearest_neighbour=False,
+ sigma_noise=0.1,
+ no_cov=False,
+ predict_features = False,
+ ):
+ super().__init__()
+ self.K = kernel(T=T, learn_temperature=learn_temperature)
+ self.sigma_noise = sigma_noise
+ self.covar_size = covar_size
+ self.pos_conv = torch.nn.Conv2d(2, gp_dim, 1, 1)
+ self.only_attention = only_attention
+ self.only_nearest_neighbour = only_nearest_neighbour
+ self.basis = basis
+ self.no_cov = no_cov
+ self.dim = gp_dim
+ self.predict_features = predict_features
+
+ def get_local_cov(self, cov):
+ K = self.covar_size
+ b, h, w, h, w = cov.shape
+ hw = h * w
+ cov = F.pad(cov, 4 * (K // 2,)) # pad v_q
+ delta = torch.stack(
+ torch.meshgrid(
+ torch.arange(-(K // 2), K // 2 + 1), torch.arange(-(K // 2), K // 2 + 1)
+ ),
+ dim=-1,
+ )
+ positions = torch.stack(
+ torch.meshgrid(
+ torch.arange(K // 2, h + K // 2), torch.arange(K // 2, w + K // 2)
+ ),
+ dim=-1,
+ )
+ neighbours = positions[:, :, None, None, :] + delta[None, :, :]
+ points = torch.arange(hw)[:, None].expand(hw, K**2)
+ local_cov = cov.reshape(b, hw, h + K - 1, w + K - 1)[
+ :,
+ points.flatten(),
+ neighbours[..., 0].flatten(),
+ neighbours[..., 1].flatten(),
+ ].reshape(b, h, w, K**2)
+ return local_cov
+
+ def reshape(self, x):
+ return rearrange(x, "b d h w -> b (h w) d")
+
+ def project_to_basis(self, x):
+ if self.basis == "fourier":
+ return torch.cos(8 * math.pi * self.pos_conv(x))
+ elif self.basis == "linear":
+ return self.pos_conv(x)
+ else:
+ raise ValueError(
+ "No other bases other than fourier and linear currently im_Bed in public release"
+ )
+
+ def get_pos_enc(self, y):
+ b, c, h, w = y.shape
+ coarse_coords = torch.meshgrid(
+ (
+ torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=y.device),
+ torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=y.device),
+ )
+ )
+
+ coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
+ None
+ ].expand(b, h, w, 2)
+ coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
+ coarse_embedded_coords = self.project_to_basis(coarse_coords)
+ return coarse_embedded_coords
+
+ def forward(self, x, y, **kwargs):
+ b, c, h1, w1 = x.shape
+ b, c, h2, w2 = y.shape
+ f = self.get_pos_enc(y)
+ b, d, h2, w2 = f.shape
+ x, y, f = self.reshape(x.float()), self.reshape(y.float()), self.reshape(f)
+ K_xx = self.K(x, x)
+ K_yy = self.K(y, y)
+ K_xy = self.K(x, y)
+ K_yx = K_xy.permute(0, 2, 1)
+ sigma_noise = self.sigma_noise * torch.eye(h2 * w2, device=x.device)[None, :, :]
+ with warnings.catch_warnings():
+ K_yy_inv = torch.linalg.inv(K_yy + sigma_noise)
+
+ mu_x = K_xy.matmul(K_yy_inv.matmul(f))
+ mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1)
+ if not self.no_cov:
+ cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx))
+ cov_x = rearrange(cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1)
+ local_cov_x = self.get_local_cov(cov_x)
+ local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w")
+ gp_feats = torch.cat((mu_x, local_cov_x), dim=1)
+ else:
+ gp_feats = mu_x
+ return gp_feats
+
+class Decoder(nn.Module):
+ def __init__(
+ self, embedding_decoder, gps, proj, conv_refiner, detach=False, scales="all", pos_embeddings = None,
+ num_refinement_steps_per_scale = 1, warp_noise_std = 0.0, displacement_dropout_p = 0.0, gm_warp_dropout_p = 0.0,
+ flow_upsample_mode = "bilinear"
+ ):
+ super().__init__()
+ self.embedding_decoder = embedding_decoder
+ self.num_refinement_steps_per_scale = num_refinement_steps_per_scale
+ self.gps = gps
+ self.proj = proj
+ self.conv_refiner = conv_refiner
+ self.detach = detach
+ if pos_embeddings is None:
+ self.pos_embeddings = {}
+ else:
+ self.pos_embeddings = pos_embeddings
+ if scales == "all":
+ self.scales = ["32", "16", "8", "4", "2", "1"]
+ else:
+ self.scales = scales
+ self.warp_noise_std = warp_noise_std
+ self.refine_init = 4
+ self.displacement_dropout_p = displacement_dropout_p
+ self.gm_warp_dropout_p = gm_warp_dropout_p
+ self.flow_upsample_mode = flow_upsample_mode
+ self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+
+ def get_placeholder_flow(self, b, h, w, device):
+ coarse_coords = torch.meshgrid(
+ (
+ torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
+ torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
+ )
+ )
+ coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
+ None
+ ].expand(b, h, w, 2)
+ coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
+ return coarse_coords
+
+ def get_positional_embedding(self, b, h ,w, device):
+ coarse_coords = torch.meshgrid(
+ (
+ torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
+ torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
+ )
+ )
+
+ coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
+ None
+ ].expand(b, h, w, 2)
+ coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
+ coarse_embedded_coords = self.pos_embedding(coarse_coords)
+ return coarse_embedded_coords
+
+ def forward(self, f1, f2, gt_warp = None, gt_prob = None, upsample = False, flow = None, certainty = None, scale_factor = 1):
+ coarse_scales = self.embedding_decoder.scales()
+ all_scales = self.scales if not upsample else ["8", "4", "2", "1"]
+ sizes = {scale: f1[scale].shape[-2:] for scale in f1}
+ h, w = sizes[1]
+ b = f1[1].shape[0]
+ device = f1[1].device
+ coarsest_scale = int(all_scales[0])
+ old_stuff = torch.zeros(
+ b, self.embedding_decoder.hidden_dim, *sizes[coarsest_scale], device=f1[coarsest_scale].device
+ )
+ corresps = {}
+ if not upsample:
+ flow = self.get_placeholder_flow(b, *sizes[coarsest_scale], device)
+ certainty = 0.0
+ else:
+ flow = F.interpolate(
+ flow,
+ size=sizes[coarsest_scale],
+ align_corners=False,
+ mode="bilinear",
+ )
+ certainty = F.interpolate(
+ certainty,
+ size=sizes[coarsest_scale],
+ align_corners=False,
+ mode="bilinear",
+ )
+ displacement = 0.0
+ for new_scale in all_scales:
+ ins = int(new_scale)
+ corresps[ins] = {}
+ f1_s, f2_s = f1[ins], f2[ins]
+ if new_scale in self.proj:
+ with torch.autocast("cuda", self.amp_dtype):
+ f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
+
+ if ins in coarse_scales:
+ old_stuff = F.interpolate(
+ old_stuff, size=sizes[ins], mode="bilinear", align_corners=False
+ )
+ gp_posterior = self.gps[new_scale](f1_s, f2_s)
+ gm_warp_or_cls, certainty, old_stuff = self.embedding_decoder(
+ gp_posterior, f1_s, old_stuff, new_scale
+ )
+
+ if self.embedding_decoder.is_classifier:
+ flow = cls_to_flow_refine(
+ gm_warp_or_cls,
+ ).permute(0,3,1,2)
+ corresps[ins].update({"gm_cls": gm_warp_or_cls,"gm_certainty": certainty,}) if self.training else None
+ else:
+ corresps[ins].update({"gm_flow": gm_warp_or_cls,"gm_certainty": certainty,}) if self.training else None
+ flow = gm_warp_or_cls.detach()
+
+ if new_scale in self.conv_refiner:
+ corresps[ins].update({"flow_pre_delta": flow}) if self.training else None
+ delta_flow, delta_certainty = self.conv_refiner[new_scale](
+ f1_s, f2_s, flow, scale_factor = scale_factor, logits = certainty,
+ )
+ corresps[ins].update({"delta_flow": delta_flow,}) if self.training else None
+ displacement = ins*torch.stack((delta_flow[:, 0].float() / (self.refine_init * w),
+ delta_flow[:, 1].float() / (self.refine_init * h),),dim=1,)
+ flow = flow + displacement
+ certainty = (
+ certainty + delta_certainty
+ ) # predict both certainty and displacement
+ corresps[ins].update({
+ "certainty": certainty,
+ "flow": flow,
+ })
+ if new_scale != "1":
+ flow = F.interpolate(
+ flow,
+ size=sizes[ins // 2],
+ mode=self.flow_upsample_mode,
+ )
+ certainty = F.interpolate(
+ certainty,
+ size=sizes[ins // 2],
+ mode=self.flow_upsample_mode,
+ )
+ if self.detach:
+ flow = flow.detach()
+ certainty = certainty.detach()
+ #torch.cuda.empty_cache()
+ return corresps
+
+
+class RegressionMatcher(nn.Module):
+ def __init__(
+ self,
+ encoder,
+ decoder,
+ h=448,
+ w=448,
+ sample_mode = "threshold",
+ upsample_preds = False,
+ symmetric = False,
+ name = None,
+ attenuate_cert = None,
+ ):
+ super().__init__()
+ self.attenuate_cert = attenuate_cert
+ self.encoder = encoder
+ self.decoder = decoder
+ self.name = name
+ self.w_resized = w
+ self.h_resized = h
+ self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True)
+ self.sample_mode = sample_mode
+ self.upsample_preds = upsample_preds
+ self.upsample_res = (14*16*6, 14*16*6)
+ self.symmetric = symmetric
+ self.sample_thresh = 0.05
+
+ def get_output_resolution(self):
+ if not self.upsample_preds:
+ return self.h_resized, self.w_resized
+ else:
+ return self.upsample_res
+
+ def extract_backbone_features(self, batch, batched = True, upsample = False):
+ x_q = batch["im_A"]
+ x_s = batch["im_B"]
+ if batched:
+ X = torch.cat((x_q, x_s), dim = 0)
+ feature_pyramid = self.encoder(X, upsample = upsample)
+ else:
+ feature_pyramid = self.encoder(x_q, upsample = upsample), self.encoder(x_s, upsample = upsample)
+ return feature_pyramid
+
+ def sample(
+ self,
+ matches,
+ certainty,
+ num=10000,
+ ):
+ if "threshold" in self.sample_mode:
+ upper_thresh = self.sample_thresh
+ certainty = certainty.clone()
+ certainty[certainty > upper_thresh] = 1
+ matches, certainty = (
+ matches.reshape(-1, 4),
+ certainty.reshape(-1),
+ )
+ expansion_factor = 4 if "balanced" in self.sample_mode else 1
+ good_samples = torch.multinomial(certainty,
+ num_samples = min(expansion_factor*num, len(certainty)),
+ replacement=False)
+ good_matches, good_certainty = matches[good_samples], certainty[good_samples]
+ if "balanced" not in self.sample_mode:
+ return good_matches, good_certainty
+ density = kde(good_matches, std=0.1)
+ p = 1 / (density+1)
+ p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
+ balanced_samples = torch.multinomial(p,
+ num_samples = min(num,len(good_certainty)),
+ replacement=False)
+ return good_matches[balanced_samples], good_certainty[balanced_samples]
+
+ def forward(self, batch, batched = True, upsample = False, scale_factor = 1):
+ feature_pyramid = self.extract_backbone_features(batch, batched=batched, upsample = upsample)
+ if batched:
+ f_q_pyramid = {
+ scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items()
+ }
+ f_s_pyramid = {
+ scale: f_scale.chunk(2)[1] for scale, f_scale in feature_pyramid.items()
+ }
+ else:
+ f_q_pyramid, f_s_pyramid = feature_pyramid
+ corresps = self.decoder(f_q_pyramid,
+ f_s_pyramid,
+ upsample = upsample,
+ **(batch["corresps"] if "corresps" in batch else {}),
+ scale_factor=scale_factor)
+
+ return corresps
+
+ def forward_symmetric(self, batch, batched = True, upsample = False, scale_factor = 1):
+ feature_pyramid = self.extract_backbone_features(batch, batched = batched, upsample = upsample)
+ f_q_pyramid = feature_pyramid
+ f_s_pyramid = {
+ scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]), dim = 0)
+ for scale, f_scale in feature_pyramid.items()
+ }
+ corresps = self.decoder(f_q_pyramid,
+ f_s_pyramid,
+ upsample = upsample,
+ **(batch["corresps"] if "corresps" in batch else {}),
+ scale_factor=scale_factor)
+ return corresps
+
+ def to_pixel_coordinates(self, matches, H_A, W_A, H_B, W_B):
+ kpts_A, kpts_B = matches[...,:2], matches[...,2:]
+ kpts_A = torch.stack((W_A/2 * (kpts_A[...,0]+1), H_A/2 * (kpts_A[...,1]+1)),axis=-1)
+ kpts_B = torch.stack((W_B/2 * (kpts_B[...,0]+1), H_B/2 * (kpts_B[...,1]+1)),axis=-1)
+ return kpts_A, kpts_B
+
+ def match(
+ self,
+ im_A_path,
+ im_B_path,
+ *args,
+ batched=False,
+ device = None,
+ ):
+ if device is None:
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ from PIL import Image
+ if isinstance(im_A_path, (str, os.PathLike)):
+ im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
+ else:
+ # Assume its not a path
+ im_A, im_B = im_A_path, im_B_path
+ symmetric = self.symmetric
+ self.train(False)
+ with torch.no_grad():
+ if not batched:
+ b = 1
+ w, h = im_A.size
+ w2, h2 = im_B.size
+ # Get images in good format
+ ws = self.w_resized
+ hs = self.h_resized
+
+ test_transform = get_tuple_transform_ops(
+ resize=(hs, ws), normalize=True, clahe = False
+ )
+ im_A, im_B = test_transform((im_A, im_B))
+ batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
+ else:
+ b, c, h, w = im_A.shape
+ b, c, h2, w2 = im_B.shape
+ assert w == w2 and h == h2, "For batched images we assume same size"
+ batch = {"im_A": im_A.to(device), "im_B": im_B.to(device)}
+ if h != self.h_resized or self.w_resized != w:
+ warn("Model resolution and batch resolution differ, may produce unexpected results")
+ hs, ws = h, w
+ finest_scale = 1
+ # Run matcher
+ if symmetric:
+ corresps = self.forward_symmetric(batch)
+ else:
+ corresps = self.forward(batch, batched = True)
+
+ if self.upsample_preds:
+ hs, ws = self.upsample_res
+
+ if self.attenuate_cert:
+ low_res_certainty = F.interpolate(
+ corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
+ )
+ cert_clamp = 0
+ factor = 0.5
+ low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp)
+
+ if self.upsample_preds:
+ finest_corresps = corresps[finest_scale]
+ torch.cuda.empty_cache()
+ test_transform = get_tuple_transform_ops(
+ resize=(hs, ws), normalize=True
+ )
+ im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
+ im_A, im_B = test_transform((im_A, im_B))
+ im_A, im_B = im_A[None].to(device), im_B[None].to(device)
+ scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized))
+ batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps}
+ if symmetric:
+ corresps = self.forward_symmetric(batch, upsample = True, batched=True, scale_factor = scale_factor)
+ else:
+ corresps = self.forward(batch, batched = True, upsample=True, scale_factor = scale_factor)
+
+ im_A_to_im_B = corresps[finest_scale]["flow"]
+ certainty = corresps[finest_scale]["certainty"] - (low_res_certainty if self.attenuate_cert else 0)
+ if finest_scale != 1:
+ im_A_to_im_B = F.interpolate(
+ im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
+ )
+ certainty = F.interpolate(
+ certainty, size=(hs, ws), align_corners=False, mode="bilinear"
+ )
+ im_A_to_im_B = im_A_to_im_B.permute(
+ 0, 2, 3, 1
+ )
+ # Create im_A meshgrid
+ im_A_coords = torch.meshgrid(
+ (
+ torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device="cuda"),
+ torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device="cuda"),
+ )
+ )
+ im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
+ im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
+ certainty = certainty.sigmoid() # logits -> probs
+ im_A_coords = im_A_coords.permute(0, 2, 3, 1)
+ if (im_A_to_im_B.abs() > 1).any() and True:
+ wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0
+ certainty[wrong[:,None]] = 0
+ im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1)
+ if symmetric:
+ A_to_B, B_to_A = im_A_to_im_B.chunk(2)
+ q_warp = torch.cat((im_A_coords, A_to_B), dim=-1)
+ im_B_coords = im_A_coords
+ s_warp = torch.cat((B_to_A, im_B_coords), dim=-1)
+ warp = torch.cat((q_warp, s_warp),dim=2)
+ certainty = torch.cat(certainty.chunk(2), dim=3)
+ else:
+ warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
+ if batched:
+ return (
+ warp,
+ certainty[:, 0]
+ )
+ else:
+ return (
+ warp[0],
+ certainty[0, 0],
+ )
+
diff --git a/third_party/Roma/roma/models/model_zoo/__init__.py b/third_party/Roma/roma/models/model_zoo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..91edd4e69f2b39f18d62545a95f2774324ff404b
--- /dev/null
+++ b/third_party/Roma/roma/models/model_zoo/__init__.py
@@ -0,0 +1,30 @@
+import torch
+from .roma_models import roma_model
+
+weight_urls = {
+ "roma": {
+ "outdoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_outdoor.pth",
+ "indoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_indoor.pth",
+ },
+ "dinov2": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", #hopefully this doesnt change :D
+}
+
+def roma_outdoor(device, weights=None, dinov2_weights=None):
+ if weights is None:
+ weights = torch.hub.load_state_dict_from_url(weight_urls["roma"]["outdoor"],
+ map_location=device)
+ if dinov2_weights is None:
+ dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"],
+ map_location=device)
+ return roma_model(resolution=(14*8*6,14*8*6), upsample_preds=True,
+ weights=weights,dinov2_weights = dinov2_weights,device=device)
+
+def roma_indoor(device, weights=None, dinov2_weights=None):
+ if weights is None:
+ weights = torch.hub.load_state_dict_from_url(weight_urls["roma"]["indoor"],
+ map_location=device)
+ if dinov2_weights is None:
+ dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"],
+ map_location=device)
+ return roma_model(resolution=(14*8*5,14*8*5), upsample_preds=False,
+ weights=weights,dinov2_weights = dinov2_weights,device=device)
diff --git a/third_party/Roma/roma/models/model_zoo/roma_models.py b/third_party/Roma/roma/models/model_zoo/roma_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfb0ff7264880d25f0feb0802e582bf29c84b051
--- /dev/null
+++ b/third_party/Roma/roma/models/model_zoo/roma_models.py
@@ -0,0 +1,157 @@
+import warnings
+import torch.nn as nn
+from roma.models.matcher import *
+from roma.models.transformer import Block, TransformerDecoder, MemEffAttention
+from roma.models.encoders import *
+
+def roma_model(resolution, upsample_preds, device = None, weights=None, dinov2_weights=None, **kwargs):
+ # roma weights and dinov2 weights are loaded seperately, as dinov2 weights are not parameters
+ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
+ warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
+ gp_dim = 512
+ feat_dim = 512
+ decoder_dim = gp_dim + feat_dim
+ cls_to_coord_res = 64
+ coordinate_decoder = TransformerDecoder(
+ nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]),
+ decoder_dim,
+ cls_to_coord_res**2 + 1,
+ is_classifier=True,
+ amp = True,
+ pos_enc = False,)
+ dw = True
+ hidden_blocks = 8
+ kernel_size = 5
+ displacement_emb = "linear"
+ disable_local_corr_grad = True
+
+ conv_refiner = nn.ModuleDict(
+ {
+ "16": ConvRefiner(
+ 2 * 512+128+(2*7+1)**2,
+ 2 * 512+128+(2*7+1)**2,
+ 2 + 1,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ displacement_emb=displacement_emb,
+ displacement_emb_dim=128,
+ local_corr_radius = 7,
+ corr_in_other = True,
+ amp = True,
+ disable_local_corr_grad = disable_local_corr_grad,
+ bn_momentum = 0.01,
+ ),
+ "8": ConvRefiner(
+ 2 * 512+64+(2*3+1)**2,
+ 2 * 512+64+(2*3+1)**2,
+ 2 + 1,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ displacement_emb=displacement_emb,
+ displacement_emb_dim=64,
+ local_corr_radius = 3,
+ corr_in_other = True,
+ amp = True,
+ disable_local_corr_grad = disable_local_corr_grad,
+ bn_momentum = 0.01,
+ ),
+ "4": ConvRefiner(
+ 2 * 256+32+(2*2+1)**2,
+ 2 * 256+32+(2*2+1)**2,
+ 2 + 1,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ displacement_emb=displacement_emb,
+ displacement_emb_dim=32,
+ local_corr_radius = 2,
+ corr_in_other = True,
+ amp = True,
+ disable_local_corr_grad = disable_local_corr_grad,
+ bn_momentum = 0.01,
+ ),
+ "2": ConvRefiner(
+ 2 * 64+16,
+ 128+16,
+ 2 + 1,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ displacement_emb=displacement_emb,
+ displacement_emb_dim=16,
+ amp = True,
+ disable_local_corr_grad = disable_local_corr_grad,
+ bn_momentum = 0.01,
+ ),
+ "1": ConvRefiner(
+ 2 * 9 + 6,
+ 24,
+ 2 + 1,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks = hidden_blocks,
+ displacement_emb = displacement_emb,
+ displacement_emb_dim = 6,
+ amp = True,
+ disable_local_corr_grad = disable_local_corr_grad,
+ bn_momentum = 0.01,
+ ),
+ }
+ )
+ kernel_temperature = 0.2
+ learn_temperature = False
+ no_cov = True
+ kernel = CosKernel
+ only_attention = False
+ basis = "fourier"
+ gp16 = GP(
+ kernel,
+ T=kernel_temperature,
+ learn_temperature=learn_temperature,
+ only_attention=only_attention,
+ gp_dim=gp_dim,
+ basis=basis,
+ no_cov=no_cov,
+ )
+ gps = nn.ModuleDict({"16": gp16})
+ proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
+ proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
+ proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
+ proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
+ proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
+ proj = nn.ModuleDict({
+ "16": proj16,
+ "8": proj8,
+ "4": proj4,
+ "2": proj2,
+ "1": proj1,
+ })
+ displacement_dropout_p = 0.0
+ gm_warp_dropout_p = 0.0
+ decoder = Decoder(coordinate_decoder,
+ gps,
+ proj,
+ conv_refiner,
+ detach=True,
+ scales=["16", "8", "4", "2", "1"],
+ displacement_dropout_p = displacement_dropout_p,
+ gm_warp_dropout_p = gm_warp_dropout_p)
+
+ encoder = CNNandDinov2(
+ cnn_kwargs = dict(
+ pretrained=False,
+ amp = True),
+ amp = True,
+ use_vgg = True,
+ dinov2_weights = dinov2_weights
+ )
+ h,w = resolution
+ symmetric = True
+ attenuate_cert = True
+ matcher = RegressionMatcher(encoder, decoder, h=h, w=w, upsample_preds=upsample_preds,
+ symmetric = symmetric, attenuate_cert=attenuate_cert, **kwargs).to(device)
+ matcher.load_state_dict(weights)
+ return matcher
diff --git a/third_party/Roma/roma/models/transformer/__init__.py b/third_party/Roma/roma/models/transformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4770ebb19f111df14f1539fa3696553d96d4e48b
--- /dev/null
+++ b/third_party/Roma/roma/models/transformer/__init__.py
@@ -0,0 +1,47 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from roma.utils.utils import get_grid
+from .layers.block import Block
+from .layers.attention import MemEffAttention
+from .dinov2 import vit_large
+
+class TransformerDecoder(nn.Module):
+ def __init__(self, blocks, hidden_dim, out_dim, is_classifier = False, *args,
+ amp = False, pos_enc = True, learned_embeddings = False, embedding_dim = None, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self.blocks = blocks
+ self.to_out = nn.Linear(hidden_dim, out_dim)
+ self.hidden_dim = hidden_dim
+ self.out_dim = out_dim
+ self._scales = [16]
+ self.is_classifier = is_classifier
+ self.amp = amp
+ self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+ self.pos_enc = pos_enc
+ self.learned_embeddings = learned_embeddings
+ if self.learned_embeddings:
+ self.learned_pos_embeddings = nn.Parameter(nn.init.kaiming_normal_(torch.empty((1, hidden_dim, embedding_dim, embedding_dim))))
+
+ def scales(self):
+ return self._scales.copy()
+
+ def forward(self, gp_posterior, features, old_stuff, new_scale):
+ with torch.autocast("cuda", dtype=self.amp_dtype, enabled=self.amp):
+ B,C,H,W = gp_posterior.shape
+ x = torch.cat((gp_posterior, features), dim = 1)
+ B,C,H,W = x.shape
+ grid = get_grid(B, H, W, x.device).reshape(B,H*W,2)
+ if self.learned_embeddings:
+ pos_enc = F.interpolate(self.learned_pos_embeddings, size = (H,W), mode = 'bilinear', align_corners = False).permute(0,2,3,1).reshape(1,H*W,C)
+ else:
+ pos_enc = 0
+ tokens = x.reshape(B,C,H*W).permute(0,2,1) + pos_enc
+ z = self.blocks(tokens)
+ out = self.to_out(z)
+ out = out.permute(0,2,1).reshape(B, self.out_dim, H, W)
+ warp, certainty = out[:, :-1], out[:, -1:]
+ return warp, certainty, None
+
+
diff --git a/third_party/Roma/roma/models/transformer/dinov2.py b/third_party/Roma/roma/models/transformer/dinov2.py
new file mode 100644
index 0000000000000000000000000000000000000000..b556c63096d17239c8603d5fe626c331963099fd
--- /dev/null
+++ b/third_party/Roma/roma/models/transformer/dinov2.py
@@ -0,0 +1,359 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+from torch.nn.init import trunc_normal_
+
+from .layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
+
+
+
+def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x):
+ for b in self:
+ x = b(x)
+ return x
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ init_values=None, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=Block,
+ ffn_layer="mlp",
+ block_chunks=1,
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ """
+ super().__init__()
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.init_weights()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ @property
+ def device(self):
+ return self.cls_token.device
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0]
+ patch_pos_embed = pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ w0, h0 = w0 + 0.1, h0 + 0.1
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
+ mode="bicubic",
+ )
+
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
+
+ def prepare_tokens_with_masks(self, x, masks=None):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, w, h)
+
+ return x
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+ for blk in self.blocks:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_patchtokens": x_norm[:, 1:],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features(self, x, masks=None):
+ if isinstance(x, list):
+ return self.forward_features_list(x, masks)
+
+ x = self.prepare_tokens_with_masks(x, masks)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x_norm = self.norm(x)
+ return {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_patchtokens": x_norm[:, 1:],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1:] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=False, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ if is_training:
+ return ret
+ else:
+ return self.head(ret["x_norm_clstoken"])
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+
+def vit_small(patch_size=16, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ **kwargs,
+ )
+ return model
+
+
+def vit_base(patch_size=16, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=16, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ **kwargs,
+ )
+ return model
+
+
+def vit_giant2(patch_size=16, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ **kwargs,
+ )
+ return model
\ No newline at end of file
diff --git a/third_party/Roma/roma/models/transformer/layers/__init__.py b/third_party/Roma/roma/models/transformer/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..31f196aacac5be8a7c537a3dfa8f97084671b466
--- /dev/null
+++ b/third_party/Roma/roma/models/transformer/layers/__init__.py
@@ -0,0 +1,12 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .dino_head import DINOHead
+from .mlp import Mlp
+from .patch_embed import PatchEmbed
+from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
+from .block import NestedTensorBlock
+from .attention import MemEffAttention
diff --git a/third_party/Roma/roma/models/transformer/layers/attention.py b/third_party/Roma/roma/models/transformer/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f9b0c94b40967dfdff4f261c127cbd21328c905
--- /dev/null
+++ b/third_party/Roma/roma/models/transformer/layers/attention.py
@@ -0,0 +1,81 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import logging
+
+from torch import Tensor
+from torch import nn
+
+
+logger = logging.getLogger("dinov2")
+
+
+try:
+ from xformers.ops import memory_efficient_attention, unbind, fmha
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ logger.warning("xFormers not available")
+ XFORMERS_AVAILABLE = False
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+ attn = q @ k.transpose(-2, -1)
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ if not XFORMERS_AVAILABLE:
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
+ return super().forward(x)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
diff --git a/third_party/Roma/roma/models/transformer/layers/block.py b/third_party/Roma/roma/models/transformer/layers/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..25488f57cc0ad3c692f86b62555f6668e2a66db1
--- /dev/null
+++ b/third_party/Roma/roma/models/transformer/layers/block.py
@@ -0,0 +1,252 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+import logging
+from typing import Callable, List, Any, Tuple, Dict
+
+import torch
+from torch import nn, Tensor
+
+from .attention import Attention, MemEffAttention
+from .drop_path import DropPath
+from .layer_scale import LayerScale
+from .mlp import Mlp
+
+
+logger = logging.getLogger("dinov2")
+
+
+try:
+ from xformers.ops import fmha
+ from xformers.ops import scaled_index_add, index_select_cat
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ logger.warning("xFormers not available")
+ XFORMERS_AVAILABLE = False
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values=None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ ) -> None:
+ super().__init__()
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor) -> Tensor:
+ def attn_residual_func(x: Tensor) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x)))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor,
+ residual_func: Callable[[Tensor], Tensor],
+ sample_drop_ratio: float = 0.0,
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ residual = residual_func(x_subset)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ else:
+ x_plus_residual = scaled_index_add(
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list)
+ elif isinstance(x_or_x_list, list):
+ assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
diff --git a/third_party/Roma/roma/models/transformer/layers/dino_head.py b/third_party/Roma/roma/models/transformer/layers/dino_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..7212db92a4fd8d4c7230e284e551a0234e9d8623
--- /dev/null
+++ b/third_party/Roma/roma/models/transformer/layers/dino_head.py
@@ -0,0 +1,59 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from torch.nn.init import trunc_normal_
+from torch.nn.utils import weight_norm
+
+
+class DINOHead(nn.Module):
+ def __init__(
+ self,
+ in_dim,
+ out_dim,
+ use_bn=False,
+ nlayers=3,
+ hidden_dim=2048,
+ bottleneck_dim=256,
+ mlp_bias=True,
+ ):
+ super().__init__()
+ nlayers = max(nlayers, 1)
+ self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
+ self.apply(self._init_weights)
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
+ self.last_layer.weight_g.data.fill_(1)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ x = self.mlp(x)
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
+ x = self.last_layer(x)
+ return x
+
+
+def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
+ if nlayers == 1:
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
+ else:
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ for _ in range(nlayers - 2):
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
+ return nn.Sequential(*layers)
diff --git a/third_party/Roma/roma/models/transformer/layers/drop_path.py b/third_party/Roma/roma/models/transformer/layers/drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..af05625984dd14682cc96a63bf0c97bab1f123b1
--- /dev/null
+++ b/third_party/Roma/roma/models/transformer/layers/drop_path.py
@@ -0,0 +1,35 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
+
+
+from torch import nn
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
diff --git a/third_party/Roma/roma/models/transformer/layers/layer_scale.py b/third_party/Roma/roma/models/transformer/layers/layer_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca5daa52bd81d3581adeb2198ea5b7dba2a3aea1
--- /dev/null
+++ b/third_party/Roma/roma/models/transformer/layers/layer_scale.py
@@ -0,0 +1,28 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
+
+from typing import Union
+
+import torch
+from torch import Tensor
+from torch import nn
+
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
diff --git a/third_party/Roma/roma/models/transformer/layers/mlp.py b/third_party/Roma/roma/models/transformer/layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e4b315f972f9a9f54aef1e4ef4e81b52976f018
--- /dev/null
+++ b/third_party/Roma/roma/models/transformer/layers/mlp.py
@@ -0,0 +1,41 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
+
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
diff --git a/third_party/Roma/roma/models/transformer/layers/patch_embed.py b/third_party/Roma/roma/models/transformer/layers/patch_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..574abe41175568d700a389b8b96d1ba554914779
--- /dev/null
+++ b/third_party/Roma/roma/models/transformer/layers/patch_embed.py
@@ -0,0 +1,89 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Callable, Optional, Tuple, Union
+
+from torch import Tensor
+import torch.nn as nn
+
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (
+ image_HW[0] // patch_HW[0],
+ image_HW[1] // patch_HW[1],
+ )
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
diff --git a/third_party/Roma/roma/models/transformer/layers/swiglu_ffn.py b/third_party/Roma/roma/models/transformer/layers/swiglu_ffn.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3324b266fb0a50ccf8c3a0ede2ae10ac4dfa03e
--- /dev/null
+++ b/third_party/Roma/roma/models/transformer/layers/swiglu_ffn.py
@@ -0,0 +1,63 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+import torch.nn.functional as F
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+try:
+ from xformers.ops import SwiGLU
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ SwiGLU = SwiGLUFFN
+ XFORMERS_AVAILABLE = False
+
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(
+ in_features=in_features,
+ hidden_features=hidden_features,
+ out_features=out_features,
+ bias=bias,
+ )
diff --git a/third_party/Roma/roma/train/__init__.py b/third_party/Roma/roma/train/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..90269dc0f345a575e0ba21f5afa34202c7e6b433
--- /dev/null
+++ b/third_party/Roma/roma/train/__init__.py
@@ -0,0 +1 @@
+from .train import train_k_epochs
diff --git a/third_party/Roma/roma/train/train.py b/third_party/Roma/roma/train/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..5556f7ebf9b6378e1395c125dde093f5e55e7141
--- /dev/null
+++ b/third_party/Roma/roma/train/train.py
@@ -0,0 +1,102 @@
+from tqdm import tqdm
+from roma.utils.utils import to_cuda
+import roma
+import torch
+import wandb
+
+def log_param_statistics(named_parameters, norm_type = 2):
+ named_parameters = list(named_parameters)
+ grads = [p.grad for n, p in named_parameters if p.grad is not None]
+ weight_norms = [p.norm(p=norm_type) for n, p in named_parameters if p.grad is not None]
+ names = [n for n,p in named_parameters if p.grad is not None]
+ param_norm = torch.stack(weight_norms).norm(p=norm_type)
+ device = grads[0].device
+ grad_norms = torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads])
+ nans_or_infs = torch.isinf(grad_norms) | torch.isnan(grad_norms)
+ nan_inf_names = [name for name, naninf in zip(names, nans_or_infs) if naninf]
+ total_grad_norm = torch.norm(grad_norms, norm_type)
+ if torch.any(nans_or_infs):
+ print(f"These params have nan or inf grads: {nan_inf_names}")
+ wandb.log({"grad_norm": total_grad_norm.item()}, step = roma.GLOBAL_STEP)
+ wandb.log({"param_norm": param_norm.item()}, step = roma.GLOBAL_STEP)
+
+def train_step(train_batch, model, objective, optimizer, grad_scaler, grad_clip_norm = 1.,**kwargs):
+ optimizer.zero_grad()
+ out = model(train_batch)
+ l = objective(out, train_batch)
+ grad_scaler.scale(l).backward()
+ grad_scaler.unscale_(optimizer)
+ log_param_statistics(model.named_parameters())
+ torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm) # what should max norm be?
+ grad_scaler.step(optimizer)
+ grad_scaler.update()
+ wandb.log({"grad_scale": grad_scaler._scale.item()}, step = roma.GLOBAL_STEP)
+ if grad_scaler._scale < 1.:
+ grad_scaler._scale = torch.tensor(1.).to(grad_scaler._scale)
+ roma.GLOBAL_STEP = roma.GLOBAL_STEP + roma.STEP_SIZE # increment global step
+ return {"train_out": out, "train_loss": l.item()}
+
+
+def train_k_steps(
+ n_0, k, dataloader, model, objective, optimizer, lr_scheduler, grad_scaler, progress_bar=True, grad_clip_norm = 1., warmup = None, ema_model = None,
+):
+ for n in tqdm(range(n_0, n_0 + k), disable=(not progress_bar) or roma.RANK > 0):
+ batch = next(dataloader)
+ model.train(True)
+ batch = to_cuda(batch)
+ train_step(
+ train_batch=batch,
+ model=model,
+ objective=objective,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ grad_scaler=grad_scaler,
+ n=n,
+ grad_clip_norm = grad_clip_norm,
+ )
+ if ema_model is not None:
+ ema_model.update()
+ if warmup is not None:
+ with warmup.dampening():
+ lr_scheduler.step()
+ else:
+ lr_scheduler.step()
+ [wandb.log({f"lr_group_{grp}": lr}) for grp, lr in enumerate(lr_scheduler.get_last_lr())]
+
+
+def train_epoch(
+ dataloader=None,
+ model=None,
+ objective=None,
+ optimizer=None,
+ lr_scheduler=None,
+ epoch=None,
+):
+ model.train(True)
+ print(f"At epoch {epoch}")
+ for batch in tqdm(dataloader, mininterval=5.0):
+ batch = to_cuda(batch)
+ train_step(
+ train_batch=batch, model=model, objective=objective, optimizer=optimizer
+ )
+ lr_scheduler.step()
+ return {
+ "model": model,
+ "optimizer": optimizer,
+ "lr_scheduler": lr_scheduler,
+ "epoch": epoch,
+ }
+
+
+def train_k_epochs(
+ start_epoch, end_epoch, dataloader, model, objective, optimizer, lr_scheduler
+):
+ for epoch in range(start_epoch, end_epoch + 1):
+ train_epoch(
+ dataloader=dataloader,
+ model=model,
+ objective=objective,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ epoch=epoch,
+ )
diff --git a/third_party/Roma/roma/utils/__init__.py b/third_party/Roma/roma/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2709f5e586150289085a4e2cbd458bc443fab7f3
--- /dev/null
+++ b/third_party/Roma/roma/utils/__init__.py
@@ -0,0 +1,16 @@
+from .utils import (
+ pose_auc,
+ get_pose,
+ compute_relative_pose,
+ compute_pose_error,
+ estimate_pose,
+ estimate_pose_uncalibrated,
+ rotate_intrinsic,
+ get_tuple_transform_ops,
+ get_depth_tuple_transform_ops,
+ warp_kpts,
+ numpy_to_pil,
+ tensor_to_pil,
+ recover_pose,
+ signed_left_to_right_epipolar_distance,
+)
diff --git a/third_party/Roma/roma/utils/kde.py b/third_party/Roma/roma/utils/kde.py
new file mode 100644
index 0000000000000000000000000000000000000000..90a058fb68253cfe23c2a7f21b213bea8e06cfe3
--- /dev/null
+++ b/third_party/Roma/roma/utils/kde.py
@@ -0,0 +1,8 @@
+import torch
+
+def kde(x, std = 0.1):
+ # use a gaussian kernel to estimate density
+ x = x.half() # Do it in half precision
+ scores = (-torch.cdist(x,x)**2/(2*std**2)).exp()
+ density = scores.sum(dim=-1)
+ return density
\ No newline at end of file
diff --git a/third_party/Roma/roma/utils/local_correlation.py b/third_party/Roma/roma/utils/local_correlation.py
new file mode 100644
index 0000000000000000000000000000000000000000..586eef5f154a95968b253ad9701933b55b3a4dd6
--- /dev/null
+++ b/third_party/Roma/roma/utils/local_correlation.py
@@ -0,0 +1,47 @@
+import torch
+import torch.nn.functional as F
+
+def local_correlation(
+ feature0,
+ feature1,
+ local_radius,
+ padding_mode="zeros",
+ flow = None,
+ sample_mode = "bilinear",
+):
+ r = local_radius
+ K = (2*r+1)**2
+ B, c, h, w = feature0.size()
+ feature0 = feature0.half()
+ feature1 = feature1.half()
+ corr = torch.empty((B,K,h,w), device = feature0.device, dtype=feature0.dtype)
+ if flow is None:
+ # If flow is None, assume feature0 and feature1 are aligned
+ coords = torch.meshgrid(
+ (
+ torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device="cuda"),
+ torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device="cuda"),
+ ))
+ coords = torch.stack((coords[1], coords[0]), dim=-1)[
+ None
+ ].expand(B, h, w, 2)
+ else:
+ coords = flow.permute(0,2,3,1) # If using flow, sample around flow target.
+ local_window = torch.meshgrid(
+ (
+ torch.linspace(-2*local_radius/h, 2*local_radius/h, 2*r+1, device="cuda"),
+ torch.linspace(-2*local_radius/w, 2*local_radius/w, 2*r+1, device="cuda"),
+ ))
+ local_window = torch.stack((local_window[1], local_window[0]), dim=-1)[
+ None
+ ].expand(1, 2*r+1, 2*r+1, 2).reshape(1, (2*r+1)**2, 2)
+ for _ in range(B):
+ with torch.no_grad():
+ local_window_coords = (coords[_,:,:,None]+local_window[:,None,None]).reshape(1,h,w*(2*r+1)**2,2).float()
+ window_feature = F.grid_sample(
+ feature1[_:_+1].float(), local_window_coords, padding_mode=padding_mode, align_corners=False, mode = sample_mode, #
+ )
+ window_feature = window_feature.reshape(c,h,w,(2*r+1)**2)
+ corr[_] = (feature0[_,...,None]/(c**.5)*window_feature).sum(dim=0).permute(2,0,1)
+ torch.cuda.empty_cache()
+ return corr
\ No newline at end of file
diff --git a/third_party/Roma/roma/utils/transforms.py b/third_party/Roma/roma/utils/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea6476bd816a31df36f7d1b5417853637b65474b
--- /dev/null
+++ b/third_party/Roma/roma/utils/transforms.py
@@ -0,0 +1,118 @@
+from typing import Dict
+import numpy as np
+import torch
+import kornia.augmentation as K
+from kornia.geometry.transform import warp_perspective
+
+# Adapted from Kornia
+class GeometricSequential:
+ def __init__(self, *transforms, align_corners=True) -> None:
+ self.transforms = transforms
+ self.align_corners = align_corners
+
+ def __call__(self, x, mode="bilinear"):
+ b, c, h, w = x.shape
+ M = torch.eye(3, device=x.device)[None].expand(b, 3, 3)
+ for t in self.transforms:
+ if np.random.rand() < t.p:
+ M = M.matmul(
+ t.compute_transformation(x, t.generate_parameters((b, c, h, w)), None)
+ )
+ return (
+ warp_perspective(
+ x, M, dsize=(h, w), mode=mode, align_corners=self.align_corners
+ ),
+ M,
+ )
+
+ def apply_transform(self, x, M, mode="bilinear"):
+ b, c, h, w = x.shape
+ return warp_perspective(
+ x, M, dsize=(h, w), align_corners=self.align_corners, mode=mode
+ )
+
+
+class RandomPerspective(K.RandomPerspective):
+ def generate_parameters(self, batch_shape: torch.Size) -> Dict[str, torch.Tensor]:
+ distortion_scale = torch.as_tensor(
+ self.distortion_scale, device=self._device, dtype=self._dtype
+ )
+ return self.random_perspective_generator(
+ batch_shape[0],
+ batch_shape[-2],
+ batch_shape[-1],
+ distortion_scale,
+ self.same_on_batch,
+ self.device,
+ self.dtype,
+ )
+
+ def random_perspective_generator(
+ self,
+ batch_size: int,
+ height: int,
+ width: int,
+ distortion_scale: torch.Tensor,
+ same_on_batch: bool = False,
+ device: torch.device = torch.device("cpu"),
+ dtype: torch.dtype = torch.float32,
+ ) -> Dict[str, torch.Tensor]:
+ r"""Get parameters for ``perspective`` for a random perspective transform.
+
+ Args:
+ batch_size (int): the tensor batch size.
+ height (int) : height of the image.
+ width (int): width of the image.
+ distortion_scale (torch.Tensor): it controls the degree of distortion and ranges from 0 to 1.
+ same_on_batch (bool): apply the same transformation across the batch. Default: False.
+ device (torch.device): the device on which the random numbers will be generated. Default: cpu.
+ dtype (torch.dtype): the data type of the generated random numbers. Default: float32.
+
+ Returns:
+ params Dict[str, torch.Tensor]: parameters to be passed for transformation.
+ - start_points (torch.Tensor): element-wise perspective source areas with a shape of (B, 4, 2).
+ - end_points (torch.Tensor): element-wise perspective target areas with a shape of (B, 4, 2).
+
+ Note:
+ The generated random numbers are not reproducible across different devices and dtypes.
+ """
+ if not (distortion_scale.dim() == 0 and 0 <= distortion_scale <= 1):
+ raise AssertionError(
+ f"'distortion_scale' must be a scalar within [0, 1]. Got {distortion_scale}."
+ )
+ if not (
+ type(height) is int and height > 0 and type(width) is int and width > 0
+ ):
+ raise AssertionError(
+ f"'height' and 'width' must be integers. Got {height}, {width}."
+ )
+
+ start_points: torch.Tensor = torch.tensor(
+ [[[0.0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]],
+ device=distortion_scale.device,
+ dtype=distortion_scale.dtype,
+ ).expand(batch_size, -1, -1)
+
+ # generate random offset not larger than half of the image
+ fx = distortion_scale * width / 2
+ fy = distortion_scale * height / 2
+
+ factor = torch.stack([fx, fy], dim=0).view(-1, 1, 2)
+ offset = (torch.rand_like(start_points) - 0.5) * 2
+ end_points = start_points + factor * offset
+
+ return dict(start_points=start_points, end_points=end_points)
+
+
+
+class RandomErasing:
+ def __init__(self, p = 0., scale = 0.) -> None:
+ self.p = p
+ self.scale = scale
+ self.random_eraser = K.RandomErasing(scale = (0.02, scale), p = p)
+ def __call__(self, image, depth):
+ if self.p > 0:
+ image = self.random_eraser(image)
+ depth = self.random_eraser(depth, params=self.random_eraser._params)
+ return image, depth
+
\ No newline at end of file
diff --git a/third_party/Roma/roma/utils/utils.py b/third_party/Roma/roma/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d673f679823c833688e2548dd40bf50943796a71
--- /dev/null
+++ b/third_party/Roma/roma/utils/utils.py
@@ -0,0 +1,622 @@
+import warnings
+import numpy as np
+import cv2
+import math
+import torch
+from torchvision import transforms
+from torchvision.transforms.functional import InterpolationMode
+import torch.nn.functional as F
+from PIL import Image
+import kornia
+
+def recover_pose(E, kpts0, kpts1, K0, K1, mask):
+ best_num_inliers = 0
+ K0inv = np.linalg.inv(K0[:2,:2])
+ K1inv = np.linalg.inv(K1[:2,:2])
+
+ kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T
+ kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T
+
+ for _E in np.split(E, len(E) / 3):
+ n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask)
+ if n > best_num_inliers:
+ best_num_inliers = n
+ ret = (R, t, mask.ravel() > 0)
+ return ret
+
+
+
+# Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py
+# --- GEOMETRY ---
+def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
+ if len(kpts0) < 5:
+ return None
+ K0inv = np.linalg.inv(K0[:2,:2])
+ K1inv = np.linalg.inv(K1[:2,:2])
+
+ kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T
+ kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T
+ E, mask = cv2.findEssentialMat(
+ kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf
+ )
+
+ ret = None
+ if E is not None:
+ best_num_inliers = 0
+
+ for _E in np.split(E, len(E) / 3):
+ n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
+ if n > best_num_inliers:
+ best_num_inliers = n
+ ret = (R, t, mask.ravel() > 0)
+ return ret
+
+def estimate_pose_uncalibrated(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
+ if len(kpts0) < 5:
+ return None
+ method = cv2.USAC_ACCURATE
+ F, mask = cv2.findFundamentalMat(
+ kpts0, kpts1, ransacReprojThreshold=norm_thresh, confidence=conf, method=method, maxIters=10000
+ )
+ E = K1.T@F@K0
+ ret = None
+ if E is not None:
+ best_num_inliers = 0
+ K0inv = np.linalg.inv(K0[:2,:2])
+ K1inv = np.linalg.inv(K1[:2,:2])
+
+ kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T
+ kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T
+
+ for _E in np.split(E, len(E) / 3):
+ n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask)
+ if n > best_num_inliers:
+ best_num_inliers = n
+ ret = (R, t, mask.ravel() > 0)
+ return ret
+
+def unnormalize_coords(x_n,h,w):
+ x = torch.stack(
+ (w * (x_n[..., 0] + 1) / 2, h * (x_n[..., 1] + 1) / 2), dim=-1
+ ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
+ return x
+
+
+def rotate_intrinsic(K, n):
+ base_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])
+ rot = np.linalg.matrix_power(base_rot, n)
+ return rot @ K
+
+
+def rotate_pose_inplane(i_T_w, rot):
+ rotation_matrices = [
+ np.array(
+ [
+ [np.cos(r), -np.sin(r), 0.0, 0.0],
+ [np.sin(r), np.cos(r), 0.0, 0.0],
+ [0.0, 0.0, 1.0, 0.0],
+ [0.0, 0.0, 0.0, 1.0],
+ ],
+ dtype=np.float32,
+ )
+ for r in [np.deg2rad(d) for d in (0, 270, 180, 90)]
+ ]
+ return np.dot(rotation_matrices[rot], i_T_w)
+
+
+def scale_intrinsics(K, scales):
+ scales = np.diag([1.0 / scales[0], 1.0 / scales[1], 1.0])
+ return np.dot(scales, K)
+
+
+def to_homogeneous(points):
+ return np.concatenate([points, np.ones_like(points[:, :1])], axis=-1)
+
+
+def angle_error_mat(R1, R2):
+ cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2
+ cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds
+ return np.rad2deg(np.abs(np.arccos(cos)))
+
+
+def angle_error_vec(v1, v2):
+ n = np.linalg.norm(v1) * np.linalg.norm(v2)
+ return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0)))
+
+
+def compute_pose_error(T_0to1, R, t):
+ R_gt = T_0to1[:3, :3]
+ t_gt = T_0to1[:3, 3]
+ error_t = angle_error_vec(t.squeeze(), t_gt)
+ error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation
+ error_R = angle_error_mat(R, R_gt)
+ return error_t, error_R
+
+
+def pose_auc(errors, thresholds):
+ sort_idx = np.argsort(errors)
+ errors = np.array(errors.copy())[sort_idx]
+ recall = (np.arange(len(errors)) + 1) / len(errors)
+ errors = np.r_[0.0, errors]
+ recall = np.r_[0.0, recall]
+ aucs = []
+ for t in thresholds:
+ last_index = np.searchsorted(errors, t)
+ r = np.r_[recall[:last_index], recall[last_index - 1]]
+ e = np.r_[errors[:last_index], t]
+ aucs.append(np.trapz(r, x=e) / t)
+ return aucs
+
+
+# From Patch2Pix https://github.com/GrumpyZhou/patch2pix
+def get_depth_tuple_transform_ops_nearest_exact(resize=None):
+ ops = []
+ if resize:
+ ops.append(TupleResizeNearestExact(resize))
+ return TupleCompose(ops)
+
+def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False):
+ ops = []
+ if resize:
+ ops.append(TupleResize(resize, mode=InterpolationMode.BILINEAR))
+ return TupleCompose(ops)
+
+
+def get_tuple_transform_ops(resize=None, normalize=True, unscale=False, clahe = False, colorjiggle_params = None):
+ ops = []
+ if resize:
+ ops.append(TupleResize(resize))
+ ops.append(TupleToTensorScaled())
+ if normalize:
+ ops.append(
+ TupleNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ) # Imagenet mean/std
+ return TupleCompose(ops)
+
+class ToTensorScaled(object):
+ """Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]"""
+
+ def __call__(self, im):
+ if not isinstance(im, torch.Tensor):
+ im = np.array(im, dtype=np.float32).transpose((2, 0, 1))
+ im /= 255.0
+ return torch.from_numpy(im)
+ else:
+ return im
+
+ def __repr__(self):
+ return "ToTensorScaled(./255)"
+
+
+class TupleToTensorScaled(object):
+ def __init__(self):
+ self.to_tensor = ToTensorScaled()
+
+ def __call__(self, im_tuple):
+ return [self.to_tensor(im) for im in im_tuple]
+
+ def __repr__(self):
+ return "TupleToTensorScaled(./255)"
+
+
+class ToTensorUnscaled(object):
+ """Convert a RGB PIL Image to a CHW ordered Tensor"""
+
+ def __call__(self, im):
+ return torch.from_numpy(np.array(im, dtype=np.float32).transpose((2, 0, 1)))
+
+ def __repr__(self):
+ return "ToTensorUnscaled()"
+
+
+class TupleToTensorUnscaled(object):
+ """Convert a RGB PIL Image to a CHW ordered Tensor"""
+
+ def __init__(self):
+ self.to_tensor = ToTensorUnscaled()
+
+ def __call__(self, im_tuple):
+ return [self.to_tensor(im) for im in im_tuple]
+
+ def __repr__(self):
+ return "TupleToTensorUnscaled()"
+
+class TupleResizeNearestExact:
+ def __init__(self, size):
+ self.size = size
+ def __call__(self, im_tuple):
+ return [F.interpolate(im, size = self.size, mode = 'nearest-exact') for im in im_tuple]
+
+ def __repr__(self):
+ return "TupleResizeNearestExact(size={})".format(self.size)
+
+
+class TupleResize(object):
+ def __init__(self, size, mode=InterpolationMode.BICUBIC):
+ self.size = size
+ self.resize = transforms.Resize(size, mode)
+ def __call__(self, im_tuple):
+ return [self.resize(im) for im in im_tuple]
+
+ def __repr__(self):
+ return "TupleResize(size={})".format(self.size)
+
+class Normalize:
+ def __call__(self,im):
+ mean = im.mean(dim=(1,2), keepdims=True)
+ std = im.std(dim=(1,2), keepdims=True)
+ return (im-mean)/std
+
+
+class TupleNormalize(object):
+ def __init__(self, mean, std):
+ self.mean = mean
+ self.std = std
+ self.normalize = transforms.Normalize(mean=mean, std=std)
+
+ def __call__(self, im_tuple):
+ c,h,w = im_tuple[0].shape
+ if c > 3:
+ warnings.warn(f"Number of channels c={c} > 3, assuming first 3 are rgb")
+ return [self.normalize(im[:3]) for im in im_tuple]
+
+ def __repr__(self):
+ return "TupleNormalize(mean={}, std={})".format(self.mean, self.std)
+
+
+class TupleCompose(object):
+ def __init__(self, transforms):
+ self.transforms = transforms
+
+ def __call__(self, im_tuple):
+ for t in self.transforms:
+ im_tuple = t(im_tuple)
+ return im_tuple
+
+ def __repr__(self):
+ format_string = self.__class__.__name__ + "("
+ for t in self.transforms:
+ format_string += "\n"
+ format_string += " {0}".format(t)
+ format_string += "\n)"
+ return format_string
+
+@torch.no_grad()
+def cls_to_flow(cls, deterministic_sampling = True):
+ B,C,H,W = cls.shape
+ device = cls.device
+ res = round(math.sqrt(C))
+ G = torch.meshgrid(*[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)])
+ G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2)
+ if deterministic_sampling:
+ sampled_cls = cls.max(dim=1).indices
+ else:
+ sampled_cls = torch.multinomial(cls.permute(0,2,3,1).reshape(B*H*W,C).softmax(dim=-1), 1).reshape(B,H,W)
+ flow = G[sampled_cls]
+ return flow
+
+@torch.no_grad()
+def cls_to_flow_refine(cls):
+ B,C,H,W = cls.shape
+ device = cls.device
+ res = round(math.sqrt(C))
+ G = torch.meshgrid(*[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)])
+ G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2)
+ cls = cls.softmax(dim=1)
+ mode = cls.max(dim=1).indices
+
+ index = torch.stack((mode-1, mode, mode+1, mode - res, mode + res), dim = 1).clamp(0,C - 1).long()
+ neighbours = torch.gather(cls, dim = 1, index = index)[...,None]
+ flow = neighbours[:,0] * G[index[:,0]] + neighbours[:,1] * G[index[:,1]] + neighbours[:,2] * G[index[:,2]] + neighbours[:,3] * G[index[:,3]] + neighbours[:,4] * G[index[:,4]]
+ tot_prob = neighbours.sum(dim=1)
+ flow = flow / tot_prob
+ return flow
+
+
+def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bilinear', relative_depth_error_threshold = 0.05, H = None, W = None):
+
+ if H is None:
+ B,H,W = depth1.shape
+ else:
+ B = depth1.shape[0]
+ with torch.no_grad():
+ x1_n = torch.meshgrid(
+ *[
+ torch.linspace(
+ -1 + 1 / n, 1 - 1 / n, n, device=depth1.device
+ )
+ for n in (B, H, W)
+ ]
+ )
+ x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2)
+ mask, x2 = warp_kpts(
+ x1_n.double(),
+ depth1.double(),
+ depth2.double(),
+ T_1to2.double(),
+ K1.double(),
+ K2.double(),
+ depth_interpolation_mode = depth_interpolation_mode,
+ relative_depth_error_threshold = relative_depth_error_threshold,
+ )
+ prob = mask.float().reshape(B, H, W)
+ x2 = x2.reshape(B, H, W, 2)
+ return x2, prob
+
+@torch.no_grad()
+def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return_relative_depth_error = False, depth_interpolation_mode = "bilinear", relative_depth_error_threshold = 0.05):
+ """Warp kpts0 from I0 to I1 with depth, K and Rt
+ Also check covisibility and depth consistency.
+ Depth is consistent if relative error < 0.2 (hard-coded).
+ # https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here
+ Args:
+ kpts0 (torch.Tensor): [N, L, 2] - , should be normalized in (-1,1)
+ depth0 (torch.Tensor): [N, H, W],
+ depth1 (torch.Tensor): [N, H, W],
+ T_0to1 (torch.Tensor): [N, 3, 4],
+ K0 (torch.Tensor): [N, 3, 3],
+ K1 (torch.Tensor): [N, 3, 3],
+ Returns:
+ calculable_mask (torch.Tensor): [N, L]
+ warped_keypoints0 (torch.Tensor): [N, L, 2]
+ """
+ (
+ n,
+ h,
+ w,
+ ) = depth0.shape
+ if depth_interpolation_mode == "combined":
+ # Inspired by approach in inloc, try to fill holes from bilinear interpolation by nearest neighbour interpolation
+ if smooth_mask:
+ raise NotImplementedError("Combined bilinear and NN warp not implemented")
+ valid_bilinear, warp_bilinear = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1,
+ smooth_mask = smooth_mask,
+ return_relative_depth_error = return_relative_depth_error,
+ depth_interpolation_mode = "bilinear",
+ relative_depth_error_threshold = relative_depth_error_threshold)
+ valid_nearest, warp_nearest = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1,
+ smooth_mask = smooth_mask,
+ return_relative_depth_error = return_relative_depth_error,
+ depth_interpolation_mode = "nearest-exact",
+ relative_depth_error_threshold = relative_depth_error_threshold)
+ nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest)
+ warp = warp_bilinear.clone()
+ warp[nearest_valid_bilinear_invalid] = warp_nearest[nearest_valid_bilinear_invalid]
+ valid = valid_bilinear | valid_nearest
+ return valid, warp
+
+
+ kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode = depth_interpolation_mode, align_corners=False)[
+ :, 0, :, 0
+ ]
+ kpts0 = torch.stack(
+ (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1
+ ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
+ # Sample depth, get calculable_mask on depth != 0
+ nonzero_mask = kpts0_depth != 0
+
+ # Unproject
+ kpts0_h = (
+ torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1)
+ * kpts0_depth[..., None]
+ ) # (N, L, 3)
+ kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
+ kpts0_cam = kpts0_n
+
+ # Rigid Transform
+ w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
+ w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
+
+ # Project
+ w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
+ w_kpts0 = w_kpts0_h[:, :, :2] / (
+ w_kpts0_h[:, :, [2]] + 1e-4
+ ) # (N, L, 2), +1e-4 to avoid zero depth
+
+ # Covisible Check
+ h, w = depth1.shape[1:3]
+ covisible_mask = (
+ (w_kpts0[:, :, 0] > 0)
+ * (w_kpts0[:, :, 0] < w - 1)
+ * (w_kpts0[:, :, 1] > 0)
+ * (w_kpts0[:, :, 1] < h - 1)
+ )
+ w_kpts0 = torch.stack(
+ (2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1
+ ) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h]
+ # w_kpts0[~covisible_mask, :] = -5 # xd
+
+ w_kpts0_depth = F.grid_sample(
+ depth1[:, None], w_kpts0[:, :, None], mode=depth_interpolation_mode, align_corners=False
+ )[:, 0, :, 0]
+
+ relative_depth_error = (
+ (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth
+ ).abs()
+ if not smooth_mask:
+ consistent_mask = relative_depth_error < relative_depth_error_threshold
+ else:
+ consistent_mask = (-relative_depth_error/smooth_mask).exp()
+ valid_mask = nonzero_mask * covisible_mask * consistent_mask
+ if return_relative_depth_error:
+ return relative_depth_error, w_kpts0
+ else:
+ return valid_mask, w_kpts0
+
+imagenet_mean = torch.tensor([0.485, 0.456, 0.406])
+imagenet_std = torch.tensor([0.229, 0.224, 0.225])
+
+
+def numpy_to_pil(x: np.ndarray):
+ """
+ Args:
+ x: Assumed to be of shape (h,w,c)
+ """
+ if isinstance(x, torch.Tensor):
+ x = x.detach().cpu().numpy()
+ if x.max() <= 1.01:
+ x *= 255
+ x = x.astype(np.uint8)
+ return Image.fromarray(x)
+
+
+def tensor_to_pil(x, unnormalize=False):
+ if unnormalize:
+ x = x * (imagenet_std[:, None, None].to(x.device)) + (imagenet_mean[:, None, None].to(x.device))
+ x = x.detach().permute(1, 2, 0).cpu().numpy()
+ x = np.clip(x, 0.0, 1.0)
+ return numpy_to_pil(x)
+
+
+def to_cuda(batch):
+ for key, value in batch.items():
+ if isinstance(value, torch.Tensor):
+ batch[key] = value.cuda()
+ return batch
+
+
+def to_cpu(batch):
+ for key, value in batch.items():
+ if isinstance(value, torch.Tensor):
+ batch[key] = value.cpu()
+ return batch
+
+
+def get_pose(calib):
+ w, h = np.array(calib["imsize"])[0]
+ return np.array(calib["K"]), np.array(calib["R"]), np.array(calib["T"]).T, h, w
+
+
+def compute_relative_pose(R1, t1, R2, t2):
+ rots = R2 @ (R1.T)
+ trans = -rots @ t1 + t2
+ return rots, trans
+
+@torch.no_grad()
+def reset_opt(opt):
+ for group in opt.param_groups:
+ for p in group['params']:
+ if p.requires_grad:
+ state = opt.state[p]
+ # State initialization
+
+ # Exponential moving average of gradient values
+ state['exp_avg'] = torch.zeros_like(p)
+ # Exponential moving average of squared gradient values
+ state['exp_avg_sq'] = torch.zeros_like(p)
+ # Exponential moving average of gradient difference
+ state['exp_avg_diff'] = torch.zeros_like(p)
+
+
+def flow_to_pixel_coords(flow, h1, w1):
+ flow = (
+ torch.stack(
+ (
+ w1 * (flow[..., 0] + 1) / 2,
+ h1 * (flow[..., 1] + 1) / 2,
+ ),
+ axis=-1,
+ )
+ )
+ return flow
+
+def flow_to_normalized_coords(flow, h1, w1):
+ flow = (
+ torch.stack(
+ (
+ 2 * (flow[..., 0]) / w1 - 1,
+ 2 * (flow[..., 1]) / h1 - 1,
+ ),
+ axis=-1,
+ )
+ )
+ return flow
+
+
+def warp_to_pixel_coords(warp, h1, w1, h2, w2):
+ warp1 = warp[..., :2]
+ warp1 = (
+ torch.stack(
+ (
+ w1 * (warp1[..., 0] + 1) / 2,
+ h1 * (warp1[..., 1] + 1) / 2,
+ ),
+ axis=-1,
+ )
+ )
+ warp2 = warp[..., 2:]
+ warp2 = (
+ torch.stack(
+ (
+ w2 * (warp2[..., 0] + 1) / 2,
+ h2 * (warp2[..., 1] + 1) / 2,
+ ),
+ axis=-1,
+ )
+ )
+ return torch.cat((warp1,warp2), dim=-1)
+
+
+
+def signed_point_line_distance(point, line, eps: float = 1e-9):
+ r"""Return the distance from points to lines.
+
+ Args:
+ point: (possibly homogeneous) points :math:`(*, N, 2 or 3)`.
+ line: lines coefficients :math:`(a, b, c)` with shape :math:`(*, N, 3)`, where :math:`ax + by + c = 0`.
+ eps: Small constant for safe sqrt.
+
+ Returns:
+ the computed distance with shape :math:`(*, N)`.
+ """
+
+ if not point.shape[-1] in (2, 3):
+ raise ValueError(f"pts must be a (*, 2 or 3) tensor. Got {point.shape}")
+
+ if not line.shape[-1] == 3:
+ raise ValueError(f"lines must be a (*, 3) tensor. Got {line.shape}")
+
+ numerator = (line[..., 0] * point[..., 0] + line[..., 1] * point[..., 1] + line[..., 2])
+ denominator = line[..., :2].norm(dim=-1)
+
+ return numerator / (denominator + eps)
+
+
+def signed_left_to_right_epipolar_distance(pts1, pts2, Fm):
+ r"""Return one-sided epipolar distance for correspondences given the fundamental matrix.
+
+ This method measures the distance from points in the right images to the epilines
+ of the corresponding points in the left images as they reflect in the right images.
+
+ Args:
+ pts1: correspondences from the left images with shape
+ :math:`(*, N, 2 or 3)`. If they are not homogeneous, converted automatically.
+ pts2: correspondences from the right images with shape
+ :math:`(*, N, 2 or 3)`. If they are not homogeneous, converted automatically.
+ Fm: Fundamental matrices with shape :math:`(*, 3, 3)`. Called Fm to
+ avoid ambiguity with torch.nn.functional.
+
+ Returns:
+ the computed Symmetrical distance with shape :math:`(*, N)`.
+ """
+ import kornia
+ if (len(Fm.shape) < 3) or not Fm.shape[-2:] == (3, 3):
+ raise ValueError(f"Fm must be a (*, 3, 3) tensor. Got {Fm.shape}")
+
+ if pts1.shape[-1] == 2:
+ pts1 = kornia.geometry.convert_points_to_homogeneous(pts1)
+
+ F_t = Fm.transpose(dim0=-2, dim1=-1)
+ line1_in_2 = pts1 @ F_t
+
+ return signed_point_line_distance(pts2, line1_in_2)
+
+def get_grid(b, h, w, device):
+ grid = torch.meshgrid(
+ *[
+ torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=device)
+ for n in (b, h, w)
+ ]
+ )
+ grid = torch.stack((grid[2], grid[1]), dim=-1).reshape(b, h, w, 2)
+ return grid
diff --git a/third_party/Roma/setup.py b/third_party/Roma/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae777c0e5a41f0e4b03a838d19bc9a2bb04d4617
--- /dev/null
+++ b/third_party/Roma/setup.py
@@ -0,0 +1,9 @@
+from setuptools import setup
+
+setup(
+ name="roma",
+ packages=["roma"],
+ version="0.0.1",
+ author="Johan Edstedt",
+ install_requires=open("requirements.txt", "r").read().split("\n"),
+)