vmoras commited on
Commit
b808c95
1 Parent(s): a86e443

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -21
app.py CHANGED
@@ -1,12 +1,14 @@
1
  import gradio as gr
2
- import torch
3
  import cv2
 
4
  import traceback
5
  import numpy as np
 
6
  from transformers import SamModel, SamProcessor
7
 
8
 
9
- model = SamModel.from_pretrained('facebook/sam-vit-huge').to('cpu')
10
  processor = SamProcessor.from_pretrained('facebook/sam-vit-huge')
11
 
12
 
@@ -14,7 +16,8 @@ def set_predictor(image):
14
  """
15
  Creates a Sam predictor object based on a given image and model.
16
  """
17
- inputs = processor(image, return_tensors='pt').to('cpu')
 
18
  image_embedding = model.get_image_embeddings(inputs['pixel_values'])
19
 
20
  return [image, image_embedding, 'Done']
@@ -23,11 +26,14 @@ def set_predictor(image):
23
  def get_polygon(points, image, image_embedding):
24
  """
25
  Returns the points of the polygon given a bounding box and a prediction
26
- made by Sam, or if an exception was triggered, it returns such exception.
27
  """
28
- points = [int(w) for w in points.split(',')]
 
 
29
 
30
- inputs = processor(image, input_boxes=[points], return_tensors="pt").to('cpu')
 
31
 
32
  # pop the pixel_values as they are not neded
33
  inputs.pop("pixel_values", None)
@@ -43,39 +49,69 @@ def get_polygon(points, image, image_embedding):
43
  )
44
 
45
  mask = masks[0].squeeze().numpy()
46
-
47
  img = mask.astype(np.uint8)[0]
48
-
49
  contours, hierarchy = cv2.findContours(img, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
50
 
51
  if len(contours) == 0:
52
- return [0], img
53
-
54
  points = contours[0]
 
55
  polygon = []
56
  for point in points:
57
  for x, y in point:
58
  polygon.append([int(x), int(y)])
59
 
60
- return polygon
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
 
 
 
63
 
64
- with gr.Blocks() as app:
 
65
  image = gr.State()
66
  embedding = gr.State()
 
67
 
68
- with gr.Tab('Get embedding'):
69
  input_image = gr.Image(label='Image')
70
- output_status = gr.Textbox(label='Status')
71
- predictor_button = gr.Button('Send Image')
72
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- with gr.Tab('Get points'):
75
- bbox = gr.Textbox(label="bbox")
76
- polygon = [gr.Textbox(label='Polygon')]
77
- mask = gr.Image(label='Mask')
78
- points_button = gr.Button('Send bounding box')
 
 
 
 
79
 
80
 
81
  predictor_button.click(
@@ -90,4 +126,17 @@ with gr.Blocks() as app:
90
  [polygon, mask],
91
  )
92
 
93
- app.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import ast
3
  import cv2
4
+ import torch
5
  import traceback
6
  import numpy as np
7
+ from itertools import chain
8
  from transformers import SamModel, SamProcessor
9
 
10
 
11
+ model = SamModel.from_pretrained('facebook/sam-vit-huge').to('cuda')
12
  processor = SamProcessor.from_pretrained('facebook/sam-vit-huge')
13
 
14
 
 
16
  """
17
  Creates a Sam predictor object based on a given image and model.
18
  """
19
+ device = 'cuda'
20
+ inputs = processor(image, return_tensors='pt').to(device)
21
  image_embedding = model.get_image_embeddings(inputs['pixel_values'])
22
 
23
  return [image, image_embedding, 'Done']
 
26
  def get_polygon(points, image, image_embedding):
27
  """
28
  Returns the points of the polygon given a bounding box and a prediction
29
+ made by Sam.
30
  """
31
+ #points = [int(w) for w in points.split(',')]
32
+ points = list(chain.from_iterable(points))
33
+ print(points)
34
 
35
+ device = 'cuda'
36
+ inputs = processor(image, input_boxes=[points], return_tensors="pt").to(device)
37
 
38
  # pop the pixel_values as they are not neded
39
  inputs.pop("pixel_values", None)
 
49
  )
50
 
51
  mask = masks[0].squeeze().numpy()
 
52
  img = mask.astype(np.uint8)[0]
 
53
  contours, hierarchy = cv2.findContours(img, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
54
 
55
  if len(contours) == 0:
56
+ return [], img
57
+
58
  points = contours[0]
59
+
60
  polygon = []
61
  for point in points:
62
  for x, y in point:
63
  polygon.append([int(x), int(y)])
64
 
65
+ mask = np.zeros(image.shape, dtype='uint8')
66
+ poly = np.array(polygon)
67
+ cv2.fillPoly(mask, [poly], (0, 255, 0))
68
+
69
+ return polygon, mask
70
+
71
+
72
+ def add_bbox(bbox, evt: gr.SelectData):
73
+ if bbox[0] == [0, 0]:
74
+ bbox[0] = [evt.index[0], evt.index[1]]
75
+ return bbox, bbox
76
+
77
+ bbox[1] = [evt.index[0], evt.index[1]]
78
+ return bbox, bbox
79
 
80
 
81
+ def clear_bbox(bbox):
82
+ updated_bbox = [[0, 0], [0, 0]]
83
+ return updated_bbox, updated_bbox
84
 
85
+
86
+ with gr.Blocks() as demo:
87
  image = gr.State()
88
  embedding = gr.State()
89
+ bbox = gr.State([[0, 0], [0, 0]])
90
 
91
+ with gr.Row():
92
  input_image = gr.Image(label='Image')
93
+ mask = gr.Image(label='Mask')
 
94
 
95
+ with gr.Row():
96
+ with gr.Column():
97
+ output_status = gr.Textbox(label='Status')
98
+
99
+ with gr.Column():
100
+ predictor_button = gr.Button('Send Image')
101
+
102
+ with gr.Row():
103
+ with gr.Column():
104
+ bbox_box = gr.Textbox(label="bbox")
105
 
106
+ with gr.Column():
107
+ bbox_button = gr.Button('Clear bbox')
108
+
109
+ with gr.Row():
110
+ with gr.Column():
111
+ polygon = gr.Textbox(label='Polygon')
112
+
113
+ with gr.Column():
114
+ points_button = gr.Button('Send bounding box')
115
 
116
 
117
  predictor_button.click(
 
126
  [polygon, mask],
127
  )
128
 
129
+ bbox_button.click(
130
+ clear_bbox,
131
+ bbox,
132
+ [bbox, bbox_box],
133
+ )
134
+
135
+ input_image.select(
136
+ add_bbox,
137
+ bbox,
138
+ [bbox, bbox_box]
139
+ )
140
+
141
+
142
+ demo.launch(debug=True)