jhj0517 commited on
Commit
41938cd
·
1 Parent(s): 3c09bbc

Add point prompt

Browse files
Files changed (1) hide show
  1. modules/sam_inference.py +19 -6
modules/sam_inference.py CHANGED
@@ -1,7 +1,7 @@
1
  from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
2
  from sam2.build_sam import build_sam2
3
  from sam2.sam2_image_predictor import SAM2ImagePredictor
4
- from typing import Dict, List
5
  import torch
6
  import os
7
  from datetime import datetime
@@ -83,7 +83,9 @@ class SamInference:
83
  def predict_image(self,
84
  image: np.ndarray,
85
  model_type: str,
86
- box: np.ndarray,
 
 
87
  **params):
88
  if self.model is None or self.model_type != model_type:
89
  self.model_type = model_type
@@ -94,6 +96,8 @@ class SamInference:
94
  try:
95
  masks, scores, logits = self.image_predictor.predict(
96
  box=box,
 
 
97
  multimask_output=params["multimask_output"],
98
  )
99
  except Exception as e:
@@ -136,15 +140,24 @@ class SamInference:
136
  elif input_mode == BOX_PROMPT_MODE:
137
  image = image_prompt_input_data["image"]
138
  image = np.array(image.convert("RGB"))
139
- box = image_prompt_input_data["points"]
140
- if len(box) == 0:
141
  return [image], []
142
- box = np.array([[x1, y1, x2, y2] for x1, y1, _, x2, y2, _ in box])
 
 
 
 
 
 
 
143
 
144
  predicted_masks, scores, logits = self.predict_image(
145
  image=image,
146
  model_type=model_type,
147
- box=box,
 
 
148
  multimask_output=hparams["multimask_output"]
149
  )
150
  generated_masks = self.format_to_auto_result(predicted_masks)
 
1
  from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
2
  from sam2.build_sam import build_sam2
3
  from sam2.sam2_image_predictor import SAM2ImagePredictor
4
+ from typing import Dict, List, Optional
5
  import torch
6
  import os
7
  from datetime import datetime
 
83
  def predict_image(self,
84
  image: np.ndarray,
85
  model_type: str,
86
+ box: Optional[np.ndarray] = None,
87
+ point_coords: Optional[np.ndarray] = None,
88
+ point_labels: Optional[np.ndarray] = None,
89
  **params):
90
  if self.model is None or self.model_type != model_type:
91
  self.model_type = model_type
 
96
  try:
97
  masks, scores, logits = self.image_predictor.predict(
98
  box=box,
99
+ point_coords=point_coords,
100
+ point_labels=point_labels,
101
  multimask_output=params["multimask_output"],
102
  )
103
  except Exception as e:
 
140
  elif input_mode == BOX_PROMPT_MODE:
141
  image = image_prompt_input_data["image"]
142
  image = np.array(image.convert("RGB"))
143
+ prompt = image_prompt_input_data["points"]
144
+ if len(prompt) == 0:
145
  return [image], []
146
+
147
+ is_prompt_point = prompt[0][-1] == 4.0
148
+
149
+ if is_prompt_point:
150
+ point_labels = np.array([1 if is_left_click else 0 for x1, y1, is_left_click, x2, y2, _ in prompt])
151
+ prompt = np.array([[x1, y1] for x1, y1, is_left_click, x2, y2, _ in prompt])
152
+ else:
153
+ prompt = np.array([[x1, y1, x2, y2] for x1, y1, is_left_click, x2, y2, _ in prompt])
154
 
155
  predicted_masks, scores, logits = self.predict_image(
156
  image=image,
157
  model_type=model_type,
158
+ box=prompt if not is_prompt_point else None,
159
+ point_coords=prompt if is_prompt_point else None,
160
+ point_labels=point_labels if is_prompt_point else None,
161
  multimask_output=hparams["multimask_output"]
162
  )
163
  generated_masks = self.format_to_auto_result(predicted_masks)