Spaces:
Runtime error
Runtime error
jhj0517
commited on
Commit
•
1b5d47b
1
Parent(s):
ee4969b
Add `invert_mask` parameter to the functions
Browse files- modules/sam_inference.py +30 -3
modules/sam_inference.py
CHANGED
@@ -16,6 +16,7 @@ from modules.model_downloader import (
|
|
16 |
from modules.paths import (MODELS_DIR, TEMP_OUT_DIR, TEMP_DIR, MODEL_CONFIGS, OUTPUT_DIR)
|
17 |
from modules.constants import (BOX_PROMPT_MODE, AUTOMATIC_MODE, COLOR_FILTER, PIXELIZE_FILTER, IMAGE_FILE_EXT)
|
18 |
from modules.mask_utils import (
|
|
|
19 |
save_psd_with_masks,
|
20 |
create_mask_combined_images,
|
21 |
create_mask_gallery,
|
@@ -129,6 +130,7 @@ class SamInference:
|
|
129 |
def generate_mask(self,
|
130 |
image: np.ndarray,
|
131 |
model_type: str,
|
|
|
132 |
**params) -> List[Dict[str, Any]]:
|
133 |
"""
|
134 |
Generate masks with Automatic segmentation. Default hyperparameters are in './configs/default_hparams.yaml.'
|
@@ -136,6 +138,7 @@ class SamInference:
|
|
136 |
Args:
|
137 |
image (np.ndarray): The input image.
|
138 |
model_type (str): The model type to load.
|
|
|
139 |
**params: The hyperparameters for the mask generator.
|
140 |
|
141 |
Returns:
|
@@ -154,6 +157,11 @@ class SamInference:
|
|
154 |
except Exception as e:
|
155 |
logger.exception(f"Error while auto generating masks : {e}")
|
156 |
raise RuntimeError(f"Failed to generate masks") from e
|
|
|
|
|
|
|
|
|
|
|
157 |
return generated_masks
|
158 |
|
159 |
def predict_image(self,
|
@@ -162,6 +170,7 @@ class SamInference:
|
|
162 |
box: Optional[np.ndarray] = None,
|
163 |
point_coords: Optional[np.ndarray] = None,
|
164 |
point_labels: Optional[np.ndarray] = None,
|
|
|
165 |
**params) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
166 |
"""
|
167 |
Predict image with prompt data.
|
@@ -172,6 +181,7 @@ class SamInference:
|
|
172 |
box (np.ndarray): The box prompt data.
|
173 |
point_coords (np.ndarray): The point coordinates prompt data.
|
174 |
point_labels (np.ndarray): The point labels prompt data.
|
|
|
175 |
**params: The hyperparameters for the mask generator.
|
176 |
|
177 |
Returns:
|
@@ -195,6 +205,10 @@ class SamInference:
|
|
195 |
except Exception as e:
|
196 |
logger.exception(f"Error while predicting image with prompt: {str(e)}")
|
197 |
raise RuntimeError(f"Failed to predict image with prompt") from e
|
|
|
|
|
|
|
|
|
198 |
return masks, scores, logits
|
199 |
|
200 |
def add_prediction_to_frame(self,
|
@@ -291,6 +305,7 @@ class SamInference:
|
|
291 |
frame_idx: int,
|
292 |
pixel_size: Optional[int] = None,
|
293 |
color_hex: Optional[str] = None,
|
|
|
294 |
):
|
295 |
"""
|
296 |
Add filter to the preview image with the prompt data. Specially made for gradio app.
|
@@ -302,6 +317,7 @@ class SamInference:
|
|
302 |
frame_idx (int): The frame index of the video.
|
303 |
pixel_size (int): The pixel size for the pixelize filter.
|
304 |
color_hex (str): The color hex code for the solid color filter.
|
|
|
305 |
|
306 |
Returns:
|
307 |
np.ndarray: The filtered image output.
|
@@ -332,6 +348,9 @@ class SamInference:
|
|
332 |
box=box
|
333 |
)
|
334 |
masks = (logits[0] > 0.0).cpu().numpy()
|
|
|
|
|
|
|
335 |
generated_masks = self.format_to_auto_result(masks)
|
336 |
|
337 |
if filter_mode == COLOR_FILTER:
|
@@ -347,7 +366,8 @@ class SamInference:
|
|
347 |
filter_mode: str,
|
348 |
frame_idx: int,
|
349 |
pixel_size: Optional[int] = None,
|
350 |
-
color_hex: Optional[str] = None
|
|
|
351 |
):
|
352 |
"""
|
353 |
Create a whole filtered video with video_inference_state. Currently only one frame tracking is supported.
|
@@ -359,6 +379,7 @@ class SamInference:
|
|
359 |
frame_idx (int): The frame index of the video.
|
360 |
pixel_size (int): The pixel size for the pixelize filter.
|
361 |
color_hex (str): The color hex code for the solid color filter.
|
|
|
362 |
|
363 |
Returns:
|
364 |
str: The output video path.
|
@@ -390,12 +411,14 @@ class SamInference:
|
|
390 |
inference_state=self.video_inference_state,
|
391 |
points=point_coords,
|
392 |
labels=point_labels,
|
393 |
-
box=box
|
394 |
)
|
395 |
|
396 |
video_segments = self.propagate_in_video(inference_state=self.video_inference_state)
|
397 |
for frame_index, info in video_segments.items():
|
398 |
orig_image, masks = info["image"], info["mask"]
|
|
|
|
|
399 |
masks = self.format_to_auto_result(masks)
|
400 |
|
401 |
if filter_mode == COLOR_FILTER:
|
@@ -423,6 +446,7 @@ class SamInference:
|
|
423 |
image_prompt_input_data: Dict,
|
424 |
input_mode: str,
|
425 |
model_type: str,
|
|
|
426 |
*params):
|
427 |
"""
|
428 |
Divide the layer with the given prompt data and save psd file.
|
@@ -432,6 +456,7 @@ class SamInference:
|
|
432 |
image_prompt_input_data (Dict): The image prompt data.
|
433 |
input_mode (str): The input mode for the image prompt data. ["Automatic", "Box Prompt"]
|
434 |
model_type (str): The model type to load.
|
|
|
435 |
*params: The hyperparameters for the mask generator.
|
436 |
|
437 |
Returns:
|
@@ -463,6 +488,7 @@ class SamInference:
|
|
463 |
generated_masks = self.generate_mask(
|
464 |
image=image,
|
465 |
model_type=model_type,
|
|
|
466 |
**hparams
|
467 |
)
|
468 |
|
@@ -481,7 +507,8 @@ class SamInference:
|
|
481 |
box=box,
|
482 |
point_coords=point_coords,
|
483 |
point_labels=point_labels,
|
484 |
-
multimask_output=hparams["multimask_output"]
|
|
|
485 |
)
|
486 |
generated_masks = self.format_to_auto_result(predicted_masks)
|
487 |
|
|
|
16 |
from modules.paths import (MODELS_DIR, TEMP_OUT_DIR, TEMP_DIR, MODEL_CONFIGS, OUTPUT_DIR)
|
17 |
from modules.constants import (BOX_PROMPT_MODE, AUTOMATIC_MODE, COLOR_FILTER, PIXELIZE_FILTER, IMAGE_FILE_EXT)
|
18 |
from modules.mask_utils import (
|
19 |
+
invert_masks,
|
20 |
save_psd_with_masks,
|
21 |
create_mask_combined_images,
|
22 |
create_mask_gallery,
|
|
|
130 |
def generate_mask(self,
|
131 |
image: np.ndarray,
|
132 |
model_type: str,
|
133 |
+
invert_mask: bool = False,
|
134 |
**params) -> List[Dict[str, Any]]:
|
135 |
"""
|
136 |
Generate masks with Automatic segmentation. Default hyperparameters are in './configs/default_hparams.yaml.'
|
|
|
138 |
Args:
|
139 |
image (np.ndarray): The input image.
|
140 |
model_type (str): The model type to load.
|
141 |
+
invert_mask (bool): Invert the mask output - used for background masking.
|
142 |
**params: The hyperparameters for the mask generator.
|
143 |
|
144 |
Returns:
|
|
|
157 |
except Exception as e:
|
158 |
logger.exception(f"Error while auto generating masks : {e}")
|
159 |
raise RuntimeError(f"Failed to generate masks") from e
|
160 |
+
|
161 |
+
if invert_mask:
|
162 |
+
generated_masks = [{'segmentation': invert_masks(mask['segmentation']),
|
163 |
+
'area': mask['area']} for mask in generated_masks]
|
164 |
+
|
165 |
return generated_masks
|
166 |
|
167 |
def predict_image(self,
|
|
|
170 |
box: Optional[np.ndarray] = None,
|
171 |
point_coords: Optional[np.ndarray] = None,
|
172 |
point_labels: Optional[np.ndarray] = None,
|
173 |
+
invert_mask: bool = False,
|
174 |
**params) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
175 |
"""
|
176 |
Predict image with prompt data.
|
|
|
181 |
box (np.ndarray): The box prompt data.
|
182 |
point_coords (np.ndarray): The point coordinates prompt data.
|
183 |
point_labels (np.ndarray): The point labels prompt data.
|
184 |
+
invert_mask (bool): Invert the mask output - used for background masking.
|
185 |
**params: The hyperparameters for the mask generator.
|
186 |
|
187 |
Returns:
|
|
|
205 |
except Exception as e:
|
206 |
logger.exception(f"Error while predicting image with prompt: {str(e)}")
|
207 |
raise RuntimeError(f"Failed to predict image with prompt") from e
|
208 |
+
|
209 |
+
if invert_mask:
|
210 |
+
masks = invert_masks(masks)
|
211 |
+
|
212 |
return masks, scores, logits
|
213 |
|
214 |
def add_prediction_to_frame(self,
|
|
|
305 |
frame_idx: int,
|
306 |
pixel_size: Optional[int] = None,
|
307 |
color_hex: Optional[str] = None,
|
308 |
+
invert_mask: bool = False
|
309 |
):
|
310 |
"""
|
311 |
Add filter to the preview image with the prompt data. Specially made for gradio app.
|
|
|
317 |
frame_idx (int): The frame index of the video.
|
318 |
pixel_size (int): The pixel size for the pixelize filter.
|
319 |
color_hex (str): The color hex code for the solid color filter.
|
320 |
+
invert_mask (bool): Invert the mask output - used for background masking.
|
321 |
|
322 |
Returns:
|
323 |
np.ndarray: The filtered image output.
|
|
|
348 |
box=box
|
349 |
)
|
350 |
masks = (logits[0] > 0.0).cpu().numpy()
|
351 |
+
if invert_mask:
|
352 |
+
masks = invert_masks(masks)
|
353 |
+
|
354 |
generated_masks = self.format_to_auto_result(masks)
|
355 |
|
356 |
if filter_mode == COLOR_FILTER:
|
|
|
366 |
filter_mode: str,
|
367 |
frame_idx: int,
|
368 |
pixel_size: Optional[int] = None,
|
369 |
+
color_hex: Optional[str] = None,
|
370 |
+
invert_mask: bool = False
|
371 |
):
|
372 |
"""
|
373 |
Create a whole filtered video with video_inference_state. Currently only one frame tracking is supported.
|
|
|
379 |
frame_idx (int): The frame index of the video.
|
380 |
pixel_size (int): The pixel size for the pixelize filter.
|
381 |
color_hex (str): The color hex code for the solid color filter.
|
382 |
+
invert_mask (bool): Invert the mask output - used for background masking.
|
383 |
|
384 |
Returns:
|
385 |
str: The output video path.
|
|
|
411 |
inference_state=self.video_inference_state,
|
412 |
points=point_coords,
|
413 |
labels=point_labels,
|
414 |
+
box=box,
|
415 |
)
|
416 |
|
417 |
video_segments = self.propagate_in_video(inference_state=self.video_inference_state)
|
418 |
for frame_index, info in video_segments.items():
|
419 |
orig_image, masks = info["image"], info["mask"]
|
420 |
+
if invert_mask:
|
421 |
+
masks = invert_masks(masks)
|
422 |
masks = self.format_to_auto_result(masks)
|
423 |
|
424 |
if filter_mode == COLOR_FILTER:
|
|
|
446 |
image_prompt_input_data: Dict,
|
447 |
input_mode: str,
|
448 |
model_type: str,
|
449 |
+
invert_mask: bool = False,
|
450 |
*params):
|
451 |
"""
|
452 |
Divide the layer with the given prompt data and save psd file.
|
|
|
456 |
image_prompt_input_data (Dict): The image prompt data.
|
457 |
input_mode (str): The input mode for the image prompt data. ["Automatic", "Box Prompt"]
|
458 |
model_type (str): The model type to load.
|
459 |
+
invert_mask (bool): Invert the mask output.
|
460 |
*params: The hyperparameters for the mask generator.
|
461 |
|
462 |
Returns:
|
|
|
488 |
generated_masks = self.generate_mask(
|
489 |
image=image,
|
490 |
model_type=model_type,
|
491 |
+
invert_mask=invert_mask,
|
492 |
**hparams
|
493 |
)
|
494 |
|
|
|
507 |
box=box,
|
508 |
point_coords=point_coords,
|
509 |
point_labels=point_labels,
|
510 |
+
multimask_output=hparams["multimask_output"],
|
511 |
+
invert_mask=invert_mask
|
512 |
)
|
513 |
generated_masks = self.format_to_auto_result(predicted_masks)
|
514 |
|