Spaces:
Running
on
Zero
Running
on
Zero
Adding Depth Anything model
Browse files- depth_anything/blocks.py +153 -0
- depth_anything/dpt.py +187 -0
- depth_anything/util/transform.py +248 -0
depth_anything/blocks.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
5 |
+
scratch = nn.Module()
|
6 |
+
|
7 |
+
out_shape1 = out_shape
|
8 |
+
out_shape2 = out_shape
|
9 |
+
out_shape3 = out_shape
|
10 |
+
if len(in_shape) >= 4:
|
11 |
+
out_shape4 = out_shape
|
12 |
+
|
13 |
+
if expand:
|
14 |
+
out_shape1 = out_shape
|
15 |
+
out_shape2 = out_shape*2
|
16 |
+
out_shape3 = out_shape*4
|
17 |
+
if len(in_shape) >= 4:
|
18 |
+
out_shape4 = out_shape*8
|
19 |
+
|
20 |
+
scratch.layer1_rn = nn.Conv2d(
|
21 |
+
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
22 |
+
)
|
23 |
+
scratch.layer2_rn = nn.Conv2d(
|
24 |
+
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
25 |
+
)
|
26 |
+
scratch.layer3_rn = nn.Conv2d(
|
27 |
+
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
28 |
+
)
|
29 |
+
if len(in_shape) >= 4:
|
30 |
+
scratch.layer4_rn = nn.Conv2d(
|
31 |
+
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
32 |
+
)
|
33 |
+
|
34 |
+
return scratch
|
35 |
+
|
36 |
+
|
37 |
+
class ResidualConvUnit(nn.Module):
|
38 |
+
"""Residual convolution module.
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(self, features, activation, bn):
|
42 |
+
"""Init.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
features (int): number of features
|
46 |
+
"""
|
47 |
+
super().__init__()
|
48 |
+
|
49 |
+
self.bn = bn
|
50 |
+
|
51 |
+
self.groups=1
|
52 |
+
|
53 |
+
self.conv1 = nn.Conv2d(
|
54 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
55 |
+
)
|
56 |
+
|
57 |
+
self.conv2 = nn.Conv2d(
|
58 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
59 |
+
)
|
60 |
+
|
61 |
+
if self.bn==True:
|
62 |
+
self.bn1 = nn.BatchNorm2d(features)
|
63 |
+
self.bn2 = nn.BatchNorm2d(features)
|
64 |
+
|
65 |
+
self.activation = activation
|
66 |
+
|
67 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
"""Forward pass.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
x (tensor): input
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
tensor: output
|
77 |
+
"""
|
78 |
+
|
79 |
+
out = self.activation(x)
|
80 |
+
out = self.conv1(out)
|
81 |
+
if self.bn==True:
|
82 |
+
out = self.bn1(out)
|
83 |
+
|
84 |
+
out = self.activation(out)
|
85 |
+
out = self.conv2(out)
|
86 |
+
if self.bn==True:
|
87 |
+
out = self.bn2(out)
|
88 |
+
|
89 |
+
if self.groups > 1:
|
90 |
+
out = self.conv_merge(out)
|
91 |
+
|
92 |
+
return self.skip_add.add(out, x)
|
93 |
+
|
94 |
+
|
95 |
+
class FeatureFusionBlock(nn.Module):
|
96 |
+
"""Feature fusion block.
|
97 |
+
"""
|
98 |
+
|
99 |
+
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
|
100 |
+
"""Init.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
features (int): number of features
|
104 |
+
"""
|
105 |
+
super(FeatureFusionBlock, self).__init__()
|
106 |
+
|
107 |
+
self.deconv = deconv
|
108 |
+
self.align_corners = align_corners
|
109 |
+
|
110 |
+
self.groups=1
|
111 |
+
|
112 |
+
self.expand = expand
|
113 |
+
out_features = features
|
114 |
+
if self.expand==True:
|
115 |
+
out_features = features//2
|
116 |
+
|
117 |
+
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
118 |
+
|
119 |
+
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
|
120 |
+
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
|
121 |
+
|
122 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
123 |
+
|
124 |
+
self.size=size
|
125 |
+
|
126 |
+
def forward(self, *xs, size=None):
|
127 |
+
"""Forward pass.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
tensor: output
|
131 |
+
"""
|
132 |
+
output = xs[0]
|
133 |
+
|
134 |
+
if len(xs) == 2:
|
135 |
+
res = self.resConfUnit1(xs[1])
|
136 |
+
output = self.skip_add.add(output, res)
|
137 |
+
|
138 |
+
output = self.resConfUnit2(output)
|
139 |
+
|
140 |
+
if (size is None) and (self.size is None):
|
141 |
+
modifier = {"scale_factor": 2}
|
142 |
+
elif size is None:
|
143 |
+
modifier = {"size": self.size}
|
144 |
+
else:
|
145 |
+
modifier = {"size": size}
|
146 |
+
|
147 |
+
output = nn.functional.interpolate(
|
148 |
+
output, **modifier, mode="bilinear", align_corners=self.align_corners
|
149 |
+
)
|
150 |
+
|
151 |
+
output = self.out_conv(output)
|
152 |
+
|
153 |
+
return output
|
depth_anything/dpt.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
6 |
+
|
7 |
+
from depth_anything.blocks import FeatureFusionBlock, _make_scratch
|
8 |
+
|
9 |
+
|
10 |
+
def _make_fusion_block(features, use_bn, size = None):
|
11 |
+
return FeatureFusionBlock(
|
12 |
+
features,
|
13 |
+
nn.ReLU(False),
|
14 |
+
deconv=False,
|
15 |
+
bn=use_bn,
|
16 |
+
expand=False,
|
17 |
+
align_corners=True,
|
18 |
+
size=size,
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
class DPTHead(nn.Module):
|
23 |
+
def __init__(self, nclass, in_channels, features=256, use_bn=False, out_channels=[256, 512, 1024, 1024], use_clstoken=False):
|
24 |
+
super(DPTHead, self).__init__()
|
25 |
+
|
26 |
+
self.nclass = nclass
|
27 |
+
self.use_clstoken = use_clstoken
|
28 |
+
|
29 |
+
self.projects = nn.ModuleList([
|
30 |
+
nn.Conv2d(
|
31 |
+
in_channels=in_channels,
|
32 |
+
out_channels=out_channel,
|
33 |
+
kernel_size=1,
|
34 |
+
stride=1,
|
35 |
+
padding=0,
|
36 |
+
) for out_channel in out_channels
|
37 |
+
])
|
38 |
+
|
39 |
+
self.resize_layers = nn.ModuleList([
|
40 |
+
nn.ConvTranspose2d(
|
41 |
+
in_channels=out_channels[0],
|
42 |
+
out_channels=out_channels[0],
|
43 |
+
kernel_size=4,
|
44 |
+
stride=4,
|
45 |
+
padding=0),
|
46 |
+
nn.ConvTranspose2d(
|
47 |
+
in_channels=out_channels[1],
|
48 |
+
out_channels=out_channels[1],
|
49 |
+
kernel_size=2,
|
50 |
+
stride=2,
|
51 |
+
padding=0),
|
52 |
+
nn.Identity(),
|
53 |
+
nn.Conv2d(
|
54 |
+
in_channels=out_channels[3],
|
55 |
+
out_channels=out_channels[3],
|
56 |
+
kernel_size=3,
|
57 |
+
stride=2,
|
58 |
+
padding=1)
|
59 |
+
])
|
60 |
+
|
61 |
+
if use_clstoken:
|
62 |
+
self.readout_projects = nn.ModuleList()
|
63 |
+
for _ in range(len(self.projects)):
|
64 |
+
self.readout_projects.append(
|
65 |
+
nn.Sequential(
|
66 |
+
nn.Linear(2 * in_channels, in_channels),
|
67 |
+
nn.GELU()))
|
68 |
+
|
69 |
+
self.scratch = _make_scratch(
|
70 |
+
out_channels,
|
71 |
+
features,
|
72 |
+
groups=1,
|
73 |
+
expand=False,
|
74 |
+
)
|
75 |
+
|
76 |
+
self.scratch.stem_transpose = None
|
77 |
+
|
78 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
79 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
80 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
81 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
82 |
+
|
83 |
+
head_features_1 = features
|
84 |
+
head_features_2 = 32
|
85 |
+
|
86 |
+
if nclass > 1:
|
87 |
+
self.scratch.output_conv = nn.Sequential(
|
88 |
+
nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1),
|
89 |
+
nn.ReLU(True),
|
90 |
+
nn.Conv2d(head_features_1, nclass, kernel_size=1, stride=1, padding=0),
|
91 |
+
)
|
92 |
+
else:
|
93 |
+
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
|
94 |
+
|
95 |
+
self.scratch.output_conv2 = nn.Sequential(
|
96 |
+
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
|
97 |
+
nn.ReLU(True),
|
98 |
+
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
|
99 |
+
nn.ReLU(True),
|
100 |
+
nn.Identity(),
|
101 |
+
)
|
102 |
+
|
103 |
+
def forward(self, out_features, patch_h, patch_w):
|
104 |
+
out = []
|
105 |
+
for i, x in enumerate(out_features):
|
106 |
+
if self.use_clstoken:
|
107 |
+
x, cls_token = x[0], x[1]
|
108 |
+
readout = cls_token.unsqueeze(1).expand_as(x)
|
109 |
+
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
110 |
+
else:
|
111 |
+
x = x[0]
|
112 |
+
|
113 |
+
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
114 |
+
|
115 |
+
x = self.projects[i](x)
|
116 |
+
x = self.resize_layers[i](x)
|
117 |
+
|
118 |
+
out.append(x)
|
119 |
+
|
120 |
+
layer_1, layer_2, layer_3, layer_4 = out
|
121 |
+
|
122 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
123 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
124 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
125 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
126 |
+
|
127 |
+
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
128 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
|
129 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
|
130 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
131 |
+
|
132 |
+
out = self.scratch.output_conv1(path_1)
|
133 |
+
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
|
134 |
+
out = self.scratch.output_conv2(out)
|
135 |
+
|
136 |
+
return out
|
137 |
+
|
138 |
+
|
139 |
+
class DPT_DINOv2(nn.Module):
|
140 |
+
def __init__(self, encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024], use_bn=False, use_clstoken=False, localhub=True):
|
141 |
+
super(DPT_DINOv2, self).__init__()
|
142 |
+
|
143 |
+
assert encoder in ['vits', 'vitb', 'vitl']
|
144 |
+
|
145 |
+
# in case the Internet connection is not stable, please load the DINOv2 locally
|
146 |
+
if localhub:
|
147 |
+
self.pretrained = torch.hub.load('torchhub/facebookresearch_dinov2_main', 'dinov2_{:}14'.format(encoder), source='local', pretrained=False)
|
148 |
+
else:
|
149 |
+
self.pretrained = torch.hub.load('facebookresearch/dinov2', 'dinov2_{:}14'.format(encoder))
|
150 |
+
|
151 |
+
dim = self.pretrained.blocks[0].attn.qkv.in_features
|
152 |
+
|
153 |
+
self.depth_head = DPTHead(1, dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
|
154 |
+
|
155 |
+
def forward(self, x):
|
156 |
+
h, w = x.shape[-2:]
|
157 |
+
|
158 |
+
features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True)
|
159 |
+
|
160 |
+
patch_h, patch_w = h // 14, w // 14
|
161 |
+
|
162 |
+
depth = self.depth_head(features, patch_h, patch_w)
|
163 |
+
depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True)
|
164 |
+
depth = F.relu(depth)
|
165 |
+
|
166 |
+
return depth.squeeze(1)
|
167 |
+
|
168 |
+
|
169 |
+
class DepthAnything(DPT_DINOv2, PyTorchModelHubMixin):
|
170 |
+
def __init__(self, config):
|
171 |
+
super().__init__(**config)
|
172 |
+
|
173 |
+
|
174 |
+
if __name__ == '__main__':
|
175 |
+
parser = argparse.ArgumentParser()
|
176 |
+
parser.add_argument(
|
177 |
+
"--encoder",
|
178 |
+
default="vits",
|
179 |
+
type=str,
|
180 |
+
choices=["vits", "vitb", "vitl"],
|
181 |
+
)
|
182 |
+
args = parser.parse_args()
|
183 |
+
|
184 |
+
model = DepthAnything.from_pretrained("LiheYoung/depth_anything_{:}14".format(args.encoder))
|
185 |
+
|
186 |
+
print(model)
|
187 |
+
|
depth_anything/util/transform.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from PIL import Image, ImageOps, ImageFilter
|
3 |
+
import torch
|
4 |
+
from torchvision import transforms
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import cv2
|
9 |
+
import math
|
10 |
+
|
11 |
+
|
12 |
+
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
13 |
+
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
sample (dict): sample
|
17 |
+
size (tuple): image size
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
tuple: new size
|
21 |
+
"""
|
22 |
+
shape = list(sample["disparity"].shape)
|
23 |
+
|
24 |
+
if shape[0] >= size[0] and shape[1] >= size[1]:
|
25 |
+
return sample
|
26 |
+
|
27 |
+
scale = [0, 0]
|
28 |
+
scale[0] = size[0] / shape[0]
|
29 |
+
scale[1] = size[1] / shape[1]
|
30 |
+
|
31 |
+
scale = max(scale)
|
32 |
+
|
33 |
+
shape[0] = math.ceil(scale * shape[0])
|
34 |
+
shape[1] = math.ceil(scale * shape[1])
|
35 |
+
|
36 |
+
# resize
|
37 |
+
sample["image"] = cv2.resize(
|
38 |
+
sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
|
39 |
+
)
|
40 |
+
|
41 |
+
sample["disparity"] = cv2.resize(
|
42 |
+
sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
|
43 |
+
)
|
44 |
+
sample["mask"] = cv2.resize(
|
45 |
+
sample["mask"].astype(np.float32),
|
46 |
+
tuple(shape[::-1]),
|
47 |
+
interpolation=cv2.INTER_NEAREST,
|
48 |
+
)
|
49 |
+
sample["mask"] = sample["mask"].astype(bool)
|
50 |
+
|
51 |
+
return tuple(shape)
|
52 |
+
|
53 |
+
|
54 |
+
class Resize(object):
|
55 |
+
"""Resize sample to given size (width, height).
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
width,
|
61 |
+
height,
|
62 |
+
resize_target=True,
|
63 |
+
keep_aspect_ratio=False,
|
64 |
+
ensure_multiple_of=1,
|
65 |
+
resize_method="lower_bound",
|
66 |
+
image_interpolation_method=cv2.INTER_AREA,
|
67 |
+
):
|
68 |
+
"""Init.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
width (int): desired output width
|
72 |
+
height (int): desired output height
|
73 |
+
resize_target (bool, optional):
|
74 |
+
True: Resize the full sample (image, mask, target).
|
75 |
+
False: Resize image only.
|
76 |
+
Defaults to True.
|
77 |
+
keep_aspect_ratio (bool, optional):
|
78 |
+
True: Keep the aspect ratio of the input sample.
|
79 |
+
Output sample might not have the given width and height, and
|
80 |
+
resize behaviour depends on the parameter 'resize_method'.
|
81 |
+
Defaults to False.
|
82 |
+
ensure_multiple_of (int, optional):
|
83 |
+
Output width and height is constrained to be multiple of this parameter.
|
84 |
+
Defaults to 1.
|
85 |
+
resize_method (str, optional):
|
86 |
+
"lower_bound": Output will be at least as large as the given size.
|
87 |
+
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
88 |
+
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
89 |
+
Defaults to "lower_bound".
|
90 |
+
"""
|
91 |
+
self.__width = width
|
92 |
+
self.__height = height
|
93 |
+
|
94 |
+
self.__resize_target = resize_target
|
95 |
+
self.__keep_aspect_ratio = keep_aspect_ratio
|
96 |
+
self.__multiple_of = ensure_multiple_of
|
97 |
+
self.__resize_method = resize_method
|
98 |
+
self.__image_interpolation_method = image_interpolation_method
|
99 |
+
|
100 |
+
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
101 |
+
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
102 |
+
|
103 |
+
if max_val is not None and y > max_val:
|
104 |
+
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
105 |
+
|
106 |
+
if y < min_val:
|
107 |
+
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
108 |
+
|
109 |
+
return y
|
110 |
+
|
111 |
+
def get_size(self, width, height):
|
112 |
+
# determine new height and width
|
113 |
+
scale_height = self.__height / height
|
114 |
+
scale_width = self.__width / width
|
115 |
+
|
116 |
+
if self.__keep_aspect_ratio:
|
117 |
+
if self.__resize_method == "lower_bound":
|
118 |
+
# scale such that output size is lower bound
|
119 |
+
if scale_width > scale_height:
|
120 |
+
# fit width
|
121 |
+
scale_height = scale_width
|
122 |
+
else:
|
123 |
+
# fit height
|
124 |
+
scale_width = scale_height
|
125 |
+
elif self.__resize_method == "upper_bound":
|
126 |
+
# scale such that output size is upper bound
|
127 |
+
if scale_width < scale_height:
|
128 |
+
# fit width
|
129 |
+
scale_height = scale_width
|
130 |
+
else:
|
131 |
+
# fit height
|
132 |
+
scale_width = scale_height
|
133 |
+
elif self.__resize_method == "minimal":
|
134 |
+
# scale as least as possbile
|
135 |
+
if abs(1 - scale_width) < abs(1 - scale_height):
|
136 |
+
# fit width
|
137 |
+
scale_height = scale_width
|
138 |
+
else:
|
139 |
+
# fit height
|
140 |
+
scale_width = scale_height
|
141 |
+
else:
|
142 |
+
raise ValueError(
|
143 |
+
f"resize_method {self.__resize_method} not implemented"
|
144 |
+
)
|
145 |
+
|
146 |
+
if self.__resize_method == "lower_bound":
|
147 |
+
new_height = self.constrain_to_multiple_of(
|
148 |
+
scale_height * height, min_val=self.__height
|
149 |
+
)
|
150 |
+
new_width = self.constrain_to_multiple_of(
|
151 |
+
scale_width * width, min_val=self.__width
|
152 |
+
)
|
153 |
+
elif self.__resize_method == "upper_bound":
|
154 |
+
new_height = self.constrain_to_multiple_of(
|
155 |
+
scale_height * height, max_val=self.__height
|
156 |
+
)
|
157 |
+
new_width = self.constrain_to_multiple_of(
|
158 |
+
scale_width * width, max_val=self.__width
|
159 |
+
)
|
160 |
+
elif self.__resize_method == "minimal":
|
161 |
+
new_height = self.constrain_to_multiple_of(scale_height * height)
|
162 |
+
new_width = self.constrain_to_multiple_of(scale_width * width)
|
163 |
+
else:
|
164 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
165 |
+
|
166 |
+
return (new_width, new_height)
|
167 |
+
|
168 |
+
def __call__(self, sample):
|
169 |
+
width, height = self.get_size(
|
170 |
+
sample["image"].shape[1], sample["image"].shape[0]
|
171 |
+
)
|
172 |
+
|
173 |
+
# resize sample
|
174 |
+
sample["image"] = cv2.resize(
|
175 |
+
sample["image"],
|
176 |
+
(width, height),
|
177 |
+
interpolation=self.__image_interpolation_method,
|
178 |
+
)
|
179 |
+
|
180 |
+
if self.__resize_target:
|
181 |
+
if "disparity" in sample:
|
182 |
+
sample["disparity"] = cv2.resize(
|
183 |
+
sample["disparity"],
|
184 |
+
(width, height),
|
185 |
+
interpolation=cv2.INTER_NEAREST,
|
186 |
+
)
|
187 |
+
|
188 |
+
if "depth" in sample:
|
189 |
+
sample["depth"] = cv2.resize(
|
190 |
+
sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
|
191 |
+
)
|
192 |
+
|
193 |
+
if "semseg_mask" in sample:
|
194 |
+
# sample["semseg_mask"] = cv2.resize(
|
195 |
+
# sample["semseg_mask"], (width, height), interpolation=cv2.INTER_NEAREST
|
196 |
+
# )
|
197 |
+
sample["semseg_mask"] = F.interpolate(torch.from_numpy(sample["semseg_mask"]).float()[None, None, ...], (height, width), mode='nearest').numpy()[0, 0]
|
198 |
+
|
199 |
+
if "mask" in sample:
|
200 |
+
sample["mask"] = cv2.resize(
|
201 |
+
sample["mask"].astype(np.float32),
|
202 |
+
(width, height),
|
203 |
+
interpolation=cv2.INTER_NEAREST,
|
204 |
+
)
|
205 |
+
# sample["mask"] = sample["mask"].astype(bool)
|
206 |
+
|
207 |
+
# print(sample['image'].shape, sample['depth'].shape)
|
208 |
+
return sample
|
209 |
+
|
210 |
+
|
211 |
+
class NormalizeImage(object):
|
212 |
+
"""Normlize image by given mean and std.
|
213 |
+
"""
|
214 |
+
|
215 |
+
def __init__(self, mean, std):
|
216 |
+
self.__mean = mean
|
217 |
+
self.__std = std
|
218 |
+
|
219 |
+
def __call__(self, sample):
|
220 |
+
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
221 |
+
|
222 |
+
return sample
|
223 |
+
|
224 |
+
|
225 |
+
class PrepareForNet(object):
|
226 |
+
"""Prepare sample for usage as network input.
|
227 |
+
"""
|
228 |
+
|
229 |
+
def __init__(self):
|
230 |
+
pass
|
231 |
+
|
232 |
+
def __call__(self, sample):
|
233 |
+
image = np.transpose(sample["image"], (2, 0, 1))
|
234 |
+
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
235 |
+
|
236 |
+
if "mask" in sample:
|
237 |
+
sample["mask"] = sample["mask"].astype(np.float32)
|
238 |
+
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
239 |
+
|
240 |
+
if "depth" in sample:
|
241 |
+
depth = sample["depth"].astype(np.float32)
|
242 |
+
sample["depth"] = np.ascontiguousarray(depth)
|
243 |
+
|
244 |
+
if "semseg_mask" in sample:
|
245 |
+
sample["semseg_mask"] = sample["semseg_mask"].astype(np.float32)
|
246 |
+
sample["semseg_mask"] = np.ascontiguousarray(sample["semseg_mask"])
|
247 |
+
|
248 |
+
return sample
|