import torch import torch.nn as nn import torchsparse import torchsparse.nn as spnn from torchsparse.tensor import PointTensor from tsparse.torchsparse_utils import * # __all__ = ['SPVCNN', 'SConv3d', 'SparseConvGRU'] class ConvBnReLU(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1): super(ConvBnReLU, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False) self.bn = nn.BatchNorm2d(out_channels) self.activation = nn.ReLU(inplace=True) def forward(self, x): return self.activation(self.bn(self.conv(x))) class ConvBnReLU3D(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1): super(ConvBnReLU3D, self).__init__() self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False) self.bn = nn.BatchNorm3d(out_channels) self.activation = nn.ReLU(inplace=True) def forward(self, x): return self.activation(self.bn(self.conv(x))) ################################### feature net ###################################### class FeatureNet(nn.Module): """ output 3 levels of features using a FPN structure """ def __init__(self): super(FeatureNet, self).__init__() self.conv0 = nn.Sequential( ConvBnReLU(3, 8, 3, 1, 1), ConvBnReLU(8, 8, 3, 1, 1)) self.conv1 = nn.Sequential( ConvBnReLU(8, 16, 5, 2, 2), ConvBnReLU(16, 16, 3, 1, 1), ConvBnReLU(16, 16, 3, 1, 1)) self.conv2 = nn.Sequential( ConvBnReLU(16, 32, 5, 2, 2), ConvBnReLU(32, 32, 3, 1, 1), ConvBnReLU(32, 32, 3, 1, 1)) self.toplayer = nn.Conv2d(32, 32, 1) self.lat1 = nn.Conv2d(16, 32, 1) self.lat0 = nn.Conv2d(8, 32, 1) # to reduce channel size of the outputs from FPN self.smooth1 = nn.Conv2d(32, 16, 3, padding=1) self.smooth0 = nn.Conv2d(32, 8, 3, padding=1) def _upsample_add(self, x, y): return torch.nn.functional.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) + y def forward(self, x): # x: (B, 3, H, W) conv0 = self.conv0(x) # (B, 8, H, W) conv1 = self.conv1(conv0) # (B, 16, H//2, W//2) conv2 = self.conv2(conv1) # (B, 32, H//4, W//4) feat2 = self.toplayer(conv2) # (B, 32, H//4, W//4) feat1 = self._upsample_add(feat2, self.lat1(conv1)) # (B, 32, H//2, W//2) feat0 = self._upsample_add(feat1, self.lat0(conv0)) # (B, 32, H, W) # reduce output channels feat1 = self.smooth1(feat1) # (B, 16, H//2, W//2) feat0 = self.smooth0(feat0) # (B, 8, H, W) # feats = {"level_0": feat0, # "level_1": feat1, # "level_2": feat2} return [feat2, feat1, feat0] # coarser to finer features class BasicSparseConvolutionBlock(nn.Module): def __init__(self, inc, outc, ks=3, stride=1, dilation=1): super().__init__() self.net = nn.Sequential( spnn.Conv3d(inc, outc, kernel_size=ks, dilation=dilation, stride=stride), spnn.BatchNorm(outc), spnn.ReLU(True)) def forward(self, x): out = self.net(x) return out class BasicSparseDeconvolutionBlock(nn.Module): def __init__(self, inc, outc, ks=3, stride=1): super().__init__() self.net = nn.Sequential( spnn.Conv3d(inc, outc, kernel_size=ks, stride=stride, transposed=True), spnn.BatchNorm(outc), spnn.ReLU(True)) def forward(self, x): return self.net(x) class SparseResidualBlock(nn.Module): def __init__(self, inc, outc, ks=3, stride=1, dilation=1): super().__init__() self.net = nn.Sequential( spnn.Conv3d(inc, outc, kernel_size=ks, dilation=dilation, stride=stride), spnn.BatchNorm(outc), spnn.ReLU(True), spnn.Conv3d(outc, outc, kernel_size=ks, dilation=dilation, stride=1), spnn.BatchNorm(outc)) self.downsample = nn.Sequential() if (inc == outc and stride == 1) else \ nn.Sequential( spnn.Conv3d(inc, outc, kernel_size=1, dilation=1, stride=stride), spnn.BatchNorm(outc) ) self.relu = spnn.ReLU(True) def forward(self, x): out = self.relu(self.net(x) + self.downsample(x)) return out class SPVCNN(nn.Module): def __init__(self, **kwargs): super().__init__() self.dropout = kwargs['dropout'] cr = kwargs.get('cr', 1.0) cs = [32, 64, 128, 96, 96] cs = [int(cr * x) for x in cs] if 'pres' in kwargs and 'vres' in kwargs: self.pres = kwargs['pres'] self.vres = kwargs['vres'] self.stem = nn.Sequential( spnn.Conv3d(kwargs['in_channels'], cs[0], kernel_size=3, stride=1), spnn.BatchNorm(cs[0]), spnn.ReLU(True) ) self.stage1 = nn.Sequential( BasicSparseConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1), SparseResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1), SparseResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1), ) self.stage2 = nn.Sequential( BasicSparseConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1), SparseResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1), SparseResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1), ) self.up1 = nn.ModuleList([ BasicSparseDeconvolutionBlock(cs[2], cs[3], ks=2, stride=2), nn.Sequential( SparseResidualBlock(cs[3] + cs[1], cs[3], ks=3, stride=1, dilation=1), SparseResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1), ) ]) self.up2 = nn.ModuleList([ BasicSparseDeconvolutionBlock(cs[3], cs[4], ks=2, stride=2), nn.Sequential( SparseResidualBlock(cs[4] + cs[0], cs[4], ks=3, stride=1, dilation=1), SparseResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1), ) ]) self.point_transforms = nn.ModuleList([ nn.Sequential( nn.Linear(cs[0], cs[2]), nn.BatchNorm1d(cs[2]), nn.ReLU(True), ), nn.Sequential( nn.Linear(cs[2], cs[4]), nn.BatchNorm1d(cs[4]), nn.ReLU(True), ) ]) self.weight_initialization() if self.dropout: self.dropout = nn.Dropout(0.3, True) def weight_initialization(self): for m in self.modules(): if isinstance(m, nn.BatchNorm1d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, z): # x: SparseTensor z: PointTensor x0 = initial_voxelize(z, self.pres, self.vres) x0 = self.stem(x0) z0 = voxel_to_point(x0, z, nearest=False) z0.F = z0.F x1 = point_to_voxel(x0, z0) x1 = self.stage1(x1) x2 = self.stage2(x1) z1 = voxel_to_point(x2, z0) z1.F = z1.F + self.point_transforms[0](z0.F) y3 = point_to_voxel(x2, z1) if self.dropout: y3.F = self.dropout(y3.F) y3 = self.up1[0](y3) y3 = torchsparse.cat([y3, x1]) y3 = self.up1[1](y3) y4 = self.up2[0](y3) y4 = torchsparse.cat([y4, x0]) y4 = self.up2[1](y4) z3 = voxel_to_point(y4, z1) z3.F = z3.F + self.point_transforms[1](z1.F) return z3.F class SparseCostRegNet(nn.Module): """ Sparse cost regularization network; require sparse tensors as input """ def __init__(self, d_in, d_out=8): super(SparseCostRegNet, self).__init__() self.d_in = d_in self.d_out = d_out self.conv0 = BasicSparseConvolutionBlock(d_in, d_out) self.conv1 = BasicSparseConvolutionBlock(d_out, 16, stride=2) self.conv2 = BasicSparseConvolutionBlock(16, 16) self.conv3 = BasicSparseConvolutionBlock(16, 32, stride=2) self.conv4 = BasicSparseConvolutionBlock(32, 32) self.conv5 = BasicSparseConvolutionBlock(32, 64, stride=2) self.conv6 = BasicSparseConvolutionBlock(64, 64) self.conv7 = BasicSparseDeconvolutionBlock(64, 32, ks=3, stride=2) self.conv9 = BasicSparseDeconvolutionBlock(32, 16, ks=3, stride=2) self.conv11 = BasicSparseDeconvolutionBlock(16, d_out, ks=3, stride=2) def forward(self, x): """ :param x: sparse tensor :return: sparse tensor """ conv0 = self.conv0(x) conv2 = self.conv2(self.conv1(conv0)) conv4 = self.conv4(self.conv3(conv2)) x = self.conv6(self.conv5(conv4)) x = conv4 + self.conv7(x) del conv4 x = conv2 + self.conv9(x) del conv2 x = conv0 + self.conv11(x) del conv0 return x.F class SConv3d(nn.Module): def __init__(self, inc, outc, pres, vres, ks=3, stride=1, dilation=1): super().__init__() self.net = spnn.Conv3d(inc, outc, kernel_size=ks, dilation=dilation, stride=stride) self.point_transforms = nn.Sequential( nn.Linear(inc, outc), ) self.pres = pres self.vres = vres def forward(self, z): x = initial_voxelize(z, self.pres, self.vres) x = self.net(x) out = voxel_to_point(x, z, nearest=False) out.F = out.F + self.point_transforms(z.F) return out