Spaces:
Sleeping
Sleeping
""" | |
Source url: https://github.com/OPHoperHPO/image-background-remove-tool | |
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. | |
License: Apache License 2.0 | |
""" | |
from carvekit.ml.wrap.fba_matting import FBAMatting | |
from typing import Union, List | |
from PIL import Image | |
from pathlib import Path | |
from carvekit.trimap.cv_gen import CV2TrimapGenerator | |
from carvekit.trimap.generator import TrimapGenerator | |
from carvekit.utils.mask_utils import apply_mask | |
from carvekit.utils.pool_utils import thread_pool_processing | |
from carvekit.utils.image_utils import load_image, convert_image | |
__all__ = ["MattingMethod"] | |
class MattingMethod: | |
""" | |
Improving the edges of the object mask using neural networks for matting and algorithms for creating trimap. | |
Neural network for matting performs accurate object edge detection by using a special map called trimap, | |
with unknown area that we scan for boundary, already known general object area and the background.""" | |
def __init__( | |
self, | |
matting_module: Union[FBAMatting], | |
trimap_generator: Union[TrimapGenerator, CV2TrimapGenerator], | |
device="cpu", | |
): | |
""" | |
Initializes Matting Method class. | |
Args: | |
matting_module: Initialized matting neural network class | |
trimap_generator: Initialized trimap generator class | |
device: Processing device used for applying mask to image | |
""" | |
self.device = device | |
self.matting_module = matting_module | |
self.trimap_generator = trimap_generator | |
def __call__( | |
self, | |
images: List[Union[str, Path, Image.Image]], | |
masks: List[Union[str, Path, Image.Image]], | |
): | |
""" | |
Passes data through apply_mask function | |
Args: | |
images: list of images | |
masks: list pf masks | |
Returns: | |
list of images | |
""" | |
if len(images) != len(masks): | |
raise ValueError("Images and Masks lists should have same length!") | |
images = thread_pool_processing(lambda x: convert_image(load_image(x)), images) | |
masks = thread_pool_processing( | |
lambda x: convert_image(load_image(x), mode="L"), masks | |
) | |
trimaps = thread_pool_processing( | |
lambda x: self.trimap_generator(original_image=images[x], mask=masks[x]), | |
range(len(images)), | |
) | |
alpha = self.matting_module(images=images, trimaps=trimaps) | |
return list( | |
map( | |
lambda x: apply_mask( | |
image=images[x], mask=alpha[x], device=self.device | |
), | |
range(len(images)), | |
) | |
) | |