Yuxiao319 commited on
Commit
8ca9794
·
1 Parent(s): 28468ea

sam_segment

Browse files
Files changed (1) hide show
  1. gradio_app.py +7 -5
gradio_app.py CHANGED
@@ -57,8 +57,8 @@ if not hasattr(Image, 'Resampling'):
57
 
58
 
59
  def sam_init():
60
- model = SamModel.from_pretrained("facebook/sam-vit-huge").to("cuda")
61
- processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
62
  return model, processor
63
 
64
  def sam_segment(sam_model, sam_processor, input_image, *bbox_coords):
@@ -68,15 +68,17 @@ def sam_segment(sam_model, sam_processor, input_image, *bbox_coords):
68
 
69
  start_time = time.time()
70
 
71
- inputs = sam_processor(input_image, input_boxes=bbox, return_tensors="pt").to("cuda")
72
- outputs = sam_model(**inputs)
 
73
  masks = sam_processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
74
 
75
  print(f"SAM Time: {time.time() - start_time:.3f}s")
76
  out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
77
  out_image[:, :, :3] = image
78
  out_image_bbox = out_image.copy()
79
- out_image_bbox[:, :, 3] = masks[-1].astype(np.uint8) * 255
 
80
  torch.cuda.empty_cache()
81
  return Image.fromarray(out_image_bbox, mode='RGBA')
82
 
 
57
 
58
 
59
  def sam_init():
60
+ model = SamModel.from_pretrained("facebook/sam-vit-large").to("cuda")
61
+ processor = SamProcessor.from_pretrained("facebook/sam-vit-large")
62
  return model, processor
63
 
64
  def sam_segment(sam_model, sam_processor, input_image, *bbox_coords):
 
68
 
69
  start_time = time.time()
70
 
71
+ inputs = sam_processor(input_image.convert('RGB'), input_boxes=bbox, return_tensors="pt", do_resize=False).to("cuda")
72
+
73
+ outputs = sam_model(**inputs, multimask_output=False)
74
  masks = sam_processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
75
 
76
  print(f"SAM Time: {time.time() - start_time:.3f}s")
77
  out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
78
  out_image[:, :, :3] = image
79
  out_image_bbox = out_image.copy()
80
+
81
+ out_image_bbox[:, :, 3] = masks[-1].cpu().detach().numpy().astype(np.uint8) * 255
82
  torch.cuda.empty_cache()
83
  return Image.fromarray(out_image_bbox, mode='RGBA')
84