import os from functools import lru_cache from random import randint from typing import Dict, List import cv2 import gradio as gr import numpy as np import PIL import torch from segment_anything import SamAutomaticMaskGenerator, sam_model_registry CHECKPOINT_PATH = "sam_vit_h_4b8939.pth" MODEL_TYPE = "default" MAX_WIDTH = MAX_HEIGHT = 800 THRESHOLD = 0.05 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @lru_cache def load_mask_generator(model_size: str = "large") -> SamAutomaticMaskGenerator: sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device) mask_generator = SamAutomaticMaskGenerator(sam) return mask_generator def adjust_image_size(image: np.ndarray) -> np.ndarray: height, width = image.shape[:2] if height > width: if height > MAX_HEIGHT: height, width = MAX_HEIGHT, int(MAX_HEIGHT / height * width) else: if width > MAX_WIDTH: height, width = int(MAX_WIDTH / width * height), MAX_WIDTH image = cv2.resize(image, (width, height)) return image def filter_masks( masks: List[Dict[str, np.ndarray]], predicted_iou_threshold: float, stability_score_threshold: float, query: str, clip_threshold: float, ) -> List[np.ndarray]: filtered_masks: List[Dict[str, np.ndarray]] = [] for mask in masks: if ( mask["predicted_iou"] < predicted_iou_threshold or mask["stability_score"] < stability_score_threshold ): continue filtered_masks.append(mask) return [mask["segmentation"] for mask in filtered_masks] def draw_masks( image: np.ndarray, masks: List[np.ndarray], alpha: float = 0.7 ) -> np.ndarray: for mask in masks: color = [randint(127, 255) for _ in range(3)] # draw mask overlay colored_mask = np.expand_dims(mask, 0).repeat(3, axis=0) colored_mask = np.moveaxis(colored_mask, 0, -1) masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color) image_overlay = masked.filled() image = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0) # draw contour contours, _ = cv2.findContours( np.uint8(mask), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE ) cv2.drawContours(image, contours, -1, (255, 0, 0), 2) return image def segment( predicted_iou_threshold: float, stability_score_threshold: float, clip_threshold: float, image_path: str, query: str, ) -> PIL.ImageFile.ImageFile: mask_generator = load_mask_generator() # reduce the size to save gpu memory image = adjust_image_size(cv2.imread(image_path)) masks = mask_generator.generate(image) masks = filter_masks( masks, predicted_iou_threshold, stability_score_threshold, query, clip_threshold ) image = draw_masks(image, masks) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = PIL.Image.fromarray(np.uint8(image)).convert("RGB") return image demo = gr.Interface( fn=segment, inputs=[ gr.Slider(0, 1, value=0.9, label="predicted_iou_threshold"), gr.Slider(0, 1, value=0.8, label="stability_score_threshold"), gr.Slider(0, 1, value=0.05, label="clip_threshold"), gr.Image(type="filepath"), "text", ], outputs="image", allow_flagging="never", title="Segment Anything with CLIP", examples=[ [ 0.9, 0.8, 0.05, os.path.join(os.path.dirname(__file__), "examples/dog.jpg"), "", ], [ 0.9, 0.8, 0.05, os.path.join(os.path.dirname(__file__), "examples/city.jpg"), "", ], [ 0.9, 0.8, 0.05, os.path.join(os.path.dirname(__file__), "examples/food.jpg"), "", ], [ 0.9, 0.8, 0.05, os.path.join(os.path.dirname(__file__), "examples/horse.jpg"), "", ], ], ) if __name__ == "__main__": demo.launch()