|
import gradio as gr |
|
import ast |
|
import cv2 |
|
import torch |
|
import traceback |
|
import numpy as np |
|
from itertools import chain |
|
from transformers import SamModel, SamProcessor |
|
|
|
|
|
model = SamModel.from_pretrained('facebook/sam-vit-huge').to('cuda') |
|
processor = SamProcessor.from_pretrained('facebook/sam-vit-huge') |
|
|
|
|
|
def set_predictor(image): |
|
""" |
|
Creates a Sam predictor object based on a given image and model. |
|
""" |
|
device = 'cuda' |
|
inputs = processor(image, return_tensors='pt').to(device) |
|
image_embedding = model.get_image_embeddings(inputs['pixel_values']) |
|
|
|
return [image, image_embedding, 'Done'] |
|
|
|
|
|
def get_polygon(points, image, image_embedding): |
|
""" |
|
Returns the points of the polygon given a bounding box and a prediction |
|
made by Sam. |
|
""" |
|
|
|
points = list(chain.from_iterable(points)) |
|
print(points) |
|
|
|
device = 'cuda' |
|
inputs = processor(image, input_boxes=[points], return_tensors="pt").to(device) |
|
|
|
|
|
inputs.pop("pixel_values", None) |
|
inputs.update({"image_embeddings": image_embedding}) |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
masks = processor.image_processor.post_process_masks( |
|
outputs.pred_masks.cpu(), |
|
inputs["original_sizes"].cpu(), |
|
inputs["reshaped_input_sizes"].cpu() |
|
) |
|
|
|
mask = masks[0].squeeze().numpy() |
|
img = mask.astype(np.uint8)[0] |
|
contours, hierarchy = cv2.findContours(img, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) |
|
|
|
if len(contours) == 0: |
|
return [], img |
|
|
|
points = contours[0] |
|
|
|
polygon = [] |
|
for point in points: |
|
for x, y in point: |
|
polygon.append([int(x), int(y)]) |
|
|
|
mask = np.zeros(image.shape, dtype='uint8') |
|
poly = np.array(polygon) |
|
cv2.fillPoly(mask, [poly], (0, 255, 0)) |
|
|
|
return polygon, mask |
|
|
|
|
|
def add_bbox(bbox, evt: gr.SelectData): |
|
if bbox[0] == [0, 0]: |
|
bbox[0] = [evt.index[0], evt.index[1]] |
|
return bbox, bbox |
|
|
|
bbox[1] = [evt.index[0], evt.index[1]] |
|
return bbox, bbox |
|
|
|
|
|
def clear_bbox(bbox): |
|
updated_bbox = [[0, 0], [0, 0]] |
|
return updated_bbox, updated_bbox |
|
|
|
|
|
with gr.Blocks() as demo: |
|
image = gr.State() |
|
embedding = gr.State() |
|
bbox = gr.State([[0, 0], [0, 0]]) |
|
|
|
with gr.Row(): |
|
input_image = gr.Image(label='Image') |
|
mask = gr.Image(label='Mask') |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
output_status = gr.Textbox(label='Status') |
|
|
|
with gr.Column(): |
|
predictor_button = gr.Button('Send Image') |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
bbox_box = gr.Textbox(label="bbox") |
|
|
|
with gr.Column(): |
|
bbox_button = gr.Button('Clear bbox') |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
polygon = gr.Textbox(label='Polygon') |
|
|
|
with gr.Column(): |
|
points_button = gr.Button('Send bounding box') |
|
|
|
|
|
predictor_button.click( |
|
set_predictor, |
|
input_image, |
|
[image, embedding, output_status], |
|
) |
|
|
|
points_button.click( |
|
get_polygon, |
|
[bbox, image, embedding], |
|
[polygon, mask], |
|
) |
|
|
|
bbox_button.click( |
|
clear_bbox, |
|
bbox, |
|
[bbox, bbox_box], |
|
) |
|
|
|
input_image.select( |
|
add_bbox, |
|
bbox, |
|
[bbox, bbox_box] |
|
) |
|
|
|
|
|
demo.launch(debug=True) |