flamehaze1115 commited on
Commit
080cd4b
·
verified ·
1 Parent(s): e5c095e

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +4 -3
gradio_app.py CHANGED
@@ -57,25 +57,26 @@ if not hasattr(Image, 'Resampling'):
57
 
58
 
59
  def sam_init():
60
- model = SamModel.from_pretrained("facebook/sam-vit-huge")
61
  processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
62
  return model, processor
63
 
64
  def sam_segment(sam_model, sam_processor, input_image, *bbox_coords):
65
  bbox = torch.tensor(bbox_coords, dtype=torch.float32)
66
- bbox = box_tensor = bbox.unsqueeze(0).unsqueeze(0) # (1, 1, 4)
67
  image = np.asarray(input_image)
68
 
69
  start_time = time.time()
70
 
71
  inputs = sam_processor(input_image, input_boxes=bbox, return_tensors="pt").to("cuda")
72
  outputs = sam_model(**inputs)
 
73
 
74
  print(f"SAM Time: {time.time() - start_time:.3f}s")
75
  out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
76
  out_image[:, :, :3] = image
77
  out_image_bbox = out_image.copy()
78
- out_image_bbox[:, :, 3] = outputs["masks"][-1].astype(np.uint8) * 255
79
  torch.cuda.empty_cache()
80
  return Image.fromarray(out_image_bbox, mode='RGBA')
81
 
 
57
 
58
 
59
  def sam_init():
60
+ model = SamModel.from_pretrained("facebook/sam-vit-huge").to("cuda")
61
  processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
62
  return model, processor
63
 
64
  def sam_segment(sam_model, sam_processor, input_image, *bbox_coords):
65
  bbox = torch.tensor(bbox_coords, dtype=torch.float32)
66
+ bbox = bbox.unsqueeze(0)
67
  image = np.asarray(input_image)
68
 
69
  start_time = time.time()
70
 
71
  inputs = sam_processor(input_image, input_boxes=bbox, return_tensors="pt").to("cuda")
72
  outputs = sam_model(**inputs)
73
+ masks = sam_processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
74
 
75
  print(f"SAM Time: {time.time() - start_time:.3f}s")
76
  out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
77
  out_image[:, :, :3] = image
78
  out_image_bbox = out_image.copy()
79
+ out_image_bbox[:, :, 3] = masks[-1].astype(np.uint8) * 255
80
  torch.cuda.empty_cache()
81
  return Image.fromarray(out_image_bbox, mode='RGBA')
82