vmoras commited on
Commit
ff605cf
1 Parent(s): c0feeb2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -0
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import cv2
4
+ import traceback
5
+ import numpy as np
6
+ from transformers import SamModel, SamProcessor
7
+
8
+
9
+ model = SamModel.from_pretrained('facebook/sam-vit-huge')
10
+ processor = SamProcessor.from_pretrained('facebook/sam-vit-huge')
11
+
12
+
13
+ def set_predictor(image):
14
+ """
15
+ Creates a Sam predictor object based on a given image and model.
16
+ """
17
+ device = 'cpu'
18
+ inputs = processor(image, return_tensors='pt').to(device)
19
+ image_embedding = model.get_image_embeddings(inputs['pixel_values'])
20
+
21
+ return [image, image_embedding, 'Done']
22
+
23
+
24
+ def get_polygon(points, image, image_embedding):
25
+ """
26
+ Returns the points of the polygon given a bounding box and a prediction
27
+ made by Sam, or if an exception was triggered, it returns such exception.
28
+ """
29
+ points = [int(w) for w in points.split(',')]
30
+
31
+ device = 'cpu'
32
+ inputs = processor(image, input_boxes=[points], return_tensors="pt").to(device)
33
+
34
+ # pop the pixel_values as they are not neded
35
+ inputs.pop("pixel_values", None)
36
+ inputs.update({"image_embeddings": image_embedding})
37
+
38
+ with torch.no_grad():
39
+ outputs = model(**inputs)
40
+
41
+ masks = processor.image_processor.post_process_masks(
42
+ outputs.pred_masks.cpu(),
43
+ inputs["original_sizes"].cpu(),
44
+ inputs["reshaped_input_sizes"].cpu()
45
+ )
46
+
47
+ mask = masks[0].squeeze().numpy()
48
+
49
+ img = mask.astype(np.uint8)[0]
50
+
51
+ contours, hierarchy = cv2.findContours(img, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
52
+ points = contours[0]
53
+
54
+ polygon = []
55
+ for point in points:
56
+ for x, y in point:
57
+ polygon.append([int(x), int(y)])
58
+
59
+ return polygon
60
+
61
+
62
+
63
+ with gr.Blocks() as app:
64
+ image = gr.State()
65
+ embedding = gr.State()
66
+
67
+ with gr.Tab('Get embedding'):
68
+ input_image = gr.Image(label='Image')
69
+ output_status = gr.Textbox(label='Status')
70
+ predictor_button = gr.Button('Send Image')
71
+
72
+
73
+ with gr.Tab('Get points'):
74
+ bbox = gr.Textbox(label="bbox")
75
+ polygon = [gr.Textbox(label='Polygon')]
76
+ points_button = gr.Button('Send bounding box')
77
+
78
+
79
+ predictor_button.click(
80
+ set_predictor,
81
+ input_image,
82
+ [image, embedding, output_status],
83
+ )
84
+
85
+ points_button.click(
86
+ get_polygon,
87
+ [bbox, image, embedding],
88
+ polygon,
89
+ )
90
+
91
+ app.queue()
92
+ app.launch(debug=True)