flamehaze1115 commited on
Commit
6bd0055
·
verified ·
1 Parent(s): 56861f3

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +4 -11
gradio_app.py CHANGED
@@ -29,7 +29,7 @@ from mvdiffusion.pipelines.pipeline_mvdiffusion_image import MVDiffusionImagePip
29
  from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler
30
  from einops import rearrange
31
  import numpy as np
32
- from transformers import SamModel
33
 
34
  def save_image(tensor):
35
  ndarr = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
@@ -57,11 +57,7 @@ if not hasattr(Image, 'Resampling'):
57
 
58
 
59
  def sam_init():
60
- # sam_checkpoint = os.path.join(os.path.dirname(__file__), "sam_pt", "sam_vit_h_4b8939.pth")
61
- # model_type = "vit_h"
62
- # sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=f"cuda:{_GPU_ID}")
63
- sam = SamModel.from_pretrained("facebook/sam-vit-huge").to(device=f"cuda:{_GPU_ID}")
64
- predictor = SamPredictor(sam)
65
  return predictor
66
 
67
  def sam_segment(predictor, input_image, *bbox_coords):
@@ -71,16 +67,13 @@ def sam_segment(predictor, input_image, *bbox_coords):
71
  start_time = time.time()
72
  predictor.set_image(image)
73
 
74
- masks_bbox, scores_bbox, logits_bbox = predictor.predict(
75
- box=bbox,
76
- multimask_output=True
77
- )
78
 
79
  print(f"SAM Time: {time.time() - start_time:.3f}s")
80
  out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
81
  out_image[:, :, :3] = image
82
  out_image_bbox = out_image.copy()
83
- out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255
84
  torch.cuda.empty_cache()
85
  return Image.fromarray(out_image_bbox, mode='RGBA')
86
 
 
29
  from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler
30
  from einops import rearrange
31
  import numpy as np
32
+ from transformers import SamModel, SamProcessor
33
 
34
  def save_image(tensor):
35
  ndarr = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
 
57
 
58
 
59
  def sam_init():
60
+ predictor = pipeline("mask-generation", device = f"cuda:{_GPU_ID}", points_per_batch = 256)
 
 
 
 
61
  return predictor
62
 
63
  def sam_segment(predictor, input_image, *bbox_coords):
 
67
  start_time = time.time()
68
  predictor.set_image(image)
69
 
70
+ generator(image_url, points_per_batch = 256)
 
 
 
71
 
72
  print(f"SAM Time: {time.time() - start_time:.3f}s")
73
  out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
74
  out_image[:, :, :3] = image
75
  out_image_bbox = out_image.copy()
76
+ out_image_bbox[:, :, 3] = outputs["masks"][-1].astype(np.uint8) * 255
77
  torch.cuda.empty_cache()
78
  return Image.fromarray(out_image_bbox, mode='RGBA')
79