ThomasSimonini HF staff commited on
Commit
3044d56
1 Parent(s): 1e78183

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +232 -2
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import spaces
2
 
3
  import os
 
 
4
  import imageio
5
  import numpy as np
6
  import torch
@@ -31,6 +33,13 @@ from huggingface_hub import hf_hub_download
31
 
32
  import gradio as gr
33
 
 
 
 
 
 
 
 
34
 
35
  ###############################################################################
36
  # Configuration for InstantMesh
@@ -212,6 +221,175 @@ def make3d(images):
212
  return mesh_fpath, mesh_glb_fpath
213
 
214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  ###############################################################################
216
  # Gradio
217
  ###############################################################################
@@ -318,7 +496,8 @@ with gr.Blocks() as demo:
318
  gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
319
  with gr.Row():
320
  gr.Markdown('''Try a different <b>seed value</b> if the result is unsatisfying (Default: 42).''')
321
-
 
322
  mv_images = gr.State()
323
 
324
  step1_submit.click(fn=check_input_image, inputs=[input_image]).success(
@@ -334,7 +513,58 @@ with gr.Blocks() as demo:
334
  inputs=[mv_images],
335
  outputs=[output_model_obj, output_model_glb]
336
  )
337
- gr.Markdown(STEP2_HEADER)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  gr.Markdown(STEP3_HEADER)
339
  gr.Markdown(STEP4_HEADER)
340
 
 
1
  import spaces
2
 
3
  import os
4
+ import time
5
+
6
  import imageio
7
  import numpy as np
8
  import torch
 
33
 
34
  import gradio as gr
35
 
36
+ # Imports for MeshAnythingv2
37
+ from accelerate.utils import set_seed
38
+ from accelerate import Accelerator
39
+ from main import load_v2
40
+ from mesh_to_pc import process_mesh_to_pc
41
+ import matplotlib.pyplot as plt
42
+ from mpl_toolkits.mplot3d.art3d import Poly3DCollection
43
 
44
  ###############################################################################
45
  # Configuration for InstantMesh
 
221
  return mesh_fpath, mesh_glb_fpath
222
 
223
 
