ThomasSimonini HF staff commited on
Commit
26e79c0
Β·
verified Β·
1 Parent(s): 4cea4ee

Change TripoSR to InstantMesh

Browse files

* Replace generative model
* added instant-mesh utils

Files changed (2) hide show
  1. app.py +141 -28
  2. instant-mesh/utils.py +178 -0
app.py CHANGED
@@ -1,6 +1,7 @@
 
1
  import gradio as gr
 
2
  import numpy as np
3
- import spaces
4
  import torch
5
  import rembg
6
  from PIL import Image
@@ -12,11 +13,51 @@ import shlex
12
  import subprocess
13
  import tempfile
14
  import time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
 
 
 
 
 
16
  subprocess.run(shlex.split('pip install wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl'))
17
 
18
  from tsr.system import TSR
19
  from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
 
20
 
21
  HEADER = """
22
  # Generate 3D Assets for Roblox
@@ -37,7 +78,7 @@ We wrote a tutorial here
37
  STEP1_HEADER = """
38
  ## Step 1: Generate the 3D Mesh
39
 
40
- For this step, we use TripoSR, an open-source model for **fast** feedforward 3D reconstruction from a single image, developed in collaboration between [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
41
 
42
  During this step, you need to upload an image of what you want to generate a 3D Model from.
43
 
@@ -46,10 +87,7 @@ During this step, you need to upload an image of what you want to generate a 3D
46
 
47
  - If there's a background, βœ… Remove background.
48
 
49
- - If you find the result is unsatisfied, please try to change the foreground ratio. It might improve the results.
50
-
51
- - To know more about what is the Marching Cubes Resolution check this : https://huggingface.co/learn/ml-for-3d-course/en/unit4/marching-cubes#marching-cubes
52
-
53
  """
54
 
55
  STEP2_HEADER = """
@@ -86,7 +124,8 @@ STEP4_HEADER = """
86
 
87
  """
88
 
89
-
 
90
  # These part of the code (check_input_image and preprocess were taken from https://huggingface.co/spaces/stabilityai/TripoSR/blob/main/app.py)
91
  if torch.cuda.is_available():
92
  device = "cuda:0"
@@ -142,6 +181,61 @@ def generate(image, mc_resolution, formats=["obj", "glb"]):
142
  mesh.export(mesh_path_obj.name)
143
 
144
  return mesh_path_obj.name, mesh_path_glb.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
 
147
  with gr.Blocks() as demo:
@@ -155,30 +249,41 @@ with gr.Blocks() as demo:
155
  image_mode = "RGBA",
156
  sources = "upload",
157
  type="pil",
158
- elem_id="content_image")
159
- processed_image = gr.Image(label="Processed Image", interactive=False)
 
 
 
 
 
160
  with gr.Row():
161
  with gr.Group():
162
  do_remove_background = gr.Checkbox(
163
  label="Remove Background",
164
  value=True)
165
- foreground_ratio = gr.Slider(
166
- label="Foreground Ratio",
167
- minimum=0.5,
168
- maximum=1.0,
169
- value=0.85,
170
- step=0.05,
171
- )
172
- mc_resolution = gr.Slider(
173
- label="Marching Cubes Resolution",
174
- minimum=32,
175
- maximum=320,
176
- value=256,
177
- step=32
178
- )
179
  with gr.Row():
180
  step1_submit = gr.Button("Generate", elem_id="generate", variant="primary")
181
-
 
 
 
 
 
 
 
 
182
  with gr.Column():
183
  with gr.Tab("OBJ"):
184
  output_model_obj = gr.Model3D(
@@ -192,15 +297,23 @@ with gr.Blocks() as demo:
192
  interactive=False,
193
  )
194
  gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
 
 
 
 
195
 
196
  step1_submit.click(fn=check_input_image, inputs=[input_image]).success(
197
  fn=preprocess,
198
- inputs=[input_image, do_remove_background, foreground_ratio],
199
  outputs=[processed_image],
200
  ).success(
201
- fn=generate,
202
- inputs=[processed_image, mc_resolution],
203
- outputs=[output_model_obj, output_model_glb],
 
 
 
 
204
  )
205
  gr.Markdown(STEP2_HEADER)
206
  gr.Markdown(STEP3_HEADER)
 
1
+ import spaces
2
  import gradio as gr
3
+
4
  import numpy as np
 
