jhj0517 commited on
Commit
6df697d
·
1 Parent(s): 60def5b

Add `add_filter_to_preview()`

Browse files
Files changed (1) hide show
  1. modules/sam_inference.py +115 -36
modules/sam_inference.py CHANGED
@@ -1,6 +1,7 @@
1
  from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
2
  from sam2.build_sam import build_sam2, build_sam2_video_predictor
3
  from sam2.sam2_image_predictor import SAM2ImagePredictor
 
4
  from typing import Dict, List, Optional
5
  import torch
6
  import os
@@ -13,11 +14,13 @@ from modules.model_downloader import (
13
  download_sam_model_url
14
  )
15
  from modules.paths import SAM2_CONFIGS_DIR, MODELS_DIR
16
- from modules.constants import BOX_PROMPT_MODE, AUTOMATIC_MODE
17
  from modules.mask_utils import (
18
  save_psd_with_masks,
19
  create_mask_combined_images,
20
- create_mask_gallery
 
 
21
  )
22
  from modules.logger_util import get_logger
23
 
@@ -37,7 +40,7 @@ class SamInference:
37
  ):
38
  self.model = None
39
  self.available_models = list(AVAILABLE_MODELS.keys())
40
- self.model_type = DEFAULT_MODEL_TYPE
41
  self.model_dir = model_dir
42
  self.output_dir = output_dir
43
  self.model_path = os.path.join(self.model_dir, AVAILABLE_MODELS[DEFAULT_MODEL_TYPE][0])
@@ -48,14 +51,18 @@ class SamInference:
48
  self.video_inference_state = None
49
 
50
  def load_model(self,
 
51
  load_video_predictor: bool = False):
52
- config = MODEL_CONFIGS[self.model_type]
53
- filename, url = AVAILABLE_MODELS[self.model_type]
 
 
 
54
  model_path = os.path.join(self.model_dir, filename)
55
 
56
- if not is_sam_exist(self.model_type):
57
- logger.info(f"No SAM2 model found, downloading {self.model_type} model...")
58
- download_sam_model_url(self.model_type)
59
  logger.info(f"Applying configs to model..")
60
 
61
  if load_video_predictor:
@@ -81,22 +88,26 @@ class SamInference:
81
  raise f"Error while loading SAM2 model!: {e}"
82
 
83
  def init_video_inference_state(self,
 
84
  vid_input: str):
85
- if self.video_predictor is None:
86
- self.load_model(load_video_predictor=True)
 
 
87
 
88
  if self.video_inference_state is not None:
89
  self.video_predictor.reset_state(self.video_inference_state)
 
90
 
91
- self.video_predictor.init_state(video_path=vid_input)
92
 
93
  def generate_mask(self,
94
  image: np.ndarray,
95
  model_type: str,
96
  **params):
97
- if self.model is None or self.model_type != model_type:
98
- self.model_type = model_type
99
- self.load_model()
100
  self.mask_generator = SAM2AutomaticMaskGenerator(
101
  model=self.model,
102
  **params
@@ -115,9 +126,9 @@ class SamInference:
115
  point_coords: Optional[np.ndarray] = None,
116
  point_labels: Optional[np.ndarray] = None,
117
  **params):
118
- if self.model is None or self.model_type != model_type:
119
- self.model_type = model_type
120
- self.load_model()
121
  self.image_predictor = SAM2ImagePredictor(sam_model=self.model)
122
  self.image_predictor.set_image(image)
123
 
@@ -137,34 +148,78 @@ class SamInference:
137
  frame_idx: int,
138
  obj_id: int,
139
  inference_state: Dict,
140
- points: np.ndarray,
141
- labels: np.ndarray):
142
- if self.video_inference_state is None:
 
143
  logger.exception("Error while predicting frame from video, load video predictor first")
144
  raise f"Error while predicting frame from video"
145
 
146
  try:
147
- out_masks, out_obj_ids, out_mask_logits = self.video_predictor.add_new_points_or_box(
148
  inference_state=inference_state,
149
  frame_idx=frame_idx,
150
  obj_id=obj_id,
151
  points=points,
152
  labels=labels,
 
153
  )
154
  except Exception as e:
155
  logger.exception("Error while predicting frame with prompt")
156
- raise f"Error while predicting frame with prompt: {str(e)}"
 
157
 
158
- return out_masks, out_obj_ids, out_mask_logits
159
 
160
  def predict_video(self,
161
  video_input):
162
  pass
163
 
164
  def add_filter_to_preview(self,
165
- image: np.ndarray,
 
 
 
 
166
  ):
