pengc02's picture
all
ec9a6bc
import torch
import kaolin
import math
from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
class CameraModule():
def __init__(self):
self.bg_color = torch.tensor([1.0] * 32).float()
self.scale_modifier = 1.0
def perspective_camera(self, points, camera_proj):
projected_points = torch.bmm(points, camera_proj.permute(0, 2, 1))
projected_2d_points = projected_points[:, :, :2] / projected_points[:, :, 2:3]
return projected_2d_points
def prepare_vertices(self, vertices, faces, camera_proj, camera_rot=None, camera_trans=None,
camera_transform=None):
if camera_transform is None:
assert camera_trans is not None and camera_rot is not None, \
"camera_transform or camera_trans and camera_rot must be defined"
vertices_camera = kaolin.render.camera.rotate_translate_points(vertices, camera_rot,
camera_trans)
else:
assert camera_trans is None and camera_rot is None, \
"camera_trans and camera_rot must be None when camera_transform is defined"
padded_vertices = torch.nn.functional.pad(
vertices, (0, 1), mode='constant', value=1.
)
vertices_camera = (padded_vertices @ camera_transform)
# Project the vertices on the camera image plan
vertices_image = self.perspective_camera(vertices_camera, camera_proj)
face_vertices_camera = kaolin.ops.mesh.index_vertices_by_faces(vertices_camera, faces)
face_vertices_image = kaolin.ops.mesh.index_vertices_by_faces(vertices_image, faces)
face_normals = kaolin.ops.mesh.face_normals(face_vertices_camera, unit=True)
return face_vertices_camera, face_vertices_image, face_normals
def render(self, data, resolution):
verts_list = data['verts_list']
faces_list = data['faces_list']
verts_color_list = data['verts_color_list']
B = len(verts_list)
render_images = []
render_soft_masks = []
render_depths = []
render_normals = []
face_normals_list = []
for b in range(B):
intrinsics = data['intrinsics'][b]
extrinsics = data['extrinsics'][b]
#camera_proj = torch.stack([intrinsics[:, 0, 0] / intrinsics[:, 0, 2], intrinsics[:, 1, 1] / intrinsics[:, 1, 2], torch.ones_like(intrinsics[:, 0, 0])], -1).to(device)
camera_proj = intrinsics
camera_transform = extrinsics.permute(0, 2, 1)
verts = verts_list[b].unsqueeze(0).repeat(intrinsics.shape[0], 1, 1)
faces = faces_list[b]
verts_color = verts_color_list[b].unsqueeze(0).repeat(intrinsics.shape[0], 1, 1)
faces_color = verts_color[:, faces]
face_vertices_camera, face_vertices_image, face_normals = self.prepare_vertices(
verts, faces, camera_proj, camera_transform=camera_transform
)
face_vertices_image[:, :, :, 1] = -face_vertices_image[:, :, :, 1]
#face_vertices_camera[:, :, :, 1:] = -face_vertices_camera[:, :, :, 1:]
face_normals[:, :, 1:] = -face_normals[:, :, 1:]
### Perform Rasterization ###
# Construct attributes that DI1-R rasterizer will interpolate.
# the first is the UVS associated to each face
# the second will make a hard segmentation mask
face_attributes = [
faces_color,
torch.ones((faces_color.shape[0], faces_color.shape[1], 3, 1), device=verts.device),
face_vertices_camera[:, :, :, 2:],
face_normals.unsqueeze(-2).repeat(1, 1, 3, 1),
]
# If you have nvdiffrast installed you can change rast_backend to
# nvdiffrast or nvdiffrast_fwd
image_features, soft_masks, face_idx = kaolin.render.mesh.dibr_rasterization(
resolution, resolution, -face_vertices_camera[:, :, :, -1],
face_vertices_image, face_attributes, face_normals[:, :, -1],
rast_backend='cuda')
# image_features is a tuple in composed of the interpolated attributes of face_attributes
images, masks, depths, normals = image_features
images = torch.clamp(images * masks, 0., 1.)
depths = (depths * masks)
normals = (normals * masks)
render_images.append(images)
render_soft_masks.append(soft_masks)
render_depths.append(depths)
render_normals.append(normals)
face_normals_list.append(face_normals)
render_images = torch.stack(render_images, 0)
render_soft_masks = torch.stack(render_soft_masks, 0)
render_depths = torch.stack(render_depths, 0)
render_normals = torch.stack(render_normals, 0)
data['render_images'] = render_images
data['render_soft_masks'] = render_soft_masks
data['render_depths'] = render_depths
data['render_normals'] = render_normals
data['verts_list'] = verts_list
data['faces_list'] = faces_list
data['face_normals_list'] = face_normals_list
return data
def render_gaussian(self, data, resolution):
"""
Render the scene.
Background tensor (bg_color) must be on GPU!
"""
B = data['xyz'].shape[0]
xyz = data['xyz']
#shs = rearrange(data['shs'], 'b n (x y) -> b n x y', y=3)
colors_precomp = data['color']
opacity = data['opacity']
scales = data['scales']
rotations = data['rotation']
fovx = data['fovx']
fovy = data['fovy']
bg_color = self.bg_color if 'bg_color' not in data.keys() else data['bg_color']
world_view_transform = data['world_view_transform']
full_proj_transform = data['full_proj_transform']
camera_center = data['camera_center']
# Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
screenspace_points = torch.zeros_like(xyz, dtype=xyz.dtype, requires_grad=True, device=xyz.device) + 0
try:
screenspace_points.retain_grad()
except:
pass
render_images = []
radii = []
for b in range(B):
tanfovx = math.tan(fovx[b] * 0.5)
tanfovy = math.tan(fovy[b] * 0.5)
# Set up rasterization configuration
raster_settings = GaussianRasterizationSettings(
image_height=int(resolution),
image_width=int(resolution),
tanfovx=tanfovx,
tanfovy=tanfovy,
bg=bg_color.to(xyz.device),
scale_modifier=self.scale_modifier,
viewmatrix=world_view_transform[b],
projmatrix=full_proj_transform[b],
sh_degree=0,
campos=camera_center[b],
prefiltered=False,
debug=False
)
rasterizer = GaussianRasterizer(raster_settings=raster_settings)
means3D = xyz[b]
means2D = screenspace_points[b]
# Rasterize visible Gaussians to image, obtain their radii (on screen).
render_images_b, radii_b, _ = rasterizer(
means3D = means3D,
means2D = means2D,
#shs = shs[b],
colors_precomp = colors_precomp[b],
opacities = opacity[b],
scales = scales[b],
rotations = rotations[b])
render_images.append(render_images_b)
radii.append(radii_b)
render_images = torch.stack(render_images)
radii = torch.stack(radii)
data['render_images'] = render_images
data['viewspace_points'] = screenspace_points
data['visibility_filter'] = radii > 0
data['radii'] = radii
return data