5
  import torch
6
  import rembg
7
  from PIL import Image
 
13
  import subprocess
14
  import tempfile
15
  import time
16
+ from PIL import Image
17
+ from torchvision.transforms import v2
18
+ from pytorch_lightning import seed_everything
19
+ from omegaconf import OmegaConf
20
+ from einops import rearrange, repeat
21
+ from tqdm import tqdm
22
+ from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
23
+
24
+ import os
25
+ import imageio
26
+ import numpy as np
27
+ import torch
28
+ import rembg
29
+ from PIL import Image
30
+ from torchvision.transforms import v2
31
+ from pytorch_lightning import seed_everything
32
+ from omegaconf import OmegaConf
33
+ from einops import rearrange, repeat
34
+ from tqdm import tqdm
35
+ from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
36
+
37
+ from src.utils.train_util import instantiate_from_config
38
+ from src.utils.camera_util import (
39
+ FOV_to_intrinsics,
40
+ get_zero123plus_input_cameras,
41
+ get_circular_camera_poses,
42
+ )
43
+ from src.utils.mesh_util import save_obj, save_glb
44
+ from src.utils.infer_util import remove_background, resize_foreground, images_to_video
45
+
46
+ import tempfile
47
+ from functools import partial
48
+
49
+ from huggingface_hub import hf_hub_download
50
 
51
+ from instant-mesh import get_render_cameras, find_cuda, check_input_image, generate_mvs, make3d
52
+
53
+
54
+ # This was the code needed for TripoSR
55
+ """
56
  subprocess.run(shlex.split('pip install wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl'))
57
 
58
  from tsr.system import TSR
59
  from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
60
+ """
61
 
62
  HEADER = """
63
  # Generate 3D Assets for Roblox
 
78
  STEP1_HEADER = """
79
  ## Step 1: Generate the 3D Mesh
80
 
81
+ For this step, we use <a href='https://github.com/TencentARC/InstantMesh' target='_blank'>InstantMesh</a>, an open-source model for **fast** feedforward 3D mesh generation from a single image.
82
 
83
  During this step, you need to upload an image of what you want to generate a 3D Model from.
84
 
 
87
 
88
  - If there's a background, βœ… Remove background.
89
 
90
+ - The 3D mesh generation results highly depend on the quality of generated multi-view images. Please try a different **seed value** if the result is unsatisfying (Default: 42).
 
 
 
91
  """
92
 
93
  STEP2_HEADER = """
 
124
 
125
  """
126
 
127
+ # Code for TripoSR
128
+ """
129
  # These part of the code (check_input_image and preprocess were taken from https://huggingface.co/spaces/stabilityai/TripoSR/blob/main/app.py)
130
  if torch.cuda.is_available():
131
  device = "cuda:0"
 
181
  mesh.export(mesh_path_obj.name)
182
 
183
  return mesh_path_obj.name, mesh_path_glb.name
