Spaces:
Running
Running
Vincentqyw
commited on
Commit
·
dbf8b7e
1
Parent(s):
10b4a5f
add: roma
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- third_party/DKM/.gitignore +3 -0
- third_party/DKM/LICENSE +25 -0
- third_party/DKM/README.md +117 -0
- third_party/DKM/assets/ams_hom_A.jpg +3 -0
- third_party/DKM/assets/ams_hom_B.jpg +3 -0
- third_party/DKM/assets/dkmv3_warp.jpg +3 -0
- third_party/DKM/assets/mega_8_scenes_0008_0.1_0.3.npz +3 -0
- third_party/DKM/assets/mega_8_scenes_0008_0.3_0.5.npz +3 -0
- third_party/DKM/assets/mega_8_scenes_0019_0.1_0.3.npz +3 -0
- third_party/DKM/assets/mega_8_scenes_0019_0.3_0.5.npz +3 -0
- third_party/DKM/assets/mega_8_scenes_0021_0.1_0.3.npz +3 -0
- third_party/DKM/assets/mega_8_scenes_0021_0.3_0.5.npz +3 -0
- third_party/DKM/assets/mega_8_scenes_0024_0.1_0.3.npz +3 -0
- third_party/DKM/assets/mega_8_scenes_0024_0.3_0.5.npz +3 -0
- third_party/DKM/assets/mega_8_scenes_0025_0.1_0.3.npz +3 -0
- third_party/DKM/assets/mega_8_scenes_0025_0.3_0.5.npz +3 -0
- third_party/DKM/assets/mega_8_scenes_0032_0.1_0.3.npz +3 -0
- third_party/DKM/assets/mega_8_scenes_0032_0.3_0.5.npz +3 -0
- third_party/DKM/assets/mega_8_scenes_0063_0.1_0.3.npz +3 -0
- third_party/DKM/assets/mega_8_scenes_0063_0.3_0.5.npz +3 -0
- third_party/DKM/assets/mega_8_scenes_1589_0.1_0.3.npz +3 -0
- third_party/DKM/assets/mega_8_scenes_1589_0.3_0.5.npz +3 -0
- third_party/DKM/assets/mount_rushmore.mp4 +0 -0
- third_party/DKM/assets/sacre_coeur_A.jpg +3 -0
- third_party/DKM/assets/sacre_coeur_B.jpg +3 -0
- third_party/DKM/data/.gitignore +2 -0
- third_party/DKM/demo/.gitignore +1 -0
- third_party/DKM/demo/demo_fundamental.py +37 -0
- third_party/DKM/demo/demo_match.py +48 -0
- third_party/DKM/dkm/__init__.py +4 -0
- third_party/DKM/dkm/benchmarks/__init__.py +4 -0
- third_party/DKM/dkm/benchmarks/deprecated/hpatches_sequences_dense_benchmark.py +100 -0
- third_party/DKM/dkm/benchmarks/deprecated/yfcc100m_benchmark.py +119 -0
- third_party/DKM/dkm/benchmarks/hpatches_sequences_homog_benchmark.py +114 -0
- third_party/DKM/dkm/benchmarks/megadepth1500_benchmark.py +124 -0
- third_party/DKM/dkm/benchmarks/megadepth_dense_benchmark.py +86 -0
- third_party/DKM/dkm/benchmarks/scannet_benchmark.py +143 -0
- third_party/DKM/dkm/checkpointing/__init__.py +1 -0
- third_party/DKM/dkm/checkpointing/checkpoint.py +31 -0
- third_party/DKM/dkm/datasets/__init__.py +1 -0
- third_party/DKM/dkm/datasets/megadepth.py +177 -0
- third_party/DKM/dkm/datasets/scannet.py +151 -0
- third_party/DKM/dkm/losses/__init__.py +1 -0
- third_party/DKM/dkm/losses/depth_match_regression_loss.py +128 -0
- third_party/DKM/dkm/models/__init__.py +4 -0
- third_party/DKM/dkm/models/deprecated/build_model.py +787 -0
- third_party/DKM/dkm/models/deprecated/corr_channels.py +34 -0
- third_party/DKM/dkm/models/deprecated/local_corr.py +630 -0
- third_party/DKM/dkm/models/dkm.py +759 -0
- third_party/DKM/dkm/models/encoders.py +147 -0
third_party/DKM/.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
*.egg-info*
|
2 |
+
*.vscode*
|
3 |
+
*__pycache__*
|
third_party/DKM/LICENSE
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
NOTE! Models trained on our synthetic dataset uses datasets which are licensed under non-commercial licenses.
|
2 |
+
Hence we cannot provide them under the MIT license. However, MegaDepth is under MIT license, hence we provide those models under MIT license, see below.
|
3 |
+
|
4 |
+
|
5 |
+
License for Models Trained on MegaDepth ONLY below:
|
6 |
+
|
7 |
+
Copyright (c) 2022 Johan Edstedt
|
8 |
+
|
9 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
10 |
+
of this software and associated documentation files (the "Software"), to deal
|
11 |
+
in the Software without restriction, including without limitation the rights
|
12 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
13 |
+
copies of the Software, and to permit persons to whom the Software is
|
14 |
+
furnished to do so, subject to the following conditions:
|
15 |
+
|
16 |
+
The above copyright notice and this permission notice shall be included in all
|
17 |
+
copies or substantial portions of the Software.
|
18 |
+
|
19 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
20 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
21 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
22 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
23 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
24 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
25 |
+
SOFTWARE.
|
third_party/DKM/README.md
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DKM: Dense Kernelized Feature Matching for Geometry Estimation
|
2 |
+
### [Project Page](https://parskatt.github.io/DKM) | [Paper](https://arxiv.org/abs/2202.00667)
|
3 |
+
<br/>
|
4 |
+
|
5 |
+
> DKM: Dense Kernelized Feature Matching for Geometry Estimation
|
6 |
+
> [Johan Edstedt](https://scholar.google.com/citations?user=Ul-vMR0AAAAJ), [Ioannis Athanasiadis](https://scholar.google.com/citations?user=RCAtJgUAAAAJ), [Mårten Wadenbäck](https://scholar.google.com/citations?user=6WRQpCQAAAAJ), [Michael Felsberg](https://scholar.google.com/citations?&user=lkWfR08AAAAJ)
|
7 |
+
> CVPR 2023
|
8 |
+
|
9 |
+
## How to Use?
|
10 |
+
<details>
|
11 |
+
Our model produces a dense (for all pixels) warp and certainty.
|
12 |
+
|
13 |
+
Warp: [B,H,W,4] for all images in batch of size B, for each pixel HxW, we ouput the input and matching coordinate in the normalized grids [-1,1]x[-1,1].
|
14 |
+
|
15 |
+
Certainty: [B,H,W] a number in each pixel indicating the matchability of the pixel.
|
16 |
+
|
17 |
+
See [demo](dkm/demo/) for two demos of DKM.
|
18 |
+
|
19 |
+
See [api.md](docs/api.md) for API.
|
20 |
+
</details>
|
21 |
+
|
22 |
+
## Qualitative Results
|
23 |
+
<details>
|
24 |
+
|
25 |
+
https://user-images.githubusercontent.com/22053118/223748279-0f0c21b4-376a-440a-81f5-7f9a5d87483f.mp4
|
26 |
+
|
27 |
+
|
28 |
+
https://user-images.githubusercontent.com/22053118/223748512-1bca4a17-cffa-491d-a448-96aac1353ce9.mp4
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
https://user-images.githubusercontent.com/22053118/223748518-4d475d9f-a933-4581-97ed-6e9413c4caca.mp4
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
https://user-images.githubusercontent.com/22053118/223748522-39c20631-aa16-4954-9c27-95763b38f2ce.mp4
|
37 |
+
|
38 |
+
|
39 |
+
</details>
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
## Benchmark Results
|
44 |
+
|
45 |
+
<details>
|
46 |
+
|
47 |
+
### Megadepth1500
|
48 |
+
|
49 |
+
| | @5 | @10 | @20 |
|
50 |
+
|-------|-------|------|------|
|
51 |
+
| DKMv1 | 54.5 | 70.7 | 82.3 |
|
52 |
+
| DKMv2 | *56.8* | *72.3* | *83.2* |
|
53 |
+
| DKMv3 (paper) | **60.5** | **74.9** | **85.1** |
|
54 |
+
| DKMv3 (this repo) | **60.0** | **74.6** | **84.9** |
|
55 |
+
|
56 |
+
### Megadepth 8 Scenes
|
57 |
+
| | @5 | @10 | @20 |
|
58 |
+
|-------|-------|------|------|
|
59 |
+
| DKMv3 (paper) | **60.5** | **74.5** | **84.2** |
|
60 |
+
| DKMv3 (this repo) | **60.4** | **74.6** | **84.3** |
|
61 |
+
|
62 |
+
|
63 |
+
### ScanNet1500
|
64 |
+
| | @5 | @10 | @20 |
|
65 |
+
|-------|-------|------|------|
|
66 |
+
| DKMv1 | 24.8 | 44.4 | 61.9 |
|
67 |
+
| DKMv2 | *28.2* | *49.2* | *66.6* |
|
68 |
+
| DKMv3 (paper) | **29.4** | **50.7** | **68.3** |
|
69 |
+
| DKMv3 (this repo) | **29.8** | **50.8** | **68.3** |
|
70 |
+
|
71 |
+
</details>
|
72 |
+
|
73 |
+
## Navigating the Code
|
74 |
+
* Code for models can be found in [dkm/models](dkm/models/)
|
75 |
+
* Code for benchmarks can be found in [dkm/benchmarks](dkm/benchmarks/)
|
76 |
+
* Code for reproducing experiments from our paper can be found in [experiments/](experiments/)
|
77 |
+
|
78 |
+
## Install
|
79 |
+
Run ``pip install -e .``
|
80 |
+
|
81 |
+
## Demo
|
82 |
+
|
83 |
+
A demonstration of our method can be run by:
|
84 |
+
``` bash
|
85 |
+
python demo_match.py
|
86 |
+
```
|
87 |
+
This runs our model trained on mega on two images taken from Sacre Coeur.
|
88 |
+
|
89 |
+
## Benchmarks
|
90 |
+
See [Benchmarks](docs/benchmarks.md) for details.
|
91 |
+
## Training
|
92 |
+
See [Training](docs/training.md) for details.
|
93 |
+
## Reproducing Results
|
94 |
+
Given that the required benchmark or training dataset has been downloaded and unpacked, results can be reproduced by running the experiments in the experiments folder.
|
95 |
+
|
96 |
+
## Using DKM matches for estimation
|
97 |
+
We recommend using the excellent Graph-Cut RANSAC algorithm: https://github.com/danini/graph-cut-ransac
|
98 |
+
|
99 |
+
| | @5 | @10 | @20 |
|
100 |
+
|-------|-------|------|------|
|
101 |
+
| DKMv3 (RANSAC) | *60.5* | *74.9* | *85.1* |
|
102 |
+
| DKMv3 (GC-RANSAC) | **65.5** | **78.0** | **86.7** |
|
103 |
+
|
104 |
+
|
105 |
+
## Acknowledgements
|
106 |
+
We have used code and been inspired by https://github.com/PruneTruong/DenseMatching, https://github.com/zju3dv/LoFTR, and https://github.com/GrumpyZhou/patch2pix. We additionally thank the authors of ECO-TR for providing their benchmark.
|
107 |
+
|
108 |
+
## BibTeX
|
109 |
+
If you find our models useful, please consider citing our paper!
|
110 |
+
```
|
111 |
+
@inproceedings{edstedt2023dkm,
|
112 |
+
title={{DKM}: Dense Kernelized Feature Matching for Geometry Estimation},
|
113 |
+
author={Edstedt, Johan and Athanasiadis, Ioannis and Wadenbäck, Mårten and Felsberg, Michael},
|
114 |
+
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
115 |
+
year={2023}
|
116 |
+
}
|
117 |
+
```
|
third_party/DKM/assets/ams_hom_A.jpg
ADDED
Git LFS Details
|
third_party/DKM/assets/ams_hom_B.jpg
ADDED
Git LFS Details
|
third_party/DKM/assets/dkmv3_warp.jpg
ADDED
Git LFS Details
|
third_party/DKM/assets/mega_8_scenes_0008_0.1_0.3.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c902547181fc9b370fdd16272140be6803fe983aea978c68683db803ac70dd57
|
3 |
+
size 906160
|
third_party/DKM/assets/mega_8_scenes_0008_0.3_0.5.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:65ce02bd248988b42363ccd257abaa9b99a00d569d2779597b36ba6c4da35021
|
3 |
+
size 906160
|
third_party/DKM/assets/mega_8_scenes_0019_0.1_0.3.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c6104feb8807a4ebdd1266160e67b3c507c550012f54c23292d0ebf99b88753f
|
3 |
+
size 368192
|
third_party/DKM/assets/mega_8_scenes_0019_0.3_0.5.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9600ba5c24d414f63728bf5ee7550a3b035d7c615461e357590890ae0e0f042e
|
3 |
+
size 368192
|
third_party/DKM/assets/mega_8_scenes_0021_0.1_0.3.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:de89e9ccf10515cc4196ba1e7172ec98b2fb92ff9f85d90db5df1af5b6503313
|
3 |
+
size 167528
|
third_party/DKM/assets/mega_8_scenes_0021_0.3_0.5.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:94c97b57beb10411b3b98a1e88c9e1e2f9db51994dce04580d2b7cfc8919dab3
|
3 |
+
size 167528
|
third_party/DKM/assets/mega_8_scenes_0024_0.1_0.3.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9f14a66dbbd7fa8f31756dd496bfabe4c3ea115c6914acad9365dd02e46ae674
|
3 |
+
size 63909
|
third_party/DKM/assets/mega_8_scenes_0024_0.3_0.5.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dfaee333beccd1da0d920777cdc8f17d584b21ba20f675b39222c5a205acf72a
|
3 |
+
size 63909
|
third_party/DKM/assets/mega_8_scenes_0025_0.1_0.3.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b446ca3cc2073c8a3963cf68cc450ef2ebf73d2b956b1f5ae6b37621bc67cce4
|
3 |
+
size 200371
|
third_party/DKM/assets/mega_8_scenes_0025_0.3_0.5.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:df1969fd94032562b5e8d916467101a878168d586b795770e64c108bab250c9e
|
3 |
+
size 200371
|
third_party/DKM/assets/mega_8_scenes_0032_0.1_0.3.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:37cbafa0b0f981f5d69aba202ddd37c5892bda0fa13d053a3ad27d6ddad51c16
|
3 |
+
size 642823
|
third_party/DKM/assets/mega_8_scenes_0032_0.3_0.5.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:962b3fadc7c94ea8e4a1bb5e168e72b0b6cc474ae56b0aee70ba4e517553fbcf
|
3 |
+
size 642823
|
third_party/DKM/assets/mega_8_scenes_0063_0.1_0.3.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:50ed6b02dff2fa719e4e9ca216b4704b82cbbefd127355d3ba7120828407e723
|
3 |
+
size 228647
|
third_party/DKM/assets/mega_8_scenes_0063_0.3_0.5.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:edc97129d000f0478020495f646f2fa7667408247ccd11054e02efbbb38d1444
|
3 |
+
size 228647
|
third_party/DKM/assets/mega_8_scenes_1589_0.1_0.3.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:04b0b6c6adff812e12b66476f7ca2a6ed2564cdd8208ec0c775f7b922f160103
|
3 |
+
size 177063
|
third_party/DKM/assets/mega_8_scenes_1589_0.3_0.5.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ae931f8cac1b2168f699c70efe42c215eaff27d3f0617d59afb3db183c9b1848
|
3 |
+
size 177063
|
third_party/DKM/assets/mount_rushmore.mp4
ADDED
Binary file (986 kB). View file
|
|
third_party/DKM/assets/sacre_coeur_A.jpg
ADDED
Git LFS Details
|
third_party/DKM/assets/sacre_coeur_B.jpg
ADDED
Git LFS Details
|
third_party/DKM/data/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
third_party/DKM/demo/.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
*.jpg
|
third_party/DKM/demo/demo_fundamental.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from dkm.utils.utils import tensor_to_pil
|
6 |
+
import cv2
|
7 |
+
from dkm import DKMv3_outdoor
|
8 |
+
|
9 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
10 |
+
|
11 |
+
|
12 |
+
if __name__ == "__main__":
|
13 |
+
from argparse import ArgumentParser
|
14 |
+
parser = ArgumentParser()
|
15 |
+
parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
|
16 |
+
parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
|
17 |
+
|
18 |
+
args, _ = parser.parse_known_args()
|
19 |
+
im1_path = args.im_A_path
|
20 |
+
im2_path = args.im_B_path
|
21 |
+
|
22 |
+
# Create model
|
23 |
+
dkm_model = DKMv3_outdoor(device=device)
|
24 |
+
|
25 |
+
|
26 |
+
W_A, H_A = Image.open(im1_path).size
|
27 |
+
W_B, H_B = Image.open(im2_path).size
|
28 |
+
|
29 |
+
# Match
|
30 |
+
warp, certainty = dkm_model.match(im1_path, im2_path, device=device)
|
31 |
+
# Sample matches for estimation
|
32 |
+
matches, certainty = dkm_model.sample(warp, certainty)
|
33 |
+
kpts1, kpts2 = dkm_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
|
34 |
+
F, mask = cv2.findFundamentalMat(
|
35 |
+
kpts1.cpu().numpy(), kpts2.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
|
36 |
+
)
|
37 |
+
# TODO: some better visualization
|
third_party/DKM/demo/demo_match.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from dkm.utils.utils import tensor_to_pil
|
6 |
+
|
7 |
+
from dkm import DKMv3_outdoor
|
8 |
+
|
9 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
10 |
+
|
11 |
+
|
12 |
+
if __name__ == "__main__":
|
13 |
+
from argparse import ArgumentParser
|
14 |
+
parser = ArgumentParser()
|
15 |
+
parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
|
16 |
+
parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
|
17 |
+
parser.add_argument("--save_path", default="demo/dkmv3_warp_sacre_coeur.jpg", type=str)
|
18 |
+
|
19 |
+
args, _ = parser.parse_known_args()
|
20 |
+
im1_path = args.im_A_path
|
21 |
+
im2_path = args.im_B_path
|
22 |
+
save_path = args.save_path
|
23 |
+
|
24 |
+
# Create model
|
25 |
+
dkm_model = DKMv3_outdoor(device=device)
|
26 |
+
|
27 |
+
H, W = 864, 1152
|
28 |
+
|
29 |
+
im1 = Image.open(im1_path).resize((W, H))
|
30 |
+
im2 = Image.open(im2_path).resize((W, H))
|
31 |
+
|
32 |
+
# Match
|
33 |
+
warp, certainty = dkm_model.match(im1_path, im2_path, device=device)
|
34 |
+
# Sampling not needed, but can be done with model.sample(warp, certainty)
|
35 |
+
dkm_model.sample(warp, certainty)
|
36 |
+
x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
|
37 |
+
x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
|
38 |
+
|
39 |
+
im2_transfer_rgb = F.grid_sample(
|
40 |
+
x2[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
|
41 |
+
)[0]
|
42 |
+
im1_transfer_rgb = F.grid_sample(
|
43 |
+
x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
|
44 |
+
)[0]
|
45 |
+
warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2)
|
46 |
+
white_im = torch.ones((H,2*W),device=device)
|
47 |
+
vis_im = certainty * warp_im + (1 - certainty) * white_im
|
48 |
+
tensor_to_pil(vis_im, unnormalize=False).save(save_path)
|
third_party/DKM/dkm/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .models import (
|
2 |
+
DKMv3_outdoor,
|
3 |
+
DKMv3_indoor,
|
4 |
+
)
|
third_party/DKM/dkm/benchmarks/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .hpatches_sequences_homog_benchmark import HpatchesHomogBenchmark
|
2 |
+
from .scannet_benchmark import ScanNetBenchmark
|
3 |
+
from .megadepth1500_benchmark import Megadepth1500Benchmark
|
4 |
+
from .megadepth_dense_benchmark import MegadepthDenseBenchmark
|
third_party/DKM/dkm/benchmarks/deprecated/hpatches_sequences_dense_benchmark.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import os
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from dkm.utils import *
|
10 |
+
|
11 |
+
|
12 |
+
class HpatchesDenseBenchmark:
|
13 |
+
"""WARNING: HPATCHES grid goes from [0,n-1] instead of [0.5,n-0.5]"""
|
14 |
+
|
15 |
+
def __init__(self, dataset_path) -> None:
|
16 |
+
seqs_dir = "hpatches-sequences-release"
|
17 |
+
self.seqs_path = os.path.join(dataset_path, seqs_dir)
|
18 |
+
self.seq_names = sorted(os.listdir(self.seqs_path))
|
19 |
+
|
20 |
+
def convert_coordinates(self, query_coords, query_to_support, wq, hq, wsup, hsup):
|
21 |
+
# Get matches in output format on the grid [0, n] where the center of the top-left coordinate is [0.5, 0.5]
|
22 |
+
offset = (
|
23 |
+
0.5 # Hpatches assumes that the center of the top-left pixel is at [0,0]
|
24 |
+
)
|
25 |
+
query_coords = (
|
26 |
+
torch.stack(
|
27 |
+
(
|
28 |
+
wq * (query_coords[..., 0] + 1) / 2,
|
29 |
+
hq * (query_coords[..., 1] + 1) / 2,
|
30 |
+
),
|
31 |
+
axis=-1,
|
32 |
+
)
|
33 |
+
- offset
|
34 |
+
)
|
35 |
+
query_to_support = (
|
36 |
+
torch.stack(
|
37 |
+
(
|
38 |
+
wsup * (query_to_support[..., 0] + 1) / 2,
|
39 |
+
hsup * (query_to_support[..., 1] + 1) / 2,
|
40 |
+
),
|
41 |
+
axis=-1,
|
42 |
+
)
|
43 |
+
- offset
|
44 |
+
)
|
45 |
+
return query_coords, query_to_support
|
46 |
+
|
47 |
+
def inside_image(self, x, w, h):
|
48 |
+
return torch.logical_and(
|
49 |
+
x[:, 0] < (w - 1),
|
50 |
+
torch.logical_and(x[:, 1] < (h - 1), (x > 0).prod(dim=-1)),
|
51 |
+
)
|
52 |
+
|
53 |
+
def benchmark(self, model):
|
54 |
+
use_cuda = torch.cuda.is_available()
|
55 |
+
device = torch.device("cuda:0" if use_cuda else "cpu")
|
56 |
+
aepes = []
|
57 |
+
pcks = []
|
58 |
+
for seq_idx, seq_name in tqdm(
|
59 |
+
enumerate(self.seq_names), total=len(self.seq_names)
|
60 |
+
):
|
61 |
+
if seq_name[0] == "i":
|
62 |
+
continue
|
63 |
+
im1_path = os.path.join(self.seqs_path, seq_name, "1.ppm")
|
64 |
+
im1 = Image.open(im1_path)
|
65 |
+
w1, h1 = im1.size
|
66 |
+
for im_idx in range(2, 7):
|
67 |
+
im2_path = os.path.join(self.seqs_path, seq_name, f"{im_idx}.ppm")
|
68 |
+
im2 = Image.open(im2_path)
|
69 |
+
w2, h2 = im2.size
|
70 |
+
matches, certainty = model.match(im2, im1, do_pred_in_og_res=True)
|
71 |
+
matches, certainty = matches.reshape(-1, 4), certainty.reshape(-1)
|
72 |
+
inv_homography = torch.from_numpy(
|
73 |
+
np.loadtxt(
|
74 |
+
os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx))
|
75 |
+
)
|
76 |
+
).to(device)
|
77 |
+
homography = torch.linalg.inv(inv_homography)
|
78 |
+
pos_a, pos_b = self.convert_coordinates(
|
79 |
+
matches[:, :2], matches[:, 2:], w2, h2, w1, h1
|
80 |
+
)
|
81 |
+
pos_a, pos_b = pos_a.double(), pos_b.double()
|
82 |
+
pos_a_h = torch.cat(
|
83 |
+
[pos_a, torch.ones([pos_a.shape[0], 1], device=device)], dim=1
|
84 |
+
)
|
85 |
+
pos_b_proj_h = (homography @ pos_a_h.t()).t()
|
86 |
+
pos_b_proj = pos_b_proj_h[:, :2] / pos_b_proj_h[:, 2:]
|
87 |
+
mask = self.inside_image(pos_b_proj, w1, h1)
|
88 |
+
residual = pos_b - pos_b_proj
|
89 |
+
dist = (residual**2).sum(dim=1).sqrt()[mask]
|
90 |
+
aepes.append(torch.mean(dist).item())
|
91 |
+
pck1 = (dist < 1.0).float().mean().item()
|
92 |
+
pck3 = (dist < 3.0).float().mean().item()
|
93 |
+
pck5 = (dist < 5.0).float().mean().item()
|
94 |
+
pcks.append([pck1, pck3, pck5])
|
95 |
+
m_pcks = np.mean(np.array(pcks), axis=0)
|
96 |
+
return {
|
97 |
+
"hp_pck1": m_pcks[0],
|
98 |
+
"hp_pck3": m_pcks[1],
|
99 |
+
"hp_pck5": m_pcks[2],
|
100 |
+
}
|
third_party/DKM/dkm/benchmarks/deprecated/yfcc100m_benchmark.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import h5py
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from dkm.utils import *
|
6 |
+
from PIL import Image
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
|
10 |
+
class Yfcc100mBenchmark:
|
11 |
+
def __init__(self, data_root="data/yfcc100m_test") -> None:
|
12 |
+
self.scenes = [
|
13 |
+
"buckingham_palace",
|
14 |
+
"notre_dame_front_facade",
|
15 |
+
"reichstag",
|
16 |
+
"sacre_coeur",
|
17 |
+
]
|
18 |
+
self.data_root = data_root
|
19 |
+
|
20 |
+
def benchmark(self, model, r=2):
|
21 |
+
model.train(False)
|
22 |
+
with torch.no_grad():
|
23 |
+
data_root = self.data_root
|
24 |
+
meta_info = open(
|
25 |
+
f"{data_root}/yfcc_test_pairs_with_gt.txt", "r"
|
26 |
+
).readlines()
|
27 |
+
tot_e_t, tot_e_R, tot_e_pose = [], [], []
|
28 |
+
for scene_ind in range(len(self.scenes)):
|
29 |
+
scene = self.scenes[scene_ind]
|
30 |
+
pairs = np.array(
|
31 |
+
pickle.load(
|
32 |
+
open(f"{data_root}/pairs/{scene}-te-1000-pairs.pkl", "rb")
|
33 |
+
)
|
34 |
+
)
|
35 |
+
scene_dir = f"{data_root}/yfcc100m/{scene}/test/"
|
36 |
+
calibs = open(scene_dir + "calibration.txt", "r").read().split("\n")
|
37 |
+
images = open(scene_dir + "images.txt", "r").read().split("\n")
|
38 |
+
pair_inds = np.random.choice(
|
39 |
+
range(len(pairs)), size=len(pairs), replace=False
|
40 |
+
)
|
41 |
+
for pairind in tqdm(pair_inds):
|
42 |
+
idx1, idx2 = pairs[pairind]
|
43 |
+
params = meta_info[1000 * scene_ind + pairind].split()
|
44 |
+
rot1, rot2 = int(params[2]), int(params[3])
|
45 |
+
calib1 = h5py.File(scene_dir + calibs[idx1], "r")
|
46 |
+
K1, R1, t1, _, _ = get_pose(calib1)
|
47 |
+
calib2 = h5py.File(scene_dir + calibs[idx2], "r")
|
48 |
+
K2, R2, t2, _, _ = get_pose(calib2)
|
49 |
+
|
50 |
+
R, t = compute_relative_pose(R1, t1, R2, t2)
|
51 |
+
im1 = images[idx1]
|
52 |
+
im2 = images[idx2]
|
53 |
+
im1 = Image.open(scene_dir + im1).rotate(rot1 * 90, expand=True)
|
54 |
+
w1, h1 = im1.size
|
55 |
+
im2 = Image.open(scene_dir + im2).rotate(rot2 * 90, expand=True)
|
56 |
+
w2, h2 = im2.size
|
57 |
+
K1 = rotate_intrinsic(K1, rot1)
|
58 |
+
K2 = rotate_intrinsic(K2, rot2)
|
59 |
+
|
60 |
+
dense_matches, dense_certainty = model.match(im1, im2)
|
61 |
+
dense_certainty = dense_certainty ** (1 / r)
|
62 |
+
sparse_matches, sparse_confidence = model.sample(
|
63 |
+
dense_matches, dense_certainty, 10000
|
64 |
+
)
|
65 |
+
scale1 = 480 / min(w1, h1)
|
66 |
+
scale2 = 480 / min(w2, h2)
|
67 |
+
w1, h1 = scale1 * w1, scale1 * h1
|
68 |
+
w2, h2 = scale2 * w2, scale2 * h2
|
69 |
+
K1 = K1 * scale1
|
70 |
+
K2 = K2 * scale2
|
71 |
+
|
72 |
+
kpts1 = sparse_matches[:, :2]
|
73 |
+
kpts1 = np.stack(
|
74 |
+
(w1 * kpts1[:, 0] / 2, h1 * kpts1[:, 1] / 2), axis=-1
|
75 |
+
)
|
76 |
+
kpts2 = sparse_matches[:, 2:]
|
77 |
+
kpts2 = np.stack(
|
78 |
+
(w2 * kpts2[:, 0] / 2, h2 * kpts2[:, 1] / 2), axis=-1
|
79 |
+
)
|
80 |
+
try:
|
81 |
+
threshold = 1.0
|
82 |
+
norm_threshold = threshold / (
|
83 |
+
np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2]))
|
84 |
+
)
|
85 |
+
R_est, t_est, mask = estimate_pose(
|
86 |
+
kpts1,
|
87 |
+
kpts2,
|
88 |
+
K1[:2, :2],
|
89 |
+
K2[:2, :2],
|
90 |
+
norm_threshold,
|
91 |
+
conf=0.9999999,
|
92 |
+
)
|
93 |
+
T1_to_2 = np.concatenate((R_est, t_est), axis=-1) #
|
94 |
+
e_t, e_R = compute_pose_error(T1_to_2, R, t)
|
95 |
+
e_pose = max(e_t, e_R)
|
96 |
+
except:
|
97 |
+
e_t, e_R = 90, 90
|
98 |
+
e_pose = max(e_t, e_R)
|
99 |
+
tot_e_t.append(e_t)
|
100 |
+
tot_e_R.append(e_R)
|
101 |
+
tot_e_pose.append(e_pose)
|
102 |
+
tot_e_pose = np.array(tot_e_pose)
|
103 |
+
thresholds = [5, 10, 20]
|
104 |
+
auc = pose_auc(tot_e_pose, thresholds)
|
105 |
+
acc_5 = (tot_e_pose < 5).mean()
|
106 |
+
acc_10 = (tot_e_pose < 10).mean()
|
107 |
+
acc_15 = (tot_e_pose < 15).mean()
|
108 |
+
acc_20 = (tot_e_pose < 20).mean()
|
109 |
+
map_5 = acc_5
|
110 |
+
map_10 = np.mean([acc_5, acc_10])
|
111 |
+
map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
|
112 |
+
return {
|
113 |
+
"auc_5": auc[0],
|
114 |
+
"auc_10": auc[1],
|
115 |
+
"auc_20": auc[2],
|
116 |
+
"map_5": map_5,
|
117 |
+
"map_10": map_10,
|
118 |
+
"map_20": map_20,
|
119 |
+
}
|
third_party/DKM/dkm/benchmarks/hpatches_sequences_homog_benchmark.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import os
|
5 |
+
|
6 |
+
from tqdm import tqdm
|
7 |
+
from dkm.utils import pose_auc
|
8 |
+
import cv2
|
9 |
+
|
10 |
+
|
11 |
+
class HpatchesHomogBenchmark:
|
12 |
+
"""Hpatches grid goes from [0,n-1] instead of [0.5,n-0.5]"""
|
13 |
+
|
14 |
+
def __init__(self, dataset_path) -> None:
|
15 |
+
seqs_dir = "hpatches-sequences-release"
|
16 |
+
self.seqs_path = os.path.join(dataset_path, seqs_dir)
|
17 |
+
self.seq_names = sorted(os.listdir(self.seqs_path))
|
18 |
+
# Ignore seqs is same as LoFTR.
|
19 |
+
self.ignore_seqs = set(
|
20 |
+
[
|
21 |
+
"i_contruction",
|
22 |
+
"i_crownnight",
|
23 |
+
"i_dc",
|
24 |
+
"i_pencils",
|
25 |
+
"i_whitebuilding",
|
26 |
+
"v_artisans",
|
27 |
+
"v_astronautis",
|
28 |
+
"v_talent",
|
29 |
+
]
|
30 |
+
)
|
31 |
+
|
32 |
+
def convert_coordinates(self, query_coords, query_to_support, wq, hq, wsup, hsup):
|
33 |
+
offset = 0.5 # Hpatches assumes that the center of the top-left pixel is at [0,0] (I think)
|
34 |
+
query_coords = (
|
35 |
+
np.stack(
|
36 |
+
(
|
37 |
+
wq * (query_coords[..., 0] + 1) / 2,
|
38 |
+
hq * (query_coords[..., 1] + 1) / 2,
|
39 |
+
),
|
40 |
+
axis=-1,
|
41 |
+
)
|
42 |
+
- offset
|
43 |
+
)
|
44 |
+
query_to_support = (
|
45 |
+
np.stack(
|
46 |
+
(
|
47 |
+
wsup * (query_to_support[..., 0] + 1) / 2,
|
48 |
+
hsup * (query_to_support[..., 1] + 1) / 2,
|
49 |
+
),
|
50 |
+
axis=-1,
|
51 |
+
)
|
52 |
+
- offset
|
53 |
+
)
|
54 |
+
return query_coords, query_to_support
|
55 |
+
|
56 |
+
def benchmark(self, model, model_name = None):
|
57 |
+
n_matches = []
|
58 |
+
homog_dists = []
|
59 |
+
for seq_idx, seq_name in tqdm(
|
60 |
+
enumerate(self.seq_names), total=len(self.seq_names)
|
61 |
+
):
|
62 |
+
if seq_name in self.ignore_seqs:
|
63 |
+
continue
|
64 |
+
im1_path = os.path.join(self.seqs_path, seq_name, "1.ppm")
|
65 |
+
im1 = Image.open(im1_path)
|
66 |
+
w1, h1 = im1.size
|
67 |
+
for im_idx in range(2, 7):
|
68 |
+
im2_path = os.path.join(self.seqs_path, seq_name, f"{im_idx}.ppm")
|
69 |
+
im2 = Image.open(im2_path)
|
70 |
+
w2, h2 = im2.size
|
71 |
+
H = np.loadtxt(
|
72 |
+
os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx))
|
73 |
+
)
|
74 |
+
dense_matches, dense_certainty = model.match(
|
75 |
+
im1_path, im2_path
|
76 |
+
)
|
77 |
+
good_matches, _ = model.sample(dense_matches, dense_certainty, 5000)
|
78 |
+
pos_a, pos_b = self.convert_coordinates(
|
79 |
+
good_matches[:, :2], good_matches[:, 2:], w1, h1, w2, h2
|
80 |
+
)
|
81 |
+
try:
|
82 |
+
H_pred, inliers = cv2.findHomography(
|
83 |
+
pos_a,
|
84 |
+
pos_b,
|
85 |
+
method = cv2.RANSAC,
|
86 |
+
confidence = 0.99999,
|
87 |
+
ransacReprojThreshold = 3 * min(w2, h2) / 480,
|
88 |
+
)
|
89 |
+
except:
|
90 |
+
H_pred = None
|
91 |
+
if H_pred is None:
|
92 |
+
H_pred = np.zeros((3, 3))
|
93 |
+
H_pred[2, 2] = 1.0
|
94 |
+
corners = np.array(
|
95 |
+
[[0, 0, 1], [0, h1 - 1, 1], [w1 - 1, 0, 1], [w1 - 1, h1 - 1, 1]]
|
96 |
+
)
|
97 |
+
real_warped_corners = np.dot(corners, np.transpose(H))
|
98 |
+
real_warped_corners = (
|
99 |
+
real_warped_corners[:, :2] / real_warped_corners[:, 2:]
|
100 |
+
)
|
101 |
+
warped_corners = np.dot(corners, np.transpose(H_pred))
|
102 |
+
warped_corners = warped_corners[:, :2] / warped_corners[:, 2:]
|
103 |
+
mean_dist = np.mean(
|
104 |
+
np.linalg.norm(real_warped_corners - warped_corners, axis=1)
|
105 |
+
) / (min(w2, h2) / 480.0)
|
106 |
+
homog_dists.append(mean_dist)
|
107 |
+
n_matches = np.array(n_matches)
|
108 |
+
thresholds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
109 |
+
auc = pose_auc(np.array(homog_dists), thresholds)
|
110 |
+
return {
|
111 |
+
"hpatches_homog_auc_3": auc[2],
|
112 |
+
"hpatches_homog_auc_5": auc[4],
|
113 |
+
"hpatches_homog_auc_10": auc[9],
|
114 |
+
}
|
third_party/DKM/dkm/benchmarks/megadepth1500_benchmark.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from dkm.utils import *
|
4 |
+
from PIL import Image
|
5 |
+
from tqdm import tqdm
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
class Megadepth1500Benchmark:
|
9 |
+
def __init__(self, data_root="data/megadepth", scene_names = None) -> None:
|
10 |
+
if scene_names is None:
|
11 |
+
self.scene_names = [
|
12 |
+
"0015_0.1_0.3.npz",
|
13 |
+
"0015_0.3_0.5.npz",
|
14 |
+
"0022_0.1_0.3.npz",
|
15 |
+
"0022_0.3_0.5.npz",
|
16 |
+
"0022_0.5_0.7.npz",
|
17 |
+
]
|
18 |
+
else:
|
19 |
+
self.scene_names = scene_names
|
20 |
+
self.scenes = [
|
21 |
+
np.load(f"{data_root}/{scene}", allow_pickle=True)
|
22 |
+
for scene in self.scene_names
|
23 |
+
]
|
24 |
+
self.data_root = data_root
|
25 |
+
|
26 |
+
def benchmark(self, model):
|
27 |
+
with torch.no_grad():
|
28 |
+
data_root = self.data_root
|
29 |
+
tot_e_t, tot_e_R, tot_e_pose = [], [], []
|
30 |
+
for scene_ind in range(len(self.scenes)):
|
31 |
+
scene = self.scenes[scene_ind]
|
32 |
+
pairs = scene["pair_infos"]
|
33 |
+
intrinsics = scene["intrinsics"]
|
34 |
+
poses = scene["poses"]
|
35 |
+
im_paths = scene["image_paths"]
|
36 |
+
pair_inds = range(len(pairs))
|
37 |
+
for pairind in tqdm(pair_inds):
|
38 |
+
idx1, idx2 = pairs[pairind][0]
|
39 |
+
K1 = intrinsics[idx1].copy()
|
40 |
+
T1 = poses[idx1].copy()
|
41 |
+
R1, t1 = T1[:3, :3], T1[:3, 3]
|
42 |
+
K2 = intrinsics[idx2].copy()
|
43 |
+
T2 = poses[idx2].copy()
|
44 |
+
R2, t2 = T2[:3, :3], T2[:3, 3]
|
45 |
+
R, t = compute_relative_pose(R1, t1, R2, t2)
|
46 |
+
im1_path = f"{data_root}/{im_paths[idx1]}"
|
47 |
+
im2_path = f"{data_root}/{im_paths[idx2]}"
|
48 |
+
im1 = Image.open(im1_path)
|
49 |
+
w1, h1 = im1.size
|
50 |
+
im2 = Image.open(im2_path)
|
51 |
+
w2, h2 = im2.size
|
52 |
+
scale1 = 1200 / max(w1, h1)
|
53 |
+
scale2 = 1200 / max(w2, h2)
|
54 |
+
w1, h1 = scale1 * w1, scale1 * h1
|
55 |
+
w2, h2 = scale2 * w2, scale2 * h2
|
56 |
+
K1[:2] = K1[:2] * scale1
|
57 |
+
K2[:2] = K2[:2] * scale2
|
58 |
+
dense_matches, dense_certainty = model.match(im1_path, im2_path)
|
59 |
+
sparse_matches,_ = model.sample(
|
60 |
+
dense_matches, dense_certainty, 5000
|
61 |
+
)
|
62 |
+
kpts1 = sparse_matches[:, :2]
|
63 |
+
kpts1 = (
|
64 |
+
torch.stack(
|
65 |
+
(
|
66 |
+
w1 * (kpts1[:, 0] + 1) / 2,
|
67 |
+
h1 * (kpts1[:, 1] + 1) / 2,
|
68 |
+
),
|
69 |
+
axis=-1,
|
70 |
+
)
|
71 |
+
)
|
72 |
+
kpts2 = sparse_matches[:, 2:]
|
73 |
+
kpts2 = (
|
74 |
+
torch.stack(
|
75 |
+
(
|
76 |
+
w2 * (kpts2[:, 0] + 1) / 2,
|
77 |
+
h2 * (kpts2[:, 1] + 1) / 2,
|
78 |
+
),
|
79 |
+
axis=-1,
|
80 |
+
)
|
81 |
+
)
|
82 |
+
for _ in range(5):
|
83 |
+
shuffling = np.random.permutation(np.arange(len(kpts1)))
|
84 |
+
kpts1 = kpts1[shuffling]
|
85 |
+
kpts2 = kpts2[shuffling]
|
86 |
+
try:
|
87 |
+
norm_threshold = 0.5 / (
|
88 |
+
np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
|
89 |
+
R_est, t_est, mask = estimate_pose(
|
90 |
+
kpts1.cpu().numpy(),
|
91 |
+
kpts2.cpu().numpy(),
|
92 |
+
K1,
|
93 |
+
K2,
|
94 |
+
norm_threshold,
|
95 |
+
conf=0.99999,
|
96 |
+
)
|
97 |
+
T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
|
98 |
+
e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
|
99 |
+
e_pose = max(e_t, e_R)
|
100 |
+
except Exception as e:
|
101 |
+
print(repr(e))
|
102 |
+
e_t, e_R = 90, 90
|
103 |
+
e_pose = max(e_t, e_R)
|
104 |
+
tot_e_t.append(e_t)
|
105 |
+
tot_e_R.append(e_R)
|
106 |
+
tot_e_pose.append(e_pose)
|
107 |
+
tot_e_pose = np.array(tot_e_pose)
|
108 |
+
thresholds = [5, 10, 20]
|
109 |
+
auc = pose_auc(tot_e_pose, thresholds)
|
110 |
+
acc_5 = (tot_e_pose < 5).mean()
|
111 |
+
acc_10 = (tot_e_pose < 10).mean()
|
112 |
+
acc_15 = (tot_e_pose < 15).mean()
|
113 |
+
acc_20 = (tot_e_pose < 20).mean()
|
114 |
+
map_5 = acc_5
|
115 |
+
map_10 = np.mean([acc_5, acc_10])
|
116 |
+
map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
|
117 |
+
return {
|
118 |
+
"auc_5": auc[0],
|
119 |
+
"auc_10": auc[1],
|
120 |
+
"auc_20": auc[2],
|
121 |
+
"map_5": map_5,
|
122 |
+
"map_10": map_10,
|
123 |
+
"map_20": map_20,
|
124 |
+
}
|
third_party/DKM/dkm/benchmarks/megadepth_dense_benchmark.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import tqdm
|
4 |
+
from dkm.datasets import MegadepthBuilder
|
5 |
+
from dkm.utils import warp_kpts
|
6 |
+
from torch.utils.data import ConcatDataset
|
7 |
+
|
8 |
+
|
9 |
+
class MegadepthDenseBenchmark:
|
10 |
+
def __init__(self, data_root="data/megadepth", h = 384, w = 512, num_samples = 2000, device=None) -> None:
|
11 |
+
mega = MegadepthBuilder(data_root=data_root)
|
12 |
+
self.dataset = ConcatDataset(
|
13 |
+
mega.build_scenes(split="test_loftr", ht=h, wt=w)
|
14 |
+
) # fixed resolution of 384,512
|
15 |
+
self.num_samples = num_samples
|
16 |
+
if device is None:
|
17 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
18 |
+
self.device = device
|
19 |
+
|
20 |
+
def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches):
|
21 |
+
b, h1, w1, d = dense_matches.shape
|
22 |
+
with torch.no_grad():
|
23 |
+
x1 = dense_matches[..., :2].reshape(b, h1 * w1, 2)
|
24 |
+
# x1 = torch.stack((2*x1[...,0]/w1-1,2*x1[...,1]/h1-1),dim=-1)
|
25 |
+
mask, x2 = warp_kpts(
|
26 |
+
x1.double(),
|
27 |
+
depth1.double(),
|
28 |
+
depth2.double(),
|
29 |
+
T_1to2.double(),
|
30 |
+
K1.double(),
|
31 |
+
K2.double(),
|
32 |
+
)
|
33 |
+
x2 = torch.stack(
|
34 |
+
(w1 * (x2[..., 0] + 1) / 2, h1 * (x2[..., 1] + 1) / 2), dim=-1
|
35 |
+
)
|
36 |
+
prob = mask.float().reshape(b, h1, w1)
|
37 |
+
x2_hat = dense_matches[..., 2:]
|
38 |
+
x2_hat = torch.stack(
|
39 |
+
(w1 * (x2_hat[..., 0] + 1) / 2, h1 * (x2_hat[..., 1] + 1) / 2), dim=-1
|
40 |
+
)
|
41 |
+
gd = (x2_hat - x2.reshape(b, h1, w1, 2)).norm(dim=-1)
|
42 |
+
gd = gd[prob == 1]
|
43 |
+
pck_1 = (gd < 1.0).float().mean()
|
44 |
+
pck_3 = (gd < 3.0).float().mean()
|
45 |
+
pck_5 = (gd < 5.0).float().mean()
|
46 |
+
gd = gd.mean()
|
47 |
+
return gd, pck_1, pck_3, pck_5
|
48 |
+
|
49 |
+
def benchmark(self, model, batch_size=8):
|
50 |
+
model.train(False)
|
51 |
+
with torch.no_grad():
|
52 |
+
gd_tot = 0.0
|
53 |
+
pck_1_tot = 0.0
|
54 |
+
pck_3_tot = 0.0
|
55 |
+
pck_5_tot = 0.0
|
56 |
+
sampler = torch.utils.data.WeightedRandomSampler(
|
57 |
+
torch.ones(len(self.dataset)), replacement=False, num_samples=self.num_samples
|
58 |
+
)
|
59 |
+
dataloader = torch.utils.data.DataLoader(
|
60 |
+
self.dataset, batch_size=8, num_workers=batch_size, sampler=sampler
|
61 |
+
)
|
62 |
+
for data in tqdm.tqdm(dataloader):
|
63 |
+
im1, im2, depth1, depth2, T_1to2, K1, K2 = (
|
64 |
+
data["query"],
|
65 |
+
data["support"],
|
66 |
+
data["query_depth"].to(self.device),
|
67 |
+
data["support_depth"].to(self.device),
|
68 |
+
data["T_1to2"].to(self.device),
|
69 |
+
data["K1"].to(self.device),
|
70 |
+
data["K2"].to(self.device),
|
71 |
+
)
|
72 |
+
matches, certainty = model.match(im1, im2, batched=True)
|
73 |
+
gd, pck_1, pck_3, pck_5 = self.geometric_dist(
|
74 |
+
depth1, depth2, T_1to2, K1, K2, matches
|
75 |
+
)
|
76 |
+
gd_tot, pck_1_tot, pck_3_tot, pck_5_tot = (
|
77 |
+
gd_tot + gd,
|
78 |
+
pck_1_tot + pck_1,
|
79 |
+
pck_3_tot + pck_3,
|
80 |
+
pck_5_tot + pck_5,
|
81 |
+
)
|
82 |
+
return {
|
83 |
+
"mega_pck_1": pck_1_tot.item() / len(dataloader),
|
84 |
+
"mega_pck_3": pck_3_tot.item() / len(dataloader),
|
85 |
+
"mega_pck_5": pck_5_tot.item() / len(dataloader),
|
86 |
+
}
|
third_party/DKM/dkm/benchmarks/scannet_benchmark.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from dkm.utils import *
|
5 |
+
from PIL import Image
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
|
9 |
+
class ScanNetBenchmark:
|
10 |
+
def __init__(self, data_root="data/scannet") -> None:
|
11 |
+
self.data_root = data_root
|
12 |
+
|
13 |
+
def benchmark(self, model, model_name = None):
|
14 |
+
model.train(False)
|
15 |
+
with torch.no_grad():
|
16 |
+
data_root = self.data_root
|
17 |
+
tmp = np.load(osp.join(data_root, "test.npz"))
|
18 |
+
pairs, rel_pose = tmp["name"], tmp["rel_pose"]
|
19 |
+
tot_e_t, tot_e_R, tot_e_pose = [], [], []
|
20 |
+
pair_inds = np.random.choice(
|
21 |
+
range(len(pairs)), size=len(pairs), replace=False
|
22 |
+
)
|
23 |
+
for pairind in tqdm(pair_inds, smoothing=0.9):
|
24 |
+
scene = pairs[pairind]
|
25 |
+
scene_name = f"scene0{scene[0]}_00"
|
26 |
+
im1_path = osp.join(
|
27 |
+
self.data_root,
|
28 |
+
"scans_test",
|
29 |
+
scene_name,
|
30 |
+
"color",
|
31 |
+
f"{scene[2]}.jpg",
|
32 |
+
)
|
33 |
+
im1 = Image.open(im1_path)
|
34 |
+
im2_path = osp.join(
|
35 |
+
self.data_root,
|
36 |
+
"scans_test",
|
37 |
+
scene_name,
|
38 |
+
"color",
|
39 |
+
f"{scene[3]}.jpg",
|
40 |
+
)
|
41 |
+
im2 = Image.open(im2_path)
|
42 |
+
T_gt = rel_pose[pairind].reshape(3, 4)
|
43 |
+
R, t = T_gt[:3, :3], T_gt[:3, 3]
|
44 |
+
K = np.stack(
|
45 |
+
[
|
46 |
+
np.array([float(i) for i in r.split()])
|
47 |
+
for r in open(
|
48 |
+
osp.join(
|
49 |
+
self.data_root,
|
50 |
+
"scans_test",
|
51 |
+
scene_name,
|
52 |
+
"intrinsic",
|
53 |
+
"intrinsic_color.txt",
|
54 |
+
),
|
55 |
+
"r",
|
56 |
+
)
|
57 |
+
.read()
|
58 |
+
.split("\n")
|
59 |
+
if r
|
60 |
+
]
|
61 |
+
)
|
62 |
+
w1, h1 = im1.size
|
63 |
+
w2, h2 = im2.size
|
64 |
+
K1 = K.copy()
|
65 |
+
K2 = K.copy()
|
66 |
+
dense_matches, dense_certainty = model.match(im1_path, im2_path)
|
67 |
+
sparse_matches, sparse_certainty = model.sample(
|
68 |
+
dense_matches, dense_certainty, 5000
|
69 |
+
)
|
70 |
+
scale1 = 480 / min(w1, h1)
|
71 |
+
scale2 = 480 / min(w2, h2)
|
72 |
+
w1, h1 = scale1 * w1, scale1 * h1
|
73 |
+
w2, h2 = scale2 * w2, scale2 * h2
|
74 |
+
K1 = K1 * scale1
|
75 |
+
K2 = K2 * scale2
|
76 |
+
|
77 |
+
offset = 0.5
|
78 |
+
kpts1 = sparse_matches[:, :2]
|
79 |
+
kpts1 = (
|
80 |
+
np.stack(
|
81 |
+
(
|
82 |
+
w1 * (kpts1[:, 0] + 1) / 2 - offset,
|
83 |
+
h1 * (kpts1[:, 1] + 1) / 2 - offset,
|
84 |
+
),
|
85 |
+
axis=-1,
|
86 |
+
)
|
87 |
+
)
|
88 |
+
kpts2 = sparse_matches[:, 2:]
|
89 |
+
kpts2 = (
|
90 |
+
np.stack(
|
91 |
+
(
|
92 |
+
w2 * (kpts2[:, 0] + 1) / 2 - offset,
|
93 |
+
h2 * (kpts2[:, 1] + 1) / 2 - offset,
|
94 |
+
),
|
95 |
+
axis=-1,
|
96 |
+
)
|
97 |
+
)
|
98 |
+
for _ in range(5):
|
99 |
+
shuffling = np.random.permutation(np.arange(len(kpts1)))
|
100 |
+
kpts1 = kpts1[shuffling]
|
101 |
+
kpts2 = kpts2[shuffling]
|
102 |
+
try:
|
103 |
+
norm_threshold = 0.5 / (
|
104 |
+
np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
|
105 |
+
R_est, t_est, mask = estimate_pose(
|
106 |
+
kpts1,
|
107 |
+
kpts2,
|
108 |
+
K1,
|
109 |
+
K2,
|
110 |
+
norm_threshold,
|
111 |
+
conf=0.99999,
|
112 |
+
)
|
113 |
+
T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
|
114 |
+
e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
|
115 |
+
e_pose = max(e_t, e_R)
|
116 |
+
except Exception as e:
|
117 |
+
print(repr(e))
|
118 |
+
e_t, e_R = 90, 90
|
119 |
+
e_pose = max(e_t, e_R)
|
120 |
+
tot_e_t.append(e_t)
|
121 |
+
tot_e_R.append(e_R)
|
122 |
+
tot_e_pose.append(e_pose)
|
123 |
+
tot_e_t.append(e_t)
|
124 |
+
tot_e_R.append(e_R)
|
125 |
+
tot_e_pose.append(e_pose)
|
126 |
+
tot_e_pose = np.array(tot_e_pose)
|
127 |
+
thresholds = [5, 10, 20]
|
128 |
+
auc = pose_auc(tot_e_pose, thresholds)
|
129 |
+
acc_5 = (tot_e_pose < 5).mean()
|
130 |
+
acc_10 = (tot_e_pose < 10).mean()
|
131 |
+
acc_15 = (tot_e_pose < 15).mean()
|
132 |
+
acc_20 = (tot_e_pose < 20).mean()
|
133 |
+
map_5 = acc_5
|
134 |
+
map_10 = np.mean([acc_5, acc_10])
|
135 |
+
map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
|
136 |
+
return {
|
137 |
+
"auc_5": auc[0],
|
138 |
+
"auc_10": auc[1],
|
139 |
+
"auc_20": auc[2],
|
140 |
+
"map_5": map_5,
|
141 |
+
"map_10": map_10,
|
142 |
+
"map_20": map_20,
|
143 |
+
}
|
third_party/DKM/dkm/checkpointing/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .checkpoint import CheckPoint
|
third_party/DKM/dkm/checkpointing/checkpoint.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from torch.nn.parallel.data_parallel import DataParallel
|
4 |
+
from torch.nn.parallel.distributed import DistributedDataParallel
|
5 |
+
from loguru import logger
|
6 |
+
|
7 |
+
|
8 |
+
class CheckPoint:
|
9 |
+
def __init__(self, dir=None, name="tmp"):
|
10 |
+
self.name = name
|
11 |
+
self.dir = dir
|
12 |
+
os.makedirs(self.dir, exist_ok=True)
|
13 |
+
|
14 |
+
def __call__(
|
15 |
+
self,
|
16 |
+
model,
|
17 |
+
optimizer,
|
18 |
+
lr_scheduler,
|
19 |
+
n,
|
20 |
+
):
|
21 |
+
assert model is not None
|
22 |
+
if isinstance(model, (DataParallel, DistributedDataParallel)):
|
23 |
+
model = model.module
|
24 |
+
states = {
|
25 |
+
"model": model.state_dict(),
|
26 |
+
"n": n,
|
27 |
+
"optimizer": optimizer.state_dict(),
|
28 |
+
"lr_scheduler": lr_scheduler.state_dict(),
|
29 |
+
}
|
30 |
+
torch.save(states, self.dir + self.name + f"_latest.pth")
|
31 |
+
logger.info(f"Saved states {list(states.keys())}, at step {n}")
|
third_party/DKM/dkm/datasets/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .megadepth import MegadepthBuilder
|
third_party/DKM/dkm/datasets/megadepth.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
from PIL import Image
|
4 |
+
import h5py
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import Dataset, DataLoader, ConcatDataset
|
8 |
+
|
9 |
+
from dkm.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
|
10 |
+
import torchvision.transforms.functional as tvf
|
11 |
+
from dkm.utils.transforms import GeometricSequential
|
12 |
+
import kornia.augmentation as K
|
13 |
+
|
14 |
+
|
15 |
+
class MegadepthScene:
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
data_root,
|
19 |
+
scene_info,
|
20 |
+
ht=384,
|
21 |
+
wt=512,
|
22 |
+
min_overlap=0.0,
|
23 |
+
shake_t=0,
|
24 |
+
rot_prob=0.0,
|
25 |
+
normalize=True,
|
26 |
+
) -> None:
|
27 |
+
self.data_root = data_root
|
28 |
+
self.image_paths = scene_info["image_paths"]
|
29 |
+
self.depth_paths = scene_info["depth_paths"]
|
30 |
+
self.intrinsics = scene_info["intrinsics"]
|
31 |
+
self.poses = scene_info["poses"]
|
32 |
+
self.pairs = scene_info["pairs"]
|
33 |
+
self.overlaps = scene_info["overlaps"]
|
34 |
+
threshold = self.overlaps > min_overlap
|
35 |
+
self.pairs = self.pairs[threshold]
|
36 |
+
self.overlaps = self.overlaps[threshold]
|
37 |
+
if len(self.pairs) > 100000:
|
38 |
+
pairinds = np.random.choice(
|
39 |
+
np.arange(0, len(self.pairs)), 100000, replace=False
|
40 |
+
)
|
41 |
+
self.pairs = self.pairs[pairinds]
|
42 |
+
self.overlaps = self.overlaps[pairinds]
|
43 |
+
# counts, bins = np.histogram(self.overlaps,20)
|
44 |
+
# print(counts)
|
45 |
+
self.im_transform_ops = get_tuple_transform_ops(
|
46 |
+
resize=(ht, wt), normalize=normalize
|
47 |
+
)
|
48 |
+
self.depth_transform_ops = get_depth_tuple_transform_ops(
|
49 |
+
resize=(ht, wt), normalize=False
|
50 |
+
)
|
51 |
+
self.wt, self.ht = wt, ht
|
52 |
+
self.shake_t = shake_t
|
53 |
+
self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob))
|
54 |
+
|
55 |
+
def load_im(self, im_ref, crop=None):
|
56 |
+
im = Image.open(im_ref)
|
57 |
+
return im
|
58 |
+
|
59 |
+
def load_depth(self, depth_ref, crop=None):
|
60 |
+
depth = np.array(h5py.File(depth_ref, "r")["depth"])
|
61 |
+
return torch.from_numpy(depth)
|
62 |
+
|
63 |
+
def __len__(self):
|
64 |
+
return len(self.pairs)
|
65 |
+
|
66 |
+
def scale_intrinsic(self, K, wi, hi):
|
67 |
+
sx, sy = self.wt / wi, self.ht / hi
|
68 |
+
sK = torch.tensor([[sx, 0, 0], [0, sy, 0], [0, 0, 1]])
|
69 |
+
return sK @ K
|
70 |
+
|
71 |
+
def rand_shake(self, *things):
|
72 |
+
t = np.random.choice(range(-self.shake_t, self.shake_t + 1), size=2)
|
73 |
+
return [
|
74 |
+
tvf.affine(thing, angle=0.0, translate=list(t), scale=1.0, shear=[0.0, 0.0])
|
75 |
+
for thing in things
|
76 |
+
], t
|
77 |
+
|
78 |
+
def __getitem__(self, pair_idx):
|
79 |
+
# read intrinsics of original size
|
80 |
+
idx1, idx2 = self.pairs[pair_idx]
|
81 |
+
K1 = torch.tensor(self.intrinsics[idx1].copy(), dtype=torch.float).reshape(3, 3)
|
82 |
+
K2 = torch.tensor(self.intrinsics[idx2].copy(), dtype=torch.float).reshape(3, 3)
|
83 |
+
|
84 |
+
# read and compute relative poses
|
85 |
+
T1 = self.poses[idx1]
|
86 |
+
T2 = self.poses[idx2]
|
87 |
+
T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[
|
88 |
+
:4, :4
|
89 |
+
] # (4, 4)
|
90 |
+
|
91 |
+
# Load positive pair data
|
92 |
+
im1, im2 = self.image_paths[idx1], self.image_paths[idx2]
|
93 |
+
depth1, depth2 = self.depth_paths[idx1], self.depth_paths[idx2]
|
94 |
+
im_src_ref = os.path.join(self.data_root, im1)
|
95 |
+
im_pos_ref = os.path.join(self.data_root, im2)
|
96 |
+
depth_src_ref = os.path.join(self.data_root, depth1)
|
97 |
+
depth_pos_ref = os.path.join(self.data_root, depth2)
|
98 |
+
# return torch.randn((1000,1000))
|
99 |
+
im_src = self.load_im(im_src_ref)
|
100 |
+
im_pos = self.load_im(im_pos_ref)
|
101 |
+
depth_src = self.load_depth(depth_src_ref)
|
102 |
+
depth_pos = self.load_depth(depth_pos_ref)
|
103 |
+
|
104 |
+
# Recompute camera intrinsic matrix due to the resize
|
105 |
+
K1 = self.scale_intrinsic(K1, im_src.width, im_src.height)
|
106 |
+
K2 = self.scale_intrinsic(K2, im_pos.width, im_pos.height)
|
107 |
+
# Process images
|
108 |
+
im_src, im_pos = self.im_transform_ops((im_src, im_pos))
|
109 |
+
depth_src, depth_pos = self.depth_transform_ops(
|
110 |
+
(depth_src[None, None], depth_pos[None, None])
|
111 |
+
)
|
112 |
+
[im_src, im_pos, depth_src, depth_pos], t = self.rand_shake(
|
113 |
+
im_src, im_pos, depth_src, depth_pos
|
114 |
+
)
|
115 |
+
im_src, Hq = self.H_generator(im_src[None])
|
116 |
+
depth_src = self.H_generator.apply_transform(depth_src, Hq)
|
117 |
+
K1[:2, 2] += t
|
118 |
+
K2[:2, 2] += t
|
119 |
+
K1 = Hq[0] @ K1
|
120 |
+
data_dict = {
|
121 |
+
"query": im_src[0],
|
122 |
+
"query_identifier": self.image_paths[idx1].split("/")[-1].split(".jpg")[0],
|
123 |
+
"support": im_pos,
|
124 |
+
"support_identifier": self.image_paths[idx2]
|
125 |
+
.split("/")[-1]
|
126 |
+
.split(".jpg")[0],
|
127 |
+
"query_depth": depth_src[0, 0],
|
128 |
+
"support_depth": depth_pos[0, 0],
|
129 |
+
"K1": K1,
|
130 |
+
"K2": K2,
|
131 |
+
"T_1to2": T_1to2,
|
132 |
+
}
|
133 |
+
return data_dict
|
134 |
+
|
135 |
+
|
136 |
+
class MegadepthBuilder:
|
137 |
+
def __init__(self, data_root="data/megadepth") -> None:
|
138 |
+
self.data_root = data_root
|
139 |
+
self.scene_info_root = os.path.join(data_root, "prep_scene_info")
|
140 |
+
self.all_scenes = os.listdir(self.scene_info_root)
|
141 |
+
self.test_scenes = ["0017.npy", "0004.npy", "0048.npy", "0013.npy"]
|
142 |
+
self.test_scenes_loftr = ["0015.npy", "0022.npy"]
|
143 |
+
|
144 |
+
def build_scenes(self, split="train", min_overlap=0.0, **kwargs):
|
145 |
+
if split == "train":
|
146 |
+
scene_names = set(self.all_scenes) - set(self.test_scenes)
|
147 |
+
elif split == "train_loftr":
|
148 |
+
scene_names = set(self.all_scenes) - set(self.test_scenes_loftr)
|
149 |
+
elif split == "test":
|
150 |
+
scene_names = self.test_scenes
|
151 |
+
elif split == "test_loftr":
|
152 |
+
scene_names = self.test_scenes_loftr
|
153 |
+
else:
|
154 |
+
raise ValueError(f"Split {split} not available")
|
155 |
+
scenes = []
|
156 |
+
for scene_name in scene_names:
|
157 |
+
scene_info = np.load(
|
158 |
+
os.path.join(self.scene_info_root, scene_name), allow_pickle=True
|
159 |
+
).item()
|
160 |
+
scenes.append(
|
161 |
+
MegadepthScene(
|
162 |
+
self.data_root, scene_info, min_overlap=min_overlap, **kwargs
|
163 |
+
)
|
164 |
+
)
|
165 |
+
return scenes
|
166 |
+
|
167 |
+
def weight_scenes(self, concat_dataset, alpha=0.5):
|
168 |
+
ns = []
|
169 |
+
for d in concat_dataset.datasets:
|
170 |
+
ns.append(len(d))
|
171 |
+
ws = torch.cat([torch.ones(n) / n**alpha for n in ns])
|
172 |
+
return ws
|
173 |
+
|
174 |
+
|
175 |
+
if __name__ == "__main__":
|
176 |
+
mega_test = ConcatDataset(MegadepthBuilder().build_scenes(split="train"))
|
177 |
+
mega_test[0]
|
third_party/DKM/dkm/datasets/scannet.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
from PIL import Image
|
4 |
+
import cv2
|
5 |
+
import h5py
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from torch.utils.data import (
|
9 |
+
Dataset,
|
10 |
+
DataLoader,
|
11 |
+
ConcatDataset)
|
12 |
+
|
13 |
+
import torchvision.transforms.functional as tvf
|
14 |
+
import kornia.augmentation as K
|
15 |
+
import os.path as osp
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
from dkm.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
|
18 |
+
from dkm.utils.transforms import GeometricSequential
|
19 |
+
|
20 |
+
from tqdm import tqdm
|
21 |
+
|
22 |
+
class ScanNetScene:
|
23 |
+
def __init__(self, data_root, scene_info, ht = 384, wt = 512, min_overlap=0., shake_t = 0, rot_prob=0.) -> None:
|
24 |
+
self.scene_root = osp.join(data_root,"scans","scans_train")
|
25 |
+
self.data_names = scene_info['name']
|
26 |
+
self.overlaps = scene_info['score']
|
27 |
+
# Only sample 10s
|
28 |
+
valid = (self.data_names[:,-2:] % 10).sum(axis=-1) == 0
|
29 |
+
self.overlaps = self.overlaps[valid]
|
30 |
+
self.data_names = self.data_names[valid]
|
31 |
+
if len(self.data_names) > 10000:
|
32 |
+
pairinds = np.random.choice(np.arange(0,len(self.data_names)),10000,replace=False)
|
33 |
+
self.data_names = self.data_names[pairinds]
|
34 |
+
self.overlaps = self.overlaps[pairinds]
|
35 |
+
self.im_transform_ops = get_tuple_transform_ops(resize=(ht, wt), normalize=True)
|
36 |
+
self.depth_transform_ops = get_depth_tuple_transform_ops(resize=(ht, wt), normalize=False)
|
37 |
+
self.wt, self.ht = wt, ht
|
38 |
+
self.shake_t = shake_t
|
39 |
+
self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob))
|
40 |
+
|
41 |
+
def load_im(self, im_ref, crop=None):
|
42 |
+
im = Image.open(im_ref)
|
43 |
+
return im
|
44 |
+
|
45 |
+
def load_depth(self, depth_ref, crop=None):
|
46 |
+
depth = cv2.imread(str(depth_ref), cv2.IMREAD_UNCHANGED)
|
47 |
+
depth = depth / 1000
|
48 |
+
depth = torch.from_numpy(depth).float() # (h, w)
|
49 |
+
return depth
|
50 |
+
|
51 |
+
def __len__(self):
|
52 |
+
return len(self.data_names)
|
53 |
+
|
54 |
+
def scale_intrinsic(self, K, wi, hi):
|
55 |
+
sx, sy = self.wt / wi, self.ht / hi
|
56 |
+
sK = torch.tensor([[sx, 0, 0],
|
57 |
+
[0, sy, 0],
|
58 |
+
[0, 0, 1]])
|
59 |
+
return sK@K
|
60 |
+
|
61 |
+
def read_scannet_pose(self,path):
|
62 |
+
""" Read ScanNet's Camera2World pose and transform it to World2Camera.
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
pose_w2c (np.ndarray): (4, 4)
|
66 |
+
"""
|
67 |
+
cam2world = np.loadtxt(path, delimiter=' ')
|
68 |
+
world2cam = np.linalg.inv(cam2world)
|
69 |
+
return world2cam
|
70 |
+
|
71 |
+
|
72 |
+
def read_scannet_intrinsic(self,path):
|
73 |
+
""" Read ScanNet's intrinsic matrix and return the 3x3 matrix.
|
74 |
+
"""
|
75 |
+
intrinsic = np.loadtxt(path, delimiter=' ')
|
76 |
+
return intrinsic[:-1, :-1]
|
77 |
+
|
78 |
+
def __getitem__(self, pair_idx):
|
79 |
+
# read intrinsics of original size
|
80 |
+
data_name = self.data_names[pair_idx]
|
81 |
+
scene_name, scene_sub_name, stem_name_1, stem_name_2 = data_name
|
82 |
+
scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}'
|
83 |
+
|
84 |
+
# read the intrinsic of depthmap
|
85 |
+
K1 = K2 = self.read_scannet_intrinsic(osp.join(self.scene_root,
|
86 |
+
scene_name,
|
87 |
+
'intrinsic', 'intrinsic_color.txt'))#the depth K is not the same, but doesnt really matter
|
88 |
+
# read and compute relative poses
|
89 |
+
T1 = self.read_scannet_pose(osp.join(self.scene_root,
|
90 |
+
scene_name,
|
91 |
+
'pose', f'{stem_name_1}.txt'))
|
92 |
+
T2 = self.read_scannet_pose(osp.join(self.scene_root,
|
93 |
+
scene_name,
|
94 |
+
'pose', f'{stem_name_2}.txt'))
|
95 |
+
T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[:4, :4] # (4, 4)
|
96 |
+
|
97 |
+
# Load positive pair data
|
98 |
+
im_src_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_1}.jpg')
|
99 |
+
im_pos_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_2}.jpg')
|
100 |
+
depth_src_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_1}.png')
|
101 |
+
depth_pos_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_2}.png')
|
102 |
+
|
103 |
+
im_src = self.load_im(im_src_ref)
|
104 |
+
im_pos = self.load_im(im_pos_ref)
|
105 |
+
depth_src = self.load_depth(depth_src_ref)
|
106 |
+
depth_pos = self.load_depth(depth_pos_ref)
|
107 |
+
|
108 |
+
# Recompute camera intrinsic matrix due to the resize
|
109 |
+
K1 = self.scale_intrinsic(K1, im_src.width, im_src.height)
|
110 |
+
K2 = self.scale_intrinsic(K2, im_pos.width, im_pos.height)
|
111 |
+
# Process images
|
112 |
+
im_src, im_pos = self.im_transform_ops((im_src, im_pos))
|
113 |
+
depth_src, depth_pos = self.depth_transform_ops((depth_src[None,None], depth_pos[None,None]))
|
114 |
+
|
115 |
+
data_dict = {'query': im_src,
|
116 |
+
'support': im_pos,
|
117 |
+
'query_depth': depth_src[0,0],
|
118 |
+
'support_depth': depth_pos[0,0],
|
119 |
+
'K1': K1,
|
120 |
+
'K2': K2,
|
121 |
+
'T_1to2':T_1to2,
|
122 |
+
}
|
123 |
+
return data_dict
|
124 |
+
|
125 |
+
|
126 |
+
class ScanNetBuilder:
|
127 |
+
def __init__(self, data_root = 'data/scannet') -> None:
|
128 |
+
self.data_root = data_root
|
129 |
+
self.scene_info_root = os.path.join(data_root,'scannet_indices')
|
130 |
+
self.all_scenes = os.listdir(self.scene_info_root)
|
131 |
+
|
132 |
+
def build_scenes(self, split = 'train', min_overlap=0., **kwargs):
|
133 |
+
# Note: split doesn't matter here as we always use same scannet_train scenes
|
134 |
+
scene_names = self.all_scenes
|
135 |
+
scenes = []
|
136 |
+
for scene_name in tqdm(scene_names):
|
137 |
+
scene_info = np.load(os.path.join(self.scene_info_root,scene_name), allow_pickle=True)
|
138 |
+
scenes.append(ScanNetScene(self.data_root, scene_info, min_overlap=min_overlap, **kwargs))
|
139 |
+
return scenes
|
140 |
+
|
141 |
+
def weight_scenes(self, concat_dataset, alpha=.5):
|
142 |
+
ns = []
|
143 |
+
for d in concat_dataset.datasets:
|
144 |
+
ns.append(len(d))
|
145 |
+
ws = torch.cat([torch.ones(n)/n**alpha for n in ns])
|
146 |
+
return ws
|
147 |
+
|
148 |
+
|
149 |
+
if __name__ == "__main__":
|
150 |
+
mega_test = ConcatDataset(ScanNetBuilder("data/scannet").build_scenes(split='train'))
|
151 |
+
mega_test[0]
|
third_party/DKM/dkm/losses/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .depth_match_regression_loss import DepthRegressionLoss
|
third_party/DKM/dkm/losses/depth_match_regression_loss.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from einops.einops import rearrange
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from dkm.utils.utils import warp_kpts
|
6 |
+
|
7 |
+
|
8 |
+
class DepthRegressionLoss(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
robust=True,
|
12 |
+
center_coords=False,
|
13 |
+
scale_normalize=False,
|
14 |
+
ce_weight=0.01,
|
15 |
+
local_loss=True,
|
16 |
+
local_dist=4.0,
|
17 |
+
local_largest_scale=8,
|
18 |
+
):
|
19 |
+
super().__init__()
|
20 |
+
self.robust = robust # measured in pixels
|
21 |
+
self.center_coords = center_coords
|
22 |
+
self.scale_normalize = scale_normalize
|
23 |
+
self.ce_weight = ce_weight
|
24 |
+
self.local_loss = local_loss
|
25 |
+
self.local_dist = local_dist
|
26 |
+
self.local_largest_scale = local_largest_scale
|
27 |
+
|
28 |
+
def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches, scale):
|
29 |
+
"""[summary]
|
30 |
+
|
31 |
+
Args:
|
32 |
+
H ([type]): [description]
|
33 |
+
scale ([type]): [description]
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
[type]: [description]
|
37 |
+
"""
|
38 |
+
b, h1, w1, d = dense_matches.shape
|
39 |
+
with torch.no_grad():
|
40 |
+
x1_n = torch.meshgrid(
|
41 |
+
*[
|
42 |
+
torch.linspace(
|
43 |
+
-1 + 1 / n, 1 - 1 / n, n, device=dense_matches.device
|
44 |
+
)
|
45 |
+
for n in (b, h1, w1)
|
46 |
+
]
|
47 |
+
)
|
48 |
+
x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(b, h1 * w1, 2)
|
49 |
+
mask, x2 = warp_kpts(
|
50 |
+
x1_n.double(),
|
51 |
+
depth1.double(),
|
52 |
+
depth2.double(),
|
53 |
+
T_1to2.double(),
|
54 |
+
K1.double(),
|
55 |
+
K2.double(),
|
56 |
+
)
|
57 |
+
prob = mask.float().reshape(b, h1, w1)
|
58 |
+
gd = (dense_matches - x2.reshape(b, h1, w1, 2)).norm(dim=-1) # *scale?
|
59 |
+
return gd, prob
|
60 |
+
|
61 |
+
def dense_depth_loss(self, dense_certainty, prob, gd, scale, eps=1e-8):
|
62 |
+
"""[summary]
|
63 |
+
|
64 |
+
Args:
|
65 |
+
dense_certainty ([type]): [description]
|
66 |
+
prob ([type]): [description]
|
67 |
+
eps ([type], optional): [description]. Defaults to 1e-8.
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
[type]: [description]
|
71 |
+
"""
|
72 |
+
smooth_prob = prob
|
73 |
+
ce_loss = F.binary_cross_entropy_with_logits(dense_certainty[:, 0], smooth_prob)
|
74 |
+
depth_loss = gd[prob > 0]
|
75 |
+
if not torch.any(prob > 0).item():
|
76 |
+
depth_loss = (gd * 0.0).mean() # Prevent issues where prob is 0 everywhere
|
77 |
+
return {
|
78 |
+
f"ce_loss_{scale}": ce_loss.mean(),
|
79 |
+
f"depth_loss_{scale}": depth_loss.mean(),
|
80 |
+
}
|
81 |
+
|
82 |
+
def forward(self, dense_corresps, batch):
|
83 |
+
"""[summary]
|
84 |
+
|
85 |
+
Args:
|
86 |
+
out ([type]): [description]
|
87 |
+
batch ([type]): [description]
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
[type]: [description]
|
91 |
+
"""
|
92 |
+
scales = list(dense_corresps.keys())
|
93 |
+
tot_loss = 0.0
|
94 |
+
prev_gd = 0.0
|
95 |
+
for scale in scales:
|
96 |
+
dense_scale_corresps = dense_corresps[scale]
|
97 |
+
dense_scale_certainty, dense_scale_coords = (
|
98 |
+
dense_scale_corresps["dense_certainty"],
|
99 |
+
dense_scale_corresps["dense_flow"],
|
100 |
+
)
|
101 |
+
dense_scale_coords = rearrange(dense_scale_coords, "b d h w -> b h w d")
|
102 |
+
b, h, w, d = dense_scale_coords.shape
|
103 |
+
gd, prob = self.geometric_dist(
|
104 |
+
batch["query_depth"],
|
105 |
+
batch["support_depth"],
|
106 |
+
batch["T_1to2"],
|
107 |
+
batch["K1"],
|
108 |
+
batch["K2"],
|
109 |
+
dense_scale_coords,
|
110 |
+
scale,
|
111 |
+
)
|
112 |
+
if (
|
113 |
+
scale <= self.local_largest_scale and self.local_loss
|
114 |
+
): # Thought here is that fine matching loss should not be punished by coarse mistakes, but should identify wrong matching
|
115 |
+
prob = prob * (
|
116 |
+
F.interpolate(prev_gd[:, None], size=(h, w), mode="nearest")[:, 0]
|
117 |
+
< (2 / 512) * (self.local_dist * scale)
|
118 |
+
)
|
119 |
+
depth_losses = self.dense_depth_loss(dense_scale_certainty, prob, gd, scale)
|
120 |
+
scale_loss = (
|
121 |
+
self.ce_weight * depth_losses[f"ce_loss_{scale}"]
|
122 |
+
+ depth_losses[f"depth_loss_{scale}"]
|
123 |
+
) # scale ce loss for coarser scales
|
124 |
+
if self.scale_normalize:
|
125 |
+
scale_loss = scale_loss * 1 / scale
|
126 |
+
tot_loss = tot_loss + scale_loss
|
127 |
+
prev_gd = gd.detach()
|
128 |
+
return tot_loss
|
third_party/DKM/dkm/models/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .model_zoo import (
|
2 |
+
DKMv3_outdoor,
|
3 |
+
DKMv3_indoor,
|
4 |
+
)
|
third_party/DKM/dkm/models/deprecated/build_model.py
ADDED
@@ -0,0 +1,787 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from dkm import *
|
4 |
+
from .local_corr import LocalCorr
|
5 |
+
from .corr_channels import NormedCorr
|
6 |
+
from torchvision.models import resnet as tv_resnet
|
7 |
+
|
8 |
+
dkm_pretrained_urls = {
|
9 |
+
"DKM": {
|
10 |
+
"mega_synthetic": "https://github.com/Parskatt/storage/releases/download/dkm_mega_synthetic/dkm_mega_synthetic.pth",
|
11 |
+
"mega": "https://github.com/Parskatt/storage/releases/download/dkm_mega/dkm_mega.pth",
|
12 |
+
},
|
13 |
+
"DKMv2":{
|
14 |
+
"outdoor": "https://github.com/Parskatt/storage/releases/download/dkmv2/dkm_v2_outdoor.pth",
|
15 |
+
"indoor": "https://github.com/Parskatt/storage/releases/download/dkmv2/dkm_v2_indoor.pth",
|
16 |
+
}
|
17 |
+
}
|
18 |
+
|
19 |
+
|
20 |
+
def DKM(pretrained=True, version="mega_synthetic", device=None):
|
21 |
+
if device is None:
|
22 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
23 |
+
gp_dim = 256
|
24 |
+
dfn_dim = 384
|
25 |
+
feat_dim = 256
|
26 |
+
coordinate_decoder = DFN(
|
27 |
+
internal_dim=dfn_dim,
|
28 |
+
feat_input_modules=nn.ModuleDict(
|
29 |
+
{
|
30 |
+
"32": nn.Conv2d(512, feat_dim, 1, 1),
|
31 |
+
"16": nn.Conv2d(512, feat_dim, 1, 1),
|
32 |
+
}
|
33 |
+
),
|
34 |
+
pred_input_modules=nn.ModuleDict(
|
35 |
+
{
|
36 |
+
"32": nn.Identity(),
|
37 |
+
"16": nn.Identity(),
|
38 |
+
}
|
39 |
+
),
|
40 |
+
rrb_d_dict=nn.ModuleDict(
|
41 |
+
{
|
42 |
+
"32": RRB(gp_dim + feat_dim, dfn_dim),
|
43 |
+
"16": RRB(gp_dim + feat_dim, dfn_dim),
|
44 |
+
}
|
45 |
+
),
|
46 |
+
cab_dict=nn.ModuleDict(
|
47 |
+
{
|
48 |
+
"32": CAB(2 * dfn_dim, dfn_dim),
|
49 |
+
"16": CAB(2 * dfn_dim, dfn_dim),
|
50 |
+
}
|
51 |
+
),
|
52 |
+
rrb_u_dict=nn.ModuleDict(
|
53 |
+
{
|
54 |
+
"32": RRB(dfn_dim, dfn_dim),
|
55 |
+
"16": RRB(dfn_dim, dfn_dim),
|
56 |
+
}
|
57 |
+
),
|
58 |
+
terminal_module=nn.ModuleDict(
|
59 |
+
{
|
60 |
+
"32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
|
61 |
+
"16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
|
62 |
+
}
|
63 |
+
),
|
64 |
+
)
|
65 |
+
dw = True
|
66 |
+
hidden_blocks = 8
|
67 |
+
kernel_size = 5
|
68 |
+
conv_refiner = nn.ModuleDict(
|
69 |
+
{
|
70 |
+
"16": ConvRefiner(
|
71 |
+
2 * 512,
|
72 |
+
1024,
|
73 |
+
3,
|
74 |
+
kernel_size=kernel_size,
|
75 |
+
dw=dw,
|
76 |
+
hidden_blocks=hidden_blocks,
|
77 |
+
),
|
78 |
+
"8": ConvRefiner(
|
79 |
+
2 * 512,
|
80 |
+
1024,
|
81 |
+
3,
|
82 |
+
kernel_size=kernel_size,
|
83 |
+
dw=dw,
|
84 |
+
hidden_blocks=hidden_blocks,
|
85 |
+
),
|
86 |
+
"4": ConvRefiner(
|
87 |
+
2 * 256,
|
88 |
+
512,
|
89 |
+
3,
|
90 |
+
kernel_size=kernel_size,
|
91 |
+
dw=dw,
|
92 |
+
hidden_blocks=hidden_blocks,
|
93 |
+
),
|
94 |
+
"2": ConvRefiner(
|
95 |
+
2 * 64,
|
96 |
+
128,
|
97 |
+
3,
|
98 |
+
kernel_size=kernel_size,
|
99 |
+
dw=dw,
|
100 |
+
hidden_blocks=hidden_blocks,
|
101 |
+
),
|
102 |
+
"1": ConvRefiner(
|
103 |
+
2 * 3,
|
104 |
+
24,
|
105 |
+
3,
|
106 |
+
kernel_size=kernel_size,
|
107 |
+
dw=dw,
|
108 |
+
hidden_blocks=hidden_blocks,
|
109 |
+
),
|
110 |
+
}
|
111 |
+
)
|
112 |
+
kernel_temperature = 0.2
|
113 |
+
learn_temperature = False
|
114 |
+
no_cov = True
|
115 |
+
kernel = CosKernel
|
116 |
+
only_attention = False
|
117 |
+
basis = "fourier"
|
118 |
+
gp32 = GP(
|
119 |
+
kernel,
|
120 |
+
T=kernel_temperature,
|
121 |
+
learn_temperature=learn_temperature,
|
122 |
+
only_attention=only_attention,
|
123 |
+
gp_dim=gp_dim,
|
124 |
+
basis=basis,
|
125 |
+
no_cov=no_cov,
|
126 |
+
)
|
127 |
+
gp16 = GP(
|
128 |
+
kernel,
|
129 |
+
T=kernel_temperature,
|
130 |
+
learn_temperature=learn_temperature,
|
131 |
+
only_attention=only_attention,
|
132 |
+
gp_dim=gp_dim,
|
133 |
+
basis=basis,
|
134 |
+
no_cov=no_cov,
|
135 |
+
)
|
136 |
+
gps = nn.ModuleDict({"32": gp32, "16": gp16})
|
137 |
+
proj = nn.ModuleDict(
|
138 |
+
{"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
|
139 |
+
)
|
140 |
+
decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
|
141 |
+
h, w = 384, 512
|
142 |
+
encoder = Encoder(
|
143 |
+
tv_resnet.resnet50(pretrained=not pretrained),
|
144 |
+
) # only load pretrained weights if not loading a pretrained matcher ;)
|
145 |
+
matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device)
|
146 |
+
if pretrained:
|
147 |
+
weights = torch.hub.load_state_dict_from_url(
|
148 |
+
dkm_pretrained_urls["DKM"][version]
|
149 |
+
)
|
150 |
+
matcher.load_state_dict(weights)
|
151 |
+
return matcher
|
152 |
+
|
153 |
+
def DKMv2(pretrained=True, version="outdoor", resolution = "low", **kwargs):
|
154 |
+
gp_dim = 256
|
155 |
+
dfn_dim = 384
|
156 |
+
feat_dim = 256
|
157 |
+
coordinate_decoder = DFN(
|
158 |
+
internal_dim=dfn_dim,
|
159 |
+
feat_input_modules=nn.ModuleDict(
|
160 |
+
{
|
161 |
+
"32": nn.Conv2d(512, feat_dim, 1, 1),
|
162 |
+
"16": nn.Conv2d(512, feat_dim, 1, 1),
|
163 |
+
}
|
164 |
+
),
|
165 |
+
pred_input_modules=nn.ModuleDict(
|
166 |
+
{
|
167 |
+
"32": nn.Identity(),
|
168 |
+
"16": nn.Identity(),
|
169 |
+
}
|
170 |
+
),
|
171 |
+
rrb_d_dict=nn.ModuleDict(
|
172 |
+
{
|
173 |
+
"32": RRB(gp_dim + feat_dim, dfn_dim),
|
174 |
+
"16": RRB(gp_dim + feat_dim, dfn_dim),
|
175 |
+
}
|
176 |
+
),
|
177 |
+
cab_dict=nn.ModuleDict(
|
178 |
+
{
|
179 |
+
"32": CAB(2 * dfn_dim, dfn_dim),
|
180 |
+
"16": CAB(2 * dfn_dim, dfn_dim),
|
181 |
+
}
|
182 |
+
),
|
183 |
+
rrb_u_dict=nn.ModuleDict(
|
184 |
+
{
|
185 |
+
"32": RRB(dfn_dim, dfn_dim),
|
186 |
+
"16": RRB(dfn_dim, dfn_dim),
|
187 |
+
}
|
188 |
+
),
|
189 |
+
terminal_module=nn.ModuleDict(
|
190 |
+
{
|
191 |
+
"32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
|
192 |
+
"16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
|
193 |
+
}
|
194 |
+
),
|
195 |
+
)
|
196 |
+
dw = True
|
197 |
+
hidden_blocks = 8
|
198 |
+
kernel_size = 5
|
199 |
+
displacement_emb = "linear"
|
200 |
+
conv_refiner = nn.ModuleDict(
|
201 |
+
{
|
202 |
+
"16": ConvRefiner(
|
203 |
+
2 * 512+128,
|
204 |
+
1024+128,
|
205 |
+
3,
|
206 |
+
kernel_size=kernel_size,
|
207 |
+
dw=dw,
|
208 |
+
hidden_blocks=hidden_blocks,
|
209 |
+
displacement_emb=displacement_emb,
|
210 |
+
displacement_emb_dim=128,
|
211 |
+
),
|
212 |
+
"8": ConvRefiner(
|
213 |
+
2 * 512+64,
|
214 |
+
1024+64,
|
215 |
+
3,
|
216 |
+
kernel_size=kernel_size,
|
217 |
+
dw=dw,
|
218 |
+
hidden_blocks=hidden_blocks,
|
219 |
+
displacement_emb=displacement_emb,
|
220 |
+
displacement_emb_dim=64,
|
221 |
+
),
|
222 |
+
"4": ConvRefiner(
|
223 |
+
2 * 256+32,
|
224 |
+
512+32,
|
225 |
+
3,
|
226 |
+
kernel_size=kernel_size,
|
227 |
+
dw=dw,
|
228 |
+
hidden_blocks=hidden_blocks,
|
229 |
+
displacement_emb=displacement_emb,
|
230 |
+
displacement_emb_dim=32,
|
231 |
+
),
|
232 |
+
"2": ConvRefiner(
|
233 |
+
2 * 64+16,
|
234 |
+
128+16,
|
235 |
+
3,
|
236 |
+
kernel_size=kernel_size,
|
237 |
+
dw=dw,
|
238 |
+
hidden_blocks=hidden_blocks,
|
239 |
+
displacement_emb=displacement_emb,
|
240 |
+
displacement_emb_dim=16,
|
241 |
+
),
|
242 |
+
"1": ConvRefiner(
|
243 |
+
2 * 3+6,
|
244 |
+
24,
|
245 |
+
3,
|
246 |
+
kernel_size=kernel_size,
|
247 |
+
dw=dw,
|
248 |
+
hidden_blocks=hidden_blocks,
|
249 |
+
displacement_emb=displacement_emb,
|
250 |
+
displacement_emb_dim=6,
|
251 |
+
),
|
252 |
+
}
|
253 |
+
)
|
254 |
+
kernel_temperature = 0.2
|
255 |
+
learn_temperature = False
|
256 |
+
no_cov = True
|
257 |
+
kernel = CosKernel
|
258 |
+
only_attention = False
|
259 |
+
basis = "fourier"
|
260 |
+
gp32 = GP(
|
261 |
+
kernel,
|
262 |
+
T=kernel_temperature,
|
263 |
+
learn_temperature=learn_temperature,
|
264 |
+
only_attention=only_attention,
|
265 |
+
gp_dim=gp_dim,
|
266 |
+
basis=basis,
|
267 |
+
no_cov=no_cov,
|
268 |
+
)
|
269 |
+
gp16 = GP(
|
270 |
+
kernel,
|
271 |
+
T=kernel_temperature,
|
272 |
+
learn_temperature=learn_temperature,
|
273 |
+
only_attention=only_attention,
|
274 |
+
gp_dim=gp_dim,
|
275 |
+
basis=basis,
|
276 |
+
no_cov=no_cov,
|
277 |
+
)
|
278 |
+
gps = nn.ModuleDict({"32": gp32, "16": gp16})
|
279 |
+
proj = nn.ModuleDict(
|
280 |
+
{"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
|
281 |
+
)
|
282 |
+
decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
|
283 |
+
if resolution == "low":
|
284 |
+
h, w = 384, 512
|
285 |
+
elif resolution == "high":
|
286 |
+
h, w = 480, 640
|
287 |
+
encoder = Encoder(
|
288 |
+
tv_resnet.resnet50(pretrained=not pretrained),
|
289 |
+
) # only load pretrained weights if not loading a pretrained matcher ;)
|
290 |
+
matcher = RegressionMatcher(encoder, decoder, h=h, w=w,**kwargs).to(device)
|
291 |
+
if pretrained:
|
292 |
+
try:
|
293 |
+
weights = torch.hub.load_state_dict_from_url(
|
294 |
+
dkm_pretrained_urls["DKMv2"][version]
|
295 |
+
)
|
296 |
+
except:
|
297 |
+
weights = torch.load(
|
298 |
+
dkm_pretrained_urls["DKMv2"][version]
|
299 |
+
)
|
300 |
+
matcher.load_state_dict(weights)
|
301 |
+
return matcher
|
302 |
+
|
303 |
+
|
304 |
+
def local_corr(pretrained=True, version="mega_synthetic"):
|
305 |
+
gp_dim = 256
|
306 |
+
dfn_dim = 384
|
307 |
+
feat_dim = 256
|
308 |
+
coordinate_decoder = DFN(
|
309 |
+
internal_dim=dfn_dim,
|
310 |
+
feat_input_modules=nn.ModuleDict(
|
311 |
+
{
|
312 |
+
"32": nn.Conv2d(512, feat_dim, 1, 1),
|
313 |
+
"16": nn.Conv2d(512, feat_dim, 1, 1),
|
314 |
+
}
|
315 |
+
),
|
316 |
+
pred_input_modules=nn.ModuleDict(
|
317 |
+
{
|
318 |
+
"32": nn.Identity(),
|
319 |
+
"16": nn.Identity(),
|
320 |
+
}
|
321 |
+
),
|
322 |
+
rrb_d_dict=nn.ModuleDict(
|
323 |
+
{
|
324 |
+
"32": RRB(gp_dim + feat_dim, dfn_dim),
|
325 |
+
"16": RRB(gp_dim + feat_dim, dfn_dim),
|
326 |
+
}
|
327 |
+
),
|
328 |
+
cab_dict=nn.ModuleDict(
|
329 |
+
{
|
330 |
+
"32": CAB(2 * dfn_dim, dfn_dim),
|
331 |
+
"16": CAB(2 * dfn_dim, dfn_dim),
|
332 |
+
}
|
333 |
+
),
|
334 |
+
rrb_u_dict=nn.ModuleDict(
|
335 |
+
{
|
336 |
+
"32": RRB(dfn_dim, dfn_dim),
|
337 |
+
"16": RRB(dfn_dim, dfn_dim),
|
338 |
+
}
|
339 |
+
),
|
340 |
+
terminal_module=nn.ModuleDict(
|
341 |
+
{
|
342 |
+
"32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
|
343 |
+
"16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
|
344 |
+
}
|
345 |
+
),
|
346 |
+
)
|
347 |
+
dw = True
|
348 |
+
hidden_blocks = 8
|
349 |
+
kernel_size = 5
|
350 |
+
conv_refiner = nn.ModuleDict(
|
351 |
+
{
|
352 |
+
"16": LocalCorr(
|
353 |
+
81,
|
354 |
+
81 * 12,
|
355 |
+
3,
|
356 |
+
kernel_size=kernel_size,
|
357 |
+
dw=dw,
|
358 |
+
hidden_blocks=hidden_blocks,
|
359 |
+
),
|
360 |
+
"8": LocalCorr(
|
361 |
+
81,
|
362 |
+
81 * 12,
|
363 |
+
3,
|
364 |
+
kernel_size=kernel_size,
|
365 |
+
dw=dw,
|
366 |
+
hidden_blocks=hidden_blocks,
|
367 |
+
),
|
368 |
+
"4": LocalCorr(
|
369 |
+
81,
|
370 |
+
81 * 6,
|
371 |
+
3,
|
372 |
+
kernel_size=kernel_size,
|
373 |
+
dw=dw,
|
374 |
+
hidden_blocks=hidden_blocks,
|
375 |
+
),
|
376 |
+
"2": LocalCorr(
|
377 |
+
81,
|
378 |
+
81,
|
379 |
+
3,
|
380 |
+
kernel_size=kernel_size,
|
381 |
+
dw=dw,
|
382 |
+
hidden_blocks=hidden_blocks,
|
383 |
+
),
|
384 |
+
"1": ConvRefiner(
|
385 |
+
2 * 3,
|
386 |
+
24,
|
387 |
+
3,
|
388 |
+
kernel_size=kernel_size,
|
389 |
+
dw=dw,
|
390 |
+
hidden_blocks=hidden_blocks,
|
391 |
+
),
|
392 |
+
}
|
393 |
+
)
|
394 |
+
kernel_temperature = 0.2
|
395 |
+
learn_temperature = False
|
396 |
+
no_cov = True
|
397 |
+
kernel = CosKernel
|
398 |
+
only_attention = False
|
399 |
+
basis = "fourier"
|
400 |
+
gp32 = GP(
|
401 |
+
kernel,
|
402 |
+
T=kernel_temperature,
|
403 |
+
learn_temperature=learn_temperature,
|
404 |
+
only_attention=only_attention,
|
405 |
+
gp_dim=gp_dim,
|
406 |
+
basis=basis,
|
407 |
+
no_cov=no_cov,
|
408 |
+
)
|
409 |
+
gp16 = GP(
|
410 |
+
kernel,
|
411 |
+
T=kernel_temperature,
|
412 |
+
learn_temperature=learn_temperature,
|
413 |
+
only_attention=only_attention,
|
414 |
+
gp_dim=gp_dim,
|
415 |
+
basis=basis,
|
416 |
+
no_cov=no_cov,
|
417 |
+
)
|
418 |
+
gps = nn.ModuleDict({"32": gp32, "16": gp16})
|
419 |
+
proj = nn.ModuleDict(
|
420 |
+
{"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
|
421 |
+
)
|
422 |
+
decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
|
423 |
+
h, w = 384, 512
|
424 |
+
encoder = Encoder(
|
425 |
+
tv_resnet.resnet50(pretrained=not pretrained)
|
426 |
+
) # only load pretrained weights if not loading a pretrained matcher ;)
|
427 |
+
matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device)
|
428 |
+
if pretrained:
|
429 |
+
weights = torch.hub.load_state_dict_from_url(
|
430 |
+
dkm_pretrained_urls["local_corr"][version]
|
431 |
+
)
|
432 |
+
matcher.load_state_dict(weights)
|
433 |
+
return matcher
|
434 |
+
|
435 |
+
|
436 |
+
def corr_channels(pretrained=True, version="mega_synthetic"):
|
437 |
+
h, w = 384, 512
|
438 |
+
gp_dim = (h // 32) * (w // 32), (h // 16) * (w // 16)
|
439 |
+
dfn_dim = 384
|
440 |
+
feat_dim = 256
|
441 |
+
coordinate_decoder = DFN(
|
442 |
+
internal_dim=dfn_dim,
|
443 |
+
feat_input_modules=nn.ModuleDict(
|
444 |
+
{
|
445 |
+
"32": nn.Conv2d(512, feat_dim, 1, 1),
|
446 |
+
"16": nn.Conv2d(512, feat_dim, 1, 1),
|
447 |
+
}
|
448 |
+
),
|
449 |
+
pred_input_modules=nn.ModuleDict(
|
450 |
+
{
|
451 |
+
"32": nn.Identity(),
|
452 |
+
"16": nn.Identity(),
|
453 |
+
}
|
454 |
+
),
|
455 |
+
rrb_d_dict=nn.ModuleDict(
|
456 |
+
{
|
457 |
+
"32": RRB(gp_dim[0] + feat_dim, dfn_dim),
|
458 |
+
"16": RRB(gp_dim[1] + feat_dim, dfn_dim),
|
459 |
+
}
|
460 |
+
),
|
461 |
+
cab_dict=nn.ModuleDict(
|
462 |
+
{
|
463 |
+
"32": CAB(2 * dfn_dim, dfn_dim),
|
464 |
+
"16": CAB(2 * dfn_dim, dfn_dim),
|
465 |
+
}
|
466 |
+
),
|
467 |
+
rrb_u_dict=nn.ModuleDict(
|
468 |
+
{
|
469 |
+
"32": RRB(dfn_dim, dfn_dim),
|
470 |
+
"16": RRB(dfn_dim, dfn_dim),
|
471 |
+
}
|
472 |
+
),
|
473 |
+
terminal_module=nn.ModuleDict(
|
474 |
+
{
|
475 |
+
"32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
|
476 |
+
"16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
|
477 |
+
}
|
478 |
+
),
|
479 |
+
)
|
480 |
+
dw = True
|
481 |
+
hidden_blocks = 8
|
482 |
+
kernel_size = 5
|
483 |
+
conv_refiner = nn.ModuleDict(
|
484 |
+
{
|
485 |
+
"16": ConvRefiner(
|
486 |
+
2 * 512,
|
487 |
+
1024,
|
488 |
+
3,
|
489 |
+
kernel_size=kernel_size,
|
490 |
+
dw=dw,
|
491 |
+
hidden_blocks=hidden_blocks,
|
492 |
+
),
|
493 |
+
"8": ConvRefiner(
|
494 |
+
2 * 512,
|
495 |
+
1024,
|
496 |
+
3,
|
497 |
+
kernel_size=kernel_size,
|
498 |
+
dw=dw,
|
499 |
+
hidden_blocks=hidden_blocks,
|
500 |
+
),
|
501 |
+
"4": ConvRefiner(
|
502 |
+
2 * 256,
|
503 |
+
512,
|
504 |
+
3,
|
505 |
+
kernel_size=kernel_size,
|
506 |
+
dw=dw,
|
507 |
+
hidden_blocks=hidden_blocks,
|
508 |
+
),
|
509 |
+
"2": ConvRefiner(
|
510 |
+
2 * 64,
|
511 |
+
128,
|
512 |
+
3,
|
513 |
+
kernel_size=kernel_size,
|
514 |
+
dw=dw,
|
515 |
+
hidden_blocks=hidden_blocks,
|
516 |
+
),
|
517 |
+
"1": ConvRefiner(
|
518 |
+
2 * 3,
|
519 |
+
24,
|
520 |
+
3,
|
521 |
+
kernel_size=kernel_size,
|
522 |
+
dw=dw,
|
523 |
+
hidden_blocks=hidden_blocks,
|
524 |
+
),
|
525 |
+
}
|
526 |
+
)
|
527 |
+
gp32 = NormedCorr()
|
528 |
+
gp16 = NormedCorr()
|
529 |
+
gps = nn.ModuleDict({"32": gp32, "16": gp16})
|
530 |
+
proj = nn.ModuleDict(
|
531 |
+
{"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
|
532 |
+
)
|
533 |
+
decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
|
534 |
+
h, w = 384, 512
|
535 |
+
encoder = Encoder(
|
536 |
+
tv_resnet.resnet50(pretrained=not pretrained)
|
537 |
+
) # only load pretrained weights if not loading a pretrained matcher ;)
|
538 |
+
matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device)
|
539 |
+
if pretrained:
|
540 |
+
weights = torch.hub.load_state_dict_from_url(
|
541 |
+
dkm_pretrained_urls["corr_channels"][version]
|
542 |
+
)
|
543 |
+
matcher.load_state_dict(weights)
|
544 |
+
return matcher
|
545 |
+
|
546 |
+
|
547 |
+
def baseline(pretrained=True, version="mega_synthetic"):
|
548 |
+
h, w = 384, 512
|
549 |
+
gp_dim = (h // 32) * (w // 32), (h // 16) * (w // 16)
|
550 |
+
dfn_dim = 384
|
551 |
+
feat_dim = 256
|
552 |
+
coordinate_decoder = DFN(
|
553 |
+
internal_dim=dfn_dim,
|
554 |
+
feat_input_modules=nn.ModuleDict(
|
555 |
+
{
|
556 |
+
"32": nn.Conv2d(512, feat_dim, 1, 1),
|
557 |
+
"16": nn.Conv2d(512, feat_dim, 1, 1),
|
558 |
+
}
|
559 |
+
),
|
560 |
+
pred_input_modules=nn.ModuleDict(
|
561 |
+
{
|
562 |
+
"32": nn.Identity(),
|
563 |
+
"16": nn.Identity(),
|
564 |
+
}
|
565 |
+
),
|
566 |
+
rrb_d_dict=nn.ModuleDict(
|
567 |
+
{
|
568 |
+
"32": RRB(gp_dim[0] + feat_dim, dfn_dim),
|
569 |
+
"16": RRB(gp_dim[1] + feat_dim, dfn_dim),
|
570 |
+
}
|
571 |
+
),
|
572 |
+
cab_dict=nn.ModuleDict(
|
573 |
+
{
|
574 |
+
"32": CAB(2 * dfn_dim, dfn_dim),
|
575 |
+
"16": CAB(2 * dfn_dim, dfn_dim),
|
576 |
+
}
|
577 |
+
),
|
578 |
+
rrb_u_dict=nn.ModuleDict(
|
579 |
+
{
|
580 |
+
"32": RRB(dfn_dim, dfn_dim),
|
581 |
+
"16": RRB(dfn_dim, dfn_dim),
|
582 |
+
}
|
583 |
+
),
|
584 |
+
terminal_module=nn.ModuleDict(
|
585 |
+
{
|
586 |
+
"32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
|
587 |
+
"16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
|
588 |
+
}
|
589 |
+
),
|
590 |
+
)
|
591 |
+
dw = True
|
592 |
+
hidden_blocks = 8
|
593 |
+
kernel_size = 5
|
594 |
+
conv_refiner = nn.ModuleDict(
|
595 |
+
{
|
596 |
+
"16": LocalCorr(
|
597 |
+
81,
|
598 |
+
81 * 12,
|
599 |
+
3,
|
600 |
+
kernel_size=kernel_size,
|
601 |
+
dw=dw,
|
602 |
+
hidden_blocks=hidden_blocks,
|
603 |
+
),
|
604 |
+
"8": LocalCorr(
|
605 |
+
81,
|
606 |
+
81 * 12,
|
607 |
+
3,
|
608 |
+
kernel_size=kernel_size,
|
609 |
+
dw=dw,
|
610 |
+
hidden_blocks=hidden_blocks,
|
611 |
+
),
|
612 |
+
"4": LocalCorr(
|
613 |
+
81,
|
614 |
+
81 * 6,
|
615 |
+
3,
|
616 |
+
kernel_size=kernel_size,
|
617 |
+
dw=dw,
|
618 |
+
hidden_blocks=hidden_blocks,
|
619 |
+
),
|
620 |
+
"2": LocalCorr(
|
621 |
+
81,
|
622 |
+
81,
|
623 |
+
3,
|
624 |
+
kernel_size=kernel_size,
|
625 |
+
dw=dw,
|
626 |
+
hidden_blocks=hidden_blocks,
|
627 |
+
),
|
628 |
+
"1": ConvRefiner(
|
629 |
+
2 * 3,
|
630 |
+
24,
|
631 |
+
3,
|
632 |
+
kernel_size=kernel_size,
|
633 |
+
dw=dw,
|
634 |
+
hidden_blocks=hidden_blocks,
|
635 |
+
),
|
636 |
+
}
|
637 |
+
)
|
638 |
+
gp32 = NormedCorr()
|
639 |
+
gp16 = NormedCorr()
|
640 |
+
gps = nn.ModuleDict({"32": gp32, "16": gp16})
|
641 |
+
proj = nn.ModuleDict(
|
642 |
+
{"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
|
643 |
+
)
|
644 |
+
decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
|
645 |
+
h, w = 384, 512
|
646 |
+
encoder = Encoder(
|
647 |
+
tv_resnet.resnet50(pretrained=not pretrained)
|
648 |
+
) # only load pretrained weights if not loading a pretrained matcher ;)
|
649 |
+
matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device)
|
650 |
+
if pretrained:
|
651 |
+
weights = torch.hub.load_state_dict_from_url(
|
652 |
+
dkm_pretrained_urls["baseline"][version]
|
653 |
+
)
|
654 |
+
matcher.load_state_dict(weights)
|
655 |
+
return matcher
|
656 |
+
|
657 |
+
|
658 |
+
def linear(pretrained=True, version="mega_synthetic"):
|
659 |
+
gp_dim = 256
|
660 |
+
dfn_dim = 384
|
661 |
+
feat_dim = 256
|
662 |
+
coordinate_decoder = DFN(
|
663 |
+
internal_dim=dfn_dim,
|
664 |
+
feat_input_modules=nn.ModuleDict(
|
665 |
+
{
|
666 |
+
"32": nn.Conv2d(512, feat_dim, 1, 1),
|
667 |
+
"16": nn.Conv2d(512, feat_dim, 1, 1),
|
668 |
+
}
|
669 |
+
),
|
670 |
+
pred_input_modules=nn.ModuleDict(
|
671 |
+
{
|
672 |
+
"32": nn.Identity(),
|
673 |
+
"16": nn.Identity(),
|
674 |
+
}
|
675 |
+
),
|
676 |
+
rrb_d_dict=nn.ModuleDict(
|
677 |
+
{
|
678 |
+
"32": RRB(gp_dim + feat_dim, dfn_dim),
|
679 |
+
"16": RRB(gp_dim + feat_dim, dfn_dim),
|
680 |
+
}
|
681 |
+
),
|
682 |
+
cab_dict=nn.ModuleDict(
|
683 |
+
{
|
684 |
+
"32": CAB(2 * dfn_dim, dfn_dim),
|
685 |
+
"16": CAB(2 * dfn_dim, dfn_dim),
|
686 |
+
}
|
687 |
+
),
|
688 |
+
rrb_u_dict=nn.ModuleDict(
|
689 |
+
{
|
690 |
+
"32": RRB(dfn_dim, dfn_dim),
|
691 |
+
"16": RRB(dfn_dim, dfn_dim),
|
692 |
+
}
|
693 |
+
),
|
694 |
+
terminal_module=nn.ModuleDict(
|
695 |
+
{
|
696 |
+
"32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
|
697 |
+
"16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
|
698 |
+
}
|
699 |
+
),
|
700 |
+
)
|
701 |
+
dw = True
|
702 |
+
hidden_blocks = 8
|
703 |
+
kernel_size = 5
|
704 |
+
conv_refiner = nn.ModuleDict(
|
705 |
+
{
|
706 |
+
"16": ConvRefiner(
|
707 |
+
2 * 512,
|
708 |
+
1024,
|
709 |
+
3,
|
710 |
+
kernel_size=kernel_size,
|
711 |
+
dw=dw,
|
712 |
+
hidden_blocks=hidden_blocks,
|
713 |
+
),
|
714 |
+
"8": ConvRefiner(
|
715 |
+
2 * 512,
|
716 |
+
1024,
|
717 |
+
3,
|
718 |
+
kernel_size=kernel_size,
|
719 |
+
dw=dw,
|
720 |
+
hidden_blocks=hidden_blocks,
|
721 |
+
),
|
722 |
+
"4": ConvRefiner(
|
723 |
+
2 * 256,
|
724 |
+
512,
|
725 |
+
3,
|
726 |
+
kernel_size=kernel_size,
|
727 |
+
dw=dw,
|
728 |
+
hidden_blocks=hidden_blocks,
|
729 |
+
),
|
730 |
+
"2": ConvRefiner(
|
731 |
+
2 * 64,
|
732 |
+
128,
|
733 |
+
3,
|
734 |
+
kernel_size=kernel_size,
|
735 |
+
dw=dw,
|
736 |
+
hidden_blocks=hidden_blocks,
|
737 |
+
),
|
738 |
+
"1": ConvRefiner(
|
739 |
+
2 * 3,
|
740 |
+
24,
|
741 |
+
3,
|
742 |
+
kernel_size=kernel_size,
|
743 |
+
dw=dw,
|
744 |
+
hidden_blocks=hidden_blocks,
|
745 |
+
),
|
746 |
+
}
|
747 |
+
)
|
748 |
+
kernel_temperature = 0.2
|
749 |
+
learn_temperature = False
|
750 |
+
no_cov = True
|
751 |
+
kernel = CosKernel
|
752 |
+
only_attention = False
|
753 |
+
basis = "linear"
|
754 |
+
gp32 = GP(
|
755 |
+
kernel,
|
756 |
+
T=kernel_temperature,
|
757 |
+
learn_temperature=learn_temperature,
|
758 |
+
only_attention=only_attention,
|
759 |
+
gp_dim=gp_dim,
|
760 |
+
basis=basis,
|
761 |
+
no_cov=no_cov,
|
762 |
+
)
|
763 |
+
gp16 = GP(
|
764 |
+
kernel,
|
765 |
+
T=kernel_temperature,
|
766 |
+
learn_temperature=learn_temperature,
|
767 |
+
only_attention=only_attention,
|
768 |
+
gp_dim=gp_dim,
|
769 |
+
basis=basis,
|
770 |
+
no_cov=no_cov,
|
771 |
+
)
|
772 |
+
gps = nn.ModuleDict({"32": gp32, "16": gp16})
|
773 |
+
proj = nn.ModuleDict(
|
774 |
+
{"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
|
775 |
+
)
|
776 |
+
decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
|
777 |
+
h, w = 384, 512
|
778 |
+
encoder = Encoder(
|
779 |
+
tv_resnet.resnet50(pretrained=not pretrained)
|
780 |
+
) # only load pretrained weights if not loading a pretrained matcher ;)
|
781 |
+
matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device)
|
782 |
+
if pretrained:
|
783 |
+
weights = torch.hub.load_state_dict_from_url(
|
784 |
+
dkm_pretrained_urls["linear"][version]
|
785 |
+
)
|
786 |
+
matcher.load_state_dict(weights)
|
787 |
+
return matcher
|
third_party/DKM/dkm/models/deprecated/corr_channels.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops import rearrange
|
5 |
+
|
6 |
+
|
7 |
+
class NormedCorrelationKernel(nn.Module): # similar to softmax kernel
|
8 |
+
def __init__(self):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
def __call__(self, x, y, eps=1e-6):
|
12 |
+
c = torch.einsum("bnd,bmd->bnm", x, y) / (
|
13 |
+
x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps
|
14 |
+
)
|
15 |
+
return c
|
16 |
+
|
17 |
+
|
18 |
+
class NormedCorr(nn.Module):
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
):
|
22 |
+
super().__init__()
|
23 |
+
self.corr = NormedCorrelationKernel()
|
24 |
+
|
25 |
+
def reshape(self, x):
|
26 |
+
return rearrange(x, "b d h w -> b (h w) d")
|
27 |
+
|
28 |
+
def forward(self, x, y, **kwargs):
|
29 |
+
b, c, h, w = y.shape
|
30 |
+
assert x.shape == y.shape
|
31 |
+
x, y = self.reshape(x), self.reshape(y)
|
32 |
+
corr_xy = self.corr(x, y)
|
33 |
+
corr_xy_flat = rearrange(corr_xy, "b (h w) c -> b c h w", h=h, w=w)
|
34 |
+
return corr_xy_flat
|
third_party/DKM/dkm/models/deprecated/local_corr.py
ADDED
@@ -0,0 +1,630 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
try:
|
5 |
+
import cupy
|
6 |
+
except:
|
7 |
+
print("Cupy not found, local correlation will not work")
|
8 |
+
import re
|
9 |
+
from ..dkm import ConvRefiner
|
10 |
+
|
11 |
+
|
12 |
+
class Stream:
|
13 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
14 |
+
if device == 'cuda':
|
15 |
+
stream = torch.cuda.current_stream(device=device).cuda_stream
|
16 |
+
else:
|
17 |
+
stream = None
|
18 |
+
|
19 |
+
|
20 |
+
kernel_Correlation_rearrange = """
|
21 |
+
extern "C" __global__ void kernel_Correlation_rearrange(
|
22 |
+
const int n,
|
23 |
+
const float* input,
|
24 |
+
float* output
|
25 |
+
) {
|
26 |
+
int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x;
|
27 |
+
if (intIndex >= n) {
|
28 |
+
return;
|
29 |
+
}
|
30 |
+
int intSample = blockIdx.z;
|
31 |
+
int intChannel = blockIdx.y;
|
32 |
+
float dblValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex];
|
33 |
+
__syncthreads();
|
34 |
+
int intPaddedY = (intIndex / SIZE_3(input)) + 4;
|
35 |
+
int intPaddedX = (intIndex % SIZE_3(input)) + 4;
|
36 |
+
int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX;
|
37 |
+
output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = dblValue;
|
38 |
+
}
|
39 |
+
"""
|
40 |
+
|
41 |
+
kernel_Correlation_updateOutput = """
|
42 |
+
extern "C" __global__ void kernel_Correlation_updateOutput(
|
43 |
+
const int n,
|
44 |
+
const float* rbot0,
|
45 |
+
const float* rbot1,
|
46 |
+
float* top
|
47 |
+
) {
|
48 |
+
extern __shared__ char patch_data_char[];
|
49 |
+
float *patch_data = (float *)patch_data_char;
|
50 |
+
// First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1
|
51 |
+
int x1 = blockIdx.x + 4;
|
52 |
+
int y1 = blockIdx.y + 4;
|
53 |
+
int item = blockIdx.z;
|
54 |
+
int ch_off = threadIdx.x;
|
55 |
+
// Load 3D patch into shared shared memory
|
56 |
+
for (int j = 0; j < 1; j++) { // HEIGHT
|
57 |
+
for (int i = 0; i < 1; i++) { // WIDTH
|
58 |
+
int ji_off = (j + i) * SIZE_3(rbot0);
|
59 |
+
for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
|
60 |
+
int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch;
|
61 |
+
int idxPatchData = ji_off + ch;
|
62 |
+
patch_data[idxPatchData] = rbot0[idx1];
|
63 |
+
}
|
64 |
+
}
|
65 |
+
}
|
66 |
+
__syncthreads();
|
67 |
+
__shared__ float sum[32];
|
68 |
+
// Compute correlation
|
69 |
+
for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) {
|
70 |
+
sum[ch_off] = 0;
|
71 |
+
int s2o = top_channel % 9 - 4;
|
72 |
+
int s2p = top_channel / 9 - 4;
|
73 |
+
for (int j = 0; j < 1; j++) { // HEIGHT
|
74 |
+
for (int i = 0; i < 1; i++) { // WIDTH
|
75 |
+
int ji_off = (j + i) * SIZE_3(rbot0);
|
76 |
+
for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
|
77 |
+
int x2 = x1 + s2o;
|
78 |
+
int y2 = y1 + s2p;
|
79 |
+
int idxPatchData = ji_off + ch;
|
80 |
+
int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch;
|
81 |
+
sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2];
|
82 |
+
}
|
83 |
+
}
|
84 |
+
}
|
85 |
+
__syncthreads();
|
86 |
+
if (ch_off == 0) {
|
87 |
+
float total_sum = 0;
|
88 |
+
for (int idx = 0; idx < 32; idx++) {
|
89 |
+
total_sum += sum[idx];
|
90 |
+
}
|
91 |
+
const int sumelems = SIZE_3(rbot0);
|
92 |
+
const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x;
|
93 |
+
top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems;
|
94 |
+
}
|
95 |
+
}
|
96 |
+
}
|
97 |
+
"""
|
98 |
+
|
99 |
+
kernel_Correlation_updateGradFirst = """
|
100 |
+
#define ROUND_OFF 50000
|
101 |
+
extern "C" __global__ void kernel_Correlation_updateGradFirst(
|
102 |
+
const int n,
|
103 |
+
const int intSample,
|
104 |
+
const float* rbot0,
|
105 |
+
const float* rbot1,
|
106 |
+
const float* gradOutput,
|
107 |
+
float* gradFirst,
|
108 |
+
float* gradSecond
|
109 |
+
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
110 |
+
int n = intIndex % SIZE_1(gradFirst); // channels
|
111 |
+
int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 4; // w-pos
|
112 |
+
int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 4; // h-pos
|
113 |
+
// round_off is a trick to enable integer division with ceil, even for negative numbers
|
114 |
+
// We use a large offset, for the inner part not to become negative.
|
115 |
+
const int round_off = ROUND_OFF;
|
116 |
+
const int round_off_s1 = round_off;
|
117 |
+
// We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
|
118 |
+
int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4)
|
119 |
+
int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4)
|
120 |
+
// Same here:
|
121 |
+
int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4)
|
122 |
+
int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4)
|
123 |
+
float sum = 0;
|
124 |
+
if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
|
125 |
+
xmin = max(0,xmin);
|
126 |
+
xmax = min(SIZE_3(gradOutput)-1,xmax);
|
127 |
+
ymin = max(0,ymin);
|
128 |
+
ymax = min(SIZE_2(gradOutput)-1,ymax);
|
129 |
+
for (int p = -4; p <= 4; p++) {
|
130 |
+
for (int o = -4; o <= 4; o++) {
|
131 |
+
// Get rbot1 data:
|
132 |
+
int s2o = o;
|
133 |
+
int s2p = p;
|
134 |
+
int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n;
|
135 |
+
float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n]
|
136 |
+
// Index offset for gradOutput in following loops:
|
137 |
+
int op = (p+4) * 9 + (o+4); // index[o,p]
|
138 |
+
int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
|
139 |
+
for (int y = ymin; y <= ymax; y++) {
|
140 |
+
for (int x = xmin; x <= xmax; x++) {
|
141 |
+
int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
|
142 |
+
sum += gradOutput[idxgradOutput] * bot1tmp;
|
143 |
+
}
|
144 |
+
}
|
145 |
+
}
|
146 |
+
}
|
147 |
+
}
|
148 |
+
const int sumelems = SIZE_1(gradFirst);
|
149 |
+
const int bot0index = ((n * SIZE_2(gradFirst)) + (m-4)) * SIZE_3(gradFirst) + (l-4);
|
150 |
+
gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems;
|
151 |
+
} }
|
152 |
+
"""
|
153 |
+
|
154 |
+
kernel_Correlation_updateGradSecond = """
|
155 |
+
#define ROUND_OFF 50000
|
156 |
+
extern "C" __global__ void kernel_Correlation_updateGradSecond(
|
157 |
+
const int n,
|
158 |
+
const int intSample,
|
159 |
+
const float* rbot0,
|
160 |
+
const float* rbot1,
|
161 |
+
const float* gradOutput,
|
162 |
+
float* gradFirst,
|
163 |
+
float* gradSecond
|
164 |
+
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
165 |
+
int n = intIndex % SIZE_1(gradSecond); // channels
|
166 |
+
int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 4; // w-pos
|
167 |
+
int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 4; // h-pos
|
168 |
+
// round_off is a trick to enable integer division with ceil, even for negative numbers
|
169 |
+
// We use a large offset, for the inner part not to become negative.
|
170 |
+
const int round_off = ROUND_OFF;
|
171 |
+
const int round_off_s1 = round_off;
|
172 |
+
float sum = 0;
|
173 |
+
for (int p = -4; p <= 4; p++) {
|
174 |
+
for (int o = -4; o <= 4; o++) {
|
175 |
+
int s2o = o;
|
176 |
+
int s2p = p;
|
177 |
+
//Get X,Y ranges and clamp
|
178 |
+
// We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
|
179 |
+
int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o)
|
180 |
+
int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o)
|
181 |
+
// Same here:
|
182 |
+
int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o)
|
183 |
+
int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p)
|
184 |
+
if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
|
185 |
+
xmin = max(0,xmin);
|
186 |
+
xmax = min(SIZE_3(gradOutput)-1,xmax);
|
187 |
+
ymin = max(0,ymin);
|
188 |
+
ymax = min(SIZE_2(gradOutput)-1,ymax);
|
189 |
+
// Get rbot0 data:
|
190 |
+
int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n;
|
191 |
+
float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n]
|
192 |
+
// Index offset for gradOutput in following loops:
|
193 |
+
int op = (p+4) * 9 + (o+4); // index[o,p]
|
194 |
+
int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
|
195 |
+
for (int y = ymin; y <= ymax; y++) {
|
196 |
+
for (int x = xmin; x <= xmax; x++) {
|
197 |
+
int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
|
198 |
+
sum += gradOutput[idxgradOutput] * bot0tmp;
|
199 |
+
}
|
200 |
+
}
|
201 |
+
}
|
202 |
+
}
|
203 |
+
}
|
204 |
+
const int sumelems = SIZE_1(gradSecond);
|
205 |
+
const int bot1index = ((n * SIZE_2(gradSecond)) + (m-4)) * SIZE_3(gradSecond) + (l-4);
|
206 |
+
gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems;
|
207 |
+
} }
|
208 |
+
"""
|
209 |
+
|
210 |
+
|
211 |
+
def cupy_kernel(strFunction, objectVariables):
|
212 |
+
strKernel = globals()[strFunction]
|
213 |
+
|
214 |
+
while True:
|
215 |
+
objectMatch = re.search(r"(SIZE_)([0-4])(\()([^\)]*)(\))", strKernel)
|
216 |
+
|
217 |
+
if objectMatch is None:
|
218 |
+
break
|
219 |
+
|
220 |
+
intArg = int(objectMatch.group(2))
|
221 |
+
|
222 |
+
strTensor = objectMatch.group(4)
|
223 |
+
intSizes = objectVariables[strTensor].size()
|
224 |
+
|
225 |
+
strKernel = strKernel.replace(objectMatch.group(), str(intSizes[intArg]))
|
226 |
+
|
227 |
+
while True:
|
228 |
+
objectMatch = re.search(r"(VALUE_)([0-4])(\()([^\)]+)(\))", strKernel)
|
229 |
+
|
230 |
+
if objectMatch is None:
|
231 |
+
break
|
232 |
+
|
233 |
+
intArgs = int(objectMatch.group(2))
|
234 |
+
strArgs = objectMatch.group(4).split(",")
|
235 |
+
|
236 |
+
strTensor = strArgs[0]
|
237 |
+
intStrides = objectVariables[strTensor].stride()
|
238 |
+
strIndex = [
|
239 |
+
"(("
|
240 |
+
+ strArgs[intArg + 1].replace("{", "(").replace("}", ")").strip()
|
241 |
+
+ ")*"
|
242 |
+
+ str(intStrides[intArg])
|
243 |
+
+ ")"
|
244 |
+
for intArg in range(intArgs)
|
245 |
+
]
|
246 |
+
|
247 |
+
strKernel = strKernel.replace(
|
248 |
+
objectMatch.group(0), strTensor + "[" + str.join("+", strIndex) + "]"
|
249 |
+
)
|
250 |
+
|
251 |
+
return strKernel
|
252 |
+
|
253 |
+
|
254 |
+
try:
|
255 |
+
|
256 |
+
@cupy.memoize(for_each_device=True)
|
257 |
+
def cupy_launch(strFunction, strKernel):
|
258 |
+
return cupy.RawModule(code=strKernel).get_function(strFunction)
|
259 |
+
|
260 |
+
except:
|
261 |
+
pass
|
262 |
+
|
263 |
+
|
264 |
+
class _FunctionCorrelation(torch.autograd.Function):
|
265 |
+
@staticmethod
|
266 |
+
def forward(self, first, second):
|
267 |
+
rbot0 = first.new_zeros(
|
268 |
+
[first.size(0), first.size(2) + 8, first.size(3) + 8, first.size(1)]
|
269 |
+
)
|
270 |
+
rbot1 = first.new_zeros(
|
271 |
+
[first.size(0), first.size(2) + 8, first.size(3) + 8, first.size(1)]
|
272 |
+
)
|
273 |
+
|
274 |
+
self.save_for_backward(first, second, rbot0, rbot1)
|
275 |
+
|
276 |
+
first = first.contiguous()
|
277 |
+
second = second.contiguous()
|
278 |
+
|
279 |
+
output = first.new_zeros([first.size(0), 81, first.size(2), first.size(3)])
|
280 |
+
|
281 |
+
if first.is_cuda == True:
|
282 |
+
n = first.size(2) * first.size(3)
|
283 |
+
cupy_launch(
|
284 |
+
"kernel_Correlation_rearrange",
|
285 |
+
cupy_kernel(
|
286 |
+
"kernel_Correlation_rearrange", {"input": first, "output": rbot0}
|
287 |
+
),
|
288 |
+
)(
|
289 |
+
grid=tuple([int((n + 16 - 1) / 16), first.size(1), first.size(0)]),
|
290 |
+
block=tuple([16, 1, 1]),
|
291 |
+
args=[n, first.data_ptr(), rbot0.data_ptr()],
|
292 |
+
stream=Stream,
|
293 |
+
)
|
294 |
+
|
295 |
+
n = second.size(2) * second.size(3)
|
296 |
+
cupy_launch(
|
297 |
+
"kernel_Correlation_rearrange",
|
298 |
+
cupy_kernel(
|
299 |
+
"kernel_Correlation_rearrange", {"input": second, "output": rbot1}
|
300 |
+
),
|
301 |
+
)(
|
302 |
+
grid=tuple([int((n + 16 - 1) / 16), second.size(1), second.size(0)]),
|
303 |
+
block=tuple([16, 1, 1]),
|
304 |
+
args=[n, second.data_ptr(), rbot1.data_ptr()],
|
305 |
+
stream=Stream,
|
306 |
+
)
|
307 |
+
|
308 |
+
n = output.size(1) * output.size(2) * output.size(3)
|
309 |
+
cupy_launch(
|
310 |
+
"kernel_Correlation_updateOutput",
|
311 |
+
cupy_kernel(
|
312 |
+
"kernel_Correlation_updateOutput",
|
313 |
+
{"rbot0": rbot0, "rbot1": rbot1, "top": output},
|
314 |
+
),
|
315 |
+
)(
|
316 |
+
grid=tuple([output.size(3), output.size(2), output.size(0)]),
|
317 |
+
block=tuple([32, 1, 1]),
|
318 |
+
shared_mem=first.size(1) * 4,
|
319 |
+
args=[n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr()],
|
320 |
+
stream=Stream,
|
321 |
+
)
|
322 |
+
|
323 |
+
elif first.is_cuda == False:
|
324 |
+
raise NotImplementedError()
|
325 |
+
|
326 |
+
return output
|
327 |
+
|
328 |
+
@staticmethod
|
329 |
+
def backward(self, gradOutput):
|
330 |
+
first, second, rbot0, rbot1 = self.saved_tensors
|
331 |
+
|
332 |
+
gradOutput = gradOutput.contiguous()
|
333 |
+
|
334 |
+
assert gradOutput.is_contiguous() == True
|
335 |
+
|
336 |
+
gradFirst = (
|
337 |
+
first.new_zeros(
|
338 |
+
[first.size(0), first.size(1), first.size(2), first.size(3)]
|
339 |
+
)
|
340 |
+
if self.needs_input_grad[0] == True
|
341 |
+
else None
|
342 |
+
)
|
343 |
+
gradSecond = (
|
344 |
+
first.new_zeros(
|
345 |
+
[first.size(0), first.size(1), first.size(2), first.size(3)]
|
346 |
+
)
|
347 |
+
if self.needs_input_grad[1] == True
|
348 |
+
else None
|
349 |
+
)
|
350 |
+
|
351 |
+
if first.is_cuda == True:
|
352 |
+
if gradFirst is not None:
|
353 |
+
for intSample in range(first.size(0)):
|
354 |
+
n = first.size(1) * first.size(2) * first.size(3)
|
355 |
+
cupy_launch(
|
356 |
+
"kernel_Correlation_updateGradFirst",
|
357 |
+
cupy_kernel(
|
358 |
+
"kernel_Correlation_updateGradFirst",
|
359 |
+
{
|
360 |
+
"rbot0": rbot0,
|
361 |
+
"rbot1": rbot1,
|
362 |
+
"gradOutput": gradOutput,
|
363 |
+
"gradFirst": gradFirst,
|
364 |
+
"gradSecond": None,
|
365 |
+
},
|
366 |
+
),
|
367 |
+
)(
|
368 |
+
grid=tuple([int((n + 512 - 1) / 512), 1, 1]),
|
369 |
+
block=tuple([512, 1, 1]),
|
370 |
+
args=[
|
371 |
+
n,
|
372 |
+
intSample,
|
373 |
+
rbot0.data_ptr(),
|
374 |
+
rbot1.data_ptr(),
|
375 |
+
gradOutput.data_ptr(),
|
376 |
+
gradFirst.data_ptr(),
|
377 |
+
None,
|
378 |
+
],
|
379 |
+
stream=Stream,
|
380 |
+
)
|
381 |
+
|
382 |
+
if gradSecond is not None:
|
383 |
+
for intSample in range(first.size(0)):
|
384 |
+
n = first.size(1) * first.size(2) * first.size(3)
|
385 |
+
cupy_launch(
|
386 |
+
"kernel_Correlation_updateGradSecond",
|
387 |
+
cupy_kernel(
|
388 |
+
"kernel_Correlation_updateGradSecond",
|
389 |
+
{
|
390 |
+
"rbot0": rbot0,
|
391 |
+
"rbot1": rbot1,
|
392 |
+
"gradOutput": gradOutput,
|
393 |
+
"gradFirst": None,
|
394 |
+
"gradSecond": gradSecond,
|
395 |
+
},
|
396 |
+
),
|
397 |
+
)(
|
398 |
+
grid=tuple([int((n + 512 - 1) / 512), 1, 1]),
|
399 |
+
block=tuple([512, 1, 1]),
|
400 |
+
args=[
|
401 |
+
n,
|
402 |
+
intSample,
|
403 |
+
rbot0.data_ptr(),
|
404 |
+
rbot1.data_ptr(),
|
405 |
+
gradOutput.data_ptr(),
|
406 |
+
None,
|
407 |
+
gradSecond.data_ptr(),
|
408 |
+
],
|
409 |
+
stream=Stream,
|
410 |
+
)
|
411 |
+
|
412 |
+
elif first.is_cuda == False:
|
413 |
+
raise NotImplementedError()
|
414 |
+
|
415 |
+
return gradFirst, gradSecond
|
416 |
+
|
417 |
+
|
418 |
+
class _FunctionCorrelationTranspose(torch.autograd.Function):
|
419 |
+
@staticmethod
|
420 |
+
def forward(self, input, second):
|
421 |
+
rbot0 = second.new_zeros(
|
422 |
+
[second.size(0), second.size(2) + 8, second.size(3) + 8, second.size(1)]
|
423 |
+
)
|
424 |
+
rbot1 = second.new_zeros(
|
425 |
+
[second.size(0), second.size(2) + 8, second.size(3) + 8, second.size(1)]
|
426 |
+
)
|
427 |
+
|
428 |
+
self.save_for_backward(input, second, rbot0, rbot1)
|
429 |
+
|
430 |
+
input = input.contiguous()
|
431 |
+
second = second.contiguous()
|
432 |
+
|
433 |
+
output = second.new_zeros(
|
434 |
+
[second.size(0), second.size(1), second.size(2), second.size(3)]
|
435 |
+
)
|
436 |
+
|
437 |
+
if second.is_cuda == True:
|
438 |
+
n = second.size(2) * second.size(3)
|
439 |
+
cupy_launch(
|
440 |
+
"kernel_Correlation_rearrange",
|
441 |
+
cupy_kernel(
|
442 |
+
"kernel_Correlation_rearrange", {"input": second, "output": rbot1}
|
443 |
+
),
|
444 |
+
)(
|
445 |
+
grid=tuple([int((n + 16 - 1) / 16), second.size(1), second.size(0)]),
|
446 |
+
block=tuple([16, 1, 1]),
|
447 |
+
args=[n, second.data_ptr(), rbot1.data_ptr()],
|
448 |
+
stream=Stream,
|
449 |
+
)
|
450 |
+
|
451 |
+
for intSample in range(second.size(0)):
|
452 |
+
n = second.size(1) * second.size(2) * second.size(3)
|
453 |
+
cupy_launch(
|
454 |
+
"kernel_Correlation_updateGradFirst",
|
455 |
+
cupy_kernel(
|
456 |
+
"kernel_Correlation_updateGradFirst",
|
457 |
+
{
|
458 |
+
"rbot0": rbot0,
|
459 |
+
"rbot1": rbot1,
|
460 |
+
"gradOutput": input,
|
461 |
+
"gradFirst": output,
|
462 |
+
"gradSecond": None,
|
463 |
+
},
|
464 |
+
),
|
465 |
+
)(
|
466 |
+
grid=tuple([int((n + 512 - 1) / 512), 1, 1]),
|
467 |
+
block=tuple([512, 1, 1]),
|
468 |
+
args=[
|
469 |
+
n,
|
470 |
+
intSample,
|
471 |
+
rbot0.data_ptr(),
|
472 |
+
rbot1.data_ptr(),
|
473 |
+
input.data_ptr(),
|
474 |
+
output.data_ptr(),
|
475 |
+
None,
|
476 |
+
],
|
477 |
+
stream=Stream,
|
478 |
+
)
|
479 |
+
|
480 |
+
elif second.is_cuda == False:
|
481 |
+
raise NotImplementedError()
|
482 |
+
|
483 |
+
return output
|
484 |
+
|
485 |
+
@staticmethod
|
486 |
+
def backward(self, gradOutput):
|
487 |
+
input, second, rbot0, rbot1 = self.saved_tensors
|
488 |
+
|
489 |
+
gradOutput = gradOutput.contiguous()
|
490 |
+
|
491 |
+
gradInput = (
|
492 |
+
input.new_zeros(
|
493 |
+
[input.size(0), input.size(1), input.size(2), input.size(3)]
|
494 |
+
)
|
495 |
+
if self.needs_input_grad[0] == True
|
496 |
+
else None
|
497 |
+
)
|
498 |
+
gradSecond = (
|
499 |
+
second.new_zeros(
|
500 |
+
[second.size(0), second.size(1), second.size(2), second.size(3)]
|
501 |
+
)
|
502 |
+
if self.needs_input_grad[1] == True
|
503 |
+
else None
|
504 |
+
)
|
505 |
+
|
506 |
+
if second.is_cuda == True:
|
507 |
+
if gradInput is not None or gradSecond is not None:
|
508 |
+
n = second.size(2) * second.size(3)
|
509 |
+
cupy_launch(
|
510 |
+
"kernel_Correlation_rearrange",
|
511 |
+
cupy_kernel(
|
512 |
+
"kernel_Correlation_rearrange",
|
513 |
+
{"input": gradOutput, "output": rbot0},
|
514 |
+
),
|
515 |
+
)(
|
516 |
+
grid=tuple(
|
517 |
+
[int((n + 16 - 1) / 16), gradOutput.size(1), gradOutput.size(0)]
|
518 |
+
),
|
519 |
+
block=tuple([16, 1, 1]),
|
520 |
+
args=[n, gradOutput.data_ptr(), rbot0.data_ptr()],
|
521 |
+
stream=Stream,
|
522 |
+
)
|
523 |
+
|
524 |
+
if gradInput is not None:
|
525 |
+
n = gradInput.size(1) * gradInput.size(2) * gradInput.size(3)
|
526 |
+
cupy_launch(
|
527 |
+
"kernel_Correlation_updateOutput",
|
528 |
+
cupy_kernel(
|
529 |
+
"kernel_Correlation_updateOutput",
|
530 |
+
{"rbot0": rbot0, "rbot1": rbot1, "top": gradInput},
|
531 |
+
),
|
532 |
+
)(
|
533 |
+
grid=tuple(
|
534 |
+
[gradInput.size(3), gradInput.size(2), gradInput.size(0)]
|
535 |
+
),
|
536 |
+
block=tuple([32, 1, 1]),
|
537 |
+
shared_mem=gradOutput.size(1) * 4,
|
538 |
+
args=[n, rbot0.data_ptr(), rbot1.data_ptr(), gradInput.data_ptr()],
|
539 |
+
stream=Stream,
|
540 |
+
)
|
541 |
+
|
542 |
+
if gradSecond is not None:
|
543 |
+
for intSample in range(second.size(0)):
|
544 |
+
n = second.size(1) * second.size(2) * second.size(3)
|
545 |
+
cupy_launch(
|
546 |
+
"kernel_Correlation_updateGradSecond",
|
547 |
+
cupy_kernel(
|
548 |
+
"kernel_Correlation_updateGradSecond",
|
549 |
+
{
|
550 |
+
"rbot0": rbot0,
|
551 |
+
"rbot1": rbot1,
|
552 |
+
"gradOutput": input,
|
553 |
+
"gradFirst": None,
|
554 |
+
"gradSecond": gradSecond,
|
555 |
+
},
|
556 |
+
),
|
557 |
+
)(
|
558 |
+
grid=tuple([int((n + 512 - 1) / 512), 1, 1]),
|
559 |
+
block=tuple([512, 1, 1]),
|
560 |
+
args=[
|
561 |
+
n,
|
562 |
+
intSample,
|
563 |
+
rbot0.data_ptr(),
|
564 |
+
rbot1.data_ptr(),
|
565 |
+
input.data_ptr(),
|
566 |
+
None,
|
567 |
+
gradSecond.data_ptr(),
|
568 |
+
],
|
569 |
+
stream=Stream,
|
570 |
+
)
|
571 |
+
|
572 |
+
elif second.is_cuda == False:
|
573 |
+
raise NotImplementedError()
|
574 |
+
|
575 |
+
return gradInput, gradSecond
|
576 |
+
|
577 |
+
|
578 |
+
def FunctionCorrelation(reference_features, query_features):
|
579 |
+
return _FunctionCorrelation.apply(reference_features, query_features)
|
580 |
+
|
581 |
+
|
582 |
+
class ModuleCorrelation(torch.nn.Module):
|
583 |
+
def __init__(self):
|
584 |
+
super(ModuleCorrelation, self).__init__()
|
585 |
+
|
586 |
+
def forward(self, tensorFirst, tensorSecond):
|
587 |
+
return _FunctionCorrelation.apply(tensorFirst, tensorSecond)
|
588 |
+
|
589 |
+
|
590 |
+
def FunctionCorrelationTranspose(reference_features, query_features):
|
591 |
+
return _FunctionCorrelationTranspose.apply(reference_features, query_features)
|
592 |
+
|
593 |
+
|
594 |
+
class ModuleCorrelationTranspose(torch.nn.Module):
|
595 |
+
def __init__(self):
|
596 |
+
super(ModuleCorrelationTranspose, self).__init__()
|
597 |
+
|
598 |
+
def forward(self, tensorFirst, tensorSecond):
|
599 |
+
return _FunctionCorrelationTranspose.apply(tensorFirst, tensorSecond)
|
600 |
+
|
601 |
+
|
602 |
+
class LocalCorr(ConvRefiner):
|
603 |
+
def forward(self, x, y, flow):
|
604 |
+
"""Computes the relative refining displacement in pixels for a given image x,y and a coarse flow-field between them
|
605 |
+
|
606 |
+
Args:
|
607 |
+
x ([type]): [description]
|
608 |
+
y ([type]): [description]
|
609 |
+
flow ([type]): [description]
|
610 |
+
|
611 |
+
Returns:
|
612 |
+
[type]: [description]
|
613 |
+
"""
|
614 |
+
with torch.no_grad():
|
615 |
+
x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False)
|
616 |
+
corr = FunctionCorrelation(x, x_hat)
|
617 |
+
d = self.block1(corr)
|
618 |
+
d = self.hidden_blocks(d)
|
619 |
+
d = self.out_conv(d)
|
620 |
+
certainty, displacement = d[:, :-2], d[:, -2:]
|
621 |
+
return certainty, displacement
|
622 |
+
|
623 |
+
|
624 |
+
if __name__ == "__main__":
|
625 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
626 |
+
x = torch.randn(2, 128, 32, 32).to(device)
|
627 |
+
y = torch.randn(2, 128, 32, 32).to(device)
|
628 |
+
local_corr = LocalCorr(in_dim=81, hidden_dim=81 * 4)
|
629 |
+
z = local_corr(x, y)
|
630 |
+
print("hej")
|
third_party/DKM/dkm/models/dkm.py
ADDED
@@ -0,0 +1,759 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from ..utils import get_tuple_transform_ops
|
9 |
+
from einops import rearrange
|
10 |
+
from ..utils.local_correlation import local_correlation
|
11 |
+
|
12 |
+
|
13 |
+
class ConvRefiner(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
in_dim=6,
|
17 |
+
hidden_dim=16,
|
18 |
+
out_dim=2,
|
19 |
+
dw=False,
|
20 |
+
kernel_size=5,
|
21 |
+
hidden_blocks=3,
|
22 |
+
displacement_emb = None,
|
23 |
+
displacement_emb_dim = None,
|
24 |
+
local_corr_radius = None,
|
25 |
+
corr_in_other = None,
|
26 |
+
no_support_fm = False,
|
27 |
+
):
|
28 |
+
super().__init__()
|
29 |
+
self.block1 = self.create_block(
|
30 |
+
in_dim, hidden_dim, dw=dw, kernel_size=kernel_size
|
31 |
+
)
|
32 |
+
self.hidden_blocks = nn.Sequential(
|
33 |
+
*[
|
34 |
+
self.create_block(
|
35 |
+
hidden_dim,
|
36 |
+
hidden_dim,
|
37 |
+
dw=dw,
|
38 |
+
kernel_size=kernel_size,
|
39 |
+
)
|
40 |
+
for hb in range(hidden_blocks)
|
41 |
+
]
|
42 |
+
)
|
43 |
+
self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
|
44 |
+
if displacement_emb:
|
45 |
+
self.has_displacement_emb = True
|
46 |
+
self.disp_emb = nn.Conv2d(2,displacement_emb_dim,1,1,0)
|
47 |
+
else:
|
48 |
+
self.has_displacement_emb = False
|
49 |
+
self.local_corr_radius = local_corr_radius
|
50 |
+
self.corr_in_other = corr_in_other
|
51 |
+
self.no_support_fm = no_support_fm
|
52 |
+
def create_block(
|
53 |
+
self,
|
54 |
+
in_dim,
|
55 |
+
out_dim,
|
56 |
+
dw=False,
|
57 |
+
kernel_size=5,
|
58 |
+
):
|
59 |
+
num_groups = 1 if not dw else in_dim
|
60 |
+
if dw:
|
61 |
+
assert (
|
62 |
+
out_dim % in_dim == 0
|
63 |
+
), "outdim must be divisible by indim for depthwise"
|
64 |
+
conv1 = nn.Conv2d(
|
65 |
+
in_dim,
|
66 |
+
out_dim,
|
67 |
+
kernel_size=kernel_size,
|
68 |
+
stride=1,
|
69 |
+
padding=kernel_size // 2,
|
70 |
+
groups=num_groups,
|
71 |
+
)
|
72 |
+
norm = nn.BatchNorm2d(out_dim)
|
73 |
+
relu = nn.ReLU(inplace=True)
|
74 |
+
conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
|
75 |
+
return nn.Sequential(conv1, norm, relu, conv2)
|
76 |
+
|
77 |
+
def forward(self, x, y, flow):
|
78 |
+
"""Computes the relative refining displacement in pixels for a given image x,y and a coarse flow-field between them
|
79 |
+
|
80 |
+
Args:
|
81 |
+
x ([type]): [description]
|
82 |
+
y ([type]): [description]
|
83 |
+
flow ([type]): [description]
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
[type]: [description]
|
87 |
+
"""
|
88 |
+
device = x.device
|
89 |
+
b,c,hs,ws = x.shape
|
90 |
+
with torch.no_grad():
|
91 |
+
x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False)
|
92 |
+
if self.has_displacement_emb:
|
93 |
+
query_coords = torch.meshgrid(
|
94 |
+
(
|
95 |
+
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
|
96 |
+
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
|
97 |
+
)
|
98 |
+
)
|
99 |
+
query_coords = torch.stack((query_coords[1], query_coords[0]))
|
100 |
+
query_coords = query_coords[None].expand(b, 2, hs, ws)
|
101 |
+
in_displacement = flow-query_coords
|
102 |
+
emb_in_displacement = self.disp_emb(in_displacement)
|
103 |
+
if self.local_corr_radius:
|
104 |
+
#TODO: should corr have gradient?
|
105 |
+
if self.corr_in_other:
|
106 |
+
# Corr in other means take a kxk grid around the predicted coordinate in other image
|
107 |
+
local_corr = local_correlation(x,y,local_radius=self.local_corr_radius,flow = flow)
|
108 |
+
else:
|
109 |
+
# Otherwise we use the warp to sample in the first image
|
110 |
+
# This is actually different operations, especially for large viewpoint changes
|
111 |
+
local_corr = local_correlation(x, x_hat, local_radius=self.local_corr_radius,)
|
112 |
+
if self.no_support_fm:
|
113 |
+
x_hat = torch.zeros_like(x)
|
114 |
+
d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1)
|
115 |
+
else:
|
116 |
+
d = torch.cat((x, x_hat, emb_in_displacement), dim=1)
|
117 |
+
else:
|
118 |
+
if self.no_support_fm:
|
119 |
+
x_hat = torch.zeros_like(x)
|
120 |
+
d = torch.cat((x, x_hat), dim=1)
|
121 |
+
d = self.block1(d)
|
122 |
+
d = self.hidden_blocks(d)
|
123 |
+
d = self.out_conv(d)
|
124 |
+
certainty, displacement = d[:, :-2], d[:, -2:]
|
125 |
+
return certainty, displacement
|
126 |
+
|
127 |
+
|
128 |
+
class CosKernel(nn.Module): # similar to softmax kernel
|
129 |
+
def __init__(self, T, learn_temperature=False):
|
130 |
+
super().__init__()
|
131 |
+
self.learn_temperature = learn_temperature
|
132 |
+
if self.learn_temperature:
|
133 |
+
self.T = nn.Parameter(torch.tensor(T))
|
134 |
+
else:
|
135 |
+
self.T = T
|
136 |
+
|
137 |
+
def __call__(self, x, y, eps=1e-6):
|
138 |
+
c = torch.einsum("bnd,bmd->bnm", x, y) / (
|
139 |
+
x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps
|
140 |
+
)
|
141 |
+
if self.learn_temperature:
|
142 |
+
T = self.T.abs() + 0.01
|
143 |
+
else:
|
144 |
+
T = torch.tensor(self.T, device=c.device)
|
145 |
+
K = ((c - 1.0) / T).exp()
|
146 |
+
return K
|
147 |
+
|
148 |
+
|
149 |
+
class CAB(nn.Module):
|
150 |
+
def __init__(self, in_channels, out_channels):
|
151 |
+
super(CAB, self).__init__()
|
152 |
+
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
153 |
+
self.conv1 = nn.Conv2d(
|
154 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
155 |
+
)
|
156 |
+
self.relu = nn.ReLU()
|
157 |
+
self.conv2 = nn.Conv2d(
|
158 |
+
out_channels, out_channels, kernel_size=1, stride=1, padding=0
|
159 |
+
)
|
160 |
+
self.sigmod = nn.Sigmoid()
|
161 |
+
|
162 |
+
def forward(self, x):
|
163 |
+
x1, x2 = x # high, low (old, new)
|
164 |
+
x = torch.cat([x1, x2], dim=1)
|
165 |
+
x = self.global_pooling(x)
|
166 |
+
x = self.conv1(x)
|
167 |
+
x = self.relu(x)
|
168 |
+
x = self.conv2(x)
|
169 |
+
x = self.sigmod(x)
|
170 |
+
x2 = x * x2
|
171 |
+
res = x2 + x1
|
172 |
+
return res
|
173 |
+
|
174 |
+
|
175 |
+
class RRB(nn.Module):
|
176 |
+
def __init__(self, in_channels, out_channels, kernel_size=3):
|
177 |
+
super(RRB, self).__init__()
|
178 |
+
self.conv1 = nn.Conv2d(
|
179 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
180 |
+
)
|
181 |
+
self.conv2 = nn.Conv2d(
|
182 |
+
out_channels,
|
183 |
+
out_channels,
|
184 |
+
kernel_size=kernel_size,
|
185 |
+
stride=1,
|
186 |
+
padding=kernel_size // 2,
|
187 |
+
)
|
188 |
+
self.relu = nn.ReLU()
|
189 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
190 |
+
self.conv3 = nn.Conv2d(
|
191 |
+
out_channels,
|
192 |
+
out_channels,
|
193 |
+
kernel_size=kernel_size,
|
194 |
+
stride=1,
|
195 |
+
padding=kernel_size // 2,
|
196 |
+
)
|
197 |
+
|
198 |
+
def forward(self, x):
|
199 |
+
x = self.conv1(x)
|
200 |
+
res = self.conv2(x)
|
201 |
+
res = self.bn(res)
|
202 |
+
res = self.relu(res)
|
203 |
+
res = self.conv3(res)
|
204 |
+
return self.relu(x + res)
|
205 |
+
|
206 |
+
|
207 |
+
class DFN(nn.Module):
|
208 |
+
def __init__(
|
209 |
+
self,
|
210 |
+
internal_dim,
|
211 |
+
feat_input_modules,
|
212 |
+
pred_input_modules,
|
213 |
+
rrb_d_dict,
|
214 |
+
cab_dict,
|
215 |
+
rrb_u_dict,
|
216 |
+
use_global_context=False,
|
217 |
+
global_dim=None,
|
218 |
+
terminal_module=None,
|
219 |
+
upsample_mode="bilinear",
|
220 |
+
align_corners=False,
|
221 |
+
):
|
222 |
+
super().__init__()
|
223 |
+
if use_global_context:
|
224 |
+
assert (
|
225 |
+
global_dim is not None
|
226 |
+
), "Global dim must be provided when using global context"
|
227 |
+
self.align_corners = align_corners
|
228 |
+
self.internal_dim = internal_dim
|
229 |
+
self.feat_input_modules = feat_input_modules
|
230 |
+
self.pred_input_modules = pred_input_modules
|
231 |
+
self.rrb_d = rrb_d_dict
|
232 |
+
self.cab = cab_dict
|
233 |
+
self.rrb_u = rrb_u_dict
|
234 |
+
self.use_global_context = use_global_context
|
235 |
+
if use_global_context:
|
236 |
+
self.global_to_internal = nn.Conv2d(global_dim, self.internal_dim, 1, 1, 0)
|
237 |
+
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
238 |
+
self.terminal_module = (
|
239 |
+
terminal_module if terminal_module is not None else nn.Identity()
|
240 |
+
)
|
241 |
+
self.upsample_mode = upsample_mode
|
242 |
+
self._scales = [int(key) for key in self.terminal_module.keys()]
|
243 |
+
|
244 |
+
def scales(self):
|
245 |
+
return self._scales.copy()
|
246 |
+
|
247 |
+
def forward(self, embeddings, feats, context, key):
|
248 |
+
feats = self.feat_input_modules[str(key)](feats)
|
249 |
+
embeddings = torch.cat([feats, embeddings], dim=1)
|
250 |
+
embeddings = self.rrb_d[str(key)](embeddings)
|
251 |
+
context = self.cab[str(key)]([context, embeddings])
|
252 |
+
context = self.rrb_u[str(key)](context)
|
253 |
+
preds = self.terminal_module[str(key)](context)
|
254 |
+
pred_coord = preds[:, -2:]
|
255 |
+
pred_certainty = preds[:, :-2]
|
256 |
+
return pred_coord, pred_certainty, context
|
257 |
+
|
258 |
+
|
259 |
+
class GP(nn.Module):
|
260 |
+
def __init__(
|
261 |
+
self,
|
262 |
+
kernel,
|
263 |
+
T=1,
|
264 |
+
learn_temperature=False,
|
265 |
+
only_attention=False,
|
266 |
+
gp_dim=64,
|
267 |
+
basis="fourier",
|
268 |
+
covar_size=5,
|
269 |
+
only_nearest_neighbour=False,
|
270 |
+
sigma_noise=0.1,
|
271 |
+
no_cov=False,
|
272 |
+
predict_features = False,
|
273 |
+
):
|
274 |
+
super().__init__()
|
275 |
+
self.K = kernel(T=T, learn_temperature=learn_temperature)
|
276 |
+
self.sigma_noise = sigma_noise
|
277 |
+
self.covar_size = covar_size
|
278 |
+
self.pos_conv = torch.nn.Conv2d(2, gp_dim, 1, 1)
|
279 |
+
self.only_attention = only_attention
|
280 |
+
self.only_nearest_neighbour = only_nearest_neighbour
|
281 |
+
self.basis = basis
|
282 |
+
self.no_cov = no_cov
|
283 |
+
self.dim = gp_dim
|
284 |
+
self.predict_features = predict_features
|
285 |
+
|
286 |
+
def get_local_cov(self, cov):
|
287 |
+
K = self.covar_size
|
288 |
+
b, h, w, h, w = cov.shape
|
289 |
+
hw = h * w
|
290 |
+
cov = F.pad(cov, 4 * (K // 2,)) # pad v_q
|
291 |
+
delta = torch.stack(
|
292 |
+
torch.meshgrid(
|
293 |
+
torch.arange(-(K // 2), K // 2 + 1), torch.arange(-(K // 2), K // 2 + 1)
|
294 |
+
),
|
295 |
+
dim=-1,
|
296 |
+
)
|
297 |
+
positions = torch.stack(
|
298 |
+
torch.meshgrid(
|
299 |
+
torch.arange(K // 2, h + K // 2), torch.arange(K // 2, w + K // 2)
|
300 |
+
),
|
301 |
+
dim=-1,
|
302 |
+
)
|
303 |
+
neighbours = positions[:, :, None, None, :] + delta[None, :, :]
|
304 |
+
points = torch.arange(hw)[:, None].expand(hw, K**2)
|
305 |
+
local_cov = cov.reshape(b, hw, h + K - 1, w + K - 1)[
|
306 |
+
:,
|
307 |
+
points.flatten(),
|
308 |
+
neighbours[..., 0].flatten(),
|
309 |
+
neighbours[..., 1].flatten(),
|
310 |
+
].reshape(b, h, w, K**2)
|
311 |
+
return local_cov
|
312 |
+
|
313 |
+
def reshape(self, x):
|
314 |
+
return rearrange(x, "b d h w -> b (h w) d")
|
315 |
+
|
316 |
+
def project_to_basis(self, x):
|
317 |
+
if self.basis == "fourier":
|
318 |
+
return torch.cos(8 * math.pi * self.pos_conv(x))
|
319 |
+
elif self.basis == "linear":
|
320 |
+
return self.pos_conv(x)
|
321 |
+
else:
|
322 |
+
raise ValueError(
|
323 |
+
"No other bases other than fourier and linear currently supported in public release"
|
324 |
+
)
|
325 |
+
|
326 |
+
def get_pos_enc(self, y):
|
327 |
+
b, c, h, w = y.shape
|
328 |
+
coarse_coords = torch.meshgrid(
|
329 |
+
(
|
330 |
+
torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=y.device),
|
331 |
+
torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=y.device),
|
332 |
+
)
|
333 |
+
)
|
334 |
+
|
335 |
+
coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
|
336 |
+
None
|
337 |
+
].expand(b, h, w, 2)
|
338 |
+
coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
|
339 |
+
coarse_embedded_coords = self.project_to_basis(coarse_coords)
|
340 |
+
return coarse_embedded_coords
|
341 |
+
|
342 |
+
def forward(self, x, y, **kwargs):
|
343 |
+
b, c, h1, w1 = x.shape
|
344 |
+
b, c, h2, w2 = y.shape
|
345 |
+
f = self.get_pos_enc(y)
|
346 |
+
if self.predict_features:
|
347 |
+
f = f + y[:,:self.dim] # Stupid way to predict features
|
348 |
+
b, d, h2, w2 = f.shape
|
349 |
+
#assert x.shape == y.shape
|
350 |
+
x, y, f = self.reshape(x), self.reshape(y), self.reshape(f)
|
351 |
+
K_xx = self.K(x, x)
|
352 |
+
K_yy = self.K(y, y)
|
353 |
+
K_xy = self.K(x, y)
|
354 |
+
K_yx = K_xy.permute(0, 2, 1)
|
355 |
+
sigma_noise = self.sigma_noise * torch.eye(h2 * w2, device=x.device)[None, :, :]
|
356 |
+
# Due to https://github.com/pytorch/pytorch/issues/16963 annoying warnings, remove batch if N large
|
357 |
+
if len(K_yy[0]) > 2000:
|
358 |
+
K_yy_inv = torch.cat([torch.linalg.inv(K_yy[k:k+1] + sigma_noise[k:k+1]) for k in range(b)])
|
359 |
+
else:
|
360 |
+
K_yy_inv = torch.linalg.inv(K_yy + sigma_noise)
|
361 |
+
|
362 |
+
mu_x = K_xy.matmul(K_yy_inv.matmul(f))
|
363 |
+
mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1)
|
364 |
+
if not self.no_cov:
|
365 |
+
cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx))
|
366 |
+
cov_x = rearrange(cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1)
|
367 |
+
local_cov_x = self.get_local_cov(cov_x)
|
368 |
+
local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w")
|
369 |
+
gp_feats = torch.cat((mu_x, local_cov_x), dim=1)
|
370 |
+
else:
|
371 |
+
gp_feats = mu_x
|
372 |
+
return gp_feats
|
373 |
+
|
374 |
+
|
375 |
+
class Encoder(nn.Module):
|
376 |
+
def __init__(self, resnet):
|
377 |
+
super().__init__()
|
378 |
+
self.resnet = resnet
|
379 |
+
def forward(self, x):
|
380 |
+
x0 = x
|
381 |
+
b, c, h, w = x.shape
|
382 |
+
x = self.resnet.conv1(x)
|
383 |
+
x = self.resnet.bn1(x)
|
384 |
+
x1 = self.resnet.relu(x)
|
385 |
+
|
386 |
+
x = self.resnet.maxpool(x1)
|
387 |
+
x2 = self.resnet.layer1(x)
|
388 |
+
|
389 |
+
x3 = self.resnet.layer2(x2)
|
390 |
+
|
391 |
+
x4 = self.resnet.layer3(x3)
|
392 |
+
|
393 |
+
x5 = self.resnet.layer4(x4)
|
394 |
+
feats = {32: x5, 16: x4, 8: x3, 4: x2, 2: x1, 1: x0}
|
395 |
+
return feats
|
396 |
+
|
397 |
+
def train(self, mode=True):
|
398 |
+
super().train(mode)
|
399 |
+
for m in self.modules():
|
400 |
+
if isinstance(m, nn.BatchNorm2d):
|
401 |
+
m.eval()
|
402 |
+
pass
|
403 |
+
|
404 |
+
|
405 |
+
class Decoder(nn.Module):
|
406 |
+
def __init__(
|
407 |
+
self, embedding_decoder, gps, proj, conv_refiner, transformers = None, detach=False, scales="all", pos_embeddings = None,
|
408 |
+
):
|
409 |
+
super().__init__()
|
410 |
+
self.embedding_decoder = embedding_decoder
|
411 |
+
self.gps = gps
|
412 |
+
self.proj = proj
|
413 |
+
self.conv_refiner = conv_refiner
|
414 |
+
self.detach = detach
|
415 |
+
if scales == "all":
|
416 |
+
self.scales = ["32", "16", "8", "4", "2", "1"]
|
417 |
+
else:
|
418 |
+
self.scales = scales
|
419 |
+
|
420 |
+
def upsample_preds(self, flow, certainty, query, support):
|
421 |
+
b, hs, ws, d = flow.shape
|
422 |
+
b, c, h, w = query.shape
|
423 |
+
flow = flow.permute(0, 3, 1, 2)
|
424 |
+
certainty = F.interpolate(
|
425 |
+
certainty, size=(h, w), align_corners=False, mode="bilinear"
|
426 |
+
)
|
427 |
+
flow = F.interpolate(
|
428 |
+
flow, size=(h, w), align_corners=False, mode="bilinear"
|
429 |
+
)
|
430 |
+
delta_certainty, delta_flow = self.conv_refiner["1"](query, support, flow)
|
431 |
+
flow = torch.stack(
|
432 |
+
(
|
433 |
+
flow[:, 0] + delta_flow[:, 0] / (4 * w),
|
434 |
+
flow[:, 1] + delta_flow[:, 1] / (4 * h),
|
435 |
+
),
|
436 |
+
dim=1,
|
437 |
+
)
|
438 |
+
flow = flow.permute(0, 2, 3, 1)
|
439 |
+
certainty = certainty + delta_certainty
|
440 |
+
return flow, certainty
|
441 |
+
|
442 |
+
def get_placeholder_flow(self, b, h, w, device):
|
443 |
+
coarse_coords = torch.meshgrid(
|
444 |
+
(
|
445 |
+
torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
|
446 |
+
torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
|
447 |
+
)
|
448 |
+
)
|
449 |
+
coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
|
450 |
+
None
|
451 |
+
].expand(b, h, w, 2)
|
452 |
+
coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
|
453 |
+
return coarse_coords
|
454 |
+
|
455 |
+
|
456 |
+
def forward(self, f1, f2, upsample = False, dense_flow = None, dense_certainty = None):
|
457 |
+
coarse_scales = self.embedding_decoder.scales()
|
458 |
+
all_scales = self.scales if not upsample else ["8", "4", "2", "1"]
|
459 |
+
sizes = {scale: f1[scale].shape[-2:] for scale in f1}
|
460 |
+
h, w = sizes[1]
|
461 |
+
b = f1[1].shape[0]
|
462 |
+
device = f1[1].device
|
463 |
+
coarsest_scale = int(all_scales[0])
|
464 |
+
old_stuff = torch.zeros(
|
465 |
+
b, self.embedding_decoder.internal_dim, *sizes[coarsest_scale], device=f1[coarsest_scale].device
|
466 |
+
)
|
467 |
+
dense_corresps = {}
|
468 |
+
if not upsample:
|
469 |
+
dense_flow = self.get_placeholder_flow(b, *sizes[coarsest_scale], device)
|
470 |
+
dense_certainty = 0.0
|
471 |
+
else:
|
472 |
+
dense_flow = F.interpolate(
|
473 |
+
dense_flow,
|
474 |
+
size=sizes[coarsest_scale],
|
475 |
+
align_corners=False,
|
476 |
+
mode="bilinear",
|
477 |
+
)
|
478 |
+
dense_certainty = F.interpolate(
|
479 |
+
dense_certainty,
|
480 |
+
size=sizes[coarsest_scale],
|
481 |
+
align_corners=False,
|
482 |
+
mode="bilinear",
|
483 |
+
)
|
484 |
+
for new_scale in all_scales:
|
485 |
+
ins = int(new_scale)
|
486 |
+
f1_s, f2_s = f1[ins], f2[ins]
|
487 |
+
if new_scale in self.proj:
|
488 |
+
f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
|
489 |
+
b, c, hs, ws = f1_s.shape
|
490 |
+
if ins in coarse_scales:
|
491 |
+
old_stuff = F.interpolate(
|
492 |
+
old_stuff, size=sizes[ins], mode="bilinear", align_corners=False
|
493 |
+
)
|
494 |
+
new_stuff = self.gps[new_scale](f1_s, f2_s, dense_flow=dense_flow)
|
495 |
+
dense_flow, dense_certainty, old_stuff = self.embedding_decoder(
|
496 |
+
new_stuff, f1_s, old_stuff, new_scale
|
497 |
+
)
|
498 |
+
|
499 |
+
if new_scale in self.conv_refiner:
|
500 |
+
delta_certainty, displacement = self.conv_refiner[new_scale](
|
501 |
+
f1_s, f2_s, dense_flow
|
502 |
+
)
|
503 |
+
dense_flow = torch.stack(
|
504 |
+
(
|
505 |
+
dense_flow[:, 0] + ins * displacement[:, 0] / (4 * w),
|
506 |
+
dense_flow[:, 1] + ins * displacement[:, 1] / (4 * h),
|
507 |
+
),
|
508 |
+
dim=1,
|
509 |
+
)
|
510 |
+
dense_certainty = (
|
511 |
+
dense_certainty + delta_certainty
|
512 |
+
) # predict both certainty and displacement
|
513 |
+
|
514 |
+
dense_corresps[ins] = {
|
515 |
+
"dense_flow": dense_flow,
|
516 |
+
"dense_certainty": dense_certainty,
|
517 |
+
}
|
518 |
+
|
519 |
+
if new_scale != "1":
|
520 |
+
dense_flow = F.interpolate(
|
521 |
+
dense_flow,
|
522 |
+
size=sizes[ins // 2],
|
523 |
+
align_corners=False,
|
524 |
+
mode="bilinear",
|
525 |
+
)
|
526 |
+
|
527 |
+
dense_certainty = F.interpolate(
|
528 |
+
dense_certainty,
|
529 |
+
size=sizes[ins // 2],
|
530 |
+
align_corners=False,
|
531 |
+
mode="bilinear",
|
532 |
+
)
|
533 |
+
if self.detach:
|
534 |
+
dense_flow = dense_flow.detach()
|
535 |
+
dense_certainty = dense_certainty.detach()
|
536 |
+
return dense_corresps
|
537 |
+
|
538 |
+
|
539 |
+
class RegressionMatcher(nn.Module):
|
540 |
+
def __init__(
|
541 |
+
self,
|
542 |
+
encoder,
|
543 |
+
decoder,
|
544 |
+
h=384,
|
545 |
+
w=512,
|
546 |
+
use_contrastive_loss = False,
|
547 |
+
alpha = 1,
|
548 |
+
beta = 0,
|
549 |
+
sample_mode = "threshold",
|
550 |
+
upsample_preds = False,
|
551 |
+
symmetric = False,
|
552 |
+
name = None,
|
553 |
+
use_soft_mutual_nearest_neighbours = False,
|
554 |
+
):
|
555 |
+
super().__init__()
|
556 |
+
self.encoder = encoder
|
557 |
+
self.decoder = decoder
|
558 |
+
self.w_resized = w
|
559 |
+
self.h_resized = h
|
560 |
+
self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True)
|
561 |
+
self.use_contrastive_loss = use_contrastive_loss
|
562 |
+
self.alpha = alpha
|
563 |
+
self.beta = beta
|
564 |
+
self.sample_mode = sample_mode
|
565 |
+
self.upsample_preds = upsample_preds
|
566 |
+
self.symmetric = symmetric
|
567 |
+
self.name = name
|
568 |
+
self.sample_thresh = 0.05
|
569 |
+
self.upsample_res = (864,1152)
|
570 |
+
if use_soft_mutual_nearest_neighbours:
|
571 |
+
assert symmetric, "MNS requires symmetric inference"
|
572 |
+
self.use_soft_mutual_nearest_neighbours = use_soft_mutual_nearest_neighbours
|
573 |
+
|
574 |
+
def extract_backbone_features(self, batch, batched = True, upsample = True):
|
575 |
+
#TODO: only extract stride [1,2,4,8] for upsample = True
|
576 |
+
x_q = batch["query"]
|
577 |
+
x_s = batch["support"]
|
578 |
+
if batched:
|
579 |
+
X = torch.cat((x_q, x_s))
|
580 |
+
feature_pyramid = self.encoder(X)
|
581 |
+
else:
|
582 |
+
feature_pyramid = self.encoder(x_q), self.encoder(x_s)
|
583 |
+
return feature_pyramid
|
584 |
+
|
585 |
+
def sample(
|
586 |
+
self,
|
587 |
+
dense_matches,
|
588 |
+
dense_certainty,
|
589 |
+
num=10000,
|
590 |
+
):
|
591 |
+
if "threshold" in self.sample_mode:
|
592 |
+
upper_thresh = self.sample_thresh
|
593 |
+
dense_certainty = dense_certainty.clone()
|
594 |
+
dense_certainty[dense_certainty > upper_thresh] = 1
|
595 |
+
elif "pow" in self.sample_mode:
|
596 |
+
dense_certainty = dense_certainty**(1/3)
|
597 |
+
elif "naive" in self.sample_mode:
|
598 |
+
dense_certainty = torch.ones_like(dense_certainty)
|
599 |
+
matches, certainty = (
|
600 |
+
dense_matches.reshape(-1, 4),
|
601 |
+
dense_certainty.reshape(-1),
|
602 |
+
)
|
603 |
+
expansion_factor = 4 if "balanced" in self.sample_mode else 1
|
604 |
+
good_samples = torch.multinomial(certainty,
|
605 |
+
num_samples = min(expansion_factor*num, len(certainty)),
|
606 |
+
replacement=False)
|
607 |
+
good_matches, good_certainty = matches[good_samples], certainty[good_samples]
|
608 |
+
if "balanced" not in self.sample_mode:
|
609 |
+
return good_matches, good_certainty
|
610 |
+
|
611 |
+
from ..utils.kde import kde
|
612 |
+
density = kde(good_matches, std=0.1)
|
613 |
+
p = 1 / (density+1)
|
614 |
+
p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
|
615 |
+
balanced_samples = torch.multinomial(p,
|
616 |
+
num_samples = min(num,len(good_certainty)),
|
617 |
+
replacement=False)
|
618 |
+
return good_matches[balanced_samples], good_certainty[balanced_samples]
|
619 |
+
|
620 |
+
def forward(self, batch, batched = True):
|
621 |
+
feature_pyramid = self.extract_backbone_features(batch, batched=batched)
|
622 |
+
if batched:
|
623 |
+
f_q_pyramid = {
|
624 |
+
scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items()
|
625 |
+
}
|
626 |
+
f_s_pyramid = {
|
627 |
+
scale: f_scale.chunk(2)[1] for scale, f_scale in feature_pyramid.items()
|
628 |
+
}
|
629 |
+
else:
|
630 |
+
f_q_pyramid, f_s_pyramid = feature_pyramid
|
631 |
+
dense_corresps = self.decoder(f_q_pyramid, f_s_pyramid)
|
632 |
+
if self.training and self.use_contrastive_loss:
|
633 |
+
return dense_corresps, (f_q_pyramid, f_s_pyramid)
|
634 |
+
else:
|
635 |
+
return dense_corresps
|
636 |
+
|
637 |
+
def forward_symmetric(self, batch, upsample = False, batched = True):
|
638 |
+
feature_pyramid = self.extract_backbone_features(batch, upsample = upsample, batched = batched)
|
639 |
+
f_q_pyramid = feature_pyramid
|
640 |
+
f_s_pyramid = {
|
641 |
+
scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]))
|
642 |
+
for scale, f_scale in feature_pyramid.items()
|
643 |
+
}
|
644 |
+
dense_corresps = self.decoder(f_q_pyramid, f_s_pyramid, upsample = upsample, **(batch["corresps"] if "corresps" in batch else {}))
|
645 |
+
return dense_corresps
|
646 |
+
|
647 |
+
def to_pixel_coordinates(self, matches, H_A, W_A, H_B, W_B):
|
648 |
+
kpts_A, kpts_B = matches[...,:2], matches[...,2:]
|
649 |
+
kpts_A = torch.stack((W_A/2 * (kpts_A[...,0]+1), H_A/2 * (kpts_A[...,1]+1)),axis=-1)
|
650 |
+
kpts_B = torch.stack((W_B/2 * (kpts_B[...,0]+1), H_B/2 * (kpts_B[...,1]+1)),axis=-1)
|
651 |
+
return kpts_A, kpts_B
|
652 |
+
|
653 |
+
def match(
|
654 |
+
self,
|
655 |
+
im1_path,
|
656 |
+
im2_path,
|
657 |
+
*args,
|
658 |
+
batched=False,
|
659 |
+
device = None
|
660 |
+
):
|
661 |
+
assert not (batched and self.upsample_preds), "Cannot upsample preds if in batchmode (as we don't have access to high res images). You can turn off upsample_preds by model.upsample_preds = False "
|
662 |
+
if isinstance(im1_path, (str, os.PathLike)):
|
663 |
+
im1, im2 = Image.open(im1_path), Image.open(im2_path)
|
664 |
+
else: # assume it is a PIL Image
|
665 |
+
im1, im2 = im1_path, im2_path
|
666 |
+
if device is None:
|
667 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
668 |
+
symmetric = self.symmetric
|
669 |
+
self.train(False)
|
670 |
+
with torch.no_grad():
|
671 |
+
if not batched:
|
672 |
+
b = 1
|
673 |
+
w, h = im1.size
|
674 |
+
w2, h2 = im2.size
|
675 |
+
# Get images in good format
|
676 |
+
ws = self.w_resized
|
677 |
+
hs = self.h_resized
|
678 |
+
|
679 |
+
test_transform = get_tuple_transform_ops(
|
680 |
+
resize=(hs, ws), normalize=True
|
681 |
+
)
|
682 |
+
query, support = test_transform((im1, im2))
|
683 |
+
batch = {"query": query[None].to(device), "support": support[None].to(device)}
|
684 |
+
else:
|
685 |
+
b, c, h, w = im1.shape
|
686 |
+
b, c, h2, w2 = im2.shape
|
687 |
+
assert w == w2 and h == h2, "For batched images we assume same size"
|
688 |
+
batch = {"query": im1.to(device), "support": im2.to(device)}
|
689 |
+
hs, ws = self.h_resized, self.w_resized
|
690 |
+
finest_scale = 1
|
691 |
+
# Run matcher
|
692 |
+
if symmetric:
|
693 |
+
dense_corresps = self.forward_symmetric(batch, batched = True)
|
694 |
+
else:
|
695 |
+
dense_corresps = self.forward(batch, batched = True)
|
696 |
+
|
697 |
+
if self.upsample_preds:
|
698 |
+
hs, ws = self.upsample_res
|
699 |
+
low_res_certainty = F.interpolate(
|
700 |
+
dense_corresps[16]["dense_certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
|
701 |
+
)
|
702 |
+
cert_clamp = 0
|
703 |
+
factor = 0.5
|
704 |
+
low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp)
|
705 |
+
|
706 |
+
if self.upsample_preds:
|
707 |
+
test_transform = get_tuple_transform_ops(
|
708 |
+
resize=(hs, ws), normalize=True
|
709 |
+
)
|
710 |
+
query, support = test_transform((im1, im2))
|
711 |
+
query, support = query[None].to(device), support[None].to(device)
|
712 |
+
batch = {"query": query, "support": support, "corresps": dense_corresps[finest_scale]}
|
713 |
+
if symmetric:
|
714 |
+
dense_corresps = self.forward_symmetric(batch, upsample = True, batched=True)
|
715 |
+
else:
|
716 |
+
dense_corresps = self.forward(batch, batched = True, upsample=True)
|
717 |
+
query_to_support = dense_corresps[finest_scale]["dense_flow"]
|
718 |
+
dense_certainty = dense_corresps[finest_scale]["dense_certainty"]
|
719 |
+
|
720 |
+
# Get certainty interpolation
|
721 |
+
dense_certainty = dense_certainty - low_res_certainty
|
722 |
+
query_to_support = query_to_support.permute(
|
723 |
+
0, 2, 3, 1
|
724 |
+
)
|
725 |
+
# Create im1 meshgrid
|
726 |
+
query_coords = torch.meshgrid(
|
727 |
+
(
|
728 |
+
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
|
729 |
+
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
|
730 |
+
)
|
731 |
+
)
|
732 |
+
query_coords = torch.stack((query_coords[1], query_coords[0]))
|
733 |
+
query_coords = query_coords[None].expand(b, 2, hs, ws)
|
734 |
+
dense_certainty = dense_certainty.sigmoid() # logits -> probs
|
735 |
+
query_coords = query_coords.permute(0, 2, 3, 1)
|
736 |
+
if (query_to_support.abs() > 1).any() and True:
|
737 |
+
wrong = (query_to_support.abs() > 1).sum(dim=-1) > 0
|
738 |
+
dense_certainty[wrong[:,None]] = 0
|
739 |
+
|
740 |
+
query_to_support = torch.clamp(query_to_support, -1, 1)
|
741 |
+
if symmetric:
|
742 |
+
support_coords = query_coords
|
743 |
+
qts, stq = query_to_support.chunk(2)
|
744 |
+
q_warp = torch.cat((query_coords, qts), dim=-1)
|
745 |
+
s_warp = torch.cat((stq, support_coords), dim=-1)
|
746 |
+
warp = torch.cat((q_warp, s_warp),dim=2)
|
747 |
+
dense_certainty = torch.cat(dense_certainty.chunk(2), dim=3)[:,0]
|
748 |
+
else:
|
749 |
+
warp = torch.cat((query_coords, query_to_support), dim=-1)
|
750 |
+
if batched:
|
751 |
+
return (
|
752 |
+
warp,
|
753 |
+
dense_certainty
|
754 |
+
)
|
755 |
+
else:
|
756 |
+
return (
|
757 |
+
warp[0],
|
758 |
+
dense_certainty[0],
|
759 |
+
)
|
third_party/DKM/dkm/models/encoders.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision.models as tvm
|
5 |
+
|
6 |
+
class ResNet18(nn.Module):
|
7 |
+
def __init__(self, pretrained=False) -> None:
|
8 |
+
super().__init__()
|
9 |
+
self.net = tvm.resnet18(pretrained=pretrained)
|
10 |
+
def forward(self, x):
|
11 |
+
self = self.net
|
12 |
+
x1 = x
|
13 |
+
x = self.conv1(x1)
|
14 |
+
x = self.bn1(x)
|
15 |
+
x2 = self.relu(x)
|
16 |
+
x = self.maxpool(x2)
|
17 |
+
x4 = self.layer1(x)
|
18 |
+
x8 = self.layer2(x4)
|
19 |
+
x16 = self.layer3(x8)
|
20 |
+
x32 = self.layer4(x16)
|
21 |
+
return {32:x32,16:x16,8:x8,4:x4,2:x2,1:x1}
|
22 |
+
|
23 |
+
def train(self, mode=True):
|
24 |
+
super().train(mode)
|
25 |
+
for m in self.modules():
|
26 |
+
if isinstance(m, nn.BatchNorm2d):
|
27 |
+
m.eval()
|
28 |
+
pass
|
29 |
+
|
30 |
+
class ResNet50(nn.Module):
|
31 |
+
def __init__(self, pretrained=False, high_res = False, weights = None, dilation = None, freeze_bn = True, anti_aliased = False) -> None:
|
32 |
+
super().__init__()
|
33 |
+
if dilation is None:
|
34 |
+
dilation = [False,False,False]
|
35 |
+
if anti_aliased:
|
36 |
+
pass
|
37 |
+
else:
|
38 |
+
if weights is not None:
|
39 |
+
self.net = tvm.resnet50(weights = weights,replace_stride_with_dilation=dilation)
|
40 |
+
else:
|
41 |
+
self.net = tvm.resnet50(pretrained=pretrained,replace_stride_with_dilation=dilation)
|
42 |
+
|
43 |
+
self.high_res = high_res
|
44 |
+
self.freeze_bn = freeze_bn
|
45 |
+
def forward(self, x):
|
46 |
+
net = self.net
|
47 |
+
feats = {1:x}
|
48 |
+
x = net.conv1(x)
|
49 |
+
x = net.bn1(x)
|
50 |
+
x = net.relu(x)
|
51 |
+
feats[2] = x
|
52 |
+
x = net.maxpool(x)
|
53 |
+
x = net.layer1(x)
|
54 |
+
feats[4] = x
|
55 |
+
x = net.layer2(x)
|
56 |
+
feats[8] = x
|
57 |
+
x = net.layer3(x)
|
58 |
+
feats[16] = x
|
59 |
+
x = net.layer4(x)
|
60 |
+
feats[32] = x
|
61 |
+
return feats
|
62 |
+
|
63 |
+
def train(self, mode=True):
|
64 |
+
super().train(mode)
|
65 |
+
if self.freeze_bn:
|
66 |
+
for m in self.modules():
|
67 |
+
if isinstance(m, nn.BatchNorm2d):
|
68 |
+
m.eval()
|
69 |
+
pass
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
class ResNet101(nn.Module):
|
75 |
+
def __init__(self, pretrained=False, high_res = False, weights = None) -> None:
|
76 |
+
super().__init__()
|
77 |
+
if weights is not None:
|
78 |
+
self.net = tvm.resnet101(weights = weights)
|
79 |
+
else:
|
80 |
+
self.net = tvm.resnet101(pretrained=pretrained)
|
81 |
+
self.high_res = high_res
|
82 |
+
self.scale_factor = 1 if not high_res else 1.5
|
83 |
+
def forward(self, x):
|
84 |
+
net = self.net
|
85 |
+
feats = {1:x}
|
86 |
+
sf = self.scale_factor
|
87 |
+
if self.high_res:
|
88 |
+
x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic")
|
89 |
+
x = net.conv1(x)
|
90 |
+
x = net.bn1(x)
|
91 |
+
x = net.relu(x)
|
92 |
+
feats[2] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
|
93 |
+
x = net.maxpool(x)
|
94 |
+
x = net.layer1(x)
|
95 |
+
feats[4] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
|
96 |
+
x = net.layer2(x)
|
97 |
+
feats[8] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
|
98 |
+
x = net.layer3(x)
|
99 |
+
feats[16] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
|
100 |
+
x = net.layer4(x)
|
101 |
+
feats[32] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
|
102 |
+
return feats
|
103 |
+
|
104 |
+
def train(self, mode=True):
|
105 |
+
super().train(mode)
|
106 |
+
for m in self.modules():
|
107 |
+
if isinstance(m, nn.BatchNorm2d):
|
108 |
+
m.eval()
|
109 |
+
pass
|
110 |
+
|
111 |
+
|
112 |
+
class WideResNet50(nn.Module):
|
113 |
+
def __init__(self, pretrained=False, high_res = False, weights = None) -> None:
|
114 |
+
super().__init__()
|
115 |
+
if weights is not None:
|
116 |
+
self.net = tvm.wide_resnet50_2(weights = weights)
|
117 |
+
else:
|
118 |
+
self.net = tvm.wide_resnet50_2(pretrained=pretrained)
|
119 |
+
self.high_res = high_res
|
120 |
+
self.scale_factor = 1 if not high_res else 1.5
|
121 |
+
def forward(self, x):
|
122 |
+
net = self.net
|
123 |
+
feats = {1:x}
|
124 |
+
sf = self.scale_factor
|
125 |
+
if self.high_res:
|
126 |
+
x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic")
|
127 |
+
x = net.conv1(x)
|
128 |
+
x = net.bn1(x)
|
129 |
+
x = net.relu(x)
|
130 |
+
feats[2] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
|
131 |
+
x = net.maxpool(x)
|
132 |
+
x = net.layer1(x)
|
133 |
+
feats[4] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
|
134 |
+
x = net.layer2(x)
|
135 |
+
feats[8] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
|
136 |
+
x = net.layer3(x)
|
137 |
+
feats[16] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
|
138 |
+
x = net.layer4(x)
|
139 |
+
feats[32] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
|
140 |
+
return feats
|
141 |
+
|
142 |
+
def train(self, mode=True):
|
143 |
+
super().train(mode)
|
144 |
+
for m in self.modules():
|
145 |
+
if isinstance(m, nn.BatchNorm2d):
|
146 |
+
m.eval()
|
147 |
+
pass
|