Vincentqyw commited on
Commit
dbf8b7e
·
1 Parent(s): 10b4a5f
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. third_party/DKM/.gitignore +3 -0
  2. third_party/DKM/LICENSE +25 -0
  3. third_party/DKM/README.md +117 -0
  4. third_party/DKM/assets/ams_hom_A.jpg +3 -0
  5. third_party/DKM/assets/ams_hom_B.jpg +3 -0
  6. third_party/DKM/assets/dkmv3_warp.jpg +3 -0
  7. third_party/DKM/assets/mega_8_scenes_0008_0.1_0.3.npz +3 -0
  8. third_party/DKM/assets/mega_8_scenes_0008_0.3_0.5.npz +3 -0
  9. third_party/DKM/assets/mega_8_scenes_0019_0.1_0.3.npz +3 -0
  10. third_party/DKM/assets/mega_8_scenes_0019_0.3_0.5.npz +3 -0
  11. third_party/DKM/assets/mega_8_scenes_0021_0.1_0.3.npz +3 -0
  12. third_party/DKM/assets/mega_8_scenes_0021_0.3_0.5.npz +3 -0
  13. third_party/DKM/assets/mega_8_scenes_0024_0.1_0.3.npz +3 -0
  14. third_party/DKM/assets/mega_8_scenes_0024_0.3_0.5.npz +3 -0
  15. third_party/DKM/assets/mega_8_scenes_0025_0.1_0.3.npz +3 -0
  16. third_party/DKM/assets/mega_8_scenes_0025_0.3_0.5.npz +3 -0
  17. third_party/DKM/assets/mega_8_scenes_0032_0.1_0.3.npz +3 -0
  18. third_party/DKM/assets/mega_8_scenes_0032_0.3_0.5.npz +3 -0
  19. third_party/DKM/assets/mega_8_scenes_0063_0.1_0.3.npz +3 -0
  20. third_party/DKM/assets/mega_8_scenes_0063_0.3_0.5.npz +3 -0
  21. third_party/DKM/assets/mega_8_scenes_1589_0.1_0.3.npz +3 -0
  22. third_party/DKM/assets/mega_8_scenes_1589_0.3_0.5.npz +3 -0
  23. third_party/DKM/assets/mount_rushmore.mp4 +0 -0
  24. third_party/DKM/assets/sacre_coeur_A.jpg +3 -0
  25. third_party/DKM/assets/sacre_coeur_B.jpg +3 -0
  26. third_party/DKM/data/.gitignore +2 -0
  27. third_party/DKM/demo/.gitignore +1 -0
  28. third_party/DKM/demo/demo_fundamental.py +37 -0
  29. third_party/DKM/demo/demo_match.py +48 -0
  30. third_party/DKM/dkm/__init__.py +4 -0
  31. third_party/DKM/dkm/benchmarks/__init__.py +4 -0
  32. third_party/DKM/dkm/benchmarks/deprecated/hpatches_sequences_dense_benchmark.py +100 -0
  33. third_party/DKM/dkm/benchmarks/deprecated/yfcc100m_benchmark.py +119 -0
  34. third_party/DKM/dkm/benchmarks/hpatches_sequences_homog_benchmark.py +114 -0
  35. third_party/DKM/dkm/benchmarks/megadepth1500_benchmark.py +124 -0
  36. third_party/DKM/dkm/benchmarks/megadepth_dense_benchmark.py +86 -0
  37. third_party/DKM/dkm/benchmarks/scannet_benchmark.py +143 -0
  38. third_party/DKM/dkm/checkpointing/__init__.py +1 -0
  39. third_party/DKM/dkm/checkpointing/checkpoint.py +31 -0
  40. third_party/DKM/dkm/datasets/__init__.py +1 -0
  41. third_party/DKM/dkm/datasets/megadepth.py +177 -0
  42. third_party/DKM/dkm/datasets/scannet.py +151 -0
  43. third_party/DKM/dkm/losses/__init__.py +1 -0
  44. third_party/DKM/dkm/losses/depth_match_regression_loss.py +128 -0
  45. third_party/DKM/dkm/models/__init__.py +4 -0
  46. third_party/DKM/dkm/models/deprecated/build_model.py +787 -0
  47. third_party/DKM/dkm/models/deprecated/corr_channels.py +34 -0
  48. third_party/DKM/dkm/models/deprecated/local_corr.py +630 -0
  49. third_party/DKM/dkm/models/dkm.py +759 -0
  50. third_party/DKM/dkm/models/encoders.py +147 -0
