vmoras commited on
Commit
0ad7527
1 Parent(s): 5c7957c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -6,7 +6,7 @@ import numpy as np
6
  from transformers import SamModel, SamProcessor
7
 
8
 
9
- model = SamModel.from_pretrained('facebook/sam-vit-huge')
10
  processor = SamProcessor.from_pretrained('facebook/sam-vit-huge')
11
 
12
 
@@ -14,8 +14,7 @@ def set_predictor(image):
14
  """
15
  Creates a Sam predictor object based on a given image and model.
16
  """
17
- device = 'cpu'
18
- inputs = processor(image, return_tensors='pt').to(device)
19
  image_embedding = model.get_image_embeddings(inputs['pixel_values'])
20
 
21
  return [image, image_embedding, 'Done']
@@ -28,8 +27,7 @@ def get_polygon(points, image, image_embedding):
28
  """
29
  points = [int(w) for w in points.split(',')]
30
 
31
- device = 'cpu'
32
- inputs = processor(image, input_boxes=[points], return_tensors="pt").to(device)
33
 
34
  # pop the pixel_values as they are not neded
35
  inputs.pop("pixel_values", None)
@@ -49,8 +47,11 @@ def get_polygon(points, image, image_embedding):
49
  img = mask.astype(np.uint8)[0]
50
 
51
  contours, hierarchy = cv2.findContours(img, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
52
- points = contours[0]
53
 
 
 
 
 
54
  polygon = []
55
  for point in points:
56
  for x, y in point:
@@ -73,6 +74,7 @@ with gr.Blocks() as app:
73
  with gr.Tab('Get points'):
74
  bbox = gr.Textbox(label="bbox")
75
  polygon = [gr.Textbox(label='Polygon')]
 
76
  points_button = gr.Button('Send bounding box')
77
 
78
 
@@ -85,7 +87,7 @@ with gr.Blocks() as app:
85
  points_button.click(
86
  get_polygon,
87
  [bbox, image, embedding],
88
- polygon,
89
  )
90
 
91
  app.launch(debug=True)
 
6
  from transformers import SamModel, SamProcessor
7
 
8
 
9
+ model = SamModel.from_pretrained('facebook/sam-vit-huge').to('cuda')
10
  processor = SamProcessor.from_pretrained('facebook/sam-vit-huge')
11
 
12
 
 
14
  """
15
  Creates a Sam predictor object based on a given image and model.
16
  """
17
+ inputs = processor(image, return_tensors='pt').to('cuda')
 
18
  image_embedding = model.get_image_embeddings(inputs['pixel_values'])
19
 
20
  return [image, image_embedding, 'Done']
 
27
  """
28
  points = [int(w) for w in points.split(',')]
29
 
30
+ inputs = processor(image, input_boxes=[points], return_tensors="pt").to('cuda')
 
31
 
32
  # pop the pixel_values as they are not neded
33
  inputs.pop("pixel_values", None)
 
47
  img = mask.astype(np.uint8)[0]
48
 
49
  contours, hierarchy = cv2.findContours(img, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
 
50
 
51
+ if len(contours) == 0:
52
+ return [0], img
53
+
54
+ points = contours[0]
55
  polygon = []
56
  for point in points:
57
  for x, y in point:
 
74
  with gr.Tab('Get points'):
75
  bbox = gr.Textbox(label="bbox")
76
  polygon = [gr.Textbox(label='Polygon')]
77
+ mask = gr.Image(label='Mask')
78
  points_button = gr.Button('Send bounding box')
79
 
80
 
 
87
  points_button.click(
88
  get_polygon,
89
  [bbox, image, embedding],
90
+ [polygon, mask],
91
  )
92
 
93
  app.launch(debug=True)