vmoras commited on
Commit
052cf0e
1 Parent(s): 9d5a832

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -23
app.py CHANGED
@@ -5,22 +5,25 @@ import torch
5
  import traceback
6
  import numpy as np
7
  from itertools import chain
8
- from transformers import SamModel, SamProcessor
9
 
10
 
11
- model = SamModel.from_pretrained('facebook/sam-vit-huge')
12
- processor = SamProcessor.from_pretrained('facebook/sam-vit-huge')
13
 
14
 
15
  def set_predictor(image):
16
  """
17
  Creates a Sam predictor object based on a given image and model.
18
  """
19
- device = 'cpu'
20
- inputs = processor(image, return_tensors='pt').to(device)
21
- image_embedding = model.get_image_embeddings(inputs['pixel_values'])
22
 
23
- return [image, image_embedding, 'Done']
 
 
 
 
 
 
24
 
25
 
26
  def get_polygon(points, image, image_embedding):
@@ -30,24 +33,12 @@ def get_polygon(points, image, image_embedding):
30
  """
31
  points = list(chain.from_iterable(points))
32
 
33
- device = 'cpu'
34
- inputs = processor(image, input_boxes=[points], return_tensors="pt").to(device)
35
-
36
- # pop the pixel_values as they are not neded
37
- inputs.pop("pixel_values", None)
38
- inputs.update({"image_embeddings": image_embedding})
39
-
40
- with torch.no_grad():
41
- outputs = model(**inputs)
42
-
43
- masks = processor.image_processor.post_process_masks(
44
- outputs.pred_masks.cpu(),
45
- inputs["original_sizes"].cpu(),
46
- inputs["reshaped_input_sizes"].cpu()
47
  )
48
 
49
- mask = masks[0].squeeze().numpy()
50
- img = mask.astype(np.uint8)[0]
51
  contours, hierarchy = cv2.findContours(img, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
52
 
53
  if len(contours) == 0:
 
5
  import traceback
6
  import numpy as np
7
  from itertools import chain
8
+ from segment_anything import SamPredictor, sam_model_registry
9
 
10
 
11
+ sam_checkpoint = "./checkpoints/sam_vit_h_4b8939.pth"
12
+ model_type = "vit_h"
13
 
14
 
15
  def set_predictor(image):
16
  """
17
  Creates a Sam predictor object based on a given image and model.
18
  """
 
 
 
19
 
20
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
21
+ sam.to(device=device)
22
+
23
+ predictor = SamPredictor(sam)
24
+ predictor.set_image(image)
25
+
26
+ return [image, predictor, 'Done']
27
 
28
 
29
  def get_polygon(points, image, image_embedding):
 
33
  """
34
  points = list(chain.from_iterable(points))
35
 
36
+ masks, _, _ = predictor.predict(
37
+ box=input_box[None, :],
38
+ multimask_output=False,
 
 
 
 
 
 
 
 
 
 
 
39
  )
40
 
41
+ img = masks[0].astype(np.uint8)
 
42
  contours, hierarchy = cv2.findContours(img, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
43
 
44
  if len(contours) == 0: