Spaces:
Runtime error
Runtime error
import re | |
from typing import List | |
import cv2 | |
import numpy as np | |
import supervision as sv | |
class Visualizer: | |
def __init__( | |
self, | |
line_thickness: int = 2, | |
mask_opacity: float = 0.1, | |
text_scale: float = 0.5 | |
) -> None: | |
self.box_annotator = sv.BoundingBoxAnnotator( | |
color_lookup=sv.ColorLookup.INDEX, | |
thickness=line_thickness) | |
self.mask_annotator = sv.MaskAnnotator( | |
color_lookup=sv.ColorLookup.INDEX, | |
opacity=mask_opacity) | |
self.polygon_annotator = sv.PolygonAnnotator( | |
color_lookup=sv.ColorLookup.INDEX, | |
thickness=line_thickness) | |
self.label_annotator = sv.LabelAnnotator( | |
color_lookup=sv.ColorLookup.INDEX, | |
text_position=sv.Position.CENTER_OF_MASS, | |
text_scale=text_scale) | |
def visualize( | |
self, | |
image: np.ndarray, | |
detections: sv.Detections, | |
with_box: bool, | |
with_mask: bool, | |
with_polygon: bool, | |
with_label: bool | |
) -> np.ndarray: | |
annotated_image = image.copy() | |
if with_box: | |
annotated_image = self.box_annotator.annotate( | |
scene=annotated_image, detections=detections) | |
if with_mask: | |
annotated_image = self.mask_annotator.annotate( | |
scene=annotated_image, detections=detections) | |
if with_polygon: | |
annotated_image = self.polygon_annotator.annotate( | |
scene=annotated_image, detections=detections) | |
if with_label: | |
labels = list(map(str, range(len(detections)))) | |
annotated_image = self.label_annotator.annotate( | |
scene=annotated_image, detections=detections, labels=labels) | |
return annotated_image | |
def refine_mask( | |
mask: np.ndarray, | |
area_threshold: float, | |
mode: str = 'islands' | |
) -> np.ndarray: | |
""" | |
Refines a mask by removing small islands or filling small holes based on area | |
threshold. | |
Parameters: | |
mask (np.ndarray): Input binary mask. | |
area_threshold (float): Threshold for relative area to remove or fill features. | |
mode (str): Operation mode ('islands' for removing islands, 'holes' for filling | |
holes). | |
Returns: | |
np.ndarray: Refined binary mask. | |
""" | |
mask = np.uint8(mask * 255) | |
operation = cv2.RETR_EXTERNAL if mode == 'islands' else cv2.RETR_CCOMP | |
contours, _ = cv2.findContours( | |
mask, operation, cv2.CHAIN_APPROX_SIMPLE | |
) | |
total_area = cv2.countNonZero(mask) if mode == 'islands' else mask.size | |
for contour in contours: | |
area = cv2.contourArea(contour) | |
relative_area = area / total_area | |
if relative_area < area_threshold: | |
cv2.drawContours( | |
mask, [contour], -1, (0 if mode == 'islands' else 255), -1 | |
) | |
return np.where(mask > 0, 1, 0).astype(bool) | |
def filter_masks_by_relative_area( | |
masks: np.ndarray, | |
min_relative_area: float = 0.02, | |
max_relative_area: float = 1.0 | |
) -> np.ndarray: | |
""" | |
Filters out masks based on their relative area. | |
Parameters: | |
masks (np.ndarray): A 3D numpy array where each slice along the third dimension | |
represents a mask. | |
min_relative_area (float): Minimum relative area threshold for keeping a mask. | |
max_relative_area (float): Maximum relative area threshold for keeping a mask. | |
Returns: | |
np.ndarray: A 3D numpy array of filtered masks. | |
""" | |
mask_areas = masks.sum(axis=(1, 2)) | |
total_area = masks.shape[1] * masks.shape[2] | |
relative_areas = mask_areas / total_area | |
min_area_filter = relative_areas >= min_relative_area | |
max_area_filter = relative_areas <= max_relative_area | |
return masks[min_area_filter & max_area_filter] | |
def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float: | |
""" | |
Computes the Intersection over Union (IoU) of two masks. | |
Parameters: | |
mask1, mask2 (np.ndarray): Two mask arrays. | |
Returns: | |
float: The IoU of the two masks. | |
""" | |
intersection = np.logical_and(mask1, mask2).sum() | |
union = np.logical_or(mask1, mask2).sum() | |
return intersection / union if union != 0 else 0 | |
def filter_highly_overlapping_masks( | |
masks: np.ndarray, | |
iou_threshold: float | |
) -> np.ndarray: | |
""" | |
Removes masks with high overlap from a set of masks. | |
Parameters: | |
masks (np.ndarray): A 3D numpy array with shape (N, H, W), where N is the | |
number of masks, and H and W are the height and width of the masks. | |
iou_threshold (float): The IoU threshold above which masks will be considered as | |
overlapping. | |
Returns: | |
np.ndarray: A 3D numpy array of masks with highly overlapping masks removed. | |
""" | |
num_masks = masks.shape[0] | |
keep_mask = np.ones(num_masks, dtype=bool) | |
for i in range(num_masks): | |
for j in range(i + 1, num_masks): | |
if not keep_mask[i] or not keep_mask[j]: | |
continue | |
iou = compute_iou(masks[i, :, :], masks[j, :, :]) | |
if iou > iou_threshold: | |
keep_mask[j] = False | |
return masks[keep_mask] | |
def postprocess_masks( | |
detections: sv.Detections, | |
area_threshold: float = 0.01, | |
min_relative_area: float = 0.01, | |
max_relative_area: float = 1.0, | |
iou_threshold: float = 0.9 | |
) -> sv.Detections: | |
""" | |
Post-processes the masks of detection objects by removing small islands and filling | |
small holes. | |
Parameters: | |
detections (sv.Detections): Detection objects to be filtered. | |
area_threshold (float): Threshold for relative area to remove or fill features. | |
min_relative_area (float): Minimum relative area threshold for detections. | |
max_relative_area (float): Maximum relative area threshold for detections. | |
iou_threshold (float): The IoU threshold above which masks will be considered as | |
overlapping. | |
Returns: | |
np.ndarray: Post-processed masks. | |
""" | |
masks = detections.mask.copy() | |
for i in range(len(masks)): | |
masks[i] = refine_mask( | |
mask=masks[i], | |
area_threshold=area_threshold, | |
mode='islands' | |
) | |
masks[i] = refine_mask( | |
mask=masks[i], | |
area_threshold=area_threshold, | |
mode='holes' | |
) | |
masks = filter_masks_by_relative_area( | |
masks=masks, | |
min_relative_area=min_relative_area, | |
max_relative_area=max_relative_area) | |
masks = filter_highly_overlapping_masks( | |
masks=masks, | |
iou_threshold=iou_threshold) | |
return sv.Detections( | |
xyxy=sv.mask_to_xyxy(masks), | |
mask=masks | |
) | |
def extract_numbers_in_brackets(text: str) -> List[int]: | |
""" | |
Extracts all numbers enclosed in square brackets from a given string. | |
Args: | |
text (str): The string to be searched. | |
Returns: | |
List[int]: A list of integers found within square brackets. | |
""" | |
pattern = r'\[(\d+)\]' | |
numbers = [int(num) for num in re.findall(pattern, text)] | |
return numbers | |