import gradio as gr import torch import cv2 import traceback import numpy as np from transformers import SamModel, SamProcessor model = SamModel.from_pretrained('facebook/sam-vit-huge') processor = SamProcessor.from_pretrained('facebook/sam-vit-huge') def set_predictor(image, request: gr.Request): """ Creates a Sam predictor object based on a given image and model. """ if request: result = str(request.headers) + "|" + str(request.query_params) + "|" + str(request.path_params) + "|" + str(request.url) else: result = 'done' device = 'cpu' inputs = processor(image, return_tensors='pt').to(device) image_embedding = model.get_image_embeddings(inputs['pixel_values']) return [image, image_embedding, result] def get_polygon(points, image, image_embedding): """ Returns the points of the polygon given a bounding box and a prediction made by Sam, or if an exception was triggered, it returns such exception. """ points = [int(w) for w in points.split(',')] device = 'cpu' inputs = processor(image, input_boxes=[points], return_tensors="pt").to(device) # pop the pixel_values as they are not neded inputs.pop("pixel_values", None) inputs.update({"image_embeddings": image_embedding}) with torch.no_grad(): outputs = model(**inputs) masks = processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() ) mask = masks[0].squeeze().numpy() img = mask.astype(np.uint8)[0] contours, hierarchy = cv2.findContours(img, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) points = contours[0] polygon = [] for point in points: for x, y in point: polygon.append([int(x), int(y)]) return polygon with gr.Blocks() as app: image = gr.State() embedding = gr.State() with gr.Tab('Get embedding'): input_image = gr.Image(label='Image') output_status = gr.Textbox(label='Status') predictor_button = gr.Button('Send Image') with gr.Tab('Get points'): bbox = gr.Textbox(label="bbox") polygon = [gr.Textbox(label='Polygon')] points_button = gr.Button('Send bounding box') predictor_button.click( set_predictor, input_image, [image, embedding, output_status], ) points_button.click( get_polygon, [bbox, image, embedding], polygon, ) app.launch(debug=True)