--- library_name: transformers license: mit --- # RobustSAM: Segment Anything Robustly on Degraded Images (CVPR 2024 Highlight) # Model Card for ViT Large (ViT-L) version [![Huggingfaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/robustsam/robustsam/tree/main) Official repository for RobustSAM: Segment Anything Robustly on Degraded Images [Project Page](https://robustsam.github.io/) | [Paper](https://arxiv.org/abs/2406.09627) | [Dataset](https://huggingface.co/robustsam/robustsam/tree/main/dataset) ## Introduction Segment Anything Model (SAM) has emerged as a transformative approach in image segmentation, acclaimed for its robust zero-shot segmentation capabilities and flexible prompting system. Nonetheless, its performance is challenged by images with degraded quality. Addressing this limitation, we propose the Robust Segment Anything Model (RobustSAM), which enhances SAM's performance on low-quality images while preserving its promptability and zero-shot generalization. Our method leverages the pre-trained SAM model with only marginal parameter increments and computational requirements. The additional parameters of RobustSAM can be optimized within 30 hours on eight GPUs, demonstrating its feasibility and practicality for typical research laboratories. We also introduce the Robust-Seg dataset, a collection of 688K image-mask pairs with different degradations designed to train and evaluate our model optimally. Extensive experiments across various segmentation tasks and datasets confirm RobustSAM's superior performance, especially under zero-shot conditions, underscoring its potential for extensive real-world application. Additionally, our method has been shown to effectively improve the performance of SAM-based downstream tasks such as single image dehazing and deblurring. **Disclaimer**: Content from **this** model card has been written by the Hugging Face team, and parts of it were copy pasted from the original [SAM model card](https://github.com/facebookresearch/segment-anything). # Model Details The RobustSAM model is made up of 3 modules: - The `VisionEncoder`: a VIT based image encoder. It computes the image embeddings using attention on patches of the image. Relative Positional Embedding is used. - The `PromptEncoder`: generates embeddings for points and bounding boxes - The `MaskDecoder`: a two-ways transformer which performs cross attention between the image embedding and the point embeddings (->) and between the point embeddings and the image embeddings. The outputs are fed - The `Neck`: predicts the output masks based on the contextualized masks produced by the `MaskDecoder`. # Usage ## Prompted-Mask-Generation ```python from PIL import Image import requests from transformers import AutoProcessor, AutoModelForMaskGeneration # load the RobustSAM model and processor processor = AutoProcessor.from_pretrained("jadechoghari/robustsam-vit-large") model = AutoModelForMaskGeneration.from_pretrained("jadechoghari/robustsam-vit-large") # load an image from a url img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") # we define input points (2D localization of an object in the image) input_points = [[[450, 600]]] # example point ``` ```python # process the image and input points inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to("cuda") # generate masks using the model with torch.no_grad(): outputs = model(**inputs) masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()) scores = outputs.iou_scores ``` Among other arguments to generate masks, you can pass 2D locations on the approximate position of your object of interest, a bounding box wrapping the object of interest (the format should be x, y coordinate of the top right and bottom left point of the bounding box), a segmentation mask. At this time of writing, passing a text as input is not supported by the official model according to [the official repository](https://github.com/facebookresearch/segment-anything/issues/4#issuecomment-1497626844). For more details, refer to this notebook, which shows a walk throught of how to use the model, with a visual example! ## Automatic-Mask-Generation The model can be used for generating segmentation masks in a "zero-shot" fashion, given an input image. The model is automatically prompt with a grid of `1024` points which are all fed to the model. The pipeline is made for automatic mask generation. The following snippet demonstrates how easy you can run it (on any device! Simply feed the appropriate `points_per_batch` argument) ```python from transformers import pipeline # initialize the pipeline for mask generation generator = pipeline("mask-generation", model="jadechoghari/robustsam-vit-large", device=0, points_per_batch=256) image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" outputs = generator(image_url, points_per_batch=256) ``` Now to display the generated mask on the image: ```python import matplotlib.pyplot as plt from PIL import Image import numpy as np # simple function to display the mask def show_mask(mask, ax, random_color=False): if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) else: color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) # get the height and width from the mask h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image) # display the original image plt.imshow(np.array(raw_image)) ax = plt.gca() # loop through the masks and display each one for mask in outputs["masks"]: show_mask(mask, ax=ax, random_color=True) plt.axis("off") # show the image with the masks plt.show() ``` ## Visual Comparison