import gradio as gr import ast import cv2 import torch import traceback import numpy as np from itertools import chain from transformers import SamModel, SamProcessor model = SamModel.from_pretrained('facebook/sam-vit-huge') processor = SamProcessor.from_pretrained('facebook/sam-vit-huge') def set_predictor(image): """ Creates a Sam predictor object based on a given image and model. """ device = 'cpu' inputs = processor(image, return_tensors='pt').to(device) image_embedding = model.get_image_embeddings(inputs['pixel_values']) return [image, image_embedding, 'Done'] def get_polygon(points, image, image_embedding): """ Returns the points of the polygon given a bounding box and a prediction made by Sam. """ points = list(chain.from_iterable(points)) device = 'cpu' inputs = processor(image, input_boxes=[points], return_tensors="pt").to(device) # pop the pixel_values as they are not neded inputs.pop("pixel_values", None) inputs.update({"image_embeddings": image_embedding}) with torch.no_grad(): outputs = model(**inputs) masks = processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() ) mask = masks[0].squeeze().numpy() img = mask.astype(np.uint8)[0] contours, hierarchy = cv2.findContours(img, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) if len(contours) == 0: return [], img points = contours[0] polygon = [] for point in points: for x, y in point: polygon.append([int(x), int(y)]) mask = np.zeros(image.shape, dtype='uint8') poly = np.array(polygon) cv2.fillPoly(mask, [poly], (0, 255, 0)) return polygon, mask def add_bbox(bbox, evt: gr.SelectData): if bbox[0] == [0, 0]: bbox[0] = [evt.index[0], evt.index[1]] return bbox, bbox bbox[1] = [evt.index[0], evt.index[1]] return bbox, bbox def clear_bbox(bbox): updated_bbox = [[0, 0], [0, 0]] return updated_bbox, updated_bbox with gr.Blocks() as demo: image = gr.State() embedding = gr.State() bbox = gr.State([[0, 0], [0, 0]]) with gr.Row(): input_image = gr.Image(label='Image') mask = gr.Image(label='Mask') with gr.Row(): with gr.Column(): output_status = gr.Textbox(label='Status') with gr.Column(): predictor_button = gr.Button('Send Image') with gr.Row(): with gr.Column(): bbox_box = gr.Textbox(label="bbox") with gr.Column(): bbox_button = gr.Button('Clear bbox') with gr.Row(): with gr.Column(): polygon = gr.Textbox(label='Polygon') with gr.Column(): points_button = gr.Button('Send bounding box') predictor_button.click( set_predictor, input_image, [image, embedding, output_status], ) points_button.click( get_polygon, [bbox, image, embedding], [polygon, mask], ) bbox_button.click( clear_bbox, bbox, [bbox, bbox_box], ) input_image.select( add_bbox, bbox, [bbox, bbox_box] ) demo.launch(debug=True)