from typing import Optional import gradio as gr import numpy as np import supervision as sv import torch from PIL import Image from gradio_image_prompter import ImagePrompter from utils.models import load_models, CHECKPOINT_NAMES, MODE_NAMES, \ MASK_GENERATION_MODE, BOX_PROMPT_MODE MARKDOWN = """ # Segment Anything Model 2 🔥
Segment Anything Model 2 (SAM 2) is a foundation model designed to address promptable visual segmentation in both images and videos. **Video segmentation will be available soon.** """ EXAMPLES = [ ["tiny", MASK_GENERATION_MODE, "https://media.roboflow.com/notebooks/examples/dog-2.jpeg", None], ["tiny", MASK_GENERATION_MODE, "https://media.roboflow.com/notebooks/examples/dog-3.jpeg", None], ["tiny", MASK_GENERATION_MODE, "https://media.roboflow.com/notebooks/examples/dog-4.jpeg", None], ] DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') MASK_ANNOTATOR = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX) IMAGE_PREDICTORS, MASK_GENERATORS = load_models(device=DEVICE) def process( checkpoint_dropdown, mode_dropdown, image_input, image_prompter_input ) -> Optional[Image.Image]: if mode_dropdown == BOX_PROMPT_MODE: image_input = image_prompter_input["image"] prompt = image_prompter_input["points"] if len(prompt) == 0: return image_input model = IMAGE_PREDICTORS[checkpoint_dropdown] image = np.array(image_input.convert("RGB")) box = np.array([[x1, y1, x2, y2] for x1, y1, _, x2, y2, _ in prompt]) model.set_image(image) masks, _, _ = model.predict(box=box, multimask_output=False) # dirty fix; remove this later if len(masks.shape) == 4: masks = np.squeeze(masks) detections = sv.Detections( xyxy=sv.mask_to_xyxy(masks=masks), mask=masks.astype(bool) ) return MASK_ANNOTATOR.annotate(image_input, detections) if mode_dropdown == MASK_GENERATION_MODE: model = MASK_GENERATORS[checkpoint_dropdown] image = np.array(image_input.convert("RGB")) result = model.generate(image) detections = sv.Detections.from_sam(result) return MASK_ANNOTATOR.annotate(image_input, detections) with gr.Blocks() as demo: gr.Markdown(MARKDOWN) with gr.Row(): checkpoint_dropdown_component = gr.Dropdown( choices=CHECKPOINT_NAMES, value=CHECKPOINT_NAMES[0], label="Checkpoint", info="Select a SAM2 checkpoint to use.", interactive=True ) mode_dropdown_component = gr.Dropdown( choices=MODE_NAMES, value=MODE_NAMES[0], label="Mode", info="Select a mode to use. `box prompt` if you want to generate masks for " "selected objects, `mask generation` if you want to generate masks " "for the whole image.", interactive=True ) with gr.Row(): with gr.Column(): image_input_component = gr.Image( type='pil', label='Upload image', visible=False) image_prompter_input_component = ImagePrompter( type='pil', label='Image prompt') submit_button_component = gr.Button( value='Submit', variant='primary') with gr.Column(): image_output_component = gr.Image(type='pil', label='Image Output') with gr.Row(): gr.Examples( fn=process, examples=EXAMPLES, inputs=[ checkpoint_dropdown_component, mode_dropdown_component, image_input_component, image_prompter_input_component, ], outputs=[image_output_component], run_on_click=True ) def on_mode_dropdown_change(text): return [ gr.Image(visible=text == MASK_GENERATION_MODE), ImagePrompter(visible=text == BOX_PROMPT_MODE) ] mode_dropdown_component.change( on_mode_dropdown_change, inputs=[mode_dropdown_component], outputs=[ image_input_component, image_prompter_input_component ] ) submit_button_component.click( fn=process, inputs=[ checkpoint_dropdown_component, mode_dropdown_component, image_input_component, image_prompter_input_component, ], outputs=[image_output_component] ) demo.launch(debug=False, show_error=True, max_threads=1)