diff --git "a/pipeline.py" "b/pipeline.py" --- "a/pipeline.py" +++ "b/pipeline.py" @@ -1,3011 +1,14 @@ -import collections -import itertools -import math from dataclasses import dataclass -from typing import Dict, List, Optional, Set, Tuple, Union import cv2 import numpy as np -import nvdiffrast.torch as dr import rembg import torch -import torch.nn as nn import torch.nn.functional as F -import xatlas -from diffusers import ConfigMixin, DiffusionPipeline, ModelMixin +from diffusers import DiffusionPipeline from diffusers.utils import BaseOutput from PIL import Image from torchvision.transforms import v2 -from transformers import PreTrainedModel, ViTConfig, ViTImageProcessor -from transformers.activations import ACT2FN -from transformers.modeling_outputs import (BaseModelOutput, - BaseModelOutputWithPooling) -from transformers.pytorch_utils import (find_pruneable_heads_and_indices, - prune_linear_layer) - - -def generate_planes(): - """ - Defines planes by the three vectors that form the "axes" of the - plane. Should work with arbitrary number of planes and planes of - arbitrary orientation. - - Bugfix reference: https://github.com/NVlabs/eg3d/issues/67 - """ - return torch.tensor([[[1, 0, 0], - [0, 1, 0], - [0, 0, 1]], - [[1, 0, 0], - [0, 0, 1], - [0, 1, 0]], - [[0, 0, 1], - [0, 1, 0], - [1, 0, 0]]], dtype=torch.float32) - -def project_onto_planes(planes, coordinates): - """ - Does a projection of a 3D point onto a batch of 2D planes, - returning 2D plane coordinates. - - Takes plane axes of shape n_planes, 3, 3 - # Takes coordinates of shape N, M, 3 - # returns projections of shape N*n_planes, M, 2 - """ - N, M, C = coordinates.shape - n_planes, _, _ = planes.shape - coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3) - inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3) - projections = torch.bmm(coordinates, inv_planes) - return projections[..., :2] - -def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None): - assert padding_mode == 'zeros' - N, n_planes, C, H, W = plane_features.shape - _, M, _ = coordinates.shape - plane_features = plane_features.view(N*n_planes, C, H, W) - dtype = plane_features.dtype - - coordinates = (2/box_warp) * coordinates # add specific box bounds - - projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1) - output_features = torch.nn.functional.grid_sample( - plane_features, - projected_coordinates.to(dtype), - mode=mode, - padding_mode=padding_mode, - align_corners=False, - ).permute(0, 3, 2, 1).reshape(N, n_planes, M, C) - return output_features - - -class OSGDecoder(nn.Module): - """ - Triplane decoder that gives RGB and sigma values from sampled features. - Using ReLU here instead of Softplus in the original implementation. - - Reference: - EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112 - """ - def __init__(self, n_features: int, - hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU): - super().__init__() - - self.net_sdf = nn.Sequential( - nn.Linear(3 * n_features, hidden_dim), - activation(), - *itertools.chain(*[[ - nn.Linear(hidden_dim, hidden_dim), - activation(), - ] for _ in range(num_layers - 2)]), - nn.Linear(hidden_dim, 1), - ) - self.net_rgb = nn.Sequential( - nn.Linear(3 * n_features, hidden_dim), - activation(), - *itertools.chain(*[[ - nn.Linear(hidden_dim, hidden_dim), - activation(), - ] for _ in range(num_layers - 2)]), - nn.Linear(hidden_dim, 3), - ) - self.net_deformation = nn.Sequential( - nn.Linear(3 * n_features, hidden_dim), - activation(), - *itertools.chain(*[[ - nn.Linear(hidden_dim, hidden_dim), - activation(), - ] for _ in range(num_layers - 2)]), - nn.Linear(hidden_dim, 3), - ) - self.net_weight = nn.Sequential( - nn.Linear(8 * 3 * n_features, hidden_dim), - activation(), - *itertools.chain(*[[ - nn.Linear(hidden_dim, hidden_dim), - activation(), - ] for _ in range(num_layers - 2)]), - nn.Linear(hidden_dim, 21), - ) - - # init all bias to zero - for m in self.modules(): - if isinstance(m, nn.Linear): - nn.init.zeros_(m.bias) - - def get_geometry_prediction(self, sampled_features, flexicubes_indices): - _N, n_planes, _M, _C = sampled_features.shape - sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) - - sdf = self.net_sdf(sampled_features) - deformation = self.net_deformation(sampled_features) - - grid_features = torch.index_select(input=sampled_features, index=flexicubes_indices.reshape(-1), dim=1) - grid_features = grid_features.reshape( - sampled_features.shape[0], flexicubes_indices.shape[0], flexicubes_indices.shape[1] * sampled_features.shape[-1]) - weight = self.net_weight(grid_features) * 0.1 - - return sdf, deformation, weight - - def get_texture_prediction(self, sampled_features): - _N, n_planes, _M, _C = sampled_features.shape - sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) - - rgb = self.net_rgb(sampled_features) - rgb = torch.sigmoid(rgb)*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF - - return rgb - - -class TriplaneSynthesizer(nn.Module): - """ - Synthesizer that renders a triplane volume with planes and a camera. - - Reference: - EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19 - """ - - DEFAULT_RENDERING_KWARGS = { - 'ray_start': 'auto', - 'ray_end': 'auto', - 'box_warp': 2., - 'white_back': True, - 'disparity_space_sampling': False, - 'clamp_mode': 'softplus', - 'sampler_bbox_min': -1., - 'sampler_bbox_max': 1., - } - - def __init__(self, triplane_dim: int, samples_per_ray: int): - super().__init__() - - # attributes - self.triplane_dim = triplane_dim - self.rendering_kwargs = { - **self.DEFAULT_RENDERING_KWARGS, - 'depth_resolution': samples_per_ray // 2, - 'depth_resolution_importance': samples_per_ray // 2, - } - - # modules - self.plane_axes = generate_planes() - self.decoder = OSGDecoder(n_features=triplane_dim) - - def get_geometry_prediction(self, planes, sample_coordinates, flexicubes_indices): - plane_axes = self.plane_axes.to(planes.device) - sampled_features = sample_from_planes( - plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp']) - - sdf, deformation, weight = self.decoder.get_geometry_prediction(sampled_features, flexicubes_indices) - return sdf, deformation, weight - - def get_texture_prediction(self, planes, sample_coordinates): - plane_axes = self.plane_axes.to(planes.device) - sampled_features = sample_from_planes( - plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp']) - - rgb = self.decoder.get_texture_prediction(sampled_features) - return rgb - - - -dmc_table = [ -[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]], -[[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]], -[[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]] -] -num_vd_table = [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2, -2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, -1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1, -1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2, -2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, -3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1, -2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, -1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, -1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0] -check_table = [ -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 1, 0, 0, 194], -[1, -1, 0, 0, 193], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 1, 0, 164], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, -1, 0, 161], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 0, 1, 152], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 0, 1, 145], -[1, 0, 0, 1, 144], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 0, -1, 137], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 1, 0, 133], -[1, 0, 1, 0, 132], -[1, 1, 0, 0, 131], -[1, 1, 0, 0, 130], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 0, 1, 100], -[0, 0, 0, 0, 0], -[1, 0, 0, 1, 98], -[0, 0, 0, 0, 0], -[1, 0, 0, 1, 96], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 1, 0, 88], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, -1, 0, 82], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 1, 0, 74], -[0, 0, 0, 0, 0], -[1, 0, 1, 0, 72], -[0, 0, 0, 0, 0], -[1, 0, 0, -1, 70], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, -1, 0, 0, 67], -[0, 0, 0, 0, 0], -[1, -1, 0, 0, 65], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 1, 0, 0, 56], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, -1, 0, 0, 52], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 1, 0, 0, 44], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 1, 0, 0, 40], -[0, 0, 0, 0, 0], -[1, 0, 0, -1, 38], -[1, 0, -1, 0, 37], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, -1, 0, 33], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, -1, 0, 0, 28], -[0, 0, 0, 0, 0], -[1, 0, -1, 0, 26], -[1, 0, 0, -1, 25], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, -1, 0, 0, 20], -[0, 0, 0, 0, 0], -[1, 0, -1, 0, 18], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 0, -1, 9], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 0, -1, 6], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0] -] -tet_table = [ -[-1, -1, -1, -1, -1, -1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[4, 4, 4, 4, 4, 4], -[0, 0, 0, 0, 0, 0], -[4, 0, 0, 4, 4, -1], -[1, 1, 1, 1, 1, 1], -[4, 4, 4, 4, 4, 4], -[0, 4, 0, 4, 4, -1], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[5, 5, 5, 5, 5, 5], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[2, 0, 2, -1, 0, 2], -[1, 1, 1, 1, 1, 1], -[2, -1, 2, 4, 4, 2], -[0, 0, 0, 0, 0, 0], -[2, 0, 2, 4, 4, 2], -[1, 1, 1, 1, 1, 1], -[2, 4, 2, 4, 4, 2], -[0, 4, 0, 4, 4, 0], -[2, 0, 2, 0, 0, 2], -[1, 1, 1, 1, 1, 1], -[2, 5, 2, 5, 5, 2], -[0, 0, 0, 0, 0, 0], -[2, 0, 2, 0, 0, 2], -[1, 1, 1, 1, 1, 1], -[1, 1, 1, 1, 1, 1], -[0, 1, 1, -1, 0, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[4, 1, 1, 4, 4, 1], -[0, 1, 1, 0, 0, 1], -[4, 0, 0, 4, 4, 0], -[2, 2, 2, 2, 2, 2], -[-1, 1, 1, 4, 4, 1], -[0, 1, 1, 4, 4, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[5, 1, 1, 5, 5, 1], -[0, 1, 1, 0, 0, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[8, 8, 8, 8, 8, 8], -[1, 1, 1, 4, 4, 1], -[0, 0, 0, 0, 0, 0], -[4, 0, 0, 4, 4, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 4, 4, 1], -[0, 4, 0, 4, 4, 0], -[0, 0, 0, 0, 0, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 5, 5, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[5, 5, 5, 5, 5, 5], -[6, 6, 6, 6, 6, 6], -[6, -1, 0, 6, 0, 6], -[6, 0, 0, 6, 0, 6], -[6, 1, 1, 6, 1, 6], -[4, 4, 4, 4, 4, 4], -[0, 0, 0, 0, 0, 0], -[4, 0, 0, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[6, 4, -1, 6, 4, 6], -[6, 4, 0, 6, 4, 6], -[6, 0, 0, 6, 0, 6], -[6, 1, 1, 6, 1, 6], -[5, 5, 5, 5, 5, 5], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[2, 0, 2, 2, 0, 2], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[2, 0, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[2, 4, 2, 2, 4, 2], -[0, 4, 0, 4, 4, 0], -[2, 0, 2, 2, 0, 2], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[6, 1, 1, 6, -1, 6], -[6, 1, 1, 6, 0, 6], -[6, 0, 0, 6, 0, 6], -[6, 2, 2, 6, 2, 6], -[4, 1, 1, 4, 4, 1], -[0, 1, 1, 0, 0, 1], -[4, 0, 0, 4, 4, 4], -[2, 2, 2, 2, 2, 2], -[6, 1, 1, 6, 4, 6], -[6, 1, 1, 6, 4, 6], -[6, 0, 0, 6, 0, 6], -[6, 2, 2, 6, 2, 6], -[5, 1, 1, 5, 5, 1], -[0, 1, 1, 0, 0, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[6, 6, 6, 6, 6, 6], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 1, 4, 1], -[0, 4, 0, 4, 4, 0], -[0, 0, 0, 0, 0, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 5, 0, 5, 0, 5], -[5, 5, 5, 5, 5, 5], -[5, 5, 5, 5, 5, 5], -[0, 5, 0, 5, 0, 5], -[-1, 5, 0, 5, 0, 5], -[1, 5, 1, 5, 1, 5], -[4, 5, -1, 5, 4, 5], -[0, 5, 0, 5, 0, 5], -[4, 5, 0, 5, 4, 5], -[1, 5, 1, 5, 1, 5], -[4, 4, 4, 4, 4, 4], -[0, 4, 0, 4, 4, 4], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[6, 6, 6, 6, 6, 6], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[2, 5, 2, 5, -1, 5], -[0, 5, 0, 5, 0, 5], -[2, 5, 2, 5, 0, 5], -[1, 5, 1, 5, 1, 5], -[2, 5, 2, 5, 4, 5], -[0, 5, 0, 5, 0, 5], -[2, 5, 2, 5, 4, 5], -[1, 5, 1, 5, 1, 5], -[2, 4, 2, 4, 4, 2], -[0, 4, 0, 4, 4, 4], -[2, 0, 2, 0, 0, 2], -[1, 1, 1, 1, 1, 1], -[2, 6, 2, 6, 6, 2], -[0, 0, 0, 0, 0, 0], -[2, 0, 2, 0, 0, 2], -[1, 1, 1, 1, 1, 1], -[1, 1, 1, 1, 1, 1], -[0, 1, 1, 1, 0, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[4, 1, 1, 1, 4, 1], -[0, 1, 1, 1, 0, 1], -[4, 0, 0, 4, 4, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[5, 5, 5, 5, 5, 5], -[1, 1, 1, 1, 4, 1], -[0, 0, 0, 0, 0, 0], -[4, 0, 0, 4, 4, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[6, 0, 0, 6, 0, 6], -[0, 0, 0, 0, 0, 0], -[6, 6, 6, 6, 6, 6], -[5, 5, 5, 5, 5, 5], -[5, 5, 0, 5, 0, 5], -[5, 5, 0, 5, 0, 5], -[5, 5, 1, 5, 1, 5], -[4, 4, 4, 4, 4, 4], -[0, 0, 0, 0, 0, 0], -[4, 4, 0, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[4, 4, 4, 4, 4, 4], -[4, 4, 0, 4, 4, 4], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[8, 8, 8, 8, 8, 8], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 0, 2], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[4, 1, 1, 4, 4, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[1, 1, 1, 1, 1, 1], -[1, 1, 1, 1, 0, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[2, 4, 2, 4, 4, 2], -[1, 1, 1, 1, 1, 1], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[5, 5, 5, 5, 5, 5], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[12, 12, 12, 12, 12, 12] -] - - -class FlexiCubes: - def __init__(self, device="cuda", qef_reg_scale=1e-3, weight_scale=0.99): - self.device = device - self.dmc_table = torch.tensor( - dmc_table, dtype=torch.long, device=device, requires_grad=False - ) - self.num_vd_table = torch.tensor( - num_vd_table, dtype=torch.long, device=device, requires_grad=False - ) - self.check_table = torch.tensor( - check_table, dtype=torch.long, device=device, requires_grad=False - ) - - self.tet_table = torch.tensor( - tet_table, dtype=torch.long, device=device, requires_grad=False - ) - self.quad_split_1 = torch.tensor( - [0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False - ) - self.quad_split_2 = torch.tensor( - [0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False - ) - self.quad_split_train = torch.tensor( - [0, 1, 1, 2, 2, 3, 3, 0], - dtype=torch.long, - device=device, - requires_grad=False, - ) - - self.cube_corners = torch.tensor( - [ - [0, 0, 0], - [1, 0, 0], - [0, 1, 0], - [1, 1, 0], - [0, 0, 1], - [1, 0, 1], - [0, 1, 1], - [1, 1, 1], - ], - dtype=torch.float, - device=device, - ) - self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False)) - self.cube_edges = torch.tensor( - [0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, 2, 0, 3, 1, 7, 5, 6, 4], - dtype=torch.long, - device=device, - requires_grad=False, - ) - - self.edge_dir_table = torch.tensor( - [0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1], dtype=torch.long, device=device - ) - self.dir_faces_table = torch.tensor( - [ - [[5, 4], [3, 2], [4, 5], [2, 3]], - [[5, 4], [1, 0], [4, 5], [0, 1]], - [[3, 2], [1, 0], [2, 3], [0, 1]], - ], - dtype=torch.long, - device=device, - ) - self.adj_pairs = torch.tensor( - [0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device - ) - self.qef_reg_scale = qef_reg_scale - self.weight_scale = weight_scale - - def construct_voxel_grid(self, res): - """ - Generates a voxel grid based on the specified resolution. - - Args: - res (int or list[int]): The resolution of the voxel grid. If an integer - is provided, it is used for all three dimensions. If a list or tuple - of 3 integers is provided, they define the resolution for the x, - y, and z dimensions respectively. - - Returns: - (torch.Tensor, torch.Tensor): Returns the vertices and the indices of the - cube corners (index into vertices) of the constructed voxel grid. - The vertices are centered at the origin, with the length of each - dimension in the grid being one. - """ - base_cube_f = torch.arange(8).to(self.device) - if isinstance(res, int): - res = (res, res, res) - voxel_grid_template = torch.ones(res, device=self.device) - - res = torch.tensor([res], dtype=torch.float, device=self.device) - coords = torch.nonzero(voxel_grid_template).float() / res # N, 3 - verts = (self.cube_corners.unsqueeze(0) / res + coords.unsqueeze(1)).reshape( - -1, 3 - ) - cubes = ( - base_cube_f.unsqueeze(0) - + torch.arange(coords.shape[0], device=self.device).unsqueeze(1) * 8 - ).reshape(-1) - - verts_rounded = torch.round(verts * 10**5) / (10**5) - verts_unique, inverse_indices = torch.unique( - verts_rounded, dim=0, return_inverse=True - ) - cubes = inverse_indices[cubes.reshape(-1)].reshape(-1, 8) - - return verts_unique - 0.5, cubes - - def __call__( - self, - x_nx3, - s_n, - cube_fx8, - res, - beta_fx12=None, - alpha_fx8=None, - gamma_f=None, - training=False, - output_tetmesh=False, - grad_func=None, - ): - r""" - Main function for mesh extraction from scalar field using FlexiCubes. This function converts - discrete signed distance fields, encoded on voxel grids and additional per-cube parameters, - to triangle or tetrahedral meshes using a differentiable operation as described in - `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_. FlexiCubes enhances - mesh quality and geometric fidelity by adjusting the surface representation based on gradient - optimization. The output surface is differentiable with respect to the input vertex positions, - scalar field values, and weight parameters. - - If you intend to extract a surface mesh from a fixed Signed Distance Field without the - optimization of parameters, it is suggested to provide the "grad_func" which should - return the surface gradient at any given 3D position. When grad_func is provided, the process - to determine the dual vertex position adapts to solve a Quadratic Error Function (QEF), as - described in the `Manifold Dual Contouring`_ paper, and employs an smart splitting strategy. - Please note, this approach is non-differentiable. - - For more details and example usage in optimization, refer to the - `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_ SIGGRAPH 2023 paper. - - Args: - x_nx3 (torch.Tensor): Coordinates of the voxel grid vertices, can be deformed. - s_n (torch.Tensor): Scalar field values at each vertex of the voxel grid. Negative values - denote that the corresponding vertex resides inside the isosurface. This affects - the directions of the extracted triangle faces and volume to be tetrahedralized. - cube_fx8 (torch.Tensor): Indices of 8 vertices for each cube in the voxel grid. - res (int or list[int]): The resolution of the voxel grid. If an integer is provided, it - is used for all three dimensions. If a list or tuple of 3 integers is provided, they - specify the resolution for the x, y, and z dimensions respectively. - beta_fx12 (torch.Tensor, optional): Weight parameters for the cube edges to adjust dual - vertices positioning. Defaults to uniform value for all edges. - alpha_fx8 (torch.Tensor, optional): Weight parameters for the cube corners to adjust dual - vertices positioning. Defaults to uniform value for all vertices. - gamma_f (torch.Tensor, optional): Weight parameters to control the splitting of - quadrilaterals into triangles. Defaults to uniform value for all cubes. - training (bool, optional): If set to True, applies differentiable quad splitting for - training. Defaults to False. - output_tetmesh (bool, optional): If set to True, outputs a tetrahedral mesh, otherwise, - outputs a triangular mesh. Defaults to False. - grad_func (callable, optional): A function to compute the surface gradient at specified - 3D positions (input: Nx3 positions). The function should return gradients as an Nx3 - tensor. If None, the original FlexiCubes algorithm is utilized. Defaults to None. - - Returns: - (torch.Tensor, torch.LongTensor, torch.Tensor): Tuple containing: - - Vertices for the extracted triangular/tetrahedral mesh. - - Faces for the extracted triangular/tetrahedral mesh. - - Regularizer L_dev, computed per dual vertex. - - .. _Flexible Isosurface Extraction for Gradient-Based Mesh Optimization: - https://research.nvidia.com/labs/toronto-ai/flexicubes/ - .. _Manifold Dual Contouring: - https://people.engr.tamu.edu/schaefer/research/dualsimp_tvcg.pdf - """ - - surf_cubes, occ_fx8 = self._identify_surf_cubes(s_n, cube_fx8) - if surf_cubes.sum() == 0: - return ( - torch.zeros((0, 3), device=self.device), - ( - torch.zeros((0, 4), dtype=torch.long, device=self.device) - if output_tetmesh - else torch.zeros((0, 3), dtype=torch.long, device=self.device) - ), - torch.zeros((0), device=self.device), - ) - beta_fx12, alpha_fx8, gamma_f = self._normalize_weights( - beta_fx12, alpha_fx8, gamma_f, surf_cubes - ) - - case_ids = self._get_case_id(occ_fx8, surf_cubes, res) - - surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges( - s_n, cube_fx8, surf_cubes - ) - - vd, L_dev, vd_gamma, vd_idx_map = self._compute_vd( - x_nx3, - cube_fx8[surf_cubes], - surf_edges, - s_n, - case_ids, - beta_fx12, - alpha_fx8, - gamma_f, - idx_map, - grad_func, - ) - vertices, faces, s_edges, edge_indices = self._triangulate( - s_n, - surf_edges, - vd, - vd_gamma, - edge_counts, - idx_map, - vd_idx_map, - surf_edges_mask, - training, - grad_func, - ) - if not output_tetmesh: - return vertices, faces, L_dev - else: - vertices, tets = self._tetrahedralize( - x_nx3, - s_n, - cube_fx8, - vertices, - faces, - surf_edges, - s_edges, - vd_idx_map, - case_ids, - edge_indices, - surf_cubes, - training, - ) - return vertices, tets, L_dev - - def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges): - """ - Regularizer L_dev as in Equation 8 - """ - dist = torch.norm( - ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1 - ) - mean_l2 = torch.zeros_like(vd[:, 0]) - mean_l2 = (mean_l2).index_add_( - 0, edge_group_to_vd, dist - ) / vd_num_edges.squeeze(1).float() - mad = ( - dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0) - ).abs() - return mad - - def _normalize_weights(self, beta_fx12, alpha_fx8, gamma_f, surf_cubes): - """ - Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones. - """ - n_cubes = surf_cubes.shape[0] - - if beta_fx12 is not None: - beta_fx12 = torch.tanh(beta_fx12) * self.weight_scale + 1 - else: - beta_fx12 = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device) - - if alpha_fx8 is not None: - alpha_fx8 = torch.tanh(alpha_fx8) * self.weight_scale + 1 - else: - alpha_fx8 = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device) - - if gamma_f is not None: - gamma_f = ( - torch.sigmoid(gamma_f) * self.weight_scale + (1 - self.weight_scale) / 2 - ) - else: - gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device) - - return beta_fx12[surf_cubes], alpha_fx8[surf_cubes], gamma_f[surf_cubes] - - @torch.no_grad() - def _get_case_id(self, occ_fx8, surf_cubes, res): - """ - Obtains the ID of topology cases based on cell corner occupancy. This function resolves the - ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the - supplementary material. It should be noted that this function assumes a regular grid. - """ - case_ids = ( - occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0) - ).sum(-1) - - problem_config = self.check_table.to(self.device)[case_ids] - to_check = problem_config[..., 0] == 1 - problem_config = problem_config[to_check] - if not isinstance(res, (list, tuple)): - res = [res, res, res] - - # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array, - # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes). - # This allows efficient checking on adjacent cubes. - problem_config_full = torch.zeros( - list(res) + [5], device=self.device, dtype=torch.long - ) - vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3 - vol_idx_problem = vol_idx[surf_cubes][to_check] - problem_config_full[ - vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2] - ] = problem_config - vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4] - - within_range = ( - (vol_idx_problem_adj[..., 0] >= 0) - & (vol_idx_problem_adj[..., 0] < res[0]) - & (vol_idx_problem_adj[..., 1] >= 0) - & (vol_idx_problem_adj[..., 1] < res[1]) - & (vol_idx_problem_adj[..., 2] >= 0) - & (vol_idx_problem_adj[..., 2] < res[2]) - ) - - vol_idx_problem = vol_idx_problem[within_range] - vol_idx_problem_adj = vol_idx_problem_adj[within_range] - problem_config = problem_config[within_range] - problem_config_adj = problem_config_full[ - vol_idx_problem_adj[..., 0], - vol_idx_problem_adj[..., 1], - vol_idx_problem_adj[..., 2], - ] - # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted. - to_invert = problem_config_adj[..., 0] == 1 - idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][ - within_range - ][to_invert] - case_ids.index_put_((idx,), problem_config[to_invert][..., -1]) - return case_ids - - @torch.no_grad() - def _identify_surf_edges(self, s_n, cube_fx8, surf_cubes): - """ - Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge - can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge - and marks the cube edges with this index. - """ - occ_n = s_n < 0 - all_edges = cube_fx8[surf_cubes][:, self.cube_edges].reshape(-1, 2) - unique_edges, _idx_map, counts = torch.unique( - all_edges, dim=0, return_inverse=True, return_counts=True - ) - - unique_edges = unique_edges.long() - mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 - - surf_edges_mask = mask_edges[_idx_map] - counts = counts[_idx_map] - - mapping = ( - torch.ones( - (unique_edges.shape[0]), dtype=torch.long, device=cube_fx8.device - ) - * -1 - ) - mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_fx8.device) - # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index - # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1. - idx_map = mapping[_idx_map] - surf_edges = unique_edges[mask_edges] - return surf_edges, idx_map, counts, surf_edges_mask - - @torch.no_grad() - def _identify_surf_cubes(self, s_n, cube_fx8): - """ - Identifies grid cubes that intersect with the underlying surface by checking if the signs at - all corners are not identical. - """ - occ_n = s_n < 0 - occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8) - _occ_sum = torch.sum(occ_fx8, -1) - surf_cubes = (_occ_sum > 0) & (_occ_sum < 8) - return surf_cubes, occ_fx8 - - def _linear_interp(self, edges_weight, edges_x): - """ - Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'. - """ - edge_dim = edges_weight.dim() - 2 - assert edges_weight.shape[edge_dim] == 2 - edges_weight = torch.cat( - [ - torch.index_select( - input=edges_weight, - index=torch.tensor(1, device=self.device), - dim=edge_dim, - ), - -torch.index_select( - input=edges_weight, - index=torch.tensor(0, device=self.device), - dim=edge_dim, - ), - ], - edge_dim, - ) - denominator = edges_weight.sum(edge_dim) - ue = (edges_x * edges_weight).sum(edge_dim) / denominator - return ue - - def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3=None): - p_bxnx3 = p_bxnx3.reshape(-1, 7, 3) - norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3) - c_bx3 = c_bx3.reshape(-1, 3) - A = norm_bxnx3 - B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True) - - A_reg = ( - (torch.eye(3, device=p_bxnx3.device) * self.qef_reg_scale) - .unsqueeze(0) - .repeat(p_bxnx3.shape[0], 1, 1) - ) - B_reg = (self.qef_reg_scale * c_bx3).unsqueeze(-1) - A = torch.cat([A, A_reg], 1) - B = torch.cat([B, B_reg], 1) - dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1) - return dual_verts - - def _compute_vd( - self, - x_nx3, - surf_cubes_fx8, - surf_edges, - s_n, - case_ids, - beta_fx12, - alpha_fx8, - gamma_f, - idx_map, - grad_func, - ): - """ - Computes the location of dual vertices as described in Section 4.2 - """ - alpha_nx12x2 = torch.index_select( - input=alpha_fx8, index=self.cube_edges, dim=1 - ).reshape(-1, 12, 2) - surf_edges_x = torch.index_select( - input=x_nx3, index=surf_edges.reshape(-1), dim=0 - ).reshape(-1, 2, 3) - surf_edges_s = torch.index_select( - input=s_n, index=surf_edges.reshape(-1), dim=0 - ).reshape(-1, 2, 1) - zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x) - - idx_map = idx_map.reshape(-1, 12) - num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0) - edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = ( - [], - [], - [], - [], - [], - ) - - total_num_vd = 0 - vd_idx_map = torch.zeros( - (case_ids.shape[0], 12), - dtype=torch.long, - device=self.device, - requires_grad=False, - ) - if grad_func is not None: - normals = torch.nn.functional.normalize(grad_func(zero_crossing), dim=-1) - vd = [] - for num in torch.unique(num_vd): - cur_cubes = ( - num_vd == num - ) # consider cubes with the same numbers of vd emitted (for batching) - curr_num_vd = cur_cubes.sum() * num - curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape( - -1, num * 7 - ) - curr_edge_group_to_vd = ( - torch.arange(curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) - + total_num_vd - ) - total_num_vd += curr_num_vd - curr_edge_group_to_cube = ( - torch.arange(idx_map.shape[0], device=self.device)[cur_cubes] - .unsqueeze(-1) - .repeat(1, num * 7) - .reshape_as(curr_edge_group) - ) - - curr_mask = curr_edge_group != -1 - edge_group.append(torch.masked_select(curr_edge_group, curr_mask)) - edge_group_to_vd.append( - torch.masked_select( - curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask - ) - ) - edge_group_to_cube.append( - torch.masked_select(curr_edge_group_to_cube, curr_mask) - ) - vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True)) - vd_gamma.append( - torch.masked_select(gamma_f, cur_cubes) - .unsqueeze(-1) - .repeat(1, num) - .reshape(-1) - ) - - if grad_func is not None: - with torch.no_grad(): - cube_e_verts_idx = idx_map[cur_cubes] - curr_edge_group[~curr_mask] = 0 - - verts_group_idx = torch.gather( - input=cube_e_verts_idx, dim=1, index=curr_edge_group - ) - verts_group_idx[verts_group_idx == -1] = 0 - verts_group_pos = torch.index_select( - input=zero_crossing, index=verts_group_idx.reshape(-1), dim=0 - ).reshape(-1, num.item(), 7, 3) - v0 = ( - x_nx3[surf_cubes_fx8[cur_cubes][:, 0]] - .reshape(-1, 1, 1, 3) - .repeat(1, num.item(), 1, 1) - ) - curr_mask = curr_mask.reshape(-1, num.item(), 7, 1) - verts_centroid = (verts_group_pos * curr_mask).sum(2) / ( - curr_mask.sum(2) - ) - - normals_bx7x3 = torch.index_select( - input=normals, index=verts_group_idx.reshape(-1), dim=0 - ).reshape(-1, num.item(), 7, 3) - curr_mask = curr_mask.squeeze(2) - vd.append( - self._solve_vd_QEF( - (verts_group_pos - v0) * curr_mask, - normals_bx7x3 * curr_mask, - verts_centroid - v0.squeeze(2), - ) - + v0.reshape(-1, 3) - ) - edge_group = torch.cat(edge_group) - edge_group_to_vd = torch.cat(edge_group_to_vd) - edge_group_to_cube = torch.cat(edge_group_to_cube) - vd_num_edges = torch.cat(vd_num_edges) - vd_gamma = torch.cat(vd_gamma) - - if grad_func is not None: - vd = torch.cat(vd) - L_dev = torch.zeros([1], device=self.device) - else: - vd = torch.zeros((total_num_vd, 3), device=self.device) - beta_sum = torch.zeros((total_num_vd, 1), device=self.device) - - idx_group = torch.gather( - input=idx_map.reshape(-1), - dim=0, - index=edge_group_to_cube * 12 + edge_group, - ) - - x_group = torch.index_select( - input=surf_edges_x, index=idx_group.reshape(-1), dim=0 - ).reshape(-1, 2, 3) - s_group = torch.index_select( - input=surf_edges_s, index=idx_group.reshape(-1), dim=0 - ).reshape(-1, 2, 1) - - zero_crossing_group = torch.index_select( - input=zero_crossing, index=idx_group.reshape(-1), dim=0 - ).reshape(-1, 3) - - alpha_group = torch.index_select( - input=alpha_nx12x2.reshape(-1, 2), - dim=0, - index=edge_group_to_cube * 12 + edge_group, - ).reshape(-1, 2, 1) - ue_group = self._linear_interp(s_group * alpha_group, x_group) - - beta_group = torch.gather( - input=beta_fx12.reshape(-1), - dim=0, - index=edge_group_to_cube * 12 + edge_group, - ).reshape(-1, 1) - beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group) - vd = ( - vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) - / beta_sum - ) - L_dev = self._compute_reg_loss( - vd, zero_crossing_group, edge_group_to_vd, vd_num_edges - ) - - v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd - - vd_idx_map = (vd_idx_map.reshape(-1)).scatter( - dim=0, - index=edge_group_to_cube * 12 + edge_group, - src=v_idx[edge_group_to_vd], - ) - - return vd, L_dev, vd_gamma, vd_idx_map - - def _triangulate( - self, - s_n, - surf_edges, - vd, - vd_gamma, - edge_counts, - idx_map, - vd_idx_map, - surf_edges_mask, - training, - grad_func, - ): - """ - Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into - triangles based on the gamma parameter, as described in Section 4.3. - """ - with torch.no_grad(): - group_mask = ( - edge_counts == 4 - ) & surf_edges_mask # surface edges shared by 4 cubes. - group = idx_map.reshape(-1)[group_mask] - vd_idx = vd_idx_map[group_mask] - edge_indices, indices = torch.sort(group, stable=True) - quad_vd_idx = vd_idx[indices].reshape(-1, 4) - - # Ensure all face directions point towards the positive SDF to maintain consistent winding. - s_edges = s_n[ - surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1) - ].reshape(-1, 2) - flip_mask = s_edges[:, 0] > 0 - quad_vd_idx = torch.cat( - ( - quad_vd_idx[flip_mask][:, [0, 1, 3, 2]], - quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]], - ) - ) - if grad_func is not None: - # when grad_func is given, split quadrilaterals along the diagonals with more consistent gradients. - with torch.no_grad(): - vd_gamma = torch.nn.functional.normalize(grad_func(vd), dim=-1) - quad_gamma = torch.index_select( - input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0 - ).reshape(-1, 4, 3) - gamma_02 = (quad_gamma[:, 0] * quad_gamma[:, 2]).sum(-1, keepdims=True) - gamma_13 = (quad_gamma[:, 1] * quad_gamma[:, 3]).sum(-1, keepdims=True) - else: - quad_gamma = torch.index_select( - input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0 - ).reshape(-1, 4) - gamma_02 = torch.index_select( - input=quad_gamma, index=torch.tensor(0, device=self.device), dim=1 - ) * torch.index_select( - input=quad_gamma, index=torch.tensor(2, device=self.device), dim=1 - ) - gamma_13 = torch.index_select( - input=quad_gamma, index=torch.tensor(1, device=self.device), dim=1 - ) * torch.index_select( - input=quad_gamma, index=torch.tensor(3, device=self.device), dim=1 - ) - if not training: - mask = (gamma_02 > gamma_13).squeeze(1) - faces = torch.zeros( - (quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device - ) - faces[mask] = quad_vd_idx[mask][:, self.quad_split_1] - faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2] - faces = faces.reshape(-1, 3) - else: - vd_quad = torch.index_select( - input=vd, index=quad_vd_idx.reshape(-1), dim=0 - ).reshape(-1, 4, 3) - vd_02 = ( - torch.index_select( - input=vd_quad, index=torch.tensor(0, device=self.device), dim=1 - ) - + torch.index_select( - input=vd_quad, index=torch.tensor(2, device=self.device), dim=1 - ) - ) / 2 - vd_13 = ( - torch.index_select( - input=vd_quad, index=torch.tensor(1, device=self.device), dim=1 - ) - + torch.index_select( - input=vd_quad, index=torch.tensor(3, device=self.device), dim=1 - ) - ) / 2 - weight_sum = (gamma_02 + gamma_13) + 1e-8 - vd_center = ( - (vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) - / weight_sum.unsqueeze(-1) - ).squeeze(1) - vd_center_idx = ( - torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0] - ) - vd = torch.cat([vd, vd_center]) - faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2) - faces = torch.cat( - [faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1 - ).reshape(-1, 3) - return vd, faces, s_edges, edge_indices - - def _tetrahedralize( - self, - x_nx3, - s_n, - cube_fx8, - vertices, - faces, - surf_edges, - s_edges, - vd_idx_map, - case_ids, - edge_indices, - surf_cubes, - training, - ): - """ - Tetrahedralizes the interior volume to produce a tetrahedral mesh, as described in Section 4.5. - """ - occ_n = s_n < 0 - occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8) - occ_sum = torch.sum(occ_fx8, -1) - - inside_verts = x_nx3[occ_n] - mapping_inside_verts = ( - torch.ones((occ_n.shape[0]), dtype=torch.long, device=self.device) * -1 - ) - mapping_inside_verts[occ_n] = ( - torch.arange(occ_n.sum(), device=self.device) + vertices.shape[0] - ) - """ - For each grid edge connecting two grid vertices with different - signs, we first form a four-sided pyramid by connecting one - of the grid vertices with four mesh vertices that correspond - to the grid edge and then subdivide the pyramid into two tetrahedra - """ - inside_verts_idx = mapping_inside_verts[ - surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1, 2)[s_edges < 0] - ] - if not training: - inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 2).reshape(-1) - else: - inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 4).reshape(-1) - - tets_surface = torch.cat([faces, inside_verts_idx.unsqueeze(-1)], -1) - """ - For each grid edge connecting two grid vertices with the - same sign, the tetrahedron is formed by the two grid vertices - and two vertices in consecutive adjacent cells - """ - inside_cubes = occ_sum == 8 - inside_cubes_center = ( - x_nx3[cube_fx8[inside_cubes].reshape(-1)].reshape(-1, 8, 3).mean(1) - ) - inside_cubes_center_idx = ( - torch.arange(inside_cubes_center.shape[0], device=inside_cubes.device) - + vertices.shape[0] - + inside_verts.shape[0] - ) - - surface_n_inside_cubes = surf_cubes | inside_cubes - edge_center_vertex_idx = ( - torch.ones( - ((surface_n_inside_cubes).sum(), 13), - dtype=torch.long, - device=x_nx3.device, - ) - * -1 - ) - surf_cubes = surf_cubes[surface_n_inside_cubes] - inside_cubes = inside_cubes[surface_n_inside_cubes] - edge_center_vertex_idx[surf_cubes, :12] = vd_idx_map.reshape(-1, 12) - edge_center_vertex_idx[inside_cubes, 12] = inside_cubes_center_idx - - all_edges = cube_fx8[surface_n_inside_cubes][:, self.cube_edges].reshape(-1, 2) - unique_edges, _idx_map, counts = torch.unique( - all_edges, dim=0, return_inverse=True, return_counts=True - ) - unique_edges = unique_edges.long() - mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 2 - mask = mask_edges[_idx_map] - counts = counts[_idx_map] - mapping = ( - torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device) - * -1 - ) - mapping[mask_edges] = torch.arange(mask_edges.sum(), device=self.device) - idx_map = mapping[_idx_map] - - group_mask = (counts == 4) & mask - group = idx_map.reshape(-1)[group_mask] - edge_indices, indices = torch.sort(group) - cube_idx = ( - torch.arange( - (_idx_map.shape[0] // 12), dtype=torch.long, device=self.device - ) - .unsqueeze(1) - .expand(-1, 12) - .reshape(-1)[group_mask] - ) - edge_idx = ( - torch.arange((12), dtype=torch.long, device=self.device) - .unsqueeze(0) - .expand(_idx_map.shape[0] // 12, -1) - .reshape(-1)[group_mask] - ) - # Identify the face shared by the adjacent cells. - cube_idx_4 = cube_idx[indices].reshape(-1, 4) - edge_dir = self.edge_dir_table[edge_idx[indices]].reshape(-1, 4)[..., 0] - shared_faces_4x2 = self.dir_faces_table[edge_dir].reshape(-1) - cube_idx_4x2 = cube_idx_4[:, self.adj_pairs].reshape(-1) - # Identify an edge of the face with different signs and - # select the mesh vertex corresponding to the identified edge. - case_ids_expand = ( - torch.ones( - (surface_n_inside_cubes).sum(), dtype=torch.long, device=x_nx3.device - ) - * 255 - ) - case_ids_expand[surf_cubes] = case_ids - cases = case_ids_expand[cube_idx_4x2] - quad_edge = edge_center_vertex_idx[ - cube_idx_4x2, self.tet_table[cases, shared_faces_4x2] - ].reshape(-1, 2) - mask = (quad_edge == -1).sum(-1) == 0 - inside_edge = mapping_inside_verts[ - unique_edges[mask_edges][edge_indices].reshape(-1) - ].reshape(-1, 2) - tets_inside = torch.cat([quad_edge, inside_edge], -1)[mask] - - tets = torch.cat([tets_surface, tets_inside]) - vertices = torch.cat([vertices, inside_verts, inside_cubes_center]) - return vertices, tets - - -def get_center_boundary_index(grid_res, device): - v = torch.zeros( - (grid_res + 1, grid_res + 1, grid_res + 1), dtype=torch.bool, device=device - ) - v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = True - center_indices = torch.nonzero(v.reshape(-1)) - - v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = False - v[:2, ...] = True - v[-2:, ...] = True - v[:, :2, ...] = True - v[:, -2:, ...] = True - v[:, :, :2] = True - v[:, :, -2:] = True - boundary_indices = torch.nonzero(v.reshape(-1)) - return center_indices, boundary_indices - - -class Geometry: - def __init__(self): - pass - - def forward(self): - pass - - -class FlexiCubesGeometry(Geometry): - def __init__( - self, - grid_res=64, - scale=2.0, - device="cuda", - renderer=None, - render_type="neural_render", - args=None, - ): - super(FlexiCubesGeometry, self).__init__() - self.grid_res = grid_res - self.device = device - self.args = args - self.fc = FlexiCubes(device, weight_scale=0.5) - self.verts, self.indices = self.fc.construct_voxel_grid(grid_res) - if isinstance(scale, list): - self.verts[:, 0] = self.verts[:, 0] * scale[0] - self.verts[:, 1] = self.verts[:, 1] * scale[1] - self.verts[:, 2] = self.verts[:, 2] * scale[1] - else: - self.verts = self.verts * scale - - all_edges = self.indices[:, self.fc.cube_edges].reshape(-1, 2) - self.all_edges = torch.unique(all_edges, dim=0) - - # Parameters used for fix boundary sdf - self.center_indices, self.boundary_indices = get_center_boundary_index( - self.grid_res, device - ) - self.renderer = renderer - self.render_type = render_type - - def getAABB(self): - return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values - - def get_mesh( - self, - v_deformed_nx3, - sdf_n, - weight_n=None, - with_uv=False, - indices=None, - is_training=False, - ): - if indices is None: - indices = self.indices - - verts, faces, v_reg_loss = self.fc( - v_deformed_nx3, - sdf_n, - indices, - self.grid_res, - beta_fx12=weight_n[:, :12], - alpha_fx8=weight_n[:, 12:20], - gamma_f=weight_n[:, 20], - training=is_training, - ) - return verts, faces, v_reg_loss - - def render_mesh( - self, - mesh_v_nx3, - mesh_f_fx3, - camera_mv_bx4x4, - resolution=256, - hierarchical_mask=False, - ): - return_value = dict() - if self.render_type == "neural_render": - tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal = ( - self.renderer.render_mesh( - mesh_v_nx3.unsqueeze(dim=0), - mesh_f_fx3.int(), - camera_mv_bx4x4, - mesh_v_nx3.unsqueeze(dim=0), - resolution=resolution, - device=self.device, - hierarchical_mask=hierarchical_mask, - ) - ) - - return_value["tex_pos"] = tex_pos - return_value["mask"] = mask - return_value["hard_mask"] = hard_mask - return_value["rast"] = rast - return_value["v_pos_clip"] = v_pos_clip - return_value["mask_pyramid"] = mask_pyramid - return_value["depth"] = depth - return_value["normal"] = normal - else: - raise NotImplementedError - - return return_value - - def render( - self, - v_deformed_bxnx3=None, - sdf_bxn=None, - camera_mv_bxnviewx4x4=None, - resolution=256, - ): - # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1 - v_list = [] - f_list = [] - n_batch = v_deformed_bxnx3.shape[0] - all_render_output = [] - for i_batch in range(n_batch): - verts_nx3, faces_fx3 = self.get_mesh( - v_deformed_bxnx3[i_batch], sdf_bxn[i_batch] - ) - v_list.append(verts_nx3) - f_list.append(faces_fx3) - render_output = self.render_mesh( - verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution - ) - all_render_output.append(render_output) - - # Concatenate all render output - return_keys = all_render_output[0].keys() - return_value = dict() - for k in return_keys: - value = [v[k] for v in all_render_output] - return_value[k] = value - # We can do concatenation outside of the render - return return_value - - -def interpolate(attr, rast, attr_idx, rast_db=None): - return dr.interpolate( - attr.contiguous(), - rast, - attr_idx, - rast_db=rast_db, - diff_attrs=None if rast_db is None else "all", - ) - - -def xfm_points(points, matrix, use_python=True): - """Transform points. - Args: - points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3] - matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4] - use_python: Use PyTorch's torch.matmul (for validation) - Returns: - Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4]. - """ - out = torch.matmul( - torch.nn.functional.pad(points, pad=(0, 1), mode="constant", value=1.0), - torch.transpose(matrix, 1, 2), - ) - if torch.is_anomaly_enabled(): - assert torch.all( - torch.isfinite(out) - ), "Output of xfm_points contains inf or NaN" - return out - - -def dot(x, y): - return torch.sum(x * y, -1, keepdim=True) - - -def compute_vertex_normal(v_pos, t_pos_idx): - i0 = t_pos_idx[:, 0] - i1 = t_pos_idx[:, 1] - i2 = t_pos_idx[:, 2] - - v0 = v_pos[i0, :] - v1 = v_pos[i1, :] - v2 = v_pos[i2, :] - - face_normals = torch.cross(v1 - v0, v2 - v0) - - # Splat face normals to vertices - v_nrm = torch.zeros_like(v_pos) - v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) - v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) - v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) - - # Normalize, replace zero (degenerated) normals with some default value - v_nrm = torch.where( - dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm) - ) - v_nrm = F.normalize(v_nrm, dim=1) - assert torch.all(torch.isfinite(v_nrm)) - - return v_nrm - - -class Renderer: - def __init__(self): - pass - - def forward(self): - pass - - -class NeuralRender(Renderer): - def __init__(self, device="cuda", camera_model=None): - super(NeuralRender, self).__init__() - self.device = device - self.ctx = dr.RasterizeCudaContext(device=device) - self.projection_mtx = None - self.camera = camera_model - - def render_mesh( - self, - mesh_v_pos_bxnx3, - mesh_t_pos_idx_fx3, - camera_mv_bx4x4, - mesh_v_feat_bxnxd, - resolution=256, - spp=1, - device="cuda", - hierarchical_mask=False, - ): - assert not hierarchical_mask - - mtx_in = ( - torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device) - if not torch.is_tensor(camera_mv_bx4x4) - else camera_mv_bx4x4 - ) - v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) # Rotate it to camera coordinates - v_pos_clip = self.camera.project(v_pos) # Projection in the camera - - v_nrm = compute_vertex_normal( - mesh_v_pos_bxnx3[0], mesh_t_pos_idx_fx3.long() - ) # vertex normals in world coordinates - - # Render the image, - # Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render - num_layers = 1 - mask_pyramid = None - assert mesh_t_pos_idx_fx3.shape[0] > 0 # Make sure we have shapes - mesh_v_feat_bxnxd = torch.cat( - [mesh_v_feat_bxnxd.repeat(v_pos.shape[0], 1, 1), v_pos], dim=-1 - ) # Concatenate the pos - - with dr.DepthPeeler( - self.ctx, - v_pos_clip, - mesh_t_pos_idx_fx3, - [resolution * spp, resolution * spp], - ) as peeler: - for _ in range(num_layers): - rast, db = peeler.rasterize_next_layer() - gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3) - - hard_mask = torch.clamp(rast[..., -1:], 0, 1) - antialias_mask = dr.antialias( - hard_mask.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3 - ) - - depth = gb_feat[..., -2:-1] - ori_mesh_feature = gb_feat[..., :-4] - - normal, _ = interpolate(v_nrm[None, ...], rast, mesh_t_pos_idx_fx3) - normal = dr.antialias( - normal.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3 - ) - normal = F.normalize(normal, dim=-1) - normal = torch.lerp( - torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.float() - ) # black background - - return ( - ori_mesh_feature, - antialias_mask, - hard_mask, - rast, - v_pos_clip, - mask_pyramid, - depth, - normal, - ) - - -def projection(x=0.1, n=1.0, f=50.0, near_plane=None): - if near_plane is None: - near_plane = n - return np.array( - [ - [n / x, 0, 0, 0], - [0, n / -x, 0, 0], - [ - 0, - 0, - -(f + near_plane) / (f - near_plane), - -(2 * f * near_plane) / (f - near_plane), - ], - [0, 0, -1, 0], - ] - ).astype(np.float32) - - -class Camera(nn.Module): - def __init__(self): - super(Camera, self).__init__() - pass - - -class PerspectiveCamera(Camera): - def __init__(self, fovy=49.0, device="cuda"): - super(PerspectiveCamera, self).__init__() - self.device = device - focal = np.tan(fovy / 180.0 * np.pi * 0.5) - self.proj_mtx = ( - torch.from_numpy(projection(x=focal, f=1000.0, n=1.0, near_plane=0.1)) - .to(self.device) - .unsqueeze(dim=0) - ) - - def project(self, points_bxnx4): - out = torch.matmul(points_bxnx4, torch.transpose(self.proj_mtx, 1, 2)) - return out - - -class ViTEmbeddings(nn.Module): - def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None: - super().__init__() - - self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) - self.mask_token = ( - nn.Parameter(torch.zeros(1, 1, config.hidden_size)) - if use_mask_token - else None - ) - self.patch_embeddings = ViTPatchEmbeddings(config) - num_patches = self.patch_embeddings.num_patches - self.position_embeddings = nn.Parameter( - torch.randn(1, num_patches + 1, config.hidden_size) - ) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.config = config - - def interpolate_pos_encoding( - self, embeddings: torch.Tensor, height: int, width: int - ) -> torch.Tensor: - """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher - resolution images. - - Source: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 - """ - - num_patches = embeddings.shape[1] - 1 - num_positions = self.position_embeddings.shape[1] - 1 - if num_patches == num_positions and height == width: - return self.position_embeddings - class_pos_embed = self.position_embeddings[:, 0] - patch_pos_embed = self.position_embeddings[:, 1:] - dim = embeddings.shape[-1] - h0 = height // self.config.patch_size - w0 = width // self.config.patch_size - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - h0, w0 = h0 + 0.1, w0 + 0.1 - patch_pos_embed = patch_pos_embed.reshape( - 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim - ) - patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) - patch_pos_embed = nn.functional.interpolate( - patch_pos_embed, - scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), - mode="bicubic", - align_corners=False, - ) - assert ( - int(h0) == patch_pos_embed.shape[-2] - and int(w0) == patch_pos_embed.shape[-1] - ) - patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) - - def forward( - self, - pixel_values: torch.Tensor, - bool_masked_pos: Optional[torch.BoolTensor] = None, - interpolate_pos_encoding: bool = False, - ) -> torch.Tensor: - batch_size, num_channels, height, width = pixel_values.shape - embeddings = self.patch_embeddings( - pixel_values, interpolate_pos_encoding=interpolate_pos_encoding - ) - - if bool_masked_pos is not None: - seq_length = embeddings.shape[1] - mask_tokens = self.mask_token.expand(batch_size, seq_length, -1) - # replace the masked visual tokens by mask_tokens - mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) - embeddings = embeddings * (1.0 - mask) + mask_tokens * mask - - # add the [CLS] token to the embedded patch tokens - cls_tokens = self.cls_token.expand(batch_size, -1, -1) - embeddings = torch.cat((cls_tokens, embeddings), dim=1) - - # add positional encoding to each token - if interpolate_pos_encoding: - embeddings = embeddings + self.interpolate_pos_encoding( - embeddings, height, width - ) - else: - embeddings = embeddings + self.position_embeddings - - embeddings = self.dropout(embeddings) - - return embeddings - - -class ViTPatchEmbeddings(nn.Module): - """ - This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial - `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a - Transformer. - """ - - def __init__(self, config): - super().__init__() - image_size, patch_size = config.image_size, config.patch_size - num_channels, hidden_size = config.num_channels, config.hidden_size - - image_size = ( - image_size - if isinstance(image_size, collections.abc.Iterable) - else (image_size, image_size) - ) - patch_size = ( - patch_size - if isinstance(patch_size, collections.abc.Iterable) - else (patch_size, patch_size) - ) - num_patches = (image_size[1] // patch_size[1]) * ( - image_size[0] // patch_size[0] - ) - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_patches = num_patches - - self.projection = nn.Conv2d( - num_channels, hidden_size, kernel_size=patch_size, stride=patch_size - ) - - def forward( - self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False - ) -> torch.Tensor: - batch_size, num_channels, height, width = pixel_values.shape - if num_channels != self.num_channels: - raise ValueError( - "Make sure that the channel dimension of the pixel values match with the one set in the configuration." - f" Expected {self.num_channels} but got {num_channels}." - ) - if not interpolate_pos_encoding: - if height != self.image_size[0] or width != self.image_size[1]: - raise ValueError( - f"Input image size ({height}*{width}) doesn't match model" - f" ({self.image_size[0]}*{self.image_size[1]})." - ) - embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) - return embeddings - - -class ViTSelfAttention(nn.Module): - def __init__(self, config: ViTConfig) -> None: - super().__init__() - if config.hidden_size % config.num_attention_heads != 0 and not hasattr( - config, "embedding_size" - ): - raise ValueError( - f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " - f"heads {config.num_attention_heads}." - ) - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = nn.Linear( - config.hidden_size, self.all_head_size, bias=config.qkv_bias - ) - self.key = nn.Linear( - config.hidden_size, self.all_head_size, bias=config.qkv_bias - ) - self.value = nn.Linear( - config.hidden_size, self.all_head_size, bias=config.qkv_bias - ) - - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + ( - self.num_attention_heads, - self.attention_head_size, - ) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward( - self, - hidden_states, - head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - - # Normalize the attention scores to probabilities. - attention_probs = nn.functional.softmax(attention_scores, dim=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_layer) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) - - outputs = ( - (context_layer, attention_probs) if output_attentions else (context_layer,) - ) - - return outputs - - -class ViTSelfOutput(nn.Module): - """ - The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the - layernorm applied before each block. - """ - - def __init__(self, config: ViTConfig) -> None: - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward( - self, hidden_states: torch.Tensor, input_tensor: torch.Tensor - ) -> torch.Tensor: - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - - return hidden_states - - -class ViTAttention(nn.Module): - def __init__(self, config: ViTConfig) -> None: - super().__init__() - self.attention = ViTSelfAttention(config) - self.output = ViTSelfOutput(config) - self.pruned_heads = set() - - def prune_heads(self, heads: Set[int]) -> None: - if len(heads) == 0: - return - heads, index = find_pruneable_heads_and_indices( - heads, - self.attention.num_attention_heads, - self.attention.attention_head_size, - self.pruned_heads, - ) - - # Prune linear layers - self.attention.query = prune_linear_layer(self.attention.query, index) - self.attention.key = prune_linear_layer(self.attention.key, index) - self.attention.value = prune_linear_layer(self.attention.value, index) - self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) - - # Update hyper params and store pruned heads - self.attention.num_attention_heads = self.attention.num_attention_heads - len( - heads - ) - self.attention.all_head_size = ( - self.attention.attention_head_size * self.attention.num_attention_heads - ) - self.pruned_heads = self.pruned_heads.union(heads) - - def forward( - self, - hidden_states: torch.Tensor, - head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: - self_outputs = self.attention(hidden_states, head_mask, output_attentions) - - attention_output = self.output(self_outputs[0], hidden_states) - - outputs = (attention_output,) + self_outputs[ - 1: - ] # add attentions if we output them - return outputs - - -class ViTIntermediate(nn.Module): - def __init__(self, config: ViTConfig) -> None: - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.intermediate_size) - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = ACT2FN[config.hidden_act] - else: - self.intermediate_act_fn = config.hidden_act - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.dense(hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - - return hidden_states - - -class ViTOutput(nn.Module): - def __init__(self, config: ViTConfig) -> None: - super().__init__() - self.dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward( - self, hidden_states: torch.Tensor, input_tensor: torch.Tensor - ) -> torch.Tensor: - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - - hidden_states = hidden_states + input_tensor - - return hidden_states - - -def modulate(x, shift, scale): - return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) - - -class ViTLayer(nn.Module): - """This corresponds to the Block class in the timm implementation.""" - - def __init__(self, config: ViTConfig) -> None: - super().__init__() - self.chunk_size_feed_forward = config.chunk_size_feed_forward - self.seq_len_dim = 1 - self.attention = ViTAttention(config) - self.intermediate = ViTIntermediate(config) - self.output = ViTOutput(config) - self.layernorm_before = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps - ) - self.layernorm_after = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps - ) - - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), nn.Linear(config.hidden_size, 4 * config.hidden_size, bias=True) - ) - nn.init.constant_(self.adaLN_modulation[-1].weight, 0) - nn.init.constant_(self.adaLN_modulation[-1].bias, 0) - - def forward( - self, - hidden_states: torch.Tensor, - adaln_input: torch.Tensor = None, - head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: - shift_msa, scale_msa, shift_mlp, scale_mlp = self.adaLN_modulation( - adaln_input - ).chunk(4, dim=1) - - self_attention_outputs = self.attention( - modulate( - self.layernorm_before(hidden_states), shift_msa, scale_msa - ), # in ViT, layernorm is applied before self-attention - head_mask, - output_attentions=output_attentions, - ) - attention_output = self_attention_outputs[0] - outputs = self_attention_outputs[ - 1: - ] # add self attentions if we output attention weights - - # first residual connection - hidden_states = attention_output + hidden_states - - # in ViT, layernorm is also applied after self-attention - layer_output = modulate( - self.layernorm_after(hidden_states), shift_mlp, scale_mlp - ) - layer_output = self.intermediate(layer_output) - - # second residual connection is done here - layer_output = self.output(layer_output, hidden_states) - - outputs = (layer_output,) + outputs - - return outputs - - -class ViTEncoder(nn.Module): - def __init__(self, config: ViTConfig) -> None: - super().__init__() - self.config = config - self.layer = nn.ModuleList( - [ViTLayer(config) for _ in range(config.num_hidden_layers)] - ) - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.Tensor, - adaln_input: torch.Tensor = None, - head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ) -> Union[tuple, BaseModelOutput]: - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_head_mask = head_mask[i] if head_mask is not None else None - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - adaln_input, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, adaln_input, layer_head_mask, output_attentions - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [hidden_states, all_hidden_states, all_self_attentions] - if v is not None - ) - return BaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - -class ViTPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = ViTConfig - base_model_prefix = "vit" - main_input_name = "pixel_values" - supports_gradient_checkpointing = True - _no_split_modules = ["ViTEmbeddings", "ViTLayer"] - - def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: - """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d)): - # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid - # `trunc_normal_cpu` not implemented in `half` issues - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.weight.dtype) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, ViTEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - - module.cls_token.data = nn.init.trunc_normal_( - module.cls_token.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.cls_token.dtype) - - -class ViTModel(ViTPreTrainedModel): - def __init__( - self, - config: ViTConfig, - add_pooling_layer: bool = True, - use_mask_token: bool = False, - ): - super().__init__(config) - self.config = config - - self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token) - self.encoder = ViTEncoder(config) - - self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pooler = ViTPooler(config) if add_pooling_layer else None - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self) -> ViTPatchEmbeddings: - return self.embeddings.patch_embeddings - - def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - for layer, heads in heads_to_prune.items(): - self.encoder.layer[layer].attention.prune_heads(heads) - - def forward( - self, - pixel_values: Optional[torch.Tensor] = None, - adaln_input: Optional[torch.Tensor] = None, - bool_masked_pos: Optional[torch.BoolTensor] = None, - head_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - interpolate_pos_encoding: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPooling]: - r""" - bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): - Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). - """ - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - - # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?) - expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype - if pixel_values.dtype != expected_dtype: - pixel_values = pixel_values.to(expected_dtype) - - embedding_output = self.embeddings( - pixel_values, - bool_masked_pos=bool_masked_pos, - interpolate_pos_encoding=interpolate_pos_encoding, - ) - - encoder_outputs = self.encoder( - embedding_output, - adaln_input=adaln_input, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - sequence_output = encoder_outputs[0] - sequence_output = self.layernorm(sequence_output) - pooled_output = ( - self.pooler(sequence_output) if self.pooler is not None else None - ) - - if not return_dict: - head_outputs = ( - (sequence_output, pooled_output) - if pooled_output is not None - else (sequence_output,) - ) - return head_outputs + encoder_outputs[1:] - - return BaseModelOutputWithPooling( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - -class ViTPooler(nn.Module): - def __init__(self, config: ViTConfig): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.activation = nn.Tanh() - - def forward(self, hidden_states): - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(first_token_tensor) - pooled_output = self.activation(pooled_output) - return pooled_output - - -class DinoWrapper(nn.Module): - def __init__(self, model_name: str, freeze: bool = True): - super().__init__() - self.model, self.processor = self._build_dino(model_name) - self.camera_embedder = nn.Sequential( - nn.Linear(16, self.model.config.hidden_size, bias=True), - nn.SiLU(), - nn.Linear( - self.model.config.hidden_size, self.model.config.hidden_size, bias=True - ), - ) - if freeze: - self._freeze() - - def forward(self, image, camera): - if image.ndim == 5: - image = image.view(-1, *image.shape[2:]) - dtype = image.dtype - inputs = ( - self.processor( - images=image.float(), - return_tensors="pt", - do_rescale=False, - do_resize=False, - ) - .to(self.model.device) - .to(dtype) - ) - # embed camera - camera_embeddings = self.camera_embedder(camera) - camera_embeddings = camera_embeddings.view(-1, camera_embeddings.shape[-1]) - embeddings = camera_embeddings - # This resampling of positional embedding uses bicubic interpolation - outputs = self.model( - **inputs, adaln_input=embeddings, interpolate_pos_encoding=True - ) - last_hidden_states = outputs.last_hidden_state - return last_hidden_states - - def _freeze(self): - self.model.eval() - for name, param in self.model.named_parameters(): - param.requires_grad = False - - @staticmethod - def _build_dino( - model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5 - ): - import requests - - try: - model = ViTModel.from_pretrained(model_name, add_pooling_layer=False) - processor = ViTImageProcessor.from_pretrained(model_name) - return model, processor - except requests.exceptions.ProxyError as err: - if proxy_error_retries > 0: - print( - f"Huggingface ProxyError: Retrying in {proxy_error_cooldown} seconds..." - ) - import time - - time.sleep(proxy_error_cooldown) - return DinoWrapper._build_dino( - model_name, proxy_error_retries - 1, proxy_error_cooldown - ) - else: - raise err - - -class BasicTransformerBlock(nn.Module): - def __init__( - self, - inner_dim: int, - cond_dim: int, - num_heads: int, - eps: float, - attn_drop: float = 0.0, - attn_bias: bool = False, - mlp_ratio: float = 4.0, - mlp_drop: float = 0.0, - ): - super().__init__() - - self.norm1 = nn.LayerNorm(inner_dim) - self.cross_attn = nn.MultiheadAttention( - embed_dim=inner_dim, - num_heads=num_heads, - kdim=cond_dim, - vdim=cond_dim, - dropout=attn_drop, - bias=attn_bias, - batch_first=True, - ) - self.norm2 = nn.LayerNorm(inner_dim) - self.self_attn = nn.MultiheadAttention( - embed_dim=inner_dim, - num_heads=num_heads, - dropout=attn_drop, - bias=attn_bias, - batch_first=True, - ) - self.norm3 = nn.LayerNorm(inner_dim) - self.mlp = nn.Sequential( - nn.Linear(inner_dim, int(inner_dim * mlp_ratio)), - nn.GELU(), - nn.Dropout(mlp_drop), - nn.Linear(int(inner_dim * mlp_ratio), inner_dim), - nn.Dropout(mlp_drop), - ) - - def forward(self, x, cond): - x = x + self.cross_attn(self.norm1(x), cond, cond)[0] - before_sa = self.norm2(x) - x = x + self.self_attn(before_sa, before_sa, before_sa)[0] - x = x + self.mlp(self.norm3(x)) - return x - - -class TriplaneTransformer(nn.Module): - def __init__( - self, - inner_dim: int, - image_feat_dim: int, - triplane_low_res: int, - triplane_high_res: int, - triplane_dim: int, - num_layers: int, - num_heads: int, - eps: float = 1e-6, - ): - super().__init__() - - self.triplane_low_res = triplane_low_res - self.triplane_high_res = triplane_high_res - self.triplane_dim = triplane_dim - - self.pos_embed = nn.Parameter( - torch.randn(1, 3 * triplane_low_res**2, inner_dim) - * (1.0 / inner_dim) ** 0.5 - ) - self.layers = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim=inner_dim, - cond_dim=image_feat_dim, - num_heads=num_heads, - eps=eps, - ) - for _ in range(num_layers) - ] - ) - self.norm = nn.LayerNorm(inner_dim, eps=eps) - self.deconv = nn.ConvTranspose2d( - inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0 - ) - - def forward(self, image_feats): - - N = image_feats.shape[0] - H = W = self.triplane_low_res - - x = self.pos_embed.repeat(N, 1, 1) - for layer in self.layers: - x = layer(x, image_feats) - x = self.norm(x) - - x = x.view(N, 3, H, W, -1) - x = torch.einsum("nihwd->indhw", x) - x = x.contiguous().view(3 * N, -1, H, W) - x = self.deconv(x) - x = x.view(3, N, *x.shape[-3:]) - x = torch.einsum("indhw->nidhw", x) - x = x.contiguous() - - return x - - -def interpolate_atlas(attr, rast, attr_idx, rast_db=None): - return dr.interpolate( - attr.contiguous(), - rast, - attr_idx, - rast_db=rast_db, - diff_attrs=None if rast_db is None else "all", - ) - - -def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution): - _, indices, uvs = xatlas.parametrize( - mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy() - ) - - indices_int64 = indices.astype(np.uint64, casting="same_kind").view(np.int64) - - uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device) - mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device) - uv_clip = uvs[None, ...] * 2.0 - 1.0 - - uv_clip4 = torch.cat( - ( - uv_clip, - torch.zeros_like(uv_clip[..., 0:1]), - torch.ones_like(uv_clip[..., 0:1]), - ), - dim=-1, - ) - - rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution)) - - gb_pos, _ = interpolate_atlas(mesh_v[None, ...], rast, mesh_pos_idx.int()) - mask = rast[..., 3:4] > 0 - return uvs, mesh_tex_idx, gb_pos, mask def pad_camera_extrinsics_4x4(extrinsics): @@ -3082,374 +85,6 @@ def get_zero123plus_input_cameras(batch_size=1, radius=4.0, fov=30.0): return cameras.unsqueeze(0).repeat(batch_size, 1, 1) -class LRM(ModelMixin, ConfigMixin): - def __init__( - self, - encoder_freeze: bool = False, - encoder_model_name: str = "facebook/dino-vitb16", - encoder_feat_dim: int = 768, - transformer_dim: int = 1024, - transformer_layers: int = 16, - transformer_heads: int = 16, - triplane_low_res: int = 32, - triplane_high_res: int = 64, - triplane_dim: int = 80, - rendering_samples_per_ray: int = 128, - grid_res: int = 128, - grid_scale: float = 2.1, - ): - super().__init__() - - self.grid_res = grid_res - self.grid_scale = grid_scale - self.deformation_multiplier = 4.0 - - self.encoder = DinoWrapper( - model_name=encoder_model_name, - freeze=encoder_freeze, - ) - - self.transformer = TriplaneTransformer( - inner_dim=transformer_dim, - num_layers=transformer_layers, - num_heads=transformer_heads, - image_feat_dim=encoder_feat_dim, - triplane_low_res=triplane_low_res, - triplane_high_res=triplane_high_res, - triplane_dim=triplane_dim, - ) - - self.synthesizer = TriplaneSynthesizer( - triplane_dim=triplane_dim, - samples_per_ray=rendering_samples_per_ray, - ) - - def init_flexicubes_geometry(self, device, fovy=50.0): - camera = PerspectiveCamera(fovy=fovy, device=device) - renderer = NeuralRender(device, camera_model=camera) - self.geometry = FlexiCubesGeometry( - grid_res=self.grid_res, - scale=self.grid_scale, - renderer=renderer, - render_type="neural_render", - device=device, - ) - - def forward_planes(self, images, cameras): - B = images.shape[0] - - image_feats = self.encoder(images, cameras) - image_feats = image_feats.view(B, -1, image_feats.shape[-1]) - - planes = self.transformer(image_feats) - - return planes - - def get_sdf_deformation_prediction(self, planes): - init_position = self.geometry.verts.unsqueeze(0).expand(planes.shape[0], -1, -1) - - sdf, deformation, weight = torch.utils.checkpoint.checkpoint( - self.synthesizer.get_geometry_prediction, - planes, - init_position, - self.geometry.indices, - use_reentrant=False, - ) - - deformation = ( - 1.0 - / (self.grid_res * self.deformation_multiplier) - * torch.tanh(deformation) - ) - sdf_reg_loss = torch.zeros(sdf.shape[0], device=sdf.device, dtype=torch.float32) - - sdf_bxnxnxn = sdf.reshape( - (sdf.shape[0], self.grid_res + 1, self.grid_res + 1, self.grid_res + 1) - ) - sdf_less_boundary = sdf_bxnxnxn[:, 1:-1, 1:-1, 1:-1].reshape(sdf.shape[0], -1) - pos_shape = torch.sum((sdf_less_boundary > 0).int(), dim=-1) - neg_shape = torch.sum((sdf_less_boundary < 0).int(), dim=-1) - zero_surface = torch.bitwise_or(pos_shape == 0, neg_shape == 0) - if torch.sum(zero_surface).item() > 0: - update_sdf = torch.zeros_like(sdf[0:1]) - max_sdf = sdf.max() - min_sdf = sdf.min() - update_sdf[:, self.geometry.center_indices] += 1.0 - min_sdf - update_sdf[:, self.geometry.boundary_indices] += -1 - max_sdf - new_sdf = torch.zeros_like(sdf) - for i_batch in range(zero_surface.shape[0]): - if zero_surface[i_batch]: - new_sdf[i_batch : i_batch + 1] += update_sdf - update_mask = (new_sdf == 0).float() - sdf_reg_loss = torch.abs(sdf).mean(dim=-1).mean(dim=-1) - sdf_reg_loss = sdf_reg_loss * zero_surface.float() - sdf = sdf * update_mask + new_sdf * (1 - update_mask) - - final_sdf = [] - final_def = [] - for i_batch in range(zero_surface.shape[0]): - if zero_surface[i_batch]: - final_sdf.append(sdf[i_batch : i_batch + 1].detach()) - final_def.append(deformation[i_batch : i_batch + 1].detach()) - else: - final_sdf.append(sdf[i_batch : i_batch + 1]) - final_def.append(deformation[i_batch : i_batch + 1]) - sdf = torch.cat(final_sdf, dim=0) - deformation = torch.cat(final_def, dim=0) - return sdf, deformation, sdf_reg_loss, weight - - def get_geometry_prediction(self, planes=None): - sdf, deformation, sdf_reg_loss, weight = self.get_sdf_deformation_prediction( - planes - ) - v_deformed = ( - self.geometry.verts.unsqueeze(dim=0).expand(sdf.shape[0], -1, -1) - + deformation - ) - tets = self.geometry.indices - n_batch = planes.shape[0] - v_list = [] - f_list = [] - flexicubes_surface_reg_list = [] - - for i_batch in range(n_batch): - verts, faces, flexicubes_surface_reg = self.geometry.get_mesh( - v_deformed[i_batch], - sdf[i_batch].squeeze(dim=-1), - with_uv=False, - indices=tets, - weight_n=weight[i_batch].squeeze(dim=-1), - is_training=self.training, - ) - flexicubes_surface_reg_list.append(flexicubes_surface_reg) - v_list.append(verts) - f_list.append(faces) - - flexicubes_surface_reg = torch.cat(flexicubes_surface_reg_list).mean() - flexicubes_weight_reg = (weight**2).mean() - - return ( - v_list, - f_list, - sdf, - deformation, - v_deformed, - (sdf_reg_loss, flexicubes_surface_reg, flexicubes_weight_reg), - ) - - def get_texture_prediction(self, planes, tex_pos, hard_mask=None): - tex_pos = torch.cat(tex_pos, dim=0) - if hard_mask is not None: - tex_pos = tex_pos * hard_mask.float() - batch_size = tex_pos.shape[0] - tex_pos = tex_pos.reshape(batch_size, -1, 3) - if hard_mask is not None: - n_point_list = torch.sum( - hard_mask.long().reshape(hard_mask.shape[0], -1), dim=-1 - ) - sample_tex_pose_list = [] - max_point = n_point_list.max() - expanded_hard_mask = ( - hard_mask.reshape(batch_size, -1, 1).expand(-1, -1, 3) > 0.5 - ) - for i in range(tex_pos.shape[0]): - tex_pos_one_shape = tex_pos[i][expanded_hard_mask[i]].reshape(1, -1, 3) - if tex_pos_one_shape.shape[1] < max_point: - tex_pos_one_shape = torch.cat( - [ - tex_pos_one_shape, - torch.zeros( - 1, - max_point - tex_pos_one_shape.shape[1], - 3, - device=tex_pos_one_shape.device, - dtype=torch.float32, - ), - ], - dim=1, - ) - sample_tex_pose_list.append(tex_pos_one_shape) - tex_pos = torch.cat(sample_tex_pose_list, dim=0) - - tex_feat = torch.utils.checkpoint.checkpoint( - self.synthesizer.get_texture_prediction, - planes, - tex_pos, - use_reentrant=False, - ) - - if hard_mask is not None: - final_tex_feat = torch.zeros( - planes.shape[0], - hard_mask.shape[1] * hard_mask.shape[2], - tex_feat.shape[-1], - device=tex_feat.device, - ) - expanded_hard_mask = ( - hard_mask.reshape(hard_mask.shape[0], -1, 1).expand( - -1, -1, final_tex_feat.shape[-1] - ) - > 0.5 - ) - for i in range(planes.shape[0]): - final_tex_feat[i][expanded_hard_mask[i]] = tex_feat[i][ - : n_point_list[i] - ].reshape(-1) - tex_feat = final_tex_feat - - return tex_feat.reshape( - planes.shape[0], hard_mask.shape[1], hard_mask.shape[2], tex_feat.shape[-1] - ) - - def render_mesh(self, mesh_v, mesh_f, cam_mv, render_size=256): - return_value_list = [] - for i_mesh in range(len(mesh_v)): - return_value = self.geometry.render_mesh( - mesh_v[i_mesh], - mesh_f[i_mesh].int(), - cam_mv[i_mesh], - resolution=render_size, - hierarchical_mask=False, - ) - return_value_list.append(return_value) - - return_keys = return_value_list[0].keys() - return_value = dict() - for k in return_keys: - value = [v[k] for v in return_value_list] - return_value[k] = value - - mask = torch.cat(return_value["mask"], dim=0) - hard_mask = torch.cat(return_value["hard_mask"], dim=0) - tex_pos = return_value["tex_pos"] - depth = torch.cat(return_value["depth"], dim=0) - normal = torch.cat(return_value["normal"], dim=0) - return mask, hard_mask, tex_pos, depth, normal - - def forward_geometry(self, planes, render_cameras, render_size=256): - B, NV = render_cameras.shape[:2] - - mesh_v, mesh_f, sdf, _, _, sdf_reg_loss = self.get_geometry_prediction(planes) - - cam_mv = render_cameras - run_n_view = cam_mv.shape[1] - antilias_mask, hard_mask, tex_pos, depth, normal = self.render_mesh( - mesh_v, mesh_f, cam_mv, render_size=render_size - ) - - tex_hard_mask = hard_mask - tex_pos = [ - torch.cat([pos[i_view : i_view + 1] for i_view in range(run_n_view)], dim=2) - for pos in tex_pos - ] - tex_hard_mask = torch.cat( - [ - torch.cat( - [ - tex_hard_mask[ - i * run_n_view + i_view : i * run_n_view + i_view + 1 - ] - for i_view in range(run_n_view) - ], - dim=2, - ) - for i in range(planes.shape[0]) - ], - dim=0, - ) - - tex_feat = self.get_texture_prediction(planes, tex_pos, tex_hard_mask) - background_feature = torch.ones_like(tex_feat) - - img_feat = tex_feat * tex_hard_mask + background_feature * (1 - tex_hard_mask) - - img_feat = torch.cat( - [ - torch.cat( - [ - img_feat[ - i : i + 1, - :, - render_size * i_view : render_size * (i_view + 1), - ] - for i_view in range(run_n_view) - ], - dim=0, - ) - for i in range(len(tex_pos)) - ], - dim=0, - ) - - img = img_feat.clamp(0, 1).permute(0, 3, 1, 2).unflatten(0, (B, NV)) - antilias_mask = antilias_mask.permute(0, 3, 1, 2).unflatten(0, (B, NV)) - depth = -depth.permute(0, 3, 1, 2).unflatten(0, (B, NV)) - normal = normal.permute(0, 3, 1, 2).unflatten(0, (B, NV)) - - out = { - "img": img, - "mask": antilias_mask, - "depth": depth, - "normal": normal, - "sdf": sdf, - "mesh_v": mesh_v, - "mesh_f": mesh_f, - "sdf_reg_loss": sdf_reg_loss, - } - return out - - def forward(self, images, cameras, render_cameras, render_size: int): - planes = self.forward_planes(images, cameras) - out = self.forward_geometry(planes, render_cameras, render_size=render_size) - - return {"planes": planes, **out} - - def extract_mesh( - self, - planes: torch.Tensor, - use_texture_map: bool = False, - texture_resolution: int = 1024, - **kwargs, - ): - """ - Extract a 3D mesh from FlexiCubes. Only support batch_size 1. - :param planes: triplane features - :param use_texture_map: use texture map or vertex color - :param texture_resolution: the resolution of texure map - """ - assert planes.shape[0] == 1 - - # predict geometry first - mesh_v, mesh_f, _, _, _, _ = self.get_geometry_prediction(planes) - vertices, faces = mesh_v[0], mesh_f[0] - - if not use_texture_map: - # query vertex colors - vertices_tensor = vertices.unsqueeze(0) - vertices_colors = ( - self.synthesizer.get_texture_prediction(planes, vertices_tensor) - .clamp(0, 1) - .squeeze(0) - .cpu() - .numpy() - ) - vertices_colors = (vertices_colors * 255).astype(np.uint8) - - return vertices.cpu().numpy(), faces.cpu().numpy(), vertices_colors - - uvs, mesh_tex_idx, gb_pos, tex_hard_mask = xatlas_uvmap( - self.geometry.renderer.ctx, vertices, faces, resolution=texture_resolution - ) - tex_hard_mask = tex_hard_mask.float() - - tex_feat = self.get_texture_prediction(planes, [gb_pos], tex_hard_mask) - background_feature = torch.zeros_like(tex_feat) - img_feat = torch.lerp(background_feature, tex_feat, tex_hard_mask) - texture_map = img_feat.permute(0, 3, 1, 2).squeeze(0) - - return vertices, faces, uvs, mesh_tex_idx, texture_map - - @dataclass class InstantMeshPipelineOutput(BaseOutput): vertices: np.ndarray @@ -3459,7 +94,7 @@ class InstantMeshPipelineOutput(BaseOutput): class InstantMeshPipeline(DiffusionPipeline): - def __init__(self, lrm: LRM): + def __init__(self, lrm): super().__init__() self.lrm = lrm self.register_modules(lrm=self.lrm)