Text-to-3D
image-to-3d
Chao Xu
sparseneus and elev est
854f0d0
raw
history blame
10.5 kB
import torch
import torch.nn as nn
import torchsparse
import torchsparse.nn as spnn
from torchsparse.tensor import PointTensor
from tsparse.torchsparse_utils import *
# __all__ = ['SPVCNN', 'SConv3d', 'SparseConvGRU']
class ConvBnReLU(nn.Module):
def __init__(self, in_channels, out_channels,
kernel_size=3, stride=1, pad=1):
super(ConvBnReLU, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels,
kernel_size, stride=stride, padding=pad, bias=False)
self.bn = nn.BatchNorm2d(out_channels)
self.activation = nn.ReLU(inplace=True)
def forward(self, x):
return self.activation(self.bn(self.conv(x)))
class ConvBnReLU3D(nn.Module):
def __init__(self, in_channels, out_channels,
kernel_size=3, stride=1, pad=1):
super(ConvBnReLU3D, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels,
kernel_size, stride=stride, padding=pad, bias=False)
self.bn = nn.BatchNorm3d(out_channels)
self.activation = nn.ReLU(inplace=True)
def forward(self, x):
return self.activation(self.bn(self.conv(x)))
################################### feature net ######################################
class FeatureNet(nn.Module):
"""
output 3 levels of features using a FPN structure
"""
def __init__(self):
super(FeatureNet, self).__init__()
self.conv0 = nn.Sequential(
ConvBnReLU(3, 8, 3, 1, 1),
ConvBnReLU(8, 8, 3, 1, 1))
self.conv1 = nn.Sequential(
ConvBnReLU(8, 16, 5, 2, 2),
ConvBnReLU(16, 16, 3, 1, 1),
ConvBnReLU(16, 16, 3, 1, 1))
self.conv2 = nn.Sequential(
ConvBnReLU(16, 32, 5, 2, 2),
ConvBnReLU(32, 32, 3, 1, 1),
ConvBnReLU(32, 32, 3, 1, 1))
self.toplayer = nn.Conv2d(32, 32, 1)
self.lat1 = nn.Conv2d(16, 32, 1)
self.lat0 = nn.Conv2d(8, 32, 1)
# to reduce channel size of the outputs from FPN
self.smooth1 = nn.Conv2d(32, 16, 3, padding=1)
self.smooth0 = nn.Conv2d(32, 8, 3, padding=1)
def _upsample_add(self, x, y):
return torch.nn.functional.interpolate(x, scale_factor=2,
mode="bilinear", align_corners=True) + y
def forward(self, x):
# x: (B, 3, H, W)
conv0 = self.conv0(x) # (B, 8, H, W)
conv1 = self.conv1(conv0) # (B, 16, H//2, W//2)
conv2 = self.conv2(conv1) # (B, 32, H//4, W//4)
feat2 = self.toplayer(conv2) # (B, 32, H//4, W//4)
feat1 = self._upsample_add(feat2, self.lat1(conv1)) # (B, 32, H//2, W//2)
feat0 = self._upsample_add(feat1, self.lat0(conv0)) # (B, 32, H, W)
# reduce output channels
feat1 = self.smooth1(feat1) # (B, 16, H//2, W//2)
feat0 = self.smooth0(feat0) # (B, 8, H, W)
# feats = {"level_0": feat0,
# "level_1": feat1,
# "level_2": feat2}
return [feat2, feat1, feat0] # coarser to finer features
class BasicSparseConvolutionBlock(nn.Module):
def __init__(self, inc, outc, ks=3, stride=1, dilation=1):
super().__init__()
self.net = nn.Sequential(
spnn.Conv3d(inc,
outc,
kernel_size=ks,
dilation=dilation,
stride=stride),
spnn.BatchNorm(outc),
spnn.ReLU(True))
def forward(self, x):
out = self.net(x)
return out
class BasicSparseDeconvolutionBlock(nn.Module):
def __init__(self, inc, outc, ks=3, stride=1):
super().__init__()
self.net = nn.Sequential(
spnn.Conv3d(inc,
outc,
kernel_size=ks,
stride=stride,
transposed=True),
spnn.BatchNorm(outc),
spnn.ReLU(True))
def forward(self, x):
return self.net(x)
class SparseResidualBlock(nn.Module):
def __init__(self, inc, outc, ks=3, stride=1, dilation=1):
super().__init__()
self.net = nn.Sequential(
spnn.Conv3d(inc,
outc,
kernel_size=ks,
dilation=dilation,
stride=stride), spnn.BatchNorm(outc),
spnn.ReLU(True),
spnn.Conv3d(outc,
outc,
kernel_size=ks,
dilation=dilation,
stride=1), spnn.BatchNorm(outc))
self.downsample = nn.Sequential() if (inc == outc and stride == 1) else \
nn.Sequential(
spnn.Conv3d(inc, outc, kernel_size=1, dilation=1, stride=stride),
spnn.BatchNorm(outc)
)
self.relu = spnn.ReLU(True)
def forward(self, x):
out = self.relu(self.net(x) + self.downsample(x))
return out
class SPVCNN(nn.Module):
def __init__(self, **kwargs):
super().__init__()
self.dropout = kwargs['dropout']
cr = kwargs.get('cr', 1.0)
cs = [32, 64, 128, 96, 96]
cs = [int(cr * x) for x in cs]
if 'pres' in kwargs and 'vres' in kwargs:
self.pres = kwargs['pres']
self.vres = kwargs['vres']
self.stem = nn.Sequential(
spnn.Conv3d(kwargs['in_channels'], cs[0], kernel_size=3, stride=1),
spnn.BatchNorm(cs[0]), spnn.ReLU(True)
)
self.stage1 = nn.Sequential(
BasicSparseConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1),
SparseResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1),
SparseResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1),
)
self.stage2 = nn.Sequential(
BasicSparseConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1),
SparseResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1),
SparseResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1),
)
self.up1 = nn.ModuleList([
BasicSparseDeconvolutionBlock(cs[2], cs[3], ks=2, stride=2),
nn.Sequential(
SparseResidualBlock(cs[3] + cs[1], cs[3], ks=3, stride=1,
dilation=1),
SparseResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1),
)
])
self.up2 = nn.ModuleList([
BasicSparseDeconvolutionBlock(cs[3], cs[4], ks=2, stride=2),
nn.Sequential(
SparseResidualBlock(cs[4] + cs[0], cs[4], ks=3, stride=1,
dilation=1),
SparseResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1),
)
])
self.point_transforms = nn.ModuleList([
nn.Sequential(
nn.Linear(cs[0], cs[2]),
nn.BatchNorm1d(cs[2]),
nn.ReLU(True),
),
nn.Sequential(
nn.Linear(cs[2], cs[4]),
nn.BatchNorm1d(cs[4]),
nn.ReLU(True),
)
])
self.weight_initialization()
if self.dropout:
self.dropout = nn.Dropout(0.3, True)
def weight_initialization(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, z):
# x: SparseTensor z: PointTensor
x0 = initial_voxelize(z, self.pres, self.vres)
x0 = self.stem(x0)
z0 = voxel_to_point(x0, z, nearest=False)
z0.F = z0.F
x1 = point_to_voxel(x0, z0)
x1 = self.stage1(x1)
x2 = self.stage2(x1)
z1 = voxel_to_point(x2, z0)
z1.F = z1.F + self.point_transforms[0](z0.F)
y3 = point_to_voxel(x2, z1)
if self.dropout:
y3.F = self.dropout(y3.F)
y3 = self.up1[0](y3)
y3 = torchsparse.cat([y3, x1])
y3 = self.up1[1](y3)
y4 = self.up2[0](y3)
y4 = torchsparse.cat([y4, x0])
y4 = self.up2[1](y4)
z3 = voxel_to_point(y4, z1)
z3.F = z3.F + self.point_transforms[1](z1.F)
return z3.F
class SparseCostRegNet(nn.Module):
"""
Sparse cost regularization network;
require sparse tensors as input
"""
def __init__(self, d_in, d_out=8):
super(SparseCostRegNet, self).__init__()
self.d_in = d_in
self.d_out = d_out
self.conv0 = BasicSparseConvolutionBlock(d_in, d_out)
self.conv1 = BasicSparseConvolutionBlock(d_out, 16, stride=2)
self.conv2 = BasicSparseConvolutionBlock(16, 16)
self.conv3 = BasicSparseConvolutionBlock(16, 32, stride=2)
self.conv4 = BasicSparseConvolutionBlock(32, 32)
self.conv5 = BasicSparseConvolutionBlock(32, 64, stride=2)
self.conv6 = BasicSparseConvolutionBlock(64, 64)
self.conv7 = BasicSparseDeconvolutionBlock(64, 32, ks=3, stride=2)
self.conv9 = BasicSparseDeconvolutionBlock(32, 16, ks=3, stride=2)
self.conv11 = BasicSparseDeconvolutionBlock(16, d_out, ks=3, stride=2)
def forward(self, x):
"""
:param x: sparse tensor
:return: sparse tensor
"""
conv0 = self.conv0(x)
conv2 = self.conv2(self.conv1(conv0))
conv4 = self.conv4(self.conv3(conv2))
x = self.conv6(self.conv5(conv4))
x = conv4 + self.conv7(x)
del conv4
x = conv2 + self.conv9(x)
del conv2
x = conv0 + self.conv11(x)
del conv0
return x.F
class SConv3d(nn.Module):
def __init__(self, inc, outc, pres, vres, ks=3, stride=1, dilation=1):
super().__init__()
self.net = spnn.Conv3d(inc,
outc,
kernel_size=ks,
dilation=dilation,
stride=stride)
self.point_transforms = nn.Sequential(
nn.Linear(inc, outc),
)
self.pres = pres
self.vres = vres
def forward(self, z):
x = initial_voxelize(z, self.pres, self.vres)
x = self.net(x)
out = voxel_to_point(x, z, nearest=False)
out.F = out.F + self.point_transforms(z.F)
return out