ThomasSimonini HF staff commited on
Commit
ef24b41
·
verified ·
1 Parent(s): 63c5f26

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +201 -201
app.py CHANGED
@@ -22,7 +22,7 @@ from src.utils.camera_util import (
22
  FOV_to_intrinsics,
23
  get_zero123plus_input_cameras,
24
  get_circular_camera_poses,
25
- )
26
  from src.utils.mesh_util import save_obj, save_glb
27
  from src.utils.infer_util import remove_background, resize_foreground, images_to_video
28
 
@@ -57,23 +57,23 @@ def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexi
57
  intrinsics = FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
58
  cameras = torch.cat([extrinsics, intrinsics], dim=-1)
59
  cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
60
- return cameras
61
 
62
 
63
- def images_to_video(images, output_path, fps=30):
64
  # images: (N, C, H, W)
65
  os.makedirs(os.path.dirname(output_path), exist_ok=True)
66
  frames = []
67
  for i in range(images.shape[0]):
68
  frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8).clip(0, 255)
69
  assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
70
- f"Frame shape mismatch: {frame.shape} vs {images.shape}"
71
  assert frame.min() >= 0 and frame.max() <= 255, \
72
- f"Frame value out of range: {frame.min()} ~ {frame.max()}"
73
  frames.append(frame)
74
- imageio.mimwrite(output_path, np.stack(frames), fps=fps, codec='h264')
75
 
76
- def find_cuda():
77
  # Check if CUDA_HOME or CUDA_PATH environment variables are set
78
  cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
79
 
@@ -88,24 +88,24 @@ def find_cuda():
88
  cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
89
  return cuda_path
90
 
91
- return None
92
 
93
- cuda_path = find_cuda()
94
 
95
- if cuda_path:
96
- print(f"CUDA installation found at: {cuda_path}")
97
- else:
98
- print("CUDA installation not found")
99
 
100
- config_path = 'configs/instant-mesh-large.yaml'
101
- config = OmegaConf.load(config_path)
102
- config_name = os.path.basename(config_path).replace('.yaml', '')
103
- model_config = config.model_config
104
- infer_config = config.infer_config
105
 
106
- IS_FLEXICUBES = True if config_name.startswith('instant-mesh') else False
107
 
108
- device = torch.device('cuda')
109
 
110
  # load diffusion model
111
  print('Loading diffusion model ...')
@@ -113,10 +113,10 @@ pipeline = DiffusionPipeline.from_pretrained(
113
  "sudo-ai/zero123plus-v1.2",
114
  custom_pipeline="zero123plus",
115
  torch_dtype=torch.float16,
116
- )
117
  pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
118
  pipeline.scheduler.config, timestep_spacing='trailing'
119
- )
120
 
121
  # load custom white-background UNet
122
  unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
@@ -143,27 +143,27 @@ def check_input_image(input_image):
143
  raise gr.Error("No image uploaded!")
144
 
145
 
146
- def preprocess(input_image, do_remove_background):
147
 
148
- rembg_session = rembg.new_session() if do_remove_background else None
149
 
150
- if do_remove_background:
151
- input_image = remove_background(input_image, rembg_session)
152
- input_image = resize_foreground(input_image, 0.85)
153
 
154
- return input_image
155
 
156
 
157
- @spaces.GPU
158
- def generate_mvs(input_image, sample_steps, sample_seed):
159
 
160
- seed_everything(sample_seed)
161
-
162
  # sampling
163
  z123_image = pipeline(
164
  input_image,
165
  num_inference_steps=sample_steps
166
- ).images[0]
167
 
168
  show_image = np.asarray(z123_image, dtype=np.uint8)
169
  show_image = torch.from_numpy(show_image) # (960, 640, 3)
@@ -174,15 +174,15 @@ def generate_mvs(input_image, sample_steps, sample_seed):
174
  return z123_image, show_image
175
 
176
 
177
- @spaces.GPU
178
- def make3d(images):
179
 
180
- global model
181
- if IS_FLEXICUBES:
182
- model.init_flexicubes_geometry(device, use_renderer=False)
183
- model = model.eval()
184
 
