Yuxiao319 commited on
Commit
5ac6f37
·
1 Parent(s): 8ca9794

wonder3d_plus

Browse files
Files changed (1) hide show
  1. gradio_app.py +39 -18
gradio_app.py CHANGED
@@ -23,14 +23,17 @@ from typing import Dict, Optional, Tuple, List
23
  from dataclasses import dataclass
24
  import huggingface_hub
25
  from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
26
- from mvdiffusion.models.unet_mv2d_condition import UNetMV2DConditionModel
27
- from mvdiffusion.data.single_image_dataset import SingleImageDataset as MVDiffusionDataset
28
- from mvdiffusion.pipelines.pipeline_mvdiffusion_image import MVDiffusionImagePipeline
29
  from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler
30
  from einops import rearrange
31
  import numpy as np
32
  from transformers import SamModel, SamProcessor
33
 
 
 
 
34
  def save_image(tensor):
35
  ndarr = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
36
  # pdb.set_trace()
@@ -48,6 +51,7 @@ Generate consistent multi-view normals maps and color images.
48
  <div>
49
  The demo does not include the mesh reconstruction part, please visit <a href="https://github.com/xxlong0/Wonder3D/">our github repo</a> to get a textured mesh.
50
  </div>
 
51
  '''
52
  _GPU_ID = 0
53
 
@@ -57,30 +61,34 @@ if not hasattr(Image, 'Resampling'):
57
 
58
 
59
  def sam_init():
60
- model = SamModel.from_pretrained("facebook/sam-vit-large").to("cuda")
61
- processor = SamProcessor.from_pretrained("facebook/sam-vit-large")
62
  return model, processor
63
 
64
  def sam_segment(sam_model, sam_processor, input_image, *bbox_coords):
 
65
  bbox = torch.tensor(bbox_coords, dtype=torch.float32)
66
  bbox = bbox.unsqueeze(0).unsqueeze(0)
67
  image = np.asarray(input_image)
68
 
69
  start_time = time.time()
70
 
71
- inputs = sam_processor(input_image.convert('RGB'), input_boxes=bbox, return_tensors="pt", do_resize=False).to("cuda")
72
 
73
  outputs = sam_model(**inputs, multimask_output=False)
74
- masks = sam_processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
 
 
75
 
76
  print(f"SAM Time: {time.time() - start_time:.3f}s")
77
  out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
78
  out_image[:, :, :3] = image
79
  out_image_bbox = out_image.copy()
80
 
81
- out_image_bbox[:, :, 3] = masks[-1].cpu().detach().numpy().astype(np.uint8) * 255
 
82
  torch.cuda.empty_cache()
83
- return Image.fromarray(out_image_bbox, mode='RGBA')
84
 
85
  def expand2square(pil_img, background_color):
86
  width, height = pil_img.size
@@ -142,7 +150,7 @@ def load_wonder3d_pipeline(cfg):
142
  feature_extractor = CLIPImageProcessor.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="feature_extractor", revision=cfg.revision)
143
  vae = AutoencoderKL.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="vae", revision=cfg.revision)
144
  unet = UNetMV2DConditionModel.from_pretrained_2d(cfg.pretrained_unet_path, subfolder="unet", revision=cfg.revision, **cfg.unet_from_pretrained_kwargs)
145
- unet.enable_xformers_memory_efficient_attention()
146
 
147
  # Move text_encode and vae to gpu and cast to weight_dtype
148
  image_encoder.to(dtype=weight_dtype)
@@ -160,24 +168,28 @@ def load_wonder3d_pipeline(cfg):
160
  # sys.main_lock = threading.Lock()
161
  return pipeline
162
 
163
- from mvdiffusion.data.single_image_dataset import SingleImageDataset
164
- def prepare_data(single_image, crop_size):
165
  dataset = SingleImageDataset(
166
  root_dir = None,
167
  num_views = 6,
168
  img_wh=[256, 256],
169
  bg_color='white',
170
  crop_size=crop_size,
171
- single_image=single_image
 
 
172
  )
173
  return dataset[0]
174
 
175
 
176
- def run_pipeline(pipeline, cfg, single_image, guidance_scale, steps, seed, crop_size):
177
  import pdb
178
  # pdb.set_trace()
179
 
180
- batch = prepare_data(single_image, crop_size)
 
 
181
 
182
  pipeline.set_progress_bar_config(disable=True)
183
  seed = int(seed)
@@ -244,13 +256,14 @@ class TestConfig:
244
 
245
  cond_on_normals: bool
246
  cond_on_colors: bool
 
247
 
248
 
249
  def run_demo():
250
  from utils.misc import load_config
251
  from omegaconf import OmegaConf
252
  # parse YAML config to OmegaConf
253
- cfg = load_config("./configs/mvdiffusion-joint-ortho-6views.yaml")
254
  # print(cfg)
255
  schema = OmegaConf.structured(TestConfig)
256
  cfg = OmegaConf.merge(schema, cfg)
@@ -302,7 +315,7 @@ def run_demo():
302
  output_processing = gr.CheckboxGroup(['Background Removal'], label='Output Image Postprocessing', value=[])
303
  with gr.Row():
304
  with gr.Column():
305
- scale_slider = gr.Slider(1, 5, value=3, step=1,
306
  label='Classifier Free Guidance Scale')
307
  with gr.Column():
308
  steps_slider = gr.Slider(15, 100, value=50, step=1,
@@ -312,6 +325,14 @@ def run_demo():
312
  seed = gr.Number(42, label='Seed')
313
  with gr.Column():
314
  crop_size = gr.Number(192, label='Crop size')
 
 
 
 
 
 
 
 
315
  # crop_size = 192
316
  run_btn = gr.Button('Generate', variant='primary', interactive=True)
317
  with gr.Row():
@@ -338,7 +359,7 @@ def run_demo():
338
  inputs=[input_image, input_processing],
339
  outputs=[processed_image_highres, processed_image], queue=True
340
  ).success(fn=partial(run_pipeline, pipeline, cfg),
341
- inputs=[processed_image_highres, scale_slider, steps_slider, seed, crop_size],
342
  outputs=[view_1, view_2, view_3, view_4, view_5, view_6,
343
  normal_1, normal_2, normal_3, normal_4, normal_5, normal_6,
344
  view_gallery, normal_gallery]
 
23
  from dataclasses import dataclass
24
  import huggingface_hub
25
  from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
26
+ from mv_diffusion_30.models.unet_mv2d_condition import UNetMV2DConditionModel
27
+ from mv_diffusion_30.data.single_image_dataset import SingleImageDataset as MVDiffusionDataset
28
+ from mv_diffusion_30.pipelines.pipeline_mvdiffusion_image import MVDiffusionImagePipeline
29
  from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler
30
  from einops import rearrange
31
  import numpy as np
32
  from transformers import SamModel, SamProcessor
33
 
34
+
35
+
36
+
37
  def save_image(tensor):
38
  ndarr = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
39
  # pdb.set_trace()
 
51
  <div>
52
  The demo does not include the mesh reconstruction part, please visit <a href="https://github.com/xxlong0/Wonder3D/">our github repo</a> to get a textured mesh.
53
  </div>
54
+ <span style="font-weight: bold; color: #d9534f;">- 2024.11.5 We shift our ckpt to the a more powerful model [Wonder3D_Plus] that supports both orthogonal and perspective camera settings and further improves generalizability.</span>
55
  '''
