File size: 6,162 Bytes
f53b39e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
"""Copyright (C) 2024 Apple Inc. All Rights Reserved.
Dense Prediction Transformer Decoder architecture.
Implements a variant of Vision Transformers for Dense Prediction, https://arxiv.org/abs/2103.13413
"""
from __future__ import annotations
from typing import Iterable
import torch
from torch import nn
class MultiresConvDecoder(nn.Module):
"""Decoder for multi-resolution encodings."""
def __init__(
self,
dims_encoder: Iterable[int],
dim_decoder: int,
):
"""Initialize multiresolution convolutional decoder.
Args:
----
dims_encoder: Expected dims at each level from the encoder.
dim_decoder: Dim of decoder features.
"""
super().__init__()
self.dims_encoder = list(dims_encoder)
self.dim_decoder = dim_decoder
self.dim_out = dim_decoder
num_encoders = len(self.dims_encoder)
# At the highest resolution, i.e. level 0, we apply projection w/ 1x1 convolution
# when the dimensions mismatch. Otherwise we do not do anything, which is
# the default behavior of monodepth.
conv0 = (
nn.Conv2d(self.dims_encoder[0], dim_decoder, kernel_size=1, bias=False)
if self.dims_encoder[0] != dim_decoder
else nn.Identity()
)
convs = [conv0]
for i in range(1, num_encoders):
convs.append(
nn.Conv2d(
self.dims_encoder[i],
dim_decoder,
kernel_size=3,
stride=1,
padding=1,
bias=False,
)
)
self.convs = nn.ModuleList(convs)
fusions = []
for i in range(num_encoders):
fusions.append(
FeatureFusionBlock2d(
num_features=dim_decoder,
deconv=(i != 0),
batch_norm=False,
)
)
self.fusions = nn.ModuleList(fusions)
def forward(self, encodings: torch.Tensor) -> torch.Tensor:
"""Decode the multi-resolution encodings."""
num_levels = len(encodings)
num_encoders = len(self.dims_encoder)
if num_levels != num_encoders:
raise ValueError(
f"Got encoder output levels={num_levels}, expected levels={num_encoders+1}."
)
# Project features of different encoder dims to the same decoder dim.
# Fuse features from the lowest resolution (num_levels-1)
# to the highest (0).
features = self.convs[-1](encodings[-1])
lowres_features = features
features = self.fusions[-1](features)
for i in range(num_levels - 2, -1, -1):
features_i = self.convs[i](encodings[i])
features = self.fusions[i](features, features_i)
return features, lowres_features
class ResidualBlock(nn.Module):
"""Generic implementation of residual blocks.
This implements a generic residual block from
He et al. - Identity Mappings in Deep Residual Networks (2016),
https://arxiv.org/abs/1603.05027
which can be further customized via factory functions.
"""
def __init__(self, residual: nn.Module, shortcut: nn.Module | None = None) -> None:
"""Initialize ResidualBlock."""
super().__init__()
self.residual = residual
self.shortcut = shortcut
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply residual block."""
delta_x = self.residual(x)
if self.shortcut is not None:
x = self.shortcut(x)
return x + delta_x
class FeatureFusionBlock2d(nn.Module):
"""Feature fusion for DPT."""
def __init__(
self,
num_features: int,
deconv: bool = False,
batch_norm: bool = False,
):
"""Initialize feature fusion block.
Args:
----
num_features: Input and output dimensions.
deconv: Whether to use deconv before the final output conv.
batch_norm: Whether to use batch normalization in resnet blocks.
"""
super().__init__()
self.resnet1 = self._residual_block(num_features, batch_norm)
self.resnet2 = self._residual_block(num_features, batch_norm)
self.use_deconv = deconv
if deconv:
self.deconv = nn.ConvTranspose2d(
in_channels=num_features,
out_channels=num_features,
kernel_size=2,
stride=2,
padding=0,
bias=False,
)
self.out_conv = nn.Conv2d(
num_features,
num_features,
kernel_size=1,
stride=1,
padding=0,
bias=True,
)
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x0: torch.Tensor, x1: torch.Tensor | None = None) -> torch.Tensor:
"""Process and fuse input features."""
x = x0
if x1 is not None:
res = self.resnet1(x1)
x = self.skip_add.add(x, res)
x = self.resnet2(x)
if self.use_deconv:
x = self.deconv(x)
x = self.out_conv(x)
return x
@staticmethod
def _residual_block(num_features: int, batch_norm: bool):
"""Create a residual block."""
def _create_block(dim: int, batch_norm: bool) -> list[nn.Module]:
layers = [
nn.ReLU(False),
nn.Conv2d(
num_features,
num_features,
kernel_size=3,
stride=1,
padding=1,
bias=not batch_norm,
),
]
if batch_norm:
layers.append(nn.BatchNorm2d(dim))
return layers
residual = nn.Sequential(
*_create_block(dim=num_features, batch_norm=batch_norm),
*_create_block(dim=num_features, batch_norm=batch_norm),
)
return ResidualBlock(residual)
|