185
- images = np.asarray(images, dtype=np.float32) / 255.0
186
  images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
187
  images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
188
 
@@ -208,7 +208,7 @@ def make3d(images):
208
  planes,
209
  use_texture_map=False,
210
  **infer_config,
211
- )
212
 
213
  vertices, faces, vertex_colors = mesh_out
214
  vertices = vertices[:, [1, 2, 0]]
@@ -218,7 +218,7 @@ def make3d(images):
218
 
219
  print(f"Mesh saved to {mesh_fpath}")
220
 
221
- return mesh_fpath, mesh_glb_fpath
222
 
223
 
224
  ###############################################################################
@@ -228,14 +228,14 @@ model = load_v2()
228
  device = torch.device('cuda')
229
  accelerator = Accelerator(
230
  mixed_precision="fp16",
231
- )
232
  model = accelerator.prepare(model)
233
  model.eval()
234
  print("Model loaded to device")
235
 
236
  def wireframe_render(mesh):
237
  views = [
238
- (90, 20), (270, 20)
239
  ]
240
  mesh.vertices = mesh.vertices[:, [0, 2, 1]]
241
 
@@ -260,7 +260,7 @@ def wireframe_render(mesh):
260
  facecolors=(0.8, 0.5, 0.2, 1.0), # Brownish yellow
261
  edgecolors='k',
262
  linewidths=0.5,
263
- ))
264
 
265
  # Set limits and center the view on the object
266
  ax.set_xlim(center[0] - scale / 2, center[0] + scale / 2)
@@ -300,13 +300,13 @@ def wireframe_render(mesh):
300
  plt.close(fig)
301
  return save_path
302
 
303
- @spaces.GPU(duration=360)
304
- def do_inference(input_3d, sample_seed=0, do_sampling=False, do_marching_cubes=False):
305
- set_seed(sample_seed)
306
- print("Seed value:", sample_seed)
307
 
308
- input_mesh = trimesh.load(input_3d)
309
- pc_list, mesh_list = process_mesh_to_pc([input_mesh], marching_cubes = do_marching_cubes)
310
  pc_normal = pc_list[0] # 4096, 6
311
  mesh = mesh_list[0]
312
  vertices = mesh.vertices