56
  _GPU_ID = 0
57
 
 
61
 
62
 
63
  def sam_init():
64
+ model = SamModel.from_pretrained("facebook/sam-vit-huge").to("cuda")
65
+ processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
66
  return model, processor
67
 
68
  def sam_segment(sam_model, sam_processor, input_image, *bbox_coords):
69
+ input_points = [[[bbox_coords[2] - bbox_coords[0], bbox_coords[3] - bbox_coords[1]]]]
70
  bbox = torch.tensor(bbox_coords, dtype=torch.float32)
71
  bbox = bbox.unsqueeze(0).unsqueeze(0)
72
  image = np.asarray(input_image)
73
 
74
  start_time = time.time()
75
 
76
+ inputs = sam_processor(input_image, input_boxes=bbox, return_tensors="pt", do_resize=False).to("cuda")
77
 
78
  outputs = sam_model(**inputs, multimask_output=False)
79
+ masks = sam_processor.image_processor.post_process_masks(outputs.pred_masks.cpu(),
80
+ inputs["original_sizes"].cpu(),
81
+ inputs["reshaped_input_sizes"].cpu(), )
82
 
83
  print(f"SAM Time: {time.time() - start_time:.3f}s")
84
  out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
85
  out_image[:, :, :3] = image
86
  out_image_bbox = out_image.copy()
