Vincentqyw
commited on
Commit
·
8b973ee
1
Parent(s):
62c7319
fix: roma
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +3 -2
- third_party/ALIKE/alike.py +91 -36
- third_party/ALIKE/alnet.py +66 -36
- third_party/ALIKE/demo.py +82 -48
- third_party/ALIKE/hseq/eval.py +71 -36
- third_party/ALIKE/hseq/extract.py +45 -29
- third_party/ALIKE/soft_detect.py +72 -32
- third_party/ASpanFormer/configs/aspan/indoor/aspan_test.py +5 -4
- third_party/ASpanFormer/configs/aspan/indoor/aspan_train.py +4 -3
- third_party/ASpanFormer/configs/aspan/outdoor/aspan_test.py +6 -5
- third_party/ASpanFormer/configs/aspan/outdoor/aspan_train.py +4 -3
- third_party/ASpanFormer/configs/data/base.py +1 -0
- third_party/ASpanFormer/configs/data/megadepth_test_1500.py +3 -3
- third_party/ASpanFormer/configs/data/megadepth_trainval_832.py +7 -3
- third_party/ASpanFormer/configs/data/scannet_trainval.py +7 -3
- third_party/ASpanFormer/demo/demo.py +68 -40
- third_party/ASpanFormer/demo/demo_utils.py +71 -27
- third_party/ASpanFormer/src/ASpanFormer/aspan_module/__init__.py +1 -1
- third_party/ASpanFormer/src/ASpanFormer/aspan_module/attention.py +224 -110
- third_party/ASpanFormer/src/ASpanFormer/aspan_module/fine_preprocess.py +36 -20
- third_party/ASpanFormer/src/ASpanFormer/aspan_module/loftr.py +22 -26
- third_party/ASpanFormer/src/ASpanFormer/aspan_module/transformer.py +247 -140
- third_party/ASpanFormer/src/ASpanFormer/aspanformer.py +107 -62
- third_party/ASpanFormer/src/ASpanFormer/backbone/__init__.py +8 -6
- third_party/ASpanFormer/src/ASpanFormer/backbone/resnet_fpn.py +36 -21
- third_party/ASpanFormer/src/ASpanFormer/utils/coarse_matching.py +168 -132
- third_party/ASpanFormer/src/ASpanFormer/utils/cvpr_ds_config.py +6 -6
- third_party/ASpanFormer/src/ASpanFormer/utils/fine_matching.py +32 -22
- third_party/ASpanFormer/src/ASpanFormer/utils/geometry.py +29 -10
- third_party/ASpanFormer/src/ASpanFormer/utils/position_encoding.py +36 -17
- third_party/ASpanFormer/src/ASpanFormer/utils/supervision.py +62 -41
- third_party/ASpanFormer/src/config/default.py +50 -31
- third_party/ASpanFormer/src/datasets/__init__.py +0 -1
- third_party/ASpanFormer/src/datasets/megadepth.py +83 -56
- third_party/ASpanFormer/src/datasets/sampler.py +33 -20
- third_party/ASpanFormer/src/datasets/scannet.py +52 -42
- third_party/ASpanFormer/src/lightning/data.py +222 -143
- third_party/ASpanFormer/src/lightning/lightning_aspanformer.py +218 -120
- third_party/ASpanFormer/src/losses/aspan_loss.py +155 -97
- third_party/ASpanFormer/src/optimizers/__init__.py +22 -9
- third_party/ASpanFormer/src/utils/augment.py +33 -23
- third_party/ASpanFormer/src/utils/comm.py +12 -7
- third_party/ASpanFormer/src/utils/dataloader.py +8 -7
- third_party/ASpanFormer/src/utils/dataset.py +48 -38
- third_party/ASpanFormer/src/utils/metrics.py +100 -67
- third_party/ASpanFormer/src/utils/misc.py +83 -38
- third_party/ASpanFormer/src/utils/plotting.py +128 -94
- third_party/ASpanFormer/src/utils/profiler.py +5 -4
- third_party/ASpanFormer/test.py +43 -17
- third_party/ASpanFormer/tools/extract.py +59 -25
app.py
CHANGED
@@ -9,9 +9,10 @@ from extra_utils.utils import (
|
|
9 |
match_features,
|
10 |
get_model,
|
11 |
get_feature_model,
|
12 |
-
display_matches
|
13 |
)
|
14 |
|
|
|
15 |
def run_matching(
|
16 |
match_threshold, extract_max_keypoints, keypoint_threshold, key, image0, image1
|
17 |
):
|
@@ -277,7 +278,7 @@ def run(config):
|
|
277 |
matcher_info,
|
278 |
]
|
279 |
button_reset.click(fn=ui_reset_state, inputs=inputs, outputs=reset_outputs)
|
280 |
-
|
281 |
app.launch(share=False)
|
282 |
|
283 |
|
|
|
9 |
match_features,
|
10 |
get_model,
|
11 |
get_feature_model,
|
12 |
+
display_matches,
|
13 |
)
|
14 |
|
15 |
+
|
16 |
def run_matching(
|
17 |
match_threshold, extract_max_keypoints, keypoint_threshold, key, image0, image1
|
18 |
):
|
|
|
278 |
matcher_info,
|
279 |
]
|
280 |
button_reset.click(fn=ui_reset_state, inputs=inputs, outputs=reset_outputs)
|
281 |
+
|
282 |
app.launch(share=False)
|
283 |
|
284 |
|
third_party/ALIKE/alike.py
CHANGED
@@ -12,46 +12,89 @@ from soft_detect import DKD
|
|
12 |
import time
|
13 |
|
14 |
configs = {
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
}
|
24 |
|
25 |
|
26 |
class ALike(ALNet):
|
27 |
-
def __init__(
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
super().__init__(c1, c2, c3, c4, dim, single_head)
|
39 |
self.radius = radius
|
40 |
self.top_k = top_k
|
41 |
self.n_limit = n_limit
|
42 |
self.scores_th = scores_th
|
43 |
-
self.dkd = DKD(
|
44 |
-
|
|
|
|
|
|
|
|
|
45 |
self.device = device
|
46 |
|
47 |
-
if model_path !=
|
48 |
state_dict = torch.load(model_path, self.device)
|
49 |
self.load_state_dict(state_dict)
|
50 |
self.to(self.device)
|
51 |
self.eval()
|
52 |
-
logging.info(f
|
53 |
logging.info(
|
54 |
-
f"Number of model parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e3}KB"
|
|
|
55 |
|
56 |
def extract_dense_map(self, image, ret_dict=False):
|
57 |
# ====================================================
|
@@ -81,7 +124,10 @@ class ALike(ALNet):
|
|
81 |
descriptor_map = torch.nn.functional.normalize(descriptor_map, p=2, dim=1)
|
82 |
|
83 |
if ret_dict:
|
84 |
-
return {
|
|
|
|
|
|
|
85 |
else:
|
86 |
return descriptor_map, scores_map
|
87 |
|
@@ -104,15 +150,22 @@ class ALike(ALNet):
|
|
104 |
image = cv2.resize(image, dsize=None, fx=ratio, fy=ratio)
|
105 |
|
106 |
# ==================== convert image to tensor
|
107 |
-
image =
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
# ==================== extract keypoints
|
110 |
start = time.time()
|
111 |
|
112 |
with torch.no_grad():
|
113 |
descriptor_map, scores_map = self.extract_dense_map(image)
|
114 |
-
keypoints, descriptors, scores, _ = self.dkd(
|
115 |
-
|
|
|
116 |
keypoints, descriptors, scores = keypoints[0], descriptors[0], scores[0]
|
117 |
keypoints = (keypoints + 1) / 2 * keypoints.new_tensor([[W - 1, H - 1]])
|
118 |
|
@@ -124,14 +177,16 @@ class ALike(ALNet):
|
|
124 |
|
125 |
end = time.time()
|
126 |
|
127 |
-
return {
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
|
|
|
|
132 |
|
133 |
|
134 |
-
if __name__ ==
|
135 |
import numpy as np
|
136 |
from thop import profile
|
137 |
|
@@ -139,5 +194,5 @@ if __name__ == '__main__':
|
|
139 |
|
140 |
image = np.random.random((640, 480, 3)).astype(np.float32)
|
141 |
flops, params = profile(net, inputs=(image, 9999, False), verbose=False)
|
142 |
-
print(
|
143 |
-
print(
|
|
|
12 |
import time
|
13 |
|
14 |
configs = {
|
15 |
+
"alike-t": {
|
16 |
+
"c1": 8,
|
17 |
+
"c2": 16,
|
18 |
+
"c3": 32,
|
19 |
+
"c4": 64,
|
20 |
+
"dim": 64,
|
21 |
+
"single_head": True,
|
22 |
+
"radius": 2,
|
23 |
+
"model_path": os.path.join(os.path.split(__file__)[0], "models", "alike-t.pth"),
|
24 |
+
},
|
25 |
+
"alike-s": {
|
26 |
+
"c1": 8,
|
27 |
+
"c2": 16,
|
28 |
+
"c3": 48,
|
29 |
+
"c4": 96,
|
30 |
+
"dim": 96,
|
31 |
+
"single_head": True,
|
32 |
+
"radius": 2,
|
33 |
+
"model_path": os.path.join(os.path.split(__file__)[0], "models", "alike-s.pth"),
|
34 |
+
},
|
35 |
+
"alike-n": {
|
36 |
+
"c1": 16,
|
37 |
+
"c2": 32,
|
38 |
+
"c3": 64,
|
39 |
+
"c4": 128,
|
40 |
+
"dim": 128,
|
41 |
+
"single_head": True,
|
42 |
+
"radius": 2,
|
43 |
+
"model_path": os.path.join(os.path.split(__file__)[0], "models", "alike-n.pth"),
|
44 |
+
},
|
45 |
+
"alike-l": {
|
46 |
+
"c1": 32,
|
47 |
+
"c2": 64,
|
48 |
+
"c3": 128,
|
49 |
+
"c4": 128,
|
50 |
+
"dim": 128,
|
51 |
+
"single_head": False,
|
52 |
+
"radius": 2,
|
53 |
+
"model_path": os.path.join(os.path.split(__file__)[0], "models", "alike-l.pth"),
|
54 |
+
},
|
55 |
}
|
56 |
|
57 |
|
58 |
class ALike(ALNet):
|
59 |
+
def __init__(
|
60 |
+
self,
|
61 |
+
# ================================== feature encoder
|
62 |
+
c1: int = 32,
|
63 |
+
c2: int = 64,
|
64 |
+
c3: int = 128,
|
65 |
+
c4: int = 128,
|
66 |
+
dim: int = 128,
|
67 |
+
single_head: bool = False,
|
68 |
+
# ================================== detect parameters
|
69 |
+
radius: int = 2,
|
70 |
+
top_k: int = 500,
|
71 |
+
scores_th: float = 0.5,
|
72 |
+
n_limit: int = 5000,
|
73 |
+
device: str = "cpu",
|
74 |
+
model_path: str = "",
|
75 |
+
):
|
76 |
super().__init__(c1, c2, c3, c4, dim, single_head)
|
77 |
self.radius = radius
|
78 |
self.top_k = top_k
|
79 |
self.n_limit = n_limit
|
80 |
self.scores_th = scores_th
|
81 |
+
self.dkd = DKD(
|
82 |
+
radius=self.radius,
|
83 |
+
top_k=self.top_k,
|
84 |
+
scores_th=self.scores_th,
|
85 |
+
n_limit=self.n_limit,
|
86 |
+
)
|
87 |
self.device = device
|
88 |
|
89 |
+
if model_path != "":
|
90 |
state_dict = torch.load(model_path, self.device)
|
91 |
self.load_state_dict(state_dict)
|
92 |
self.to(self.device)
|
93 |
self.eval()
|
94 |
+
logging.info(f"Loaded model parameters from {model_path}")
|
95 |
logging.info(
|
96 |
+
f"Number of model parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e3}KB"
|
97 |
+
)
|
98 |
|
99 |
def extract_dense_map(self, image, ret_dict=False):
|
100 |
# ====================================================
|
|
|
124 |
descriptor_map = torch.nn.functional.normalize(descriptor_map, p=2, dim=1)
|
125 |
|
126 |
if ret_dict:
|
127 |
+
return {
|
128 |
+
"descriptor_map": descriptor_map,
|
129 |
+
"scores_map": scores_map,
|
130 |
+
}
|
131 |
else:
|
132 |
return descriptor_map, scores_map
|
133 |
|
|
|
150 |
image = cv2.resize(image, dsize=None, fx=ratio, fy=ratio)
|
151 |
|
152 |
# ==================== convert image to tensor
|
153 |
+
image = (
|
154 |
+
torch.from_numpy(image)
|
155 |
+
.to(self.device)
|
156 |
+
.to(torch.float32)
|
157 |
+
.permute(2, 0, 1)[None]
|
158 |
+
/ 255.0
|
159 |
+
)
|
160 |
|
161 |
# ==================== extract keypoints
|
162 |
start = time.time()
|
163 |
|
164 |
with torch.no_grad():
|
165 |
descriptor_map, scores_map = self.extract_dense_map(image)
|
166 |
+
keypoints, descriptors, scores, _ = self.dkd(
|
167 |
+
scores_map, descriptor_map, sub_pixel=sub_pixel
|
168 |
+
)
|
169 |
keypoints, descriptors, scores = keypoints[0], descriptors[0], scores[0]
|
170 |
keypoints = (keypoints + 1) / 2 * keypoints.new_tensor([[W - 1, H - 1]])
|
171 |
|
|
|
177 |
|
178 |
end = time.time()
|
179 |
|
180 |
+
return {
|
181 |
+
"keypoints": keypoints.cpu().numpy(),
|
182 |
+
"descriptors": descriptors.cpu().numpy(),
|
183 |
+
"scores": scores.cpu().numpy(),
|
184 |
+
"scores_map": scores_map.cpu().numpy(),
|
185 |
+
"time": end - start,
|
186 |
+
}
|
187 |
|
188 |
|
189 |
+
if __name__ == "__main__":
|
190 |
import numpy as np
|
191 |
from thop import profile
|
192 |
|
|
|
194 |
|
195 |
image = np.random.random((640, 480, 3)).astype(np.float32)
|
196 |
flops, params = profile(net, inputs=(image, 9999, False), verbose=False)
|
197 |
+
print("{:<30} {:<8} GFLops".format("Computational complexity: ", flops / 1e9))
|
198 |
+
print("{:<30} {:<8} KB".format("Number of parameters: ", params / 1e3))
|
third_party/ALIKE/alnet.py
CHANGED
@@ -5,9 +5,13 @@ from typing import Optional, Callable
|
|
5 |
|
6 |
|
7 |
class ConvBlock(nn.Module):
|
8 |
-
def __init__(
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
11 |
super().__init__()
|
12 |
if gate is None:
|
13 |
self.gate = nn.ReLU(inplace=True)
|
@@ -31,16 +35,16 @@ class ResBlock(nn.Module):
|
|
31 |
expansion: int = 1
|
32 |
|
33 |
def __init__(
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
) -> None:
|
45 |
super(ResBlock, self).__init__()
|
46 |
if gate is None:
|
@@ -50,7 +54,7 @@ class ResBlock(nn.Module):
|
|
50 |
if norm_layer is None:
|
51 |
norm_layer = nn.BatchNorm2d
|
52 |
if groups != 1 or base_width != 64:
|
53 |
-
raise ValueError(
|
54 |
if dilation > 1:
|
55 |
raise NotImplementedError("Dilation > 1 not supported in ResBlock")
|
56 |
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
@@ -81,9 +85,15 @@ class ResBlock(nn.Module):
|
|
81 |
|
82 |
|
83 |
class ALNet(nn.Module):
|
84 |
-
def __init__(
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
super().__init__()
|
88 |
|
89 |
self.gate = nn.ReLU(inplace=True)
|
@@ -93,28 +103,48 @@ class ALNet(nn.Module):
|
|
93 |
|
94 |
self.block1 = ConvBlock(3, c1, self.gate, nn.BatchNorm2d)
|
95 |
|
96 |
-
self.block2 = ResBlock(
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
self.
|
105 |
-
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
# ================================== feature aggregation
|
110 |
self.conv1 = resnet.conv1x1(c1, dim // 4)
|
111 |
self.conv2 = resnet.conv1x1(c2, dim // 4)
|
112 |
self.conv3 = resnet.conv1x1(c3, dim // 4)
|
113 |
self.conv4 = resnet.conv1x1(dim, dim // 4)
|
114 |
-
self.upsample2 = nn.Upsample(
|
115 |
-
|
116 |
-
|
117 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
# ================================== detector and descriptor head
|
120 |
self.single_head = single_head
|
@@ -153,12 +183,12 @@ class ALNet(nn.Module):
|
|
153 |
return scores_map, descriptor_map
|
154 |
|
155 |
|
156 |
-
if __name__ ==
|
157 |
from thop import profile
|
158 |
|
159 |
net = ALNet(c1=16, c2=32, c3=64, c4=128, dim=128, single_head=True)
|
160 |
|
161 |
image = torch.randn(1, 3, 640, 480)
|
162 |
flops, params = profile(net, inputs=(image,), verbose=False)
|
163 |
-
print(
|
164 |
-
print(
|
|
|
5 |
|
6 |
|
7 |
class ConvBlock(nn.Module):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
in_channels,
|
11 |
+
out_channels,
|
12 |
+
gate: Optional[Callable[..., nn.Module]] = None,
|
13 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
14 |
+
):
|
15 |
super().__init__()
|
16 |
if gate is None:
|
17 |
self.gate = nn.ReLU(inplace=True)
|
|
|
35 |
expansion: int = 1
|
36 |
|
37 |
def __init__(
|
38 |
+
self,
|
39 |
+
inplanes: int,
|
40 |
+
planes: int,
|
41 |
+
stride: int = 1,
|
42 |
+
downsample: Optional[nn.Module] = None,
|
43 |
+
groups: int = 1,
|
44 |
+
base_width: int = 64,
|
45 |
+
dilation: int = 1,
|
46 |
+
gate: Optional[Callable[..., nn.Module]] = None,
|
47 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
48 |
) -> None:
|
49 |
super(ResBlock, self).__init__()
|
50 |
if gate is None:
|
|
|
54 |
if norm_layer is None:
|
55 |
norm_layer = nn.BatchNorm2d
|
56 |
if groups != 1 or base_width != 64:
|
57 |
+
raise ValueError("ResBlock only supports groups=1 and base_width=64")
|
58 |
if dilation > 1:
|
59 |
raise NotImplementedError("Dilation > 1 not supported in ResBlock")
|
60 |
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
|
|
85 |
|
86 |
|
87 |
class ALNet(nn.Module):
|
88 |
+
def __init__(
|
89 |
+
self,
|
90 |
+
c1: int = 32,
|
91 |
+
c2: int = 64,
|
92 |
+
c3: int = 128,
|
93 |
+
c4: int = 128,
|
94 |
+
dim: int = 128,
|
95 |
+
single_head: bool = True,
|
96 |
+
):
|
97 |
super().__init__()
|
98 |
|
99 |
self.gate = nn.ReLU(inplace=True)
|
|
|
103 |
|
104 |
self.block1 = ConvBlock(3, c1, self.gate, nn.BatchNorm2d)
|
105 |
|
106 |
+
self.block2 = ResBlock(
|
107 |
+
inplanes=c1,
|
108 |
+
planes=c2,
|
109 |
+
stride=1,
|
110 |
+
downsample=nn.Conv2d(c1, c2, 1),
|
111 |
+
gate=self.gate,
|
112 |
+
norm_layer=nn.BatchNorm2d,
|
113 |
+
)
|
114 |
+
self.block3 = ResBlock(
|
115 |
+
inplanes=c2,
|
116 |
+
planes=c3,
|
117 |
+
stride=1,
|
118 |
+
downsample=nn.Conv2d(c2, c3, 1),
|
119 |
+
gate=self.gate,
|
120 |
+
norm_layer=nn.BatchNorm2d,
|
121 |
+
)
|
122 |
+
self.block4 = ResBlock(
|
123 |
+
inplanes=c3,
|
124 |
+
planes=c4,
|
125 |
+
stride=1,
|
126 |
+
downsample=nn.Conv2d(c3, c4, 1),
|
127 |
+
gate=self.gate,
|
128 |
+
norm_layer=nn.BatchNorm2d,
|
129 |
+
)
|
130 |
|
131 |
# ================================== feature aggregation
|
132 |
self.conv1 = resnet.conv1x1(c1, dim // 4)
|
133 |
self.conv2 = resnet.conv1x1(c2, dim // 4)
|
134 |
self.conv3 = resnet.conv1x1(c3, dim // 4)
|
135 |
self.conv4 = resnet.conv1x1(dim, dim // 4)
|
136 |
+
self.upsample2 = nn.Upsample(
|
137 |
+
scale_factor=2, mode="bilinear", align_corners=True
|
138 |
+
)
|
139 |
+
self.upsample4 = nn.Upsample(
|
140 |
+
scale_factor=4, mode="bilinear", align_corners=True
|
141 |
+
)
|
142 |
+
self.upsample8 = nn.Upsample(
|
143 |
+
scale_factor=8, mode="bilinear", align_corners=True
|
144 |
+
)
|
145 |
+
self.upsample32 = nn.Upsample(
|
146 |
+
scale_factor=32, mode="bilinear", align_corners=True
|
147 |
+
)
|
148 |
|
149 |
# ================================== detector and descriptor head
|
150 |
self.single_head = single_head
|
|
|
183 |
return scores_map, descriptor_map
|
184 |
|
185 |
|
186 |
+
if __name__ == "__main__":
|
187 |
from thop import profile
|
188 |
|
189 |
net = ALNet(c1=16, c2=32, c3=64, c4=128, dim=128, single_head=True)
|
190 |
|
191 |
image = torch.randn(1, 3, 640, 480)
|
192 |
flops, params = profile(net, inputs=(image,), verbose=False)
|
193 |
+
print("{:<30} {:<8} GFLops".format("Computational complexity: ", flops / 1e9))
|
194 |
+
print("{:<30} {:<8} KB".format("Number of parameters: ", params / 1e3))
|
third_party/ALIKE/demo.py
CHANGED
@@ -12,13 +12,13 @@ from alike import ALike, configs
|
|
12 |
class ImageLoader(object):
|
13 |
def __init__(self, filepath: str):
|
14 |
self.N = 3000
|
15 |
-
if filepath.startswith(
|
16 |
camera = int(filepath[6:])
|
17 |
self.cap = cv2.VideoCapture(camera)
|
18 |
if not self.cap.isOpened():
|
19 |
raise IOError(f"Can't open camera {camera}!")
|
20 |
-
logging.info(f
|
21 |
-
self.mode =
|
22 |
elif os.path.exists(filepath):
|
23 |
if os.path.isfile(filepath):
|
24 |
self.cap = cv2.VideoCapture(filepath)
|
@@ -27,34 +27,38 @@ class ImageLoader(object):
|
|
27 |
rate = self.cap.get(cv2.CAP_PROP_FPS)
|
28 |
self.N = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) - 1
|
29 |
duration = self.N / rate
|
30 |
-
logging.info(f
|
31 |
-
logging.info(f
|
32 |
-
self.mode =
|
33 |
else:
|
34 |
-
self.images =
|
35 |
-
|
36 |
-
|
|
|
|
|
37 |
self.images.sort()
|
38 |
self.N = len(self.images)
|
39 |
-
logging.info(f
|
40 |
-
self.mode =
|
41 |
else:
|
42 |
-
raise IOError(
|
|
|
|
|
43 |
|
44 |
def __getitem__(self, item):
|
45 |
-
if self.mode ==
|
46 |
if item > self.N:
|
47 |
return None
|
48 |
ret, img = self.cap.read()
|
49 |
if not ret:
|
50 |
raise "Can't read image from camera"
|
51 |
-
if self.mode ==
|
52 |
self.cap.set(cv2.CAP_PROP_POS_FRAMES, item)
|
53 |
-
elif self.mode ==
|
54 |
filename = self.images[item]
|
55 |
img = cv2.imread(filename)
|
56 |
if img is None:
|
57 |
-
raise Exception(
|
58 |
return img
|
59 |
|
60 |
def __len__(self):
|
@@ -99,38 +103,68 @@ class SimpleTracker(object):
|
|
99 |
nn12 = np.argmax(sim, axis=1)
|
100 |
nn21 = np.argmax(sim, axis=0)
|
101 |
ids1 = np.arange(0, sim.shape[0])
|
102 |
-
mask =
|
103 |
matches = np.stack([ids1[mask], nn12[mask]])
|
104 |
return matches.transpose()
|
105 |
|
106 |
|
107 |
-
if __name__ ==
|
108 |
-
parser = argparse.ArgumentParser(description=
|
109 |
-
parser.add_argument(
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
args = parser.parse_args()
|
125 |
|
126 |
logging.basicConfig(level=logging.INFO)
|
127 |
|
128 |
image_loader = ImageLoader(args.input)
|
129 |
-
model = ALike(
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
|
|
|
|
134 |
tracker = SimpleTracker()
|
135 |
|
136 |
if not args.no_display:
|
@@ -142,26 +176,26 @@ if __name__ == '__main__':
|
|
142 |
for img in progress_bar:
|
143 |
if img is None:
|
144 |
break
|
145 |
-
|
146 |
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
147 |
pred = model(img_rgb, sub_pixel=not args.no_sub_pixel)
|
148 |
-
kpts = pred[
|
149 |
-
desc = pred[
|
150 |
-
runtime.append(pred[
|
151 |
|
152 |
out, N_matches = tracker.update(img, kpts, desc)
|
153 |
|
154 |
-
ave_fps = (1. / np.stack(runtime)).mean()
|
155 |
status = f"Fps:{ave_fps:.1f}, Keypoints/Matches: {len(kpts)}/{N_matches}"
|
156 |
progress_bar.set_description(status)
|
157 |
|
158 |
if not args.no_display:
|
159 |
-
cv2.setWindowTitle(args.model, args.model +
|
160 |
cv2.imshow(args.model, out)
|
161 |
-
if cv2.waitKey(1) == ord(
|
162 |
break
|
163 |
|
164 |
-
logging.info(
|
165 |
if not args.no_display:
|
166 |
-
logging.info(
|
167 |
cv2.waitKey()
|
|
|
12 |
class ImageLoader(object):
|
13 |
def __init__(self, filepath: str):
|
14 |
self.N = 3000
|
15 |
+
if filepath.startswith("camera"):
|
16 |
camera = int(filepath[6:])
|
17 |
self.cap = cv2.VideoCapture(camera)
|
18 |
if not self.cap.isOpened():
|
19 |
raise IOError(f"Can't open camera {camera}!")
|
20 |
+
logging.info(f"Opened camera {camera}")
|
21 |
+
self.mode = "camera"
|
22 |
elif os.path.exists(filepath):
|
23 |
if os.path.isfile(filepath):
|
24 |
self.cap = cv2.VideoCapture(filepath)
|
|
|
27 |
rate = self.cap.get(cv2.CAP_PROP_FPS)
|
28 |
self.N = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) - 1
|
29 |
duration = self.N / rate
|
30 |
+
logging.info(f"Opened video {filepath}")
|
31 |
+
logging.info(f"Frames: {self.N}, FPS: {rate}, Duration: {duration}s")
|
32 |
+
self.mode = "video"
|
33 |
else:
|
34 |
+
self.images = (
|
35 |
+
glob.glob(os.path.join(filepath, "*.png"))
|
36 |
+
+ glob.glob(os.path.join(filepath, "*.jpg"))
|
37 |
+
+ glob.glob(os.path.join(filepath, "*.ppm"))
|
38 |
+
)
|
39 |
self.images.sort()
|
40 |
self.N = len(self.images)
|
41 |
+
logging.info(f"Loading {self.N} images")
|
42 |
+
self.mode = "images"
|
43 |
else:
|
44 |
+
raise IOError(
|
45 |
+
"Error filepath (camerax/path of images/path of videos): ", filepath
|
46 |
+
)
|
47 |
|
48 |
def __getitem__(self, item):
|
49 |
+
if self.mode == "camera" or self.mode == "video":
|
50 |
if item > self.N:
|
51 |
return None
|
52 |
ret, img = self.cap.read()
|
53 |
if not ret:
|
54 |
raise "Can't read image from camera"
|
55 |
+
if self.mode == "video":
|
56 |
self.cap.set(cv2.CAP_PROP_POS_FRAMES, item)
|
57 |
+
elif self.mode == "images":
|
58 |
filename = self.images[item]
|
59 |
img = cv2.imread(filename)
|
60 |
if img is None:
|
61 |
+
raise Exception("Error reading image %s" % filename)
|
62 |
return img
|
63 |
|
64 |
def __len__(self):
|
|
|
103 |
nn12 = np.argmax(sim, axis=1)
|
104 |
nn21 = np.argmax(sim, axis=0)
|
105 |
ids1 = np.arange(0, sim.shape[0])
|
106 |
+
mask = ids1 == nn21[nn12]
|
107 |
matches = np.stack([ids1[mask], nn12[mask]])
|
108 |
return matches.transpose()
|
109 |
|
110 |
|
111 |
+
if __name__ == "__main__":
|
112 |
+
parser = argparse.ArgumentParser(description="ALike Demo.")
|
113 |
+
parser.add_argument(
|
114 |
+
"input",
|
115 |
+
type=str,
|
116 |
+
default="",
|
117 |
+
help='Image directory or movie file or "camera0" (for webcam0).',
|
118 |
+
)
|
119 |
+
parser.add_argument(
|
120 |
+
"--model",
|
121 |
+
choices=["alike-t", "alike-s", "alike-n", "alike-l"],
|
122 |
+
default="alike-t",
|
123 |
+
help="The model configuration",
|
124 |
+
)
|
125 |
+
parser.add_argument(
|
126 |
+
"--device", type=str, default="cuda", help="Running device (default: cuda)."
|
127 |
+
)
|
128 |
+
parser.add_argument(
|
129 |
+
"--top_k",
|
130 |
+
type=int,
|
131 |
+
default=-1,
|
132 |
+
help="Detect top K keypoints. -1 for threshold based mode, >0 for top K mode. (default: -1)",
|
133 |
+
)
|
134 |
+
parser.add_argument(
|
135 |
+
"--scores_th",
|
136 |
+
type=float,
|
137 |
+
default=0.2,
|
138 |
+
help="Detector score threshold (default: 0.2).",
|
139 |
+
)
|
140 |
+
parser.add_argument(
|
141 |
+
"--n_limit",
|
142 |
+
type=int,
|
143 |
+
default=5000,
|
144 |
+
help="Maximum number of keypoints to be detected (default: 5000).",
|
145 |
+
)
|
146 |
+
parser.add_argument(
|
147 |
+
"--no_display",
|
148 |
+
action="store_true",
|
149 |
+
help="Do not display images to screen. Useful if running remotely (default: False).",
|
150 |
+
)
|
151 |
+
parser.add_argument(
|
152 |
+
"--no_sub_pixel",
|
153 |
+
action="store_true",
|
154 |
+
help="Do not detect sub-pixel keypoints (default: False).",
|
155 |
+
)
|
156 |
args = parser.parse_args()
|
157 |
|
158 |
logging.basicConfig(level=logging.INFO)
|
159 |
|
160 |
image_loader = ImageLoader(args.input)
|
161 |
+
model = ALike(
|
162 |
+
**configs[args.model],
|
163 |
+
device=args.device,
|
164 |
+
top_k=args.top_k,
|
165 |
+
scores_th=args.scores_th,
|
166 |
+
n_limit=args.n_limit,
|
167 |
+
)
|
168 |
tracker = SimpleTracker()
|
169 |
|
170 |
if not args.no_display:
|
|
|
176 |
for img in progress_bar:
|
177 |
if img is None:
|
178 |
break
|
179 |
+
|
180 |
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
181 |
pred = model(img_rgb, sub_pixel=not args.no_sub_pixel)
|
182 |
+
kpts = pred["keypoints"]
|
183 |
+
desc = pred["descriptors"]
|
184 |
+
runtime.append(pred["time"])
|
185 |
|
186 |
out, N_matches = tracker.update(img, kpts, desc)
|
187 |
|
188 |
+
ave_fps = (1.0 / np.stack(runtime)).mean()
|
189 |
status = f"Fps:{ave_fps:.1f}, Keypoints/Matches: {len(kpts)}/{N_matches}"
|
190 |
progress_bar.set_description(status)
|
191 |
|
192 |
if not args.no_display:
|
193 |
+
cv2.setWindowTitle(args.model, args.model + ": " + status)
|
194 |
cv2.imshow(args.model, out)
|
195 |
+
if cv2.waitKey(1) == ord("q"):
|
196 |
break
|
197 |
|
198 |
+
logging.info("Finished!")
|
199 |
if not args.no_display:
|
200 |
+
logging.info("Press any key to exit!")
|
201 |
cv2.waitKey()
|
third_party/ALIKE/hseq/eval.py
CHANGED
@@ -6,29 +6,53 @@ import numpy as np
|
|
6 |
from extract import extract_method
|
7 |
|
8 |
use_cuda = torch.cuda.is_available()
|
9 |
-
device = torch.device(
|
10 |
-
|
11 |
-
methods = [
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
top_k = None
|
17 |
n_i = 52
|
18 |
n_v = 56
|
19 |
-
cache_dir =
|
20 |
-
dataset_path =
|
21 |
|
22 |
|
23 |
-
def generate_read_function(method, extension=
|
24 |
def read_function(seq_name, im_idx):
|
25 |
-
aux = np.load(
|
|
|
|
|
|
|
|
|
26 |
if top_k is None:
|
27 |
-
return aux[
|
28 |
else:
|
29 |
-
assert
|
30 |
-
ids = np.argsort(aux[
|
31 |
-
return aux[
|
32 |
|
33 |
return read_function
|
34 |
|
@@ -39,7 +63,7 @@ def mnn_matcher(descriptors_a, descriptors_b):
|
|
39 |
nn12 = torch.max(sim, dim=1)[1]
|
40 |
nn21 = torch.max(sim, dim=0)[1]
|
41 |
ids1 = torch.arange(0, sim.shape[0], device=device)
|
42 |
-
mask =
|
43 |
matches = torch.stack([ids1[mask], nn12[mask]])
|
44 |
return matches.t().data.cpu().numpy()
|
45 |
|
@@ -73,7 +97,7 @@ def benchmark_features(read_feats):
|
|
73 |
n_feats.append(keypoints_a.shape[0])
|
74 |
|
75 |
# =========== compute homography
|
76 |
-
ref_img = cv2.imread(os.path.join(dataset_path, seq_name,
|
77 |
ref_img_shape = ref_img.shape
|
78 |
|
79 |
for im_idx in range(2, 7):
|
@@ -82,17 +106,19 @@ def benchmark_features(read_feats):
|
|
82 |
|
83 |
matches = mnn_matcher(
|
84 |
torch.from_numpy(descriptors_a).to(device=device),
|
85 |
-
torch.from_numpy(descriptors_b).to(device=device)
|
86 |
)
|
87 |
|
88 |
-
homography = np.loadtxt(
|
|
|
|
|
89 |
|
90 |
-
pos_a = keypoints_a[matches[:, 0], :
|
91 |
pos_a_h = np.concatenate([pos_a, np.ones([matches.shape[0], 1])], axis=1)
|
92 |
pos_b_proj_h = np.transpose(np.dot(homography, np.transpose(pos_a_h)))
|
93 |
-
pos_b_proj = pos_b_proj_h[:, :
|
94 |
|
95 |
-
pos_b = keypoints_b[matches[:, 1], :
|
96 |
|
97 |
dist = np.sqrt(np.sum((pos_b - pos_b_proj) ** 2, axis=1))
|
98 |
|
@@ -103,28 +129,37 @@ def benchmark_features(read_feats):
|
|
103 |
dist = np.array([float("inf")])
|
104 |
|
105 |
for thr in rng:
|
106 |
-
if seq_name[0] ==
|
107 |
i_err[thr] += np.mean(dist <= thr)
|
108 |
else:
|
109 |
v_err[thr] += np.mean(dist <= thr)
|
110 |
|
111 |
# =========== compute homography
|
112 |
gt_homo = homography
|
113 |
-
pred_homo, _ = cv2.findHomography(
|
114 |
-
|
|
|
|
|
|
|
115 |
if pred_homo is None:
|
116 |
homo_dist = np.array([float("inf")])
|
117 |
else:
|
118 |
-
corners = np.array(
|
119 |
-
|
120 |
-
|
121 |
-
|
|
|
|
|
|
|
|
|
122 |
real_warped_corners = homo_trans(corners, gt_homo)
|
123 |
warped_corners = homo_trans(corners, pred_homo)
|
124 |
-
homo_dist = np.mean(
|
|
|
|
|
125 |
|
126 |
for thr in rng:
|
127 |
-
if seq_name[0] ==
|
128 |
i_err_homo[thr] += np.mean(homo_dist <= thr)
|
129 |
else:
|
130 |
v_err_homo[thr] += np.mean(homo_dist <= thr)
|
@@ -136,10 +171,10 @@ def benchmark_features(read_feats):
|
|
136 |
return i_err, v_err, i_err_homo, v_err_homo, [seq_type, n_feats, n_matches]
|
137 |
|
138 |
|
139 |
-
if __name__ ==
|
140 |
errors = {}
|
141 |
for method in methods:
|
142 |
-
output_file = os.path.join(cache_dir, method +
|
143 |
read_function = generate_read_function(method)
|
144 |
if os.path.exists(output_file):
|
145 |
errors[method] = np.load(output_file, allow_pickle=True)
|
@@ -152,11 +187,11 @@ if __name__ == '__main__':
|
|
152 |
i_err, v_err, i_err_hom, v_err_hom, _ = errors[method]
|
153 |
|
154 |
print(f"====={name}=====")
|
155 |
-
print(f"MMA@1 MMA@2 MMA@3 MHA@1 MHA@2 MHA@3: ", end=
|
156 |
for thr in range(1, 4):
|
157 |
err = (i_err[thr] + v_err[thr]) / ((n_i + n_v) * 5)
|
158 |
-
print(f"{err * 100:.2f}%", end=
|
159 |
for thr in range(1, 4):
|
160 |
err_hom = (i_err_hom[thr] + v_err_hom[thr]) / ((n_i + n_v) * 5)
|
161 |
-
print(f"{err_hom * 100:.2f}%", end=
|
162 |
-
print(
|
|
|
6 |
from extract import extract_method
|
7 |
|
8 |
use_cuda = torch.cuda.is_available()
|
9 |
+
device = torch.device("cuda" if use_cuda else "cpu")
|
10 |
+
|
11 |
+
methods = [
|
12 |
+
"d2",
|
13 |
+
"lfnet",
|
14 |
+
"superpoint",
|
15 |
+
"r2d2",
|
16 |
+
"aslfeat",
|
17 |
+
"disk",
|
18 |
+
"alike-n",
|
19 |
+
"alike-l",
|
20 |
+
"alike-n-ms",
|
21 |
+
"alike-l-ms",
|
22 |
+
]
|
23 |
+
names = [
|
24 |
+
"D2-Net(MS)",
|
25 |
+
"LF-Net(MS)",
|
26 |
+
"SuperPoint",
|
27 |
+
"R2D2(MS)",
|
28 |
+
"ASLFeat(MS)",
|
29 |
+
"DISK",
|
30 |
+
"ALike-N",
|
31 |
+
"ALike-L",
|
32 |
+
"ALike-N(MS)",
|
33 |
+
"ALike-L(MS)",
|
34 |
+
]
|
35 |
|
36 |
top_k = None
|
37 |
n_i = 52
|
38 |
n_v = 56
|
39 |
+
cache_dir = "hseq/cache"
|
40 |
+
dataset_path = "hseq/hpatches-sequences-release"
|
41 |
|
42 |
|
43 |
+
def generate_read_function(method, extension="ppm"):
|
44 |
def read_function(seq_name, im_idx):
|
45 |
+
aux = np.load(
|
46 |
+
os.path.join(
|
47 |
+
dataset_path, seq_name, "%d.%s.%s" % (im_idx, extension, method)
|
48 |
+
)
|
49 |
+
)
|
50 |
if top_k is None:
|
51 |
+
return aux["keypoints"], aux["descriptors"]
|
52 |
else:
|
53 |
+
assert "scores" in aux
|
54 |
+
ids = np.argsort(aux["scores"])[-top_k:]
|
55 |
+
return aux["keypoints"][ids, :], aux["descriptors"][ids, :]
|
56 |
|
57 |
return read_function
|
58 |
|
|
|
63 |
nn12 = torch.max(sim, dim=1)[1]
|
64 |
nn21 = torch.max(sim, dim=0)[1]
|
65 |
ids1 = torch.arange(0, sim.shape[0], device=device)
|
66 |
+
mask = ids1 == nn21[nn12]
|
67 |
matches = torch.stack([ids1[mask], nn12[mask]])
|
68 |
return matches.t().data.cpu().numpy()
|
69 |
|
|
|
97 |
n_feats.append(keypoints_a.shape[0])
|
98 |
|
99 |
# =========== compute homography
|
100 |
+
ref_img = cv2.imread(os.path.join(dataset_path, seq_name, "1.ppm"))
|
101 |
ref_img_shape = ref_img.shape
|
102 |
|
103 |
for im_idx in range(2, 7):
|
|
|
106 |
|
107 |
matches = mnn_matcher(
|
108 |
torch.from_numpy(descriptors_a).to(device=device),
|
109 |
+
torch.from_numpy(descriptors_b).to(device=device),
|
110 |
)
|
111 |
|
112 |
+
homography = np.loadtxt(
|
113 |
+
os.path.join(dataset_path, seq_name, "H_1_" + str(im_idx))
|
114 |
+
)
|
115 |
|
116 |
+
pos_a = keypoints_a[matches[:, 0], :2]
|
117 |
pos_a_h = np.concatenate([pos_a, np.ones([matches.shape[0], 1])], axis=1)
|
118 |
pos_b_proj_h = np.transpose(np.dot(homography, np.transpose(pos_a_h)))
|
119 |
+
pos_b_proj = pos_b_proj_h[:, :2] / pos_b_proj_h[:, 2:]
|
120 |
|
121 |
+
pos_b = keypoints_b[matches[:, 1], :2]
|
122 |
|
123 |
dist = np.sqrt(np.sum((pos_b - pos_b_proj) ** 2, axis=1))
|
124 |
|
|
|
129 |
dist = np.array([float("inf")])
|
130 |
|
131 |
for thr in rng:
|
132 |
+
if seq_name[0] == "i":
|
133 |
i_err[thr] += np.mean(dist <= thr)
|
134 |
else:
|
135 |
v_err[thr] += np.mean(dist <= thr)
|
136 |
|
137 |
# =========== compute homography
|
138 |
gt_homo = homography
|
139 |
+
pred_homo, _ = cv2.findHomography(
|
140 |
+
keypoints_a[matches[:, 0], :2],
|
141 |
+
keypoints_b[matches[:, 1], :2],
|
142 |
+
cv2.RANSAC,
|
143 |
+
)
|
144 |
if pred_homo is None:
|
145 |
homo_dist = np.array([float("inf")])
|
146 |
else:
|
147 |
+
corners = np.array(
|
148 |
+
[
|
149 |
+
[0, 0],
|
150 |
+
[ref_img_shape[1] - 1, 0],
|
151 |
+
[0, ref_img_shape[0] - 1],
|
152 |
+
[ref_img_shape[1] - 1, ref_img_shape[0] - 1],
|
153 |
+
]
|
154 |
+
)
|
155 |
real_warped_corners = homo_trans(corners, gt_homo)
|
156 |
warped_corners = homo_trans(corners, pred_homo)
|
157 |
+
homo_dist = np.mean(
|
158 |
+
np.linalg.norm(real_warped_corners - warped_corners, axis=1)
|
159 |
+
)
|
160 |
|
161 |
for thr in rng:
|
162 |
+
if seq_name[0] == "i":
|
163 |
i_err_homo[thr] += np.mean(homo_dist <= thr)
|
164 |
else:
|
165 |
v_err_homo[thr] += np.mean(homo_dist <= thr)
|
|
|
171 |
return i_err, v_err, i_err_homo, v_err_homo, [seq_type, n_feats, n_matches]
|
172 |
|
173 |
|
174 |
+
if __name__ == "__main__":
|
175 |
errors = {}
|
176 |
for method in methods:
|
177 |
+
output_file = os.path.join(cache_dir, method + ".npy")
|
178 |
read_function = generate_read_function(method)
|
179 |
if os.path.exists(output_file):
|
180 |
errors[method] = np.load(output_file, allow_pickle=True)
|
|
|
187 |
i_err, v_err, i_err_hom, v_err_hom, _ = errors[method]
|
188 |
|
189 |
print(f"====={name}=====")
|
190 |
+
print(f"MMA@1 MMA@2 MMA@3 MHA@1 MHA@2 MHA@3: ", end="")
|
191 |
for thr in range(1, 4):
|
192 |
err = (i_err[thr] + v_err[thr]) / ((n_i + n_v) * 5)
|
193 |
+
print(f"{err * 100:.2f}%", end=" ")
|
194 |
for thr in range(1, 4):
|
195 |
err_hom = (i_err_hom[thr] + v_err_hom[thr]) / ((n_i + n_v) * 5)
|
196 |
+
print(f"{err_hom * 100:.2f}%", end=" ")
|
197 |
+
print("")
|
third_party/ALIKE/hseq/extract.py
CHANGED
@@ -9,23 +9,23 @@ from tqdm import tqdm
|
|
9 |
from copy import deepcopy
|
10 |
from torchvision.transforms import ToTensor
|
11 |
|
12 |
-
sys.path.append(os.path.join(os.path.dirname(__file__),
|
13 |
from alike import ALike, configs
|
14 |
|
15 |
-
dataset_root =
|
16 |
use_cuda = torch.cuda.is_available()
|
17 |
-
device =
|
18 |
-
methods = [
|
19 |
|
20 |
|
21 |
class HPatchesDataset(data.Dataset):
|
22 |
-
def __init__(self, root: str = dataset_root, alteration: str =
|
23 |
"""
|
24 |
Args:
|
25 |
root: dataset root path
|
26 |
alteration: # 'all', 'i' for illumination or 'v' for viewpoint
|
27 |
"""
|
28 |
-
assert
|
29 |
self.root = root
|
30 |
|
31 |
# get all image file name
|
@@ -35,15 +35,15 @@ class HPatchesDataset(data.Dataset):
|
|
35 |
folders = [x for x in Path(self.root).iterdir() if x.is_dir()]
|
36 |
self.seqs = []
|
37 |
for folder in folders:
|
38 |
-
if alteration ==
|
39 |
continue
|
40 |
-
if alteration ==
|
41 |
continue
|
42 |
|
43 |
self.seqs.append(folder)
|
44 |
|
45 |
self.len = len(self.seqs)
|
46 |
-
assert
|
47 |
|
48 |
def __getitem__(self, item):
|
49 |
folder = self.seqs[item]
|
@@ -51,12 +51,12 @@ class HPatchesDataset(data.Dataset):
|
|
51 |
imgs = []
|
52 |
homos = []
|
53 |
for i in range(1, 7):
|
54 |
-
img = cv2.imread(str(folder / f
|
55 |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # HxWxC
|
56 |
imgs.append(img)
|
57 |
|
58 |
if i != 1:
|
59 |
-
homo = np.loadtxt(str(folder / f
|
60 |
homos.append(homo)
|
61 |
|
62 |
return imgs, homos, folder.stem
|
@@ -68,11 +68,18 @@ class HPatchesDataset(data.Dataset):
|
|
68 |
return self.__class__
|
69 |
|
70 |
|
71 |
-
def extract_multiscale(
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
H_, W_, three = img.shape
|
77 |
assert three == 3, "input image shape should be [HxWx3]"
|
78 |
|
@@ -100,7 +107,9 @@ def extract_multiscale(model, img, scale_f=2 ** 0.5,
|
|
100 |
# extract descriptors
|
101 |
with torch.no_grad():
|
102 |
descriptor_map, scores_map = model.extract_dense_map(image)
|
103 |
-
keypoints_, descriptors_, scores_, _ = model.dkd(
|
|
|
|
|
104 |
|
105 |
keypoints.append(keypoints_[0])
|
106 |
descriptors.append(descriptors_[0])
|
@@ -110,7 +119,9 @@ def extract_multiscale(model, img, scale_f=2 ** 0.5,
|
|
110 |
|
111 |
# down-scale the image for next iteration
|
112 |
nh, nw = round(H * s), round(W * s)
|
113 |
-
image = torch.nn.functional.interpolate(
|
|
|
|
|
114 |
|
115 |
# restore value
|
116 |
torch.backends.cudnn.benchmark = old_bm
|
@@ -131,29 +142,34 @@ def extract_multiscale(model, img, scale_f=2 ** 0.5,
|
|
131 |
descriptors = descriptors[0:n_k]
|
132 |
scores = scores[0:n_k]
|
133 |
|
134 |
-
return {
|
135 |
|
136 |
|
137 |
def extract_method(m):
|
138 |
-
hpatches = HPatchesDataset(root=dataset_root, alteration=
|
139 |
model = m[:7]
|
140 |
-
min_scale = 0.3 if m[8:] ==
|
141 |
|
142 |
model = ALike(**configs[model], device=device, top_k=0, scores_th=0.2, n_limit=5000)
|
143 |
|
144 |
-
progbar = tqdm(hpatches, desc=
|
145 |
for imgs, homos, seq_name in progbar:
|
146 |
for i in range(1, 7):
|
147 |
img = imgs[i - 1]
|
148 |
-
pred = extract_multiscale(
|
149 |
-
|
|
|
|
|
150 |
|
151 |
-
with open(os.path.join(dataset_root, seq_name, f
|
152 |
-
np.savez(
|
153 |
-
|
154 |
-
|
|
|
|
|
|
|
155 |
|
156 |
|
157 |
-
if __name__ ==
|
158 |
for method in methods:
|
159 |
extract_method(method)
|
|
|
9 |
from copy import deepcopy
|
10 |
from torchvision.transforms import ToTensor
|
11 |
|
12 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
13 |
from alike import ALike, configs
|
14 |
|
15 |
+
dataset_root = "hseq/hpatches-sequences-release"
|
16 |
use_cuda = torch.cuda.is_available()
|
17 |
+
device = "cuda" if use_cuda else "cpu"
|
18 |
+
methods = ["alike-n", "alike-l", "alike-n-ms", "alike-l-ms"]
|
19 |
|
20 |
|
21 |
class HPatchesDataset(data.Dataset):
|
22 |
+
def __init__(self, root: str = dataset_root, alteration: str = "all"):
|
23 |
"""
|
24 |
Args:
|
25 |
root: dataset root path
|
26 |
alteration: # 'all', 'i' for illumination or 'v' for viewpoint
|
27 |
"""
|
28 |
+
assert Path(root).exists(), f"Dataset root path {root} dose not exist!"
|
29 |
self.root = root
|
30 |
|
31 |
# get all image file name
|
|
|
35 |
folders = [x for x in Path(self.root).iterdir() if x.is_dir()]
|
36 |
self.seqs = []
|
37 |
for folder in folders:
|
38 |
+
if alteration == "i" and folder.stem[0] != "i":
|
39 |
continue
|
40 |
+
if alteration == "v" and folder.stem[0] != "v":
|
41 |
continue
|
42 |
|
43 |
self.seqs.append(folder)
|
44 |
|
45 |
self.len = len(self.seqs)
|
46 |
+
assert self.len > 0, f"Can not find PatchDataset in path {self.root}"
|
47 |
|
48 |
def __getitem__(self, item):
|
49 |
folder = self.seqs[item]
|
|
|
51 |
imgs = []
|
52 |
homos = []
|
53 |
for i in range(1, 7):
|
54 |
+
img = cv2.imread(str(folder / f"{i}.ppm"), cv2.IMREAD_COLOR)
|
55 |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # HxWxC
|
56 |
imgs.append(img)
|
57 |
|
58 |
if i != 1:
|
59 |
+
homo = np.loadtxt(str(folder / f"H_1_{i}")).astype("float32")
|
60 |
homos.append(homo)
|
61 |
|
62 |
return imgs, homos, folder.stem
|
|
|
68 |
return self.__class__
|
69 |
|
70 |
|
71 |
+
def extract_multiscale(
|
72 |
+
model,
|
73 |
+
img,
|
74 |
+
scale_f=2**0.5,
|
75 |
+
min_scale=1.0,
|
76 |
+
max_scale=1.0,
|
77 |
+
min_size=0.0,
|
78 |
+
max_size=99999.0,
|
79 |
+
image_size_max=99999,
|
80 |
+
n_k=0,
|
81 |
+
sort=False,
|
82 |
+
):
|
83 |
H_, W_, three = img.shape
|
84 |
assert three == 3, "input image shape should be [HxWx3]"
|
85 |
|
|
|
107 |
# extract descriptors
|
108 |
with torch.no_grad():
|
109 |
descriptor_map, scores_map = model.extract_dense_map(image)
|
110 |
+
keypoints_, descriptors_, scores_, _ = model.dkd(
|
111 |
+
scores_map, descriptor_map
|
112 |
+
)
|
113 |
|
114 |
keypoints.append(keypoints_[0])
|
115 |
descriptors.append(descriptors_[0])
|
|
|
119 |
|
120 |
# down-scale the image for next iteration
|
121 |
nh, nw = round(H * s), round(W * s)
|
122 |
+
image = torch.nn.functional.interpolate(
|
123 |
+
image, (nh, nw), mode="bilinear", align_corners=False
|
124 |
+
)
|
125 |
|
126 |
# restore value
|
127 |
torch.backends.cudnn.benchmark = old_bm
|
|
|
142 |
descriptors = descriptors[0:n_k]
|
143 |
scores = scores[0:n_k]
|
144 |
|
145 |
+
return {"keypoints": keypoints, "descriptors": descriptors, "scores": scores}
|
146 |
|
147 |
|
148 |
def extract_method(m):
|
149 |
+
hpatches = HPatchesDataset(root=dataset_root, alteration="all")
|
150 |
model = m[:7]
|
151 |
+
min_scale = 0.3 if m[8:] == "ms" else 1.0
|
152 |
|
153 |
model = ALike(**configs[model], device=device, top_k=0, scores_th=0.2, n_limit=5000)
|
154 |
|
155 |
+
progbar = tqdm(hpatches, desc="Extracting for {}".format(m))
|
156 |
for imgs, homos, seq_name in progbar:
|
157 |
for i in range(1, 7):
|
158 |
img = imgs[i - 1]
|
159 |
+
pred = extract_multiscale(
|
160 |
+
model, img, min_scale=min_scale, max_scale=1, sort=False, n_k=5000
|
161 |
+
)
|
162 |
+
kpts, descs, scores = pred["keypoints"], pred["descriptors"], pred["scores"]
|
163 |
|
164 |
+
with open(os.path.join(dataset_root, seq_name, f"{i}.ppm.{m}"), "wb") as f:
|
165 |
+
np.savez(
|
166 |
+
f,
|
167 |
+
keypoints=kpts.cpu().numpy(),
|
168 |
+
scores=scores.cpu().numpy(),
|
169 |
+
descriptors=descs.cpu().numpy(),
|
170 |
+
)
|
171 |
|
172 |
|
173 |
+
if __name__ == "__main__":
|
174 |
for method in methods:
|
175 |
extract_method(method)
|
third_party/ALIKE/soft_detect.py
CHANGED
@@ -17,13 +17,15 @@ import torch.nn.functional as F
|
|
17 |
# v
|
18 |
# [ y: range=-1.0~1.0; h: range=0~H ]
|
19 |
|
|
|
20 |
def simple_nms(scores, nms_radius: int):
|
21 |
-
"""
|
22 |
-
assert
|
23 |
|
24 |
def max_pool(x):
|
25 |
return torch.nn.functional.max_pool2d(
|
26 |
-
x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
|
|
|
27 |
|
28 |
zeros = torch.zeros_like(scores)
|
29 |
max_mask = scores == max_pool(scores)
|
@@ -50,8 +52,14 @@ def sample_descriptor(descriptor_map, kpts, bilinear_interp=False):
|
|
50 |
kptsi = kpts[index] # Nx2,(x,y)
|
51 |
|
52 |
if bilinear_interp:
|
53 |
-
descriptors_ = torch.nn.functional.grid_sample(
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
else:
|
56 |
kptsi = (kptsi + 1) / 2 * kptsi.new_tensor([[width - 1, height - 1]])
|
57 |
kptsi = kptsi.long()
|
@@ -94,10 +102,10 @@ class DKD(nn.Module):
|
|
94 |
nms_scores = simple_nms(scores_nograd, 2)
|
95 |
|
96 |
# remove border
|
97 |
-
nms_scores[:, :, :self.radius + 1, :] = 0
|
98 |
-
nms_scores[:, :, :, :self.radius + 1] = 0
|
99 |
-
nms_scores[:, :, h - self.radius:, :] = 0
|
100 |
-
nms_scores[:, :, :, w - self.radius:] = 0
|
101 |
|
102 |
# detect keypoints without grad
|
103 |
if self.top_k > 0:
|
@@ -121,7 +129,7 @@ class DKD(nn.Module):
|
|
121 |
if len(indices) > self.n_limit:
|
122 |
kpts_sc = scores[indices]
|
123 |
sort_idx = kpts_sc.sort(descending=True)[1]
|
124 |
-
sel_idx = sort_idx[:self.n_limit]
|
125 |
indices = indices[sel_idx]
|
126 |
indices_keypoints.append(indices)
|
127 |
|
@@ -134,42 +142,73 @@ class DKD(nn.Module):
|
|
134 |
self.hw_grid = self.hw_grid.to(patches) # to device
|
135 |
for b_idx in range(b):
|
136 |
patch = patches[b_idx].t() # (H*W) x (kernel**2)
|
137 |
-
indices_kpt = indices_keypoints[
|
|
|
|
|
138 |
patch_scores = patch[indices_kpt] # M x (kernel**2)
|
139 |
|
140 |
# max is detached to prevent undesired backprop loops in the graph
|
141 |
max_v = patch_scores.max(dim=1).values.detach()[:, None]
|
142 |
-
x_exp = (
|
|
|
|
|
143 |
|
144 |
# \frac{ \sum{(i,j) \times \exp(x/T)} }{ \sum{\exp(x/T)} }
|
145 |
-
xy_residual =
|
146 |
-
|
147 |
-
|
148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
scoredispersity = (x_exp * hw_grid_dist2).sum(dim=1) / x_exp.sum(dim=1)
|
150 |
|
151 |
# compute result keypoints
|
152 |
-
keypoints_xy_nms = torch.stack(
|
|
|
|
|
153 |
keypoints_xy = keypoints_xy_nms + xy_residual
|
154 |
-
keypoints_xy =
|
155 |
-
[w - 1, h - 1]) * 2 - 1
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
keypoints.append(keypoints_xy)
|
162 |
scoredispersitys.append(scoredispersity)
|
163 |
kptscores.append(kptscore)
|
164 |
else:
|
165 |
for b_idx in range(b):
|
166 |
-
indices_kpt = indices_keypoints[
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
keypoints.append(keypoints_xy)
|
174 |
scoredispersitys.append(None)
|
175 |
kptscores.append(kptscore)
|
@@ -183,8 +222,9 @@ class DKD(nn.Module):
|
|
183 |
:param sub_pixel: whether to use sub-pixel keypoint detection
|
184 |
:return: kpts: list[Nx2,...]; kptscores: list[N,....] normalised position: -1.0 ~ 1.0
|
185 |
"""
|
186 |
-
keypoints, scoredispersitys, kptscores = self.detect_keypoints(
|
187 |
-
|
|
|
188 |
|
189 |
descriptors = sample_descriptor(descriptor_map, keypoints, sub_pixel)
|
190 |
|
|
|
17 |
# v
|
18 |
# [ y: range=-1.0~1.0; h: range=0~H ]
|
19 |
|
20 |
+
|
21 |
def simple_nms(scores, nms_radius: int):
|
22 |
+
"""Fast Non-maximum suppression to remove nearby points"""
|
23 |
+
assert nms_radius >= 0
|
24 |
|
25 |
def max_pool(x):
|
26 |
return torch.nn.functional.max_pool2d(
|
27 |
+
x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
|
28 |
+
)
|
29 |
|
30 |
zeros = torch.zeros_like(scores)
|
31 |
max_mask = scores == max_pool(scores)
|
|
|
52 |
kptsi = kpts[index] # Nx2,(x,y)
|
53 |
|
54 |
if bilinear_interp:
|
55 |
+
descriptors_ = torch.nn.functional.grid_sample(
|
56 |
+
descriptor_map[index].unsqueeze(0),
|
57 |
+
kptsi.view(1, 1, -1, 2),
|
58 |
+
mode="bilinear",
|
59 |
+
align_corners=True,
|
60 |
+
)[
|
61 |
+
0, :, 0, :
|
62 |
+
] # CxN
|
63 |
else:
|
64 |
kptsi = (kptsi + 1) / 2 * kptsi.new_tensor([[width - 1, height - 1]])
|
65 |
kptsi = kptsi.long()
|
|
|
102 |
nms_scores = simple_nms(scores_nograd, 2)
|
103 |
|
104 |
# remove border
|
105 |
+
nms_scores[:, :, : self.radius + 1, :] = 0
|
106 |
+
nms_scores[:, :, :, : self.radius + 1] = 0
|
107 |
+
nms_scores[:, :, h - self.radius :, :] = 0
|
108 |
+
nms_scores[:, :, :, w - self.radius :] = 0
|
109 |
|
110 |
# detect keypoints without grad
|
111 |
if self.top_k > 0:
|
|
|
129 |
if len(indices) > self.n_limit:
|
130 |
kpts_sc = scores[indices]
|
131 |
sort_idx = kpts_sc.sort(descending=True)[1]
|
132 |
+
sel_idx = sort_idx[: self.n_limit]
|
133 |
indices = indices[sel_idx]
|
134 |
indices_keypoints.append(indices)
|
135 |
|
|
|
142 |
self.hw_grid = self.hw_grid.to(patches) # to device
|
143 |
for b_idx in range(b):
|
144 |
patch = patches[b_idx].t() # (H*W) x (kernel**2)
|
145 |
+
indices_kpt = indices_keypoints[
|
146 |
+
b_idx
|
147 |
+
] # one dimension vector, say its size is M
|
148 |
patch_scores = patch[indices_kpt] # M x (kernel**2)
|
149 |
|
150 |
# max is detached to prevent undesired backprop loops in the graph
|
151 |
max_v = patch_scores.max(dim=1).values.detach()[:, None]
|
152 |
+
x_exp = (
|
153 |
+
(patch_scores - max_v) / self.temperature
|
154 |
+
).exp() # M * (kernel**2), in [0, 1]
|
155 |
|
156 |
# \frac{ \sum{(i,j) \times \exp(x/T)} }{ \sum{\exp(x/T)} }
|
157 |
+
xy_residual = (
|
158 |
+
x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None]
|
159 |
+
) # Soft-argmax, Mx2
|
160 |
+
|
161 |
+
hw_grid_dist2 = (
|
162 |
+
torch.norm(
|
163 |
+
(self.hw_grid[None, :, :] - xy_residual[:, None, :])
|
164 |
+
/ self.radius,
|
165 |
+
dim=-1,
|
166 |
+
)
|
167 |
+
** 2
|
168 |
+
)
|
169 |
scoredispersity = (x_exp * hw_grid_dist2).sum(dim=1) / x_exp.sum(dim=1)
|
170 |
|
171 |
# compute result keypoints
|
172 |
+
keypoints_xy_nms = torch.stack(
|
173 |
+
[indices_kpt % w, indices_kpt // w], dim=1
|
174 |
+
) # Mx2
|
175 |
keypoints_xy = keypoints_xy_nms + xy_residual
|
176 |
+
keypoints_xy = (
|
177 |
+
keypoints_xy / keypoints_xy.new_tensor([w - 1, h - 1]) * 2 - 1
|
178 |
+
) # (w,h) -> (-1~1,-1~1)
|
179 |
+
|
180 |
+
kptscore = torch.nn.functional.grid_sample(
|
181 |
+
scores_map[b_idx].unsqueeze(0),
|
182 |
+
keypoints_xy.view(1, 1, -1, 2),
|
183 |
+
mode="bilinear",
|
184 |
+
align_corners=True,
|
185 |
+
)[
|
186 |
+
0, 0, 0, :
|
187 |
+
] # CxN
|
188 |
|
189 |
keypoints.append(keypoints_xy)
|
190 |
scoredispersitys.append(scoredispersity)
|
191 |
kptscores.append(kptscore)
|
192 |
else:
|
193 |
for b_idx in range(b):
|
194 |
+
indices_kpt = indices_keypoints[
|
195 |
+
b_idx
|
196 |
+
] # one dimension vector, say its size is M
|
197 |
+
keypoints_xy_nms = torch.stack(
|
198 |
+
[indices_kpt % w, indices_kpt // w], dim=1
|
199 |
+
) # Mx2
|
200 |
+
keypoints_xy = (
|
201 |
+
keypoints_xy_nms / keypoints_xy_nms.new_tensor([w - 1, h - 1]) * 2
|
202 |
+
- 1
|
203 |
+
) # (w,h) -> (-1~1,-1~1)
|
204 |
+
kptscore = torch.nn.functional.grid_sample(
|
205 |
+
scores_map[b_idx].unsqueeze(0),
|
206 |
+
keypoints_xy.view(1, 1, -1, 2),
|
207 |
+
mode="bilinear",
|
208 |
+
align_corners=True,
|
209 |
+
)[
|
210 |
+
0, 0, 0, :
|
211 |
+
] # CxN
|
212 |
keypoints.append(keypoints_xy)
|
213 |
scoredispersitys.append(None)
|
214 |
kptscores.append(kptscore)
|
|
|
222 |
:param sub_pixel: whether to use sub-pixel keypoint detection
|
223 |
:return: kpts: list[Nx2,...]; kptscores: list[N,....] normalised position: -1.0 ~ 1.0
|
224 |
"""
|
225 |
+
keypoints, scoredispersitys, kptscores = self.detect_keypoints(
|
226 |
+
scores_map, sub_pixel
|
227 |
+
)
|
228 |
|
229 |
descriptors = sample_descriptor(descriptor_map, keypoints, sub_pixel)
|
230 |
|
third_party/ASpanFormer/configs/aspan/indoor/aspan_test.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
3 |
-
|
|
|
4 |
from src.config.default import _CN as cfg
|
5 |
|
6 |
-
cfg.ASPAN.MATCH_COARSE.MATCH_TYPE =
|
7 |
|
8 |
cfg.ASPAN.MATCH_COARSE.BORDER_RM = 0
|
9 |
-
cfg.ASPAN.COARSE.COARSEST_LEVEL= [15,20]
|
10 |
-
cfg.ASPAN.COARSE.TRAIN_RES = [480,640]
|
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
3 |
+
|
4 |
+
sys.path.append(str(Path(__file__).parent / "../../../"))
|
5 |
from src.config.default import _CN as cfg
|
6 |
|
7 |
+
cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = "dual_softmax"
|
8 |
|
9 |
cfg.ASPAN.MATCH_COARSE.BORDER_RM = 0
|
10 |
+
cfg.ASPAN.COARSE.COARSEST_LEVEL = [15, 20]
|
11 |
+
cfg.ASPAN.COARSE.TRAIN_RES = [480, 640]
|
third_party/ASpanFormer/configs/aspan/indoor/aspan_train.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
3 |
-
|
|
|
4 |
from src.config.default import _CN as cfg
|
5 |
|
6 |
-
cfg.ASPAN.COARSE.COARSEST_LEVEL= [15,20]
|
7 |
-
cfg.ASPAN.MATCH_COARSE.MATCH_TYPE =
|
8 |
|
9 |
cfg.ASPAN.MATCH_COARSE.SPARSE_SPVS = False
|
10 |
cfg.ASPAN.MATCH_COARSE.BORDER_RM = 0
|
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
3 |
+
|
4 |
+
sys.path.append(str(Path(__file__).parent / "../../../"))
|
5 |
from src.config.default import _CN as cfg
|
6 |
|
7 |
+
cfg.ASPAN.COARSE.COARSEST_LEVEL = [15, 20]
|
8 |
+
cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = "dual_softmax"
|
9 |
|
10 |
cfg.ASPAN.MATCH_COARSE.SPARSE_SPVS = False
|
11 |
cfg.ASPAN.MATCH_COARSE.BORDER_RM = 0
|
third_party/ASpanFormer/configs/aspan/outdoor/aspan_test.py
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
3 |
-
|
|
|
4 |
from src.config.default import _CN as cfg
|
5 |
|
6 |
-
cfg.ASPAN.COARSE.COARSEST_LEVEL= [36,36]
|
7 |
-
cfg.ASPAN.COARSE.TRAIN_RES = [832,832]
|
8 |
-
cfg.ASPAN.COARSE.TEST_RES = [1152,1152]
|
9 |
-
cfg.ASPAN.MATCH_COARSE.MATCH_TYPE =
|
10 |
|
11 |
cfg.TRAINER.CANONICAL_LR = 8e-3
|
12 |
cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs
|
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
3 |
+
|
4 |
+
sys.path.append(str(Path(__file__).parent / "../../../"))
|
5 |
from src.config.default import _CN as cfg
|
6 |
|
7 |
+
cfg.ASPAN.COARSE.COARSEST_LEVEL = [36, 36]
|
8 |
+
cfg.ASPAN.COARSE.TRAIN_RES = [832, 832]
|
9 |
+
cfg.ASPAN.COARSE.TEST_RES = [1152, 1152]
|
10 |
+
cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = "dual_softmax"
|
11 |
|
12 |
cfg.TRAINER.CANONICAL_LR = 8e-3
|
13 |
cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs
|
third_party/ASpanFormer/configs/aspan/outdoor/aspan_train.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
3 |
-
|
|
|
4 |
from src.config.default import _CN as cfg
|
5 |
|
6 |
-
cfg.ASPAN.COARSE.COARSEST_LEVEL= [26,26]
|
7 |
-
cfg.ASPAN.MATCH_COARSE.MATCH_TYPE =
|
8 |
cfg.ASPAN.MATCH_COARSE.SPARSE_SPVS = False
|
9 |
|
10 |
cfg.TRAINER.CANONICAL_LR = 8e-3
|
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
3 |
+
|
4 |
+
sys.path.append(str(Path(__file__).parent / "../../../"))
|
5 |
from src.config.default import _CN as cfg
|
6 |
|
7 |
+
cfg.ASPAN.COARSE.COARSEST_LEVEL = [26, 26]
|
8 |
+
cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = "dual_softmax"
|
9 |
cfg.ASPAN.MATCH_COARSE.SPARSE_SPVS = False
|
10 |
|
11 |
cfg.TRAINER.CANONICAL_LR = 8e-3
|
third_party/ASpanFormer/configs/data/base.py
CHANGED
@@ -4,6 +4,7 @@ Setups in data configs will override all existed setups!
|
|
4 |
"""
|
5 |
|
6 |
from yacs.config import CfgNode as CN
|
|
|
7 |
_CN = CN()
|
8 |
_CN.DATASET = CN()
|
9 |
_CN.TRAINER = CN()
|
|
|
4 |
"""
|
5 |
|
6 |
from yacs.config import CfgNode as CN
|
7 |
+
|
8 |
_CN = CN()
|
9 |
_CN.DATASET = CN()
|
10 |
_CN.TRAINER = CN()
|
third_party/ASpanFormer/configs/data/megadepth_test_1500.py
CHANGED
@@ -8,6 +8,6 @@ cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}"
|
|
8 |
cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/megadepth_test_1500.txt"
|
9 |
|
10 |
cfg.DATASET.MGDPT_IMG_RESIZE = 1152
|
11 |
-
cfg.DATASET.MGDPT_IMG_PAD=True
|
12 |
-
cfg.DATASET.MGDPT_DF =8
|
13 |
-
cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
|
|
|
8 |
cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/megadepth_test_1500.txt"
|
9 |
|
10 |
cfg.DATASET.MGDPT_IMG_RESIZE = 1152
|
11 |
+
cfg.DATASET.MGDPT_IMG_PAD = True
|
12 |
+
cfg.DATASET.MGDPT_DF = 8
|
13 |
+
cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
|
third_party/ASpanFormer/configs/data/megadepth_trainval_832.py
CHANGED
@@ -11,9 +11,13 @@ cfg.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.0
|
|
11 |
TEST_BASE_PATH = "data/megadepth/index"
|
12 |
cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth"
|
13 |
cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test"
|
14 |
-
cfg.DATASET.VAL_NPZ_ROOT =
|
15 |
-
cfg.DATASET.
|
16 |
-
|
|
|
|
|
|
|
|
|
17 |
|
18 |
# 368 scenes in total for MegaDepth
|
19 |
# (with difficulty balanced (further split each scene to 3 sub-scenes))
|
|
|
11 |
TEST_BASE_PATH = "data/megadepth/index"
|
12 |
cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth"
|
13 |
cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test"
|
14 |
+
cfg.DATASET.VAL_NPZ_ROOT = (
|
15 |
+
cfg.DATASET.TEST_NPZ_ROOT
|
16 |
+
) = f"{TEST_BASE_PATH}/scene_info_val_1500"
|
17 |
+
cfg.DATASET.VAL_LIST_PATH = (
|
18 |
+
cfg.DATASET.TEST_LIST_PATH
|
19 |
+
) = f"{TEST_BASE_PATH}/trainvaltest_list/val_list.txt"
|
20 |
+
cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val
|
21 |
|
22 |
# 368 scenes in total for MegaDepth
|
23 |
# (with difficulty balanced (further split each scene to 3 sub-scenes))
|
third_party/ASpanFormer/configs/data/scannet_trainval.py
CHANGED
@@ -12,6 +12,10 @@ TEST_BASE_PATH = "assets/scannet_test_1500"
|
|
12 |
cfg.DATASET.TEST_DATA_SOURCE = "ScanNet"
|
13 |
cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/scannet/test"
|
14 |
cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = TEST_BASE_PATH
|
15 |
-
cfg.DATASET.VAL_LIST_PATH =
|
16 |
-
cfg.DATASET.
|
17 |
-
|
|
|
|
|
|
|
|
|
|
12 |
cfg.DATASET.TEST_DATA_SOURCE = "ScanNet"
|
13 |
cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/scannet/test"
|
14 |
cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = TEST_BASE_PATH
|
15 |
+
cfg.DATASET.VAL_LIST_PATH = (
|
16 |
+
cfg.DATASET.TEST_LIST_PATH
|
17 |
+
) = f"{TEST_BASE_PATH}/scannet_test.txt"
|
18 |
+
cfg.DATASET.VAL_INTRINSIC_PATH = (
|
19 |
+
cfg.DATASET.TEST_INTRINSIC_PATH
|
20 |
+
) = f"{TEST_BASE_PATH}/intrinsics.npz"
|
21 |
+
cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val
|
third_party/ASpanFormer/demo/demo.py
CHANGED
@@ -1,63 +1,91 @@
|
|
1 |
import os
|
2 |
import sys
|
|
|
3 |
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
4 |
sys.path.insert(0, ROOT_DIR)
|
5 |
|
6 |
-
from src.ASpanFormer.aspanformer import ASpanFormer
|
7 |
from src.config.default import get_cfg_defaults
|
8 |
from src.utils.misc import lower_config
|
9 |
-
import demo_utils
|
10 |
|
11 |
import cv2
|
12 |
import torch
|
13 |
import numpy as np
|
14 |
|
15 |
import argparse
|
|
|
16 |
parser = argparse.ArgumentParser()
|
17 |
-
parser.add_argument(
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
parser.add_argument(
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
args = parser.parse_args()
|
31 |
|
32 |
|
33 |
-
if __name__==
|
34 |
config = get_cfg_defaults()
|
35 |
config.merge_from_file(args.config_path)
|
36 |
_config = lower_config(config)
|
37 |
-
matcher = ASpanFormer(config=_config[
|
38 |
-
state_dict = torch.load(args.weights_path, map_location=
|
39 |
-
matcher.load_state_dict(state_dict,strict=False)
|
40 |
-
matcher.cuda(),matcher.eval()
|
41 |
-
|
42 |
-
img0,img1=cv2.imread(args.img0_path),cv2.imread(args.img1_path)
|
43 |
-
img0_g,img1_g=cv2.imread(args.img0_path,0),cv2.imread(args.img1_path,0)
|
44 |
-
img0,img1=demo_utils.resize(img0,args.long_dim0),demo_utils.resize(
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
if mask_F is not None:
|
54 |
-
|
55 |
else:
|
56 |
-
|
57 |
-
|
58 |
-
#visualize match
|
59 |
-
display=demo_utils.draw_match(img0,img1,corr0,corr1)
|
60 |
-
display_ransac=demo_utils.draw_match(img0,img1,corr0[mask_F],corr1[mask_F])
|
61 |
-
cv2.imwrite(
|
62 |
-
cv2.imwrite(
|
63 |
-
print(len(corr1),len(corr1[mask_F]))
|
|
|
1 |
import os
|
2 |
import sys
|
3 |
+
|
4 |
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
5 |
sys.path.insert(0, ROOT_DIR)
|
6 |
|
7 |
+
from src.ASpanFormer.aspanformer import ASpanFormer
|
8 |
from src.config.default import get_cfg_defaults
|
9 |
from src.utils.misc import lower_config
|
10 |
+
import demo_utils
|
11 |
|
12 |
import cv2
|
13 |
import torch
|
14 |
import numpy as np
|
15 |
|
16 |
import argparse
|
17 |
+
|
18 |
parser = argparse.ArgumentParser()
|
19 |
+
parser.add_argument(
|
20 |
+
"--config_path",
|
21 |
+
type=str,
|
22 |
+
default="../configs/aspan/outdoor/aspan_test.py",
|
23 |
+
help="path for config file.",
|
24 |
+
)
|
25 |
+
parser.add_argument(
|
26 |
+
"--img0_path",
|
27 |
+
type=str,
|
28 |
+
default="../assets/phototourism_sample_images/piazza_san_marco_06795901_3725050516.jpg",
|
29 |
+
help="path for image0.",
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
"--img1_path",
|
33 |
+
type=str,
|
34 |
+
default="../assets/phototourism_sample_images/piazza_san_marco_15148634_5228701572.jpg",
|
35 |
+
help="path for image1.",
|
36 |
+
)
|
37 |
+
parser.add_argument(
|
38 |
+
"--weights_path",
|
39 |
+
type=str,
|
40 |
+
default="../weights/outdoor.ckpt",
|
41 |
+
help="path for model weights.",
|
42 |
+
)
|
43 |
+
parser.add_argument(
|
44 |
+
"--long_dim0", type=int, default=1024, help="resize for longest dim of image0."
|
45 |
+
)
|
46 |
+
parser.add_argument(
|
47 |
+
"--long_dim1", type=int, default=1024, help="resize for longest dim of image1."
|
48 |
+
)
|
49 |
|
50 |
args = parser.parse_args()
|
51 |
|
52 |
|
53 |
+
if __name__ == "__main__":
|
54 |
config = get_cfg_defaults()
|
55 |
config.merge_from_file(args.config_path)
|
56 |
_config = lower_config(config)
|
57 |
+
matcher = ASpanFormer(config=_config["aspan"])
|
58 |
+
state_dict = torch.load(args.weights_path, map_location="cpu")["state_dict"]
|
59 |
+
matcher.load_state_dict(state_dict, strict=False)
|
60 |
+
matcher.cuda(), matcher.eval()
|
61 |
+
|
62 |
+
img0, img1 = cv2.imread(args.img0_path), cv2.imread(args.img1_path)
|
63 |
+
img0_g, img1_g = cv2.imread(args.img0_path, 0), cv2.imread(args.img1_path, 0)
|
64 |
+
img0, img1 = demo_utils.resize(img0, args.long_dim0), demo_utils.resize(
|
65 |
+
img1, args.long_dim1
|
66 |
+
)
|
67 |
+
img0_g, img1_g = demo_utils.resize(img0_g, args.long_dim0), demo_utils.resize(
|
68 |
+
img1_g, args.long_dim1
|
69 |
+
)
|
70 |
+
data = {
|
71 |
+
"image0": torch.from_numpy(img0_g / 255.0)[None, None].cuda().float(),
|
72 |
+
"image1": torch.from_numpy(img1_g / 255.0)[None, None].cuda().float(),
|
73 |
+
}
|
74 |
+
with torch.no_grad():
|
75 |
+
matcher(data, online_resize=True)
|
76 |
+
corr0, corr1 = data["mkpts0_f"].cpu().numpy(), data["mkpts1_f"].cpu().numpy()
|
77 |
+
|
78 |
+
F_hat, mask_F = cv2.findFundamentalMat(
|
79 |
+
corr0, corr1, method=cv2.FM_RANSAC, ransacReprojThreshold=1
|
80 |
+
)
|
81 |
if mask_F is not None:
|
82 |
+
mask_F = mask_F[:, 0].astype(bool)
|
83 |
else:
|
84 |
+
mask_F = np.zeros_like(corr0[:, 0]).astype(bool)
|
85 |
+
|
86 |
+
# visualize match
|
87 |
+
display = demo_utils.draw_match(img0, img1, corr0, corr1)
|
88 |
+
display_ransac = demo_utils.draw_match(img0, img1, corr0[mask_F], corr1[mask_F])
|
89 |
+
cv2.imwrite("match.png", display)
|
90 |
+
cv2.imwrite("match_ransac.png", display_ransac)
|
91 |
+
print(len(corr1), len(corr1[mask_F]))
|
third_party/ASpanFormer/demo/demo_utils.py
CHANGED
@@ -1,44 +1,88 @@
|
|
1 |
import cv2
|
2 |
import numpy as np
|
3 |
|
4 |
-
|
5 |
-
|
6 |
-
|
|
|
|
|
|
|
7 |
return image
|
8 |
|
9 |
-
|
|
|
10 |
dp = [(int(points[i, 0]), int(points[i, 1])) for i in range(points.shape[0])]
|
11 |
for i in range(points.shape[0]):
|
12 |
-
cv2.circle(img, dp[i],radius=radius,color=color)
|
13 |
return img
|
14 |
-
|
15 |
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
if resize is not None:
|
18 |
-
scale1,scale2=[img1.shape[1]/resize[0],img1.shape[0]/resize[1]],[
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
assert len(corr1) == len(corr2)
|
25 |
|
26 |
draw_matches = [cv2.DMatch(i, i, 0) for i in range(len(corr1))]
|
27 |
if color is None:
|
28 |
-
color = [(0, 255, 0) if cur_inlier else (0,0,255) for cur_inlier in inlier]
|
29 |
-
if len(color)==1:
|
30 |
-
display = cv2.drawMatches(
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
else:
|
36 |
-
height,width=max(img1.shape[0],img2.shape[0]),img1.shape[1]+img2.shape[1]
|
37 |
-
display=np.zeros([height,width,3],np.uint8)
|
38 |
-
display[:img1.shape[0]
|
39 |
-
display[:img2.shape[0],img1.shape[1]:]=img2
|
40 |
for i in range(len(corr1)):
|
41 |
-
left_x,left_y,right_x,right_y=
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import cv2
|
2 |
import numpy as np
|
3 |
|
4 |
+
|
5 |
+
def resize(image, long_dim):
|
6 |
+
h, w = image.shape[0], image.shape[1]
|
7 |
+
image = cv2.resize(
|
8 |
+
image, (int(w * long_dim / max(h, w)), int(h * long_dim / max(h, w)))
|
9 |
+
)
|
10 |
return image
|
11 |
|
12 |
+
|
13 |
+
def draw_points(img, points, color=(0, 255, 0), radius=3):
|
14 |
dp = [(int(points[i, 0]), int(points[i, 1])) for i in range(points.shape[0])]
|
15 |
for i in range(points.shape[0]):
|
16 |
+
cv2.circle(img, dp[i], radius=radius, color=color)
|
17 |
return img
|
|
|
18 |
|
19 |
+
|
20 |
+
def draw_match(
|
21 |
+
img1,
|
22 |
+
img2,
|
23 |
+
corr1,
|
24 |
+
corr2,
|
25 |
+
inlier=[True],
|
26 |
+
color=None,
|
27 |
+
radius1=1,
|
28 |
+
radius2=1,
|
29 |
+
resize=None,
|
30 |
+
):
|
31 |
if resize is not None:
|
32 |
+
scale1, scale2 = [img1.shape[1] / resize[0], img1.shape[0] / resize[1]], [
|
33 |
+
img2.shape[1] / resize[0],
|
34 |
+
img2.shape[0] / resize[1],
|
35 |
+
]
|
36 |
+
img1, img2 = cv2.resize(img1, resize, interpolation=cv2.INTER_AREA), cv2.resize(
|
37 |
+
img2, resize, interpolation=cv2.INTER_AREA
|
38 |
+
)
|
39 |
+
corr1, corr2 = (
|
40 |
+
corr1 / np.asarray(scale1)[np.newaxis],
|
41 |
+
corr2 / np.asarray(scale2)[np.newaxis],
|
42 |
+
)
|
43 |
+
corr1_key = [
|
44 |
+
cv2.KeyPoint(corr1[i, 0], corr1[i, 1], radius1) for i in range(corr1.shape[0])
|
45 |
+
]
|
46 |
+
corr2_key = [
|
47 |
+
cv2.KeyPoint(corr2[i, 0], corr2[i, 1], radius2) for i in range(corr2.shape[0])
|
48 |
+
]
|
49 |
|
50 |
assert len(corr1) == len(corr2)
|
51 |
|
52 |
draw_matches = [cv2.DMatch(i, i, 0) for i in range(len(corr1))]
|
53 |
if color is None:
|
54 |
+
color = [(0, 255, 0) if cur_inlier else (0, 0, 255) for cur_inlier in inlier]
|
55 |
+
if len(color) == 1:
|
56 |
+
display = cv2.drawMatches(
|
57 |
+
img1,
|
58 |
+
corr1_key,
|
59 |
+
img2,
|
60 |
+
corr2_key,
|
61 |
+
draw_matches,
|
62 |
+
None,
|
63 |
+
matchColor=color[0],
|
64 |
+
singlePointColor=color[0],
|
65 |
+
flags=4,
|
66 |
+
)
|
67 |
else:
|
68 |
+
height, width = max(img1.shape[0], img2.shape[0]), img1.shape[1] + img2.shape[1]
|
69 |
+
display = np.zeros([height, width, 3], np.uint8)
|
70 |
+
display[: img1.shape[0], : img1.shape[1]] = img1
|
71 |
+
display[: img2.shape[0], img1.shape[1] :] = img2
|
72 |
for i in range(len(corr1)):
|
73 |
+
left_x, left_y, right_x, right_y = (
|
74 |
+
int(corr1[i][0]),
|
75 |
+
int(corr1[i][1]),
|
76 |
+
int(corr2[i][0] + img1.shape[1]),
|
77 |
+
int(corr2[i][1]),
|
78 |
+
)
|
79 |
+
cur_color = (int(color[i][0]), int(color[i][1]), int(color[i][2]))
|
80 |
+
cv2.line(
|
81 |
+
display,
|
82 |
+
(left_x, left_y),
|
83 |
+
(right_x, right_y),
|
84 |
+
cur_color,
|
85 |
+
1,
|
86 |
+
lineType=cv2.LINE_AA,
|
87 |
+
)
|
88 |
+
return display
|
third_party/ASpanFormer/src/ASpanFormer/aspan_module/__init__.py
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
from .transformer import LocalFeatureTransformer_Flow
|
2 |
-
from .loftr import LocalFeatureTransformer
|
3 |
from .fine_preprocess import FinePreprocess
|
|
|
1 |
from .transformer import LocalFeatureTransformer_Flow
|
2 |
+
from .loftr import LocalFeatureTransformer
|
3 |
from .fine_preprocess import FinePreprocess
|
third_party/ASpanFormer/src/ASpanFormer/aspan_module/attention.py
CHANGED
@@ -4,39 +4,59 @@ import torch.nn as nn
|
|
4 |
from itertools import product
|
5 |
from torch.nn import functional as F
|
6 |
|
|
|
7 |
class layernorm2d(nn.Module):
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
19 |
|
20 |
|
21 |
class HierachicalAttention(Module):
|
22 |
-
def __init__(self,d_model,nhead,nsample,radius_scale,nlevel=3):
|
23 |
super().__init__()
|
24 |
-
self.d_model=d_model
|
25 |
-
self.nhead=nhead
|
26 |
-
self.nsample=nsample
|
27 |
-
self.nlevel=nlevel
|
28 |
-
self.radius_scale=radius_scale
|
29 |
self.merge_head = nn.Sequential(
|
30 |
-
nn.Conv1d(d_model*3, d_model, kernel_size=1,bias=False),
|
31 |
nn.ReLU(True),
|
32 |
-
nn.Conv1d(d_model, d_model, kernel_size=1,bias=False),
|
33 |
)
|
34 |
-
self.fullattention=FullAttention(d_model,nhead)
|
35 |
-
self.temp=nn.parameter.Parameter(torch.tensor(1.),requires_grad=True)
|
36 |
-
sample_offset=torch.tensor(
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
-
def forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
"""
|
41 |
Args:
|
42 |
q,k,v (torch.Tensor): [B, C, L]
|
@@ -45,123 +65,217 @@ class HierachicalAttention(Module):
|
|
45 |
Return:
|
46 |
all_message (torch.Tensor): [B, C, H, W]
|
47 |
"""
|
48 |
-
|
49 |
-
variance=flow[
|
50 |
-
offset=flow[
|
51 |
-
bs=query.shape[0]
|
52 |
-
h0,w0=size_q[0],size_q[1]
|
53 |
-
h1,w1=size_kv[0],size_kv[1]
|
54 |
-
variance=torch.exp(0.5*variance)*self.radius_scale
|
55 |
-
span_scale=torch.clamp((variance*2/self.nsample[1]),min=1)
|
56 |
-
|
57 |
-
sub_sample0,sub_sample1=[ds0,2,1],[ds1,2,1]
|
58 |
-
q_list=[
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
if mask0 is not None:
|
66 |
-
mask0,mask1=mask0.view(bs,1,h0,w0),mask1.view(bs,1,h1,w1)
|
67 |
-
mask0_list=[
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
else:
|
70 |
-
mask0_list=mask1_list=[None,None,None]
|
71 |
-
|
72 |
-
message_list=[]
|
73 |
-
#full attention at coarse scale
|
74 |
-
mask0_flatten=mask0_list[0].view(bs
|
75 |
-
mask1_flatten=mask1_list[0].view(bs
|
76 |
-
message_list.append(
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
return all_message
|
90 |
-
|
91 |
-
def partition_token(self,q,k,v,offset,span_scale,maskv):
|
92 |
-
#q,k,v: B*C*H*W
|
93 |
-
#o: B*H/2*W/2*2
|
94 |
-
#span_scale:B*H*W
|
95 |
-
bs=q.shape[0]
|
96 |
-
h,w=q.shape[2],q.shape[3]
|
97 |
-
hk,wk=k.shape[2],k.shape[3]
|
98 |
-
offset=offset.view(bs
|
99 |
-
span_scale=span_scale.view(bs
|
100 |
-
#B*G*2
|
101 |
-
offset_sample=self.sample_offset[None,None]*span_scale
|
102 |
-
sample_pixel=offset[
|
103 |
-
sample_norm=
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
if maskv is not None:
|
112 |
-
mask_sample=
|
|
|
|
|
|
|
|
|
|
|
113 |
else:
|
114 |
-
mask_sample=None
|
115 |
-
return q,k,v,sample_pixel,mask_sample
|
116 |
-
|
117 |
|
118 |
-
def group_attention(self,query,key,value,temp,mask_sample=None):
|
119 |
-
#q,k,v: B*Head*D*G*N(G*N=H*W for q)
|
120 |
-
bs=query.shape[0]
|
121 |
-
#import pdb;pdb.set_trace()
|
122 |
QK = torch.einsum("bhdgn,bhdgm->bhgnm", query, key)
|
123 |
if mask_sample is not None:
|
124 |
-
num_head,number_n=QK.shape[1],QK.shape[3]
|
125 |
-
QK.masked_fill_(
|
|
|
|
|
|
|
|
|
|
|
126 |
# Compute the attention and the weighted average
|
127 |
-
softmax_temp = temp / query.size(2)
|
128 |
A = torch.softmax(softmax_temp * QK, dim=-1)
|
129 |
-
queried_values =
|
|
|
|
|
|
|
|
|
130 |
return queried_values
|
131 |
|
132 |
-
|
133 |
|
134 |
class FullAttention(Module):
|
135 |
-
def __init__(self,d_model,nhead):
|
136 |
super().__init__()
|
137 |
-
self.d_model=d_model
|
138 |
-
self.nhead=nhead
|
139 |
|
140 |
-
def forward(self, q, k,v
|
141 |
-
"""
|
142 |
Args:
|
143 |
q,k,v: [N, D, L]
|
144 |
mask: [N, L]
|
145 |
Returns:
|
146 |
msg: [N,L]
|
147 |
"""
|
148 |
-
bs=q.shape[0]
|
149 |
-
q,k,v=
|
|
|
|
|
|
|
|
|
150 |
# Compute the unnormalized attention and apply the masks
|
151 |
QK = torch.einsum("nhdl,nhds->nhls", q, k)
|
152 |
if mask0 is not None:
|
153 |
-
QK.masked_fill_(
|
|
|
|
|
154 |
# Compute the attention and the weighted average
|
155 |
-
softmax_temp = temp / q.size(2)
|
156 |
A = torch.softmax(softmax_temp * QK, dim=-1)
|
157 |
-
queried_values =
|
|
|
|
|
|
|
|
|
158 |
return queried_values
|
159 |
-
|
160 |
-
|
161 |
|
162 |
def elu_feature_map(x):
|
163 |
return F.elu(x) + 1
|
164 |
|
|
|
165 |
class LinearAttention(Module):
|
166 |
def __init__(self, eps=1e-6):
|
167 |
super().__init__()
|
@@ -169,7 +283,7 @@ class LinearAttention(Module):
|
|
169 |
self.eps = eps
|
170 |
|
171 |
def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
|
172 |
-
"""
|
173 |
Args:
|
174 |
queries: [N, L, H, D]
|
175 |
keys: [N, S, H, D]
|
@@ -195,4 +309,4 @@ class LinearAttention(Module):
|
|
195 |
Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
|
196 |
queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
|
197 |
|
198 |
-
return queried_values.contiguous()
|
|
|
4 |
from itertools import product
|
5 |
from torch.nn import functional as F
|
6 |
|
7 |
+
|
8 |
class layernorm2d(nn.Module):
|
9 |
+
def __init__(self, dim):
|
10 |
+
super().__init__()
|
11 |
+
self.dim = dim
|
12 |
+
self.affine = nn.parameter.Parameter(torch.ones(dim), requires_grad=True)
|
13 |
+
self.bias = nn.parameter.Parameter(torch.zeros(dim), requires_grad=True)
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
# x: B*C*H*W
|
17 |
+
mean, std = x.mean(dim=1, keepdim=True), x.std(dim=1, keepdim=True)
|
18 |
+
return (
|
19 |
+
self.affine[None, :, None, None] * (x - mean) / (std + 1e-6)
|
20 |
+
+ self.bias[None, :, None, None]
|
21 |
+
)
|
22 |
|
23 |
|
24 |
class HierachicalAttention(Module):
|
25 |
+
def __init__(self, d_model, nhead, nsample, radius_scale, nlevel=3):
|
26 |
super().__init__()
|
27 |
+
self.d_model = d_model
|
28 |
+
self.nhead = nhead
|
29 |
+
self.nsample = nsample
|
30 |
+
self.nlevel = nlevel
|
31 |
+
self.radius_scale = radius_scale
|
32 |
self.merge_head = nn.Sequential(
|
33 |
+
nn.Conv1d(d_model * 3, d_model, kernel_size=1, bias=False),
|
34 |
nn.ReLU(True),
|
35 |
+
nn.Conv1d(d_model, d_model, kernel_size=1, bias=False),
|
36 |
)
|
37 |
+
self.fullattention = FullAttention(d_model, nhead)
|
38 |
+
self.temp = nn.parameter.Parameter(torch.tensor(1.0), requires_grad=True)
|
39 |
+
sample_offset = torch.tensor(
|
40 |
+
[
|
41 |
+
[pos[0] - nsample[1] / 2 + 0.5, pos[1] - nsample[1] / 2 + 0.5]
|
42 |
+
for pos in product(range(nsample[1]), range(nsample[1]))
|
43 |
+
]
|
44 |
+
) # r^2*2
|
45 |
+
self.sample_offset = nn.parameter.Parameter(sample_offset, requires_grad=False)
|
46 |
|
47 |
+
def forward(
|
48 |
+
self,
|
49 |
+
query,
|
50 |
+
key,
|
51 |
+
value,
|
52 |
+
flow,
|
53 |
+
size_q,
|
54 |
+
size_kv,
|
55 |
+
mask0=None,
|
56 |
+
mask1=None,
|
57 |
+
ds0=[4, 4],
|
58 |
+
ds1=[4, 4],
|
59 |
+
):
|
60 |
"""
|
61 |
Args:
|
62 |
q,k,v (torch.Tensor): [B, C, L]
|
|
|
65 |
Return:
|
66 |
all_message (torch.Tensor): [B, C, H, W]
|
67 |
"""
|
68 |
+
|
69 |
+
variance = flow[:, :, :, 2:]
|
70 |
+
offset = flow[:, :, :, :2] # B*H*W*2
|
71 |
+
bs = query.shape[0]
|
72 |
+
h0, w0 = size_q[0], size_q[1]
|
73 |
+
h1, w1 = size_kv[0], size_kv[1]
|
74 |
+
variance = torch.exp(0.5 * variance) * self.radius_scale # b*h*w*2(pixel scale)
|
75 |
+
span_scale = torch.clamp((variance * 2 / self.nsample[1]), min=1) # b*h*w*2
|
76 |
+
|
77 |
+
sub_sample0, sub_sample1 = [ds0, 2, 1], [ds1, 2, 1]
|
78 |
+
q_list = [
|
79 |
+
F.avg_pool2d(
|
80 |
+
query.view(bs, -1, h0, w0), kernel_size=sub_size, stride=sub_size
|
81 |
+
)
|
82 |
+
for sub_size in sub_sample0
|
83 |
+
]
|
84 |
+
k_list = [
|
85 |
+
F.avg_pool2d(
|
86 |
+
key.view(bs, -1, h1, w1), kernel_size=sub_size, stride=sub_size
|
87 |
+
)
|
88 |
+
for sub_size in sub_sample1
|
89 |
+
]
|
90 |
+
v_list = [
|
91 |
+
F.avg_pool2d(
|
92 |
+
value.view(bs, -1, h1, w1), kernel_size=sub_size, stride=sub_size
|
93 |
+
)
|
94 |
+
for sub_size in sub_sample1
|
95 |
+
] # n_level
|
96 |
+
|
97 |
+
offset_list = [
|
98 |
+
F.avg_pool2d(
|
99 |
+
offset.permute(0, 3, 1, 2),
|
100 |
+
kernel_size=sub_size * self.nsample[0],
|
101 |
+
stride=sub_size * self.nsample[0],
|
102 |
+
).permute(0, 2, 3, 1)
|
103 |
+
/ sub_size
|
104 |
+
for sub_size in sub_sample0[1:]
|
105 |
+
] # n_level-1
|
106 |
+
span_list = [
|
107 |
+
F.avg_pool2d(
|
108 |
+
span_scale.permute(0, 3, 1, 2),
|
109 |
+
kernel_size=sub_size * self.nsample[0],
|
110 |
+
stride=sub_size * self.nsample[0],
|
111 |
+
).permute(0, 2, 3, 1)
|
112 |
+
for sub_size in sub_sample0[1:]
|
113 |
+
] # n_level-1
|
114 |
|
115 |
if mask0 is not None:
|
116 |
+
mask0, mask1 = mask0.view(bs, 1, h0, w0), mask1.view(bs, 1, h1, w1)
|
117 |
+
mask0_list = [
|
118 |
+
-F.max_pool2d(-mask0, kernel_size=sub_size, stride=sub_size)
|
119 |
+
for sub_size in sub_sample0
|
120 |
+
]
|
121 |
+
mask1_list = [
|
122 |
+
-F.max_pool2d(-mask1, kernel_size=sub_size, stride=sub_size)
|
123 |
+
for sub_size in sub_sample1
|
124 |
+
]
|
125 |
else:
|
126 |
+
mask0_list = mask1_list = [None, None, None]
|
127 |
+
|
128 |
+
message_list = []
|
129 |
+
# full attention at coarse scale
|
130 |
+
mask0_flatten = mask0_list[0].view(bs, -1) if mask0 is not None else None
|
131 |
+
mask1_flatten = mask1_list[0].view(bs, -1) if mask1 is not None else None
|
132 |
+
message_list.append(
|
133 |
+
self.fullattention(
|
134 |
+
q_list[0], k_list[0], v_list[0], mask0_flatten, mask1_flatten, self.temp
|
135 |
+
).view(bs, self.d_model, h0 // ds0[0], w0 // ds0[1])
|
136 |
+
)
|
137 |
+
|
138 |
+
for index in range(1, self.nlevel):
|
139 |
+
q, k, v = q_list[index], k_list[index], v_list[index]
|
140 |
+
mask0, mask1 = mask0_list[index], mask1_list[index]
|
141 |
+
s, o = span_list[index - 1], offset_list[index - 1] # B*h*w(*2)
|
142 |
+
q, k, v, sample_pixel, mask_sample = self.partition_token(
|
143 |
+
q, k, v, o, s, mask0
|
144 |
+
) # B*Head*D*G*N(G*N=H*W for q)
|
145 |
+
message_list.append(
|
146 |
+
self.group_attention(q, k, v, 1, mask_sample).view(
|
147 |
+
bs, self.d_model, h0 // sub_sample0[index], w0 // sub_sample0[index]
|
148 |
+
)
|
149 |
+
)
|
150 |
+
# fuse
|
151 |
+
all_message = torch.cat(
|
152 |
+
[
|
153 |
+
F.upsample(
|
154 |
+
message_list[idx], scale_factor=sub_sample0[idx], mode="nearest"
|
155 |
+
)
|
156 |
+
for idx in range(self.nlevel)
|
157 |
+
],
|
158 |
+
dim=1,
|
159 |
+
).view(
|
160 |
+
bs, -1, h0 * w0
|
161 |
+
) # b*3d*H*W
|
162 |
+
|
163 |
+
all_message = self.merge_head(all_message).view(bs, -1, h0, w0) # b*d*H*W
|
164 |
return all_message
|
165 |
+
|
166 |
+
def partition_token(self, q, k, v, offset, span_scale, maskv):
|
167 |
+
# q,k,v: B*C*H*W
|
168 |
+
# o: B*H/2*W/2*2
|
169 |
+
# span_scale:B*H*W
|
170 |
+
bs = q.shape[0]
|
171 |
+
h, w = q.shape[2], q.shape[3]
|
172 |
+
hk, wk = k.shape[2], k.shape[3]
|
173 |
+
offset = offset.view(bs, -1, 2)
|
174 |
+
span_scale = span_scale.view(bs, -1, 1, 2)
|
175 |
+
# B*G*2
|
176 |
+
offset_sample = self.sample_offset[None, None] * span_scale
|
177 |
+
sample_pixel = offset[:, :, None] + offset_sample # B*G*r^2*2
|
178 |
+
sample_norm = (
|
179 |
+
sample_pixel / torch.tensor([wk / 2, hk / 2]).cuda()[None, None, None] - 1
|
180 |
+
)
|
181 |
+
|
182 |
+
q = (
|
183 |
+
q.view(
|
184 |
+
bs,
|
185 |
+
-1,
|
186 |
+
h // self.nsample[0],
|
187 |
+
self.nsample[0],
|
188 |
+
w // self.nsample[0],
|
189 |
+
self.nsample[0],
|
190 |
+
)
|
191 |
+
.permute(0, 1, 2, 4, 3, 5)
|
192 |
+
.contiguous()
|
193 |
+
.view(bs, self.nhead, self.d_model // self.nhead, -1, self.nsample[0] ** 2)
|
194 |
+
) # B*head*D*G*N(G*N=H*W for q)
|
195 |
+
# sample token
|
196 |
+
k = F.grid_sample(k, grid=sample_norm).view(
|
197 |
+
bs, self.nhead, self.d_model // self.nhead, -1, self.nsample[1] ** 2
|
198 |
+
) # B*head*D*G*r^2
|
199 |
+
v = F.grid_sample(v, grid=sample_norm).view(
|
200 |
+
bs, self.nhead, self.d_model // self.nhead, -1, self.nsample[1] ** 2
|
201 |
+
) # B*head*D*G*r^2
|
202 |
+
# import pdb;pdb.set_trace()
|
203 |
if maskv is not None:
|
204 |
+
mask_sample = (
|
205 |
+
F.grid_sample(
|
206 |
+
maskv.view(bs, -1, h, w).float(), grid=sample_norm, mode="nearest"
|
207 |
+
)
|
208 |
+
== 1
|
209 |
+
) # B*1*G*r^2
|
210 |
else:
|
211 |
+
mask_sample = None
|
212 |
+
return q, k, v, sample_pixel, mask_sample
|
|
|
213 |
|
214 |
+
def group_attention(self, query, key, value, temp, mask_sample=None):
|
215 |
+
# q,k,v: B*Head*D*G*N(G*N=H*W for q)
|
216 |
+
bs = query.shape[0]
|
217 |
+
# import pdb;pdb.set_trace()
|
218 |
QK = torch.einsum("bhdgn,bhdgm->bhgnm", query, key)
|
219 |
if mask_sample is not None:
|
220 |
+
num_head, number_n = QK.shape[1], QK.shape[3]
|
221 |
+
QK.masked_fill_(
|
222 |
+
~(mask_sample[:, :, :, None])
|
223 |
+
.expand(-1, num_head, -1, number_n, -1)
|
224 |
+
.bool(),
|
225 |
+
float(-1e8),
|
226 |
+
)
|
227 |
# Compute the attention and the weighted average
|
228 |
+
softmax_temp = temp / query.size(2) ** 0.5 # sqrt(D)
|
229 |
A = torch.softmax(softmax_temp * QK, dim=-1)
|
230 |
+
queried_values = (
|
231 |
+
torch.einsum("bhgnm,bhdgm->bhdgn", A, value)
|
232 |
+
.contiguous()
|
233 |
+
.view(bs, self.d_model, -1)
|
234 |
+
)
|
235 |
return queried_values
|
236 |
|
|
|
237 |
|
238 |
class FullAttention(Module):
|
239 |
+
def __init__(self, d_model, nhead):
|
240 |
super().__init__()
|
241 |
+
self.d_model = d_model
|
242 |
+
self.nhead = nhead
|
243 |
|
244 |
+
def forward(self, q, k, v, mask0=None, mask1=None, temp=1):
|
245 |
+
"""Multi-head scaled dot-product attention, a.k.a full attention.
|
246 |
Args:
|
247 |
q,k,v: [N, D, L]
|
248 |
mask: [N, L]
|
249 |
Returns:
|
250 |
msg: [N,L]
|
251 |
"""
|
252 |
+
bs = q.shape[0]
|
253 |
+
q, k, v = (
|
254 |
+
q.view(bs, self.nhead, self.d_model // self.nhead, -1),
|
255 |
+
k.view(bs, self.nhead, self.d_model // self.nhead, -1),
|
256 |
+
v.view(bs, self.nhead, self.d_model // self.nhead, -1),
|
257 |
+
)
|
258 |
# Compute the unnormalized attention and apply the masks
|
259 |
QK = torch.einsum("nhdl,nhds->nhls", q, k)
|
260 |
if mask0 is not None:
|
261 |
+
QK.masked_fill_(
|
262 |
+
~(mask0[:, None, :, None] * mask1[:, None, None]).bool(), float(-1e8)
|
263 |
+
)
|
264 |
# Compute the attention and the weighted average
|
265 |
+
softmax_temp = temp / q.size(2) ** 0.5 # sqrt(D)
|
266 |
A = torch.softmax(softmax_temp * QK, dim=-1)
|
267 |
+
queried_values = (
|
268 |
+
torch.einsum("nhls,nhds->nhdl", A, v)
|
269 |
+
.contiguous()
|
270 |
+
.view(bs, self.d_model, -1)
|
271 |
+
)
|
272 |
return queried_values
|
273 |
+
|
|
|
274 |
|
275 |
def elu_feature_map(x):
|
276 |
return F.elu(x) + 1
|
277 |
|
278 |
+
|
279 |
class LinearAttention(Module):
|
280 |
def __init__(self, eps=1e-6):
|
281 |
super().__init__()
|
|
|
283 |
self.eps = eps
|
284 |
|
285 |
def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
|
286 |
+
"""Multi-Head linear attention proposed in "Transformers are RNNs"
|
287 |
Args:
|
288 |
queries: [N, L, H, D]
|
289 |
keys: [N, S, H, D]
|
|
|
309 |
Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
|
310 |
queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
|
311 |
|
312 |
+
return queried_values.contiguous()
|
third_party/ASpanFormer/src/ASpanFormer/aspan_module/fine_preprocess.py
CHANGED
@@ -9,15 +9,15 @@ class FinePreprocess(nn.Module):
|
|
9 |
super().__init__()
|
10 |
|
11 |
self.config = config
|
12 |
-
self.cat_c_feat = config[
|
13 |
-
self.W = self.config[
|
14 |
|
15 |
-
d_model_c = self.config[
|
16 |
-
d_model_f = self.config[
|
17 |
self.d_model_f = d_model_f
|
18 |
if self.cat_c_feat:
|
19 |
self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True)
|
20 |
-
self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True)
|
21 |
|
22 |
self._reset_parameters()
|
23 |
|
@@ -28,32 +28,48 @@ class FinePreprocess(nn.Module):
|
|
28 |
|
29 |
def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data):
|
30 |
W = self.W
|
31 |
-
stride = data[
|
32 |
|
33 |
-
data.update({
|
34 |
-
if data[
|
35 |
feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
|
36 |
feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
|
37 |
return feat0, feat1
|
38 |
|
39 |
# 1. unfold(crop) all local windows
|
40 |
-
feat_f0_unfold = F.unfold(
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
44 |
|
45 |
# 2. select only the predicted matches
|
46 |
-
feat_f0_unfold = feat_f0_unfold[data[
|
47 |
-
feat_f1_unfold = feat_f1_unfold[data[
|
48 |
|
49 |
# option: use coarse-level loftr feature as context: concat and linear
|
50 |
if self.cat_c_feat:
|
51 |
-
feat_c_win = self.down_proj(
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0)
|
58 |
|
59 |
return feat_f0_unfold, feat_f1_unfold
|
|
|
9 |
super().__init__()
|
10 |
|
11 |
self.config = config
|
12 |
+
self.cat_c_feat = config["fine_concat_coarse_feat"]
|
13 |
+
self.W = self.config["fine_window_size"]
|
14 |
|
15 |
+
d_model_c = self.config["coarse"]["d_model"]
|
16 |
+
d_model_f = self.config["fine"]["d_model"]
|
17 |
self.d_model_f = d_model_f
|
18 |
if self.cat_c_feat:
|
19 |
self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True)
|
20 |
+
self.merge_feat = nn.Linear(2 * d_model_f, d_model_f, bias=True)
|
21 |
|
22 |
self._reset_parameters()
|
23 |
|
|
|
28 |
|
29 |
def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data):
|
30 |
W = self.W
|
31 |
+
stride = data["hw0_f"][0] // data["hw0_c"][0]
|
32 |
|
33 |
+
data.update({"W": W})
|
34 |
+
if data["b_ids"].shape[0] == 0:
|
35 |
feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
|
36 |
feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
|
37 |
return feat0, feat1
|
38 |
|
39 |
# 1. unfold(crop) all local windows
|
40 |
+
feat_f0_unfold = F.unfold(
|
41 |
+
feat_f0, kernel_size=(W, W), stride=stride, padding=W // 2
|
42 |
+
)
|
43 |
+
feat_f0_unfold = rearrange(feat_f0_unfold, "n (c ww) l -> n l ww c", ww=W**2)
|
44 |
+
feat_f1_unfold = F.unfold(
|
45 |
+
feat_f1, kernel_size=(W, W), stride=stride, padding=W // 2
|
46 |
+
)
|
47 |
+
feat_f1_unfold = rearrange(feat_f1_unfold, "n (c ww) l -> n l ww c", ww=W**2)
|
48 |
|
49 |
# 2. select only the predicted matches
|
50 |
+
feat_f0_unfold = feat_f0_unfold[data["b_ids"], data["i_ids"]] # [n, ww, cf]
|
51 |
+
feat_f1_unfold = feat_f1_unfold[data["b_ids"], data["j_ids"]]
|
52 |
|
53 |
# option: use coarse-level loftr feature as context: concat and linear
|
54 |
if self.cat_c_feat:
|
55 |
+
feat_c_win = self.down_proj(
|
56 |
+
torch.cat(
|
57 |
+
[
|
58 |
+
feat_c0[data["b_ids"], data["i_ids"]],
|
59 |
+
feat_c1[data["b_ids"], data["j_ids"]],
|
60 |
+
],
|
61 |
+
0,
|
62 |
+
)
|
63 |
+
) # [2n, c]
|
64 |
+
feat_cf_win = self.merge_feat(
|
65 |
+
torch.cat(
|
66 |
+
[
|
67 |
+
torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf]
|
68 |
+
repeat(feat_c_win, "n c -> n ww c", ww=W**2), # [2n, ww, cf]
|
69 |
+
],
|
70 |
+
-1,
|
71 |
+
)
|
72 |
+
)
|
73 |
feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0)
|
74 |
|
75 |
return feat_f0_unfold, feat_f1_unfold
|
third_party/ASpanFormer/src/ASpanFormer/aspan_module/loftr.py
CHANGED
@@ -3,11 +3,9 @@ import torch
|
|
3 |
import torch.nn as nn
|
4 |
from .attention import LinearAttention
|
5 |
|
|
|
6 |
class LoFTREncoderLayer(nn.Module):
|
7 |
-
def __init__(self,
|
8 |
-
d_model,
|
9 |
-
nhead,
|
10 |
-
attention='linear'):
|
11 |
super(LoFTREncoderLayer, self).__init__()
|
12 |
|
13 |
self.dim = d_model // nhead
|
@@ -22,9 +20,9 @@ class LoFTREncoderLayer(nn.Module):
|
|
22 |
|
23 |
# feed-forward network
|
24 |
self.mlp = nn.Sequential(
|
25 |
-
nn.Linear(d_model*2, d_model*2, bias=False),
|
26 |
nn.ReLU(True),
|
27 |
-
nn.Linear(d_model*2, d_model, bias=False),
|
28 |
)
|
29 |
|
30 |
# norm and dropout
|
@@ -43,16 +41,14 @@ class LoFTREncoderLayer(nn.Module):
|
|
43 |
query, key, value = x, source, source
|
44 |
|
45 |
# multi-head attention
|
46 |
-
query = self.q_proj(query).view(
|
47 |
-
|
48 |
-
key = self.k_proj(key).view(bs, -1, self.nhead,
|
49 |
-
self.dim) # [N, S, (H, D)]
|
50 |
value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
|
51 |
|
52 |
message = self.attention(
|
53 |
-
query, key, value, q_mask=x_mask, kv_mask=source_mask
|
54 |
-
|
55 |
-
|
56 |
message = self.norm1(message)
|
57 |
|
58 |
# feed-forward network
|
@@ -69,13 +65,15 @@ class LocalFeatureTransformer(nn.Module):
|
|
69 |
super(LocalFeatureTransformer, self).__init__()
|
70 |
|
71 |
self.config = config
|
72 |
-
self.d_model = config[
|
73 |
-
self.nhead = config[
|
74 |
-
self.layer_names = config[
|
75 |
encoder_layer = LoFTREncoderLayer(
|
76 |
-
config[
|
|
|
77 |
self.layers = nn.ModuleList(
|
78 |
-
[copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]
|
|
|
79 |
self._reset_parameters()
|
80 |
|
81 |
def _reset_parameters(self):
|
@@ -93,20 +91,18 @@ class LocalFeatureTransformer(nn.Module):
|
|
93 |
"""
|
94 |
|
95 |
assert self.d_model == feat0.size(
|
96 |
-
2
|
|
|
97 |
|
98 |
index = 0
|
99 |
for layer, name in zip(self.layers, self.layer_names):
|
100 |
-
if name ==
|
101 |
-
feat0 = layer(feat0, feat0, mask0, mask0,
|
102 |
-
type='self', index=index)
|
103 |
feat1 = layer(feat1, feat1, mask1, mask1)
|
104 |
-
elif name ==
|
105 |
feat0 = layer(feat0, feat1, mask0, mask1)
|
106 |
-
feat1 = layer(feat1, feat0, mask1, mask0,
|
107 |
-
type='cross', index=index)
|
108 |
index += 1
|
109 |
else:
|
110 |
raise KeyError
|
111 |
return feat0, feat1
|
112 |
-
|
|
|
3 |
import torch.nn as nn
|
4 |
from .attention import LinearAttention
|
5 |
|
6 |
+
|
7 |
class LoFTREncoderLayer(nn.Module):
|
8 |
+
def __init__(self, d_model, nhead, attention="linear"):
|
|
|
|
|
|
|
9 |
super(LoFTREncoderLayer, self).__init__()
|
10 |
|
11 |
self.dim = d_model // nhead
|
|
|
20 |
|
21 |
# feed-forward network
|
22 |
self.mlp = nn.Sequential(
|
23 |
+
nn.Linear(d_model * 2, d_model * 2, bias=False),
|
24 |
nn.ReLU(True),
|
25 |
+
nn.Linear(d_model * 2, d_model, bias=False),
|
26 |
)
|
27 |
|
28 |
# norm and dropout
|
|
|
41 |
query, key, value = x, source, source
|
42 |
|
43 |
# multi-head attention
|
44 |
+
query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
|
45 |
+
key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
|
|
|
|
|
46 |
value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
|
47 |
|
48 |
message = self.attention(
|
49 |
+
query, key, value, q_mask=x_mask, kv_mask=source_mask
|
50 |
+
) # [N, L, (H, D)]
|
51 |
+
message = self.merge(message.view(bs, -1, self.nhead * self.dim)) # [N, L, C]
|
52 |
message = self.norm1(message)
|
53 |
|
54 |
# feed-forward network
|
|
|
65 |
super(LocalFeatureTransformer, self).__init__()
|
66 |
|
67 |
self.config = config
|
68 |
+
self.d_model = config["d_model"]
|
69 |
+
self.nhead = config["nhead"]
|
70 |
+
self.layer_names = config["layer_names"]
|
71 |
encoder_layer = LoFTREncoderLayer(
|
72 |
+
config["d_model"], config["nhead"], config["attention"]
|
73 |
+
)
|
74 |
self.layers = nn.ModuleList(
|
75 |
+
[copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]
|
76 |
+
)
|
77 |
self._reset_parameters()
|
78 |
|
79 |
def _reset_parameters(self):
|
|
|
91 |
"""
|
92 |
|
93 |
assert self.d_model == feat0.size(
|
94 |
+
2
|
95 |
+
), "the feature number of src and transformer must be equal"
|
96 |
|
97 |
index = 0
|
98 |
for layer, name in zip(self.layers, self.layer_names):
|
99 |
+
if name == "self":
|
100 |
+
feat0 = layer(feat0, feat0, mask0, mask0, type="self", index=index)
|
|
|
101 |
feat1 = layer(feat1, feat1, mask1, mask1)
|
102 |
+
elif name == "cross":
|
103 |
feat0 = layer(feat0, feat1, mask0, mask1)
|
104 |
+
feat1 = layer(feat1, feat0, mask1, mask0, type="cross", index=index)
|
|
|
105 |
index += 1
|
106 |
else:
|
107 |
raise KeyError
|
108 |
return feat0, feat1
|
|
third_party/ASpanFormer/src/ASpanFormer/aspan_module/transformer.py
CHANGED
@@ -2,44 +2,42 @@ import copy
|
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
import torch.nn.functional as F
|
5 |
-
from .attention import FullAttention, HierachicalAttention
|
6 |
|
7 |
|
8 |
class messageLayer_ini(nn.Module):
|
9 |
-
|
10 |
-
def __init__(self, d_model, d_flow,d_value, nhead):
|
11 |
super().__init__()
|
12 |
super(messageLayer_ini, self).__init__()
|
13 |
|
14 |
self.d_model = d_model
|
15 |
self.d_flow = d_flow
|
16 |
-
self.d_value=d_value
|
17 |
self.nhead = nhead
|
18 |
-
self.attention = FullAttention(d_model,nhead)
|
19 |
|
20 |
-
self.q_proj = nn.Conv1d(d_model, d_model, kernel_size=1,bias=False)
|
21 |
-
self.k_proj = nn.Conv1d(d_model, d_model, kernel_size=1,bias=False)
|
22 |
-
self.v_proj = nn.Conv1d(d_value, d_model, kernel_size=1,bias=False)
|
23 |
-
self.merge_head=nn.Conv1d(d_model,d_model,kernel_size=1,bias=False)
|
24 |
|
25 |
-
self.merge_f= self.merge_f = nn.Sequential(
|
26 |
-
nn.Conv2d(d_model*2, d_model*2, kernel_size=1, bias=False),
|
27 |
nn.ReLU(True),
|
28 |
-
nn.Conv2d(d_model*2, d_model, kernel_size=1, bias=False),
|
29 |
)
|
30 |
|
31 |
self.norm1 = layernorm2d(d_model)
|
32 |
self.norm2 = layernorm2d(d_model)
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
-
def
|
36 |
-
#x1,x2: b*d*L
|
37 |
-
x0,x1=self.update(x0,x1,pos1,mask0,mask1),\
|
38 |
-
self.update(x1,x0,pos0,mask1,mask0)
|
39 |
-
return x0,x1
|
40 |
-
|
41 |
-
|
42 |
-
def update(self,f0,f1,pos1,mask0,mask1):
|
43 |
"""
|
44 |
Args:
|
45 |
f0: [N, D, H, W]
|
@@ -47,53 +45,77 @@ class messageLayer_ini(nn.Module):
|
|
47 |
Returns:
|
48 |
f0_new: (N, d, h, w)
|
49 |
"""
|
50 |
-
bs,h,w=f0.shape[0],f0.shape[2],f0.shape[3]
|
51 |
|
52 |
-
f0_flatten,f1_flatten=f0.view(bs,self.d_model
|
53 |
-
|
54 |
-
|
|
|
|
|
55 |
|
56 |
-
queries,keys=self.q_proj(f0_flatten),self.k_proj(f1_flatten)
|
57 |
-
values=self.v_proj(f1_flatten_v).view(
|
58 |
-
|
59 |
-
|
60 |
-
msg=self.merge_head(queried_values).view(bs,-1,h,w)
|
61 |
-
msg=self.norm2(self.merge_f(torch.cat([f0,self.norm1(msg)],dim=1)))
|
62 |
-
return f0+msg
|
63 |
|
|
|
|
|
|
|
|
|
64 |
|
65 |
|
66 |
class messageLayer_gla(nn.Module):
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
super().__init__()
|
71 |
self.d_model = d_model
|
72 |
-
self.d_flow=d_flow
|
73 |
-
self.d_value=d_value
|
74 |
self.nhead = nhead
|
75 |
-
self.radius_scale=radius_scale
|
76 |
-
self.update_flow=update_flow
|
77 |
-
self.flow_decoder=nn.Sequential(
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
self.q_proj = nn.Conv1d(d_model, d_model, kernel_size=1,bias=False)
|
84 |
-
self.k_proj = nn.Conv1d(d_model, d_model, kernel_size=1,bias=False)
|
85 |
-
self.v_proj = nn.Conv1d(d_value, d_model, kernel_size=1,bias=False)
|
86 |
-
|
87 |
-
d_extra=d_flow if update_flow else 0
|
88 |
-
self.merge_f=nn.Sequential(
|
89 |
-
nn.Conv2d(d_model*2+d_extra, d_model+d_flow, kernel_size=1, bias=False),
|
90 |
-
nn.ReLU(True),
|
91 |
-
nn.Conv2d(d_model+d_flow, d_model+d_extra, kernel_size=3,padding=1, bias=False),
|
92 |
-
)
|
93 |
-
self.norm1 = layernorm2d(d_model)
|
94 |
-
self.norm2 = layernorm2d(d_model+d_extra)
|
95 |
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
"""
|
98 |
Args:
|
99 |
x0 (torch.Tensor): [B, C, H, W]
|
@@ -101,88 +123,135 @@ class messageLayer_gla(nn.Module):
|
|
101 |
flow_feature0 (torch.Tensor): [B, C', H, W]
|
102 |
flow_feature1 (torch.Tensor): [B, C', H, W]
|
103 |
"""
|
104 |
-
flow0,flow1=self.decode_flow(
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
if self.update_flow:
|
117 |
-
update_feature=torch.cat([x0,flow_feature0],dim=1)
|
118 |
else:
|
119 |
-
update_feature=x0
|
120 |
-
msg=self.norm2(
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
return flow
|
134 |
|
135 |
|
136 |
class flow_initializer(nn.Module):
|
137 |
-
|
138 |
def __init__(self, dim, dim_flow, nhead, layer_num):
|
139 |
super().__init__()
|
140 |
-
self.layer_num= layer_num
|
141 |
self.dim = dim
|
142 |
self.dim_flow = dim_flow
|
143 |
|
144 |
-
encoder_layer = messageLayer_ini(
|
145 |
-
dim ,dim_flow,dim+dim_flow , nhead)
|
146 |
self.layers_coarse = nn.ModuleList(
|
147 |
-
[copy.deepcopy(encoder_layer) for _ in range(layer_num)]
|
148 |
-
|
149 |
-
|
150 |
-
self.up_merge = nn.Conv2d(2*dim, dim, kernel_size=1)
|
151 |
|
152 |
-
def forward(
|
|
|
|
|
153 |
# feat0: [B, C, H0, W0]
|
154 |
# feat1: [B, C, H1, W1]
|
155 |
# use low-res MHA to initialize flow feature
|
156 |
bs = feat0.size(0)
|
157 |
-
h0,w0,h1,w1=feat0.shape[2],feat0.shape[3],feat1.shape[2],feat1.shape[3]
|
158 |
|
159 |
# coarse level
|
160 |
-
sub_feat0, sub_feat1 = F.avg_pool2d(feat0, ds0, stride=ds0),
|
161 |
-
|
|
|
|
|
|
|
|
|
|
|
162 |
|
163 |
-
sub_pos0,sub_pos1=F.avg_pool2d(pos0, ds0, stride=ds0), \
|
164 |
-
F.avg_pool2d(pos1, ds1, stride=ds1)
|
165 |
-
|
166 |
if mask0 is not None:
|
167 |
-
mask0,mask1
|
168 |
-
|
169 |
-
|
|
|
|
|
|
|
|
|
|
|
170 |
for layer in self.layers_coarse:
|
171 |
-
sub_feat0, sub_feat1 = layer(
|
|
|
|
|
172 |
# decouple flow and visual features
|
173 |
-
decoupled_feature0, decoupled_feature1 = self.decoupler(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
update_feat1, flow_feature1 = F.upsample(sub_feat1, scale_factor=ds1, mode='bilinear'),\
|
180 |
-
F.upsample(sub_flow_feature1, scale_factor=ds1, mode='bilinear')
|
181 |
-
|
182 |
-
feat0 = feat0+self.up_merge(torch.cat([feat0, update_feat0], dim=1))
|
183 |
-
feat1 = feat1+self.up_merge(torch.cat([feat1, update_feat1], dim=1))
|
184 |
-
|
185 |
-
return feat0,feat1,flow_feature0,flow_feature1 #b*c*h*w
|
186 |
|
187 |
|
188 |
class LocalFeatureTransformer_Flow(nn.Module):
|
@@ -192,27 +261,49 @@ class LocalFeatureTransformer_Flow(nn.Module):
|
|
192 |
super(LocalFeatureTransformer_Flow, self).__init__()
|
193 |
|
194 |
self.config = config
|
195 |
-
self.d_model = config[
|
196 |
-
self.nhead = config[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
|
198 |
-
self.pos_transform=nn.Conv2d(config['d_model'],config['d_flow'],kernel_size=1,bias=False)
|
199 |
-
self.ini_layer = flow_initializer(self.d_model, config['d_flow'], config['nhead'],config['ini_layer_num'])
|
200 |
-
|
201 |
encoder_layer = messageLayer_gla(
|
202 |
-
config[
|
203 |
-
|
204 |
-
config[
|
205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
self._reset_parameters()
|
207 |
-
|
208 |
def _reset_parameters(self):
|
209 |
-
for name,p in self.named_parameters():
|
210 |
-
if
|
211 |
continue
|
212 |
if p.dim() > 1:
|
213 |
nn.init.xavier_uniform_(p)
|
214 |
|
215 |
-
def forward(
|
|
|
|
|
216 |
"""
|
217 |
Args:
|
218 |
feat0 (torch.Tensor): [N, C, H, W]
|
@@ -224,21 +315,37 @@ class LocalFeatureTransformer_Flow(nn.Module):
|
|
224 |
flow_list: [L,N,H,W,4]*1(2)
|
225 |
"""
|
226 |
bs = feat0.size(0)
|
227 |
-
|
228 |
-
pos0,pos1=self.pos_transform(pos0),self.pos_transform(pos1)
|
229 |
-
pos0,pos1=pos0.expand(bs
|
230 |
assert self.d_model == feat0.size(
|
231 |
-
1
|
232 |
-
|
233 |
-
|
|
|
234 |
if mask0 is not None:
|
235 |
-
mask0,mask1=mask0[:,None].float(),mask1[:,None].float()
|
236 |
-
feat0,feat1, flow_feature0, flow_feature1 = self.ini_layer(
|
|
|
|
|
237 |
for layer in self.layers:
|
238 |
-
feat0,feat1,flow_feature0,flow_feature1,flow0,flow1=layer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
flow_list[0].append(flow0)
|
240 |
flow_list[1].append(flow1)
|
241 |
-
flow_list[0]=torch.stack(flow_list[0],dim=0)
|
242 |
-
flow_list[1]=torch.stack(flow_list[1],dim=0)
|
243 |
-
feat0, feat1 = feat0.permute(0, 2, 3, 1).view(
|
244 |
-
|
|
|
|
|
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
import torch.nn.functional as F
|
5 |
+
from .attention import FullAttention, HierachicalAttention, layernorm2d
|
6 |
|
7 |
|
8 |
class messageLayer_ini(nn.Module):
|
9 |
+
def __init__(self, d_model, d_flow, d_value, nhead):
|
|
|
10 |
super().__init__()
|
11 |
super(messageLayer_ini, self).__init__()
|
12 |
|
13 |
self.d_model = d_model
|
14 |
self.d_flow = d_flow
|
15 |
+
self.d_value = d_value
|
16 |
self.nhead = nhead
|
17 |
+
self.attention = FullAttention(d_model, nhead)
|
18 |
|
19 |
+
self.q_proj = nn.Conv1d(d_model, d_model, kernel_size=1, bias=False)
|
20 |
+
self.k_proj = nn.Conv1d(d_model, d_model, kernel_size=1, bias=False)
|
21 |
+
self.v_proj = nn.Conv1d(d_value, d_model, kernel_size=1, bias=False)
|
22 |
+
self.merge_head = nn.Conv1d(d_model, d_model, kernel_size=1, bias=False)
|
23 |
|
24 |
+
self.merge_f = self.merge_f = nn.Sequential(
|
25 |
+
nn.Conv2d(d_model * 2, d_model * 2, kernel_size=1, bias=False),
|
26 |
nn.ReLU(True),
|
27 |
+
nn.Conv2d(d_model * 2, d_model, kernel_size=1, bias=False),
|
28 |
)
|
29 |
|
30 |
self.norm1 = layernorm2d(d_model)
|
31 |
self.norm2 = layernorm2d(d_model)
|
32 |
|
33 |
+
def forward(self, x0, x1, pos0, pos1, mask0=None, mask1=None):
|
34 |
+
# x1,x2: b*d*L
|
35 |
+
x0, x1 = self.update(x0, x1, pos1, mask0, mask1), self.update(
|
36 |
+
x1, x0, pos0, mask1, mask0
|
37 |
+
)
|
38 |
+
return x0, x1
|
39 |
|
40 |
+
def update(self, f0, f1, pos1, mask0, mask1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
"""
|
42 |
Args:
|
43 |
f0: [N, D, H, W]
|
|
|
45 |
Returns:
|
46 |
f0_new: (N, d, h, w)
|
47 |
"""
|
48 |
+
bs, h, w = f0.shape[0], f0.shape[2], f0.shape[3]
|
49 |
|
50 |
+
f0_flatten, f1_flatten = f0.view(bs, self.d_model, -1), f1.view(
|
51 |
+
bs, self.d_model, -1
|
52 |
+
)
|
53 |
+
pos1_flatten = pos1.view(bs, self.d_value - self.d_model, -1)
|
54 |
+
f1_flatten_v = torch.cat([f1_flatten, pos1_flatten], dim=1)
|
55 |
|
56 |
+
queries, keys = self.q_proj(f0_flatten), self.k_proj(f1_flatten)
|
57 |
+
values = self.v_proj(f1_flatten_v).view(
|
58 |
+
bs, self.nhead, self.d_model // self.nhead, -1
|
59 |
+
)
|
|
|
|
|
|
|
60 |
|
61 |
+
queried_values = self.attention(queries, keys, values, mask0, mask1)
|
62 |
+
msg = self.merge_head(queried_values).view(bs, -1, h, w)
|
63 |
+
msg = self.norm2(self.merge_f(torch.cat([f0, self.norm1(msg)], dim=1)))
|
64 |
+
return f0 + msg
|
65 |
|
66 |
|
67 |
class messageLayer_gla(nn.Module):
|
68 |
+
def __init__(
|
69 |
+
self, d_model, d_flow, d_value, nhead, radius_scale, nsample, update_flow=True
|
70 |
+
):
|
71 |
super().__init__()
|
72 |
self.d_model = d_model
|
73 |
+
self.d_flow = d_flow
|
74 |
+
self.d_value = d_value
|
75 |
self.nhead = nhead
|
76 |
+
self.radius_scale = radius_scale
|
77 |
+
self.update_flow = update_flow
|
78 |
+
self.flow_decoder = nn.Sequential(
|
79 |
+
nn.Conv1d(d_flow, d_flow // 2, kernel_size=1, bias=False),
|
80 |
+
nn.ReLU(True),
|
81 |
+
nn.Conv1d(d_flow // 2, 4, kernel_size=1, bias=False),
|
82 |
+
)
|
83 |
+
self.attention = HierachicalAttention(d_model, nhead, nsample, radius_scale)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
+
self.q_proj = nn.Conv1d(d_model, d_model, kernel_size=1, bias=False)
|
86 |
+
self.k_proj = nn.Conv1d(d_model, d_model, kernel_size=1, bias=False)
|
87 |
+
self.v_proj = nn.Conv1d(d_value, d_model, kernel_size=1, bias=False)
|
88 |
+
|
89 |
+
d_extra = d_flow if update_flow else 0
|
90 |
+
self.merge_f = nn.Sequential(
|
91 |
+
nn.Conv2d(
|
92 |
+
d_model * 2 + d_extra, d_model + d_flow, kernel_size=1, bias=False
|
93 |
+
),
|
94 |
+
nn.ReLU(True),
|
95 |
+
nn.Conv2d(
|
96 |
+
d_model + d_flow,
|
97 |
+
d_model + d_extra,
|
98 |
+
kernel_size=3,
|
99 |
+
padding=1,
|
100 |
+
bias=False,
|
101 |
+
),
|
102 |
+
)
|
103 |
+
self.norm1 = layernorm2d(d_model)
|
104 |
+
self.norm2 = layernorm2d(d_model + d_extra)
|
105 |
+
|
106 |
+
def forward(
|
107 |
+
self,
|
108 |
+
x0,
|
109 |
+
x1,
|
110 |
+
flow_feature0,
|
111 |
+
flow_feature1,
|
112 |
+
pos0,
|
113 |
+
pos1,
|
114 |
+
mask0=None,
|
115 |
+
mask1=None,
|
116 |
+
ds0=[4, 4],
|
117 |
+
ds1=[4, 4],
|
118 |
+
):
|
119 |
"""
|
120 |
Args:
|
121 |
x0 (torch.Tensor): [B, C, H, W]
|
|
|
123 |
flow_feature0 (torch.Tensor): [B, C', H, W]
|
124 |
flow_feature1 (torch.Tensor): [B, C', H, W]
|
125 |
"""
|
126 |
+
flow0, flow1 = self.decode_flow(
|
127 |
+
flow_feature0, flow_feature1.shape[2:]
|
128 |
+
), self.decode_flow(flow_feature1, flow_feature0.shape[2:])
|
129 |
+
x0_new, flow_feature0_new = self.update(
|
130 |
+
x0, x1, flow0.detach(), flow_feature0, pos1, mask0, mask1, ds0, ds1
|
131 |
+
)
|
132 |
+
x1_new, flow_feature1_new = self.update(
|
133 |
+
x1, x0, flow1.detach(), flow_feature1, pos0, mask1, mask0, ds1, ds0
|
134 |
+
)
|
135 |
+
return x0_new, x1_new, flow_feature0_new, flow_feature1_new, flow0, flow1
|
136 |
+
|
137 |
+
def update(self, x0, x1, flow0, flow_feature0, pos1, mask0, mask1, ds0, ds1):
|
138 |
+
bs = x0.shape[0]
|
139 |
+
queries, keys = self.q_proj(x0.view(bs, self.d_model, -1)), self.k_proj(
|
140 |
+
x1.view(bs, self.d_model, -1)
|
141 |
+
)
|
142 |
+
x1_pos = torch.cat([x1, pos1], dim=1)
|
143 |
+
values = self.v_proj(x1_pos.view(bs, self.d_value, -1))
|
144 |
+
msg = self.attention(
|
145 |
+
queries,
|
146 |
+
keys,
|
147 |
+
values,
|
148 |
+
flow0,
|
149 |
+
x0.shape[2:],
|
150 |
+
x1.shape[2:],
|
151 |
+
mask0,
|
152 |
+
mask1,
|
153 |
+
ds0,
|
154 |
+
ds1,
|
155 |
+
)
|
156 |
|
157 |
if self.update_flow:
|
158 |
+
update_feature = torch.cat([x0, flow_feature0], dim=1)
|
159 |
else:
|
160 |
+
update_feature = x0
|
161 |
+
msg = self.norm2(
|
162 |
+
self.merge_f(torch.cat([update_feature, self.norm1(msg)], dim=1))
|
163 |
+
)
|
164 |
+
update_feature = update_feature + msg
|
165 |
+
|
166 |
+
x0_new, flow_feature0_new = (
|
167 |
+
update_feature[:, : self.d_model],
|
168 |
+
update_feature[:, self.d_model :],
|
169 |
+
)
|
170 |
+
return x0_new, flow_feature0_new
|
171 |
+
|
172 |
+
def decode_flow(self, flow_feature, kshape):
|
173 |
+
bs, h, w = flow_feature.shape[0], flow_feature.shape[2], flow_feature.shape[3]
|
174 |
+
scale_factor = torch.tensor([kshape[1], kshape[0]]).cuda()[None, None, None]
|
175 |
+
flow = (
|
176 |
+
self.flow_decoder(flow_feature.view(bs, -1, h * w))
|
177 |
+
.permute(0, 2, 1)
|
178 |
+
.view(bs, h, w, 4)
|
179 |
+
)
|
180 |
+
flow_coordinates = torch.sigmoid(flow[:, :, :, :2]) * scale_factor
|
181 |
+
flow_var = flow[:, :, :, 2:]
|
182 |
+
flow = torch.cat([flow_coordinates, flow_var], dim=-1) # B*H*W*4
|
183 |
return flow
|
184 |
|
185 |
|
186 |
class flow_initializer(nn.Module):
|
|
|
187 |
def __init__(self, dim, dim_flow, nhead, layer_num):
|
188 |
super().__init__()
|
189 |
+
self.layer_num = layer_num
|
190 |
self.dim = dim
|
191 |
self.dim_flow = dim_flow
|
192 |
|
193 |
+
encoder_layer = messageLayer_ini(dim, dim_flow, dim + dim_flow, nhead)
|
|
|
194 |
self.layers_coarse = nn.ModuleList(
|
195 |
+
[copy.deepcopy(encoder_layer) for _ in range(layer_num)]
|
196 |
+
)
|
197 |
+
self.decoupler = nn.Conv2d(self.dim, self.dim + self.dim_flow, kernel_size=1)
|
198 |
+
self.up_merge = nn.Conv2d(2 * dim, dim, kernel_size=1)
|
199 |
|
200 |
+
def forward(
|
201 |
+
self, feat0, feat1, pos0, pos1, mask0=None, mask1=None, ds0=[4, 4], ds1=[4, 4]
|
202 |
+
):
|
203 |
# feat0: [B, C, H0, W0]
|
204 |
# feat1: [B, C, H1, W1]
|
205 |
# use low-res MHA to initialize flow feature
|
206 |
bs = feat0.size(0)
|
207 |
+
h0, w0, h1, w1 = feat0.shape[2], feat0.shape[3], feat1.shape[2], feat1.shape[3]
|
208 |
|
209 |
# coarse level
|
210 |
+
sub_feat0, sub_feat1 = F.avg_pool2d(feat0, ds0, stride=ds0), F.avg_pool2d(
|
211 |
+
feat1, ds1, stride=ds1
|
212 |
+
)
|
213 |
+
|
214 |
+
sub_pos0, sub_pos1 = F.avg_pool2d(pos0, ds0, stride=ds0), F.avg_pool2d(
|
215 |
+
pos1, ds1, stride=ds1
|
216 |
+
)
|
217 |
|
|
|
|
|
|
|
218 |
if mask0 is not None:
|
219 |
+
mask0, mask1 = -F.max_pool2d(
|
220 |
+
-mask0.view(bs, 1, h0, w0), ds0, stride=ds0
|
221 |
+
).view(bs, -1), -F.max_pool2d(
|
222 |
+
-mask1.view(bs, 1, h1, w1), ds1, stride=ds1
|
223 |
+
).view(
|
224 |
+
bs, -1
|
225 |
+
)
|
226 |
+
|
227 |
for layer in self.layers_coarse:
|
228 |
+
sub_feat0, sub_feat1 = layer(
|
229 |
+
sub_feat0, sub_feat1, sub_pos0, sub_pos1, mask0, mask1
|
230 |
+
)
|
231 |
# decouple flow and visual features
|
232 |
+
decoupled_feature0, decoupled_feature1 = self.decoupler(
|
233 |
+
sub_feat0
|
234 |
+
), self.decoupler(sub_feat1)
|
235 |
+
|
236 |
+
sub_feat0, sub_flow_feature0 = (
|
237 |
+
decoupled_feature0[:, : self.dim],
|
238 |
+
decoupled_feature0[:, self.dim :],
|
239 |
+
)
|
240 |
+
sub_feat1, sub_flow_feature1 = (
|
241 |
+
decoupled_feature1[:, : self.dim],
|
242 |
+
decoupled_feature1[:, self.dim :],
|
243 |
+
)
|
244 |
+
update_feat0, flow_feature0 = F.upsample(
|
245 |
+
sub_feat0, scale_factor=ds0, mode="bilinear"
|
246 |
+
), F.upsample(sub_flow_feature0, scale_factor=ds0, mode="bilinear")
|
247 |
+
update_feat1, flow_feature1 = F.upsample(
|
248 |
+
sub_feat1, scale_factor=ds1, mode="bilinear"
|
249 |
+
), F.upsample(sub_flow_feature1, scale_factor=ds1, mode="bilinear")
|
250 |
|
251 |
+
feat0 = feat0 + self.up_merge(torch.cat([feat0, update_feat0], dim=1))
|
252 |
+
feat1 = feat1 + self.up_merge(torch.cat([feat1, update_feat1], dim=1))
|
253 |
+
|
254 |
+
return feat0, feat1, flow_feature0, flow_feature1 # b*c*h*w
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
|
256 |
|
257 |
class LocalFeatureTransformer_Flow(nn.Module):
|
|
|
261 |
super(LocalFeatureTransformer_Flow, self).__init__()
|
262 |
|
263 |
self.config = config
|
264 |
+
self.d_model = config["d_model"]
|
265 |
+
self.nhead = config["nhead"]
|
266 |
+
|
267 |
+
self.pos_transform = nn.Conv2d(
|
268 |
+
config["d_model"], config["d_flow"], kernel_size=1, bias=False
|
269 |
+
)
|
270 |
+
self.ini_layer = flow_initializer(
|
271 |
+
self.d_model, config["d_flow"], config["nhead"], config["ini_layer_num"]
|
272 |
+
)
|
273 |
|
|
|
|
|
|
|
274 |
encoder_layer = messageLayer_gla(
|
275 |
+
config["d_model"],
|
276 |
+
config["d_flow"],
|
277 |
+
config["d_flow"] + config["d_model"],
|
278 |
+
config["nhead"],
|
279 |
+
config["radius_scale"],
|
280 |
+
config["nsample"],
|
281 |
+
)
|
282 |
+
encoder_layer_last = messageLayer_gla(
|
283 |
+
config["d_model"],
|
284 |
+
config["d_flow"],
|
285 |
+
config["d_flow"] + config["d_model"],
|
286 |
+
config["nhead"],
|
287 |
+
config["radius_scale"],
|
288 |
+
config["nsample"],
|
289 |
+
update_flow=False,
|
290 |
+
)
|
291 |
+
self.layers = nn.ModuleList(
|
292 |
+
[copy.deepcopy(encoder_layer) for _ in range(config["layer_num"] - 1)]
|
293 |
+
+ [encoder_layer_last]
|
294 |
+
)
|
295 |
self._reset_parameters()
|
296 |
+
|
297 |
def _reset_parameters(self):
|
298 |
+
for name, p in self.named_parameters():
|
299 |
+
if "temp" in name or "sample_offset" in name:
|
300 |
continue
|
301 |
if p.dim() > 1:
|
302 |
nn.init.xavier_uniform_(p)
|
303 |
|
304 |
+
def forward(
|
305 |
+
self, feat0, feat1, pos0, pos1, mask0=None, mask1=None, ds0=[4, 4], ds1=[4, 4]
|
306 |
+
):
|
307 |
"""
|
308 |
Args:
|
309 |
feat0 (torch.Tensor): [N, C, H, W]
|
|
|
315 |
flow_list: [L,N,H,W,4]*1(2)
|
316 |
"""
|
317 |
bs = feat0.size(0)
|
318 |
+
|
319 |
+
pos0, pos1 = self.pos_transform(pos0), self.pos_transform(pos1)
|
320 |
+
pos0, pos1 = pos0.expand(bs, -1, -1, -1), pos1.expand(bs, -1, -1, -1)
|
321 |
assert self.d_model == feat0.size(
|
322 |
+
1
|
323 |
+
), "the feature number of src and transformer must be equal"
|
324 |
+
|
325 |
+
flow_list = [[], []] # [px,py,sx,sy]
|
326 |
if mask0 is not None:
|
327 |
+
mask0, mask1 = mask0[:, None].float(), mask1[:, None].float()
|
328 |
+
feat0, feat1, flow_feature0, flow_feature1 = self.ini_layer(
|
329 |
+
feat0, feat1, pos0, pos1, mask0, mask1, ds0, ds1
|
330 |
+
)
|
331 |
for layer in self.layers:
|
332 |
+
feat0, feat1, flow_feature0, flow_feature1, flow0, flow1 = layer(
|
333 |
+
feat0,
|
334 |
+
feat1,
|
335 |
+
flow_feature0,
|
336 |
+
flow_feature1,
|
337 |
+
pos0,
|
338 |
+
pos1,
|
339 |
+
mask0,
|
340 |
+
mask1,
|
341 |
+
ds0,
|
342 |
+
ds1,
|
343 |
+
)
|
344 |
flow_list[0].append(flow0)
|
345 |
flow_list[1].append(flow1)
|
346 |
+
flow_list[0] = torch.stack(flow_list[0], dim=0)
|
347 |
+
flow_list[1] = torch.stack(flow_list[1], dim=0)
|
348 |
+
feat0, feat1 = feat0.permute(0, 2, 3, 1).view(
|
349 |
+
bs, -1, self.d_model
|
350 |
+
), feat1.permute(0, 2, 3, 1).view(bs, -1, self.d_model)
|
351 |
+
return feat0, feat1, flow_list
|
third_party/ASpanFormer/src/ASpanFormer/aspanformer.py
CHANGED
@@ -5,7 +5,11 @@ from einops.einops import rearrange
|
|
5 |
|
6 |
from .backbone import build_backbone
|
7 |
from .utils.position_encoding import PositionEncodingSine
|
8 |
-
from .aspan_module import
|
|
|
|
|
|
|
|
|
9 |
from .utils.coarse_matching import CoarseMatching
|
10 |
from .utils.fine_matching import FineMatching
|
11 |
|
@@ -19,16 +23,18 @@ class ASpanFormer(nn.Module):
|
|
19 |
# Modules
|
20 |
self.backbone = build_backbone(config)
|
21 |
self.pos_encoding = PositionEncodingSine(
|
22 |
-
config[
|
23 |
-
|
24 |
-
|
|
|
|
|
25 |
self.fine_preprocess = FinePreprocess(config)
|
26 |
self.loftr_fine = LocalFeatureTransformer(config["fine"])
|
27 |
self.fine_matching = FineMatching()
|
28 |
-
self.coarsest_level=config[
|
29 |
|
30 |
def forward(self, data, online_resize=False):
|
31 |
-
"""
|
32 |
Update:
|
33 |
data (dict): {
|
34 |
'image0': (torch.Tensor): (N, 1, H, W)
|
@@ -38,96 +44,135 @@ class ASpanFormer(nn.Module):
|
|
38 |
}
|
39 |
"""
|
40 |
if online_resize:
|
41 |
-
assert data[
|
42 |
-
self.resize_input(data,self.config[
|
43 |
else:
|
44 |
-
data[
|
45 |
|
46 |
# 1. Local Feature CNN
|
47 |
-
data.update(
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
53 |
feats_c, feats_f = self.backbone(
|
54 |
-
torch.cat([data[
|
|
|
55 |
(feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(
|
56 |
-
data[
|
|
|
57 |
else: # handle different input shapes
|
58 |
(feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(
|
59 |
-
data[
|
|
|
60 |
|
61 |
-
data.update(
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
65 |
|
66 |
# 2. coarse-level loftr module
|
67 |
# add featmap with positional encoding, then flatten it to sequence [N, HW, C]
|
68 |
-
[feat_c0, pos_encoding0], [feat_c1, pos_encoding1] = self.pos_encoding(
|
69 |
-
|
70 |
-
|
|
|
|
|
71 |
|
72 |
-
#TODO:adjust ds
|
73 |
-
ds0=[
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
if online_resize:
|
76 |
-
ds0,ds1=[4,4],[4,4]
|
77 |
|
78 |
mask_c0 = mask_c1 = None # mask is useful in training
|
79 |
-
if
|
80 |
-
mask_c0, mask_c1 = data[
|
81 |
-
-2), data['mask1'].flatten(-2)
|
82 |
feat_c0, feat_c1, flow_list = self.loftr_coarse(
|
83 |
-
feat_c0, feat_c1,pos_encoding0,pos_encoding1,mask_c0,mask_c1,ds0,ds1
|
|
|
84 |
|
85 |
# 3. match coarse-level and register predicted offset
|
86 |
-
self.coarse_matching(
|
87 |
-
|
|
|
88 |
|
89 |
# 4. fine-level refinement
|
90 |
feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(
|
91 |
-
feat_f0, feat_f1, feat_c0, feat_c1, data
|
|
|
92 |
if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted
|
93 |
feat_f0_unfold, feat_f1_unfold = self.loftr_fine(
|
94 |
-
feat_f0_unfold, feat_f1_unfold
|
|
|
95 |
|
96 |
# 5. match fine-level
|
97 |
self.fine_matching(feat_f0_unfold, feat_f1_unfold, data)
|
98 |
|
99 |
# 6. resize match coordinates back to input resolution
|
100 |
if online_resize:
|
101 |
-
data[
|
102 |
-
data[
|
103 |
-
|
104 |
def load_state_dict(self, state_dict, *args, **kwargs):
|
105 |
for k in list(state_dict.keys()):
|
106 |
-
if k.startswith(
|
107 |
-
if
|
108 |
state_dict.pop(k)
|
109 |
else:
|
110 |
-
state_dict[k.replace(
|
111 |
return super().load_state_dict(state_dict, *args, **kwargs)
|
112 |
-
|
113 |
-
def resize_input(self,data,train_res,df=32):
|
114 |
-
h0,w0,h1,w1=
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
else:
|
120 |
-
train_res_h,train_res_w=train_res[0],train_res[1]
|
121 |
-
data[
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
else:
|
132 |
-
img_new=image
|
133 |
return img_new
|
|
|
5 |
|
6 |
from .backbone import build_backbone
|
7 |
from .utils.position_encoding import PositionEncodingSine
|
8 |
+
from .aspan_module import (
|
9 |
+
LocalFeatureTransformer_Flow,
|
10 |
+
LocalFeatureTransformer,
|
11 |
+
FinePreprocess,
|
12 |
+
)
|
13 |
from .utils.coarse_matching import CoarseMatching
|
14 |
from .utils.fine_matching import FineMatching
|
15 |
|
|
|
23 |
# Modules
|
24 |
self.backbone = build_backbone(config)
|
25 |
self.pos_encoding = PositionEncodingSine(
|
26 |
+
config["coarse"]["d_model"],
|
27 |
+
pre_scaling=[config["coarse"]["train_res"], config["coarse"]["test_res"]],
|
28 |
+
)
|
29 |
+
self.loftr_coarse = LocalFeatureTransformer_Flow(config["coarse"])
|
30 |
+
self.coarse_matching = CoarseMatching(config["match_coarse"])
|
31 |
self.fine_preprocess = FinePreprocess(config)
|
32 |
self.loftr_fine = LocalFeatureTransformer(config["fine"])
|
33 |
self.fine_matching = FineMatching()
|
34 |
+
self.coarsest_level = config["coarse"]["coarsest_level"]
|
35 |
|
36 |
def forward(self, data, online_resize=False):
|
37 |
+
"""
|
38 |
Update:
|
39 |
data (dict): {
|
40 |
'image0': (torch.Tensor): (N, 1, H, W)
|
|
|
44 |
}
|
45 |
"""
|
46 |
if online_resize:
|
47 |
+
assert data["image0"].shape[0] == 1 and data["image1"].shape[1] == 1
|
48 |
+
self.resize_input(data, self.config["coarse"]["train_res"])
|
49 |
else:
|
50 |
+
data["pos_scale0"], data["pos_scale1"] = None, None
|
51 |
|
52 |
# 1. Local Feature CNN
|
53 |
+
data.update(
|
54 |
+
{
|
55 |
+
"bs": data["image0"].size(0),
|
56 |
+
"hw0_i": data["image0"].shape[2:],
|
57 |
+
"hw1_i": data["image1"].shape[2:],
|
58 |
+
}
|
59 |
+
)
|
60 |
+
|
61 |
+
if data["hw0_i"] == data["hw1_i"]: # faster & better BN convergence
|
62 |
feats_c, feats_f = self.backbone(
|
63 |
+
torch.cat([data["image0"], data["image1"]], dim=0)
|
64 |
+
)
|
65 |
(feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(
|
66 |
+
data["bs"]
|
67 |
+
), feats_f.split(data["bs"])
|
68 |
else: # handle different input shapes
|
69 |
(feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(
|
70 |
+
data["image0"]
|
71 |
+
), self.backbone(data["image1"])
|
72 |
|
73 |
+
data.update(
|
74 |
+
{
|
75 |
+
"hw0_c": feat_c0.shape[2:],
|
76 |
+
"hw1_c": feat_c1.shape[2:],
|
77 |
+
"hw0_f": feat_f0.shape[2:],
|
78 |
+
"hw1_f": feat_f1.shape[2:],
|
79 |
+
}
|
80 |
+
)
|
81 |
|
82 |
# 2. coarse-level loftr module
|
83 |
# add featmap with positional encoding, then flatten it to sequence [N, HW, C]
|
84 |
+
[feat_c0, pos_encoding0], [feat_c1, pos_encoding1] = self.pos_encoding(
|
85 |
+
feat_c0, data["pos_scale0"]
|
86 |
+
), self.pos_encoding(feat_c1, data["pos_scale1"])
|
87 |
+
feat_c0 = rearrange(feat_c0, "n c h w -> n c h w ")
|
88 |
+
feat_c1 = rearrange(feat_c1, "n c h w -> n c h w ")
|
89 |
|
90 |
+
# TODO:adjust ds
|
91 |
+
ds0 = [
|
92 |
+
int(data["hw0_c"][0] / self.coarsest_level[0]),
|
93 |
+
int(data["hw0_c"][1] / self.coarsest_level[1]),
|
94 |
+
]
|
95 |
+
ds1 = [
|
96 |
+
int(data["hw1_c"][0] / self.coarsest_level[0]),
|
97 |
+
int(data["hw1_c"][1] / self.coarsest_level[1]),
|
98 |
+
]
|
99 |
if online_resize:
|
100 |
+
ds0, ds1 = [4, 4], [4, 4]
|
101 |
|
102 |
mask_c0 = mask_c1 = None # mask is useful in training
|
103 |
+
if "mask0" in data:
|
104 |
+
mask_c0, mask_c1 = data["mask0"].flatten(-2), data["mask1"].flatten(-2)
|
|
|
105 |
feat_c0, feat_c1, flow_list = self.loftr_coarse(
|
106 |
+
feat_c0, feat_c1, pos_encoding0, pos_encoding1, mask_c0, mask_c1, ds0, ds1
|
107 |
+
)
|
108 |
|
109 |
# 3. match coarse-level and register predicted offset
|
110 |
+
self.coarse_matching(
|
111 |
+
feat_c0, feat_c1, flow_list, data, mask_c0=mask_c0, mask_c1=mask_c1
|
112 |
+
)
|
113 |
|
114 |
# 4. fine-level refinement
|
115 |
feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(
|
116 |
+
feat_f0, feat_f1, feat_c0, feat_c1, data
|
117 |
+
)
|
118 |
if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted
|
119 |
feat_f0_unfold, feat_f1_unfold = self.loftr_fine(
|
120 |
+
feat_f0_unfold, feat_f1_unfold
|
121 |
+
)
|
122 |
|
123 |
# 5. match fine-level
|
124 |
self.fine_matching(feat_f0_unfold, feat_f1_unfold, data)
|
125 |
|
126 |
# 6. resize match coordinates back to input resolution
|
127 |
if online_resize:
|
128 |
+
data["mkpts0_f"] *= data["online_resize_scale0"]
|
129 |
+
data["mkpts1_f"] *= data["online_resize_scale1"]
|
130 |
+
|
131 |
def load_state_dict(self, state_dict, *args, **kwargs):
|
132 |
for k in list(state_dict.keys()):
|
133 |
+
if k.startswith("matcher."):
|
134 |
+
if "sample_offset" in k:
|
135 |
state_dict.pop(k)
|
136 |
else:
|
137 |
+
state_dict[k.replace("matcher.", "", 1)] = state_dict.pop(k)
|
138 |
return super().load_state_dict(state_dict, *args, **kwargs)
|
139 |
+
|
140 |
+
def resize_input(self, data, train_res, df=32):
|
141 |
+
h0, w0, h1, w1 = (
|
142 |
+
data["image0"].shape[2],
|
143 |
+
data["image0"].shape[3],
|
144 |
+
data["image1"].shape[2],
|
145 |
+
data["image1"].shape[3],
|
146 |
+
)
|
147 |
+
data["image0"], data["image1"] = self.resize_df(
|
148 |
+
data["image0"], df
|
149 |
+
), self.resize_df(data["image1"], df)
|
150 |
+
|
151 |
+
if len(train_res) == 1:
|
152 |
+
train_res_h = train_res_w = train_res
|
153 |
else:
|
154 |
+
train_res_h, train_res_w = train_res[0], train_res[1]
|
155 |
+
data["pos_scale0"], data["pos_scale1"] = [
|
156 |
+
train_res_h / data["image0"].shape[2],
|
157 |
+
train_res_w / data["image0"].shape[3],
|
158 |
+
], [
|
159 |
+
train_res_h / data["image1"].shape[2],
|
160 |
+
train_res_w / data["image1"].shape[3],
|
161 |
+
]
|
162 |
+
data["online_resize_scale0"], data["online_resize_scale1"] = (
|
163 |
+
torch.tensor([w0 / data["image0"].shape[3], h0 / data["image0"].shape[2]])[
|
164 |
+
None
|
165 |
+
].cuda(),
|
166 |
+
torch.tensor([w1 / data["image1"].shape[3], h1 / data["image1"].shape[2]])[
|
167 |
+
None
|
168 |
+
].cuda(),
|
169 |
+
)
|
170 |
+
|
171 |
+
def resize_df(self, image, df=32):
|
172 |
+
h, w = image.shape[2], image.shape[3]
|
173 |
+
h_new, w_new = h // df * df, w // df * df
|
174 |
+
if h != h_new or w != w_new:
|
175 |
+
img_new = transforms.Resize([h_new, w_new]).forward(image)
|
176 |
else:
|
177 |
+
img_new = image
|
178 |
return img_new
|
third_party/ASpanFormer/src/ASpanFormer/backbone/__init__.py
CHANGED
@@ -2,10 +2,12 @@ from .resnet_fpn import ResNetFPN_8_2, ResNetFPN_16_4
|
|
2 |
|
3 |
|
4 |
def build_backbone(config):
|
5 |
-
if config[
|
6 |
-
if config[
|
7 |
-
return ResNetFPN_8_2(config[
|
8 |
-
elif config[
|
9 |
-
return ResNetFPN_16_4(config[
|
10 |
else:
|
11 |
-
raise ValueError(
|
|
|
|
|
|
2 |
|
3 |
|
4 |
def build_backbone(config):
|
5 |
+
if config["backbone_type"] == "ResNetFPN":
|
6 |
+
if config["resolution"] == (8, 2):
|
7 |
+
return ResNetFPN_8_2(config["resnetfpn"])
|
8 |
+
elif config["resolution"] == (16, 4):
|
9 |
+
return ResNetFPN_16_4(config["resnetfpn"])
|
10 |
else:
|
11 |
+
raise ValueError(
|
12 |
+
f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported."
|
13 |
+
)
|
third_party/ASpanFormer/src/ASpanFormer/backbone/resnet_fpn.py
CHANGED
@@ -4,12 +4,16 @@ import torch.nn.functional as F
|
|
4 |
|
5 |
def conv1x1(in_planes, out_planes, stride=1):
|
6 |
"""1x1 convolution without padding"""
|
7 |
-
return nn.Conv2d(
|
|
|
|
|
8 |
|
9 |
|
10 |
def conv3x3(in_planes, out_planes, stride=1):
|
11 |
"""3x3 convolution with padding"""
|
12 |
-
return nn.Conv2d(
|
|
|
|
|
13 |
|
14 |
|
15 |
class BasicBlock(nn.Module):
|
@@ -25,8 +29,7 @@ class BasicBlock(nn.Module):
|
|
25 |
self.downsample = None
|
26 |
else:
|
27 |
self.downsample = nn.Sequential(
|
28 |
-
conv1x1(in_planes, planes, stride=stride),
|
29 |
-
nn.BatchNorm2d(planes)
|
30 |
)
|
31 |
|
32 |
def forward(self, x):
|
@@ -37,7 +40,7 @@ class BasicBlock(nn.Module):
|
|
37 |
if self.downsample is not None:
|
38 |
x = self.downsample(x)
|
39 |
|
40 |
-
return self.relu(x+y)
|
41 |
|
42 |
|
43 |
class ResNetFPN_8_2(nn.Module):
|
@@ -50,14 +53,16 @@ class ResNetFPN_8_2(nn.Module):
|
|
50 |
super().__init__()
|
51 |
# Config
|
52 |
block = BasicBlock
|
53 |
-
initial_dim = config[
|
54 |
-
block_dims = config[
|
55 |
|
56 |
# Class Variable
|
57 |
self.in_planes = initial_dim
|
58 |
|
59 |
# Networks
|
60 |
-
self.conv1 = nn.Conv2d(
|
|
|
|
|
61 |
self.bn1 = nn.BatchNorm2d(initial_dim)
|
62 |
self.relu = nn.ReLU(inplace=True)
|
63 |
|
@@ -84,7 +89,7 @@ class ResNetFPN_8_2(nn.Module):
|
|
84 |
|
85 |
for m in self.modules():
|
86 |
if isinstance(m, nn.Conv2d):
|
87 |
-
nn.init.kaiming_normal_(m.weight, mode=
|
88 |
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
89 |
nn.init.constant_(m.weight, 1)
|
90 |
nn.init.constant_(m.bias, 0)
|
@@ -107,13 +112,17 @@ class ResNetFPN_8_2(nn.Module):
|
|
107 |
# FPN
|
108 |
x3_out = self.layer3_outconv(x3)
|
109 |
|
110 |
-
x3_out_2x = F.interpolate(
|
|
|
|
|
111 |
x2_out = self.layer2_outconv(x2)
|
112 |
-
x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
|
113 |
|
114 |
-
x2_out_2x = F.interpolate(
|
|
|
|
|
115 |
x1_out = self.layer1_outconv(x1)
|
116 |
-
x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
|
117 |
|
118 |
return [x3_out, x1_out]
|
119 |
|
@@ -128,14 +137,16 @@ class ResNetFPN_16_4(nn.Module):
|
|
128 |
super().__init__()
|
129 |
# Config
|
130 |
block = BasicBlock
|
131 |
-
initial_dim = config[
|
132 |
-
block_dims = config[
|
133 |
|
134 |
# Class Variable
|
135 |
self.in_planes = initial_dim
|
136 |
|
137 |
# Networks
|
138 |
-
self.conv1 = nn.Conv2d(
|
|
|
|
|
139 |
self.bn1 = nn.BatchNorm2d(initial_dim)
|
140 |
self.relu = nn.ReLU(inplace=True)
|
141 |
|
@@ -164,7 +175,7 @@ class ResNetFPN_16_4(nn.Module):
|
|
164 |
|
165 |
for m in self.modules():
|
166 |
if isinstance(m, nn.Conv2d):
|
167 |
-
nn.init.kaiming_normal_(m.weight, mode=
|
168 |
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
169 |
nn.init.constant_(m.weight, 1)
|
170 |
nn.init.constant_(m.bias, 0)
|
@@ -188,12 +199,16 @@ class ResNetFPN_16_4(nn.Module):
|
|
188 |
# FPN
|
189 |
x4_out = self.layer4_outconv(x4)
|
190 |
|
191 |
-
x4_out_2x = F.interpolate(
|
|
|
|
|
192 |
x3_out = self.layer3_outconv(x3)
|
193 |
-
x3_out = self.layer3_outconv2(x3_out+x4_out_2x)
|
194 |
|
195 |
-
x3_out_2x = F.interpolate(
|
|
|
|
|
196 |
x2_out = self.layer2_outconv(x2)
|
197 |
-
x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
|
198 |
|
199 |
return [x4_out, x2_out]
|
|
|
4 |
|
5 |
def conv1x1(in_planes, out_planes, stride=1):
|
6 |
"""1x1 convolution without padding"""
|
7 |
+
return nn.Conv2d(
|
8 |
+
in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False
|
9 |
+
)
|
10 |
|
11 |
|
12 |
def conv3x3(in_planes, out_planes, stride=1):
|
13 |
"""3x3 convolution with padding"""
|
14 |
+
return nn.Conv2d(
|
15 |
+
in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
|
16 |
+
)
|
17 |
|
18 |
|
19 |
class BasicBlock(nn.Module):
|
|
|
29 |
self.downsample = None
|
30 |
else:
|
31 |
self.downsample = nn.Sequential(
|
32 |
+
conv1x1(in_planes, planes, stride=stride), nn.BatchNorm2d(planes)
|
|
|
33 |
)
|
34 |
|
35 |
def forward(self, x):
|
|
|
40 |
if self.downsample is not None:
|
41 |
x = self.downsample(x)
|
42 |
|
43 |
+
return self.relu(x + y)
|
44 |
|
45 |
|
46 |
class ResNetFPN_8_2(nn.Module):
|
|
|
53 |
super().__init__()
|
54 |
# Config
|
55 |
block = BasicBlock
|
56 |
+
initial_dim = config["initial_dim"]
|
57 |
+
block_dims = config["block_dims"]
|
58 |
|
59 |
# Class Variable
|
60 |
self.in_planes = initial_dim
|
61 |
|
62 |
# Networks
|
63 |
+
self.conv1 = nn.Conv2d(
|
64 |
+
1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False
|
65 |
+
)
|
66 |
self.bn1 = nn.BatchNorm2d(initial_dim)
|
67 |
self.relu = nn.ReLU(inplace=True)
|
68 |
|
|
|
89 |
|
90 |
for m in self.modules():
|
91 |
if isinstance(m, nn.Conv2d):
|
92 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
93 |
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
94 |
nn.init.constant_(m.weight, 1)
|
95 |
nn.init.constant_(m.bias, 0)
|
|
|
112 |
# FPN
|
113 |
x3_out = self.layer3_outconv(x3)
|
114 |
|
115 |
+
x3_out_2x = F.interpolate(
|
116 |
+
x3_out, scale_factor=2.0, mode="bilinear", align_corners=True
|
117 |
+
)
|
118 |
x2_out = self.layer2_outconv(x2)
|
119 |
+
x2_out = self.layer2_outconv2(x2_out + x3_out_2x)
|
120 |
|
121 |
+
x2_out_2x = F.interpolate(
|
122 |
+
x2_out, scale_factor=2.0, mode="bilinear", align_corners=True
|
123 |
+
)
|
124 |
x1_out = self.layer1_outconv(x1)
|
125 |
+
x1_out = self.layer1_outconv2(x1_out + x2_out_2x)
|
126 |
|
127 |
return [x3_out, x1_out]
|
128 |
|
|
|
137 |
super().__init__()
|
138 |
# Config
|
139 |
block = BasicBlock
|
140 |
+
initial_dim = config["initial_dim"]
|
141 |
+
block_dims = config["block_dims"]
|
142 |
|
143 |
# Class Variable
|
144 |
self.in_planes = initial_dim
|
145 |
|
146 |
# Networks
|
147 |
+
self.conv1 = nn.Conv2d(
|
148 |
+
1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False
|
149 |
+
)
|
150 |
self.bn1 = nn.BatchNorm2d(initial_dim)
|
151 |
self.relu = nn.ReLU(inplace=True)
|
152 |
|
|
|
175 |
|
176 |
for m in self.modules():
|
177 |
if isinstance(m, nn.Conv2d):
|
178 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
179 |
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
180 |
nn.init.constant_(m.weight, 1)
|
181 |
nn.init.constant_(m.bias, 0)
|
|
|
199 |
# FPN
|
200 |
x4_out = self.layer4_outconv(x4)
|
201 |
|
202 |
+
x4_out_2x = F.interpolate(
|
203 |
+
x4_out, scale_factor=2.0, mode="bilinear", align_corners=True
|
204 |
+
)
|
205 |
x3_out = self.layer3_outconv(x3)
|
206 |
+
x3_out = self.layer3_outconv2(x3_out + x4_out_2x)
|
207 |
|
208 |
+
x3_out_2x = F.interpolate(
|
209 |
+
x3_out, scale_factor=2.0, mode="bilinear", align_corners=True
|
210 |
+
)
|
211 |
x2_out = self.layer2_outconv(x2)
|
212 |
+
x2_out = self.layer2_outconv2(x2_out + x3_out_2x)
|
213 |
|
214 |
return [x4_out, x2_out]
|
third_party/ASpanFormer/src/ASpanFormer/utils/coarse_matching.py
CHANGED
@@ -7,8 +7,9 @@ from time import time
|
|
7 |
|
8 |
INF = 1e9
|
9 |
|
|
|
10 |
def mask_border(m, b: int, v):
|
11 |
-
"""
|
12 |
Args:
|
13 |
m (torch.Tensor): [N, H0, W0, H1, W1]
|
14 |
b (int)
|
@@ -39,22 +40,21 @@ def mask_border_with_padding(m, bd, v, p_m0, p_m1):
|
|
39 |
h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int()
|
40 |
h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int()
|
41 |
for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)):
|
42 |
-
m[b_idx, h0 - bd:] = v
|
43 |
-
m[b_idx, :, w0 - bd:] = v
|
44 |
-
m[b_idx, :, :, h1 - bd:] = v
|
45 |
-
m[b_idx, :, :, :, w1 - bd:] = v
|
46 |
|
47 |
|
48 |
def compute_max_candidates(p_m0, p_m1):
|
49 |
"""Compute the max candidates of all pairs within a batch
|
50 |
-
|
51 |
Args:
|
52 |
p_m0, p_m1 (torch.Tensor): padded masks
|
53 |
"""
|
54 |
h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0]
|
55 |
h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0]
|
56 |
-
max_cand = torch.sum(
|
57 |
-
torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0])
|
58 |
return max_cand
|
59 |
|
60 |
|
@@ -63,29 +63,32 @@ class CoarseMatching(nn.Module):
|
|
63 |
super().__init__()
|
64 |
self.config = config
|
65 |
# general config
|
66 |
-
self.thr = config[
|
67 |
-
self.border_rm = config[
|
68 |
# -- # for trainig fine-level LoFTR
|
69 |
-
self.train_coarse_percent = config[
|
70 |
-
self.train_pad_num_gt_min = config[
|
71 |
-
|
72 |
# we provide 2 options for differentiable matching
|
73 |
-
self.match_type = config[
|
74 |
-
if self.match_type ==
|
75 |
-
self.temperature=nn.parameter.Parameter(
|
76 |
-
|
|
|
|
|
77 |
try:
|
78 |
from .superglue import log_optimal_transport
|
79 |
except ImportError:
|
80 |
raise ImportError("download superglue.py first!")
|
81 |
self.log_optimal_transport = log_optimal_transport
|
82 |
self.bin_score = nn.Parameter(
|
83 |
-
torch.tensor(config[
|
84 |
-
|
85 |
-
self.
|
|
|
86 |
else:
|
87 |
raise NotImplementedError()
|
88 |
-
|
89 |
def forward(self, feat_c0, feat_c1, flow_list, data, mask_c0=None, mask_c1=None):
|
90 |
"""
|
91 |
Args:
|
@@ -108,29 +111,32 @@ class CoarseMatching(nn.Module):
|
|
108 |
"""
|
109 |
N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2)
|
110 |
# normalize
|
111 |
-
feat_c0, feat_c1 = map(
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
|
|
|
|
117 |
if mask_c0 is not None:
|
118 |
sim_matrix.masked_fill_(
|
119 |
-
~(mask_c0[..., None] * mask_c1[:, None]).bool(),
|
120 |
-
|
121 |
conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2)
|
122 |
-
|
123 |
-
elif self.match_type ==
|
124 |
# sinkhorn, dustbin included
|
125 |
sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1)
|
126 |
if mask_c0 is not None:
|
127 |
sim_matrix[:, :L, :S].masked_fill_(
|
128 |
-
~(mask_c0[..., None] * mask_c1[:, None]).bool(),
|
129 |
-
|
130 |
|
131 |
# build uniform prior & use sinkhorn
|
132 |
log_assign_matrix = self.log_optimal_transport(
|
133 |
-
sim_matrix, self.bin_score, self.skh_iters
|
|
|
134 |
assign_matrix = log_assign_matrix.exp()
|
135 |
conf_matrix = assign_matrix[:, :-1, :-1]
|
136 |
|
@@ -141,18 +147,21 @@ class CoarseMatching(nn.Module):
|
|
141 |
conf_matrix[filter0[..., None].repeat(1, 1, S)] = 0
|
142 |
conf_matrix[filter1[:, None].repeat(1, L, 1)] = 0
|
143 |
|
144 |
-
if self.config[
|
145 |
-
data.update({
|
146 |
|
147 |
-
data.update({
|
148 |
# predict coarse matches from conf_matrix
|
149 |
data.update(**self.get_coarse_match(conf_matrix, data))
|
150 |
|
151 |
-
#update predicted offset
|
152 |
-
if
|
153 |
-
flow_list
|
154 |
-
|
155 |
-
|
|
|
|
|
|
|
156 |
|
157 |
@torch.no_grad()
|
158 |
def get_coarse_match(self, conf_matrix, data):
|
@@ -172,28 +181,33 @@ class CoarseMatching(nn.Module):
|
|
172 |
'mconf' (torch.Tensor): [M]}
|
173 |
"""
|
174 |
axes_lengths = {
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
}
|
180 |
_device = conf_matrix.device
|
181 |
# 1. confidence thresholding
|
182 |
mask = conf_matrix > self.thr
|
183 |
-
mask = rearrange(
|
184 |
-
|
185 |
-
|
|
|
186 |
mask_border(mask, self.border_rm, False)
|
187 |
else:
|
188 |
-
mask_border_with_padding(
|
189 |
-
|
190 |
-
|
191 |
-
|
|
|
|
|
192 |
|
193 |
# 2. mutual nearest
|
194 |
-
mask =
|
195 |
-
|
|
|
196 |
* (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0])
|
|
|
197 |
|
198 |
# 3. find all valid coarse matches
|
199 |
# this only works when at most one `True` in each row
|
@@ -208,67 +222,79 @@ class CoarseMatching(nn.Module):
|
|
208 |
# NOTE:
|
209 |
# The sampling is performed across all pairs in a batch without manually balancing
|
210 |
# #samples for fine-level increases w.r.t. batch_size
|
211 |
-
if
|
212 |
-
num_candidates_max = mask.size(0) * max(
|
213 |
-
mask.size(1), mask.size(2))
|
214 |
else:
|
215 |
num_candidates_max = compute_max_candidates(
|
216 |
-
data[
|
217 |
-
|
218 |
-
|
219 |
num_matches_pred = len(b_ids)
|
220 |
-
assert
|
221 |
-
|
|
|
|
|
222 |
# pred_indices is to select from prediction
|
223 |
if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min:
|
224 |
pred_indices = torch.arange(num_matches_pred, device=_device)
|
225 |
else:
|
226 |
pred_indices = torch.randint(
|
227 |
num_matches_pred,
|
228 |
-
(num_matches_train - self.train_pad_num_gt_min,
|
229 |
-
device=_device
|
|
|
230 |
|
231 |
# gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200)
|
232 |
gt_pad_indices = torch.randint(
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
mconf_gt = torch.zeros(
|
|
|
|
|
238 |
|
239 |
b_ids, i_ids, j_ids, mconf = map(
|
240 |
-
lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]],
|
241 |
-
|
242 |
-
|
243 |
-
|
|
|
|
|
|
|
|
|
244 |
|
245 |
# These matches select patches that feed into fine-level network
|
246 |
-
coarse_matches = {
|
247 |
|
248 |
# 4. Update with matches in original image resolution
|
249 |
-
scale = data[
|
250 |
-
scale0 = scale * data[
|
251 |
-
scale1 = scale * data[
|
252 |
-
mkpts0_c =
|
253 |
-
[i_ids % data[
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
dim=1)
|
|
|
|
|
258 |
|
259 |
# These matches is the current prediction (for visualization)
|
260 |
-
coarse_matches.update(
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
|
|
|
|
267 |
|
268 |
return coarse_matches
|
269 |
|
270 |
@torch.no_grad()
|
271 |
-
def get_offset_match(self, flow_list, data,mask1,mask2):
|
272 |
"""
|
273 |
Args:
|
274 |
offset (torch.Tensor): [L, B, H, W, 2]
|
@@ -280,52 +306,62 @@ class CoarseMatching(nn.Module):
|
|
280 |
'mkpts1_c' (torch.Tensor): [M, 2],
|
281 |
'mconf' (torch.Tensor): [M]}
|
282 |
"""
|
283 |
-
offset1=flow_list[0]
|
284 |
-
bs,layer_num=offset1.shape[1],offset1.shape[0]
|
285 |
-
|
286 |
-
#left side
|
287 |
-
offset1=offset1.view(layer_num,bs
|
288 |
-
conf1=offset1[
|
289 |
if mask1 is not None:
|
290 |
-
conf1.masked_fill_(~mask1.bool()[None].expand(layer_num
|
291 |
-
offset1=offset1[
|
292 |
-
self.get_offset_match_work(offset1,conf1,data,
|
293 |
-
|
294 |
-
#rihgt side
|
295 |
-
if len(flow_list)==2:
|
296 |
-
offset2=flow_list[1].view(layer_num,bs
|
297 |
-
conf2=offset2[
|
298 |
if mask2 is not None:
|
299 |
-
conf2.masked_fill_(~mask2.bool()[None].expand(layer_num
|
300 |
-
offset2=offset2[
|
301 |
-
self.get_offset_match_work(offset2,conf2,data,
|
302 |
-
|
303 |
|
304 |
@torch.no_grad()
|
305 |
-
def get_offset_match_work(self, offset,conf, data,side):
|
306 |
-
bs,layer_num=offset.shape[1],offset.shape[0]
|
307 |
# 1. confidence thresholding
|
308 |
-
mask_conf= conf<2
|
309 |
for index in range(bs):
|
310 |
-
mask_conf[:,index,0]=True
|
311 |
# 3. find offset matches
|
312 |
-
scale = data[
|
313 |
-
l_ids,b_ids,i_ids = torch.where(mask_conf)
|
314 |
-
j_coor=offset[l_ids,b_ids,i_ids
|
315 |
-
i_coor=
|
316 |
-
|
|
|
|
|
|
|
317 |
# These matches is the current prediction (for visualization)
|
318 |
-
data.update(
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
327 |
else:
|
328 |
-
data.update(
|
329 |
-
|
330 |
-
|
331 |
-
|
|
|
|
|
|
7 |
|
8 |
INF = 1e9
|
9 |
|
10 |
+
|
11 |
def mask_border(m, b: int, v):
|
12 |
+
"""Mask borders with value
|
13 |
Args:
|
14 |
m (torch.Tensor): [N, H0, W0, H1, W1]
|
15 |
b (int)
|
|
|
40 |
h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int()
|
41 |
h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int()
|
42 |
for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)):
|
43 |
+
m[b_idx, h0 - bd :] = v
|
44 |
+
m[b_idx, :, w0 - bd :] = v
|
45 |
+
m[b_idx, :, :, h1 - bd :] = v
|
46 |
+
m[b_idx, :, :, :, w1 - bd :] = v
|
47 |
|
48 |
|
49 |
def compute_max_candidates(p_m0, p_m1):
|
50 |
"""Compute the max candidates of all pairs within a batch
|
51 |
+
|
52 |
Args:
|
53 |
p_m0, p_m1 (torch.Tensor): padded masks
|
54 |
"""
|
55 |
h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0]
|
56 |
h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0]
|
57 |
+
max_cand = torch.sum(torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0])
|
|
|
58 |
return max_cand
|
59 |
|
60 |
|
|
|
63 |
super().__init__()
|
64 |
self.config = config
|
65 |
# general config
|
66 |
+
self.thr = config["thr"]
|
67 |
+
self.border_rm = config["border_rm"]
|
68 |
# -- # for trainig fine-level LoFTR
|
69 |
+
self.train_coarse_percent = config["train_coarse_percent"]
|
70 |
+
self.train_pad_num_gt_min = config["train_pad_num_gt_min"]
|
71 |
+
|
72 |
# we provide 2 options for differentiable matching
|
73 |
+
self.match_type = config["match_type"]
|
74 |
+
if self.match_type == "dual_softmax":
|
75 |
+
self.temperature = nn.parameter.Parameter(
|
76 |
+
torch.tensor(10.0), requires_grad=True
|
77 |
+
)
|
78 |
+
elif self.match_type == "sinkhorn":
|
79 |
try:
|
80 |
from .superglue import log_optimal_transport
|
81 |
except ImportError:
|
82 |
raise ImportError("download superglue.py first!")
|
83 |
self.log_optimal_transport = log_optimal_transport
|
84 |
self.bin_score = nn.Parameter(
|
85 |
+
torch.tensor(config["skh_init_bin_score"], requires_grad=True)
|
86 |
+
)
|
87 |
+
self.skh_iters = config["skh_iters"]
|
88 |
+
self.skh_prefilter = config["skh_prefilter"]
|
89 |
else:
|
90 |
raise NotImplementedError()
|
91 |
+
|
92 |
def forward(self, feat_c0, feat_c1, flow_list, data, mask_c0=None, mask_c1=None):
|
93 |
"""
|
94 |
Args:
|
|
|
111 |
"""
|
112 |
N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2)
|
113 |
# normalize
|
114 |
+
feat_c0, feat_c1 = map(
|
115 |
+
lambda feat: feat / feat.shape[-1] ** 0.5, [feat_c0, feat_c1]
|
116 |
+
)
|
117 |
+
|
118 |
+
if self.match_type == "dual_softmax":
|
119 |
+
sim_matrix = (
|
120 |
+
torch.einsum("nlc,nsc->nls", feat_c0, feat_c1) * self.temperature
|
121 |
+
)
|
122 |
if mask_c0 is not None:
|
123 |
sim_matrix.masked_fill_(
|
124 |
+
~(mask_c0[..., None] * mask_c1[:, None]).bool(), -INF
|
125 |
+
)
|
126 |
conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2)
|
127 |
+
|
128 |
+
elif self.match_type == "sinkhorn":
|
129 |
# sinkhorn, dustbin included
|
130 |
sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1)
|
131 |
if mask_c0 is not None:
|
132 |
sim_matrix[:, :L, :S].masked_fill_(
|
133 |
+
~(mask_c0[..., None] * mask_c1[:, None]).bool(), -INF
|
134 |
+
)
|
135 |
|
136 |
# build uniform prior & use sinkhorn
|
137 |
log_assign_matrix = self.log_optimal_transport(
|
138 |
+
sim_matrix, self.bin_score, self.skh_iters
|
139 |
+
)
|
140 |
assign_matrix = log_assign_matrix.exp()
|
141 |
conf_matrix = assign_matrix[:, :-1, :-1]
|
142 |
|
|
|
147 |
conf_matrix[filter0[..., None].repeat(1, 1, S)] = 0
|
148 |
conf_matrix[filter1[:, None].repeat(1, L, 1)] = 0
|
149 |
|
150 |
+
if self.config["sparse_spvs"]:
|
151 |
+
data.update({"conf_matrix_with_bin": assign_matrix.clone()})
|
152 |
|
153 |
+
data.update({"conf_matrix": conf_matrix})
|
154 |
# predict coarse matches from conf_matrix
|
155 |
data.update(**self.get_coarse_match(conf_matrix, data))
|
156 |
|
157 |
+
# update predicted offset
|
158 |
+
if (
|
159 |
+
flow_list[0].shape[2] == flow_list[1].shape[2]
|
160 |
+
and flow_list[0].shape[3] == flow_list[1].shape[3]
|
161 |
+
):
|
162 |
+
flow_list = torch.stack(flow_list, dim=0)
|
163 |
+
data.update({"predict_flow": flow_list}) # [2*L*B*H*W*4]
|
164 |
+
self.get_offset_match(flow_list, data, mask_c0, mask_c1)
|
165 |
|
166 |
@torch.no_grad()
|
167 |
def get_coarse_match(self, conf_matrix, data):
|
|
|
181 |
'mconf' (torch.Tensor): [M]}
|
182 |
"""
|
183 |
axes_lengths = {
|
184 |
+
"h0c": data["hw0_c"][0],
|
185 |
+
"w0c": data["hw0_c"][1],
|
186 |
+
"h1c": data["hw1_c"][0],
|
187 |
+
"w1c": data["hw1_c"][1],
|
188 |
}
|
189 |
_device = conf_matrix.device
|
190 |
# 1. confidence thresholding
|
191 |
mask = conf_matrix > self.thr
|
192 |
+
mask = rearrange(
|
193 |
+
mask, "b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c", **axes_lengths
|
194 |
+
)
|
195 |
+
if "mask0" not in data:
|
196 |
mask_border(mask, self.border_rm, False)
|
197 |
else:
|
198 |
+
mask_border_with_padding(
|
199 |
+
mask, self.border_rm, False, data["mask0"], data["mask1"]
|
200 |
+
)
|
201 |
+
mask = rearrange(
|
202 |
+
mask, "b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)", **axes_lengths
|
203 |
+
)
|
204 |
|
205 |
# 2. mutual nearest
|
206 |
+
mask = (
|
207 |
+
mask
|
208 |
+
* (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0])
|
209 |
* (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0])
|
210 |
+
)
|
211 |
|
212 |
# 3. find all valid coarse matches
|
213 |
# this only works when at most one `True` in each row
|
|
|
222 |
# NOTE:
|
223 |
# The sampling is performed across all pairs in a batch without manually balancing
|
224 |
# #samples for fine-level increases w.r.t. batch_size
|
225 |
+
if "mask0" not in data:
|
226 |
+
num_candidates_max = mask.size(0) * max(mask.size(1), mask.size(2))
|
|
|
227 |
else:
|
228 |
num_candidates_max = compute_max_candidates(
|
229 |
+
data["mask0"], data["mask1"]
|
230 |
+
)
|
231 |
+
num_matches_train = int(num_candidates_max * self.train_coarse_percent)
|
232 |
num_matches_pred = len(b_ids)
|
233 |
+
assert (
|
234 |
+
self.train_pad_num_gt_min < num_matches_train
|
235 |
+
), "min-num-gt-pad should be less than num-train-matches"
|
236 |
+
|
237 |
# pred_indices is to select from prediction
|
238 |
if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min:
|
239 |
pred_indices = torch.arange(num_matches_pred, device=_device)
|
240 |
else:
|
241 |
pred_indices = torch.randint(
|
242 |
num_matches_pred,
|
243 |
+
(num_matches_train - self.train_pad_num_gt_min,),
|
244 |
+
device=_device,
|
245 |
+
)
|
246 |
|
247 |
# gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200)
|
248 |
gt_pad_indices = torch.randint(
|
249 |
+
len(data["spv_b_ids"]),
|
250 |
+
(max(num_matches_train - num_matches_pred, self.train_pad_num_gt_min),),
|
251 |
+
device=_device,
|
252 |
+
)
|
253 |
+
mconf_gt = torch.zeros(
|
254 |
+
len(data["spv_b_ids"]), device=_device
|
255 |
+
) # set conf of gt paddings to all zero
|
256 |
|
257 |
b_ids, i_ids, j_ids, mconf = map(
|
258 |
+
lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]], dim=0),
|
259 |
+
*zip(
|
260 |
+
[b_ids, data["spv_b_ids"]],
|
261 |
+
[i_ids, data["spv_i_ids"]],
|
262 |
+
[j_ids, data["spv_j_ids"]],
|
263 |
+
[mconf, mconf_gt],
|
264 |
+
)
|
265 |
+
)
|
266 |
|
267 |
# These matches select patches that feed into fine-level network
|
268 |
+
coarse_matches = {"b_ids": b_ids, "i_ids": i_ids, "j_ids": j_ids}
|
269 |
|
270 |
# 4. Update with matches in original image resolution
|
271 |
+
scale = data["hw0_i"][0] / data["hw0_c"][0]
|
272 |
+
scale0 = scale * data["scale0"][b_ids] if "scale0" in data else scale
|
273 |
+
scale1 = scale * data["scale1"][b_ids] if "scale1" in data else scale
|
274 |
+
mkpts0_c = (
|
275 |
+
torch.stack([i_ids % data["hw0_c"][1], i_ids // data["hw0_c"][1]], dim=1)
|
276 |
+
* scale0
|
277 |
+
)
|
278 |
+
mkpts1_c = (
|
279 |
+
torch.stack([j_ids % data["hw1_c"][1], j_ids // data["hw1_c"][1]], dim=1)
|
280 |
+
* scale1
|
281 |
+
)
|
282 |
|
283 |
# These matches is the current prediction (for visualization)
|
284 |
+
coarse_matches.update(
|
285 |
+
{
|
286 |
+
"gt_mask": mconf == 0,
|
287 |
+
"m_bids": b_ids[mconf != 0], # mconf == 0 => gt matches
|
288 |
+
"mkpts0_c": mkpts0_c[mconf != 0],
|
289 |
+
"mkpts1_c": mkpts1_c[mconf != 0],
|
290 |
+
"mconf": mconf[mconf != 0],
|
291 |
+
}
|
292 |
+
)
|
293 |
|
294 |
return coarse_matches
|
295 |
|
296 |
@torch.no_grad()
|
297 |
+
def get_offset_match(self, flow_list, data, mask1, mask2):
|
298 |
"""
|
299 |
Args:
|
300 |
offset (torch.Tensor): [L, B, H, W, 2]
|
|
|
306 |
'mkpts1_c' (torch.Tensor): [M, 2],
|
307 |
'mconf' (torch.Tensor): [M]}
|
308 |
"""
|
309 |
+
offset1 = flow_list[0]
|
310 |
+
bs, layer_num = offset1.shape[1], offset1.shape[0]
|
311 |
+
|
312 |
+
# left side
|
313 |
+
offset1 = offset1.view(layer_num, bs, -1, 4)
|
314 |
+
conf1 = offset1[:, :, :, 2:].mean(dim=-1)
|
315 |
if mask1 is not None:
|
316 |
+
conf1.masked_fill_(~mask1.bool()[None].expand(layer_num, -1, -1), 100)
|
317 |
+
offset1 = offset1[:, :, :, :2]
|
318 |
+
self.get_offset_match_work(offset1, conf1, data, "left")
|
319 |
+
|
320 |
+
# rihgt side
|
321 |
+
if len(flow_list) == 2:
|
322 |
+
offset2 = flow_list[1].view(layer_num, bs, -1, 4)
|
323 |
+
conf2 = offset2[:, :, :, 2:].mean(dim=-1)
|
324 |
if mask2 is not None:
|
325 |
+
conf2.masked_fill_(~mask2.bool()[None].expand(layer_num, -1, -1), 100)
|
326 |
+
offset2 = offset2[:, :, :, :2]
|
327 |
+
self.get_offset_match_work(offset2, conf2, data, "right")
|
|
|
328 |
|
329 |
@torch.no_grad()
|
330 |
+
def get_offset_match_work(self, offset, conf, data, side):
|
331 |
+
bs, layer_num = offset.shape[1], offset.shape[0]
|
332 |
# 1. confidence thresholding
|
333 |
+
mask_conf = conf < 2
|
334 |
for index in range(bs):
|
335 |
+
mask_conf[:, index, 0] = True # safe guard in case that no match survives
|
336 |
# 3. find offset matches
|
337 |
+
scale = data["hw0_i"][0] / data["hw0_c"][0]
|
338 |
+
l_ids, b_ids, i_ids = torch.where(mask_conf)
|
339 |
+
j_coor = offset[l_ids, b_ids, i_ids, :2] * scale # [N,2]
|
340 |
+
i_coor = (
|
341 |
+
torch.stack([i_ids % data["hw0_c"][1], i_ids // data["hw0_c"][1]], dim=1)
|
342 |
+
* scale
|
343 |
+
)
|
344 |
+
# i_coor=torch.as_tensor([[index%data['hw0_c'][1],index//data['hw0_c'][1]] for index in i_ids]).cuda().float()*scale #[N,2]
|
345 |
# These matches is the current prediction (for visualization)
|
346 |
+
data.update(
|
347 |
+
{
|
348 |
+
"offset_bids_" + side: b_ids, # mconf == 0 => gt matches
|
349 |
+
"offset_lids_" + side: l_ids,
|
350 |
+
"conf" + side: conf[mask_conf],
|
351 |
+
}
|
352 |
+
)
|
353 |
+
|
354 |
+
if side == "right":
|
355 |
+
data.update(
|
356 |
+
{
|
357 |
+
"offset_kpts0_f_" + side: j_coor.detach(),
|
358 |
+
"offset_kpts1_f_" + side: i_coor,
|
359 |
+
}
|
360 |
+
)
|
361 |
else:
|
362 |
+
data.update(
|
363 |
+
{
|
364 |
+
"offset_kpts0_f_" + side: i_coor,
|
365 |
+
"offset_kpts1_f_" + side: j_coor.detach(),
|
366 |
+
}
|
367 |
+
)
|
third_party/ASpanFormer/src/ASpanFormer/utils/cvpr_ds_config.py
CHANGED
@@ -8,7 +8,7 @@ def lower_config(yacs_cfg):
|
|
8 |
|
9 |
|
10 |
_CN = CN()
|
11 |
-
_CN.BACKBONE_TYPE =
|
12 |
_CN.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)]
|
13 |
_CN.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd
|
14 |
_CN.FINE_CONCAT_COARSE_FEAT = True
|
@@ -23,15 +23,15 @@ _CN.COARSE = CN()
|
|
23 |
_CN.COARSE.D_MODEL = 256
|
24 |
_CN.COARSE.D_FFN = 256
|
25 |
_CN.COARSE.NHEAD = 8
|
26 |
-
_CN.COARSE.LAYER_NAMES = [
|
27 |
-
_CN.COARSE.ATTENTION =
|
28 |
_CN.COARSE.TEMP_BUG_FIX = False
|
29 |
|
30 |
# 3. Coarse-Matching config
|
31 |
_CN.MATCH_COARSE = CN()
|
32 |
_CN.MATCH_COARSE.THR = 0.1
|
33 |
_CN.MATCH_COARSE.BORDER_RM = 2
|
34 |
-
_CN.MATCH_COARSE.MATCH_TYPE =
|
35 |
_CN.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1
|
36 |
_CN.MATCH_COARSE.SKH_ITERS = 3
|
37 |
_CN.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0
|
@@ -44,7 +44,7 @@ _CN.FINE = CN()
|
|
44 |
_CN.FINE.D_MODEL = 128
|
45 |
_CN.FINE.D_FFN = 128
|
46 |
_CN.FINE.NHEAD = 8
|
47 |
-
_CN.FINE.LAYER_NAMES = [
|
48 |
-
_CN.FINE.ATTENTION =
|
49 |
|
50 |
default_cfg = lower_config(_CN)
|
|
|
8 |
|
9 |
|
10 |
_CN = CN()
|
11 |
+
_CN.BACKBONE_TYPE = "ResNetFPN"
|
12 |
_CN.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)]
|
13 |
_CN.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd
|
14 |
_CN.FINE_CONCAT_COARSE_FEAT = True
|
|
|
23 |
_CN.COARSE.D_MODEL = 256
|
24 |
_CN.COARSE.D_FFN = 256
|
25 |
_CN.COARSE.NHEAD = 8
|
26 |
+
_CN.COARSE.LAYER_NAMES = ["self", "cross"] * 4
|
27 |
+
_CN.COARSE.ATTENTION = "linear" # options: ['linear', 'full']
|
28 |
_CN.COARSE.TEMP_BUG_FIX = False
|
29 |
|
30 |
# 3. Coarse-Matching config
|
31 |
_CN.MATCH_COARSE = CN()
|
32 |
_CN.MATCH_COARSE.THR = 0.1
|
33 |
_CN.MATCH_COARSE.BORDER_RM = 2
|
34 |
+
_CN.MATCH_COARSE.MATCH_TYPE = "dual_softmax" # options: ['dual_softmax, 'sinkhorn']
|
35 |
_CN.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1
|
36 |
_CN.MATCH_COARSE.SKH_ITERS = 3
|
37 |
_CN.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0
|
|
|
44 |
_CN.FINE.D_MODEL = 128
|
45 |
_CN.FINE.D_FFN = 128
|
46 |
_CN.FINE.NHEAD = 8
|
47 |
+
_CN.FINE.LAYER_NAMES = ["self", "cross"] * 1
|
48 |
+
_CN.FINE.ATTENTION = "linear"
|
49 |
|
50 |
default_cfg = lower_config(_CN)
|
third_party/ASpanFormer/src/ASpanFormer/utils/fine_matching.py
CHANGED
@@ -26,35 +26,46 @@ class FineMatching(nn.Module):
|
|
26 |
"""
|
27 |
M, WW, C = feat_f0.shape
|
28 |
W = int(math.sqrt(WW))
|
29 |
-
scale = data[
|
30 |
self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale
|
31 |
|
32 |
# corner case: if no coarse matches found
|
33 |
if M == 0:
|
34 |
-
assert
|
|
|
|
|
35 |
# logger.warning('No matches found in coarse-level.')
|
36 |
-
data.update(
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
41 |
return
|
42 |
|
43 |
-
feat_f0_picked = feat_f0_picked = feat_f0[:, WW//2, :]
|
44 |
-
sim_matrix = torch.einsum(
|
45 |
-
softmax_temp = 1. / C
|
46 |
heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W)
|
47 |
|
48 |
# compute coordinates from heatmap
|
49 |
coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] # [M, 2]
|
50 |
-
grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(
|
|
|
|
|
51 |
|
52 |
# compute std over <x, y>
|
53 |
-
var =
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
56 |
# for fine-level supervision
|
57 |
-
data.update({
|
58 |
|
59 |
# compute absolute kpt coords
|
60 |
self.get_fine_match(coords_normalized, data)
|
@@ -64,11 +75,10 @@ class FineMatching(nn.Module):
|
|
64 |
W, WW, C, scale = self.W, self.WW, self.C, self.scale
|
65 |
|
66 |
# mkpts0_f and mkpts1_f
|
67 |
-
mkpts0_f = data[
|
68 |
-
scale1 = scale * data[
|
69 |
-
mkpts1_f =
|
|
|
|
|
70 |
|
71 |
-
data.update({
|
72 |
-
"mkpts0_f": mkpts0_f,
|
73 |
-
"mkpts1_f": mkpts1_f
|
74 |
-
})
|
|
|
26 |
"""
|
27 |
M, WW, C = feat_f0.shape
|
28 |
W = int(math.sqrt(WW))
|
29 |
+
scale = data["hw0_i"][0] / data["hw0_f"][0]
|
30 |
self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale
|
31 |
|
32 |
# corner case: if no coarse matches found
|
33 |
if M == 0:
|
34 |
+
assert (
|
35 |
+
self.training == False
|
36 |
+
), "M is always >0, when training, see coarse_matching.py"
|
37 |
# logger.warning('No matches found in coarse-level.')
|
38 |
+
data.update(
|
39 |
+
{
|
40 |
+
"expec_f": torch.empty(0, 3, device=feat_f0.device),
|
41 |
+
"mkpts0_f": data["mkpts0_c"],
|
42 |
+
"mkpts1_f": data["mkpts1_c"],
|
43 |
+
}
|
44 |
+
)
|
45 |
return
|
46 |
|
47 |
+
feat_f0_picked = feat_f0_picked = feat_f0[:, WW // 2, :]
|
48 |
+
sim_matrix = torch.einsum("mc,mrc->mr", feat_f0_picked, feat_f1)
|
49 |
+
softmax_temp = 1.0 / C**0.5
|
50 |
heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W)
|
51 |
|
52 |
# compute coordinates from heatmap
|
53 |
coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] # [M, 2]
|
54 |
+
grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(
|
55 |
+
1, -1, 2
|
56 |
+
) # [1, WW, 2]
|
57 |
|
58 |
# compute std over <x, y>
|
59 |
+
var = (
|
60 |
+
torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1)
|
61 |
+
- coords_normalized**2
|
62 |
+
) # [M, 2]
|
63 |
+
std = torch.sum(
|
64 |
+
torch.sqrt(torch.clamp(var, min=1e-10)), -1
|
65 |
+
) # [M] clamp needed for numerical stability
|
66 |
+
|
67 |
# for fine-level supervision
|
68 |
+
data.update({"expec_f": torch.cat([coords_normalized, std.unsqueeze(1)], -1)})
|
69 |
|
70 |
# compute absolute kpt coords
|
71 |
self.get_fine_match(coords_normalized, data)
|
|
|
75 |
W, WW, C, scale = self.W, self.WW, self.C, self.scale
|
76 |
|
77 |
# mkpts0_f and mkpts1_f
|
78 |
+
mkpts0_f = data["mkpts0_c"]
|
79 |
+
scale1 = scale * data["scale1"][data["b_ids"]] if "scale0" in data else scale
|
80 |
+
mkpts1_f = (
|
81 |
+
data["mkpts1_c"] + (coords_normed * (W // 2) * scale1)[: len(data["mconf"])]
|
82 |
+
)
|
83 |
|
84 |
+
data.update({"mkpts0_f": mkpts0_f, "mkpts1_f": mkpts1_f})
|
|
|
|
|
|
third_party/ASpanFormer/src/ASpanFormer/utils/geometry.py
CHANGED
@@ -3,10 +3,10 @@ import torch
|
|
3 |
|
4 |
@torch.no_grad()
|
5 |
def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
|
6 |
-
"""
|
7 |
Also check covisibility and depth consistency.
|
8 |
Depth is consistent if relative error < 0.2 (hard-coded).
|
9 |
-
|
10 |
Args:
|
11 |
kpts0 (torch.Tensor): [N, L, 2] - <x, y>,
|
12 |
depth0 (torch.Tensor): [N, H, W],
|
@@ -22,33 +22,52 @@ def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
|
|
22 |
|
23 |
# Sample depth, get calculable_mask on depth != 0
|
24 |
kpts0_depth = torch.stack(
|
25 |
-
[
|
|
|
|
|
|
|
|
|
26 |
) # (N, L)
|
27 |
nonzero_mask = kpts0_depth != 0
|
28 |
|
29 |
# Unproject
|
30 |
-
kpts0_h =
|
|
|
|
|
|
|
31 |
kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
|
32 |
|
33 |
# Rigid Transform
|
34 |
-
w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]]
|
35 |
w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
|
36 |
|
37 |
# Project
|
38 |
w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
|
39 |
-
w_kpts0 = w_kpts0_h[:, :, :2] / (
|
|
|
|
|
40 |
|
41 |
# Covisible Check
|
42 |
h, w = depth1.shape[1:3]
|
43 |
-
covisible_mask = (
|
44 |
-
(w_kpts0[:, :,
|
|
|
|
|
|
|
|
|
45 |
w_kpts0_long = w_kpts0.long()
|
46 |
w_kpts0_long[~covisible_mask, :] = 0
|
47 |
|
48 |
w_kpts0_depth = torch.stack(
|
49 |
-
[
|
|
|
|
|
|
|
|
|
50 |
) # (N, L)
|
51 |
-
consistent_mask = (
|
|
|
|
|
52 |
valid_mask = nonzero_mask * covisible_mask * consistent_mask
|
53 |
|
54 |
return valid_mask, w_kpts0
|
|
|
3 |
|
4 |
@torch.no_grad()
|
5 |
def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
|
6 |
+
"""Warp kpts0 from I0 to I1 with depth, K and Rt
|
7 |
Also check covisibility and depth consistency.
|
8 |
Depth is consistent if relative error < 0.2 (hard-coded).
|
9 |
+
|
10 |
Args:
|
11 |
kpts0 (torch.Tensor): [N, L, 2] - <x, y>,
|
12 |
depth0 (torch.Tensor): [N, H, W],
|
|
|
22 |
|
23 |
# Sample depth, get calculable_mask on depth != 0
|
24 |
kpts0_depth = torch.stack(
|
25 |
+
[
|
26 |
+
depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]]
|
27 |
+
for i in range(kpts0.shape[0])
|
28 |
+
],
|
29 |
+
dim=0,
|
30 |
) # (N, L)
|
31 |
nonzero_mask = kpts0_depth != 0
|
32 |
|
33 |
# Unproject
|
34 |
+
kpts0_h = (
|
35 |
+
torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1)
|
36 |
+
* kpts0_depth[..., None]
|
37 |
+
) # (N, L, 3)
|
38 |
kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
|
39 |
|
40 |
# Rigid Transform
|
41 |
+
w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
|
42 |
w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
|
43 |
|
44 |
# Project
|
45 |
w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
|
46 |
+
w_kpts0 = w_kpts0_h[:, :, :2] / (
|
47 |
+
w_kpts0_h[:, :, [2]] + 1e-4
|
48 |
+
) # (N, L, 2), +1e-4 to avoid zero depth
|
49 |
|
50 |
# Covisible Check
|
51 |
h, w = depth1.shape[1:3]
|
52 |
+
covisible_mask = (
|
53 |
+
(w_kpts0[:, :, 0] > 0)
|
54 |
+
* (w_kpts0[:, :, 0] < w - 1)
|
55 |
+
* (w_kpts0[:, :, 1] > 0)
|
56 |
+
* (w_kpts0[:, :, 1] < h - 1)
|
57 |
+
)
|
58 |
w_kpts0_long = w_kpts0.long()
|
59 |
w_kpts0_long[~covisible_mask, :] = 0
|
60 |
|
61 |
w_kpts0_depth = torch.stack(
|
62 |
+
[
|
63 |
+
depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]]
|
64 |
+
for i in range(w_kpts0_long.shape[0])
|
65 |
+
],
|
66 |
+
dim=0,
|
67 |
) # (N, L)
|
68 |
+
consistent_mask = (
|
69 |
+
(w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth
|
70 |
+
).abs() < 0.2
|
71 |
valid_mask = nonzero_mask * covisible_mask * consistent_mask
|
72 |
|
73 |
return valid_mask, w_kpts0
|
third_party/ASpanFormer/src/ASpanFormer/utils/position_encoding.py
CHANGED
@@ -8,7 +8,7 @@ class PositionEncodingSine(nn.Module):
|
|
8 |
This is a sinusoidal position encoding that generalized to 2-dimensional images
|
9 |
"""
|
10 |
|
11 |
-
def __init__(self, d_model, max_shape=(256, 256),pre_scaling=None):
|
12 |
"""
|
13 |
Args:
|
14 |
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
|
@@ -18,44 +18,63 @@ class PositionEncodingSine(nn.Module):
|
|
18 |
We will remove the buggy impl after re-training all variants of our released models.
|
19 |
"""
|
20 |
super().__init__()
|
21 |
-
self.d_model=d_model
|
22 |
-
self.max_shape=max_shape
|
23 |
-
self.pre_scaling=pre_scaling
|
24 |
|
25 |
pe = torch.zeros((d_model, *max_shape))
|
26 |
y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
|
27 |
x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
|
28 |
|
29 |
if pre_scaling[0] is not None and pre_scaling[1] is not None:
|
30 |
-
train_res,test_res=pre_scaling[0],pre_scaling[1]
|
31 |
-
x_position,y_position=
|
|
|
|
|
|
|
32 |
|
33 |
-
div_term = torch.exp(
|
|
|
|
|
|
|
34 |
div_term = div_term[:, None, None] # [C//4, 1, 1]
|
35 |
pe[0::4, :, :] = torch.sin(x_position * div_term)
|
36 |
pe[1::4, :, :] = torch.cos(x_position * div_term)
|
37 |
pe[2::4, :, :] = torch.sin(y_position * div_term)
|
38 |
pe[3::4, :, :] = torch.cos(y_position * div_term)
|
39 |
|
40 |
-
self.register_buffer(
|
41 |
|
42 |
-
def forward(self, x,scaling=None):
|
43 |
"""
|
44 |
Args:
|
45 |
x: [N, C, H, W]
|
46 |
"""
|
47 |
-
if scaling is None:
|
48 |
-
return
|
|
|
|
|
|
|
49 |
else:
|
50 |
pe = torch.zeros((self.d_model, *self.max_shape))
|
51 |
-
y_position =
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
div_term = div_term[:, None, None] # [C//4, 1, 1]
|
56 |
pe[0::4, :, :] = torch.sin(x_position * div_term)
|
57 |
pe[1::4, :, :] = torch.cos(x_position * div_term)
|
58 |
pe[2::4, :, :] = torch.sin(y_position * div_term)
|
59 |
pe[3::4, :, :] = torch.cos(y_position * div_term)
|
60 |
-
pe=pe.unsqueeze(0).to(x.device)
|
61 |
-
return
|
|
|
|
|
|
|
|
8 |
This is a sinusoidal position encoding that generalized to 2-dimensional images
|
9 |
"""
|
10 |
|
11 |
+
def __init__(self, d_model, max_shape=(256, 256), pre_scaling=None):
|
12 |
"""
|
13 |
Args:
|
14 |
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
|
|
|
18 |
We will remove the buggy impl after re-training all variants of our released models.
|
19 |
"""
|
20 |
super().__init__()
|
21 |
+
self.d_model = d_model
|
22 |
+
self.max_shape = max_shape
|
23 |
+
self.pre_scaling = pre_scaling
|
24 |
|
25 |
pe = torch.zeros((d_model, *max_shape))
|
26 |
y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
|
27 |
x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
|
28 |
|
29 |
if pre_scaling[0] is not None and pre_scaling[1] is not None:
|
30 |
+
train_res, test_res = pre_scaling[0], pre_scaling[1]
|
31 |
+
x_position, y_position = (
|
32 |
+
x_position * train_res[1] / test_res[1],
|
33 |
+
y_position * train_res[0] / test_res[0],
|
34 |
+
)
|
35 |
|
36 |
+
div_term = torch.exp(
|
37 |
+
torch.arange(0, d_model // 2, 2).float()
|
38 |
+
* (-math.log(10000.0) / (d_model // 2))
|
39 |
+
)
|
40 |
div_term = div_term[:, None, None] # [C//4, 1, 1]
|
41 |
pe[0::4, :, :] = torch.sin(x_position * div_term)
|
42 |
pe[1::4, :, :] = torch.cos(x_position * div_term)
|
43 |
pe[2::4, :, :] = torch.sin(y_position * div_term)
|
44 |
pe[3::4, :, :] = torch.cos(y_position * div_term)
|
45 |
|
46 |
+
self.register_buffer("pe", pe.unsqueeze(0), persistent=False) # [1, C, H, W]
|
47 |
|
48 |
+
def forward(self, x, scaling=None):
|
49 |
"""
|
50 |
Args:
|
51 |
x: [N, C, H, W]
|
52 |
"""
|
53 |
+
if scaling is None: # onliner scaling overwrites pre_scaling
|
54 |
+
return (
|
55 |
+
x + self.pe[:, :, : x.size(2), : x.size(3)],
|
56 |
+
self.pe[:, :, : x.size(2), : x.size(3)],
|
57 |
+
)
|
58 |
else:
|
59 |
pe = torch.zeros((self.d_model, *self.max_shape))
|
60 |
+
y_position = (
|
61 |
+
torch.ones(self.max_shape).cumsum(0).float().unsqueeze(0) * scaling[0]
|
62 |
+
)
|
63 |
+
x_position = (
|
64 |
+
torch.ones(self.max_shape).cumsum(1).float().unsqueeze(0) * scaling[1]
|
65 |
+
)
|
66 |
+
|
67 |
+
div_term = torch.exp(
|
68 |
+
torch.arange(0, self.d_model // 2, 2).float()
|
69 |
+
* (-math.log(10000.0) / (self.d_model // 2))
|
70 |
+
)
|
71 |
div_term = div_term[:, None, None] # [C//4, 1, 1]
|
72 |
pe[0::4, :, :] = torch.sin(x_position * div_term)
|
73 |
pe[1::4, :, :] = torch.cos(x_position * div_term)
|
74 |
pe[2::4, :, :] = torch.sin(y_position * div_term)
|
75 |
pe[3::4, :, :] = torch.cos(y_position * div_term)
|
76 |
+
pe = pe.unsqueeze(0).to(x.device)
|
77 |
+
return (
|
78 |
+
x + pe[:, :, : x.size(2), : x.size(3)],
|
79 |
+
pe[:, :, : x.size(2), : x.size(3)],
|
80 |
+
)
|
third_party/ASpanFormer/src/ASpanFormer/utils/supervision.py
CHANGED
@@ -13,7 +13,7 @@ from .geometry import warp_kpts
|
|
13 |
@torch.no_grad()
|
14 |
def mask_pts_at_padded_regions(grid_pt, mask):
|
15 |
"""For megadepth dataset, zero-padding exists in images"""
|
16 |
-
mask = repeat(mask,
|
17 |
grid_pt[~mask.bool()] = 0
|
18 |
return grid_pt
|
19 |
|
@@ -30,37 +30,55 @@ def spvs_coarse(data, config):
|
|
30 |
'spv_w_pt0_i': [N, hw0, 2], in original image resolution
|
31 |
'spv_pt1_i': [N, hw1, 2], in original image resolution
|
32 |
}
|
33 |
-
|
34 |
NOTE:
|
35 |
- for scannet dataset, there're 3 kinds of resolution {i, c, f}
|
36 |
- for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f}
|
37 |
"""
|
38 |
# 1. misc
|
39 |
-
device = data[
|
40 |
-
N, _, H0, W0 = data[
|
41 |
-
_, _, H1, W1 = data[
|
42 |
-
scale = config[
|
43 |
-
scale0 = scale * data[
|
44 |
-
scale1 = scale * data[
|
45 |
h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1])
|
46 |
|
47 |
# 2. warp grids
|
48 |
# create kpts in meshgrid and resize them to image resolution
|
49 |
-
grid_pt0_c =
|
|
|
|
|
50 |
grid_pt0_i = scale0 * grid_pt0_c
|
51 |
-
grid_pt1_c =
|
|
|
|
|
52 |
grid_pt1_i = scale1 * grid_pt1_c
|
53 |
|
54 |
# mask padded region to (0, 0), so no need to manually mask conf_matrix_gt
|
55 |
-
if
|
56 |
-
grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data[
|
57 |
-
grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data[
|
58 |
|
59 |
# warp kpts bi-directionally and resize them to coarse-level resolution
|
60 |
# (no depth consistency check, since it leads to worse results experimentally)
|
61 |
# (unhandled edge case: points with 0-depth will be warped to the left-up corner)
|
62 |
-
_, w_pt0_i = warp_kpts(
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
w_pt0_c = w_pt0_i / scale1
|
65 |
w_pt1_c = w_pt1_i / scale0
|
66 |
|
@@ -72,21 +90,26 @@ def spvs_coarse(data, config):
|
|
72 |
|
73 |
# corner case: out of boundary
|
74 |
def out_bound_mask(pt, w, h):
|
75 |
-
return (
|
|
|
|
|
|
|
76 |
nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0
|
77 |
nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0
|
78 |
|
79 |
-
loop_back = torch.stack(
|
80 |
-
|
|
|
|
|
81 |
correct_0to1[:, 0] = False # ignore the top-left corner
|
82 |
|
83 |
# 4. construct a gt conf_matrix
|
84 |
-
conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device)
|
85 |
b_ids, i_ids = torch.where(correct_0to1 != 0)
|
86 |
j_ids = nearest_index1[b_ids, i_ids]
|
87 |
|
88 |
conf_matrix_gt[b_ids, i_ids, j_ids] = 1
|
89 |
-
data.update({
|
90 |
|
91 |
# 5. save coarse matches(gt) for training fine level
|
92 |
if len(b_ids) == 0:
|
@@ -96,30 +119,26 @@ def spvs_coarse(data, config):
|
|
96 |
i_ids = torch.tensor([0], device=device)
|
97 |
j_ids = torch.tensor([0], device=device)
|
98 |
|
99 |
-
data.update({
|
100 |
-
'spv_b_ids': b_ids,
|
101 |
-
'spv_i_ids': i_ids,
|
102 |
-
'spv_j_ids': j_ids
|
103 |
-
})
|
104 |
|
105 |
# 6. save intermediate results (for fast fine-level computation)
|
106 |
-
data.update({
|
107 |
-
'spv_w_pt0_i': w_pt0_i,
|
108 |
-
'spv_pt1_i': grid_pt1_i
|
109 |
-
})
|
110 |
|
111 |
|
112 |
def compute_supervision_coarse(data, config):
|
113 |
-
assert
|
114 |
-
|
115 |
-
|
|
|
|
|
116 |
spvs_coarse(data, config)
|
117 |
else:
|
118 |
-
raise ValueError(f
|
119 |
|
120 |
|
121 |
############## ↓ Fine-Level supervision ↓ ##############
|
122 |
|
|
|
123 |
@torch.no_grad()
|
124 |
def spvs_fine(data, config):
|
125 |
"""
|
@@ -129,23 +148,25 @@ def spvs_fine(data, config):
|
|
129 |
"""
|
130 |
# 1. misc
|
131 |
# w_pt0_i, pt1_i = data.pop('spv_w_pt0_i'), data.pop('spv_pt1_i')
|
132 |
-
w_pt0_i, pt1_i = data[
|
133 |
-
scale = config[
|
134 |
-
radius = config[
|
135 |
|
136 |
# 2. get coarse prediction
|
137 |
-
b_ids, i_ids, j_ids = data[
|
138 |
|
139 |
# 3. compute gt
|
140 |
-
scale = scale * data[
|
141 |
# `expec_f_gt` might exceed the window, i.e. abs(*) > 1, which would be filtered later
|
142 |
-
expec_f_gt = (
|
|
|
|
|
143 |
data.update({"expec_f_gt": expec_f_gt})
|
144 |
|
145 |
|
146 |
def compute_supervision_fine(data, config):
|
147 |
-
data_source = data[
|
148 |
-
if data_source.lower() in [
|
149 |
spvs_fine(data, config)
|
150 |
else:
|
151 |
raise NotImplementedError
|
|
|
13 |
@torch.no_grad()
|
14 |
def mask_pts_at_padded_regions(grid_pt, mask):
|
15 |
"""For megadepth dataset, zero-padding exists in images"""
|
16 |
+
mask = repeat(mask, "n h w -> n (h w) c", c=2)
|
17 |
grid_pt[~mask.bool()] = 0
|
18 |
return grid_pt
|
19 |
|
|
|
30 |
'spv_w_pt0_i': [N, hw0, 2], in original image resolution
|
31 |
'spv_pt1_i': [N, hw1, 2], in original image resolution
|
32 |
}
|
33 |
+
|
34 |
NOTE:
|
35 |
- for scannet dataset, there're 3 kinds of resolution {i, c, f}
|
36 |
- for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f}
|
37 |
"""
|
38 |
# 1. misc
|
39 |
+
device = data["image0"].device
|
40 |
+
N, _, H0, W0 = data["image0"].shape
|
41 |
+
_, _, H1, W1 = data["image1"].shape
|
42 |
+
scale = config["ASPAN"]["RESOLUTION"][0]
|
43 |
+
scale0 = scale * data["scale0"][:, None] if "scale0" in data else scale
|
44 |
+
scale1 = scale * data["scale1"][:, None] if "scale0" in data else scale
|
45 |
h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1])
|
46 |
|
47 |
# 2. warp grids
|
48 |
# create kpts in meshgrid and resize them to image resolution
|
49 |
+
grid_pt0_c = (
|
50 |
+
create_meshgrid(h0, w0, False, device).reshape(1, h0 * w0, 2).repeat(N, 1, 1)
|
51 |
+
) # [N, hw, 2]
|
52 |
grid_pt0_i = scale0 * grid_pt0_c
|
53 |
+
grid_pt1_c = (
|
54 |
+
create_meshgrid(h1, w1, False, device).reshape(1, h1 * w1, 2).repeat(N, 1, 1)
|
55 |
+
)
|
56 |
grid_pt1_i = scale1 * grid_pt1_c
|
57 |
|
58 |
# mask padded region to (0, 0), so no need to manually mask conf_matrix_gt
|
59 |
+
if "mask0" in data:
|
60 |
+
grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data["mask0"])
|
61 |
+
grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data["mask1"])
|
62 |
|
63 |
# warp kpts bi-directionally and resize them to coarse-level resolution
|
64 |
# (no depth consistency check, since it leads to worse results experimentally)
|
65 |
# (unhandled edge case: points with 0-depth will be warped to the left-up corner)
|
66 |
+
_, w_pt0_i = warp_kpts(
|
67 |
+
grid_pt0_i,
|
68 |
+
data["depth0"],
|
69 |
+
data["depth1"],
|
70 |
+
data["T_0to1"],
|
71 |
+
data["K0"],
|
72 |
+
data["K1"],
|
73 |
+
)
|
74 |
+
_, w_pt1_i = warp_kpts(
|
75 |
+
grid_pt1_i,
|
76 |
+
data["depth1"],
|
77 |
+
data["depth0"],
|
78 |
+
data["T_1to0"],
|
79 |
+
data["K1"],
|
80 |
+
data["K0"],
|
81 |
+
)
|
82 |
w_pt0_c = w_pt0_i / scale1
|
83 |
w_pt1_c = w_pt1_i / scale0
|
84 |
|
|
|
90 |
|
91 |
# corner case: out of boundary
|
92 |
def out_bound_mask(pt, w, h):
|
93 |
+
return (
|
94 |
+
(pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
|
95 |
+
)
|
96 |
+
|
97 |
nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0
|
98 |
nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0
|
99 |
|
100 |
+
loop_back = torch.stack(
|
101 |
+
[nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0
|
102 |
+
)
|
103 |
+
correct_0to1 = loop_back == torch.arange(h0 * w0, device=device)[None].repeat(N, 1)
|
104 |
correct_0to1[:, 0] = False # ignore the top-left corner
|
105 |
|
106 |
# 4. construct a gt conf_matrix
|
107 |
+
conf_matrix_gt = torch.zeros(N, h0 * w0, h1 * w1, device=device)
|
108 |
b_ids, i_ids = torch.where(correct_0to1 != 0)
|
109 |
j_ids = nearest_index1[b_ids, i_ids]
|
110 |
|
111 |
conf_matrix_gt[b_ids, i_ids, j_ids] = 1
|
112 |
+
data.update({"conf_matrix_gt": conf_matrix_gt})
|
113 |
|
114 |
# 5. save coarse matches(gt) for training fine level
|
115 |
if len(b_ids) == 0:
|
|
|
119 |
i_ids = torch.tensor([0], device=device)
|
120 |
j_ids = torch.tensor([0], device=device)
|
121 |
|
122 |
+
data.update({"spv_b_ids": b_ids, "spv_i_ids": i_ids, "spv_j_ids": j_ids})
|
|
|
|
|
|
|
|
|
123 |
|
124 |
# 6. save intermediate results (for fast fine-level computation)
|
125 |
+
data.update({"spv_w_pt0_i": w_pt0_i, "spv_pt1_i": grid_pt1_i})
|
|
|
|
|
|
|
126 |
|
127 |
|
128 |
def compute_supervision_coarse(data, config):
|
129 |
+
assert (
|
130 |
+
len(set(data["dataset_name"])) == 1
|
131 |
+
), "Do not support mixed datasets training!"
|
132 |
+
data_source = data["dataset_name"][0]
|
133 |
+
if data_source.lower() in ["scannet", "megadepth"]:
|
134 |
spvs_coarse(data, config)
|
135 |
else:
|
136 |
+
raise ValueError(f"Unknown data source: {data_source}")
|
137 |
|
138 |
|
139 |
############## ↓ Fine-Level supervision ↓ ##############
|
140 |
|
141 |
+
|
142 |
@torch.no_grad()
|
143 |
def spvs_fine(data, config):
|
144 |
"""
|
|
|
148 |
"""
|
149 |
# 1. misc
|
150 |
# w_pt0_i, pt1_i = data.pop('spv_w_pt0_i'), data.pop('spv_pt1_i')
|
151 |
+
w_pt0_i, pt1_i = data["spv_w_pt0_i"], data["spv_pt1_i"]
|
152 |
+
scale = config["ASPAN"]["RESOLUTION"][1]
|
153 |
+
radius = config["ASPAN"]["FINE_WINDOW_SIZE"] // 2
|
154 |
|
155 |
# 2. get coarse prediction
|
156 |
+
b_ids, i_ids, j_ids = data["b_ids"], data["i_ids"], data["j_ids"]
|
157 |
|
158 |
# 3. compute gt
|
159 |
+
scale = scale * data["scale1"][b_ids] if "scale0" in data else scale
|
160 |
# `expec_f_gt` might exceed the window, i.e. abs(*) > 1, which would be filtered later
|
161 |
+
expec_f_gt = (
|
162 |
+
(w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scale / radius
|
163 |
+
) # [M, 2]
|
164 |
data.update({"expec_f_gt": expec_f_gt})
|
165 |
|
166 |
|
167 |
def compute_supervision_fine(data, config):
|
168 |
+
data_source = data["dataset_name"][0]
|
169 |
+
if data_source.lower() in ["scannet", "megadepth"]:
|
170 |
spvs_fine(data, config)
|
171 |
else:
|
172 |
raise NotImplementedError
|
third_party/ASpanFormer/src/config/default.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
from yacs.config import CfgNode as CN
|
|
|
2 |
_CN = CN()
|
3 |
|
4 |
############## ↓ ASPAN Pipeline ↓ ##############
|
5 |
_CN.ASPAN = CN()
|
6 |
-
_CN.ASPAN.BACKBONE_TYPE =
|
7 |
_CN.ASPAN.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)]
|
8 |
_CN.ASPAN.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd
|
9 |
_CN.ASPAN.FINE_CONCAT_COARSE_FEAT = True
|
@@ -17,14 +18,14 @@ _CN.ASPAN.RESNETFPN.BLOCK_DIMS = [128, 196, 256] # s1, s2, s3
|
|
17 |
_CN.ASPAN.COARSE = CN()
|
18 |
_CN.ASPAN.COARSE.D_MODEL = 256
|
19 |
_CN.ASPAN.COARSE.D_FFN = 256
|
20 |
-
_CN.ASPAN.COARSE.D_FLOW= 128
|
21 |
_CN.ASPAN.COARSE.NHEAD = 8
|
22 |
-
_CN.ASPAN.COARSE.NLEVEL= 3
|
23 |
-
_CN.ASPAN.COARSE.INI_LAYER_NUM =
|
24 |
-
_CN.ASPAN.COARSE.LAYER_NUM =
|
25 |
-
_CN.ASPAN.COARSE.NSAMPLE = [2,8]
|
26 |
-
_CN.ASPAN.COARSE.RADIUS_SCALE= 5
|
27 |
-
_CN.ASPAN.COARSE.COARSEST_LEVEL= [26,26]
|
28 |
_CN.ASPAN.COARSE.TRAIN_RES = None
|
29 |
_CN.ASPAN.COARSE.TEST_RES = None
|
30 |
|
@@ -32,7 +33,9 @@ _CN.ASPAN.COARSE.TEST_RES = None
|
|
32 |
_CN.ASPAN.MATCH_COARSE = CN()
|
33 |
_CN.ASPAN.MATCH_COARSE.THR = 0.2
|
34 |
_CN.ASPAN.MATCH_COARSE.BORDER_RM = 2
|
35 |
-
_CN.ASPAN.MATCH_COARSE.MATCH_TYPE =
|
|
|
|
|
36 |
_CN.ASPAN.MATCH_COARSE.SKH_ITERS = 3
|
37 |
_CN.ASPAN.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0
|
38 |
_CN.ASPAN.MATCH_COARSE.SKH_PREFILTER = False
|
@@ -46,13 +49,13 @@ _CN.ASPAN.FINE = CN()
|
|
46 |
_CN.ASPAN.FINE.D_MODEL = 128
|
47 |
_CN.ASPAN.FINE.D_FFN = 128
|
48 |
_CN.ASPAN.FINE.NHEAD = 8
|
49 |
-
_CN.ASPAN.FINE.LAYER_NAMES = [
|
50 |
-
_CN.ASPAN.FINE.ATTENTION =
|
51 |
|
52 |
# 5. ASPAN Losses
|
53 |
# -- # coarse-level
|
54 |
_CN.ASPAN.LOSS = CN()
|
55 |
-
_CN.ASPAN.LOSS.COARSE_TYPE =
|
56 |
_CN.ASPAN.LOSS.COARSE_WEIGHT = 1.0
|
57 |
# _CN.ASPAN.LOSS.SPARSE_SPVS = False
|
58 |
# -- - -- # focal loss (coarse)
|
@@ -64,7 +67,7 @@ _CN.ASPAN.LOSS.NEG_WEIGHT = 1.0
|
|
64 |
# use `_CN.ASPAN.MATCH_COARSE.MATCH_TYPE`
|
65 |
|
66 |
# -- # fine-level
|
67 |
-
_CN.ASPAN.LOSS.FINE_TYPE =
|
68 |
_CN.ASPAN.LOSS.FINE_WEIGHT = 1.0
|
69 |
_CN.ASPAN.LOSS.FINE_CORRECT_THR = 1.0 # for filtering valid fine-level gts (some gt matches might fall out of the fine-level window)
|
70 |
|
@@ -85,24 +88,32 @@ _CN.DATASET.TRAIN_INTRINSIC_PATH = None
|
|
85 |
_CN.DATASET.VAL_DATA_ROOT = None
|
86 |
_CN.DATASET.VAL_POSE_ROOT = None # (optional directory for poses)
|
87 |
_CN.DATASET.VAL_NPZ_ROOT = None
|
88 |
-
_CN.DATASET.VAL_LIST_PATH =
|
|
|
|
|
89 |
_CN.DATASET.VAL_INTRINSIC_PATH = None
|
90 |
# testing
|
91 |
_CN.DATASET.TEST_DATA_SOURCE = None
|
92 |
_CN.DATASET.TEST_DATA_ROOT = None
|
93 |
_CN.DATASET.TEST_POSE_ROOT = None # (optional directory for poses)
|
94 |
_CN.DATASET.TEST_NPZ_ROOT = None
|
95 |
-
_CN.DATASET.TEST_LIST_PATH =
|
|
|
|
|
96 |
_CN.DATASET.TEST_INTRINSIC_PATH = None
|
97 |
|
98 |
# 2. dataset config
|
99 |
# general options
|
100 |
-
_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN =
|
|
|
|
|
101 |
_CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
|
102 |
_CN.DATASET.AUGMENTATION_TYPE = None # options: [None, 'dark', 'mobile']
|
103 |
|
104 |
# MegaDepth options
|
105 |
-
_CN.DATASET.MGDPT_IMG_RESIZE =
|
|
|
|
|
106 |
_CN.DATASET.MGDPT_IMG_PAD = True # pad img to square with size = MGDPT_IMG_RESIZE
|
107 |
_CN.DATASET.MGDPT_DEPTH_PAD = True # pad depthmap to square with size = 2000
|
108 |
_CN.DATASET.MGDPT_DF = 8
|
@@ -118,17 +129,17 @@ _CN.TRAINER.FIND_LR = False # use learning rate finder from pytorch-lightning
|
|
118 |
# optimizer
|
119 |
_CN.TRAINER.OPTIMIZER = "adamw" # [adam, adamw]
|
120 |
_CN.TRAINER.TRUE_LR = None # this will be calculated automatically at runtime
|
121 |
-
_CN.TRAINER.ADAM_DECAY = 0. # ADAM: for adam
|
122 |
_CN.TRAINER.ADAMW_DECAY = 0.1
|
123 |
|
124 |
# step-based warm-up
|
125 |
-
_CN.TRAINER.WARMUP_TYPE =
|
126 |
-
_CN.TRAINER.WARMUP_RATIO = 0.
|
127 |
_CN.TRAINER.WARMUP_STEP = 4800
|
128 |
|
129 |
# learning rate scheduler
|
130 |
-
_CN.TRAINER.SCHEDULER =
|
131 |
-
_CN.TRAINER.SCHEDULER_INTERVAL =
|
132 |
_CN.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12] # MSLR: MultiStepLR
|
133 |
_CN.TRAINER.MSLR_GAMMA = 0.5
|
134 |
_CN.TRAINER.COSA_TMAX = 30 # COSA: CosineAnnealing
|
@@ -136,25 +147,33 @@ _CN.TRAINER.ELR_GAMMA = 0.999992 # ELR: ExponentialLR, this value for 'step' in
|
|
136 |
|
137 |
# plotting related
|
138 |
_CN.TRAINER.ENABLE_PLOTTING = True
|
139 |
-
_CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 32
|
140 |
-
_CN.TRAINER.PLOT_MODE =
|
141 |
-
_CN.TRAINER.PLOT_MATCHES_ALPHA =
|
142 |
|
143 |
# geometric metrics and pose solver
|
144 |
-
_CN.TRAINER.EPI_ERR_THR =
|
145 |
-
|
146 |
-
|
|
|
|
|
147 |
_CN.TRAINER.RANSAC_PIXEL_THR = 0.5
|
148 |
_CN.TRAINER.RANSAC_CONF = 0.99999
|
149 |
_CN.TRAINER.RANSAC_MAX_ITERS = 10000
|
150 |
_CN.TRAINER.USE_MAGSACPP = False
|
151 |
|
152 |
# data sampler for train_dataloader
|
153 |
-
_CN.TRAINER.DATA_SAMPLER =
|
|
|
|
|
154 |
# 'scene_balance' config
|
155 |
_CN.TRAINER.N_SAMPLES_PER_SUBSET = 200
|
156 |
-
_CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT =
|
157 |
-
|
|
|
|
|
|
|
|
|
158 |
_CN.TRAINER.SB_REPEAT = 1 # repeat N times for training the sampled data
|
159 |
# 'random' config
|
160 |
_CN.TRAINER.RDM_REPLACEMENT = True
|
|
|
1 |
from yacs.config import CfgNode as CN
|
2 |
+
|
3 |
_CN = CN()
|
4 |
|
5 |
############## ↓ ASPAN Pipeline ↓ ##############
|
6 |
_CN.ASPAN = CN()
|
7 |
+
_CN.ASPAN.BACKBONE_TYPE = "ResNetFPN"
|
8 |
_CN.ASPAN.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)]
|
9 |
_CN.ASPAN.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd
|
10 |
_CN.ASPAN.FINE_CONCAT_COARSE_FEAT = True
|
|
|
18 |
_CN.ASPAN.COARSE = CN()
|
19 |
_CN.ASPAN.COARSE.D_MODEL = 256
|
20 |
_CN.ASPAN.COARSE.D_FFN = 256
|
21 |
+
_CN.ASPAN.COARSE.D_FLOW = 128
|
22 |
_CN.ASPAN.COARSE.NHEAD = 8
|
23 |
+
_CN.ASPAN.COARSE.NLEVEL = 3
|
24 |
+
_CN.ASPAN.COARSE.INI_LAYER_NUM = 2
|
25 |
+
_CN.ASPAN.COARSE.LAYER_NUM = 4
|
26 |
+
_CN.ASPAN.COARSE.NSAMPLE = [2, 8]
|
27 |
+
_CN.ASPAN.COARSE.RADIUS_SCALE = 5
|
28 |
+
_CN.ASPAN.COARSE.COARSEST_LEVEL = [26, 26]
|
29 |
_CN.ASPAN.COARSE.TRAIN_RES = None
|
30 |
_CN.ASPAN.COARSE.TEST_RES = None
|
31 |
|
|
|
33 |
_CN.ASPAN.MATCH_COARSE = CN()
|
34 |
_CN.ASPAN.MATCH_COARSE.THR = 0.2
|
35 |
_CN.ASPAN.MATCH_COARSE.BORDER_RM = 2
|
36 |
+
_CN.ASPAN.MATCH_COARSE.MATCH_TYPE = (
|
37 |
+
"dual_softmax" # options: ['dual_softmax, 'sinkhorn']
|
38 |
+
)
|
39 |
_CN.ASPAN.MATCH_COARSE.SKH_ITERS = 3
|
40 |
_CN.ASPAN.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0
|
41 |
_CN.ASPAN.MATCH_COARSE.SKH_PREFILTER = False
|
|
|
49 |
_CN.ASPAN.FINE.D_MODEL = 128
|
50 |
_CN.ASPAN.FINE.D_FFN = 128
|
51 |
_CN.ASPAN.FINE.NHEAD = 8
|
52 |
+
_CN.ASPAN.FINE.LAYER_NAMES = ["self", "cross"] * 1
|
53 |
+
_CN.ASPAN.FINE.ATTENTION = "linear"
|
54 |
|
55 |
# 5. ASPAN Losses
|
56 |
# -- # coarse-level
|
57 |
_CN.ASPAN.LOSS = CN()
|
58 |
+
_CN.ASPAN.LOSS.COARSE_TYPE = "focal" # ['focal', 'cross_entropy']
|
59 |
_CN.ASPAN.LOSS.COARSE_WEIGHT = 1.0
|
60 |
# _CN.ASPAN.LOSS.SPARSE_SPVS = False
|
61 |
# -- - -- # focal loss (coarse)
|
|
|
67 |
# use `_CN.ASPAN.MATCH_COARSE.MATCH_TYPE`
|
68 |
|
69 |
# -- # fine-level
|
70 |
+
_CN.ASPAN.LOSS.FINE_TYPE = "l2_with_std" # ['l2_with_std', 'l2']
|
71 |
_CN.ASPAN.LOSS.FINE_WEIGHT = 1.0
|
72 |
_CN.ASPAN.LOSS.FINE_CORRECT_THR = 1.0 # for filtering valid fine-level gts (some gt matches might fall out of the fine-level window)
|
73 |
|
|
|
88 |
_CN.DATASET.VAL_DATA_ROOT = None
|
89 |
_CN.DATASET.VAL_POSE_ROOT = None # (optional directory for poses)
|
90 |
_CN.DATASET.VAL_NPZ_ROOT = None
|
91 |
+
_CN.DATASET.VAL_LIST_PATH = (
|
92 |
+
None # None if val data from all scenes are bundled into a single npz file
|
93 |
+
)
|
94 |
_CN.DATASET.VAL_INTRINSIC_PATH = None
|
95 |
# testing
|
96 |
_CN.DATASET.TEST_DATA_SOURCE = None
|
97 |
_CN.DATASET.TEST_DATA_ROOT = None
|
98 |
_CN.DATASET.TEST_POSE_ROOT = None # (optional directory for poses)
|
99 |
_CN.DATASET.TEST_NPZ_ROOT = None
|
100 |
+
_CN.DATASET.TEST_LIST_PATH = (
|
101 |
+
None # None if test data from all scenes are bundled into a single npz file
|
102 |
+
)
|
103 |
_CN.DATASET.TEST_INTRINSIC_PATH = None
|
104 |
|
105 |
# 2. dataset config
|
106 |
# general options
|
107 |
+
_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = (
|
108 |
+
0.4 # discard data with overlap_score < min_overlap_score
|
109 |
+
)
|
110 |
_CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
|
111 |
_CN.DATASET.AUGMENTATION_TYPE = None # options: [None, 'dark', 'mobile']
|
112 |
|
113 |
# MegaDepth options
|
114 |
+
_CN.DATASET.MGDPT_IMG_RESIZE = (
|
115 |
+
640 # resize the longer side, zero-pad bottom-right to square.
|
116 |
+
)
|
117 |
_CN.DATASET.MGDPT_IMG_PAD = True # pad img to square with size = MGDPT_IMG_RESIZE
|
118 |
_CN.DATASET.MGDPT_DEPTH_PAD = True # pad depthmap to square with size = 2000
|
119 |
_CN.DATASET.MGDPT_DF = 8
|
|
|
129 |
# optimizer
|
130 |
_CN.TRAINER.OPTIMIZER = "adamw" # [adam, adamw]
|
131 |
_CN.TRAINER.TRUE_LR = None # this will be calculated automatically at runtime
|
132 |
+
_CN.TRAINER.ADAM_DECAY = 0.0 # ADAM: for adam
|
133 |
_CN.TRAINER.ADAMW_DECAY = 0.1
|
134 |
|
135 |
# step-based warm-up
|
136 |
+
_CN.TRAINER.WARMUP_TYPE = "linear" # [linear, constant]
|
137 |
+
_CN.TRAINER.WARMUP_RATIO = 0.0
|
138 |
_CN.TRAINER.WARMUP_STEP = 4800
|
139 |
|
140 |
# learning rate scheduler
|
141 |
+
_CN.TRAINER.SCHEDULER = "MultiStepLR" # [MultiStepLR, CosineAnnealing, ExponentialLR]
|
142 |
+
_CN.TRAINER.SCHEDULER_INTERVAL = "epoch" # [epoch, step]
|
143 |
_CN.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12] # MSLR: MultiStepLR
|
144 |
_CN.TRAINER.MSLR_GAMMA = 0.5
|
145 |
_CN.TRAINER.COSA_TMAX = 30 # COSA: CosineAnnealing
|
|
|
147 |
|
148 |
# plotting related
|
149 |
_CN.TRAINER.ENABLE_PLOTTING = True
|
150 |
+
_CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 32 # number of val/test paris for plotting
|
151 |
+
_CN.TRAINER.PLOT_MODE = "evaluation" # ['evaluation', 'confidence']
|
152 |
+
_CN.TRAINER.PLOT_MATCHES_ALPHA = "dynamic"
|
153 |
|
154 |
# geometric metrics and pose solver
|
155 |
+
_CN.TRAINER.EPI_ERR_THR = (
|
156 |
+
5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue)
|
157 |
+
)
|
158 |
+
_CN.TRAINER.POSE_GEO_MODEL = "E" # ['E', 'F', 'H']
|
159 |
+
_CN.TRAINER.POSE_ESTIMATION_METHOD = "RANSAC" # [RANSAC, DEGENSAC, MAGSAC]
|
160 |
_CN.TRAINER.RANSAC_PIXEL_THR = 0.5
|
161 |
_CN.TRAINER.RANSAC_CONF = 0.99999
|
162 |
_CN.TRAINER.RANSAC_MAX_ITERS = 10000
|
163 |
_CN.TRAINER.USE_MAGSACPP = False
|
164 |
|
165 |
# data sampler for train_dataloader
|
166 |
+
_CN.TRAINER.DATA_SAMPLER = (
|
167 |
+
"scene_balance" # options: ['scene_balance', 'random', 'normal']
|
168 |
+
)
|
169 |
# 'scene_balance' config
|
170 |
_CN.TRAINER.N_SAMPLES_PER_SUBSET = 200
|
171 |
+
_CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = (
|
172 |
+
True # whether sample each scene with replacement or not
|
173 |
+
)
|
174 |
+
_CN.TRAINER.SB_SUBSET_SHUFFLE = (
|
175 |
+
True # after sampling from scenes, whether shuffle within the epoch or not
|
176 |
+
)
|
177 |
_CN.TRAINER.SB_REPEAT = 1 # repeat N times for training the sampled data
|
178 |
# 'random' config
|
179 |
_CN.TRAINER.RDM_REPLACEMENT = True
|
third_party/ASpanFormer/src/datasets/__init__.py
CHANGED
@@ -1,3 +1,2 @@
|
|
1 |
from .scannet import ScanNetDataset
|
2 |
from .megadepth import MegaDepthDataset
|
3 |
-
|
|
|
1 |
from .scannet import ScanNetDataset
|
2 |
from .megadepth import MegaDepthDataset
|
|
third_party/ASpanFormer/src/datasets/megadepth.py
CHANGED
@@ -9,20 +9,22 @@ from src.utils.dataset import read_megadepth_gray, read_megadepth_depth
|
|
9 |
|
10 |
|
11 |
class MegaDepthDataset(Dataset):
|
12 |
-
def __init__(
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
|
|
23 |
"""
|
24 |
Manage one scene(npz_path) of MegaDepth dataset.
|
25 |
-
|
26 |
Args:
|
27 |
root_dir (str): megadepth root directory that has `phoenix`.
|
28 |
npz_path (str): {scene_id}.npz path. This contains image pair information of a scene.
|
@@ -38,28 +40,36 @@ class MegaDepthDataset(Dataset):
|
|
38 |
super().__init__()
|
39 |
self.root_dir = root_dir
|
40 |
self.mode = mode
|
41 |
-
self.scene_id = npz_path.split(
|
42 |
|
43 |
# prepare scene_info and pair_info
|
44 |
-
if mode ==
|
45 |
-
logger.warning(
|
|
|
|
|
46 |
min_overlap_score = 0
|
47 |
self.scene_info = np.load(npz_path, allow_pickle=True)
|
48 |
-
self.pair_infos = self.scene_info[
|
49 |
-
del self.scene_info[
|
50 |
-
self.pair_infos = [
|
|
|
|
|
|
|
|
|
51 |
|
52 |
# parameters for image resizing, padding and depthmap padding
|
53 |
-
if mode ==
|
54 |
assert img_resize is not None and img_padding and depth_padding
|
55 |
self.img_resize = img_resize
|
56 |
self.df = df
|
57 |
self.img_padding = img_padding
|
58 |
-
self.depth_max_size =
|
|
|
|
|
59 |
|
60 |
# for training LoFTR
|
61 |
-
self.augment_fn = augment_fn if mode ==
|
62 |
-
self.coarse_scale = getattr(kwargs,
|
63 |
|
64 |
def __len__(self):
|
65 |
return len(self.pair_infos)
|
@@ -68,60 +78,77 @@ class MegaDepthDataset(Dataset):
|
|
68 |
(idx0, idx1), overlap_score, central_matches = self.pair_infos[idx]
|
69 |
|
70 |
# read grayscale image and mask. (1, h, w) and (h, w)
|
71 |
-
img_name0 = osp.join(self.root_dir, self.scene_info[
|
72 |
-
img_name1 = osp.join(self.root_dir, self.scene_info[
|
73 |
-
|
74 |
# TODO: Support augmentation & handle seeds for each worker correctly.
|
75 |
image0, mask0, scale0 = read_megadepth_gray(
|
76 |
-
img_name0, self.img_resize, self.df, self.img_padding, None
|
77 |
-
|
|
|
78 |
image1, mask1, scale1 = read_megadepth_gray(
|
79 |
-
img_name1, self.img_resize, self.df, self.img_padding, None
|
80 |
-
|
|
|
81 |
|
82 |
# read depth. shape: (h, w)
|
83 |
-
if self.mode in [
|
84 |
depth0 = read_megadepth_depth(
|
85 |
-
osp.join(self.root_dir, self.scene_info[
|
|
|
|
|
86 |
depth1 = read_megadepth_depth(
|
87 |
-
osp.join(self.root_dir, self.scene_info[
|
|
|
|
|
88 |
else:
|
89 |
depth0 = depth1 = torch.tensor([])
|
90 |
|
91 |
# read intrinsics of original size
|
92 |
-
K_0 = torch.tensor(
|
93 |
-
|
|
|
|
|
|
|
|
|
94 |
|
95 |
# read and compute relative poses
|
96 |
-
T0 = self.scene_info[
|
97 |
-
T1 = self.scene_info[
|
98 |
-
T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[
|
|
|
|
|
99 |
T_1to0 = T_0to1.inverse()
|
100 |
|
101 |
data = {
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
116 |
}
|
117 |
|
118 |
# for LoFTR training
|
119 |
if mask0 is not None: # img_padding is True
|
120 |
if self.coarse_scale:
|
121 |
-
[ts_mask_0, ts_mask_1] = F.interpolate(
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
|
|
126 |
|
127 |
return data
|
|
|
9 |
|
10 |
|
11 |
class MegaDepthDataset(Dataset):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
root_dir,
|
15 |
+
npz_path,
|
16 |
+
mode="train",
|
17 |
+
min_overlap_score=0.4,
|
18 |
+
img_resize=None,
|
19 |
+
df=None,
|
20 |
+
img_padding=False,
|
21 |
+
depth_padding=False,
|
22 |
+
augment_fn=None,
|
23 |
+
**kwargs
|
24 |
+
):
|
25 |
"""
|
26 |
Manage one scene(npz_path) of MegaDepth dataset.
|
27 |
+
|
28 |
Args:
|
29 |
root_dir (str): megadepth root directory that has `phoenix`.
|
30 |
npz_path (str): {scene_id}.npz path. This contains image pair information of a scene.
|
|
|
40 |
super().__init__()
|
41 |
self.root_dir = root_dir
|
42 |
self.mode = mode
|
43 |
+
self.scene_id = npz_path.split(".")[0]
|
44 |
|
45 |
# prepare scene_info and pair_info
|
46 |
+
if mode == "test" and min_overlap_score != 0:
|
47 |
+
logger.warning(
|
48 |
+
"You are using `min_overlap_score`!=0 in test mode. Set to 0."
|
49 |
+
)
|
50 |
min_overlap_score = 0
|
51 |
self.scene_info = np.load(npz_path, allow_pickle=True)
|
52 |
+
self.pair_infos = self.scene_info["pair_infos"].copy()
|
53 |
+
del self.scene_info["pair_infos"]
|
54 |
+
self.pair_infos = [
|
55 |
+
pair_info
|
56 |
+
for pair_info in self.pair_infos
|
57 |
+
if pair_info[1] > min_overlap_score
|
58 |
+
]
|
59 |
|
60 |
# parameters for image resizing, padding and depthmap padding
|
61 |
+
if mode == "train":
|
62 |
assert img_resize is not None and img_padding and depth_padding
|
63 |
self.img_resize = img_resize
|
64 |
self.df = df
|
65 |
self.img_padding = img_padding
|
66 |
+
self.depth_max_size = (
|
67 |
+
2000 if depth_padding else None
|
68 |
+
) # the upperbound of depthmaps size in megadepth.
|
69 |
|
70 |
# for training LoFTR
|
71 |
+
self.augment_fn = augment_fn if mode == "train" else None
|
72 |
+
self.coarse_scale = getattr(kwargs, "coarse_scale", 0.125)
|
73 |
|
74 |
def __len__(self):
|
75 |
return len(self.pair_infos)
|
|
|
78 |
(idx0, idx1), overlap_score, central_matches = self.pair_infos[idx]
|
79 |
|
80 |
# read grayscale image and mask. (1, h, w) and (h, w)
|
81 |
+
img_name0 = osp.join(self.root_dir, self.scene_info["image_paths"][idx0])
|
82 |
+
img_name1 = osp.join(self.root_dir, self.scene_info["image_paths"][idx1])
|
83 |
+
|
84 |
# TODO: Support augmentation & handle seeds for each worker correctly.
|
85 |
image0, mask0, scale0 = read_megadepth_gray(
|
86 |
+
img_name0, self.img_resize, self.df, self.img_padding, None
|
87 |
+
)
|
88 |
+
# np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
|
89 |
image1, mask1, scale1 = read_megadepth_gray(
|
90 |
+
img_name1, self.img_resize, self.df, self.img_padding, None
|
91 |
+
)
|
92 |
+
# np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
|
93 |
|
94 |
# read depth. shape: (h, w)
|
95 |
+
if self.mode in ["train", "val"]:
|
96 |
depth0 = read_megadepth_depth(
|
97 |
+
osp.join(self.root_dir, self.scene_info["depth_paths"][idx0]),
|
98 |
+
pad_to=self.depth_max_size,
|
99 |
+
)
|
100 |
depth1 = read_megadepth_depth(
|
101 |
+
osp.join(self.root_dir, self.scene_info["depth_paths"][idx1]),
|
102 |
+
pad_to=self.depth_max_size,
|
103 |
+
)
|
104 |
else:
|
105 |
depth0 = depth1 = torch.tensor([])
|
106 |
|
107 |
# read intrinsics of original size
|
108 |
+
K_0 = torch.tensor(
|
109 |
+
self.scene_info["intrinsics"][idx0].copy(), dtype=torch.float
|
110 |
+
).reshape(3, 3)
|
111 |
+
K_1 = torch.tensor(
|
112 |
+
self.scene_info["intrinsics"][idx1].copy(), dtype=torch.float
|
113 |
+
).reshape(3, 3)
|
114 |
|
115 |
# read and compute relative poses
|
116 |
+
T0 = self.scene_info["poses"][idx0]
|
117 |
+
T1 = self.scene_info["poses"][idx1]
|
118 |
+
T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[
|
119 |
+
:4, :4
|
120 |
+
] # (4, 4)
|
121 |
T_1to0 = T_0to1.inverse()
|
122 |
|
123 |
data = {
|
124 |
+
"image0": image0, # (1, h, w)
|
125 |
+
"depth0": depth0, # (h, w)
|
126 |
+
"image1": image1,
|
127 |
+
"depth1": depth1,
|
128 |
+
"T_0to1": T_0to1, # (4, 4)
|
129 |
+
"T_1to0": T_1to0,
|
130 |
+
"K0": K_0, # (3, 3)
|
131 |
+
"K1": K_1,
|
132 |
+
"scale0": scale0, # [scale_w, scale_h]
|
133 |
+
"scale1": scale1,
|
134 |
+
"dataset_name": "MegaDepth",
|
135 |
+
"scene_id": self.scene_id,
|
136 |
+
"pair_id": idx,
|
137 |
+
"pair_names": (
|
138 |
+
self.scene_info["image_paths"][idx0],
|
139 |
+
self.scene_info["image_paths"][idx1],
|
140 |
+
),
|
141 |
}
|
142 |
|
143 |
# for LoFTR training
|
144 |
if mask0 is not None: # img_padding is True
|
145 |
if self.coarse_scale:
|
146 |
+
[ts_mask_0, ts_mask_1] = F.interpolate(
|
147 |
+
torch.stack([mask0, mask1], dim=0)[None].float(),
|
148 |
+
scale_factor=self.coarse_scale,
|
149 |
+
mode="nearest",
|
150 |
+
recompute_scale_factor=False,
|
151 |
+
)[0].bool()
|
152 |
+
data.update({"mask0": ts_mask_0, "mask1": ts_mask_1})
|
153 |
|
154 |
return data
|
third_party/ASpanFormer/src/datasets/sampler.py
CHANGED
@@ -3,10 +3,10 @@ from torch.utils.data import Sampler, ConcatDataset
|
|
3 |
|
4 |
|
5 |
class RandomConcatSampler(Sampler):
|
6 |
-
"""
|
7 |
in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement.
|
8 |
However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase.
|
9 |
-
|
10 |
For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not.
|
11 |
Args:
|
12 |
shuffle (bool): shuffle the random sampled indices across all sub-datsets.
|
@@ -18,16 +18,19 @@ class RandomConcatSampler(Sampler):
|
|
18 |
TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs.
|
19 |
ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373
|
20 |
"""
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
28 |
if not isinstance(data_source, ConcatDataset):
|
29 |
raise TypeError("data_source should be torch.utils.data.ConcatDataset")
|
30 |
-
|
31 |
self.data_source = data_source
|
32 |
self.n_subset = len(self.data_source.datasets)
|
33 |
self.n_samples_per_subset = n_samples_per_subset
|
@@ -37,27 +40,37 @@ class RandomConcatSampler(Sampler):
|
|
37 |
self.shuffle = shuffle
|
38 |
self.generator = torch.manual_seed(seed)
|
39 |
assert self.repeat >= 1
|
40 |
-
|
41 |
def __len__(self):
|
42 |
return self.n_samples
|
43 |
-
|
44 |
def __iter__(self):
|
45 |
indices = []
|
46 |
# sample from each sub-dataset
|
47 |
for d_idx in range(self.n_subset):
|
48 |
-
low = 0 if d_idx==0 else self.data_source.cumulative_sizes[d_idx-1]
|
49 |
high = self.data_source.cumulative_sizes[d_idx]
|
50 |
if self.subset_replacement:
|
51 |
-
rand_tensor = torch.randint(
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
53 |
else: # sample without replacement
|
54 |
len_subset = len(self.data_source.datasets[d_idx])
|
55 |
rand_tensor = torch.randperm(len_subset, generator=self.generator) + low
|
56 |
if len_subset >= self.n_samples_per_subset:
|
57 |
-
rand_tensor = rand_tensor[:self.n_samples_per_subset]
|
58 |
-
else:
|
59 |
-
rand_tensor_replacement = torch.randint(
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
61 |
rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement])
|
62 |
indices.append(rand_tensor)
|
63 |
indices = torch.cat(indices)
|
@@ -72,6 +85,6 @@ class RandomConcatSampler(Sampler):
|
|
72 |
_choice = lambda x: x[torch.randperm(len(x), generator=self.generator)]
|
73 |
repeat_indices = map(_choice, repeat_indices)
|
74 |
indices = torch.cat([indices, *repeat_indices], 0)
|
75 |
-
|
76 |
assert indices.shape[0] == self.n_samples
|
77 |
return iter(indices.tolist())
|
|
|
3 |
|
4 |
|
5 |
class RandomConcatSampler(Sampler):
|
6 |
+
"""Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset
|
7 |
in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement.
|
8 |
However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase.
|
9 |
+
|
10 |
For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not.
|
11 |
Args:
|
12 |
shuffle (bool): shuffle the random sampled indices across all sub-datsets.
|
|
|
18 |
TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs.
|
19 |
ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373
|
20 |
"""
|
21 |
+
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
data_source: ConcatDataset,
|
25 |
+
n_samples_per_subset: int,
|
26 |
+
subset_replacement: bool = True,
|
27 |
+
shuffle: bool = True,
|
28 |
+
repeat: int = 1,
|
29 |
+
seed: int = None,
|
30 |
+
):
|
31 |
if not isinstance(data_source, ConcatDataset):
|
32 |
raise TypeError("data_source should be torch.utils.data.ConcatDataset")
|
33 |
+
|
34 |
self.data_source = data_source
|
35 |
self.n_subset = len(self.data_source.datasets)
|
36 |
self.n_samples_per_subset = n_samples_per_subset
|
|
|
40 |
self.shuffle = shuffle
|
41 |
self.generator = torch.manual_seed(seed)
|
42 |
assert self.repeat >= 1
|
43 |
+
|
44 |
def __len__(self):
|
45 |
return self.n_samples
|
46 |
+
|
47 |
def __iter__(self):
|
48 |
indices = []
|
49 |
# sample from each sub-dataset
|
50 |
for d_idx in range(self.n_subset):
|
51 |
+
low = 0 if d_idx == 0 else self.data_source.cumulative_sizes[d_idx - 1]
|
52 |
high = self.data_source.cumulative_sizes[d_idx]
|
53 |
if self.subset_replacement:
|
54 |
+
rand_tensor = torch.randint(
|
55 |
+
low,
|
56 |
+
high,
|
57 |
+
(self.n_samples_per_subset,),
|
58 |
+
generator=self.generator,
|
59 |
+
dtype=torch.int64,
|
60 |
+
)
|
61 |
else: # sample without replacement
|
62 |
len_subset = len(self.data_source.datasets[d_idx])
|
63 |
rand_tensor = torch.randperm(len_subset, generator=self.generator) + low
|
64 |
if len_subset >= self.n_samples_per_subset:
|
65 |
+
rand_tensor = rand_tensor[: self.n_samples_per_subset]
|
66 |
+
else: # padding with replacement
|
67 |
+
rand_tensor_replacement = torch.randint(
|
68 |
+
low,
|
69 |
+
high,
|
70 |
+
(self.n_samples_per_subset - len_subset,),
|
71 |
+
generator=self.generator,
|
72 |
+
dtype=torch.int64,
|
73 |
+
)
|
74 |
rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement])
|
75 |
indices.append(rand_tensor)
|
76 |
indices = torch.cat(indices)
|
|
|
85 |
_choice = lambda x: x[torch.randperm(len(x), generator=self.generator)]
|
86 |
repeat_indices = map(_choice, repeat_indices)
|
87 |
indices = torch.cat([indices, *repeat_indices], 0)
|
88 |
+
|
89 |
assert indices.shape[0] == self.n_samples
|
90 |
return iter(indices.tolist())
|
third_party/ASpanFormer/src/datasets/scannet.py
CHANGED
@@ -10,20 +10,22 @@ from src.utils.dataset import (
|
|
10 |
read_scannet_gray,
|
11 |
read_scannet_depth,
|
12 |
read_scannet_pose,
|
13 |
-
read_scannet_intrinsic
|
14 |
)
|
15 |
|
16 |
|
17 |
class ScanNetDataset(utils.data.Dataset):
|
18 |
-
def __init__(
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
|
|
27 |
"""Manage one scene of ScanNet Dataset.
|
28 |
Args:
|
29 |
root_dir (str): ScanNet root directory that contains scene folders.
|
@@ -41,73 +43,81 @@ class ScanNetDataset(utils.data.Dataset):
|
|
41 |
|
42 |
# prepare data_names, intrinsics and extrinsics(T)
|
43 |
with np.load(npz_path) as data:
|
44 |
-
self.data_names = data[
|
45 |
-
if
|
46 |
-
kept_mask = data[
|
47 |
self.data_names = self.data_names[kept_mask]
|
48 |
self.intrinsics = dict(np.load(intrinsic_path))
|
49 |
|
50 |
# for training LoFTR
|
51 |
-
self.augment_fn = augment_fn if mode ==
|
52 |
|
53 |
def __len__(self):
|
54 |
return len(self.data_names)
|
55 |
|
56 |
def _read_abs_pose(self, scene_name, name):
|
57 |
-
pth = osp.join(self.pose_dir,
|
58 |
-
scene_name,
|
59 |
-
'pose', f'{name}.txt')
|
60 |
return read_scannet_pose(pth)
|
61 |
|
62 |
def _compute_rel_pose(self, scene_name, name0, name1):
|
63 |
pose0 = self._read_abs_pose(scene_name, name0)
|
64 |
pose1 = self._read_abs_pose(scene_name, name1)
|
65 |
-
|
66 |
return np.matmul(pose1, inv(pose0)) # (4, 4)
|
67 |
|
68 |
def __getitem__(self, idx):
|
69 |
data_name = self.data_names[idx]
|
70 |
scene_name, scene_sub_name, stem_name_0, stem_name_1 = data_name
|
71 |
-
scene_name = f
|
72 |
|
73 |
# read the grayscale image which will be resized to (1, 480, 640)
|
74 |
-
img_name0 = osp.join(self.root_dir, scene_name,
|
75 |
-
img_name1 = osp.join(self.root_dir, scene_name,
|
76 |
# TODO: Support augmentation & handle seeds for each worker correctly.
|
77 |
image0 = read_scannet_gray(img_name0, resize=(640, 480), augment_fn=None)
|
78 |
-
|
79 |
image1 = read_scannet_gray(img_name1, resize=(640, 480), augment_fn=None)
|
80 |
-
|
81 |
|
82 |
# read the depthmap which is stored as (480, 640)
|
83 |
-
if self.mode in [
|
84 |
-
depth0 = read_scannet_depth(
|
85 |
-
|
|
|
|
|
|
|
|
|
86 |
else:
|
87 |
depth0 = depth1 = torch.tensor([])
|
88 |
|
89 |
# read the intrinsic of depthmap
|
90 |
-
K_0 = K_1 = torch.tensor(
|
|
|
|
|
91 |
|
92 |
# read and compute relative poses
|
93 |
-
T_0to1 = torch.tensor(
|
94 |
-
|
|
|
|
|
95 |
T_1to0 = T_0to1.inverse()
|
96 |
|
97 |
data = {
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
|
|
|
|
111 |
}
|
112 |
|
113 |
return data
|
|
|
10 |
read_scannet_gray,
|
11 |
read_scannet_depth,
|
12 |
read_scannet_pose,
|
13 |
+
read_scannet_intrinsic,
|
14 |
)
|
15 |
|
16 |
|
17 |
class ScanNetDataset(utils.data.Dataset):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
root_dir,
|
21 |
+
npz_path,
|
22 |
+
intrinsic_path,
|
23 |
+
mode="train",
|
24 |
+
min_overlap_score=0.4,
|
25 |
+
augment_fn=None,
|
26 |
+
pose_dir=None,
|
27 |
+
**kwargs,
|
28 |
+
):
|
29 |
"""Manage one scene of ScanNet Dataset.
|
30 |
Args:
|
31 |
root_dir (str): ScanNet root directory that contains scene folders.
|
|
|
43 |
|
44 |
# prepare data_names, intrinsics and extrinsics(T)
|
45 |
with np.load(npz_path) as data:
|
46 |
+
self.data_names = data["name"]
|
47 |
+
if "score" in data.keys() and mode not in ["val" or "test"]:
|
48 |
+
kept_mask = data["score"] > min_overlap_score
|
49 |
self.data_names = self.data_names[kept_mask]
|
50 |
self.intrinsics = dict(np.load(intrinsic_path))
|
51 |
|
52 |
# for training LoFTR
|
53 |
+
self.augment_fn = augment_fn if mode == "train" else None
|
54 |
|
55 |
def __len__(self):
|
56 |
return len(self.data_names)
|
57 |
|
58 |
def _read_abs_pose(self, scene_name, name):
|
59 |
+
pth = osp.join(self.pose_dir, scene_name, "pose", f"{name}.txt")
|
|
|
|
|
60 |
return read_scannet_pose(pth)
|
61 |
|
62 |
def _compute_rel_pose(self, scene_name, name0, name1):
|
63 |
pose0 = self._read_abs_pose(scene_name, name0)
|
64 |
pose1 = self._read_abs_pose(scene_name, name1)
|
65 |
+
|
66 |
return np.matmul(pose1, inv(pose0)) # (4, 4)
|
67 |
|
68 |
def __getitem__(self, idx):
|
69 |
data_name = self.data_names[idx]
|
70 |
scene_name, scene_sub_name, stem_name_0, stem_name_1 = data_name
|
71 |
+
scene_name = f"scene{scene_name:04d}_{scene_sub_name:02d}"
|
72 |
|
73 |
# read the grayscale image which will be resized to (1, 480, 640)
|
74 |
+
img_name0 = osp.join(self.root_dir, scene_name, "color", f"{stem_name_0}.jpg")
|
75 |
+
img_name1 = osp.join(self.root_dir, scene_name, "color", f"{stem_name_1}.jpg")
|
76 |
# TODO: Support augmentation & handle seeds for each worker correctly.
|
77 |
image0 = read_scannet_gray(img_name0, resize=(640, 480), augment_fn=None)
|
78 |
+
# augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
|
79 |
image1 = read_scannet_gray(img_name1, resize=(640, 480), augment_fn=None)
|
80 |
+
# augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
|
81 |
|
82 |
# read the depthmap which is stored as (480, 640)
|
83 |
+
if self.mode in ["train", "val"]:
|
84 |
+
depth0 = read_scannet_depth(
|
85 |
+
osp.join(self.root_dir, scene_name, "depth", f"{stem_name_0}.png")
|
86 |
+
)
|
87 |
+
depth1 = read_scannet_depth(
|
88 |
+
osp.join(self.root_dir, scene_name, "depth", f"{stem_name_1}.png")
|
89 |
+
)
|
90 |
else:
|
91 |
depth0 = depth1 = torch.tensor([])
|
92 |
|
93 |
# read the intrinsic of depthmap
|
94 |
+
K_0 = K_1 = torch.tensor(
|
95 |
+
self.intrinsics[scene_name].copy(), dtype=torch.float
|
96 |
+
).reshape(3, 3)
|
97 |
|
98 |
# read and compute relative poses
|
99 |
+
T_0to1 = torch.tensor(
|
100 |
+
self._compute_rel_pose(scene_name, stem_name_0, stem_name_1),
|
101 |
+
dtype=torch.float32,
|
102 |
+
)
|
103 |
T_1to0 = T_0to1.inverse()
|
104 |
|
105 |
data = {
|
106 |
+
"image0": image0, # (1, h, w)
|
107 |
+
"depth0": depth0, # (h, w)
|
108 |
+
"image1": image1,
|
109 |
+
"depth1": depth1,
|
110 |
+
"T_0to1": T_0to1, # (4, 4)
|
111 |
+
"T_1to0": T_1to0,
|
112 |
+
"K0": K_0, # (3, 3)
|
113 |
+
"K1": K_1,
|
114 |
+
"dataset_name": "ScanNet",
|
115 |
+
"scene_id": scene_name,
|
116 |
+
"pair_id": idx,
|
117 |
+
"pair_names": (
|
118 |
+
osp.join(scene_name, "color", f"{stem_name_0}.jpg"),
|
119 |
+
osp.join(scene_name, "color", f"{stem_name_1}.jpg"),
|
120 |
+
),
|
121 |
}
|
122 |
|
123 |
return data
|
third_party/ASpanFormer/src/lightning/data.py
CHANGED
@@ -16,7 +16,7 @@ from torch.utils.data import (
|
|
16 |
ConcatDataset,
|
17 |
DistributedSampler,
|
18 |
RandomSampler,
|
19 |
-
dataloader
|
20 |
)
|
21 |
|
22 |
from src.utils.augment import build_augmentor
|
@@ -29,10 +29,11 @@ from src.datasets.sampler import RandomConcatSampler
|
|
29 |
|
30 |
|
31 |
class MultiSceneDataModule(pl.LightningDataModule):
|
32 |
-
"""
|
33 |
For distributed training, each training process is assgined
|
34 |
only a part of the training scenes to reduce memory overhead.
|
35 |
"""
|
|
|
36 |
def __init__(self, args, config):
|
37 |
super().__init__()
|
38 |
|
@@ -60,47 +61,51 @@ class MultiSceneDataModule(pl.LightningDataModule):
|
|
60 |
|
61 |
# 2. dataset config
|
62 |
# general options
|
63 |
-
self.min_overlap_score_test =
|
|
|
|
|
64 |
self.min_overlap_score_train = config.DATASET.MIN_OVERLAP_SCORE_TRAIN
|
65 |
-
self.augment_fn = build_augmentor(
|
|
|
|
|
66 |
|
67 |
# MegaDepth options
|
68 |
self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE # 840
|
69 |
-
self.mgdpt_img_pad = config.DATASET.MGDPT_IMG_PAD
|
70 |
-
self.mgdpt_depth_pad = config.DATASET.MGDPT_DEPTH_PAD
|
71 |
self.mgdpt_df = config.DATASET.MGDPT_DF # 8
|
72 |
self.coarse_scale = 1 / config.ASPAN.RESOLUTION[0] # 0.125. for training loftr.
|
73 |
|
74 |
# 3.loader parameters
|
75 |
self.train_loader_params = {
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
}
|
80 |
self.val_loader_params = {
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
}
|
86 |
self.test_loader_params = {
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
}
|
92 |
-
|
93 |
# 4. sampler
|
94 |
self.data_sampler = config.TRAINER.DATA_SAMPLER
|
95 |
self.n_samples_per_subset = config.TRAINER.N_SAMPLES_PER_SUBSET
|
96 |
self.subset_replacement = config.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT
|
97 |
self.shuffle = config.TRAINER.SB_SUBSET_SHUFFLE
|
98 |
self.repeat = config.TRAINER.SB_REPEAT
|
99 |
-
|
100 |
# (optional) RandomSampler for debugging
|
101 |
|
102 |
# misc configurations
|
103 |
-
self.parallel_load_data = getattr(args,
|
104 |
self.seed = config.TRAINER.SEED # 66
|
105 |
|
106 |
def setup(self, stage=None):
|
@@ -110,7 +115,7 @@ class MultiSceneDataModule(pl.LightningDataModule):
|
|
110 |
stage (str): 'fit' in training phase, and 'test' in testing phase.
|
111 |
"""
|
112 |
|
113 |
-
assert stage in [
|
114 |
|
115 |
try:
|
116 |
self.world_size = dist.get_world_size()
|
@@ -121,73 +126,94 @@ class MultiSceneDataModule(pl.LightningDataModule):
|
|
121 |
self.rank = 0
|
122 |
logger.warning(str(ae) + " (set wolrd_size=1 and rank=0)")
|
123 |
|
124 |
-
if stage ==
|
125 |
self.train_dataset = self._setup_dataset(
|
126 |
self.train_data_root,
|
127 |
self.train_npz_root,
|
128 |
self.train_list_path,
|
129 |
self.train_intrinsic_path,
|
130 |
-
mode=
|
131 |
min_overlap_score=self.min_overlap_score_train,
|
132 |
-
pose_dir=self.train_pose_root
|
|
|
133 |
# setup multiple (optional) validation subsets
|
134 |
if isinstance(self.val_list_path, (list, tuple)):
|
135 |
self.val_dataset = []
|
136 |
if not isinstance(self.val_npz_root, (list, tuple)):
|
137 |
-
self.val_npz_root = [
|
|
|
|
|
138 |
for npz_list, npz_root in zip(self.val_list_path, self.val_npz_root):
|
139 |
-
self.val_dataset.append(
|
140 |
-
self.
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
|
|
|
|
|
|
147 |
else:
|
148 |
self.val_dataset = self._setup_dataset(
|
149 |
self.val_data_root,
|
150 |
self.val_npz_root,
|
151 |
self.val_list_path,
|
152 |
self.val_intrinsic_path,
|
153 |
-
mode=
|
154 |
min_overlap_score=self.min_overlap_score_test,
|
155 |
-
pose_dir=self.val_pose_root
|
156 |
-
|
|
|
157 |
else: # stage == 'test
|
158 |
self.test_dataset = self._setup_dataset(
|
159 |
self.test_data_root,
|
160 |
self.test_npz_root,
|
161 |
self.test_list_path,
|
162 |
self.test_intrinsic_path,
|
163 |
-
mode=
|
164 |
min_overlap_score=self.min_overlap_score_test,
|
165 |
-
pose_dir=self.test_pose_root
|
166 |
-
|
|
|
167 |
|
168 |
-
def _setup_dataset(
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
|
|
|
|
178 |
npz_names = [name.split()[0] for name in f.readlines()]
|
179 |
|
180 |
-
if mode ==
|
181 |
-
local_npz_names = get_local_split(
|
|
|
|
|
182 |
else:
|
183 |
local_npz_names = npz_names
|
184 |
-
logger.info(f
|
185 |
-
|
186 |
-
dataset_builder =
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
|
192 |
def _build_concat_dataset(
|
193 |
self,
|
@@ -196,49 +222,61 @@ class MultiSceneDataModule(pl.LightningDataModule):
|
|
196 |
npz_dir,
|
197 |
intrinsic_path,
|
198 |
mode,
|
199 |
-
min_overlap_score=0
|
200 |
-
pose_dir=None
|
201 |
):
|
202 |
datasets = []
|
203 |
-
augment_fn = self.augment_fn if mode ==
|
204 |
-
data_source =
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
if
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
# `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time.
|
216 |
npz_path = osp.join(npz_dir, npz_name)
|
217 |
-
if data_source ==
|
218 |
datasets.append(
|
219 |
-
ScanNetDataset(
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
|
|
|
|
|
|
227 |
datasets.append(
|
228 |
-
MegaDepthDataset(
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
|
|
|
|
|
|
238 |
else:
|
239 |
raise NotImplementedError()
|
240 |
return ConcatDataset(datasets)
|
241 |
-
|
242 |
def _build_concat_dataset_parallel(
|
243 |
self,
|
244 |
data_root,
|
@@ -246,78 +284,119 @@ class MultiSceneDataModule(pl.LightningDataModule):
|
|
246 |
npz_dir,
|
247 |
intrinsic_path,
|
248 |
mode,
|
249 |
-
min_overlap_score=0
|
250 |
pose_dir=None,
|
251 |
):
|
252 |
-
augment_fn = self.augment_fn if mode ==
|
253 |
-
data_source =
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
# TODO: _pickle.PicklingError: Could not pickle the task to send it to the workers.
|
273 |
raise NotImplementedError()
|
274 |
-
datasets = Parallel(
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
else:
|
289 |
-
raise ValueError(f
|
290 |
return ConcatDataset(datasets)
|
291 |
|
292 |
def train_dataloader(self):
|
293 |
-
"""
|
294 |
-
assert self.data_sampler in [
|
295 |
-
logger.info(
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
301 |
else:
|
302 |
sampler = None
|
303 |
-
dataloader = DataLoader(
|
|
|
|
|
304 |
return dataloader
|
305 |
-
|
306 |
def val_dataloader(self):
|
307 |
-
"""
|
308 |
-
logger.info(
|
|
|
|
|
309 |
if not isinstance(self.val_dataset, abc.Sequence):
|
310 |
sampler = DistributedSampler(self.val_dataset, shuffle=False)
|
311 |
-
return DataLoader(
|
|
|
|
|
312 |
else:
|
313 |
dataloaders = []
|
314 |
for dataset in self.val_dataset:
|
315 |
sampler = DistributedSampler(dataset, shuffle=False)
|
316 |
-
dataloaders.append(
|
|
|
|
|
317 |
return dataloaders
|
318 |
|
319 |
def test_dataloader(self, *args, **kwargs):
|
320 |
-
logger.info(
|
|
|
|
|
321 |
sampler = DistributedSampler(self.test_dataset, shuffle=False)
|
322 |
return DataLoader(self.test_dataset, sampler=sampler, **self.test_loader_params)
|
323 |
|
|
|
16 |
ConcatDataset,
|
17 |
DistributedSampler,
|
18 |
RandomSampler,
|
19 |
+
dataloader,
|
20 |
)
|
21 |
|
22 |
from src.utils.augment import build_augmentor
|
|
|
29 |
|
30 |
|
31 |
class MultiSceneDataModule(pl.LightningDataModule):
|
32 |
+
"""
|
33 |
For distributed training, each training process is assgined
|
34 |
only a part of the training scenes to reduce memory overhead.
|
35 |
"""
|
36 |
+
|
37 |
def __init__(self, args, config):
|
38 |
super().__init__()
|
39 |
|
|
|
61 |
|
62 |
# 2. dataset config
|
63 |
# general options
|
64 |
+
self.min_overlap_score_test = (
|
65 |
+
config.DATASET.MIN_OVERLAP_SCORE_TEST
|
66 |
+
) # 0.4, omit data with overlap_score < min_overlap_score
|
67 |
self.min_overlap_score_train = config.DATASET.MIN_OVERLAP_SCORE_TRAIN
|
68 |
+
self.augment_fn = build_augmentor(
|
69 |
+
config.DATASET.AUGMENTATION_TYPE
|
70 |
+
) # None, options: [None, 'dark', 'mobile']
|
71 |
|
72 |
# MegaDepth options
|
73 |
self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE # 840
|
74 |
+
self.mgdpt_img_pad = config.DATASET.MGDPT_IMG_PAD # True
|
75 |
+
self.mgdpt_depth_pad = config.DATASET.MGDPT_DEPTH_PAD # True
|
76 |
self.mgdpt_df = config.DATASET.MGDPT_DF # 8
|
77 |
self.coarse_scale = 1 / config.ASPAN.RESOLUTION[0] # 0.125. for training loftr.
|
78 |
|
79 |
# 3.loader parameters
|
80 |
self.train_loader_params = {
|
81 |
+
"batch_size": args.batch_size,
|
82 |
+
"num_workers": args.num_workers,
|
83 |
+
"pin_memory": getattr(args, "pin_memory", True),
|
84 |
}
|
85 |
self.val_loader_params = {
|
86 |
+
"batch_size": 1,
|
87 |
+
"shuffle": False,
|
88 |
+
"num_workers": args.num_workers,
|
89 |
+
"pin_memory": getattr(args, "pin_memory", True),
|
90 |
}
|
91 |
self.test_loader_params = {
|
92 |
+
"batch_size": 1,
|
93 |
+
"shuffle": False,
|
94 |
+
"num_workers": args.num_workers,
|
95 |
+
"pin_memory": True,
|
96 |
}
|
97 |
+
|
98 |
# 4. sampler
|
99 |
self.data_sampler = config.TRAINER.DATA_SAMPLER
|
100 |
self.n_samples_per_subset = config.TRAINER.N_SAMPLES_PER_SUBSET
|
101 |
self.subset_replacement = config.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT
|
102 |
self.shuffle = config.TRAINER.SB_SUBSET_SHUFFLE
|
103 |
self.repeat = config.TRAINER.SB_REPEAT
|
104 |
+
|
105 |
# (optional) RandomSampler for debugging
|
106 |
|
107 |
# misc configurations
|
108 |
+
self.parallel_load_data = getattr(args, "parallel_load_data", False)
|
109 |
self.seed = config.TRAINER.SEED # 66
|
110 |
|
111 |
def setup(self, stage=None):
|
|
|
115 |
stage (str): 'fit' in training phase, and 'test' in testing phase.
|
116 |
"""
|
117 |
|
118 |
+
assert stage in ["fit", "test"], "stage must be either fit or test"
|
119 |
|
120 |
try:
|
121 |
self.world_size = dist.get_world_size()
|
|
|
126 |
self.rank = 0
|
127 |
logger.warning(str(ae) + " (set wolrd_size=1 and rank=0)")
|
128 |
|
129 |
+
if stage == "fit":
|
130 |
self.train_dataset = self._setup_dataset(
|
131 |
self.train_data_root,
|
132 |
self.train_npz_root,
|
133 |
self.train_list_path,
|
134 |
self.train_intrinsic_path,
|
135 |
+
mode="train",
|
136 |
min_overlap_score=self.min_overlap_score_train,
|
137 |
+
pose_dir=self.train_pose_root,
|
138 |
+
)
|
139 |
# setup multiple (optional) validation subsets
|
140 |
if isinstance(self.val_list_path, (list, tuple)):
|
141 |
self.val_dataset = []
|
142 |
if not isinstance(self.val_npz_root, (list, tuple)):
|
143 |
+
self.val_npz_root = [
|
144 |
+
self.val_npz_root for _ in range(len(self.val_list_path))
|
145 |
+
]
|
146 |
for npz_list, npz_root in zip(self.val_list_path, self.val_npz_root):
|
147 |
+
self.val_dataset.append(
|
148 |
+
self._setup_dataset(
|
149 |
+
self.val_data_root,
|
150 |
+
npz_root,
|
151 |
+
npz_list,
|
152 |
+
self.val_intrinsic_path,
|
153 |
+
mode="val",
|
154 |
+
min_overlap_score=self.min_overlap_score_test,
|
155 |
+
pose_dir=self.val_pose_root,
|
156 |
+
)
|
157 |
+
)
|
158 |
else:
|
159 |
self.val_dataset = self._setup_dataset(
|
160 |
self.val_data_root,
|
161 |
self.val_npz_root,
|
162 |
self.val_list_path,
|
163 |
self.val_intrinsic_path,
|
164 |
+
mode="val",
|
165 |
min_overlap_score=self.min_overlap_score_test,
|
166 |
+
pose_dir=self.val_pose_root,
|
167 |
+
)
|
168 |
+
logger.info(f"[rank:{self.rank}] Train & Val Dataset loaded!")
|
169 |
else: # stage == 'test
|
170 |
self.test_dataset = self._setup_dataset(
|
171 |
self.test_data_root,
|
172 |
self.test_npz_root,
|
173 |
self.test_list_path,
|
174 |
self.test_intrinsic_path,
|
175 |
+
mode="test",
|
176 |
min_overlap_score=self.min_overlap_score_test,
|
177 |
+
pose_dir=self.test_pose_root,
|
178 |
+
)
|
179 |
+
logger.info(f"[rank:{self.rank}]: Test Dataset loaded!")
|
180 |
|
181 |
+
def _setup_dataset(
|
182 |
+
self,
|
183 |
+
data_root,
|
184 |
+
split_npz_root,
|
185 |
+
scene_list_path,
|
186 |
+
intri_path,
|
187 |
+
mode="train",
|
188 |
+
min_overlap_score=0.0,
|
189 |
+
pose_dir=None,
|
190 |
+
):
|
191 |
+
"""Setup train / val / test set"""
|
192 |
+
with open(scene_list_path, "r") as f:
|
193 |
npz_names = [name.split()[0] for name in f.readlines()]
|
194 |
|
195 |
+
if mode == "train":
|
196 |
+
local_npz_names = get_local_split(
|
197 |
+
npz_names, self.world_size, self.rank, self.seed
|
198 |
+
)
|
199 |
else:
|
200 |
local_npz_names = npz_names
|
201 |
+
logger.info(f"[rank {self.rank}]: {len(local_npz_names)} scene(s) assigned.")
|
202 |
+
|
203 |
+
dataset_builder = (
|
204 |
+
self._build_concat_dataset_parallel
|
205 |
+
if self.parallel_load_data
|
206 |
+
else self._build_concat_dataset
|
207 |
+
)
|
208 |
+
return dataset_builder(
|
209 |
+
data_root,
|
210 |
+
local_npz_names,
|
211 |
+
split_npz_root,
|
212 |
+
intri_path,
|
213 |
+
mode=mode,
|
214 |
+
min_overlap_score=min_overlap_score,
|
215 |
+
pose_dir=pose_dir,
|
216 |
+
)
|
217 |
|
218 |
def _build_concat_dataset(
|
219 |
self,
|
|
|
222 |
npz_dir,
|
223 |
intrinsic_path,
|
224 |
mode,
|
225 |
+
min_overlap_score=0.0,
|
226 |
+
pose_dir=None,
|
227 |
):
|
228 |
datasets = []
|
229 |
+
augment_fn = self.augment_fn if mode == "train" else None
|
230 |
+
data_source = (
|
231 |
+
self.trainval_data_source
|
232 |
+
if mode in ["train", "val"]
|
233 |
+
else self.test_data_source
|
234 |
+
)
|
235 |
+
if data_source == "GL3D" and mode == "val":
|
236 |
+
data_source = "MegaDepth"
|
237 |
+
if str(data_source).lower() == "megadepth":
|
238 |
+
npz_names = [f"{n}.npz" for n in npz_names]
|
239 |
+
if str(data_source).lower() == "gl3d":
|
240 |
+
npz_names = [f"{n}.txt" for n in npz_names]
|
241 |
+
# npz_names=npz_names[:8]
|
242 |
+
for npz_name in tqdm(
|
243 |
+
npz_names,
|
244 |
+
desc=f"[rank:{self.rank}] loading {mode} datasets",
|
245 |
+
disable=int(self.rank) != 0,
|
246 |
+
):
|
247 |
# `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time.
|
248 |
npz_path = osp.join(npz_dir, npz_name)
|
249 |
+
if data_source == "ScanNet":
|
250 |
datasets.append(
|
251 |
+
ScanNetDataset(
|
252 |
+
data_root,
|
253 |
+
npz_path,
|
254 |
+
intrinsic_path,
|
255 |
+
mode=mode,
|
256 |
+
min_overlap_score=min_overlap_score,
|
257 |
+
augment_fn=augment_fn,
|
258 |
+
pose_dir=pose_dir,
|
259 |
+
)
|
260 |
+
)
|
261 |
+
elif data_source == "MegaDepth":
|
262 |
datasets.append(
|
263 |
+
MegaDepthDataset(
|
264 |
+
data_root,
|
265 |
+
npz_path,
|
266 |
+
mode=mode,
|
267 |
+
min_overlap_score=min_overlap_score,
|
268 |
+
img_resize=self.mgdpt_img_resize,
|
269 |
+
df=self.mgdpt_df,
|
270 |
+
img_padding=self.mgdpt_img_pad,
|
271 |
+
depth_padding=self.mgdpt_depth_pad,
|
272 |
+
augment_fn=augment_fn,
|
273 |
+
coarse_scale=self.coarse_scale,
|
274 |
+
)
|
275 |
+
)
|
276 |
else:
|
277 |
raise NotImplementedError()
|
278 |
return ConcatDataset(datasets)
|
279 |
+
|
280 |
def _build_concat_dataset_parallel(
|
281 |
self,
|
282 |
data_root,
|
|
|
284 |
npz_dir,
|
285 |
intrinsic_path,
|
286 |
mode,
|
287 |
+
min_overlap_score=0.0,
|
288 |
pose_dir=None,
|
289 |
):
|
290 |
+
augment_fn = self.augment_fn if mode == "train" else None
|
291 |
+
data_source = (
|
292 |
+
self.trainval_data_source
|
293 |
+
if mode in ["train", "val"]
|
294 |
+
else self.test_data_source
|
295 |
+
)
|
296 |
+
if str(data_source).lower() == "megadepth":
|
297 |
+
npz_names = [f"{n}.npz" for n in npz_names]
|
298 |
+
# npz_names=npz_names[:8]
|
299 |
+
with tqdm_joblib(
|
300 |
+
tqdm(
|
301 |
+
desc=f"[rank:{self.rank}] loading {mode} datasets",
|
302 |
+
total=len(npz_names),
|
303 |
+
disable=int(self.rank) != 0,
|
304 |
+
)
|
305 |
+
):
|
306 |
+
if data_source == "ScanNet":
|
307 |
+
datasets = Parallel(
|
308 |
+
n_jobs=math.floor(
|
309 |
+
len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()
|
310 |
+
)
|
311 |
+
)(
|
312 |
+
delayed(
|
313 |
+
lambda x: _build_dataset(
|
314 |
+
ScanNetDataset,
|
315 |
+
data_root,
|
316 |
+
osp.join(npz_dir, x),
|
317 |
+
intrinsic_path,
|
318 |
+
mode=mode,
|
319 |
+
min_overlap_score=min_overlap_score,
|
320 |
+
augment_fn=augment_fn,
|
321 |
+
pose_dir=pose_dir,
|
322 |
+
)
|
323 |
+
)(name)
|
324 |
+
for name in npz_names
|
325 |
+
)
|
326 |
+
elif data_source == "MegaDepth":
|
327 |
# TODO: _pickle.PicklingError: Could not pickle the task to send it to the workers.
|
328 |
raise NotImplementedError()
|
329 |
+
datasets = Parallel(
|
330 |
+
n_jobs=math.floor(
|
331 |
+
len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()
|
332 |
+
)
|
333 |
+
)(
|
334 |
+
delayed(
|
335 |
+
lambda x: _build_dataset(
|
336 |
+
MegaDepthDataset,
|
337 |
+
data_root,
|
338 |
+
osp.join(npz_dir, x),
|
339 |
+
mode=mode,
|
340 |
+
min_overlap_score=min_overlap_score,
|
341 |
+
img_resize=self.mgdpt_img_resize,
|
342 |
+
df=self.mgdpt_df,
|
343 |
+
img_padding=self.mgdpt_img_pad,
|
344 |
+
depth_padding=self.mgdpt_depth_pad,
|
345 |
+
augment_fn=augment_fn,
|
346 |
+
coarse_scale=self.coarse_scale,
|
347 |
+
)
|
348 |
+
)(name)
|
349 |
+
for name in npz_names
|
350 |
+
)
|
351 |
else:
|
352 |
+
raise ValueError(f"Unknown dataset: {data_source}")
|
353 |
return ConcatDataset(datasets)
|
354 |
|
355 |
def train_dataloader(self):
|
356 |
+
"""Build training dataloader for ScanNet / MegaDepth."""
|
357 |
+
assert self.data_sampler in ["scene_balance"]
|
358 |
+
logger.info(
|
359 |
+
f"[rank:{self.rank}/{self.world_size}]: Train Sampler and DataLoader re-init (should not re-init between epochs!)."
|
360 |
+
)
|
361 |
+
if self.data_sampler == "scene_balance":
|
362 |
+
sampler = RandomConcatSampler(
|
363 |
+
self.train_dataset,
|
364 |
+
self.n_samples_per_subset,
|
365 |
+
self.subset_replacement,
|
366 |
+
self.shuffle,
|
367 |
+
self.repeat,
|
368 |
+
self.seed,
|
369 |
+
)
|
370 |
else:
|
371 |
sampler = None
|
372 |
+
dataloader = DataLoader(
|
373 |
+
self.train_dataset, sampler=sampler, **self.train_loader_params
|
374 |
+
)
|
375 |
return dataloader
|
376 |
+
|
377 |
def val_dataloader(self):
|
378 |
+
"""Build validation dataloader for ScanNet / MegaDepth."""
|
379 |
+
logger.info(
|
380 |
+
f"[rank:{self.rank}/{self.world_size}]: Val Sampler and DataLoader re-init."
|
381 |
+
)
|
382 |
if not isinstance(self.val_dataset, abc.Sequence):
|
383 |
sampler = DistributedSampler(self.val_dataset, shuffle=False)
|
384 |
+
return DataLoader(
|
385 |
+
self.val_dataset, sampler=sampler, **self.val_loader_params
|
386 |
+
)
|
387 |
else:
|
388 |
dataloaders = []
|
389 |
for dataset in self.val_dataset:
|
390 |
sampler = DistributedSampler(dataset, shuffle=False)
|
391 |
+
dataloaders.append(
|
392 |
+
DataLoader(dataset, sampler=sampler, **self.val_loader_params)
|
393 |
+
)
|
394 |
return dataloaders
|
395 |
|
396 |
def test_dataloader(self, *args, **kwargs):
|
397 |
+
logger.info(
|
398 |
+
f"[rank:{self.rank}/{self.world_size}]: Test Sampler and DataLoader re-init."
|
399 |
+
)
|
400 |
sampler = DistributedSampler(self.test_dataset, shuffle=False)
|
401 |
return DataLoader(self.test_dataset, sampler=sampler, **self.test_loader_params)
|
402 |
|
third_party/ASpanFormer/src/lightning/lightning_aspanformer.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
|
2 |
from collections import defaultdict
|
3 |
import pprint
|
4 |
from loguru import logger
|
@@ -10,15 +9,19 @@ import pytorch_lightning as pl
|
|
10 |
from matplotlib import pyplot as plt
|
11 |
|
12 |
from src.ASpanFormer.aspanformer import ASpanFormer
|
13 |
-
from src.ASpanFormer.utils.supervision import
|
|
|
|
|
|
|
14 |
from src.losses.aspan_loss import ASpanLoss
|
15 |
from src.optimizers import build_optimizer, build_scheduler
|
16 |
from src.utils.metrics import (
|
17 |
-
compute_symmetrical_epipolar_errors,
|
|
|
18 |
compute_pose_errors,
|
19 |
-
aggregate_metrics
|
20 |
)
|
21 |
-
from src.utils.plotting import make_matching_figures,make_matching_figures_offset
|
22 |
from src.utils.comm import gather, all_gather
|
23 |
from src.utils.misc import lower_config, flattenList
|
24 |
from src.utils.profiler import PassThroughProfiler
|
@@ -34,200 +37,288 @@ class PL_ASpanFormer(pl.LightningModule):
|
|
34 |
# Misc
|
35 |
self.config = config # full config
|
36 |
_config = lower_config(self.config)
|
37 |
-
self.loftr_cfg = lower_config(_config[
|
38 |
self.profiler = profiler or PassThroughProfiler()
|
39 |
-
self.n_vals_plot = max(
|
|
|
|
|
40 |
|
41 |
# Matcher: LoFTR
|
42 |
-
self.matcher = ASpanFormer(config=_config[
|
43 |
self.loss = ASpanLoss(_config)
|
44 |
|
45 |
# Pretrained weights
|
46 |
print(pretrained_ckpt)
|
47 |
if pretrained_ckpt:
|
48 |
-
print(
|
49 |
-
state_dict = torch.load(pretrained_ckpt, map_location=
|
50 |
-
msg=self.matcher.load_state_dict(state_dict, strict=False)
|
51 |
print(msg)
|
52 |
-
logger.info(f"Load
|
53 |
-
|
54 |
# Testing
|
55 |
self.dump_dir = dump_dir
|
56 |
-
|
57 |
def configure_optimizers(self):
|
58 |
# FIXME: The scheduler did not work properly when `--resume_from_checkpoint`
|
59 |
optimizer = build_optimizer(self, self.config)
|
60 |
scheduler = build_scheduler(self.config, optimizer)
|
61 |
return [optimizer], [scheduler]
|
62 |
-
|
63 |
def optimizer_step(
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
# learning rate warm up
|
67 |
warmup_step = self.config.TRAINER.WARMUP_STEP
|
68 |
if self.trainer.global_step < warmup_step:
|
69 |
-
if self.config.TRAINER.WARMUP_TYPE ==
|
70 |
base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR
|
71 |
-
lr = base_lr +
|
72 |
-
|
73 |
-
|
74 |
for pg in optimizer.param_groups:
|
75 |
-
pg[
|
76 |
-
elif self.config.TRAINER.WARMUP_TYPE ==
|
77 |
pass
|
78 |
else:
|
79 |
-
raise ValueError(
|
|
|
|
|
80 |
|
81 |
# update params
|
82 |
optimizer.step(closure=optimizer_closure)
|
83 |
optimizer.zero_grad()
|
84 |
-
|
85 |
def _trainval_inference(self, batch):
|
86 |
with self.profiler.profile("Compute coarse supervision"):
|
87 |
-
compute_supervision_coarse(batch, self.config)
|
88 |
-
|
89 |
with self.profiler.profile("LoFTR"):
|
90 |
-
self.matcher(batch)
|
91 |
-
|
92 |
with self.profiler.profile("Compute fine supervision"):
|
93 |
-
compute_supervision_fine(batch, self.config)
|
94 |
-
|
95 |
with self.profiler.profile("Compute losses"):
|
96 |
-
self.loss(batch)
|
97 |
-
|
98 |
def _compute_metrics(self, batch):
|
99 |
with self.profiler.profile("Copmute metrics"):
|
100 |
-
compute_symmetrical_epipolar_errors(
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
-
rel_pair_names = list(zip(*batch[
|
105 |
-
bs = batch[
|
106 |
metrics = {
|
107 |
# to filter duplicate pairs caused by DistributedSampler
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
return ret_dict, rel_pair_names
|
116 |
-
|
117 |
-
|
118 |
def training_step(self, batch, batch_idx):
|
119 |
self._trainval_inference(batch)
|
120 |
-
|
121 |
# logging
|
122 |
-
if
|
|
|
|
|
|
|
123 |
# scalars
|
124 |
-
for k, v in batch[
|
125 |
-
if not k.startswith(
|
126 |
-
self.logger.experiment.add_scalar(f
|
127 |
-
|
128 |
-
#log offset_loss and conf for each layer and level
|
129 |
-
layer_num=self.loftr_cfg[
|
130 |
for layer_index in range(layer_num):
|
131 |
-
log_title=
|
132 |
-
self.logger.experiment.add_scalar(
|
133 |
-
|
134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
# net-params
|
136 |
-
if self.config.ASPAN.MATCH_COARSE.MATCH_TYPE ==
|
137 |
self.logger.experiment.add_scalar(
|
138 |
-
f
|
|
|
|
|
|
|
139 |
|
140 |
# figures
|
141 |
if self.config.TRAINER.ENABLE_PLOTTING:
|
142 |
-
compute_symmetrical_epipolar_errors(
|
143 |
-
|
|
|
|
|
|
|
|
|
144 |
for k, v in figures.items():
|
145 |
-
self.logger.experiment.add_figure(
|
|
|
|
|
146 |
|
147 |
-
#plot offset
|
148 |
-
if self.global_step%200==0:
|
149 |
compute_symmetrical_epipolar_errors_offset_bidirectional(batch)
|
150 |
-
figures_left = make_matching_figures_offset(
|
151 |
-
|
|
|
|
|
|
|
|
|
152 |
for k, v in figures_left.items():
|
153 |
-
self.logger.experiment.add_figure(
|
154 |
-
|
|
|
|
|
|
|
|
|
155 |
for k, v in figures_right.items():
|
156 |
-
self.logger.experiment.add_figure(
|
157 |
-
|
158 |
-
|
|
|
|
|
159 |
|
160 |
def training_epoch_end(self, outputs):
|
161 |
-
avg_loss = torch.stack([x[
|
162 |
if self.trainer.global_rank == 0:
|
163 |
self.logger.experiment.add_scalar(
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
def validation_step(self, batch, batch_idx):
|
168 |
self._trainval_inference(batch)
|
169 |
-
|
170 |
-
ret_dict, _ = self._compute_metrics(
|
171 |
-
|
|
|
|
|
172 |
val_plot_interval = max(self.trainer.num_val_batches[0] // self.n_vals_plot, 1)
|
173 |
figures = {self.config.TRAINER.PLOT_MODE: []}
|
174 |
figures_offset = {self.config.TRAINER.PLOT_MODE: []}
|
175 |
if batch_idx % val_plot_interval == 0:
|
176 |
-
figures = make_matching_figures(
|
177 |
-
|
|
|
|
|
|
|
|
|
178 |
return {
|
179 |
**ret_dict,
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
}
|
184 |
-
|
185 |
def validation_epoch_end(self, outputs):
|
186 |
# handle multiple validation sets
|
187 |
-
multi_outputs =
|
|
|
|
|
188 |
multi_val_metrics = defaultdict(list)
|
189 |
-
|
190 |
for valset_idx, outputs in enumerate(multi_outputs):
|
191 |
# since pl performs sanity_check at the very begining of the training
|
192 |
cur_epoch = self.trainer.current_epoch
|
193 |
-
if
|
|
|
|
|
|
|
194 |
cur_epoch = -1
|
195 |
|
196 |
# 1. loss_scalars: dict of list, on cpu
|
197 |
-
_loss_scalars = [o[
|
198 |
-
loss_scalars = {
|
|
|
|
|
|
|
199 |
|
200 |
# 2. val metrics: dict of list, numpy
|
201 |
-
_metrics = [o[
|
202 |
-
metrics = {
|
203 |
-
|
204 |
-
|
|
|
|
|
|
|
|
|
|
|
205 |
for thr in [5, 10, 20]:
|
206 |
-
multi_val_metrics[f
|
207 |
-
|
208 |
# 3. figures
|
209 |
-
_figures = [o[
|
210 |
-
figures = {
|
|
|
|
|
|
|
211 |
|
212 |
# tensorboard records only on rank 0
|
213 |
if self.trainer.global_rank == 0:
|
214 |
for k, v in loss_scalars.items():
|
215 |
mean_v = torch.stack(v).mean()
|
216 |
-
self.logger.experiment.add_scalar(
|
|
|
|
|
217 |
|
218 |
for k, v in val_metrics_4tb.items():
|
219 |
-
self.logger.experiment.add_scalar(
|
220 |
-
|
|
|
|
|
221 |
for k, v in figures.items():
|
222 |
if self.trainer.global_rank == 0:
|
223 |
for plot_idx, fig in enumerate(v):
|
224 |
self.logger.experiment.add_figure(
|
225 |
-
f
|
226 |
-
|
|
|
|
|
|
|
|
|
227 |
|
228 |
for thr in [5, 10, 20]:
|
229 |
# log on all ranks for ModelCheckpoint callback to work properly
|
230 |
-
self.log(
|
|
|
|
|
231 |
|
232 |
def test_step(self, batch, batch_idx):
|
233 |
with self.profiler.profile("LoFTR"):
|
@@ -238,39 +329,46 @@ class PL_ASpanFormer(pl.LightningModule):
|
|
238 |
with self.profiler.profile("dump_results"):
|
239 |
if self.dump_dir is not None:
|
240 |
# dump results for further analysis
|
241 |
-
keys_to_save = {
|
242 |
-
pair_names = list(zip(*batch[
|
243 |
-
bs = batch[
|
244 |
dumps = []
|
245 |
for b_id in range(bs):
|
246 |
item = {}
|
247 |
-
mask = batch[
|
248 |
-
item[
|
249 |
-
item[
|
250 |
for key in keys_to_save:
|
251 |
item[key] = batch[key][mask].cpu().numpy()
|
252 |
-
for key in [
|
253 |
item[key] = batch[key][b_id]
|
254 |
dumps.append(item)
|
255 |
-
ret_dict[
|
256 |
|
257 |
return ret_dict
|
258 |
|
259 |
def test_epoch_end(self, outputs):
|
260 |
# metrics: dict of list, numpy
|
261 |
-
_metrics = [o[
|
262 |
-
metrics = {
|
|
|
|
|
|
|
263 |
|
264 |
# [{key: [{...}, *#bs]}, *#batch]
|
265 |
if self.dump_dir is not None:
|
266 |
Path(self.dump_dir).mkdir(parents=True, exist_ok=True)
|
267 |
-
_dumps = flattenList([o[
|
268 |
dumps = flattenList(gather(_dumps)) # [{...}, #proc*#bs*#batch]
|
269 |
-
logger.info(
|
|
|
|
|
270 |
|
271 |
if self.trainer.global_rank == 0:
|
272 |
print(self.profiler.summary())
|
273 |
-
val_metrics_4tb = aggregate_metrics(
|
274 |
-
|
|
|
|
|
275 |
if self.dump_dir is not None:
|
276 |
-
np.save(Path(self.dump_dir) /
|
|
|
|
|
1 |
from collections import defaultdict
|
2 |
import pprint
|
3 |
from loguru import logger
|
|
|
9 |
from matplotlib import pyplot as plt
|
10 |
|
11 |
from src.ASpanFormer.aspanformer import ASpanFormer
|
12 |
+
from src.ASpanFormer.utils.supervision import (
|
13 |
+
compute_supervision_coarse,
|
14 |
+
compute_supervision_fine,
|
15 |
+
)
|
16 |
from src.losses.aspan_loss import ASpanLoss
|
17 |
from src.optimizers import build_optimizer, build_scheduler
|
18 |
from src.utils.metrics import (
|
19 |
+
compute_symmetrical_epipolar_errors,
|
20 |
+
compute_symmetrical_epipolar_errors_offset_bidirectional,
|
21 |
compute_pose_errors,
|
22 |
+
aggregate_metrics,
|
23 |
)
|
24 |
+
from src.utils.plotting import make_matching_figures, make_matching_figures_offset
|
25 |
from src.utils.comm import gather, all_gather
|
26 |
from src.utils.misc import lower_config, flattenList
|
27 |
from src.utils.profiler import PassThroughProfiler
|
|
|
37 |
# Misc
|
38 |
self.config = config # full config
|
39 |
_config = lower_config(self.config)
|
40 |
+
self.loftr_cfg = lower_config(_config["aspan"])
|
41 |
self.profiler = profiler or PassThroughProfiler()
|
42 |
+
self.n_vals_plot = max(
|
43 |
+
config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1
|
44 |
+
)
|
45 |
|
46 |
# Matcher: LoFTR
|
47 |
+
self.matcher = ASpanFormer(config=_config["aspan"])
|
48 |
self.loss = ASpanLoss(_config)
|
49 |
|
50 |
# Pretrained weights
|
51 |
print(pretrained_ckpt)
|
52 |
if pretrained_ckpt:
|
53 |
+
print("load")
|
54 |
+
state_dict = torch.load(pretrained_ckpt, map_location="cpu")["state_dict"]
|
55 |
+
msg = self.matcher.load_state_dict(state_dict, strict=False)
|
56 |
print(msg)
|
57 |
+
logger.info(f"Load '{pretrained_ckpt}' as pretrained checkpoint")
|
58 |
+
|
59 |
# Testing
|
60 |
self.dump_dir = dump_dir
|
61 |
+
|
62 |
def configure_optimizers(self):
|
63 |
# FIXME: The scheduler did not work properly when `--resume_from_checkpoint`
|
64 |
optimizer = build_optimizer(self, self.config)
|
65 |
scheduler = build_scheduler(self.config, optimizer)
|
66 |
return [optimizer], [scheduler]
|
67 |
+
|
68 |
def optimizer_step(
|
69 |
+
self,
|
70 |
+
epoch,
|
71 |
+
batch_idx,
|
72 |
+
optimizer,
|
73 |
+
optimizer_idx,
|
74 |
+
optimizer_closure,
|
75 |
+
on_tpu,
|
76 |
+
using_native_amp,
|
77 |
+
using_lbfgs,
|
78 |
+
):
|
79 |
# learning rate warm up
|
80 |
warmup_step = self.config.TRAINER.WARMUP_STEP
|
81 |
if self.trainer.global_step < warmup_step:
|
82 |
+
if self.config.TRAINER.WARMUP_TYPE == "linear":
|
83 |
base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR
|
84 |
+
lr = base_lr + (
|
85 |
+
self.trainer.global_step / self.config.TRAINER.WARMUP_STEP
|
86 |
+
) * abs(self.config.TRAINER.TRUE_LR - base_lr)
|
87 |
for pg in optimizer.param_groups:
|
88 |
+
pg["lr"] = lr
|
89 |
+
elif self.config.TRAINER.WARMUP_TYPE == "constant":
|
90 |
pass
|
91 |
else:
|
92 |
+
raise ValueError(
|
93 |
+
f"Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}"
|
94 |
+
)
|
95 |
|
96 |
# update params
|
97 |
optimizer.step(closure=optimizer_closure)
|
98 |
optimizer.zero_grad()
|
99 |
+
|
100 |
def _trainval_inference(self, batch):
|
101 |
with self.profiler.profile("Compute coarse supervision"):
|
102 |
+
compute_supervision_coarse(batch, self.config)
|
103 |
+
|
104 |
with self.profiler.profile("LoFTR"):
|
105 |
+
self.matcher(batch)
|
106 |
+
|
107 |
with self.profiler.profile("Compute fine supervision"):
|
108 |
+
compute_supervision_fine(batch, self.config)
|
109 |
+
|
110 |
with self.profiler.profile("Compute losses"):
|
111 |
+
self.loss(batch)
|
112 |
+
|
113 |
def _compute_metrics(self, batch):
|
114 |
with self.profiler.profile("Copmute metrics"):
|
115 |
+
compute_symmetrical_epipolar_errors(
|
116 |
+
batch
|
117 |
+
) # compute epi_errs for each match
|
118 |
+
compute_symmetrical_epipolar_errors_offset_bidirectional(
|
119 |
+
batch
|
120 |
+
) # compute epi_errs for offset match
|
121 |
+
compute_pose_errors(
|
122 |
+
batch, self.config
|
123 |
+
) # compute R_errs, t_errs, pose_errs for each pair
|
124 |
|
125 |
+
rel_pair_names = list(zip(*batch["pair_names"]))
|
126 |
+
bs = batch["image0"].size(0)
|
127 |
metrics = {
|
128 |
# to filter duplicate pairs caused by DistributedSampler
|
129 |
+
"identifiers": ["#".join(rel_pair_names[b]) for b in range(bs)],
|
130 |
+
"epi_errs": [
|
131 |
+
batch["epi_errs"][batch["m_bids"] == b].cpu().numpy()
|
132 |
+
for b in range(bs)
|
133 |
+
],
|
134 |
+
"epi_errs_offset": [
|
135 |
+
batch["epi_errs_offset_left"][batch["offset_bids_left"] == b]
|
136 |
+
.cpu()
|
137 |
+
.numpy()
|
138 |
+
for b in range(bs)
|
139 |
+
], # only consider left side
|
140 |
+
"R_errs": batch["R_errs"],
|
141 |
+
"t_errs": batch["t_errs"],
|
142 |
+
"inliers": batch["inliers"],
|
143 |
+
}
|
144 |
+
ret_dict = {"metrics": metrics}
|
145 |
return ret_dict, rel_pair_names
|
146 |
+
|
|
|
147 |
def training_step(self, batch, batch_idx):
|
148 |
self._trainval_inference(batch)
|
149 |
+
|
150 |
# logging
|
151 |
+
if (
|
152 |
+
self.trainer.global_rank == 0
|
153 |
+
and self.global_step % self.trainer.log_every_n_steps == 0
|
154 |
+
):
|
155 |
# scalars
|
156 |
+
for k, v in batch["loss_scalars"].items():
|
157 |
+
if not k.startswith("loss_flow") and not k.startswith("conf_"):
|
158 |
+
self.logger.experiment.add_scalar(f"train/{k}", v, self.global_step)
|
159 |
+
|
160 |
+
# log offset_loss and conf for each layer and level
|
161 |
+
layer_num = self.loftr_cfg["coarse"]["layer_num"]
|
162 |
for layer_index in range(layer_num):
|
163 |
+
log_title = "layer_" + str(layer_index)
|
164 |
+
self.logger.experiment.add_scalar(
|
165 |
+
log_title + "/offset_loss",
|
166 |
+
batch["loss_scalars"]["loss_flow_" + str(layer_index)],
|
167 |
+
self.global_step,
|
168 |
+
)
|
169 |
+
self.logger.experiment.add_scalar(
|
170 |
+
log_title + "/conf_",
|
171 |
+
batch["loss_scalars"]["conf_" + str(layer_index)],
|
172 |
+
self.global_step,
|
173 |
+
)
|
174 |
+
|
175 |
# net-params
|
176 |
+
if self.config.ASPAN.MATCH_COARSE.MATCH_TYPE == "sinkhorn":
|
177 |
self.logger.experiment.add_scalar(
|
178 |
+
f"skh_bin_score",
|
179 |
+
self.matcher.coarse_matching.bin_score.clone().detach().cpu().data,
|
180 |
+
self.global_step,
|
181 |
+
)
|
182 |
|
183 |
# figures
|
184 |
if self.config.TRAINER.ENABLE_PLOTTING:
|
185 |
+
compute_symmetrical_epipolar_errors(
|
186 |
+
batch
|
187 |
+
) # compute epi_errs for each match
|
188 |
+
figures = make_matching_figures(
|
189 |
+
batch, self.config, self.config.TRAINER.PLOT_MODE
|
190 |
+
)
|
191 |
for k, v in figures.items():
|
192 |
+
self.logger.experiment.add_figure(
|
193 |
+
f"train_match/{k}", v, self.global_step
|
194 |
+
)
|
195 |
|
196 |
+
# plot offset
|
197 |
+
if self.global_step % 200 == 0:
|
198 |
compute_symmetrical_epipolar_errors_offset_bidirectional(batch)
|
199 |
+
figures_left = make_matching_figures_offset(
|
200 |
+
batch, self.config, self.config.TRAINER.PLOT_MODE, side="_left"
|
201 |
+
)
|
202 |
+
figures_right = make_matching_figures_offset(
|
203 |
+
batch, self.config, self.config.TRAINER.PLOT_MODE, side="_right"
|
204 |
+
)
|
205 |
for k, v in figures_left.items():
|
206 |
+
self.logger.experiment.add_figure(
|
207 |
+
f"train_offset/{k}" + "_left", v, self.global_step
|
208 |
+
)
|
209 |
+
figures = make_matching_figures_offset(
|
210 |
+
batch, self.config, self.config.TRAINER.PLOT_MODE, side="_right"
|
211 |
+
)
|
212 |
for k, v in figures_right.items():
|
213 |
+
self.logger.experiment.add_figure(
|
214 |
+
f"train_offset/{k}" + "_right", v, self.global_step
|
215 |
+
)
|
216 |
+
|
217 |
+
return {"loss": batch["loss"]}
|
218 |
|
219 |
def training_epoch_end(self, outputs):
|
220 |
+
avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
|
221 |
if self.trainer.global_rank == 0:
|
222 |
self.logger.experiment.add_scalar(
|
223 |
+
"train/avg_loss_on_epoch", avg_loss, global_step=self.current_epoch
|
224 |
+
)
|
225 |
+
|
226 |
def validation_step(self, batch, batch_idx):
|
227 |
self._trainval_inference(batch)
|
228 |
+
|
229 |
+
ret_dict, _ = self._compute_metrics(
|
230 |
+
batch
|
231 |
+
) # this func also compute the epi_errors
|
232 |
+
|
233 |
val_plot_interval = max(self.trainer.num_val_batches[0] // self.n_vals_plot, 1)
|
234 |
figures = {self.config.TRAINER.PLOT_MODE: []}
|
235 |
figures_offset = {self.config.TRAINER.PLOT_MODE: []}
|
236 |
if batch_idx % val_plot_interval == 0:
|
237 |
+
figures = make_matching_figures(
|
238 |
+
batch, self.config, mode=self.config.TRAINER.PLOT_MODE
|
239 |
+
)
|
240 |
+
figures_offset = make_matching_figures_offset(
|
241 |
+
batch, self.config, self.config.TRAINER.PLOT_MODE, "_left"
|
242 |
+
)
|
243 |
return {
|
244 |
**ret_dict,
|
245 |
+
"loss_scalars": batch["loss_scalars"],
|
246 |
+
"figures": figures,
|
247 |
+
"figures_offset_left": figures_offset,
|
248 |
}
|
249 |
+
|
250 |
def validation_epoch_end(self, outputs):
|
251 |
# handle multiple validation sets
|
252 |
+
multi_outputs = (
|
253 |
+
[outputs] if not isinstance(outputs[0], (list, tuple)) else outputs
|
254 |
+
)
|
255 |
multi_val_metrics = defaultdict(list)
|
256 |
+
|
257 |
for valset_idx, outputs in enumerate(multi_outputs):
|
258 |
# since pl performs sanity_check at the very begining of the training
|
259 |
cur_epoch = self.trainer.current_epoch
|
260 |
+
if (
|
261 |
+
not self.trainer.resume_from_checkpoint
|
262 |
+
and self.trainer.running_sanity_check
|
263 |
+
):
|
264 |
cur_epoch = -1
|
265 |
|
266 |
# 1. loss_scalars: dict of list, on cpu
|
267 |
+
_loss_scalars = [o["loss_scalars"] for o in outputs]
|
268 |
+
loss_scalars = {
|
269 |
+
k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars]))
|
270 |
+
for k in _loss_scalars[0]
|
271 |
+
}
|
272 |
|
273 |
# 2. val metrics: dict of list, numpy
|
274 |
+
_metrics = [o["metrics"] for o in outputs]
|
275 |
+
metrics = {
|
276 |
+
k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics])))
|
277 |
+
for k in _metrics[0]
|
278 |
+
}
|
279 |
+
# NOTE: all ranks need to `aggregate_merics`, but only log at rank-0
|
280 |
+
val_metrics_4tb = aggregate_metrics(
|
281 |
+
metrics, self.config.TRAINER.EPI_ERR_THR
|
282 |
+
)
|
283 |
for thr in [5, 10, 20]:
|
284 |
+
multi_val_metrics[f"auc@{thr}"].append(val_metrics_4tb[f"auc@{thr}"])
|
285 |
+
|
286 |
# 3. figures
|
287 |
+
_figures = [o["figures"] for o in outputs]
|
288 |
+
figures = {
|
289 |
+
k: flattenList(gather(flattenList([_me[k] for _me in _figures])))
|
290 |
+
for k in _figures[0]
|
291 |
+
}
|
292 |
|
293 |
# tensorboard records only on rank 0
|
294 |
if self.trainer.global_rank == 0:
|
295 |
for k, v in loss_scalars.items():
|
296 |
mean_v = torch.stack(v).mean()
|
297 |
+
self.logger.experiment.add_scalar(
|
298 |
+
f"val_{valset_idx}/avg_{k}", mean_v, global_step=cur_epoch
|
299 |
+
)
|
300 |
|
301 |
for k, v in val_metrics_4tb.items():
|
302 |
+
self.logger.experiment.add_scalar(
|
303 |
+
f"metrics_{valset_idx}/{k}", v, global_step=cur_epoch
|
304 |
+
)
|
305 |
+
|
306 |
for k, v in figures.items():
|
307 |
if self.trainer.global_rank == 0:
|
308 |
for plot_idx, fig in enumerate(v):
|
309 |
self.logger.experiment.add_figure(
|
310 |
+
f"val_match_{valset_idx}/{k}/pair-{plot_idx}",
|
311 |
+
fig,
|
312 |
+
cur_epoch,
|
313 |
+
close=True,
|
314 |
+
)
|
315 |
+
plt.close("all")
|
316 |
|
317 |
for thr in [5, 10, 20]:
|
318 |
# log on all ranks for ModelCheckpoint callback to work properly
|
319 |
+
self.log(
|
320 |
+
f"auc@{thr}", torch.tensor(np.mean(multi_val_metrics[f"auc@{thr}"]))
|
321 |
+
) # ckpt monitors on this
|
322 |
|
323 |
def test_step(self, batch, batch_idx):
|
324 |
with self.profiler.profile("LoFTR"):
|
|
|
329 |
with self.profiler.profile("dump_results"):
|
330 |
if self.dump_dir is not None:
|
331 |
# dump results for further analysis
|
332 |
+
keys_to_save = {"mkpts0_f", "mkpts1_f", "mconf", "epi_errs"}
|
333 |
+
pair_names = list(zip(*batch["pair_names"]))
|
334 |
+
bs = batch["image0"].shape[0]
|
335 |
dumps = []
|
336 |
for b_id in range(bs):
|
337 |
item = {}
|
338 |
+
mask = batch["m_bids"] == b_id
|
339 |
+
item["pair_names"] = pair_names[b_id]
|
340 |
+
item["identifier"] = "#".join(rel_pair_names[b_id])
|
341 |
for key in keys_to_save:
|
342 |
item[key] = batch[key][mask].cpu().numpy()
|
343 |
+
for key in ["R_errs", "t_errs", "inliers"]:
|
344 |
item[key] = batch[key][b_id]
|
345 |
dumps.append(item)
|
346 |
+
ret_dict["dumps"] = dumps
|
347 |
|
348 |
return ret_dict
|
349 |
|
350 |
def test_epoch_end(self, outputs):
|
351 |
# metrics: dict of list, numpy
|
352 |
+
_metrics = [o["metrics"] for o in outputs]
|
353 |
+
metrics = {
|
354 |
+
k: flattenList(gather(flattenList([_me[k] for _me in _metrics])))
|
355 |
+
for k in _metrics[0]
|
356 |
+
}
|
357 |
|
358 |
# [{key: [{...}, *#bs]}, *#batch]
|
359 |
if self.dump_dir is not None:
|
360 |
Path(self.dump_dir).mkdir(parents=True, exist_ok=True)
|
361 |
+
_dumps = flattenList([o["dumps"] for o in outputs]) # [{...}, #bs*#batch]
|
362 |
dumps = flattenList(gather(_dumps)) # [{...}, #proc*#bs*#batch]
|
363 |
+
logger.info(
|
364 |
+
f"Prediction and evaluation results will be saved to: {self.dump_dir}"
|
365 |
+
)
|
366 |
|
367 |
if self.trainer.global_rank == 0:
|
368 |
print(self.profiler.summary())
|
369 |
+
val_metrics_4tb = aggregate_metrics(
|
370 |
+
metrics, self.config.TRAINER.EPI_ERR_THR
|
371 |
+
)
|
372 |
+
logger.info("\n" + pprint.pformat(val_metrics_4tb))
|
373 |
if self.dump_dir is not None:
|
374 |
+
np.save(Path(self.dump_dir) / "LoFTR_pred_eval", dumps)
|
third_party/ASpanFormer/src/losses/aspan_loss.py
CHANGED
@@ -3,48 +3,55 @@ from loguru import logger
|
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
|
|
|
6 |
class ASpanLoss(nn.Module):
|
7 |
def __init__(self, config):
|
8 |
super().__init__()
|
9 |
self.config = config # config under the global namespace
|
10 |
-
self.loss_config = config[
|
11 |
-
self.match_type = self.config[
|
12 |
-
self.sparse_spvs = self.config[
|
13 |
-
self.flow_weight=self.config[
|
14 |
|
15 |
# coarse-level
|
16 |
-
self.correct_thr = self.loss_config[
|
17 |
-
self.c_pos_w = self.loss_config[
|
18 |
-
self.c_neg_w = self.loss_config[
|
19 |
# fine-level
|
20 |
-
self.fine_type = self.loss_config[
|
21 |
-
|
22 |
-
def compute_flow_loss(self,coarse_corr_gt,flow_list,h0,w0,h1,w1):
|
23 |
-
#coarse_corr_gt:[[batch_indices],[left_indices],[right_indices]]
|
24 |
-
#flow_list: [L,B,H,W,4]
|
25 |
-
loss1=self.flow_loss_worker(
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
28 |
return total_loss
|
29 |
|
30 |
-
def flow_loss_worker(self,flow,batch_indicies,self_indicies,cross_indicies,w):
|
31 |
-
bs,layer_num=flow.shape[1],flow.shape[0]
|
32 |
-
flow=flow.view(layer_num,bs
|
33 |
-
gt_flow=torch.stack([cross_indicies%w,cross_indicies//w],dim=1)
|
34 |
|
35 |
-
total_loss_list=[]
|
36 |
for layer_index in range(layer_num):
|
37 |
-
cur_flow_list=flow[layer_index]
|
38 |
-
spv_flow=cur_flow_list[batch_indicies,self_indicies][
|
39 |
-
spv_conf=cur_flow_list[batch_indicies,self_indicies][
|
40 |
-
|
41 |
-
|
|
|
|
|
42 |
total_loss_list.append(total_loss.mean())
|
43 |
-
total_loss=torch.stack(total_loss_list,dim=-1)*self.flow_weight
|
44 |
return total_loss
|
45 |
-
|
46 |
def compute_coarse_loss(self, conf, conf_gt, weight=None):
|
47 |
-
"""
|
48 |
Args:
|
49 |
conf (torch.Tensor): (N, HW0, HW1) / (N, HW0+1, HW1+1)
|
50 |
conf_gt (torch.Tensor): (N, HW0, HW1)
|
@@ -56,38 +63,44 @@ class ASpanLoss(nn.Module):
|
|
56 |
if not pos_mask.any(): # assign a wrong gt
|
57 |
pos_mask[0, 0, 0] = True
|
58 |
if weight is not None:
|
59 |
-
weight[0, 0, 0] = 0.
|
60 |
-
c_pos_w = 0.
|
61 |
if not neg_mask.any():
|
62 |
neg_mask[0, 0, 0] = True
|
63 |
if weight is not None:
|
64 |
-
weight[0, 0, 0] = 0.
|
65 |
-
c_neg_w = 0.
|
66 |
-
|
67 |
-
if self.loss_config[
|
68 |
-
assert
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
|
|
72 |
if weight is not None:
|
73 |
loss_pos = loss_pos * weight[pos_mask]
|
74 |
loss_neg = loss_neg * weight[neg_mask]
|
75 |
return c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean()
|
76 |
-
elif self.loss_config[
|
77 |
-
conf = torch.clamp(conf, 1e-6, 1-1e-6)
|
78 |
-
alpha = self.loss_config[
|
79 |
-
gamma = self.loss_config[
|
80 |
-
|
81 |
if self.sparse_spvs:
|
82 |
-
pos_conf =
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
|
|
86 |
# calculate losses for negative samples
|
87 |
-
if self.match_type ==
|
88 |
neg0, neg1 = conf_gt.sum(-1) == 0, conf_gt.sum(1) == 0
|
89 |
-
neg_conf = torch.cat(
|
90 |
-
|
|
|
|
|
91 |
else:
|
92 |
# These is no dustbin for dual_softmax, so we left unmatchable patches without supervision.
|
93 |
# we could also add 'pseudo negtive-samples'
|
@@ -97,32 +110,46 @@ class ASpanLoss(nn.Module):
|
|
97 |
# Different from dense-spvs, the loss w.r.t. padded regions aren't directly zeroed out,
|
98 |
# but only through manually setting corresponding regions in sim_matrix to '-inf'.
|
99 |
loss_pos = loss_pos * weight[pos_mask]
|
100 |
-
if self.match_type ==
|
101 |
neg_w0 = (weight.sum(-1) != 0)[neg0]
|
102 |
neg_w1 = (weight.sum(1) != 0)[neg1]
|
103 |
neg_mask = torch.cat([neg_w0, neg_w1], 0)
|
104 |
loss_neg = loss_neg[neg_mask]
|
105 |
-
|
106 |
-
loss =
|
107 |
-
|
108 |
-
|
|
|
|
|
109 |
return loss
|
110 |
# positive and negative elements occupy similar propotions. => more balanced loss weights needed
|
111 |
else: # dense supervision (in the case of match_type=='sinkhorn', the dustbin is not supervised.)
|
112 |
-
loss_pos =
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
if weight is not None:
|
115 |
loss_pos = loss_pos * weight[pos_mask]
|
116 |
loss_neg = loss_neg * weight[neg_mask]
|
117 |
return c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean()
|
118 |
# each negative element occupy a smaller propotion than positive elements. => higher negative loss weight needed
|
119 |
else:
|
120 |
-
raise ValueError(
|
121 |
-
|
|
|
|
|
|
|
|
|
122 |
def compute_fine_loss(self, expec_f, expec_f_gt):
|
123 |
-
if self.fine_type ==
|
124 |
return self._compute_fine_loss_l2_std(expec_f, expec_f_gt)
|
125 |
-
elif self.fine_type ==
|
126 |
return self._compute_fine_loss_l2(expec_f, expec_f_gt)
|
127 |
else:
|
128 |
raise NotImplementedError()
|
@@ -133,9 +160,13 @@ class ASpanLoss(nn.Module):
|
|
133 |
expec_f (torch.Tensor): [M, 2] <x, y>
|
134 |
expec_f_gt (torch.Tensor): [M, 2] <x, y>
|
135 |
"""
|
136 |
-
correct_mask =
|
|
|
|
|
137 |
if correct_mask.sum() == 0:
|
138 |
-
if
|
|
|
|
|
139 |
logger.warning("assign a false supervision to avoid ddp deadlock")
|
140 |
correct_mask[0] = True
|
141 |
else:
|
@@ -150,20 +181,26 @@ class ASpanLoss(nn.Module):
|
|
150 |
expec_f_gt (torch.Tensor): [M, 2] <x, y>
|
151 |
"""
|
152 |
# correct_mask tells you which pair to compute fine-loss
|
153 |
-
correct_mask =
|
|
|
|
|
154 |
|
155 |
# use std as weight that measures uncertainty
|
156 |
std = expec_f[:, 2]
|
157 |
-
inverse_std = 1. / torch.clamp(std, min=1e-10)
|
158 |
-
weight = (
|
|
|
|
|
159 |
|
160 |
# corner case: no correct coarse match found
|
161 |
if not correct_mask.any():
|
162 |
-
if
|
163 |
-
|
|
|
|
|
164 |
logger.warning("assign a false supervision to avoid ddp deadlock")
|
165 |
correct_mask[0] = True
|
166 |
-
weight[0] = 0.
|
167 |
else:
|
168 |
return None
|
169 |
|
@@ -172,12 +209,15 @@ class ASpanLoss(nn.Module):
|
|
172 |
loss = (flow_l2 * weight[correct_mask]).mean()
|
173 |
|
174 |
return loss
|
175 |
-
|
176 |
@torch.no_grad()
|
177 |
def compute_c_weight(self, data):
|
178 |
-
"""
|
179 |
-
if
|
180 |
-
c_weight = (
|
|
|
|
|
|
|
181 |
else:
|
182 |
c_weight = None
|
183 |
return c_weight
|
@@ -196,36 +236,54 @@ class ASpanLoss(nn.Module):
|
|
196 |
|
197 |
# 1. coarse-level loss
|
198 |
loss_c = self.compute_coarse_loss(
|
199 |
-
data[
|
200 |
-
|
201 |
-
data[
|
202 |
-
|
203 |
-
|
|
|
|
|
204 |
loss_scalars.update({"loss_c": loss_c.clone().detach().cpu()})
|
205 |
|
206 |
# 2. fine-level loss
|
207 |
-
loss_f = self.compute_fine_loss(data[
|
208 |
if loss_f is not None:
|
209 |
-
loss += loss_f * self.loss_config[
|
210 |
-
loss_scalars.update({"loss_f":
|
211 |
else:
|
212 |
assert self.training is False
|
213 |
-
loss_scalars.update({
|
214 |
-
|
215 |
# 3. flow loss
|
216 |
-
coarse_corr=[data[
|
217 |
-
loss_flow = self.compute_flow_loss(
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
for layer_index in range(layer_num):
|
225 |
-
loss_scalars.update(
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
data.update({"loss": loss, "loss_scalars": loss_scalars})
|
|
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
|
6 |
+
|
7 |
class ASpanLoss(nn.Module):
|
8 |
def __init__(self, config):
|
9 |
super().__init__()
|
10 |
self.config = config # config under the global namespace
|
11 |
+
self.loss_config = config["aspan"]["loss"]
|
12 |
+
self.match_type = self.config["aspan"]["match_coarse"]["match_type"]
|
13 |
+
self.sparse_spvs = self.config["aspan"]["match_coarse"]["sparse_spvs"]
|
14 |
+
self.flow_weight = self.config["aspan"]["loss"]["flow_weight"]
|
15 |
|
16 |
# coarse-level
|
17 |
+
self.correct_thr = self.loss_config["fine_correct_thr"]
|
18 |
+
self.c_pos_w = self.loss_config["pos_weight"]
|
19 |
+
self.c_neg_w = self.loss_config["neg_weight"]
|
20 |
# fine-level
|
21 |
+
self.fine_type = self.loss_config["fine_type"]
|
22 |
+
|
23 |
+
def compute_flow_loss(self, coarse_corr_gt, flow_list, h0, w0, h1, w1):
|
24 |
+
# coarse_corr_gt:[[batch_indices],[left_indices],[right_indices]]
|
25 |
+
# flow_list: [L,B,H,W,4]
|
26 |
+
loss1 = self.flow_loss_worker(
|
27 |
+
flow_list[0], coarse_corr_gt[0], coarse_corr_gt[1], coarse_corr_gt[2], w1
|
28 |
+
)
|
29 |
+
loss2 = self.flow_loss_worker(
|
30 |
+
flow_list[1], coarse_corr_gt[0], coarse_corr_gt[2], coarse_corr_gt[1], w0
|
31 |
+
)
|
32 |
+
total_loss = (loss1 + loss2) / 2
|
33 |
return total_loss
|
34 |
|
35 |
+
def flow_loss_worker(self, flow, batch_indicies, self_indicies, cross_indicies, w):
|
36 |
+
bs, layer_num = flow.shape[1], flow.shape[0]
|
37 |
+
flow = flow.view(layer_num, bs, -1, 4)
|
38 |
+
gt_flow = torch.stack([cross_indicies % w, cross_indicies // w], dim=1)
|
39 |
|
40 |
+
total_loss_list = []
|
41 |
for layer_index in range(layer_num):
|
42 |
+
cur_flow_list = flow[layer_index]
|
43 |
+
spv_flow = cur_flow_list[batch_indicies, self_indicies][:, :2]
|
44 |
+
spv_conf = cur_flow_list[batch_indicies, self_indicies][
|
45 |
+
:, 2:
|
46 |
+
] # [#coarse,2]
|
47 |
+
l2_flow_dis = (gt_flow - spv_flow) ** 2 # [#coarse,2]
|
48 |
+
total_loss = spv_conf + torch.exp(-spv_conf) * l2_flow_dis # [#coarse,2]
|
49 |
total_loss_list.append(total_loss.mean())
|
50 |
+
total_loss = torch.stack(total_loss_list, dim=-1) * self.flow_weight
|
51 |
return total_loss
|
52 |
+
|
53 |
def compute_coarse_loss(self, conf, conf_gt, weight=None):
|
54 |
+
"""Point-wise CE / Focal Loss with 0 / 1 confidence as gt.
|
55 |
Args:
|
56 |
conf (torch.Tensor): (N, HW0, HW1) / (N, HW0+1, HW1+1)
|
57 |
conf_gt (torch.Tensor): (N, HW0, HW1)
|
|
|
63 |
if not pos_mask.any(): # assign a wrong gt
|
64 |
pos_mask[0, 0, 0] = True
|
65 |
if weight is not None:
|
66 |
+
weight[0, 0, 0] = 0.0
|
67 |
+
c_pos_w = 0.0
|
68 |
if not neg_mask.any():
|
69 |
neg_mask[0, 0, 0] = True
|
70 |
if weight is not None:
|
71 |
+
weight[0, 0, 0] = 0.0
|
72 |
+
c_neg_w = 0.0
|
73 |
+
|
74 |
+
if self.loss_config["coarse_type"] == "cross_entropy":
|
75 |
+
assert (
|
76 |
+
not self.sparse_spvs
|
77 |
+
), "Sparse Supervision for cross-entropy not implemented!"
|
78 |
+
conf = torch.clamp(conf, 1e-6, 1 - 1e-6)
|
79 |
+
loss_pos = -torch.log(conf[pos_mask])
|
80 |
+
loss_neg = -torch.log(1 - conf[neg_mask])
|
81 |
if weight is not None:
|
82 |
loss_pos = loss_pos * weight[pos_mask]
|
83 |
loss_neg = loss_neg * weight[neg_mask]
|
84 |
return c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean()
|
85 |
+
elif self.loss_config["coarse_type"] == "focal":
|
86 |
+
conf = torch.clamp(conf, 1e-6, 1 - 1e-6)
|
87 |
+
alpha = self.loss_config["focal_alpha"]
|
88 |
+
gamma = self.loss_config["focal_gamma"]
|
89 |
+
|
90 |
if self.sparse_spvs:
|
91 |
+
pos_conf = (
|
92 |
+
conf[:, :-1, :-1][pos_mask]
|
93 |
+
if self.match_type == "sinkhorn"
|
94 |
+
else conf[pos_mask]
|
95 |
+
)
|
96 |
+
loss_pos = -alpha * torch.pow(1 - pos_conf, gamma) * pos_conf.log()
|
97 |
# calculate losses for negative samples
|
98 |
+
if self.match_type == "sinkhorn":
|
99 |
neg0, neg1 = conf_gt.sum(-1) == 0, conf_gt.sum(1) == 0
|
100 |
+
neg_conf = torch.cat(
|
101 |
+
[conf[:, :-1, -1][neg0], conf[:, -1, :-1][neg1]], 0
|
102 |
+
)
|
103 |
+
loss_neg = -alpha * torch.pow(1 - neg_conf, gamma) * neg_conf.log()
|
104 |
else:
|
105 |
# These is no dustbin for dual_softmax, so we left unmatchable patches without supervision.
|
106 |
# we could also add 'pseudo negtive-samples'
|
|
|
110 |
# Different from dense-spvs, the loss w.r.t. padded regions aren't directly zeroed out,
|
111 |
# but only through manually setting corresponding regions in sim_matrix to '-inf'.
|
112 |
loss_pos = loss_pos * weight[pos_mask]
|
113 |
+
if self.match_type == "sinkhorn":
|
114 |
neg_w0 = (weight.sum(-1) != 0)[neg0]
|
115 |
neg_w1 = (weight.sum(1) != 0)[neg1]
|
116 |
neg_mask = torch.cat([neg_w0, neg_w1], 0)
|
117 |
loss_neg = loss_neg[neg_mask]
|
118 |
+
|
119 |
+
loss = (
|
120 |
+
c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean()
|
121 |
+
if self.match_type == "sinkhorn"
|
122 |
+
else c_pos_w * loss_pos.mean()
|
123 |
+
)
|
124 |
return loss
|
125 |
# positive and negative elements occupy similar propotions. => more balanced loss weights needed
|
126 |
else: # dense supervision (in the case of match_type=='sinkhorn', the dustbin is not supervised.)
|
127 |
+
loss_pos = (
|
128 |
+
-alpha
|
129 |
+
* torch.pow(1 - conf[pos_mask], gamma)
|
130 |
+
* (conf[pos_mask]).log()
|
131 |
+
)
|
132 |
+
loss_neg = (
|
133 |
+
-alpha
|
134 |
+
* torch.pow(conf[neg_mask], gamma)
|
135 |
+
* (1 - conf[neg_mask]).log()
|
136 |
+
)
|
137 |
if weight is not None:
|
138 |
loss_pos = loss_pos * weight[pos_mask]
|
139 |
loss_neg = loss_neg * weight[neg_mask]
|
140 |
return c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean()
|
141 |
# each negative element occupy a smaller propotion than positive elements. => higher negative loss weight needed
|
142 |
else:
|
143 |
+
raise ValueError(
|
144 |
+
"Unknown coarse loss: {type}".format(
|
145 |
+
type=self.loss_config["coarse_type"]
|
146 |
+
)
|
147 |
+
)
|
148 |
+
|
149 |
def compute_fine_loss(self, expec_f, expec_f_gt):
|
150 |
+
if self.fine_type == "l2_with_std":
|
151 |
return self._compute_fine_loss_l2_std(expec_f, expec_f_gt)
|
152 |
+
elif self.fine_type == "l2":
|
153 |
return self._compute_fine_loss_l2(expec_f, expec_f_gt)
|
154 |
else:
|
155 |
raise NotImplementedError()
|
|
|
160 |
expec_f (torch.Tensor): [M, 2] <x, y>
|
161 |
expec_f_gt (torch.Tensor): [M, 2] <x, y>
|
162 |
"""
|
163 |
+
correct_mask = (
|
164 |
+
torch.linalg.norm(expec_f_gt, ord=float("inf"), dim=1) < self.correct_thr
|
165 |
+
)
|
166 |
if correct_mask.sum() == 0:
|
167 |
+
if (
|
168 |
+
self.training
|
169 |
+
): # this seldomly happen when training, since we pad prediction with gt
|
170 |
logger.warning("assign a false supervision to avoid ddp deadlock")
|
171 |
correct_mask[0] = True
|
172 |
else:
|
|
|
181 |
expec_f_gt (torch.Tensor): [M, 2] <x, y>
|
182 |
"""
|
183 |
# correct_mask tells you which pair to compute fine-loss
|
184 |
+
correct_mask = (
|
185 |
+
torch.linalg.norm(expec_f_gt, ord=float("inf"), dim=1) < self.correct_thr
|
186 |
+
)
|
187 |
|
188 |
# use std as weight that measures uncertainty
|
189 |
std = expec_f[:, 2]
|
190 |
+
inverse_std = 1.0 / torch.clamp(std, min=1e-10)
|
191 |
+
weight = (
|
192 |
+
inverse_std / torch.mean(inverse_std)
|
193 |
+
).detach() # avoid minizing loss through increase std
|
194 |
|
195 |
# corner case: no correct coarse match found
|
196 |
if not correct_mask.any():
|
197 |
+
if (
|
198 |
+
self.training
|
199 |
+
): # this seldomly happen during training, since we pad prediction with gt
|
200 |
+
# sometimes there is not coarse-level gt at all.
|
201 |
logger.warning("assign a false supervision to avoid ddp deadlock")
|
202 |
correct_mask[0] = True
|
203 |
+
weight[0] = 0.0
|
204 |
else:
|
205 |
return None
|
206 |
|
|
|
209 |
loss = (flow_l2 * weight[correct_mask]).mean()
|
210 |
|
211 |
return loss
|
212 |
+
|
213 |
@torch.no_grad()
|
214 |
def compute_c_weight(self, data):
|
215 |
+
"""compute element-wise weights for computing coarse-level loss."""
|
216 |
+
if "mask0" in data:
|
217 |
+
c_weight = (
|
218 |
+
data["mask0"].flatten(-2)[..., None]
|
219 |
+
* data["mask1"].flatten(-2)[:, None]
|
220 |
+
).float()
|
221 |
else:
|
222 |
c_weight = None
|
223 |
return c_weight
|
|
|
236 |
|
237 |
# 1. coarse-level loss
|
238 |
loss_c = self.compute_coarse_loss(
|
239 |
+
data["conf_matrix_with_bin"]
|
240 |
+
if self.sparse_spvs and self.match_type == "sinkhorn"
|
241 |
+
else data["conf_matrix"],
|
242 |
+
data["conf_matrix_gt"],
|
243 |
+
weight=c_weight,
|
244 |
+
)
|
245 |
+
loss = loss_c * self.loss_config["coarse_weight"]
|
246 |
loss_scalars.update({"loss_c": loss_c.clone().detach().cpu()})
|
247 |
|
248 |
# 2. fine-level loss
|
249 |
+
loss_f = self.compute_fine_loss(data["expec_f"], data["expec_f_gt"])
|
250 |
if loss_f is not None:
|
251 |
+
loss += loss_f * self.loss_config["fine_weight"]
|
252 |
+
loss_scalars.update({"loss_f": loss_f.clone().detach().cpu()})
|
253 |
else:
|
254 |
assert self.training is False
|
255 |
+
loss_scalars.update({"loss_f": torch.tensor(1.0)}) # 1 is the upper bound
|
256 |
+
|
257 |
# 3. flow loss
|
258 |
+
coarse_corr = [data["spv_b_ids"], data["spv_i_ids"], data["spv_j_ids"]]
|
259 |
+
loss_flow = self.compute_flow_loss(
|
260 |
+
coarse_corr,
|
261 |
+
data["predict_flow"],
|
262 |
+
data["hw0_c"][0],
|
263 |
+
data["hw0_c"][1],
|
264 |
+
data["hw1_c"][0],
|
265 |
+
data["hw1_c"][1],
|
266 |
+
)
|
267 |
+
loss_flow = loss_flow * self.flow_weight
|
268 |
+
for index, loss_off in enumerate(loss_flow):
|
269 |
+
loss_scalars.update(
|
270 |
+
{"loss_flow_" + str(index): loss_off.clone().detach().cpu()}
|
271 |
+
) # 1 is the upper bound
|
272 |
+
conf = data["predict_flow"][0][:, :, :, :, 2:]
|
273 |
+
layer_num = conf.shape[0]
|
274 |
for layer_index in range(layer_num):
|
275 |
+
loss_scalars.update(
|
276 |
+
{
|
277 |
+
"conf_"
|
278 |
+
+ str(layer_index): conf[layer_index]
|
279 |
+
.mean()
|
280 |
+
.clone()
|
281 |
+
.detach()
|
282 |
+
.cpu()
|
283 |
+
}
|
284 |
+
) # 1 is the upper bound
|
285 |
+
|
286 |
+
loss += loss_flow.sum()
|
287 |
+
# print((loss_c * self.loss_config['coarse_weight']).data,loss_flow.data)
|
288 |
+
loss_scalars.update({"loss": loss.clone().detach().cpu()})
|
289 |
data.update({"loss": loss, "loss_scalars": loss_scalars})
|
third_party/ASpanFormer/src/optimizers/__init__.py
CHANGED
@@ -7,9 +7,13 @@ def build_optimizer(model, config):
|
|
7 |
lr = config.TRAINER.TRUE_LR
|
8 |
|
9 |
if name == "adam":
|
10 |
-
return torch.optim.Adam(
|
|
|
|
|
11 |
elif name == "adamw":
|
12 |
-
return torch.optim.AdamW(
|
|
|
|
|
13 |
else:
|
14 |
raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!")
|
15 |
|
@@ -24,18 +28,27 @@ def build_scheduler(config, optimizer):
|
|
24 |
'frequency': x, (optional)
|
25 |
}
|
26 |
"""
|
27 |
-
scheduler = {
|
28 |
name = config.TRAINER.SCHEDULER
|
29 |
|
30 |
-
if name ==
|
31 |
scheduler.update(
|
32 |
-
{
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
scheduler.update(
|
35 |
-
{
|
36 |
-
|
|
|
37 |
scheduler.update(
|
38 |
-
{
|
|
|
39 |
else:
|
40 |
raise NotImplementedError()
|
41 |
|
|
|
7 |
lr = config.TRAINER.TRUE_LR
|
8 |
|
9 |
if name == "adam":
|
10 |
+
return torch.optim.Adam(
|
11 |
+
model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY
|
12 |
+
)
|
13 |
elif name == "adamw":
|
14 |
+
return torch.optim.AdamW(
|
15 |
+
model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY
|
16 |
+
)
|
17 |
else:
|
18 |
raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!")
|
19 |
|
|
|
28 |
'frequency': x, (optional)
|
29 |
}
|
30 |
"""
|
31 |
+
scheduler = {"interval": config.TRAINER.SCHEDULER_INTERVAL}
|
32 |
name = config.TRAINER.SCHEDULER
|
33 |
|
34 |
+
if name == "MultiStepLR":
|
35 |
scheduler.update(
|
36 |
+
{
|
37 |
+
"scheduler": MultiStepLR(
|
38 |
+
optimizer,
|
39 |
+
config.TRAINER.MSLR_MILESTONES,
|
40 |
+
gamma=config.TRAINER.MSLR_GAMMA,
|
41 |
+
)
|
42 |
+
}
|
43 |
+
)
|
44 |
+
elif name == "CosineAnnealing":
|
45 |
scheduler.update(
|
46 |
+
{"scheduler": CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)}
|
47 |
+
)
|
48 |
+
elif name == "ExponentialLR":
|
49 |
scheduler.update(
|
50 |
+
{"scheduler": ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)}
|
51 |
+
)
|
52 |
else:
|
53 |
raise NotImplementedError()
|
54 |
|
third_party/ASpanFormer/src/utils/augment.py
CHANGED
@@ -7,16 +7,21 @@ class DarkAug(object):
|
|
7 |
"""
|
8 |
|
9 |
def __init__(self) -> None:
|
10 |
-
self.augmentor = A.Compose(
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
def __call__(self, x):
|
19 |
-
return self.augmentor(image=x)[
|
20 |
|
21 |
|
22 |
class MobileAug(object):
|
@@ -25,31 +30,36 @@ class MobileAug(object):
|
|
25 |
"""
|
26 |
|
27 |
def __init__(self):
|
28 |
-
self.augmentor = A.Compose(
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
36 |
|
37 |
def __call__(self, x):
|
38 |
-
return self.augmentor(image=x)[
|
39 |
|
40 |
|
41 |
def build_augmentor(method=None, **kwargs):
|
42 |
if method is not None:
|
43 |
-
raise NotImplementedError(
|
44 |
-
|
|
|
|
|
45 |
return DarkAug()
|
46 |
-
elif method ==
|
47 |
return MobileAug()
|
48 |
elif method is None:
|
49 |
return None
|
50 |
else:
|
51 |
-
raise ValueError(f
|
52 |
|
53 |
|
54 |
-
if __name__ ==
|
55 |
-
augmentor = build_augmentor(
|
|
|
7 |
"""
|
8 |
|
9 |
def __init__(self) -> None:
|
10 |
+
self.augmentor = A.Compose(
|
11 |
+
[
|
12 |
+
A.RandomBrightnessContrast(
|
13 |
+
p=0.75, brightness_limit=(-0.6, 0.0), contrast_limit=(-0.5, 0.3)
|
14 |
+
),
|
15 |
+
A.Blur(p=0.1, blur_limit=(3, 9)),
|
16 |
+
A.MotionBlur(p=0.2, blur_limit=(3, 25)),
|
17 |
+
A.RandomGamma(p=0.1, gamma_limit=(15, 65)),
|
18 |
+
A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40)),
|
19 |
+
],
|
20 |
+
p=0.75,
|
21 |
+
)
|
22 |
|
23 |
def __call__(self, x):
|
24 |
+
return self.augmentor(image=x)["image"]
|
25 |
|
26 |
|
27 |
class MobileAug(object):
|
|
|
30 |
"""
|
31 |
|
32 |
def __init__(self):
|
33 |
+
self.augmentor = A.Compose(
|
34 |
+
[
|
35 |
+
A.MotionBlur(p=0.25),
|
36 |
+
A.ColorJitter(p=0.5),
|
37 |
+
A.RandomRain(p=0.1), # random occlusion
|
38 |
+
A.RandomSunFlare(p=0.1),
|
39 |
+
A.JpegCompression(p=0.25),
|
40 |
+
A.ISONoise(p=0.25),
|
41 |
+
],
|
42 |
+
p=1.0,
|
43 |
+
)
|
44 |
|
45 |
def __call__(self, x):
|
46 |
+
return self.augmentor(image=x)["image"]
|
47 |
|
48 |
|
49 |
def build_augmentor(method=None, **kwargs):
|
50 |
if method is not None:
|
51 |
+
raise NotImplementedError(
|
52 |
+
"Using of augmentation functions are not supported yet!"
|
53 |
+
)
|
54 |
+
if method == "dark":
|
55 |
return DarkAug()
|
56 |
+
elif method == "mobile":
|
57 |
return MobileAug()
|
58 |
elif method is None:
|
59 |
return None
|
60 |
else:
|
61 |
+
raise ValueError(f"Invalid augmentation method: {method}")
|
62 |
|
63 |
|
64 |
+
if __name__ == "__main__":
|
65 |
+
augmentor = build_augmentor("FDA")
|
third_party/ASpanFormer/src/utils/comm.py
CHANGED
@@ -98,11 +98,11 @@ def _serialize_to_tensor(data, group):
|
|
98 |
device = torch.device("cpu" if backend == "gloo" else "cuda")
|
99 |
|
100 |
buffer = pickle.dumps(data)
|
101 |
-
if len(buffer) > 1024
|
102 |
logger = logging.getLogger(__name__)
|
103 |
logger.warning(
|
104 |
"Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
|
105 |
-
get_rank(), len(buffer) / (1024
|
106 |
)
|
107 |
)
|
108 |
storage = torch.ByteStorage.from_buffer(buffer)
|
@@ -122,7 +122,8 @@ def _pad_to_largest_tensor(tensor, group):
|
|
122 |
), "comm.gather/all_gather must be called from ranks within the given group!"
|
123 |
local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
|
124 |
size_list = [
|
125 |
-
torch.zeros([1], dtype=torch.int64, device=tensor.device)
|
|
|
126 |
]
|
127 |
dist.all_gather(size_list, local_size, group=group)
|
128 |
|
@@ -133,7 +134,9 @@ def _pad_to_largest_tensor(tensor, group):
|
|
133 |
# we pad the tensor because torch all_gather does not support
|
134 |
# gathering tensors of different shapes
|
135 |
if local_size != max_size:
|
136 |
-
padding = torch.zeros(
|
|
|
|
|
137 |
tensor = torch.cat((tensor, padding), dim=0)
|
138 |
return size_list, tensor
|
139 |
|
@@ -164,7 +167,8 @@ def all_gather(data, group=None):
|
|
164 |
|
165 |
# receiving Tensor from all ranks
|
166 |
tensor_list = [
|
167 |
-
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
|
|
|
168 |
]
|
169 |
dist.all_gather(tensor_list, tensor, group=group)
|
170 |
|
@@ -205,7 +209,8 @@ def gather(data, dst=0, group=None):
|
|
205 |
if rank == dst:
|
206 |
max_size = max(size_list)
|
207 |
tensor_list = [
|
208 |
-
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
|
|
|
209 |
]
|
210 |
dist.gather(tensor, tensor_list, dst=dst, group=group)
|
211 |
|
@@ -228,7 +233,7 @@ def shared_random_seed():
|
|
228 |
|
229 |
All workers must call this function, otherwise it will deadlock.
|
230 |
"""
|
231 |
-
ints = np.random.randint(2
|
232 |
all_ints = all_gather(ints)
|
233 |
return all_ints[0]
|
234 |
|
|
|
98 |
device = torch.device("cpu" if backend == "gloo" else "cuda")
|
99 |
|
100 |
buffer = pickle.dumps(data)
|
101 |
+
if len(buffer) > 1024**3:
|
102 |
logger = logging.getLogger(__name__)
|
103 |
logger.warning(
|
104 |
"Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
|
105 |
+
get_rank(), len(buffer) / (1024**3), device
|
106 |
)
|
107 |
)
|
108 |
storage = torch.ByteStorage.from_buffer(buffer)
|
|
|
122 |
), "comm.gather/all_gather must be called from ranks within the given group!"
|
123 |
local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
|
124 |
size_list = [
|
125 |
+
torch.zeros([1], dtype=torch.int64, device=tensor.device)
|
126 |
+
for _ in range(world_size)
|
127 |
]
|
128 |
dist.all_gather(size_list, local_size, group=group)
|
129 |
|
|
|
134 |
# we pad the tensor because torch all_gather does not support
|
135 |
# gathering tensors of different shapes
|
136 |
if local_size != max_size:
|
137 |
+
padding = torch.zeros(
|
138 |
+
(max_size - local_size,), dtype=torch.uint8, device=tensor.device
|
139 |
+
)
|
140 |
tensor = torch.cat((tensor, padding), dim=0)
|
141 |
return size_list, tensor
|
142 |
|
|
|
167 |
|
168 |
# receiving Tensor from all ranks
|
169 |
tensor_list = [
|
170 |
+
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
|
171 |
+
for _ in size_list
|
172 |
]
|
173 |
dist.all_gather(tensor_list, tensor, group=group)
|
174 |
|
|
|
209 |
if rank == dst:
|
210 |
max_size = max(size_list)
|
211 |
tensor_list = [
|
212 |
+
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
|
213 |
+
for _ in size_list
|
214 |
]
|
215 |
dist.gather(tensor, tensor_list, dst=dst, group=group)
|
216 |
|
|
|
233 |
|
234 |
All workers must call this function, otherwise it will deadlock.
|
235 |
"""
|
236 |
+
ints = np.random.randint(2**31)
|
237 |
all_ints = all_gather(ints)
|
238 |
return all_ints[0]
|
239 |
|
third_party/ASpanFormer/src/utils/dataloader.py
CHANGED
@@ -3,21 +3,22 @@ import numpy as np
|
|
3 |
|
4 |
# --- PL-DATAMODULE ---
|
5 |
|
|
|
6 |
def get_local_split(items: list, world_size: int, rank: int, seed: int):
|
7 |
-
"""
|
8 |
n_items = len(items)
|
9 |
items_permute = np.random.RandomState(seed).permutation(items)
|
10 |
if n_items % world_size == 0:
|
11 |
padded_items = items_permute
|
12 |
else:
|
13 |
padding = np.random.RandomState(seed).choice(
|
14 |
-
items,
|
15 |
-
|
16 |
-
replace=True)
|
17 |
padded_items = np.concatenate([items_permute, padding])
|
18 |
-
assert
|
19 |
-
|
|
|
20 |
n_per_rank = len(padded_items) // world_size
|
21 |
-
local_items = padded_items[n_per_rank * rank: n_per_rank * (rank+1)]
|
22 |
|
23 |
return local_items
|
|
|
3 |
|
4 |
# --- PL-DATAMODULE ---
|
5 |
|
6 |
+
|
7 |
def get_local_split(items: list, world_size: int, rank: int, seed: int):
|
8 |
+
"""The local rank only loads a split of the dataset."""
|
9 |
n_items = len(items)
|
10 |
items_permute = np.random.RandomState(seed).permutation(items)
|
11 |
if n_items % world_size == 0:
|
12 |
padded_items = items_permute
|
13 |
else:
|
14 |
padding = np.random.RandomState(seed).choice(
|
15 |
+
items, world_size - (n_items % world_size), replace=True
|
16 |
+
)
|
|
|
17 |
padded_items = np.concatenate([items_permute, padding])
|
18 |
+
assert (
|
19 |
+
len(padded_items) % world_size == 0
|
20 |
+
), f"len(padded_items): {len(padded_items)}; world_size: {world_size}; len(padding): {len(padding)}"
|
21 |
n_per_rank = len(padded_items) // world_size
|
22 |
+
local_items = padded_items[n_per_rank * rank : n_per_rank * (rank + 1)]
|
23 |
|
24 |
return local_items
|
third_party/ASpanFormer/src/utils/dataset.py
CHANGED
@@ -15,8 +15,11 @@ except Exception:
|
|
15 |
|
16 |
# --- DATA IO ---
|
17 |
|
|
|
18 |
def load_array_from_s3(
|
19 |
-
path,
|
|
|
|
|
20 |
use_h5py=False,
|
21 |
):
|
22 |
byte_str = client.Get(path)
|
@@ -26,7 +29,7 @@ def load_array_from_s3(
|
|
26 |
data = cv2.imdecode(raw_array, cv_type)
|
27 |
else:
|
28 |
f = io.BytesIO(byte_str)
|
29 |
-
data = np.array(h5py.File(f,
|
30 |
except Exception as ex:
|
31 |
print(f"==> Data loading failure: {path}")
|
32 |
raise ex
|
@@ -36,9 +39,8 @@ def load_array_from_s3(
|
|
36 |
|
37 |
|
38 |
def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT):
|
39 |
-
cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None
|
40 |
-
|
41 |
-
if str(path).startswith('s3://'):
|
42 |
image = load_array_from_s3(str(path), client, cv_type)
|
43 |
else:
|
44 |
image = cv2.imread(str(path), cv_type)
|
@@ -54,7 +56,7 @@ def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT):
|
|
54 |
def get_resized_wh(w, h, resize=None):
|
55 |
if resize is not None: # resize the longer edge
|
56 |
scale = resize / max(h, w)
|
57 |
-
w_new, h_new = int(round(w*scale)), int(round(h*scale))
|
58 |
else:
|
59 |
w_new, h_new = w, h
|
60 |
return w_new, h_new
|
@@ -69,20 +71,22 @@ def get_divisible_wh(w, h, df=None):
|
|
69 |
|
70 |
|
71 |
def pad_bottom_right(inp, pad_size, ret_mask=False):
|
72 |
-
assert isinstance(pad_size, int) and pad_size >= max(
|
|
|
|
|
73 |
mask = None
|
74 |
if inp.ndim == 2:
|
75 |
padded = np.zeros((pad_size, pad_size), dtype=inp.dtype)
|
76 |
-
padded[:inp.shape[0], :inp.shape[1]] = inp
|
77 |
if ret_mask:
|
78 |
mask = np.zeros((pad_size, pad_size), dtype=bool)
|
79 |
-
mask[:inp.shape[0], :inp.shape[1]] = True
|
80 |
elif inp.ndim == 3:
|
81 |
padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype)
|
82 |
-
padded[:, :inp.shape[1], :inp.shape[2]] = inp
|
83 |
if ret_mask:
|
84 |
mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool)
|
85 |
-
mask[:, :inp.shape[1], :inp.shape[2]] = True
|
86 |
else:
|
87 |
raise NotImplementedError()
|
88 |
return padded, mask
|
@@ -90,6 +94,7 @@ def pad_bottom_right(inp, pad_size, ret_mask=False):
|
|
90 |
|
91 |
# --- MEGADEPTH ---
|
92 |
|
|
|
93 |
def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None):
|
94 |
"""
|
95 |
Args:
|
@@ -99,7 +104,7 @@ def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=No
|
|
99 |
Returns:
|
100 |
image (torch.tensor): (1, h, w)
|
101 |
mask (torch.tensor): (h, w)
|
102 |
-
scale (torch.tensor): [w/w_new, h/h_new]
|
103 |
"""
|
104 |
# read image
|
105 |
image = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT)
|
@@ -110,7 +115,7 @@ def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=No
|
|
110 |
w_new, h_new = get_divisible_wh(w_new, h_new, df)
|
111 |
|
112 |
image = cv2.resize(image, (w_new, h_new))
|
113 |
-
scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float)
|
114 |
|
115 |
if padding: # padding
|
116 |
pad_to = max(h_new, w_new)
|
@@ -118,7 +123,9 @@ def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=No
|
|
118 |
else:
|
119 |
mask = None
|
120 |
|
121 |
-
image =
|
|
|
|
|
122 |
if mask is not None:
|
123 |
mask = torch.from_numpy(mask)
|
124 |
|
@@ -126,10 +133,10 @@ def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=No
|
|
126 |
|
127 |
|
128 |
def read_megadepth_depth(path, pad_to=None):
|
129 |
-
if str(path).startswith(
|
130 |
depth = load_array_from_s3(path, MEGADEPTH_CLIENT, None, use_h5py=True)
|
131 |
else:
|
132 |
-
depth = np.array(h5py.File(path,
|
133 |
if pad_to is not None:
|
134 |
depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False)
|
135 |
depth = torch.from_numpy(depth).float() # (h, w)
|
@@ -138,6 +145,7 @@ def read_megadepth_depth(path, pad_to=None):
|
|
138 |
|
139 |
# --- ScanNet ---
|
140 |
|
|
|
141 |
def read_scannet_gray(path, resize=(640, 480), augment_fn=None):
|
142 |
"""
|
143 |
Args:
|
@@ -146,7 +154,7 @@ def read_scannet_gray(path, resize=(640, 480), augment_fn=None):
|
|
146 |
Returns:
|
147 |
image (torch.tensor): (1, h, w)
|
148 |
mask (torch.tensor): (h, w)
|
149 |
-
scale (torch.tensor): [w/w_new, h/h_new]
|
150 |
"""
|
151 |
# read and resize image
|
152 |
image = imread_gray(path, augment_fn)
|
@@ -158,7 +166,7 @@ def read_scannet_gray(path, resize=(640, 480), augment_fn=None):
|
|
158 |
|
159 |
|
160 |
def read_scannet_depth(path):
|
161 |
-
if str(path).startswith(
|
162 |
depth = load_array_from_s3(str(path), SCANNET_CLIENT, cv2.IMREAD_UNCHANGED)
|
163 |
else:
|
164 |
depth = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)
|
@@ -168,55 +176,57 @@ def read_scannet_depth(path):
|
|
168 |
|
169 |
|
170 |
def read_scannet_pose(path):
|
171 |
-
"""
|
172 |
-
|
173 |
Returns:
|
174 |
pose_w2c (np.ndarray): (4, 4)
|
175 |
"""
|
176 |
-
cam2world = np.loadtxt(path, delimiter=
|
177 |
world2cam = inv(cam2world)
|
178 |
return world2cam
|
179 |
|
180 |
|
181 |
def read_scannet_intrinsic(path):
|
182 |
-
"""
|
183 |
-
""
|
184 |
-
intrinsic = np.loadtxt(path, delimiter=' ')
|
185 |
return intrinsic[:-1, :-1]
|
186 |
|
187 |
|
188 |
-
def read_gl3d_gray(path,resize):
|
189 |
-
img=cv2.resize(cv2.imread(path,cv2.IMREAD_GRAYSCALE),(int(resize),int(resize)))
|
190 |
-
img =
|
|
|
|
|
191 |
return img
|
192 |
|
|
|
193 |
def read_gl3d_depth(file_path):
|
194 |
-
with open(file_path,
|
195 |
color = None
|
196 |
width = None
|
197 |
height = None
|
198 |
scale = None
|
199 |
data_type = None
|
200 |
-
header = str(fin.readline().decode(
|
201 |
-
if header ==
|
202 |
color = True
|
203 |
-
elif header ==
|
204 |
color = False
|
205 |
else:
|
206 |
-
raise Exception(
|
207 |
-
dim_match = re.match(r
|
208 |
if dim_match:
|
209 |
width, height = map(int, dim_match.groups())
|
210 |
else:
|
211 |
-
raise Exception(
|
212 |
-
scale = float((fin.readline().decode(
|
213 |
if scale < 0: # little-endian
|
214 |
-
data_type =
|
215 |
else:
|
216 |
-
data_type =
|
217 |
data_string = fin.read()
|
218 |
data = np.fromstring(data_string, data_type)
|
219 |
shape = (height, width, 3) if color else (height, width)
|
220 |
data = np.reshape(data, shape)
|
221 |
data = np.flip(data, 0)
|
222 |
-
return torch.from_numpy(data.copy()).float()
|
|
|
15 |
|
16 |
# --- DATA IO ---
|
17 |
|
18 |
+
|
19 |
def load_array_from_s3(
|
20 |
+
path,
|
21 |
+
client,
|
22 |
+
cv_type,
|
23 |
use_h5py=False,
|
24 |
):
|
25 |
byte_str = client.Get(path)
|
|
|
29 |
data = cv2.imdecode(raw_array, cv_type)
|
30 |
else:
|
31 |
f = io.BytesIO(byte_str)
|
32 |
+
data = np.array(h5py.File(f, "r")["/depth"])
|
33 |
except Exception as ex:
|
34 |
print(f"==> Data loading failure: {path}")
|
35 |
raise ex
|
|
|
39 |
|
40 |
|
41 |
def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT):
|
42 |
+
cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None else cv2.IMREAD_COLOR
|
43 |
+
if str(path).startswith("s3://"):
|
|
|
44 |
image = load_array_from_s3(str(path), client, cv_type)
|
45 |
else:
|
46 |
image = cv2.imread(str(path), cv_type)
|
|
|
56 |
def get_resized_wh(w, h, resize=None):
|
57 |
if resize is not None: # resize the longer edge
|
58 |
scale = resize / max(h, w)
|
59 |
+
w_new, h_new = int(round(w * scale)), int(round(h * scale))
|
60 |
else:
|
61 |
w_new, h_new = w, h
|
62 |
return w_new, h_new
|
|
|
71 |
|
72 |
|
73 |
def pad_bottom_right(inp, pad_size, ret_mask=False):
|
74 |
+
assert isinstance(pad_size, int) and pad_size >= max(
|
75 |
+
inp.shape[-2:]
|
76 |
+
), f"{pad_size} < {max(inp.shape[-2:])}"
|
77 |
mask = None
|
78 |
if inp.ndim == 2:
|
79 |
padded = np.zeros((pad_size, pad_size), dtype=inp.dtype)
|
80 |
+
padded[: inp.shape[0], : inp.shape[1]] = inp
|
81 |
if ret_mask:
|
82 |
mask = np.zeros((pad_size, pad_size), dtype=bool)
|
83 |
+
mask[: inp.shape[0], : inp.shape[1]] = True
|
84 |
elif inp.ndim == 3:
|
85 |
padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype)
|
86 |
+
padded[:, : inp.shape[1], : inp.shape[2]] = inp
|
87 |
if ret_mask:
|
88 |
mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool)
|
89 |
+
mask[:, : inp.shape[1], : inp.shape[2]] = True
|
90 |
else:
|
91 |
raise NotImplementedError()
|
92 |
return padded, mask
|
|
|
94 |
|
95 |
# --- MEGADEPTH ---
|
96 |
|
97 |
+
|
98 |
def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None):
|
99 |
"""
|
100 |
Args:
|
|
|
104 |
Returns:
|
105 |
image (torch.tensor): (1, h, w)
|
106 |
mask (torch.tensor): (h, w)
|
107 |
+
scale (torch.tensor): [w/w_new, h/h_new]
|
108 |
"""
|
109 |
# read image
|
110 |
image = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT)
|
|
|
115 |
w_new, h_new = get_divisible_wh(w_new, h_new, df)
|
116 |
|
117 |
image = cv2.resize(image, (w_new, h_new))
|
118 |
+
scale = torch.tensor([w / w_new, h / h_new], dtype=torch.float)
|
119 |
|
120 |
if padding: # padding
|
121 |
pad_to = max(h_new, w_new)
|
|
|
123 |
else:
|
124 |
mask = None
|
125 |
|
126 |
+
image = (
|
127 |
+
torch.from_numpy(image).float()[None] / 255
|
128 |
+
) # (h, w) -> (1, h, w) and normalized
|
129 |
if mask is not None:
|
130 |
mask = torch.from_numpy(mask)
|
131 |
|
|
|
133 |
|
134 |
|
135 |
def read_megadepth_depth(path, pad_to=None):
|
136 |
+
if str(path).startswith("s3://"):
|
137 |
depth = load_array_from_s3(path, MEGADEPTH_CLIENT, None, use_h5py=True)
|
138 |
else:
|
139 |
+
depth = np.array(h5py.File(path, "r")["depth"])
|
140 |
if pad_to is not None:
|
141 |
depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False)
|
142 |
depth = torch.from_numpy(depth).float() # (h, w)
|
|
|
145 |
|
146 |
# --- ScanNet ---
|
147 |
|
148 |
+
|
149 |
def read_scannet_gray(path, resize=(640, 480), augment_fn=None):
|
150 |
"""
|
151 |
Args:
|
|
|
154 |
Returns:
|
155 |
image (torch.tensor): (1, h, w)
|
156 |
mask (torch.tensor): (h, w)
|
157 |
+
scale (torch.tensor): [w/w_new, h/h_new]
|
158 |
"""
|
159 |
# read and resize image
|
160 |
image = imread_gray(path, augment_fn)
|
|
|
166 |
|
167 |
|
168 |
def read_scannet_depth(path):
|
169 |
+
if str(path).startswith("s3://"):
|
170 |
depth = load_array_from_s3(str(path), SCANNET_CLIENT, cv2.IMREAD_UNCHANGED)
|
171 |
else:
|
172 |
depth = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)
|
|
|
176 |
|
177 |
|
178 |
def read_scannet_pose(path):
|
179 |
+
"""Read ScanNet's Camera2World pose and transform it to World2Camera.
|
180 |
+
|
181 |
Returns:
|
182 |
pose_w2c (np.ndarray): (4, 4)
|
183 |
"""
|
184 |
+
cam2world = np.loadtxt(path, delimiter=" ")
|
185 |
world2cam = inv(cam2world)
|
186 |
return world2cam
|
187 |
|
188 |
|
189 |
def read_scannet_intrinsic(path):
|
190 |
+
"""Read ScanNet's intrinsic matrix and return the 3x3 matrix."""
|
191 |
+
intrinsic = np.loadtxt(path, delimiter=" ")
|
|
|
192 |
return intrinsic[:-1, :-1]
|
193 |
|
194 |
|
195 |
+
def read_gl3d_gray(path, resize):
|
196 |
+
img = cv2.resize(cv2.imread(path, cv2.IMREAD_GRAYSCALE), (int(resize), int(resize)))
|
197 |
+
img = (
|
198 |
+
torch.from_numpy(img).float()[None] / 255
|
199 |
+
) # (h, w) -> (1, h, w) and normalized
|
200 |
return img
|
201 |
|
202 |
+
|
203 |
def read_gl3d_depth(file_path):
|
204 |
+
with open(file_path, "rb") as fin:
|
205 |
color = None
|
206 |
width = None
|
207 |
height = None
|
208 |
scale = None
|
209 |
data_type = None
|
210 |
+
header = str(fin.readline().decode("UTF-8")).rstrip()
|
211 |
+
if header == "PF":
|
212 |
color = True
|
213 |
+
elif header == "Pf":
|
214 |
color = False
|
215 |
else:
|
216 |
+
raise Exception("Not a PFM file.")
|
217 |
+
dim_match = re.match(r"^(\d+)\s(\d+)\s$", fin.readline().decode("UTF-8"))
|
218 |
if dim_match:
|
219 |
width, height = map(int, dim_match.groups())
|
220 |
else:
|
221 |
+
raise Exception("Malformed PFM header.")
|
222 |
+
scale = float((fin.readline().decode("UTF-8")).rstrip())
|
223 |
if scale < 0: # little-endian
|
224 |
+
data_type = "<f"
|
225 |
else:
|
226 |
+
data_type = ">f" # big-endian
|
227 |
data_string = fin.read()
|
228 |
data = np.fromstring(data_string, data_type)
|
229 |
shape = (height, width, 3) if color else (height, width)
|
230 |
data = np.reshape(data, shape)
|
231 |
data = np.flip(data, 0)
|
232 |
+
return torch.from_numpy(data.copy()).float()
|
third_party/ASpanFormer/src/utils/metrics.py
CHANGED
@@ -9,6 +9,7 @@ from kornia.geometry.conversions import convert_points_to_homogeneous
|
|
9 |
|
10 |
# --- METRICS ---
|
11 |
|
|
|
12 |
def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0):
|
13 |
# angle error between 2 vectors
|
14 |
t_gt = T_0to1[:3, 3]
|
@@ -21,7 +22,7 @@ def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0):
|
|
21 |
# angle error between 2 rotation matrices
|
22 |
R_gt = T_0to1[:3, :3]
|
23 |
cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2
|
24 |
-
cos = np.clip(cos, -1
|
25 |
R_err = np.rad2deg(np.abs(np.arccos(cos)))
|
26 |
|
27 |
return t_err, R_err
|
@@ -43,93 +44,108 @@ def symmetric_epipolar_distance(pts0, pts1, E, K0, K1):
|
|
43 |
p1Ep0 = torch.sum(pts1 * Ep0, -1) # [N,]
|
44 |
Etp1 = pts1 @ E # [N, 3]
|
45 |
|
46 |
-
d = p1Ep0**2 * (
|
|
|
|
|
|
|
47 |
return d
|
48 |
|
49 |
|
50 |
def compute_symmetrical_epipolar_errors(data):
|
51 |
-
"""
|
52 |
Update:
|
53 |
data (dict):{"epi_errs": [M]}
|
54 |
"""
|
55 |
-
Tx = numeric.cross_product_matrix(data[
|
56 |
-
E_mat = Tx @ data[
|
57 |
|
58 |
-
m_bids = data[
|
59 |
-
pts0 = data[
|
60 |
-
pts1 = data[
|
61 |
|
62 |
epi_errs = []
|
63 |
for bs in range(Tx.size(0)):
|
64 |
mask = m_bids == bs
|
65 |
epi_errs.append(
|
66 |
-
symmetric_epipolar_distance(
|
|
|
|
|
|
|
67 |
epi_errs = torch.cat(epi_errs, dim=0)
|
68 |
|
69 |
-
data.update({
|
|
|
70 |
|
71 |
def compute_symmetrical_epipolar_errors_offset(data):
|
72 |
-
"""
|
73 |
Update:
|
74 |
data (dict):{"epi_errs": [M]}
|
75 |
"""
|
76 |
-
Tx = numeric.cross_product_matrix(data[
|
77 |
-
E_mat = Tx @ data[
|
78 |
|
79 |
-
m_bids = data[
|
80 |
-
l_ids=data[
|
81 |
-
pts0 = data[
|
82 |
-
pts1 = data[
|
83 |
|
84 |
epi_errs = []
|
85 |
-
layer_num=data[
|
86 |
-
|
87 |
for bs in range(Tx.size(0)):
|
88 |
for ls in range(layer_num):
|
89 |
mask_b = m_bids == bs
|
90 |
mask_l = l_ids == ls
|
91 |
-
mask=mask_b&mask_l
|
92 |
epi_errs.append(
|
93 |
-
symmetric_epipolar_distance(
|
|
|
|
|
|
|
94 |
epi_errs = torch.cat(epi_errs, dim=0)
|
95 |
|
96 |
-
data.update({
|
|
|
97 |
|
98 |
def compute_symmetrical_epipolar_errors_offset_bidirectional(data):
|
99 |
-
"""
|
100 |
Update
|
101 |
data (dict):{"epi_errs": [M]}
|
102 |
"""
|
103 |
-
_compute_symmetrical_epipolar_errors_offset(data,
|
104 |
-
_compute_symmetrical_epipolar_errors_offset(data,
|
105 |
|
106 |
|
107 |
-
def _compute_symmetrical_epipolar_errors_offset(data,side):
|
108 |
-
"""
|
109 |
Update
|
110 |
data (dict):{"epi_errs": [M]}
|
111 |
"""
|
112 |
-
assert side==
|
113 |
|
114 |
-
Tx = numeric.cross_product_matrix(data[
|
115 |
-
E_mat = Tx @ data[
|
116 |
|
117 |
-
m_bids = data[
|
118 |
-
l_ids=data[
|
119 |
-
pts0 = data[
|
120 |
-
pts1 = data[
|
121 |
|
122 |
epi_errs = []
|
123 |
-
layer_num=data[
|
124 |
for bs in range(Tx.size(0)):
|
125 |
for ls in range(layer_num):
|
126 |
mask_b = m_bids == bs
|
127 |
mask_l = l_ids == ls
|
128 |
-
mask=mask_b&mask_l
|
129 |
epi_errs.append(
|
130 |
-
symmetric_epipolar_distance(
|
|
|
|
|
|
|
131 |
epi_errs = torch.cat(epi_errs, dim=0)
|
132 |
-
data.update({
|
|
|
133 |
|
134 |
def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999):
|
135 |
if len(kpts0) < 5:
|
@@ -143,7 +159,8 @@ def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999):
|
|
143 |
|
144 |
# compute pose with cv2
|
145 |
E, mask = cv2.findEssentialMat(
|
146 |
-
kpts0, kpts1, np.eye(3), threshold=ransac_thr, prob=conf, method=cv2.RANSAC
|
|
|
147 |
if E is None:
|
148 |
print("\nE is None while trying to recover pose.\n")
|
149 |
return None
|
@@ -161,7 +178,7 @@ def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999):
|
|
161 |
|
162 |
|
163 |
def compute_pose_errors(data, config):
|
164 |
-
"""
|
165 |
Update:
|
166 |
data (dict):{
|
167 |
"R_errs" List[float]: [N]
|
@@ -171,33 +188,36 @@ def compute_pose_errors(data, config):
|
|
171 |
"""
|
172 |
pixel_thr = config.TRAINER.RANSAC_PIXEL_THR # 0.5
|
173 |
conf = config.TRAINER.RANSAC_CONF # 0.99999
|
174 |
-
data.update({
|
175 |
|
176 |
-
m_bids = data[
|
177 |
-
pts0 = data[
|
178 |
-
pts1 = data[
|
179 |
-
K0 = data[
|
180 |
-
K1 = data[
|
181 |
-
T_0to1 = data[
|
182 |
|
183 |
for bs in range(K0.shape[0]):
|
184 |
mask = m_bids == bs
|
185 |
-
ret = estimate_pose(
|
|
|
|
|
186 |
|
187 |
if ret is None:
|
188 |
-
data[
|
189 |
-
data[
|
190 |
-
data[
|
191 |
else:
|
192 |
R, t, inliers = ret
|
193 |
t_err, R_err = relative_pose_error(T_0to1[bs], R, t, ignore_gt_t_thr=0.0)
|
194 |
-
data[
|
195 |
-
data[
|
196 |
-
data[
|
197 |
|
198 |
|
199 |
# --- METRIC AGGREGATION ---
|
200 |
|
|
|
201 |
def error_auc(errors, thresholds):
|
202 |
"""
|
203 |
Args:
|
@@ -211,14 +231,14 @@ def error_auc(errors, thresholds):
|
|
211 |
thresholds = [5, 10, 20]
|
212 |
for thr in thresholds:
|
213 |
last_index = np.searchsorted(errors, thr)
|
214 |
-
y = recall[:last_index] + [recall[last_index-1]]
|
215 |
x = errors[:last_index] + [thr]
|
216 |
aucs.append(np.trapz(y, x) / thr)
|
217 |
|
218 |
-
return {f
|
219 |
|
220 |
|
221 |
-
def epidist_prec(errors, thresholds, ret_dict=False,offset=False):
|
222 |
precs = []
|
223 |
for thr in thresholds:
|
224 |
prec_ = []
|
@@ -227,34 +247,47 @@ def epidist_prec(errors, thresholds, ret_dict=False,offset=False):
|
|
227 |
prec_.append(np.mean(correct_mask) if len(correct_mask) > 0 else 0)
|
228 |
precs.append(np.mean(prec_) if len(prec_) > 0 else 0)
|
229 |
if ret_dict:
|
230 |
-
return
|
|
|
|
|
|
|
|
|
231 |
else:
|
232 |
return precs
|
233 |
|
234 |
|
235 |
def aggregate_metrics(metrics, epi_err_thr=5e-4):
|
236 |
-
"""
|
237 |
(This method should be called once per dataset)
|
238 |
1. AUC of the pose error (angular) at the threshold [5, 10, 20]
|
239 |
2. Mean matching precision at the threshold 5e-4(ScanNet), 1e-4(MegaDepth)
|
240 |
"""
|
241 |
# filter duplicates
|
242 |
-
unq_ids = OrderedDict((iden, id) for id, iden in enumerate(metrics[
|
243 |
unq_ids = list(unq_ids.values())
|
244 |
-
logger.info(f
|
245 |
|
246 |
# pose auc
|
247 |
angular_thresholds = [5, 10, 20]
|
248 |
-
pose_errors = np.max(np.stack([metrics[
|
|
|
|
|
249 |
aucs = error_auc(pose_errors, angular_thresholds) # (auc@5, auc@10, auc@20)
|
250 |
|
251 |
# matching precision
|
252 |
dist_thresholds = [epi_err_thr]
|
253 |
-
precs = epidist_prec(
|
254 |
-
|
255 |
-
#
|
|
|
|
|
256 |
try:
|
257 |
-
precs_offset = epidist_prec(
|
258 |
-
|
|
|
|
|
|
|
|
|
|
|
259 |
except:
|
260 |
return {**aucs, **precs}
|
|
|
9 |
|
10 |
# --- METRICS ---
|
11 |
|
12 |
+
|
13 |
def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0):
|
14 |
# angle error between 2 vectors
|
15 |
t_gt = T_0to1[:3, 3]
|
|
|
22 |
# angle error between 2 rotation matrices
|
23 |
R_gt = T_0to1[:3, :3]
|
24 |
cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2
|
25 |
+
cos = np.clip(cos, -1.0, 1.0) # handle numercial errors
|
26 |
R_err = np.rad2deg(np.abs(np.arccos(cos)))
|
27 |
|
28 |
return t_err, R_err
|
|
|
44 |
p1Ep0 = torch.sum(pts1 * Ep0, -1) # [N,]
|
45 |
Etp1 = pts1 @ E # [N, 3]
|
46 |
|
47 |
+
d = p1Ep0**2 * (
|
48 |
+
1.0 / (Ep0[:, 0] ** 2 + Ep0[:, 1] ** 2)
|
49 |
+
+ 1.0 / (Etp1[:, 0] ** 2 + Etp1[:, 1] ** 2)
|
50 |
+
) # N
|
51 |
return d
|
52 |
|
53 |
|
54 |
def compute_symmetrical_epipolar_errors(data):
|
55 |
+
"""
|
56 |
Update:
|
57 |
data (dict):{"epi_errs": [M]}
|
58 |
"""
|
59 |
+
Tx = numeric.cross_product_matrix(data["T_0to1"][:, :3, 3])
|
60 |
+
E_mat = Tx @ data["T_0to1"][:, :3, :3]
|
61 |
|
62 |
+
m_bids = data["m_bids"]
|
63 |
+
pts0 = data["mkpts0_f"]
|
64 |
+
pts1 = data["mkpts1_f"]
|
65 |
|
66 |
epi_errs = []
|
67 |
for bs in range(Tx.size(0)):
|
68 |
mask = m_bids == bs
|
69 |
epi_errs.append(
|
70 |
+
symmetric_epipolar_distance(
|
71 |
+
pts0[mask], pts1[mask], E_mat[bs], data["K0"][bs], data["K1"][bs]
|
72 |
+
)
|
73 |
+
)
|
74 |
epi_errs = torch.cat(epi_errs, dim=0)
|
75 |
|
76 |
+
data.update({"epi_errs": epi_errs})
|
77 |
+
|
78 |
|
79 |
def compute_symmetrical_epipolar_errors_offset(data):
|
80 |
+
"""
|
81 |
Update:
|
82 |
data (dict):{"epi_errs": [M]}
|
83 |
"""
|
84 |
+
Tx = numeric.cross_product_matrix(data["T_0to1"][:, :3, 3])
|
85 |
+
E_mat = Tx @ data["T_0to1"][:, :3, :3]
|
86 |
|
87 |
+
m_bids = data["offset_bids"]
|
88 |
+
l_ids = data["offset_lids"]
|
89 |
+
pts0 = data["offset_kpts0_f"]
|
90 |
+
pts1 = data["offset_kpts1_f"]
|
91 |
|
92 |
epi_errs = []
|
93 |
+
layer_num = data["predict_flow"][0].shape[0]
|
94 |
+
|
95 |
for bs in range(Tx.size(0)):
|
96 |
for ls in range(layer_num):
|
97 |
mask_b = m_bids == bs
|
98 |
mask_l = l_ids == ls
|
99 |
+
mask = mask_b & mask_l
|
100 |
epi_errs.append(
|
101 |
+
symmetric_epipolar_distance(
|
102 |
+
pts0[mask], pts1[mask], E_mat[bs], data["K0"][bs], data["K1"][bs]
|
103 |
+
)
|
104 |
+
)
|
105 |
epi_errs = torch.cat(epi_errs, dim=0)
|
106 |
|
107 |
+
data.update({"epi_errs_offset": epi_errs}) # [b*l*n]
|
108 |
+
|
109 |
|
110 |
def compute_symmetrical_epipolar_errors_offset_bidirectional(data):
|
111 |
+
"""
|
112 |
Update
|
113 |
data (dict):{"epi_errs": [M]}
|
114 |
"""
|
115 |
+
_compute_symmetrical_epipolar_errors_offset(data, "left")
|
116 |
+
_compute_symmetrical_epipolar_errors_offset(data, "right")
|
117 |
|
118 |
|
119 |
+
def _compute_symmetrical_epipolar_errors_offset(data, side):
|
120 |
+
"""
|
121 |
Update
|
122 |
data (dict):{"epi_errs": [M]}
|
123 |
"""
|
124 |
+
assert side == "left" or side == "right", "invalid side"
|
125 |
|
126 |
+
Tx = numeric.cross_product_matrix(data["T_0to1"][:, :3, 3])
|
127 |
+
E_mat = Tx @ data["T_0to1"][:, :3, :3]
|
128 |
|
129 |
+
m_bids = data["offset_bids_" + side]
|
130 |
+
l_ids = data["offset_lids_" + side]
|
131 |
+
pts0 = data["offset_kpts0_f_" + side]
|
132 |
+
pts1 = data["offset_kpts1_f_" + side]
|
133 |
|
134 |
epi_errs = []
|
135 |
+
layer_num = data["predict_flow"][0].shape[0]
|
136 |
for bs in range(Tx.size(0)):
|
137 |
for ls in range(layer_num):
|
138 |
mask_b = m_bids == bs
|
139 |
mask_l = l_ids == ls
|
140 |
+
mask = mask_b & mask_l
|
141 |
epi_errs.append(
|
142 |
+
symmetric_epipolar_distance(
|
143 |
+
pts0[mask], pts1[mask], E_mat[bs], data["K0"][bs], data["K1"][bs]
|
144 |
+
)
|
145 |
+
)
|
146 |
epi_errs = torch.cat(epi_errs, dim=0)
|
147 |
+
data.update({"epi_errs_offset_" + side: epi_errs}) # [b*l*n]
|
148 |
+
|
149 |
|
150 |
def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999):
|
151 |
if len(kpts0) < 5:
|
|
|
159 |
|
160 |
# compute pose with cv2
|
161 |
E, mask = cv2.findEssentialMat(
|
162 |
+
kpts0, kpts1, np.eye(3), threshold=ransac_thr, prob=conf, method=cv2.RANSAC
|
163 |
+
)
|
164 |
if E is None:
|
165 |
print("\nE is None while trying to recover pose.\n")
|
166 |
return None
|
|
|
178 |
|
179 |
|
180 |
def compute_pose_errors(data, config):
|
181 |
+
"""
|
182 |
Update:
|
183 |
data (dict):{
|
184 |
"R_errs" List[float]: [N]
|
|
|
188 |
"""
|
189 |
pixel_thr = config.TRAINER.RANSAC_PIXEL_THR # 0.5
|
190 |
conf = config.TRAINER.RANSAC_CONF # 0.99999
|
191 |
+
data.update({"R_errs": [], "t_errs": [], "inliers": []})
|
192 |
|
193 |
+
m_bids = data["m_bids"].cpu().numpy()
|
194 |
+
pts0 = data["mkpts0_f"].cpu().numpy()
|
195 |
+
pts1 = data["mkpts1_f"].cpu().numpy()
|
196 |
+
K0 = data["K0"].cpu().numpy()
|
197 |
+
K1 = data["K1"].cpu().numpy()
|
198 |
+
T_0to1 = data["T_0to1"].cpu().numpy()
|
199 |
|
200 |
for bs in range(K0.shape[0]):
|
201 |
mask = m_bids == bs
|
202 |
+
ret = estimate_pose(
|
203 |
+
pts0[mask], pts1[mask], K0[bs], K1[bs], pixel_thr, conf=conf
|
204 |
+
)
|
205 |
|
206 |
if ret is None:
|
207 |
+
data["R_errs"].append(np.inf)
|
208 |
+
data["t_errs"].append(np.inf)
|
209 |
+
data["inliers"].append(np.array([]).astype(np.bool))
|
210 |
else:
|
211 |
R, t, inliers = ret
|
212 |
t_err, R_err = relative_pose_error(T_0to1[bs], R, t, ignore_gt_t_thr=0.0)
|
213 |
+
data["R_errs"].append(R_err)
|
214 |
+
data["t_errs"].append(t_err)
|
215 |
+
data["inliers"].append(inliers)
|
216 |
|
217 |
|
218 |
# --- METRIC AGGREGATION ---
|
219 |
|
220 |
+
|
221 |
def error_auc(errors, thresholds):
|
222 |
"""
|
223 |
Args:
|
|
|
231 |
thresholds = [5, 10, 20]
|
232 |
for thr in thresholds:
|
233 |
last_index = np.searchsorted(errors, thr)
|
234 |
+
y = recall[:last_index] + [recall[last_index - 1]]
|
235 |
x = errors[:last_index] + [thr]
|
236 |
aucs.append(np.trapz(y, x) / thr)
|
237 |
|
238 |
+
return {f"auc@{t}": auc for t, auc in zip(thresholds, aucs)}
|
239 |
|
240 |
|
241 |
+
def epidist_prec(errors, thresholds, ret_dict=False, offset=False):
|
242 |
precs = []
|
243 |
for thr in thresholds:
|
244 |
prec_ = []
|
|
|
247 |
prec_.append(np.mean(correct_mask) if len(correct_mask) > 0 else 0)
|
248 |
precs.append(np.mean(prec_) if len(prec_) > 0 else 0)
|
249 |
if ret_dict:
|
250 |
+
return (
|
251 |
+
{f"prec@{t:.0e}": prec for t, prec in zip(thresholds, precs)}
|
252 |
+
if not offset
|
253 |
+
else {f"prec_flow@{t:.0e}": prec for t, prec in zip(thresholds, precs)}
|
254 |
+
)
|
255 |
else:
|
256 |
return precs
|
257 |
|
258 |
|
259 |
def aggregate_metrics(metrics, epi_err_thr=5e-4):
|
260 |
+
"""Aggregate metrics for the whole dataset:
|
261 |
(This method should be called once per dataset)
|
262 |
1. AUC of the pose error (angular) at the threshold [5, 10, 20]
|
263 |
2. Mean matching precision at the threshold 5e-4(ScanNet), 1e-4(MegaDepth)
|
264 |
"""
|
265 |
# filter duplicates
|
266 |
+
unq_ids = OrderedDict((iden, id) for id, iden in enumerate(metrics["identifiers"]))
|
267 |
unq_ids = list(unq_ids.values())
|
268 |
+
logger.info(f"Aggregating metrics over {len(unq_ids)} unique items...")
|
269 |
|
270 |
# pose auc
|
271 |
angular_thresholds = [5, 10, 20]
|
272 |
+
pose_errors = np.max(np.stack([metrics["R_errs"], metrics["t_errs"]]), axis=0)[
|
273 |
+
unq_ids
|
274 |
+
]
|
275 |
aucs = error_auc(pose_errors, angular_thresholds) # (auc@5, auc@10, auc@20)
|
276 |
|
277 |
# matching precision
|
278 |
dist_thresholds = [epi_err_thr]
|
279 |
+
precs = epidist_prec(
|
280 |
+
np.array(metrics["epi_errs"], dtype=object)[unq_ids], dist_thresholds, True
|
281 |
+
) # (prec@err_thr)
|
282 |
+
|
283 |
+
# offset precision
|
284 |
try:
|
285 |
+
precs_offset = epidist_prec(
|
286 |
+
np.array(metrics["epi_errs_offset"], dtype=object)[unq_ids],
|
287 |
+
[2e-3],
|
288 |
+
True,
|
289 |
+
offset=True,
|
290 |
+
)
|
291 |
+
return {**aucs, **precs, **precs_offset}
|
292 |
except:
|
293 |
return {**aucs, **precs}
|
third_party/ASpanFormer/src/utils/misc.py
CHANGED
@@ -11,6 +11,7 @@ from pytorch_lightning.utilities import rank_zero_only
|
|
11 |
import cv2
|
12 |
import numpy as np
|
13 |
|
|
|
14 |
def lower_config(yacs_cfg):
|
15 |
if not isinstance(yacs_cfg, CN):
|
16 |
return yacs_cfg
|
@@ -25,7 +26,7 @@ def upper_config(dict_cfg):
|
|
25 |
|
26 |
def log_on(condition, message, level):
|
27 |
if condition:
|
28 |
-
assert level in [
|
29 |
logger.log(level, message)
|
30 |
|
31 |
|
@@ -35,32 +36,35 @@ def get_rank_zero_only_logger(logger: _Logger):
|
|
35 |
else:
|
36 |
for _level in logger._core.levels.keys():
|
37 |
level = _level.lower()
|
38 |
-
setattr(logger, level,
|
39 |
-
lambda x: None)
|
40 |
logger._log = lambda x: None
|
41 |
return logger
|
42 |
|
43 |
|
44 |
def setup_gpus(gpus: Union[str, int]) -> int:
|
45 |
-
"""
|
46 |
gpus = str(gpus)
|
47 |
gpu_ids = []
|
48 |
-
|
49 |
-
if
|
50 |
n_gpus = int(gpus)
|
51 |
return n_gpus if n_gpus != -1 else torch.cuda.device_count()
|
52 |
else:
|
53 |
-
gpu_ids = [i.strip() for i in gpus.split(
|
54 |
-
|
55 |
# setup environment variables
|
56 |
-
visible_devices = os.getenv(
|
57 |
if visible_devices is None:
|
58 |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
59 |
-
os.environ["CUDA_VISIBLE_DEVICES"] =
|
60 |
-
visible_devices = os.getenv(
|
61 |
-
logger.warning(
|
|
|
|
|
62 |
else:
|
63 |
-
logger.warning(
|
|
|
|
|
64 |
return len(gpu_ids)
|
65 |
|
66 |
|
@@ -71,11 +75,11 @@ def flattenList(x):
|
|
71 |
@contextlib.contextmanager
|
72 |
def tqdm_joblib(tqdm_object):
|
73 |
"""Context manager to patch joblib to report into tqdm progress bar given as argument
|
74 |
-
|
75 |
Usage:
|
76 |
with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar:
|
77 |
Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10))
|
78 |
-
|
79 |
When iterating over a generator, directly use of tqdm is also a solutin (but monitor the task queuing, instead of finishing)
|
80 |
ret_vals = Parallel(n_jobs=args.world_size)(
|
81 |
delayed(lambda x: _compute_cov_score(pid, *x))(param)
|
@@ -84,6 +88,7 @@ def tqdm_joblib(tqdm_object):
|
|
84 |
total=len(image_ids)*(len(image_ids)-1)/2))
|
85 |
Src: https://stackoverflow.com/a/58936697
|
86 |
"""
|
|
|
87 |
class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
|
88 |
def __init__(self, *args, **kwargs):
|
89 |
super().__init__(*args, **kwargs)
|
@@ -101,39 +106,79 @@ def tqdm_joblib(tqdm_object):
|
|
101 |
tqdm_object.close()
|
102 |
|
103 |
|
104 |
-
def draw_points(img,points,color=(0,255,0),radius=3):
|
105 |
dp = [(int(points[i, 0]), int(points[i, 1])) for i in range(points.shape[0])]
|
106 |
for i in range(points.shape[0]):
|
107 |
-
cv2.circle(img, dp[i],radius=radius,color=color)
|
108 |
return img
|
109 |
-
|
110 |
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
if resize is not None:
|
113 |
-
scale1,scale2=[img1.shape[1]/resize[0],img1.shape[0]/resize[1]],[
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
assert len(corr1) == len(corr2)
|
120 |
|
121 |
draw_matches = [cv2.DMatch(i, i, 0) for i in range(len(corr1))]
|
122 |
if color is None:
|
123 |
-
color = [(0, 255, 0) if cur_inlier else (0,0,255) for cur_inlier in inlier]
|
124 |
-
if len(color)==1:
|
125 |
-
display = cv2.drawMatches(
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
else:
|
131 |
-
height,width=max(img1.shape[0],img2.shape[0]),img1.shape[1]+img2.shape[1]
|
132 |
-
display=np.zeros([height,width,3],np.uint8)
|
133 |
-
display[:img1.shape[0]
|
134 |
-
display[:img2.shape[0],img1.shape[1]:]=img2
|
135 |
for i in range(len(corr1)):
|
136 |
-
left_x,left_y,right_x,right_y=
|
137 |
-
|
138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
return display
|
|
|
11 |
import cv2
|
12 |
import numpy as np
|
13 |
|
14 |
+
|
15 |
def lower_config(yacs_cfg):
|
16 |
if not isinstance(yacs_cfg, CN):
|
17 |
return yacs_cfg
|
|
|
26 |
|
27 |
def log_on(condition, message, level):
|
28 |
if condition:
|
29 |
+
assert level in ["INFO", "DEBUG", "WARNING", "ERROR", "CRITICAL"]
|
30 |
logger.log(level, message)
|
31 |
|
32 |
|
|
|
36 |
else:
|
37 |
for _level in logger._core.levels.keys():
|
38 |
level = _level.lower()
|
39 |
+
setattr(logger, level, lambda x: None)
|
|
|
40 |
logger._log = lambda x: None
|
41 |
return logger
|
42 |
|
43 |
|
44 |
def setup_gpus(gpus: Union[str, int]) -> int:
|
45 |
+
"""A temporary fix for pytorch-lighting 1.3.x"""
|
46 |
gpus = str(gpus)
|
47 |
gpu_ids = []
|
48 |
+
|
49 |
+
if "," not in gpus:
|
50 |
n_gpus = int(gpus)
|
51 |
return n_gpus if n_gpus != -1 else torch.cuda.device_count()
|
52 |
else:
|
53 |
+
gpu_ids = [i.strip() for i in gpus.split(",") if i != ""]
|
54 |
+
|
55 |
# setup environment variables
|
56 |
+
visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
|
57 |
if visible_devices is None:
|
58 |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
59 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in gpu_ids)
|
60 |
+
visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
|
61 |
+
logger.warning(
|
62 |
+
f"[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}"
|
63 |
+
)
|
64 |
else:
|
65 |
+
logger.warning(
|
66 |
+
"[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process."
|
67 |
+
)
|
68 |
return len(gpu_ids)
|
69 |
|
70 |
|
|
|
75 |
@contextlib.contextmanager
|
76 |
def tqdm_joblib(tqdm_object):
|
77 |
"""Context manager to patch joblib to report into tqdm progress bar given as argument
|
78 |
+
|
79 |
Usage:
|
80 |
with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar:
|
81 |
Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10))
|
82 |
+
|
83 |
When iterating over a generator, directly use of tqdm is also a solutin (but monitor the task queuing, instead of finishing)
|
84 |
ret_vals = Parallel(n_jobs=args.world_size)(
|
85 |
delayed(lambda x: _compute_cov_score(pid, *x))(param)
|
|
|
88 |
total=len(image_ids)*(len(image_ids)-1)/2))
|
89 |
Src: https://stackoverflow.com/a/58936697
|
90 |
"""
|
91 |
+
|
92 |
class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
|
93 |
def __init__(self, *args, **kwargs):
|
94 |
super().__init__(*args, **kwargs)
|
|
|
106 |
tqdm_object.close()
|
107 |
|
108 |
|
109 |
+
def draw_points(img, points, color=(0, 255, 0), radius=3):
|
110 |
dp = [(int(points[i, 0]), int(points[i, 1])) for i in range(points.shape[0])]
|
111 |
for i in range(points.shape[0]):
|
112 |
+
cv2.circle(img, dp[i], radius=radius, color=color)
|
113 |
return img
|
|
|
114 |
|
115 |
+
|
116 |
+
def draw_match(
|
117 |
+
img1,
|
118 |
+
img2,
|
119 |
+
corr1,
|
120 |
+
corr2,
|
121 |
+
inlier=[True],
|
122 |
+
color=None,
|
123 |
+
radius1=1,
|
124 |
+
radius2=1,
|
125 |
+
resize=None,
|
126 |
+
):
|
127 |
if resize is not None:
|
128 |
+
scale1, scale2 = [img1.shape[1] / resize[0], img1.shape[0] / resize[1]], [
|
129 |
+
img2.shape[1] / resize[0],
|
130 |
+
img2.shape[0] / resize[1],
|
131 |
+
]
|
132 |
+
img1, img2 = cv2.resize(img1, resize, interpolation=cv2.INTER_AREA), cv2.resize(
|
133 |
+
img2, resize, interpolation=cv2.INTER_AREA
|
134 |
+
)
|
135 |
+
corr1, corr2 = (
|
136 |
+
corr1 / np.asarray(scale1)[np.newaxis],
|
137 |
+
corr2 / np.asarray(scale2)[np.newaxis],
|
138 |
+
)
|
139 |
+
corr1_key = [
|
140 |
+
cv2.KeyPoint(corr1[i, 0], corr1[i, 1], radius1) for i in range(corr1.shape[0])
|
141 |
+
]
|
142 |
+
corr2_key = [
|
143 |
+
cv2.KeyPoint(corr2[i, 0], corr2[i, 1], radius2) for i in range(corr2.shape[0])
|
144 |
+
]
|
145 |
|
146 |
assert len(corr1) == len(corr2)
|
147 |
|
148 |
draw_matches = [cv2.DMatch(i, i, 0) for i in range(len(corr1))]
|
149 |
if color is None:
|
150 |
+
color = [(0, 255, 0) if cur_inlier else (0, 0, 255) for cur_inlier in inlier]
|
151 |
+
if len(color) == 1:
|
152 |
+
display = cv2.drawMatches(
|
153 |
+
img1,
|
154 |
+
corr1_key,
|
155 |
+
img2,
|
156 |
+
corr2_key,
|
157 |
+
draw_matches,
|
158 |
+
None,
|
159 |
+
matchColor=color[0],
|
160 |
+
singlePointColor=color[0],
|
161 |
+
flags=4,
|
162 |
+
)
|
163 |
else:
|
164 |
+
height, width = max(img1.shape[0], img2.shape[0]), img1.shape[1] + img2.shape[1]
|
165 |
+
display = np.zeros([height, width, 3], np.uint8)
|
166 |
+
display[: img1.shape[0], : img1.shape[1]] = img1
|
167 |
+
display[: img2.shape[0], img1.shape[1] :] = img2
|
168 |
for i in range(len(corr1)):
|
169 |
+
left_x, left_y, right_x, right_y = (
|
170 |
+
int(corr1[i][0]),
|
171 |
+
int(corr1[i][1]),
|
172 |
+
int(corr2[i][0] + img1.shape[1]),
|
173 |
+
int(corr2[i][1]),
|
174 |
+
)
|
175 |
+
cur_color = (int(color[i][0]), int(color[i][1]), int(color[i][2]))
|
176 |
+
cv2.line(
|
177 |
+
display,
|
178 |
+
(left_x, left_y),
|
179 |
+
(right_x, right_y),
|
180 |
+
cur_color,
|
181 |
+
1,
|
182 |
+
lineType=cv2.LINE_AA,
|
183 |
+
)
|
184 |
return display
|
third_party/ASpanFormer/src/utils/plotting.py
CHANGED
@@ -4,38 +4,51 @@ import matplotlib.pyplot as plt
|
|
4 |
import matplotlib
|
5 |
from copy import deepcopy
|
6 |
|
|
|
7 |
def _compute_conf_thresh(data):
|
8 |
-
dataset_name = data[
|
9 |
-
if dataset_name ==
|
10 |
thr = 5e-4
|
11 |
-
elif dataset_name ==
|
12 |
thr = 1e-4
|
13 |
else:
|
14 |
-
raise ValueError(f
|
15 |
return thr
|
16 |
|
17 |
|
18 |
# --- VISUALIZATION --- #
|
19 |
|
|
|
20 |
def make_matching_figure(
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
# draw image pair
|
24 |
-
assert
|
|
|
|
|
25 |
fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
|
26 |
-
axes[0].imshow(img0, cmap=
|
27 |
-
axes[1].imshow(img1, cmap=
|
28 |
-
for i in range(2):
|
29 |
axes[i].get_yaxis().set_ticks([])
|
30 |
axes[i].get_xaxis().set_ticks([])
|
31 |
for spine in axes[i].spines.values():
|
32 |
spine.set_visible(False)
|
33 |
plt.tight_layout(pad=1)
|
34 |
-
|
35 |
if kpts0 is not None:
|
36 |
assert kpts1 is not None
|
37 |
-
axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c=
|
38 |
-
axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c=
|
39 |
|
40 |
# draw matches
|
41 |
if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
|
@@ -43,164 +56,181 @@ def make_matching_figure(
|
|
43 |
transFigure = fig.transFigure.inverted()
|
44 |
fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
|
45 |
fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
|
46 |
-
fig.lines = [
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4)
|
52 |
axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4)
|
53 |
|
54 |
# put txts
|
55 |
-
txt_color =
|
56 |
fig.text(
|
57 |
-
0.01,
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
# save or return figure
|
61 |
if path:
|
62 |
-
plt.savefig(str(path), bbox_inches=
|
63 |
plt.close()
|
64 |
else:
|
65 |
return fig
|
66 |
|
67 |
|
68 |
-
def _make_evaluation_figure(data, b_id, alpha=
|
69 |
-
b_mask = data[
|
70 |
conf_thr = _compute_conf_thresh(data)
|
71 |
-
|
72 |
-
img0 = (data[
|
73 |
-
img1 = (data[
|
74 |
-
kpts0 = data[
|
75 |
-
kpts1 = data[
|
76 |
-
|
77 |
# for megadepth, we visualize matches on the resized image
|
78 |
-
if
|
79 |
-
kpts0 = kpts0 / data[
|
80 |
-
kpts1 = kpts1 / data[
|
81 |
-
epi_errs = data[
|
82 |
correct_mask = epi_errs < conf_thr
|
83 |
precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
|
84 |
n_correct = np.sum(correct_mask)
|
85 |
-
n_gt_matches = int(data[
|
86 |
recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches)
|
87 |
# recall might be larger than 1, since the calculation of conf_matrix_gt
|
88 |
# uses groundtruth depths and camera poses, but epipolar distance is used here.
|
89 |
|
90 |
# matching info
|
91 |
-
if alpha ==
|
92 |
alpha = dynamic_alpha(len(correct_mask))
|
93 |
color = error_colormap(epi_errs, conf_thr, alpha=alpha)
|
94 |
-
|
95 |
text = [
|
96 |
-
f
|
97 |
-
f
|
98 |
-
f
|
99 |
]
|
100 |
-
|
101 |
# make the figure
|
102 |
-
figure = make_matching_figure(img0, img1, kpts0, kpts1,
|
103 |
-
color, text=text)
|
104 |
return figure
|
105 |
|
106 |
-
def _make_evaluation_figure_offset(data, b_id, alpha='dynamic',side=''):
|
107 |
-
layer_num=data['predict_flow'][0].shape[0]
|
108 |
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
116 |
for layer_index in range(layer_num):
|
117 |
-
l_mask=data[
|
118 |
-
mask=l_mask&b_mask
|
119 |
-
kpts0 = data[
|
120 |
-
kpts1 = data[
|
121 |
-
|
122 |
-
epi_errs = data[
|
123 |
correct_mask = epi_errs < conf_thr
|
124 |
-
|
125 |
precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
|
126 |
n_correct = np.sum(correct_mask)
|
127 |
-
n_gt_matches = int(data[
|
128 |
recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches)
|
129 |
# recall might be larger than 1, since the calculation of conf_matrix_gt
|
130 |
# uses groundtruth depths and camera poses, but epipolar distance is used here.
|
131 |
|
132 |
# matching info
|
133 |
-
if alpha ==
|
134 |
alpha = dynamic_alpha(len(correct_mask))
|
135 |
color = error_colormap(epi_errs, conf_thr, alpha=alpha)
|
136 |
-
|
137 |
text = [
|
138 |
-
f
|
139 |
-
f
|
140 |
-
f
|
141 |
]
|
142 |
-
|
143 |
# make the figure
|
144 |
-
#import pdb;pdb.set_trace()
|
145 |
-
figure = make_matching_figure(
|
146 |
-
|
|
|
147 |
figure_list.append(figure)
|
148 |
return figure
|
149 |
|
|
|
150 |
def _make_confidence_figure(data, b_id):
|
151 |
# TODO: Implement confidence figure
|
152 |
raise NotImplementedError()
|
153 |
|
154 |
|
155 |
-
def make_matching_figures(data, config, mode=
|
156 |
-
"""
|
157 |
-
|
158 |
Args:
|
159 |
data (Dict): a batch updated by PL_LoFTR.
|
160 |
config (Dict): matcher config
|
161 |
Returns:
|
162 |
figures (Dict[str, List[plt.figure]]
|
163 |
"""
|
164 |
-
assert mode in [
|
165 |
figures = {mode: []}
|
166 |
-
for b_id in range(data[
|
167 |
-
if mode ==
|
168 |
fig = _make_evaluation_figure(
|
169 |
-
data, b_id,
|
170 |
-
|
171 |
-
elif mode ==
|
172 |
fig = _make_confidence_figure(data, b_id)
|
173 |
else:
|
174 |
-
raise ValueError(f
|
175 |
figures[mode].append(fig)
|
176 |
return figures
|
177 |
|
178 |
-
|
179 |
-
|
180 |
-
|
|
|
181 |
Args:
|
182 |
data (Dict): a batch updated by PL_LoFTR.
|
183 |
config (Dict): matcher config
|
184 |
Returns:
|
185 |
figures (Dict[str, List[plt.figure]]
|
186 |
"""
|
187 |
-
assert mode in [
|
188 |
figures = {mode: []}
|
189 |
-
for b_id in range(data[
|
190 |
-
if mode ==
|
191 |
fig = _make_evaluation_figure_offset(
|
192 |
-
data, b_id,
|
193 |
-
|
194 |
-
elif mode ==
|
195 |
fig = _make_evaluation_figure_offset(data, b_id)
|
196 |
else:
|
197 |
-
raise ValueError(f
|
198 |
figures[mode].append(fig)
|
199 |
return figures
|
200 |
|
201 |
-
|
202 |
-
|
203 |
-
|
|
|
204 |
if n_matches == 0:
|
205 |
return 1.0
|
206 |
ranges = list(zip(alphas, alphas[1:] + [None]))
|
@@ -209,11 +239,15 @@ def dynamic_alpha(n_matches,
|
|
209 |
if _range[1] is None:
|
210 |
return _range[0]
|
211 |
return _range[1] + (milestones[loc + 1] - n_matches) / (
|
212 |
-
milestones[loc + 1] - milestones[loc]
|
|
|
213 |
|
214 |
|
215 |
def error_colormap(err, thr, alpha=1.0):
|
216 |
assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
|
217 |
x = 1 - np.clip(err / (thr * 2), 0, 1)
|
218 |
return np.clip(
|
219 |
-
np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1),
|
|
|
|
|
|
|
|
4 |
import matplotlib
|
5 |
from copy import deepcopy
|
6 |
|
7 |
+
|
8 |
def _compute_conf_thresh(data):
|
9 |
+
dataset_name = data["dataset_name"][0].lower()
|
10 |
+
if dataset_name == "scannet":
|
11 |
thr = 5e-4
|
12 |
+
elif dataset_name == "megadepth" or dataset_name == "gl3d":
|
13 |
thr = 1e-4
|
14 |
else:
|
15 |
+
raise ValueError(f"Unknown dataset: {dataset_name}")
|
16 |
return thr
|
17 |
|
18 |
|
19 |
# --- VISUALIZATION --- #
|
20 |
|
21 |
+
|
22 |
def make_matching_figure(
|
23 |
+
img0,
|
24 |
+
img1,
|
25 |
+
mkpts0,
|
26 |
+
mkpts1,
|
27 |
+
color,
|
28 |
+
kpts0=None,
|
29 |
+
kpts1=None,
|
30 |
+
text=[],
|
31 |
+
dpi=75,
|
32 |
+
path=None,
|
33 |
+
):
|
34 |
# draw image pair
|
35 |
+
assert (
|
36 |
+
mkpts0.shape[0] == mkpts1.shape[0]
|
37 |
+
), f"mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}"
|
38 |
fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
|
39 |
+
axes[0].imshow(img0, cmap="gray")
|
40 |
+
axes[1].imshow(img1, cmap="gray")
|
41 |
+
for i in range(2): # clear all frames
|
42 |
axes[i].get_yaxis().set_ticks([])
|
43 |
axes[i].get_xaxis().set_ticks([])
|
44 |
for spine in axes[i].spines.values():
|
45 |
spine.set_visible(False)
|
46 |
plt.tight_layout(pad=1)
|
47 |
+
|
48 |
if kpts0 is not None:
|
49 |
assert kpts1 is not None
|
50 |
+
axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c="w", s=2)
|
51 |
+
axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c="w", s=2)
|
52 |
|
53 |
# draw matches
|
54 |
if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
|
|
|
56 |
transFigure = fig.transFigure.inverted()
|
57 |
fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
|
58 |
fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
|
59 |
+
fig.lines = [
|
60 |
+
matplotlib.lines.Line2D(
|
61 |
+
(fkpts0[i, 0], fkpts1[i, 0]),
|
62 |
+
(fkpts0[i, 1], fkpts1[i, 1]),
|
63 |
+
transform=fig.transFigure,
|
64 |
+
c=color[i],
|
65 |
+
linewidth=1,
|
66 |
+
)
|
67 |
+
for i in range(len(mkpts0))
|
68 |
+
]
|
69 |
+
|
70 |
axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4)
|
71 |
axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4)
|
72 |
|
73 |
# put txts
|
74 |
+
txt_color = "k" if img0[:100, :200].mean() > 200 else "w"
|
75 |
fig.text(
|
76 |
+
0.01,
|
77 |
+
0.99,
|
78 |
+
"\n".join(text),
|
79 |
+
transform=fig.axes[0].transAxes,
|
80 |
+
fontsize=15,
|
81 |
+
va="top",
|
82 |
+
ha="left",
|
83 |
+
color=txt_color,
|
84 |
+
)
|
85 |
|
86 |
# save or return figure
|
87 |
if path:
|
88 |
+
plt.savefig(str(path), bbox_inches="tight", pad_inches=0)
|
89 |
plt.close()
|
90 |
else:
|
91 |
return fig
|
92 |
|
93 |
|
94 |
+
def _make_evaluation_figure(data, b_id, alpha="dynamic"):
|
95 |
+
b_mask = data["m_bids"] == b_id
|
96 |
conf_thr = _compute_conf_thresh(data)
|
97 |
+
|
98 |
+
img0 = (data["image0"][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
|
99 |
+
img1 = (data["image1"][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
|
100 |
+
kpts0 = data["mkpts0_f"][b_mask].cpu().numpy()
|
101 |
+
kpts1 = data["mkpts1_f"][b_mask].cpu().numpy()
|
102 |
+
|
103 |
# for megadepth, we visualize matches on the resized image
|
104 |
+
if "scale0" in data:
|
105 |
+
kpts0 = kpts0 / data["scale0"][b_id].cpu().numpy()[[1, 0]]
|
106 |
+
kpts1 = kpts1 / data["scale1"][b_id].cpu().numpy()[[1, 0]]
|
107 |
+
epi_errs = data["epi_errs"][b_mask].cpu().numpy()
|
108 |
correct_mask = epi_errs < conf_thr
|
109 |
precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
|
110 |
n_correct = np.sum(correct_mask)
|
111 |
+
n_gt_matches = int(data["conf_matrix_gt"][b_id].sum().cpu())
|
112 |
recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches)
|
113 |
# recall might be larger than 1, since the calculation of conf_matrix_gt
|
114 |
# uses groundtruth depths and camera poses, but epipolar distance is used here.
|
115 |
|
116 |
# matching info
|
117 |
+
if alpha == "dynamic":
|
118 |
alpha = dynamic_alpha(len(correct_mask))
|
119 |
color = error_colormap(epi_errs, conf_thr, alpha=alpha)
|
120 |
+
|
121 |
text = [
|
122 |
+
f"#Matches {len(kpts0)}",
|
123 |
+
f"Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}",
|
124 |
+
f"Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}",
|
125 |
]
|
126 |
+
|
127 |
# make the figure
|
128 |
+
figure = make_matching_figure(img0, img1, kpts0, kpts1, color, text=text)
|
|
|
129 |
return figure
|
130 |
|
|
|
|
|
131 |
|
132 |
+
def _make_evaluation_figure_offset(data, b_id, alpha="dynamic", side=""):
|
133 |
+
layer_num = data["predict_flow"][0].shape[0]
|
134 |
+
|
135 |
+
b_mask = data["offset_bids" + side] == b_id
|
136 |
+
conf_thr = 2e-3 # hardcode for scannet(coarse level)
|
137 |
+
img0 = (data["image0"][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
|
138 |
+
img1 = (data["image1"][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
|
139 |
+
|
140 |
+
figure_list = []
|
141 |
+
# draw offset matches in different layers
|
142 |
for layer_index in range(layer_num):
|
143 |
+
l_mask = data["offset_lids" + side] == layer_index
|
144 |
+
mask = l_mask & b_mask
|
145 |
+
kpts0 = data["offset_kpts0_f" + side][mask].cpu().numpy()
|
146 |
+
kpts1 = data["offset_kpts1_f" + side][mask].cpu().numpy()
|
147 |
+
|
148 |
+
epi_errs = data["epi_errs_offset" + side][mask].cpu().numpy()
|
149 |
correct_mask = epi_errs < conf_thr
|
150 |
+
|
151 |
precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
|
152 |
n_correct = np.sum(correct_mask)
|
153 |
+
n_gt_matches = int(data["conf_matrix_gt"][b_id].sum().cpu())
|
154 |
recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches)
|
155 |
# recall might be larger than 1, since the calculation of conf_matrix_gt
|
156 |
# uses groundtruth depths and camera poses, but epipolar distance is used here.
|
157 |
|
158 |
# matching info
|
159 |
+
if alpha == "dynamic":
|
160 |
alpha = dynamic_alpha(len(correct_mask))
|
161 |
color = error_colormap(epi_errs, conf_thr, alpha=alpha)
|
162 |
+
|
163 |
text = [
|
164 |
+
f"#Matches {len(kpts0)}",
|
165 |
+
f"Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}",
|
166 |
+
f"Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}",
|
167 |
]
|
168 |
+
|
169 |
# make the figure
|
170 |
+
# import pdb;pdb.set_trace()
|
171 |
+
figure = make_matching_figure(
|
172 |
+
deepcopy(img0), deepcopy(img1), kpts0, kpts1, color, text=text
|
173 |
+
)
|
174 |
figure_list.append(figure)
|
175 |
return figure
|
176 |
|
177 |
+
|
178 |
def _make_confidence_figure(data, b_id):
|
179 |
# TODO: Implement confidence figure
|
180 |
raise NotImplementedError()
|
181 |
|
182 |
|
183 |
+
def make_matching_figures(data, config, mode="evaluation"):
|
184 |
+
"""Make matching figures for a batch.
|
185 |
+
|
186 |
Args:
|
187 |
data (Dict): a batch updated by PL_LoFTR.
|
188 |
config (Dict): matcher config
|
189 |
Returns:
|
190 |
figures (Dict[str, List[plt.figure]]
|
191 |
"""
|
192 |
+
assert mode in ["evaluation", "confidence"] # 'confidence'
|
193 |
figures = {mode: []}
|
194 |
+
for b_id in range(data["image0"].size(0)):
|
195 |
+
if mode == "evaluation":
|
196 |
fig = _make_evaluation_figure(
|
197 |
+
data, b_id, alpha=config.TRAINER.PLOT_MATCHES_ALPHA
|
198 |
+
)
|
199 |
+
elif mode == "confidence":
|
200 |
fig = _make_confidence_figure(data, b_id)
|
201 |
else:
|
202 |
+
raise ValueError(f"Unknown plot mode: {mode}")
|
203 |
figures[mode].append(fig)
|
204 |
return figures
|
205 |
|
206 |
+
|
207 |
+
def make_matching_figures_offset(data, config, mode="evaluation", side=""):
|
208 |
+
"""Make matching figures for a batch.
|
209 |
+
|
210 |
Args:
|
211 |
data (Dict): a batch updated by PL_LoFTR.
|
212 |
config (Dict): matcher config
|
213 |
Returns:
|
214 |
figures (Dict[str, List[plt.figure]]
|
215 |
"""
|
216 |
+
assert mode in ["evaluation", "confidence"] # 'confidence'
|
217 |
figures = {mode: []}
|
218 |
+
for b_id in range(data["image0"].size(0)):
|
219 |
+
if mode == "evaluation":
|
220 |
fig = _make_evaluation_figure_offset(
|
221 |
+
data, b_id, alpha=config.TRAINER.PLOT_MATCHES_ALPHA, side=side
|
222 |
+
)
|
223 |
+
elif mode == "confidence":
|
224 |
fig = _make_evaluation_figure_offset(data, b_id)
|
225 |
else:
|
226 |
+
raise ValueError(f"Unknown plot mode: {mode}")
|
227 |
figures[mode].append(fig)
|
228 |
return figures
|
229 |
|
230 |
+
|
231 |
+
def dynamic_alpha(
|
232 |
+
n_matches, milestones=[0, 300, 1000, 2000], alphas=[1.0, 0.8, 0.4, 0.2]
|
233 |
+
):
|
234 |
if n_matches == 0:
|
235 |
return 1.0
|
236 |
ranges = list(zip(alphas, alphas[1:] + [None]))
|
|
|
239 |
if _range[1] is None:
|
240 |
return _range[0]
|
241 |
return _range[1] + (milestones[loc + 1] - n_matches) / (
|
242 |
+
milestones[loc + 1] - milestones[loc]
|
243 |
+
) * (_range[0] - _range[1])
|
244 |
|
245 |
|
246 |
def error_colormap(err, thr, alpha=1.0):
|
247 |
assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
|
248 |
x = 1 - np.clip(err / (thr * 2), 0, 1)
|
249 |
return np.clip(
|
250 |
+
np.stack([2 - x * 2, x * 2, np.zeros_like(x), np.ones_like(x) * alpha], -1),
|
251 |
+
0,
|
252 |
+
1,
|
253 |
+
)
|
third_party/ASpanFormer/src/utils/profiler.py
CHANGED
@@ -7,7 +7,7 @@ from pytorch_lightning.utilities import rank_zero_only
|
|
7 |
class InferenceProfiler(SimpleProfiler):
|
8 |
"""
|
9 |
This profiler records duration of actions with cuda.synchronize()
|
10 |
-
Use this in test time.
|
11 |
"""
|
12 |
|
13 |
def __init__(self):
|
@@ -28,12 +28,13 @@ class InferenceProfiler(SimpleProfiler):
|
|
28 |
|
29 |
|
30 |
def build_profiler(name):
|
31 |
-
if name ==
|
32 |
return InferenceProfiler()
|
33 |
-
elif name ==
|
34 |
from pytorch_lightning.profiler import PyTorchProfiler
|
|
|
35 |
return PyTorchProfiler(use_cuda=True, profile_memory=True, row_limit=100)
|
36 |
elif name is None:
|
37 |
return PassThroughProfiler()
|
38 |
else:
|
39 |
-
raise ValueError(f
|
|
|
7 |
class InferenceProfiler(SimpleProfiler):
|
8 |
"""
|
9 |
This profiler records duration of actions with cuda.synchronize()
|
10 |
+
Use this in test time.
|
11 |
"""
|
12 |
|
13 |
def __init__(self):
|
|
|
28 |
|
29 |
|
30 |
def build_profiler(name):
|
31 |
+
if name == "inference":
|
32 |
return InferenceProfiler()
|
33 |
+
elif name == "pytorch":
|
34 |
from pytorch_lightning.profiler import PyTorchProfiler
|
35 |
+
|
36 |
return PyTorchProfiler(use_cuda=True, profile_memory=True, row_limit=100)
|
37 |
elif name is None:
|
38 |
return PassThroughProfiler()
|
39 |
else:
|
40 |
+
raise ValueError(f"Invalid profiler: {name}")
|
third_party/ASpanFormer/test.py
CHANGED
@@ -10,33 +10,52 @@ from src.lightning.data import MultiSceneDataModule
|
|
10 |
from src.lightning.lightning_aspanformer import PL_ASpanFormer
|
11 |
import torch
|
12 |
|
|
|
13 |
def parse_args():
|
14 |
# init a costum parser which will be added into pl.Trainer parser
|
15 |
# check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags
|
16 |
-
parser = argparse.ArgumentParser(
|
17 |
-
|
18 |
-
|
19 |
-
parser.add_argument(
|
20 |
-
|
21 |
-
parser.add_argument(
|
22 |
-
'--ckpt_path', type=str, default="weights/indoor_ds.ckpt", help='path to the checkpoint')
|
23 |
-
parser.add_argument(
|
24 |
-
'--dump_dir', type=str, default=None, help="if set, the matching results will be dump to dump_dir")
|
25 |
parser.add_argument(
|
26 |
-
|
|
|
|
|
|
|
|
|
27 |
parser.add_argument(
|
28 |
-
|
|
|
|
|
|
|
|
|
29 |
parser.add_argument(
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
parser.add_argument(
|
32 |
-
|
|
|
|
|
|
|
|
|
33 |
parser.add_argument(
|
34 |
-
|
|
|
|
|
|
|
|
|
35 |
parser = pl.Trainer.add_argparse_args(parser)
|
36 |
return parser.parse_args()
|
37 |
|
38 |
|
39 |
-
if __name__ ==
|
40 |
# parse arguments
|
41 |
args = parse_args()
|
42 |
pprint.pprint(vars(args))
|
@@ -55,7 +74,12 @@ if __name__ == '__main__':
|
|
55 |
|
56 |
# lightning module
|
57 |
profiler = build_profiler(args.profiler_name)
|
58 |
-
model = PL_ASpanFormer(
|
|
|
|
|
|
|
|
|
|
|
59 |
loguru_logger.info(f"ASpanFormer-lightning initialized!")
|
60 |
|
61 |
# lightning data
|
@@ -63,7 +87,9 @@ if __name__ == '__main__':
|
|
63 |
loguru_logger.info(f"DataModule initialized!")
|
64 |
|
65 |
# lightning trainer
|
66 |
-
trainer = pl.Trainer.from_argparse_args(
|
|
|
|
|
67 |
|
68 |
loguru_logger.info(f"Start testing!")
|
69 |
trainer.test(model, datamodule=data_module, verbose=False)
|
|
|
10 |
from src.lightning.lightning_aspanformer import PL_ASpanFormer
|
11 |
import torch
|
12 |
|
13 |
+
|
14 |
def parse_args():
|
15 |
# init a costum parser which will be added into pl.Trainer parser
|
16 |
# check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags
|
17 |
+
parser = argparse.ArgumentParser(
|
18 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
19 |
+
)
|
20 |
+
parser.add_argument("data_cfg_path", type=str, help="data config path")
|
21 |
+
parser.add_argument("main_cfg_path", type=str, help="main config path")
|
|
|
|
|
|
|
|
|
22 |
parser.add_argument(
|
23 |
+
"--ckpt_path",
|
24 |
+
type=str,
|
25 |
+
default="weights/indoor_ds.ckpt",
|
26 |
+
help="path to the checkpoint",
|
27 |
+
)
|
28 |
parser.add_argument(
|
29 |
+
"--dump_dir",
|
30 |
+
type=str,
|
31 |
+
default=None,
|
32 |
+
help="if set, the matching results will be dump to dump_dir",
|
33 |
+
)
|
34 |
parser.add_argument(
|
35 |
+
"--profiler_name",
|
36 |
+
type=str,
|
37 |
+
default=None,
|
38 |
+
help="options: [inference, pytorch], or leave it unset",
|
39 |
+
)
|
40 |
+
parser.add_argument("--batch_size", type=int, default=1, help="batch_size per gpu")
|
41 |
+
parser.add_argument("--num_workers", type=int, default=2)
|
42 |
parser.add_argument(
|
43 |
+
"--thr",
|
44 |
+
type=float,
|
45 |
+
default=None,
|
46 |
+
help="modify the coarse-level matching threshold.",
|
47 |
+
)
|
48 |
parser.add_argument(
|
49 |
+
"--mode",
|
50 |
+
type=str,
|
51 |
+
default="vanilla",
|
52 |
+
help="modify the coarse-level matching threshold.",
|
53 |
+
)
|
54 |
parser = pl.Trainer.add_argparse_args(parser)
|
55 |
return parser.parse_args()
|
56 |
|
57 |
|
58 |
+
if __name__ == "__main__":
|
59 |
# parse arguments
|
60 |
args = parse_args()
|
61 |
pprint.pprint(vars(args))
|
|
|
74 |
|
75 |
# lightning module
|
76 |
profiler = build_profiler(args.profiler_name)
|
77 |
+
model = PL_ASpanFormer(
|
78 |
+
config,
|
79 |
+
pretrained_ckpt=args.ckpt_path,
|
80 |
+
profiler=profiler,
|
81 |
+
dump_dir=args.dump_dir,
|
82 |
+
)
|
83 |
loguru_logger.info(f"ASpanFormer-lightning initialized!")
|
84 |
|
85 |
# lightning data
|
|
|
87 |
loguru_logger.info(f"DataModule initialized!")
|
88 |
|
89 |
# lightning trainer
|
90 |
+
trainer = pl.Trainer.from_argparse_args(
|
91 |
+
args, replace_sampler_ddp=False, logger=False
|
92 |
+
)
|
93 |
|
94 |
loguru_logger.info(f"Start testing!")
|
95 |
trainer.test(model, datamodule=data_module, verbose=False)
|
third_party/ASpanFormer/tools/extract.py
CHANGED
@@ -5,43 +5,77 @@ from tqdm import tqdm
|
|
5 |
from multiprocessing import Pool
|
6 |
from functools import partial
|
7 |
|
8 |
-
scannet_dir=
|
9 |
-
dump_dir=
|
10 |
-
num_process=32
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
os.system(cmd)
|
19 |
|
20 |
-
|
|
|
21 |
if not os.path.exists(dump_dir):
|
22 |
os.mkdir(dump_dir)
|
23 |
-
os.mkdir(os.path.join(dump_dir,
|
24 |
-
os.mkdir(os.path.join(dump_dir,
|
25 |
|
26 |
-
train_seq_list=[
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
-
extract_train=partial(
|
30 |
-
|
|
|
|
|
|
|
|
|
31 |
|
32 |
-
num_train_iter=
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
pool = Pool(num_process)
|
36 |
for index in tqdm(range(num_train_iter)):
|
37 |
-
seq_list=train_seq_list[
|
38 |
-
|
|
|
|
|
39 |
pool.close()
|
40 |
pool.join()
|
41 |
-
|
42 |
pool = Pool(num_process)
|
43 |
for index in tqdm(range(num_test_iter)):
|
44 |
-
seq_list=test_seq_list[
|
45 |
-
|
|
|
|
|
46 |
pool.close()
|
47 |
-
pool.join()
|
|
|
5 |
from multiprocessing import Pool
|
6 |
from functools import partial
|
7 |
|
8 |
+
scannet_dir = "/root/data/ScanNet-v2-1.0.0/data/raw"
|
9 |
+
dump_dir = "/root/data/scannet_dump"
|
10 |
+
num_process = 32
|
11 |
+
|
12 |
+
|
13 |
+
def extract(seq, scannet_dir, split, dump_dir):
|
14 |
+
assert split == "train" or split == "test"
|
15 |
+
if not os.path.exists(os.path.join(dump_dir, split, seq)):
|
16 |
+
os.mkdir(os.path.join(dump_dir, split, seq))
|
17 |
+
cmd = (
|
18 |
+
"python reader.py --filename "
|
19 |
+
+ os.path.join(
|
20 |
+
scannet_dir,
|
21 |
+
"scans" if split == "train" else "scans_test",
|
22 |
+
seq,
|
23 |
+
seq + ".sens",
|
24 |
+
)
|
25 |
+
+ " --output_path "
|
26 |
+
+ os.path.join(dump_dir, split, seq)
|
27 |
+
+ " --export_depth_images --export_color_images --export_poses --export_intrinsics"
|
28 |
+
)
|
29 |
os.system(cmd)
|
30 |
|
31 |
+
|
32 |
+
if __name__ == "__main__":
|
33 |
if not os.path.exists(dump_dir):
|
34 |
os.mkdir(dump_dir)
|
35 |
+
os.mkdir(os.path.join(dump_dir, "train"))
|
36 |
+
os.mkdir(os.path.join(dump_dir, "test"))
|
37 |
|
38 |
+
train_seq_list = [
|
39 |
+
seq.split("/")[-1]
|
40 |
+
for seq in glob.glob(os.path.join(scannet_dir, "scans", "scene*"))
|
41 |
+
]
|
42 |
+
test_seq_list = [
|
43 |
+
seq.split("/")[-1]
|
44 |
+
for seq in glob.glob(os.path.join(scannet_dir, "scans_test", "scene*"))
|
45 |
+
]
|
46 |
|
47 |
+
extract_train = partial(
|
48 |
+
extract, scannet_dir=scannet_dir, split="train", dump_dir=dump_dir
|
49 |
+
)
|
50 |
+
extract_test = partial(
|
51 |
+
extract, scannet_dir=scannet_dir, split="test", dump_dir=dump_dir
|
52 |
+
)
|
53 |
|
54 |
+
num_train_iter = (
|
55 |
+
len(train_seq_list) // num_process
|
56 |
+
if len(train_seq_list) % num_process == 0
|
57 |
+
else len(train_seq_list) // num_process + 1
|
58 |
+
)
|
59 |
+
num_test_iter = (
|
60 |
+
len(test_seq_list) // num_process
|
61 |
+
if len(test_seq_list) % num_process == 0
|
62 |
+
else len(test_seq_list) // num_process + 1
|
63 |
+
)
|
64 |
|
65 |
pool = Pool(num_process)
|
66 |
for index in tqdm(range(num_train_iter)):
|
67 |
+
seq_list = train_seq_list[
|
68 |
+
index * num_process : min((index + 1) * num_process, len(train_seq_list))
|
69 |
+
]
|
70 |
+
pool.map(extract_train, seq_list)
|
71 |
pool.close()
|
72 |
pool.join()
|
73 |
+
|
74 |
pool = Pool(num_process)
|
75 |
for index in tqdm(range(num_test_iter)):
|
76 |
+
seq_list = test_seq_list[
|
77 |
+
index * num_process : min((index + 1) * num_process, len(test_seq_list))
|
78 |
+
]
|
79 |
+
pool.map(extract_test, seq_list)
|
80 |
pool.close()
|
81 |
+
pool.join()
|