Spaces:
Sleeping
Sleeping
""" | |
Source url: https://github.com/OPHoperHPO/image-background-remove-tool | |
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. | |
License: Apache License 2.0 | |
""" | |
import pathlib | |
from typing import List, Union | |
import PIL.Image | |
import numpy as np | |
import torch | |
from PIL import Image | |
from carvekit.ml.arch.u2net.u2net import U2NETArchitecture | |
from carvekit.ml.files.models_loc import u2net_full_pretrained | |
from carvekit.utils.image_utils import load_image, convert_image | |
from carvekit.utils.pool_utils import thread_pool_processing, batch_generator | |
__all__ = ["U2NET"] | |
class U2NET(U2NETArchitecture): | |
"""U^2-Net model interface""" | |
def __init__( | |
self, | |
layers_cfg="full", | |
device="cpu", | |
input_image_size: Union[List[int], int] = 320, | |
batch_size: int = 10, | |
load_pretrained: bool = True, | |
fp16: bool = False, | |
): | |
""" | |
Initialize the U2NET model | |
Args: | |
layers_cfg: neural network layers configuration | |
device: processing device | |
input_image_size: input image size | |
batch_size: the number of images that the neural network processes in one run | |
load_pretrained: loading pretrained model | |
fp16: use fp16 precision // not supported at this moment. | |
""" | |
super(U2NET, self).__init__(cfg_type=layers_cfg, out_ch=1) | |
self.device = device | |
self.batch_size = batch_size | |
if isinstance(input_image_size, list): | |
self.input_image_size = input_image_size[:2] | |
else: | |
self.input_image_size = (input_image_size, input_image_size) | |
self.to(device) | |
if load_pretrained: | |
self.load_state_dict( | |
torch.load(u2net_full_pretrained(), map_location=self.device) | |
) | |
self.eval() | |
def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor: | |
""" | |
Transform input image to suitable data format for neural network | |
Args: | |
data: input image | |
Returns: | |
input for neural network | |
""" | |
resized = data.resize(self.input_image_size, resample=3) | |
# noinspection PyTypeChecker | |
resized_arr = np.array(resized, dtype=float) | |
temp_image = np.zeros((resized_arr.shape[0], resized_arr.shape[1], 3)) | |
if np.max(resized_arr) != 0: | |
resized_arr /= np.max(resized_arr) | |
temp_image[:, :, 0] = (resized_arr[:, :, 0] - 0.485) / 0.229 | |
temp_image[:, :, 1] = (resized_arr[:, :, 1] - 0.456) / 0.224 | |
temp_image[:, :, 2] = (resized_arr[:, :, 2] - 0.406) / 0.225 | |
temp_image = temp_image.transpose((2, 0, 1)) | |
temp_image = np.expand_dims(temp_image, 0) | |
return torch.from_numpy(temp_image).type(torch.FloatTensor) | |
def data_postprocessing( | |
data: torch.tensor, original_image: PIL.Image.Image | |
) -> PIL.Image.Image: | |
""" | |
Transforms output data from neural network to suitable data | |
format for using with other components of this framework. | |
Args: | |
data: output data from neural network | |
original_image: input image which was used for predicted data | |
Returns: | |
Segmentation mask as PIL Image instance | |
""" | |
data = data.unsqueeze(0) | |
mask = data[:, 0, :, :] | |
ma = torch.max(mask) # Normalizes prediction | |
mi = torch.min(mask) | |
predict = ((mask - mi) / (ma - mi)).squeeze() | |
predict_np = predict.cpu().data.numpy() * 255 | |
mask = Image.fromarray(predict_np).convert("L") | |
mask = mask.resize(original_image.size, resample=3) | |
return mask | |
def __call__( | |
self, images: List[Union[str, pathlib.Path, PIL.Image.Image]] | |
) -> List[PIL.Image.Image]: | |
""" | |
Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances | |
Args: | |
images: input images | |
Returns: | |
segmentation masks as for input images, as PIL.Image.Image instances | |
""" | |
collect_masks = [] | |
for image_batch in batch_generator(images, self.batch_size): | |
images = thread_pool_processing( | |
lambda x: convert_image(load_image(x)), image_batch | |
) | |
batches = torch.vstack( | |
thread_pool_processing(self.data_preprocessing, images) | |
) | |
with torch.no_grad(): | |
batches = batches.to(self.device) | |
masks, d2, d3, d4, d5, d6, d7 = super(U2NET, self).__call__(batches) | |
masks_cpu = masks.cpu() | |
del d2, d3, d4, d5, d6, d7, batches, masks | |
masks = thread_pool_processing( | |
lambda x: self.data_postprocessing(masks_cpu[x], images[x]), | |
range(len(images)), | |
) | |
collect_masks += masks | |
return collect_masks | |