|
import os |
|
from dataclasses import dataclass, field |
|
from typing import Any, List, Optional, Tuple |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
import trimesh |
|
from einops import rearrange |
|
from huggingface_hub import hf_hub_download |
|
from jaxtyping import Float |
|
from omegaconf import OmegaConf |
|
from PIL import Image |
|
from safetensors.torch import load_model |
|
from torch import Tensor |
|
|
|
from sf3d.models.isosurface import MarchingTetrahedraHelper |
|
from sf3d.models.mesh import Mesh |
|
from sf3d.models.utils import ( |
|
BaseModule, |
|
ImageProcessor, |
|
convert_data, |
|
dilate_fill, |
|
dot, |
|
find_class, |
|
float32_to_uint8_np, |
|
normalize, |
|
scale_tensor, |
|
) |
|
from sf3d.utils import create_intrinsic_from_fov_deg, default_cond_c2w |
|
|
|
|
|
class SF3D(BaseModule): |
|
@dataclass |
|
class Config(BaseModule.Config): |
|
cond_image_size: int |
|
isosurface_resolution: int |
|
isosurface_threshold: float = 10.0 |
|
radius: float = 1.0 |
|
background_color: list[float] = field(default_factory=lambda: [0.5, 0.5, 0.5]) |
|
default_fovy_deg: float = 40.0 |
|
default_distance: float = 1.6 |
|
|
|
camera_embedder_cls: str = "" |
|
camera_embedder: dict = field(default_factory=dict) |
|
|
|
image_tokenizer_cls: str = "" |
|
image_tokenizer: dict = field(default_factory=dict) |
|
|
|
tokenizer_cls: str = "" |
|
tokenizer: dict = field(default_factory=dict) |
|
|
|
backbone_cls: str = "" |
|
backbone: dict = field(default_factory=dict) |
|
|
|
post_processor_cls: str = "" |
|
post_processor: dict = field(default_factory=dict) |
|
|
|
decoder_cls: str = "" |
|
decoder: dict = field(default_factory=dict) |
|
|
|
image_estimator_cls: str = "" |
|
image_estimator: dict = field(default_factory=dict) |
|
|
|
global_estimator_cls: str = "" |
|
global_estimator: dict = field(default_factory=dict) |
|
|
|
cfg: Config |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str |
|
): |
|
if os.path.isdir(pretrained_model_name_or_path): |
|
config_path = os.path.join(pretrained_model_name_or_path, config_name) |
|
weight_path = os.path.join(pretrained_model_name_or_path, weight_name) |
|
else: |
|
config_path = hf_hub_download( |
|
repo_id=pretrained_model_name_or_path, filename=config_name |
|
) |
|
weight_path = hf_hub_download( |
|
repo_id=pretrained_model_name_or_path, filename=weight_name |
|
) |
|
|
|
cfg = OmegaConf.load(config_path) |
|
OmegaConf.resolve(cfg) |
|
model = cls(cfg) |
|
load_model(model, weight_path) |
|
return model |
|
|
|
@property |
|
def device(self): |
|
return next(self.parameters()).device |
|
|
|
def configure(self): |
|
self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)( |
|
self.cfg.image_tokenizer |
|
) |
|
self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer) |
|
self.camera_embedder = find_class(self.cfg.camera_embedder_cls)( |
|
self.cfg.camera_embedder |
|
) |
|
self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone) |
|
self.post_processor = find_class(self.cfg.post_processor_cls)( |
|
self.cfg.post_processor |
|
) |
|
self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder) |
|
self.image_estimator = find_class(self.cfg.image_estimator_cls)( |
|
self.cfg.image_estimator |
|
) |
|
self.global_estimator = find_class(self.cfg.global_estimator_cls)( |
|
self.cfg.global_estimator |
|
) |
|
|
|
self.bbox: Float[Tensor, "2 3"] |
|
self.register_buffer( |
|
"bbox", |
|
torch.as_tensor( |
|
[ |
|
[-self.cfg.radius, -self.cfg.radius, -self.cfg.radius], |
|
[self.cfg.radius, self.cfg.radius, self.cfg.radius], |
|
], |
|
dtype=torch.float32, |
|
), |
|
) |
|
self.isosurface_helper = MarchingTetrahedraHelper( |
|
self.cfg.isosurface_resolution, |
|
os.path.join( |
|
os.path.dirname(__file__), |
|
"..", |
|
"load", |
|
"tets", |
|
f"{self.cfg.isosurface_resolution}_tets.npz", |
|
), |
|
) |
|
|
|
self.image_processor = ImageProcessor() |
|
|
|
def triplane_to_meshes( |
|
self, triplanes: Float[Tensor, "B 3 Cp Hp Wp"] |
|
) -> list[Mesh]: |
|
meshes = [] |
|
for i in range(triplanes.shape[0]): |
|
triplane = triplanes[i] |
|
grid_vertices = scale_tensor( |
|
self.isosurface_helper.grid_vertices.to(triplanes.device), |
|
self.isosurface_helper.points_range, |
|
self.bbox, |
|
) |
|
|
|
values = self.query_triplane(grid_vertices, triplane) |
|
decoded = self.decoder(values, include=["vertex_offset", "density"]) |
|
sdf = decoded["density"] - self.cfg.isosurface_threshold |
|
|
|
deform = decoded["vertex_offset"].squeeze(0) |
|
|
|
mesh: Mesh = self.isosurface_helper( |
|
sdf.view(-1, 1), deform.view(-1, 3) if deform is not None else None |
|
) |
|
mesh.v_pos = scale_tensor( |
|
mesh.v_pos, self.isosurface_helper.points_range, self.bbox |
|
) |
|
|
|
meshes.append(mesh) |
|
|
|
return meshes |
|
|
|
def query_triplane( |
|
self, |
|
positions: Float[Tensor, "*B N 3"], |
|
triplanes: Float[Tensor, "*B 3 Cp Hp Wp"], |
|
) -> Float[Tensor, "*B N F"]: |
|
batched = positions.ndim == 3 |
|
if not batched: |
|
|
|
triplanes = triplanes[None, ...] |
|
positions = positions[None, ...] |
|
assert triplanes.ndim == 5 and positions.ndim == 3 |
|
|
|
positions = scale_tensor( |
|
positions, (-self.cfg.radius, self.cfg.radius), (-1, 1) |
|
) |
|
|
|
indices2D: Float[Tensor, "B 3 N 2"] = torch.stack( |
|
(positions[..., [0, 1]], positions[..., [0, 2]], positions[..., [1, 2]]), |
|
dim=-3, |
|
).to(triplanes.dtype) |
|
out: Float[Tensor, "B3 Cp 1 N"] = F.grid_sample( |
|
rearrange(triplanes, "B Np Cp Hp Wp -> (B Np) Cp Hp Wp", Np=3).float(), |
|
rearrange(indices2D, "B Np N Nd -> (B Np) () N Nd", Np=3).float(), |
|
align_corners=True, |
|
mode="bilinear", |
|
) |
|
out = rearrange(out, "(B Np) Cp () N -> B N (Np Cp)", Np=3) |
|
|
|
return out |
|
|
|
def get_scene_codes(self, batch) -> Float[Tensor, "B 3 C H W"]: |
|
|
|
if len(batch["rgb_cond"].shape) == 4: |
|
batch["rgb_cond"] = batch["rgb_cond"].unsqueeze(1) |
|
batch["mask_cond"] = batch["mask_cond"].unsqueeze(1) |
|
batch["c2w_cond"] = batch["c2w_cond"].unsqueeze(1) |
|
batch["intrinsic_cond"] = batch["intrinsic_cond"].unsqueeze(1) |
|
batch["intrinsic_normed_cond"] = batch["intrinsic_normed_cond"].unsqueeze(1) |
|
batch_size, n_input_views = batch["rgb_cond"].shape[:2] |
|
|
|
camera_embeds: Optional[Float[Tensor, "B Nv Cc"]] |
|
camera_embeds = self.camera_embedder(**batch) |
|
|
|
input_image_tokens: Float[Tensor, "B Nv Cit Nit"] = self.image_tokenizer( |
|
rearrange(batch["rgb_cond"], "B Nv H W C -> B Nv C H W"), |
|
modulation_cond=camera_embeds, |
|
) |
|
|
|
input_image_tokens = rearrange( |
|
input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=n_input_views |
|
) |
|
|
|
tokens: Float[Tensor, "B Ct Nt"] = self.tokenizer(batch_size) |
|
|
|
tokens = self.backbone( |
|
tokens, |
|
encoder_hidden_states=input_image_tokens, |
|
modulation_cond=None, |
|
) |
|
|
|
direct_codes = self.tokenizer.detokenize(tokens) |
|
scene_codes = self.post_processor(direct_codes) |
|
return scene_codes, direct_codes |
|
|
|
def run_image( |
|
self, |
|
image: Image, |
|
bake_resolution: int, |
|
estimate_illumination: bool = False, |
|
) -> Tuple[trimesh.Trimesh, dict[str, Any]]: |
|
if image.mode != "RGBA": |
|
raise ValueError("Image must be in RGBA mode") |
|
img_cond = ( |
|
torch.from_numpy( |
|
np.asarray( |
|
image.resize((self.cfg.cond_image_size, self.cfg.cond_image_size)) |
|
).astype(np.float32) |
|
/ 255.0 |
|
) |
|
.float() |
|
.clip(0, 1) |
|
.to(self.device) |
|
) |
|
mask_cond = img_cond[:, :, -1:] |
|
rgb_cond = torch.lerp( |
|
torch.tensor(self.cfg.background_color, device=self.device)[None, None, :], |
|
img_cond[:, :, :3], |
|
mask_cond, |
|
) |
|
|
|
c2w_cond = default_cond_c2w(self.cfg.default_distance).to(self.device) |
|
intrinsic, intrinsic_normed_cond = create_intrinsic_from_fov_deg( |
|
self.cfg.default_fovy_deg, |
|
self.cfg.cond_image_size, |
|
self.cfg.cond_image_size, |
|
) |
|
|
|
batch = { |
|
"rgb_cond": rgb_cond, |
|
"mask_cond": mask_cond, |
|
"c2w_cond": c2w_cond.unsqueeze(0), |
|
"intrinsic_cond": intrinsic.to(self.device).unsqueeze(0), |
|
"intrinsic_normed_cond": intrinsic_normed_cond.to(self.device).unsqueeze(0), |
|
} |
|
|
|
meshes, global_dict = self.generate_mesh( |
|
batch, bake_resolution, estimate_illumination |
|
) |
|
return meshes[0], global_dict |
|
|
|
def generate_mesh( |
|
self, |
|
batch, |
|
bake_resolution: int, |
|
estimate_illumination: bool = False, |
|
) -> Tuple[List[trimesh.Trimesh], dict[str, Any]]: |
|
from .texture_baker import TextureBaker |
|
|
|
baker = TextureBaker() |
|
batch["rgb_cond"] = self.image_processor( |
|
batch["rgb_cond"], self.cfg.cond_image_size |
|
) |
|
batch["mask_cond"] = self.image_processor( |
|
batch["mask_cond"], self.cfg.cond_image_size |
|
) |
|
scene_codes, non_postprocessed_codes = self.get_scene_codes(batch) |
|
|
|
global_dict = {} |
|
if self.image_estimator is not None: |
|
global_dict.update( |
|
self.image_estimator(batch["rgb_cond"] * batch["mask_cond"]) |
|
) |
|
if self.global_estimator is not None and estimate_illumination: |
|
global_dict.update(self.global_estimator(non_postprocessed_codes)) |
|
|
|
with torch.no_grad(): |
|
with torch.autocast(device_type="cuda", enabled=False): |
|
meshes = self.triplane_to_meshes(scene_codes) |
|
|
|
rets = [] |
|
for i, mesh in enumerate(meshes): |
|
|
|
if mesh.v_pos.shape[0] == 0: |
|
rets.append(trimesh.Trimesh()) |
|
continue |
|
|
|
mesh.unwrap_uv() |
|
|
|
|
|
rast = baker.rasterize(mesh.v_tex, mesh.t_pos_idx, bake_resolution) |
|
bake_mask = baker.get_mask(rast) |
|
|
|
pos_bake = baker.interpolate( |
|
mesh.v_pos, |
|
rast, |
|
mesh.t_pos_idx, |
|
mesh.v_tex, |
|
) |
|
gb_pos = pos_bake[bake_mask] |
|
|
|
tri_query = self.query_triplane(gb_pos, scene_codes[i])[0] |
|
decoded = self.decoder( |
|
tri_query, exclude=["density", "vertex_offset"] |
|
) |
|
|
|
nrm = baker.interpolate( |
|
mesh.v_nrm, |
|
rast, |
|
mesh.t_pos_idx, |
|
mesh.v_tex, |
|
) |
|
gb_nrm = F.normalize(nrm[bake_mask], dim=-1) |
|
decoded["normal"] = gb_nrm |
|
|
|
|
|
for k, v in global_dict.items(): |
|
if k.startswith("decoder_"): |
|
decoded[k.replace("decoder_", "")] = v[i] |
|
|
|
mat_out = { |
|
"albedo": decoded["features"], |
|
"roughness": decoded["roughness"], |
|
"metallic": decoded["metallic"], |
|
"normal": normalize(decoded["perturb_normal"]), |
|
"bump": None, |
|
} |
|
|
|
for k, v in mat_out.items(): |
|
if v is None: |
|
continue |
|
if v.shape[0] == 1: |
|
|
|
mat_out[k] = v[0] |
|
else: |
|
f = torch.zeros( |
|
bake_resolution, |
|
bake_resolution, |
|
v.shape[-1], |
|
dtype=v.dtype, |
|
device=v.device, |
|
) |
|
if v.shape == f.shape: |
|
continue |
|
if k == "normal": |
|
|
|
|
|
tng = baker.interpolate( |
|
mesh.v_tng, |
|
rast, |
|
mesh.t_pos_idx, |
|
mesh.v_tex, |
|
) |
|
gb_tng = tng[bake_mask] |
|
gb_tng = F.normalize(gb_tng, dim=-1) |
|
gb_btng = F.normalize( |
|
torch.cross(gb_tng, gb_nrm, dim=-1), dim=-1 |
|
) |
|
normal = F.normalize(mat_out["normal"], dim=-1) |
|
|
|
bump = torch.cat( |
|
|
|
( |
|
dot(normal, gb_tng), |
|
dot(normal, gb_btng), |
|
dot(normal, gb_nrm).clip( |
|
0.3, 1 |
|
), |
|
), |
|
-1, |
|
) |
|
bump = (bump * 0.5 + 0.5).clamp(0, 1) |
|
|
|
f[bake_mask] = bump.view(-1, 3) |
|
mat_out["bump"] = f |
|
else: |
|
f[bake_mask] = v.view(-1, v.shape[-1]) |
|
mat_out[k] = f |
|
|
|
def uv_padding(arr): |
|
if arr.ndim == 1: |
|
return arr |
|
return ( |
|
dilate_fill( |
|
arr.permute(2, 0, 1)[None, ...], |
|
bake_mask.unsqueeze(0).unsqueeze(0), |
|
iterations=bake_resolution // 150, |
|
) |
|
.squeeze(0) |
|
.permute(1, 2, 0) |
|
) |
|
|
|
verts_np = convert_data(mesh.v_pos) |
|
faces = convert_data(mesh.t_pos_idx) |
|
uvs = convert_data(mesh.v_tex) |
|
|
|
basecolor_tex = Image.fromarray( |
|
float32_to_uint8_np(convert_data(uv_padding(mat_out["albedo"]))) |
|
).convert("RGB") |
|
basecolor_tex.format = "JPEG" |
|
|
|
metallic = mat_out["metallic"].squeeze().cpu().item() |
|
roughness = mat_out["roughness"].squeeze().cpu().item() |
|
|
|
if "bump" in mat_out and mat_out["bump"] is not None: |
|
bump_np = convert_data(uv_padding(mat_out["bump"])) |
|
bump_up = np.ones_like(bump_np) |
|
bump_up[..., :2] = 0.5 |
|
bump_up[..., 2:] = 1 |
|
bump_tex = Image.fromarray( |
|
float32_to_uint8_np( |
|
bump_np, |
|
dither=True, |
|
|
|
dither_mask=np.all( |
|
bump_np == bump_up, axis=-1, keepdims=True |
|
).astype(np.float32), |
|
) |
|
).convert("RGB") |
|
bump_tex.format = ( |
|
"JPEG" |
|
) |
|
else: |
|
bump_tex = None |
|
|
|
material = trimesh.visual.material.PBRMaterial( |
|
baseColorTexture=basecolor_tex, |
|
roughnessFactor=roughness, |
|
metallicFactor=metallic, |
|
normalTexture=bump_tex, |
|
) |
|
|
|
tmesh = trimesh.Trimesh( |
|
vertices=verts_np, |
|
faces=faces, |
|
visual=trimesh.visual.texture.TextureVisuals( |
|
uv=uvs, material=material |
|
), |
|
) |
|
rot = trimesh.transformations.rotation_matrix( |
|
np.radians(-90), [1, 0, 0] |
|
) |
|
tmesh.apply_transform(rot) |
|
tmesh.apply_transform( |
|
trimesh.transformations.rotation_matrix( |
|
np.radians(90), [0, 1, 0] |
|
) |
|
) |
|
|
|
tmesh.invert() |
|
|
|
rets.append(tmesh) |
|
|
|
return rets, global_dict |
|
|