jhj0517 commited on
Commit
1b5d47b
1 Parent(s): ee4969b

Add `invert_mask` parameter to the functions

Browse files
Files changed (1) hide show
  1. modules/sam_inference.py +30 -3
modules/sam_inference.py CHANGED
@@ -16,6 +16,7 @@ from modules.model_downloader import (
16
  from modules.paths import (MODELS_DIR, TEMP_OUT_DIR, TEMP_DIR, MODEL_CONFIGS, OUTPUT_DIR)
17
  from modules.constants import (BOX_PROMPT_MODE, AUTOMATIC_MODE, COLOR_FILTER, PIXELIZE_FILTER, IMAGE_FILE_EXT)
18
  from modules.mask_utils import (
 
19
  save_psd_with_masks,
20
  create_mask_combined_images,
21
  create_mask_gallery,
@@ -129,6 +130,7 @@ class SamInference:
129
  def generate_mask(self,
130
  image: np.ndarray,
131
  model_type: str,
 
132
  **params) -> List[Dict[str, Any]]:
133
  """
134
  Generate masks with Automatic segmentation. Default hyperparameters are in './configs/default_hparams.yaml.'
@@ -136,6 +138,7 @@ class SamInference:
136
  Args:
137
  image (np.ndarray): The input image.
138
  model_type (str): The model type to load.
 
139
  **params: The hyperparameters for the mask generator.
140
 
141
  Returns:
@@ -154,6 +157,11 @@ class SamInference:
154
  except Exception as e:
155
  logger.exception(f"Error while auto generating masks : {e}")
156
  raise RuntimeError(f"Failed to generate masks") from e
 
 
 
 
 
157
  return generated_masks
158
 
159
  def predict_image(self,
@@ -162,6 +170,7 @@ class SamInference:
162
  box: Optional[np.ndarray] = None,
163
  point_coords: Optional[np.ndarray] = None,
164
  point_labels: Optional[np.ndarray] = None,
 
165
  **params) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
166
  """
167
  Predict image with prompt data.
@@ -172,6 +181,7 @@ class SamInference:
172
  box (np.ndarray): The box prompt data.
173
  point_coords (np.ndarray): The point coordinates prompt data.
174
  point_labels (np.ndarray): The point labels prompt data.
 
175
  **params: The hyperparameters for the mask generator.
176
 
177
  Returns:
@@ -195,6 +205,10 @@ class SamInference:
195
  except Exception as e:
196
  logger.exception(f"Error while predicting image with prompt: {str(e)}")
197
  raise RuntimeError(f"Failed to predict image with prompt") from e
 
 
 
 
198
  return masks, scores, logits
199
 
200
  def add_prediction_to_frame(self,
@@ -291,6 +305,7 @@ class SamInference:
291
  frame_idx: int,
292
  pixel_size: Optional[int] = None,
293
  color_hex: Optional[str] = None,
 
294
  ):
295
  """
296
  Add filter to the preview image with the prompt data. Specially made for gradio app.
@@ -302,6 +317,7 @@ class SamInference:
302
  frame_idx (int): The frame index of the video.
303
  pixel_size (int): The pixel size for the pixelize filter.
304
  color_hex (str): The color hex code for the solid color filter.
 
305
 
306
  Returns:
307
  np.ndarray: The filtered image output.
@@ -332,6 +348,9 @@ class SamInference:
332
  box=box
333
  )
334
  masks = (logits[0] > 0.0).cpu().numpy()
 
 
 
335
  generated_masks = self.format_to_auto_result(masks)
336
 
337
  if filter_mode == COLOR_FILTER:
@@ -347,7 +366,8 @@ class SamInference:
347
  filter_mode: str,
348
  frame_idx: int,
349
  pixel_size: Optional[int] = None,
350
- color_hex: Optional[str] = None
 
351
  ):
352
  """
353
  Create a whole filtered video with video_inference_state. Currently only one frame tracking is supported.
@@ -359,6 +379,7 @@ class SamInference:
359
  frame_idx (int): The frame index of the video.
360
  pixel_size (int): The pixel size for the pixelize filter.
361
  color_hex (str): The color hex code for the solid color filter.
 
362
 
363
  Returns:
364
  str: The output video path.
@@ -390,12 +411,14 @@ class SamInference:
390
  inference_state=self.video_inference_state,
391
  points=point_coords,
392
  labels=point_labels,
393
- box=box
394
  )
395
 
396
  video_segments = self.propagate_in_video(inference_state=self.video_inference_state)
397
  for frame_index, info in video_segments.items():
398
  orig_image, masks = info["image"], info["mask"]
 
 
399
  masks = self.format_to_auto_result(masks)
400
 
401
  if filter_mode == COLOR_FILTER:
@@ -423,6 +446,7 @@ class SamInference:
423
  image_prompt_input_data: Dict,
424
  input_mode: str,
425
  model_type: str,
 
426
  *params):
