|
import os |
|
import cv2 |
|
import torch |
|
import traceback |
|
import numpy as np |
|
import gradio as gr |
|
from itertools import chain |
|
from huggingface_hub import hf_hub_download |
|
from segment_anything import SamPredictor, sam_model_registry |
|
|
|
|
|
hf_hub_download(repo_id="vmoras/sam_api", filename="sam_vit_h.pth", token=os.environ.get('model_token'), local_dir="./") |
|
|
|
sam_checkpoint = "sam_vit_h.pth" |
|
model_type = "vit_h" |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
def set_predictor(image): |
|
""" |
|
Creates a Sam predictor object based on a given image and model. |
|
""" |
|
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) |
|
sam.to(device=device) |
|
|
|
predictor = SamPredictor(sam) |
|
predictor.set_image(image) |
|
|
|
return [image, predictor, 'Done'] |
|
|
|
|
|
def get_polygon(points, image, predictor): |
|
""" |
|
Returns the points of the polygon given a bounding box and a prediction |
|
made by Sam. |
|
""" |
|
points = list(chain.from_iterable(points)) |
|
|
|
input_box = np.array(points) |
|
|
|
masks, _, _ = predictor.predict( |
|
box=input_box[None, :], |
|
multimask_output=False, |
|
) |
|
|
|
img = masks[0].astype(np.uint8) |
|
contours, hierarchy = cv2.findContours(img, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) |
|
|
|
if len(contours) == 0: |
|
return [], img |
|
|
|
points = contours[0] |
|
|
|
polygon = [] |
|
for point in points: |
|
for x, y in point: |
|
polygon.append([int(x), int(y)]) |
|
|
|
mask = np.zeros(image.shape, dtype='uint8') |
|
poly = np.array(polygon) |
|
cv2.fillPoly(mask, [poly], (0, 255, 0)) |
|
|
|
return polygon, mask |
|
|
|
|
|
def add_bbox(bbox, evt: gr.SelectData): |
|
if bbox[0] == [0, 0]: |
|
bbox[0] = [evt.index[0], evt.index[1]] |
|
return bbox, bbox |
|
|
|
bbox[1] = [evt.index[0], evt.index[1]] |
|
return bbox, bbox |
|
|
|
|
|
def clear_bbox(bbox): |
|
updated_bbox = [[0, 0], [0, 0]] |
|
return updated_bbox, updated_bbox |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
# Instructions |
|
1. Upload the image and press 'Send Image'. |
|
2. Wait until the word 'Done' appears on the 'Status' box. |
|
3. Click on the image where the upper left corner of the bbox should be. |
|
4. Click on the image where the lower right corner of the bbox should be. |
|
5. Check the coordinates using the 'bbox' box. |
|
6. Click on 'Send bounding box'. |
|
7. On the right side you will see the binary mask '\*'. |
|
8. On the lower side you will see the points that made up the polygon '\*'. |
|
9. Click on 'Clear bbox' to send another bounding box and repeat the steps from the thrid step. |
|
10. Repeat steps 3 to 9 until all the segments for this image are done. |
|
11. Click on the right corner of the image to remove it and repeat all the steps with the next |
|
image. |
|
|
|
'\*' If the binary mask is all black and the polygon is an empty list, it means the program did |
|
not find any segment in the bbox. Make the bbox a little big bigger if that happens. |
|
""") |
|
|
|
|
|
image = gr.State() |
|
embedding = gr.State() |
|
bbox = gr.State([[0, 0], [0, 0]]) |
|
|
|
with gr.Row(): |
|
input_image = gr.Image(label='Image') |
|
mask = gr.Image(label='Mask') |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
output_status = gr.Textbox(label='Status') |
|
|
|
with gr.Column(): |
|
predictor_button = gr.Button('Send Image') |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
bbox_box = gr.Textbox(label="bbox") |
|
|
|
with gr.Column(): |
|
bbox_button = gr.Button('Clear bbox') |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
polygon = gr.Textbox(label='Polygon') |
|
|
|
with gr.Column(): |
|
points_button = gr.Button('Send bounding box') |
|
|
|
|
|
predictor_button.click( |
|
set_predictor, |
|
input_image, |
|
[image, embedding, output_status], |
|
) |
|
|
|
points_button.click( |
|
get_polygon, |
|
[bbox, image, embedding], |
|
[polygon, mask], |
|
) |
|
|
|
bbox_button.click( |
|
clear_bbox, |
|
bbox, |
|
[bbox, bbox_box], |
|
) |
|
|
|
input_image.select( |
|
add_bbox, |
|
bbox, |
|
[bbox, bbox_box] |
|
) |
|
|
|
|
|
demo.launch(debug=True, auth=(os.environ['user'], os.environ['password'])) |