|
|
|
import nvdiffrast.torch as dr |
|
import torch |
|
from typing import Tuple |
|
|
|
def _warmup(glctx, device=None): |
|
device = 'cuda' if device is None else device |
|
|
|
def tensor(*args, **kwargs): |
|
return torch.tensor(*args, device=device, **kwargs) |
|
pos = tensor([[[-0.8, -0.8, 0, 1], [0.8, -0.8, 0, 1], [-0.8, 0.8, 0, 1]]], dtype=torch.float32) |
|
tri = tensor([[0, 1, 2]], dtype=torch.int32) |
|
dr.rasterize(glctx, pos, tri, resolution=[256, 256]) |
|
|
|
class NormalsRenderer: |
|
|
|
_glctx:dr.RasterizeGLContext = None |
|
|
|
def __init__( |
|
self, |
|
mv: torch.Tensor, |
|
proj: torch.Tensor, |
|
image_size: Tuple[int,int], |
|
mvp = None, |
|
device=None, |
|
): |
|
if mvp is None: |
|
self._mvp = proj @ mv |
|
else: |
|
self._mvp = mvp |
|
self._image_size = image_size |
|
self._glctx = dr.RasterizeGLContext(output_db=False, device=device) |
|
_warmup(self._glctx, device) |
|
|
|
def render(self, |
|
vertices: torch.Tensor, |
|
normals: torch.Tensor, |
|
faces: torch.Tensor, |
|
) ->torch.Tensor: |
|
|
|
V = vertices.shape[0] |
|
faces = faces.type(torch.int32) |
|
vert_hom = torch.cat((vertices, torch.ones(V,1,device=vertices.device)),axis=-1) |
|
vertices_clip = vert_hom @ self._mvp.transpose(-2,-1) |
|
rast_out,_ = dr.rasterize(self._glctx, vertices_clip, faces, resolution=self._image_size, grad_db=False) |
|
vert_col = (normals+1)/2 |
|
col,_ = dr.interpolate(vert_col, rast_out, faces) |
|
alpha = torch.clamp(rast_out[..., -1:], max=1) |
|
col = torch.concat((col,alpha),dim=-1) |
|
col = dr.antialias(col, rast_out, vertices_clip, faces) |
|
return col |
|
|