import os import PIL from functools import lru_cache from random import randint import gradio as gr import cv2 import torch import numpy as np from segment_anything import sam_model_registry, SamAutomaticMaskGenerator from typing import List CHECKPOINT_PATH = "sam_vit_h_4b8939.pth" MODEL_TYPE = "default" MAX_WIDTH = MAX_HEIGHT = 800 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @lru_cache def load_mask_generator(model_size: str = "large") -> SamAutomaticMaskGenerator: sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device) mask_generator = SamAutomaticMaskGenerator(sam) return mask_generator def adjust_image_size(image: np.ndarray) -> np.ndarray: height, width = image.shape[:2] if height > width: if height > MAX_HEIGHT: height, width = MAX_HEIGHT, int(MAX_HEIGHT / height * width) else: if width > MAX_WIDTH: height, width = int(MAX_WIDTH / width * height), MAX_WIDTH image = cv2.resize(image, (width, height)) print(image.shape) return image def draw_masks( image: np.ndarray, masks: List[np.ndarray], alpha: float = 0.7 ) -> np.ndarray: for mask in masks: color = [randint(127, 255) for _ in range(3)] segmentation = mask["segmentation"] # draw mask overlay colored_seg = np.expand_dims(segmentation, 0).repeat(3, axis=0) colored_seg = np.moveaxis(colored_seg, 0, -1) masked = np.ma.MaskedArray(image, mask=colored_seg, fill_value=color) image_overlay = masked.filled() image = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0) # draw contour contours, _ = cv2.findContours( np.uint8(segmentation), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE ) cv2.drawContours(image, contours, -1, (255, 0, 0), 2) return image def segment(image_path: str, query: str) -> PIL.ImageFile.ImageFile: mask_generator = load_mask_generator() # reduce the size to save gpu memory image = adjust_image_size(cv2.imread(image_path)) masks = mask_generator.generate(image) image = draw_masks(image, masks) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = PIL.Image.fromarray(np.uint8(image)).convert("RGB") return image demo = gr.Interface( fn=segment, inputs=[gr.Image(type="filepath"), "text"], outputs="image", allow_flagging="never", title="Segment Anything with CLIP", examples=[ [os.path.join(os.path.dirname(__file__), "examples/dog.jpg"), ""], [os.path.join(os.path.dirname(__file__), "examples/city.jpg"), ""], [os.path.join(os.path.dirname(__file__), "examples/food.jpg"), ""], [os.path.join(os.path.dirname(__file__), "examples/horse.jpg"), ""], ], ) if __name__ == "__main__": demo.launch()