Spaces:
Runtime error
Runtime error
flamehaze1115
commited on
Update gradio_app.py
Browse files- gradio_app.py +11 -9
gradio_app.py
CHANGED
@@ -29,7 +29,7 @@ from mvdiffusion.pipelines.pipeline_mvdiffusion_image import MVDiffusionImagePip
|
|
29 |
from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler
|
30 |
from einops import rearrange
|
31 |
import numpy as np
|
32 |
-
from transformers import
|
33 |
|
34 |
def save_image(tensor):
|
35 |
ndarr = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
|
@@ -57,16 +57,18 @@ if not hasattr(Image, 'Resampling'):
|
|
57 |
|
58 |
|
59 |
def sam_init():
|
60 |
-
|
61 |
-
|
|
|
62 |
|
63 |
-
def sam_segment(
|
64 |
bbox = np.array(bbox_coords)
|
65 |
image = np.asarray(input_image)
|
66 |
|
67 |
start_time = time.time()
|
68 |
|
69 |
-
|
|
|
70 |
|
71 |
print(f"SAM Time: {time.time() - start_time:.3f}s")
|
72 |
out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
|
@@ -89,7 +91,7 @@ def expand2square(pil_img, background_color):
|
|
89 |
result.paste(pil_img, ((height - width) // 2, 0))
|
90 |
return result
|
91 |
|
92 |
-
def preprocess(
|
93 |
RES = 1024
|
94 |
input_image.thumbnail([RES, RES], Image.Resampling.LANCZOS)
|
95 |
if chk_group is not None:
|
@@ -105,7 +107,7 @@ def preprocess(predictor, input_image, chk_group=None, segment=True, rescale=Fal
|
|
105 |
y_min = int(y_nonzero[0].min())
|
106 |
x_max = int(x_nonzero[0].max())
|
107 |
y_max = int(y_nonzero[0].max())
|
108 |
-
input_image = sam_segment(
|
109 |
# Rescale and recenter
|
110 |
if rescale:
|
111 |
image_arr = np.array(input_image)
|
@@ -253,7 +255,7 @@ def run_demo():
|
|
253 |
torch.set_grad_enabled(False)
|
254 |
pipeline.to(f'cuda:{_GPU_ID}')
|
255 |
|
256 |
-
|
257 |
|
258 |
custom_theme = gr.themes.Soft(primary_hue="blue").set(
|
259 |
button_secondary_background_fill="*neutral_100",
|
@@ -328,7 +330,7 @@ def run_demo():
|
|
328 |
normal_gallery = gr.Gallery(interactive=False, show_label=False, container=True, preview=True, allow_preview=False, height=1200)
|
329 |
|
330 |
|
331 |
-
run_btn.click(fn=partial(preprocess,
|
332 |
inputs=[input_image, input_processing],
|
333 |
outputs=[processed_image_highres, processed_image], queue=True
|
334 |
).success(fn=partial(run_pipeline, pipeline, cfg),
|
|
|
29 |
from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler
|
30 |
from einops import rearrange
|
31 |
import numpy as np
|
32 |
+
from transformers import SamModel, SamProcessor
|
33 |
|
34 |
def save_image(tensor):
|
35 |
ndarr = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
|
|
|
57 |
|
58 |
|
59 |
def sam_init():
|
60 |
+
model = SamModel.from_pretrained("facebook/sam-vit-huge")
|
61 |
+
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
62 |
+
return model, processor
|
63 |
|
64 |
+
def sam_segment(sam_model, sam_processor, input_image, *bbox_coords):
|
65 |
bbox = np.array(bbox_coords)
|
66 |
image = np.asarray(input_image)
|
67 |
|
68 |
start_time = time.time()
|
69 |
|
70 |
+
inputs = sam_processor(raw_image, input_boxes=bbox, return_tensors="pt").to("cuda")
|
71 |
+
outputs = sam_model(**inputs)
|
72 |
|
73 |
print(f"SAM Time: {time.time() - start_time:.3f}s")
|
74 |
out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
|
|
|
91 |
result.paste(pil_img, ((height - width) // 2, 0))
|
92 |
return result
|
93 |
|
94 |
+
def preprocess(sam_model, sam_processor, input_image, chk_group=None, segment=True, rescale=False):
|
95 |
RES = 1024
|
96 |
input_image.thumbnail([RES, RES], Image.Resampling.LANCZOS)
|
97 |
if chk_group is not None:
|
|
|
107 |
y_min = int(y_nonzero[0].min())
|
108 |
x_max = int(x_nonzero[0].max())
|
109 |
y_max = int(y_nonzero[0].max())
|
110 |
+
input_image = sam_segment(sam_model, sam_processor, input_image.convert('RGB'), x_min, y_min, x_max, y_max)
|
111 |
# Rescale and recenter
|
112 |
if rescale:
|
113 |
image_arr = np.array(input_image)
|
|
|
255 |
torch.set_grad_enabled(False)
|
256 |
pipeline.to(f'cuda:{_GPU_ID}')
|
257 |
|
258 |
+
sam_model, sam_processor = sam_init()
|
259 |
|
260 |
custom_theme = gr.themes.Soft(primary_hue="blue").set(
|
261 |
button_secondary_background_fill="*neutral_100",
|
|
|
330 |
normal_gallery = gr.Gallery(interactive=False, show_label=False, container=True, preview=True, allow_preview=False, height=1200)
|
331 |
|
332 |
|
333 |
+
run_btn.click(fn=partial(preprocess, sam_model, sam_processor),
|
334 |
inputs=[input_image, input_processing],
|
335 |
outputs=[processed_image_highres, processed_image], queue=True
|
336 |
).success(fn=partial(run_pipeline, pipeline, cfg),
|