Mov2Mov / ebsynth /_ebsynth.py
AmyTheKamiwazaGirl2001's picture
Upload 37 files
99c376f verified
# fork for Ezsynth(https://github.com/Trentonom0r3/Ezsynth)
import os
import sys
from ctypes import *
from pathlib import Path
import cv2
import numpy as np
libebsynth = None
cached_buffer = {}
EBSYNTH_BACKEND_CPU = 0x0001
EBSYNTH_BACKEND_CUDA = 0x0002
EBSYNTH_BACKEND_AUTO = 0x0000
EBSYNTH_MAX_STYLE_CHANNELS = 8
EBSYNTH_MAX_GUIDE_CHANNELS = 24
EBSYNTH_VOTEMODE_PLAIN = 0x0001 # weight = 1
EBSYNTH_VOTEMODE_WEIGHTED = 0x0002 # weight = 1/(1+error)
def _normalize_img_shape(img):
img_len = len(img.shape)
if img_len == 2:
sh, sw = img.shape
sc = 0
elif img_len == 3:
sh, sw, sc = img.shape
if sc == 0:
sc = 1
img = img[..., np.newaxis]
return img
def run(img_style, guides,
patch_size=5,
num_pyramid_levels=-1,
num_search_vote_iters=6,
num_patch_match_iters=4,
stop_threshold=5,
uniformity_weight=3500.0,
extraPass3x3=False,
):
if patch_size < 3:
raise ValueError("patch_size is too small")
if patch_size % 2 == 0:
raise ValueError("patch_size must be an odd number")
if len(guides) == 0:
raise ValueError("at least one guide must be specified")
global libebsynth
if libebsynth is None:
if sys.platform[0:3] == 'win':
libebsynth_path = str(Path(__file__).parent / 'ebsynth.dll')
libebsynth = CDLL(libebsynth_path)
else:
# todo: implement for linux
pass
if libebsynth is not None:
libebsynth.ebsynthRun.argtypes = ( \
c_int,
c_int,
c_int,
c_int,
c_int,
c_void_p,
c_void_p,
c_int,
c_int,
c_void_p,
c_void_p,
POINTER(c_float),
POINTER(c_float),
c_float,
c_int,
c_int,
c_int,
POINTER(c_int),
POINTER(c_int),
POINTER(c_int),
c_int,
c_void_p,
c_void_p
)
if libebsynth is None:
return img_style
img_style = _normalize_img_shape(img_style)
sh, sw, sc = img_style.shape
t_h, t_w, t_c = 0, 0, 0
if sc > EBSYNTH_MAX_STYLE_CHANNELS:
raise ValueError(f"error: too many style channels {sc}, maximum number is {EBSYNTH_MAX_STYLE_CHANNELS}")
guides_source = []
guides_target = []
guides_weights = []
for i in range(len(guides)):
source_guide, target_guide, guide_weight = guides[i]
source_guide = _normalize_img_shape(source_guide)
target_guide = _normalize_img_shape(target_guide)
s_h, s_w, s_c = source_guide.shape
nt_h, nt_w, nt_c = target_guide.shape
if s_h != sh or s_w != sw:
raise ValueError("guide source and style resolution must match style resolution.")
if t_c == 0:
t_h, t_w, t_c = nt_h, nt_w, nt_c
elif nt_h != t_h or nt_w != t_w:
raise ValueError("guides target resolutions must be equal")
if s_c != nt_c:
raise ValueError("guide source and target channels must match exactly.")
guides_source.append(source_guide)
guides_target.append(target_guide)
guides_weights += [guide_weight / s_c] * s_c
guides_source = np.concatenate(guides_source, axis=-1)
guides_target = np.concatenate(guides_target, axis=-1)
guides_weights = (c_float * len(guides_weights))(*guides_weights)
styleWeight = 1.0
style_weights = [styleWeight / sc for i in range(sc)]
style_weights = (c_float * sc)(*style_weights)
maxPyramidLevels = 0
for level in range(32, -1, -1):
if min(min(sh, t_h) * pow(2.0, -level), \
min(sw, t_w) * pow(2.0, -level)) >= (2 * patch_size + 1):
maxPyramidLevels = level + 1
break
if num_pyramid_levels == -1:
num_pyramid_levels = maxPyramidLevels
num_pyramid_levels = min(num_pyramid_levels, maxPyramidLevels)
num_search_vote_iters_per_level = (c_int * num_pyramid_levels)(*[num_search_vote_iters] * num_pyramid_levels)
num_patch_match_iters_per_level = (c_int * num_pyramid_levels)(*[num_patch_match_iters] * num_pyramid_levels)
stop_threshold_per_level = (c_int * num_pyramid_levels)(*[stop_threshold] * num_pyramid_levels)
buffer = cached_buffer.get((t_h, t_w, sc), None)
if buffer is None:
buffer = create_string_buffer(t_h * t_w * sc)
cached_buffer[(t_h, t_w, sc)] = buffer
libebsynth.ebsynthRun(EBSYNTH_BACKEND_AUTO, # backend
sc, # numStyleChannels
guides_source.shape[-1], # numGuideChannels
sw, # sourceWidth
sh, # sourceHeight
img_style.tobytes(),
# sourceStyleData (width * height * numStyleChannels) bytes, scan-line order
guides_source.tobytes(),
# sourceGuideData (width * height * numGuideChannels) bytes, scan-line order
t_w, # targetWidth
t_h, # targetHeight
guides_target.tobytes(),
# targetGuideData (width * height * numGuideChannels) bytes, scan-line order
None,
# targetModulationData (width * height * numGuideChannels) bytes, scan-line order; pass NULL to switch off the modulation
style_weights, # styleWeights (numStyleChannels) floats
guides_weights, # guideWeights (numGuideChannels) floats
uniformity_weight,
# uniformityWeight reasonable values are between 500-15000, 3500 is a good default
patch_size, # patchSize odd sizes only, use 5 for 5x5 patch, 7 for 7x7, etc.
EBSYNTH_VOTEMODE_WEIGHTED, # voteMode use VOTEMODE_WEIGHTED for sharper result
num_pyramid_levels, # numPyramidLevels
num_search_vote_iters_per_level,
# numSearchVoteItersPerLevel how many search/vote iters to perform at each level (array of ints, coarse first, fine last)
num_patch_match_iters_per_level,
# numPatchMatchItersPerLevel how many Patch-Match iters to perform at each level (array of ints, coarse first, fine last)
stop_threshold_per_level,
# stopThresholdPerLevel stop improving pixel when its change since last iteration falls under this threshold
1 if extraPass3x3 else 0,
# extraPass3x3 perform additional polishing pass with 3x3 patches at the finest level, use 0 to disable
None, # outputNnfData (width * height * 2) ints, scan-line order; pass NULL to ignore
buffer # outputImageData (width * height * numStyleChannels) bytes, scan-line order
)
return np.frombuffer(buffer, dtype=np.uint8).reshape((t_h, t_w, sc)).copy()
# transfer color from source to target
def color_transfer(img_source, img_target):
guides = [(cv2.cvtColor(img_source, cv2.COLOR_BGR2GRAY),
cv2.cvtColor(img_target, cv2.COLOR_BGR2GRAY),
1)]
h, w, c = img_source.shape
result = []
for i in range(c):
result += [
run(img_source[..., i:i + 1], guides=guides,
patch_size=11,
num_pyramid_levels=40,
num_search_vote_iters=6,
num_patch_match_iters=4,
stop_threshold=5,
uniformity_weight=500.0,
extraPass3x3=True,
)
]
return np.concatenate(result, axis=-1)
def task(img_style, guides):
return run(img_style,
guides,
patch_size=5,
num_pyramid_levels=6,
num_search_vote_iters=12,
num_patch_match_iters=6,
uniformity_weight=3500.0,
extraPass3x3=False
)