Spaces:
Sleeping
Sleeping
import argparse | |
import multiprocessing as mp | |
import os | |
import numpy as np | |
import math | |
import cairosvg | |
import shutil | |
from data_utils.svg_utils import clockwise, render | |
from common_utils import affine_shear, affine_rotate, affine_scale, trans2_white_bg | |
def render_svg(svg_str, font_dir, char_idx, aug_idx, img_size): | |
svg_html = render(svg_str) | |
svg_path = open(f'{font_dir}/aug_svgs/{str(char_idx)}.svg', 'w') | |
svg_path.write(svg_html) | |
svg_path.close() | |
cairosvg.svg2png(url=f'{font_dir}/aug_svgs/{str(char_idx)}.svg', | |
write_to=f'{font_dir}/aug_imgs/{str(char_idx)}_{aug_idx}.png', output_width=img_size, output_height=img_size) | |
img_arr = trans2_white_bg(f'{font_dir}/aug_imgs/{str(char_idx)}_{aug_idx}.png') | |
return img_arr | |
def aug_rules(char_seq, aug_idx): | |
if aug_idx == 0: | |
return clockwise(affine_shear(char_seq, dx=0.2))['sequence'] | |
elif aug_idx == 1: | |
return clockwise(affine_shear(char_seq, dy=-0.1))['sequence'] | |
elif aug_idx == 2: | |
return clockwise(affine_scale(char_seq, 0.8))['sequence'] | |
elif aug_idx == 3: | |
return clockwise(affine_rotate(char_seq, theta=5))['sequence'] | |
else: | |
return clockwise(affine_rotate(char_seq, theta=-5))['sequence'] | |
def copy_others(dir_src, dir_tgt): | |
for item in ['class.npy', 'font_id.npy', 'seq_len.npy']: | |
shutil.copy(f'{dir_src}/{item}', f'{dir_tgt}/{item}') | |
def apply_aug(opts): | |
""" | |
applying data augmentation for Chinese fonts | |
""" | |
data_path = os.path.join(opts.output_path, opts.language, opts.split) | |
font_dirs_ = os.listdir(data_path) | |
font_dirs = [] | |
for idx in range(len(font_dirs_)): | |
if '_' not in font_dirs_[idx].split('/')[-1]: | |
font_dirs.append(font_dirs_[idx]) | |
font_dirs.sort() | |
num_fonts = len(font_dirs) | |
print(f"Number {opts.split} fonts before processing", num_fonts) | |
num_processes = mp.cpu_count() - 2 | |
fonts_per_process = num_fonts // num_processes + 1 | |
def process(process_id): | |
for i in range(process_id * fonts_per_process, (process_id + 1) * fonts_per_process): | |
if i >= num_fonts: | |
break | |
font_dir = os.path.join(data_path, font_dirs[i]) | |
font_seq = np.load(os.path.join(font_dir, 'sequence.npy')).reshape(opts.n_chars, opts.max_len, -1) | |
ret_seq_list = [] | |
ret_img_list = [] | |
for k in range(opts.n_aug): | |
os.makedirs(font_dir + '_' + str(k), exist_ok=True) | |
ret_seq_list.append([]) | |
ret_img_list.append([]) | |
os.makedirs(f'{font_dir}/aug_svgs', exist_ok=True) | |
os.makedirs(f'{font_dir}/aug_imgs', exist_ok=True) | |
for j in range(opts.n_chars): | |
char_seq = font_seq[j] # default as [71, 12] | |
for k in range(opts.n_aug): | |
char_seq_aug = aug_rules(char_seq, k) | |
ret_seq_list[k].append(char_seq_aug) | |
img_arr = render_svg(char_seq_aug, font_dir, j, aug_idx=k, img_size=opts.img_size) | |
ret_img_list[k].append(img_arr) | |
for k in range(opts.n_aug): | |
ret_seq_list[k] = np.array(ret_seq_list[k]).reshape(opts.n_chars, opts.max_len * 10) | |
ret_img_list[k] = np.array(ret_img_list[k]).reshape(opts.n_chars, opts.img_size, opts.img_size) | |
np.save(os.path.join(font_dir + '_' + str(k), f'sequence.npy'), ret_seq_list[k]) | |
np.save(os.path.join(font_dir + '_' + str(k), f'rendered_{opts.img_size}.npy'), ret_img_list[k]) | |
copy_others(font_dir, font_dir + '_' + str(k)) | |
processes = [mp.Process(target=process, args=[pid]) for pid in range(num_processes)] | |
for p in processes: | |
p.start() | |
for p in processes: | |
p.join() | |
def main(): | |
parser = argparse.ArgumentParser(description="relax representation") | |
parser.add_argument("--language", type=str, default='eng', choices=['eng', 'chn', 'tha']) | |
parser.add_argument("--output_path", type=str, default='../data/vecfont_dataset_/', help="Path to write the database to") | |
parser.add_argument('--max_len', type=int, default=71, help="by default, 51 for english and 71 for chinese") | |
parser.add_argument('--n_aug', type=int, default=5, help="for each font, augment it for n_aug times") | |
parser.add_argument('--n_chars', type=int, default=52) | |
parser.add_argument('--img_size', type=int, default=64, help="the height and width of glyph images") | |
parser.add_argument("--split", type=str, default='train') | |
parser.add_argument('--debug', type=bool, default=True) | |
opts = parser.parse_args() | |
apply_aug(opts) | |
if __name__ == "__main__": | |
main() | |