flamehaze1115 commited on
Commit
ae84133
·
verified ·
1 Parent(s): 8475a4a

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +11 -9
gradio_app.py CHANGED
@@ -29,7 +29,7 @@ from mvdiffusion.pipelines.pipeline_mvdiffusion_image import MVDiffusionImagePip
29
  from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler
30
  from einops import rearrange
31
  import numpy as np
32
- from transformers import pipeline
33
 
34
  def save_image(tensor):
35
  ndarr = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
@@ -57,16 +57,18 @@ if not hasattr(Image, 'Resampling'):
57
 
58
 
59
  def sam_init():
60
- predictor = pipeline("mask-generation", device = f"cuda:{_GPU_ID}", points_per_batch = 256)
61
- return predictor
 
62
 
63
- def sam_segment(predictor, input_image, *bbox_coords):
64
  bbox = np.array(bbox_coords)
65
  image = np.asarray(input_image)
66
 
67
  start_time = time.time()
68
 
69
- outputs = predictor(input_image, points_per_batch = 256)
 
70
 
71
  print(f"SAM Time: {time.time() - start_time:.3f}s")
72
  out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
@@ -89,7 +91,7 @@ def expand2square(pil_img, background_color):
89
  result.paste(pil_img, ((height - width) // 2, 0))
90
  return result
91
 
92
- def preprocess(predictor, input_image, chk_group=None, segment=True, rescale=False):
93
  RES = 1024
94
  input_image.thumbnail([RES, RES], Image.Resampling.LANCZOS)
95
  if chk_group is not None:
@@ -105,7 +107,7 @@ def preprocess(predictor, input_image, chk_group=None, segment=True, rescale=Fal
105
  y_min = int(y_nonzero[0].min())
106
  x_max = int(x_nonzero[0].max())
107
  y_max = int(y_nonzero[0].max())
108
- input_image = sam_segment(predictor, input_image.convert('RGB'), x_min, y_min, x_max, y_max)
109
  # Rescale and recenter
110
  if rescale:
111
  image_arr = np.array(input_image)
@@ -253,7 +255,7 @@ def run_demo():
253
  torch.set_grad_enabled(False)
254
  pipeline.to(f'cuda:{_GPU_ID}')
255
 
256
- predictor = sam_init()
257
 
258
  custom_theme = gr.themes.Soft(primary_hue="blue").set(
259
  button_secondary_background_fill="*neutral_100",
@@ -328,7 +330,7 @@ def run_demo():
328
  normal_gallery = gr.Gallery(interactive=False, show_label=False, container=True, preview=True, allow_preview=False, height=1200)
329
 
330
 
331
- run_btn.click(fn=partial(preprocess, predictor),
332
  inputs=[input_image, input_processing],
333
  outputs=[processed_image_highres, processed_image], queue=True
334
  ).success(fn=partial(run_pipeline, pipeline, cfg),
 
29
  from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler
30
  from einops import rearrange
31
  import numpy as np
32
+ from transformers import SamModel, SamProcessor
33
 
34
  def save_image(tensor):
35
  ndarr = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
 
57
 
58
 
59
  def sam_init():
60
+ model = SamModel.from_pretrained("facebook/sam-vit-huge")
61
+ processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
62
+ return model, processor
63
 
64
+ def sam_segment(sam_model, sam_processor, input_image, *bbox_coords):
65
  bbox = np.array(bbox_coords)
66
  image = np.asarray(input_image)
67
 
68
  start_time = time.time()
69
 
70
+ inputs = sam_processor(raw_image, input_boxes=bbox, return_tensors="pt").to("cuda")
71
+ outputs = sam_model(**inputs)
72
 
73
  print(f"SAM Time: {time.time() - start_time:.3f}s")
74
  out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
 
91
  result.paste(pil_img, ((height - width) // 2, 0))
92
  return result
93
 
94
+ def preprocess(sam_model, sam_processor, input_image, chk_group=None, segment=True, rescale=False):
95
  RES = 1024
96
  input_image.thumbnail([RES, RES], Image.Resampling.LANCZOS)
97
  if chk_group is not None:
 
107
  y_min = int(y_nonzero[0].min())
108
  x_max = int(x_nonzero[0].max())
109
  y_max = int(y_nonzero[0].max())
110
+ input_image = sam_segment(sam_model, sam_processor, input_image.convert('RGB'), x_min, y_min, x_max, y_max)
111
  # Rescale and recenter
112
  if rescale:
113
  image_arr = np.array(input_image)
 
255
  torch.set_grad_enabled(False)
256
  pipeline.to(f'cuda:{_GPU_ID}')
257
 
258
+ sam_model, sam_processor = sam_init()
259
 
260
  custom_theme = gr.themes.Soft(primary_hue="blue").set(
261
  button_secondary_background_fill="*neutral_100",
 
330
  normal_gallery = gr.Gallery(interactive=False, show_label=False, container=True, preview=True, allow_preview=False, height=1200)
331
 
332
 
333
+ run_btn.click(fn=partial(preprocess, sam_model, sam_processor),
334
  inputs=[input_image, input_processing],
335
  outputs=[processed_image_highres, processed_image], queue=True
336
  ).success(fn=partial(run_pipeline, pipeline, cfg),