import gradio as gr import numpy as np import torch from PIL import Image from transformers import SamModel, SamProcessor from gradio_image_prompter import ImagePrompter import spaces device = torch.device("cuda" if torch.cuda.is_available() else "cpu") sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to("cuda") sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") slimsam_model = SamModel.from_pretrained("nielsr/slimsam-50-uniform").to("cuda") slimsam_processor = SamProcessor.from_pretrained("nielsr/slimsam-50-uniform") def get_processor_and_model(slim: bool): if slim: return slimsam_processor, slimsam_model return sam_processor, sam_model def sam_box_inference(image, x_min, y_min, x_max, y_max, *, slim=False): processor, model = get_processor_and_model(slim) inputs = processor( Image.fromarray(image), input_boxes=[[[[x_min, y_min, x_max, y_max]]]], return_tensors="pt" ).to(device) with torch.no_grad(): outputs = model(**inputs) mask = processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() )[0][0][0].numpy() mask = mask[np.newaxis, ...] print(mask) print(mask.shape) return [(mask, "mask")] @spaces.GPU def sam_point_inference(image, x, y, *, slim=False): processor, model = get_processor_and_model(slim) inputs = processor( image, input_points=[[[x, y]]], return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) mask = processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() )[0][0][0].numpy() mask = mask[np.newaxis, ...] print(type(mask)) print(mask.shape) return [(mask, "mask")] def infer_point(img): if img is None: gr.Error("Please upload an image and select a point.") if img["background"] is None: gr.Error("Please upload an image and select a point.") # background (original image) layers[0] ( point prompt) composite (total image) image = img["background"].convert("RGB") point_prompt = img["layers"][0] total_image = img["composite"] img_arr = np.array(point_prompt) if not np.any(img_arr): gr.Error("Please select a point on top of the image.") else: nonzero_indices = np.nonzero(img_arr) img_arr = np.array(point_prompt) nonzero_indices = np.nonzero(img_arr) center_x = int(np.mean(nonzero_indices[1])) center_y = int(np.mean(nonzero_indices[0])) print("Point inference returned.") return ((image, sam_point_inference(image, center_x, center_y, slim=True)), (image, sam_point_inference(image, center_x, center_y))) def infer_box(prompts): # background (original image) layers[0] ( point prompt) composite (total image) image = prompts["image"] if image is None: gr.Error("Please upload an image and draw a box before submitting") points = prompts["points"][0] if points is None: gr.Error("Please draw a box before submitting.") print(points) # x_min = points[0] x_max = points[3] y_min = points[1] y_max = points[4] return ((image, sam_box_inference(image, points[0], points[1], points[3], points[4], slim=True)), (image, sam_box_inference(image, points[0], points[1], points[3], points[4]))) with gr.Blocks(title="SlimSAM") as demo: gr.Markdown("# SlimSAM") gr.Markdown("SlimSAM is the pruned-distilled version of SAM that is smaller.") gr.Markdown("In this demo, you can compare SlimSAM and SAM outputs in point and box prompts.") with gr.Tab("Box Prompt"): with gr.Row(): with gr.Column(scale=1): # Title gr.Markdown("To try box prompting, simply upload and image and draw a box on it.") with gr.Row(): with gr.Column(): im = ImagePrompter() btn = gr.Button("Submit") with gr.Column(): output_box_slimsam = gr.AnnotatedImage(label="SlimSAM Output") output_box_sam = gr.AnnotatedImage(label="SAM Output") btn.click(infer_box, inputs=im, outputs=[output_box_slimsam, output_box_sam]) with gr.Tab("Point Prompt"): with gr.Row(): with gr.Column(scale=1): # Title gr.Markdown("To try point prompting, simply upload and image and leave a dot on it.") with gr.Row(): with gr.Column(): im = gr.ImageEditor( type="pil", ) with gr.Column(): output_slimsam = gr.AnnotatedImage(label="SlimSAM Output") output_sam = gr.AnnotatedImage(label="SAM Output") im.change(infer_point, inputs=im, outputs=[output_slimsam, output_sam]) demo.launch(debug=True)