Spaces:
Sleeping
Sleeping
""" | |
Source url: https://github.com/OPHoperHPO/image-background-remove-tool | |
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. | |
License: Apache License 2.0 | |
""" | |
import PIL.Image | |
import cv2 | |
import numpy as np | |
class CV2TrimapGenerator: | |
def __init__(self, kernel_size: int = 30, erosion_iters: int = 1): | |
""" | |
Initialize a new CV2TrimapGenerator instance | |
Args: | |
kernel_size: The size of the offset from the object mask | |
in pixels when an unknown area is detected in the trimap | |
erosion_iters: The number of iterations of erosion that | |
the object's mask will be subjected to before forming an unknown area | |
""" | |
self.kernel_size = kernel_size | |
self.erosion_iters = erosion_iters | |
def __call__( | |
self, original_image: PIL.Image.Image, mask: PIL.Image.Image | |
) -> PIL.Image.Image: | |
""" | |
Generates trimap based on predicted object mask to refine object mask borders. | |
Based on cv2 erosion algorithm. | |
Args: | |
original_image: Original image | |
mask: Predicted object mask | |
Returns: | |
Generated trimap for image. | |
""" | |
if mask.mode != "L": | |
raise ValueError("Input mask has wrong color mode.") | |
if mask.size != original_image.size: | |
raise ValueError("Sizes of input image and predicted mask doesn't equal") | |
# noinspection PyTypeChecker | |
mask_array = np.array(mask) | |
pixels = 2 * self.kernel_size + 1 | |
kernel = np.ones((pixels, pixels), np.uint8) | |
if self.erosion_iters > 0: | |
erosion_kernel = np.ones((3, 3), np.uint8) | |
erode = cv2.erode(mask_array, erosion_kernel, iterations=self.erosion_iters) | |
erode = np.where(erode == 0, 0, mask_array) | |
else: | |
erode = mask_array.copy() | |
dilation = cv2.dilate(erode, kernel, iterations=1) | |
dilation = np.where(dilation == 255, 127, dilation) # WHITE to GRAY | |
trimap = np.where(erode > 127, 200, dilation) # mark the tumor inside GRAY | |
trimap = np.where(trimap < 127, 0, trimap) # Embelishment | |
trimap = np.where(trimap > 200, 0, trimap) # Embelishment | |
trimap = np.where(trimap == 200, 255, trimap) # GRAY to WHITE | |
return PIL.Image.fromarray(trimap).convert("L") | |