224
+ ###############################################################################
225
+ # Configuration for MeshAnythingv2
226
+ ###############################################################################
227
+ model = load_v2()
228
+ device = torch.device('cuda')
229
+ accelerator = Accelerator(
230
+ mixed_precision="fp16",
231
+ )
232
+ model = accelerator.prepare(model)
233
+ model.eval()
234
+ print("Model loaded to device")
235
+
236
+ def wireframe_render(mesh):
237
+ views = [
238
+ (90, 20), (270, 20)
239
+ ]
240
+ mesh.vertices = mesh.vertices[:, [0, 2, 1]]
241
+
242
+ bounding_box = mesh.bounds
243
+ center = mesh.centroid
244
+ scale = np.ptp(bounding_box, axis=0).max()
245
+
246
+ fig = plt.figure(figsize=(10, 10))
247
+
248
+ # Function to render and return each view as an image
249
+ def render_view(mesh, azimuth, elevation):
250
+ ax = fig.add_subplot(111, projection='3d')
251
+ ax.set_axis_off()
252
+
253
+ # Extract vertices and faces for plotting
254
+ vertices = mesh.vertices
255
+ faces = mesh.faces
256
+
257
+ # Plot faces
258
+ ax.add_collection3d(Poly3DCollection(
259
+ vertices[faces],
260
+ facecolors=(0.8, 0.5, 0.2, 1.0), # Brownish yellow
261
+ edgecolors='k',
262
+ linewidths=0.5,
263
+ ))
264
+
265
+ # Set limits and center the view on the object
266
+ ax.set_xlim(center[0] - scale / 2, center[0] + scale / 2)
267
+ ax.set_ylim(center[1] - scale / 2, center[1] + scale / 2)
268
+ ax.set_zlim(center[2] - scale / 2, center[2] + scale / 2)
269
+
270
+ # Set view angle
271
+ ax.view_init(elev=elevation, azim=azimuth)
272
+
273
+ # Save the figure to a buffer
274
+ buf = io.BytesIO()
275
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=300)
276
+ plt.clf()
277
+ buf.seek(0)
278
+
279
+ return Image.open(buf)
280
+
281
+ # Render each view and store in a list
282
+ images = [render_view(mesh, az, el) for az, el in views]
283
+
284
+ # Combine images horizontally
285
+ widths, heights = zip(*(i.size for i in images))
286
+ total_width = sum(widths)
287
+ max_height = max(heights)
288
+
289
+ combined_image = Image.new('RGBA', (total_width, max_height))
290
+
291
+ x_offset = 0
292
+ for img in images:
293
+ combined_image.paste(img, (x_offset, 0))
294
+ x_offset += img.width
295
+
296
+ # Save the combined image
297
+ save_path = f"combined_mesh_view_{int(time.time())}.png"
298
+ combined_image.save(save_path)
299
+
300
+ plt.close(fig)
301
+ return save_path
302
+
303
+ @spaces.GPU(duration=360)
304
+ def do_inference(input_3d, sample_seed=0, do_sampling=False, do_marching_cubes=False):
305
+ set_seed(sample_seed)
306
+ print("Seed value:", sample_seed)
307
+
308
+ input_mesh = trimesh.load(input_3d)
309
+ pc_list, mesh_list = process_mesh_to_pc([input_mesh], marching_cubes = do_marching_cubes)
310
+ pc_normal = pc_list[0] # 4096, 6
311
+ mesh = mesh_list[0]
312
+ vertices = mesh.vertices
313
+
314
+ pc_coor = pc_normal[:, :3]
315
+ normals = pc_normal[:, 3:]
316
+
317
+ bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)])
318
+ # scale mesh and pc
319
+ vertices = vertices - (bounds[0] + bounds[1])[None, :] / 2
320
+ vertices = vertices / (bounds[1] - bounds[0]).max()
321
+ mesh.vertices = vertices
322
+ pc_coor = pc_coor - (bounds[0] + bounds[1])[None, :] / 2
323
+ pc_coor = pc_coor / (bounds[1] - bounds[0]).max()
324
+
325
+ mesh.merge_vertices()
326
+ mesh.update_faces(mesh.nondegenerate_faces())
327
+ mesh.update_faces(mesh.unique_faces())
328
+ mesh.remove_unreferenced_vertices()
329
+ mesh.fix_normals()
330
+ try:
331
+ if mesh.visual.vertex_colors is not None:
332
+ orange_color = np.array([255, 165, 0, 255], dtype=np.uint8)
333
+
334
+ mesh.visual.vertex_colors = np.tile(orange_color, (mesh.vertices.shape[0], 1))
335
+ else:
336
+ orange_color = np.array([255, 165, 0, 255], dtype=np.uint8)
337
+ mesh.visual.vertex_colors = np.tile(orange_color, (mesh.vertices.shape[0], 1))
338
+ except Exception as e:
339
+ print(e)
340
+ input_save_name = f"processed_input_{int(time.time())}.obj"
341
+ mesh.export(input_save_name)
342
+ input_render_res = wireframe_render(mesh)
343
+
344
+ pc_coor = pc_coor / np.abs(pc_coor).max() * 0.99 # input should be from -1 to 1
345
+
346
+ assert (np.linalg.norm(normals, axis=-1) > 0.99).all(), "normals should be unit vectors, something wrong"
347
+ normalized_pc_normal = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16)
348
+
349
+ input = torch.tensor(normalized_pc_normal, dtype=torch.float16, device=device)[None]
350
+ print("Data loaded")
351
+
352
+ # with accelerator.autocast():
353
+ with accelerator.autocast():
354
+ outputs = model(input, do_sampling)
355
+ print("Model inference done")
356
+ recon_mesh = outputs[0]
357
+
358
+ valid_mask = torch.all(~torch.isnan(recon_mesh.reshape((-1, 9))), dim=1)
359
+ recon_mesh = recon_mesh[valid_mask] # nvalid_face x 3 x 3
360
+ vertices = recon_mesh.reshape(-1, 3).cpu()
361
+ vertices_index = np.arange(len(vertices)) # 0, 1, ..., 3 x face
362
+ triangles = vertices_index.reshape(-1, 3)
363
+
364
+ artist_mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, force="mesh",
365
+ merge_primitives=True)
366
+
367
+ artist_mesh.merge_vertices()
368
+ artist_mesh.update_faces(artist_mesh.nondegenerate_faces())
369
+ artist_mesh.update_faces(artist_mesh.unique_faces())
370
+ artist_mesh.remove_unreferenced_vertices()
371
+ artist_mesh.fix_normals()
372
+
373
+ if artist_mesh.visual.vertex_colors is not None:
374
+ orange_color = np.array([255, 165, 0, 255], dtype=np.uint8)
375
+
376
+ artist_mesh.visual.vertex_colors = np.tile(orange_color, (artist_mesh.vertices.shape[0], 1))
377
+ else:
378
+ orange_color = np.array([255, 165, 0, 255], dtype=np.uint8)
379
+ artist_mesh.visual.vertex_colors = np.tile(orange_color, (artist_mesh.vertices.shape[0], 1))
380
+
381
+ num_faces = len(artist_mesh.faces)
382
+
383
+ brown_color = np.array([165, 42, 42, 255], dtype=np.uint8)
384
+ face_colors = np.tile(brown_color, (num_faces, 1))
385
+
386
+ artist_mesh.visual.face_colors = face_colors
387
+ # add time stamp to avoid cache
388
+ save_name = f"output_{int(time.time())}.obj"
389
+ artist_mesh.export(save_name)
390
+ output_render = wireframe_render(artist_mesh)
391
+ return input_save_name, input_render_res, save_name, output_render
392
+
393
  ###############################################################################
