Spaces:
Running
on
L4
Running
on
L4
sam_segment
Browse files- gradio_app.py +7 -5
gradio_app.py
CHANGED
@@ -57,8 +57,8 @@ if not hasattr(Image, 'Resampling'):
|
|
57 |
|
58 |
|
59 |
def sam_init():
|
60 |
-
model = SamModel.from_pretrained("facebook/sam-vit-
|
61 |
-
processor = SamProcessor.from_pretrained("facebook/sam-vit-
|
62 |
return model, processor
|
63 |
|
64 |
def sam_segment(sam_model, sam_processor, input_image, *bbox_coords):
|
@@ -68,15 +68,17 @@ def sam_segment(sam_model, sam_processor, input_image, *bbox_coords):
|
|
68 |
|
69 |
start_time = time.time()
|
70 |
|
71 |
-
inputs = sam_processor(input_image, input_boxes=bbox, return_tensors="pt").to("cuda")
|
72 |
-
|
|
|
73 |
masks = sam_processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
|
74 |
|
75 |
print(f"SAM Time: {time.time() - start_time:.3f}s")
|
76 |
out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
|
77 |
out_image[:, :, :3] = image
|
78 |
out_image_bbox = out_image.copy()
|
79 |
-
|
|
|
80 |
torch.cuda.empty_cache()
|
81 |
return Image.fromarray(out_image_bbox, mode='RGBA')
|
82 |
|
|
|
57 |
|
58 |
|
59 |
def sam_init():
|
60 |
+
model = SamModel.from_pretrained("facebook/sam-vit-large").to("cuda")
|
61 |
+
processor = SamProcessor.from_pretrained("facebook/sam-vit-large")
|
62 |
return model, processor
|
63 |
|
64 |
def sam_segment(sam_model, sam_processor, input_image, *bbox_coords):
|
|
|
68 |
|
69 |
start_time = time.time()
|
70 |
|
71 |
+
inputs = sam_processor(input_image.convert('RGB'), input_boxes=bbox, return_tensors="pt", do_resize=False).to("cuda")
|
72 |
+
|
73 |
+
outputs = sam_model(**inputs, multimask_output=False)
|
74 |
masks = sam_processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
|
75 |
|
76 |
print(f"SAM Time: {time.time() - start_time:.3f}s")
|
77 |
out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
|
78 |
out_image[:, :, :3] = image
|
79 |
out_image_bbox = out_image.copy()
|
80 |
+
|
81 |
+
out_image_bbox[:, :, 3] = masks[-1].cpu().detach().numpy().astype(np.uint8) * 255
|
82 |
torch.cuda.empty_cache()
|
83 |
return Image.fromarray(out_image_bbox, mode='RGBA')
|
84 |
|