Spaces:
Configuration error
Configuration error
# fork for Ezsynth(https://github.com/Trentonom0r3/Ezsynth) | |
import os | |
import sys | |
from ctypes import * | |
from pathlib import Path | |
import cv2 | |
import numpy as np | |
libebsynth = None | |
cached_buffer = {} | |
EBSYNTH_BACKEND_CPU = 0x0001 | |
EBSYNTH_BACKEND_CUDA = 0x0002 | |
EBSYNTH_BACKEND_AUTO = 0x0000 | |
EBSYNTH_MAX_STYLE_CHANNELS = 8 | |
EBSYNTH_MAX_GUIDE_CHANNELS = 24 | |
EBSYNTH_VOTEMODE_PLAIN = 0x0001 # weight = 1 | |
EBSYNTH_VOTEMODE_WEIGHTED = 0x0002 # weight = 1/(1+error) | |
def _normalize_img_shape(img): | |
img_len = len(img.shape) | |
if img_len == 2: | |
sh, sw = img.shape | |
sc = 0 | |
elif img_len == 3: | |
sh, sw, sc = img.shape | |
if sc == 0: | |
sc = 1 | |
img = img[..., np.newaxis] | |
return img | |
def run(img_style, guides, | |
patch_size=5, | |
num_pyramid_levels=-1, | |
num_search_vote_iters=6, | |
num_patch_match_iters=4, | |
stop_threshold=5, | |
uniformity_weight=3500.0, | |
extraPass3x3=False, | |
): | |
if patch_size < 3: | |
raise ValueError("patch_size is too small") | |
if patch_size % 2 == 0: | |
raise ValueError("patch_size must be an odd number") | |
if len(guides) == 0: | |
raise ValueError("at least one guide must be specified") | |
global libebsynth | |
if libebsynth is None: | |
if sys.platform[0:3] == 'win': | |
libebsynth_path = str(Path(__file__).parent / 'ebsynth.dll') | |
libebsynth = CDLL(libebsynth_path) | |
else: | |
# todo: implement for linux | |
pass | |
if libebsynth is not None: | |
libebsynth.ebsynthRun.argtypes = ( \ | |
c_int, | |
c_int, | |
c_int, | |
c_int, | |
c_int, | |
c_void_p, | |
c_void_p, | |
c_int, | |
c_int, | |
c_void_p, | |
c_void_p, | |
POINTER(c_float), | |
POINTER(c_float), | |
c_float, | |
c_int, | |
c_int, | |
c_int, | |
POINTER(c_int), | |
POINTER(c_int), | |
POINTER(c_int), | |
c_int, | |
c_void_p, | |
c_void_p | |
) | |
if libebsynth is None: | |
return img_style | |
img_style = _normalize_img_shape(img_style) | |
sh, sw, sc = img_style.shape | |
t_h, t_w, t_c = 0, 0, 0 | |
if sc > EBSYNTH_MAX_STYLE_CHANNELS: | |
raise ValueError(f"error: too many style channels {sc}, maximum number is {EBSYNTH_MAX_STYLE_CHANNELS}") | |
guides_source = [] | |
guides_target = [] | |
guides_weights = [] | |
for i in range(len(guides)): | |
source_guide, target_guide, guide_weight = guides[i] | |
source_guide = _normalize_img_shape(source_guide) | |
target_guide = _normalize_img_shape(target_guide) | |
s_h, s_w, s_c = source_guide.shape | |
nt_h, nt_w, nt_c = target_guide.shape | |
if s_h != sh or s_w != sw: | |
raise ValueError("guide source and style resolution must match style resolution.") | |
if t_c == 0: | |
t_h, t_w, t_c = nt_h, nt_w, nt_c | |
elif nt_h != t_h or nt_w != t_w: | |
raise ValueError("guides target resolutions must be equal") | |
if s_c != nt_c: | |
raise ValueError("guide source and target channels must match exactly.") | |
guides_source.append(source_guide) | |
guides_target.append(target_guide) | |
guides_weights += [guide_weight / s_c] * s_c | |
guides_source = np.concatenate(guides_source, axis=-1) | |
guides_target = np.concatenate(guides_target, axis=-1) | |
guides_weights = (c_float * len(guides_weights))(*guides_weights) | |
styleWeight = 1.0 | |
style_weights = [styleWeight / sc for i in range(sc)] | |
style_weights = (c_float * sc)(*style_weights) | |
maxPyramidLevels = 0 | |
for level in range(32, -1, -1): | |
if min(min(sh, t_h) * pow(2.0, -level), \ | |
min(sw, t_w) * pow(2.0, -level)) >= (2 * patch_size + 1): | |
maxPyramidLevels = level + 1 | |
break | |
if num_pyramid_levels == -1: | |
num_pyramid_levels = maxPyramidLevels | |
num_pyramid_levels = min(num_pyramid_levels, maxPyramidLevels) | |
num_search_vote_iters_per_level = (c_int * num_pyramid_levels)(*[num_search_vote_iters] * num_pyramid_levels) | |
num_patch_match_iters_per_level = (c_int * num_pyramid_levels)(*[num_patch_match_iters] * num_pyramid_levels) | |
stop_threshold_per_level = (c_int * num_pyramid_levels)(*[stop_threshold] * num_pyramid_levels) | |
buffer = cached_buffer.get((t_h, t_w, sc), None) | |
if buffer is None: | |
buffer = create_string_buffer(t_h * t_w * sc) | |
cached_buffer[(t_h, t_w, sc)] = buffer | |
libebsynth.ebsynthRun(EBSYNTH_BACKEND_AUTO, # backend | |
sc, # numStyleChannels | |
guides_source.shape[-1], # numGuideChannels | |
sw, # sourceWidth | |
sh, # sourceHeight | |
img_style.tobytes(), | |
# sourceStyleData (width * height * numStyleChannels) bytes, scan-line order | |
guides_source.tobytes(), | |
# sourceGuideData (width * height * numGuideChannels) bytes, scan-line order | |
t_w, # targetWidth | |
t_h, # targetHeight | |
guides_target.tobytes(), | |
# targetGuideData (width * height * numGuideChannels) bytes, scan-line order | |
None, | |
# targetModulationData (width * height * numGuideChannels) bytes, scan-line order; pass NULL to switch off the modulation | |
style_weights, # styleWeights (numStyleChannels) floats | |
guides_weights, # guideWeights (numGuideChannels) floats | |
uniformity_weight, | |
# uniformityWeight reasonable values are between 500-15000, 3500 is a good default | |
patch_size, # patchSize odd sizes only, use 5 for 5x5 patch, 7 for 7x7, etc. | |
EBSYNTH_VOTEMODE_WEIGHTED, # voteMode use VOTEMODE_WEIGHTED for sharper result | |
num_pyramid_levels, # numPyramidLevels | |
num_search_vote_iters_per_level, | |
# numSearchVoteItersPerLevel how many search/vote iters to perform at each level (array of ints, coarse first, fine last) | |
num_patch_match_iters_per_level, | |
# numPatchMatchItersPerLevel how many Patch-Match iters to perform at each level (array of ints, coarse first, fine last) | |
stop_threshold_per_level, | |
# stopThresholdPerLevel stop improving pixel when its change since last iteration falls under this threshold | |
1 if extraPass3x3 else 0, | |
# extraPass3x3 perform additional polishing pass with 3x3 patches at the finest level, use 0 to disable | |
None, # outputNnfData (width * height * 2) ints, scan-line order; pass NULL to ignore | |
buffer # outputImageData (width * height * numStyleChannels) bytes, scan-line order | |
) | |
return np.frombuffer(buffer, dtype=np.uint8).reshape((t_h, t_w, sc)).copy() | |
# transfer color from source to target | |
def color_transfer(img_source, img_target): | |
guides = [(cv2.cvtColor(img_source, cv2.COLOR_BGR2GRAY), | |
cv2.cvtColor(img_target, cv2.COLOR_BGR2GRAY), | |
1)] | |
h, w, c = img_source.shape | |
result = [] | |
for i in range(c): | |
result += [ | |
run(img_source[..., i:i + 1], guides=guides, | |
patch_size=11, | |
num_pyramid_levels=40, | |
num_search_vote_iters=6, | |
num_patch_match_iters=4, | |
stop_threshold=5, | |
uniformity_weight=500.0, | |
extraPass3x3=True, | |
) | |
] | |
return np.concatenate(result, axis=-1) | |
def task(img_style, guides): | |
return run(img_style, | |
guides, | |
patch_size=5, | |
num_pyramid_levels=6, | |
num_search_vote_iters=12, | |
num_patch_match_iters=6, | |
uniformity_weight=3500.0, | |
extraPass3x3=False | |
) | |