427
  """
428
  Divide the layer with the given prompt data and save psd file.
@@ -432,6 +456,7 @@ class SamInference:
432
  image_prompt_input_data (Dict): The image prompt data.
433
  input_mode (str): The input mode for the image prompt data. ["Automatic", "Box Prompt"]
434
  model_type (str): The model type to load.
 
435
  *params: The hyperparameters for the mask generator.
436
 
437
  Returns:
@@ -463,6 +488,7 @@ class SamInference:
463
  generated_masks = self.generate_mask(
464
  image=image,
465
  model_type=model_type,
 
466
  **hparams
467
  )
468
 
@@ -481,7 +507,8 @@ class SamInference:
481
  box=box,
482
  point_coords=point_coords,
483
  point_labels=point_labels,
484
- multimask_output=hparams["multimask_output"]
 
485
  )
486
  generated_masks = self.format_to_auto_result(predicted_masks)
487
 
 
16
  from modules.paths import (MODELS_DIR, TEMP_OUT_DIR, TEMP_DIR, MODEL_CONFIGS, OUTPUT_DIR)
17
  from modules.constants import (BOX_PROMPT_MODE, AUTOMATIC_MODE, COLOR_FILTER, PIXELIZE_FILTER, IMAGE_FILE_EXT)
18
  from modules.mask_utils import (
19
+ invert_masks,
20
  save_psd_with_masks,
21
  create_mask_combined_images,
22
  create_mask_gallery,
 
130
  def generate_mask(self,
131
  image: np.ndarray,
132
  model_type: str,
133
+ invert_mask: bool = False,
134
  **params) -> List[Dict[str, Any]]:
135
  """
136
  Generate masks with Automatic segmentation. Default hyperparameters are in './configs/default_hparams.yaml.'
 
138
  Args:
139
  image (np.ndarray): The input image.
140
  model_type (str): The model type to load.
141
+ invert_mask (bool): Invert the mask output - used for background masking.
142
  **params: The hyperparameters for the mask generator.
143
 
144
  Returns:
 
157
  except Exception as e:
158
  logger.exception(f"Error while auto generating masks : {e}")
159
  raise RuntimeError(f"Failed to generate masks") from e
160
+
161
+ if invert_mask:
162
+ generated_masks = [{'segmentation': invert_masks(mask['segmentation']),
163
+ 'area': mask['area']} for mask in generated_masks]
164
+
165
  return generated_masks
166
 
167
  def predict_image(self,
 
170
  box: Optional[np.ndarray] = None,
171
  point_coords: Optional[np.ndarray] = None,
172
  point_labels: Optional[np.ndarray] = None,
173
+ invert_mask: bool = False,
174
  **params) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
175
  """
176
  Predict image with prompt data.
 
181
  box (np.ndarray): The box prompt data.
182
  point_coords (np.ndarray): The point coordinates prompt data.
183
  point_labels (np.ndarray): The point labels prompt data.
184
+ invert_mask (bool): Invert the mask output - used for background masking.
185
  **params: The hyperparameters for the mask generator.
186
 
187
  Returns:
 
205
  except Exception as e:
206
  logger.exception(f"Error while predicting image with prompt: {str(e)}")
207
  raise RuntimeError(f"Failed to predict image with prompt") from e
208
+
209
+ if invert_mask:
210
+ masks = invert_masks(masks)
211
+
212
  return masks, scores, logits
213
 
214
  def add_prediction_to_frame(self,
 
305
  frame_idx: int,
306
  pixel_size: Optional[int] = None,
307
  color_hex: Optional[str] = None,
308
+ invert_mask: bool = False
309
  ):
310
  """
311
  Add filter to the preview image with the prompt data. Specially made for gradio app.
 
317
  frame_idx (int): The frame index of the video.
318
  pixel_size (int): The pixel size for the pixelize filter.
319
  color_hex (str): The color hex code for the solid color filter.
320
+ invert_mask (bool): Invert the mask output - used for background masking.
321
 
322
  Returns:
323
  np.ndarray: The filtered image output.
 
348
  box=box
349
  )
350
  masks = (logits[0] > 0.0).cpu().numpy()
351
+ if invert_mask:
352
+ masks = invert_masks(masks)
353
+
354
  generated_masks = self.format_to_auto_result(masks)
355
 
356
  if filter_mode == COLOR_FILTER:
 
366
  filter_mode: str,
367
  frame_idx: int,
368
  pixel_size: Optional[int] = None,
369
+ color_hex: Optional[str] = None,
370
+ invert_mask: bool = False
371
  ):
372
  """
373
  Create a whole filtered video with video_inference_state. Currently only one frame tracking is supported.
 
379
  frame_idx (int): The frame index of the video.
380
  pixel_size (int): The pixel size for the pixelize filter.
381
  color_hex (str): The color hex code for the solid color filter.
382
+ invert_mask (bool): Invert the mask output - used for background masking.
383
 
384
  Returns:
385
  str: The output video path.
 
411
  inference_state=self.video_inference_state,
412
  points=point_coords,
413
  labels=point_labels,
414
+ box=box,
415
  )
416
 
417
  video_segments = self.propagate_in_video(inference_state=self.video_inference_state)
418
  for frame_index, info in video_segments.items():
419
  orig_image, masks = info["image"], info["mask"]
420
+ if invert_mask:
421
+ masks = invert_masks(masks)
422
  masks = self.format_to_auto_result(masks)
423
 
424
  if filter_mode == COLOR_FILTER:
 
446
  image_prompt_input_data: Dict,
447
  input_mode: str,
448
  model_type: str,
449
+ invert_mask: bool = False,
450
  *params):
451
  """
452
  Divide the layer with the given prompt data and save psd file.
 
456
  image_prompt_input_data (Dict): The image prompt data.
457
  input_mode (str): The input mode for the image prompt data. ["Automatic", "Box Prompt"]
458
  model_type (str): The model type to load.
459
+ invert_mask (bool): Invert the mask output.
460
  *params: The hyperparameters for the mask generator.
461
 
462
  Returns:
 
488
  generated_masks = self.generate_mask(
489
  image=image,
490
  model_type=model_type,
491
+ invert_mask=invert_mask,
492
  **hparams
493
  )
494
 
 
507
  box=box,
508
  point_coords=point_coords,
509
  point_labels=point_labels,
510
+ multimask_output=hparams["multimask_output"],
511
+ invert_mask=invert_mask
512
  )
513
  generated_masks = self.format_to_auto_result(predicted_masks)
514