Spaces:
Running
on
L4
Running
on
L4
wonder3d_plus
Browse files- gradio_app.py +39 -18
gradio_app.py
CHANGED
@@ -23,14 +23,17 @@ from typing import Dict, Optional, Tuple, List
|
|
23 |
from dataclasses import dataclass
|
24 |
import huggingface_hub
|
25 |
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
26 |
-
from
|
27 |
-
from
|
28 |
-
from
|
29 |
from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler
|
30 |
from einops import rearrange
|
31 |
import numpy as np
|
32 |
from transformers import SamModel, SamProcessor
|
33 |
|
|
|
|
|
|
|
34 |
def save_image(tensor):
|
35 |
ndarr = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
|
36 |
# pdb.set_trace()
|
@@ -48,6 +51,7 @@ Generate consistent multi-view normals maps and color images.
|
|
48 |
<div>
|
49 |
The demo does not include the mesh reconstruction part, please visit <a href="https://github.com/xxlong0/Wonder3D/">our github repo</a> to get a textured mesh.
|
50 |
</div>
|
|
|
51 |
'''
|
52 |
_GPU_ID = 0
|
53 |
|
@@ -57,30 +61,34 @@ if not hasattr(Image, 'Resampling'):
|
|
57 |
|
58 |
|
59 |
def sam_init():
|
60 |
-
model = SamModel.from_pretrained("facebook/sam-vit-
|
61 |
-
processor = SamProcessor.from_pretrained("facebook/sam-vit-
|
62 |
return model, processor
|
63 |
|
64 |
def sam_segment(sam_model, sam_processor, input_image, *bbox_coords):
|
|
|
65 |
bbox = torch.tensor(bbox_coords, dtype=torch.float32)
|
66 |
bbox = bbox.unsqueeze(0).unsqueeze(0)
|
67 |
image = np.asarray(input_image)
|
68 |
|
69 |
start_time = time.time()
|
70 |
|
71 |
-
inputs = sam_processor(input_image
|
72 |
|
73 |
outputs = sam_model(**inputs, multimask_output=False)
|
74 |
-
masks = sam_processor.image_processor.post_process_masks(outputs.pred_masks.cpu(),
|
|
|
|
|
75 |
|
76 |
print(f"SAM Time: {time.time() - start_time:.3f}s")
|
77 |
out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
|
78 |
out_image[:, :, :3] = image
|
79 |
out_image_bbox = out_image.copy()
|
80 |
|
81 |
-
|
|
|
82 |
torch.cuda.empty_cache()
|
83 |
-
return Image.fromarray(out_image_bbox, mode='RGBA')
|
84 |
|
85 |
def expand2square(pil_img, background_color):
|
86 |
width, height = pil_img.size
|
@@ -142,7 +150,7 @@ def load_wonder3d_pipeline(cfg):
|
|
142 |
feature_extractor = CLIPImageProcessor.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="feature_extractor", revision=cfg.revision)
|
143 |
vae = AutoencoderKL.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="vae", revision=cfg.revision)
|
144 |
unet = UNetMV2DConditionModel.from_pretrained_2d(cfg.pretrained_unet_path, subfolder="unet", revision=cfg.revision, **cfg.unet_from_pretrained_kwargs)
|
145 |
-
unet.enable_xformers_memory_efficient_attention()
|
146 |
|
147 |
# Move text_encode and vae to gpu and cast to weight_dtype
|
148 |
image_encoder.to(dtype=weight_dtype)
|
@@ -160,24 +168,28 @@ def load_wonder3d_pipeline(cfg):
|
|
160 |
# sys.main_lock = threading.Lock()
|
161 |
return pipeline
|
162 |
|
163 |
-
from
|
164 |
-
def prepare_data(single_image, crop_size):
|
165 |
dataset = SingleImageDataset(
|
166 |
root_dir = None,
|
167 |
num_views = 6,
|
168 |
img_wh=[256, 256],
|
169 |
bg_color='white',
|
170 |
crop_size=crop_size,
|
171 |
-
single_image=single_image
|
|
|
|
|
172 |
)
|
173 |
return dataset[0]
|
174 |
|
175 |
|
176 |
-
def run_pipeline(pipeline, cfg, single_image, guidance_scale, steps, seed, crop_size):
|
177 |
import pdb
|
178 |
# pdb.set_trace()
|
179 |
|
180 |
-
|
|
|
|
|
181 |
|
182 |
pipeline.set_progress_bar_config(disable=True)
|
183 |
seed = int(seed)
|
@@ -244,13 +256,14 @@ class TestConfig:
|
|
244 |
|
245 |
cond_on_normals: bool
|
246 |
cond_on_colors: bool
|
|
|
247 |
|
248 |
|
249 |
def run_demo():
|
250 |
from utils.misc import load_config
|
251 |
from omegaconf import OmegaConf
|
252 |
# parse YAML config to OmegaConf
|
253 |
-
cfg = load_config("./configs/mvdiffusion-joint-
|
254 |
# print(cfg)
|
255 |
schema = OmegaConf.structured(TestConfig)
|
256 |
cfg = OmegaConf.merge(schema, cfg)
|
@@ -302,7 +315,7 @@ def run_demo():
|
|
302 |
output_processing = gr.CheckboxGroup(['Background Removal'], label='Output Image Postprocessing', value=[])
|
303 |
with gr.Row():
|
304 |
with gr.Column():
|
305 |
-
scale_slider = gr.Slider(1, 5, value=
|
306 |
label='Classifier Free Guidance Scale')
|
307 |
with gr.Column():
|
308 |
steps_slider = gr.Slider(15, 100, value=50, step=1,
|
@@ -312,6 +325,14 @@ def run_demo():
|
|
312 |
seed = gr.Number(42, label='Seed')
|
313 |
with gr.Column():
|
314 |
crop_size = gr.Number(192, label='Crop size')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
315 |
# crop_size = 192
|
316 |
run_btn = gr.Button('Generate', variant='primary', interactive=True)
|
317 |
with gr.Row():
|
@@ -338,7 +359,7 @@ def run_demo():
|
|
338 |
inputs=[input_image, input_processing],
|
339 |
outputs=[processed_image_highres, processed_image], queue=True
|
340 |
).success(fn=partial(run_pipeline, pipeline, cfg),
|
341 |
-
inputs=[processed_image_highres, scale_slider, steps_slider, seed, crop_size],
|
342 |
outputs=[view_1, view_2, view_3, view_4, view_5, view_6,
|
343 |
normal_1, normal_2, normal_3, normal_4, normal_5, normal_6,
|
344 |
view_gallery, normal_gallery]
|
|
|
23 |
from dataclasses import dataclass
|
24 |
import huggingface_hub
|
25 |
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
26 |
+
from mv_diffusion_30.models.unet_mv2d_condition import UNetMV2DConditionModel
|
27 |
+
from mv_diffusion_30.data.single_image_dataset import SingleImageDataset as MVDiffusionDataset
|
28 |
+
from mv_diffusion_30.pipelines.pipeline_mvdiffusion_image import MVDiffusionImagePipeline
|
29 |
from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler
|
30 |
from einops import rearrange
|
31 |
import numpy as np
|
32 |
from transformers import SamModel, SamProcessor
|
33 |
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
def save_image(tensor):
|
38 |
ndarr = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
|
39 |
# pdb.set_trace()
|
|
|
51 |
<div>
|
52 |
The demo does not include the mesh reconstruction part, please visit <a href="https://github.com/xxlong0/Wonder3D/">our github repo</a> to get a textured mesh.
|
53 |
</div>
|
54 |
+
<span style="font-weight: bold; color: #d9534f;">- 2024.11.5 We shift our ckpt to the a more powerful model [Wonder3D_Plus] that supports both orthogonal and perspective camera settings and further improves generalizability.</span>
|
55 |
'''
|
56 |
_GPU_ID = 0
|
57 |
|
|
|
61 |
|
62 |
|
63 |
def sam_init():
|
64 |
+
model = SamModel.from_pretrained("facebook/sam-vit-huge").to("cuda")
|
65 |
+
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
66 |
return model, processor
|
67 |
|
68 |
def sam_segment(sam_model, sam_processor, input_image, *bbox_coords):
|
69 |
+
input_points = [[[bbox_coords[2] - bbox_coords[0], bbox_coords[3] - bbox_coords[1]]]]
|
70 |
bbox = torch.tensor(bbox_coords, dtype=torch.float32)
|
71 |
bbox = bbox.unsqueeze(0).unsqueeze(0)
|
72 |
image = np.asarray(input_image)
|
73 |
|
74 |
start_time = time.time()
|
75 |
|
76 |
+
inputs = sam_processor(input_image, input_boxes=bbox, return_tensors="pt", do_resize=False).to("cuda")
|
77 |
|
78 |
outputs = sam_model(**inputs, multimask_output=False)
|
79 |
+
masks = sam_processor.image_processor.post_process_masks(outputs.pred_masks.cpu(),
|
80 |
+
inputs["original_sizes"].cpu(),
|
81 |
+
inputs["reshaped_input_sizes"].cpu(), )
|
82 |
|
83 |
print(f"SAM Time: {time.time() - start_time:.3f}s")
|
84 |
out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
|
85 |
out_image[:, :, :3] = image
|
86 |
out_image_bbox = out_image.copy()
|
87 |
|
88 |
+
foreground_mask = masks[-1][-1, -1, ...] * 1.
|
89 |
+
out_image_bbox[:, :, 3] = foreground_mask.cpu().detach().numpy().astype(np.uint8) * 255
|
90 |
torch.cuda.empty_cache()
|
91 |
+
return Image.fromarray(out_image_bbox, mode='RGBA')
|
92 |
|
93 |
def expand2square(pil_img, background_color):
|
94 |
width, height = pil_img.size
|
|
|
150 |
feature_extractor = CLIPImageProcessor.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="feature_extractor", revision=cfg.revision)
|
151 |
vae = AutoencoderKL.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="vae", revision=cfg.revision)
|
152 |
unet = UNetMV2DConditionModel.from_pretrained_2d(cfg.pretrained_unet_path, subfolder="unet", revision=cfg.revision, **cfg.unet_from_pretrained_kwargs)
|
153 |
+
# unet.enable_xformers_memory_efficient_attention()
|
154 |
|
155 |
# Move text_encode and vae to gpu and cast to weight_dtype
|
156 |
image_encoder.to(dtype=weight_dtype)
|
|
|
168 |
# sys.main_lock = threading.Lock()
|
169 |
return pipeline
|
170 |
|
171 |
+
from mv_diffusion_30.data.single_image_dataset import SingleImageDataset
|
172 |
+
def prepare_data(single_image, crop_size, input_camera_type):
|
173 |
dataset = SingleImageDataset(
|
174 |
root_dir = None,
|
175 |
num_views = 6,
|
176 |
img_wh=[256, 256],
|
177 |
bg_color='white',
|
178 |
crop_size=crop_size,
|
179 |
+
single_image=single_image,
|
180 |
+
load_cam_type=True,
|
181 |
+
cam_types=[input_camera_type]
|
182 |
)
|
183 |
return dataset[0]
|
184 |
|
185 |
|
186 |
+
def run_pipeline(pipeline, cfg, single_image, guidance_scale, steps, seed, crop_size, input_camera_type):
|
187 |
import pdb
|
188 |
# pdb.set_trace()
|
189 |
|
190 |
+
|
191 |
+
|
192 |
+
batch = prepare_data(single_image, crop_size, input_camera_type)
|
193 |
|
194 |
pipeline.set_progress_bar_config(disable=True)
|
195 |
seed = int(seed)
|
|
|
256 |
|
257 |
cond_on_normals: bool
|
258 |
cond_on_colors: bool
|
259 |
+
load_task: bool
|
260 |
|
261 |
|
262 |
def run_demo():
|
263 |
from utils.misc import load_config
|
264 |
from omegaconf import OmegaConf
|
265 |
# parse YAML config to OmegaConf
|
266 |
+
cfg = load_config("./configs/mvdiffusion-joint-plus.yaml")
|
267 |
# print(cfg)
|
268 |
schema = OmegaConf.structured(TestConfig)
|
269 |
cfg = OmegaConf.merge(schema, cfg)
|
|
|
315 |
output_processing = gr.CheckboxGroup(['Background Removal'], label='Output Image Postprocessing', value=[])
|
316 |
with gr.Row():
|
317 |
with gr.Column():
|
318 |
+
scale_slider = gr.Slider(1, 5, value=2, step=1,
|
319 |
label='Classifier Free Guidance Scale')
|
320 |
with gr.Column():
|
321 |
steps_slider = gr.Slider(15, 100, value=50, step=1,
|
|
|
325 |
seed = gr.Number(42, label='Seed')
|
326 |
with gr.Column():
|
327 |
crop_size = gr.Number(192, label='Crop size')
|
328 |
+
with gr.Row():
|
329 |
+
camera_type = gr.Radio(
|
330 |
+
choices=[("Orthogonal Camera", "ortho"), ("Perspective Camera", "persp")],
|
331 |
+
value="ortho",
|
332 |
+
label="Camera Type"
|
333 |
+
)
|
334 |
+
|
335 |
+
|
336 |
# crop_size = 192
|
337 |
run_btn = gr.Button('Generate', variant='primary', interactive=True)
|
338 |
with gr.Row():
|
|
|
359 |
inputs=[input_image, input_processing],
|
360 |
outputs=[processed_image_highres, processed_image], queue=True
|
361 |
).success(fn=partial(run_pipeline, pipeline, cfg),
|
362 |
+
inputs=[processed_image_highres, scale_slider, steps_slider, seed, crop_size, camera_type],
|
363 |
outputs=[view_1, view_2, view_3, view_4, view_5, view_6,
|
364 |
normal_1, normal_2, normal_3, normal_4, normal_5, normal_6,
|
365 |
view_gallery, normal_gallery]
|