from typing import List import gradio as gr import numpy as np import supervision as sv import torch from PIL import Image from transformers import pipeline, CLIPProcessor, CLIPModel MARKDOWN = """ # Segment Anything Model + MetaCLIP This is the demo for a Open Vocabulary Image Segmentation using [Segment Anything Model](https://github.com/facebookresearch/segment-anything) and [MetaCLIP](https://github.com/facebookresearch/MetaCLIP) combo. """ EXAMPLES = [ ["https://media.roboflow.com/notebooks/examples/dog.jpeg", "dog", 0.5], ["https://media.roboflow.com/notebooks/examples/dog.jpeg", "building", 0.5], ["https://media.roboflow.com/notebooks/examples/dog-3.jpeg", "jacket", 0.5], ["https://media.roboflow.com/notebooks/examples/dog-3.jpeg", "coffee", 0.6], ] MIN_AREA_THRESHOLD = 0.01 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" SAM_GENERATOR = pipeline( task="mask-generation", model="facebook/sam-vit-large", device=DEVICE) CLIP_MODEL = CLIPModel.from_pretrained("facebook/metaclip-b32-400m").to(DEVICE) CLIP_PROCESSOR = CLIPProcessor.from_pretrained("facebook/metaclip-b32-400m") SEMITRANSPARENT_MASK_ANNOTATOR = sv.MaskAnnotator( color=sv.Color.red(), color_lookup=sv.ColorLookup.INDEX) SOLID_MASK_ANNOTATOR = sv.MaskAnnotator( color=sv.Color.red(), color_lookup=sv.ColorLookup.INDEX, opacity=1) def run_sam(image_rgb_pil: Image.Image) -> sv.Detections: outputs = SAM_GENERATOR(image_rgb_pil, points_per_batch=32) mask = np.array(outputs['masks']) return sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask) def run_clip(image_rgb_pil: Image.Image, text: List[str]) -> np.ndarray: inputs = CLIP_PROCESSOR( text=text, images=image_rgb_pil, return_tensors="pt", padding=True ).to(DEVICE) outputs = CLIP_MODEL(**inputs) probs = outputs.logits_per_image.softmax(dim=1) return probs.detach().cpu().numpy() def reverse_mask_image(image: np.ndarray, mask: np.ndarray, gray_value=128): gray_color = np.array([gray_value, gray_value, gray_value], dtype=np.uint8) return np.where(mask[..., None], image, gray_color) def annotate( image_rgb_pil: Image.Image, detections: sv.Detections, annotator: sv.MaskAnnotator ) -> Image.Image: img_bgr_numpy = np.array(image_rgb_pil)[:, :, ::-1] annotated_bgr_image = annotator.annotate( scene=img_bgr_numpy, detections=detections) return Image.fromarray(annotated_bgr_image[:, :, ::-1]) def filter_detections( image_rgb_pil: Image.Image, detections: sv.Detections, prompt: str, confidence: float ) -> sv.Detections: img_rgb_numpy = np.array(image_rgb_pil) text = [f"a picture of {prompt}", "a picture of background"] filtering_mask = [] for xyxy, mask in zip(detections.xyxy, detections.mask): crop = sv.crop_image(image=img_rgb_numpy, xyxy=xyxy) mask_crop = sv.crop_image(image=mask, xyxy=xyxy) masked_crop = reverse_mask_image(image=crop, mask=mask_crop) masked_crop_pil = Image.fromarray(masked_crop) probs = run_clip(image_rgb_pil=masked_crop_pil, text=text) filtering_mask.append(probs[0][0] > confidence) filtering_mask = np.array(filtering_mask) return detections[filtering_mask] def inference( image_rgb_pil: Image.Image, prompt: str, confidence: float ) -> List[Image.Image]: width, height = image_rgb_pil.size area = width * height detections = run_sam(image_rgb_pil) detections = detections[detections.area / area > MIN_AREA_THRESHOLD] detections = filter_detections( image_rgb_pil=image_rgb_pil, detections=detections, prompt=prompt, confidence=confidence) blank_image = Image.new("RGB", (width, height), "black") return [ annotate( image_rgb_pil=image_rgb_pil, detections=detections, annotator=SEMITRANSPARENT_MASK_ANNOTATOR), annotate( image_rgb_pil=blank_image, detections=detections, annotator=SOLID_MASK_ANNOTATOR) ] with gr.Blocks() as demo: gr.Markdown(MARKDOWN) with gr.Row(): with gr.Column(): input_image = gr.Image( image_mode='RGB', type='pil', height=500) prompt_text = gr.Textbox( label="Prompt", value="dog") confidence_slider = gr.Slider( label="Confidence", minimum=0.5, maximum=1.0, step=0.05, value=0.6) submit_button = gr.Button("Submit") gallery = gr.Gallery(label="Result", object_fit="scale-down", preview=True) with gr.Row(): gr.Examples( examples=EXAMPLES, fn=inference, inputs=[input_image, prompt_text, confidence_slider], outputs=[gallery], cache_examples=True, run_on_click=True ) submit_button.click( inference, inputs=[input_image, prompt_text, confidence_slider], outputs=gallery) demo.launch(debug=False, show_error=True)