|
import math |
|
from typing import Tuple |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from jaxtyping import Float, Integer |
|
from torch import Tensor |
|
|
|
from sf3d.models.utils import dot, triangle_intersection_2d |
|
|
|
|
|
def _box_assign_vertex_to_cube_face( |
|
vertex_positions: Float[Tensor, "Nv 3"], |
|
vertex_normals: Float[Tensor, "Nv 3"], |
|
triangle_idxs: Integer[Tensor, "Nf 3"], |
|
bbox: Float[Tensor, "2 3"], |
|
) -> Tuple[Float[Tensor, "Nf 3 2"], Integer[Tensor, "Nf 3"]]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
v_pos_normalized = (vertex_positions - bbox[:1]) / (bbox[1:] - bbox[:1]) |
|
|
|
v_pos_normalized = 2.0 * v_pos_normalized - 1.0 |
|
|
|
|
|
|
|
v0 = v_pos_normalized[triangle_idxs[:, 0]] |
|
v1 = v_pos_normalized[triangle_idxs[:, 1]] |
|
v2 = v_pos_normalized[triangle_idxs[:, 2]] |
|
tri_stack = torch.stack([v0, v1, v2], dim=1) |
|
|
|
vn0 = vertex_normals[triangle_idxs[:, 0]] |
|
vn1 = vertex_normals[triangle_idxs[:, 1]] |
|
vn2 = vertex_normals[triangle_idxs[:, 2]] |
|
tri_stack_nrm = torch.stack([vn0, vn1, vn2], dim=1) |
|
|
|
|
|
face_normal = F.normalize(torch.sum(tri_stack_nrm, 1), eps=1e-6, dim=-1) |
|
|
|
|
|
|
|
abs_x, abs_y, abs_z = tri_stack.abs().unbind(-1) |
|
|
|
axis = torch.tensor( |
|
[ |
|
[1, 0, 0], |
|
[-1, 0, 0], |
|
[0, 1, 0], |
|
[0, -1, 0], |
|
[0, 0, 1], |
|
[0, 0, -1], |
|
], |
|
device=face_normal.device, |
|
dtype=face_normal.dtype, |
|
) |
|
face_normal_axis = (face_normal[:, None] * axis[None]).sum(-1) |
|
index = face_normal_axis.argmax(-1) |
|
|
|
max_axis, uc, vc = ( |
|
torch.ones_like(abs_x), |
|
torch.zeros_like(tri_stack[..., :1]), |
|
torch.zeros_like(tri_stack[..., :1]), |
|
) |
|
mask_pos_x = index == 0 |
|
max_axis[mask_pos_x] = abs_x[mask_pos_x] |
|
uc[mask_pos_x] = tri_stack[mask_pos_x][..., 1:2] |
|
vc[mask_pos_x] = -tri_stack[mask_pos_x][..., -1:] |
|
|
|
mask_neg_x = index == 1 |
|
max_axis[mask_neg_x] = abs_x[mask_neg_x] |
|
uc[mask_neg_x] = tri_stack[mask_neg_x][..., 1:2] |
|
vc[mask_neg_x] = -tri_stack[mask_neg_x][..., -1:] |
|
|
|
mask_pos_y = index == 2 |
|
max_axis[mask_pos_y] = abs_y[mask_pos_y] |
|
uc[mask_pos_y] = tri_stack[mask_pos_y][..., 0:1] |
|
vc[mask_pos_y] = -tri_stack[mask_pos_y][..., -1:] |
|
|
|
mask_neg_y = index == 3 |
|
max_axis[mask_neg_y] = abs_y[mask_neg_y] |
|
uc[mask_neg_y] = tri_stack[mask_neg_y][..., 0:1] |
|
vc[mask_neg_y] = -tri_stack[mask_neg_y][..., -1:] |
|
|
|
mask_pos_z = index == 4 |
|
max_axis[mask_pos_z] = abs_z[mask_pos_z] |
|
uc[mask_pos_z] = tri_stack[mask_pos_z][..., 0:1] |
|
vc[mask_pos_z] = tri_stack[mask_pos_z][..., 1:2] |
|
|
|
mask_neg_z = index == 5 |
|
max_axis[mask_neg_z] = abs_z[mask_neg_z] |
|
uc[mask_neg_z] = tri_stack[mask_neg_z][..., 0:1] |
|
vc[mask_neg_z] = -tri_stack[mask_neg_z][..., 1:2] |
|
|
|
|
|
max_dim_div = max_axis.max(dim=0, keepdims=True).values |
|
uc = ((uc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1) |
|
vc = ((vc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1) |
|
|
|
uv = torch.stack([uc, vc], dim=-1) |
|
|
|
return uv, index |
|
|
|
|
|
def _assign_faces_uv_to_atlas_index( |
|
vertex_positions: Float[Tensor, "Nv 3"], |
|
triangle_idxs: Integer[Tensor, "Nf 3"], |
|
face_uv: Float[Tensor, "Nf 3 2"], |
|
face_index: Integer[Tensor, "Nf 3"], |
|
) -> Integer[Tensor, "Nf"]: |
|
triangle_pos = vertex_positions[triangle_idxs] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assign_idx = face_index.clone() |
|
for overlap_step in range(3): |
|
overlapping_indicator = torch.zeros_like(assign_idx, dtype=torch.bool) |
|
for i in range(overlap_step * 6, (overlap_step + 1) * 6): |
|
mask = assign_idx == i |
|
if not mask.any(): |
|
continue |
|
|
|
uv_triangle = face_uv[mask] |
|
cur_triangle_pos = triangle_pos[mask] |
|
|
|
center_uv = uv_triangle.mean(dim=1, keepdim=True) |
|
|
|
uv_triangle_radius = (uv_triangle - center_uv).norm(dim=-1).max(-1).values |
|
|
|
potentially_overlapping_mask = ( |
|
|
|
(center_uv[None, ...] - center_uv[:, None]).norm(dim=-1) |
|
|
|
+ torch.eye( |
|
uv_triangle.shape[0], |
|
device=uv_triangle.device, |
|
dtype=uv_triangle.dtype, |
|
).unsqueeze(-1) |
|
* 1000 |
|
) |
|
|
|
potentially_overlapping_mask = ( |
|
potentially_overlapping_mask |
|
<= (uv_triangle_radius.view(-1, 1, 1) * 3.0) |
|
).squeeze(-1) |
|
overlap_coords = torch.stack(torch.where(potentially_overlapping_mask), -1) |
|
|
|
|
|
f = torch.min(overlap_coords, dim=-1).values |
|
s = torch.max(overlap_coords, dim=-1).values |
|
overlap_coords = torch.unique(torch.stack([f, s], dim=1), dim=0) |
|
first, second = overlap_coords.unbind(-1) |
|
|
|
|
|
tri_1 = uv_triangle[first] |
|
tri_2 = uv_triangle[second] |
|
|
|
|
|
its = triangle_intersection_2d(tri_1, tri_2, eps=1e-6) |
|
|
|
|
|
|
|
|
|
|
|
ax = 0 if i < 2 else 1 if i < 4 else 2 |
|
use_max = i % 2 == 1 |
|
|
|
tri1_c = cur_triangle_pos[first].mean(dim=1) |
|
tri2_c = cur_triangle_pos[second].mean(dim=1) |
|
|
|
mark_first = ( |
|
(tri1_c[..., ax] > tri2_c[..., ax]) |
|
if use_max |
|
else (tri1_c[..., ax] < tri2_c[..., ax]) |
|
) |
|
first[mark_first] = second[mark_first] |
|
|
|
|
|
|
|
|
|
unique_idx, rev_idx = torch.unique(first, return_inverse=True) |
|
|
|
add = torch.zeros_like(unique_idx, dtype=torch.float32) |
|
add.index_add_(0, rev_idx, its.float()) |
|
its_mask = add > 0 |
|
|
|
|
|
idx = torch.where(mask)[0][unique_idx] |
|
overlapping_indicator[idx] = its_mask |
|
|
|
|
|
assign_idx[overlapping_indicator] += 6 |
|
|
|
|
|
max_idx = 6 * 2 |
|
return assign_idx.clamp(0, max_idx) |
|
|
|
|
|
def _find_slice_offset_and_scale( |
|
index: Integer[Tensor, "Nf"], |
|
) -> Tuple[ |
|
Float[Tensor, "Nf"], Float[Tensor, "Nf"], Float[Tensor, "Nf"], Float[Tensor, "Nf"] |
|
]: |
|
|
|
off = 1 / 3 |
|
dupl_off = 1 / 6 |
|
|
|
|
|
def x_offset_calc(x, i): |
|
offset_calc = i // 6 |
|
|
|
if offset_calc == 0: |
|
return off * x |
|
else: |
|
|
|
|
|
return dupl_off * x + min(offset_calc - 1, 1) * 0.5 |
|
|
|
def y_offset_calc(x, i): |
|
offset_calc = i // 6 |
|
|
|
if offset_calc == 0: |
|
return off * x |
|
else: |
|
|
|
return dupl_off * x + off * 2 |
|
|
|
offset_x = torch.zeros_like(index, dtype=torch.float32) |
|
offset_y = torch.zeros_like(index, dtype=torch.float32) |
|
offset_x_vals = [0, 1, 2, 0, 1, 2] |
|
offset_y_vals = [0, 0, 0, 1, 1, 1] |
|
for i in range(index.max().item() + 1): |
|
mask = index == i |
|
if not mask.any(): |
|
continue |
|
offset_x[mask] = x_offset_calc(offset_x_vals[i % 6], i) |
|
offset_y[mask] = y_offset_calc(offset_y_vals[i % 6], i) |
|
|
|
div_x = torch.full_like(index, 6 // 2, dtype=torch.float32) |
|
|
|
div_x[index >= 6] = 6 |
|
div_y = div_x.clone() |
|
|
|
div_x[index >= 12] = 2 |
|
|
|
div_y[index >= 12] = 3 |
|
|
|
return offset_x, offset_y, div_x, div_y |
|
|
|
|
|
def rotation_flip_matrix_2d( |
|
rad: float, flip_x: bool = False, flip_y: bool = False |
|
) -> Float[Tensor, "2 2"]: |
|
cos = math.cos(rad) |
|
sin = math.sin(rad) |
|
rot_mat = torch.tensor([[cos, -sin], [sin, cos]], dtype=torch.float32) |
|
flip_mat = torch.tensor( |
|
[ |
|
[-1 if flip_x else 1, 0], |
|
[0, -1 if flip_y else 1], |
|
], |
|
dtype=torch.float32, |
|
) |
|
|
|
return flip_mat @ rot_mat |
|
|
|
|
|
def calculate_tangents( |
|
vertex_positions: Float[Tensor, "Nv 3"], |
|
vertex_normals: Float[Tensor, "Nv 3"], |
|
triangle_idxs: Integer[Tensor, "Nf 3"], |
|
face_uv: Float[Tensor, "Nf 3 2"], |
|
) -> Float[Tensor, "Nf 3 4"]: |
|
vn_idx = [None] * 3 |
|
pos = [None] * 3 |
|
tex = face_uv.unbind(1) |
|
for i in range(0, 3): |
|
pos[i] = vertex_positions[triangle_idxs[:, i]] |
|
|
|
vn_idx[i] = triangle_idxs[:, i] |
|
|
|
tangents = torch.zeros_like(vertex_normals) |
|
tansum = torch.zeros_like(vertex_normals) |
|
|
|
|
|
duv1 = tex[1] - tex[0] |
|
duv2 = tex[2] - tex[0] |
|
dpos1 = pos[1] - pos[0] |
|
dpos2 = pos[2] - pos[0] |
|
|
|
tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2] |
|
|
|
denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1] |
|
|
|
|
|
denom_safe = denom.clip(1e-6) |
|
tang = tng_nom / denom_safe |
|
|
|
|
|
for i in range(0, 3): |
|
idx = vn_idx[i][:, None].repeat(1, 3) |
|
tangents.scatter_add_(0, idx, tang) |
|
tansum.scatter_add_( |
|
0, idx, torch.ones_like(tang) |
|
) |
|
|
|
|
|
tangents = tangents / tansum |
|
|
|
|
|
tangents = F.normalize(tangents, dim=1) |
|
tangents = F.normalize(tangents - dot(tangents, vertex_normals) * vertex_normals) |
|
|
|
return tangents |
|
|
|
|
|
def _rotate_uv_slices_consistent_space( |
|
vertex_positions: Float[Tensor, "Nv 3"], |
|
vertex_normals: Float[Tensor, "Nv 3"], |
|
triangle_idxs: Integer[Tensor, "Nf 3"], |
|
uv: Float[Tensor, "Nf 3 2"], |
|
index: Integer[Tensor, "Nf"], |
|
): |
|
tangents = calculate_tangents(vertex_positions, vertex_normals, triangle_idxs, uv) |
|
pos_stack = torch.stack( |
|
[ |
|
-vertex_positions[..., 1], |
|
vertex_positions[..., 0], |
|
torch.zeros_like(vertex_positions[..., 0]), |
|
], |
|
dim=-1, |
|
) |
|
expected_tangents = F.normalize( |
|
torch.linalg.cross( |
|
vertex_normals, torch.linalg.cross(pos_stack, vertex_normals) |
|
), |
|
-1, |
|
) |
|
|
|
actual_tangents = tangents[triangle_idxs] |
|
expected_tangents = expected_tangents[triangle_idxs] |
|
|
|
def rotation_matrix_2d(theta): |
|
c, s = torch.cos(theta), torch.sin(theta) |
|
return torch.tensor([[c, -s], [s, c]]) |
|
|
|
|
|
index_mod = index % 6 |
|
for i in range(6): |
|
mask = index_mod == i |
|
if not mask.any(): |
|
continue |
|
|
|
actual_mean_tangent = actual_tangents[mask].mean(dim=(0, 1)) |
|
expected_mean_tangent = expected_tangents[mask].mean(dim=(0, 1)) |
|
|
|
dot_product = torch.dot(actual_mean_tangent, expected_mean_tangent) |
|
cross_product = ( |
|
actual_mean_tangent[0] * expected_mean_tangent[1] |
|
- actual_mean_tangent[1] * expected_mean_tangent[0] |
|
) |
|
angle = torch.atan2(cross_product, dot_product) |
|
|
|
rot_matrix = rotation_matrix_2d(angle).to(mask.device) |
|
|
|
uv_cur = uv[mask] * 2 - 1 |
|
|
|
uv[mask] = torch.einsum("ij,nfj->nfi", rot_matrix, uv_cur) |
|
|
|
|
|
uv[mask] = (uv[mask] - uv[mask].min()) / (uv[mask].max() - uv[mask].min()) |
|
|
|
return uv |
|
|
|
|
|
def _handle_slice_uvs( |
|
uv: Float[Tensor, "Nf 3 2"], |
|
index: Integer[Tensor, "Nf"], |
|
island_padding: float, |
|
max_index: int = 6 * 2, |
|
) -> Float[Tensor, "Nf 3 2"]: |
|
uc, vc = uv.unbind(-1) |
|
|
|
|
|
index_filter = [index == i for i in range(6, max_index)] |
|
|
|
|
|
for i, fi in enumerate(index_filter): |
|
if fi.sum() > 0: |
|
|
|
|
|
uc[fi] = (uc[fi] - uc[fi].min()) / (uc[fi].max() - uc[fi].min()).clip(0.5) |
|
vc[fi] = (vc[fi] - vc[fi].min()) / (vc[fi].max() - vc[fi].min()).clip(0.5) |
|
|
|
uc_padded = (uc * (1 - 2 * island_padding) + island_padding).clip(0, 1) |
|
vc_padded = (vc * (1 - 2 * island_padding) + island_padding).clip(0, 1) |
|
|
|
return torch.stack([uc_padded, vc_padded], dim=-1) |
|
|
|
|
|
def _handle_remaining_uvs( |
|
uv: Float[Tensor, "Nf 3 2"], |
|
index: Integer[Tensor, "Nf"], |
|
island_padding: float, |
|
) -> Float[Tensor, "Nf 3 2"]: |
|
uc, vc = uv.unbind(-1) |
|
|
|
remaining_filter = index >= 6 * 2 |
|
squares_left = remaining_filter.sum() |
|
|
|
if squares_left == 0: |
|
return uv |
|
|
|
uc = uc[remaining_filter] |
|
vc = vc[remaining_filter] |
|
|
|
|
|
|
|
ratio = 0.5 * (1 / 3) |
|
|
|
|
|
mult = math.sqrt(squares_left / ratio) |
|
num_square_width = int(math.ceil(0.5 * mult)) |
|
num_square_height = int(math.ceil(squares_left / num_square_width)) |
|
|
|
width = 1 / num_square_width |
|
height = 1 / num_square_height |
|
|
|
|
|
|
|
|
|
clip_val = min(width, height) * 1.5 |
|
|
|
uc = (uc - uc.min(dim=1, keepdim=True).values) / ( |
|
uc.amax(dim=1, keepdim=True) - uc.amin(dim=1, keepdim=True) |
|
).clip(clip_val) |
|
vc = (vc - vc.min(dim=1, keepdim=True).values) / ( |
|
vc.amax(dim=1, keepdim=True) - vc.amin(dim=1, keepdim=True) |
|
).clip(clip_val) |
|
|
|
uc = ( |
|
uc * (1 - island_padding * num_square_width * 0.5) |
|
+ island_padding * num_square_width * 0.25 |
|
).clip(0, 1) |
|
vc = ( |
|
vc * (1 - island_padding * num_square_height * 0.5) |
|
+ island_padding * num_square_height * 0.25 |
|
).clip(0, 1) |
|
|
|
uc = uc * width |
|
vc = vc * height |
|
|
|
|
|
idx = torch.arange(uc.shape[0], device=uc.device, dtype=torch.int32) |
|
x_idx = idx % num_square_width |
|
y_idx = idx // num_square_width |
|
|
|
uc = uc + x_idx[:, None] * width |
|
vc = vc + y_idx[:, None] * height |
|
|
|
uc = (uc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1) |
|
vc = (vc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1) |
|
|
|
uv[remaining_filter] = torch.stack([uc, vc], dim=-1) |
|
|
|
return uv |
|
|
|
|
|
def _distribute_individual_uvs_in_atlas( |
|
face_uv: Float[Tensor, "Nf 3 2"], |
|
assigned_faces: Integer[Tensor, "Nf"], |
|
offset_x: Float[Tensor, "Nf"], |
|
offset_y: Float[Tensor, "Nf"], |
|
div_x: Float[Tensor, "Nf"], |
|
div_y: Float[Tensor, "Nf"], |
|
island_padding: float, |
|
): |
|
|
|
placed_uv = _handle_slice_uvs(face_uv, assigned_faces, island_padding) |
|
|
|
placed_uv = _handle_remaining_uvs(placed_uv, assigned_faces, island_padding) |
|
|
|
uc, vc = placed_uv.unbind(-1) |
|
uc = uc / div_x[:, None] + offset_x[:, None] |
|
vc = vc / div_y[:, None] + offset_y[:, None] |
|
|
|
uv = torch.stack([uc, vc], dim=-1).view(-1, 2) |
|
|
|
return uv |
|
|
|
|
|
def _get_unique_face_uv( |
|
uv: Float[Tensor, "Nf 3 2"], |
|
) -> Tuple[Float[Tensor, "Utex 3"], Integer[Tensor, "Nf"]]: |
|
unique_uv, unique_idx = torch.unique(uv, return_inverse=True, dim=0) |
|
|
|
vtex_idx = unique_idx.view(-1, 3) |
|
|
|
return unique_uv, vtex_idx |
|
|
|
|
|
def _align_mesh_with_main_axis( |
|
vertex_positions: Float[Tensor, "Nv 3"], vertex_normals: Float[Tensor, "Nv 3"] |
|
) -> Tuple[Float[Tensor, "Nv 3"], Float[Tensor, "Nv 3"]]: |
|
|
|
|
|
torch.manual_seed(0) |
|
_, _, v = torch.pca_lowrank(vertex_positions, q=2) |
|
main_axis, seconday_axis = v[:, 0], v[:, 1] |
|
|
|
main_axis: Float[Tensor, "3"] = F.normalize(main_axis, eps=1e-6, dim=-1) |
|
|
|
seconday_axis: Float[Tensor, "3"] = F.normalize( |
|
seconday_axis - dot(seconday_axis, main_axis) * main_axis, eps=1e-6, dim=-1 |
|
) |
|
|
|
third_axis: Float[Tensor, "3"] = F.normalize( |
|
torch.cross(main_axis, seconday_axis), dim=-1, eps=1e-6 |
|
) |
|
|
|
|
|
main_axis_max_idx = main_axis.abs().argmax().item() |
|
seconday_axis_max_idx = seconday_axis.abs().argmax().item() |
|
third_axis_max_idx = third_axis.abs().argmax().item() |
|
|
|
|
|
|
|
all_possible_axis = {0, 1, 2} |
|
cur_index = 1 |
|
while len(set([main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx])) != 3: |
|
|
|
missing_axis = all_possible_axis - set( |
|
[main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx] |
|
) |
|
missing_axis = missing_axis.pop() |
|
|
|
|
|
if cur_index == 1: |
|
third_axis_max_idx = missing_axis |
|
elif cur_index == 2: |
|
seconday_axis_max_idx = missing_axis |
|
else: |
|
raise ValueError("Could not find 3 unique axis") |
|
cur_index += 1 |
|
|
|
if len({main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx}) != 3: |
|
raise ValueError("Could not find 3 unique axis") |
|
|
|
axes = [None] * 3 |
|
axes[main_axis_max_idx] = main_axis |
|
axes[seconday_axis_max_idx] = seconday_axis |
|
axes[third_axis_max_idx] = third_axis |
|
|
|
rot_mat = torch.stack(axes, dim=1).T |
|
|
|
|
|
vertex_positions = torch.einsum("ij,nj->ni", rot_mat, vertex_positions) |
|
vertex_normals = torch.einsum("ij,nj->ni", rot_mat, vertex_normals) |
|
|
|
return vertex_positions, vertex_normals |
|
|
|
|
|
def box_projection_uv_unwrap( |
|
vertex_positions: Float[Tensor, "Nv 3"], |
|
vertex_normals: Float[Tensor, "Nv 3"], |
|
triangle_idxs: Integer[Tensor, "Nf 3"], |
|
island_padding: float, |
|
) -> Tuple[Float[Tensor, "Utex 3"], Integer[Tensor, "Nf"]]: |
|
|
|
vertex_positions, vertex_normals = _align_mesh_with_main_axis( |
|
vertex_positions, vertex_normals |
|
) |
|
|
|
bbox: Float[Tensor, "2 3"] = torch.stack( |
|
[vertex_positions.min(dim=0).values, vertex_positions.max(dim=0).values], dim=0 |
|
) |
|
|
|
face_uv, face_index = _box_assign_vertex_to_cube_face( |
|
vertex_positions, vertex_normals, triangle_idxs, bbox |
|
) |
|
|
|
|
|
face_uv = _rotate_uv_slices_consistent_space( |
|
vertex_positions, vertex_normals, triangle_idxs, face_uv, face_index |
|
) |
|
|
|
|
|
|
|
assigned_atlas_index = _assign_faces_uv_to_atlas_index( |
|
vertex_positions, triangle_idxs, face_uv, face_index |
|
) |
|
|
|
|
|
offset_x, offset_y, div_x, div_y = _find_slice_offset_and_scale( |
|
assigned_atlas_index |
|
) |
|
|
|
|
|
placed_uv = _distribute_individual_uvs_in_atlas( |
|
face_uv, assigned_atlas_index, offset_x, offset_y, div_x, div_y, island_padding |
|
) |
|
|
|
|
|
return _get_unique_face_uv(placed_uv) |
|
|