|
from typing import Optional, Tuple |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from jaxtyping import Float, Integer |
|
from torch import Tensor |
|
|
|
from .mesh import Mesh |
|
|
|
|
|
class IsosurfaceHelper(nn.Module): |
|
points_range: Tuple[float, float] = (0, 1) |
|
|
|
@property |
|
def grid_vertices(self) -> Float[Tensor, "N 3"]: |
|
raise NotImplementedError |
|
|
|
@property |
|
def requires_instance_per_batch(self) -> bool: |
|
return False |
|
|
|
|
|
class MarchingTetrahedraHelper(IsosurfaceHelper): |
|
def __init__(self, resolution: int, tets_path: str): |
|
super().__init__() |
|
self.resolution = resolution |
|
self.tets_path = tets_path |
|
|
|
self.triangle_table: Float[Tensor, "..."] |
|
self.register_buffer( |
|
"triangle_table", |
|
torch.as_tensor( |
|
[ |
|
[-1, -1, -1, -1, -1, -1], |
|
[1, 0, 2, -1, -1, -1], |
|
[4, 0, 3, -1, -1, -1], |
|
[1, 4, 2, 1, 3, 4], |
|
[3, 1, 5, -1, -1, -1], |
|
[2, 3, 0, 2, 5, 3], |
|
[1, 4, 0, 1, 5, 4], |
|
[4, 2, 5, -1, -1, -1], |
|
[4, 5, 2, -1, -1, -1], |
|
[4, 1, 0, 4, 5, 1], |
|
[3, 2, 0, 3, 5, 2], |
|
[1, 3, 5, -1, -1, -1], |
|
[4, 1, 2, 4, 3, 1], |
|
[3, 0, 4, -1, -1, -1], |
|
[2, 0, 1, -1, -1, -1], |
|
[-1, -1, -1, -1, -1, -1], |
|
], |
|
dtype=torch.long, |
|
), |
|
persistent=False, |
|
) |
|
self.num_triangles_table: Integer[Tensor, "..."] |
|
self.register_buffer( |
|
"num_triangles_table", |
|
torch.as_tensor( |
|
[0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long |
|
), |
|
persistent=False, |
|
) |
|
self.base_tet_edges: Integer[Tensor, "..."] |
|
self.register_buffer( |
|
"base_tet_edges", |
|
torch.as_tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long), |
|
persistent=False, |
|
) |
|
|
|
tets = np.load(self.tets_path) |
|
self._grid_vertices: Float[Tensor, "..."] |
|
self.register_buffer( |
|
"_grid_vertices", |
|
torch.from_numpy(tets["vertices"]).float(), |
|
persistent=False, |
|
) |
|
self.indices: Integer[Tensor, "..."] |
|
self.register_buffer( |
|
"indices", torch.from_numpy(tets["indices"]).long(), persistent=False |
|
) |
|
|
|
self._all_edges: Optional[Integer[Tensor, "Ne 2"]] = None |
|
|
|
center_indices, boundary_indices = self.get_center_boundary_index( |
|
self._grid_vertices |
|
) |
|
self.center_indices: Integer[Tensor, "..."] |
|
self.register_buffer("center_indices", center_indices, persistent=False) |
|
self.boundary_indices: Integer[Tensor, "..."] |
|
self.register_buffer("boundary_indices", boundary_indices, persistent=False) |
|
|
|
def get_center_boundary_index(self, verts): |
|
magn = torch.sum(verts**2, dim=-1) |
|
|
|
center_idx = torch.argmin(magn) |
|
boundary_neg = verts == verts.max() |
|
boundary_pos = verts == verts.min() |
|
|
|
boundary = torch.bitwise_or(boundary_pos, boundary_neg) |
|
boundary = torch.sum(boundary.float(), dim=-1) |
|
|
|
boundary_idx = torch.nonzero(boundary) |
|
return center_idx, boundary_idx.squeeze(dim=-1) |
|
|
|
def normalize_grid_deformation( |
|
self, grid_vertex_offsets: Float[Tensor, "Nv 3"] |
|
) -> Float[Tensor, "Nv 3"]: |
|
return ( |
|
(self.points_range[1] - self.points_range[0]) |
|
/ self.resolution |
|
* torch.tanh(grid_vertex_offsets) |
|
) |
|
|
|
@property |
|
def grid_vertices(self) -> Float[Tensor, "Nv 3"]: |
|
return self._grid_vertices |
|
|
|
@property |
|
def all_edges(self) -> Integer[Tensor, "Ne 2"]: |
|
if self._all_edges is None: |
|
|
|
edges = torch.tensor( |
|
[0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], |
|
dtype=torch.long, |
|
device=self.indices.device, |
|
) |
|
_all_edges = self.indices[:, edges].reshape(-1, 2) |
|
_all_edges_sorted = torch.sort(_all_edges, dim=1)[0] |
|
_all_edges = torch.unique(_all_edges_sorted, dim=0) |
|
self._all_edges = _all_edges |
|
return self._all_edges |
|
|
|
def sort_edges(self, edges_ex2): |
|
with torch.no_grad(): |
|
order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long() |
|
order = order.unsqueeze(dim=1) |
|
|
|
a = torch.gather(input=edges_ex2, index=order, dim=1) |
|
b = torch.gather(input=edges_ex2, index=1 - order, dim=1) |
|
|
|
return torch.stack([a, b], -1) |
|
|
|
def _forward(self, pos_nx3, sdf_n, tet_fx4): |
|
with torch.no_grad(): |
|
occ_n = sdf_n > 0 |
|
occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) |
|
occ_sum = torch.sum(occ_fx4, -1) |
|
valid_tets = (occ_sum > 0) & (occ_sum < 4) |
|
occ_sum = occ_sum[valid_tets] |
|
|
|
|
|
all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2) |
|
all_edges = self.sort_edges(all_edges) |
|
unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) |
|
|
|
unique_edges = unique_edges.long() |
|
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 |
|
mapping = ( |
|
torch.ones( |
|
(unique_edges.shape[0]), dtype=torch.long, device=pos_nx3.device |
|
) |
|
* -1 |
|
) |
|
mapping[mask_edges] = torch.arange( |
|
mask_edges.sum(), dtype=torch.long, device=pos_nx3.device |
|
) |
|
idx_map = mapping[idx_map] |
|
|
|
interp_v = unique_edges[mask_edges] |
|
edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3) |
|
edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1) |
|
edges_to_interp_sdf[:, -1] *= -1 |
|
|
|
denominator = edges_to_interp_sdf.sum(1, keepdim=True) |
|
|
|
edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator |
|
verts = (edges_to_interp * edges_to_interp_sdf).sum(1) |
|
|
|
idx_map = idx_map.reshape(-1, 6) |
|
|
|
v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=pos_nx3.device)) |
|
tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) |
|
num_triangles = self.num_triangles_table[tetindex] |
|
|
|
|
|
faces = torch.cat( |
|
( |
|
torch.gather( |
|
input=idx_map[num_triangles == 1], |
|
dim=1, |
|
index=self.triangle_table[tetindex[num_triangles == 1]][:, :3], |
|
).reshape(-1, 3), |
|
torch.gather( |
|
input=idx_map[num_triangles == 2], |
|
dim=1, |
|
index=self.triangle_table[tetindex[num_triangles == 2]][:, :6], |
|
).reshape(-1, 3), |
|
), |
|
dim=0, |
|
) |
|
|
|
return verts, faces |
|
|
|
def forward( |
|
self, |
|
level: Float[Tensor, "N3 1"], |
|
deformation: Optional[Float[Tensor, "N3 3"]] = None, |
|
) -> Mesh: |
|
if deformation is not None: |
|
grid_vertices = self.grid_vertices + self.normalize_grid_deformation( |
|
deformation |
|
) |
|
else: |
|
grid_vertices = self.grid_vertices |
|
|
|
v_pos, t_pos_idx = self._forward(grid_vertices, level, self.indices) |
|
|
|
mesh = Mesh( |
|
v_pos=v_pos, |
|
t_pos_idx=t_pos_idx, |
|
|
|
grid_vertices=grid_vertices, |
|
tet_edges=self.all_edges, |
|
grid_level=level, |
|
grid_deformation=deformation, |
|
) |
|
|
|
return mesh |
|
|