167
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  def divide_layer(self,
170
  image_input: np.ndarray,
@@ -207,21 +262,14 @@ class SamInference:
207
  if len(prompt) == 0:
208
  return [image], []
209
 
210
- point_labels, point_coords, box = [], [], []
211
-
212
- for x1, y1, left_click_indicator, x2, y2, point_indicator in prompt:
213
- if point_indicator == 4.0:
214
- point_labels.append(left_click_indicator)
215
- point_coords.append([x1, y1])
216
- else:
217
- box.append([x1, y1, x2, y2])
218
 
219
  predicted_masks, scores, logits = self.predict_image(
220
  image=image,
221
  model_type=model_type,
222
- box=np.array(box) if box else None,
223
- point_coords=np.array(point_coords) if point_coords else None,
224
- point_labels=np.array(point_labels) if point_labels else None,
225
  multimask_output=hparams["multimask_output"]
226
  )
227
  generated_masks = self.format_to_auto_result(predicted_masks)
@@ -242,3 +290,34 @@ class SamInference:
242
  masks = np.expand_dims(masks, axis=0)
243
  result = [{"segmentation": mask[0], "area": place_holder} for mask in masks]
244
  return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
2
  from sam2.build_sam import build_sam2, build_sam2_video_predictor
3
  from sam2.sam2_image_predictor import SAM2ImagePredictor
4
+ from sam2.sam2_video_predictor import SAM2VideoPredictor
5
  from typing import Dict, List, Optional
6
  import torch
7
  import os
 
14
  download_sam_model_url
15
  )
16
  from modules.paths import SAM2_CONFIGS_DIR, MODELS_DIR
17
+ from modules.constants import BOX_PROMPT_MODE, AUTOMATIC_MODE, COLOR_FILTER, PIXELIZE_FILTER
18
  from modules.mask_utils import (
19
  save_psd_with_masks,
20
  create_mask_combined_images,
21
+ create_mask_gallery,
22
+ create_mask_pixelized_image,
23
+ create_solid_color_mask_image
24
  )
25
  from modules.logger_util import get_logger
26
 
 
40
  ):
41
  self.model = None
42
  self.available_models = list(AVAILABLE_MODELS.keys())
43
+ self.current_model_type = DEFAULT_MODEL_TYPE
44
  self.model_dir = model_dir
45
  self.output_dir = output_dir
46
  self.model_path = os.path.join(self.model_dir, AVAILABLE_MODELS[DEFAULT_MODEL_TYPE][0])
 
51
  self.video_inference_state = None
52
 
53
  def load_model(self,
54
+ model_type: Optional[str] = None,
55
  load_video_predictor: bool = False):
56
+ if model_type is None:
57
+ model_type = DEFAULT_MODEL_TYPE
58
+
59
+ config = MODEL_CONFIGS[model_type]
60
+ filename, url = AVAILABLE_MODELS[model_type]
61
  model_path = os.path.join(self.model_dir, filename)
62
 
63
+ if not is_sam_exist(model_type):
64
+ logger.info(f"No SAM2 model found, downloading {model_type} model...")
65
+ download_sam_model_url(model_type)
66
  logger.info(f"Applying configs to model..")
67
 
68
  if load_video_predictor:
 
88
  raise f"Error while loading SAM2 model!: {e}"
89
 
90
  def init_video_inference_state(self,
91
+ model_type: str,
92
  vid_input: str):
93
+
94
+ if self.video_predictor is None or model_type != self.current_model_type:
95
+ self.current_model_type = model_type
96
+ self.load_model(model_type=model_type, load_video_predictor=True)
97
 
98
  if self.video_inference_state is not None:
99
  self.video_predictor.reset_state(self.video_inference_state)
100
+ self.video_inference_state = None
101
 
102
+ self.video_inference_state = self.video_predictor.init_state(video_path=vid_input)
103
 
104
  def generate_mask(self,
105
  image: np.ndarray,
106
  model_type: str,
107
  **params):
108
+ if self.model is None or self.current_model_type != model_type:
109
+ self.current_model_type = model_type
110
+ self.load_model(model_type=model_type)
111
  self.mask_generator = SAM2AutomaticMaskGenerator(
112
  model=self.model,
113
  **params
 
126
  point_coords: Optional[np.ndarray] = None,
127
  point_labels: Optional[np.ndarray] = None,
128
  **params):
129
+ if self.model is None or self.current_model_type != model_type:
130
+ self.current_model_type = model_type
131
+ self.load_model(model_type=model_type)
132
  self.image_predictor = SAM2ImagePredictor(sam_model=self.model)
133
  self.image_predictor.set_image(image)
134
 
 
148
  frame_idx: int,
149
  obj_id: int,
150
  inference_state: Dict,
