from __future__ import division, print_function import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from torch.autograd import Function import voxelize_cuda class VoxelizationFunction(Function): """ Definition of differentiable voxelization function Currently implemented only for cuda Tensors """ @staticmethod def forward( ctx, smpl_vertices, smpl_face_center, smpl_face_normal, smpl_vertex_code, smpl_face_code, smpl_tetrahedrons, volume_res, sigma, smooth_kernel_size, ): """ forward pass Output format: (batch_size, z_dims, y_dims, x_dims, channel_num) """ assert smpl_vertices.size()[1] == smpl_vertex_code.size()[1] assert smpl_face_center.size()[1] == smpl_face_normal.size()[1] assert smpl_face_center.size()[1] == smpl_face_code.size()[1] ctx.batch_size = smpl_vertices.size()[0] ctx.volume_res = volume_res ctx.sigma = sigma ctx.smooth_kernel_size = smooth_kernel_size ctx.smpl_vertex_num = smpl_vertices.size()[1] ctx.device = smpl_vertices.device smpl_vertices = smpl_vertices.contiguous() smpl_face_center = smpl_face_center.contiguous() smpl_face_normal = smpl_face_normal.contiguous() smpl_vertex_code = smpl_vertex_code.contiguous() smpl_face_code = smpl_face_code.contiguous() smpl_tetrahedrons = smpl_tetrahedrons.contiguous() occ_volume = torch.cuda.FloatTensor( ctx.batch_size, ctx.volume_res, ctx.volume_res, ctx.volume_res ).fill_(0.0) semantic_volume = torch.cuda.FloatTensor( ctx.batch_size, ctx.volume_res, ctx.volume_res, ctx.volume_res, 3 ).fill_(0.0) weight_sum_volume = torch.cuda.FloatTensor( ctx.batch_size, ctx.volume_res, ctx.volume_res, ctx.volume_res ).fill_(1e-3) # occ_volume [B, volume_res, volume_res, volume_res] # semantic_volume [B, volume_res, volume_res, volume_res, 3] # weight_sum_volume [B, volume_res, volume_res, volume_res] ( occ_volume, semantic_volume, weight_sum_volume, ) = voxelize_cuda.forward_semantic_voxelization( smpl_vertices, smpl_vertex_code, smpl_tetrahedrons, occ_volume, semantic_volume, weight_sum_volume, sigma, ) return semantic_volume class Voxelization(nn.Module): """ Wrapper around the autograd function VoxelizationFunction """ def __init__( self, smpl_vertex_code, smpl_face_code, smpl_face_indices, smpl_tetraderon_indices, volume_res, sigma, smooth_kernel_size, batch_size, ): super(Voxelization, self).__init__() assert len(smpl_face_indices.shape) == 2 assert len(smpl_tetraderon_indices.shape) == 2 assert smpl_face_indices.shape[1] == 3 assert smpl_tetraderon_indices.shape[1] == 4 self.volume_res = volume_res self.sigma = sigma self.smooth_kernel_size = smooth_kernel_size self.batch_size = batch_size self.device = None self.smpl_vertex_code = smpl_vertex_code self.smpl_face_code = smpl_face_code self.smpl_face_indices = smpl_face_indices self.smpl_tetraderon_indices = smpl_tetraderon_indices def update_param(self, voxel_faces): self.device = voxel_faces.device self.smpl_tetraderon_indices = voxel_faces smpl_vertex_code_batch = torch.tile(self.smpl_vertex_code, (self.batch_size, 1, 1)) smpl_face_code_batch = torch.tile(self.smpl_face_code, (self.batch_size, 1, 1)) smpl_face_indices_batch = torch.tile(self.smpl_face_indices, (self.batch_size, 1, 1)) smpl_vertex_code_batch = (smpl_vertex_code_batch.contiguous().to(self.device)) smpl_face_code_batch = (smpl_face_code_batch.contiguous().to(self.device)) smpl_face_indices_batch = (smpl_face_indices_batch.contiguous().to(self.device)) smpl_tetraderon_indices_batch = (self.smpl_tetraderon_indices.contiguous().to(self.device)) self.register_buffer("smpl_vertex_code_batch", smpl_vertex_code_batch) self.register_buffer("smpl_face_code_batch", smpl_face_code_batch) self.register_buffer("smpl_face_indices_batch", smpl_face_indices_batch) self.register_buffer("smpl_tetraderon_indices_batch", smpl_tetraderon_indices_batch) def forward(self, smpl_vertices): """ Generate semantic volumes from SMPL vertices """ self.check_input(smpl_vertices) smpl_faces = self.vertices_to_faces(smpl_vertices) smpl_tetrahedrons = self.vertices_to_tetrahedrons(smpl_vertices) smpl_face_center = self.calc_face_centers(smpl_faces) smpl_face_normal = self.calc_face_normals(smpl_faces) smpl_surface_vertex_num = self.smpl_vertex_code_batch.size()[1] smpl_vertices_surface = smpl_vertices[:, :smpl_surface_vertex_num, :] vol = VoxelizationFunction.apply( smpl_vertices_surface, smpl_face_center, smpl_face_normal, self.smpl_vertex_code_batch, self.smpl_face_code_batch, smpl_tetrahedrons, self.volume_res, self.sigma, self.smooth_kernel_size, ) return vol.permute((0, 4, 1, 2, 3)) # (bzyxc --> bcdhw) def vertices_to_faces(self, vertices): assert vertices.ndimension() == 3 bs, nv = vertices.shape[:2] face = ( self.smpl_face_indices_batch + (torch.arange(bs, dtype=torch.int32).to(self.device) * nv)[:, None, None] ) vertices_ = vertices.reshape((bs * nv, 3)) return vertices_[face.long()] def vertices_to_tetrahedrons(self, vertices): assert vertices.ndimension() == 3 bs, nv = vertices.shape[:2] tets = ( self.smpl_tetraderon_indices_batch + (torch.arange(bs, dtype=torch.int32).to(self.device) * nv)[:, None, None] ) vertices_ = vertices.reshape((bs * nv, 3)) return vertices_[tets.long()] def calc_face_centers(self, face_verts): assert len(face_verts.shape) == 4 assert face_verts.shape[2] == 3 assert face_verts.shape[3] == 3 bs, nf = face_verts.shape[:2] face_centers = ( face_verts[:, :, 0, :] + face_verts[:, :, 1, :] + face_verts[:, :, 2, :] ) / 3.0 face_centers = face_centers.reshape((bs, nf, 3)) return face_centers def calc_face_normals(self, face_verts): assert len(face_verts.shape) == 4 assert face_verts.shape[2] == 3 assert face_verts.shape[3] == 3 bs, nf = face_verts.shape[:2] face_verts = face_verts.reshape((bs * nf, 3, 3)) v10 = face_verts[:, 0] - face_verts[:, 1] v12 = face_verts[:, 2] - face_verts[:, 1] normals = F.normalize(torch.cross(v10, v12), eps=1e-5) normals = normals.reshape((bs, nf, 3)) return normals def check_input(self, x): if x.device == "cpu": raise TypeError("Voxelization module supports only cuda tensors") if x.type() != "torch.cuda.FloatTensor": raise TypeError("Voxelization module supports only float32 tensors")