Spaces:
Paused
Paused
import torch | |
import torch.nn as nn | |
import numpy as np | |
import math | |
import torch.nn.functional as F | |
def compute_depth_expectation(prob, depth_values): | |
depth_values = depth_values.view(*depth_values.shape, 1, 1) | |
depth = torch.sum(prob * depth_values, 1) | |
return depth | |
class ConvBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size=3): | |
super(ConvBlock, self).__init__() | |
if kernel_size == 3: | |
self.conv = nn.Sequential( | |
nn.ReflectionPad2d(1), | |
nn.Conv2d(in_channels, out_channels, 3, padding=0, stride=1), | |
) | |
elif kernel_size == 1: | |
self.conv = nn.Conv2d(int(in_channels), int(out_channels), 1, padding=0, stride=1) | |
self.nonlin = nn.ELU(inplace=True) | |
def forward(self, x): | |
out = self.conv(x) | |
out = self.nonlin(out) | |
return out | |
class ConvBlock_double(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size=3): | |
super(ConvBlock_double, self).__init__() | |
if kernel_size == 3: | |
self.conv = nn.Sequential( | |
nn.ReflectionPad2d(1), | |
nn.Conv2d(in_channels, out_channels, 3, padding=0, stride=1), | |
) | |
elif kernel_size == 1: | |
self.conv = nn.Conv2d(int(in_channels), int(out_channels), 1, padding=0, stride=1) | |
self.nonlin = nn.ELU(inplace=True) | |
self.conv_2 = nn.Conv2d(out_channels, out_channels, 1, padding=0, stride=1) | |
self.nonlin_2 =nn.ELU(inplace=True) | |
def forward(self, x): | |
out = self.conv(x) | |
out = self.nonlin(out) | |
out = self.conv_2(out) | |
out = self.nonlin_2(out) | |
return out | |
class DecoderFeature(nn.Module): | |
def __init__(self, feat_channels, num_ch_dec=[64, 64, 128, 256]): | |
super(DecoderFeature, self).__init__() | |
self.num_ch_dec = num_ch_dec | |
self.feat_channels = feat_channels | |
self.upconv_3_0 = ConvBlock(self.feat_channels[3], self.num_ch_dec[3], kernel_size=1) | |
self.upconv_3_1 = ConvBlock_double( | |
self.feat_channels[2] + self.num_ch_dec[3], | |
self.num_ch_dec[3], | |
kernel_size=1) | |
self.upconv_2_0 = ConvBlock(self.num_ch_dec[3], self.num_ch_dec[2], kernel_size=3) | |
self.upconv_2_1 = ConvBlock_double( | |
self.feat_channels[1] + self.num_ch_dec[2], | |
self.num_ch_dec[2], | |
kernel_size=3) | |
self.upconv_1_0 = ConvBlock(self.num_ch_dec[2], self.num_ch_dec[1], kernel_size=3) | |
self.upconv_1_1 = ConvBlock_double( | |
self.feat_channels[0] + self.num_ch_dec[1], | |
self.num_ch_dec[1], | |
kernel_size=3) | |
self.upsample = nn.Upsample(scale_factor=2, mode='nearest') | |
def forward(self, ref_feature): | |
x = ref_feature[3] | |
x = self.upconv_3_0(x) | |
x = torch.cat((self.upsample(x), ref_feature[2]), 1) | |
x = self.upconv_3_1(x) | |
x = self.upconv_2_0(x) | |
x = torch.cat((self.upsample(x), ref_feature[1]), 1) | |
x = self.upconv_2_1(x) | |
x = self.upconv_1_0(x) | |
x = torch.cat((self.upsample(x), ref_feature[0]), 1) | |
x = self.upconv_1_1(x) | |
return x | |
class UNet(nn.Module): | |
def __init__(self, inp_ch=32, output_chal=1, down_sample_times=3, channel_mode='v0'): | |
super(UNet, self).__init__() | |
basic_block = ConvBnReLU | |
num_depth = 128 | |
self.conv0 = basic_block(inp_ch, num_depth) | |
if channel_mode == 'v0': | |
channels = [num_depth, num_depth//2, num_depth//4, num_depth//8, num_depth // 8] | |
elif channel_mode == 'v1': | |
channels = [num_depth, num_depth, num_depth, num_depth, num_depth, num_depth] | |
self.down_sample_times = down_sample_times | |
for i in range(down_sample_times): | |
setattr( | |
self, 'conv_%d' % i, | |
nn.Sequential( | |
basic_block(channels[i], channels[i+1], stride=2), | |
basic_block(channels[i+1], channels[i+1]) | |
) | |
) | |
for i in range(down_sample_times-1,-1,-1): | |
setattr(self, 'deconv_%d' % i, | |
nn.Sequential( | |
nn.ConvTranspose2d( | |
channels[i+1], | |
channels[i], | |
kernel_size=3, | |
padding=1, | |
output_padding=1, | |
stride=2, | |
bias=False), | |
nn.BatchNorm2d(channels[i]), | |
nn.ReLU(inplace=True) | |
) | |
) | |
self.prob = nn.Conv2d(num_depth, output_chal, 1, stride=1, padding=0) | |
def forward(self, x): | |
features = {} | |
conv0 = self.conv0(x) | |
x = conv0 | |
features[0] = conv0 | |
for i in range(self.down_sample_times): | |
x = getattr(self, 'conv_%d' % i)(x) | |
features[i+1] = x | |
for i in range(self.down_sample_times-1,-1,-1): | |
x = features[i] + getattr(self, 'deconv_%d' % i)(x) | |
x = self.prob(x) | |
return x | |
class ConvBnReLU(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1): | |
super(ConvBnReLU, self).__init__() | |
self.conv = nn.Conv2d( | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride=stride, | |
padding=pad, | |
bias=False | |
) | |
self.bn = nn.BatchNorm2d(out_channels) | |
def forward(self, x): | |
return F.relu(self.bn(self.conv(x)), inplace=True) | |
class HourglassDecoder(nn.Module): | |
def __init__(self, cfg): | |
super(HourglassDecoder, self).__init__() | |
self.inchannels = cfg.model.decode_head.in_channels # [256, 512, 1024, 2048] | |
self.decoder_channels = cfg.model.decode_head.decoder_channel # [64, 64, 128, 256] | |
self.min_val = cfg.data_basic.depth_normalize[0] | |
self.max_val = cfg.data_basic.depth_normalize[1] | |
self.num_ch_dec = self.decoder_channels # [64, 64, 128, 256] | |
self.num_depth_regressor_anchor = 512 | |
self.feat_channels = self.inchannels | |
unet_in_channel = self.num_ch_dec[1] | |
unet_out_channel = 256 | |
self.decoder_mono = DecoderFeature(self.feat_channels, self.num_ch_dec) | |
self.conv_out_2 = UNet(inp_ch=unet_in_channel, | |
output_chal=unet_out_channel + 1, | |
down_sample_times=3, | |
channel_mode='v0', | |
) | |
self.depth_regressor_2 = nn.Sequential( | |
nn.Conv2d(unet_out_channel, | |
self.num_depth_regressor_anchor, | |
kernel_size=3, | |
padding=1, | |
), | |
nn.BatchNorm2d(self.num_depth_regressor_anchor), | |
nn.ReLU(inplace=True), | |
nn.Conv2d( | |
self.num_depth_regressor_anchor, | |
self.num_depth_regressor_anchor, | |
kernel_size=1, | |
) | |
) | |
self.residual_channel = 16 | |
self.conv_up_2 = nn.Sequential( | |
nn.Conv2d(1 + 2 + unet_out_channel, self.residual_channel, 3, padding=1), | |
nn.BatchNorm2d(self.residual_channel), | |
nn.ReLU(), | |
nn.Conv2d(self.residual_channel, self.residual_channel, 3, padding=1), | |
nn.Upsample(scale_factor=4), | |
nn.Conv2d(self.residual_channel, self.residual_channel, 3, padding=1), | |
nn.ReLU(), | |
nn.Conv2d(self.residual_channel, 1, 1, padding=0), | |
) | |
def get_bins(self, bins_num): | |
depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num, device='cuda') | |
depth_bins_vec = torch.exp(depth_bins_vec) | |
return depth_bins_vec | |
def register_depth_expectation_anchor(self, bins_num, B): | |
depth_bins_vec = self.get_bins(bins_num) | |
depth_bins_vec = depth_bins_vec.unsqueeze(0).repeat(B, 1) | |
self.register_buffer('depth_expectation_anchor', depth_bins_vec, persistent=False) | |
def upsample(self, x, scale_factor=2): | |
return F.interpolate(x, scale_factor=scale_factor, mode='nearest') | |
def regress_depth_2(self, feature_map_d): | |
prob = self.depth_regressor_2(feature_map_d).softmax(dim=1) | |
B = prob.shape[0] | |
if "depth_expectation_anchor" not in self._buffers: | |
self.register_depth_expectation_anchor(self.num_depth_regressor_anchor, B) | |
d = compute_depth_expectation( | |
prob, | |
self.depth_expectation_anchor[:B, ...] | |
).unsqueeze(1) | |
return d | |
def create_mesh_grid(self, height, width, batch, device="cuda", set_buffer=True): | |
y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=device), | |
torch.arange(0, width, dtype=torch.float32, device=device)], indexing='ij') | |
meshgrid = torch.stack((x, y)) | |
meshgrid = meshgrid.unsqueeze(0).repeat(batch, 1, 1, 1) | |
return meshgrid | |
def forward(self, features_mono, **kwargs): | |
''' | |
trans_ref2src: list of transformation matrix from the reference view to source view. [B, 4, 4] | |
inv_intrinsic_pool: list of inverse intrinsic matrix. | |
features_mono: features of reference and source views. [[ref_f1, ref_f2, ref_f3, ref_f4],[src1_f1, src1_f2, src1_f3, src1_f4], ...]. | |
''' | |
outputs = {} | |
# get encoder feature of the reference view | |
ref_feat = features_mono | |
feature_map_mono = self.decoder_mono(ref_feat) | |
feature_map_mono_pred = self.conv_out_2(feature_map_mono) | |
confidence_map_2 = feature_map_mono_pred[:, -1:, :, :] | |
feature_map_d_2 = feature_map_mono_pred[:, :-1, :, :] | |
depth_pred_2 = self.regress_depth_2(feature_map_d_2) | |
B, _, H, W = depth_pred_2.shape | |
meshgrid = self.create_mesh_grid(H, W, B) | |
depth_pred_mono = self.upsample(depth_pred_2, scale_factor=4) + 1e-1 * \ | |
self.conv_up_2( | |
torch.cat((depth_pred_2, meshgrid[:B, ...], feature_map_d_2), 1) | |
) | |
confidence_map_mono = self.upsample(confidence_map_2, scale_factor=4) | |
outputs=dict( | |
prediction=depth_pred_mono, | |
confidence=confidence_map_mono, | |
pred_logit=None, | |
) | |
return outputs |