394
  # Gradio
395
  ###############################################################################
 
496
  gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
497
  with gr.Row():
498
  gr.Markdown('''Try a different <b>seed value</b> if the result is unsatisfying (Default: 42).''')
499
+
500
+
501
  mv_images = gr.State()
502
 
503
  step1_submit.click(fn=check_input_image, inputs=[input_image]).success(
 
513
  inputs=[mv_images],
514
  outputs=[output_model_obj, output_model_glb]
515
  )
516
+
517
+ gr.Markdown(STEP2_HEADER)
518
+ with gr.Row(variant="panel"):
519
+ with gr.Column():
520
+ with gr.Row():
521
+ input_3d = gr.Model3D(
522
+ label="Input Mesh",
523
+ display_mode="wireframe",
524
+ clear_color=[1,1,1,1],
525
+ )
526
+
527
+ with gr.Row():
528
+ with gr.Group():
529
+ do_marching_cubes = gr.Checkbox(label="Preprocess with Marching Cubes", value=False)
530
+ do_sampling = gr.Checkbox(label="Random Sampling", value=False)
531
+ sample_seed = gr.Number(value=0, label="Seed Value", precision=0)
532
+
533
+ with gr.Row():
534
+ submit = gr.Button("Generate", elem_id="generate", variant="primary")
535
+
536
+ with gr.Row(variant="panel"):
537
+ mesh_examples = gr.Examples(
538
+ examples=[
539
+ os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
540
+ ],
541
+ inputs=input_3d,
542
+ outputs=[preprocess_model_obj, input_image_render, output_model_obj, output_image_render],
543
+ fn=do_inference,
544
+ cache_examples = False,
545
+ examples_per_page=10
546
+ )
547
+
548
+ with gr.Column():
549
+ with gr.Row():
550
+ input_image_render.render()
551
+ with gr.Row():
552
+ with gr.Tab("OBJ"):
553
+ preprocess_model_obj.render()
554
+ with gr.Row():
555
+ output_image_render.render()
556
+ with gr.Row():
557
+ with gr.Tab("OBJ"):
558
+ output_model_obj.render()
559
+ with gr.Row():
560
+ gr.Markdown('''Try click random sampling and different <b>Seed Value</b> if the result is unsatisfying''')
561
+
562
+ submit.click(
563
+ fn=do_inference,
564
+ inputs=[input_3d, sample_seed, do_sampling, do_marching_cubes],
565
+ outputs=[preprocess_model_obj, input_image_render, output_model_obj, output_image_render],
566
+ )
567
+
568
  gr.Markdown(STEP3_HEADER)
569
  gr.Markdown(STEP4_HEADER)
570