import argparse | |
import multiprocessing as mp | |
import os | |
import pickle | |
import numpy as np | |
from data_utils import svg_utils | |
from tqdm import tqdm | |
def exist_empty_imgs(imgs_array, num_chars): | |
for char_id in range(num_chars): | |
print(np.max(imgs_array[char_id])) | |
input() | |
if np.max(imgs_array[char_id]) == 0: | |
return True | |
return False | |
def create_db(opts, output_path, log_path): | |
charset = open(f"{opts.data_path}/char_set/{opts.language}.txt", 'r').read() | |
print("Process sfd to npy files in dirs....") | |
sdf_path = os.path.join(opts.sfd_path, opts.language, opts.split) | |
all_font_ids = sorted(os.listdir(sdf_path)) | |
num_fonts = len(all_font_ids) | |
num_fonts_w = len(str(num_fonts)) | |
print(f"Number {opts.split} fonts before processing", num_fonts) | |
num_processes = mp.cpu_count() - 1 | |
fonts_per_process = num_fonts // num_processes + 1 | |
num_chars = len(charset) | |
num_chars_w = len(str(num_chars)) | |
# import ipdb; ipdb.set_trace() | |
def process(process_id): | |
valid_chars = [] | |
invalid_path = [] | |
invalid_glypts = [] | |
cur_process_log_file = open(os.path.join(log_path, f'log_{opts.split}_{process_id}.txt'), 'w') | |
for i in tqdm(range(process_id * fonts_per_process, (process_id + 1) * fonts_per_process)): | |
if i >= num_fonts: | |
break | |
font_id = all_font_ids[i] | |
cur_font_sfd_dir = os.path.join(sdf_path, font_id) | |
cur_font_glyphs = [] | |
if not os.path.exists(os.path.join(cur_font_sfd_dir, 'imgs_' + str(opts.img_size) + '.npy')): | |
continue | |
# a whole font as an entry | |
for char_id in range(num_chars): | |
# print('char_id :',char_id) | |
if not os.path.exists(os.path.join(cur_font_sfd_dir, '{}_{num:0{width}}.sfd'.format(font_id, num=char_id, width=num_chars_w))): | |
break | |
char_desp_f = open(os.path.join(cur_font_sfd_dir, '{}_{num:0{width}}.txt'.format(font_id, num=char_id, width=num_chars_w)), 'r') | |
char_desp = char_desp_f.readlines() | |
sfd_f = open(os.path.join(cur_font_sfd_dir, '{}_{num:0{width}}.sfd'.format(font_id, num=char_id, width=num_chars_w)), 'r') | |
sfd = | |
uni = int(char_desp[0].strip()) | |
width = int(char_desp[1].strip()) | |
vwidth = int(char_desp[2].strip()) | |
char_idx = char_desp[3].strip() | |
font_idx = char_desp[4].strip() | |
cur_glyph = {} | |
cur_glyph['uni'] = uni | |
cur_glyph['width'] = width | |
cur_glyph['vwidth'] = vwidth | |
cur_glyph['sfd'] = sfd | |
cur_glyph['id'] = char_idx | |
cur_glyph['binary_fp'] = font_idx | |
if not svg_utils.is_valid_glyph(cur_glyph): | |
msg = f"font {font_idx}, char {char_idx} is not a valid glyph\n" | |
invalid_path.glypts([font_idx, int(char_idx), charset[int(char_idx)]]) | |
cur_process_log_file.write(msg) | |
char_desp_f.close() | |
sfd_f.close() | |
# use the font whose all glyphs are valid | |
break | |
pathunibfp = svg_utils.convert_to_path(cur_glyph) | |
if not svg_utils.is_valid_path(pathunibfp): | |
msg = f"font {font_idx}, char {char_idx}'s sfd is not a valid path\n" | |
invalid_path.append([font_idx, int(char_idx), charset[int(char_idx)]]) | |
cur_process_log_file.write(msg) | |
char_desp_f.close() | |
sfd_f.close() | |
break | |
valid_chars.append([font_idx, int(char_idx), charset[int(char_idx)]]) | |
example = svg_utils.create_example(pathunibfp) | |
cur_font_glyphs.append(example) | |
char_desp_f.close() | |
sfd_f.close() | |
if len(cur_font_glyphs) == num_chars: | |
# use the font whose all glyphs are valid | |
# merge the whole font | |
rendered = np.load(os.path.join(cur_font_sfd_dir, 'imgs_' + str(opts.img_size) + '.npy')) | |
if (rendered[0] == rendered[1]).all() == True: | |
continue | |
sequence = [] | |
seq_len = [] | |
binaryfp = [] | |
char_class = [] | |
for char_id in range(num_chars): | |
example = cur_font_glyphs[char_id] | |
sequence.append(example['sequence']) | |
seq_len.append(example['seq_len']) | |
char_class.append(example['class']) | |
binaryfp = example['binary_fp'] | |
if not os.path.exists(os.path.join(output_path, '{num:0{width}}'.format(num=i, width=num_fonts_w))): | |
os.mkdir(os.path.join(output_path, '{num:0{width}}'.format(num=i, width=num_fonts_w))) | |, '{num:0{width}}'.format(num=i, width=num_fonts_w), 'sequence.npy'), np.array(sequence)) | |, '{num:0{width}}'.format(num=i, width=num_fonts_w), 'seq_len.npy'), np.array(seq_len)) | |, '{num:0{width}}'.format(num=i, width=num_fonts_w), 'class.npy'), np.array(char_class)) | |, '{num:0{width}}'.format(num=i, width=num_fonts_w), 'font_id.npy'), np.array(binaryfp)) | |, '{num:0{width}}'.format(num=i, width=num_fonts_w), 'rendered_' + str(opts.img_size) + '.npy'), rendered) | |
print("valid_chars", len(valid_chars)) | |
print("invalid_path:", invalid_path) | |
print("invalid_glypts:",invalid_glypts) | |
processes = [mp.Process(target=process, args=[pid]) for pid in range(num_processes)] | |
for p in processes: | |
p.start() | |
for p in processes: | |
p.join() | |
print("Finished processing all sfd files, logs (invalid glyphs and paths) are saved to", log_path) | |
def cal_mean_stddev(opts, output_path): | |
print("Calculating all glyphs' mean stddev ....") | |
charset = open(f"{opts.data_path}/char_set/{opts.language}.txt", 'r').read() | |
font_paths = [] | |
for root, dirs, files in os.walk(output_path): | |
for dir_name in dirs: | |
font_paths.append(os.path.join(output_path, dir_name)) | |
font_paths.sort() | |
num_fonts = len(font_paths) | |
num_processes = mp.cpu_count() - 1 | |
fonts_per_process = num_fonts // num_processes + 1 | |
num_chars = len(charset) | |
manager = mp.Manager() | |
return_dict = manager.dict() | |
main_stddev_accum = svg_utils.MeanStddev() | |
print(main_stddev_accum) | |
def process(process_id, return_dict): | |
mean_stddev_accum = svg_utils.MeanStddev() | |
cur_sum_count = mean_stddev_accum.create_accumulator() | |
for i in range(process_id * fonts_per_process, (process_id + 1) * fonts_per_process): | |
if i >= num_fonts: | |
break | |
cur_font_path = font_paths[i] | |
for charid in range(num_chars): | |
cur_font_char = {} | |
cur_font_char['seq_len'] = np.load(os.path.join(cur_font_path, 'seq_len.npy')).tolist()[charid] | |
cur_font_char['sequence'] = np.load(os.path.join(cur_font_path, 'sequence.npy')).tolist()[charid] | |
# print(cur_font_char) | |
cur_sum_count = mean_stddev_accum.add_input(cur_sum_count, cur_font_char) | |
return_dict[process_id] = cur_sum_count | |
processes = [mp.Process(target=process, args=[pid, return_dict]) for pid in range(num_processes)] | |
for p in processes: | |
p.start() | |
for p in processes: | |
p.join() | |
merged_sum_count = main_stddev_accum.merge_accumulators(return_dict.values()) | |
output = main_stddev_accum.extract_output(merged_sum_count) | |
print('output :', output) | |
mean = output['mean'] | |
stdev = output['stddev'] | |
print('mean :', mean) | |
mean = np.concatenate((np.zeros([4]), mean[4:]), axis=0) | |
stdev = np.concatenate((np.ones([4]), stdev[4:]), axis=0) | |
# finally, save the mean and stddev files | |
output_path_ = os.path.join(opts.output_path, opts.language) | |, 'mean'), mean) | |, 'stdev'), stdev) | |
# rename npy to npz, don't mind about it, just some legacy issue | |
os.rename(os.path.join(output_path_, 'mean.npy'), os.path.join(output_path_, 'mean.npz')) | |
os.rename(os.path.join(output_path_, 'stdev.npy'), os.path.join(output_path_, 'stdev.npz')) | |
def main(): | |
parser = argparse.ArgumentParser(description="LMDB creation") | |
parser.add_argument("--language", type=str, default='eng', choices=['eng', 'chn', 'tha']) | |
parser.add_argument("--data_path", type=str, default='./Font_Dataset', help="Path to Dataset") | |
parser.add_argument("--ttf_path", type=str, default='../data/font_ttfs') | |
parser.add_argument('--sfd_path', type=str, default='../data/font_sfds') | |
parser.add_argument("--output_path", type=str, default='../data/vecfont_dataset_/', help="Path to write the database to") | |
parser.add_argument('--img_size', type=int, default=64, help="the height and width of glyph images") | |
parser.add_argument("--split", type=str, default='train') | |
# parser.add_argument("--log_dir", type=str, default='../data/font_sfds/log/') | |
parser.add_argument("--phase", type=int, default=0, choices=[0, 1, 2], | |
help="0 all, 1 create db, 2 cal stddev") | |
opts = parser.parse_args() | |
assert os.path.exists(opts.sfd_path), "specified sfd glyphs path does not exist" | |
output_path = os.path.join(opts.output_path, opts.language, opts.split) | |
log_path = os.path.join(opts.sfd_path, opts.language, 'log') | |
if not os.path.exists(output_path): | |
os.makedirs(output_path) | |
if not os.path.exists(log_path): | |
os.makedirs(log_path) | |
if opts.phase <= 1: | |
create_db(opts, output_path, log_path) | |
if opts.phase <= 2 and opts.split == 'train': | |
cal_mean_stddev(opts, output_path) | |
if __name__ == "__main__": | |
main() |