SAM_test / app.py
vmoras's picture
Update app.py
4076f7d
raw
history blame
2.34 kB
import gradio as gr
import cv2
import traceback
import numpy as np
from transformers import SamModel, SamProcessor
model = SamModel.from_pretrained('facebook/sam-vit-huge')
processor = SamProcessor.from_pretrained('facebook/sam-vit-huge')
def set_predictor(image):
"""
Creates a Sam predictor object based on a given image and model.
"""
device = 'cpu'
inputs = processor(image, return_tensors='pt').to(device)
image_embedding = model.get_image_embeddings(inputs['pixel_values'])
return [image, image_embedding, 'Done']
def get_polygon(points, image, image_embedding):
"""
Returns the points of the polygon given a bounding box and a prediction
made by Sam, or if an exception was triggered, it returns such exception.
"""
points = [int(w) for w in points.split(',')]
device = 'cpu'
inputs = processor(image, input_boxes=[points], return_tensors="pt").to(device)
# pop the pixel_values as they are not neded
inputs.pop("pixel_values", None)
inputs.update({"image_embeddings": image_embedding})
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
mask = masks[0].squeeze().numpy()
img = mask.astype(np.uint8)[0]
contours, hierarchy = cv2.findContours(img, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
points = contours[0]
polygon = []
for point in points:
for x, y in point:
polygon.append([int(x), int(y)])
return polygon
with gr.Blocks() as app:
image = gr.State()
embedding = gr.State()
with gr.Tab('Get embedding'):
input_image = gr.Image(label='Image')
output_status = gr.Textbox(label='Status')
predictor_button = gr.Button('Send Image')
with gr.Tab('Get points'):
bbox = gr.Textbox(label="bbox")
polygon = [gr.Textbox(label='Polygon')]
points_button = gr.Button('Send bounding box')
predictor_button.click(
set_predictor,
input_image,
[image, embedding, output_status],
)
points_button.click(
get_polygon,
[bbox, image, embedding],
polygon,
)
app.queue()
app.launch(debug=True)