third_party/DKM/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *.egg-info*
2
+ *.vscode*
3
+ *__pycache__*
third_party/DKM/LICENSE ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NOTE! Models trained on our synthetic dataset uses datasets which are licensed under non-commercial licenses.
2
+ Hence we cannot provide them under the MIT license. However, MegaDepth is under MIT license, hence we provide those models under MIT license, see below.
3
+
4
+
5
+ License for Models Trained on MegaDepth ONLY below:
6
+
7
+ Copyright (c) 2022 Johan Edstedt
8
+
9
+ Permission is hereby granted, free of charge, to any person obtaining a copy
10
+ of this software and associated documentation files (the "Software"), to deal
11
+ in the Software without restriction, including without limitation the rights
12
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13
+ copies of the Software, and to permit persons to whom the Software is
14
+ furnished to do so, subject to the following conditions:
15
+
16
+ The above copyright notice and this permission notice shall be included in all
17
+ copies or substantial portions of the Software.
18
+
19
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25
+ SOFTWARE.
third_party/DKM/README.md ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DKM: Dense Kernelized Feature Matching for Geometry Estimation
2
+ ### [Project Page](https://parskatt.github.io/DKM) | [Paper](https://arxiv.org/abs/2202.00667)
3
+ <br/>
4
+
5
+ > DKM: Dense Kernelized Feature Matching for Geometry Estimation
6
+ > [Johan Edstedt](https://scholar.google.com/citations?user=Ul-vMR0AAAAJ), [Ioannis Athanasiadis](https://scholar.google.com/citations?user=RCAtJgUAAAAJ), [Mårten Wadenbäck](https://scholar.google.com/citations?user=6WRQpCQAAAAJ), [Michael Felsberg](https://scholar.google.com/citations?&user=lkWfR08AAAAJ)
7
+ > CVPR 2023
8
+
9
+ ## How to Use?
10
+ <details>
11
+ Our model produces a dense (for all pixels) warp and certainty.
12
+
13
+ Warp: [B,H,W,4] for all images in batch of size B, for each pixel HxW, we ouput the input and matching coordinate in the normalized grids [-1,1]x[-1,1].
14
+
15
+ Certainty: [B,H,W] a number in each pixel indicating the matchability of the pixel.
16
+
17
+ See [demo](dkm/demo/) for two demos of DKM.
18
+
19
+ See [api.md](docs/api.md) for API.
20
+ </details>
21
+
22
+ ## Qualitative Results
23
+ <details>
24
+
25
+ https://user-images.githubusercontent.com/22053118/223748279-0f0c21b4-376a-440a-81f5-7f9a5d87483f.mp4
26
+
27
+
28
+ https://user-images.githubusercontent.com/22053118/223748512-1bca4a17-cffa-491d-a448-96aac1353ce9.mp4
29
+
30
+
31
+
32
+ https://user-images.githubusercontent.com/22053118/223748518-4d475d9f-a933-4581-97ed-6e9413c4caca.mp4
33
+
34
+
35
+
36
+ https://user-images.githubusercontent.com/22053118/223748522-39c20631-aa16-4954-9c27-95763b38f2ce.mp4
37
+
38
+
39
+ </details>
40
+
41
+
42
+
43
+ ## Benchmark Results
44
+
45
+ <details>
46
+
47
+ ### Megadepth1500
48
+
49
+ | | @5 | @10 | @20 |
50
+ |-------|-------|------|------|
51
+ | DKMv1 | 54.5 | 70.7 | 82.3 |
52
+ | DKMv2 | *56.8* | *72.3* | *83.2* |
53
+ | DKMv3 (paper) | **60.5** | **74.9** | **85.1** |
54
+ | DKMv3 (this repo) | **60.0** | **74.6** | **84.9** |
55
+
56
+ ### Megadepth 8 Scenes
57
+ | | @5 | @10 | @20 |
58
+ |-------|-------|------|------|
59
+ | DKMv3 (paper) | **60.5** | **74.5** | **84.2** |
60
+ | DKMv3 (this repo) | **60.4** | **74.6** | **84.3** |
61
+
62
+
63
+ ### ScanNet1500
64
+ | | @5 | @10 | @20 |
65
+ |-------|-------|------|------|
66
+ | DKMv1 | 24.8 | 44.4 | 61.9 |
67
+ | DKMv2 | *28.2* | *49.2* | *66.6* |
68
+ | DKMv3 (paper) | **29.4** | **50.7** | **68.3** |
69
+ | DKMv3 (this repo) | **29.8** | **50.8** | **68.3** |
70
+
71
+ </details>
72
+
73
+ ## Navigating the Code
74
+ * Code for models can be found in [dkm/models](dkm/models/)
75
+ * Code for benchmarks can be found in [dkm/benchmarks](dkm/benchmarks/)
76
+ * Code for reproducing experiments from our paper can be found in [experiments/](experiments/)
77
+
78
+ ## Install
79
+ Run ``pip install -e .``
80
+
81
+ ## Demo
82
+
83
+ A demonstration of our method can be run by:
84
+ ``` bash
85
+ python demo_match.py
86
+ ```
87
+ This runs our model trained on mega on two images taken from Sacre Coeur.
88
+
89
+ ## Benchmarks
90
+ See [Benchmarks](docs/benchmarks.md) for details.
91
+ ## Training
92
+ See [Training](docs/training.md) for details.
93
+ ## Reproducing Results
94
+ Given that the required benchmark or training dataset has been downloaded and unpacked, results can be reproduced by running the experiments in the experiments folder.
95
+
96
+ ## Using DKM matches for estimation
97
+ We recommend using the excellent Graph-Cut RANSAC algorithm: https://github.com/danini/graph-cut-ransac
98
+
99
+ | | @5 | @10 | @20 |
100
+ |-------|-------|------|------|
101
+ | DKMv3 (RANSAC) | *60.5* | *74.9* | *85.1* |
102
+ | DKMv3 (GC-RANSAC) | **65.5** | **78.0** | **86.7** |
103
+
104
+
105
+ ## Acknowledgements
106
+ We have used code and been inspired by https://github.com/PruneTruong/DenseMatching, https://github.com/zju3dv/LoFTR, and https://github.com/GrumpyZhou/patch2pix. We additionally thank the authors of ECO-TR for providing their benchmark.
107
+
108
+ ## BibTeX
109
+ If you find our models useful, please consider citing our paper!
110
+ ```
111
+ @inproceedings{edstedt2023dkm,
112
+ title={{DKM}: Dense Kernelized Feature Matching for Geometry Estimation},
113
+ author={Edstedt, Johan and Athanasiadis, Ioannis and Wadenbäck, Mårten and Felsberg, Michael},
114
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
115
+ year={2023}
116
+ }
117
+ ```
third_party/DKM/assets/ams_hom_A.jpg ADDED

Git LFS Details

  • SHA256: 271a19f0b29fc88d8f88d1136f001078ca6bf5105ff95355f89a18e787c50e3a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.19 MB
third_party/DKM/assets/ams_hom_B.jpg ADDED

Git LFS Details

  • SHA256: d84ced12e607f5ac5f7628151694fbaa2300caa091ac168e0aedad2ebaf491d6
  • Pointer size: 132 Bytes
  • Size of remote file: 1.21 MB
third_party/DKM/assets/dkmv3_warp.jpg ADDED

Git LFS Details

  • SHA256: 04c46e39d5ea68e9e116d4ae71c038a459beaf3eed89e8b7b87ccafd01d3bf85
  • Pointer size: 131 Bytes
  • Size of remote file: 571 kB
third_party/DKM/assets/mega_8_scenes_0008_0.1_0.3.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c902547181fc9b370fdd16272140be6803fe983aea978c68683db803ac70dd57
3
+ size 906160
third_party/DKM/assets/mega_8_scenes_0008_0.3_0.5.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65ce02bd248988b42363ccd257abaa9b99a00d569d2779597b36ba6c4da35021
3
+ size 906160
third_party/DKM/assets/mega_8_scenes_0019_0.1_0.3.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6104feb8807a4ebdd1266160e67b3c507c550012f54c23292d0ebf99b88753f
3
+ size 368192
third_party/DKM/assets/mega_8_scenes_0019_0.3_0.5.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9600ba5c24d414f63728bf5ee7550a3b035d7c615461e357590890ae0e0f042e
3
+ size 368192
third_party/DKM/assets/mega_8_scenes_0021_0.1_0.3.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de89e9ccf10515cc4196ba1e7172ec98b2fb92ff9f85d90db5df1af5b6503313
3
+ size 167528
third_party/DKM/assets/mega_8_scenes_0021_0.3_0.5.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94c97b57beb10411b3b98a1e88c9e1e2f9db51994dce04580d2b7cfc8919dab3
3
+ size 167528
third_party/DKM/assets/mega_8_scenes_0024_0.1_0.3.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f14a66dbbd7fa8f31756dd496bfabe4c3ea115c6914acad9365dd02e46ae674
3
+ size 63909
third_party/DKM/assets/mega_8_scenes_0024_0.3_0.5.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dfaee333beccd1da0d920777cdc8f17d584b21ba20f675b39222c5a205acf72a
3
+ size 63909
third_party/DKM/assets/mega_8_scenes_0025_0.1_0.3.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b446ca3cc2073c8a3963cf68cc450ef2ebf73d2b956b1f5ae6b37621bc67cce4
3
+ size 200371
third_party/DKM/assets/mega_8_scenes_0025_0.3_0.5.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df1969fd94032562b5e8d916467101a878168d586b795770e64c108bab250c9e
3
+ size 200371
third_party/DKM/assets/mega_8_scenes_0032_0.1_0.3.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37cbafa0b0f981f5d69aba202ddd37c5892bda0fa13d053a3ad27d6ddad51c16
3
+ size 642823
third_party/DKM/assets/mega_8_scenes_0032_0.3_0.5.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:962b3fadc7c94ea8e4a1bb5e168e72b0b6cc474ae56b0aee70ba4e517553fbcf
3
+ size 642823
third_party/DKM/assets/mega_8_scenes_0063_0.1_0.3.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50ed6b02dff2fa719e4e9ca216b4704b82cbbefd127355d3ba7120828407e723
3
+ size 228647
third_party/DKM/assets/mega_8_scenes_0063_0.3_0.5.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:edc97129d000f0478020495f646f2fa7667408247ccd11054e02efbbb38d1444
3
+ size 228647
third_party/DKM/assets/mega_8_scenes_1589_0.1_0.3.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:04b0b6c6adff812e12b66476f7ca2a6ed2564cdd8208ec0c775f7b922f160103
3
+ size 177063
third_party/DKM/assets/mega_8_scenes_1589_0.3_0.5.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae931f8cac1b2168f699c70efe42c215eaff27d3f0617d59afb3db183c9b1848
3
+ size 177063
third_party/DKM/assets/mount_rushmore.mp4 ADDED
Binary file (986 kB). View file
 
third_party/DKM/assets/sacre_coeur_A.jpg ADDED

Git LFS Details

  • SHA256: 90d9c5f5a4d76425624989215120fba6f2899190a1d5654b88fa380c64cf6b2c
  • Pointer size: 131 Bytes
  • Size of remote file: 118 kB
third_party/DKM/assets/sacre_coeur_B.jpg ADDED

Git LFS Details

  • SHA256: 2f1eb9bdd4d80e480f672d6a729689ac77f9fd5c8deb90f59b377590f3ca4799
  • Pointer size: 131 Bytes
  • Size of remote file: 153 kB
third_party/DKM/data/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *
2
+ !.gitignore
third_party/DKM/demo/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *.jpg
third_party/DKM/demo/demo_fundamental.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from dkm.utils.utils import tensor_to_pil
6
+ import cv2
7
+ from dkm import DKMv3_outdoor
8
+
9
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
+
11
+
12
+ if __name__ == "__main__":
13
+ from argparse import ArgumentParser
14
+ parser = ArgumentParser()
15
+ parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
16
+ parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
17
+
18
+ args, _ = parser.parse_known_args()
19
+ im1_path = args.im_A_path
20
+ im2_path = args.im_B_path
21
+
22
+ # Create model
23
+ dkm_model = DKMv3_outdoor(device=device)
24
+
25
+
26
+ W_A, H_A = Image.open(im1_path).size
27
+ W_B, H_B = Image.open(im2_path).size
28
+
29
+ # Match
30
+ warp, certainty = dkm_model.match(im1_path, im2_path, device=device)
31
+ # Sample matches for estimation
32
+ matches, certainty = dkm_model.sample(warp, certainty)
33
+ kpts1, kpts2 = dkm_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
34
+ F, mask = cv2.findFundamentalMat(
35
+ kpts1.cpu().numpy(), kpts2.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
36
+ )
37
+ # TODO: some better visualization
third_party/DKM/demo/demo_match.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from dkm.utils.utils import tensor_to_pil
6
+
7
+ from dkm import DKMv3_outdoor
8
+
9
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
+
11
+
12
+ if __name__ == "__main__":
13
+ from argparse import ArgumentParser
14
+ parser = ArgumentParser()
15
+ parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
16
+ parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
17
+ parser.add_argument("--save_path", default="demo/dkmv3_warp_sacre_coeur.jpg", type=str)
18
+
19
+ args, _ = parser.parse_known_args()
20
+ im1_path = args.im_A_path
21
+ im2_path = args.im_B_path
22
+ save_path = args.save_path
23
+
24
+ # Create model
25
+ dkm_model = DKMv3_outdoor(device=device)
26
+
27
+ H, W = 864, 1152
28
+
29
+ im1 = Image.open(im1_path).resize((W, H))
30
+ im2 = Image.open(im2_path).resize((W, H))
31
+
32
+ # Match
33
+ warp, certainty = dkm_model.match(im1_path, im2_path, device=device)
34
+ # Sampling not needed, but can be done with model.sample(warp, certainty)
35
+ dkm_model.sample(warp, certainty)
36
+ x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
37
+ x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
38
+
39
+ im2_transfer_rgb = F.grid_sample(
40
+ x2[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
41
+ )[0]
42
+ im1_transfer_rgb = F.grid_sample(
43
+ x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
44
+ )[0]
45
+ warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2)
46
+ white_im = torch.ones((H,2*W),device=device)
47
+ vis_im = certainty * warp_im + (1 - certainty) * white_im
48
+ tensor_to_pil(vis_im, unnormalize=False).save(save_path)
third_party/DKM/dkm/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .models import (
2
+ DKMv3_outdoor,
3
+ DKMv3_indoor,
4
+ )
third_party/DKM/dkm/benchmarks/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .hpatches_sequences_homog_benchmark import HpatchesHomogBenchmark
2
+ from .scannet_benchmark import ScanNetBenchmark
3
+ from .megadepth1500_benchmark import Megadepth1500Benchmark
4
+ from .megadepth_dense_benchmark import MegadepthDenseBenchmark
third_party/DKM/dkm/benchmarks/deprecated/hpatches_sequences_dense_benchmark.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+
4
+ import os
5
+
6
+ import torch
7
+ from tqdm import tqdm
8
+
9
+ from dkm.utils import *
10
+
11
+
12
+ class HpatchesDenseBenchmark:
13
+ """WARNING: HPATCHES grid goes from [0,n-1] instead of [0.5,n-0.5]"""
14
+
15
+ def __init__(self, dataset_path) -> None:
16
+ seqs_dir = "hpatches-sequences-release"
17
+ self.seqs_path = os.path.join(dataset_path, seqs_dir)
18
+ self.seq_names = sorted(os.listdir(self.seqs_path))
19
+
20
+ def convert_coordinates(self, query_coords, query_to_support, wq, hq, wsup, hsup):
21
+ # Get matches in output format on the grid [0, n] where the center of the top-left coordinate is [0.5, 0.5]
22
+ offset = (
23
+ 0.5 # Hpatches assumes that the center of the top-left pixel is at [0,0]
24
+ )
25
+ query_coords = (
26
+ torch.stack(
27
+ (
28
+ wq * (query_coords[..., 0] + 1) / 2,
29
+ hq * (query_coords[..., 1] + 1) / 2,
30
+ ),
31
+ axis=-1,
32
+ )
33
+ - offset
34
+ )
35
+ query_to_support = (
36
+ torch.stack(
37
+ (
38
+ wsup * (query_to_support[..., 0] + 1) / 2,
39
+ hsup * (query_to_support[..., 1] + 1) / 2,
40
+ ),
41
+ axis=-1,
42
+ )
43
+ - offset
44
+ )
45
+ return query_coords, query_to_support
46
+
47
+ def inside_image(self, x, w, h):
48
+ return torch.logical_and(
49
+ x[:, 0] < (w - 1),
50
+ torch.logical_and(x[:, 1] < (h - 1), (x > 0).prod(dim=-1)),
51
+ )
52
+
53
+ def benchmark(self, model):
54
+ use_cuda = torch.cuda.is_available()
55
+ device = torch.device("cuda:0" if use_cuda else "cpu")
56
+ aepes = []
57
+ pcks = []
58
+ for seq_idx, seq_name in tqdm(
59
+ enumerate(self.seq_names), total=len(self.seq_names)
60
+ ):
61
+ if seq_name[0] == "i":
62
+ continue
63
+ im1_path = os.path.join(self.seqs_path, seq_name, "1.ppm")
64
+ im1 = Image.open(im1_path)
65
+ w1, h1 = im1.size
66
+ for im_idx in range(2, 7):
67
+ im2_path = os.path.join(self.seqs_path, seq_name, f"{im_idx}.ppm")
68
+ im2 = Image.open(im2_path)
69
+ w2, h2 = im2.size
70
+ matches, certainty = model.match(im2, im1, do_pred_in_og_res=True)
71
+ matches, certainty = matches.reshape(-1, 4), certainty.reshape(-1)
72
+ inv_homography = torch.from_numpy(
73
+ np.loadtxt(
74
+ os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx))
75
+ )
76
+ ).to(device)
77
+ homography = torch.linalg.inv(inv_homography)
78
+ pos_a, pos_b = self.convert_coordinates(
79
+ matches[:, :2], matches[:, 2:], w2, h2, w1, h1
80
+ )
81
+ pos_a, pos_b = pos_a.double(), pos_b.double()
82
+ pos_a_h = torch.cat(
83
+ [pos_a, torch.ones([pos_a.shape[0], 1], device=device)], dim=1
84
+ )
85
+ pos_b_proj_h = (homography @ pos_a_h.t()).t()
86
+ pos_b_proj = pos_b_proj_h[:, :2] / pos_b_proj_h[:, 2:]
87
+ mask = self.inside_image(pos_b_proj, w1, h1)
88
+ residual = pos_b - pos_b_proj
89
+ dist = (residual**2).sum(dim=1).sqrt()[mask]
90
+ aepes.append(torch.mean(dist).item())
91
+ pck1 = (dist < 1.0).float().mean().item()
92
+ pck3 = (dist < 3.0).float().mean().item()
93
+ pck5 = (dist < 5.0).float().mean().item()
94
+ pcks.append([pck1, pck3, pck5])
95
+ m_pcks = np.mean(np.array(pcks), axis=0)
96
+ return {
97
+ "hp_pck1": m_pcks[0],
98
+ "hp_pck3": m_pcks[1],
99
+ "hp_pck5": m_pcks[2],
100
+ }
third_party/DKM/dkm/benchmarks/deprecated/yfcc100m_benchmark.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import h5py
3
+ import numpy as np
4
+ import torch
5
+ from dkm.utils import *
6
+ from PIL import Image
7
+ from tqdm import tqdm
8
+
9
+
10
+ class Yfcc100mBenchmark:
11
+ def __init__(self, data_root="data/yfcc100m_test") -> None:
12
+ self.scenes = [
13
+ "buckingham_palace",
14
+ "notre_dame_front_facade",
15
+ "reichstag",
16
+ "sacre_coeur",
17
+ ]
18
+ self.data_root = data_root
19
+
20
+ def benchmark(self, model, r=2):
21
+ model.train(False)
22
+ with torch.no_grad():
23
+ data_root = self.data_root
24
+ meta_info = open(
25
+ f"{data_root}/yfcc_test_pairs_with_gt.txt", "r"
26
+ ).readlines()
27
+ tot_e_t, tot_e_R, tot_e_pose = [], [], []
28
+ for scene_ind in range(len(self.scenes)):
29
+ scene = self.scenes[scene_ind]
30
+ pairs = np.array(
31
+ pickle.load(
32
+ open(f"{data_root}/pairs/{scene}-te-1000-pairs.pkl", "rb")
33
+ )
34
+ )
35
+ scene_dir = f"{data_root}/yfcc100m/{scene}/test/"
36
+ calibs = open(scene_dir + "calibration.txt", "r").read().split("\n")
37
+ images = open(scene_dir + "images.txt", "r").read().split("\n")
38
+ pair_inds = np.random.choice(
39
+ range(len(pairs)), size=len(pairs), replace=False
40
+ )
41
+ for pairind in tqdm(pair_inds):
42
+ idx1, idx2 = pairs[pairind]
43
+ params = meta_info[1000 * scene_ind + pairind].split()
44
+ rot1, rot2 = int(params[2]), int(params[3])
45
+ calib1 = h5py.File(scene_dir + calibs[idx1], "r")
46
+ K1, R1, t1, _, _ = get_pose(calib1)
47
+ calib2 = h5py.File(scene_dir + calibs[idx2], "r")
48
+ K2, R2, t2, _, _ = get_pose(calib2)
49
+
50
+ R, t = compute_relative_pose(R1, t1, R2, t2)
51
+ im1 = images[idx1]
52
+ im2 = images[idx2]
53
+ im1 = Image.open(scene_dir + im1).rotate(rot1 * 90, expand=True)
54
+ w1, h1 = im1.size
55
+ im2 = Image.open(scene_dir + im2).rotate(rot2 * 90, expand=True)
56
+ w2, h2 = im2.size
57
+ K1 = rotate_intrinsic(K1, rot1)
58
+ K2 = rotate_intrinsic(K2, rot2)
59
+
60
+ dense_matches, dense_certainty = model.match(im1, im2)
61
+ dense_certainty = dense_certainty ** (1 / r)
62
+ sparse_matches, sparse_confidence = model.sample(
63
+ dense_matches, dense_certainty, 10000
64
+ )
65
+ scale1 = 480 / min(w1, h1)
66
+ scale2 = 480 / min(w2, h2)
67
+ w1, h1 = scale1 * w1, scale1 * h1
68
+ w2, h2 = scale2 * w2, scale2 * h2
69
+ K1 = K1 * scale1
70
+ K2 = K2 * scale2
71
+
72
+ kpts1 = sparse_matches[:, :2]
73
+ kpts1 = np.stack(
74
+ (w1 * kpts1[:, 0] / 2, h1 * kpts1[:, 1] / 2), axis=-1
75
+ )
76
+ kpts2 = sparse_matches[:, 2:]
77
+ kpts2 = np.stack(
78
+ (w2 * kpts2[:, 0] / 2, h2 * kpts2[:, 1] / 2), axis=-1
79
+ )
80
+ try:
81
+ threshold = 1.0
82
+ norm_threshold = threshold / (
83
+ np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2]))
84
+ )
85
+ R_est, t_est, mask = estimate_pose(
86
+ kpts1,
87
+ kpts2,
88
+ K1[:2, :2],
89
+ K2[:2, :2],
90
+ norm_threshold,
91
+ conf=0.9999999,
92
+ )
93
+ T1_to_2 = np.concatenate((R_est, t_est), axis=-1) #
94
+ e_t, e_R = compute_pose_error(T1_to_2, R, t)
95
+ e_pose = max(e_t, e_R)
96
+ except:
97
+ e_t, e_R = 90, 90
98
+ e_pose = max(e_t, e_R)
99
+ tot_e_t.append(e_t)
100
+ tot_e_R.append(e_R)
101
+ tot_e_pose.append(e_pose)
102
+ tot_e_pose = np.array(tot_e_pose)
103
+ thresholds = [5, 10, 20]
104
+ auc = pose_auc(tot_e_pose, thresholds)
105
+ acc_5 = (tot_e_pose < 5).mean()
106
+ acc_10 = (tot_e_pose < 10).mean()
107
+ acc_15 = (tot_e_pose < 15).mean()
108
+ acc_20 = (tot_e_pose < 20).mean()
109
+ map_5 = acc_5
110
+ map_10 = np.mean([acc_5, acc_10])
111
+ map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
112
+ return {
113
+ "auc_5": auc[0],
114
+ "auc_10": auc[1],
115
+ "auc_20": auc[2],
116
+ "map_5": map_5,
117
+ "map_10": map_10,
118
+ "map_20": map_20,
119
+ }
third_party/DKM/dkm/benchmarks/hpatches_sequences_homog_benchmark.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+
4
+ import os
5
+
6
+ from tqdm import tqdm
7
+ from dkm.utils import pose_auc
8
+ import cv2
9
+
10
+
11
+ class HpatchesHomogBenchmark:
12
+ """Hpatches grid goes from [0,n-1] instead of [0.5,n-0.5]"""
13
+
14
+ def __init__(self, dataset_path) -> None:
15
+ seqs_dir = "hpatches-sequences-release"
16
+ self.seqs_path = os.path.join(dataset_path, seqs_dir)
17
+ self.seq_names = sorted(os.listdir(self.seqs_path))
18
+ # Ignore seqs is same as LoFTR.
19
+ self.ignore_seqs = set(
20
+ [
21
+ "i_contruction",
22
+ "i_crownnight",
23
+ "i_dc",
24
+ "i_pencils",
25
+ "i_whitebuilding",
26
+ "v_artisans",
27
+ "v_astronautis",
28
+ "v_talent",
29
+ ]
30
+ )
31
+
32
+ def convert_coordinates(self, query_coords, query_to_support, wq, hq, wsup, hsup):
33
+ offset = 0.5 # Hpatches assumes that the center of the top-left pixel is at [0,0] (I think)
34
+ query_coords = (
35
+ np.stack(
36
+ (
37
+ wq * (query_coords[..., 0] + 1) / 2,
38
+ hq * (query_coords[..., 1] + 1) / 2,
39
+ ),
40
+ axis=-1,
41
+ )
42
+ - offset
43
+ )
44
+ query_to_support = (
45
+ np.stack(
46
+ (
47
+ wsup * (query_to_support[..., 0] + 1) / 2,
48
+ hsup * (query_to_support[..., 1] + 1) / 2,
49
+ ),
50
+ axis=-1,
51
+ )
52
+ - offset
53
+ )
54
+ return query_coords, query_to_support
55
+
56
+ def benchmark(self, model, model_name = None):
57
+ n_matches = []
58
+ homog_dists = []
59
+ for seq_idx, seq_name in tqdm(
60
+ enumerate(self.seq_names), total=len(self.seq_names)
61
+ ):
62
+ if seq_name in self.ignore_seqs:
63
+ continue
64
+ im1_path = os.path.join(self.seqs_path, seq_name, "1.ppm")
65
+ im1 = Image.open(im1_path)
66
+ w1, h1 = im1.size
67
+ for im_idx in range(2, 7):
68
+ im2_path = os.path.join(self.seqs_path, seq_name, f"{im_idx}.ppm")
69
+ im2 = Image.open(im2_path)
70
+ w2, h2 = im2.size
71
+ H = np.loadtxt(
72
+ os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx))
73
+ )
74
+ dense_matches, dense_certainty = model.match(
75
+ im1_path, im2_path
76
+ )
77
+ good_matches, _ = model.sample(dense_matches, dense_certainty, 5000)
78
+ pos_a, pos_b = self.convert_coordinates(
79
+ good_matches[:, :2], good_matches[:, 2:], w1, h1, w2, h2
80
+ )
81
+ try:
82
+ H_pred, inliers = cv2.findHomography(
83
+ pos_a,
84
+ pos_b,
85
+ method = cv2.RANSAC,
86
+ confidence = 0.99999,
87
+ ransacReprojThreshold = 3 * min(w2, h2) / 480,
88
+ )
89
+ except:
90
+ H_pred = None
91
+ if H_pred is None:
92
+ H_pred = np.zeros((3, 3))
93
+ H_pred[2, 2] = 1.0
94
+ corners = np.array(
95
+ [[0, 0, 1], [0, h1 - 1, 1], [w1 - 1, 0, 1], [w1 - 1, h1 - 1, 1]]
96
+ )
97
+ real_warped_corners = np.dot(corners, np.transpose(H))
98
+ real_warped_corners = (
99
+ real_warped_corners[:, :2] / real_warped_corners[:, 2:]
100
+ )
101
+ warped_corners = np.dot(corners, np.transpose(H_pred))
102
+ warped_corners = warped_corners[:, :2] / warped_corners[:, 2:]
103
+ mean_dist = np.mean(
104
+ np.linalg.norm(real_warped_corners - warped_corners, axis=1)
105
+ ) / (min(w2, h2) / 480.0)
106
+ homog_dists.append(mean_dist)
107
+ n_matches = np.array(n_matches)
108
+ thresholds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
109
+ auc = pose_auc(np.array(homog_dists), thresholds)
110
+ return {
111
+ "hpatches_homog_auc_3": auc[2],
112
+ "hpatches_homog_auc_5": auc[4],
113
+ "hpatches_homog_auc_10": auc[9],
114
+ }
third_party/DKM/dkm/benchmarks/megadepth1500_benchmark.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from dkm.utils import *
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+ import torch.nn.functional as F
7
+
8
+ class Megadepth1500Benchmark:
9
+ def __init__(self, data_root="data/megadepth", scene_names = None) -> None:
10
+ if scene_names is None:
11
+ self.scene_names = [
12
+ "0015_0.1_0.3.npz",
13
+ "0015_0.3_0.5.npz",
14
+ "0022_0.1_0.3.npz",
15
+ "0022_0.3_0.5.npz",
16
+ "0022_0.5_0.7.npz",
17
+ ]
18
+ else:
19
+ self.scene_names = scene_names
20
+ self.scenes = [
21
+ np.load(f"{data_root}/{scene}", allow_pickle=True)
22
+ for scene in self.scene_names
23
+ ]
24
+ self.data_root = data_root
25
+
26
+ def benchmark(self, model):
27
+ with torch.no_grad():
28
+ data_root = self.data_root
29
+ tot_e_t, tot_e_R, tot_e_pose = [], [], []
30
+ for scene_ind in range(len(self.scenes)):
31
+ scene = self.scenes[scene_ind]
32
+ pairs = scene["pair_infos"]
33
+ intrinsics = scene["intrinsics"]
34
+ poses = scene["poses"]
35
+ im_paths = scene["image_paths"]
36
+ pair_inds = range(len(pairs))
37
+ for pairind in tqdm(pair_inds):
38
+ idx1, idx2 = pairs[pairind][0]
39
+ K1 = intrinsics[idx1].copy()
40
+ T1 = poses[idx1].copy()
41
+ R1, t1 = T1[:3, :3], T1[:3, 3]
42
+ K2 = intrinsics[idx2].copy()
43
+ T2 = poses[idx2].copy()
44
+ R2, t2 = T2[:3, :3], T2[:3, 3]
45
+ R, t = compute_relative_pose(R1, t1, R2, t2)
46
+ im1_path = f"{data_root}/{im_paths[idx1]}"
47
+ im2_path = f"{data_root}/{im_paths[idx2]}"
48
+ im1 = Image.open(im1_path)
49
+ w1, h1 = im1.size
50
+ im2 = Image.open(im2_path)
51
+ w2, h2 = im2.size
52
+ scale1 = 1200 / max(w1, h1)
53
+ scale2 = 1200 / max(w2, h2)
54
+ w1, h1 = scale1 * w1, scale1 * h1
55
+ w2, h2 = scale2 * w2, scale2 * h2
56
+ K1[:2] = K1[:2] * scale1
57
+ K2[:2] = K2[:2] * scale2
58
+ dense_matches, dense_certainty = model.match(im1_path, im2_path)
59
+ sparse_matches,_ = model.sample(
60
+ dense_matches, dense_certainty, 5000
61
+ )
62
+ kpts1 = sparse_matches[:, :2]
63
+ kpts1 = (
64
+ torch.stack(
65
+ (
66
+ w1 * (kpts1[:, 0] + 1) / 2,
67
+ h1 * (kpts1[:, 1] + 1) / 2,
68
+ ),
69
+ axis=-1,
70
+ )
71
+ )
72
+ kpts2 = sparse_matches[:, 2:]
73
+ kpts2 = (
74
+ torch.stack(
75
+ (
76
+ w2 * (kpts2[:, 0] + 1) / 2,
77
+ h2 * (kpts2[:, 1] + 1) / 2,
78
+ ),
79
+ axis=-1,
80
+ )
81
+ )
82
+ for _ in range(5):
83
+ shuffling = np.random.permutation(np.arange(len(kpts1)))
84
+ kpts1 = kpts1[shuffling]
85
+ kpts2 = kpts2[shuffling]
86
+ try:
87
+ norm_threshold = 0.5 / (
88
+ np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
89
+ R_est, t_est, mask = estimate_pose(
90
+ kpts1.cpu().numpy(),
91
+ kpts2.cpu().numpy(),
92
+ K1,
93
+ K2,
94
+ norm_threshold,
95
+ conf=0.99999,
96
+ )
97
+ T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
98
+ e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
99
+ e_pose = max(e_t, e_R)
100
+ except Exception as e:
101
+ print(repr(e))
102
+ e_t, e_R = 90, 90
103
+ e_pose = max(e_t, e_R)
104
+ tot_e_t.append(e_t)
105
+ tot_e_R.append(e_R)
106
+ tot_e_pose.append(e_pose)
107
+ tot_e_pose = np.array(tot_e_pose)
108
+ thresholds = [5, 10, 20]
109
+ auc = pose_auc(tot_e_pose, thresholds)
110
+ acc_5 = (tot_e_pose < 5).mean()
111
+ acc_10 = (tot_e_pose < 10).mean()
112
+ acc_15 = (tot_e_pose < 15).mean()
113
+ acc_20 = (tot_e_pose < 20).mean()
114
+ map_5 = acc_5
115
+ map_10 = np.mean([acc_5, acc_10])
116
+ map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
117
+ return {
118
+ "auc_5": auc[0],
119
+ "auc_10": auc[1],
120
+ "auc_20": auc[2],
121
+ "map_5": map_5,
122
+ "map_10": map_10,
123
+ "map_20": map_20,
124
+ }
third_party/DKM/dkm/benchmarks/megadepth_dense_benchmark.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import tqdm
4
+ from dkm.datasets import MegadepthBuilder
5
+ from dkm.utils import warp_kpts
6
+ from torch.utils.data import ConcatDataset
7
+
8
+
9
+ class MegadepthDenseBenchmark:
10
+ def __init__(self, data_root="data/megadepth", h = 384, w = 512, num_samples = 2000, device=None) -> None:
11
+ mega = MegadepthBuilder(data_root=data_root)
12
+ self.dataset = ConcatDataset(
13
+ mega.build_scenes(split="test_loftr", ht=h, wt=w)
14
+ ) # fixed resolution of 384,512
15
+ self.num_samples = num_samples
16
+ if device is None:
17
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
+ self.device = device
19
+
20
+ def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches):
21
+ b, h1, w1, d = dense_matches.shape
22
+ with torch.no_grad():
23
+ x1 = dense_matches[..., :2].reshape(b, h1 * w1, 2)
24
+ # x1 = torch.stack((2*x1[...,0]/w1-1,2*x1[...,1]/h1-1),dim=-1)
25
+ mask, x2 = warp_kpts(
26
+ x1.double(),
27
+ depth1.double(),
28
+ depth2.double(),
29
+ T_1to2.double(),
30
+ K1.double(),
31
+ K2.double(),
32
+ )
33
+ x2 = torch.stack(
34
+ (w1 * (x2[..., 0] + 1) / 2, h1 * (x2[..., 1] + 1) / 2), dim=-1
35
+ )
36
+ prob = mask.float().reshape(b, h1, w1)
37
+ x2_hat = dense_matches[..., 2:]
38
+ x2_hat = torch.stack(
39
+ (w1 * (x2_hat[..., 0] + 1) / 2, h1 * (x2_hat[..., 1] + 1) / 2), dim=-1
40
+ )
41
+ gd = (x2_hat - x2.reshape(b, h1, w1, 2)).norm(dim=-1)
42
+ gd = gd[prob == 1]
43
+ pck_1 = (gd < 1.0).float().mean()
44
+ pck_3 = (gd < 3.0).float().mean()
45
+ pck_5 = (gd < 5.0).float().mean()
46
+ gd = gd.mean()
47
+ return gd, pck_1, pck_3, pck_5
48
+
49
+ def benchmark(self, model, batch_size=8):
50
+ model.train(False)
51
+ with torch.no_grad():
52
+ gd_tot = 0.0
53
+ pck_1_tot = 0.0
54
+ pck_3_tot = 0.0
55
+ pck_5_tot = 0.0
56
+ sampler = torch.utils.data.WeightedRandomSampler(
57
+ torch.ones(len(self.dataset)), replacement=False, num_samples=self.num_samples
58
+ )
59
+ dataloader = torch.utils.data.DataLoader(
60
+ self.dataset, batch_size=8, num_workers=batch_size, sampler=sampler
61
+ )
62
+ for data in tqdm.tqdm(dataloader):
63
+ im1, im2, depth1, depth2, T_1to2, K1, K2 = (
64
+ data["query"],
65
+ data["support"],
66
+ data["query_depth"].to(self.device),
67
+ data["support_depth"].to(self.device),
68
+ data["T_1to2"].to(self.device),
69
+ data["K1"].to(self.device),
70
+ data["K2"].to(self.device),
71
+ )
72
+ matches, certainty = model.match(im1, im2, batched=True)
73
+ gd, pck_1, pck_3, pck_5 = self.geometric_dist(
74
+ depth1, depth2, T_1to2, K1, K2, matches
75
+ )
76
+ gd_tot, pck_1_tot, pck_3_tot, pck_5_tot = (
77
+ gd_tot + gd,
78
+ pck_1_tot + pck_1,
79
+ pck_3_tot + pck_3,
80
+ pck_5_tot + pck_5,
81
+ )
82
+ return {
83
+ "mega_pck_1": pck_1_tot.item() / len(dataloader),
84
+ "mega_pck_3": pck_3_tot.item() / len(dataloader),
85
+ "mega_pck_5": pck_5_tot.item() / len(dataloader),
86
+ }
third_party/DKM/dkm/benchmarks/scannet_benchmark.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import numpy as np
3
+ import torch
4
+ from dkm.utils import *
5
+ from PIL import Image
6
+ from tqdm import tqdm
7
+
8
+
9
+ class ScanNetBenchmark:
10
+ def __init__(self, data_root="data/scannet") -> None:
11
+ self.data_root = data_root
12
+
13
+ def benchmark(self, model, model_name = None):
14
+ model.train(False)
15
+ with torch.no_grad():
16
+ data_root = self.data_root
17
+ tmp = np.load(osp.join(data_root, "test.npz"))
18
+ pairs, rel_pose = tmp["name"], tmp["rel_pose"]
19
+ tot_e_t, tot_e_R, tot_e_pose = [], [], []
20
+ pair_inds = np.random.choice(
21
+ range(len(pairs)), size=len(pairs), replace=False
22
+ )
23
+ for pairind in tqdm(pair_inds, smoothing=0.9):
24
+ scene = pairs[pairind]
25
+ scene_name = f"scene0{scene[0]}_00"
26
+ im1_path = osp.join(
27
+ self.data_root,
28
+ "scans_test",
29
+ scene_name,
30
+ "color",
31
+ f"{scene[2]}.jpg",
32
+ )
33
+ im1 = Image.open(im1_path)
34
+ im2_path = osp.join(
35
+ self.data_root,
36
+ "scans_test",
37
+ scene_name,
38
+ "color",
39
+ f"{scene[3]}.jpg",
40
+ )
41
+ im2 = Image.open(im2_path)
42
+ T_gt = rel_pose[pairind].reshape(3, 4)
43
+ R, t = T_gt[:3, :3], T_gt[:3, 3]
44
+ K = np.stack(
45
+ [
46
+ np.array([float(i) for i in r.split()])
47
+ for r in open(
48
+ osp.join(
49
+ self.data_root,
50
+ "scans_test",
51
+ scene_name,
52
+ "intrinsic",
53
+ "intrinsic_color.txt",
54
+ ),
55
+ "r",
56
+ )
57
+ .read()
58
+ .split("\n")
59
+ if r
60
+ ]
61
+ )
62
+ w1, h1 = im1.size
63
+ w2, h2 = im2.size
64
+ K1 = K.copy()
65
+ K2 = K.copy()
66
+ dense_matches, dense_certainty = model.match(im1_path, im2_path)
67
+ sparse_matches, sparse_certainty = model.sample(
68
+ dense_matches, dense_certainty, 5000
69
+ )
70
+ scale1 = 480 / min(w1, h1)
71
+ scale2 = 480 / min(w2, h2)
72
+ w1, h1 = scale1 * w1, scale1 * h1
73
+ w2, h2 = scale2 * w2, scale2 * h2
74
+ K1 = K1 * scale1
75
+ K2 = K2 * scale2
76
+
77
+ offset = 0.5
78
+ kpts1 = sparse_matches[:, :2]
79
+ kpts1 = (
80
+ np.stack(
81
+ (
82
+ w1 * (kpts1[:, 0] + 1) / 2 - offset,
83
+ h1 * (kpts1[:, 1] + 1) / 2 - offset,
84
+ ),
85
+ axis=-1,
86
+ )
87
+ )
88
+ kpts2 = sparse_matches[:, 2:]
89
+ kpts2 = (
90
+ np.stack(
91
+ (
92
+ w2 * (kpts2[:, 0] + 1) / 2 - offset,
93
+ h2 * (kpts2[:, 1] + 1) / 2 - offset,
94
+ ),
95
+ axis=-1,
96
+ )
97
+ )
98
+ for _ in range(5):
99
+ shuffling = np.random.permutation(np.arange(len(kpts1)))
100
+ kpts1 = kpts1[shuffling]
101
+ kpts2 = kpts2[shuffling]
102
+ try:
103
+ norm_threshold = 0.5 / (
104
+ np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
105
+ R_est, t_est, mask = estimate_pose(
106
+ kpts1,
107
+ kpts2,
108
+ K1,
109
+ K2,
110
+ norm_threshold,
111
+ conf=0.99999,
112
+ )
113
+ T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
114
+ e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
115
+ e_pose = max(e_t, e_R)
116
+ except Exception as e:
117
+ print(repr(e))
118
+ e_t, e_R = 90, 90
119
+ e_pose = max(e_t, e_R)
120
+ tot_e_t.append(e_t)
121
+ tot_e_R.append(e_R)
122
+ tot_e_pose.append(e_pose)
123
+ tot_e_t.append(e_t)
124
+ tot_e_R.append(e_R)
125
+ tot_e_pose.append(e_pose)
126
+ tot_e_pose = np.array(tot_e_pose)
127
+ thresholds = [5, 10, 20]
128
+ auc = pose_auc(tot_e_pose, thresholds)
129
+ acc_5 = (tot_e_pose < 5).mean()
130
+ acc_10 = (tot_e_pose < 10).mean()
131
+ acc_15 = (tot_e_pose < 15).mean()
132
+ acc_20 = (tot_e_pose < 20).mean()
133
+ map_5 = acc_5
134
+ map_10 = np.mean([acc_5, acc_10])
135
+ map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
136
+ return {
137
+ "auc_5": auc[0],
138
+ "auc_10": auc[1],
139
+ "auc_20": auc[2],
140
+ "map_5": map_5,
141
+ "map_10": map_10,
142
+ "map_20": map_20,
143
+ }
third_party/DKM/dkm/checkpointing/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .checkpoint import CheckPoint
third_party/DKM/dkm/checkpointing/checkpoint.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch.nn.parallel.data_parallel import DataParallel
4
+ from torch.nn.parallel.distributed import DistributedDataParallel
5
+ from loguru import logger
6
+
7
+
8
+ class CheckPoint:
9
+ def __init__(self, dir=None, name="tmp"):
10
+ self.name = name
11
+ self.dir = dir
12
+ os.makedirs(self.dir, exist_ok=True)
13
+
14
+ def __call__(
15
+ self,
16
+ model,
17
+ optimizer,
18
+ lr_scheduler,
19
+ n,
20
+ ):
21
+ assert model is not None
22
+ if isinstance(model, (DataParallel, DistributedDataParallel)):
23
+ model = model.module
24
+ states = {
25
+ "model": model.state_dict(),
26
+ "n": n,
27
+ "optimizer": optimizer.state_dict(),
28
+ "lr_scheduler": lr_scheduler.state_dict(),
29
+ }
30
+ torch.save(states, self.dir + self.name + f"_latest.pth")
31
+ logger.info(f"Saved states {list(states.keys())}, at step {n}")
third_party/DKM/dkm/datasets/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .megadepth import MegadepthBuilder
third_party/DKM/dkm/datasets/megadepth.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from PIL import Image
4
+ import h5py
5
+ import numpy as np
6
+ import torch
7
+ from torch.utils.data import Dataset, DataLoader, ConcatDataset
8
+
9
+ from dkm.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
10
+ import torchvision.transforms.functional as tvf
11
+ from dkm.utils.transforms import GeometricSequential
12
+ import kornia.augmentation as K
13
+
14
+
15
+ class MegadepthScene:
16
+ def __init__(
17
+ self,
18
+ data_root,
19
+ scene_info,
20
+ ht=384,
21
+ wt=512,
22
+ min_overlap=0.0,
23
+ shake_t=0,
24
+ rot_prob=0.0,
25
+ normalize=True,
26
+ ) -> None:
27
+ self.data_root = data_root
28
+ self.image_paths = scene_info["image_paths"]
29
+ self.depth_paths = scene_info["depth_paths"]
30
+ self.intrinsics = scene_info["intrinsics"]
31
+ self.poses = scene_info["poses"]
32
+ self.pairs = scene_info["pairs"]
33
+ self.overlaps = scene_info["overlaps"]
34
+ threshold = self.overlaps > min_overlap
35
+ self.pairs = self.pairs[threshold]
36
+ self.overlaps = self.overlaps[threshold]
37
+ if len(self.pairs) > 100000:
38
+ pairinds = np.random.choice(
39
+ np.arange(0, len(self.pairs)), 100000, replace=False
40
+ )
41
+ self.pairs = self.pairs[pairinds]
42
+ self.overlaps = self.overlaps[pairinds]
43
+ # counts, bins = np.histogram(self.overlaps,20)
44
+ # print(counts)
45
+ self.im_transform_ops = get_tuple_transform_ops(
46
+ resize=(ht, wt), normalize=normalize
47
+ )
48
+ self.depth_transform_ops = get_depth_tuple_transform_ops(
49
+ resize=(ht, wt), normalize=False
50
+ )
51
+ self.wt, self.ht = wt, ht
52
+ self.shake_t = shake_t
53
+ self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob))
54
+
55
+ def load_im(self, im_ref, crop=None):
56
+ im = Image.open(im_ref)
57
+ return im
58
+
59
+ def load_depth(self, depth_ref, crop=None):
60
+ depth = np.array(h5py.File(depth_ref, "r")["depth"])
61
+ return torch.from_numpy(depth)
62
+
63
+ def __len__(self):
64
+ return len(self.pairs)
65
+
66
+ def scale_intrinsic(self, K, wi, hi):
67
+ sx, sy = self.wt / wi, self.ht / hi
68
+ sK = torch.tensor([[sx, 0, 0], [0, sy, 0], [0, 0, 1]])
69
+ return sK @ K
70
+
71
+ def rand_shake(self, *things):
72
+ t = np.random.choice(range(-self.shake_t, self.shake_t + 1), size=2)
73
+ return [
74
+ tvf.affine(thing, angle=0.0, translate=list(t), scale=1.0, shear=[0.0, 0.0])
75
+ for thing in things
76
+ ], t
77
+
78
+ def __getitem__(self, pair_idx):
79
+ # read intrinsics of original size
80
+ idx1, idx2 = self.pairs[pair_idx]
81
+ K1 = torch.tensor(self.intrinsics[idx1].copy(), dtype=torch.float).reshape(3, 3)
82
+ K2 = torch.tensor(self.intrinsics[idx2].copy(), dtype=torch.float).reshape(3, 3)
83
+
84
+ # read and compute relative poses
85
+ T1 = self.poses[idx1]
86
+ T2 = self.poses[idx2]
87
+ T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[
88
+ :4, :4
89
+ ] # (4, 4)
90
+
91
+ # Load positive pair data
92
+ im1, im2 = self.image_paths[idx1], self.image_paths[idx2]
93
+ depth1, depth2 = self.depth_paths[idx1], self.depth_paths[idx2]
94
+ im_src_ref = os.path.join(self.data_root, im1)
95
+ im_pos_ref = os.path.join(self.data_root, im2)
96
+ depth_src_ref = os.path.join(self.data_root, depth1)
97
+ depth_pos_ref = os.path.join(self.data_root, depth2)
98
+ # return torch.randn((1000,1000))
99
+ im_src = self.load_im(im_src_ref)
100
+ im_pos = self.load_im(im_pos_ref)
101
+ depth_src = self.load_depth(depth_src_ref)
102
+ depth_pos = self.load_depth(depth_pos_ref)
103
+
104
+ # Recompute camera intrinsic matrix due to the resize
105
+ K1 = self.scale_intrinsic(K1, im_src.width, im_src.height)
106
+ K2 = self.scale_intrinsic(K2, im_pos.width, im_pos.height)
107
+ # Process images
108
+ im_src, im_pos = self.im_transform_ops((im_src, im_pos))
109
+ depth_src, depth_pos = self.depth_transform_ops(
110
+ (depth_src[None, None], depth_pos[None, None])
111
+ )
112
+ [im_src, im_pos, depth_src, depth_pos], t = self.rand_shake(
113
+ im_src, im_pos, depth_src, depth_pos
114
+ )
115
+ im_src, Hq = self.H_generator(im_src[None])
116
+ depth_src = self.H_generator.apply_transform(depth_src, Hq)
117
+ K1[:2, 2] += t
118
+ K2[:2, 2] += t
119
+ K1 = Hq[0] @ K1
120
+ data_dict = {
121
+ "query": im_src[0],
122
+ "query_identifier": self.image_paths[idx1].split("/")[-1].split(".jpg")[0],
123
+ "support": im_pos,
124
+ "support_identifier": self.image_paths[idx2]
125
+ .split("/")[-1]
126
+ .split(".jpg")[0],
127
+ "query_depth": depth_src[0, 0],
128
+ "support_depth": depth_pos[0, 0],
129
+ "K1": K1,
130
+ "K2": K2,
131
+ "T_1to2": T_1to2,
132
+ }
133
+ return data_dict
134
+
135
+
136
+ class MegadepthBuilder:
137
+ def __init__(self, data_root="data/megadepth") -> None:
138
+ self.data_root = data_root
139
+ self.scene_info_root = os.path.join(data_root, "prep_scene_info")
140
+ self.all_scenes = os.listdir(self.scene_info_root)
141
+ self.test_scenes = ["0017.npy", "0004.npy", "0048.npy", "0013.npy"]
142
+ self.test_scenes_loftr = ["0015.npy", "0022.npy"]
143
+
144
+ def build_scenes(self, split="train", min_overlap=0.0, **kwargs):
145
+ if split == "train":
146
+ scene_names = set(self.all_scenes) - set(self.test_scenes)
147
+ elif split == "train_loftr":
148
+ scene_names = set(self.all_scenes) - set(self.test_scenes_loftr)
149
+ elif split == "test":
150
+ scene_names = self.test_scenes
151
+ elif split == "test_loftr":
152
+ scene_names = self.test_scenes_loftr
153
+ else:
154
+ raise ValueError(f"Split {split} not available")
155
+ scenes = []
156
+ for scene_name in scene_names:
157
+ scene_info = np.load(
158
+ os.path.join(self.scene_info_root, scene_name), allow_pickle=True
159
+ ).item()
160
+ scenes.append(
161
+ MegadepthScene(
162
+ self.data_root, scene_info, min_overlap=min_overlap, **kwargs
163
+ )
164
+ )
165
+ return scenes
166
+
167
+ def weight_scenes(self, concat_dataset, alpha=0.5):
168
+ ns = []
169
+ for d in concat_dataset.datasets:
170
+ ns.append(len(d))
171
+ ws = torch.cat([torch.ones(n) / n**alpha for n in ns])
172
+ return ws
173
+
174
+
175
+ if __name__ == "__main__":
176
+ mega_test = ConcatDataset(MegadepthBuilder().build_scenes(split="train"))
177
+ mega_test[0]
third_party/DKM/dkm/datasets/scannet.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from PIL import Image
4
+ import cv2
5
+ import h5py
6
+ import numpy as np
7
+ import torch
8
+ from torch.utils.data import (
9
+ Dataset,
10
+ DataLoader,
11
+ ConcatDataset)
12
+
13
+ import torchvision.transforms.functional as tvf
14
+ import kornia.augmentation as K
15
+ import os.path as osp
16
+ import matplotlib.pyplot as plt
17
+ from dkm.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
18
+ from dkm.utils.transforms import GeometricSequential
19
+
20
+ from tqdm import tqdm
21
+
22
+ class ScanNetScene:
23
+ def __init__(self, data_root, scene_info, ht = 384, wt = 512, min_overlap=0., shake_t = 0, rot_prob=0.) -> None:
24
+ self.scene_root = osp.join(data_root,"scans","scans_train")
25
+ self.data_names = scene_info['name']
26
+ self.overlaps = scene_info['score']
27
+ # Only sample 10s
28
+ valid = (self.data_names[:,-2:] % 10).sum(axis=-1) == 0
29
+ self.overlaps = self.overlaps[valid]
30
+ self.data_names = self.data_names[valid]
31
+ if len(self.data_names) > 10000:
32
+ pairinds = np.random.choice(np.arange(0,len(self.data_names)),10000,replace=False)
33
+ self.data_names = self.data_names[pairinds]
34
+ self.overlaps = self.overlaps[pairinds]
35
+ self.im_transform_ops = get_tuple_transform_ops(resize=(ht, wt), normalize=True)
36
+ self.depth_transform_ops = get_depth_tuple_transform_ops(resize=(ht, wt), normalize=False)
37
+ self.wt, self.ht = wt, ht
38
+ self.shake_t = shake_t
39
+ self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob))
40
+
41
+ def load_im(self, im_ref, crop=None):
42
+ im = Image.open(im_ref)
43
+ return im
44
+
45
+ def load_depth(self, depth_ref, crop=None):
46
+ depth = cv2.imread(str(depth_ref), cv2.IMREAD_UNCHANGED)
47
+ depth = depth / 1000
48
+ depth = torch.from_numpy(depth).float() # (h, w)
49
+ return depth
50
+
51
+ def __len__(self):
52
+ return len(self.data_names)
53
+
54
+ def scale_intrinsic(self, K, wi, hi):
55
+ sx, sy = self.wt / wi, self.ht / hi
56
+ sK = torch.tensor([[sx, 0, 0],
57
+ [0, sy, 0],
58
+ [0, 0, 1]])
59
+ return sK@K
60
+
61
+ def read_scannet_pose(self,path):
62
+ """ Read ScanNet's Camera2World pose and transform it to World2Camera.
63
+
64
+ Returns:
65
+ pose_w2c (np.ndarray): (4, 4)
66
+ """
67
+ cam2world = np.loadtxt(path, delimiter=' ')
68
+ world2cam = np.linalg.inv(cam2world)
69
+ return world2cam
70
+
71
+
72
+ def read_scannet_intrinsic(self,path):
73
+ """ Read ScanNet's intrinsic matrix and return the 3x3 matrix.
74
+ """
75
+ intrinsic = np.loadtxt(path, delimiter=' ')
76
+ return intrinsic[:-1, :-1]
77
+
78
+ def __getitem__(self, pair_idx):
79
+ # read intrinsics of original size
80
+ data_name = self.data_names[pair_idx]
81
+ scene_name, scene_sub_name, stem_name_1, stem_name_2 = data_name
82
+ scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}'
83
+
84
+ # read the intrinsic of depthmap
85
+ K1 = K2 = self.read_scannet_intrinsic(osp.join(self.scene_root,
86
+ scene_name,
87
+ 'intrinsic', 'intrinsic_color.txt'))#the depth K is not the same, but doesnt really matter
88
+ # read and compute relative poses
89
+ T1 = self.read_scannet_pose(osp.join(self.scene_root,
90
+ scene_name,
91
+ 'pose', f'{stem_name_1}.txt'))
92
+ T2 = self.read_scannet_pose(osp.join(self.scene_root,
93
+ scene_name,
94
+ 'pose', f'{stem_name_2}.txt'))
95
+ T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[:4, :4] # (4, 4)
96
+
97
+ # Load positive pair data
98
+ im_src_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_1}.jpg')
99
+ im_pos_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_2}.jpg')
100
+ depth_src_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_1}.png')
101
+ depth_pos_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_2}.png')
102
+
103
+ im_src = self.load_im(im_src_ref)
104
+ im_pos = self.load_im(im_pos_ref)
105
+ depth_src = self.load_depth(depth_src_ref)
106
+ depth_pos = self.load_depth(depth_pos_ref)
107
+
108
+ # Recompute camera intrinsic matrix due to the resize
109
+ K1 = self.scale_intrinsic(K1, im_src.width, im_src.height)
110
+ K2 = self.scale_intrinsic(K2, im_pos.width, im_pos.height)
111
+ # Process images
112
+ im_src, im_pos = self.im_transform_ops((im_src, im_pos))
113
+ depth_src, depth_pos = self.depth_transform_ops((depth_src[None,None], depth_pos[None,None]))
114
+
115
+ data_dict = {'query': im_src,
116
+ 'support': im_pos,
117
+ 'query_depth': depth_src[0,0],
118
+ 'support_depth': depth_pos[0,0],
119
+ 'K1': K1,
120
+ 'K2': K2,
121
+ 'T_1to2':T_1to2,
122
+ }
123
+ return data_dict
124
+
125
+
126
+ class ScanNetBuilder:
127
+ def __init__(self, data_root = 'data/scannet') -> None:
128
+ self.data_root = data_root
129
+ self.scene_info_root = os.path.join(data_root,'scannet_indices')
130
+ self.all_scenes = os.listdir(self.scene_info_root)
131
+
132
+ def build_scenes(self, split = 'train', min_overlap=0., **kwargs):
133
+ # Note: split doesn't matter here as we always use same scannet_train scenes
134
+ scene_names = self.all_scenes
135
+ scenes = []
136
+ for scene_name in tqdm(scene_names):
137
+ scene_info = np.load(os.path.join(self.scene_info_root,scene_name), allow_pickle=True)
138
+ scenes.append(ScanNetScene(self.data_root, scene_info, min_overlap=min_overlap, **kwargs))
139
+ return scenes
140
+
141
+ def weight_scenes(self, concat_dataset, alpha=.5):
142
+ ns = []
143
+ for d in concat_dataset.datasets:
144
+ ns.append(len(d))
145
+ ws = torch.cat([torch.ones(n)/n**alpha for n in ns])
146
+ return ws
147
+
148
+
149
+ if __name__ == "__main__":
150
+ mega_test = ConcatDataset(ScanNetBuilder("data/scannet").build_scenes(split='train'))
151
+ mega_test[0]
third_party/DKM/dkm/losses/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .depth_match_regression_loss import DepthRegressionLoss
third_party/DKM/dkm/losses/depth_match_regression_loss.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops.einops import rearrange
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from dkm.utils.utils import warp_kpts
6
+
7
+
8
+ class DepthRegressionLoss(nn.Module):
9
+ def __init__(
10
+ self,
11
+ robust=True,
12
+ center_coords=False,
13
+ scale_normalize=False,
14
+ ce_weight=0.01,
15
+ local_loss=True,
16
+ local_dist=4.0,
17
+ local_largest_scale=8,
18
+ ):
19
+ super().__init__()
20
+ self.robust = robust # measured in pixels
21
+ self.center_coords = center_coords
22
+ self.scale_normalize = scale_normalize
23
+ self.ce_weight = ce_weight
24
+ self.local_loss = local_loss
25
+ self.local_dist = local_dist
26
+ self.local_largest_scale = local_largest_scale
27
+
28
+ def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches, scale):
29
+ """[summary]
30
+
31
+ Args:
32
+ H ([type]): [description]
33
+ scale ([type]): [description]
34
+
35
+ Returns:
36
+ [type]: [description]
37
+ """
38
+ b, h1, w1, d = dense_matches.shape
39
+ with torch.no_grad():
40
+ x1_n = torch.meshgrid(
41
+ *[
42
+ torch.linspace(
43
+ -1 + 1 / n, 1 - 1 / n, n, device=dense_matches.device
44
+ )
45
+ for n in (b, h1, w1)
46
+ ]
47
+ )
48
+ x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(b, h1 * w1, 2)
49
+ mask, x2 = warp_kpts(
50
+ x1_n.double(),
51
+ depth1.double(),
52
+ depth2.double(),
53
+ T_1to2.double(),
54
+ K1.double(),
55
+ K2.double(),
56
+ )
57
+ prob = mask.float().reshape(b, h1, w1)
58
+ gd = (dense_matches - x2.reshape(b, h1, w1, 2)).norm(dim=-1) # *scale?
59
+ return gd, prob
60
+
61
+ def dense_depth_loss(self, dense_certainty, prob, gd, scale, eps=1e-8):
62
+ """[summary]
63
+
64
+ Args:
65
+ dense_certainty ([type]): [description]
66
+ prob ([type]): [description]
67
+ eps ([type], optional): [description]. Defaults to 1e-8.
68
+
69
+ Returns:
70
+ [type]: [description]
71
+ """
72
+ smooth_prob = prob
73
+ ce_loss = F.binary_cross_entropy_with_logits(dense_certainty[:, 0], smooth_prob)
74
+ depth_loss = gd[prob > 0]
75
+ if not torch.any(prob > 0).item():
76
+ depth_loss = (gd * 0.0).mean() # Prevent issues where prob is 0 everywhere
77
+ return {
78
+ f"ce_loss_{scale}": ce_loss.mean(),
79
+ f"depth_loss_{scale}": depth_loss.mean(),
80
+ }
81
+
82
+ def forward(self, dense_corresps, batch):
83
+ """[summary]
84
+
85
+ Args:
86
+ out ([type]): [description]
87
+ batch ([type]): [description]
88
+
89
+ Returns:
90
+ [type]: [description]
91
+ """
92
+ scales = list(dense_corresps.keys())
93
+ tot_loss = 0.0
94
+ prev_gd = 0.0
95
+ for scale in scales:
96
+ dense_scale_corresps = dense_corresps[scale]
97
+ dense_scale_certainty, dense_scale_coords = (
98
+ dense_scale_corresps["dense_certainty"],
99
+ dense_scale_corresps["dense_flow"],
100
+ )
101
+ dense_scale_coords = rearrange(dense_scale_coords, "b d h w -> b h w d")
102
+ b, h, w, d = dense_scale_coords.shape
103
+ gd, prob = self.geometric_dist(
104
+ batch["query_depth"],
105
+ batch["support_depth"],
106
+ batch["T_1to2"],
107
+ batch["K1"],
108
+ batch["K2"],
109
+ dense_scale_coords,
110
+ scale,
111
+ )
112
+ if (
113
+ scale <= self.local_largest_scale and self.local_loss
114
+ ): # Thought here is that fine matching loss should not be punished by coarse mistakes, but should identify wrong matching
115
+ prob = prob * (
116
+ F.interpolate(prev_gd[:, None], size=(h, w), mode="nearest")[:, 0]
117
+ < (2 / 512) * (self.local_dist * scale)
118
+ )
119
+ depth_losses = self.dense_depth_loss(dense_scale_certainty, prob, gd, scale)
120
+ scale_loss = (
121
+ self.ce_weight * depth_losses[f"ce_loss_{scale}"]
122
+ + depth_losses[f"depth_loss_{scale}"]
123
+ ) # scale ce loss for coarser scales
124
+ if self.scale_normalize:
125
+ scale_loss = scale_loss * 1 / scale
126
+ tot_loss = tot_loss + scale_loss
127
+ prev_gd = gd.detach()
128
+ return tot_loss
third_party/DKM/dkm/models/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .model_zoo import (
2
+ DKMv3_outdoor,
3
+ DKMv3_indoor,
4
+ )
third_party/DKM/dkm/models/deprecated/build_model.py ADDED
@@ -0,0 +1,787 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from dkm import *
4
+ from .local_corr import LocalCorr
5
+ from .corr_channels import NormedCorr
6
+ from torchvision.models import resnet as tv_resnet
7
+
8
+ dkm_pretrained_urls = {
9
+ "DKM": {
10
+ "mega_synthetic": "https://github.com/Parskatt/storage/releases/download/dkm_mega_synthetic/dkm_mega_synthetic.pth",
11
+ "mega": "https://github.com/Parskatt/storage/releases/download/dkm_mega/dkm_mega.pth",
12
+ },
13
+ "DKMv2":{
14
+ "outdoor": "https://github.com/Parskatt/storage/releases/download/dkmv2/dkm_v2_outdoor.pth",
15
+ "indoor": "https://github.com/Parskatt/storage/releases/download/dkmv2/dkm_v2_indoor.pth",
16
+ }
17
+ }
18
+
19
+
20
+ def DKM(pretrained=True, version="mega_synthetic", device=None):
21
+ if device is None:
22
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
+ gp_dim = 256
24
+ dfn_dim = 384
25
+ feat_dim = 256
26
+ coordinate_decoder = DFN(
27
+ internal_dim=dfn_dim,
28
+ feat_input_modules=nn.ModuleDict(
29
+ {
30
+ "32": nn.Conv2d(512, feat_dim, 1, 1),
31
+ "16": nn.Conv2d(512, feat_dim, 1, 1),
32
+ }
33
+ ),
34
+ pred_input_modules=nn.ModuleDict(
35
+ {
36
+ "32": nn.Identity(),
37
+ "16": nn.Identity(),
38
+ }
39
+ ),
40
+ rrb_d_dict=nn.ModuleDict(
41
+ {
42
+ "32": RRB(gp_dim + feat_dim, dfn_dim),
43
+ "16": RRB(gp_dim + feat_dim, dfn_dim),
44
+ }
45
+ ),
46
+ cab_dict=nn.ModuleDict(
47
+ {
48
+ "32": CAB(2 * dfn_dim, dfn_dim),
49
+ "16": CAB(2 * dfn_dim, dfn_dim),
50
+ }
51
+ ),
52
+ rrb_u_dict=nn.ModuleDict(
53
+ {
54
+ "32": RRB(dfn_dim, dfn_dim),
55
+ "16": RRB(dfn_dim, dfn_dim),
56
+ }
57
+ ),
58
+ terminal_module=nn.ModuleDict(
59
+ {
60
+ "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
61
+ "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
62
+ }
63
+ ),
64
+ )
65
+ dw = True
66
+ hidden_blocks = 8
67
+ kernel_size = 5
68
+ conv_refiner = nn.ModuleDict(
69
+ {
70
+ "16": ConvRefiner(
71
+ 2 * 512,
72
+ 1024,
73
+ 3,
74
+ kernel_size=kernel_size,
75
+ dw=dw,
76
+ hidden_blocks=hidden_blocks,
77
+ ),
78
+ "8": ConvRefiner(
79
+ 2 * 512,
80
+ 1024,
81
+ 3,
82
+ kernel_size=kernel_size,
83
+ dw=dw,
84
+ hidden_blocks=hidden_blocks,
85
+ ),
86
+ "4": ConvRefiner(
87
+ 2 * 256,
88
+ 512,
89
+ 3,
90
+ kernel_size=kernel_size,
91
+ dw=dw,
92
+ hidden_blocks=hidden_blocks,
93
+ ),
94
+ "2": ConvRefiner(
95
+ 2 * 64,
96
+ 128,
97
+ 3,
98
+ kernel_size=kernel_size,
99
+ dw=dw,
100
+ hidden_blocks=hidden_blocks,
101
+ ),
102
+ "1": ConvRefiner(
103
+ 2 * 3,
104
+ 24,
105
+ 3,
106
+ kernel_size=kernel_size,
107
+ dw=dw,
108
+ hidden_blocks=hidden_blocks,
109
+ ),
110
+ }
111
+ )
112
+ kernel_temperature = 0.2
113
+ learn_temperature = False
114
+ no_cov = True
115
+ kernel = CosKernel
116
+ only_attention = False
117
+ basis = "fourier"
118
+ gp32 = GP(
119
+ kernel,
120
+ T=kernel_temperature,
121
+ learn_temperature=learn_temperature,
122
+ only_attention=only_attention,
123
+ gp_dim=gp_dim,
124
+ basis=basis,
125
+ no_cov=no_cov,
126
+ )
127
+ gp16 = GP(
128
+ kernel,
129
+ T=kernel_temperature,
130
+ learn_temperature=learn_temperature,
131
+ only_attention=only_attention,
132
+ gp_dim=gp_dim,
133
+ basis=basis,
134
+ no_cov=no_cov,
135
+ )
136
+ gps = nn.ModuleDict({"32": gp32, "16": gp16})
137
+ proj = nn.ModuleDict(
138
+ {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
139
+ )
140
+ decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
141
+ h, w = 384, 512
142
+ encoder = Encoder(
143
+ tv_resnet.resnet50(pretrained=not pretrained),
144
+ ) # only load pretrained weights if not loading a pretrained matcher ;)
145
+ matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device)
146
+ if pretrained:
147
+ weights = torch.hub.load_state_dict_from_url(
148
+ dkm_pretrained_urls["DKM"][version]
149
+ )
150
+ matcher.load_state_dict(weights)
151
+ return matcher
152
+
153
+ def DKMv2(pretrained=True, version="outdoor", resolution = "low", **kwargs):
154
+ gp_dim = 256
155
+ dfn_dim = 384
156
+ feat_dim = 256
157
+ coordinate_decoder = DFN(
158
+ internal_dim=dfn_dim,
159
+ feat_input_modules=nn.ModuleDict(
160
+ {
161
+ "32": nn.Conv2d(512, feat_dim, 1, 1),
162
+ "16": nn.Conv2d(512, feat_dim, 1, 1),
163
+ }
164
+ ),
165
+ pred_input_modules=nn.ModuleDict(
166
+ {
167
+ "32": nn.Identity(),
168
+ "16": nn.Identity(),
169
+ }
170
+ ),
171
+ rrb_d_dict=nn.ModuleDict(
172
+ {
173
+ "32": RRB(gp_dim + feat_dim, dfn_dim),
174
+ "16": RRB(gp_dim + feat_dim, dfn_dim),
175
+ }
176
+ ),
177
+ cab_dict=nn.ModuleDict(
178
+ {
179
+ "32": CAB(2 * dfn_dim, dfn_dim),
180
+ "16": CAB(2 * dfn_dim, dfn_dim),
181
+ }
182
+ ),
183
+ rrb_u_dict=nn.ModuleDict(
184
+ {
185
+ "32": RRB(dfn_dim, dfn_dim),
186
+ "16": RRB(dfn_dim, dfn_dim),
187
+ }
188
+ ),
189
+ terminal_module=nn.ModuleDict(
190
+ {
191
+ "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
192
+ "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
193
+ }
194
+ ),
195
+ )
196
+ dw = True
197
+ hidden_blocks = 8
198
+ kernel_size = 5
199
+ displacement_emb = "linear"
200
+ conv_refiner = nn.ModuleDict(
201
+ {
202
+ "16": ConvRefiner(
203
+ 2 * 512+128,
204
+ 1024+128,
205
+ 3,
206
+ kernel_size=kernel_size,
207
+ dw=dw,
208
+ hidden_blocks=hidden_blocks,
209
+ displacement_emb=displacement_emb,
210
+ displacement_emb_dim=128,
211
+ ),
212
+ "8": ConvRefiner(
213
+ 2 * 512+64,
214
+ 1024+64,
215
+ 3,
216
+ kernel_size=kernel_size,
217
+ dw=dw,
218
+ hidden_blocks=hidden_blocks,
219
+ displacement_emb=displacement_emb,
220
+ displacement_emb_dim=64,
221
+ ),
222
+ "4": ConvRefiner(
223
+ 2 * 256+32,
224
+ 512+32,
225
+ 3,
226
+ kernel_size=kernel_size,
227
+ dw=dw,
228
+ hidden_blocks=hidden_blocks,
229
+ displacement_emb=displacement_emb,
230
+ displacement_emb_dim=32,
231
+ ),
232
+ "2": ConvRefiner(
233
+ 2 * 64+16,
234
+ 128+16,
235
+ 3,
236
+ kernel_size=kernel_size,
237
+ dw=dw,
238
+ hidden_blocks=hidden_blocks,
239
+ displacement_emb=displacement_emb,
240
+ displacement_emb_dim=16,
241
+ ),
242
+ "1": ConvRefiner(
243
+ 2 * 3+6,
244
+ 24,
245
+ 3,
246
+ kernel_size=kernel_size,
247
+ dw=dw,
248
+ hidden_blocks=hidden_blocks,
249
+ displacement_emb=displacement_emb,
250
+ displacement_emb_dim=6,
251
+ ),
252
+ }
253
+ )
254
+ kernel_temperature = 0.2
255
+ learn_temperature = False
256
+ no_cov = True
257
+ kernel = CosKernel
258
+ only_attention = False
259
+ basis = "fourier"
260
+ gp32 = GP(
261
+ kernel,
262
+ T=kernel_temperature,
263
+ learn_temperature=learn_temperature,
264
+ only_attention=only_attention,
265
+ gp_dim=gp_dim,
266
+ basis=basis,
267
+ no_cov=no_cov,
268
+ )
269
+ gp16 = GP(
270
+ kernel,
271
+ T=kernel_temperature,
272
+ learn_temperature=learn_temperature,
273
+ only_attention=only_attention,
274
+ gp_dim=gp_dim,
275
+ basis=basis,
276
+ no_cov=no_cov,
277
+ )
278
+ gps = nn.ModuleDict({"32": gp32, "16": gp16})
279
+ proj = nn.ModuleDict(
280
+ {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
281
+ )
282
+ decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
283
+ if resolution == "low":
284
+ h, w = 384, 512
285
+ elif resolution == "high":
286
+ h, w = 480, 640
287
+ encoder = Encoder(
288
+ tv_resnet.resnet50(pretrained=not pretrained),
289
+ ) # only load pretrained weights if not loading a pretrained matcher ;)
290
+ matcher = RegressionMatcher(encoder, decoder, h=h, w=w,**kwargs).to(device)
291
+ if pretrained:
292
+ try:
293
+ weights = torch.hub.load_state_dict_from_url(
294
+ dkm_pretrained_urls["DKMv2"][version]
295
+ )
296
+ except:
297
+ weights = torch.load(
298
+ dkm_pretrained_urls["DKMv2"][version]
299
+ )
300
+ matcher.load_state_dict(weights)
301
+ return matcher
302
+
303
+
304
+ def local_corr(pretrained=True, version="mega_synthetic"):
305
+ gp_dim = 256
306
+ dfn_dim = 384
307
+ feat_dim = 256
308
+ coordinate_decoder = DFN(
309
+ internal_dim=dfn_dim,
310
+ feat_input_modules=nn.ModuleDict(
311
+ {
312
+ "32": nn.Conv2d(512, feat_dim, 1, 1),
313
+ "16": nn.Conv2d(512, feat_dim, 1, 1),
314
+ }
315
+ ),
316
+ pred_input_modules=nn.ModuleDict(
317
+ {
318
+ "32": nn.Identity(),
319
+ "16": nn.Identity(),
320
+ }
321
+ ),
322
+ rrb_d_dict=nn.ModuleDict(
323
+ {
324
+ "32": RRB(gp_dim + feat_dim, dfn_dim),
325
+ "16": RRB(gp_dim + feat_dim, dfn_dim),
326
+ }
327
+ ),
328
+ cab_dict=nn.ModuleDict(
329
+ {
330
+ "32": CAB(2 * dfn_dim, dfn_dim),
331
+ "16": CAB(2 * dfn_dim, dfn_dim),
332
+ }
333
+ ),
334
+ rrb_u_dict=nn.ModuleDict(
335
+ {
336
+ "32": RRB(dfn_dim, dfn_dim),
337
+ "16": RRB(dfn_dim, dfn_dim),
338
+ }
339
+ ),
340
+ terminal_module=nn.ModuleDict(
341
+ {
342
+ "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
343
+ "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
344
+ }
345
+ ),
346
+ )
347
+ dw = True
348
+ hidden_blocks = 8
349
+ kernel_size = 5
350
+ conv_refiner = nn.ModuleDict(
351
+ {
352
+ "16": LocalCorr(
353
+ 81,
354
+ 81 * 12,
355
+ 3,
356
+ kernel_size=kernel_size,
357
+ dw=dw,
358
+ hidden_blocks=hidden_blocks,
359
+ ),
360
+ "8": LocalCorr(
361
+ 81,
362
+ 81 * 12,
363
+ 3,
364
+ kernel_size=kernel_size,
365
+ dw=dw,
366
+ hidden_blocks=hidden_blocks,
367
+ ),
368
+ "4": LocalCorr(
369
+ 81,
370
+ 81 * 6,
371
+ 3,
372
+ kernel_size=kernel_size,
373
+ dw=dw,
374
+ hidden_blocks=hidden_blocks,
375
+ ),
376
+ "2": LocalCorr(
377
+ 81,
378
+ 81,
379
+ 3,
380
+ kernel_size=kernel_size,
381
+ dw=dw,
382
+ hidden_blocks=hidden_blocks,
383
+ ),
384
+ "1": ConvRefiner(
385
+ 2 * 3,
386
+ 24,
387
+ 3,
388
+ kernel_size=kernel_size,
389
+ dw=dw,
390
+ hidden_blocks=hidden_blocks,
391
+ ),
392
+ }
393
+ )
394
+ kernel_temperature = 0.2
395
+ learn_temperature = False
396
+ no_cov = True
397
+ kernel = CosKernel
398
+ only_attention = False
399
+ basis = "fourier"
400
+ gp32 = GP(
401
+ kernel,
402
+ T=kernel_temperature,
403
+ learn_temperature=learn_temperature,
404
+ only_attention=only_attention,
405
+ gp_dim=gp_dim,
406
+ basis=basis,
407
+ no_cov=no_cov,
408
+ )
409
+ gp16 = GP(
410
+ kernel,
411
+ T=kernel_temperature,
412
+ learn_temperature=learn_temperature,
413
+ only_attention=only_attention,
414
+ gp_dim=gp_dim,
415
+ basis=basis,
416
+ no_cov=no_cov,
417
+ )
418
+ gps = nn.ModuleDict({"32": gp32, "16": gp16})
419
+ proj = nn.ModuleDict(
420
+ {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
421
+ )
422
+ decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
423
+ h, w = 384, 512
424
+ encoder = Encoder(
425
+ tv_resnet.resnet50(pretrained=not pretrained)
426
+ ) # only load pretrained weights if not loading a pretrained matcher ;)
427
+ matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device)
428
+ if pretrained:
429
+ weights = torch.hub.load_state_dict_from_url(
430
+ dkm_pretrained_urls["local_corr"][version]
431
+ )
432
+ matcher.load_state_dict(weights)
433
+ return matcher
434
+
435
+
436
+ def corr_channels(pretrained=True, version="mega_synthetic"):
437
+ h, w = 384, 512
438
+ gp_dim = (h // 32) * (w // 32), (h // 16) * (w // 16)
439
+ dfn_dim = 384
440
+ feat_dim = 256
441
+ coordinate_decoder = DFN(
442
+ internal_dim=dfn_dim,
443
+ feat_input_modules=nn.ModuleDict(
444
+ {
445
+ "32": nn.Conv2d(512, feat_dim, 1, 1),
446
+ "16": nn.Conv2d(512, feat_dim, 1, 1),
447
+ }
448
+ ),
449
+ pred_input_modules=nn.ModuleDict(
450
+ {
451
+ "32": nn.Identity(),
452
+ "16": nn.Identity(),
453
+ }
454
+ ),
455
+ rrb_d_dict=nn.ModuleDict(
456
+ {
457
+ "32": RRB(gp_dim[0] + feat_dim, dfn_dim),
458
+ "16": RRB(gp_dim[1] + feat_dim, dfn_dim),
459
+ }
460
+ ),
461
+ cab_dict=nn.ModuleDict(
462
+ {
463
+ "32": CAB(2 * dfn_dim, dfn_dim),
464
+ "16": CAB(2 * dfn_dim, dfn_dim),
465
+ }
466
+ ),
467
+ rrb_u_dict=nn.ModuleDict(
468
+ {
469
+ "32": RRB(dfn_dim, dfn_dim),
470
+ "16": RRB(dfn_dim, dfn_dim),
471
+ }
472
+ ),
473
+ terminal_module=nn.ModuleDict(
474
+ {
475
+ "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
476
+ "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
477
+ }
478
+ ),
479
+ )
480
+ dw = True
481
+ hidden_blocks = 8
482
+ kernel_size = 5
483
+ conv_refiner = nn.ModuleDict(
484
+ {
485
+ "16": ConvRefiner(
486
+ 2 * 512,
487
+ 1024,
488
+ 3,
489
+ kernel_size=kernel_size,
490
+ dw=dw,
491
+ hidden_blocks=hidden_blocks,
492
+ ),
493
+ "8": ConvRefiner(
494
+ 2 * 512,
495
+ 1024,
496
+ 3,
497
+ kernel_size=kernel_size,
498
+ dw=dw,
499
+ hidden_blocks=hidden_blocks,
500
+ ),
501
+ "4": ConvRefiner(
502
+ 2 * 256,
503
+ 512,
504
+ 3,
505
+ kernel_size=kernel_size,
506
+ dw=dw,
507
+ hidden_blocks=hidden_blocks,
508
+ ),
509
+ "2": ConvRefiner(
510
+ 2 * 64,
511
+ 128,
512
+ 3,
513
+ kernel_size=kernel_size,
514
+ dw=dw,
515
+ hidden_blocks=hidden_blocks,
516
+ ),
517
+ "1": ConvRefiner(
518
+ 2 * 3,
519
+ 24,
520
+ 3,
521
+ kernel_size=kernel_size,
522
+ dw=dw,
523
+ hidden_blocks=hidden_blocks,
524
+ ),
525
+ }
526
+ )
527
+ gp32 = NormedCorr()
528
+ gp16 = NormedCorr()
529
+ gps = nn.ModuleDict({"32": gp32, "16": gp16})
530
+ proj = nn.ModuleDict(
531
+ {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
532
+ )
533
+ decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
534
+ h, w = 384, 512
535
+ encoder = Encoder(
536
+ tv_resnet.resnet50(pretrained=not pretrained)
537
+ ) # only load pretrained weights if not loading a pretrained matcher ;)
538
+ matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device)
539
+ if pretrained:
540
+ weights = torch.hub.load_state_dict_from_url(
541
+ dkm_pretrained_urls["corr_channels"][version]
542
+ )
543
+ matcher.load_state_dict(weights)
544
+ return matcher
545
+
546
+
547
+ def baseline(pretrained=True, version="mega_synthetic"):
548
+ h, w = 384, 512
549
+ gp_dim = (h // 32) * (w // 32), (h // 16) * (w // 16)
550
+ dfn_dim = 384
551
+ feat_dim = 256
552
+ coordinate_decoder = DFN(
553
+ internal_dim=dfn_dim,
554
+ feat_input_modules=nn.ModuleDict(
555
+ {
556
+ "32": nn.Conv2d(512, feat_dim, 1, 1),
557
+ "16": nn.Conv2d(512, feat_dim, 1, 1),
558
+ }
559
+ ),
560
+ pred_input_modules=nn.ModuleDict(
561
+ {
562
+ "32": nn.Identity(),
563
+ "16": nn.Identity(),
564
+ }
565
+ ),
566
+ rrb_d_dict=nn.ModuleDict(
567
+ {
568
+ "32": RRB(gp_dim[0] + feat_dim, dfn_dim),
569
+ "16": RRB(gp_dim[1] + feat_dim, dfn_dim),
570
+ }
571
+ ),
572
+ cab_dict=nn.ModuleDict(
573
+ {
574
+ "32": CAB(2 * dfn_dim, dfn_dim),
575
+ "16": CAB(2 * dfn_dim, dfn_dim),
576
+ }
577
+ ),
578
+ rrb_u_dict=nn.ModuleDict(
579
+ {
580
+ "32": RRB(dfn_dim, dfn_dim),
581
+ "16": RRB(dfn_dim, dfn_dim),
582
+ }
583
+ ),
584
+ terminal_module=nn.ModuleDict(
585
+ {
586
+ "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
587
+ "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
588
+ }
589
+ ),
590
+ )
591
+ dw = True
592
+ hidden_blocks = 8
593
+ kernel_size = 5
594
+ conv_refiner = nn.ModuleDict(
595
+ {
596
+ "16": LocalCorr(
597
+ 81,
598
+ 81 * 12,
599
+ 3,
600
+ kernel_size=kernel_size,
601
+ dw=dw,
602
+ hidden_blocks=hidden_blocks,
603
+ ),
604
+ "8": LocalCorr(
605
+ 81,
606
+ 81 * 12,
607
+ 3,
608
+ kernel_size=kernel_size,
609
+ dw=dw,
610
+ hidden_blocks=hidden_blocks,
611
+ ),
612
+ "4": LocalCorr(
613
+ 81,
614
+ 81 * 6,
615
+ 3,
616
+ kernel_size=kernel_size,
617
+ dw=dw,
618
+ hidden_blocks=hidden_blocks,
619
+ ),
620
+ "2": LocalCorr(
621
+ 81,
622
+ 81,
623
+ 3,
624
+ kernel_size=kernel_size,
625
+ dw=dw,
626
+ hidden_blocks=hidden_blocks,
627
+ ),
628
+ "1": ConvRefiner(
629
+ 2 * 3,
630
+ 24,
631
+ 3,
632
+ kernel_size=kernel_size,
633
+ dw=dw,
634
+ hidden_blocks=hidden_blocks,
635
+ ),
636
+ }
637
+ )
638
+ gp32 = NormedCorr()
639
+ gp16 = NormedCorr()
640
+ gps = nn.ModuleDict({"32": gp32, "16": gp16})
641
+ proj = nn.ModuleDict(
642
+ {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
643
+ )
644
+ decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
645
+ h, w = 384, 512
646
+ encoder = Encoder(
647
+ tv_resnet.resnet50(pretrained=not pretrained)
648
+ ) # only load pretrained weights if not loading a pretrained matcher ;)
649
+ matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device)
650
+ if pretrained:
651
+ weights = torch.hub.load_state_dict_from_url(
652
+ dkm_pretrained_urls["baseline"][version]
653
+ )
654
+ matcher.load_state_dict(weights)
655
+ return matcher
656
+
657
+
658
+ def linear(pretrained=True, version="mega_synthetic"):
659
+ gp_dim = 256
660
+ dfn_dim = 384
661
+ feat_dim = 256
662
+ coordinate_decoder = DFN(
663
+ internal_dim=dfn_dim,
664
+ feat_input_modules=nn.ModuleDict(
665
+ {
666
+ "32": nn.Conv2d(512, feat_dim, 1, 1),
667
+ "16": nn.Conv2d(512, feat_dim, 1, 1),
668
+ }
669
+ ),
670
+ pred_input_modules=nn.ModuleDict(
671
+ {
672
+ "32": nn.Identity(),
673
+ "16": nn.Identity(),
674
+ }
675
+ ),
676
+ rrb_d_dict=nn.ModuleDict(
677
+ {
678
+ "32": RRB(gp_dim + feat_dim, dfn_dim),
679
+ "16": RRB(gp_dim + feat_dim, dfn_dim),
680
+ }
681
+ ),
682
+ cab_dict=nn.ModuleDict(
683
+ {
684
+ "32": CAB(2 * dfn_dim, dfn_dim),
685
+ "16": CAB(2 * dfn_dim, dfn_dim),
686
+ }
687
+ ),
688
+ rrb_u_dict=nn.ModuleDict(
689
+ {
690
+ "32": RRB(dfn_dim, dfn_dim),
691
+ "16": RRB(dfn_dim, dfn_dim),
692
+ }
693
+ ),
694
+ terminal_module=nn.ModuleDict(
695
+ {
696
+ "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
697
+ "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
698
+ }
699
+ ),
700
+ )
701
+ dw = True
702
+ hidden_blocks = 8
703
+ kernel_size = 5
704
+ conv_refiner = nn.ModuleDict(
705
+ {
706
+ "16": ConvRefiner(
707
+ 2 * 512,
708
+ 1024,
709
+ 3,
710
+ kernel_size=kernel_size,
711
+ dw=dw,
712
+ hidden_blocks=hidden_blocks,
713
+ ),
714
+ "8": ConvRefiner(
715
+ 2 * 512,
716
+ 1024,
717
+ 3,
718
+ kernel_size=kernel_size,
719
+ dw=dw,
720
+ hidden_blocks=hidden_blocks,
721
+ ),
722
+ "4": ConvRefiner(
723
+ 2 * 256,
724
+ 512,
725
+ 3,
726
+ kernel_size=kernel_size,
727
+ dw=dw,
728
+ hidden_blocks=hidden_blocks,
729
+ ),
730
+ "2": ConvRefiner(
731
+ 2 * 64,
732
+ 128,
733
+ 3,
734
+ kernel_size=kernel_size,
735
+ dw=dw,
736
+ hidden_blocks=hidden_blocks,
737
+ ),
738
+ "1": ConvRefiner(
739
+ 2 * 3,
740
+ 24,
741
+ 3,
742
+ kernel_size=kernel_size,
743
+ dw=dw,
744
+ hidden_blocks=hidden_blocks,
745
+ ),
746
+ }
747
+ )
748
+ kernel_temperature = 0.2
749
+ learn_temperature = False
750
+ no_cov = True
751
+ kernel = CosKernel
752
+ only_attention = False
753
+ basis = "linear"
754
+ gp32 = GP(
755
+ kernel,
756
+ T=kernel_temperature,
757
+ learn_temperature=learn_temperature,
758
+ only_attention=only_attention,
759
+ gp_dim=gp_dim,
760
+ basis=basis,
761
+ no_cov=no_cov,
762
+ )
763
+ gp16 = GP(
764
+ kernel,
765
+ T=kernel_temperature,
766
+ learn_temperature=learn_temperature,
767
+ only_attention=only_attention,
768
+ gp_dim=gp_dim,
769
+ basis=basis,
770
+ no_cov=no_cov,
771
+ )
772
+ gps = nn.ModuleDict({"32": gp32, "16": gp16})
773
+ proj = nn.ModuleDict(
774
+ {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
775
+ )
776
+ decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
777
+ h, w = 384, 512
778
+ encoder = Encoder(
779
+ tv_resnet.resnet50(pretrained=not pretrained)
780
+ ) # only load pretrained weights if not loading a pretrained matcher ;)
781
+ matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device)
782
+ if pretrained:
783
+ weights = torch.hub.load_state_dict_from_url(
784
+ dkm_pretrained_urls["linear"][version]
785
+ )
786
+ matcher.load_state_dict(weights)
787
+ return matcher
third_party/DKM/dkm/models/deprecated/corr_channels.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+
6
+
7
+ class NormedCorrelationKernel(nn.Module): # similar to softmax kernel
8
+ def __init__(self):
9
+ super().__init__()
10
+
11
+ def __call__(self, x, y, eps=1e-6):
12
+ c = torch.einsum("bnd,bmd->bnm", x, y) / (
13
+ x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps
14
+ )
15
+ return c
16
+
17
+
18
+ class NormedCorr(nn.Module):
19
+ def __init__(
20
+ self,
21
+ ):
22
+ super().__init__()
23
+ self.corr = NormedCorrelationKernel()
24
+
25
+ def reshape(self, x):
26
+ return rearrange(x, "b d h w -> b (h w) d")
27
+
28
+ def forward(self, x, y, **kwargs):
29
+ b, c, h, w = y.shape
30
+ assert x.shape == y.shape
31
+ x, y = self.reshape(x), self.reshape(y)
32
+ corr_xy = self.corr(x, y)
33
+ corr_xy_flat = rearrange(corr_xy, "b (h w) c -> b c h w", h=h, w=w)
34
+ return corr_xy_flat
third_party/DKM/dkm/models/deprecated/local_corr.py ADDED
@@ -0,0 +1,630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ try:
5
+ import cupy
6
+ except:
7
+ print("Cupy not found, local correlation will not work")
8
+ import re
9
+ from ..dkm import ConvRefiner
10
+
11
+
12
+ class Stream:
13
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+ if device == 'cuda':
15
+ stream = torch.cuda.current_stream(device=device).cuda_stream
16
+ else:
17
+ stream = None
18
+
19
+
20
+ kernel_Correlation_rearrange = """
21
+ extern "C" __global__ void kernel_Correlation_rearrange(
22
+ const int n,
23
+ const float* input,
24
+ float* output
25
+ ) {
26
+ int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x;
27
+ if (intIndex >= n) {
28
+ return;
29
+ }
30
+ int intSample = blockIdx.z;
31
+ int intChannel = blockIdx.y;
32
+ float dblValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex];
33
+ __syncthreads();
34
+ int intPaddedY = (intIndex / SIZE_3(input)) + 4;
35
+ int intPaddedX = (intIndex % SIZE_3(input)) + 4;
36
+ int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX;
37
+ output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = dblValue;
38
+ }
39
+ """
40
+
41
+ kernel_Correlation_updateOutput = """
42
+ extern "C" __global__ void kernel_Correlation_updateOutput(
43
+ const int n,
44
+ const float* rbot0,
45
+ const float* rbot1,
46
+ float* top
47
+ ) {
48
+ extern __shared__ char patch_data_char[];
49
+ float *patch_data = (float *)patch_data_char;
50
+ // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1
51
+ int x1 = blockIdx.x + 4;
52
+ int y1 = blockIdx.y + 4;
53
+ int item = blockIdx.z;
54
+ int ch_off = threadIdx.x;
55
+ // Load 3D patch into shared shared memory
56
+ for (int j = 0; j < 1; j++) { // HEIGHT
57
+ for (int i = 0; i < 1; i++) { // WIDTH
58
+ int ji_off = (j + i) * SIZE_3(rbot0);
59
+ for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
60
+ int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch;
61
+ int idxPatchData = ji_off + ch;
62
+ patch_data[idxPatchData] = rbot0[idx1];
63
+ }
64
+ }
65
+ }
66
+ __syncthreads();
67
+ __shared__ float sum[32];
68
+ // Compute correlation
69
+ for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) {
70
+ sum[ch_off] = 0;
71
+ int s2o = top_channel % 9 - 4;
72
+ int s2p = top_channel / 9 - 4;
73
+ for (int j = 0; j < 1; j++) { // HEIGHT
74
+ for (int i = 0; i < 1; i++) { // WIDTH
75
+ int ji_off = (j + i) * SIZE_3(rbot0);
76
+ for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
77
+ int x2 = x1 + s2o;
78
+ int y2 = y1 + s2p;
79
+ int idxPatchData = ji_off + ch;
80
+ int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch;
81
+ sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2];
82
+ }
83
+ }
84
+ }
85
+ __syncthreads();
86
+ if (ch_off == 0) {
87
+ float total_sum = 0;
88
+ for (int idx = 0; idx < 32; idx++) {
89
+ total_sum += sum[idx];
90
+ }
91
+ const int sumelems = SIZE_3(rbot0);
92
+ const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x;
93
+ top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems;
94
+ }
95
+ }
96
+ }
97
+ """
98
+
99
+ kernel_Correlation_updateGradFirst = """
100
+ #define ROUND_OFF 50000
101
+ extern "C" __global__ void kernel_Correlation_updateGradFirst(
102
+ const int n,
103
+ const int intSample,
104
+ const float* rbot0,
105
+ const float* rbot1,
106
+ const float* gradOutput,
107
+ float* gradFirst,
108
+ float* gradSecond
109
+ ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
110
+ int n = intIndex % SIZE_1(gradFirst); // channels
111
+ int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 4; // w-pos
112
+ int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 4; // h-pos
113
+ // round_off is a trick to enable integer division with ceil, even for negative numbers
114
+ // We use a large offset, for the inner part not to become negative.
115
+ const int round_off = ROUND_OFF;
116
+ const int round_off_s1 = round_off;
117
+ // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
118
+ int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4)
119
+ int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4)
120
+ // Same here:
121
+ int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4)
122
+ int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4)
123
+ float sum = 0;
124
+ if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
125
+ xmin = max(0,xmin);
126
+ xmax = min(SIZE_3(gradOutput)-1,xmax);
127
+ ymin = max(0,ymin);
128
+ ymax = min(SIZE_2(gradOutput)-1,ymax);
129
+ for (int p = -4; p <= 4; p++) {
130
+ for (int o = -4; o <= 4; o++) {
131
+ // Get rbot1 data:
132
+ int s2o = o;
133
+ int s2p = p;
134
+ int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n;
135
+ float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n]
136
+ // Index offset for gradOutput in following loops:
137
+ int op = (p+4) * 9 + (o+4); // index[o,p]
138
+ int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
139
+ for (int y = ymin; y <= ymax; y++) {
140
+ for (int x = xmin; x <= xmax; x++) {
141
+ int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
142
+ sum += gradOutput[idxgradOutput] * bot1tmp;
143
+ }
144
+ }
145
+ }
146
+ }
147
+ }
148
+ const int sumelems = SIZE_1(gradFirst);
149
+ const int bot0index = ((n * SIZE_2(gradFirst)) + (m-4)) * SIZE_3(gradFirst) + (l-4);
150
+ gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems;
151
+ } }
152
+ """
153
+
154
+ kernel_Correlation_updateGradSecond = """
155
+ #define ROUND_OFF 50000
156
+ extern "C" __global__ void kernel_Correlation_updateGradSecond(
157
+ const int n,
158
+ const int intSample,
159
+ const float* rbot0,
160
+ const float* rbot1,
161
+ const float* gradOutput,
162
+ float* gradFirst,
163
+ float* gradSecond
164
+ ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
165
+ int n = intIndex % SIZE_1(gradSecond); // channels
166
+ int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 4; // w-pos
167
+ int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 4; // h-pos
168
+ // round_off is a trick to enable integer division with ceil, even for negative numbers
169
+ // We use a large offset, for the inner part not to become negative.
170
+ const int round_off = ROUND_OFF;
171
+ const int round_off_s1 = round_off;
172
+ float sum = 0;
173
+ for (int p = -4; p <= 4; p++) {
174
+ for (int o = -4; o <= 4; o++) {
175
+ int s2o = o;
176
+ int s2p = p;
177
+ //Get X,Y ranges and clamp
178
+ // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
179
+ int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o)
180
+ int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o)
181
+ // Same here:
182
+ int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o)
183
+ int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p)
184
+ if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
185
+ xmin = max(0,xmin);
186
+ xmax = min(SIZE_3(gradOutput)-1,xmax);
187
+ ymin = max(0,ymin);
188
+ ymax = min(SIZE_2(gradOutput)-1,ymax);
189
+ // Get rbot0 data:
190
+ int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n;
191
+ float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n]
192
+ // Index offset for gradOutput in following loops:
193
+ int op = (p+4) * 9 + (o+4); // index[o,p]
194
+ int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
195
+ for (int y = ymin; y <= ymax; y++) {
196
+ for (int x = xmin; x <= xmax; x++) {
197
+ int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
198
+ sum += gradOutput[idxgradOutput] * bot0tmp;
199
+ }
200
+ }
201
+ }
202
+ }
203
+ }
204
+ const int sumelems = SIZE_1(gradSecond);
205
+ const int bot1index = ((n * SIZE_2(gradSecond)) + (m-4)) * SIZE_3(gradSecond) + (l-4);
206
+ gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems;
207
+ } }
208
+ """
209
+
210
+
211
+ def cupy_kernel(strFunction, objectVariables):
212
+ strKernel = globals()[strFunction]
213
+
214
+ while True:
215
+ objectMatch = re.search(r"(SIZE_)([0-4])(\()([^\)]*)(\))", strKernel)
216
+
217
+ if objectMatch is None:
218
+ break
219
+
220
+ intArg = int(objectMatch.group(2))
221
+
222
+ strTensor = objectMatch.group(4)
223
+ intSizes = objectVariables[strTensor].size()
224
+
225
+ strKernel = strKernel.replace(objectMatch.group(), str(intSizes[intArg]))
226
+
227
+ while True:
228
+ objectMatch = re.search(r"(VALUE_)([0-4])(\()([^\)]+)(\))", strKernel)
229
+
230
+ if objectMatch is None:
231
+ break
232
+
233
+ intArgs = int(objectMatch.group(2))
234
+ strArgs = objectMatch.group(4).split(",")
235
+
236
+ strTensor = strArgs[0]
237
+ intStrides = objectVariables[strTensor].stride()
238
+ strIndex = [
239
+ "(("
240
+ + strArgs[intArg + 1].replace("{", "(").replace("}", ")").strip()
241
+ + ")*"
242
+ + str(intStrides[intArg])
243
+ + ")"
244
+ for intArg in range(intArgs)
245
+ ]
246
+
247
+ strKernel = strKernel.replace(
248
+ objectMatch.group(0), strTensor + "[" + str.join("+", strIndex) + "]"
249
+ )
250
+
251
+ return strKernel
252
+
253
+
254
+ try:
255
+
256
+ @cupy.memoize(for_each_device=True)
257
+ def cupy_launch(strFunction, strKernel):
258
+ return cupy.RawModule(code=strKernel).get_function(strFunction)
259
+
260
+ except:
261
+ pass
262
+
263
+
264
+ class _FunctionCorrelation(torch.autograd.Function):
265
+ @staticmethod
266
+ def forward(self, first, second):
267
+ rbot0 = first.new_zeros(
268
+ [first.size(0), first.size(2) + 8, first.size(3) + 8, first.size(1)]
269
+ )
270
+ rbot1 = first.new_zeros(
271
+ [first.size(0), first.size(2) + 8, first.size(3) + 8, first.size(1)]
272
+ )
273
+
274
+ self.save_for_backward(first, second, rbot0, rbot1)
275
+
276
+ first = first.contiguous()
277
+ second = second.contiguous()
278
+
279
+ output = first.new_zeros([first.size(0), 81, first.size(2), first.size(3)])
280
+
281
+ if first.is_cuda == True:
282
+ n = first.size(2) * first.size(3)
283
+ cupy_launch(
284
+ "kernel_Correlation_rearrange",
285
+ cupy_kernel(
286
+ "kernel_Correlation_rearrange", {"input": first, "output": rbot0}
287
+ ),
288
+ )(
289
+ grid=tuple([int((n + 16 - 1) / 16), first.size(1), first.size(0)]),
290
+ block=tuple([16, 1, 1]),
291
+ args=[n, first.data_ptr(), rbot0.data_ptr()],
292
+ stream=Stream,
293
+ )
294
+
295
+ n = second.size(2) * second.size(3)
296
+ cupy_launch(
297
+ "kernel_Correlation_rearrange",
298
+ cupy_kernel(
299
+ "kernel_Correlation_rearrange", {"input": second, "output": rbot1}
300
+ ),
301
+ )(
302
+ grid=tuple([int((n + 16 - 1) / 16), second.size(1), second.size(0)]),
303
+ block=tuple([16, 1, 1]),
304
+ args=[n, second.data_ptr(), rbot1.data_ptr()],
305
+ stream=Stream,
306
+ )
307
+
308
+ n = output.size(1) * output.size(2) * output.size(3)
309
+ cupy_launch(
310
+ "kernel_Correlation_updateOutput",
311
+ cupy_kernel(
312
+ "kernel_Correlation_updateOutput",
313
+ {"rbot0": rbot0, "rbot1": rbot1, "top": output},
314
+ ),
315
+ )(
316
+ grid=tuple([output.size(3), output.size(2), output.size(0)]),
317
+ block=tuple([32, 1, 1]),
318
+ shared_mem=first.size(1) * 4,
319
+ args=[n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr()],
320
+ stream=Stream,
321
+ )
322
+
323
+ elif first.is_cuda == False:
324
+ raise NotImplementedError()
325
+
326
+ return output
327
+
328
+ @staticmethod
329
+ def backward(self, gradOutput):
330
+ first, second, rbot0, rbot1 = self.saved_tensors
331
+
332
+ gradOutput = gradOutput.contiguous()
333
+
334
+ assert gradOutput.is_contiguous() == True
335
+
336
+ gradFirst = (
337
+ first.new_zeros(
338
+ [first.size(0), first.size(1), first.size(2), first.size(3)]
339
+ )
340
+ if self.needs_input_grad[0] == True
341
+ else None
342
+ )
343
+ gradSecond = (
344
+ first.new_zeros(
345
+ [first.size(0), first.size(1), first.size(2), first.size(3)]
346
+ )
347
+ if self.needs_input_grad[1] == True
348
+ else None
349
+ )
350
+
351
+ if first.is_cuda == True:
352
+ if gradFirst is not None:
353
+ for intSample in range(first.size(0)):
354
+ n = first.size(1) * first.size(2) * first.size(3)
355
+ cupy_launch(
356
+ "kernel_Correlation_updateGradFirst",
357
+ cupy_kernel(
358
+ "kernel_Correlation_updateGradFirst",
359
+ {
360
+ "rbot0": rbot0,
361
+ "rbot1": rbot1,
362
+ "gradOutput": gradOutput,
363
+ "gradFirst": gradFirst,
364
+ "gradSecond": None,
365
+ },
366
+ ),
367
+ )(
368
+ grid=tuple([int((n + 512 - 1) / 512), 1, 1]),
369
+ block=tuple([512, 1, 1]),
370
+ args=[
371
+ n,
372
+ intSample,
373
+ rbot0.data_ptr(),
374
+ rbot1.data_ptr(),
375
+ gradOutput.data_ptr(),
376
+ gradFirst.data_ptr(),
377
+ None,
378
+ ],
379
+ stream=Stream,
380
+ )
381
+
382
+ if gradSecond is not None:
383
+ for intSample in range(first.size(0)):
384
+ n = first.size(1) * first.size(2) * first.size(3)
385
+ cupy_launch(
386
+ "kernel_Correlation_updateGradSecond",
387
+ cupy_kernel(
388
+ "kernel_Correlation_updateGradSecond",
389
+ {
390
+ "rbot0": rbot0,
391
+ "rbot1": rbot1,
392
+ "gradOutput": gradOutput,
393
+ "gradFirst": None,
394
+ "gradSecond": gradSecond,
395
+ },
396
+ ),
397
+ )(
398
+ grid=tuple([int((n + 512 - 1) / 512), 1, 1]),
399
+ block=tuple([512, 1, 1]),
400
+ args=[
401
+ n,
402
+ intSample,
403
+ rbot0.data_ptr(),
404
+ rbot1.data_ptr(),
405
+ gradOutput.data_ptr(),
406
+ None,
407
+ gradSecond.data_ptr(),
408
+ ],
409
+ stream=Stream,
410
+ )
411
+
412
+ elif first.is_cuda == False:
413
+ raise NotImplementedError()
414
+
415
+ return gradFirst, gradSecond
416
+
417
+
418
+ class _FunctionCorrelationTranspose(torch.autograd.Function):
419
+ @staticmethod
420
+ def forward(self, input, second):
421
+ rbot0 = second.new_zeros(
422
+ [second.size(0), second.size(2) + 8, second.size(3) + 8, second.size(1)]
423
+ )
424
+ rbot1 = second.new_zeros(
425
+ [second.size(0), second.size(2) + 8, second.size(3) + 8, second.size(1)]
426
+ )
427
+
428
+ self.save_for_backward(input, second, rbot0, rbot1)
429
+
430
+ input = input.contiguous()
431
+ second = second.contiguous()
432
+
433
+ output = second.new_zeros(
434
+ [second.size(0), second.size(1), second.size(2), second.size(3)]
435
+ )
436
+
437
+ if second.is_cuda == True:
438
+ n = second.size(2) * second.size(3)
439
+ cupy_launch(
440
+ "kernel_Correlation_rearrange",
441
+ cupy_kernel(
442
+ "kernel_Correlation_rearrange", {"input": second, "output": rbot1}
443
+ ),
444
+ )(
445
+ grid=tuple([int((n + 16 - 1) / 16), second.size(1), second.size(0)]),
446
+ block=tuple([16, 1, 1]),
447
+ args=[n, second.data_ptr(), rbot1.data_ptr()],
448
+ stream=Stream,
449
+ )
450
+
451
+ for intSample in range(second.size(0)):
452
+ n = second.size(1) * second.size(2) * second.size(3)
453
+ cupy_launch(
454
+ "kernel_Correlation_updateGradFirst",
455
+ cupy_kernel(
456
+ "kernel_Correlation_updateGradFirst",
457
+ {
458
+ "rbot0": rbot0,
459
+ "rbot1": rbot1,
460
+ "gradOutput": input,
461
+ "gradFirst": output,
462
+ "gradSecond": None,
463
+ },
464
+ ),
465
+ )(
466
+ grid=tuple([int((n + 512 - 1) / 512), 1, 1]),
467
+ block=tuple([512, 1, 1]),
468
+ args=[
469
+ n,
470
+ intSample,
471
+ rbot0.data_ptr(),
472
+ rbot1.data_ptr(),
473
+ input.data_ptr(),
474
+ output.data_ptr(),
475
+ None,
476
+ ],
477
+ stream=Stream,
478
+ )
479
+
480
+ elif second.is_cuda == False:
481
+ raise NotImplementedError()
482
+
483
+ return output
484
+
485
+ @staticmethod
486
+ def backward(self, gradOutput):
487
+ input, second, rbot0, rbot1 = self.saved_tensors
488
+
489
+ gradOutput = gradOutput.contiguous()
490
+
491
+ gradInput = (
492
+ input.new_zeros(
493
+ [input.size(0), input.size(1), input.size(2), input.size(3)]
494
+ )
495
+ if self.needs_input_grad[0] == True
496
+ else None
497
+ )
498
+ gradSecond = (
499
+ second.new_zeros(
500
+ [second.size(0), second.size(1), second.size(2), second.size(3)]
501
+ )
502
+ if self.needs_input_grad[1] == True
503
+ else None
504
+ )
505
+
506
+ if second.is_cuda == True:
507
+ if gradInput is not None or gradSecond is not None:
508
+ n = second.size(2) * second.size(3)
509
+ cupy_launch(
510
+ "kernel_Correlation_rearrange",
511
+ cupy_kernel(
512
+ "kernel_Correlation_rearrange",
513
+ {"input": gradOutput, "output": rbot0},
514
+ ),
515
+ )(
516
+ grid=tuple(
517
+ [int((n + 16 - 1) / 16), gradOutput.size(1), gradOutput.size(0)]
518
+ ),
519
+ block=tuple([16, 1, 1]),
520
+ args=[n, gradOutput.data_ptr(), rbot0.data_ptr()],
521
+ stream=Stream,
522
+ )
523
+
524
+ if gradInput is not None:
525
+ n = gradInput.size(1) * gradInput.size(2) * gradInput.size(3)
526
+ cupy_launch(
527
+ "kernel_Correlation_updateOutput",
528
+ cupy_kernel(
529
+ "kernel_Correlation_updateOutput",
530
+ {"rbot0": rbot0, "rbot1": rbot1, "top": gradInput},
531
+ ),
532
+ )(
533
+ grid=tuple(
534
+ [gradInput.size(3), gradInput.size(2), gradInput.size(0)]
535
+ ),
536
+ block=tuple([32, 1, 1]),
537
+ shared_mem=gradOutput.size(1) * 4,
538
+ args=[n, rbot0.data_ptr(), rbot1.data_ptr(), gradInput.data_ptr()],
539
+ stream=Stream,
540
+ )
541
+
542
+ if gradSecond is not None:
543
+ for intSample in range(second.size(0)):
544
+ n = second.size(1) * second.size(2) * second.size(3)
545
+ cupy_launch(
546
+ "kernel_Correlation_updateGradSecond",
547
+ cupy_kernel(
548
+ "kernel_Correlation_updateGradSecond",
549
+ {
550
+ "rbot0": rbot0,
551
+ "rbot1": rbot1,
552
+ "gradOutput": input,
553
+ "gradFirst": None,
554
+ "gradSecond": gradSecond,
555
+ },
556
+ ),
557
+ )(
558
+ grid=tuple([int((n + 512 - 1) / 512), 1, 1]),
559
+ block=tuple([512, 1, 1]),
560
+ args=[
561
+ n,
562
+ intSample,
563
+ rbot0.data_ptr(),
564
+ rbot1.data_ptr(),
565
+ input.data_ptr(),
566
+ None,
567
+ gradSecond.data_ptr(),
568
+ ],
569
+ stream=Stream,
570
+ )
571
+
572
+ elif second.is_cuda == False:
573
+ raise NotImplementedError()
574
+
575
+ return gradInput, gradSecond
576
+
577
+
578
+ def FunctionCorrelation(reference_features, query_features):
579
+ return _FunctionCorrelation.apply(reference_features, query_features)
580
+
581
+
582
+ class ModuleCorrelation(torch.nn.Module):
583
+ def __init__(self):
584
+ super(ModuleCorrelation, self).__init__()
585
+
586
+ def forward(self, tensorFirst, tensorSecond):
587
+ return _FunctionCorrelation.apply(tensorFirst, tensorSecond)
588
+
589
+
590
+ def FunctionCorrelationTranspose(reference_features, query_features):
591
+ return _FunctionCorrelationTranspose.apply(reference_features, query_features)
592
+
593
+
594
+ class ModuleCorrelationTranspose(torch.nn.Module):
595
+ def __init__(self):
596
+ super(ModuleCorrelationTranspose, self).__init__()
597
+
598
+ def forward(self, tensorFirst, tensorSecond):
599
+ return _FunctionCorrelationTranspose.apply(tensorFirst, tensorSecond)
600
+
601
+
602
+ class LocalCorr(ConvRefiner):
603
+ def forward(self, x, y, flow):
604
+ """Computes the relative refining displacement in pixels for a given image x,y and a coarse flow-field between them
605
+
606
+ Args:
607
+ x ([type]): [description]
608
+ y ([type]): [description]
609
+ flow ([type]): [description]
610
+
611
+ Returns:
612
+ [type]: [description]
613
+ """
614
+ with torch.no_grad():
615
+ x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False)
616
+ corr = FunctionCorrelation(x, x_hat)
617
+ d = self.block1(corr)
618
+ d = self.hidden_blocks(d)
619
+ d = self.out_conv(d)
620
+ certainty, displacement = d[:, :-2], d[:, -2:]
621
+ return certainty, displacement
622
+
623
+
624
+ if __name__ == "__main__":
625
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
626
+ x = torch.randn(2, 128, 32, 32).to(device)
627
+ y = torch.randn(2, 128, 32, 32).to(device)
628
+ local_corr = LocalCorr(in_dim=81, hidden_dim=81 * 4)
629
+ z = local_corr(x, y)
630
+ print("hej")
third_party/DKM/dkm/models/dkm.py ADDED
@@ -0,0 +1,759 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from ..utils import get_tuple_transform_ops
9
+ from einops import rearrange
10
+ from ..utils.local_correlation import local_correlation
11
+
12
+
13
+ class ConvRefiner(nn.Module):
14
+ def __init__(
15
+ self,
16
+ in_dim=6,
17
+ hidden_dim=16,
18
+ out_dim=2,
19
+ dw=False,
20
+ kernel_size=5,
21
+ hidden_blocks=3,
22
+ displacement_emb = None,
23
+ displacement_emb_dim = None,
24
+ local_corr_radius = None,
25
+ corr_in_other = None,
26
+ no_support_fm = False,
27
+ ):
28
+ super().__init__()
29
+ self.block1 = self.create_block(
30
+ in_dim, hidden_dim, dw=dw, kernel_size=kernel_size
31
+ )
32
+ self.hidden_blocks = nn.Sequential(
33
+ *[
34
+ self.create_block(
35
+ hidden_dim,
36
+ hidden_dim,
37
+ dw=dw,
38
+ kernel_size=kernel_size,
39
+ )
40
+ for hb in range(hidden_blocks)
41
+ ]
42
+ )
43
+ self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
44
+ if displacement_emb:
45
+ self.has_displacement_emb = True
46
+ self.disp_emb = nn.Conv2d(2,displacement_emb_dim,1,1,0)
47
+ else:
48
+ self.has_displacement_emb = False
49
+ self.local_corr_radius = local_corr_radius
50
+ self.corr_in_other = corr_in_other
51
+ self.no_support_fm = no_support_fm
52
+ def create_block(
53
+ self,
54
+ in_dim,
55
+ out_dim,
56
+ dw=False,
57
+ kernel_size=5,
58
+ ):
59
+ num_groups = 1 if not dw else in_dim
60
+ if dw:
61
+ assert (
62
+ out_dim % in_dim == 0
63
+ ), "outdim must be divisible by indim for depthwise"
64
+ conv1 = nn.Conv2d(
65
+ in_dim,
66
+ out_dim,
67
+ kernel_size=kernel_size,
68
+ stride=1,
69
+ padding=kernel_size // 2,
70
+ groups=num_groups,
71
+ )
72
+ norm = nn.BatchNorm2d(out_dim)
73
+ relu = nn.ReLU(inplace=True)
74
+ conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
75
+ return nn.Sequential(conv1, norm, relu, conv2)
76
+
77
+ def forward(self, x, y, flow):
78
+ """Computes the relative refining displacement in pixels for a given image x,y and a coarse flow-field between them
79
+
80
+ Args:
81
+ x ([type]): [description]
82
+ y ([type]): [description]
83
+ flow ([type]): [description]
84
+
85
+ Returns:
86
+ [type]: [description]
87
+ """
88
+ device = x.device
89
+ b,c,hs,ws = x.shape
90
+ with torch.no_grad():
91
+ x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False)
92
+ if self.has_displacement_emb:
93
+ query_coords = torch.meshgrid(
94
+ (
95
+ torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
96
+ torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
97
+ )
98
+ )
99
+ query_coords = torch.stack((query_coords[1], query_coords[0]))
100
+ query_coords = query_coords[None].expand(b, 2, hs, ws)
101
+ in_displacement = flow-query_coords
102
+ emb_in_displacement = self.disp_emb(in_displacement)
103
+ if self.local_corr_radius:
104
+ #TODO: should corr have gradient?
105
+ if self.corr_in_other:
106
+ # Corr in other means take a kxk grid around the predicted coordinate in other image
107
+ local_corr = local_correlation(x,y,local_radius=self.local_corr_radius,flow = flow)
108
+ else:
109
+ # Otherwise we use the warp to sample in the first image
110
+ # This is actually different operations, especially for large viewpoint changes
111
+ local_corr = local_correlation(x, x_hat, local_radius=self.local_corr_radius,)
112
+ if self.no_support_fm:
113
+ x_hat = torch.zeros_like(x)
114
+ d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1)
115
+ else:
116
+ d = torch.cat((x, x_hat, emb_in_displacement), dim=1)
117
+ else:
118
+ if self.no_support_fm:
119
+ x_hat = torch.zeros_like(x)
120
+ d = torch.cat((x, x_hat), dim=1)
121
+ d = self.block1(d)
122
+ d = self.hidden_blocks(d)
123
+ d = self.out_conv(d)
124
+ certainty, displacement = d[:, :-2], d[:, -2:]
125
+ return certainty, displacement
126
+
127
+
128
+ class CosKernel(nn.Module): # similar to softmax kernel
129
+ def __init__(self, T, learn_temperature=False):
130
+ super().__init__()
131
+ self.learn_temperature = learn_temperature
132
+ if self.learn_temperature:
133
+ self.T = nn.Parameter(torch.tensor(T))
134
+ else:
135
+ self.T = T
136
+
137
+ def __call__(self, x, y, eps=1e-6):
138
+ c = torch.einsum("bnd,bmd->bnm", x, y) / (
139
+ x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps
140
+ )
141
+ if self.learn_temperature:
142
+ T = self.T.abs() + 0.01
143
+ else:
144
+ T = torch.tensor(self.T, device=c.device)
145
+ K = ((c - 1.0) / T).exp()
146
+ return K
147
+
148
+
149
+ class CAB(nn.Module):
150
+ def __init__(self, in_channels, out_channels):
151
+ super(CAB, self).__init__()
152
+ self.global_pooling = nn.AdaptiveAvgPool2d(1)
153
+ self.conv1 = nn.Conv2d(
154
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
155
+ )
156
+ self.relu = nn.ReLU()
157
+ self.conv2 = nn.Conv2d(
158
+ out_channels, out_channels, kernel_size=1, stride=1, padding=0
159
+ )
160
+ self.sigmod = nn.Sigmoid()
161
+
162
+ def forward(self, x):
163
+ x1, x2 = x # high, low (old, new)
164
+ x = torch.cat([x1, x2], dim=1)
165
+ x = self.global_pooling(x)
166
+ x = self.conv1(x)
167
+ x = self.relu(x)
168
+ x = self.conv2(x)
169
+ x = self.sigmod(x)
170
+ x2 = x * x2
171
+ res = x2 + x1
172
+ return res
173
+
174
+
175
+ class RRB(nn.Module):
176
+ def __init__(self, in_channels, out_channels, kernel_size=3):
177
+ super(RRB, self).__init__()
178
+ self.conv1 = nn.Conv2d(
179
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
180
+ )
181
+ self.conv2 = nn.Conv2d(
182
+ out_channels,
183
+ out_channels,
184
+ kernel_size=kernel_size,
185
+ stride=1,
186
+ padding=kernel_size // 2,
187
+ )
188
+ self.relu = nn.ReLU()
189
+ self.bn = nn.BatchNorm2d(out_channels)
190
+ self.conv3 = nn.Conv2d(
191
+ out_channels,
192
+ out_channels,
193
+ kernel_size=kernel_size,
194
+ stride=1,
195
+ padding=kernel_size // 2,
196
+ )
197
+
198
+ def forward(self, x):
199
+ x = self.conv1(x)
200
+ res = self.conv2(x)
201
+ res = self.bn(res)
202
+ res = self.relu(res)
203
+ res = self.conv3(res)
204
+ return self.relu(x + res)
205
+
206
+
207
+ class DFN(nn.Module):
208
+ def __init__(
209
+ self,
210
+ internal_dim,
211
+ feat_input_modules,
212
+ pred_input_modules,
213
+ rrb_d_dict,
214
+ cab_dict,
215
+ rrb_u_dict,
216
+ use_global_context=False,
217
+ global_dim=None,
218
+ terminal_module=None,
219
+ upsample_mode="bilinear",
220
+ align_corners=False,
221
+ ):
222
+ super().__init__()
223
+ if use_global_context:
224
+ assert (
225
+ global_dim is not None
226
+ ), "Global dim must be provided when using global context"
227
+ self.align_corners = align_corners
228
+ self.internal_dim = internal_dim
229
+ self.feat_input_modules = feat_input_modules
230
+ self.pred_input_modules = pred_input_modules
231
+ self.rrb_d = rrb_d_dict
232
+ self.cab = cab_dict
233
+ self.rrb_u = rrb_u_dict
234
+ self.use_global_context = use_global_context
235
+ if use_global_context:
236
+ self.global_to_internal = nn.Conv2d(global_dim, self.internal_dim, 1, 1, 0)
237
+ self.global_pooling = nn.AdaptiveAvgPool2d(1)
238
+ self.terminal_module = (
239
+ terminal_module if terminal_module is not None else nn.Identity()
240
+ )
241
+ self.upsample_mode = upsample_mode
242
+ self._scales = [int(key) for key in self.terminal_module.keys()]
243
+
244
+ def scales(self):
245
+ return self._scales.copy()
246
+
247
+ def forward(self, embeddings, feats, context, key):
248
+ feats = self.feat_input_modules[str(key)](feats)
249
+ embeddings = torch.cat([feats, embeddings], dim=1)
250
+ embeddings = self.rrb_d[str(key)](embeddings)
251
+ context = self.cab[str(key)]([context, embeddings])
252
+ context = self.rrb_u[str(key)](context)
253
+ preds = self.terminal_module[str(key)](context)
254
+ pred_coord = preds[:, -2:]
255
+ pred_certainty = preds[:, :-2]
256
+ return pred_coord, pred_certainty, context
257
+
258
+
259
+ class GP(nn.Module):
260
+ def __init__(
261
+ self,
262
+ kernel,
263
+ T=1,
264
+ learn_temperature=False,
265
+ only_attention=False,
266
+ gp_dim=64,
267
+ basis="fourier",
268
+ covar_size=5,
269
+ only_nearest_neighbour=False,
270
+ sigma_noise=0.1,
271
+ no_cov=False,
272
+ predict_features = False,
273
+ ):
274
+ super().__init__()
275
+ self.K = kernel(T=T, learn_temperature=learn_temperature)
276
+ self.sigma_noise = sigma_noise
277
+ self.covar_size = covar_size
278
+ self.pos_conv = torch.nn.Conv2d(2, gp_dim, 1, 1)
279
+ self.only_attention = only_attention
280
+ self.only_nearest_neighbour = only_nearest_neighbour
281
+ self.basis = basis
282
+ self.no_cov = no_cov
283
+ self.dim = gp_dim
284
+ self.predict_features = predict_features
285
+
286
+ def get_local_cov(self, cov):
287
+ K = self.covar_size
288
+ b, h, w, h, w = cov.shape
289
+ hw = h * w
290
+ cov = F.pad(cov, 4 * (K // 2,)) # pad v_q
291
+ delta = torch.stack(
292
+ torch.meshgrid(
293
+ torch.arange(-(K // 2), K // 2 + 1), torch.arange(-(K // 2), K // 2 + 1)
294
+ ),
295
+ dim=-1,
296
+ )
297
+ positions = torch.stack(
298
+ torch.meshgrid(
299
+ torch.arange(K // 2, h + K // 2), torch.arange(K // 2, w + K // 2)
300
+ ),
301
+ dim=-1,
302
+ )
303
+ neighbours = positions[:, :, None, None, :] + delta[None, :, :]
304
+ points = torch.arange(hw)[:, None].expand(hw, K**2)
305
+ local_cov = cov.reshape(b, hw, h + K - 1, w + K - 1)[
306
+ :,
307
+ points.flatten(),
308
+ neighbours[..., 0].flatten(),
309
+ neighbours[..., 1].flatten(),
310
+ ].reshape(b, h, w, K**2)
311
+ return local_cov
312
+
313
+ def reshape(self, x):
314
+ return rearrange(x, "b d h w -> b (h w) d")
315
+
316
+ def project_to_basis(self, x):
317
+ if self.basis == "fourier":
318
+ return torch.cos(8 * math.pi * self.pos_conv(x))
319
+ elif self.basis == "linear":
320
+ return self.pos_conv(x)
321
+ else:
322
+ raise ValueError(
323
+ "No other bases other than fourier and linear currently supported in public release"
324
+ )
325
+
326
+ def get_pos_enc(self, y):
327
+ b, c, h, w = y.shape
328
+ coarse_coords = torch.meshgrid(
329
+ (
330
+ torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=y.device),
331
+ torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=y.device),
332
+ )
333
+ )
334
+
335
+ coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
336
+ None
337
+ ].expand(b, h, w, 2)
338
+ coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
339
+ coarse_embedded_coords = self.project_to_basis(coarse_coords)
340
+ return coarse_embedded_coords
341
+
342
+ def forward(self, x, y, **kwargs):
343
+ b, c, h1, w1 = x.shape
344
+ b, c, h2, w2 = y.shape
345
+ f = self.get_pos_enc(y)
346
+ if self.predict_features:
347
+ f = f + y[:,:self.dim] # Stupid way to predict features
348
+ b, d, h2, w2 = f.shape
349
+ #assert x.shape == y.shape
350
+ x, y, f = self.reshape(x), self.reshape(y), self.reshape(f)
351
+ K_xx = self.K(x, x)
352
+ K_yy = self.K(y, y)
353
+ K_xy = self.K(x, y)
354
+ K_yx = K_xy.permute(0, 2, 1)
355
+ sigma_noise = self.sigma_noise * torch.eye(h2 * w2, device=x.device)[None, :, :]
356
+ # Due to https://github.com/pytorch/pytorch/issues/16963 annoying warnings, remove batch if N large
357
+ if len(K_yy[0]) > 2000:
358
+ K_yy_inv = torch.cat([torch.linalg.inv(K_yy[k:k+1] + sigma_noise[k:k+1]) for k in range(b)])
359
+ else:
360
+ K_yy_inv = torch.linalg.inv(K_yy + sigma_noise)
361
+
362
+ mu_x = K_xy.matmul(K_yy_inv.matmul(f))
363
+ mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1)
364
+ if not self.no_cov:
365
+ cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx))
366
+ cov_x = rearrange(cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1)
367
+ local_cov_x = self.get_local_cov(cov_x)
368
+ local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w")
369
+ gp_feats = torch.cat((mu_x, local_cov_x), dim=1)
370
+ else:
371
+ gp_feats = mu_x
372
+ return gp_feats
373
+
374
+
375
+ class Encoder(nn.Module):
376
+ def __init__(self, resnet):
377
+ super().__init__()
378
+ self.resnet = resnet
379
+ def forward(self, x):
380
+ x0 = x
381
+ b, c, h, w = x.shape
382
+ x = self.resnet.conv1(x)
383
+ x = self.resnet.bn1(x)
384
+ x1 = self.resnet.relu(x)
385
+
386
+ x = self.resnet.maxpool(x1)
387
+ x2 = self.resnet.layer1(x)
388
+
389
+ x3 = self.resnet.layer2(x2)
390
+
391
+ x4 = self.resnet.layer3(x3)
392
+
393
+ x5 = self.resnet.layer4(x4)
394
+ feats = {32: x5, 16: x4, 8: x3, 4: x2, 2: x1, 1: x0}
395
+ return feats
396
+
397
+ def train(self, mode=True):
398
+ super().train(mode)
399
+ for m in self.modules():
400
+ if isinstance(m, nn.BatchNorm2d):
401
+ m.eval()
402
+ pass
403
+
404
+
405
+ class Decoder(nn.Module):
406
+ def __init__(
407
+ self, embedding_decoder, gps, proj, conv_refiner, transformers = None, detach=False, scales="all", pos_embeddings = None,
408
+ ):
409
+ super().__init__()
410
+ self.embedding_decoder = embedding_decoder
411
+ self.gps = gps
412
+ self.proj = proj
413
+ self.conv_refiner = conv_refiner
414
+ self.detach = detach
415
+ if scales == "all":
416
+ self.scales = ["32", "16", "8", "4", "2", "1"]
417
+ else:
418
+ self.scales = scales
419
+
420
+ def upsample_preds(self, flow, certainty, query, support):
421
+ b, hs, ws, d = flow.shape
422
+ b, c, h, w = query.shape
423
+ flow = flow.permute(0, 3, 1, 2)
424
+ certainty = F.interpolate(
425
+ certainty, size=(h, w), align_corners=False, mode="bilinear"
426
+ )
427
+ flow = F.interpolate(
428
+ flow, size=(h, w), align_corners=False, mode="bilinear"
429
+ )
430
+ delta_certainty, delta_flow = self.conv_refiner["1"](query, support, flow)
431
+ flow = torch.stack(
432
+ (
433
+ flow[:, 0] + delta_flow[:, 0] / (4 * w),
434
+ flow[:, 1] + delta_flow[:, 1] / (4 * h),
435
+ ),
436
+ dim=1,
437
+ )
438
+ flow = flow.permute(0, 2, 3, 1)
439
+ certainty = certainty + delta_certainty
440
+ return flow, certainty
441
+
442
+ def get_placeholder_flow(self, b, h, w, device):
443
+ coarse_coords = torch.meshgrid(
444
+ (
445
+ torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
446
+ torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
447
+ )
448
+ )
449
+ coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
450
+ None
451
+ ].expand(b, h, w, 2)
452
+ coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
453
+ return coarse_coords
454
+
455
+
456
+ def forward(self, f1, f2, upsample = False, dense_flow = None, dense_certainty = None):
457
+ coarse_scales = self.embedding_decoder.scales()
458
+ all_scales = self.scales if not upsample else ["8", "4", "2", "1"]
459
+ sizes = {scale: f1[scale].shape[-2:] for scale in f1}
460
+ h, w = sizes[1]
461
+ b = f1[1].shape[0]
462
+ device = f1[1].device
463
+ coarsest_scale = int(all_scales[0])
464
+ old_stuff = torch.zeros(
465
+ b, self.embedding_decoder.internal_dim, *sizes[coarsest_scale], device=f1[coarsest_scale].device
466
+ )
467
+ dense_corresps = {}
468
+ if not upsample:
469
+ dense_flow = self.get_placeholder_flow(b, *sizes[coarsest_scale], device)
470
+ dense_certainty = 0.0
471
+ else:
472
+ dense_flow = F.interpolate(
473
+ dense_flow,
474
+ size=sizes[coarsest_scale],
475
+ align_corners=False,
476
+ mode="bilinear",
477
+ )
478
+ dense_certainty = F.interpolate(
479
+ dense_certainty,
480
+ size=sizes[coarsest_scale],
481
+ align_corners=False,
482
+ mode="bilinear",
483
+ )
484
+ for new_scale in all_scales:
485
+ ins = int(new_scale)
486
+ f1_s, f2_s = f1[ins], f2[ins]
487
+ if new_scale in self.proj:
488
+ f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
489
+ b, c, hs, ws = f1_s.shape
490
+ if ins in coarse_scales:
491
+ old_stuff = F.interpolate(
492
+ old_stuff, size=sizes[ins], mode="bilinear", align_corners=False
493
+ )
494
+ new_stuff = self.gps[new_scale](f1_s, f2_s, dense_flow=dense_flow)
495
+ dense_flow, dense_certainty, old_stuff = self.embedding_decoder(
496
+ new_stuff, f1_s, old_stuff, new_scale
497
+ )
498
+
499
+ if new_scale in self.conv_refiner:
500
+ delta_certainty, displacement = self.conv_refiner[new_scale](
501
+ f1_s, f2_s, dense_flow
502
+ )
503
+ dense_flow = torch.stack(
504
+ (
505
+ dense_flow[:, 0] + ins * displacement[:, 0] / (4 * w),
506
+ dense_flow[:, 1] + ins * displacement[:, 1] / (4 * h),
507
+ ),
508
+ dim=1,
509
+ )
510
+ dense_certainty = (
511
+ dense_certainty + delta_certainty
512
+ ) # predict both certainty and displacement
513
+
514
+ dense_corresps[ins] = {
515
+ "dense_flow": dense_flow,
516
+ "dense_certainty": dense_certainty,
517
+ }
518
+
519
+ if new_scale != "1":
520
+ dense_flow = F.interpolate(
521
+ dense_flow,
522
+ size=sizes[ins // 2],
523
+ align_corners=False,
524
+ mode="bilinear",
525
+ )
526
+
527
+ dense_certainty = F.interpolate(
528
+ dense_certainty,
529
+ size=sizes[ins // 2],
530
+ align_corners=False,
531
+ mode="bilinear",
532
+ )
533
+ if self.detach:
534
+ dense_flow = dense_flow.detach()
535
+ dense_certainty = dense_certainty.detach()
536
+ return dense_corresps
537
+
538
+
539
+ class RegressionMatcher(nn.Module):
540
+ def __init__(
541
+ self,
542
+ encoder,
543
+ decoder,
544
+ h=384,
545
+ w=512,
546
+ use_contrastive_loss = False,
547
+ alpha = 1,
548
+ beta = 0,
549
+ sample_mode = "threshold",
550
+ upsample_preds = False,
551
+ symmetric = False,
552
+ name = None,
553
+ use_soft_mutual_nearest_neighbours = False,
554
+ ):
555
+ super().__init__()
556
+ self.encoder = encoder
557
+ self.decoder = decoder
558
+ self.w_resized = w
559
+ self.h_resized = h
560
+ self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True)
561
+ self.use_contrastive_loss = use_contrastive_loss
562
+ self.alpha = alpha
563
+ self.beta = beta
564
+ self.sample_mode = sample_mode
565
+ self.upsample_preds = upsample_preds
566
+ self.symmetric = symmetric
567
+ self.name = name
568
+ self.sample_thresh = 0.05
569
+ self.upsample_res = (864,1152)
570
+ if use_soft_mutual_nearest_neighbours:
571
+ assert symmetric, "MNS requires symmetric inference"
572
+ self.use_soft_mutual_nearest_neighbours = use_soft_mutual_nearest_neighbours
573
+
574
+ def extract_backbone_features(self, batch, batched = True, upsample = True):
575
+ #TODO: only extract stride [1,2,4,8] for upsample = True
576
+ x_q = batch["query"]
577
+ x_s = batch["support"]
578
+ if batched:
579
+ X = torch.cat((x_q, x_s))
580
+ feature_pyramid = self.encoder(X)
581
+ else:
582
+ feature_pyramid = self.encoder(x_q), self.encoder(x_s)
583
+ return feature_pyramid
584
+
585
+ def sample(
586
+ self,
587
+ dense_matches,
588
+ dense_certainty,
589
+ num=10000,
590
+ ):
591
+ if "threshold" in self.sample_mode:
592
+ upper_thresh = self.sample_thresh
593
+ dense_certainty = dense_certainty.clone()
594
+ dense_certainty[dense_certainty > upper_thresh] = 1
595
+ elif "pow" in self.sample_mode:
596
+ dense_certainty = dense_certainty**(1/3)
597
+ elif "naive" in self.sample_mode:
598
+ dense_certainty = torch.ones_like(dense_certainty)
599
+ matches, certainty = (
600
+ dense_matches.reshape(-1, 4),
601
+ dense_certainty.reshape(-1),
602
+ )
603
+ expansion_factor = 4 if "balanced" in self.sample_mode else 1
604
+ good_samples = torch.multinomial(certainty,
605
+ num_samples = min(expansion_factor*num, len(certainty)),
606
+ replacement=False)
607
+ good_matches, good_certainty = matches[good_samples], certainty[good_samples]
608
+ if "balanced" not in self.sample_mode:
609
+ return good_matches, good_certainty
610
+
611
+ from ..utils.kde import kde
612
+ density = kde(good_matches, std=0.1)
613
+ p = 1 / (density+1)
614
+ p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
615
+ balanced_samples = torch.multinomial(p,
616
+ num_samples = min(num,len(good_certainty)),
617
+ replacement=False)
618
+ return good_matches[balanced_samples], good_certainty[balanced_samples]
619
+
620
+ def forward(self, batch, batched = True):
621
+ feature_pyramid = self.extract_backbone_features(batch, batched=batched)
622
+ if batched:
623
+ f_q_pyramid = {
624
+ scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items()
625
+ }
626
+ f_s_pyramid = {
627
+ scale: f_scale.chunk(2)[1] for scale, f_scale in feature_pyramid.items()
628
+ }
629
+ else:
630
+ f_q_pyramid, f_s_pyramid = feature_pyramid
631
+ dense_corresps = self.decoder(f_q_pyramid, f_s_pyramid)
632
+ if self.training and self.use_contrastive_loss:
633
+ return dense_corresps, (f_q_pyramid, f_s_pyramid)
634
+ else:
635
+ return dense_corresps
636
+
637
+ def forward_symmetric(self, batch, upsample = False, batched = True):
638
+ feature_pyramid = self.extract_backbone_features(batch, upsample = upsample, batched = batched)
639
+ f_q_pyramid = feature_pyramid
640
+ f_s_pyramid = {
641
+ scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]))
642
+ for scale, f_scale in feature_pyramid.items()
643
+ }
644
+ dense_corresps = self.decoder(f_q_pyramid, f_s_pyramid, upsample = upsample, **(batch["corresps"] if "corresps" in batch else {}))
645
+ return dense_corresps
646
+
647
+ def to_pixel_coordinates(self, matches, H_A, W_A, H_B, W_B):
648
+ kpts_A, kpts_B = matches[...,:2], matches[...,2:]
649
+ kpts_A = torch.stack((W_A/2 * (kpts_A[...,0]+1), H_A/2 * (kpts_A[...,1]+1)),axis=-1)
650
+ kpts_B = torch.stack((W_B/2 * (kpts_B[...,0]+1), H_B/2 * (kpts_B[...,1]+1)),axis=-1)
651
+ return kpts_A, kpts_B
652
+
653
+ def match(
654
+ self,
655
+ im1_path,
656
+ im2_path,
657
+ *args,
658
+ batched=False,
659
+ device = None
660
+ ):
661
+ assert not (batched and self.upsample_preds), "Cannot upsample preds if in batchmode (as we don't have access to high res images). You can turn off upsample_preds by model.upsample_preds = False "
662
+ if isinstance(im1_path, (str, os.PathLike)):
663
+ im1, im2 = Image.open(im1_path), Image.open(im2_path)
664
+ else: # assume it is a PIL Image
665
+ im1, im2 = im1_path, im2_path
666
+ if device is None:
667
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
668
+ symmetric = self.symmetric
669
+ self.train(False)
670
+ with torch.no_grad():
671
+ if not batched:
672
+ b = 1
673
+ w, h = im1.size
674
+ w2, h2 = im2.size
675
+ # Get images in good format
676
+ ws = self.w_resized
677
+ hs = self.h_resized
678
+
679
+ test_transform = get_tuple_transform_ops(
680
+ resize=(hs, ws), normalize=True
681
+ )
682
+ query, support = test_transform((im1, im2))
683
+ batch = {"query": query[None].to(device), "support": support[None].to(device)}
684
+ else:
685
+ b, c, h, w = im1.shape
686
+ b, c, h2, w2 = im2.shape
687
+ assert w == w2 and h == h2, "For batched images we assume same size"
688
+ batch = {"query": im1.to(device), "support": im2.to(device)}
689
+ hs, ws = self.h_resized, self.w_resized
690
+ finest_scale = 1
691
+ # Run matcher
692
+ if symmetric:
693
+ dense_corresps = self.forward_symmetric(batch, batched = True)
694
+ else:
695
+ dense_corresps = self.forward(batch, batched = True)
696
+
697
+ if self.upsample_preds:
698
+ hs, ws = self.upsample_res
699
+ low_res_certainty = F.interpolate(
700
+ dense_corresps[16]["dense_certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
701
+ )
702
+ cert_clamp = 0
703
+ factor = 0.5
704
+ low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp)
705
+
706
+ if self.upsample_preds:
707
+ test_transform = get_tuple_transform_ops(
708
+ resize=(hs, ws), normalize=True
709
+ )
710
+ query, support = test_transform((im1, im2))
711
+ query, support = query[None].to(device), support[None].to(device)
712
+ batch = {"query": query, "support": support, "corresps": dense_corresps[finest_scale]}
713
+ if symmetric:
714
+ dense_corresps = self.forward_symmetric(batch, upsample = True, batched=True)
715
+ else:
716
+ dense_corresps = self.forward(batch, batched = True, upsample=True)
717
+ query_to_support = dense_corresps[finest_scale]["dense_flow"]
718
+ dense_certainty = dense_corresps[finest_scale]["dense_certainty"]
719
+
720
+ # Get certainty interpolation
721
+ dense_certainty = dense_certainty - low_res_certainty
722
+ query_to_support = query_to_support.permute(
723
+ 0, 2, 3, 1
724
+ )
725
+ # Create im1 meshgrid
726
+ query_coords = torch.meshgrid(
727
+ (
728
+ torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
729
+ torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
730
+ )
731
+ )
732
+ query_coords = torch.stack((query_coords[1], query_coords[0]))
733
+ query_coords = query_coords[None].expand(b, 2, hs, ws)
734
+ dense_certainty = dense_certainty.sigmoid() # logits -> probs
735
+ query_coords = query_coords.permute(0, 2, 3, 1)
736
+ if (query_to_support.abs() > 1).any() and True:
737
+ wrong = (query_to_support.abs() > 1).sum(dim=-1) > 0
738
+ dense_certainty[wrong[:,None]] = 0
739
+
740
+ query_to_support = torch.clamp(query_to_support, -1, 1)
741
+ if symmetric:
742
+ support_coords = query_coords
743
+ qts, stq = query_to_support.chunk(2)
744
+ q_warp = torch.cat((query_coords, qts), dim=-1)
745
+ s_warp = torch.cat((stq, support_coords), dim=-1)
746
+ warp = torch.cat((q_warp, s_warp),dim=2)
747
+ dense_certainty = torch.cat(dense_certainty.chunk(2), dim=3)[:,0]
748
+ else:
749
+ warp = torch.cat((query_coords, query_to_support), dim=-1)
750
+ if batched:
751
+ return (
752
+ warp,
753
+ dense_certainty
754
+ )
755
+ else:
756
+ return (
757
+ warp[0],
758
+ dense_certainty[0],
759
+ )
third_party/DKM/dkm/models/encoders.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as tvm
5
+
6
+ class ResNet18(nn.Module):
7
+ def __init__(self, pretrained=False) -> None:
8
+ super().__init__()
9
+ self.net = tvm.resnet18(pretrained=pretrained)
10
+ def forward(self, x):
11
+ self = self.net
12
+ x1 = x
13
+ x = self.conv1(x1)
14
+ x = self.bn1(x)
15
+ x2 = self.relu(x)
16
+ x = self.maxpool(x2)
17
+ x4 = self.layer1(x)
18
+ x8 = self.layer2(x4)
19
+ x16 = self.layer3(x8)
20
+ x32 = self.layer4(x16)
21
+ return {32:x32,16:x16,8:x8,4:x4,2:x2,1:x1}
22
+
23
+ def train(self, mode=True):
24
+ super().train(mode)
25
+ for m in self.modules():
26
+ if isinstance(m, nn.BatchNorm2d):
27
+ m.eval()
28
+ pass
29
+
30
+ class ResNet50(nn.Module):
31
+ def __init__(self, pretrained=False, high_res = False, weights = None, dilation = None, freeze_bn = True, anti_aliased = False) -> None:
32
+ super().__init__()
33
+ if dilation is None:
34
+ dilation = [False,False,False]
35
+ if anti_aliased:
36
+ pass
37
+ else:
38
+ if weights is not None:
39
+ self.net = tvm.resnet50(weights = weights,replace_stride_with_dilation=dilation)
40
+ else:
41
+ self.net = tvm.resnet50(pretrained=pretrained,replace_stride_with_dilation=dilation)
42
+
43
+ self.high_res = high_res
44
+ self.freeze_bn = freeze_bn
45
+ def forward(self, x):
46
+ net = self.net
47
+ feats = {1:x}
48
+ x = net.conv1(x)
49
+ x = net.bn1(x)
50
+ x = net.relu(x)
51
+ feats[2] = x
52
+ x = net.maxpool(x)
53
+ x = net.layer1(x)
54
+ feats[4] = x
55
+ x = net.layer2(x)
56
+ feats[8] = x
57
+ x = net.layer3(x)
58
+ feats[16] = x
59
+ x = net.layer4(x)
60
+ feats[32] = x
61
+ return feats
62
+
63
+ def train(self, mode=True):
64
+ super().train(mode)
65
+ if self.freeze_bn:
66
+ for m in self.modules():
67
+ if isinstance(m, nn.BatchNorm2d):
68
+ m.eval()
69
+ pass
70
+
71
+
72
+
73
+
74
+ class ResNet101(nn.Module):
75
+ def __init__(self, pretrained=False, high_res = False, weights = None) -> None:
76
+ super().__init__()
77
+ if weights is not None:
78
+ self.net = tvm.resnet101(weights = weights)
79
+ else:
80
+ self.net = tvm.resnet101(pretrained=pretrained)
81
+ self.high_res = high_res
82
+ self.scale_factor = 1 if not high_res else 1.5
83
+ def forward(self, x):
84
+ net = self.net
85
+ feats = {1:x}
86
+ sf = self.scale_factor
87
+ if self.high_res:
88
+ x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic")
89
+ x = net.conv1(x)
90
+ x = net.bn1(x)
91
+ x = net.relu(x)
92
+ feats[2] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
93
+ x = net.maxpool(x)
94
+ x = net.layer1(x)
95
+ feats[4] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
96
+ x = net.layer2(x)
97
+ feats[8] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
98
+ x = net.layer3(x)
99
+ feats[16] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
100
+ x = net.layer4(x)
101
+ feats[32] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
102
+ return feats
103
+
104
+ def train(self, mode=True):
105
+ super().train(mode)
106
+ for m in self.modules():
107
+ if isinstance(m, nn.BatchNorm2d):
108
+ m.eval()
109
+ pass
110
+
111
+
112
+ class WideResNet50(nn.Module):
113
+ def __init__(self, pretrained=False, high_res = False, weights = None) -> None:
114
+ super().__init__()
115
+ if weights is not None:
116
+ self.net = tvm.wide_resnet50_2(weights = weights)
117
+ else:
118
+ self.net = tvm.wide_resnet50_2(pretrained=pretrained)
119
+ self.high_res = high_res
120
+ self.scale_factor = 1 if not high_res else 1.5
121
+ def forward(self, x):
122
+ net = self.net
123
+ feats = {1:x}
124
+ sf = self.scale_factor
125
+ if self.high_res:
126
+ x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic")
127
+ x = net.conv1(x)
128
+ x = net.bn1(x)
129
+ x = net.relu(x)
130
+ feats[2] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
131
+ x = net.maxpool(x)
132
+ x = net.layer1(x)
133
+ feats[4] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
134
+ x = net.layer2(x)
135
+ feats[8] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
136
+ x = net.layer3(x)
137
+ feats[16] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
138
+ x = net.layer4(x)
139
+ feats[32] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
140
+ return feats
141
+
142
+ def train(self, mode=True):
143
+ super().train(mode)
144
+ for m in self.modules():
145
+ if isinstance(m, nn.BatchNorm2d):
146
+ m.eval()
147
+ pass