import argparse import multiprocessing as mp import os import numpy as np import math import cairosvg import shutil from data_utils.svg_utils import clockwise, render from common_utils import affine_shear, affine_rotate, affine_scale, trans2_white_bg def render_svg(svg_str, font_dir, char_idx, aug_idx, img_size): svg_html = render(svg_str) svg_path = open(f'{font_dir}/aug_svgs/{str(char_idx)}.svg', 'w') svg_path.write(svg_html) svg_path.close() cairosvg.svg2png(url=f'{font_dir}/aug_svgs/{str(char_idx)}.svg', write_to=f'{font_dir}/aug_imgs/{str(char_idx)}_{aug_idx}.png', output_width=img_size, output_height=img_size) img_arr = trans2_white_bg(f'{font_dir}/aug_imgs/{str(char_idx)}_{aug_idx}.png') return img_arr def aug_rules(char_seq, aug_idx): if aug_idx == 0: return clockwise(affine_shear(char_seq, dx=0.2))['sequence'] elif aug_idx == 1: return clockwise(affine_shear(char_seq, dy=-0.1))['sequence'] elif aug_idx == 2: return clockwise(affine_scale(char_seq, 0.8))['sequence'] elif aug_idx == 3: return clockwise(affine_rotate(char_seq, theta=5))['sequence'] else: return clockwise(affine_rotate(char_seq, theta=-5))['sequence'] def copy_others(dir_src, dir_tgt): for item in ['class.npy', 'font_id.npy', 'seq_len.npy']: shutil.copy(f'{dir_src}/{item}', f'{dir_tgt}/{item}') def apply_aug(opts): """ applying data augmentation for Chinese fonts """ data_path = os.path.join(opts.output_path, opts.language, opts.split) font_dirs_ = os.listdir(data_path) font_dirs = [] for idx in range(len(font_dirs_)): if '_' not in font_dirs_[idx].split('/')[-1]: font_dirs.append(font_dirs_[idx]) font_dirs.sort() num_fonts = len(font_dirs) print(f"Number {opts.split} fonts before processing", num_fonts) num_processes = mp.cpu_count() - 2 fonts_per_process = num_fonts // num_processes + 1 def process(process_id): for i in range(process_id * fonts_per_process, (process_id + 1) * fonts_per_process): if i >= num_fonts: break font_dir = os.path.join(data_path, font_dirs[i]) font_seq = np.load(os.path.join(font_dir, 'sequence.npy')).reshape(opts.n_chars, opts.max_len, -1) ret_seq_list = [] ret_img_list = [] for k in range(opts.n_aug): os.makedirs(font_dir + '_' + str(k), exist_ok=True) ret_seq_list.append([]) ret_img_list.append([]) os.makedirs(f'{font_dir}/aug_svgs', exist_ok=True) os.makedirs(f'{font_dir}/aug_imgs', exist_ok=True) for j in range(opts.n_chars): char_seq = font_seq[j] # default as [71, 12] for k in range(opts.n_aug): char_seq_aug = aug_rules(char_seq, k) ret_seq_list[k].append(char_seq_aug) img_arr = render_svg(char_seq_aug, font_dir, j, aug_idx=k, img_size=opts.img_size) ret_img_list[k].append(img_arr) for k in range(opts.n_aug): ret_seq_list[k] = np.array(ret_seq_list[k]).reshape(opts.n_chars, opts.max_len * 10) ret_img_list[k] = np.array(ret_img_list[k]).reshape(opts.n_chars, opts.img_size, opts.img_size) np.save(os.path.join(font_dir + '_' + str(k), f'sequence.npy'), ret_seq_list[k]) np.save(os.path.join(font_dir + '_' + str(k), f'rendered_{opts.img_size}.npy'), ret_img_list[k]) copy_others(font_dir, font_dir + '_' + str(k)) processes = [mp.Process(target=process, args=[pid]) for pid in range(num_processes)] for p in processes: p.start() for p in processes: p.join() def main(): parser = argparse.ArgumentParser(description="relax representation") parser.add_argument("--language", type=str, default='eng', choices=['eng', 'chn', 'tha']) parser.add_argument("--output_path", type=str, default='../data/vecfont_dataset_/', help="Path to write the database to") parser.add_argument('--max_len', type=int, default=71, help="by default, 51 for english and 71 for chinese") parser.add_argument('--n_aug', type=int, default=5, help="for each font, augment it for n_aug times") parser.add_argument('--n_chars', type=int, default=52) parser.add_argument('--img_size', type=int, default=64, help="the height and width of glyph images") parser.add_argument("--split", type=str, default='train') parser.add_argument('--debug', type=bool, default=True) opts = parser.parse_args() apply_aug(opts) if __name__ == "__main__": main()