ThomasSimonini's picture
Change TripoSR to InstantMesh
26e79c0 verified
raw
history blame
5.65 kB
import os
import imageio
import numpy as np
import torch
import rembg
from PIL import Image
from torchvision.transforms import v2
from pytorch_lightning import seed_everything
from omegaconf import OmegaConf
from einops import rearrange, repeat
from tqdm import tqdm
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
from src.utils.train_util import instantiate_from_config
from src.utils.camera_util import (
FOV_to_intrinsics,
get_zero123plus_input_cameras,
get_circular_camera_poses,
)
from src.utils.mesh_util import save_obj, save_glb
from src.utils.infer_util import remove_background, resize_foreground, images_to_video
import tempfile
from functools import partial
from huggingface_hub import hf_hub_download
import gradio as gr
import shutil
import spaces
def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
"""
Get the rendering camera parameters.
"""
c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation)
if is_flexicubes:
cameras = torch.linalg.inv(c2ws)
cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
else:
extrinsics = c2ws.flatten(-2)
intrinsics = FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
cameras = torch.cat([extrinsics, intrinsics], dim=-1)
cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
return cameras
import shutil
def find_cuda():
# Check if CUDA_HOME or CUDA_PATH environment variables are set
cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
if cuda_home and os.path.exists(cuda_home):
return cuda_home
# Search for the nvcc executable in the system's PATH
nvcc_path = shutil.which('nvcc')
if nvcc_path:
# Remove the 'bin/nvcc' part to get the CUDA installation path
cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
return cuda_path
return None
def check_input_image(input_image):
if input_image is None:
raise gr.Error("No image uploaded!")
def preprocess(input_image, do_remove_background):
rembg_session = rembg.new_session() if do_remove_background else None
if do_remove_background:
input_image = remove_background(input_image, rembg_session)
input_image = resize_foreground(input_image, 0.85)
return input_image
@spaces.GPU
def generate_mvs(input_image, sample_steps, sample_seed):
seed_everything(sample_seed)
# sampling
z123_image = pipeline(
input_image,
num_inference_steps=sample_steps
).images[0]
show_image = np.asarray(z123_image, dtype=np.uint8)
show_image = torch.from_numpy(show_image) # (960, 640, 3)
show_image = rearrange(show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
show_image = rearrange(show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
show_image = Image.fromarray(show_image.numpy())
return z123_image, show_image
@spaces.GPU
def make3d(images):
global model
if IS_FLEXICUBES:
model.init_flexicubes_geometry(device, use_renderer=False)
model = model.eval()
images = np.asarray(images, dtype=np.float32) / 255.0
images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device)
render_cameras = get_render_cameras(batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device)
images = images.unsqueeze(0).to(device)
images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
print(mesh_fpath)
mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
mesh_dirname = os.path.dirname(mesh_fpath)
video_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.mp4")
mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
with torch.no_grad():
# get triplane
planes = model.forward_planes(images, input_cameras)
# # get video
# chunk_size = 20 if IS_FLEXICUBES else 1
# render_size = 384
# frames = []
# for i in tqdm(range(0, render_cameras.shape[1], chunk_size)):
# if IS_FLEXICUBES:
# frame = model.forward_geometry(
# planes,
# render_cameras[:, i:i+chunk_size],
# render_size=render_size,
# )['img']
# else:
# frame = model.synthesizer(
# planes,
# cameras=render_cameras[:, i:i+chunk_size],
# render_size=render_size,
# )['images_rgb']
# frames.append(frame)
# frames = torch.cat(frames, dim=1)
# images_to_video(
# frames[0],
# video_fpath,
# fps=30,
# )
# print(f"Video saved to {video_fpath}")
# get mesh
mesh_out = model.extract_mesh(
planes,
use_texture_map=False,
**infer_config,
)
vertices, faces, vertex_colors = mesh_out
vertices = vertices[:, [1, 2, 0]]
save_glb(vertices, faces, vertex_colors, mesh_glb_fpath)
save_obj(vertices, faces, vertex_colors, mesh_fpath)
print(f"Mesh saved to {mesh_fpath}")
return mesh_fpath, mesh_glb_fpath