@@ -330,16 +330,16 @@ def do_inference(input_3d, sample_seed=0, do_sampling=False, do_marching_cubes=F
330
  try:
331
  if mesh.visual.vertex_colors is not None:
332
  orange_color = np.array([255, 165, 0, 255], dtype=np.uint8)
333
-
334
  mesh.visual.vertex_colors = np.tile(orange_color, (mesh.vertices.shape[0], 1))
335
  else:
336
  orange_color = np.array([255, 165, 0, 255], dtype=np.uint8)
337
  mesh.visual.vertex_colors = np.tile(orange_color, (mesh.vertices.shape[0], 1))
338
- except Exception as e:
339
- print(e)
340
- input_save_name = f"processed_input_{int(time.time())}.obj"
341
- mesh.export(input_save_name)
342
- input_render_res = wireframe_render(mesh)
343
 
344
  pc_coor = pc_coor / np.abs(pc_coor).max() * 0.99 # input should be from -1 to 1
345
 
@@ -352,17 +352,17 @@ def do_inference(input_3d, sample_seed=0, do_sampling=False, do_marching_cubes=F
352
  # with accelerator.autocast():
353
  with accelerator.autocast():
354
  outputs = model(input, do_sampling)
355
- print("Model inference done")
356
- recon_mesh = outputs[0]
357
 
358
- valid_mask = torch.all(~torch.isnan(recon_mesh.reshape((-1, 9))), dim=1)
359
  recon_mesh = recon_mesh[valid_mask] # nvalid_face x 3 x 3
360
  vertices = recon_mesh.reshape(-1, 3).cpu()
361
  vertices_index = np.arange(len(vertices)) # 0, 1, ..., 3 x face
362
  triangles = vertices_index.reshape(-1, 3)
363
 
364
  artist_mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, force="mesh",
365
- merge_primitives=True)
366
 
367
  artist_mesh.merge_vertices()
368
  artist_mesh.update_faces(artist_mesh.nondegenerate_faces())
@@ -378,12 +378,12 @@ def do_inference(input_3d, sample_seed=0, do_sampling=False, do_marching_cubes=F
378
  orange_color = np.array([255, 165, 0, 255], dtype=np.uint8)
379
  artist_mesh.visual.vertex_colors = np.tile(orange_color, (artist_mesh.vertices.shape[0], 1))
380
 
381
- num_faces = len(artist_mesh.faces)
382
 
383
- brown_color = np.array([165, 42, 42, 255], dtype=np.uint8)
384
- face_colors = np.tile(brown_color, (num_faces, 1))
385
 
386
- artist_mesh.visual.face_colors = face_colors
387
  # add time stamp to avoid cache
388
  save_name = f"output_{int(time.time())}.obj"
389
  artist_mesh.export(save_name)
@@ -395,18 +395,18 @@ output_model_obj = gr.Model3D(
395
  label="Generated Mesh (OBJ Format)",
396
  display_mode="wireframe",
397
  clear_color=[1, 1, 1, 1],
398
- )
399
  preprocess_model_obj = gr.Model3D(
400
  label="Processed Input Mesh (OBJ Format)",
401
  display_mode="wireframe",
402
  clear_color=[1, 1, 1, 1],
403
- )
404
  input_image_render = gr.Image(
405
  label="Wireframe Render of Processed Input Mesh",
406
- )
407
  output_image_render = gr.Image(
408
  label="Wireframe Render of Generated Mesh",
409
- )
410
 
411
  ###############################################################################
412
  # Gradio
@@ -454,137 +454,137 @@ STEP4_HEADER = """
454
  """
455
 
456
  with gr.Blocks() as demo:
457
- gr.Markdown(HEADER)
458
- gr.Markdown(STEP1_HEADER)
459
- with gr.Row(variant = "panel"):
460
- with gr.Column():
461
- with gr.Row():
462
- input_image = gr.Image(
463
- label = "Input Image",
464
- image_mode = "RGBA",
465
- sources = "upload",
466
- type="pil",
467
- elem_id="content_image"
468
- )
469
- processed_image = gr.Image(label="Processed Image",
470
- image_mode="RGBA",
471
- type="pil",
472
- interactive=False
473
- )
474
- with gr.Row():
475
- with gr.Group():
476
- do_remove_background = gr.Checkbox(
477
- label="Remove Background",
478
- value=True)
479
- sample_seed = gr.Number(
480
- value=42,
481
- label="Seed Value",
482
- precision=0
483
- )
484
- sample_steps = gr.Slider(
485
- label="Sample Steps",
486
- minimum=30,
487
- maximum=75,
488
- value=75,
489
- step=5
490
  )
491
- with gr.Row():
492
- step1_submit = gr.Button("Generate", elem_id="generate", variant="primary")
493
- with gr.Column():
494
- with gr.Row():
495
- with gr.Column():
496
- mv_show_images = gr.Image(
497
- label="Generated Multi-views",
498
- type="pil",
499
- width=379,
500
- interactive=False
501
  )
502
- with gr.Column():
503
- with gr.Tab("OBJ"):
504
- output_model_obj = gr.Model3D(
505
- label = "Output Model (OBJ Format)",
506
- interactive = False,
507
- )
508
- gr.Markdown("Note: Downloaded object will be flipped in case of .obj export. Export .glb instead or manually flip it before usage.")
509
- with gr.Tab("GLB"):
510
- output_model_glb = gr.Model3D(
511
- label="Output Model (GLB Format)",
512
- interactive=False,
513
- )
514
- gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
515
- gr.Markdown('''Try a different <b>seed value</b> if the result is unsatisfying (Default: 42).''')
516
- with gr.Row():
517
- gr.Markdown(STEP2_HEADER)
518
- with gr.Row(variant="panel"):
519
- with gr.Column():
520
- with gr.Row():
521
- input_3d = gr.Model3D(
522
- label="Input Mesh",
523
- display_mode="wireframe",
524
- clear_color=[1,1,1,1],
525
- )
526
-
527
- with gr.Row():
528
- with gr.Group():
529
- do_marching_cubes = gr.Checkbox(label="Preprocess with Marching Cubes", value=False)
530
- do_sampling = gr.Checkbox(label="Random Sampling", value=False)
531
- sample_seed = gr.Number(value=0, label="Seed Value", precision=0)
532
-
533
- with gr.Row():
534
- step2_submit = gr.Button("Generate", elem_id="generate", variant="primary")
535
-
536
- with gr.Row(variant="panel"):
537
- mesh_examples = gr.Examples(
538
- examples=[
539
- os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
540
- ],
541
- inputs=input_3d,
542
- outputs=[preprocess_model_obj, input_image_render, output_model_obj, output_image_render],
543
- fn=do_inference,
544
- cache_examples = False,
545
- examples_per_page=10
546
- )
547
-
548
- with gr.Column():
549
- with gr.Row():
550
- input_image_render.render()
551
- with gr.Row():
552
- with gr.Tab("OBJ"):
553
- preprocess_model_obj.render()
554
- with gr.Row():
555
- output_image_render.render()
556
- with gr.Row():
557
- with gr.Tab("OBJ"):
558
- output_model_obj.render()
559
- with gr.Row():
560
- gr.Markdown('''Try click random sampling and different <b>Seed Value</b> if the result is unsatisfying''')
561
-
562
- gr.Markdown(STEP3_HEADER)
563
- gr.Markdown(STEP4_HEADER)
564
-
565
- mv_images = gr.State()
566
-
567
- step1_submit.click(fn=check_input_image, inputs=[input_image]).success(
568
- fn=preprocess,
569
- inputs=[input_image, do_remove_background],
570
- outputs=[processed_image],
571
- ).success(
572
- fn=generate_mvs,
573
- inputs=[processed_image, sample_steps, sample_seed],
574
- outputs=[mv_images, mv_show_images],
575
- ).success(
576
- fn=make3d,
577
- inputs=[mv_images],
578
- outputs=[output_model_obj, output_model_glb]
579
- )
580
-
581
- step2_submit.click(
582
- fn=do_inference,
583
- inputs=[input_3d, sample_seed, do_sampling, do_marching_cubes],
584
- outputs=[preprocess_model_obj, input_image_render, output_model_obj, output_image_render],
585
- )
586
-
587
-
588
-
589
- demo.queue(max_size=10)
590
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  FOV_to_intrinsics,
23
  get_zero123plus_input_cameras,
24
  get_circular_camera_poses,
25
+ )
26
  from src.utils.mesh_util import save_obj, save_glb
27
  from src.utils.infer_util import remove_background, resize_foreground, images_to_video
28
 
 
57
  intrinsics = FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
58
  cameras = torch.cat([extrinsics, intrinsics], dim=-1)
59
  cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
60
+ return cameras
61
 
62
 
63
+ def images_to_video(images, output_path, fps=30):
64
  # images: (N, C, H, W)
65
  os.makedirs(os.path.dirname(output_path), exist_ok=True)
66
  frames = []
67
  for i in range(images.shape[0]):
68
  frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8).clip(0, 255)
