|
import gradio as gr |
|
import torch |
|
import cv2 |
|
import traceback |
|
import numpy as np |
|
from transformers import SamModel, SamProcessor |
|
|
|
|
|
model = SamModel.from_pretrained('facebook/sam-vit-huge') |
|
processor = SamProcessor.from_pretrained('facebook/sam-vit-huge') |
|
|
|
|
|
def set_predictor(image): |
|
""" |
|
Creates a Sam predictor object based on a given image and model. |
|
""" |
|
device = 'cpu' |
|
inputs = processor(image, return_tensors='pt').to(device) |
|
image_embedding = model.get_image_embeddings(inputs['pixel_values']) |
|
|
|
return [image, image_embedding, 'Done'] |
|
|
|
|
|
def get_polygon(points, image, image_embedding): |
|
""" |
|
Returns the points of the polygon given a bounding box and a prediction |
|
made by Sam, or if an exception was triggered, it returns such exception. |
|
""" |
|
points = [int(w) for w in points.split(',')] |
|
|
|
device = 'cpu' |
|
inputs = processor(image, input_boxes=[points], return_tensors="pt").to(device) |
|
|
|
|
|
inputs.pop("pixel_values", None) |
|
inputs.update({"image_embeddings": image_embedding}) |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
masks = processor.image_processor.post_process_masks( |
|
outputs.pred_masks.cpu(), |
|
inputs["original_sizes"].cpu(), |
|
inputs["reshaped_input_sizes"].cpu() |
|
) |
|
|
|
mask = masks[0].squeeze().numpy() |
|
|
|
img = mask.astype(np.uint8)[0] |
|
|
|
contours, hierarchy = cv2.findContours(img, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) |
|
points = contours[0] |
|
|
|
polygon = [] |
|
for point in points: |
|
for x, y in point: |
|
polygon.append([int(x), int(y)]) |
|
|
|
return polygon |
|
|
|
|
|
|
|
with gr.Blocks() as app: |
|
image = gr.State() |
|
embedding = gr.State() |
|
|
|
with gr.Tab('Get embedding'): |
|
input_image = gr.Image(label='Image') |
|
output_status = gr.Textbox(label='Status') |
|
predictor_button = gr.Button('Send Image') |
|
|
|
|
|
with gr.Tab('Get points'): |
|
bbox = gr.Textbox(label="bbox") |
|
polygon = [gr.Textbox(label='Polygon')] |
|
points_button = gr.Button('Send bounding box') |
|
|
|
|
|
predictor_button.click( |
|
set_predictor, |
|
input_image, |
|
[image, embedding, output_status], |
|
) |
|
|
|
points_button.click( |
|
get_polygon, |
|
[bbox, image, embedding], |
|
polygon, |
|
) |
|
|
|
app.launch(debug=True, share=True) |