|
import numpy as np |
|
import mediapipe as mp |
|
import uuid |
|
import cv2 |
|
|
|
from PIL import Image, ImageColor |
|
from mediapipe.tasks import python |
|
from mediapipe.tasks.python import vision |
|
from scipy.ndimage import binary_dilation |
|
from croper import Croper |
|
|
|
segment_model = "checkpoints/selfie_multiclass_256x256.tflite" |
|
base_options = python.BaseOptions(model_asset_path=segment_model) |
|
options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True) |
|
segmenter = vision.ImageSegmenter.create_from_options(options) |
|
|
|
def restore_result(croper, category, generated_image): |
|
square_length = croper.square_length |
|
generated_image = generated_image.resize((square_length, square_length)) |
|
|
|
cropped_generated_image = generated_image.crop((croper.square_start_x, croper.square_start_y, croper.square_end_x, croper.square_end_y)) |
|
cropped_square_mask_image = get_restore_mask_image(croper, category, cropped_generated_image) |
|
|
|
restored_image = croper.input_image.copy() |
|
restored_image.paste(cropped_generated_image, (croper.origin_start_x, croper.origin_start_y), cropped_square_mask_image) |
|
|
|
extension = 'png' |
|
if restored_image.mode == 'RGBA': |
|
extension = 'png' |
|
else: |
|
extension = 'jpg' |
|
|
|
path = f"output/{uuid.uuid4()}.{extension}" |
|
restored_image.save(path) |
|
|
|
return restored_image, path |
|
|
|
def segment_image(input_image, category, input_size, mask_expansion, mask_dilation): |
|
mask_size = int(input_size) |
|
mask_expansion = int(mask_expansion) |
|
|
|
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image)) |
|
segmentation_result = segmenter.segment(image) |
|
category_mask = segmentation_result.category_mask |
|
category_mask_np = category_mask.numpy_view() |
|
|
|
if category == "hair": |
|
target_mask = get_hair_mask(category_mask_np, mask_dilation) |
|
elif category == "clothes": |
|
target_mask = get_clothes_mask(category_mask_np, mask_dilation) |
|
elif category == "face": |
|
target_mask = get_face_mask(category_mask_np, mask_dilation) |
|
else: |
|
target_mask = get_face_mask(category_mask_np, mask_dilation) |
|
|
|
croper = Croper(input_image, target_mask, mask_size, mask_expansion) |
|
croper.corp_mask_image() |
|
origin_area_image = croper.resized_square_image |
|
|
|
return origin_area_image, croper |
|
|
|
def segment_image_withmask(input_image, category, generate_size, mask_expansion, mask_dilation): |
|
mask_size = int(generate_size) |
|
mask_expansion = int(mask_expansion) |
|
|
|
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image)) |
|
segmentation_result = segmenter.segment(image) |
|
category_mask = segmentation_result.category_mask |
|
category_mask_np = category_mask.numpy_view() |
|
|
|
if category == "hair": |
|
target_mask = get_hair_mask(category_mask_np, mask_dilation) |
|
elif category == "clothes": |
|
target_mask = get_clothes_mask(category_mask_np, mask_dilation) |
|
elif category == "face": |
|
target_mask = get_face_mask(category_mask_np, mask_dilation) |
|
else: |
|
target_mask = get_face_mask(category_mask_np, mask_dilation) |
|
|
|
croper = Croper(input_image, target_mask, mask_size, mask_expansion) |
|
croper.corp_mask_image() |
|
origin_area_image = croper.resized_square_image |
|
mask_image = croper.resized_square_mask_image |
|
|
|
return origin_area_image, mask_image, croper |
|
|
|
def segment_image_with_gray(input_image, category, input_size, mask_expansion, mask_dilation): |
|
mask_size = int(input_size) |
|
mask_expansion = int(mask_expansion) |
|
|
|
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image)) |
|
segmentation_result = segmenter.segment(image) |
|
category_mask = segmentation_result.category_mask |
|
category_mask_np = category_mask.numpy_view() |
|
|
|
if category == "hair": |
|
target_mask = get_hair_mask(category_mask_np, mask_dilation) |
|
elif category == "clothes": |
|
target_mask = get_clothes_mask(category_mask_np, mask_dilation) |
|
elif category == "face": |
|
target_mask = get_face_mask(category_mask_np, mask_dilation) |
|
else: |
|
target_mask = get_face_mask(category_mask_np, mask_dilation) |
|
|
|
croper = Croper(input_image, target_mask, mask_size, mask_expansion) |
|
croper.corp_mask_image() |
|
origin_area_image = croper.resized_square_image |
|
mask_image = croper.resized_square_mask_image |
|
gray_area_image = origin_area_image.convert('L') |
|
origin_area_image.paste(gray_area_image, (0, 0), mask_image) |
|
|
|
return origin_area_image, croper |
|
|
|
def segment_image_with_color(input_image, color, alpha, category, input_size, mask_expansion, mask_dilation): |
|
mask_size = int(input_size) |
|
mask_expansion = int(mask_expansion) |
|
|
|
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image)) |
|
segmentation_result = segmenter.segment(image) |
|
category_mask = segmentation_result.category_mask |
|
category_mask_np = category_mask.numpy_view() |
|
|
|
if category == "hair": |
|
target_mask = get_hair_mask(category_mask_np, mask_dilation) |
|
elif category == "clothes": |
|
target_mask = get_clothes_mask(category_mask_np, mask_dilation) |
|
elif category == "face": |
|
target_mask = get_face_mask(category_mask_np, mask_dilation) |
|
else: |
|
target_mask = get_face_mask(category_mask_np, mask_dilation) |
|
|
|
cv2_image = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR) |
|
cv2_target_image = np.copy(cv2_image) |
|
cv2_target_image[~target_mask] = 0 |
|
cv2_target_hsv = cv2.cvtColor(cv2_target_image, cv2.COLOR_BGR2HSV) |
|
|
|
targetRgb = ImageColor.getcolor(color, "RGB") |
|
targetHsv = cv2.cvtColor(np.array([[targetRgb]], dtype=np.uint8), cv2.COLOR_RGB2HSV)[0][0] |
|
|
|
cv2_target_hsv[..., 0] = targetHsv[0] |
|
cv2_target_hsv[..., 1] = targetHsv[1] |
|
|
|
cv2_target_bgr = cv2.cvtColor(cv2_target_hsv, cv2.COLOR_HSV2BGR) |
|
gray_image = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2GRAY) |
|
gray_image_3d = cv2.merge([gray_image] * 3) |
|
cv2_image_w = cv2.addWeighted(cv2_target_bgr, alpha, gray_image_3d, 1 - alpha, 0) |
|
cv2_image[target_mask] = cv2_image_w[target_mask] |
|
|
|
input_image = Image.fromarray(cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB)) |
|
|
|
croper = Croper(input_image, target_mask, mask_size, mask_expansion) |
|
croper.corp_mask_image() |
|
origin_area_image = croper.resized_square_image |
|
|
|
return origin_area_image, croper |
|
|
|
def get_face_mask(category_mask_np, dilation=1): |
|
face_skin_mask = category_mask_np == 3 |
|
if dilation > 0: |
|
face_skin_mask = binary_dilation(face_skin_mask, iterations=dilation) |
|
|
|
return face_skin_mask |
|
|
|
def get_clothes_mask(category_mask_np, dilation=1): |
|
body_skin_mask = category_mask_np == 2 |
|
clothes_mask = category_mask_np == 4 |
|
combined_mask = np.logical_or(body_skin_mask, clothes_mask) |
|
combined_mask = binary_dilation(combined_mask, iterations=4) |
|
if dilation > 0: |
|
combined_mask = binary_dilation(combined_mask, iterations=dilation) |
|
return combined_mask |
|
|
|
def get_hair_mask(category_mask_np, dilation=1): |
|
hair_mask = category_mask_np == 1 |
|
if dilation > 0: |
|
hair_mask = binary_dilation(hair_mask, iterations=dilation) |
|
return hair_mask |
|
|
|
def get_restore_mask_image(croper, category, generated_image): |
|
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(generated_image)) |
|
segmentation_result = segmenter.segment(image) |
|
category_mask = segmentation_result.category_mask |
|
category_mask_np = category_mask.numpy_view() |
|
|
|
if category == "hair": |
|
target_mask = get_hair_mask(category_mask_np, 0) |
|
elif category == "clothes": |
|
target_mask = get_clothes_mask(category_mask_np, 0) |
|
elif category == "face": |
|
target_mask = get_face_mask(category_mask_np, 0) |
|
|
|
combined_mask = np.logical_or(target_mask, croper.corp_mask) |
|
mask_image = Image.fromarray((combined_mask * 255).astype(np.uint8)) |
|
return mask_image |