File size: 6,155 Bytes
9c3a994 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
# -*- coding: utf-8 -*-
import os
import time
from collections import OrderedDict
from typing import Optional, List
import argparse
from functools import partial
from einops import repeat, rearrange
import numpy as np
from PIL import Image
import trimesh
import cv2
import torch
import pytorch_lightning as pl
from michelangelo.models.tsal.tsal_base import Latent2MeshOutput
from michelangelo.models.tsal.inference_utils import extract_geometry
from michelangelo.utils.misc import get_config_from_file, instantiate_from_config
from michelangelo.utils.visualizers.pythreejs_viewer import PyThreeJSViewer
from michelangelo.utils.visualizers import html_util
def load_model(args):
model_config = get_config_from_file(args.config_path)
if hasattr(model_config, "model"):
model_config = model_config.model
model = instantiate_from_config(model_config, ckpt_path=args.ckpt_path)
model = model.cuda()
model = model.eval()
return model
def load_surface(fp):
with np.load(args.pointcloud_path) as input_pc:
surface = input_pc['points']
normal = input_pc['normals']
rng = np.random.default_rng()
ind = rng.choice(surface.shape[0], 4096, replace=False)
surface = torch.FloatTensor(surface[ind])
normal = torch.FloatTensor(normal[ind])
surface = torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda()
return surface
def prepare_image(args, number_samples=2):
image = cv2.imread(f"{args.image_path}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image_pt = torch.tensor(image).float()
image_pt = image_pt / 255 * 2 - 1
image_pt = rearrange(image_pt, "h w c -> c h w")
image_pt = repeat(image_pt, "c h w -> b c h w", b=number_samples)
return image_pt
def save_output(args, mesh_outputs):
os.makedirs(args.output_dir, exist_ok=True)
for i, mesh in enumerate(mesh_outputs):
mesh.mesh_f = mesh.mesh_f[:, ::-1]
mesh_output = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f)
name = str(i) + "_out_mesh.obj"
mesh_output.export(os.path.join(args.output_dir, name), include_normals=True)
print(f'-----------------------------------------------------------------------------')
print(f'>>> Finished and mesh saved in {args.output_dir}')
print(f'-----------------------------------------------------------------------------')
return 0
def reconstruction(args, model, bounds=(-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), octree_depth=7, num_chunks=10000):
surface = load_surface(args.pointcloud_path)
# encoding
shape_embed, shape_latents = model.model.encode_shape_embed(surface, return_latents=True)
shape_zq, posterior = model.model.shape_model.encode_kl_embed(shape_latents)
# decoding
latents = model.model.shape_model.decode(shape_zq)
geometric_func = partial(model.model.shape_model.query_geometry, latents=latents)
# reconstruction
mesh_v_f, has_surface = extract_geometry(
geometric_func=geometric_func,
device=surface.device,
batch_size=surface.shape[0],
bounds=bounds,
octree_depth=octree_depth,
num_chunks=num_chunks,
)
recon_mesh = trimesh.Trimesh(mesh_v_f[0][0], mesh_v_f[0][1])
# save
os.makedirs(args.output_dir, exist_ok=True)
recon_mesh.export(os.path.join(args.output_dir, 'reconstruction.obj'))
print(f'-----------------------------------------------------------------------------')
print(f'>>> Finished and mesh saved in {os.path.join(args.output_dir, "reconstruction.obj")}')
print(f'-----------------------------------------------------------------------------')
return 0
def image2mesh(args, model, guidance_scale=7.5, box_v=1.1, octree_depth=7):
sample_inputs = {
"image": prepare_image(args)
}
mesh_outputs = model.sample(
sample_inputs,
sample_times=1,
guidance_scale=guidance_scale,
return_intermediates=False,
bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v],
octree_depth=octree_depth,
)[0]
save_output(args, mesh_outputs)
return 0
def text2mesh(args, model, num_samples=2, guidance_scale=7.5, box_v=1.1, octree_depth=7):
sample_inputs = {
"text": [args.text] * num_samples
}
mesh_outputs = model.sample(
sample_inputs,
sample_times=1,
guidance_scale=guidance_scale,
return_intermediates=False,
bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v],
octree_depth=octree_depth,
)[0]
save_output(args, mesh_outputs)
return 0
task_dick = {
'reconstruction': reconstruction,
'image2mesh': image2mesh,
'text2mesh': text2mesh,
}
if __name__ == "__main__":
'''
1. Reconstruct point cloud
2. Image-conditioned generation
3. Text-conditioned generation
'''
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, choices=['reconstruction', 'image2mesh', 'text2mesh'], required=True)
parser.add_argument("--config_path", type=str, required=True)
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--pointcloud_path", type=str, default='./example_data/surface.npz', help='Path to the input point cloud')
parser.add_argument("--image_path", type=str, help='Path to the input image')
parser.add_argument("--text", type=str, help='Input text within a format: A 3D model of motorcar; Porsche 911.')
parser.add_argument("--output_dir", type=str, default='./output')
parser.add_argument("-s", "--seed", type=int, default=0)
args = parser.parse_args()
pl.seed_everything(args.seed)
print(f'-----------------------------------------------------------------------------')
print(f'>>> Running {args.task}')
args.output_dir = os.path.join(args.output_dir, args.task)
print(f'>>> Output directory: {args.output_dir}')
print(f'-----------------------------------------------------------------------------')
task_dick[args.task](args, load_model(args)) |