151
+ points: Optional[np.ndarray] = None,
152
+ labels: Optional[np.ndarray] = None,
153
+ box: Optional[np.ndarray] = None):
154
+ if self.video_predictor is None or self.video_inference_state is None:
155
  logger.exception("Error while predicting frame from video, load video predictor first")
156
  raise f"Error while predicting frame from video"
157
 
158
  try:
159
+ out_frame_idx, out_obj_ids, out_mask_logits = self.video_predictor.add_new_points_or_box(
160
  inference_state=inference_state,
161
  frame_idx=frame_idx,
162
  obj_id=obj_id,
163
  points=points,
164
  labels=labels,
165
+ box=box
166
  )
167
  except Exception as e:
168
  logger.exception("Error while predicting frame with prompt")
169
+ print(e)
170
+ raise f"Error while predicting frame with prompt"
171
 
172
+ return out_frame_idx, out_obj_ids, out_mask_logits
173
 
174
  def predict_video(self,
175
  video_input):
176
  pass
177
 
178
  def add_filter_to_preview(self,
179
+ image_prompt_input_data: Dict,
180
+ filter_mode: str,
181
+ frame_idx: int,
182
+ pixel_size: Optional[int] = None,
183
+ color_hex: Optional[str] = None,
184
  ):
185
+ if self.video_predictor is None or self.video_inference_state is None:
186
+ logger.exception("Error while adding filter to preview, load video predictor first")
187
+ raise f"Error while adding filter to preview"
188
+
189
+ image, prompt = image_prompt_input_data["image"], image_prompt_input_data["points"]
190
+ image = np.array(image.convert("RGB"))
191
+
192
+ point_labels, point_coords, box = self.handle_prompt_data(prompt)
193
+
194
+ if filter_mode == COLOR_FILTER:
195
+ idx, scores, logits = self.predict_frame(
196
+ frame_idx=frame_idx,
197
+ obj_id=0,
198
+ inference_state=self.video_inference_state,
199
+ points=point_coords,
200
+ labels=point_labels,
201
+ box=box
202
+ )
203
+ masks = (logits[0] > 0.0).cpu().numpy()
204
+ generated_masks = self.format_to_auto_result(masks)
205
+ image = create_solid_color_mask_image(image, generated_masks, color_hex)
206
+
207
+ elif filter_mode == PIXELIZE_FILTER:
208
+ idx, scores, logits = self.predict_frame(
209
+ frame_idx=frame_idx,
210
+ obj_id=0,
211
+ inference_state=self.video_inference_state,
212
+ points=point_coords,
213
+ labels=point_labels,
214
+ box=box
215
+ )
216
+ print("before", logits)
217
+ masks = (logits[0] > 0.0).cpu().numpy()
218
+ generated_masks = self.format_to_auto_result(masks)
219
+ print("after", generated_masks)
220
+ image = create_mask_pixelized_image(image, generated_masks, pixel_size)
221
+ #
222
+ return image
223
 
224
  def divide_layer(self,
225
  image_input: np.ndarray,
 
262
  if len(prompt) == 0:
263
  return [image], []
264
 
265
+ point_labels, point_coords, box = self.handle_prompt_data(prompt)
 
 
 
 
 
 
 
266
 
267
  predicted_masks, scores, logits = self.predict_image(
268
  image=image,
269
  model_type=model_type,
270
+ box=box,
271
+ point_coords=point_coords,
272
+ point_labels=point_labels,
273
  multimask_output=hparams["multimask_output"]
274
  )
275
  generated_masks = self.format_to_auto_result(predicted_masks)
 
290
  masks = np.expand_dims(masks, axis=0)
291
  result = [{"segmentation": mask[0], "area": place_holder} for mask in masks]
292
  return result
293
+
294
+ @staticmethod
295
+ def handle_prompt_data(
296
+ prompt_data: List
297
+ ):
298
+ """
299
+ Handle data from ImageInputPrompter.
300
+
301
+ Args:
302
+ prompt_data (Dict): A dictionary containing the 'prompt' key with a list of prompts.
303
+
304
+ Returns:
305
+ point_labels (List): list of points labels.
306
+ point_coords (List): list of points coords.
307
+ box (List): list of box datas.
308
+ """
309
+ point_labels, point_coords, box = [], [], []
310
+
311
+ for x1, y1, left_click_indicator, x2, y2, point_indicator in prompt_data:
312
+ is_point = point_indicator == 4.0
313
+ if is_point:
314
+ point_labels.append(left_click_indicator)
315
+ point_coords.append([x1, y1])
316
+ else:
317
+ box.append([x1, y1, x2, y2])
318
+
319
+ point_labels = np.array(point_labels) if point_labels else None
320
+ point_coords = np.array(point_coords) if point_coords else None
321
+ box = np.array(box) if box else None
322
+
323
+ return point_labels, point_coords, box