|
from dataclasses import dataclass |
|
|
|
import cv2 |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from diffusers import DiffusionPipeline |
|
from diffusers.utils import BaseOutput |
|
|
|
|
|
def pad_camera_extrinsics_4x4(extrinsics): |
|
if extrinsics.shape[-2] == 4: |
|
return extrinsics |
|
padding = torch.tensor([[0, 0, 0, 1]]).to(extrinsics) |
|
if extrinsics.ndim == 3: |
|
padding = padding.unsqueeze(0).repeat(extrinsics.shape[0], 1, 1) |
|
extrinsics = torch.cat([extrinsics, padding], dim=-2) |
|
return extrinsics |
|
|
|
|
|
def center_looking_at_camera_pose( |
|
camera_position: torch.Tensor, |
|
look_at: torch.Tensor = None, |
|
up_world: torch.Tensor = None, |
|
): |
|
if look_at is None: |
|
look_at = torch.tensor([0, 0, 0], dtype=torch.float32) |
|
if up_world is None: |
|
up_world = torch.tensor([0, 0, 1], dtype=torch.float32) |
|
if camera_position.ndim == 2: |
|
look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1) |
|
up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1) |
|
|
|
z_axis = camera_position - look_at |
|
z_axis = F.normalize(z_axis, dim=-1).float() |
|
x_axis = torch.linalg.cross(up_world, z_axis, dim=-1) |
|
x_axis = F.normalize(x_axis, dim=-1).float() |
|
y_axis = torch.linalg.cross(z_axis, x_axis, dim=-1) |
|
y_axis = F.normalize(y_axis, dim=-1).float() |
|
|
|
extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1) |
|
extrinsics = pad_camera_extrinsics_4x4(extrinsics) |
|
return extrinsics |
|
|
|
|
|
def spherical_camera_pose(azimuths: np.ndarray, elevations: np.ndarray, radius=2.5): |
|
azimuths = np.deg2rad(azimuths) |
|
elevations = np.deg2rad(elevations) |
|
|
|
xs = radius * np.cos(elevations) * np.cos(azimuths) |
|
ys = radius * np.cos(elevations) * np.sin(azimuths) |
|
zs = radius * np.sin(elevations) |
|
|
|
cam_locations = np.stack([xs, ys, zs], axis=-1) |
|
cam_locations = torch.from_numpy(cam_locations).float() |
|
|
|
c2ws = center_looking_at_camera_pose(cam_locations) |
|
return c2ws |
|
|
|
|
|
def FOV_to_intrinsics(fov, device="cpu"): |
|
focal_length = 0.5 / np.tan(np.deg2rad(fov) * 0.5) |
|
intrinsics = torch.tensor( |
|
[[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device |
|
) |
|
return intrinsics |
|
|
|
|
|
def get_zero123plus_input_cameras(batch_size=1, radius=4.0, fov=30.0): |
|
azimuths = np.array([30, 90, 150, 210, 270, 330]).astype(float) |
|
elevations = np.array([20, -10, 20, -10, 20, -10]).astype(float) |
|
|
|
c2ws = spherical_camera_pose(azimuths, elevations, radius) |
|
c2ws = c2ws.float().flatten(-2) |
|
|
|
Ks = FOV_to_intrinsics(fov).unsqueeze(0).repeat(6, 1, 1).float().flatten(-2) |
|
|
|
extrinsics = c2ws[:, :12] |
|
intrinsics = torch.stack([Ks[:, 0], Ks[:, 4], Ks[:, 2], Ks[:, 5]], dim=-1) |
|
cameras = torch.cat([extrinsics, intrinsics], dim=-1) |
|
|
|
return cameras.unsqueeze(0).repeat(batch_size, 1, 1) |
|
|
|
|
|
@dataclass |
|
class InstantMeshPipelineOutput(BaseOutput): |
|
vertices: np.ndarray |
|
faces: np.ndarray |
|
uvs: np.ndarray |
|
texture: np.ndarray |
|
|
|
|
|
class InstantMeshPipeline(DiffusionPipeline): |
|
def __init__(self, lrm): |
|
super().__init__() |
|
self.lrm = lrm |
|
self.register_modules(lrm=self.lrm) |
|
|
|
@torch.no_grad() |
|
def __call__(self, images: torch.Tensor): |
|
"""if remove_bg: |
|
image = rembg.remove(image) |
|
|
|
image = np.array(image) |
|
alpha = np.where(image[..., 3] > 0) |
|
y1, y2, x1, x2 = ( |
|
alpha[0].min(), |
|
alpha[0].max(), |
|
alpha[1].min(), |
|
alpha[1].max(), |
|
) |
|
fg = image[y1:y2, x1:x2] |
|
size = max(fg.shape[0], fg.shape[1]) |
|
ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2 |
|
ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0 |
|
image = np.pad( |
|
fg, |
|
((ph0, ph1), (pw0, pw1), (0, 0)), |
|
mode="constant", |
|
constant_values=((0, 0), (0, 0), (0, 0)), |
|
) |
|
|
|
new_size = int(image.shape[0] / 0.85) |
|
ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2 |
|
ph1, pw1 = new_size - size - ph0, new_size - size - pw0 |
|
image = np.pad( |
|
image, |
|
((ph0, ph1), (pw0, pw1), (0, 0)), |
|
mode="constant", |
|
constant_values=((0, 0), (0, 0), (0, 0)), |
|
) |
|
image = Image.fromarray(image) |
|
|
|
self.multi_view_diffusion = self.multi_view_diffusion.to(self._execution_device) |
|
images = self.multi_view_diffusion(image).images[0] |
|
|
|
images = np.asarray(images, dtype=np.float32) / 255.0 |
|
images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() |
|
|
|
n, m = 3, 2 |
|
c, h, w = images.shape |
|
images = ( |
|
images.view(c, n, h // n, m, w // m).permute(1, 3, 0, 2, 4).contiguous() |
|
) |
|
images = images.view(n * m, c, h // n, w // m) |
|
|
|
images = images.unsqueeze(0) |
|
images = v2.functional.resize( |
|
images, 320, interpolation=3, antialias=True |
|
).clamp(0, 1)""" |
|
|
|
self.lrm.init_flexicubes_geometry(self._execution_device, fovy=30.0) |
|
cameras = get_zero123plus_input_cameras().to(self._execution_device) |
|
planes = self.lrm.forward_planes(images, cameras) |
|
mesh_out = self.lrm.extract_mesh( |
|
planes, |
|
use_texture_map=True, |
|
texture_resolution=1024, |
|
) |
|
vertices, vertex_indices, uvs, uv_indices, texture = mesh_out |
|
|
|
vertices = vertices.cpu().numpy() |
|
vertex_indices = vertex_indices.cpu().numpy() |
|
uvs = uvs.cpu().numpy() |
|
uv_indices = uv_indices.cpu().numpy() |
|
texture = texture.permute(1, 2, 0).cpu().numpy() |
|
|
|
vertex_indices_flat = vertex_indices.reshape(-1) |
|
uv_indices_flat = uv_indices.reshape(-1) |
|
vertex_uv_pairs = np.stack([vertex_indices_flat, uv_indices_flat], axis=1) |
|
unique_pairs, unique_indices = np.unique( |
|
vertex_uv_pairs, axis=0, return_inverse=True |
|
) |
|
|
|
vertices = vertices[unique_pairs[:, 0]] |
|
uvs = uvs[unique_pairs[:, 1]] |
|
faces = unique_indices.reshape(-1, 3) |
|
|
|
lo, hi = 0, 1 |
|
img = np.asarray(texture, dtype=np.float32) |
|
img = (img - lo) * (255 / (hi - lo)) |
|
img = img.clip(0, 255) |
|
mask = np.sum(img.astype(np.float32), axis=-1, keepdims=True) |
|
mask = (mask <= 3.0).astype(np.float32) |
|
kernel = np.ones((3, 3), "uint8") |
|
dilate_img = cv2.dilate(img, kernel, iterations=1) |
|
img = img * (1 - mask) + dilate_img * mask |
|
img = img.clip(0, 255).astype(np.uint8) |
|
texture = np.ascontiguousarray(img[::-1, :, :]) |
|
|
|
return InstantMeshPipelineOutput( |
|
vertices=vertices, |
|
faces=faces, |
|
uvs=uvs, |
|
texture=texture, |
|
) |
|
|