69
  assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
70
+ f"Frame shape mismatch: {frame.shape} vs {images.shape}"
71
  assert frame.min() >= 0 and frame.max() <= 255, \
72
+ f"Frame value out of range: {frame.min()} ~ {frame.max()}"
73
  frames.append(frame)
74
+ imageio.mimwrite(output_path, np.stack(frames), fps=fps, codec='h264')
75
 
76
+ def find_cuda():
77
  # Check if CUDA_HOME or CUDA_PATH environment variables are set
78
  cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
79
 
 
88
  cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
89
  return cuda_path
90
 
91
+ return None
92
 
93
+ cuda_path = find_cuda()
94
 
95
+ if cuda_path:
96
+ print(f"CUDA installation found at: {cuda_path}")
97
+ else:
98
+ print("CUDA installation not found")
99
 
100
+ config_path = 'configs/instant-mesh-large.yaml'
101
+ config = OmegaConf.load(config_path)
102
+ config_name = os.path.basename(config_path).replace('.yaml', '')
103
+ model_config = config.model_config
104
+ infer_config = config.infer_config
105
 
106
+ IS_FLEXICUBES = True if config_name.startswith('instant-mesh') else False
107
 
108
+ device = torch.device('cuda')
109
 
110
  # load diffusion model
111
  print('Loading diffusion model ...')
 