184
+ """
185
+
186
+
187
+
188
+ ###############################################################################
189
+ # Configuration for InstantMesh
190
+ # All this code is from https://huggingface.co/spaces/TencentARC/InstantMesh/blob/main/app.py
191
+ ###############################################################################
192
+ cuda_path = find_cuda()
193
+
194
+ if cuda_path:
195
+ print(f"CUDA installation found at: {cuda_path}")
196
+ else:
197
+ print("CUDA installation not found")
198
+
199
+ config_path = 'configs/instant-mesh-large.yaml'
200
+ config = OmegaConf.load(config_path)
201
+ config_name = os.path.basename(config_path).replace('.yaml', '')
202
+ model_config = config.model_config
203
+ infer_config = config.infer_config
204
+
205
+ IS_FLEXICUBES = True if config_name.startswith('instant-mesh') else False
206
+
207
+ device = torch.device('cuda')
208
+
209
+ # load diffusion model
210
+ print('Loading diffusion model ...')
211
+ pipeline = DiffusionPipeline.from_pretrained(
212
+ "sudo-ai/zero123plus-v1.2",
213
+ custom_pipeline="zero123plus",
214
+ torch_dtype=torch.float16,
215
+ )
216
+ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
217
+ pipeline.scheduler.config, timestep_spacing='trailing'
218
+ )
219
+
220
+ # load custom white-background UNet
221
+ unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
222
+ state_dict = torch.load(unet_ckpt_path, map_location='cpu')
223
+ pipeline.unet.load_state_dict(state_dict, strict=True)
224
+
225
+ pipeline = pipeline.to(device)
226
+
227
+ # load reconstruction model
228
+ print('Loading reconstruction model ...')
229
+ model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model")
230
+ model = instantiate_from_config(model_config)
231
+ state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
232
+ state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
233
+ model.load_state_dict(state_dict, strict=True)
234
+
235
+ model = model.to(device)
236
+
237
+ print('Loading Finished!')
238
+
239
 
240
 
241
  with gr.Blocks() as demo:
 
249
  image_mode = "RGBA",
250
  sources = "upload",
251
  type="pil",
252
+ elem_id="content_image"
253
+ )
254
+ processed_image = gr.Image(label="Processed Image",
255
+ image_mode="RGBA",
256
+ type="pil",
257
+ interactive=False
258
+ )
259
  with gr.Row():
260
  with gr.Group():
261
  do_remove_background = gr.Checkbox(
262
  label="Remove Background",
263
  value=True)
264
+ sample_seed = gr.Number(
265
+ value=42,
266
+ label="Seed Value",
267
+ precision=0
268
+ )
269
+ sample_steps = gr.Slider(
270
+ label="Sample Steps",
271
+ minimum=30,
272
+ maximum=75,
273
+ value=75,
274
+ step=5
275
+ )
 
 
276
  with gr.Row():
277
  step1_submit = gr.Button("Generate", elem_id="generate", variant="primary")
278
+ with gr.Column():
279
+ with gr.Row():
280
+ with gr.Column():
281
+ mv_show_images = gr.Image(
282
+ label="Generated Multi-views",
283
+ type="pil",
284
+ width=379,
285
+ interactive=False
286
+ )
287
  with gr.Column():
288
  with gr.Tab("OBJ"):
289
  output_model_obj = gr.Model3D(
 
297
  interactive=False,
298
  )
299
  gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
300
+ with gr.Row():
301
+ gr.Markdown('''Try a different <b>seed value</b> if the result is unsatisfying (Default: 42).''')
302
+
303
+ mv_images = gr.State()
304
 
305
  step1_submit.click(fn=check_input_image, inputs=[input_image]).success(
306
  fn=preprocess,
307
+ inputs=[input_image, do_remove_background],
308
  outputs=[processed_image],
309
  ).success(
310
+ fn=generate_mvs,
311
+ inputs=[processed_image, sample_steps, sample_seed],
312
+ outputs=[mv_images, mv_show_images],
313
+ ).success(
314
+ fn=make3d,
315
+ inputs=[mv_images],
316
+ outputs=[output_model_obj, output_model_glb]
317
  )
318
  gr.Markdown(STEP2_HEADER)
319
  gr.Markdown(STEP3_HEADER)
instant-mesh/utils.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import imageio
3
+ import numpy as np
4
+ import torch
5
+ import rembg
6
+ from PIL import Image
7
+ from torchvision.transforms import v2
8
+ from pytorch_lightning import seed_everything
9
+ from omegaconf import OmegaConf
10
+ from einops import rearrange, repeat
11
+ from tqdm import tqdm
12
+ from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
13
+
14
+ from src.utils.train_util import instantiate_from_config
15
+ from src.utils.camera_util import (
16
+ FOV_to_intrinsics,
17
+ get_zero123plus_input_cameras,
18
+ get_circular_camera_poses,
19
+ )
20
+ from src.utils.mesh_util import save_obj, save_glb
21
+ from src.utils.infer_util import remove_background, resize_foreground, images_to_video
22
+
23
+ import tempfile
24
+ from functools import partial
25
+
26
+ from huggingface_hub import hf_hub_download
27
+
28
+ import gradio as gr
29
+ import shutil
30
+ import spaces
31
+
32
+
33
+ def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
34
+ """
35
+ Get the rendering camera parameters.
36
+ """
37
+ c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation)
38
+ if is_flexicubes:
39
+ cameras = torch.linalg.inv(c2ws)
40
+ cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
41
+ else:
42
+ extrinsics = c2ws.flatten(-2)
43
+ intrinsics = FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
44
+ cameras = torch.cat([extrinsics, intrinsics], dim=-1)
45
+ cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
46
+ return cameras
47
+
48
+
49
+ import shutil
50
+
51
+ def find_cuda():
52
+ # Check if CUDA_HOME or CUDA_PATH environment variables are set
53
+ cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
54
+
55
+ if cuda_home and os.path.exists(cuda_home):
56
+ return cuda_home
57
+
58
+ # Search for the nvcc executable in the system's PATH
59
+ nvcc_path = shutil.which('nvcc')
60
+
61
+ if nvcc_path:
62
+ # Remove the 'bin/nvcc' part to get the CUDA installation path
63
+ cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
64
+ return cuda_path
65
+
66
+ return None
67
+
68
+ def check_input_image(input_image):
69
+ if input_image is None:
70
+ raise gr.Error("No image uploaded!")
71
+
72
+
73
+ def preprocess(input_image, do_remove_background):
74
+
75
+ rembg_session = rembg.new_session() if do_remove_background else None
76
+
77
+ if do_remove_background:
78
+ input_image = remove_background(input_image, rembg_session)
79
+ input_image = resize_foreground(input_image, 0.85)
80
+
81
+ return input_image
82
+
83
+
84
+ @spaces.GPU
85
+ def generate_mvs(input_image, sample_steps, sample_seed):
86
+
87
+ seed_everything(sample_seed)
88
+
89
+ # sampling
90
+ z123_image = pipeline(
91
+ input_image,
92
+ num_inference_steps=sample_steps
93
+ ).images[0]
94
+
95
+ show_image = np.asarray(z123_image, dtype=np.uint8)
96
+ show_image = torch.from_numpy(show_image) # (960, 640, 3)
97
+ show_image = rearrange(show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
98
+ show_image = rearrange(show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
99
+ show_image = Image.fromarray(show_image.numpy())
100
+
101
+ return z123_image, show_image
102
+
103
+
104
+ @spaces.GPU
105
+ def make3d(images):
106
+
107
+ global model
108
+ if IS_FLEXICUBES:
109
+ model.init_flexicubes_geometry(device, use_renderer=False)
110
+ model = model.eval()
111
+
112
+ images = np.asarray(images, dtype=np.float32) / 255.0
113
+ images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
114
+ images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
115
+
116
+ input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device)
117
+ render_cameras = get_render_cameras(batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device)
118
+
119
+ images = images.unsqueeze(0).to(device)
120
+ images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
121
+
122
+ mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
123
+ print(mesh_fpath)
124
+ mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
125
+ mesh_dirname = os.path.dirname(mesh_fpath)
126
+ video_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.mp4")
127
+ mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
128
+
129
+ with torch.no_grad():
130
+ # get triplane
131
+ planes = model.forward_planes(images, input_cameras)
132
+
133
+ # # get video
134
+ # chunk_size = 20 if IS_FLEXICUBES else 1
135
+ # render_size = 384
136
+
137
+ # frames = []
138
+ # for i in tqdm(range(0, render_cameras.shape[1], chunk_size)):
139
+ # if IS_FLEXICUBES:
140
+ # frame = model.forward_geometry(
141
+ # planes,
142
+ # render_cameras[:, i:i+chunk_size],
143
+ # render_size=render_size,
144
+ # )['img']
145
+ # else:
146
+ # frame = model.synthesizer(
147
+ # planes,
148
+ # cameras=render_cameras[:, i:i+chunk_size],
149
+ # render_size=render_size,
150
+ # )['images_rgb']
151
+ # frames.append(frame)
152
+ # frames = torch.cat(frames, dim=1)
153
+
154
+ # images_to_video(
155
+ # frames[0],
156
+ # video_fpath,
157
+ # fps=30,
158
+ # )
159
+
160
+ # print(f"Video saved to {video_fpath}")
161
+
162
+ # get mesh
163
+ mesh_out = model.extract_mesh(
164
+ planes,
165
+ use_texture_map=False,
166
+ **infer_config,
167
+ )
168
+
169
+ vertices, faces, vertex_colors = mesh_out
170
+ vertices = vertices[:, [1, 2, 0]]
171
+
172
+ save_glb(vertices, faces, vertex_colors, mesh_glb_fpath)
173
+ save_obj(vertices, faces, vertex_colors, mesh_fpath)
174
+
175
+ print(f"Mesh saved to {mesh_fpath}")
176
+
177
+ return mesh_fpath, mesh_glb_fpath
178
+