InstantMesh / pipeline.py
dylanebert's picture
dylanebert HF staff
update pipeline
167816e
raw
history blame
6.79 kB
from dataclasses import dataclass
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from diffusers import DiffusionPipeline
from diffusers.utils import BaseOutput
def pad_camera_extrinsics_4x4(extrinsics):
if extrinsics.shape[-2] == 4:
return extrinsics
padding = torch.tensor([[0, 0, 0, 1]]).to(extrinsics)
if extrinsics.ndim == 3:
padding = padding.unsqueeze(0).repeat(extrinsics.shape[0], 1, 1)
extrinsics = torch.cat([extrinsics, padding], dim=-2)
return extrinsics
def center_looking_at_camera_pose(
camera_position: torch.Tensor,
look_at: torch.Tensor = None,
up_world: torch.Tensor = None,
):
if look_at is None:
look_at = torch.tensor([0, 0, 0], dtype=torch.float32)
if up_world is None:
up_world = torch.tensor([0, 0, 1], dtype=torch.float32)
if camera_position.ndim == 2:
look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1)
up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1)
z_axis = camera_position - look_at
z_axis = F.normalize(z_axis, dim=-1).float()
x_axis = torch.linalg.cross(up_world, z_axis, dim=-1)
x_axis = F.normalize(x_axis, dim=-1).float()
y_axis = torch.linalg.cross(z_axis, x_axis, dim=-1)
y_axis = F.normalize(y_axis, dim=-1).float()
extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1)
extrinsics = pad_camera_extrinsics_4x4(extrinsics)
return extrinsics
def spherical_camera_pose(azimuths: np.ndarray, elevations: np.ndarray, radius=2.5):
azimuths = np.deg2rad(azimuths)
elevations = np.deg2rad(elevations)
xs = radius * np.cos(elevations) * np.cos(azimuths)
ys = radius * np.cos(elevations) * np.sin(azimuths)
zs = radius * np.sin(elevations)
cam_locations = np.stack([xs, ys, zs], axis=-1)
cam_locations = torch.from_numpy(cam_locations).float()
c2ws = center_looking_at_camera_pose(cam_locations)
return c2ws
def FOV_to_intrinsics(fov, device="cpu"):
focal_length = 0.5 / np.tan(np.deg2rad(fov) * 0.5)
intrinsics = torch.tensor(
[[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device
)
return intrinsics
def get_zero123plus_input_cameras(batch_size=1, radius=4.0, fov=30.0):
azimuths = np.array([30, 90, 150, 210, 270, 330]).astype(float)
elevations = np.array([20, -10, 20, -10, 20, -10]).astype(float)
c2ws = spherical_camera_pose(azimuths, elevations, radius)
c2ws = c2ws.float().flatten(-2)
Ks = FOV_to_intrinsics(fov).unsqueeze(0).repeat(6, 1, 1).float().flatten(-2)
extrinsics = c2ws[:, :12]
intrinsics = torch.stack([Ks[:, 0], Ks[:, 4], Ks[:, 2], Ks[:, 5]], dim=-1)
cameras = torch.cat([extrinsics, intrinsics], dim=-1)
return cameras.unsqueeze(0).repeat(batch_size, 1, 1)
@dataclass
class InstantMeshPipelineOutput(BaseOutput):
vertices: np.ndarray
faces: np.ndarray
uvs: np.ndarray
texture: np.ndarray
class InstantMeshPipeline(DiffusionPipeline):
def __init__(self, lrm):
super().__init__()
self.lrm = lrm
self.register_modules(lrm=self.lrm)
@torch.no_grad()
def __call__(self, images: torch.Tensor):
"""if remove_bg:
image = rembg.remove(image)
image = np.array(image)
alpha = np.where(image[..., 3] > 0)
y1, y2, x1, x2 = (
alpha[0].min(),
alpha[0].max(),
alpha[1].min(),
alpha[1].max(),
)
fg = image[y1:y2, x1:x2]
size = max(fg.shape[0], fg.shape[1])
ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
image = np.pad(
fg,
((ph0, ph1), (pw0, pw1), (0, 0)),
mode="constant",
constant_values=((0, 0), (0, 0), (0, 0)),
)
new_size = int(image.shape[0] / 0.85)
ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
ph1, pw1 = new_size - size - ph0, new_size - size - pw0
image = np.pad(
image,
((ph0, ph1), (pw0, pw1), (0, 0)),
mode="constant",
constant_values=((0, 0), (0, 0), (0, 0)),
)
image = Image.fromarray(image)
self.multi_view_diffusion = self.multi_view_diffusion.to(self._execution_device)
images = self.multi_view_diffusion(image).images[0]
images = np.asarray(images, dtype=np.float32) / 255.0
images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float()
n, m = 3, 2
c, h, w = images.shape
images = (
images.view(c, n, h // n, m, w // m).permute(1, 3, 0, 2, 4).contiguous()
)
images = images.view(n * m, c, h // n, w // m)
images = images.unsqueeze(0)
images = v2.functional.resize(
images, 320, interpolation=3, antialias=True
).clamp(0, 1)"""
self.lrm.init_flexicubes_geometry(self._execution_device, fovy=30.0)
cameras = get_zero123plus_input_cameras().to(self._execution_device)
planes = self.lrm.forward_planes(images, cameras)
mesh_out = self.lrm.extract_mesh(
planes,
use_texture_map=True,
texture_resolution=1024,
)
vertices, vertex_indices, uvs, uv_indices, texture = mesh_out
vertices = vertices.cpu().numpy()
vertex_indices = vertex_indices.cpu().numpy()
uvs = uvs.cpu().numpy()
uv_indices = uv_indices.cpu().numpy()
texture = texture.permute(1, 2, 0).cpu().numpy()
vertex_indices_flat = vertex_indices.reshape(-1)
uv_indices_flat = uv_indices.reshape(-1)
vertex_uv_pairs = np.stack([vertex_indices_flat, uv_indices_flat], axis=1)
unique_pairs, unique_indices = np.unique(
vertex_uv_pairs, axis=0, return_inverse=True
)
vertices = vertices[unique_pairs[:, 0]]
uvs = uvs[unique_pairs[:, 1]]
faces = unique_indices.reshape(-1, 3)
lo, hi = 0, 1
img = np.asarray(texture, dtype=np.float32)
img = (img - lo) * (255 / (hi - lo))
img = img.clip(0, 255)
mask = np.sum(img.astype(np.float32), axis=-1, keepdims=True)
mask = (mask <= 3.0).astype(np.float32)
kernel = np.ones((3, 3), "uint8")
dilate_img = cv2.dilate(img, kernel, iterations=1)
img = img * (1 - mask) + dilate_img * mask
img = img.clip(0, 255).astype(np.uint8)
texture = np.ascontiguousarray(img[::-1, :, :])
return InstantMeshPipelineOutput(
vertices=vertices,
faces=faces,
uvs=uvs,
texture=texture,
)