113
  "sudo-ai/zero123plus-v1.2",
114
  custom_pipeline="zero123plus",
115
  torch_dtype=torch.float16,
116
+ )
117
  pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
118
  pipeline.scheduler.config, timestep_spacing='trailing'
119
+ )
120
 
121
  # load custom white-background UNet
122
  unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
 
143
  raise gr.Error("No image uploaded!")
144
 
145
 
146
+ def preprocess(input_image, do_remove_background):
147
 
148
+ rembg_session = rembg.new_session() if do_remove_background else None
149
 
150
+ if do_remove_background:
151
+ input_image = remove_background(input_image, rembg_session)
152
+ input_image = resize_foreground(input_image, 0.85)
153
 
154
+ return input_image
155
 
156
 
157
+ @spaces.GPU
158
+ def generate_mvs(input_image, sample_steps, sample_seed):
159
 
160
+ seed_everything(sample_seed)
161
+
162
  # sampling
163
  z123_image = pipeline(
164
  input_image,
165
  num_inference_steps=sample_steps
166
+ ).images[0]
167
 
168
  show_image = np.asarray(z123_image, dtype=np.uint8)
169
  show_image = torch.from_numpy(show_image) # (960, 640, 3)
 
174
  return z123_image, show_image
175
 
176
 
177
+ @spaces.GPU
178
+ def make3d(images):
179
 
180
+ global model
181
+ if IS_FLEXICUBES:
182
+ model.init_flexicubes_geometry(device, use_renderer=False)
183
+ model = model.eval()
184
 
185
+ images = np.asarray(images, dtype=np.float32) / 255.0
186
  images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
187
  images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
188
 
 
208
  planes,
209
  use_texture_map=False,
210
  **infer_config,
211
+ )
212
 
213
  vertices, faces, vertex_colors = mesh_out
214
  vertices = vertices[:, [1, 2, 0]]
 
218
 
219
  print(f"Mesh saved to {mesh_fpath}")
220
 
221
+ return mesh_fpath, mesh_glb_fpath
222
 
223
 
224
  ###############################################################################
 
228
  device = torch.device('cuda')
229
  accelerator = Accelerator(
230
  mixed_precision="fp16",
231
+ )
232
  model = accelerator.prepare(model)
233
  model.eval()
234
  print("Model loaded to device")
235
 
236
  def wireframe_render(mesh):
237
  views = [
238
+ (90, 20), (270, 20)
239
  ]
240
  mesh.vertices = mesh.vertices[:, [0, 2, 1]]
241
 
 
260
  facecolors=(0.8, 0.5, 0.2, 1.0), # Brownish yellow
261
  edgecolors='k',
262
  linewidths=0.5,
263
+ ))
264
 
265
  # Set limits and center the view on the object
266
  ax.set_xlim(center[0] - scale / 2, center[0] + scale / 2)
 
300
  plt.close(fig)
301
  return save_path
302
 
303
+ @spaces.GPU(duration=360)
304
+ def do_inference(input_3d, sample_seed=0, do_sampling=False, do_marching_cubes=False):
305
+ set_seed(sample_seed)
306
+ print("Seed value:", sample_seed)
307
 
308
+ input_mesh = trimesh.load(input_3d)
309
+ pc_list, mesh_list = process_mesh_to_pc([input_mesh], marching_cubes = do_marching_cubes)
310
  pc_normal = pc_list[0] # 4096, 6
311
  mesh = mesh_list[0]
312
  vertices = mesh.vertices
 
330
  try:
331
  if mesh.visual.vertex_colors is not None:
332
  orange_color = np.array([255, 165, 0, 255], dtype=np.uint8)
333
+
334
  mesh.visual.vertex_colors = np.tile(orange_color, (mesh.vertices.shape[0], 1))
335
  else:
336
  orange_color = np.array([255, 165, 0, 255], dtype=np.uint8)
337
  mesh.visual.vertex_colors = np.tile(orange_color, (mesh.vertices.shape[0], 1))
338
+ except Exception as e:
339
+ print(e)
340
+ input_save_name = f"processed_input_{int(time.time())}.obj"
341
+ mesh.export(input_save_name)
342
+ input_render_res = wireframe_render(mesh)
343
 
344
  pc_coor = pc_coor / np.abs(pc_coor).max() * 0.99 # input should be from -1 to 1
345
 
 
352
  # with accelerator.autocast():
353
  with accelerator.autocast():
354
  outputs = model(input, do_sampling)
355
+ print("Model inference done")
356
+ recon_mesh = outputs[0]
357
 
358
+ valid_mask = torch.all(~torch.isnan(recon_mesh.reshape((-1, 9))), dim=1)
359
  recon_mesh = recon_mesh[valid_mask] # nvalid_face x 3 x 3
360
  vertices = recon_mesh.reshape(-1, 3).cpu()
361
  vertices_index = np.arange(len(vertices)) # 0, 1, ..., 3 x face
362
  triangles = vertices_index.reshape(-1, 3)
363
 
364
  artist_mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, force="mesh",
365
+ merge_primitives=True)
366
 
367
  artist_mesh.merge_vertices()
368
  artist_mesh.update_faces(artist_mesh.nondegenerate_faces())
 
378
  orange_color = np.array([255, 165, 0, 255], dtype=np.uint8)
379
  artist_mesh.visual.vertex_colors = np.tile(orange_color, (artist_mesh.vertices.shape[0], 1))
380
 
381
+ num_faces = len(artist_mesh.faces)
382
 
383
+ brown_color = np.array([165, 42, 42, 255], dtype=np.uint8)
384
+ face_colors = np.tile(brown_color, (num_faces, 1))
385
 
386
+ artist_mesh.visual.face_colors = face_colors
387
  # add time stamp to avoid cache
388
  save_name = f"output_{int(time.time())}.obj"
389
  artist_mesh.export(save_name)
 
395
  label="Generated Mesh (OBJ Format)",
396
  display_mode="wireframe",
397
  clear_color=[1, 1, 1, 1],
398
+ )
399
  preprocess_model_obj = gr.Model3D(
400
  label="Processed Input Mesh (OBJ Format)",
401
  display_mode="wireframe",
402
  clear_color=[1, 1, 1, 1],
403
+ )
404
  input_image_render = gr.Image(
405
  label="Wireframe Render of Processed Input Mesh",
406
+ )
407
  output_image_render = gr.Image(
408
  label="Wireframe Render of Generated Mesh",
409
+ )
410
 
411
  ###############################################################################
412
  # Gradio
 
