import torch import kaolin import math from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer class CameraModule(): def __init__(self): self.bg_color = torch.tensor([1.0] * 32).float() self.scale_modifier = 1.0 def perspective_camera(self, points, camera_proj): projected_points = torch.bmm(points, camera_proj.permute(0, 2, 1)) projected_2d_points = projected_points[:, :, :2] / projected_points[:, :, 2:3] return projected_2d_points def prepare_vertices(self, vertices, faces, camera_proj, camera_rot=None, camera_trans=None, camera_transform=None): if camera_transform is None: assert camera_trans is not None and camera_rot is not None, \ "camera_transform or camera_trans and camera_rot must be defined" vertices_camera = kaolin.render.camera.rotate_translate_points(vertices, camera_rot, camera_trans) else: assert camera_trans is None and camera_rot is None, \ "camera_trans and camera_rot must be None when camera_transform is defined" padded_vertices = torch.nn.functional.pad( vertices, (0, 1), mode='constant', value=1. ) vertices_camera = (padded_vertices @ camera_transform) # Project the vertices on the camera image plan vertices_image = self.perspective_camera(vertices_camera, camera_proj) face_vertices_camera = kaolin.ops.mesh.index_vertices_by_faces(vertices_camera, faces) face_vertices_image = kaolin.ops.mesh.index_vertices_by_faces(vertices_image, faces) face_normals = kaolin.ops.mesh.face_normals(face_vertices_camera, unit=True) return face_vertices_camera, face_vertices_image, face_normals def render(self, data, resolution): verts_list = data['verts_list'] faces_list = data['faces_list'] verts_color_list = data['verts_color_list'] B = len(verts_list) render_images = [] render_soft_masks = [] render_depths = [] render_normals = [] face_normals_list = [] for b in range(B): intrinsics = data['intrinsics'][b] extrinsics = data['extrinsics'][b] #camera_proj = torch.stack([intrinsics[:, 0, 0] / intrinsics[:, 0, 2], intrinsics[:, 1, 1] / intrinsics[:, 1, 2], torch.ones_like(intrinsics[:, 0, 0])], -1).to(device) camera_proj = intrinsics camera_transform = extrinsics.permute(0, 2, 1) verts = verts_list[b].unsqueeze(0).repeat(intrinsics.shape[0], 1, 1) faces = faces_list[b] verts_color = verts_color_list[b].unsqueeze(0).repeat(intrinsics.shape[0], 1, 1) faces_color = verts_color[:, faces] face_vertices_camera, face_vertices_image, face_normals = self.prepare_vertices( verts, faces, camera_proj, camera_transform=camera_transform ) face_vertices_image[:, :, :, 1] = -face_vertices_image[:, :, :, 1] #face_vertices_camera[:, :, :, 1:] = -face_vertices_camera[:, :, :, 1:] face_normals[:, :, 1:] = -face_normals[:, :, 1:] ### Perform Rasterization ### # Construct attributes that DI1-R rasterizer will interpolate. # the first is the UVS associated to each face # the second will make a hard segmentation mask face_attributes = [ faces_color, torch.ones((faces_color.shape[0], faces_color.shape[1], 3, 1), device=verts.device), face_vertices_camera[:, :, :, 2:], face_normals.unsqueeze(-2).repeat(1, 1, 3, 1), ] # If you have nvdiffrast installed you can change rast_backend to # nvdiffrast or nvdiffrast_fwd image_features, soft_masks, face_idx = kaolin.render.mesh.dibr_rasterization( resolution, resolution, -face_vertices_camera[:, :, :, -1], face_vertices_image, face_attributes, face_normals[:, :, -1], rast_backend='cuda') # image_features is a tuple in composed of the interpolated attributes of face_attributes images, masks, depths, normals = image_features images = torch.clamp(images * masks, 0., 1.) depths = (depths * masks) normals = (normals * masks) render_images.append(images) render_soft_masks.append(soft_masks) render_depths.append(depths) render_normals.append(normals) face_normals_list.append(face_normals) render_images = torch.stack(render_images, 0) render_soft_masks = torch.stack(render_soft_masks, 0) render_depths = torch.stack(render_depths, 0) render_normals = torch.stack(render_normals, 0) data['render_images'] = render_images data['render_soft_masks'] = render_soft_masks data['render_depths'] = render_depths data['render_normals'] = render_normals data['verts_list'] = verts_list data['faces_list'] = faces_list data['face_normals_list'] = face_normals_list return data def render_gaussian(self, data, resolution): """ Render the scene. Background tensor (bg_color) must be on GPU! """ B = data['xyz'].shape[0] xyz = data['xyz'] #shs = rearrange(data['shs'], 'b n (x y) -> b n x y', y=3) colors_precomp = data['color'] opacity = data['opacity'] scales = data['scales'] rotations = data['rotation'] fovx = data['fovx'] fovy = data['fovy'] bg_color = self.bg_color if 'bg_color' not in data.keys() else data['bg_color'] world_view_transform = data['world_view_transform'] full_proj_transform = data['full_proj_transform'] camera_center = data['camera_center'] # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means screenspace_points = torch.zeros_like(xyz, dtype=xyz.dtype, requires_grad=True, device=xyz.device) + 0 try: screenspace_points.retain_grad() except: pass render_images = [] radii = [] for b in range(B): tanfovx = math.tan(fovx[b] * 0.5) tanfovy = math.tan(fovy[b] * 0.5) # Set up rasterization configuration raster_settings = GaussianRasterizationSettings( image_height=int(resolution), image_width=int(resolution), tanfovx=tanfovx, tanfovy=tanfovy, bg=bg_color.to(xyz.device), scale_modifier=self.scale_modifier, viewmatrix=world_view_transform[b], projmatrix=full_proj_transform[b], sh_degree=0, campos=camera_center[b], prefiltered=False, debug=False ) rasterizer = GaussianRasterizer(raster_settings=raster_settings) means3D = xyz[b] means2D = screenspace_points[b] # Rasterize visible Gaussians to image, obtain their radii (on screen). render_images_b, radii_b, _ = rasterizer( means3D = means3D, means2D = means2D, #shs = shs[b], colors_precomp = colors_precomp[b], opacities = opacity[b], scales = scales[b], rotations = rotations[b]) render_images.append(render_images_b) radii.append(radii_b) render_images = torch.stack(render_images) radii = torch.stack(radii) data['render_images'] = render_images data['viewspace_points'] = screenspace_points data['visibility_filter'] = radii > 0 data['radii'] = radii return data