import os, argparse, importlib import torch import time import trimesh import numpy as np from MeshAnything.models.meshanything_v2 import MeshAnythingV2 import datetime from accelerate import Accelerator from accelerate.utils import set_seed from accelerate.utils import DistributedDataParallelKwargs from safetensors.torch import load_model from mesh_to_pc import process_mesh_to_pc from huggingface_hub import hf_hub_download class Dataset: def __init__(self, input_type, input_list, mc=False): super().__init__() self.data = [] if input_type == 'pc_normal': for input_path in input_list: # load npy cur_data = np.load(input_path) # sample 4096 assert cur_data.shape[0] >= 8192, "input pc_normal should have at least 4096 points" idx = np.random.choice(cur_data.shape[0], 8192, replace=False) cur_data = cur_data[idx] self.data.append({'pc_normal': cur_data, 'uid': input_path.split('/')[-1].split('.')[0]}) elif input_type == 'mesh': mesh_list = [] for input_path in input_list: # load ply cur_data = trimesh.load(input_path) mesh_list.append(cur_data) if mc: print("First Marching Cubes and then sample point cloud, need several minutes...") pc_list, _ = process_mesh_to_pc(mesh_list, marching_cubes=mc) for input_path, cur_data in zip(input_list, pc_list): self.data.append({'pc_normal': cur_data, 'uid': input_path.split('/')[-1].split('.')[0]}) print(f"dataset total data samples: {len(self.data)}") def __len__(self): return len(self.data) def __getitem__(self, idx): data_dict = {} data_dict['pc_normal'] = self.data[idx]['pc_normal'] # normalize pc coor pc_coor = data_dict['pc_normal'][:, :3] normals = data_dict['pc_normal'][:, 3:] bounds = np.array([pc_coor.min(axis=0), pc_coor.max(axis=0)]) pc_coor = pc_coor - (bounds[0] + bounds[1])[None, :] / 2 pc_coor = pc_coor / np.abs(pc_coor).max() * 0.9995 assert (np.linalg.norm(normals, axis=-1) > 0.99).all(), "normals should be unit vectors, something wrong" data_dict['pc_normal'] = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16) data_dict['uid'] = self.data[idx]['uid'] return data_dict def get_args(): parser = argparse.ArgumentParser("MeshAnything", add_help=False) parser.add_argument('--input_dir', default=None, type=str) parser.add_argument('--input_path', default=None, type=str) parser.add_argument('--out_dir', default="inference_out", type=str) parser.add_argument( '--input_type', choices=['mesh','pc_normal'], default='pc', help="Type of the asset to process (default: pc)" ) parser.add_argument("--batchsize_per_gpu", default=1, type=int) parser.add_argument("--seed", default=0, type=int) parser.add_argument("--mc", default=False, action="store_true") parser.add_argument("--sampling", default=False, action="store_true") args = parser.parse_args() return args def load_v2(): model = MeshAnythingV2() print("load model over!!!") ckpt_path = hf_hub_download( repo_id="Yiwen-ntu/MeshAnythingV2", filename="350m.pth", ) load_model(model, ckpt_path) print("load weights over!!!") return model if __name__ == "__main__": args = get_args() cur_time = datetime.datetime.now().strftime("%d_%H-%M-%S") checkpoint_dir = os.path.join(args.out_dir, cur_time) os.makedirs(checkpoint_dir, exist_ok=True) kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator( mixed_precision="fp16", project_dir=checkpoint_dir, kwargs_handlers=[kwargs] ) model = load_v2() # create dataset if args.input_dir is not None: input_list = sorted(os.listdir(args.input_dir)) # only ply, obj or npy if args.input_type == 'pc_normal': input_list = [os.path.join(args.input_dir, x) for x in input_list if x.endswith('.npy')] else: input_list = [os.path.join(args.input_dir, x) for x in input_list if x.endswith('.ply') or x.endswith('.obj') or x.endswith('.npy')] set_seed(args.seed) dataset = Dataset(args.input_type, input_list, args.mc) elif args.input_path is not None: set_seed(args.seed) dataset = Dataset(args.input_type, [args.input_path], args.mc) else: raise ValueError("input_dir or input_path must be provided.") dataloader = torch.utils.data.DataLoader( dataset, batch_size=args.batchsize_per_gpu, drop_last = False, shuffle = False, ) if accelerator.state.num_processes > 1: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) dataloader, model = accelerator.prepare(dataloader, model) begin_time = time.time() print("Generation Start!!!") with accelerator.autocast(): for curr_iter, batch_data_label in enumerate(dataloader): curr_time = time.time() outputs = model(batch_data_label['pc_normal'], sampling=args.sampling) batch_size = outputs.shape[0] device = outputs.device for batch_id in range(batch_size): recon_mesh = outputs[batch_id] valid_mask = torch.all(~torch.isnan(recon_mesh.reshape((-1, 9))), dim=1) recon_mesh = recon_mesh[valid_mask] # nvalid_face x 3 x 3 vertices = recon_mesh.reshape(-1, 3).cpu() vertices_index = np.arange(len(vertices)) # 0, 1, ..., 3 x face triangles = vertices_index.reshape(-1, 3) scene_mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, force="mesh", merge_primitives=True) scene_mesh.merge_vertices() scene_mesh.update_faces(scene_mesh.nondegenerate_faces()) scene_mesh.update_faces(scene_mesh.unique_faces()) scene_mesh.remove_unreferenced_vertices() scene_mesh.fix_normals() save_path = os.path.join(checkpoint_dir, f'{batch_data_label["uid"][batch_id]}_gen.obj') num_faces = len(scene_mesh.faces) brown_color = np.array([255, 165, 0, 255], dtype=np.uint8) face_colors = np.tile(brown_color, (num_faces, 1)) scene_mesh.visual.face_colors = face_colors scene_mesh.export(save_path) print(f"{save_path} Over!!") end_time = time.time() print(f"Total time: {end_time - begin_time}")