454
  """
455
 
456
  with gr.Blocks() as demo:
457
+ gr.Markdown(HEADER)
458
+ gr.Markdown(STEP1_HEADER)
459
+ with gr.Row(variant = "panel"):
460
+ with gr.Column():
461
+ with gr.Row():
462
+ input_image = gr.Image(
463
+ label = "Input Image",
464
+ image_mode = "RGBA",
465
+ sources = "upload",
466
+ type="pil",
467
+ elem_id="content_image"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
  )
469
+ processed_image = gr.Image(label="Processed Image",
470
+ image_mode="RGBA",
471
+ type="pil",
472
+ interactive=False
 
 
 
 
 
 
473
  )
474
+ with gr.Row():
475
+ with gr.Group():
476
+ do_remove_background = gr.Checkbox(
477
+ label="Remove Background",
478
+ value=True)
479
+ sample_seed = gr.Number(
480
+ value=42,
481
+ label="Seed Value",
482
+ precision=0
483
+ )
484
+ sample_steps = gr.Slider(
485
+ label="Sample Steps",
486
+ minimum=30,
487
+ maximum=75,
488
+ value=75,
489
+ step=5
490
+ )
491
+ with gr.Row():
492
+ step1_submit = gr.Button("Generate", elem_id="generate", variant="primary")
493
+ with gr.Column():
494
+ with gr.Row():
495
+ with gr.Column():
496
+ mv_show_images = gr.Image(
497
+ label="Generated Multi-views",
498
+ type="pil",
499
+ width=379,
500
+ interactive=False
501
+ )
502
+ with gr.Column():
503
+ with gr.Tab("OBJ"):
504
+ output_model_obj = gr.Model3D(
505
+ label = "Output Model (OBJ Format)",
506
+ interactive = False,
507
+ )
508
+ gr.Markdown("Note: Downloaded object will be flipped in case of .obj export. Export .glb instead or manually flip it before usage.")
509
+ with gr.Tab("GLB"):
510
+ output_model_glb = gr.Model3D(
511
+ label="Output Model (GLB Format)",
512
+ interactive=False,
513
+ )
514
+ gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
515
+ gr.Markdown('''Try a different <b>seed value</b> if the result is unsatisfying (Default: 42).''')
516
+ with gr.Row():
517
+ gr.Markdown(STEP2_HEADER)
518
+ with gr.Row(variant="panel"):
519
+ with gr.Column():
520
+ with gr.Row():
521
+ input_3d = gr.Model3D(
522
+ label="Input Mesh",
523
+ display_mode="wireframe",
524
+ clear_color=[1,1,1,1],
525
+ )
526
+
527
+ with gr.Row():
528
+ with gr.Group():
529
+ do_marching_cubes = gr.Checkbox(label="Preprocess with Marching Cubes", value=False)
530
+ do_sampling = gr.Checkbox(label="Random Sampling", value=False)
531
+ sample_seed = gr.Number(value=0, label="Seed Value", precision=0)
532
+
533
+ with gr.Row():
534
+ step2_submit = gr.Button("Generate", elem_id="generate", variant="primary")
535
+
536
+ with gr.Row(variant="panel"):
537
+ mesh_examples = gr.Examples(
538
+ examples=[
539
+ os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
540
+ ],
541
+ inputs=input_3d,
542
+ outputs=[preprocess_model_obj, input_image_render, output_model_obj, output_image_render],
543
+ fn=do_inference,
544
+ cache_examples = False,
545
+ examples_per_page=10
546
+ )
547
+
548
+ with gr.Column():
549
+ with gr.Row():
550
+ input_image_render.render()
551
+ with gr.Row():
552
+ with gr.Tab("OBJ"):
553
+ preprocess_model_obj.render()
554
+ with gr.Row():
555
+ output_image_render.render()
556
+ with gr.Row():
557
+ with gr.Tab("OBJ"):
558
+ output_model_obj.render()
559
+ with gr.Row():
560
+ gr.Markdown('''Try click random sampling and different <b>Seed Value</b> if the result is unsatisfying''')
561
+
562
+ gr.Markdown(STEP3_HEADER)
563
+ gr.Markdown(STEP4_HEADER)
564
+
565
+ mv_images = gr.State()
566
+
567
+ step1_submit.click(fn=check_input_image, inputs=[input_image]).success(
568
+ fn=preprocess,
569
+ inputs=[input_image, do_remove_background],
570
+ outputs=[processed_image],
571
+ ).success(
572
+ fn=generate_mvs,
573
+ inputs=[processed_image, sample_steps, sample_seed],
574
+ outputs=[mv_images, mv_show_images],
575
+ ).success(
576
+ fn=make3d,
577
+ inputs=[mv_images],
578
+ outputs=[output_model_obj, output_model_glb]
579
+ )
580
+
581
+ step2_submit.click(
582
+ fn=do_inference,
583
+ inputs=[input_3d, sample_seed, do_sampling, do_marching_cubes],
584
+ outputs=[preprocess_model_obj, input_image_render, output_model_obj, output_image_render],
585
+ )
586
+
587
+
588
+
589
+ demo.queue(max_size=10)
590
+ demo.launch()