import gradio as gr from grounded_sam.inference import grounded_segmentation from grounded_sam.plot import plot_detections, plot_detections_plotly def app_fn( image: gr.Image, labels: str, threshold: float, bounding_box_selection: bool ) -> str: labels = labels.split("\n") labels = [label if label.endswith(".") else label + "." for label in labels] image_array, detections = grounded_segmentation(image, labels, threshold, True) fig_detection = plot_detections_plotly(image_array, detections, bounding_box_selection) return fig_detection if __name__=="__main__": title = "Grounding SAM - Text-to-Segmentation Model" with gr.Blocks(title=title) as demo: gr.Markdown(f"# {title}") gr.Markdown( """ Grounded SAM is a text-to-segmentation model that generates segmentation masks from natural language descriptions. This demo uses Grounding DINO in tandem with SAM to generate segmentation masks from text. The workflow is as follows: 1. Select text labels to generate bounding boxes with Grounding DINO. 2. Prompt the SAM model to generate segmentation masks from the bounding boxes. 3. Refine the masks if needed. 4. Visualize the segmentation masks. ### Notes - To pass multiple labels, separate them by a new line. - The model may take a few seconds to generate the segmentation masks as we need to run through two models. - The refinement is done by default by converting the mask to a polygon and back to a mask with openCV. - I use in here a concise implementation, but you can find the full code at [GitHub](https://github.com/EduardoPach/grounded-sam) """ ) with gr.Row(): threshold = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.05, label="Box Threshold") labels = gr.Textbox(lines=2, max_lines=5, label="Labels") bounding_box_selection = gr.Checkbox(label="Allow Box Selection") btn = gr.Button() with gr.Row(): img = gr.Image(type="pil") fig = gr.Plot(label="Segmentation Mask") btn.click(fn=app_fn, inputs=[img, labels, threshold, bounding_box_selection], outputs=[fig]) gr.Examples( [ ["input_image.jpeg", "a person.\na mountain.", 0.3, False], ], inputs = [img, labels, threshold, bounding_box_selection], outputs = [fig], fn=app_fn, cache_examples=True, label='Try this example input!' ) demo.launch()