Spaces:
Sleeping
Sleeping
import argparse | |
import multiprocessing as mp | |
import os | |
import numpy as np | |
from dataclasses import dataclass | |
from tqdm import tqdm | |
def numericalize(cmd, n=64): | |
"""NOTE: shall only be called after normalization""" | |
cmd = ((cmd) / 30 * n).round().clip(min=0, max=n-1).astype(int) | |
return cmd | |
def denumericalize(cmd, n=64): | |
cmd = cmd / n * 30 | |
return cmd | |
def cal_aux_bezier_pts(font_seq, opts): | |
""" | |
calculate aux pts along bezier curves | |
""" | |
pts_aux_all = [] | |
for j in range(opts.char_num): | |
char_seq = font_seq[j] # shape: opts.max_len ,12 | |
pts_aux_char = [] | |
for k in range(opts.max_seq_len): | |
stroke_seq = char_seq[k] | |
stroke_cmd = np.argmax(stroke_seq[:4], -1) | |
stroke_seq[4:] = denumericalize(numericalize(stroke_seq[4:])) | |
p0, p1, p2, p3 = stroke_seq[4:6], stroke_seq[6:8], stroke_seq[8:10], stroke_seq[10:12] | |
pts_aux_stroke = [] | |
if stroke_cmd == 0: | |
for t in range(6): | |
pts_aux_stroke.append(0) | |
elif stroke_cmd == 1: # move | |
for t in [0.25, 0.5, 0.75]: | |
coord_t = p0 + t*(p3-p0) | |
pts_aux_stroke.append(coord_t[0]) | |
pts_aux_stroke.append(coord_t[1]) | |
elif stroke_cmd == 2: # line | |
for t in [0.25, 0.5, 0.75]: | |
coord_t = p0 + t*(p3-p0) | |
pts_aux_stroke.append(coord_t[0]) | |
pts_aux_stroke.append(coord_t[1]) | |
elif stroke_cmd == 3: # curve | |
for t in [0.25, 0.5, 0.75]: | |
coord_t = (1-t)*(1-t)*(1-t)*p0 + 3*t*(1-t)*(1-t)*p1 + 3*t*t*(1-t)*p2 + t*t*t*p3 | |
pts_aux_stroke.append(coord_t[0]) | |
pts_aux_stroke.append(coord_t[1]) | |
pts_aux_stroke = np.array(pts_aux_stroke) | |
pts_aux_char.append(pts_aux_stroke) | |
pts_aux_char = np.array(pts_aux_char) | |
pts_aux_all.append(pts_aux_char) | |
pts_aux_all = np.array(pts_aux_all) | |
return pts_aux_all | |
def relax_rep(opts): | |
""" | |
relaxing the sequence representation, details are shown in paper | |
""" | |
data_path = os.path.join(opts.output_path, opts.language, opts.split) | |
font_dirs = os.listdir(data_path) | |
font_dirs.sort() | |
num_fonts = len(font_dirs) | |
print(f"Number {opts.split} fonts before processing", num_fonts) | |
num_processes = mp.cpu_count() - 1 | |
# num_processes = 1 | |
fonts_per_process = num_fonts // num_processes + 1 | |
def process(process_id): | |
for i in tqdm(range(process_id * fonts_per_process, (process_id + 1) * fonts_per_process)): | |
if i >= num_fonts: | |
break | |
font_dir = os.path.join(data_path, font_dirs[i]) | |
font_seq = np.load(os.path.join(font_dir, 'sequence.npy')).reshape(opts.char_num, opts.max_seq_len, -1) | |
font_len = np.load(os.path.join(font_dir, 'seq_len.npy')).reshape(-1) | |
cmd = font_seq[:, :, :4] | |
args = font_seq[:, :, 4:] | |
ret = [] | |
for j in range(opts.char_num): | |
char_cmds = cmd[j] | |
char_args = args[j] | |
char_len = font_len[j] | |
new_args = [] | |
for k in range(char_len): | |
cur_cls = np.argmax(char_cmds[k], -1) | |
cur_arg = char_args[k] | |
if k - 1 > -1: | |
pre_arg = char_args[k - 1] | |
if cur_cls == 1: # when k == 0, cur_cls == 1 | |
cur_arg = np.concatenate((np.array([cur_arg[-2], cur_arg[-1]]), cur_arg), -1) | |
else: | |
cur_arg = np.concatenate((np.array([pre_arg[-2], pre_arg[-1]]), cur_arg), -1) | |
new_args.append(cur_arg) | |
while(len(new_args)) < opts.max_seq_len: | |
new_args.append(np.array([0, 0, 0, 0, 0, 0, 0, 0])) | |
new_args = np.array(new_args) | |
new_seq = np.concatenate((char_cmds, new_args),-1) | |
ret.append(new_seq) | |
ret = np.array(ret) | |
# write relaxed version of sequence.npy | |
np.save(os.path.join(font_dir, 'sequence_relaxed.npy'), ret.reshape(opts.char_num, -1)) | |
pts_aux = cal_aux_bezier_pts(ret, opts) | |
np.save(os.path.join(font_dir, 'pts_aux.npy'), pts_aux) | |
processes = [mp.Process(target=process, args=[pid]) for pid in range(num_processes)] | |
for p in processes: | |
p.start() | |
for p in processes: | |
p.join() | |
def main(): | |
parser = argparse.ArgumentParser(description="relax representation") | |
parser.add_argument("--language", type=str, default='eng', choices=['eng', 'chn', 'tha']) | |
parser.add_argument("--data_path", type=str, default='./Font_Dataset', help="Path to Dataset") | |
parser.add_argument("--output_path", type=str, default='../data/vecfont_dataset_/', help="Path to write the database to") | |
parser.add_argument("--split", type=str, default='train') | |
opts = parser.parse_args() | |
relax_rep(opts) | |
if __name__ == "__main__": | |
main() |