87
 
88
+ foreground_mask = masks[-1][-1, -1, ...] * 1.
89
+ out_image_bbox[:, :, 3] = foreground_mask.cpu().detach().numpy().astype(np.uint8) * 255
90
  torch.cuda.empty_cache()
91
+ return Image.fromarray(out_image_bbox, mode='RGBA')
92
 
93
  def expand2square(pil_img, background_color):
94
  width, height = pil_img.size
 
150
  feature_extractor = CLIPImageProcessor.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="feature_extractor", revision=cfg.revision)
151
  vae = AutoencoderKL.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="vae", revision=cfg.revision)
152
  unet = UNetMV2DConditionModel.from_pretrained_2d(cfg.pretrained_unet_path, subfolder="unet", revision=cfg.revision, **cfg.unet_from_pretrained_kwargs)
153
+ # unet.enable_xformers_memory_efficient_attention()
154
 
155
  # Move text_encode and vae to gpu and cast to weight_dtype
156
  image_encoder.to(dtype=weight_dtype)
 
168
  # sys.main_lock = threading.Lock()
169
  return pipeline
170
 
171
+ from mv_diffusion_30.data.single_image_dataset import SingleImageDataset
172
+ def prepare_data(single_image, crop_size, input_camera_type):
173
  dataset = SingleImageDataset(
174
  root_dir = None,
175
  num_views = 6,
176
  img_wh=[256, 256],
177
  bg_color='white',
178
  crop_size=crop_size,
179
+ single_image=single_image,
180
+ load_cam_type=True,
181
+ cam_types=[input_camera_type]
182
  )
183
  return dataset[0]
184
 
185
 
186
+ def run_pipeline(pipeline, cfg, single_image, guidance_scale, steps, seed, crop_size, input_camera_type):
187
  import pdb
188
  # pdb.set_trace()
189
 
190
+
191
+
192
+ batch = prepare_data(single_image, crop_size, input_camera_type)
193
 
194
  pipeline.set_progress_bar_config(disable=True)
195
  seed = int(seed)
 
256
 
257
  cond_on_normals: bool
258
  cond_on_colors: bool
259
+ load_task: bool
260
 
261
 
262
  def run_demo():
263
  from utils.misc import load_config
264
  from omegaconf import OmegaConf
265
  # parse YAML config to OmegaConf
266
+ cfg = load_config("./configs/mvdiffusion-joint-plus.yaml")
267
  # print(cfg)
268
  schema = OmegaConf.structured(TestConfig)
269
  cfg = OmegaConf.merge(schema, cfg)
 
315
  output_processing = gr.CheckboxGroup(['Background Removal'], label='Output Image Postprocessing', value=[])
316
  with gr.Row():
317
  with gr.Column():
318
+ scale_slider = gr.Slider(1, 5, value=2, step=1,
319
  label='Classifier Free Guidance Scale')
320
  with gr.Column():
321
  steps_slider = gr.Slider(15, 100, value=50, step=1,
 
325
  seed = gr.Number(42, label='Seed')
326
  with gr.Column():
327
  crop_size = gr.Number(192, label='Crop size')
328
+ with gr.Row():
329
+ camera_type = gr.Radio(
330
+ choices=[("Orthogonal Camera", "ortho"), ("Perspective Camera", "persp")],
331
+ value="ortho",
332
+ label="Camera Type"
333
+ )
334
+
335
+
336
  # crop_size = 192
337
  run_btn = gr.Button('Generate', variant='primary', interactive=True)
338
  with gr.Row():
 
359
  inputs=[input_image, input_processing],
360
  outputs=[processed_image_highres, processed_image], queue=True
361
  ).success(fn=partial(run_pipeline, pipeline, cfg),
362
+ inputs=[processed_image_highres, scale_slider, steps_slider, seed, crop_size, camera_type],
363
  outputs=[view_1, view_2, view_3, view_4, view_5, view_6,
364
  normal_1, normal_2, normal_3, normal_4, normal_5, normal_6,
365
  view_gallery, normal_gallery]