Spaces:
Sleeping
Sleeping
add files
Browse files- .gitattributes +1 -0
- .gitignore +92 -0
- LICENSE +21 -0
- README.md +5 -8
- __init__.py +0 -0
- app.py +113 -0
- data_utils/__init__.py +0 -0
- data_utils/augment.py +112 -0
- data_utils/common_utils.py +75 -0
- data_utils/convert_ttf_to_sfd.py +103 -0
- data_utils/relax_rep.py +136 -0
- data_utils/svg_utils.py +1083 -0
- data_utils/svg_utils_backup.py +1174 -0
- data_utils/write_data_to_dirs.py +231 -0
- data_utils/write_glyph_imgs.py +180 -0
- dataloader.py +67 -0
- font_sample/Athiti-Regular.ttf +0 -0
- font_sample/SaoChingcha-Bold.otf +0 -0
- font_sample/SaoChingcha-Light.otf +0 -0
- font_sample/SaoChingcha-Regular.otf +0 -0
- generate.py +143 -0
- models/__init__.py +0 -0
- models/image_decoder.py +48 -0
- models/image_encoder.py +42 -0
- models/modality_fusion.py +64 -0
- models/model_main.py +212 -0
- models/pos_enc.py +21 -0
- models/transformers.py +711 -0
- models/util_funcs.py +96 -0
- models/vgg_perceptual_loss.py +69 -0
- options.py +67 -0
- packages.txt +1 -0
- requirements.txt +121 -0
- test.py +60 -0
- test_few_shot.py +164 -0
- train.py +216 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data/
|
2 |
+
experiments/
|
3 |
+
inference_few_shot.py
|
4 |
+
flagged
|
5 |
+
inference
|
6 |
+
inference_model
|
7 |
+
Font_dataset
|
8 |
+
venv
|
9 |
+
|
10 |
+
# Byte-compiled / optimized / DLL files
|
11 |
+
__pycache__/
|
12 |
+
*.py[cod]
|
13 |
+
*$py.class
|
14 |
+
|
15 |
+
# C extensions
|
16 |
+
*.so
|
17 |
+
|
18 |
+
# Distribution / packaging
|
19 |
+
.Python
|
20 |
+
build/
|
21 |
+
develop-eggs/
|
22 |
+
dist/
|
23 |
+
downloads/
|
24 |
+
eggs/
|
25 |
+
.eggs/
|
26 |
+
lib/
|
27 |
+
lib64/
|
28 |
+
parts/
|
29 |
+
sdist/
|
30 |
+
var/
|
31 |
+
wheels/
|
32 |
+
share/python-wheels/
|
33 |
+
*.egg-info/
|
34 |
+
.installed.cfg
|
35 |
+
*.egg
|
36 |
+
MANIFEST
|
37 |
+
|
38 |
+
# PyInstaller
|
39 |
+
# Usually these files are written by a python script from a template
|
40 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
41 |
+
*.manifest
|
42 |
+
*.spec
|
43 |
+
|
44 |
+
# Installer logs
|
45 |
+
pip-log.txt
|
46 |
+
pip-delete-this-directory.txt
|
47 |
+
|
48 |
+
# Unit test / coverage reports
|
49 |
+
htmlcov/
|
50 |
+
.tox/
|
51 |
+
.nox/
|
52 |
+
.coverage
|
53 |
+
.coverage.*
|
54 |
+
.cache
|
55 |
+
nosetests.xml
|
56 |
+
coverage.xml
|
57 |
+
*.cover
|
58 |
+
*.py,cover
|
59 |
+
.hypothesis/
|
60 |
+
.pytest_cache/
|
61 |
+
cover/
|
62 |
+
|
63 |
+
# Translations
|
64 |
+
*.mo
|
65 |
+
*.pot
|
66 |
+
|
67 |
+
# Django stuff:
|
68 |
+
*.log
|
69 |
+
local_settings.py
|
70 |
+
db.sqlite3
|
71 |
+
db.sqlite3-journal
|
72 |
+
|
73 |
+
# Flask stuff:
|
74 |
+
instance/
|
75 |
+
.webassets-cache
|
76 |
+
|
77 |
+
# Scrapy stuff:
|
78 |
+
.scrapy
|
79 |
+
|
80 |
+
# Sphinx documentation
|
81 |
+
docs/_build/
|
82 |
+
|
83 |
+
# PyBuilder
|
84 |
+
.pybuilder/
|
85 |
+
target/
|
86 |
+
|
87 |
+
# Jupyter Notebook
|
88 |
+
.ipynb_checkpoints
|
89 |
+
|
90 |
+
# IPython
|
91 |
+
profile_default/
|
92 |
+
ipython_config.py
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Yizhi Wang
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,13 +1,10 @@
|
|
1 |
---
|
2 |
title: ThaiVecFont
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: streamlit
|
7 |
-
sdk_version: 1.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: ThaiVecFont
|
3 |
+
emoji: 🚀
|
4 |
+
colorFrom: #D3D3D3
|
5 |
+
colorTo: #708090
|
6 |
sdk: streamlit
|
7 |
+
sdk_version: 1.25.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
---
|
|
|
|
|
|
__init__.py
ADDED
File without changes
|
app.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
import streamlit as st
|
3 |
+
from generate import ttf_to_image
|
4 |
+
from PIL import Image
|
5 |
+
import os
|
6 |
+
|
7 |
+
LOADED_TTF_KEY = "loaded_ttf"
|
8 |
+
SET_IMG_KEY = "set_img"
|
9 |
+
OUTPUT_IMG_KEY = "output_img"
|
10 |
+
|
11 |
+
def get_ttf(key: str) -> Optional[any]:
|
12 |
+
if key in st.session_state:
|
13 |
+
return st.session_state[key]
|
14 |
+
return None
|
15 |
+
|
16 |
+
def get_img(key: str) -> Optional[Image.Image]:
|
17 |
+
if key in st.session_state:
|
18 |
+
return st.session_state[key]
|
19 |
+
return None
|
20 |
+
|
21 |
+
def set_img(key: str, img: Image.Image):
|
22 |
+
st.session_state[key] = img
|
23 |
+
|
24 |
+
def ttf_uploader(prefix):
|
25 |
+
file = st.file_uploader("TTF, OTF", ["ttf", "otf"], key=f"{prefix}-uploader")
|
26 |
+
if file:
|
27 |
+
return file
|
28 |
+
|
29 |
+
return get_ttf(LOADED_TTF_KEY)
|
30 |
+
|
31 |
+
def generate_button(prefix, file_input, version, **kwargs):
|
32 |
+
|
33 |
+
col1, col2 = st.columns(2)
|
34 |
+
with col1:
|
35 |
+
n_samples = st.slider(
|
36 |
+
"Number of inference sample",
|
37 |
+
min_value=1,
|
38 |
+
max_value=200,
|
39 |
+
value=20,
|
40 |
+
key=f"{prefix}-inference-sample",
|
41 |
+
)
|
42 |
+
with col2:
|
43 |
+
ref_char_ids = st.text_area(
|
44 |
+
"ref_char_ids",
|
45 |
+
value="1,2,3,4,5,6,7,8",
|
46 |
+
key=f"{prefix}-ref_char_ids",
|
47 |
+
)
|
48 |
+
enable_attention_slicing = st.checkbox(
|
49 |
+
"Enable attention slicing (enables higher resolutions but is slower)",
|
50 |
+
key=f"{prefix}-attention-slicing",
|
51 |
+
)
|
52 |
+
enable_cpu_offload = st.checkbox(
|
53 |
+
"Enable CPU offload (if you run out of memory, e.g. for XL model)",
|
54 |
+
key=f"{prefix}-cpu-offload",
|
55 |
+
value=False,
|
56 |
+
)
|
57 |
+
|
58 |
+
if st.button("Generate image", key=f"{prefix}-btn"):
|
59 |
+
with st.spinner("⏳ Generating image..."):
|
60 |
+
image = ttf_to_image(file_input, n_samples, ref_char_ids, version)
|
61 |
+
set_img(OUTPUT_IMG_KEY, image.copy())
|
62 |
+
st.image(image)
|
63 |
+
|
64 |
+
test_font = st.text_area(
|
65 |
+
"test font",
|
66 |
+
value="กขคง",
|
67 |
+
key=f"{prefix}-prompt",
|
68 |
+
)
|
69 |
+
|
70 |
+
def generate_tab():
|
71 |
+
prefix = "ttf2img"
|
72 |
+
col1, col2 = st.columns(2)
|
73 |
+
|
74 |
+
with col1:
|
75 |
+
sample_choose = st.selectbox(
|
76 |
+
"Choose Sample", ["Custom"] + [i for i in os.listdir("font_sample/")], key=f"{prefix}-sample_choose"
|
77 |
+
)
|
78 |
+
if sample_choose == "Custom":
|
79 |
+
uploaded_file = ttf_uploader(prefix)
|
80 |
+
if uploaded_file:
|
81 |
+
st.write("filename:", uploaded_file.name)
|
82 |
+
uploaded_file = uploaded_file.getbuffer() # Send file as Buffer
|
83 |
+
|
84 |
+
else:
|
85 |
+
st.write("filename:", sample_choose)
|
86 |
+
uploaded_file = os.path.join("font_sample", sample_choose)
|
87 |
+
|
88 |
+
with col2:
|
89 |
+
if uploaded_file:
|
90 |
+
version = st.selectbox(
|
91 |
+
"Model version", ["TH2TH", "ENG2TH"], key=f"{prefix}-version"
|
92 |
+
)
|
93 |
+
generate_button(
|
94 |
+
prefix, file_input=uploaded_file, version=version
|
95 |
+
)
|
96 |
+
|
97 |
+
def main():
|
98 |
+
st.set_page_config(layout="wide")
|
99 |
+
st.title("ThaiVecFont Playground")
|
100 |
+
|
101 |
+
generate_tab()
|
102 |
+
|
103 |
+
with st.sidebar:
|
104 |
+
st.header("Latest Output Image")
|
105 |
+
output_image = get_img(OUTPUT_IMG_KEY)
|
106 |
+
if output_image:
|
107 |
+
st.image(output_image)
|
108 |
+
else:
|
109 |
+
st.markdown("No output generated yet")
|
110 |
+
|
111 |
+
|
112 |
+
if __name__ == "__main__":
|
113 |
+
main()
|
data_utils/__init__.py
ADDED
File without changes
|
data_utils/augment.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import multiprocessing as mp
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
import math
|
6 |
+
import cairosvg
|
7 |
+
import shutil
|
8 |
+
from data_utils.svg_utils import clockwise, render
|
9 |
+
from common_utils import affine_shear, affine_rotate, affine_scale, trans2_white_bg
|
10 |
+
|
11 |
+
def render_svg(svg_str, font_dir, char_idx, aug_idx, img_size):
|
12 |
+
svg_html = render(svg_str)
|
13 |
+
svg_path = open(f'{font_dir}/aug_svgs/{str(char_idx)}.svg', 'w')
|
14 |
+
svg_path.write(svg_html)
|
15 |
+
svg_path.close()
|
16 |
+
cairosvg.svg2png(url=f'{font_dir}/aug_svgs/{str(char_idx)}.svg',
|
17 |
+
write_to=f'{font_dir}/aug_imgs/{str(char_idx)}_{aug_idx}.png', output_width=img_size, output_height=img_size)
|
18 |
+
img_arr = trans2_white_bg(f'{font_dir}/aug_imgs/{str(char_idx)}_{aug_idx}.png')
|
19 |
+
return img_arr
|
20 |
+
|
21 |
+
def aug_rules(char_seq, aug_idx):
|
22 |
+
if aug_idx == 0:
|
23 |
+
return clockwise(affine_shear(char_seq, dx=0.2))['sequence']
|
24 |
+
elif aug_idx == 1:
|
25 |
+
return clockwise(affine_shear(char_seq, dy=-0.1))['sequence']
|
26 |
+
elif aug_idx == 2:
|
27 |
+
return clockwise(affine_scale(char_seq, 0.8))['sequence']
|
28 |
+
elif aug_idx == 3:
|
29 |
+
return clockwise(affine_rotate(char_seq, theta=5))['sequence']
|
30 |
+
else:
|
31 |
+
return clockwise(affine_rotate(char_seq, theta=-5))['sequence']
|
32 |
+
|
33 |
+
def copy_others(dir_src, dir_tgt):
|
34 |
+
for item in ['class.npy', 'font_id.npy', 'seq_len.npy']:
|
35 |
+
shutil.copy(f'{dir_src}/{item}', f'{dir_tgt}/{item}')
|
36 |
+
|
37 |
+
def apply_aug(opts):
|
38 |
+
"""
|
39 |
+
applying data augmentation for Chinese fonts
|
40 |
+
"""
|
41 |
+
data_path = os.path.join(opts.output_path, opts.language, opts.split)
|
42 |
+
font_dirs_ = os.listdir(data_path)
|
43 |
+
font_dirs = []
|
44 |
+
for idx in range(len(font_dirs_)):
|
45 |
+
if '_' not in font_dirs_[idx].split('/')[-1]:
|
46 |
+
font_dirs.append(font_dirs_[idx])
|
47 |
+
font_dirs.sort()
|
48 |
+
num_fonts = len(font_dirs)
|
49 |
+
print(f"Number {opts.split} fonts before processing", num_fonts)
|
50 |
+
num_processes = mp.cpu_count() - 2
|
51 |
+
fonts_per_process = num_fonts // num_processes + 1
|
52 |
+
|
53 |
+
def process(process_id):
|
54 |
+
for i in range(process_id * fonts_per_process, (process_id + 1) * fonts_per_process):
|
55 |
+
if i >= num_fonts:
|
56 |
+
break
|
57 |
+
font_dir = os.path.join(data_path, font_dirs[i])
|
58 |
+
font_seq = np.load(os.path.join(font_dir, 'sequence.npy')).reshape(opts.n_chars, opts.max_len, -1)
|
59 |
+
|
60 |
+
ret_seq_list = []
|
61 |
+
ret_img_list = []
|
62 |
+
for k in range(opts.n_aug):
|
63 |
+
os.makedirs(font_dir + '_' + str(k), exist_ok=True)
|
64 |
+
ret_seq_list.append([])
|
65 |
+
ret_img_list.append([])
|
66 |
+
|
67 |
+
os.makedirs(f'{font_dir}/aug_svgs', exist_ok=True)
|
68 |
+
os.makedirs(f'{font_dir}/aug_imgs', exist_ok=True)
|
69 |
+
|
70 |
+
for j in range(opts.n_chars):
|
71 |
+
char_seq = font_seq[j] # default as [71, 12]
|
72 |
+
for k in range(opts.n_aug):
|
73 |
+
char_seq_aug = aug_rules(char_seq, k)
|
74 |
+
ret_seq_list[k].append(char_seq_aug)
|
75 |
+
img_arr = render_svg(char_seq_aug, font_dir, j, aug_idx=k, img_size=opts.img_size)
|
76 |
+
ret_img_list[k].append(img_arr)
|
77 |
+
|
78 |
+
for k in range(opts.n_aug):
|
79 |
+
ret_seq_list[k] = np.array(ret_seq_list[k]).reshape(opts.n_chars, opts.max_len * 10)
|
80 |
+
ret_img_list[k] = np.array(ret_img_list[k]).reshape(opts.n_chars, opts.img_size, opts.img_size)
|
81 |
+
np.save(os.path.join(font_dir + '_' + str(k), f'sequence.npy'), ret_seq_list[k])
|
82 |
+
np.save(os.path.join(font_dir + '_' + str(k), f'rendered_{opts.img_size}.npy'), ret_img_list[k])
|
83 |
+
copy_others(font_dir, font_dir + '_' + str(k))
|
84 |
+
|
85 |
+
processes = [mp.Process(target=process, args=[pid]) for pid in range(num_processes)]
|
86 |
+
|
87 |
+
for p in processes:
|
88 |
+
p.start()
|
89 |
+
for p in processes:
|
90 |
+
p.join()
|
91 |
+
|
92 |
+
|
93 |
+
def main():
|
94 |
+
parser = argparse.ArgumentParser(description="relax representation")
|
95 |
+
parser.add_argument("--language", type=str, default='eng', choices=['eng', 'chn', 'tha'])
|
96 |
+
parser.add_argument("--output_path", type=str, default='../data/vecfont_dataset_/', help="Path to write the database to")
|
97 |
+
parser.add_argument('--max_len', type=int, default=71, help="by default, 51 for english and 71 for chinese")
|
98 |
+
parser.add_argument('--n_aug', type=int, default=5, help="for each font, augment it for n_aug times")
|
99 |
+
parser.add_argument('--n_chars', type=int, default=52)
|
100 |
+
parser.add_argument('--img_size', type=int, default=64, help="the height and width of glyph images")
|
101 |
+
parser.add_argument("--split", type=str, default='train')
|
102 |
+
parser.add_argument('--debug', type=bool, default=True)
|
103 |
+
opts = parser.parse_args()
|
104 |
+
apply_aug(opts)
|
105 |
+
|
106 |
+
if __name__ == "__main__":
|
107 |
+
main()
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
|
data_utils/common_utils.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
def trans2_white_bg(img_path):
|
6 |
+
img = Image.open(img_path)
|
7 |
+
img_arr = 255 - np.array(img)[:, :, 3]
|
8 |
+
img_ = Image.fromarray(img_arr)
|
9 |
+
img_.save(img_path)
|
10 |
+
return img_arr
|
11 |
+
|
12 |
+
def affine_shear(seq, dx=-0.3, dy=0.0):
|
13 |
+
mask = ~(seq == 0)
|
14 |
+
seq_12 = seq.copy()
|
15 |
+
seq_12[:,4] -= 12.0
|
16 |
+
seq_12[:,5] = -seq_12[:,5] + 12
|
17 |
+
seq_12[:,6] -= 12.0
|
18 |
+
seq_12[:,7] = -seq_12[:,7] + 12
|
19 |
+
seq_12[:,8] -= 12.0
|
20 |
+
seq_12[:,9] = -seq_12[:,9] + 12
|
21 |
+
|
22 |
+
seq_args = seq_12[:,4:]
|
23 |
+
seq_args = np.concatenate([seq_args[:, :2], seq_args[:, 2:4], seq_args[:, 4:6]], 0).transpose(1,0)
|
24 |
+
affine_matrix=np.array([[1, dx],
|
25 |
+
[dy, 1]])
|
26 |
+
rotated_args = np.dot(affine_matrix,seq_args)
|
27 |
+
rotated_args = rotated_args.transpose(1,0)
|
28 |
+
new_args = np.concatenate([rotated_args[:seq.shape[0]], rotated_args[seq.shape[0]:seq.shape[0]*2], rotated_args[seq.shape[0]*2:]],-1)
|
29 |
+
new_args[:,0] += 12.0
|
30 |
+
new_args[:,1] = -(new_args[:,1] - 12)
|
31 |
+
new_args[:,2] += 12.0
|
32 |
+
new_args[:,3] = -(new_args[:,3] - 12)
|
33 |
+
new_args[:,4] += 12.0
|
34 |
+
new_args[:,5] = -(new_args[:,5] - 12)
|
35 |
+
new_seq = np.concatenate([seq[:, :4], new_args],1)
|
36 |
+
new_seq = new_seq * mask
|
37 |
+
return new_seq
|
38 |
+
|
39 |
+
def affine_scale(seq, scale=0.8):
|
40 |
+
mask = ~(seq==0)
|
41 |
+
seq_args = seq[:, 4:] - 12.0
|
42 |
+
seq_args *= scale
|
43 |
+
seq_args = seq_args + 12.0
|
44 |
+
new_seq = np.concatenate([seq[:, :4], seq_args], 1)
|
45 |
+
new_seq = new_seq * mask
|
46 |
+
return new_seq
|
47 |
+
|
48 |
+
def affine_rotate(seq,theta=-5):
|
49 |
+
mask = ~(seq==0)
|
50 |
+
seq_12 = seq.copy()
|
51 |
+
seq_12[:,4] -=12.0
|
52 |
+
seq_12[:,5] = -seq_12[:,5] + 12
|
53 |
+
seq_12[:,6] -=12.0
|
54 |
+
seq_12[:,7] = -seq_12[:,7] + 12
|
55 |
+
seq_12[:,8] -=12.0
|
56 |
+
seq_12[:,9] = -seq_12[:,9] + 12
|
57 |
+
|
58 |
+
seq_args =seq_12[:, 4:] # default as [71,6]
|
59 |
+
seq_args = np.concatenate([seq_args[:,:2],seq_args[:,2:4],seq_args[:,4:6]],0).transpose(1,0)# note 2,213
|
60 |
+
theta = math.radians(theta)
|
61 |
+
affine_matrix=np.array([[np.cos(theta),-np.sin(theta)], [np.sin(theta), np.cos(theta)]])# note 2,2
|
62 |
+
rotated_args = np.dot(affine_matrix,seq_args)# note 2,213
|
63 |
+
rotated_args = rotated_args.transpose(1,0)# note 213,2
|
64 |
+
new_args = np.concatenate([rotated_args[:seq.shape[0]],rotated_args[seq.shape[0]:seq.shape[0]*2],rotated_args[seq.shape[0]*2:]],-1)# note 2,213
|
65 |
+
new_args[:,0] +=12.0
|
66 |
+
new_args[:,1] = -(new_args[:,1]-12)
|
67 |
+
new_args[:,2] +=12.0
|
68 |
+
new_args[:,3] = -(new_args[:,3]-12)
|
69 |
+
new_args[:,4] +=12.0
|
70 |
+
new_args[:,5] = -(new_args[:,5]-12)
|
71 |
+
|
72 |
+
new_seq = np.concatenate([seq[:,:4],new_args],1)
|
73 |
+
new_seq =new_seq *mask
|
74 |
+
return new_seq
|
75 |
+
|
data_utils/convert_ttf_to_sfd.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import fontforge # noqa
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
from tqdm import tqdm
|
5 |
+
import multiprocessing as mp
|
6 |
+
import argparse
|
7 |
+
|
8 |
+
|
9 |
+
def convert_mp(opts):
|
10 |
+
"""Useing multiprocessing to convert all fonts to sfd files"""
|
11 |
+
|
12 |
+
charset_th = open(f"{opts.data_path}/char_set/{opts.language}.txt", 'r').read()
|
13 |
+
charset = charset_th
|
14 |
+
if opts.ref_nshot == 52:
|
15 |
+
charset_eng = open(f"{opts.data_path}/char_set/eng.txt", 'r').read()
|
16 |
+
charset = charset_th + charset_eng
|
17 |
+
charset_lenw = len(str(len(charset)))
|
18 |
+
fonts_file_path = os.path.join(opts.ttf_path, opts.language) # opts.ttf_path,opts.language,
|
19 |
+
sfd_path = os.path.join(opts.sfd_path, opts.language)
|
20 |
+
print(os.path.join(fonts_file_path, opts.split))
|
21 |
+
for root, dirs, files in os.walk(os.path.join(fonts_file_path, opts.split)):
|
22 |
+
ttf_fnames = files
|
23 |
+
print(ttf_fnames)
|
24 |
+
|
25 |
+
font_num = len(ttf_fnames)
|
26 |
+
process_num = mp.cpu_count() - 1
|
27 |
+
font_num_per_process = font_num // process_num + 1
|
28 |
+
|
29 |
+
def process(process_id, font_num_p_process):
|
30 |
+
for i in tqdm(range(process_id * font_num_p_process, (process_id + 1) * font_num_p_process)):
|
31 |
+
if i >= font_num:
|
32 |
+
break
|
33 |
+
|
34 |
+
font_id = ttf_fnames[i].split('.')[0]
|
35 |
+
split = opts.split
|
36 |
+
font_name = ttf_fnames[i]
|
37 |
+
|
38 |
+
font_file_path = os.path.join(fonts_file_path, split, font_name)
|
39 |
+
try:
|
40 |
+
cur_font = fontforge.open(font_file_path)
|
41 |
+
except Exception as e:
|
42 |
+
print('Cannot open ', font_name)
|
43 |
+
print(e)
|
44 |
+
continue
|
45 |
+
|
46 |
+
target_dir = os.path.join(sfd_path, split, "{}".format(font_id))
|
47 |
+
if not os.path.exists(target_dir):
|
48 |
+
os.makedirs(target_dir)
|
49 |
+
|
50 |
+
for char_id, char in enumerate(charset):
|
51 |
+
try:
|
52 |
+
char_description = open(os.path.join(target_dir, '{}_{num:0{width}}.txt'.format(font_id, num=char_id, width=charset_lenw)), 'w')
|
53 |
+
if char in charset_th:
|
54 |
+
char = 'uni' + char.encode("unicode_escape")[2:].decode("utf-8")
|
55 |
+
|
56 |
+
cur_font.selection.select(char)
|
57 |
+
cur_font.copy()
|
58 |
+
|
59 |
+
new_font_for_char = fontforge.font()
|
60 |
+
# new_font_for_char.ascent = 750
|
61 |
+
# new_font_for_char.descent = 250
|
62 |
+
# new_font_for_char.em = new_font_for_char.ascent + new_font_for_char.descent
|
63 |
+
char = 'A'
|
64 |
+
|
65 |
+
new_font_for_char.selection.select(char)
|
66 |
+
new_font_for_char.paste()
|
67 |
+
new_font_for_char.fontname = "{}_".format(font_id) + font_name
|
68 |
+
|
69 |
+
new_font_for_char.save(os.path.join(target_dir, '{}_{num:0{width}}.sfd'.format(font_id, num=char_id, width=charset_lenw)))
|
70 |
+
|
71 |
+
char_description.write(str(ord(char)) + '\n')
|
72 |
+
char_description.write(str(new_font_for_char[char].width) + '\n')
|
73 |
+
char_description.write(str(new_font_for_char[char].vwidth) + '\n')
|
74 |
+
char_description.write('{num:0{width}}'.format(num=char_id, width=charset_lenw) + '\n')
|
75 |
+
char_description.write('{}'.format(font_id))
|
76 |
+
# print('{}_{num:0{width}}.sfd'.format(font_id, num=char_id, width=charset_lenw))
|
77 |
+
char_description.close()
|
78 |
+
except Exception as e:
|
79 |
+
print("Found Error:", font_id, font_name ,char_id, char)
|
80 |
+
print(e)
|
81 |
+
|
82 |
+
cur_font.close()
|
83 |
+
|
84 |
+
processes = [mp.Process(target=process, args=(pid, font_num_per_process)) for pid in range(process_num)]
|
85 |
+
|
86 |
+
for p in processes:
|
87 |
+
p.start()
|
88 |
+
for p in processes:
|
89 |
+
p.join()
|
90 |
+
|
91 |
+
|
92 |
+
def main():
|
93 |
+
parser = argparse.ArgumentParser(description="Convert ttf fonts to sfd fonts")
|
94 |
+
parser.add_argument("--language", type=str, default='eng', choices=['eng', 'chn', 'tha'])
|
95 |
+
parser.add_argument("--data_path", type=str, default='./Font_Dataset', help="Path to Dataset")
|
96 |
+
parser.add_argument("--ttf_path", type=str, default='../data/font_ttfs')
|
97 |
+
parser.add_argument('--sfd_path', type=str, default='../data/font_sfds')
|
98 |
+
parser.add_argument('--split', type=str, default='train')
|
99 |
+
opts = parser.parse_args()
|
100 |
+
convert_mp(opts)
|
101 |
+
|
102 |
+
if __name__ == "__main__":
|
103 |
+
main()
|
data_utils/relax_rep.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import multiprocessing as mp
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
def numericalize(cmd, n=64):
|
9 |
+
"""NOTE: shall only be called after normalization"""
|
10 |
+
cmd = ((cmd) / 30 * n).round().clip(min=0, max=n-1).astype(int)
|
11 |
+
return cmd
|
12 |
+
|
13 |
+
def denumericalize(cmd, n=64):
|
14 |
+
cmd = cmd / n * 30
|
15 |
+
return cmd
|
16 |
+
|
17 |
+
def cal_aux_bezier_pts(font_seq, opts):
|
18 |
+
"""
|
19 |
+
calculate aux pts along bezier curves
|
20 |
+
"""
|
21 |
+
pts_aux_all = []
|
22 |
+
|
23 |
+
for j in range(opts.char_num):
|
24 |
+
char_seq = font_seq[j] # shape: opts.max_len ,12
|
25 |
+
pts_aux_char = []
|
26 |
+
for k in range(opts.max_seq_len):
|
27 |
+
stroke_seq = char_seq[k]
|
28 |
+
stroke_cmd = np.argmax(stroke_seq[:4], -1)
|
29 |
+
stroke_seq[4:] = denumericalize(numericalize(stroke_seq[4:]))
|
30 |
+
p0, p1, p2, p3 = stroke_seq[4:6], stroke_seq[6:8], stroke_seq[8:10], stroke_seq[10:12]
|
31 |
+
pts_aux_stroke = []
|
32 |
+
if stroke_cmd == 0:
|
33 |
+
for t in range(6):
|
34 |
+
pts_aux_stroke.append(0)
|
35 |
+
elif stroke_cmd == 1: # move
|
36 |
+
for t in [0.25, 0.5, 0.75]:
|
37 |
+
coord_t = p0 + t*(p3-p0)
|
38 |
+
pts_aux_stroke.append(coord_t[0])
|
39 |
+
pts_aux_stroke.append(coord_t[1])
|
40 |
+
elif stroke_cmd == 2: # line
|
41 |
+
for t in [0.25, 0.5, 0.75]:
|
42 |
+
coord_t = p0 + t*(p3-p0)
|
43 |
+
pts_aux_stroke.append(coord_t[0])
|
44 |
+
pts_aux_stroke.append(coord_t[1])
|
45 |
+
elif stroke_cmd == 3: # curve
|
46 |
+
for t in [0.25, 0.5, 0.75]:
|
47 |
+
coord_t = (1-t)*(1-t)*(1-t)*p0 + 3*t*(1-t)*(1-t)*p1 + 3*t*t*(1-t)*p2 + t*t*t*p3
|
48 |
+
pts_aux_stroke.append(coord_t[0])
|
49 |
+
pts_aux_stroke.append(coord_t[1])
|
50 |
+
|
51 |
+
pts_aux_stroke = np.array(pts_aux_stroke)
|
52 |
+
pts_aux_char.append(pts_aux_stroke)
|
53 |
+
|
54 |
+
pts_aux_char = np.array(pts_aux_char)
|
55 |
+
pts_aux_all.append(pts_aux_char)
|
56 |
+
|
57 |
+
pts_aux_all = np.array(pts_aux_all)
|
58 |
+
|
59 |
+
return pts_aux_all
|
60 |
+
|
61 |
+
|
62 |
+
def relax_rep(opts):
|
63 |
+
"""
|
64 |
+
relaxing the sequence representation, details are shown in paper
|
65 |
+
"""
|
66 |
+
data_path = os.path.join(opts.output_path, opts.language, opts.split)
|
67 |
+
font_dirs = os.listdir(data_path)
|
68 |
+
font_dirs.sort()
|
69 |
+
num_fonts = len(font_dirs)
|
70 |
+
print(f"Number {opts.split} fonts before processing", num_fonts)
|
71 |
+
num_processes = mp.cpu_count() - 1
|
72 |
+
# num_processes = 1
|
73 |
+
fonts_per_process = num_fonts // num_processes + 1
|
74 |
+
|
75 |
+
def process(process_id):
|
76 |
+
|
77 |
+
for i in tqdm(range(process_id * fonts_per_process, (process_id + 1) * fonts_per_process)):
|
78 |
+
if i >= num_fonts:
|
79 |
+
break
|
80 |
+
|
81 |
+
font_dir = os.path.join(data_path, font_dirs[i])
|
82 |
+
font_seq = np.load(os.path.join(font_dir, 'sequence.npy')).reshape(opts.char_num, opts.max_seq_len, -1)
|
83 |
+
font_len = np.load(os.path.join(font_dir, 'seq_len.npy')).reshape(-1)
|
84 |
+
cmd = font_seq[:, :, :4]
|
85 |
+
args = font_seq[:, :, 4:]
|
86 |
+
|
87 |
+
ret = []
|
88 |
+
for j in range(opts.char_num):
|
89 |
+
|
90 |
+
char_cmds = cmd[j]
|
91 |
+
char_args = args[j]
|
92 |
+
char_len = font_len[j]
|
93 |
+
new_args = []
|
94 |
+
for k in range(char_len):
|
95 |
+
cur_cls = np.argmax(char_cmds[k], -1)
|
96 |
+
cur_arg = char_args[k]
|
97 |
+
if k - 1 > -1:
|
98 |
+
pre_arg = char_args[k - 1]
|
99 |
+
if cur_cls == 1: # when k == 0, cur_cls == 1
|
100 |
+
cur_arg = np.concatenate((np.array([cur_arg[-2], cur_arg[-1]]), cur_arg), -1)
|
101 |
+
else:
|
102 |
+
cur_arg = np.concatenate((np.array([pre_arg[-2], pre_arg[-1]]), cur_arg), -1)
|
103 |
+
new_args.append(cur_arg)
|
104 |
+
|
105 |
+
while(len(new_args)) < opts.max_seq_len:
|
106 |
+
new_args.append(np.array([0, 0, 0, 0, 0, 0, 0, 0]))
|
107 |
+
|
108 |
+
new_args = np.array(new_args)
|
109 |
+
new_seq = np.concatenate((char_cmds, new_args),-1)
|
110 |
+
ret.append(new_seq)
|
111 |
+
ret = np.array(ret)
|
112 |
+
# write relaxed version of sequence.npy
|
113 |
+
np.save(os.path.join(font_dir, 'sequence_relaxed.npy'), ret.reshape(opts.char_num, -1))
|
114 |
+
|
115 |
+
pts_aux = cal_aux_bezier_pts(ret, opts)
|
116 |
+
np.save(os.path.join(font_dir, 'pts_aux.npy'), pts_aux)
|
117 |
+
|
118 |
+
processes = [mp.Process(target=process, args=[pid]) for pid in range(num_processes)]
|
119 |
+
|
120 |
+
for p in processes:
|
121 |
+
p.start()
|
122 |
+
for p in processes:
|
123 |
+
p.join()
|
124 |
+
|
125 |
+
|
126 |
+
def main():
|
127 |
+
parser = argparse.ArgumentParser(description="relax representation")
|
128 |
+
parser.add_argument("--language", type=str, default='eng', choices=['eng', 'chn', 'tha'])
|
129 |
+
parser.add_argument("--data_path", type=str, default='./Font_Dataset', help="Path to Dataset")
|
130 |
+
parser.add_argument("--output_path", type=str, default='../data/vecfont_dataset_/', help="Path to write the database to")
|
131 |
+
parser.add_argument("--split", type=str, default='train')
|
132 |
+
opts = parser.parse_args()
|
133 |
+
relax_rep(opts)
|
134 |
+
|
135 |
+
if __name__ == "__main__":
|
136 |
+
main()
|
data_utils/svg_utils.py
ADDED
@@ -0,0 +1,1083 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The Magenta Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# Lint as: python3
|
16 |
+
"""Defines the Material Design Icons Problem."""
|
17 |
+
import io
|
18 |
+
import numpy as np
|
19 |
+
import re
|
20 |
+
|
21 |
+
from PIL import Image
|
22 |
+
from itertools import zip_longest
|
23 |
+
from skimage import draw
|
24 |
+
|
25 |
+
|
26 |
+
SVG_PREFIX_BIG = ('<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="'
|
27 |
+
'http://www.w3.org/1999/xlink" width="256px" height="256px"'
|
28 |
+
' style="-ms-transform: rotate(360deg); -webkit-transform:'
|
29 |
+
' rotate(360deg); transform: rotate(360deg);" '
|
30 |
+
'preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 30">')
|
31 |
+
PATH_PREFIX_1 = '<path d="'
|
32 |
+
PATH_POSFIX_1 = '" fill="currentColor"/>'
|
33 |
+
SVG_POSFIX = '</svg>'
|
34 |
+
|
35 |
+
NUM_ARGS = {'v': 1, 'V': 1, 'h': 1, 'H': 1, 'a': 7, 'A': 7, 'l': 2, 'L': 2,
|
36 |
+
't': 2, 'T': 2, 'c': 6, 'C': 6, 'm': 2, 'M': 2, 's': 4, 'S': 4,
|
37 |
+
'q': 4, 'Q': 4, 'z': 0}
|
38 |
+
# in order of arg complexity, with absolutes clustered
|
39 |
+
# recall we don't handle all commands (see docstring)
|
40 |
+
CMDS_LIST = 'zHVMLTSQCAhvmltsqca' # was zhvmltsqcaHVMLTSQCA
|
41 |
+
CMD_MAPPING = {cmd: i for i, cmd in enumerate(CMDS_LIST)}
|
42 |
+
|
43 |
+
FEATURE_DIM = 10
|
44 |
+
|
45 |
+
MAX_SEQ_LEN = 120
|
46 |
+
|
47 |
+
# Manually Change Max Sequence
|
48 |
+
def change_max_seq_len(param):
|
49 |
+
global MAX_SEQ_LEN
|
50 |
+
MAX_SEQ_LEN = param
|
51 |
+
return MAX_SEQ_LEN
|
52 |
+
|
53 |
+
############################### GENERAL UTILS #################################
|
54 |
+
def grouper(iterable, batch_size, fill_value=None):
|
55 |
+
"""Helper method for returning batches of size batch_size of a dataset."""
|
56 |
+
# grouper('ABCDEF', 3) -> 'ABC', 'DEF'
|
57 |
+
args = [iter(iterable)] * batch_size
|
58 |
+
return zip_longest(*args, fillvalue=fill_value)
|
59 |
+
|
60 |
+
|
61 |
+
def _map_uni_to_alphanum(uni):
|
62 |
+
"""Maps [0-9 A-Z a-z] to numbers 0-62."""
|
63 |
+
if 48 <= uni <= 57:
|
64 |
+
return uni - 48
|
65 |
+
elif 65 <= uni <= 90:
|
66 |
+
return uni - 65 + 10
|
67 |
+
return uni - 97 + 36
|
68 |
+
|
69 |
+
|
70 |
+
def _map_uni_to_alpha(uni):
|
71 |
+
"""Maps [A-Z a-z] to numbers 0-52."""
|
72 |
+
if 65 <= uni <= 90:
|
73 |
+
return uni - 65
|
74 |
+
return uni - 97 + 26
|
75 |
+
|
76 |
+
|
77 |
+
############# UTILS FOR CONVERTING SFD/SPLINESETS TO SVG PATHS ################
|
78 |
+
def _get_spline(sfd):
|
79 |
+
if 'SplineSet' not in sfd:
|
80 |
+
return ''
|
81 |
+
pro = sfd[sfd.index('SplineSet') + 10:] # 10 is the 'SplineSet'
|
82 |
+
pro = pro[:pro.index('EndSplineSet')]
|
83 |
+
return pro
|
84 |
+
|
85 |
+
|
86 |
+
def _spline_to_path_list(spline, height, replace_with_prev=False):
|
87 |
+
"""Converts SplineSet to a list of tokenized commands in svg path."""
|
88 |
+
path = []
|
89 |
+
prev_xy = []
|
90 |
+
for line in spline.splitlines():
|
91 |
+
if not line:
|
92 |
+
continue
|
93 |
+
tokens = line.split(' ')
|
94 |
+
cmd = tokens[-2]
|
95 |
+
if cmd not in 'cml':
|
96 |
+
# COMMAND NOT RECOGNIZED.
|
97 |
+
return []
|
98 |
+
# assert cmd in 'cml', 'Command not recognized: {}'.format(cmd)
|
99 |
+
args = tokens[:-2]
|
100 |
+
args = [float(x) for x in args if x]
|
101 |
+
|
102 |
+
if replace_with_prev and cmd in 'c':
|
103 |
+
args[:2] = prev_xy
|
104 |
+
prev_xy = args[-2:]
|
105 |
+
|
106 |
+
new_y_args = []
|
107 |
+
for i, a in enumerate(args):
|
108 |
+
if i % 2 == 1:
|
109 |
+
new_y_args.append((height - a))
|
110 |
+
else:
|
111 |
+
new_y_args.append((a))
|
112 |
+
|
113 |
+
path.append([cmd.upper()] + new_y_args)
|
114 |
+
return path
|
115 |
+
|
116 |
+
|
117 |
+
def _sfd_to_path_list(single, replace_with_prev=False):
|
118 |
+
"""Converts the given SFD glyph into a path."""
|
119 |
+
return _spline_to_path_list(_get_spline(single['sfd']), single['vwidth'], replace_with_prev)
|
120 |
+
|
121 |
+
|
122 |
+
#################### UTILS FOR PROCESSING TOKENIZED PATHS #####################
|
123 |
+
def _add_missing_cmds(path, remove_zs=False):
|
124 |
+
"""Adds missing cmd tags to the commands in the svg."""
|
125 |
+
# For instance, the command 'a' takes 7 arguments, but some SVGs declare:
|
126 |
+
# a 1 2 3 4 5 6 7 8 9 10 11 12 13 14
|
127 |
+
# Which is 14 arguments. This function converts the above to the equivalent:
|
128 |
+
# a 1 2 3 4 5 6 7 a 8 9 10 11 12 13 14
|
129 |
+
#
|
130 |
+
# Note: if remove_zs is True, this also removes any occurences of z commands.
|
131 |
+
new_path = []
|
132 |
+
for cmd in path:
|
133 |
+
if not remove_zs or cmd[0] not in 'Zz':
|
134 |
+
for new_cmd in add_missing_cmd(cmd):
|
135 |
+
new_path.append(new_cmd)
|
136 |
+
return new_path
|
137 |
+
|
138 |
+
|
139 |
+
def add_missing_cmd(command_list):
|
140 |
+
"""Adds missing cmd tags to the given command list."""
|
141 |
+
# E.g.: given:
|
142 |
+
# ['a', '0', '0', '0', '0', '0', '0', '0',
|
143 |
+
# '0', '0', '0', '0', '0', '0', '0']
|
144 |
+
# Converts to:
|
145 |
+
# [['a', '0', '0', '0', '0', '0', '0', '0'],
|
146 |
+
# ['a', '0', '0', '0', '0', '0', '0', '0']]
|
147 |
+
# And returns a string that joins these elements with spaces.
|
148 |
+
cmd_tag = command_list[0]
|
149 |
+
args = command_list[1:]
|
150 |
+
|
151 |
+
final_cmds = []
|
152 |
+
for arg_batch in grouper(args, NUM_ARGS[cmd_tag]):
|
153 |
+
final_cmds.append([cmd_tag] + list(arg_batch))
|
154 |
+
|
155 |
+
if not final_cmds:
|
156 |
+
# command has no args (e.g.: 'z')
|
157 |
+
final_cmds = [[cmd_tag]]
|
158 |
+
|
159 |
+
return final_cmds
|
160 |
+
|
161 |
+
|
162 |
+
def _normalize_args(arglist, norm, add=None, flip=False):
|
163 |
+
"""Normalize the given args with the given norm value."""
|
164 |
+
new_arglist = []
|
165 |
+
for i, arg in enumerate(arglist):
|
166 |
+
new_arg = float(arg)
|
167 |
+
|
168 |
+
if add is not None:
|
169 |
+
add_to_x, add_to_y = add
|
170 |
+
|
171 |
+
# This argument is an x-coordinate if even, y-coordinate if odd
|
172 |
+
# except when flip == True
|
173 |
+
if i % 2 == 0:
|
174 |
+
new_arg += add_to_y if flip else add_to_x
|
175 |
+
else:
|
176 |
+
new_arg += add_to_x if flip else add_to_y
|
177 |
+
|
178 |
+
new_arglist.append(str(24 * new_arg / norm))
|
179 |
+
return new_arglist
|
180 |
+
|
181 |
+
|
182 |
+
def _normalize_based_on_viewbox(path, viewbox):
|
183 |
+
"""Normalizes all args in a path to a standard 24x24 viewbox."""
|
184 |
+
# Each SVG lives in a 2D plane. The viewbox determines the region of that
|
185 |
+
# plane that gets rendered. For instance, some designers may work with a
|
186 |
+
# viewbox that's 24x24, others with one that's 100x100, etc.
|
187 |
+
|
188 |
+
# Suppose I design the the letter "h" in the Arial style using a 100x100
|
189 |
+
# viewbox (let's call it icon A). Let's suppose the icon has height 75. Then,
|
190 |
+
# I design the same character using a 20x20 viewbox (call this icon B), with
|
191 |
+
# height 15 (=75% of 20). This means that, when rendered, both icons with look
|
192 |
+
# exactly the same, but the scale of the commands each icon is using is
|
193 |
+
# different. For instance, if icon A has a command like "lineTo 100 100", the
|
194 |
+
# equivalent command in icon B will be "lineTo 20 20".
|
195 |
+
|
196 |
+
# In order to avoid this problem and bring all real values to the same scale,
|
197 |
+
# I scale all icons' commands to use a 24x24 viewbox. This function does this:
|
198 |
+
# it converts a path that exists in the given viewbox into a standard 24x24
|
199 |
+
# viewbox.
|
200 |
+
viewbox = viewbox.split(' ')
|
201 |
+
norm = max(int(viewbox[-1]), int(viewbox[-2]))
|
202 |
+
|
203 |
+
if int(viewbox[-1]) > int(viewbox[-2]):
|
204 |
+
add_to_y = 0
|
205 |
+
add_to_x = abs(int(viewbox[-1]) - int(viewbox[-2])) / 2
|
206 |
+
else:
|
207 |
+
add_to_y = abs(int(viewbox[-1]) - int(viewbox[-2])) / 2
|
208 |
+
add_to_x = 0
|
209 |
+
|
210 |
+
new_path = []
|
211 |
+
for command in path:
|
212 |
+
if command[0] == 'a':
|
213 |
+
new_path.append([command[0]] + _normalize_args(command[1:3], norm)
|
214 |
+
+ command[3:6] + _normalize_args(command[6:], norm))
|
215 |
+
elif command[0] == 'A':
|
216 |
+
new_path.append([command[0]] + _normalize_args(command[1:3], norm)
|
217 |
+
+ command[3:6] + _normalize_args(command[6:], norm, add=(add_to_x, add_to_y)))
|
218 |
+
elif command[0] == 'V':
|
219 |
+
new_path.append([command[0]] + _normalize_args(command[1:], norm, add=(add_to_x, add_to_y), flip=True))
|
220 |
+
elif command[0] == command[0].upper():
|
221 |
+
new_path.append([command[0]] + _normalize_args(command[1:], norm, add=(add_to_x, add_to_y)))
|
222 |
+
elif command[0] in 'zZ':
|
223 |
+
new_path.append([command[0]])
|
224 |
+
else:
|
225 |
+
new_path.append([command[0]] + _normalize_args(command[1:], norm))
|
226 |
+
|
227 |
+
return new_path
|
228 |
+
|
229 |
+
|
230 |
+
def _convert_args(args, curr_pos, cmd):
|
231 |
+
"""Converts given args to relative values."""
|
232 |
+
# NOTE: glyphs only use a very small subset of commands (L, C, M, and Z -- I
|
233 |
+
# believe). So I'm not handling A and H for now.
|
234 |
+
if cmd in 'AH':
|
235 |
+
raise NotImplementedError('These commands have >6 args (not supported).')
|
236 |
+
|
237 |
+
new_args = []
|
238 |
+
for i, arg in enumerate(args):
|
239 |
+
x_or_y = i % 2
|
240 |
+
if cmd == 'H':
|
241 |
+
x_or_y = (i + 1) % 2
|
242 |
+
new_args.append(str(float(arg) - curr_pos[x_or_y]))
|
243 |
+
|
244 |
+
return new_args
|
245 |
+
|
246 |
+
|
247 |
+
def _update_curr_pos(curr_pos, cmd, start_of_path):
|
248 |
+
"""Calculate the position of the pen after cmd is applied."""
|
249 |
+
if cmd[0] in 'ml':
|
250 |
+
curr_pos = [curr_pos[0] + float(cmd[1]), curr_pos[1] + float(cmd[2])]
|
251 |
+
if cmd[0] == 'm':
|
252 |
+
start_of_path = curr_pos
|
253 |
+
elif cmd[0] in 'z':
|
254 |
+
curr_pos = start_of_path
|
255 |
+
elif cmd[0] in 'h':
|
256 |
+
curr_pos = [curr_pos[0] + float(cmd[1]), curr_pos[1]]
|
257 |
+
elif cmd[0] in 'v':
|
258 |
+
curr_pos = [curr_pos[0], curr_pos[1] + float(cmd[1])]
|
259 |
+
elif cmd[0] in 'ctsqa':
|
260 |
+
curr_pos = [curr_pos[0] + float(cmd[-2]), curr_pos[1] + float(cmd[-1])]
|
261 |
+
|
262 |
+
return curr_pos, start_of_path
|
263 |
+
|
264 |
+
|
265 |
+
def _make_relative(cmds):
|
266 |
+
"""Convert commands in a path to relative positioning."""
|
267 |
+
curr_pos = (0.0, 0.0)
|
268 |
+
start_of_path = (0.0, 0.0)
|
269 |
+
new_cmds = []
|
270 |
+
for cmd in cmds:
|
271 |
+
if cmd[0].lower() == cmd[0]:
|
272 |
+
new_cmd = cmd
|
273 |
+
elif cmd[0].lower() == 'z':
|
274 |
+
new_cmd = [cmd[0].lower()]
|
275 |
+
else:
|
276 |
+
new_cmd = [cmd[0].lower()] + _convert_args(cmd[1:], curr_pos, cmd=cmd[0])
|
277 |
+
new_cmds.append(new_cmd)
|
278 |
+
curr_pos, start_of_path = _update_curr_pos(curr_pos, new_cmd, start_of_path)
|
279 |
+
return new_cmds
|
280 |
+
|
281 |
+
|
282 |
+
def _is_to_left_of(pt1, pt2):
|
283 |
+
pt1_norm = (pt1[0]**2 + pt1[1]**2)
|
284 |
+
pt2_norm = (pt2[0]**2 + pt2[1]**2)
|
285 |
+
return pt1[1] < pt2[1] or (pt1_norm == pt2_norm and pt1[0] < pt2[0])
|
286 |
+
|
287 |
+
|
288 |
+
def _get_leftmost_point(path):
|
289 |
+
"""Returns the leftmost, topmost point of the path."""
|
290 |
+
leftmost = (float('inf'), float('inf'))
|
291 |
+
idx = -1
|
292 |
+
|
293 |
+
for i, cmd in enumerate(path):
|
294 |
+
if len(cmd) > 1:
|
295 |
+
endpoint = cmd[-2:]
|
296 |
+
if _is_to_left_of(endpoint, leftmost):
|
297 |
+
leftmost = endpoint
|
298 |
+
idx = i
|
299 |
+
|
300 |
+
return leftmost, idx
|
301 |
+
|
302 |
+
|
303 |
+
def _separate_substructures(path):
|
304 |
+
"""Returns a list of subpaths, each representing substructures the glyph."""
|
305 |
+
substructures = []
|
306 |
+
curr = []
|
307 |
+
for cmd in path:
|
308 |
+
if cmd[0] in 'mM' and curr:
|
309 |
+
substructures.append(curr)
|
310 |
+
curr = []
|
311 |
+
curr.append(cmd)
|
312 |
+
if curr:
|
313 |
+
substructures.append(curr)
|
314 |
+
return substructures
|
315 |
+
|
316 |
+
|
317 |
+
def _is_clockwise(subpath):
|
318 |
+
"""Returns whether the given subpath is clockwise-oriented."""
|
319 |
+
pts = [cmd[-2:] for cmd in subpath]
|
320 |
+
det = 0
|
321 |
+
for i in range(len(pts) - 1):
|
322 |
+
det += np.linalg.det(pts[i:i + 2])
|
323 |
+
return det > 0
|
324 |
+
|
325 |
+
|
326 |
+
def _make_clockwise(subpath):
|
327 |
+
"""Inverts the cardinality of the given subpath."""
|
328 |
+
new_path = [subpath[0]]
|
329 |
+
other_cmds = list(reversed(subpath[1:]))
|
330 |
+
for i, cmd in enumerate(other_cmds):
|
331 |
+
if i + 1 == len(other_cmds):
|
332 |
+
where_we_were = subpath[0][-2:]
|
333 |
+
else:
|
334 |
+
where_we_were = other_cmds[i + 1][-2:]
|
335 |
+
|
336 |
+
if len(cmd) > 3:
|
337 |
+
new_cmd = [cmd[0], cmd[3], cmd[4], cmd[1], cmd[2],
|
338 |
+
where_we_were[0], where_we_were[1]]
|
339 |
+
else:
|
340 |
+
new_cmd = [cmd[0], where_we_were[0], where_we_were[1]]
|
341 |
+
|
342 |
+
new_path.append(new_cmd)
|
343 |
+
return new_path
|
344 |
+
|
345 |
+
|
346 |
+
def _canonicalize(path):
|
347 |
+
"""Makes all paths start at top left, and go clockwise first."""
|
348 |
+
# convert args to floats
|
349 |
+
path = [[x[0]] + list(map(float, x[1:])) for x in path]
|
350 |
+
|
351 |
+
# _canonicalize each subpath separately
|
352 |
+
new_substructures = []
|
353 |
+
for subpath in _separate_substructures(path):
|
354 |
+
leftmost_point, leftmost_idx = _get_leftmost_point(subpath)
|
355 |
+
reordered = ([['M', leftmost_point[0], leftmost_point[1]]] + subpath[leftmost_idx + 1:] + subpath[1:leftmost_idx + 1])
|
356 |
+
new_substructures.append((reordered, leftmost_point))
|
357 |
+
|
358 |
+
new_path = []
|
359 |
+
first_substructure_done = False
|
360 |
+
should_flip_cardinality = False
|
361 |
+
for sp, _ in sorted(new_substructures, key=lambda x: (x[1][1], x[1][0])):
|
362 |
+
if not first_substructure_done:
|
363 |
+
# we're looking at the first substructure now, we can determine whether we
|
364 |
+
# will flip the cardniality of the whole icon or not
|
365 |
+
should_flip_cardinality = not _is_clockwise(sp)
|
366 |
+
first_substructure_done = True
|
367 |
+
|
368 |
+
if should_flip_cardinality:
|
369 |
+
sp = _make_clockwise(sp)
|
370 |
+
|
371 |
+
new_path.extend(sp)
|
372 |
+
|
373 |
+
# convert args to strs
|
374 |
+
path = [[x[0]] + list(map(str, x[1:])) for x in new_path]
|
375 |
+
return path
|
376 |
+
|
377 |
+
|
378 |
+
# ######### UTILS FOR CONVERTING TOKENIZED PATHS TO VECTORS ###########
|
379 |
+
def _path_to_vector(path, categorical=False):
|
380 |
+
"""Converts path's commands to a series of vectors."""
|
381 |
+
# Notes:
|
382 |
+
# - The SimpleSVG dataset does not have any 't', 'q', 'Z', 'T', or 'Q'.
|
383 |
+
# Thus, we don't handle those here.
|
384 |
+
# - We also removed all 'z's.
|
385 |
+
# - The x-axis-rotation argument to a commands is always 0 in this
|
386 |
+
# dataset, so we ignore it
|
387 |
+
|
388 |
+
# Many commands have args that correspond to args in other commands.
|
389 |
+
# v __,__ _______________ ______________,_________ __,__ __,__ _,y
|
390 |
+
# h __,__ _______________ ______________,_________ __,__ __,__ x,_
|
391 |
+
# z __,__ _______________ ______________,_________ __,__ __,__ _,_
|
392 |
+
# a rx,ry x-axis-rotation large-arc-flag,sweepflag __,__ __,__ x,y
|
393 |
+
# l __,__ _______________ ______________,_________ __,__ __,__ x,y
|
394 |
+
# c __,__ _______________ ______________,_________ x1,y1 x2,y2 x,y
|
395 |
+
# m __,__ _______________ ______________,_________ __,__ __,__ x,y
|
396 |
+
# s __,__ _______________ ______________,_________ __,__ x2,y2 x,y
|
397 |
+
|
398 |
+
# So each command will be converted to a vector where the dimension is the
|
399 |
+
# minimal number of arguments to all commands:
|
400 |
+
# [rx, ry, large-arc-flag, sweepflag, x1, y1, x2, y2, x, y]
|
401 |
+
# If a command does not output a certain arg, it is set to 0.
|
402 |
+
# "l 5,5" becomes [0, 0, 0, 0, 0, 0, 0, 0, 5, 5]
|
403 |
+
|
404 |
+
# Also note, as of now we also output an extra dimension at index 0, which
|
405 |
+
# indicates which command is being outputted (integer).
|
406 |
+
new_path = []
|
407 |
+
for cmd in path:
|
408 |
+
new_path.append(_cmd_to_vector(cmd, categorical=categorical))
|
409 |
+
return new_path
|
410 |
+
|
411 |
+
|
412 |
+
def _cmd_to_vector(cmd_list, categorical=False):
|
413 |
+
"""Converts the given command (given as a list) into a vector."""
|
414 |
+
# For description of how this conversion happens, see
|
415 |
+
# _path_to_vector docstring.
|
416 |
+
cmd = cmd_list[0]
|
417 |
+
args = cmd_list[1:]
|
418 |
+
|
419 |
+
if not categorical:
|
420 |
+
# integer, for MSE
|
421 |
+
command = [float(CMD_MAPPING[cmd])]
|
422 |
+
else:
|
423 |
+
# one hot + 1 dim for EOS.
|
424 |
+
command = [0.0] * (len(CMDS_LIST) + 1)
|
425 |
+
command[CMD_MAPPING[cmd] + 1] = 1.0
|
426 |
+
|
427 |
+
arguments = [0.0] * 10
|
428 |
+
if cmd in 'hH':
|
429 |
+
arguments[8] = float(args[0]) # x
|
430 |
+
elif cmd in 'vV':
|
431 |
+
arguments[9] = float(args[0]) # y
|
432 |
+
elif cmd in 'mMlLtT':
|
433 |
+
arguments[8] = float(args[0]) # x
|
434 |
+
arguments[9] = float(args[1]) # y
|
435 |
+
elif cmd in 'sSqQ':
|
436 |
+
arguments[6] = float(args[0]) # x2
|
437 |
+
arguments[7] = float(args[1]) # y2
|
438 |
+
arguments[8] = float(args[2]) # x
|
439 |
+
arguments[9] = float(args[3]) # y
|
440 |
+
elif cmd in 'cC':
|
441 |
+
arguments[4] = float(args[0]) # x1
|
442 |
+
arguments[5] = float(args[1]) # y1
|
443 |
+
arguments[6] = float(args[2]) # x2
|
444 |
+
arguments[7] = float(args[3]) # y2
|
445 |
+
arguments[8] = float(args[4]) # x
|
446 |
+
arguments[9] = float(args[5]) # y
|
447 |
+
elif cmd in 'aA':
|
448 |
+
arguments[0] = float(args[0]) # rx
|
449 |
+
arguments[1] = float(args[1]) # ry
|
450 |
+
# we skip x-axis-rotation
|
451 |
+
arguments[2] = float(args[3]) # large-arc-flag
|
452 |
+
arguments[3] = float(args[4]) # sweep-flag
|
453 |
+
# a does not have x1, y1, x2, y2 args
|
454 |
+
arguments[8] = float(args[5]) # x
|
455 |
+
arguments[9] = float(args[6]) # y
|
456 |
+
|
457 |
+
return command + arguments
|
458 |
+
|
459 |
+
|
460 |
+
################## UTILS FOR RENDERING PATH INTO IMAGE #################
|
461 |
+
def _cubicbezier(x0, y0, x1, y1, x2, y2, x3, y3, n=40):
|
462 |
+
"""Return n points along cubiz bezier with given control points."""
|
463 |
+
# from http://rosettacode.org/wiki/Bitmap/B%C3%A9zier_curves/Cubic
|
464 |
+
pts = []
|
465 |
+
for i in range(n + 1):
|
466 |
+
t = float(i) / float(n)
|
467 |
+
a = (1. - t)**3
|
468 |
+
b = 3. * t * (1. - t)**2
|
469 |
+
c = 3.0 * t**2 * (1.0 - t)
|
470 |
+
d = t**3
|
471 |
+
|
472 |
+
x = float(a * x0 + b * x1 + c * x2 + d * x3)
|
473 |
+
y = float(a * y0 + b * y1 + c * y2 + d * y3)
|
474 |
+
pts.append((x, y))
|
475 |
+
return list(zip(*pts))
|
476 |
+
|
477 |
+
|
478 |
+
def _update_pos(curr_pos, end_pos, absolute):
|
479 |
+
if absolute:
|
480 |
+
return end_pos
|
481 |
+
return curr_pos[0] + end_pos[0], curr_pos[1] + end_pos[1]
|
482 |
+
|
483 |
+
|
484 |
+
def constant_color(*unused_args):
|
485 |
+
return np.array([255, 255, 255])
|
486 |
+
|
487 |
+
|
488 |
+
def _render_cubic(canvas, curr_pos, c_args, absolute, color):
|
489 |
+
"""Renders a cubic bezier curve in the given canvas."""
|
490 |
+
if not absolute:
|
491 |
+
c_args[0] += curr_pos[0]
|
492 |
+
c_args[1] += curr_pos[1]
|
493 |
+
c_args[2] += curr_pos[0]
|
494 |
+
c_args[3] += curr_pos[1]
|
495 |
+
c_args[4] += curr_pos[0]
|
496 |
+
c_args[5] += curr_pos[1]
|
497 |
+
x, y = _cubicbezier(curr_pos[0], curr_pos[1],
|
498 |
+
c_args[0], c_args[1],
|
499 |
+
c_args[2], c_args[3],
|
500 |
+
c_args[4], c_args[5])
|
501 |
+
max_possible = len(canvas)
|
502 |
+
x = [int(round(x_)) for x_ in x]
|
503 |
+
y = [int(round(y_)) for y_ in y]
|
504 |
+
|
505 |
+
def within_range(x):
|
506 |
+
return 0 <= x < max_possible
|
507 |
+
|
508 |
+
filtered = [(x_, y_) for x_, y_ in zip(x, y)
|
509 |
+
if within_range(x_) and within_range(y_)]
|
510 |
+
if not filtered:
|
511 |
+
return
|
512 |
+
x, y = list(zip(*filtered))
|
513 |
+
canvas[y, x, :] = color
|
514 |
+
|
515 |
+
|
516 |
+
def _render_line(canvas, curr_pos, l_args, absolute, color):
|
517 |
+
"""Renders a line in the given canvas."""
|
518 |
+
end_point = l_args
|
519 |
+
if not absolute:
|
520 |
+
end_point[0] += curr_pos[0]
|
521 |
+
end_point[1] += curr_pos[1]
|
522 |
+
rr, cc, val = draw.line_aa(int(curr_pos[0]), int(curr_pos[1]),
|
523 |
+
int(end_point[0]), int(end_point[1]))
|
524 |
+
|
525 |
+
max_possible = len(canvas)
|
526 |
+
|
527 |
+
def within_range(x):
|
528 |
+
return 0 <= x < max_possible
|
529 |
+
|
530 |
+
filtered = [(x, y, v) for x, y, v in zip(rr, cc, val)
|
531 |
+
if within_range(x) and within_range(y)]
|
532 |
+
if not filtered:
|
533 |
+
return
|
534 |
+
rr, cc, val = list(zip(*filtered))
|
535 |
+
val = [(v * color) for v in val]
|
536 |
+
canvas[cc, rr, :] = val
|
537 |
+
|
538 |
+
|
539 |
+
def _per_step_render(path, absolute=False, color=constant_color):
|
540 |
+
"""Render the icon's edges, given its path."""
|
541 |
+
def to_canvas_size(l):
|
542 |
+
return [float(f) * (64. / 24.) for f in l]
|
543 |
+
|
544 |
+
canvas = np.zeros((64, 64, 3))
|
545 |
+
curr_pos = (0.0, 0.0)
|
546 |
+
for i, cmd in enumerate(path):
|
547 |
+
if not cmd:
|
548 |
+
continue
|
549 |
+
if cmd[0] in 'mM':
|
550 |
+
curr_pos = _update_pos(curr_pos, to_canvas_size(cmd[-2:]), absolute)
|
551 |
+
elif cmd[0] in 'cC':
|
552 |
+
_render_cubic(canvas, curr_pos, to_canvas_size(cmd[1:]), absolute, color(i, 55))
|
553 |
+
curr_pos = _update_pos(curr_pos, to_canvas_size(cmd[-2:]), absolute)
|
554 |
+
elif cmd[0] in 'lL':
|
555 |
+
_render_line(canvas, curr_pos, to_canvas_size(cmd[1:]), absolute, color(i, 55))
|
556 |
+
curr_pos = _update_pos(curr_pos, to_canvas_size(cmd[1:]), absolute)
|
557 |
+
|
558 |
+
return canvas
|
559 |
+
|
560 |
+
|
561 |
+
def _zoom_out(path_list, add_baseline=0., per=22):
|
562 |
+
"""Makes glyph slightly smaller in viewbox, makes some descenders visible."""
|
563 |
+
# assumes tensor is already unnormalized, and in long form
|
564 |
+
new_path = []
|
565 |
+
for command in path_list:
|
566 |
+
args = []
|
567 |
+
is_even = False
|
568 |
+
for arg in command[1:]:
|
569 |
+
if is_even:
|
570 |
+
args.append(str(float(arg) - ((24. - per) / 24.) * 64. / 4.))
|
571 |
+
is_even = False
|
572 |
+
else:
|
573 |
+
args.append(str(float(arg) - add_baseline))
|
574 |
+
is_even = True
|
575 |
+
new_path.append([command[0]] + args)
|
576 |
+
return new_path
|
577 |
+
|
578 |
+
|
579 |
+
##################### UTILS FOR PROCESSING VECTORS ################
|
580 |
+
def _append_eos(sample, categorical, feature_dim):
|
581 |
+
if not categorical:
|
582 |
+
eos = -1 * np.ones(feature_dim)
|
583 |
+
else:
|
584 |
+
eos = np.zeros(feature_dim)
|
585 |
+
eos[0] = 1.0
|
586 |
+
sample.append(eos)
|
587 |
+
return sample
|
588 |
+
|
589 |
+
|
590 |
+
def _make_simple_cmds_long(out):
|
591 |
+
"""Converts svg decoder output to format required by some render functions."""
|
592 |
+
# out has 10 dims
|
593 |
+
# the first 4 are respectively dims 0, 4, 5, 9 of the full 20-dim onehot vec
|
594 |
+
# the latter 6 are the 6 last dims of the 10-dim arg vec
|
595 |
+
shape_minus_dim = list(np.shape(out))[:-1]
|
596 |
+
return np.concatenate([out[..., :1],
|
597 |
+
np.zeros(shape_minus_dim + [3]),
|
598 |
+
out[..., 1:3],
|
599 |
+
np.zeros(shape_minus_dim + [3]),
|
600 |
+
out[..., 3:4],
|
601 |
+
np.zeros(shape_minus_dim + [14]),
|
602 |
+
out[..., 4:]], -1)
|
603 |
+
|
604 |
+
|
605 |
+
################# UTILS FOR CONVERTING VECTORS TO SVGS ########################
|
606 |
+
def _vector_to_svg(vectors, stop_at_eos=False, categorical=False):
|
607 |
+
"""Tranforms a given vector to an svg string."""
|
608 |
+
new_path = []
|
609 |
+
for vector in vectors:
|
610 |
+
if stop_at_eos:
|
611 |
+
if categorical:
|
612 |
+
try:
|
613 |
+
is_eos = np.argmax(vector[:len(CMDS_LIST) + 1]) == 0
|
614 |
+
except Exception:
|
615 |
+
raise Exception(vector)
|
616 |
+
else:
|
617 |
+
is_eos = vector[0] < -0.5
|
618 |
+
|
619 |
+
if is_eos:
|
620 |
+
break
|
621 |
+
new_path.append(' '.join(_vector_to_cmd(vector, categorical=categorical)))
|
622 |
+
new_path = ' '.join(new_path)
|
623 |
+
return SVG_PREFIX_BIG + PATH_PREFIX_1 + new_path + PATH_POSFIX_1 + SVG_POSFIX
|
624 |
+
|
625 |
+
|
626 |
+
def _vector_to_cmd(vector, categorical=False, return_floats=False):
|
627 |
+
"""Does the inverse transformation as _cmd_to_vector()."""
|
628 |
+
cast_fn = float if return_floats else str
|
629 |
+
if categorical:
|
630 |
+
command = vector[:len(CMDS_LIST) + 1],
|
631 |
+
arguments = vector[len(CMDS_LIST) + 1:]
|
632 |
+
cmd_idx = np.argmax(command) - 1
|
633 |
+
else:
|
634 |
+
command, arguments = vector[:1], vector[1:]
|
635 |
+
cmd_idx = int(round(command[0]))
|
636 |
+
|
637 |
+
if cmd_idx < -0.5:
|
638 |
+
# EOS
|
639 |
+
return []
|
640 |
+
if cmd_idx >= len(CMDS_LIST):
|
641 |
+
cmd_idx = len(CMDS_LIST) - 1
|
642 |
+
|
643 |
+
cmd = CMDS_LIST[cmd_idx]
|
644 |
+
cmd = cmd.upper()
|
645 |
+
cmd_list = [cmd]
|
646 |
+
|
647 |
+
if cmd in 'hH':
|
648 |
+
cmd_list.append(cast_fn(arguments[8])) # x
|
649 |
+
elif cmd in 'vV':
|
650 |
+
cmd_list.append(cast_fn(arguments[9])) # y
|
651 |
+
elif cmd in 'mMlLtT':
|
652 |
+
cmd_list.append(cast_fn(arguments[8])) # x
|
653 |
+
cmd_list.append(cast_fn(arguments[9])) # y
|
654 |
+
elif cmd in 'sSqQ':
|
655 |
+
cmd_list.append(cast_fn(arguments[6])) # x2
|
656 |
+
cmd_list.append(cast_fn(arguments[7])) # y2
|
657 |
+
cmd_list.append(cast_fn(arguments[8])) # x
|
658 |
+
cmd_list.append(cast_fn(arguments[9])) # y
|
659 |
+
elif cmd in 'cC':
|
660 |
+
cmd_list.append(cast_fn(arguments[4])) # x1
|
661 |
+
cmd_list.append(cast_fn(arguments[5])) # y1
|
662 |
+
cmd_list.append(cast_fn(arguments[6])) # x2
|
663 |
+
cmd_list.append(cast_fn(arguments[7])) # y2
|
664 |
+
cmd_list.append(cast_fn(arguments[8])) # x
|
665 |
+
cmd_list.append(cast_fn(arguments[9])) # y
|
666 |
+
elif cmd in 'aA':
|
667 |
+
cmd_list.append(cast_fn(arguments[0])) # rx
|
668 |
+
cmd_list.append(cast_fn(arguments[1])) # ry
|
669 |
+
# x-axis-rotation is always 0
|
670 |
+
cmd_list.append(cast_fn('0'))
|
671 |
+
# the following two flags are binary.
|
672 |
+
cmd_list.append(cast_fn(1 if arguments[2] > 0.5 else 0)) # large-arc-flag
|
673 |
+
cmd_list.append(cast_fn(1 if arguments[3] > 0.5 else 0)) # sweep-flag
|
674 |
+
cmd_list.append(cast_fn(arguments[8])) # x
|
675 |
+
cmd_list.append(cast_fn(arguments[9])) # y
|
676 |
+
|
677 |
+
return cmd_list
|
678 |
+
|
679 |
+
|
680 |
+
############## UTILS FOR CONVERTING SVGS/VECTORS TO IMAGES ###################
|
681 |
+
|
682 |
+
# From Infer notebook
|
683 |
+
start = ("""<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www."""
|
684 |
+
"""w3.org/1999/xlink" width="256px" height="256px" style="-ms-trans"""
|
685 |
+
"""form: rotate(360deg); -webkit-transform: rotate(360deg); transfo"""
|
686 |
+
"""rm: rotate(360deg);" preserveAspectRatio="xMidYMid meet" viewBox"""
|
687 |
+
"""="0 0 24 30"><path d=\"""")
|
688 |
+
end = """\" fill="currentColor"/></svg>"""
|
689 |
+
|
690 |
+
COMMAND_RX = re.compile("([MmLlHhVvCcSsQqTtAaZz])")
|
691 |
+
FLOAT_RX = re.compile("[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?") # noqa
|
692 |
+
|
693 |
+
|
694 |
+
def svg_html_to_path_string(svg):
|
695 |
+
return svg.replace(start, '').replace(end, '')
|
696 |
+
|
697 |
+
|
698 |
+
def _tokenize(pathdef):
|
699 |
+
"""Returns each svg token from path list."""
|
700 |
+
# e.g.: 'm0.1-.5c0,6' -> m', '0.1, '-.5', 'c', '0', '6'
|
701 |
+
for x in COMMAND_RX.split(pathdef):
|
702 |
+
if x != '' and x in 'MmLlHhVvCcSsQqTtAaZz':
|
703 |
+
yield x
|
704 |
+
for token in FLOAT_RX.findall(x):
|
705 |
+
yield token
|
706 |
+
|
707 |
+
|
708 |
+
def path_string_to_tokenized_commands(path):
|
709 |
+
"""Tokenizes the given path string.
|
710 |
+
|
711 |
+
E.g.:
|
712 |
+
Given M 0.5 0.5 l 0.25 0.25 z
|
713 |
+
Returns [['M', '0.5', '0.5'], ['l', '0.25', '0.25'], ['z']]
|
714 |
+
"""
|
715 |
+
new_path = []
|
716 |
+
current_cmd = []
|
717 |
+
for token in _tokenize(path):
|
718 |
+
if len(current_cmd) > 0:
|
719 |
+
if token in 'MmLlHhVvCcSsQqTtAaZz':
|
720 |
+
# cmd ended, convert to vector and add to new_path
|
721 |
+
new_path.append(current_cmd)
|
722 |
+
current_cmd = [token]
|
723 |
+
else:
|
724 |
+
# add arg to command
|
725 |
+
current_cmd.append(token)
|
726 |
+
else:
|
727 |
+
# add to start new cmd
|
728 |
+
current_cmd.append(token)
|
729 |
+
|
730 |
+
if current_cmd:
|
731 |
+
# process command still unprocessed
|
732 |
+
new_path.append(current_cmd)
|
733 |
+
|
734 |
+
return new_path
|
735 |
+
|
736 |
+
|
737 |
+
def separate_substructures(tokenized_commands):
|
738 |
+
"""Returns a list of SVG substructures."""
|
739 |
+
# every moveTo command starts a new substructure
|
740 |
+
# an SVG substructure is a subpath that closes on itself
|
741 |
+
# such as the outter and the inner edge of the character `o`
|
742 |
+
substructures = []
|
743 |
+
curr = []
|
744 |
+
for cmd in tokenized_commands:
|
745 |
+
if cmd[0] in 'mM' and len(curr) > 0:
|
746 |
+
substructures.append(curr)
|
747 |
+
curr = []
|
748 |
+
curr.append(cmd)
|
749 |
+
if len(curr) > 0:
|
750 |
+
substructures.append(curr)
|
751 |
+
return substructures
|
752 |
+
|
753 |
+
|
754 |
+
def postprocess(svg, dist_thresh=2., skip=False):
|
755 |
+
path = svg_html_to_path_string(svg)
|
756 |
+
svg_template = svg.replace(path, '{}')
|
757 |
+
tokenized_commands = path_string_to_tokenized_commands(path)
|
758 |
+
|
759 |
+
def dist(a, b):
|
760 |
+
return np.sqrt((float(a[0]) - float(b[0]))**2 + (float(a[1]) - float(b[1]))**2)
|
761 |
+
|
762 |
+
def are_close_together(a, b, t):
|
763 |
+
return dist(a, b) < t
|
764 |
+
|
765 |
+
# first, go through each start/end point and merge if they're close enough
|
766 |
+
# together (that is, make end point the same as the start point).
|
767 |
+
# TODO: there are better ways of doing this, in a way that propagates error
|
768 |
+
# back (so if total error is 0.2, go through all N commands in this
|
769 |
+
# substructure and fix each by 0.2/N (unless they have 0 vertical change))
|
770 |
+
substructures = separate_substructures(tokenized_commands)
|
771 |
+
|
772 |
+
previous_substructure_endpoint = (0., 0.,)
|
773 |
+
for substructure in substructures:
|
774 |
+
# first, if the last substructure's endpoint was updated, we must update
|
775 |
+
# the start point of this one to reflect the opposite update
|
776 |
+
substructure[0][-2] = str(float(substructure[0][-2]) -
|
777 |
+
previous_substructure_endpoint[0])
|
778 |
+
substructure[0][-1] = str(float(substructure[0][-1]) -
|
779 |
+
previous_substructure_endpoint[1])
|
780 |
+
|
781 |
+
start = list(map(float, substructure[0][-2:]))
|
782 |
+
curr_pos = (0., 0.)
|
783 |
+
for cmd in substructure:
|
784 |
+
curr_pos, _ = _update_curr_pos(curr_pos, cmd, (0., 0.))
|
785 |
+
if are_close_together(start, curr_pos, dist_thresh):
|
786 |
+
new_point = np.array(start)
|
787 |
+
previous_substructure_endpoint = ((new_point[0] - curr_pos[0]),
|
788 |
+
(new_point[1] - curr_pos[1]))
|
789 |
+
substructure[-1][-2] = str(float(substructure[-1][-2]) +
|
790 |
+
(new_point[0] - curr_pos[0]))
|
791 |
+
substructure[-1][-1] = str(float(substructure[-1][-1]) +
|
792 |
+
(new_point[1] - curr_pos[1]))
|
793 |
+
if substructure[-1][0] in 'cC':
|
794 |
+
substructure[-1][-4] = str(float(substructure[-1][-4]) +
|
795 |
+
(new_point[0] - curr_pos[0]))
|
796 |
+
substructure[-1][-3] = str(float(substructure[-1][-3]) +
|
797 |
+
(new_point[1] - curr_pos[1]))
|
798 |
+
|
799 |
+
if skip:
|
800 |
+
return svg_template.format(' '.join([' '.join(' '.join(cmd) for cmd in s)
|
801 |
+
for s in substructures]))
|
802 |
+
|
803 |
+
def cosa(x, y):
|
804 |
+
return (x[0] * y[0] + x[1] * y[1]) / ((np.sqrt(x[0]**2 + x[1]**2) * np.sqrt(y[0]**2 + y[1]**2)))
|
805 |
+
|
806 |
+
def rotate(a, x, y):
|
807 |
+
return (x * np.cos(a) - y * np.sin(a), y * np.cos(a) + x * np.sin(a))
|
808 |
+
# second, gotta find adjacent bezier curves and, if their control points
|
809 |
+
# are well enough aligned, fully align them
|
810 |
+
for substructure in substructures:
|
811 |
+
curr_pos = (0., 0.)
|
812 |
+
new_curr_pos, _ = _update_curr_pos((0., 0.,), substructure[0], (0., 0.))
|
813 |
+
|
814 |
+
for cmd_idx in range(1, len(substructure)):
|
815 |
+
prev_cmd = substructure[cmd_idx-1]
|
816 |
+
cmd = substructure[cmd_idx]
|
817 |
+
|
818 |
+
new_new_curr_pos, _ = _update_curr_pos(
|
819 |
+
new_curr_pos, cmd, (0., 0.))
|
820 |
+
|
821 |
+
if cmd[0] == 'c':
|
822 |
+
if prev_cmd[0] == 'c':
|
823 |
+
# check the vectors and update if needed
|
824 |
+
# previous control pt wrt new curr point
|
825 |
+
prev_ctr_point = (curr_pos[0] + float(prev_cmd[3]) - new_curr_pos[0],
|
826 |
+
curr_pos[1] + float(prev_cmd[4]) - new_curr_pos[1])
|
827 |
+
ctr_point = (float(cmd[1]), float(cmd[2]))
|
828 |
+
|
829 |
+
if -1. < cosa(prev_ctr_point, ctr_point) < -0.95:
|
830 |
+
# calculate exact angle between the two vectors
|
831 |
+
angle_diff = (np.pi - np.arccos(cosa(prev_ctr_point, ctr_point)))/2
|
832 |
+
|
833 |
+
# rotate each vector by angle/2 in the correct direction for each.
|
834 |
+
sign = np.sign(np.cross(prev_ctr_point, ctr_point))
|
835 |
+
new_ctr_point = rotate(sign * angle_diff, *ctr_point)
|
836 |
+
new_prev_ctr_point = rotate(-sign * angle_diff, *prev_ctr_point)
|
837 |
+
|
838 |
+
# override the previous control points
|
839 |
+
# (which has to be wrt previous curr position)
|
840 |
+
substructure[cmd_idx-1][3] = str(new_prev_ctr_point[0] -
|
841 |
+
curr_pos[0] + new_curr_pos[0])
|
842 |
+
substructure[cmd_idx-1][4] = str(new_prev_ctr_point[1] -
|
843 |
+
curr_pos[1] + new_curr_pos[1])
|
844 |
+
substructure[cmd_idx][1] = str(new_ctr_point[0])
|
845 |
+
substructure[cmd_idx][2] = str(new_ctr_point[1])
|
846 |
+
|
847 |
+
curr_pos = new_curr_pos
|
848 |
+
new_curr_pos = new_new_curr_pos
|
849 |
+
|
850 |
+
return svg_template.format(' '.join([' '.join(' '.join(cmd) for cmd in s)
|
851 |
+
for s in substructures]))
|
852 |
+
|
853 |
+
|
854 |
+
# def get_means_stdevs(data_dir):
|
855 |
+
# """Returns the means and stdev saved in data_dir."""
|
856 |
+
# if data_dir not in means_stdevs:
|
857 |
+
# with tf.gfile.Open(os.path.join(data_dir, 'mean.npz'), 'r') as f:
|
858 |
+
# mean_npz = np.load(f)
|
859 |
+
# with tf.gfile.Open(os.path.join(data_dir, 'stdev.npz'), 'r') as f:
|
860 |
+
# stdev_npz = np.load(f)
|
861 |
+
# means_stdevs[data_dir] = (mean_npz, stdev_npz)
|
862 |
+
# return means_stdevs[data_dir]
|
863 |
+
|
864 |
+
|
865 |
+
def render(tensor, data_dir=None):
|
866 |
+
"""Converts SVG decoder output into HTML svg."""
|
867 |
+
# undo normalization
|
868 |
+
# mean_npz, stdev_npz = get_means_stdevs(data_dir)
|
869 |
+
# tensor = (tensor * stdev_npz) + mean_npz
|
870 |
+
|
871 |
+
# convert to html
|
872 |
+
tensor = _make_simple_cmds_long(tensor)
|
873 |
+
# vector = np.squeeze(np.squeeze(tensor, 0), 2)
|
874 |
+
html = _vector_to_svg(tensor, stop_at_eos=True, categorical=True)
|
875 |
+
|
876 |
+
# some aesthetic postprocessing
|
877 |
+
html = postprocess(html)
|
878 |
+
html = html.replace('256px', '50px')
|
879 |
+
|
880 |
+
return html
|
881 |
+
|
882 |
+
###############
|
883 |
+
|
884 |
+
|
885 |
+
def convert_to_svg(decoder_output, categorical=False):
|
886 |
+
converted = []
|
887 |
+
for example in decoder_output:
|
888 |
+
converted.append(_vector_to_svg(example, True, categorical=categorical))
|
889 |
+
return np.array(converted)
|
890 |
+
|
891 |
+
|
892 |
+
def create_image_conversion_fn(max_outputs, categorical=False):
|
893 |
+
"""Binds the number of outputs to the image conversion fn (to svg or png)."""
|
894 |
+
def convert_to_svg(decoder_output):
|
895 |
+
converted = []
|
896 |
+
for example in decoder_output:
|
897 |
+
if len(converted) == max_outputs:
|
898 |
+
break
|
899 |
+
converted.append(_vector_to_svg(example, True, categorical=categorical))
|
900 |
+
return np.array(converted)
|
901 |
+
|
902 |
+
return convert_to_svg
|
903 |
+
|
904 |
+
|
905 |
+
################### UTILS FOR CREATING TF SUMMARIES ##########################
|
906 |
+
def _make_encoded_image(img_tensor):
|
907 |
+
pil_img = Image.fromarray(np.squeeze(img_tensor * 255).astype(np.uint8), mode='L')
|
908 |
+
buff = io.BytesIO()
|
909 |
+
pil_img.save(buff, format='png')
|
910 |
+
encoded_image = buff.getvalue()
|
911 |
+
return encoded_image
|
912 |
+
|
913 |
+
|
914 |
+
################### CHECK GLYPH/PATH VALID ##############################################
|
915 |
+
def is_valid_glyph(g):
|
916 |
+
is_09 = 48 <= g['uni'] <= 57
|
917 |
+
is_capital_az = 65 <= g['uni'] <= 90
|
918 |
+
is_az = 97 <= g['uni'] <= 122
|
919 |
+
is_valid_dims = g['width'] != 0 and g['vwidth'] != 0
|
920 |
+
return (is_09 or is_capital_az or is_az) and is_valid_dims
|
921 |
+
|
922 |
+
|
923 |
+
def is_valid_path(pathunibfp):
|
924 |
+
return pathunibfp[0] and len(pathunibfp[0]) <= MAX_SEQ_LEN
|
925 |
+
|
926 |
+
|
927 |
+
################### DATASET PROCESSING #######################################
|
928 |
+
def convert_to_path(g):
|
929 |
+
"""Converts SplineSet in SFD font to str path."""
|
930 |
+
path = _sfd_to_path_list(g)
|
931 |
+
path = _add_missing_cmds(path, remove_zs=False)
|
932 |
+
path = _normalize_based_on_viewbox(path, '0 0 {} {}'.format(g['width'], g['vwidth']))
|
933 |
+
return path, g['uni'], g['binary_fp']
|
934 |
+
|
935 |
+
|
936 |
+
def create_example(pathunibfp):
|
937 |
+
"""Bulk of dataset processing. Converts str path to np array"""
|
938 |
+
path, uni, binary_fp = pathunibfp
|
939 |
+
final = {}
|
940 |
+
|
941 |
+
# zoom out
|
942 |
+
path = _zoom_out(path)
|
943 |
+
# make clockwise
|
944 |
+
path = _canonicalize(path)
|
945 |
+
|
946 |
+
# render path for training
|
947 |
+
final['rendered'] = _per_step_render(path, absolute=True)
|
948 |
+
|
949 |
+
# make path relative
|
950 |
+
# path = _make_relative(path)
|
951 |
+
# convert to vector
|
952 |
+
vector = _path_to_vector(path, categorical=True)
|
953 |
+
# make simple vector
|
954 |
+
vector = np.array(vector)
|
955 |
+
vector = np.concatenate([np.take(vector, [0, 4, 5, 9], axis=-1), vector[..., -6:]], axis=-1)
|
956 |
+
|
957 |
+
# count some stats
|
958 |
+
final['seq_len'] = np.shape(vector)[0]
|
959 |
+
# final['class'] = int(_map_uni_to_alphanum(uni))
|
960 |
+
final['class'] = int(_map_uni_to_alpha(uni)) # be advised that the class is useless bcz it is all 0
|
961 |
+
final['binary_fp'] = str(binary_fp)
|
962 |
+
|
963 |
+
# append eos
|
964 |
+
vector = _append_eos(vector.tolist(), True, 10)
|
965 |
+
|
966 |
+
# pad path to MAX_SEQ_LEN + 1 (with eos)
|
967 |
+
final['sequence'] = np.concatenate((vector, np.zeros(((MAX_SEQ_LEN - final['seq_len']), 10))), 0)
|
968 |
+
|
969 |
+
# make pure list:
|
970 |
+
# use last channel only
|
971 |
+
final['rendered'] = np.reshape(final['rendered'][..., 0], [64 * 64]).astype(np.float32).tolist()
|
972 |
+
final['sequence'] = np.reshape(final['sequence'], [(MAX_SEQ_LEN + 1) * 10]).astype(np.float32).tolist()
|
973 |
+
final['class'] = np.reshape(final['class'], [1]).astype(np.int64).tolist()
|
974 |
+
final['seq_len'] = np.reshape(final['seq_len'], [1]).astype(np.int64).tolist()
|
975 |
+
return final
|
976 |
+
|
977 |
+
|
978 |
+
def mean_to_example(mean_stdev):
|
979 |
+
"""Converts the found mean and stdev to example."""
|
980 |
+
# mean_stdev is a dict
|
981 |
+
mean_stdev['mean'] = np.reshape(mean_stdev['mean'], [10]).astype(np.float32).tolist()
|
982 |
+
mean_stdev['variance'] = np.reshape(mean_stdev['variance'], [10]).astype(np.float32).tolist()
|
983 |
+
mean_stdev['stddev'] = np.reshape(mean_stdev['stddev'], [10]).astype(np.float32).tolist()
|
984 |
+
mean_stdev['count'] = np.reshape(mean_stdev['count'], [1]).astype(np.int64).tolist()
|
985 |
+
return mean_stdev
|
986 |
+
|
987 |
+
|
988 |
+
def convert_simple_vector_to_path(seq):
|
989 |
+
path=[]
|
990 |
+
for i in range(seq.shape[0]):
|
991 |
+
path_i=[]
|
992 |
+
cmd = np.argmax(seq[i][:4])
|
993 |
+
p0 = seq[i][4:6]
|
994 |
+
p1 = seq[i][6:8]
|
995 |
+
p2 = seq[i][8:10]
|
996 |
+
if cmd == 0:
|
997 |
+
break
|
998 |
+
elif cmd == 1:
|
999 |
+
path_i.append('M')
|
1000 |
+
path_i.append(str(p2[0]))
|
1001 |
+
path_i.append(str(p2[1]))
|
1002 |
+
elif cmd == 2:
|
1003 |
+
path_i.append('L')
|
1004 |
+
path_i.append(str(p2[0]))
|
1005 |
+
path_i.append(str(p2[1]))
|
1006 |
+
elif cmd == 3:
|
1007 |
+
path_i.append('C')
|
1008 |
+
path_i.append(str(p0[0]))
|
1009 |
+
path_i.append(str(p0[1]))
|
1010 |
+
path_i.append(str(p1[0]))
|
1011 |
+
path_i.append(str(p1[1]))
|
1012 |
+
path_i.append(str(p2[0]))
|
1013 |
+
path_i.append(str(p2[1]))
|
1014 |
+
else:
|
1015 |
+
print("wrong!!! to path")
|
1016 |
+
path.append(path_i)
|
1017 |
+
return path
|
1018 |
+
|
1019 |
+
def clockwise(seq):
|
1020 |
+
path = convert_simple_vector_to_path(seq)
|
1021 |
+
path = _canonicalize(path)
|
1022 |
+
ret = {}
|
1023 |
+
vector = _path_to_vector(path, categorical=True)
|
1024 |
+
vector = np.array(vector)
|
1025 |
+
vector = np.concatenate([np.take(vector, [0, 4, 5, 9], axis=-1), vector[..., -6:]], axis=-1)
|
1026 |
+
ret['seq_len'] = np.shape(vector)[0]
|
1027 |
+
vector = _append_eos(vector.tolist(), True, 10)
|
1028 |
+
ret['sequence'] = np.concatenate((vector, np.zeros(((MAX_SEQ_LEN - ret['seq_len']), 10))), 0)
|
1029 |
+
return ret
|
1030 |
+
|
1031 |
+
################### CHECK VALID ##############################################
|
1032 |
+
class MeanStddev:
|
1033 |
+
"""Accumulator to compute the mean/stdev of svg commands."""
|
1034 |
+
|
1035 |
+
def create_accumulator(self):
|
1036 |
+
curr_sum = np.zeros([10])
|
1037 |
+
sum_sq = np.zeros([10])
|
1038 |
+
return (curr_sum, sum_sq, 0) # x, x^2, count
|
1039 |
+
|
1040 |
+
def add_input(self, sum_count, new_input):
|
1041 |
+
(curr_sum, sum_sq, count) = sum_count
|
1042 |
+
# new_input is a dict with keys = ['seq_len', 'sequence']
|
1043 |
+
new_seq_len = new_input['seq_len'][0] # Line #754 'seq_len' is a list of one int
|
1044 |
+
assert isinstance(new_seq_len, int), print(type(new_seq_len))
|
1045 |
+
|
1046 |
+
# remove padding and eos from sequence
|
1047 |
+
assert isinstance(new_input['sequence'], list), print(type(new_input['sequence']))
|
1048 |
+
new_input_np = np.reshape(np.array(new_input['sequence']), [-1, 10])
|
1049 |
+
assert isinstance(new_input_np, np.ndarray), print(type())
|
1050 |
+
assert new_input_np.shape[0] >= new_seq_len
|
1051 |
+
new_input_np = new_input_np[:new_seq_len, :]
|
1052 |
+
|
1053 |
+
# accumulate new_sum and new_sum_sq
|
1054 |
+
new_sum = np.sum([curr_sum, np.sum(new_input_np, axis=0)], axis=0)
|
1055 |
+
new_sum_sq = np.sum([sum_sq, np.sum(np.power(new_input_np, 2), axis=0)],
|
1056 |
+
axis=0)
|
1057 |
+
return new_sum, new_sum_sq, count + new_seq_len
|
1058 |
+
|
1059 |
+
def merge_accumulators(self, accumulators):
|
1060 |
+
curr_sums, sum_sqs, counts = list(zip(*accumulators))
|
1061 |
+
return np.sum(curr_sums, axis=0), np.sum(sum_sqs, axis=0), np.sum(counts)
|
1062 |
+
|
1063 |
+
def extract_output(self, sum_count):
|
1064 |
+
(curr_sum, curr_sum_sq, count) = sum_count
|
1065 |
+
if count:
|
1066 |
+
mean = np.divide(curr_sum, count)
|
1067 |
+
variance = np.divide(curr_sum_sq, count) - np.power(mean, 2)
|
1068 |
+
# -ve value could happen due to rounding
|
1069 |
+
variance = np.max([variance, np.zeros(np.shape(variance))], axis=0)
|
1070 |
+
stddev = np.sqrt(variance)
|
1071 |
+
return {
|
1072 |
+
'mean': mean,
|
1073 |
+
'variance': variance,
|
1074 |
+
'stddev': stddev,
|
1075 |
+
'count': count
|
1076 |
+
}
|
1077 |
+
else:
|
1078 |
+
return {
|
1079 |
+
'mean': float('NaN'),
|
1080 |
+
'variance': float('NaN'),
|
1081 |
+
'stddev': float('NaN'),
|
1082 |
+
'count': 0
|
1083 |
+
}
|
data_utils/svg_utils_backup.py
ADDED
@@ -0,0 +1,1174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Copyright 2020 The Magenta Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import pdb
|
16 |
+
# Lint as: python3
|
17 |
+
"""Defines the Material Design Icons Problem."""
|
18 |
+
import io
|
19 |
+
import numpy as np
|
20 |
+
import re
|
21 |
+
|
22 |
+
from PIL import Image
|
23 |
+
from itertools import zip_longest
|
24 |
+
from skimage import draw
|
25 |
+
import sys
|
26 |
+
|
27 |
+
SVG_PREFIX_BIG = ('<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="'
|
28 |
+
'http://www.w3.org/1999/xlink" width="256px" height="256px"'
|
29 |
+
' style="-ms-transform: rotate(360deg); -webkit-transform:'
|
30 |
+
' rotate(360deg); transform: rotate(360deg);" '
|
31 |
+
'preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24">')
|
32 |
+
PATH_PREFIX_1 = '<path d="'
|
33 |
+
PATH_POSFIX_1 = '" fill="currentColor"/>'
|
34 |
+
SVG_POSFIX = '</svg>'
|
35 |
+
|
36 |
+
NUM_ARGS = {'v': 1, 'V': 1, 'h': 1, 'H': 1, 'a': 7, 'A': 7, 'l': 2, 'L': 2,
|
37 |
+
't': 2, 'T': 2, 'c': 6, 'C': 6, 'm': 2, 'M': 2, 's': 4, 'S': 4,
|
38 |
+
'q': 4, 'Q': 4, 'z': 0}
|
39 |
+
# in order of arg complexity, with absolutes clustered
|
40 |
+
# recall we don't handle all commands (see docstring)
|
41 |
+
|
42 |
+
#note args:
|
43 |
+
# v, h: vertical horizental lines
|
44 |
+
# a: elliptical Arc 椭圆
|
45 |
+
# l: lineto
|
46 |
+
# t: smooth quadratic Bézier curveto 2次贝塞尔曲线
|
47 |
+
# c: curveto
|
48 |
+
# m: moveto
|
49 |
+
# s: smooth curveto
|
50 |
+
# Q: quadratic Bézier curve 2次贝塞尔曲线
|
51 |
+
# z: closepath
|
52 |
+
#CMDS_LIST = 'zhvmltsqcaHVMLTSQCA'
|
53 |
+
CMDS_LIST = 'zHVMLTSQCAhvmltsqca'
|
54 |
+
CMD_MAPPING = {cmd: i for i, cmd in enumerate(CMDS_LIST)}
|
55 |
+
|
56 |
+
FEATURE_DIM = 10
|
57 |
+
|
58 |
+
|
59 |
+
############################### GENERAL UTILS #################################
|
60 |
+
def grouper(iterable, batch_size, fill_value=None):
|
61 |
+
"""Helper method for returning batches of size batch_size of a dataset."""
|
62 |
+
# grouper('ABCDEF', 3) -> 'ABC', 'DEF'
|
63 |
+
args = [iter(iterable)] * batch_size
|
64 |
+
return zip_longest(*args, fillvalue=fill_value)
|
65 |
+
|
66 |
+
|
67 |
+
def _map_uni_to_alphanum(uni):
|
68 |
+
"""Maps [0-9 A-Z a-z] to numbers 0-62."""
|
69 |
+
if 48 <= uni <= 57:
|
70 |
+
return uni - 48
|
71 |
+
elif 65 <= uni <= 90:
|
72 |
+
return uni - 65 + 10
|
73 |
+
return uni - 97 + 36
|
74 |
+
|
75 |
+
|
76 |
+
def _map_uni_to_alpha(uni):
|
77 |
+
"""Maps [A-Z a-z] to numbers 0-52."""
|
78 |
+
if 65 <= uni <= 90:
|
79 |
+
return uni - 65
|
80 |
+
return uni - 97 + 26
|
81 |
+
|
82 |
+
|
83 |
+
############# UTILS FOR CONVERTING SFD/SPLINESETS TO SVG PATHS ################
|
84 |
+
def _get_spline(sfd):
|
85 |
+
if 'SplineSet' not in sfd:
|
86 |
+
return ''
|
87 |
+
pro = sfd[sfd.index('SplineSet') + 10:] # 10 is the 'SplineSet'
|
88 |
+
pro = pro[:pro.index('EndSplineSet')]
|
89 |
+
return pro
|
90 |
+
|
91 |
+
|
92 |
+
def _spline_to_path_list(spline, height, replace_with_prev=False):
|
93 |
+
"""Converts SplineSet to a list of tokenized commands in svg path."""
|
94 |
+
path = []
|
95 |
+
prev_xy = []
|
96 |
+
for line in spline.splitlines():
|
97 |
+
if not line:
|
98 |
+
continue
|
99 |
+
tokens = line.split(' ')
|
100 |
+
cmd = tokens[-2]
|
101 |
+
if cmd not in 'cml':
|
102 |
+
# COMMAND NOT RECOGNIZED.
|
103 |
+
return []
|
104 |
+
# assert cmd in 'cml', 'Command not recognized: {}'.format(cmd)
|
105 |
+
args = tokens[:-2]
|
106 |
+
args = [float(x) for x in args if x]
|
107 |
+
|
108 |
+
if replace_with_prev and cmd in 'c':
|
109 |
+
args[:2] = prev_xy
|
110 |
+
prev_xy = args[-2:]
|
111 |
+
|
112 |
+
new_y_args = []
|
113 |
+
for i, a in enumerate(args):
|
114 |
+
if i % 2 == 1:
|
115 |
+
new_y_args.append((height - a))
|
116 |
+
else:
|
117 |
+
new_y_args.append((a))
|
118 |
+
|
119 |
+
path.append([cmd.upper()] + new_y_args)
|
120 |
+
return path
|
121 |
+
|
122 |
+
|
123 |
+
def _sfd_to_path_list(single, replace_with_prev=False):
|
124 |
+
"""Converts the given SFD glyph into a path."""
|
125 |
+
return _spline_to_path_list(_get_spline(single['sfd']), single['vwidth'], replace_with_prev)
|
126 |
+
|
127 |
+
|
128 |
+
#################### UTILS FOR PROCESSING TOKENIZED PATHS #####################
|
129 |
+
def _add_missing_cmds(path, remove_zs=False):
|
130 |
+
"""Adds missing cmd tags to the commands in the svg."""
|
131 |
+
# For instance, the command 'a' takes 7 arguments, but some SVGs declare:
|
132 |
+
# a 1 2 3 4 5 6 7 8 9 10 11 12 13 14
|
133 |
+
# Which is 14 arguments. This function converts the above to the equivalent:
|
134 |
+
# a 1 2 3 4 5 6 7 a 8 9 10 11 12 13 14
|
135 |
+
#
|
136 |
+
# Note: if remove_zs is True, this also removes any occurences of z commands.
|
137 |
+
new_path = []
|
138 |
+
for cmd in path:
|
139 |
+
if not remove_zs or cmd[0] not in 'Zz':
|
140 |
+
for new_cmd in add_missing_cmd(cmd):
|
141 |
+
new_path.append(new_cmd)
|
142 |
+
return new_path
|
143 |
+
|
144 |
+
|
145 |
+
def add_missing_cmd(command_list):
|
146 |
+
"""Adds missing cmd tags to the given command list."""
|
147 |
+
# E.g.: given:
|
148 |
+
# ['a', '0', '0', '0', '0', '0', '0', '0',
|
149 |
+
# '0', '0', '0', '0', '0', '0', '0']
|
150 |
+
# Converts to:
|
151 |
+
# [['a', '0', '0', '0', '0', '0', '0', '0'],
|
152 |
+
# ['a', '0', '0', '0', '0', '0', '0', '0']]
|
153 |
+
# And returns a string that joins these elements with spaces.
|
154 |
+
cmd_tag = command_list[0]
|
155 |
+
args = command_list[1:]
|
156 |
+
|
157 |
+
final_cmds = []
|
158 |
+
for arg_batch in grouper(args, NUM_ARGS[cmd_tag]):
|
159 |
+
final_cmds.append([cmd_tag] + list(arg_batch))
|
160 |
+
|
161 |
+
if not final_cmds:
|
162 |
+
# command has no args (e.g.: 'z')
|
163 |
+
final_cmds = [[cmd_tag]]
|
164 |
+
|
165 |
+
return final_cmds
|
166 |
+
|
167 |
+
|
168 |
+
def _normalize_args(arglist, norm, add=None, flip=False):
|
169 |
+
"""Normalize the given args with the given norm value."""
|
170 |
+
new_arglist = []
|
171 |
+
for i, arg in enumerate(arglist):
|
172 |
+
new_arg = float(arg)
|
173 |
+
|
174 |
+
if add is not None:
|
175 |
+
add_to_x, add_to_y = add
|
176 |
+
|
177 |
+
# This argument is an x-coordinate if even, y-coordinate if odd
|
178 |
+
# except when flip == True
|
179 |
+
if i % 2 == 0:
|
180 |
+
new_arg += add_to_y if flip else add_to_x
|
181 |
+
else:
|
182 |
+
new_arg += add_to_x if flip else add_to_y
|
183 |
+
|
184 |
+
new_arglist.append(str(24 * new_arg / norm))
|
185 |
+
return new_arglist
|
186 |
+
|
187 |
+
|
188 |
+
def _normalize_based_on_viewbox(path, viewbox):
|
189 |
+
"""Normalizes all args in a path to a standard 24x24 viewbox."""
|
190 |
+
# Each SVG lives in a 2D plane. The viewbox determines the region of that
|
191 |
+
# plane that gets rendered. For instance, some designers may work with a
|
192 |
+
# viewbox that's 24x24, others with one that's 100x100, etc.
|
193 |
+
|
194 |
+
# Suppose I design the the letter "h" in the Arial style using a 100x100
|
195 |
+
# viewbox (let's call it icon A). Let's suppose the icon has height 75. Then,
|
196 |
+
# I design the same character using a 20x20 viewbox (call this icon B), with
|
197 |
+
# height 15 (=75% of 20). This means that, when rendered, both icons with look
|
198 |
+
# exactly the same, but the scale of the commands each icon is using is
|
199 |
+
# different. For instance, if icon A has a command like "lineTo 100 100", the
|
200 |
+
# equivalent command in icon B will be "lineTo 20 20".
|
201 |
+
|
202 |
+
# In order to avoid this problem and bring all real values to the same scale,
|
203 |
+
# I scale all icons' commands to use a 24x24 viewbox. This function does this:
|
204 |
+
# it converts a path that exists in the given viewbox into a standard 24x24
|
205 |
+
# viewbox.
|
206 |
+
viewbox = viewbox.split(' ')
|
207 |
+
norm = max(int(viewbox[-1]), int(viewbox[-2]))
|
208 |
+
|
209 |
+
if int(viewbox[-1]) > int(viewbox[-2]):
|
210 |
+
add_to_y = 0
|
211 |
+
add_to_x = abs(int(viewbox[-1]) - int(viewbox[-2])) / 2
|
212 |
+
else:
|
213 |
+
add_to_y = abs(int(viewbox[-1]) - int(viewbox[-2])) / 2
|
214 |
+
add_to_x = 0
|
215 |
+
|
216 |
+
new_path = []
|
217 |
+
for command in path:
|
218 |
+
if command[0] == 'a':
|
219 |
+
new_path.append([command[0]] + _normalize_args(command[1:3], norm)
|
220 |
+
+ command[3:6] + _normalize_args(command[6:], norm))
|
221 |
+
elif command[0] == 'A':
|
222 |
+
new_path.append([command[0]] + _normalize_args(command[1:3], norm)
|
223 |
+
+ command[3:6] + _normalize_args(command[6:], norm, add=(add_to_x, add_to_y)))
|
224 |
+
elif command[0] == 'V':
|
225 |
+
new_path.append([command[0]] + _normalize_args(command[1:], norm, add=(add_to_x, add_to_y), flip=True))
|
226 |
+
elif command[0] == command[0].upper():
|
227 |
+
new_path.append([command[0]] + _normalize_args(command[1:], norm, add=(add_to_x, add_to_y)))
|
228 |
+
elif command[0] in 'zZ':
|
229 |
+
new_path.append([command[0]])
|
230 |
+
else:
|
231 |
+
new_path.append([command[0]] + _normalize_args(command[1:], norm))
|
232 |
+
|
233 |
+
return new_path
|
234 |
+
|
235 |
+
|
236 |
+
def _convert_args(args, curr_pos, cmd):
|
237 |
+
"""Converts given args to relative values."""
|
238 |
+
# NOTE: glyphs only use a very small subset of commands (L, C, M, and Z -- I
|
239 |
+
# believe). So I'm not handling A and H for now.
|
240 |
+
if cmd in 'AH':
|
241 |
+
raise NotImplementedError('These commands have >6 args (not supported).')
|
242 |
+
|
243 |
+
new_args = []
|
244 |
+
for i, arg in enumerate(args):
|
245 |
+
x_or_y = i % 2
|
246 |
+
if cmd == 'H':
|
247 |
+
x_or_y = (i + 1) % 2
|
248 |
+
new_args.append(str(float(arg) - curr_pos[x_or_y]))
|
249 |
+
|
250 |
+
return new_args
|
251 |
+
|
252 |
+
|
253 |
+
def _update_curr_pos(curr_pos, cmd, start_of_path):
|
254 |
+
"""Calculate the position of the pen after cmd is applied."""
|
255 |
+
if cmd[0] in 'ml':
|
256 |
+
curr_pos = [curr_pos[0] + float(cmd[1]), curr_pos[1] + float(cmd[2])]
|
257 |
+
if cmd[0] == 'm':
|
258 |
+
start_of_path = curr_pos
|
259 |
+
elif cmd[0] in 'z':
|
260 |
+
curr_pos = start_of_path
|
261 |
+
elif cmd[0] in 'h':
|
262 |
+
curr_pos = [curr_pos[0] + float(cmd[1]), curr_pos[1]]
|
263 |
+
elif cmd[0] in 'v':
|
264 |
+
curr_pos = [curr_pos[0], curr_pos[1] + float(cmd[1])]
|
265 |
+
elif cmd[0] in 'ctsqa':
|
266 |
+
curr_pos = [curr_pos[0] + float(cmd[-2]), curr_pos[1] + float(cmd[-1])]
|
267 |
+
|
268 |
+
return curr_pos, start_of_path
|
269 |
+
|
270 |
+
|
271 |
+
def _make_relative(cmds):
|
272 |
+
"""Convert commands in a path to relative positioning."""
|
273 |
+
curr_pos = (0.0, 0.0)
|
274 |
+
start_of_path = (0.0, 0.0)
|
275 |
+
new_cmds = []
|
276 |
+
for cmd in cmds:
|
277 |
+
if cmd[0].lower() == cmd[0]:
|
278 |
+
new_cmd = cmd
|
279 |
+
elif cmd[0].lower() == 'z':
|
280 |
+
new_cmd = [cmd[0].lower()]
|
281 |
+
else:
|
282 |
+
new_cmd = [cmd[0].lower()] + _convert_args(cmd[1:], curr_pos, cmd=cmd[0])
|
283 |
+
new_cmds.append(new_cmd)
|
284 |
+
curr_pos, start_of_path = _update_curr_pos(curr_pos, new_cmd, start_of_path)
|
285 |
+
return new_cmds
|
286 |
+
|
287 |
+
|
288 |
+
def _is_to_left_of(pt1, pt2):
|
289 |
+
pt1_norm = (pt1[0]**2 + pt1[1]**2)
|
290 |
+
pt2_norm = (pt2[0]**2 + pt2[1]**2)
|
291 |
+
return pt1[1] < pt2[1] or (pt1_norm == pt2_norm and pt1[0] < pt2[0])
|
292 |
+
|
293 |
+
|
294 |
+
def _get_leftmost_point(path):
|
295 |
+
"""Returns the leftmost, topmost point of the path."""
|
296 |
+
leftmost = (float('inf'), float('inf'))
|
297 |
+
idx = -1
|
298 |
+
|
299 |
+
for i, cmd in enumerate(path):
|
300 |
+
if len(cmd) > 1:
|
301 |
+
endpoint = cmd[-2:]
|
302 |
+
if _is_to_left_of(endpoint, leftmost):
|
303 |
+
leftmost = endpoint
|
304 |
+
idx = i
|
305 |
+
|
306 |
+
return leftmost, idx
|
307 |
+
|
308 |
+
|
309 |
+
def _separate_substructures(path):
|
310 |
+
"""Returns a list of subpaths, each representing substructures the glyph."""
|
311 |
+
substructures = []
|
312 |
+
curr = []
|
313 |
+
for cmd in path:
|
314 |
+
if cmd[0] in 'mM' and curr:
|
315 |
+
substructures.append(curr)
|
316 |
+
curr = []
|
317 |
+
curr.append(cmd)
|
318 |
+
if curr:
|
319 |
+
substructures.append(curr)
|
320 |
+
return substructures
|
321 |
+
|
322 |
+
|
323 |
+
def _is_clockwise(subpath):
|
324 |
+
"""Returns whether the given subpath is clockwise-oriented."""
|
325 |
+
pts = [cmd[-2:] for cmd in subpath]
|
326 |
+
det = 0
|
327 |
+
for i in range(len(pts) - 1):
|
328 |
+
det += np.linalg.det(pts[i:i + 2])
|
329 |
+
return det > 0
|
330 |
+
|
331 |
+
|
332 |
+
def _make_clockwise(subpath):
|
333 |
+
"""Inverts the cardinality of the given subpath."""
|
334 |
+
new_path = [subpath[0]]
|
335 |
+
other_cmds = list(reversed(subpath[1:]))
|
336 |
+
for i, cmd in enumerate(other_cmds):
|
337 |
+
if i + 1 == len(other_cmds):
|
338 |
+
where_we_were = subpath[0][-2:]
|
339 |
+
else:
|
340 |
+
where_we_were = other_cmds[i + 1][-2:]
|
341 |
+
|
342 |
+
if len(cmd) > 3:
|
343 |
+
new_cmd = [cmd[0], cmd[3], cmd[4], cmd[1], cmd[2],
|
344 |
+
where_we_were[0], where_we_were[1]]
|
345 |
+
else:
|
346 |
+
new_cmd = [cmd[0], where_we_were[0], where_we_were[1]]
|
347 |
+
|
348 |
+
new_path.append(new_cmd)
|
349 |
+
return new_path
|
350 |
+
|
351 |
+
|
352 |
+
def _canonicalize(path):
|
353 |
+
"""Makes all paths start at top left, and go clockwise first."""
|
354 |
+
# convert args to floats
|
355 |
+
#print(len(path),path)
|
356 |
+
|
357 |
+
path = [[x[0]] + list(map(float, x[1:])) for x in path]
|
358 |
+
# print(len(path),path)
|
359 |
+
|
360 |
+
# _canonicalize each subpath separately
|
361 |
+
#pdb.set_trace()
|
362 |
+
|
363 |
+
new_substructures = []
|
364 |
+
for subpath in _separate_substructures(path):
|
365 |
+
# print(subpath,"\n")
|
366 |
+
leftmost_point, leftmost_idx = _get_leftmost_point(subpath)
|
367 |
+
reordered = ([['M', leftmost_point[0], leftmost_point[1]]] + subpath[leftmost_idx + 1:] + subpath[1:leftmost_idx + 1])
|
368 |
+
new_substructures.append((reordered, leftmost_point))
|
369 |
+
|
370 |
+
# sys.exit()
|
371 |
+
new_path = []
|
372 |
+
first_substructure_done = False
|
373 |
+
should_flip_cardinality = False
|
374 |
+
for sp, _ in sorted(new_substructures, key=lambda x: (x[1][1], x[1][0])):
|
375 |
+
if not first_substructure_done:
|
376 |
+
# we're looking at the first substructure now, we can determine whether we
|
377 |
+
# will flip the cardniality of the whole icon or not
|
378 |
+
should_flip_cardinality = not _is_clockwise(sp)
|
379 |
+
first_substructure_done = True
|
380 |
+
|
381 |
+
if should_flip_cardinality:
|
382 |
+
sp = _make_clockwise(sp)
|
383 |
+
|
384 |
+
new_path.extend(sp)
|
385 |
+
|
386 |
+
# convert args to strs
|
387 |
+
path = [[x[0]] + list(map(str, x[1:])) for x in new_path]
|
388 |
+
return path
|
389 |
+
|
390 |
+
|
391 |
+
# ######### UTILS FOR CONVERTING TOKENIZED PATHS TO VECTORS ###########
|
392 |
+
def _path_to_vector(path, categorical=False):
|
393 |
+
"""Converts path's commands to a series of vectors."""
|
394 |
+
# Notes:
|
395 |
+
# - The SimpleSVG dataset does not have any 't', 'q', 'Z', 'T', or 'Q'.
|
396 |
+
# Thus, we don't handle those here.
|
397 |
+
# - We also removed all 'z's.
|
398 |
+
# - The x-axis-rotation argument to a commands is always 0 in this
|
399 |
+
# dataset, so we ignore it
|
400 |
+
|
401 |
+
# Many commands have args that correspond to args in other commands.
|
402 |
+
# v __,__ _______________ ______________,_________ __,__ __,__ _,y
|
403 |
+
# h __,__ _______________ ______________,_________ __,__ __,__ x,_
|
404 |
+
# z __,__ _______________ ______________,_________ __,__ __,__ _,_
|
405 |
+
# a rx,ry x-axis-rotation large-arc-flag,sweepflag __,__ __,__ x,y
|
406 |
+
# l __,__ _______________ ______________,_________ __,__ __,__ x,y
|
407 |
+
# c __,__ _______________ ______________,_________ x1,y1 x2,y2 x,y
|
408 |
+
# m __,__ _______________ ______________,_________ __,__ __,__ x,y
|
409 |
+
# s __,__ _______________ ______________,_________ __,__ x2,y2 x,y
|
410 |
+
|
411 |
+
# So each command will be converted to a vector where the dimension is the
|
412 |
+
# minimal number of arguments to all commands:
|
413 |
+
# [rx, ry, large-arc-flag, sweepflag, x1, y1, x2, y2, x, y]
|
414 |
+
# If a command does not output a certain arg, it is set to 0.
|
415 |
+
# "l 5,5" becomes [0, 0, 0, 0, 0, 0, 0, 0, 5, 5]
|
416 |
+
|
417 |
+
# Also note, as of now we also output an extra dimension at index 0, which
|
418 |
+
# indicates which command is being outputted (integer).
|
419 |
+
new_path = []
|
420 |
+
for cmd in path:
|
421 |
+
new_path.append(_cmd_to_vector(cmd, categorical=categorical))
|
422 |
+
return new_path
|
423 |
+
|
424 |
+
|
425 |
+
def _cmd_to_vector(cmd_list, categorical=False):
|
426 |
+
"""Converts the given command (given as a list) into a vector.
|
427 |
+
UM_ARGS = {'v': 1, 'V': 1, 'h': 1, 'H': 1, 'a': 7, 'A': 7, 'l': 2, 'L': 2,
|
428 |
+
't': 2, 'T': 2, 'c': 6, 'C': 6, 'm': 2, 'M': 2, 's': 4, 'S': 4,
|
429 |
+
'q': 4, 'Q': 4, 'z': 0}
|
430 |
+
|
431 |
+
CMDS_LIST = 'zhvmltsqcaHVMLTSQCA'
|
432 |
+
CMD_MAPPING = {cmd: i for i, cmd in enumerate(CMDS_LIST)}
|
433 |
+
"""
|
434 |
+
# For description of how this conversion happens, see
|
435 |
+
# _path_to_vector docstring.
|
436 |
+
cmd = cmd_list[0]
|
437 |
+
args = cmd_list[1:]
|
438 |
+
|
439 |
+
if not categorical:
|
440 |
+
# integer, for MSE
|
441 |
+
command = [float(CMD_MAPPING[cmd])]
|
442 |
+
else:
|
443 |
+
# one hot + 1 dim for EOS.
|
444 |
+
command = [0.0] * (len(CMDS_LIST) + 1) # 大概有19个commands?
|
445 |
+
command[CMD_MAPPING[cmd] + 1] = 1.0
|
446 |
+
|
447 |
+
arguments = [0.0] * 10
|
448 |
+
if cmd in 'hH':
|
449 |
+
arguments[8] = float(args[0]) # x
|
450 |
+
elif cmd in 'vV':
|
451 |
+
arguments[9] = float(args[0]) # y
|
452 |
+
elif cmd in 'mMlLtT':
|
453 |
+
arguments[8] = float(args[0]) # x
|
454 |
+
arguments[9] = float(args[1]) # y
|
455 |
+
elif cmd in 'sSqQ':
|
456 |
+
arguments[6] = float(args[0]) # x2
|
457 |
+
arguments[7] = float(args[1]) # y2
|
458 |
+
arguments[8] = float(args[2]) # x
|
459 |
+
arguments[9] = float(args[3]) # y
|
460 |
+
elif cmd in 'cC':
|
461 |
+
arguments[4] = float(args[0]) # x1
|
462 |
+
arguments[5] = float(args[1]) # y1
|
463 |
+
arguments[6] = float(args[2]) # x2
|
464 |
+
arguments[7] = float(args[3]) # y2
|
465 |
+
arguments[8] = float(args[4]) # x
|
466 |
+
arguments[9] = float(args[5]) # y
|
467 |
+
elif cmd in 'aA':
|
468 |
+
arguments[0] = float(args[0]) # rx
|
469 |
+
arguments[1] = float(args[1]) # ry
|
470 |
+
# we skip x-axis-rotation
|
471 |
+
arguments[2] = float(args[3]) # large-arc-flag
|
472 |
+
arguments[3] = float(args[4]) # sweep-flag
|
473 |
+
# a does not have x1, y1, x2, y2 args
|
474 |
+
arguments[8] = float(args[5]) # x
|
475 |
+
arguments[9] = float(args[6]) # y
|
476 |
+
|
477 |
+
return command + arguments
|
478 |
+
|
479 |
+
|
480 |
+
################## UTILS FOR RENDERING PATH INTO IMAGE #################
|
481 |
+
def _cubicbezier(x0, y0, x1, y1, x2, y2, x3, y3, n=40):
|
482 |
+
"""Return n points along cubiz bezier with given control points."""
|
483 |
+
# from http://rosettacode.org/wiki/Bitmap/B%C3%A9zier_curves/Cubic
|
484 |
+
pts = []
|
485 |
+
for i in range(n + 1):
|
486 |
+
t = float(i) / float(n)
|
487 |
+
a = (1. - t)**3
|
488 |
+
b = 3. * t * (1. - t)**2
|
489 |
+
c = 3.0 * t**2 * (1.0 - t)
|
490 |
+
d = t**3
|
491 |
+
|
492 |
+
x = float(a * x0 + b * x1 + c * x2 + d * x3)
|
493 |
+
y = float(a * y0 + b * y1 + c * y2 + d * y3)
|
494 |
+
pts.append((x, y))
|
495 |
+
return list(zip(*pts))
|
496 |
+
|
497 |
+
|
498 |
+
def _update_pos(curr_pos, end_pos, absolute):
|
499 |
+
if absolute:
|
500 |
+
return end_pos
|
501 |
+
return curr_pos[0] + end_pos[0], curr_pos[1] + end_pos[1]
|
502 |
+
|
503 |
+
|
504 |
+
def constant_color(*unused_args):
|
505 |
+
return np.array([255, 255, 255])
|
506 |
+
|
507 |
+
|
508 |
+
def _render_cubic(canvas, curr_pos, c_args, absolute, color):
|
509 |
+
"""Renders a cubic bezier curve in the given canvas."""
|
510 |
+
if not absolute:
|
511 |
+
c_args[0] += curr_pos[0]
|
512 |
+
c_args[1] += curr_pos[1]
|
513 |
+
c_args[2] += curr_pos[0]
|
514 |
+
c_args[3] += curr_pos[1]
|
515 |
+
c_args[4] += curr_pos[0]
|
516 |
+
c_args[5] += curr_pos[1]
|
517 |
+
x, y = _cubicbezier(curr_pos[0], curr_pos[1],
|
518 |
+
c_args[0], c_args[1],
|
519 |
+
c_args[2], c_args[3],
|
520 |
+
c_args[4], c_args[5])
|
521 |
+
max_possible = len(canvas)
|
522 |
+
x = [int(round(x_)) for x_ in x]
|
523 |
+
y = [int(round(y_)) for y_ in y]
|
524 |
+
|
525 |
+
def within_range(x):
|
526 |
+
return 0 <= x < max_possible
|
527 |
+
|
528 |
+
filtered = [(x_, y_) for x_, y_ in zip(x, y)
|
529 |
+
if within_range(x_) and within_range(y_)]
|
530 |
+
if not filtered:
|
531 |
+
return
|
532 |
+
x, y = list(zip(*filtered))
|
533 |
+
canvas[y, x, :] = color
|
534 |
+
|
535 |
+
|
536 |
+
def _render_line(canvas, curr_pos, l_args, absolute, color):
|
537 |
+
"""Renders a line in the given canvas."""
|
538 |
+
end_point = l_args
|
539 |
+
if not absolute:
|
540 |
+
end_point[0] += curr_pos[0]
|
541 |
+
end_point[1] += curr_pos[1]
|
542 |
+
rr, cc, val = draw.line_aa(int(curr_pos[0]), int(curr_pos[1]),
|
543 |
+
int(end_point[0]), int(end_point[1]))
|
544 |
+
|
545 |
+
max_possible = len(canvas)
|
546 |
+
|
547 |
+
def within_range(x):
|
548 |
+
return 0 <= x < max_possible
|
549 |
+
|
550 |
+
filtered = [(x, y, v) for x, y, v in zip(rr, cc, val)
|
551 |
+
if within_range(x) and within_range(y)]
|
552 |
+
if not filtered:
|
553 |
+
return
|
554 |
+
rr, cc, val = list(zip(*filtered))
|
555 |
+
val = [(v * color) for v in val]
|
556 |
+
canvas[cc, rr, :] = val
|
557 |
+
|
558 |
+
|
559 |
+
def _per_step_render(path, absolute=False, color=constant_color):
|
560 |
+
"""Render the icon's edges, given its path."""
|
561 |
+
def to_canvas_size(l):
|
562 |
+
return [float(f) * (64. / 24.) for f in l]
|
563 |
+
|
564 |
+
canvas = np.zeros((64, 64, 3))
|
565 |
+
curr_pos = (0.0, 0.0)
|
566 |
+
for i, cmd in enumerate(path):
|
567 |
+
if not cmd:
|
568 |
+
continue
|
569 |
+
if cmd[0] in 'mM':
|
570 |
+
curr_pos = _update_pos(curr_pos, to_canvas_size(cmd[-2:]), absolute)
|
571 |
+
elif cmd[0] in 'cC':
|
572 |
+
_render_cubic(canvas, curr_pos, to_canvas_size(cmd[1:]), absolute, color(i, 55))
|
573 |
+
curr_pos = _update_pos(curr_pos, to_canvas_size(cmd[-2:]), absolute)
|
574 |
+
elif cmd[0] in 'lL':
|
575 |
+
_render_line(canvas, curr_pos, to_canvas_size(cmd[1:]), absolute, color(i, 55))
|
576 |
+
curr_pos = _update_pos(curr_pos, to_canvas_size(cmd[1:]), absolute)
|
577 |
+
|
578 |
+
return canvas
|
579 |
+
|
580 |
+
|
581 |
+
def _zoom_out(path_list, add_baseline=0., per=22):
|
582 |
+
"""Makes glyph slightly smaller in viewbox, makes some descenders visible."""
|
583 |
+
# assumes tensor is already unnormalized, and in long form
|
584 |
+
new_path = []
|
585 |
+
for command in path_list:
|
586 |
+
args = []
|
587 |
+
is_even = False
|
588 |
+
for arg in command[1:]:
|
589 |
+
if is_even:
|
590 |
+
args.append(str(float(arg) - ((24. - per) / 24.) * 64. / 4.))
|
591 |
+
is_even = False
|
592 |
+
else:
|
593 |
+
args.append(str(float(arg) - add_baseline))
|
594 |
+
is_even = True
|
595 |
+
new_path.append([command[0]] + args)
|
596 |
+
return new_path
|
597 |
+
|
598 |
+
|
599 |
+
##################### UTILS FOR PROCESSING VECTORS ################
|
600 |
+
def _append_eos(sample, categorical, feature_dim):
|
601 |
+
if not categorical:
|
602 |
+
eos = -1 * np.ones(feature_dim)
|
603 |
+
else:
|
604 |
+
eos = np.zeros(feature_dim)
|
605 |
+
eos[0] = 1.0
|
606 |
+
sample.append(eos)
|
607 |
+
return sample
|
608 |
+
|
609 |
+
|
610 |
+
def _make_simple_cmds_long(out):
|
611 |
+
"""Converts svg decoder output to format required by some render functions."""
|
612 |
+
# out has 10 dims
|
613 |
+
# the first 4 are respectively dims 0, 4, 5, 9 of the full 20-dim onehot vec
|
614 |
+
# the latter 6 are the 6 last dims of the 10-dim arg vec
|
615 |
+
shape_minus_dim = list(np.shape(out))[:-1]
|
616 |
+
# print("make? ",shape_minus_dim ) # [51]
|
617 |
+
|
618 |
+
return np.concatenate([out[..., :1], # [51,1] 51个steps的第1维特征
|
619 |
+
np.zeros(shape_minus_dim + [3]),# [51,3]
|
620 |
+
out[..., 1:3], #[51,2]
|
621 |
+
np.zeros(shape_minus_dim + [3]),# [51,3]
|
622 |
+
out[..., 3:4],# [51,1]
|
623 |
+
np.zeros(shape_minus_dim + [14]),# [51,14]
|
624 |
+
out[..., 4:]], -1)# [51,6] # 最后的6个绘制参数
|
625 |
+
|
626 |
+
def render(tensor, data_dir=None):
|
627 |
+
"""Converts SVG decoder output into HTML svg."""
|
628 |
+
# undo normalization
|
629 |
+
# mean_npz, stdev_npz = get_means_stdevs(data_dir)
|
630 |
+
# tensor = (tensor * stdev_npz) + mean_npz
|
631 |
+
|
632 |
+
# convert to html
|
633 |
+
#print("before",tensor.shape)# 51, 10)
|
634 |
+
tensor = _make_simple_cmds_long(tensor)
|
635 |
+
# print("after",tensor.shape)#(51, 30)
|
636 |
+
# vector = np.squeeze(np.squeeze(tensor, 0), 2)
|
637 |
+
# print("1",tensor[0,:5])# (51, 30)
|
638 |
+
html = _vector_to_svg(tensor, stop_at_eos=True, categorical=True)
|
639 |
+
# print(html.shape)
|
640 |
+
# some aesthetic postprocessing
|
641 |
+
html = postprocess(html)
|
642 |
+
html = html.replace('256px', '50px')
|
643 |
+
|
644 |
+
return html
|
645 |
+
|
646 |
+
################# UTILS FOR CONVERTING VECTORS TO SVGS ########################
|
647 |
+
#note: transform the decoded trg_seq into the common svg format.把decode出来的seq转成html的svg,命令有前后关系,也都是相对位置。
|
648 |
+
def _vector_to_svg(vectors, stop_at_eos=False, categorical=False):
|
649 |
+
"""Tranforms a given vector to an svg string.
|
650 |
+
|
651 |
+
"""
|
652 |
+
new_path = []
|
653 |
+
for vector in vectors:
|
654 |
+
if stop_at_eos:
|
655 |
+
if categorical:
|
656 |
+
try:
|
657 |
+
is_eos = np.argmax(vector[:len(CMDS_LIST) + 1]) == 0
|
658 |
+
except Exception:
|
659 |
+
raise Exception(vector)
|
660 |
+
else:
|
661 |
+
is_eos = vector[0] < -0.5
|
662 |
+
|
663 |
+
if is_eos:
|
664 |
+
break
|
665 |
+
new_path.append(' '.join(_vector_to_cmd(vector, categorical=categorical))) #
|
666 |
+
new_path = ' '.join(new_path) # 加入new_path,每个path都以空格分隔
|
667 |
+
return SVG_PREFIX_BIG + PATH_PREFIX_1 + new_path + PATH_POSFIX_1 + SVG_POSFIX
|
668 |
+
|
669 |
+
def _vector_to_path(vectors):
|
670 |
+
new_path = []
|
671 |
+
for vector in vectors:
|
672 |
+
#print(vector,"???")
|
673 |
+
new_path.append(_vector_to_cmd(vector,categorical=True)) #
|
674 |
+
#print(_vector_to_cmd(vector),"hhh")
|
675 |
+
# new_path = ' '.join(new_path) # 加入new_path,每个path都以空格分隔
|
676 |
+
return new_path
|
677 |
+
|
678 |
+
def _vector_to_cmd(vector, categorical=False, return_floats=False):
|
679 |
+
"""Does the inverse transformation as _cmd_to_vector().
|
680 |
+
UM_ARGS = {'v': 1, 'V': 1, 'h': 1, 'H': 1, 'a': 7, 'A': 7, 'l': 2, 'L': 2,
|
681 |
+
't': 2, 'T': 2, 'c': 6, 'C': 6, 'm': 2, 'M': 2, 's': 4, 'S': 4,
|
682 |
+
'q': 4, 'Q': 4, 'z': 0}
|
683 |
+
|
684 |
+
CMDS_LIST = 'zhvmltsqcaHVMLTSQCA'
|
685 |
+
CMD_MAPPING = {cmd: i for i, cmd in enumerate(CMDS_LIST)}
|
686 |
+
|
687 |
+
"""
|
688 |
+
cast_fn = float if return_floats else str
|
689 |
+
if categorical:
|
690 |
+
# print(vector.shape,vector)# 30
|
691 |
+
#print("??",len(CMDS_LIST)) # 19
|
692 |
+
command = vector[:len(CMDS_LIST) + 1],# 前20维
|
693 |
+
arguments = vector[len(CMDS_LIST) + 1:]# 后10维
|
694 |
+
cmd_idx = np.argmax(command) - 1 # 看当前绘制命令属于哪一类
|
695 |
+
|
696 |
+
else:
|
697 |
+
|
698 |
+
command, arguments = vector[:1], vector[1:]
|
699 |
+
cmd_idx = int(round(command[0]))
|
700 |
+
|
701 |
+
if cmd_idx < -0.5:
|
702 |
+
# EOS
|
703 |
+
return []
|
704 |
+
if cmd_idx >= len(CMDS_LIST):
|
705 |
+
cmd_idx = len(CMDS_LIST) - 1
|
706 |
+
|
707 |
+
cmd = CMDS_LIST[cmd_idx]
|
708 |
+
cmd = cmd.upper()
|
709 |
+
cmd_list = [cmd]
|
710 |
+
|
711 |
+
if cmd in 'hH': # 如果是画线,而且是x轴
|
712 |
+
cmd_list.append(cast_fn(arguments[8])) # x
|
713 |
+
elif cmd in 'vV': # 如果是画线,而且是y轴
|
714 |
+
cmd_list.append(cast_fn(arguments[9])) # y
|
715 |
+
elif cmd in 'mMlLtT':
|
716 |
+
cmd_list.append(cast_fn(arguments[8])) # x
|
717 |
+
cmd_list.append(cast_fn(arguments[9])) # y
|
718 |
+
elif cmd in 'sSqQ':
|
719 |
+
cmd_list.append(cast_fn(arguments[6])) # x2
|
720 |
+
cmd_list.append(cast_fn(arguments[7])) # y2
|
721 |
+
cmd_list.append(cast_fn(arguments[8])) # x
|
722 |
+
cmd_list.append(cast_fn(arguments[9])) # y
|
723 |
+
elif cmd in 'cC':
|
724 |
+
cmd_list.append(cast_fn(arguments[4])) # x1
|
725 |
+
cmd_list.append(cast_fn(arguments[5])) # y1
|
726 |
+
cmd_list.append(cast_fn(arguments[6])) # x2
|
727 |
+
cmd_list.append(cast_fn(arguments[7])) # y2
|
728 |
+
cmd_list.append(cast_fn(arguments[8])) # x
|
729 |
+
cmd_list.append(cast_fn(arguments[9])) # y
|
730 |
+
elif cmd in 'aA':
|
731 |
+
cmd_list.append(cast_fn(arguments[0])) # rx
|
732 |
+
cmd_list.append(cast_fn(arguments[1])) # ry
|
733 |
+
# x-axis-rotation is always 0
|
734 |
+
cmd_list.append(cast_fn('0'))
|
735 |
+
# the following two flags are binary.
|
736 |
+
cmd_list.append(cast_fn(1 if arguments[2] > 0.5 else 0)) # large-arc-flag
|
737 |
+
cmd_list.append(cast_fn(1 if arguments[3] > 0.5 else 0)) # sweep-flag
|
738 |
+
cmd_list.append(cast_fn(arguments[8])) # x
|
739 |
+
cmd_list.append(cast_fn(arguments[9])) # y
|
740 |
+
|
741 |
+
return cmd_list
|
742 |
+
|
743 |
+
|
744 |
+
############## UTILS FOR CONVERTING SVGS/VECTORS TO IMAGES ###################
|
745 |
+
|
746 |
+
# From Infer notebook
|
747 |
+
start = ("""<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www."""
|
748 |
+
"""w3.org/1999/xlink" width="256px" height="256px" style="-ms-trans"""
|
749 |
+
"""form: rotate(360deg); -webkit-transform: rotate(360deg); transfo"""
|
750 |
+
"""rm: rotate(360deg);" preserveAspectRatio="xMidYMid meet" viewBox"""
|
751 |
+
"""="0 0 24 24"><path d=\"""")
|
752 |
+
end = """\" fill="currentColor"/></svg>"""
|
753 |
+
|
754 |
+
COMMAND_RX = re.compile("([MmLlHhVvCcSsQqTtAaZz])")
|
755 |
+
FLOAT_RX = re.compile("[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?") # noqa
|
756 |
+
|
757 |
+
|
758 |
+
def svg_html_to_path_string(svg):
|
759 |
+
return svg.replace(start, '').replace(end, '')
|
760 |
+
|
761 |
+
|
762 |
+
def _tokenize(pathdef):
|
763 |
+
"""Returns each svg token from path list."""
|
764 |
+
# e.g.: 'm0.1-.5c0,6' -> m', '0.1, '-.5', 'c', '0', '6'
|
765 |
+
for x in COMMAND_RX.split(pathdef):
|
766 |
+
if x != '' and x in 'MmLlHhVvCcSsQqTtAaZz':
|
767 |
+
yield x
|
768 |
+
for token in FLOAT_RX.findall(x):
|
769 |
+
yield token
|
770 |
+
|
771 |
+
|
772 |
+
def path_string_to_tokenized_commands(path):
|
773 |
+
"""Tokenizes the given path string.
|
774 |
+
|
775 |
+
E.g.:
|
776 |
+
Given M 0.5 0.5 l 0.25 0.25 z
|
777 |
+
Returns [['M', '0.5', '0.5'], ['l', '0.25', '0.25'], ['z']]
|
778 |
+
"""
|
779 |
+
new_path = []
|
780 |
+
current_cmd = []
|
781 |
+
for token in _tokenize(path):
|
782 |
+
if len(current_cmd) > 0:
|
783 |
+
if token in 'MmLlHhVvCcSsQqTtAaZz':
|
784 |
+
# cmd ended, convert to vector and add to new_path
|
785 |
+
new_path.append(current_cmd)
|
786 |
+
current_cmd = [token]
|
787 |
+
else:
|
788 |
+
# add arg to command
|
789 |
+
current_cmd.append(token)
|
790 |
+
else:
|
791 |
+
# add to start new cmd
|
792 |
+
current_cmd.append(token)
|
793 |
+
|
794 |
+
if current_cmd:
|
795 |
+
# process command still unprocessed
|
796 |
+
new_path.append(current_cmd)
|
797 |
+
|
798 |
+
return new_path
|
799 |
+
|
800 |
+
|
801 |
+
def separate_substructures(tokenized_commands):
|
802 |
+
"""Returns a list of SVG substructures."""
|
803 |
+
# every moveTo command starts a new substructure
|
804 |
+
# an SVG substructure is a subpath that closes on itself
|
805 |
+
# such as the outter and the inner edge of the character `o`
|
806 |
+
substructures = []
|
807 |
+
curr = []
|
808 |
+
for cmd in tokenized_commands:
|
809 |
+
if cmd[0] in 'mM' and len(curr) > 0:
|
810 |
+
substructures.append(curr)
|
811 |
+
curr = []
|
812 |
+
curr.append(cmd)
|
813 |
+
if len(curr) > 0:
|
814 |
+
substructures.append(curr)
|
815 |
+
return substructures
|
816 |
+
|
817 |
+
|
818 |
+
def postprocess(svg, dist_thresh=2., skip=False):
|
819 |
+
path = svg_html_to_path_string(svg)
|
820 |
+
#print(svg)
|
821 |
+
svg_template = svg.replace(path, '{}')
|
822 |
+
tokenized_commands = path_string_to_tokenized_commands(path)
|
823 |
+
|
824 |
+
def dist(a, b):
|
825 |
+
return np.sqrt((float(a[0]) - float(b[0]))**2 + (float(a[1]) - float(b[1]))**2)
|
826 |
+
|
827 |
+
def are_close_together(a, b, t):
|
828 |
+
return dist(a, b) < t
|
829 |
+
|
830 |
+
# first, go through each start/end point and merge if they're close enough
|
831 |
+
# together (that is, make end point the same as the start point).
|
832 |
+
# TODO: there are better ways of doing this, in a way that propagates errors.
|
833 |
+
# back (so if total error is 0.2, go through all N commands in this
|
834 |
+
# substructure and fix each by 0.2/N (unless they have 0 vertical change))
|
835 |
+
# NOTE: this is the same.
|
836 |
+
substructures = separate_substructures(tokenized_commands)
|
837 |
+
# print(len(substructures))# 7578
|
838 |
+
|
839 |
+
previous_substructure_endpoint = (0., 0.,)
|
840 |
+
for substructure in substructures:
|
841 |
+
# first, if the last substructure's endpoint was updated, we must update
|
842 |
+
# the start point of this one to reflect the opposite update
|
843 |
+
substructure[0][-2] = str(float(substructure[0][-2]) -
|
844 |
+
previous_substructure_endpoint[0])
|
845 |
+
substructure[0][-1] = str(float(substructure[0][-1]) -
|
846 |
+
previous_substructure_endpoint[1])
|
847 |
+
|
848 |
+
start = list(map(float, substructure[0][-2:]))
|
849 |
+
curr_pos = (0., 0.)
|
850 |
+
for cmd in substructure:
|
851 |
+
curr_pos, _ = _update_curr_pos(curr_pos, cmd, (0., 0.))
|
852 |
+
if are_close_together(start, curr_pos, dist_thresh):
|
853 |
+
new_point = np.array(start)
|
854 |
+
previous_substructure_endpoint = ((new_point[0] - curr_pos[0]),
|
855 |
+
(new_point[1] - curr_pos[1]))
|
856 |
+
substructure[-1][-2] = str(float(substructure[-1][-2]) +
|
857 |
+
(new_point[0] - curr_pos[0]))
|
858 |
+
substructure[-1][-1] = str(float(substructure[-1][-1]) +
|
859 |
+
(new_point[1] - curr_pos[1]))
|
860 |
+
if substructure[-1][0] in 'cC':
|
861 |
+
substructure[-1][-4] = str(float(substructure[-1][-4]) +
|
862 |
+
(new_point[0] - curr_pos[0]))
|
863 |
+
substructure[-1][-3] = str(float(substructure[-1][-3]) +
|
864 |
+
(new_point[1] - curr_pos[1]))
|
865 |
+
|
866 |
+
if skip:
|
867 |
+
return svg_template.format(' '.join([' '.join(' '.join(cmd) for cmd in s)
|
868 |
+
for s in substructures]))
|
869 |
+
|
870 |
+
def cosa(x, y):
|
871 |
+
return (x[0] * y[0] + x[1] * y[1]) / ((np.sqrt(x[0]**2 + x[1]**2) * np.sqrt(y[0]**2 + y[1]**2)))
|
872 |
+
|
873 |
+
def rotate(a, x, y):
|
874 |
+
return (x * np.cos(a) - y * np.sin(a), y * np.cos(a) + x * np.sin(a))
|
875 |
+
# second, gotta find adjacent bezier curves and, if their control points
|
876 |
+
# are well enough aligned, fully align them
|
877 |
+
for substructure in substructures:
|
878 |
+
curr_pos = (0., 0.)
|
879 |
+
new_curr_pos, _ = _update_curr_pos((0., 0.,), substructure[0], (0., 0.))
|
880 |
+
|
881 |
+
for cmd_idx in range(1, len(substructure)):
|
882 |
+
prev_cmd = substructure[cmd_idx-1]
|
883 |
+
cmd = substructure[cmd_idx]
|
884 |
+
|
885 |
+
new_new_curr_pos, _ = _update_curr_pos(new_curr_pos, cmd, (0., 0.))
|
886 |
+
|
887 |
+
if cmd[0] == 'c':
|
888 |
+
if prev_cmd[0] == 'c':
|
889 |
+
# check the vectors and update if needed
|
890 |
+
# previous control pt wrt new curr point
|
891 |
+
prev_ctr_point = (curr_pos[0] + float(prev_cmd[3]) - new_curr_pos[0],
|
892 |
+
curr_pos[1] + float(prev_cmd[4]) - new_curr_pos[1])
|
893 |
+
ctr_point = (float(cmd[1]), float(cmd[2]))
|
894 |
+
|
895 |
+
if -1. < cosa(prev_ctr_point, ctr_point) < -0.95:
|
896 |
+
# calculate exact angle between the two vectors
|
897 |
+
angle_diff = (np.pi - np.arccos(cosa(prev_ctr_point, ctr_point)))/2
|
898 |
+
|
899 |
+
# rotate each vector by angle/2 in the correct direction for each.
|
900 |
+
sign = np.sign(np.cross(prev_ctr_point, ctr_point))
|
901 |
+
new_ctr_point = rotate(sign * angle_diff, *ctr_point)
|
902 |
+
new_prev_ctr_point = rotate(-sign * angle_diff, *prev_ctr_point)
|
903 |
+
|
904 |
+
# override the previous control points
|
905 |
+
# (which has to be wrt previous curr position)
|
906 |
+
substructure[cmd_idx-1][3] = str(new_prev_ctr_point[0] -
|
907 |
+
curr_pos[0] + new_curr_pos[0])
|
908 |
+
substructure[cmd_idx-1][4] = str(new_prev_ctr_point[1] -
|
909 |
+
curr_pos[1] + new_curr_pos[1])
|
910 |
+
substructure[cmd_idx][1] = str(new_ctr_point[0])
|
911 |
+
substructure[cmd_idx][2] = str(new_ctr_point[1])
|
912 |
+
|
913 |
+
curr_pos = new_curr_pos
|
914 |
+
new_curr_pos = new_new_curr_pos
|
915 |
+
|
916 |
+
# print('0',substructures)
|
917 |
+
return svg_template.format(' '.join([' '.join(' '.join(cmd) for cmd in s)
|
918 |
+
for s in substructures]))
|
919 |
+
|
920 |
+
|
921 |
+
|
922 |
+
|
923 |
+
|
924 |
+
# def get_means_stdevs(data_dir):
|
925 |
+
# """Returns the means and stdev saved in data_dir."""
|
926 |
+
# if data_dir not in means_stdevs:
|
927 |
+
# with tf.gfile.Open(os.path.join(data_dir, 'mean.npz'), 'r') as f:
|
928 |
+
# mean_npz = np.load(f)
|
929 |
+
# with tf.gfile.Open(os.path.join(data_dir, 'stdev.npz'), 'r') as f:
|
930 |
+
# stdev_npz = np.load(f)
|
931 |
+
# means_stdevs[data_dir] = (mean_npz, stdev_npz)
|
932 |
+
# return means_stdevs[data_dir]
|
933 |
+
|
934 |
+
|
935 |
+
|
936 |
+
|
937 |
+
###############
|
938 |
+
|
939 |
+
|
940 |
+
def convert_to_svg(decoder_output, categorical=False):
|
941 |
+
converted = []
|
942 |
+
for example in decoder_output:
|
943 |
+
converted.append(_vector_to_svg(example, True, categorical=categorical))
|
944 |
+
return np.array(converted)
|
945 |
+
|
946 |
+
|
947 |
+
def create_image_conversion_fn(max_outputs, categorical=False):
|
948 |
+
"""Binds the number of outputs to the image conversion fn (to svg or png)."""
|
949 |
+
def convert_to_svg(decoder_output):
|
950 |
+
converted = []
|
951 |
+
for example in decoder_output:
|
952 |
+
if len(converted) == max_outputs:
|
953 |
+
break
|
954 |
+
converted.append(_vector_to_svg(example, True, categorical=categorical))
|
955 |
+
return np.array(converted)
|
956 |
+
|
957 |
+
return convert_to_svg
|
958 |
+
|
959 |
+
|
960 |
+
################### UTILS FOR CREATING TF SUMMARIES ##########################
|
961 |
+
def _make_encoded_image(img_tensor):
|
962 |
+
pil_img = Image.fromarray(np.squeeze(img_tensor * 255).astype(np.uint8), mode='L')
|
963 |
+
buff = io.BytesIO()
|
964 |
+
pil_img.save(buff, format='png')
|
965 |
+
encoded_image = buff.getvalue()
|
966 |
+
return encoded_image
|
967 |
+
|
968 |
+
|
969 |
+
################### CHECK GLYPH/PATH VALID ##############################################
|
970 |
+
def is_valid_glyph(g):
|
971 |
+
is_09 = 48 <= g['uni'] <= 57
|
972 |
+
is_capital_az = 65 <= g['uni'] <= 90
|
973 |
+
is_az = 97 <= g['uni'] <= 122
|
974 |
+
is_valid_dims = g['width'] != 0 and g['vwidth'] != 0
|
975 |
+
return (is_09 or is_capital_az or is_az) and is_valid_dims
|
976 |
+
|
977 |
+
|
978 |
+
def is_valid_path(pathunibfp):
|
979 |
+
# print(len(pathunibfp[0]))
|
980 |
+
if len(pathunibfp[0])>70:
|
981 |
+
print("!!!more than 400",len(pathunibfp[0]))
|
982 |
+
# sys.exit()
|
983 |
+
return pathunibfp[0] and len(pathunibfp[0]) <= 70,len(pathunibfp[0])
|
984 |
+
|
985 |
+
|
986 |
+
################### DATASET PROCESSING #######################################
|
987 |
+
def convert_to_path(g):
|
988 |
+
"""Converts SplineSet in SFD font to str path."""
|
989 |
+
path = _sfd_to_path_list(g)
|
990 |
+
path = _add_missing_cmds(path, remove_zs=False)
|
991 |
+
path = _normalize_based_on_viewbox(path, '0 0 {} {}'.format(g['width'], g['vwidth']))
|
992 |
+
return path, g['uni'], g['binary_fp']
|
993 |
+
def convert_simple_vector_to_path(seq):
|
994 |
+
path=[]
|
995 |
+
for i in range(seq.shape[0]):
|
996 |
+
# seq_i = seq[i]
|
997 |
+
path_i=[]
|
998 |
+
cmd = np.argmax(seq[i][:4])
|
999 |
+
# args = seq[i][4:]
|
1000 |
+
p0 = seq[i][4:6]
|
1001 |
+
p1 = seq[i][6:8]
|
1002 |
+
p2 = seq[i][8:10]
|
1003 |
+
if cmd == 0:
|
1004 |
+
break
|
1005 |
+
elif cmd==1:
|
1006 |
+
path_i.append('M')
|
1007 |
+
path_i.append(str(p2[0]))
|
1008 |
+
path_i.append(str(p2[1]))
|
1009 |
+
elif cmd==2:
|
1010 |
+
path_i.append('L')
|
1011 |
+
path_i.append(str(p2[0]))
|
1012 |
+
path_i.append(str(p2[1]))
|
1013 |
+
elif cmd==3:
|
1014 |
+
path_i.append('C')
|
1015 |
+
path_i.append(str(p0[0]))
|
1016 |
+
path_i.append(str(p0[1]))
|
1017 |
+
path_i.append(str(p1[0]))
|
1018 |
+
path_i.append(str(p1[1]))
|
1019 |
+
path_i.append(str(p2[0]))
|
1020 |
+
path_i.append(str(p2[1]))
|
1021 |
+
else:
|
1022 |
+
print("wrong!!! to path")
|
1023 |
+
sys.exit()
|
1024 |
+
path.append(path_i)
|
1025 |
+
return path
|
1026 |
+
# print("jjj")
|
1027 |
+
def clockwise(seq):
|
1028 |
+
#pdb.set_trace()
|
1029 |
+
path=convert_simple_vector_to_path(seq)
|
1030 |
+
path = _canonicalize(path)
|
1031 |
+
final = {}
|
1032 |
+
final['rendered'] = _per_step_render(path, absolute=True)
|
1033 |
+
vector = _path_to_vector(path, categorical=True)
|
1034 |
+
vector = np.array(vector)
|
1035 |
+
# print(vector.shape,vector[:,9])# note vector: 12,30
|
1036 |
+
|
1037 |
+
vector = np.concatenate([np.take(vector, [0, 4, 5, 9], axis=-1), vector[..., -6:]], axis=-1)
|
1038 |
+
final['seq_len'] = np.shape(vector)[0]
|
1039 |
+
vector = _append_eos(vector.tolist(), True, 10)
|
1040 |
+
final['sequence'] = np.concatenate((vector, np.zeros(((70 - final['seq_len']), 10))), 0)
|
1041 |
+
final['rendered'] = np.reshape(final['rendered'][..., 0], [64 * 64]).astype(np.float32).tolist()
|
1042 |
+
return final
|
1043 |
+
|
1044 |
+
|
1045 |
+
def create_example(pathunibfp):
|
1046 |
+
"""Bulk of dataset processing. Converts str path to np array"""
|
1047 |
+
path, uni, binary_fp = pathunibfp
|
1048 |
+
final = {}
|
1049 |
+
|
1050 |
+
# zoom out
|
1051 |
+
path = _zoom_out(path)
|
1052 |
+
# make clockwise
|
1053 |
+
path = _canonicalize(path)
|
1054 |
+
|
1055 |
+
# render path for training
|
1056 |
+
final['rendered'] = _per_step_render(path, absolute=True)
|
1057 |
+
|
1058 |
+
# make path relative
|
1059 |
+
#path = _make_relative(path) # note 不rela 直接是绝对的
|
1060 |
+
# convert to vector
|
1061 |
+
vector = _path_to_vector(path, categorical=True)
|
1062 |
+
|
1063 |
+
|
1064 |
+
|
1065 |
+
# path2 = _vector_to_path(vector)# note vector转成path
|
1066 |
+
#print(path2)
|
1067 |
+
#print(path==path2)
|
1068 |
+
|
1069 |
+
vector = np.array(vector)
|
1070 |
+
# print(vector.shape,vector[:,9])# note vector: 12,30
|
1071 |
+
vector = np.concatenate([np.take(vector, [0, 4, 5, 9], axis=-1), vector[..., -6:]], axis=-1)
|
1072 |
+
|
1073 |
+
|
1074 |
+
|
1075 |
+
#path2 = _vector_to_path(vector)
|
1076 |
+
#print(path,"\nhhh",path2)
|
1077 |
+
|
1078 |
+
|
1079 |
+
# print("hhh",vector)
|
1080 |
+
# print(render(vector))
|
1081 |
+
# sys.exit()
|
1082 |
+
# count some stats
|
1083 |
+
final['seq_len'] = np.shape(vector)[0]
|
1084 |
+
# final['class'] = int(_map_uni_to_alphanum(uni))
|
1085 |
+
final['class'] = int(_map_uni_to_alpha(uni))
|
1086 |
+
final['binary_fp'] = str(binary_fp)
|
1087 |
+
|
1088 |
+
# append eos
|
1089 |
+
vector = _append_eos(vector.tolist(), True, 10)
|
1090 |
+
|
1091 |
+
# pad path to 51 (with eos)
|
1092 |
+
# pdb.set_trace()
|
1093 |
+
# if final['seq_len']>50:
|
1094 |
+
|
1095 |
+
# print( final['seq_len'])
|
1096 |
+
final['sequence'] = np.concatenate((vector, np.zeros(((70 - final['seq_len']), 10))), 0)
|
1097 |
+
#seq = final['sequence']
|
1098 |
+
|
1099 |
+
# new_path = convert_simple_vector_to_path(seq)
|
1100 |
+
# print(new_path,path2==new_path)
|
1101 |
+
|
1102 |
+
# sys.exit()
|
1103 |
+
# make pure list:
|
1104 |
+
# use last channel only
|
1105 |
+
final['rendered'] = np.reshape(final['rendered'][..., 0], [64 * 64]).astype(np.float32).tolist()
|
1106 |
+
final['sequence'] = np.reshape(final['sequence'], [71 * 10]).astype(np.float32).tolist()
|
1107 |
+
final['class'] = np.reshape(final['class'], [1]).astype(np.int64).tolist()
|
1108 |
+
final['seq_len'] = np.reshape(final['seq_len'], [1]).astype(np.int64).tolist()
|
1109 |
+
return final
|
1110 |
+
|
1111 |
+
|
1112 |
+
def mean_to_example(mean_stdev):
|
1113 |
+
"""Converts the found mean and stdev to example."""
|
1114 |
+
# mean_stdev is a dict
|
1115 |
+
mean_stdev['mean'] = np.reshape(mean_stdev['mean'], [10]).astype(np.float32).tolist()
|
1116 |
+
mean_stdev['variance'] = np.reshape(mean_stdev['variance'], [10]).astype(np.float32).tolist()
|
1117 |
+
mean_stdev['stddev'] = np.reshape(mean_stdev['stddev'], [10]).astype(np.float32).tolist()
|
1118 |
+
mean_stdev['count'] = np.reshape(mean_stdev['count'], [1]).astype(np.int64).tolist()
|
1119 |
+
return mean_stdev
|
1120 |
+
|
1121 |
+
|
1122 |
+
################### CHECK VALID ##############################################
|
1123 |
+
class MeanStddev:
|
1124 |
+
"""Accumulator to compute the mean/stdev of svg commands."""
|
1125 |
+
|
1126 |
+
def create_accumulator(self):
|
1127 |
+
curr_sum = np.zeros([10])
|
1128 |
+
sum_sq = np.zeros([10])
|
1129 |
+
return (curr_sum, sum_sq, 0) # x, x^2, count
|
1130 |
+
|
1131 |
+
def add_input(self, sum_count, new_input):
|
1132 |
+
(curr_sum, sum_sq, count) = sum_count
|
1133 |
+
# new_input is a dict with keys = ['seq_len', 'sequence']
|
1134 |
+
new_seq_len = new_input['seq_len'][0] # Line #754 'seq_len' is a list of one int
|
1135 |
+
assert isinstance(new_seq_len, int), print(type(new_seq_len))
|
1136 |
+
|
1137 |
+
# remove padding and eos from sequence
|
1138 |
+
assert isinstance(new_input['sequence'], list), print(type(new_input['sequence']))
|
1139 |
+
new_input_np = np.reshape(np.array(new_input['sequence']), [-1, 10])
|
1140 |
+
assert isinstance(new_input_np, np.ndarray), print(type())
|
1141 |
+
assert new_input_np.shape[0] >= new_seq_len
|
1142 |
+
new_input_np = new_input_np[:new_seq_len, :]
|
1143 |
+
|
1144 |
+
# accumulate new_sum and new_sum_sq
|
1145 |
+
new_sum = np.sum([curr_sum, np.sum(new_input_np, axis=0)], axis=0)
|
1146 |
+
new_sum_sq = np.sum([sum_sq, np.sum(np.power(new_input_np, 2), axis=0)],
|
1147 |
+
axis=0)
|
1148 |
+
return new_sum, new_sum_sq, count + new_seq_len
|
1149 |
+
|
1150 |
+
def merge_accumulators(self, accumulators):
|
1151 |
+
curr_sums, sum_sqs, counts = list(zip(*accumulators))
|
1152 |
+
return np.sum(curr_sums, axis=0), np.sum(sum_sqs, axis=0), np.sum(counts)
|
1153 |
+
|
1154 |
+
def extract_output(self, sum_count):
|
1155 |
+
(curr_sum, curr_sum_sq, count) = sum_count
|
1156 |
+
if count:
|
1157 |
+
mean = np.divide(curr_sum, count)
|
1158 |
+
variance = np.divide(curr_sum_sq, count) - np.power(mean, 2)
|
1159 |
+
# -ve value could happen due to rounding
|
1160 |
+
variance = np.max([variance, np.zeros(np.shape(variance))], axis=0)
|
1161 |
+
stddev = np.sqrt(variance)
|
1162 |
+
return {
|
1163 |
+
'mean': mean,
|
1164 |
+
'variance': variance,
|
1165 |
+
'stddev': stddev,
|
1166 |
+
'count': count
|
1167 |
+
}
|
1168 |
+
else:
|
1169 |
+
return {
|
1170 |
+
'mean': float('NaN'),
|
1171 |
+
'variance': float('NaN'),
|
1172 |
+
'stddev': float('NaN'),
|
1173 |
+
'count': 0
|
1174 |
+
}
|
data_utils/write_data_to_dirs.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import multiprocessing as mp
|
3 |
+
import os
|
4 |
+
import pickle
|
5 |
+
import numpy as np
|
6 |
+
from data_utils import svg_utils
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
def exist_empty_imgs(imgs_array, num_chars):
|
10 |
+
for char_id in range(num_chars):
|
11 |
+
print(np.max(imgs_array[char_id]))
|
12 |
+
input()
|
13 |
+
if np.max(imgs_array[char_id]) == 0:
|
14 |
+
return True
|
15 |
+
return False
|
16 |
+
|
17 |
+
def create_db(opts, output_path, log_path):
|
18 |
+
charset = open(f"{opts.data_path}/char_set/{opts.language}.txt", 'r').read()
|
19 |
+
print("Process sfd to npy files in dirs....")
|
20 |
+
sdf_path = os.path.join(opts.sfd_path, opts.language, opts.split)
|
21 |
+
all_font_ids = sorted(os.listdir(sdf_path))
|
22 |
+
num_fonts = len(all_font_ids)
|
23 |
+
num_fonts_w = len(str(num_fonts))
|
24 |
+
print(f"Number {opts.split} fonts before processing", num_fonts)
|
25 |
+
num_processes = mp.cpu_count() - 1
|
26 |
+
fonts_per_process = num_fonts // num_processes + 1
|
27 |
+
num_chars = len(charset)
|
28 |
+
num_chars_w = len(str(num_chars))
|
29 |
+
|
30 |
+
|
31 |
+
# import ipdb; ipdb.set_trace()
|
32 |
+
|
33 |
+
def process(process_id):
|
34 |
+
valid_chars = []
|
35 |
+
invalid_path = []
|
36 |
+
invalid_glypts = []
|
37 |
+
|
38 |
+
cur_process_log_file = open(os.path.join(log_path, f'log_{opts.split}_{process_id}.txt'), 'w')
|
39 |
+
for i in tqdm(range(process_id * fonts_per_process, (process_id + 1) * fonts_per_process)):
|
40 |
+
if i >= num_fonts:
|
41 |
+
break
|
42 |
+
|
43 |
+
font_id = all_font_ids[i]
|
44 |
+
cur_font_sfd_dir = os.path.join(sdf_path, font_id)
|
45 |
+
cur_font_glyphs = []
|
46 |
+
|
47 |
+
if not os.path.exists(os.path.join(cur_font_sfd_dir, 'imgs_' + str(opts.img_size) + '.npy')):
|
48 |
+
continue
|
49 |
+
|
50 |
+
# a whole font as an entry
|
51 |
+
for char_id in range(num_chars):
|
52 |
+
# print('char_id :',char_id)
|
53 |
+
if not os.path.exists(os.path.join(cur_font_sfd_dir, '{}_{num:0{width}}.sfd'.format(font_id, num=char_id, width=num_chars_w))):
|
54 |
+
break
|
55 |
+
|
56 |
+
char_desp_f = open(os.path.join(cur_font_sfd_dir, '{}_{num:0{width}}.txt'.format(font_id, num=char_id, width=num_chars_w)), 'r')
|
57 |
+
char_desp = char_desp_f.readlines()
|
58 |
+
sfd_f = open(os.path.join(cur_font_sfd_dir, '{}_{num:0{width}}.sfd'.format(font_id, num=char_id, width=num_chars_w)), 'r')
|
59 |
+
sfd = sfd_f.read()
|
60 |
+
|
61 |
+
uni = int(char_desp[0].strip())
|
62 |
+
width = int(char_desp[1].strip())
|
63 |
+
vwidth = int(char_desp[2].strip())
|
64 |
+
char_idx = char_desp[3].strip()
|
65 |
+
font_idx = char_desp[4].strip()
|
66 |
+
|
67 |
+
cur_glyph = {}
|
68 |
+
cur_glyph['uni'] = uni
|
69 |
+
cur_glyph['width'] = width
|
70 |
+
cur_glyph['vwidth'] = vwidth
|
71 |
+
cur_glyph['sfd'] = sfd
|
72 |
+
cur_glyph['id'] = char_idx
|
73 |
+
cur_glyph['binary_fp'] = font_idx
|
74 |
+
|
75 |
+
if not svg_utils.is_valid_glyph(cur_glyph):
|
76 |
+
msg = f"font {font_idx}, char {char_idx} is not a valid glyph\n"
|
77 |
+
invalid_path.glypts([font_idx, int(char_idx), charset[int(char_idx)]])
|
78 |
+
cur_process_log_file.write(msg)
|
79 |
+
char_desp_f.close()
|
80 |
+
sfd_f.close()
|
81 |
+
# use the font whose all glyphs are valid
|
82 |
+
break
|
83 |
+
pathunibfp = svg_utils.convert_to_path(cur_glyph)
|
84 |
+
|
85 |
+
if not svg_utils.is_valid_path(pathunibfp):
|
86 |
+
msg = f"font {font_idx}, char {char_idx}'s sfd is not a valid path\n"
|
87 |
+
invalid_path.append([font_idx, int(char_idx), charset[int(char_idx)]])
|
88 |
+
cur_process_log_file.write(msg)
|
89 |
+
char_desp_f.close()
|
90 |
+
sfd_f.close()
|
91 |
+
break
|
92 |
+
valid_chars.append([font_idx, int(char_idx), charset[int(char_idx)]])
|
93 |
+
example = svg_utils.create_example(pathunibfp)
|
94 |
+
|
95 |
+
cur_font_glyphs.append(example)
|
96 |
+
char_desp_f.close()
|
97 |
+
sfd_f.close()
|
98 |
+
|
99 |
+
if len(cur_font_glyphs) == num_chars:
|
100 |
+
# use the font whose all glyphs are valid
|
101 |
+
# merge the whole font
|
102 |
+
|
103 |
+
rendered = np.load(os.path.join(cur_font_sfd_dir, 'imgs_' + str(opts.img_size) + '.npy'))
|
104 |
+
|
105 |
+
if (rendered[0] == rendered[1]).all() == True:
|
106 |
+
continue
|
107 |
+
|
108 |
+
sequence = []
|
109 |
+
seq_len = []
|
110 |
+
binaryfp = []
|
111 |
+
char_class = []
|
112 |
+
for char_id in range(num_chars):
|
113 |
+
example = cur_font_glyphs[char_id]
|
114 |
+
sequence.append(example['sequence'])
|
115 |
+
seq_len.append(example['seq_len'])
|
116 |
+
char_class.append(example['class'])
|
117 |
+
binaryfp = example['binary_fp']
|
118 |
+
if not os.path.exists(os.path.join(output_path, '{num:0{width}}'.format(num=i, width=num_fonts_w))):
|
119 |
+
os.mkdir(os.path.join(output_path, '{num:0{width}}'.format(num=i, width=num_fonts_w)))
|
120 |
+
|
121 |
+
np.save(os.path.join(output_path, '{num:0{width}}'.format(num=i, width=num_fonts_w), 'sequence.npy'), np.array(sequence))
|
122 |
+
np.save(os.path.join(output_path, '{num:0{width}}'.format(num=i, width=num_fonts_w), 'seq_len.npy'), np.array(seq_len))
|
123 |
+
np.save(os.path.join(output_path, '{num:0{width}}'.format(num=i, width=num_fonts_w), 'class.npy'), np.array(char_class))
|
124 |
+
np.save(os.path.join(output_path, '{num:0{width}}'.format(num=i, width=num_fonts_w), 'font_id.npy'), np.array(binaryfp))
|
125 |
+
np.save(os.path.join(output_path, '{num:0{width}}'.format(num=i, width=num_fonts_w), 'rendered_' + str(opts.img_size) + '.npy'), rendered)
|
126 |
+
|
127 |
+
print("valid_chars", len(valid_chars))
|
128 |
+
print("invalid_path:", invalid_path)
|
129 |
+
print("invalid_glypts:",invalid_glypts)
|
130 |
+
|
131 |
+
processes = [mp.Process(target=process, args=[pid]) for pid in range(num_processes)]
|
132 |
+
|
133 |
+
for p in processes:
|
134 |
+
p.start()
|
135 |
+
for p in processes:
|
136 |
+
p.join()
|
137 |
+
|
138 |
+
print("Finished processing all sfd files, logs (invalid glyphs and paths) are saved to", log_path)
|
139 |
+
|
140 |
+
|
141 |
+
def cal_mean_stddev(opts, output_path):
|
142 |
+
print("Calculating all glyphs' mean stddev ....")
|
143 |
+
charset = open(f"{opts.data_path}/char_set/{opts.language}.txt", 'r').read()
|
144 |
+
font_paths = []
|
145 |
+
for root, dirs, files in os.walk(output_path):
|
146 |
+
for dir_name in dirs:
|
147 |
+
font_paths.append(os.path.join(output_path, dir_name))
|
148 |
+
|
149 |
+
font_paths.sort()
|
150 |
+
num_fonts = len(font_paths)
|
151 |
+
num_processes = mp.cpu_count() - 1
|
152 |
+
fonts_per_process = num_fonts // num_processes + 1
|
153 |
+
num_chars = len(charset)
|
154 |
+
manager = mp.Manager()
|
155 |
+
return_dict = manager.dict()
|
156 |
+
main_stddev_accum = svg_utils.MeanStddev()
|
157 |
+
print(main_stddev_accum)
|
158 |
+
|
159 |
+
def process(process_id, return_dict):
|
160 |
+
mean_stddev_accum = svg_utils.MeanStddev()
|
161 |
+
cur_sum_count = mean_stddev_accum.create_accumulator()
|
162 |
+
for i in range(process_id * fonts_per_process, (process_id + 1) * fonts_per_process):
|
163 |
+
if i >= num_fonts:
|
164 |
+
break
|
165 |
+
cur_font_path = font_paths[i]
|
166 |
+
for charid in range(num_chars):
|
167 |
+
cur_font_char = {}
|
168 |
+
cur_font_char['seq_len'] = np.load(os.path.join(cur_font_path, 'seq_len.npy')).tolist()[charid]
|
169 |
+
cur_font_char['sequence'] = np.load(os.path.join(cur_font_path, 'sequence.npy')).tolist()[charid]
|
170 |
+
# print(cur_font_char)
|
171 |
+
cur_sum_count = mean_stddev_accum.add_input(cur_sum_count, cur_font_char)
|
172 |
+
return_dict[process_id] = cur_sum_count
|
173 |
+
processes = [mp.Process(target=process, args=[pid, return_dict]) for pid in range(num_processes)]
|
174 |
+
|
175 |
+
for p in processes:
|
176 |
+
p.start()
|
177 |
+
for p in processes:
|
178 |
+
p.join()
|
179 |
+
|
180 |
+
merged_sum_count = main_stddev_accum.merge_accumulators(return_dict.values())
|
181 |
+
output = main_stddev_accum.extract_output(merged_sum_count)
|
182 |
+
print('output :', output)
|
183 |
+
mean = output['mean']
|
184 |
+
stdev = output['stddev']
|
185 |
+
print('mean :', mean)
|
186 |
+
mean = np.concatenate((np.zeros([4]), mean[4:]), axis=0)
|
187 |
+
stdev = np.concatenate((np.ones([4]), stdev[4:]), axis=0)
|
188 |
+
# finally, save the mean and stddev files
|
189 |
+
output_path_ = os.path.join(opts.output_path, opts.language)
|
190 |
+
np.save(os.path.join(output_path_, 'mean'), mean)
|
191 |
+
np.save(os.path.join(output_path_, 'stdev'), stdev)
|
192 |
+
|
193 |
+
# rename npy to npz, don't mind about it, just some legacy issue
|
194 |
+
os.rename(os.path.join(output_path_, 'mean.npy'), os.path.join(output_path_, 'mean.npz'))
|
195 |
+
os.rename(os.path.join(output_path_, 'stdev.npy'), os.path.join(output_path_, 'stdev.npz'))
|
196 |
+
|
197 |
+
|
198 |
+
def main():
|
199 |
+
parser = argparse.ArgumentParser(description="LMDB creation")
|
200 |
+
parser.add_argument("--language", type=str, default='eng', choices=['eng', 'chn', 'tha'])
|
201 |
+
parser.add_argument("--data_path", type=str, default='./Font_Dataset', help="Path to Dataset")
|
202 |
+
parser.add_argument("--ttf_path", type=str, default='../data/font_ttfs')
|
203 |
+
parser.add_argument('--sfd_path', type=str, default='../data/font_sfds')
|
204 |
+
parser.add_argument("--output_path", type=str, default='../data/vecfont_dataset_/', help="Path to write the database to")
|
205 |
+
parser.add_argument('--img_size', type=int, default=64, help="the height and width of glyph images")
|
206 |
+
parser.add_argument("--split", type=str, default='train')
|
207 |
+
# parser.add_argument("--log_dir", type=str, default='../data/font_sfds/log/')
|
208 |
+
parser.add_argument("--phase", type=int, default=0, choices=[0, 1, 2],
|
209 |
+
help="0 all, 1 create db, 2 cal stddev")
|
210 |
+
|
211 |
+
opts = parser.parse_args()
|
212 |
+
assert os.path.exists(opts.sfd_path), "specified sfd glyphs path does not exist"
|
213 |
+
|
214 |
+
output_path = os.path.join(opts.output_path, opts.language, opts.split)
|
215 |
+
log_path = os.path.join(opts.sfd_path, opts.language, 'log')
|
216 |
+
|
217 |
+
if not os.path.exists(output_path):
|
218 |
+
os.makedirs(output_path)
|
219 |
+
|
220 |
+
if not os.path.exists(log_path):
|
221 |
+
os.makedirs(log_path)
|
222 |
+
|
223 |
+
if opts.phase <= 1:
|
224 |
+
create_db(opts, output_path, log_path)
|
225 |
+
|
226 |
+
if opts.phase <= 2 and opts.split == 'train':
|
227 |
+
cal_mean_stddev(opts, output_path)
|
228 |
+
|
229 |
+
|
230 |
+
if __name__ == "__main__":
|
231 |
+
main()
|
data_utils/write_glyph_imgs.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from PIL import Image
|
3 |
+
from PIL import ImageDraw
|
4 |
+
from PIL import ImageFont
|
5 |
+
import argparse
|
6 |
+
import numpy as np
|
7 |
+
import os
|
8 |
+
import multiprocessing as mp
|
9 |
+
from tqdm import tqdm
|
10 |
+
char_error = 0
|
11 |
+
|
12 |
+
def get_bbox(img):
|
13 |
+
img = 255 - np.array(img)
|
14 |
+
sum_x = np.sum(img, axis=0)
|
15 |
+
sum_y = np.sum(img, axis=1)
|
16 |
+
range_x = np.where(sum_x > 0)
|
17 |
+
width = range_x[0][-1] - range_x[0][0]
|
18 |
+
range_y = np.where(sum_y > 0)
|
19 |
+
height = range_y[0][-1] - range_y[0][0]
|
20 |
+
return width, height
|
21 |
+
|
22 |
+
def write_glyph_imgs_mp(opts):
|
23 |
+
"""Useing multiprocessing to render glyph images"""
|
24 |
+
charset = open(f"{opts.data_path}/char_set/{opts.language}.txt", 'r').read()
|
25 |
+
fonts_file_path = os.path.join(opts.ttf_path, opts.language)
|
26 |
+
sfd_path = os.path.join(opts.sfd_path, opts.language)
|
27 |
+
for root, dirs, files in os.walk(os.path.join(fonts_file_path, opts.split)):
|
28 |
+
ttf_names = files
|
29 |
+
# ttf_names = ['08343.aspx_id=299524532']
|
30 |
+
ttf_names.sort()
|
31 |
+
font_num = len(ttf_names)
|
32 |
+
charset_lenw = len(str(len(charset)))
|
33 |
+
process_nums = mp.cpu_count() - 1
|
34 |
+
font_num_per_process = font_num // process_nums + 1
|
35 |
+
|
36 |
+
def process(process_id, font_num_p_process):
|
37 |
+
for i in tqdm(range(process_id * font_num_p_process, (process_id + 1) * font_num_p_process)):
|
38 |
+
if i >= font_num:
|
39 |
+
break
|
40 |
+
|
41 |
+
fontname = ttf_names[i].split('.')[0]
|
42 |
+
# print(fontname)
|
43 |
+
|
44 |
+
if not os.path.exists(os.path.join(sfd_path, opts.split, fontname)):
|
45 |
+
continue
|
46 |
+
|
47 |
+
ttf_file_path = os.path.join(fonts_file_path, opts.split, ttf_names[i])
|
48 |
+
|
49 |
+
try:
|
50 |
+
font = ImageFont.truetype(ttf_file_path, int(opts.img_size*opts.FONT_SIZE), encoding="unic")
|
51 |
+
except:
|
52 |
+
print('cant open ' + fontname)
|
53 |
+
continue
|
54 |
+
|
55 |
+
fontimgs_array = np.zeros((len(charset), opts.img_size, opts.img_size), np.uint8)
|
56 |
+
fontimgs_array[:, :, :] = 255
|
57 |
+
|
58 |
+
flag_success = True
|
59 |
+
|
60 |
+
for charid in range(len(charset)):
|
61 |
+
# read the meta file
|
62 |
+
txt_fpath = os.path.join(sfd_path, opts.split, fontname, fontname + '_' + '{num:0{width}}'.format(num=charid, width=charset_lenw) + '.txt')
|
63 |
+
try:
|
64 |
+
txt_lines = open(txt_fpath,'r').read().split('\n')
|
65 |
+
except:
|
66 |
+
print('cannot read text file')
|
67 |
+
flag_success = False
|
68 |
+
break
|
69 |
+
if len(txt_lines) < 5:
|
70 |
+
flag_success = False
|
71 |
+
break # should be empty file
|
72 |
+
# the offsets are calculated according to the rules in data_utils/svg_utils.py
|
73 |
+
vbox_w = float(txt_lines[1])
|
74 |
+
vbox_h = float(txt_lines[2])
|
75 |
+
norm = max(int(vbox_w), int(vbox_h))
|
76 |
+
|
77 |
+
if int(vbox_h) > int(vbox_w):
|
78 |
+
add_to_y = 0
|
79 |
+
add_to_x = abs(int(vbox_h) - int(vbox_w)) / 2
|
80 |
+
add_to_x = add_to_x * (float(opts.img_size) / norm)
|
81 |
+
else:
|
82 |
+
add_to_y = abs(int(vbox_h) - int(vbox_w)) / 2
|
83 |
+
add_to_y = add_to_y * (float(opts.img_size) / norm)
|
84 |
+
add_to_x = 0
|
85 |
+
|
86 |
+
char = charset[charid]
|
87 |
+
|
88 |
+
array = np.ndarray((opts.img_size, opts.img_size), np.uint8)
|
89 |
+
array[:, :] = 255
|
90 |
+
image = Image.fromarray(array)
|
91 |
+
draw = ImageDraw.Draw(image)
|
92 |
+
try:
|
93 |
+
font_width, font_height = font.getsize(char)
|
94 |
+
except Exception as e:
|
95 |
+
print('cant calculate height and width ' + "%04d"%i + '_' + '{num:0{width}}'.format(num=charid, width=charset_lenw))
|
96 |
+
flag_success = False
|
97 |
+
break
|
98 |
+
|
99 |
+
try:
|
100 |
+
ascent, descent = font.getmetrics()
|
101 |
+
except:
|
102 |
+
print('cannot get ascent, descent')
|
103 |
+
flag_success = False
|
104 |
+
break
|
105 |
+
|
106 |
+
draw_pos_x = add_to_x
|
107 |
+
#if opts.language == 'eng':
|
108 |
+
thai_characters_long = ["ญ","ฎ","ฏ","ฐ"]
|
109 |
+
|
110 |
+
if char in thai_characters_long:
|
111 |
+
draw_pos_y = add_to_y + opts.img_size - ascent - descent - int((opts.img_size / 24.0) * (10.0 / 3.0))
|
112 |
+
else:
|
113 |
+
draw_pos_y = add_to_y + opts.img_size - ascent - int((opts.img_size / 24.0) * (10.0 / 3.0))
|
114 |
+
#else:
|
115 |
+
# draw_pos_y = add_to_y + opts.img_size - ascent - int((opts.img_size / 24.0) * (10.0 / 3.0))
|
116 |
+
|
117 |
+
draw.text((draw_pos_x, draw_pos_y), char, (0), font=font)
|
118 |
+
|
119 |
+
if opts.debug:
|
120 |
+
image.save(os.path.join(sfd_path, opts.split, fontname, str(charid) + '_' + str(opts.img_size) + '.png'))
|
121 |
+
|
122 |
+
try:
|
123 |
+
char_w, char_h = get_bbox(image)
|
124 |
+
# print(charid, char_w, char_h)
|
125 |
+
except Exception as e:
|
126 |
+
print("cannot get bbox")
|
127 |
+
print(e)
|
128 |
+
flag_success = False
|
129 |
+
break
|
130 |
+
|
131 |
+
# Detect large font
|
132 |
+
problem = []
|
133 |
+
if font_width > 59:
|
134 |
+
problem.append("width")
|
135 |
+
|
136 |
+
if font_height > 93:
|
137 |
+
problem.append("height")
|
138 |
+
|
139 |
+
if problem:
|
140 |
+
print(problem,fontname, charid, font_width, font_height, char_w, char_h)
|
141 |
+
flag_success = False
|
142 |
+
break
|
143 |
+
|
144 |
+
# Detect Small Font
|
145 |
+
if (char_w < opts.img_size * 0.15) and (char_h < opts.img_size * 0.15):
|
146 |
+
flag_success = False
|
147 |
+
break
|
148 |
+
|
149 |
+
fontimgs_array[charid] = np.array(image)
|
150 |
+
|
151 |
+
if flag_success:
|
152 |
+
np.save(os.path.join(sfd_path, opts.split, fontname, 'imgs_' + str(opts.img_size) + '.npy'), fontimgs_array)
|
153 |
+
else:
|
154 |
+
global char_error # Count char flag not success
|
155 |
+
char_error += 1
|
156 |
+
print("flag on", fontname, charid, 'imgs_' + str(opts.img_size) + '.npy', " Not Succeed")
|
157 |
+
|
158 |
+
processes = [mp.Process(target=process, args=(pid, font_num_per_process)) for pid in range(process_nums)]
|
159 |
+
|
160 |
+
for p in processes:
|
161 |
+
p.start()
|
162 |
+
for p in processes:
|
163 |
+
p.join()
|
164 |
+
|
165 |
+
def main():
|
166 |
+
parser = argparse.ArgumentParser(description="Write glyph images")
|
167 |
+
parser.add_argument("--language", type=str, default='eng', choices=['eng', 'chn', 'tha'])
|
168 |
+
parser.add_argument("--data_path", type=str, default='./Font_Dataset', help="Path to Dataset")
|
169 |
+
parser.add_argument("--ttf_path", type=str, default='../data/font_ttfs')
|
170 |
+
parser.add_argument('--sfd_path', type=str, default='../data/font_sfds')
|
171 |
+
parser.add_argument('--img_size', type=int, default=64)
|
172 |
+
parser.add_argument('--split', type=str, default='train')
|
173 |
+
parser.add_argument('--FONT_SIZE', type=float, default=1)
|
174 |
+
parser.add_argument('--debug', type=bool, default=False)
|
175 |
+
opts = parser.parse_args()
|
176 |
+
write_glyph_imgs_mp(opts)
|
177 |
+
|
178 |
+
|
179 |
+
if __name__ == "__main__":
|
180 |
+
main()
|
dataloader.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# data loader for training main model
|
2 |
+
import os
|
3 |
+
import pickle
|
4 |
+
import torch
|
5 |
+
import torch.utils.data as data
|
6 |
+
import torchvision.transforms as T
|
7 |
+
import sys
|
8 |
+
import numpy as np
|
9 |
+
torch.multiprocessing.set_sharing_strategy('file_system')
|
10 |
+
|
11 |
+
|
12 |
+
class SVGDataset(data.Dataset):
|
13 |
+
def __init__(self, root_path, img_size=128, lang='eng', char_num=52, max_seq_len=51, dim_seq=10, transform=None, mode='train'):
|
14 |
+
super().__init__()
|
15 |
+
self.mode = mode
|
16 |
+
self.img_size = img_size
|
17 |
+
self.char_num = char_num
|
18 |
+
self.max_seq_len = max_seq_len
|
19 |
+
self.dim_seq = dim_seq
|
20 |
+
self.trans = transform
|
21 |
+
self.font_paths = []
|
22 |
+
self.dir_path = os.path.join(root_path, lang, self.mode)
|
23 |
+
for root, dirs, files in os.walk(self.dir_path):
|
24 |
+
depth = root.count('/') - self.dir_path.count('/')
|
25 |
+
if depth == 0:
|
26 |
+
for dir_name in dirs:
|
27 |
+
self.font_paths.append(os.path.join(self.dir_path, dir_name))
|
28 |
+
self.font_paths.sort()
|
29 |
+
print(f"Finished loading {mode} paths, number: {str(len(self.font_paths))}")
|
30 |
+
|
31 |
+
def __getitem__(self, index):
|
32 |
+
item = {}
|
33 |
+
font_path = self.font_paths[index]
|
34 |
+
item = {}
|
35 |
+
item['class'] = torch.LongTensor(np.load(os.path.join(font_path, 'class.npy')))
|
36 |
+
item['seq_len'] = torch.LongTensor(np.load(os.path.join(font_path, 'seq_len.npy')))
|
37 |
+
item['sequence'] = torch.FloatTensor(np.load(os.path.join(font_path, 'sequence_relaxed.npy'))).view(self.char_num, self.max_seq_len, self.dim_seq)
|
38 |
+
item['pts_aux'] = torch.FloatTensor(np.load(os.path.join(font_path, 'pts_aux.npy')))
|
39 |
+
item['rendered'] = torch.FloatTensor(np.load(os.path.join(font_path, 'rendered_' + str(self.img_size) + '.npy'))).view(self.char_num, self.img_size, self.img_size) / 255.
|
40 |
+
item['rendered'] = self.trans(item['rendered'])
|
41 |
+
item['font_id'] = torch.FloatTensor(np.load(os.path.join(font_path, 'font_id.npy')).astype(np.float32))
|
42 |
+
return item
|
43 |
+
|
44 |
+
def __len__(self):
|
45 |
+
return len(self.font_paths)
|
46 |
+
|
47 |
+
|
48 |
+
def get_loader(root_path, img_size, lang, char_num, max_seq_len, dim_seq, batch_size, mode='train'):
|
49 |
+
SetRange = T.Lambda(lambda X: 1. - X ) # convert [0, 1] -> [0, 1]
|
50 |
+
transform = T.Compose([SetRange])
|
51 |
+
dataset = SVGDataset(root_path, img_size, lang, char_num, max_seq_len, dim_seq, transform, mode)
|
52 |
+
dataloader = data.DataLoader(dataset, batch_size, shuffle=(mode == 'train'), num_workers=batch_size)
|
53 |
+
return dataloader
|
54 |
+
|
55 |
+
if __name__ == '__main__':
|
56 |
+
root_path = 'data/new_data'
|
57 |
+
max_seq_len = 51
|
58 |
+
dim_seq = 10
|
59 |
+
batch_size = 1
|
60 |
+
char_num = 52
|
61 |
+
|
62 |
+
loader = get_loader(root_path, char_num, max_seq_len, dim_seq, batch_size, 'train')
|
63 |
+
fout = open('train_id_record_old.txt','w')
|
64 |
+
for idx, batch in enumerate(loader):
|
65 |
+
binary_fp = batch['font_id'].numpy()[0][0]
|
66 |
+
fout.write("%05d"%int(binary_fp) + '\n')
|
67 |
+
|
font_sample/Athiti-Regular.ttf
ADDED
Binary file (187 kB). View file
|
|
font_sample/SaoChingcha-Bold.otf
ADDED
Binary file (39.6 kB). View file
|
|
font_sample/SaoChingcha-Light.otf
ADDED
Binary file (38 kB). View file
|
|
font_sample/SaoChingcha-Regular.otf
ADDED
Binary file (38.1 kB). View file
|
|
generate.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import fontTools
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
import typing
|
5 |
+
import PIL
|
6 |
+
from PIL import Image, ImageDraw, ImageFont
|
7 |
+
from data_utils.convert_ttf_to_sfd import convert_mp
|
8 |
+
from data_utils.write_glyph_imgs import write_glyph_imgs_mp
|
9 |
+
from data_utils.write_data_to_dirs import create_db
|
10 |
+
from data_utils.relax_rep import relax_rep
|
11 |
+
from test_few_shot import test_main_model
|
12 |
+
from options import get_parser_main_model
|
13 |
+
|
14 |
+
opts = get_parser_main_model().parse_args()
|
15 |
+
|
16 |
+
# Config on opts
|
17 |
+
# Inference opts
|
18 |
+
opts.mode = "test"
|
19 |
+
opts.language = "tha"
|
20 |
+
opts.char_num = 44
|
21 |
+
opts.ref_nshot = 8
|
22 |
+
opts.batch_size = 1 # inference rule
|
23 |
+
opts.img_size = 64
|
24 |
+
opts.max_seq_len = 121
|
25 |
+
opts.name_ckpt = ""
|
26 |
+
opts.model_path = "./inference_model/950_49452.ckpt"
|
27 |
+
opts.ref_char_ids = "0,1,2,3,4,5,6,7"
|
28 |
+
opts.dir_res = "./inference"
|
29 |
+
opts.data_root = "./inference/vecfont_dataset/"
|
30 |
+
|
31 |
+
# Data preprocessing opts
|
32 |
+
opts.data_path = './inference'
|
33 |
+
opts.sfd_path = f'{opts.data_path}/font_sfds'
|
34 |
+
opts.ttf_path = f'{opts.data_path}/font_ttfs'
|
35 |
+
opts.split = "test"
|
36 |
+
opts.debug = True # Save Image On write_glyph_imgs_mp
|
37 |
+
opts.output_path = f'{opts.data_path}/vecfont_dataset/'
|
38 |
+
opts.phase = 0
|
39 |
+
opts.FONT_SIZE = 1
|
40 |
+
|
41 |
+
opts.streamlit = True
|
42 |
+
|
43 |
+
# Glypts ID :
|
44 |
+
# [(0, 'A'), (1, 'B'), (2, 'C'), (3, 'D'), (4, 'E')]
|
45 |
+
# [(5, 'F'), (6, 'G'), (7, 'H'), (8, 'I'), (9, 'J')]
|
46 |
+
# [(10, 'K'), (11, 'L'), (12, 'M'), (13, 'N'), (14, 'O')]
|
47 |
+
# [(15, 'P'), (16, 'Q'), (17, 'R'), (18, 'S'), (19, 'T')]
|
48 |
+
# [(20, 'U'), (21, 'V'), (22, 'W'), (23, 'X'), (24, 'Y')]
|
49 |
+
# [(25, 'Z'), (26, 'a'), (27, 'b'), (28, 'c'), (29, 'd')]
|
50 |
+
# [(30, 'e'), (31, 'f'), (32, 'g'), (33, 'h'), (34, 'i')]
|
51 |
+
# [(35, 'j'), (36, 'k'), (37, 'l'), (38, 'm'), (39, 'n')]
|
52 |
+
# [(40, 'o'), (41, 'p'), (42, 'q'), (43, 'r'), (44, 's')]
|
53 |
+
# [(45, 't'), (46, 'u'), (47, 'v'), (48, 'w'), (49, 'x')]
|
54 |
+
# [(50, 'y'), (51, 'z'), (52, 'ก'), (53, 'ข'), (54, 'ฃ')]
|
55 |
+
# [(55, 'ค'), (56, 'ฅ'), (57, 'ฆ'), (58, 'ง'), (59, 'จ')]
|
56 |
+
# [(60, 'ฉ'), (61, 'ช'), (62, 'ซ'), (63, 'ฌ'), (64, 'ญ')]
|
57 |
+
# [(65, 'ฎ'), (66, 'ฏ'), (67, 'ฐ'), (68, 'ฑ'), (69, 'ฒ')]
|
58 |
+
# [(70, 'ณ'), (71, 'ด'), (72, 'ต'), (73, 'ถ'), (74, 'ท')]
|
59 |
+
# [(75, 'ธ'), (76, 'น'), (77, 'บ'), (78, 'ป'), (79, 'ผ')]
|
60 |
+
# [(80, 'ฝ'), (81, 'พ'), (82, 'ฟ'), (83, 'ภ'), (84, 'ม')]
|
61 |
+
# [(85, 'ย'), (86, 'ร'), (87, 'ล'), (88, 'ว'), (89, 'ศ')]
|
62 |
+
# [(90, 'ษ'), (91, 'ส'), (92, 'ห'), (93, 'ฬ'), (94, 'อ')]
|
63 |
+
# [(95, 'ฮ')]
|
64 |
+
|
65 |
+
import string
|
66 |
+
import pythainlp
|
67 |
+
|
68 |
+
thai_digits = [*pythainlp.thai_digits]
|
69 |
+
thai_characters = [*pythainlp.thai_consonants]
|
70 |
+
eng_characters = [*string.ascii_letters]
|
71 |
+
thai_floating = [*pythainlp.thai_vowels]
|
72 |
+
|
73 |
+
directories = [
|
74 |
+
"inference",
|
75 |
+
"inference/char_set",
|
76 |
+
"inference/font_sfds",
|
77 |
+
"inference/font_ttfs",
|
78 |
+
"inference/vecfont_dataset",
|
79 |
+
"inference/font_ttfs/tha/test",
|
80 |
+
]
|
81 |
+
|
82 |
+
|
83 |
+
# Data Preprocessing
|
84 |
+
def preprocessing(ttf_file) -> str:
|
85 |
+
shutil.rmtree("inference")
|
86 |
+
for directory in directories:
|
87 |
+
os.makedirs(directory, exist_ok=True)
|
88 |
+
|
89 |
+
# Save File / Copy File
|
90 |
+
if isinstance(ttf_file, memoryview):
|
91 |
+
with open(f"{opts.data_path}/font_ttfs/tha/test/0000.ttf", 'wb') as f:
|
92 |
+
f.write(ttf_file)
|
93 |
+
elif isinstance(ttf_file, str):
|
94 |
+
shutil.copy(ttf_file, f"{opts.data_path}/font_ttfs/tha/test/0000.ttf")
|
95 |
+
|
96 |
+
glypts = sorted(set(thai_characters))
|
97 |
+
print("Glypts:",len(glypts))
|
98 |
+
print("".join(glypts))
|
99 |
+
f = open("inference/char_set/tha.txt", "w")
|
100 |
+
f.write("".join(glypts))
|
101 |
+
f.close()
|
102 |
+
|
103 |
+
# Preprocess Pipeline
|
104 |
+
convert_mp(opts)
|
105 |
+
write_glyph_imgs_mp(opts)
|
106 |
+
output_path = os.path.join(opts.output_path, opts.language, opts.split)
|
107 |
+
log_path = os.path.join(opts.sfd_path, opts.language, 'log')
|
108 |
+
if not os.path.exists(output_path):
|
109 |
+
os.makedirs(output_path)
|
110 |
+
if not os.path.exists(log_path):
|
111 |
+
os.makedirs(log_path)
|
112 |
+
create_db(opts, output_path, log_path)
|
113 |
+
relax_rep(opts)
|
114 |
+
|
115 |
+
print("Finished making a data", ttf_file)
|
116 |
+
print("Saved at", output_path)
|
117 |
+
return output_path
|
118 |
+
|
119 |
+
def inference_model(n_samples, ref_char_ids, version):
|
120 |
+
opts.n_samples = n_samples
|
121 |
+
opts.ref_char_ids = ref_char_ids
|
122 |
+
|
123 |
+
# Select Model
|
124 |
+
if version == "TH2TH":
|
125 |
+
opts.model_path = "./inference_model/950_49452.ckpt"
|
126 |
+
elif version == "ENG2TH":
|
127 |
+
opts.model_path = "./inference_model/950_49452.ckpt"
|
128 |
+
else:
|
129 |
+
raise NotImplementedError
|
130 |
+
|
131 |
+
return test_main_model(opts)
|
132 |
+
|
133 |
+
def ttf_to_image(ttf_file, n_samples=10, ref_char_ids="1,2,3,4,5,6,7,8", version="TH2TH"):
|
134 |
+
preprocessing(ttf_file) # Make Data
|
135 |
+
merge_svg_img = inference_model(n_samples, ref_char_ids, version) # Inference
|
136 |
+
return merge_svg_img
|
137 |
+
|
138 |
+
def main():
|
139 |
+
print(opts.mode)
|
140 |
+
ttf_to_image("font_sample/SaoChingcha-Regular.otf")
|
141 |
+
|
142 |
+
if __name__ == "__main__":
|
143 |
+
main()
|
models/__init__.py
ADDED
File without changes
|
models/image_decoder.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
|
6 |
+
class ImageDecoder(nn.Module):
|
7 |
+
def __init__(self, img_size, input_nc, output_nc, ngf=16, norm_layer=nn.LayerNorm):
|
8 |
+
|
9 |
+
super(ImageDecoder, self).__init__()
|
10 |
+
n_upsampling = int(math.log(img_size, 2))
|
11 |
+
ks_list = [3] * (n_upsampling // 3) + [5] * (n_upsampling - n_upsampling // 3)
|
12 |
+
stride_list = [2] * n_upsampling
|
13 |
+
decoder = []
|
14 |
+
|
15 |
+
chn_mult = []
|
16 |
+
for i in range(n_upsampling):
|
17 |
+
chn_mult.append(2 ** (n_upsampling - i - 1))
|
18 |
+
|
19 |
+
decoder += [nn.ConvTranspose2d(input_nc, chn_mult[0] * ngf,
|
20 |
+
kernel_size=ks_list[0], stride=stride_list[0],
|
21 |
+
padding=ks_list[0] // 2, output_padding=stride_list[0]-1),
|
22 |
+
norm_layer([chn_mult[0] * ngf, 2, 2]),
|
23 |
+
nn.ReLU(True)]
|
24 |
+
|
25 |
+
for i in range(1, n_upsampling): # add upsampling layers
|
26 |
+
chn_prev = chn_mult[i - 1] * ngf
|
27 |
+
chn_next = chn_mult[i] * ngf
|
28 |
+
decoder += [nn.ConvTranspose2d(chn_prev, chn_next, kernel_size=ks_list[i], stride=stride_list[i], padding=ks_list[i] // 2, output_padding=stride_list[i]-1),
|
29 |
+
norm_layer([chn_next, 2 ** (i+1) , 2 ** (i+1)]),
|
30 |
+
nn.ReLU(True)]
|
31 |
+
|
32 |
+
decoder += [nn.Conv2d(chn_mult[-1] * ngf, output_nc, kernel_size=7, padding=7 // 2)]
|
33 |
+
decoder += [nn.Sigmoid()]
|
34 |
+
self.decode = nn.Sequential(*decoder)
|
35 |
+
|
36 |
+
def forward(self, latent_feat, trg_char, trg_img=None):
|
37 |
+
"""Standard forward"""
|
38 |
+
dec_input = torch.cat((latent_feat, trg_char),-1)
|
39 |
+
dec_input = dec_input.view(dec_input.size(0), dec_input.size(1), 1, 1)
|
40 |
+
dec_out = self.decode(dec_input)
|
41 |
+
output = {}
|
42 |
+
output['gen_imgs'] = dec_out
|
43 |
+
if trg_img is not None:
|
44 |
+
output['img_l1loss'] = F.l1_loss(dec_out, trg_img)
|
45 |
+
|
46 |
+
return output
|
47 |
+
|
48 |
+
|
models/image_encoder.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
|
6 |
+
class ImageEncoder(nn.Module):
|
7 |
+
|
8 |
+
def __init__(self, img_size, input_nc, ngf=16, norm_layer=nn.LayerNorm):
|
9 |
+
|
10 |
+
super(ImageEncoder, self).__init__()
|
11 |
+
n_downsampling = int(math.log(img_size, 2))
|
12 |
+
ks_list = [5] * (n_downsampling - n_downsampling // 3) + [3] * (n_downsampling // 3)
|
13 |
+
stride_list = [2] * n_downsampling
|
14 |
+
|
15 |
+
chn_mult = []
|
16 |
+
for i in range(n_downsampling):
|
17 |
+
chn_mult.append(2 ** (i + 1))
|
18 |
+
|
19 |
+
encoder = [nn.Conv2d(input_nc, ngf, kernel_size=7, padding=7 // 2, bias=True, padding_mode='replicate'),
|
20 |
+
norm_layer([ngf, 2 ** n_downsampling, 2 ** n_downsampling]),
|
21 |
+
nn.ReLU(True)]
|
22 |
+
for i in range(n_downsampling): # add downsampling layers
|
23 |
+
if i == 0:
|
24 |
+
chn_prev = ngf
|
25 |
+
else:
|
26 |
+
chn_prev = ngf * chn_mult[i - 1]
|
27 |
+
chn_next = ngf * chn_mult[i]
|
28 |
+
|
29 |
+
encoder += [nn.Conv2d(chn_prev, chn_next, kernel_size=ks_list[i], stride=stride_list[i], padding=ks_list[i] // 2, padding_mode='replicate'),
|
30 |
+
norm_layer([chn_next, 2 ** (n_downsampling - 1 - i), 2 ** (n_downsampling - 1 - i)]),
|
31 |
+
nn.ReLU(True)]
|
32 |
+
|
33 |
+
self.encode = nn.Sequential(*encoder)
|
34 |
+
self.flatten = nn.Flatten()
|
35 |
+
|
36 |
+
def forward(self, input):
|
37 |
+
"""Standard forward"""
|
38 |
+
ret = self.encode(input)
|
39 |
+
img_feat = self.flatten(ret)
|
40 |
+
output = {}
|
41 |
+
output['img_feat'] = img_feat
|
42 |
+
return output
|
models/modality_fusion.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
from options import get_parser_main_model
|
6 |
+
opts = get_parser_main_model().parse_args()
|
7 |
+
|
8 |
+
def init_weights(m):
|
9 |
+
for name, param in m.named_parameters():
|
10 |
+
nn.init.uniform_(param.data, -0.08, 0.08)
|
11 |
+
|
12 |
+
|
13 |
+
class ModalityFusion(nn.Module):
|
14 |
+
def __init__(self, img_size=64, ref_nshot=4, bottleneck_bits=512, ngf=32, seq_latent_dim=512, mode='train'):
|
15 |
+
super().__init__()
|
16 |
+
self.mode = mode
|
17 |
+
self.bottleneck_bits = bottleneck_bits
|
18 |
+
self.ref_nshot = ref_nshot
|
19 |
+
self.mode = mode
|
20 |
+
self.fc_merge = nn.Linear(seq_latent_dim * opts.ref_nshot, 512)
|
21 |
+
n_downsampling = int(math.log(img_size, 2))
|
22 |
+
mult_max = 2 ** (n_downsampling)
|
23 |
+
self.fc_fusion = nn.Linear(ngf * mult_max + seq_latent_dim, opts.bottleneck_bits * 2, bias=True) # the max multiplier for img feat channels is
|
24 |
+
|
25 |
+
def forward(self, seq_feat, img_feat, ref_pad_mask=None):
|
26 |
+
|
27 |
+
|
28 |
+
cls_one_pad = torch.ones((1,1,1)).to(seq_feat.device).repeat(seq_feat.size(0),1,1)
|
29 |
+
ref_pad_mask = torch.cat([cls_one_pad,ref_pad_mask],dim=-1)
|
30 |
+
|
31 |
+
seq_feat = seq_feat * (ref_pad_mask.transpose(1, 2))
|
32 |
+
seq_feat_ = seq_feat.view(seq_feat.size(0) // self.ref_nshot, self.ref_nshot,seq_feat.size(-2) , seq_feat.size(-1))
|
33 |
+
seq_feat_ = seq_feat_.transpose(1, 2)
|
34 |
+
seq_feat_ = seq_feat_.contiguous().view(seq_feat_.size(0), seq_feat_.size(1), seq_feat_.size(2) * seq_feat_.size(3))
|
35 |
+
seq_feat_ = self.fc_merge(seq_feat_)
|
36 |
+
seq_feat_cls = seq_feat_[:, 0]
|
37 |
+
|
38 |
+
feat_cat = torch.cat((img_feat, seq_feat_cls),-1)
|
39 |
+
dist_param = self.fc_fusion(feat_cat)
|
40 |
+
|
41 |
+
output = {}
|
42 |
+
mu = dist_param[..., :self.bottleneck_bits]
|
43 |
+
log_sigma = dist_param[..., self.bottleneck_bits:]
|
44 |
+
|
45 |
+
if self.mode == 'train':
|
46 |
+
# calculate the kl loss and reparamerize latent code
|
47 |
+
epsilon = torch.randn(*mu.size(), device=mu.device)
|
48 |
+
z = mu + torch.exp(log_sigma / 2) * epsilon
|
49 |
+
kl = 0.5 * torch.mean(torch.exp(log_sigma) + torch.square(mu) - 1. - log_sigma)
|
50 |
+
output['latent'] = z
|
51 |
+
output['kl_loss'] = kl
|
52 |
+
seq_feat_[:, 0] = z
|
53 |
+
latent_feat_seq = seq_feat_
|
54 |
+
|
55 |
+
else:
|
56 |
+
output['latent'] = mu
|
57 |
+
output['kl_loss'] = 0.0
|
58 |
+
seq_feat_[:, 0] = mu
|
59 |
+
latent_feat_seq = seq_feat_
|
60 |
+
|
61 |
+
|
62 |
+
return output, latent_feat_seq
|
63 |
+
|
64 |
+
|
models/model_main.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.image_encoder import ImageEncoder
|
2 |
+
from models.image_decoder import ImageDecoder
|
3 |
+
from models.modality_fusion import ModalityFusion
|
4 |
+
from models.vgg_perceptual_loss import VGGPerceptualLoss
|
5 |
+
from models.transformers import *
|
6 |
+
from torch.autograd import Variable
|
7 |
+
|
8 |
+
class ModelMain(nn.Module):
|
9 |
+
|
10 |
+
def __init__(self, opts, mode='train'):
|
11 |
+
super().__init__()
|
12 |
+
self.opts = opts
|
13 |
+
self.img_encoder = ImageEncoder(img_size=opts.img_size, input_nc=opts.ref_nshot, ngf=opts.ngf, norm_layer=nn.LayerNorm)
|
14 |
+
self.img_decoder = ImageDecoder(img_size=opts.img_size, input_nc=opts.bottleneck_bits + opts.char_num, output_nc=1, ngf=opts.ngf, norm_layer=nn.LayerNorm)
|
15 |
+
self.vggptlossfunc = VGGPerceptualLoss()
|
16 |
+
self.modality_fusion = ModalityFusion(img_size=opts.img_size, ref_nshot=opts.ref_nshot, bottleneck_bits=opts.bottleneck_bits, ngf=opts.ngf, mode=opts.mode)
|
17 |
+
self.transformer_main = Transformer(
|
18 |
+
input_channels = 1,
|
19 |
+
input_axis = 2, # number of axis for input data (2 for images, 3 for video)
|
20 |
+
num_freq_bands = 6, # number of freq bands, with original value (2 * K + 1)
|
21 |
+
max_freq = 10., # maximum frequency, hyperparameter depending on how fine the data is
|
22 |
+
depth = 6, # depth of net. The shape of the final attention mechanism will be:
|
23 |
+
# depth * (cross attention -> self_per_cross_attn * self attention)
|
24 |
+
num_latents = 256, # number of latents, or induced set points, or centroids. different papers giving it different names
|
25 |
+
latent_dim = opts.dim_seq_latent, # latent dimension
|
26 |
+
cross_heads = 1, # number of heads for cross attention. paper said 1
|
27 |
+
latent_heads = 8, # number of heads for latent self attention, 8
|
28 |
+
cross_dim_head = 64, # number of dimensions per cross attention head
|
29 |
+
latent_dim_head = 64, # number of dimensions per latent self attention head
|
30 |
+
num_classes = 1000, # output number of classes
|
31 |
+
attn_dropout = 0.,
|
32 |
+
ff_dropout = 0.,
|
33 |
+
weight_tie_layers = False, # whether to weight tie layers (optional, as indicated in the diagram)
|
34 |
+
fourier_encode_data = True, # whether to auto-fourier encode the data, using the input_axis given. defaults to True, but can be turned off if you are fourier encoding the data yourself
|
35 |
+
self_per_cross_attn = 2 # number of self attention blocks per cross attention
|
36 |
+
)
|
37 |
+
|
38 |
+
self.transformer_seqdec = Transformer_decoder()
|
39 |
+
|
40 |
+
|
41 |
+
def forward(self, data, mode='train'):
|
42 |
+
|
43 |
+
imgs, seqs, scalars = self.fetch_data(data, mode)
|
44 |
+
ref_img, trg_img = imgs
|
45 |
+
ref_seq, ref_seq_cat, ref_pad_mask, trg_seq, trg_seq_gt, trg_seq_shifted, trg_pts_aux = seqs
|
46 |
+
trg_char_onehot, trg_cls, trg_seqlen = scalars
|
47 |
+
|
48 |
+
# image encoding
|
49 |
+
img_encoder_out = self.img_encoder(ref_img)
|
50 |
+
img_feat = img_encoder_out['img_feat'] # bs, ngf * (2 ** 6)
|
51 |
+
|
52 |
+
# seq encoding
|
53 |
+
ref_img_ = ref_img.view(ref_img.size(0) * ref_img.size(1), ref_img.size(2), ref_img.size(3)).unsqueeze(-1) # [max_seq_len, n_bs * n_ref, 9]
|
54 |
+
seq_feat, _ = self.transformer_main(ref_img_, ref_seq_cat, mask=ref_pad_mask) # [n_bs * n_ref, max_seq_len + 1, 9]
|
55 |
+
|
56 |
+
# modality funsion
|
57 |
+
mf_output, latent_feat_seq = self.modality_fusion(seq_feat, img_feat, ref_pad_mask=ref_pad_mask)
|
58 |
+
latent_feat_seq = self.transformer_main.att_residual(latent_feat_seq) # [n_bs, max_seq_len + 1, bottleneck_bits]
|
59 |
+
z = mf_output['latent']
|
60 |
+
kl_loss = mf_output['kl_loss']
|
61 |
+
|
62 |
+
# image decoding
|
63 |
+
img_decoder_out = self.img_decoder(z, trg_char_onehot, trg_img)
|
64 |
+
|
65 |
+
ret_dict = {}
|
66 |
+
loss_dict = {}
|
67 |
+
|
68 |
+
ret_dict['img'] = {}
|
69 |
+
ret_dict['img']['out'] = img_decoder_out['gen_imgs']
|
70 |
+
ret_dict['img']['ref'] = ref_img
|
71 |
+
ret_dict['img']['trg'] = trg_img
|
72 |
+
|
73 |
+
if mode in {'train', 'val'}:
|
74 |
+
# seq decoding (training or val mode)
|
75 |
+
tgt_mask = Variable(subsequent_mask(self.opts.max_seq_len).type_as(ref_pad_mask.data)).unsqueeze(0).expand(z.size(0), -1, -1, -1).cuda().float()
|
76 |
+
command_logits, args_logits, attn = self.transformer_seqdec(x=trg_seq_shifted, memory=latent_feat_seq, trg_char=trg_cls, tgt_mask=tgt_mask)
|
77 |
+
command_logits_2, args_logits_2 = self.transformer_seqdec.parallel_decoder(command_logits, args_logits, memory=latent_feat_seq.detach(), trg_char=trg_cls)
|
78 |
+
|
79 |
+
total_loss = self.transformer_main.loss(command_logits, args_logits,trg_seq, trg_seqlen, trg_pts_aux)
|
80 |
+
total_loss_parallel = self.transformer_main.loss(command_logits_2, args_logits_2, trg_seq, trg_seqlen, trg_pts_aux)
|
81 |
+
vggpt_loss = self.vggptlossfunc(img_decoder_out['gen_imgs'], trg_img)
|
82 |
+
# loss and output
|
83 |
+
loss_svg_items = ['total', 'cmd', 'args', 'smt', 'aux']
|
84 |
+
# for image
|
85 |
+
loss_dict['img'] = {}
|
86 |
+
loss_dict['img']['l1'] = img_decoder_out['img_l1loss']
|
87 |
+
loss_dict['img']['vggpt'] = vggpt_loss['pt_c_loss']
|
88 |
+
# for latent
|
89 |
+
loss_dict['kl'] = kl_loss
|
90 |
+
# for svg
|
91 |
+
loss_dict['svg'] = {}
|
92 |
+
loss_dict['svg_para'] = {}
|
93 |
+
for item in loss_svg_items:
|
94 |
+
loss_dict['svg'][item] = total_loss[f'loss_{item}']
|
95 |
+
loss_dict['svg_para'][item] = total_loss_parallel[f'loss_{item}']
|
96 |
+
|
97 |
+
else: # testing (inference)
|
98 |
+
|
99 |
+
trg_len = trg_seq_shifted.size(0)
|
100 |
+
sampled_svg = torch.zeros(1, trg_seq.size(1), self.opts.dim_seq_short).cuda()
|
101 |
+
|
102 |
+
for t in range(0, trg_len):
|
103 |
+
tgt_mask = Variable(subsequent_mask(sampled_svg.size(0)).type_as(ref_seq_cat.data)).unsqueeze(0).expand(sampled_svg.size(1), -1, -1, -1).cuda().float()
|
104 |
+
command_logits, args_logits, attn = self.transformer_seqdec(x=sampled_svg, memory=latent_feat_seq, trg_char=trg_cls, tgt_mask=tgt_mask)
|
105 |
+
prob_comand = F.softmax(command_logits[:, -1, :], -1)
|
106 |
+
prob_args = F.softmax(args_logits[:, -1, :], -1)
|
107 |
+
next_command = torch.argmax(prob_comand, -1).unsqueeze(-1)
|
108 |
+
next_args = torch.argmax(prob_args, -1)
|
109 |
+
predict_tmp = torch.cat((next_command, next_args),-1).unsqueeze(1).transpose(0,1)
|
110 |
+
sampled_svg = torch.cat((sampled_svg, predict_tmp), dim=0)
|
111 |
+
|
112 |
+
sampled_svg = sampled_svg[1:]
|
113 |
+
cmd2 = sampled_svg[:,:,0].unsqueeze(-1)
|
114 |
+
arg2 = sampled_svg[:,:,1:]
|
115 |
+
command_logits_2, args_logits_2 = self.transformer_seqdec.parallel_decoder(cmd_logits=cmd2, args_logits=arg2, memory=latent_feat_seq, trg_char=trg_cls)
|
116 |
+
prob_comand = F.softmax(command_logits_2,-1)
|
117 |
+
prob_args = F.softmax(args_logits_2,-1)
|
118 |
+
update_command = torch.argmax(prob_comand,-1).unsqueeze(-1)
|
119 |
+
update_args = torch.argmax(prob_args,-1)
|
120 |
+
|
121 |
+
sampled_svg_parralel = torch.cat((update_command, update_args),-1).transpose(0,1)
|
122 |
+
|
123 |
+
commands1 = F.one_hot(sampled_svg[:,:,:1].long(), 4).squeeze().transpose(0, 1)
|
124 |
+
args1 = denumericalize(sampled_svg[:,:,1:]).transpose(0,1)
|
125 |
+
sampled_svg_1 = torch.cat([commands1.cpu().detach(),args1[:, :, 2:].cpu().detach()],dim =-1)
|
126 |
+
|
127 |
+
|
128 |
+
commands2 = F.one_hot(sampled_svg_parralel[:, :, :1].long(), 4).squeeze().transpose(0, 1)
|
129 |
+
args2 = denumericalize(sampled_svg_parralel[:, :, 1:]).transpose(0,1)
|
130 |
+
sampled_svg_2 = torch.cat([commands2.cpu().detach(),args2[:, :, 2:].cpu().detach()], dim =-1)
|
131 |
+
|
132 |
+
ret_dict['svg'] = {}
|
133 |
+
ret_dict['svg']['sampled_1'] = sampled_svg_1
|
134 |
+
ret_dict['svg']['sampled_2'] = sampled_svg_2
|
135 |
+
ret_dict['svg']['trg'] = trg_seq_gt
|
136 |
+
|
137 |
+
return ret_dict, loss_dict
|
138 |
+
|
139 |
+
def fetch_data(self, data, mode):
|
140 |
+
|
141 |
+
input_image = data['rendered'] # [bs, opts.char_num, opts.img_size, opts.img_size]
|
142 |
+
input_sequence = data['sequence'] # [bs, opts.char_num, opts.max_seq_len]
|
143 |
+
input_seqlen = data['seq_len']
|
144 |
+
input_seqlen = input_seqlen + 1
|
145 |
+
input_pts_aux = data['pts_aux']
|
146 |
+
arg_quant = numericalize(input_sequence[:, :, :, 4:])
|
147 |
+
cmd_cls = torch.argmax(input_sequence[:, :, :, :4], dim=-1).unsqueeze(-1)
|
148 |
+
input_sequence = torch.cat([cmd_cls, arg_quant], dim=-1) # 1 + 8 = 9 dimension
|
149 |
+
|
150 |
+
# choose reference classes and target classes
|
151 |
+
|
152 |
+
|
153 |
+
if mode == 'train':
|
154 |
+
ref_cls = torch.randint(0, self.opts.char_num, (input_image.size(0), self.opts.ref_nshot)).cuda()
|
155 |
+
if opts.ref_nshot == 52: # For ENG to TH
|
156 |
+
ref_cls_upper = torch.randint(0, 26, (input_image.size(0), self.opts.ref_nshot // 2)).cuda()
|
157 |
+
ref_cls_lower = torch.randint(26, 52, (input_image.size(0), self.opts.ref_nshot // 2)).cuda()
|
158 |
+
ref_cls = torch.cat((ref_cls_upper, ref_cls_lower), -1)
|
159 |
+
elif mode == 'val':
|
160 |
+
ref_cls = torch.arange(0, self.opts.ref_nshot, 1).cuda().unsqueeze(0).expand(input_image.size(0), -1)
|
161 |
+
else:
|
162 |
+
ref_ids = self.opts.ref_char_ids.split(',')
|
163 |
+
ref_ids = list(map(int, ref_ids))
|
164 |
+
assert len(ref_ids) == self.opts.ref_nshot
|
165 |
+
ref_cls = torch.tensor(ref_ids).cuda().unsqueeze(0).expand(self.opts.char_num, -1)
|
166 |
+
|
167 |
+
|
168 |
+
|
169 |
+
if mode in {'train', 'val'}:
|
170 |
+
trg_cls = torch.randint(0, self.opts.char_num, (input_image.size(0), 1)).cuda()
|
171 |
+
if opts.ref_nshot == 52:
|
172 |
+
trg_cls = torch.randint(52, opts.char_num, (input_image.size(0), 1)).cuda()
|
173 |
+
else:
|
174 |
+
trg_cls = torch.arange(0, self.opts.char_num).cuda()
|
175 |
+
if opts.ref_nshot == 52:
|
176 |
+
trg_cls = torch.randint(52, opts.char_num, (input_image.size(0), 1)).cuda()
|
177 |
+
trg_cls = trg_cls.view(self.opts.char_num, 1)
|
178 |
+
input_image = input_image.expand(self.opts.char_num, -1, -1, -1)
|
179 |
+
input_sequence = input_sequence.expand(self.opts.char_num, -1, -1, -1)
|
180 |
+
input_pts_aux = input_pts_aux.expand(self.opts.char_num, -1, -1, -1)
|
181 |
+
input_seqlen = input_seqlen.expand(self.opts.char_num, -1, -1)
|
182 |
+
|
183 |
+
ref_img = util_funcs.select_imgs(input_image, ref_cls, self.opts)
|
184 |
+
# select a target glyph image
|
185 |
+
trg_img = util_funcs.select_imgs(input_image, trg_cls, self.opts)
|
186 |
+
# randomly select ref vector glyphs
|
187 |
+
ref_seq = util_funcs.select_seqs(input_sequence, ref_cls, self.opts, self.opts.dim_seq_short) # [opts.batch_size, opts.ref_nshot, opts.max_seq_len, opts.dim_seq_nmr]
|
188 |
+
# randomly select a target vector glyph
|
189 |
+
trg_seq = util_funcs.select_seqs(input_sequence, trg_cls, self.opts, self.opts.dim_seq_short)
|
190 |
+
trg_seq = trg_seq.squeeze(1)
|
191 |
+
trg_pts_aux = util_funcs.select_seqs(input_pts_aux, trg_cls, self.opts, opts.n_aux_pts)
|
192 |
+
trg_pts_aux = trg_pts_aux.squeeze(1)
|
193 |
+
# the one-hot target char class
|
194 |
+
trg_char_onehot = util_funcs.trgcls_to_onehot(trg_cls, self.opts)
|
195 |
+
# shift target sequence
|
196 |
+
trg_seq_gt = trg_seq.clone().detach()
|
197 |
+
trg_seq_gt = torch.cat((trg_seq_gt[:, :, :1], trg_seq_gt[:, :, 3:]), -1)
|
198 |
+
trg_seq = trg_seq.transpose(0, 1)
|
199 |
+
trg_seq_shifted = util_funcs.shift_right(trg_seq)
|
200 |
+
|
201 |
+
ref_seq_cat = ref_seq.view(ref_seq.size(0) * ref_seq.size(1), ref_seq.size(2), ref_seq.size(3))
|
202 |
+
ref_seq_cat = ref_seq_cat.transpose(0,1)
|
203 |
+
ref_seqlen = util_funcs.select_seqlens(input_seqlen, ref_cls, self.opts)
|
204 |
+
ref_seqlen_cat = ref_seqlen.view(ref_seqlen.size(0) * ref_seqlen.size(1), ref_seqlen.size(2))
|
205 |
+
ref_pad_mask = torch.zeros(ref_seqlen_cat.size(0), self.opts.max_seq_len) # value = 1 means pos to be masked
|
206 |
+
for i in range(ref_seqlen_cat.size(0)):
|
207 |
+
ref_pad_mask[i,:ref_seqlen_cat[i]] = 1
|
208 |
+
ref_pad_mask = ref_pad_mask.cuda().float().unsqueeze(1)
|
209 |
+
trg_seqlen = util_funcs.select_seqlens(input_seqlen, trg_cls, self.opts)
|
210 |
+
trg_seqlen = trg_seqlen.squeeze()
|
211 |
+
|
212 |
+
return [ref_img, trg_img], [ref_seq, ref_seq_cat, ref_pad_mask, trg_seq, trg_seq_gt, trg_seq_shifted, trg_pts_aux], [trg_char_onehot, trg_cls, trg_seqlen]
|
models/pos_enc.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
class PositionalEncoding(nn.Module):
|
3 |
+
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
4 |
+
super(PositionalEncoding, self).__init__()
|
5 |
+
self.dropout = nn.Dropout(p=dropout)
|
6 |
+
pe = torch.zeros(max_len, d_model) # [max_len, d_model]
|
7 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # [max_len, 1]
|
8 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
9 |
+
# [d_model/2]
|
10 |
+
pe[:, 0::2] = torch.sin(position * div_term) # [max_len, d_model/2]
|
11 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
12 |
+
pe = pe.unsqueeze(0).transpose(0, 1) # 1,51,512 --> [51, 1, d_model]
|
13 |
+
self.register_buffer('pe', pe)
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
"""
|
17 |
+
:param x: [x_len, batch_size, emb_size]
|
18 |
+
:return: [x_len, batch_size, emb_size]
|
19 |
+
"""
|
20 |
+
x = x + self.pe[:x.size(0), :] # [x_len, batch_size, d_model]
|
21 |
+
return self.dropout(x)
|
models/transformers.py
ADDED
@@ -0,0 +1,711 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import pi, log
|
2 |
+
from functools import wraps
|
3 |
+
from multiprocessing import context
|
4 |
+
from textwrap import indent
|
5 |
+
import models.util_funcs as util_funcs
|
6 |
+
import math, copy
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from torch import nn, einsum
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from einops import rearrange, repeat
|
12 |
+
from einops.layers.torch import Reduce
|
13 |
+
import pdb
|
14 |
+
from einops.layers.torch import Rearrange
|
15 |
+
from options import get_parser_main_model
|
16 |
+
opts = get_parser_main_model().parse_args()
|
17 |
+
|
18 |
+
class PositionalEncoding(nn.Module):
|
19 |
+
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
20 |
+
super(PositionalEncoding, self).__init__()
|
21 |
+
self.dropout = nn.Dropout(p=dropout)
|
22 |
+
pe = torch.zeros(max_len, d_model)
|
23 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
24 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
25 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
26 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
27 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
28 |
+
self.register_buffer('pe', pe)
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
"""
|
32 |
+
:param x: [x_len, batch_size, emb_size]
|
33 |
+
:return: [x_len, batch_size, emb_size]
|
34 |
+
"""
|
35 |
+
x = x + self.pe[:x.size(0), :].to(x.device)
|
36 |
+
return self.dropout(x)
|
37 |
+
|
38 |
+
def exists(val):
|
39 |
+
return val is not None
|
40 |
+
|
41 |
+
def default(val, d):
|
42 |
+
return val if exists(val) else d
|
43 |
+
|
44 |
+
def cache_fn(f):
|
45 |
+
cache = dict()
|
46 |
+
@wraps(f)
|
47 |
+
def cached_fn(*args, _cache = True, key = None, **kwargs):
|
48 |
+
if not _cache:
|
49 |
+
return f(*args, **kwargs)
|
50 |
+
nonlocal cache
|
51 |
+
if key in cache:
|
52 |
+
return cache[key]
|
53 |
+
result = f(*args, **kwargs)
|
54 |
+
cache[key] = result
|
55 |
+
return result
|
56 |
+
return cached_fn
|
57 |
+
|
58 |
+
def fourier_encode(x, max_freq, num_bands = 4):
|
59 |
+
'''
|
60 |
+
x: ([64, 64, 2, 1]) is between [-1,1]
|
61 |
+
max_feq is 10
|
62 |
+
num_bands is 6
|
63 |
+
'''
|
64 |
+
|
65 |
+
x = x.unsqueeze(-1)
|
66 |
+
device, dtype, orig_x = x.device, x.dtype, x
|
67 |
+
|
68 |
+
scales = torch.linspace(1., max_freq / 2, num_bands, device = device, dtype = dtype) # tensor([1.0000, 1.8000, 2.6000, 3.4000, 4.2000, 5.0000]
|
69 |
+
scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)] # r([[[[1.0000, 1.8000, 2.6000, 3.4000, 4.2000, 5.0000]]]],
|
70 |
+
|
71 |
+
x = x * scales * pi
|
72 |
+
x = torch.cat([x.sin(), x.cos()], dim = -1)
|
73 |
+
|
74 |
+
x = torch.cat((x, orig_x), dim = -1)
|
75 |
+
return x
|
76 |
+
|
77 |
+
class PreNorm(nn.Module):
|
78 |
+
def __init__(self, dim, fn, context_dim = None):
|
79 |
+
super().__init__()
|
80 |
+
self.fn = fn
|
81 |
+
self.norm = nn.LayerNorm(dim)
|
82 |
+
self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None
|
83 |
+
|
84 |
+
def forward(self, x, **kwargs):
|
85 |
+
x = self.norm(x)
|
86 |
+
|
87 |
+
if exists(self.norm_context):
|
88 |
+
context = kwargs['context']
|
89 |
+
normed_context = self.norm_context(context)
|
90 |
+
kwargs.update(context = normed_context)
|
91 |
+
|
92 |
+
return self.fn(x, **kwargs)
|
93 |
+
|
94 |
+
class GEGLU(nn.Module):
|
95 |
+
def forward(self, x):
|
96 |
+
x, gates = x.chunk(2, dim = -1)
|
97 |
+
return x * F.gelu(gates)
|
98 |
+
|
99 |
+
class FeedForward(nn.Module):
|
100 |
+
def __init__(self, dim, mult = 4, dropout = 0.):
|
101 |
+
super().__init__()
|
102 |
+
self.net = nn.Sequential(
|
103 |
+
nn.Linear(dim, dim * mult * 2),
|
104 |
+
GEGLU(),
|
105 |
+
nn.Linear(dim * mult, dim),
|
106 |
+
nn.Dropout(dropout)
|
107 |
+
)
|
108 |
+
|
109 |
+
def forward(self, x):
|
110 |
+
return self.net(x)
|
111 |
+
|
112 |
+
class Attention(nn.Module):
|
113 |
+
def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64, dropout = 0.,cls_conv_dim=None):
|
114 |
+
super().__init__()
|
115 |
+
inner_dim = dim_head * heads
|
116 |
+
context_dim = default(context_dim, query_dim)
|
117 |
+
|
118 |
+
self.scale = dim_head ** -0.5
|
119 |
+
self.heads = heads
|
120 |
+
|
121 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias = False)
|
122 |
+
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False) # 27 to 5012*2 = 1024
|
123 |
+
|
124 |
+
self.dropout = nn.Dropout(dropout)
|
125 |
+
self.to_out = nn.Linear(inner_dim, query_dim)
|
126 |
+
#self.cls_dim_adjust = nn.Linear(context_dim,cls_conv_dim)
|
127 |
+
|
128 |
+
def forward(self, x, context = None, mask = None, ref_cls_onehot=None):
|
129 |
+
|
130 |
+
h = self.heads
|
131 |
+
q = self.to_q(x)
|
132 |
+
context = default(context, x)
|
133 |
+
k, v = self.to_kv(context).chunk(2, dim = -1)
|
134 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
|
135 |
+
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
136 |
+
|
137 |
+
if exists(mask):
|
138 |
+
mask = repeat(mask, 'b j k -> (b h) k j', h = h)
|
139 |
+
sim.masked_fill(mask == 0, -1e9)
|
140 |
+
|
141 |
+
# attention, what we cannot get enough of
|
142 |
+
attn = sim.softmax(dim = -1)
|
143 |
+
attn = self.dropout(attn)
|
144 |
+
out = einsum('b i j, b j d -> b i d', attn, v)
|
145 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
|
146 |
+
return self.to_out(out), attn
|
147 |
+
|
148 |
+
|
149 |
+
class SVGEmbedding(nn.Module):
|
150 |
+
def __init__(self):
|
151 |
+
super().__init__()
|
152 |
+
self.command_embed = nn.Embedding(4, 512)
|
153 |
+
self.arg_embed = nn.Embedding(128, 128,padding_idx=0)
|
154 |
+
self.embed_fcn = nn.Linear(128 * 8, 512)
|
155 |
+
self.pos_encoding = PositionalEncoding(d_model=opts.hidden_size, max_len=opts.max_seq_len + 1)
|
156 |
+
self._init_embeddings()
|
157 |
+
|
158 |
+
def _init_embeddings(self):
|
159 |
+
nn.init.kaiming_normal_(self.command_embed.weight, mode="fan_in")
|
160 |
+
nn.init.kaiming_normal_(self.arg_embed.weight, mode="fan_in")
|
161 |
+
nn.init.kaiming_normal_(self.embed_fcn.weight, mode="fan_in")
|
162 |
+
|
163 |
+
|
164 |
+
def forward(self, commands, args, groups=None):
|
165 |
+
|
166 |
+
S, GN,_ = commands.shape
|
167 |
+
src = self.command_embed(commands.long()).squeeze() + \
|
168 |
+
self.embed_fcn(self.arg_embed((args).long()).view(S, GN, -1)) # shift due to -1 PAD_VAL
|
169 |
+
|
170 |
+
src = self.pos_encoding(src)
|
171 |
+
|
172 |
+
return src
|
173 |
+
class PositionwiseFeedForward(nn.Module):
|
174 |
+
"Implements FFN equation."
|
175 |
+
|
176 |
+
def __init__(self, d_model, d_ff, dropout):
|
177 |
+
super(PositionwiseFeedForward, self).__init__()
|
178 |
+
self.w_1 = nn.Linear(d_model, d_ff)
|
179 |
+
self.w_2 = nn.Linear(d_ff, d_model)
|
180 |
+
self.dropout = nn.Dropout(dropout)
|
181 |
+
|
182 |
+
def forward(self, x):
|
183 |
+
return self.w_2(F.relu(self.dropout(self.w_1(x))))
|
184 |
+
|
185 |
+
class Transformer_decoder(nn.Module):
|
186 |
+
def __init__(self):
|
187 |
+
super().__init__()
|
188 |
+
self.SVG_embedding = SVGEmbedding()
|
189 |
+
self.command_fcn = nn.Linear(512, 4)
|
190 |
+
self.args_fcn = nn.Linear(512, 8 * 128)
|
191 |
+
c = copy.deepcopy
|
192 |
+
attn = MultiHeadedAttention(h=8, d_model=512, dropout=0.0)
|
193 |
+
ff = PositionwiseFeedForward(d_model=512, d_ff=1024, dropout=0.0)
|
194 |
+
self.decoder_layers = clones(DecoderLayer(512, c(attn), c(attn),c(ff), dropout=0.0), 6)
|
195 |
+
self.decoder_norm = nn.LayerNorm(512)
|
196 |
+
self.decoder_layers_parallel = clones(DecoderLayer(512, c(attn), c(attn), c(ff), dropout=0.0), 1)
|
197 |
+
self.decoder_norm_parallel = nn.LayerNorm(512)
|
198 |
+
self.cls_embedding = nn.Embedding(52,512)
|
199 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, 512))
|
200 |
+
|
201 |
+
def forward(self, x, memory, trg_char, src_mask=None, tgt_mask=None):
|
202 |
+
|
203 |
+
memory = memory.unsqueeze(1)
|
204 |
+
commands = x[:, :, :1]
|
205 |
+
args = x[:, :, 1:]
|
206 |
+
x = self.SVG_embedding(commands, args).transpose(0,1)
|
207 |
+
trg_char = trg_char.long()
|
208 |
+
trg_char = self.cls_embedding(trg_char)
|
209 |
+
x[:, 0:1, :] = trg_char
|
210 |
+
tgt_mask = tgt_mask.squeeze()
|
211 |
+
for layer in self.decoder_layers:
|
212 |
+
x,attn = layer(x, memory, src_mask, tgt_mask)
|
213 |
+
out = self.decoder_norm(x)
|
214 |
+
N, S, _ = out.shape
|
215 |
+
cmd_logits = self.command_fcn(out)
|
216 |
+
args_logits = self.args_fcn(out) # shape: bs, max_len, 8, 256
|
217 |
+
args_logits = args_logits.reshape(N, S, 8, 128)
|
218 |
+
return cmd_logits,args_logits,attn
|
219 |
+
|
220 |
+
def parallel_decoder(self, cmd_logits, args_logits, memory, trg_char):
|
221 |
+
|
222 |
+
memory = memory.unsqueeze(1)
|
223 |
+
cmd_args_mask = torch.Tensor([[0, 0, 0., 0., 0., 0., 0., 0.],
|
224 |
+
[1, 1, 0., 0., 0., 0., 1., 1.],
|
225 |
+
[1, 1, 0., 0., 0., 0., 1., 1.],
|
226 |
+
[1, 1, 1., 1., 1., 1., 1., 1.]]).to(cmd_logits.device)
|
227 |
+
if opts.mode == 'train':
|
228 |
+
cmd2 = torch.argmax(cmd_logits, -1).unsqueeze(-1).transpose(0, 1)
|
229 |
+
arg2 = torch.argmax(args_logits, -1).transpose(0, 1)
|
230 |
+
|
231 |
+
cmd2paddingmask = _get_key_padding_mask(cmd2).transpose(0,1).unsqueeze(-1).to(cmd2.device)
|
232 |
+
cmd2 = cmd2 * cmd2paddingmask
|
233 |
+
args_mask = torch.matmul(F.one_hot(cmd2.long(),4).float(), cmd_args_mask).transpose(-1,-2).squeeze(-1)
|
234 |
+
arg2 = arg2 * args_mask
|
235 |
+
|
236 |
+
x = self.SVG_embedding(cmd2, arg2).transpose(0, 1)
|
237 |
+
else:
|
238 |
+
cmd2 = cmd_logits
|
239 |
+
arg2 = args_logits
|
240 |
+
|
241 |
+
cmd2paddingmask = _get_key_padding_mask(cmd2).transpose(0, 1).unsqueeze(-1).to(cmd2.device)
|
242 |
+
cmd2 = cmd2 * cmd2paddingmask
|
243 |
+
args_mask = torch.matmul(F.one_hot(cmd2.long(),4).float(), cmd_args_mask).transpose(-1, -2).squeeze(-1)
|
244 |
+
arg2 = arg2 * args_mask
|
245 |
+
|
246 |
+
x = self.SVG_embedding(cmd2, arg2).transpose(0,1)
|
247 |
+
|
248 |
+
S = x.size(1)
|
249 |
+
B = x.size(0)
|
250 |
+
tgt_mask = torch.ones(S,S).to(x.device).unsqueeze(0).repeat(B, 1, 1)
|
251 |
+
cmd2paddingmask = cmd2paddingmask.transpose(0, 1).transpose(-1, -2)
|
252 |
+
tgt_mask = tgt_mask * cmd2paddingmask
|
253 |
+
|
254 |
+
trg_char = trg_char.long()
|
255 |
+
trg_char = self.cls_embedding(trg_char)
|
256 |
+
|
257 |
+
x = torch.cat([trg_char, x],1)
|
258 |
+
x[:, 0:1, :] = trg_char
|
259 |
+
x = x[:,:opts.max_seq_len,:]
|
260 |
+
tgt_mask = tgt_mask #*tri
|
261 |
+
for layer in self.decoder_layers_parallel:
|
262 |
+
x, attn = layer(x, memory, src_mask=None, tgt_mask=tgt_mask)
|
263 |
+
out = self.decoder_norm_parallel(x)
|
264 |
+
|
265 |
+
N, S, _ = out.shape
|
266 |
+
cmd_logits = self.command_fcn(out)
|
267 |
+
args_logits = self.args_fcn(out)
|
268 |
+
args_logits = args_logits.reshape(N, S, 8, 128)
|
269 |
+
|
270 |
+
return cmd_logits, args_logits
|
271 |
+
|
272 |
+
|
273 |
+
def _get_key_padding_mask(commands, seq_dim=0):
|
274 |
+
"""
|
275 |
+
Args:
|
276 |
+
commands: Shape [S, ...]
|
277 |
+
"""
|
278 |
+
lens =[]
|
279 |
+
with torch.no_grad():
|
280 |
+
key_padding_mask = (commands == 0).cumsum(dim=seq_dim) > 0
|
281 |
+
commands=commands.transpose(0,1).squeeze(-1) #bs, opts.max_seq_len
|
282 |
+
for i in range(commands.size(0)):
|
283 |
+
try:
|
284 |
+
seqi = commands[i]#blue opts.max_seq_len
|
285 |
+
index = torch.where(seqi==0)[0][0]
|
286 |
+
|
287 |
+
except:
|
288 |
+
index=opts.max_seq_len
|
289 |
+
|
290 |
+
lens.append(index)
|
291 |
+
lens = torch.tensor(lens)+1#blue b
|
292 |
+
seqlen_mask = util_funcs.sequence_mask(lens, opts.max_seq_len)#blue b,opts.max_seq_len
|
293 |
+
return seqlen_mask
|
294 |
+
|
295 |
+
class Transformer(nn.Module):
|
296 |
+
def __init__(
|
297 |
+
self,
|
298 |
+
*,
|
299 |
+
num_freq_bands,
|
300 |
+
depth,
|
301 |
+
max_freq,
|
302 |
+
input_channels = 1,
|
303 |
+
input_axis = 2,
|
304 |
+
num_latents = 512,
|
305 |
+
latent_dim = 512,
|
306 |
+
cross_heads = 1,
|
307 |
+
latent_heads = 8,
|
308 |
+
cross_dim_head = 64,
|
309 |
+
latent_dim_head = 64,
|
310 |
+
num_classes = 1000,
|
311 |
+
attn_dropout = 0.,
|
312 |
+
ff_dropout = 0.,
|
313 |
+
weight_tie_layers = False,
|
314 |
+
fourier_encode_data = True,
|
315 |
+
self_per_cross_attn = 2,
|
316 |
+
final_classifier_head = True
|
317 |
+
):
|
318 |
+
"""The shape of the final attention mechanism will be:
|
319 |
+
depth * (cross attention -> self_per_cross_attn * self attention)
|
320 |
+
|
321 |
+
Args:
|
322 |
+
num_freq_bands: Number of freq bands, with original value (2 * K + 1)
|
323 |
+
depth: Depth of net.
|
324 |
+
max_freq: Maximum frequency, hyperparameter depending on how
|
325 |
+
fine the data is.
|
326 |
+
freq_base: Base for the frequency
|
327 |
+
input_channels: Number of channels for each token of the input.
|
328 |
+
input_axis: Number of axes for input data (2 for images, 3 for video)
|
329 |
+
num_latents: Number of latents, or induced set points, or centroids.
|
330 |
+
Different papers giving it different names.
|
331 |
+
latent_dim: Latent dimension.
|
332 |
+
cross_heads: Number of heads for cross attention. Paper said 1.
|
333 |
+
latent_heads: Number of heads for latent self attention, 8.
|
334 |
+
cross_dim_head: Number of dimensions per cross attention head.
|
335 |
+
latent_dim_head: Number of dimensions per latent self attention head.
|
336 |
+
num_classes: Output number of classes.
|
337 |
+
attn_dropout: Attention dropout
|
338 |
+
ff_dropout: Feedforward dropout
|
339 |
+
weight_tie_layers: Whether to weight tie layers (optional).
|
340 |
+
fourier_encode_data: Whether to auto-fourier encode the data, using
|
341 |
+
the input_axis given. defaults to True, but can be turned off
|
342 |
+
if you are fourier encoding the data yourself.
|
343 |
+
self_per_cross_attn: Number of self attention blocks per cross attn.
|
344 |
+
final_classifier_head: mean pool and project embeddings to number of classes (num_classes) at the end
|
345 |
+
"""
|
346 |
+
super().__init__()
|
347 |
+
self.input_axis = input_axis
|
348 |
+
self.max_freq = max_freq
|
349 |
+
self.num_freq_bands = num_freq_bands
|
350 |
+
|
351 |
+
self.fourier_encode_data = fourier_encode_data
|
352 |
+
fourier_channels = (input_axis * ((num_freq_bands * 2) + 1)) if fourier_encode_data else 0 # 26
|
353 |
+
input_dim = fourier_channels + input_channels
|
354 |
+
|
355 |
+
self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
|
356 |
+
|
357 |
+
get_cross_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, input_dim, heads=cross_heads, dim_head=cross_dim_head, dropout=attn_dropout), context_dim=input_dim)
|
358 |
+
get_cross_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout=ff_dropout))
|
359 |
+
get_latent_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, heads=latent_heads, dim_head=latent_dim_head, dropout=attn_dropout))
|
360 |
+
get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout=ff_dropout))
|
361 |
+
|
362 |
+
get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map(cache_fn, (get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff))
|
363 |
+
|
364 |
+
|
365 |
+
#self_per_cross_attn=1
|
366 |
+
self.layers = nn.ModuleList([])
|
367 |
+
for i in range(depth):
|
368 |
+
should_cache = i > 0 and weight_tie_layers
|
369 |
+
cache_args = {'_cache': should_cache}
|
370 |
+
|
371 |
+
self_attns = nn.ModuleList([])
|
372 |
+
|
373 |
+
for block_ind in range(self_per_cross_attn): #BUG 之前是2 self_per_cross_attn
|
374 |
+
self_attns.append(nn.ModuleList([
|
375 |
+
get_latent_attn(**cache_args, key = block_ind),
|
376 |
+
get_latent_ff(**cache_args, key = block_ind)
|
377 |
+
]))
|
378 |
+
|
379 |
+
self.layers.append(nn.ModuleList([
|
380 |
+
get_cross_attn(**cache_args),
|
381 |
+
get_cross_ff(**cache_args),
|
382 |
+
self_attns
|
383 |
+
]))
|
384 |
+
|
385 |
+
|
386 |
+
get_cross_attn2 = lambda: PreNorm(latent_dim, Attention(latent_dim, input_dim, heads = cross_heads, dim_head = cross_dim_head, dropout = attn_dropout), context_dim = input_dim)
|
387 |
+
get_cross_ff2 = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout))
|
388 |
+
get_latent_attn2 = lambda: PreNorm(latent_dim, Attention(latent_dim, heads = latent_heads, dim_head = latent_dim_head, dropout = attn_dropout))
|
389 |
+
get_latent_ff2 = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout))
|
390 |
+
|
391 |
+
get_cross_attn2, get_cross_ff2, get_latent_attn2, get_latent_ff2 = map(cache_fn, (get_cross_attn2, get_cross_ff2, get_latent_attn2, get_latent_ff2))
|
392 |
+
|
393 |
+
self.layers_cnnsvg = nn.ModuleList([])
|
394 |
+
for i in range(1):
|
395 |
+
should_cache = i > 0 and weight_tie_layers
|
396 |
+
cache_args = {'_cache': should_cache}
|
397 |
+
|
398 |
+
self_attns2 = nn.ModuleList([])
|
399 |
+
|
400 |
+
for block_ind in range(self_per_cross_attn):
|
401 |
+
self_attns2.append(nn.ModuleList([
|
402 |
+
get_latent_attn2(**cache_args, key = block_ind),
|
403 |
+
get_latent_ff2(**cache_args, key = block_ind)
|
404 |
+
]))
|
405 |
+
|
406 |
+
self.layers_cnnsvg.append(nn.ModuleList([
|
407 |
+
get_cross_attn2(**cache_args),
|
408 |
+
get_cross_ff2(**cache_args),
|
409 |
+
self_attns2
|
410 |
+
]))
|
411 |
+
|
412 |
+
self.to_logits = nn.Sequential(
|
413 |
+
Reduce('b n d -> b d', 'mean'),
|
414 |
+
nn.LayerNorm(latent_dim),
|
415 |
+
nn.Linear(latent_dim, num_classes)
|
416 |
+
) if final_classifier_head else nn.Identity()
|
417 |
+
self.pre_lstm_fc = nn.Linear(10,opts.hidden_size)
|
418 |
+
self.posr = PositionalEncoding(d_model=opts.hidden_size,max_len=opts.max_seq_len)
|
419 |
+
|
420 |
+
patch_height = 2
|
421 |
+
patch_width = 2
|
422 |
+
patch_dim = 1 * patch_height * patch_width
|
423 |
+
self.to_patch_embedding = nn.Sequential(
|
424 |
+
Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
425 |
+
nn.Linear(patch_dim, 16),
|
426 |
+
)
|
427 |
+
|
428 |
+
self.SVG_embedding = SVGEmbedding()
|
429 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, 512))
|
430 |
+
|
431 |
+
def forward(self, data, seq, ref_cls_onehot=None, mask=None, return_embeddings=True):
|
432 |
+
|
433 |
+
b, *axis, _, device, dtype = *data.shape, data.device, data.dtype
|
434 |
+
assert len(axis) == self.input_axis, 'input data must have the right number of axis' # img is 2
|
435 |
+
x = seq
|
436 |
+
commands=x[:, :, :1]
|
437 |
+
args=x[:, :, 1:]
|
438 |
+
x = self.SVG_embedding(commands, args).transpose(0,1)
|
439 |
+
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = x.size(0))
|
440 |
+
x = torch.cat([cls_tokens,x],dim = 1)
|
441 |
+
cls_one_pad = torch.ones((1,1,1)).to(x.device).repeat(x.size(0),1,1)
|
442 |
+
mask = torch.cat([cls_one_pad,mask],dim=-1)
|
443 |
+
self_atten = []
|
444 |
+
for cross_attn, cross_ff, self_attns in self.layers:
|
445 |
+
for self_attn, self_ff in self_attns:
|
446 |
+
x_,atten = self_attn(x,mask=mask)
|
447 |
+
x = x_ + x
|
448 |
+
self_atten.append(atten)
|
449 |
+
x = self_ff(x) + x
|
450 |
+
x = x + torch.randn_like(x) # add a perturbation
|
451 |
+
return x, self_atten
|
452 |
+
|
453 |
+
def att_residual(self, x, mask=None):
|
454 |
+
|
455 |
+
for cross_attn, cross_ff, self_attns in self.layers_cnnsvg:
|
456 |
+
for self_attn, self_ff in self_attns:
|
457 |
+
x_, atten = self_attn(x)
|
458 |
+
x = x_ + x
|
459 |
+
x = self_ff(x) + x
|
460 |
+
return x
|
461 |
+
|
462 |
+
|
463 |
+
|
464 |
+
def loss(self, cmd_logits, args_logits, trg_seq, trg_seqlen, trg_pts_aux):
|
465 |
+
'''
|
466 |
+
Inputs:
|
467 |
+
cmd_logits: [b, 51, 4]
|
468 |
+
args_logits: [b, 51, 6]
|
469 |
+
'''
|
470 |
+
cmd_args_mask = torch.Tensor([[0, 0, 0., 0., 0., 0., 0., 0.],
|
471 |
+
[1, 1, 0., 0., 0., 0., 1., 1.],
|
472 |
+
[1, 1, 0., 0., 0., 0., 1., 1.],
|
473 |
+
[1, 1, 1., 1., 1., 1., 1., 1.]]).to(cmd_logits.device)
|
474 |
+
|
475 |
+
tgt_commands = trg_seq[:,:,:1].transpose(0,1)
|
476 |
+
tgt_args = trg_seq[:,:,1:].transpose(0,1)
|
477 |
+
|
478 |
+
seqlen_mask = util_funcs.sequence_mask(trg_seqlen, opts.max_seq_len).unsqueeze(-1)
|
479 |
+
seqlen_mask2 = seqlen_mask.repeat(1,1,4)# NOTE b,501,4
|
480 |
+
seqlen_mask4 = seqlen_mask.repeat(1,1,8)
|
481 |
+
seqlen_mask3 = seqlen_mask.unsqueeze(-1).repeat(1,1,8,128)
|
482 |
+
|
483 |
+
|
484 |
+
tgt_commands_onehot = F.one_hot(tgt_commands, 4)
|
485 |
+
tgt_args_onehot = F.one_hot(tgt_args, 128)
|
486 |
+
|
487 |
+
args_mask = torch.matmul(tgt_commands_onehot.float(),cmd_args_mask).squeeze()
|
488 |
+
|
489 |
+
|
490 |
+
loss_cmd = torch.sum(- tgt_commands_onehot.squeeze() * F.log_softmax(cmd_logits, -1), -1)
|
491 |
+
loss_cmd = torch.mul(loss_cmd, seqlen_mask.squeeze())
|
492 |
+
loss_cmd = torch.mean(torch.sum(loss_cmd/trg_seqlen.unsqueeze(-1),-1))
|
493 |
+
|
494 |
+
loss_args = (torch.sum(-tgt_args_onehot*F.log_softmax(args_logits,-1),-1)*seqlen_mask4*args_mask)
|
495 |
+
|
496 |
+
loss_args = torch.mean(loss_args,dim=-1,keepdim=False)
|
497 |
+
loss_args = torch.mean(torch.sum(loss_args/trg_seqlen.unsqueeze(-1),-1))
|
498 |
+
|
499 |
+
SE_mask = torch.Tensor([[1, 1],
|
500 |
+
[0, 0],
|
501 |
+
[1, 1],
|
502 |
+
[1, 1]]).to(cmd_logits.device)
|
503 |
+
|
504 |
+
SE_args_mask = torch.matmul(tgt_commands_onehot.float(),SE_mask).squeeze().unsqueeze(-1)
|
505 |
+
|
506 |
+
|
507 |
+
args_prob = F.softmax(args_logits, -1)
|
508 |
+
args_end = args_prob[:,:,6:]
|
509 |
+
args_end_shifted = torch.cat((torch.zeros(args_end.size(0),1,args_end.size(2),args_end.size(3)).to(args_end.device),args_end),1)
|
510 |
+
args_end_shifted = args_end_shifted[:,:opts.max_seq_len,:,:]
|
511 |
+
args_end_shifted = args_end_shifted*SE_args_mask + args_end*(1-SE_args_mask)
|
512 |
+
|
513 |
+
args_start = args_prob[:,:,:2]
|
514 |
+
|
515 |
+
seqlen_mask5 = util_funcs.sequence_mask(trg_seqlen-1, opts.max_seq_len).unsqueeze(-1)
|
516 |
+
seqlen_mask5 = seqlen_mask5.repeat(1,1,2)
|
517 |
+
|
518 |
+
smooth_constrained = torch.sum(torch.pow((args_end_shifted - args_start), 2), -1) * seqlen_mask5
|
519 |
+
smooth_constrained = torch.mean(smooth_constrained, dim=-1, keepdim=False)
|
520 |
+
smooth_constrained = torch.mean(torch.sum(smooth_constrained / (trg_seqlen - 1).unsqueeze(-1), -1))
|
521 |
+
|
522 |
+
args_prob2 = F.softmax(args_logits / 0.1, -1)
|
523 |
+
|
524 |
+
c = torch.argmax(args_prob2,-1).unsqueeze(-1).float() - args_prob2.detach()
|
525 |
+
p_argmax = args_prob2 + c
|
526 |
+
p_argmax = torch.mean(p_argmax,-1)
|
527 |
+
control_pts = denumericalize(p_argmax)
|
528 |
+
|
529 |
+
p0 = control_pts[:,:,:2]
|
530 |
+
p1 = control_pts[:,:,2:4]
|
531 |
+
p2 = control_pts[:,:,4:6]
|
532 |
+
p3 = control_pts[:,:,6:8]
|
533 |
+
|
534 |
+
line_mask = (tgt_commands==2).float() + (tgt_commands==1).float()
|
535 |
+
curve_mask = (tgt_commands==3).float()
|
536 |
+
|
537 |
+
t=0.25
|
538 |
+
aux_pts_line = p0 + t*(p3-p0)
|
539 |
+
for t in [0.5,0.75]:
|
540 |
+
coord_t = p0 + t*(p3-p0)
|
541 |
+
aux_pts_line = torch.cat((aux_pts_line,coord_t),-1)
|
542 |
+
aux_pts_line = aux_pts_line*line_mask
|
543 |
+
|
544 |
+
t=0.25
|
545 |
+
aux_pts_curve = (1-t)*(1-t)*(1-t)*p0 + 3*t*(1-t)*(1-t)*p1 + 3*t*t*(1-t)*p2 + t*t*t*p3
|
546 |
+
for t in [0.5, 0.75]:
|
547 |
+
coord_t = (1-t)*(1-t)*(1-t)*p0 + 3*t*(1-t)*(1-t)*p1 + 3*t*t*(1-t)*p2 + t*t*t*p3
|
548 |
+
aux_pts_curve = torch.cat((aux_pts_curve,coord_t),-1)
|
549 |
+
aux_pts_curve = aux_pts_curve * curve_mask
|
550 |
+
|
551 |
+
|
552 |
+
aux_pts_predict = aux_pts_curve + aux_pts_line
|
553 |
+
seqlen_mask_aux = util_funcs.sequence_mask(trg_seqlen - 1, opts.max_seq_len).unsqueeze(-1)
|
554 |
+
aux_pts_loss = torch.pow((aux_pts_predict - trg_pts_aux), 2) * seqlen_mask_aux
|
555 |
+
|
556 |
+
loss_aux = torch.mean(aux_pts_loss, dim=-1, keepdim=False)
|
557 |
+
loss_aux = torch.mean(torch.sum(loss_aux / trg_seqlen.unsqueeze(-1), -1))
|
558 |
+
|
559 |
+
|
560 |
+
loss = opts.loss_w_cmd * loss_cmd + opts.loss_w_args * loss_args + opts.loss_w_aux * loss_aux + opts.loss_w_smt * smooth_constrained
|
561 |
+
|
562 |
+
svg_losses = {}
|
563 |
+
svg_losses['loss_total'] = loss
|
564 |
+
svg_losses["loss_cmd"] = loss_cmd
|
565 |
+
svg_losses["loss_args"] = loss_args
|
566 |
+
svg_losses["loss_smt"] = smooth_constrained
|
567 |
+
svg_losses["loss_aux"] = loss_aux
|
568 |
+
|
569 |
+
return svg_losses
|
570 |
+
|
571 |
+
class DecoderLayer(nn.Module):
|
572 |
+
"Decoder is made of self-attn, src-attn, and feed forward (defined below)"
|
573 |
+
|
574 |
+
def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
|
575 |
+
super(DecoderLayer, self).__init__()
|
576 |
+
self.size = size
|
577 |
+
self.self_attn = self_attn
|
578 |
+
self.src_attn = src_attn
|
579 |
+
self.feed_forward = feed_forward
|
580 |
+
self.sublayer = clones(SublayerConnection(size, dropout), 3)
|
581 |
+
|
582 |
+
def forward(self, x, memory, src_mask, tgt_mask):
|
583 |
+
"Follow Figure 1 (right) for connections."
|
584 |
+
m = memory
|
585 |
+
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
|
586 |
+
x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
|
587 |
+
attn = self.self_attn.attn
|
588 |
+
return self.sublayer[2](x, self.feed_forward),attn
|
589 |
+
|
590 |
+
def subsequent_mask(size):
|
591 |
+
"Mask out subsequent positions."
|
592 |
+
attn_shape = (1, size, size)
|
593 |
+
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
|
594 |
+
return torch.from_numpy(subsequent_mask) == 0
|
595 |
+
|
596 |
+
def numericalize(cmd, n=128):
|
597 |
+
"""NOTE: shall only be called after normalization"""
|
598 |
+
# assert np.max(cmd.origin) <= 1.0 and np.min(cmd.origin) >= -1.0
|
599 |
+
cmd = (cmd / 30 * n).round().clip(min=0, max=n-1).int()
|
600 |
+
return cmd
|
601 |
+
|
602 |
+
def denumericalize(cmd, n=128):
|
603 |
+
cmd = cmd / n * 30
|
604 |
+
return cmd
|
605 |
+
|
606 |
+
def attention(query, key, value, mask=None, trg_tri_mask=None,dropout=None, posr=None):
|
607 |
+
"Compute 'Scaled Dot Product Attention'"
|
608 |
+
d_k = query.size(-1)
|
609 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
|
610 |
+
|
611 |
+
if posr is not None:
|
612 |
+
posr = posr.unsqueeze(1)
|
613 |
+
scores = scores + posr
|
614 |
+
|
615 |
+
if mask is not None:
|
616 |
+
try:
|
617 |
+
scores = scores.masked_fill(mask == 0, -1e9) # note mask: b,1,501,501 scores: b, head, 501,501
|
618 |
+
except Exception as e:
|
619 |
+
print("Shape: ",scores.shape)
|
620 |
+
print("Error: ",e)
|
621 |
+
import pdb; pdb.set_trace()
|
622 |
+
|
623 |
+
if trg_tri_mask is not None:
|
624 |
+
scores = scores.masked_fill(trg_tri_mask == 0, -1e9)
|
625 |
+
|
626 |
+
p_attn = F.softmax(scores, dim=-1)
|
627 |
+
|
628 |
+
if dropout is not None:
|
629 |
+
p_attn = dropout(p_attn)
|
630 |
+
|
631 |
+
return torch.matmul(p_attn, value), p_attn
|
632 |
+
|
633 |
+
|
634 |
+
class MultiHeadedAttention(nn.Module):
|
635 |
+
def __init__(self, h, d_model, dropout):
|
636 |
+
"Take in model size and number of heads."
|
637 |
+
super(MultiHeadedAttention, self).__init__()
|
638 |
+
assert d_model % h == 0
|
639 |
+
# We assume d_v always equals d_k
|
640 |
+
self.d_k = d_model // h #32
|
641 |
+
self.h = h #8
|
642 |
+
self.linears = clones(nn.Linear(d_model, d_model), 4)
|
643 |
+
self.attn = None
|
644 |
+
self.dropout = nn.Dropout(p=dropout)
|
645 |
+
|
646 |
+
def forward(self, query, key, value, mask=None,trg_tri_mask=None, posr=None):
|
647 |
+
"Implements Figure 2"
|
648 |
+
|
649 |
+
if mask is not None:
|
650 |
+
# Same mask applied to all h heads.
|
651 |
+
mask = mask.unsqueeze(1)
|
652 |
+
nbatches = query.size(0) #16
|
653 |
+
|
654 |
+
query, key, value = \
|
655 |
+
[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
|
656 |
+
for l, x in zip(self.linears, (query, key, value))]
|
657 |
+
|
658 |
+
x, self.attn = attention(query, key, value, mask=mask,trg_tri_mask=trg_tri_mask,
|
659 |
+
dropout=self.dropout, posr=posr)
|
660 |
+
|
661 |
+
x = x.transpose(1, 2).contiguous() \
|
662 |
+
.view(nbatches, -1, self.h * self.d_k)
|
663 |
+
|
664 |
+
return self.linears[-1](x)
|
665 |
+
|
666 |
+
def clones(module, N):
|
667 |
+
"Produce N identical layers."
|
668 |
+
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
|
669 |
+
|
670 |
+
class SublayerConnection(nn.Module):
|
671 |
+
"""
|
672 |
+
A residual connection followed by a layer norm.
|
673 |
+
Note for code simplicity the norm is first as opposed to last.
|
674 |
+
"""
|
675 |
+
|
676 |
+
def __init__(self, size, dropout):
|
677 |
+
super(SublayerConnection, self).__init__()
|
678 |
+
self.norm = nn.LayerNorm(size)
|
679 |
+
self.dropout = nn.Dropout(dropout)
|
680 |
+
|
681 |
+
def forward(self, x, sublayer):
|
682 |
+
"Apply residual connection to any sublayer with the same size."
|
683 |
+
x_norm=self.norm(x)
|
684 |
+
return x + self.dropout(sublayer(x_norm))#+ self.augs(x_norm)
|
685 |
+
|
686 |
+
|
687 |
+
if __name__ == '__main__':
|
688 |
+
model = Transformer(
|
689 |
+
input_channels = 1, # number of channels for each token of the input
|
690 |
+
input_axis = 2, # number of axis for input data (2 for images, 3 for video)
|
691 |
+
num_freq_bands = 6, # number of freq bands, with original value (2 * K + 1)
|
692 |
+
max_freq = 10., # maximum frequency, hyperparameter depending on how fine the data is
|
693 |
+
depth = 6, # depth of net. The shape of the final attention mechanism will be:
|
694 |
+
# depth * (cross attention -> self_per_cross_attn * self attention)
|
695 |
+
num_latents = 256, # number of latents, or induced set points, or centroids. different papers giving it different names
|
696 |
+
latent_dim = 512, # latent dimension
|
697 |
+
cross_heads = 1, # number of heads for cross attention. paper said 1
|
698 |
+
latent_heads = 8, # number of heads for latent self attention, 8
|
699 |
+
cross_dim_head = 64, # number of dimensions per cross attention head
|
700 |
+
latent_dim_head = 64, # number of dimensions per latent self attention head
|
701 |
+
num_classes = 1000, # output number of classes
|
702 |
+
attn_dropout = 0.,
|
703 |
+
ff_dropout = 0.,
|
704 |
+
weight_tie_layers = False, # whether to weight tie layers (optional, as indicated in the diagram)
|
705 |
+
fourier_encode_data = True, # whether to auto-fourier encode the data, using the input_axis given. defaults to True, but can be turned off if you are fourier encoding the data yourself
|
706 |
+
self_per_cross_attn = 2 # number of self attention blocks per cross attention
|
707 |
+
)
|
708 |
+
|
709 |
+
img = torch.randn(1, 224, 224, 3) # 1 imagenet image, pixelized
|
710 |
+
|
711 |
+
model(img) # (1, 1000)
|
models/util_funcs.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import cairosvg
|
4 |
+
from data_utils.common_utils import trans2_white_bg
|
5 |
+
from PIL import Image
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
def select_imgs(images_of_onefont, selected_cls, opts):
|
9 |
+
# given selected char classes, return selected imgs
|
10 |
+
# images_of_onefont: [bs, 52, opts.img_size, opts.img_size]
|
11 |
+
# selected_cls: [bs, nshot]
|
12 |
+
nums = selected_cls.size(1)
|
13 |
+
selected_cls_ = selected_cls.unsqueeze(2)
|
14 |
+
selected_cls_ = selected_cls_.unsqueeze(3)
|
15 |
+
selected_cls_ = selected_cls_.expand(images_of_onefont.size(0), nums, opts.img_size, opts.img_size)
|
16 |
+
selected_img = torch.gather(images_of_onefont, 1, selected_cls_)
|
17 |
+
return selected_img
|
18 |
+
|
19 |
+
def select_seqs(seqs_of_onefont, selected_cls, opts, seq_dim):
|
20 |
+
|
21 |
+
nums = selected_cls.size(1)
|
22 |
+
selected_cls_ = selected_cls.unsqueeze(2)
|
23 |
+
selected_cls_ = selected_cls_.unsqueeze(3)
|
24 |
+
selected_cls_ = selected_cls_.expand(seqs_of_onefont.size(0), nums, opts.max_seq_len, seq_dim)
|
25 |
+
selected_seqs = torch.gather(seqs_of_onefont, 1, selected_cls_)
|
26 |
+
return selected_seqs
|
27 |
+
|
28 |
+
def select_seqlens(seqlens_of_onefont, selected_cls, opts):
|
29 |
+
|
30 |
+
nums = selected_cls.size(1)
|
31 |
+
selected_cls_ = selected_cls.unsqueeze(2)
|
32 |
+
selected_cls_ = selected_cls_.expand(seqlens_of_onefont.size(0), nums, 1) # 64, nums, 1
|
33 |
+
selected_seqlens = torch.gather(seqlens_of_onefont, 1, selected_cls_)
|
34 |
+
return selected_seqlens
|
35 |
+
|
36 |
+
def trgcls_to_onehot(trg_cls, opts):
|
37 |
+
trg_char = F.one_hot(trg_cls, num_classes=opts.char_num).squeeze(dim=1)
|
38 |
+
return trg_char
|
39 |
+
|
40 |
+
|
41 |
+
def shift_right(x, pad_value=None):
|
42 |
+
if pad_value is None:
|
43 |
+
shifted = F.pad(x, (0, 0, 0, 0, 1, 0))[:-1, :, :]
|
44 |
+
else:
|
45 |
+
shifted = torch.cat([pad_value, x], axis=0)[:-1, :, :]
|
46 |
+
return shifted
|
47 |
+
|
48 |
+
|
49 |
+
def length_form_embedding(emb):
|
50 |
+
"""Compute the length of each sequence in the batch
|
51 |
+
Args:
|
52 |
+
emb: [seq_len, batch, depth]
|
53 |
+
Returns:
|
54 |
+
a 0/1 tensor: [batch]
|
55 |
+
"""
|
56 |
+
absed = torch.abs(emb)
|
57 |
+
sum_last = torch.sum(absed, dim=2, keepdim=True)
|
58 |
+
mask = sum_last != 0
|
59 |
+
sum_except_batch = torch.sum(mask, dim=(0, 2), dtype=torch.long)
|
60 |
+
return sum_except_batch
|
61 |
+
|
62 |
+
|
63 |
+
def lognormal(y, mean, logstd, logsqrttwopi):
|
64 |
+
y_mean = y - mean # NOTE y:[b*51*6, 1] mean: [b*51*6, 50]
|
65 |
+
logstd_exp = logstd.exp() # NOTE [b*51*6, 50]
|
66 |
+
y_mean_divide_exp = y_mean / logstd_exp
|
67 |
+
return -0.5 * (y_mean_divide_exp) ** 2 - logstd - logsqrttwopi
|
68 |
+
|
69 |
+
def sequence_mask(lengths, max_len=None):
|
70 |
+
batch_size=lengths.numel()
|
71 |
+
max_len=max_len or lengths.max()
|
72 |
+
return (torch.arange(0, max_len, device=lengths.device)
|
73 |
+
.type_as(lengths)
|
74 |
+
.unsqueeze(0).expand(batch_size,max_len)
|
75 |
+
.lt(lengths.unsqueeze(1)))
|
76 |
+
|
77 |
+
def svg2img(path_svg, path_img, img_size):
|
78 |
+
cairosvg.svg2png(url=path_svg, write_to=path_img, output_width=img_size, output_height=img_size)
|
79 |
+
img_arr = trans2_white_bg(path_img)
|
80 |
+
return img_arr
|
81 |
+
|
82 |
+
def cal_img_l1_dist(path_img1, path_img2):
|
83 |
+
img1 = np.array(Image.open(path_img1))
|
84 |
+
img2 = np.array(Image.open(path_img2))
|
85 |
+
dist = np.mean(np.abs(img1 - img2[:, :, 0]))
|
86 |
+
return dist
|
87 |
+
|
88 |
+
def cal_iou(path_img1, path_img2):
|
89 |
+
|
90 |
+
img1 = np.array(Image.open(path_img1))
|
91 |
+
img2 = np.array(Image.open(path_img2))[:, :, 0]
|
92 |
+
mask_img1 = img1 < (255 * 3 / 4)
|
93 |
+
mask_img2 = img2 < (255 * 3 / 4)
|
94 |
+
iou = np.sum(mask_img1 * mask_img2) / (np.sum(mask_img1 + mask_img2))
|
95 |
+
l1_dist = np.mean(np.abs(mask_img1.astype(float) - mask_img2.astype(float)))
|
96 |
+
return iou, l1_dist
|
models/vgg_perceptual_loss.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
|
4 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
5 |
+
|
6 |
+
class VGG19Feats(torch.nn.Module):
|
7 |
+
def __init__(self, requires_grad=False):
|
8 |
+
super(VGG19Feats, self).__init__()
|
9 |
+
vgg = torchvision.models.vgg19(pretrained=True).to(device) #.cuda()
|
10 |
+
# vgg.eval()
|
11 |
+
vgg_pretrained_features = vgg.features.eval()
|
12 |
+
self.requires_grad = requires_grad
|
13 |
+
self.slice1 = torch.nn.Sequential()
|
14 |
+
self.slice2 = torch.nn.Sequential()
|
15 |
+
self.slice3 = torch.nn.Sequential()
|
16 |
+
self.slice4 = torch.nn.Sequential()
|
17 |
+
self.slice5 = torch.nn.Sequential()
|
18 |
+
for x in range(3):
|
19 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
20 |
+
for x in range(3, 8):
|
21 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
22 |
+
for x in range(8, 13):
|
23 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
24 |
+
for x in range(13, 22):
|
25 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
26 |
+
for x in range(22, 31):
|
27 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
28 |
+
if not self.requires_grad:
|
29 |
+
for param in self.parameters():
|
30 |
+
param.requires_grad = False
|
31 |
+
|
32 |
+
def forward(self, img):
|
33 |
+
conv1_2 = self.slice1(img)
|
34 |
+
conv2_2 = self.slice2(conv1_2)
|
35 |
+
conv3_2 = self.slice3(conv2_2)
|
36 |
+
conv4_2 = self.slice4(conv3_2)
|
37 |
+
conv5_2 = self.slice5(conv4_2)
|
38 |
+
out = [conv1_2, conv2_2, conv3_2, conv4_2, conv5_2]
|
39 |
+
return out
|
40 |
+
|
41 |
+
|
42 |
+
class VGGPerceptualLoss(torch.nn.Module):
|
43 |
+
def __init__(self):
|
44 |
+
super(VGGPerceptualLoss, self).__init__()
|
45 |
+
self.vgg = VGG19Feats().to(device)
|
46 |
+
self.criterion = torch.nn.functional.l1_loss
|
47 |
+
self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
48 |
+
self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
49 |
+
self.weights = [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 1.0*10/1.5]
|
50 |
+
|
51 |
+
def forward(self, input_img, target_img):
|
52 |
+
|
53 |
+
if input_img.shape[1] != 3:
|
54 |
+
input_img = input_img.repeat(1, 3, 1, 1)
|
55 |
+
target_img = target_img.repeat(1, 3, 1, 1)
|
56 |
+
input_img = (input_img - self.mean) / self.std
|
57 |
+
target_img = (target_img - self.mean) / self.std
|
58 |
+
|
59 |
+
x_vgg, y_vgg = self.vgg(input_img), self.vgg(target_img)
|
60 |
+
|
61 |
+
loss = {}
|
62 |
+
loss['pt_c_loss'] = self.weights[0] * self.criterion(x_vgg[0], y_vgg[0])+\
|
63 |
+
self.weights[1] * self.criterion(x_vgg[1], y_vgg[1])+\
|
64 |
+
self.weights[2] * self.criterion(x_vgg[2], y_vgg[2])+\
|
65 |
+
self.weights[3] * self.criterion(x_vgg[3], y_vgg[3])+\
|
66 |
+
self.weights[4] * self.criterion(x_vgg[4], y_vgg[4])
|
67 |
+
loss['pt_s_loss'] = 0.0
|
68 |
+
|
69 |
+
return loss
|
options.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
def get_parser_main_model():
|
4 |
+
parser = argparse.ArgumentParser()
|
5 |
+
# basic parameters training related
|
6 |
+
parser.add_argument('--model_name', type=str, default='main_model', choices=['main_model', 'neural_raster'], help='current model_name')
|
7 |
+
parser.add_argument("--language", type=str, default='tha', choices=['eng', 'chn', 'tha'])
|
8 |
+
parser.add_argument('--bottleneck_bits', type=int, default=512, help='latent code number of bottleneck bits')
|
9 |
+
parser.add_argument('--char_num', type=int, default=44, help='number of glyphs, original is 44 (Thai)')
|
10 |
+
parser.add_argument('--seed', type=int, default=3712)
|
11 |
+
parser.add_argument('--ref_nshot', type=int, default=8, help='reference number')
|
12 |
+
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
|
13 |
+
parser.add_argument('--batch_size_val', type=int, default=8, help='batch size when do validation')
|
14 |
+
parser.add_argument('--img_size', type=int, default=64, help='image size')
|
15 |
+
parser.add_argument('--max_seq_len', type=int, default=121, help='maximum length of sequence')
|
16 |
+
parser.add_argument('--dim_seq', type=int, default=12, help='the dim of each stroke in a sequence, 4 + 8, 4 is cmd, and 8 is args')
|
17 |
+
parser.add_argument('--dim_seq_short', type=int, default=9, help='the short dim of each stroke in a sequence, 1 + 8, 1 is cmd class num, and 8 is args')
|
18 |
+
parser.add_argument('--hidden_size', type=int, default=512, help='hidden_size')
|
19 |
+
parser.add_argument('--dim_seq_latent', type=int, default=512, help='sequence encoder latent dim')
|
20 |
+
parser.add_argument('--ngf', type=int, default=16, help='the basic num of channel in image encoder and decoder')
|
21 |
+
parser.add_argument('--n_aux_pts', type=int, default=6, help='the number of aux pts in bezier curves for additional supervison')
|
22 |
+
# experiment related
|
23 |
+
|
24 |
+
parser.add_argument('--random_index', type=str, default='00')
|
25 |
+
parser.add_argument('--name_ckpt', type=str, default='600_192921.ckpt')
|
26 |
+
parser.add_argument('--model_path', type=str, default='.')
|
27 |
+
parser.add_argument('--n_epochs', type=int, default=800, help='number of epochs')
|
28 |
+
parser.add_argument('--n_samples', type=int, default=20, help='the number of samples for each glyph when testing')
|
29 |
+
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate')
|
30 |
+
parser.add_argument('--ref_char_ids', type=str, default='0,1,26,27', help='default is A, B, a, b')
|
31 |
+
|
32 |
+
parser.add_argument('--mode', type=str, default='test', choices=['train', 'val', 'test'])
|
33 |
+
parser.add_argument('--multi_gpu', type=bool, default=False)
|
34 |
+
parser.add_argument('--name_exp', type=str, default='dvf')
|
35 |
+
|
36 |
+
# continue training'
|
37 |
+
parser.add_argument('--continue_training', type=bool, default=False, help='whether continue training from old checkpoint')
|
38 |
+
parser.add_argument('--continue_ckpt', type=str, default='.', help='checkpoint model for continue training')
|
39 |
+
parser.add_argument('--init_epoch', type=int, default=0, help='init epoch')
|
40 |
+
|
41 |
+
# Manually Add
|
42 |
+
parser.add_argument('--exp_path', type=str, default='.')
|
43 |
+
parser.add_argument('--dir_res', type=str, default=None)
|
44 |
+
parser.add_argument('--data_root', type=str, default='./data/vecfont_dataset/')
|
45 |
+
parser.add_argument('--freq_ckpt', type=int, default=50, help='save checkpoint frequency of epoch')
|
46 |
+
parser.add_argument('--threshold_ckpt', type=int, default=0, help='save checkpoint only when more than threshold epoch')
|
47 |
+
|
48 |
+
parser.add_argument('--freq_sample', type=int, default=500, help='sample train output of steps')
|
49 |
+
parser.add_argument('--freq_log', type=int, default=50, help='freq of showing logs')
|
50 |
+
parser.add_argument('--freq_val', type=int, default=500, help='sample validate output of steps')
|
51 |
+
parser.add_argument('--beta1', type=float, default=0.9, help='beta1 of Adam optimizer')
|
52 |
+
parser.add_argument('--beta2', type=float, default=0.999, help='beta2 of Adam optimizer')
|
53 |
+
parser.add_argument('--eps', type=float, default=1e-8, help='Adam epsilon')
|
54 |
+
parser.add_argument('--weight_decay', type=float, default=0.0, help='weight decay')
|
55 |
+
parser.add_argument('--wandb', type=bool, default=True, help='whether use wandb to visulize loss')
|
56 |
+
parser.add_argument('--wandb_project_name', type=str, default="DeepVecFontV2", help='wandb project name')
|
57 |
+
|
58 |
+
# loss weight
|
59 |
+
parser.add_argument('--kl_beta', type=float, default=0.01, help='latent code kl loss beta')
|
60 |
+
parser.add_argument('--loss_w_pt_c', type=float, default=0.001 * 10, help='the weight of perceptual content loss')
|
61 |
+
parser.add_argument('--loss_w_l1', type=float, default=1.0 * 10, help='the weight of image reconstruction l1 loss')
|
62 |
+
parser.add_argument('--loss_w_cmd', type=float, default=1.0, help='the weight of cmd loss')
|
63 |
+
parser.add_argument('--loss_w_args', type=float, default=1.0, help='the weight of args loss')
|
64 |
+
parser.add_argument('--loss_w_aux', type=float, default=0.01, help='the weight of pts aux loss')
|
65 |
+
parser.add_argument('--loss_w_smt', type=float, default=10., help='the weight of smooth loss')
|
66 |
+
|
67 |
+
return parser
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
python3-fontforge
|
requirements.txt
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==23.2.1
|
2 |
+
altair==5.3.0
|
3 |
+
annotated-types==0.7.0
|
4 |
+
anyio==4.4.0
|
5 |
+
attrs==23.2.0
|
6 |
+
blinker==1.8.2
|
7 |
+
build==1.2.1
|
8 |
+
cachetools==5.3.3
|
9 |
+
cairocffi==1.7.0
|
10 |
+
CairoSVG==2.7.1
|
11 |
+
certifi==2024.2.2
|
12 |
+
cffi==1.16.0
|
13 |
+
charset-normalizer==3.3.2
|
14 |
+
click==8.1.7
|
15 |
+
contourpy==1.2.1
|
16 |
+
cssselect2==0.7.0
|
17 |
+
cycler==0.12.1
|
18 |
+
defusedxml==0.7.1
|
19 |
+
dnspython==2.6.1
|
20 |
+
docker-pycreds==0.4.0
|
21 |
+
einops==0.8.0
|
22 |
+
email_validator==2.1.1
|
23 |
+
exceptiongroup==1.2.1
|
24 |
+
fastapi==0.111.0
|
25 |
+
fastapi-cli==0.0.4
|
26 |
+
ffmpy==0.3.2
|
27 |
+
filelock==3.14.0
|
28 |
+
fonttools==4.52.4
|
29 |
+
fsspec==2024.5.0
|
30 |
+
gitdb==4.0.11
|
31 |
+
GitPython==3.1.43
|
32 |
+
Glances==4.0.7
|
33 |
+
gradio==4.31.5
|
34 |
+
gradio_client==0.16.4
|
35 |
+
h11==0.14.0
|
36 |
+
httpcore==1.0.5
|
37 |
+
httptools==0.6.1
|
38 |
+
httpx==0.27.0
|
39 |
+
huggingface-hub==0.23.2
|
40 |
+
idna==3.7
|
41 |
+
imageio==2.34.1
|
42 |
+
importlib_resources==6.4.0
|
43 |
+
Jinja2==3.1.4
|
44 |
+
jsonschema==4.22.0
|
45 |
+
jsonschema-specifications==2023.12.1
|
46 |
+
kiwisolver==1.4.5
|
47 |
+
lazy_loader==0.4
|
48 |
+
markdown-it-py==3.0.0
|
49 |
+
MarkupSafe==2.1.5
|
50 |
+
matplotlib==3.9.0
|
51 |
+
mdurl==0.1.2
|
52 |
+
networkx==3.3
|
53 |
+
numpy==1.26.4
|
54 |
+
orjson==3.10.3
|
55 |
+
packaging==24.0
|
56 |
+
pandas==2.2.2
|
57 |
+
Pillow==9.5.0
|
58 |
+
pip-tools==7.4.1
|
59 |
+
platformdirs==4.2.2
|
60 |
+
protobuf==4.25.3
|
61 |
+
psutil==5.9.8
|
62 |
+
py3nvml==0.2.7
|
63 |
+
pyarrow==16.1.0
|
64 |
+
pycparser==2.22
|
65 |
+
pydantic==2.7.2
|
66 |
+
pydantic_core==2.18.3
|
67 |
+
pydeck==0.9.1
|
68 |
+
pydub==0.25.1
|
69 |
+
Pygments==2.18.0
|
70 |
+
pyparsing==3.1.2
|
71 |
+
pyproject_hooks==1.1.0
|
72 |
+
pythainlp==5.0.3
|
73 |
+
python-dateutil==2.9.0.post0
|
74 |
+
python-dotenv==1.0.1
|
75 |
+
python-multipart==0.0.9
|
76 |
+
pytz==2024.1
|
77 |
+
PyYAML==6.0.1
|
78 |
+
referencing==0.35.1
|
79 |
+
requests==2.32.2
|
80 |
+
rich==13.7.1
|
81 |
+
rpds-py==0.18.1
|
82 |
+
ruff==0.4.6
|
83 |
+
safetensors==0.4.3
|
84 |
+
scikit-image==0.23.2
|
85 |
+
scipy==1.13.1
|
86 |
+
semantic-version==2.10.0
|
87 |
+
sentry-sdk==2.3.1
|
88 |
+
setproctitle==1.3.3
|
89 |
+
shellingham==1.5.4
|
90 |
+
six==1.16.0
|
91 |
+
smmap==5.0.1
|
92 |
+
sniffio==1.3.1
|
93 |
+
starlette==0.37.2
|
94 |
+
streamlit==1.35.0
|
95 |
+
tenacity==8.3.0
|
96 |
+
tensorboardX==2.6.2.2
|
97 |
+
tifffile==2024.5.22
|
98 |
+
timm==1.0.3
|
99 |
+
tinycss2==1.3.0
|
100 |
+
toml==0.10.2
|
101 |
+
tomli==2.0.1
|
102 |
+
tomlkit==0.12.0
|
103 |
+
toolz==0.12.1
|
104 |
+
torch==1.13.1+cu117
|
105 |
+
torchaudio==0.13.1+cu117
|
106 |
+
torchvision==0.14.1+cu117
|
107 |
+
tornado==6.4
|
108 |
+
tqdm==4.66.4
|
109 |
+
typer==0.12.3
|
110 |
+
typing_extensions==4.12.0
|
111 |
+
tzdata==2024.1
|
112 |
+
ujson==5.10.0
|
113 |
+
urllib3==2.2.1
|
114 |
+
uvicorn==0.30.0
|
115 |
+
uvloop==0.19.0
|
116 |
+
wandb==0.17.0
|
117 |
+
watchdog==4.0.1
|
118 |
+
watchfiles==0.22.0
|
119 |
+
webencodings==0.5.1
|
120 |
+
websockets==11.0.3
|
121 |
+
xmltodict==0.13.0
|
test.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torchvision.utils import save_image
|
6 |
+
from dataloader import get_loader
|
7 |
+
from models.model_main import ModelMain
|
8 |
+
from models.transformers import denumericalize
|
9 |
+
from options import get_parser_main_model
|
10 |
+
from data_utils.svg_utils import render
|
11 |
+
from models.util_funcs import svg2img, cal_iou
|
12 |
+
|
13 |
+
# Testing (Only accuracy)
|
14 |
+
|
15 |
+
def test_main_model(opts):
|
16 |
+
test_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size, 'test')
|
17 |
+
|
18 |
+
model_main = ModelMain(opts)
|
19 |
+
path_ckpt = os.path.join(f"{opts.model_path}")
|
20 |
+
model_main.load_state_dict(torch.load(path_ckpt)['model'])
|
21 |
+
model_main.cuda()
|
22 |
+
model_main.eval() # Testing mode
|
23 |
+
|
24 |
+
with torch.no_grad():
|
25 |
+
loss_val = {'img':{'l1':0.0, 'vggpt':0.0}, 'svg':{'total':0.0, 'cmd':0.0, 'args':0.0, 'aux':0.0},
|
26 |
+
'svg_para':{'total':0.0, 'cmd':0.0, 'args':0.0, 'aux':0.0}}
|
27 |
+
|
28 |
+
for val_idx, val_data in enumerate(test_loader):
|
29 |
+
for key in val_data: val_data[key] = val_data[key].cuda()
|
30 |
+
ret_dict_val, loss_dict_val = model_main(val_data, mode='val')
|
31 |
+
for loss_cat in ['img', 'svg']:
|
32 |
+
for key, _ in loss_val[loss_cat].items():
|
33 |
+
loss_val[loss_cat][key] += loss_dict_val[loss_cat][key]
|
34 |
+
|
35 |
+
for loss_cat in ['img', 'svg']:
|
36 |
+
for key, _ in loss_val[loss_cat].items():
|
37 |
+
loss_val[loss_cat][key] /= len(test_loader)
|
38 |
+
|
39 |
+
val_msg = (
|
40 |
+
f"Val loss img l1: {loss_val['img']['l1']: .6f}, "
|
41 |
+
f"Val loss img pt: {loss_val['img']['vggpt']: .6f}, "
|
42 |
+
f"Val loss total: {loss_val['svg']['total']: .6f}, "
|
43 |
+
f"Val loss cmd: {loss_val['svg']['cmd']: .6f}, "
|
44 |
+
f"Val loss args: {loss_val['svg']['args']: .6f}, "
|
45 |
+
)
|
46 |
+
|
47 |
+
print(val_msg)
|
48 |
+
print(f"l1: {loss_val['img']['l1']: .6f}, pt: {loss_val['img']['vggpt']: .6f}")
|
49 |
+
|
50 |
+
def main():
|
51 |
+
|
52 |
+
opts = get_parser_main_model().parse_args()
|
53 |
+
opts.name_exp = opts.name_exp + '_' + opts.model_name
|
54 |
+
experiment_dir = os.path.join(f"{opts.exp_path}","experiments", opts.name_exp)
|
55 |
+
print(f"Testing on experiment {opts.name_exp}...")
|
56 |
+
# Dump options
|
57 |
+
test_main_model(opts)
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
main()
|
test_few_shot.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torchvision.utils import save_image
|
6 |
+
from dataloader import get_loader
|
7 |
+
from models.model_main import ModelMain
|
8 |
+
from models.transformers import denumericalize
|
9 |
+
from options import get_parser_main_model
|
10 |
+
from data_utils.svg_utils import render
|
11 |
+
from models.util_funcs import svg2img, cal_iou
|
12 |
+
from tqdm import tqdm
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
def test_main_model(opts):
|
16 |
+
if opts.streamlit:
|
17 |
+
import streamlit as st
|
18 |
+
|
19 |
+
if opts.dir_res:
|
20 |
+
os.mkdir(os.path.join(opts.dir_res, "results"))
|
21 |
+
dir_res = os.path.join(opts.dir_res, "results")
|
22 |
+
else:
|
23 |
+
dir_res = os.path.join(f"{opts.exp_path}", "experiments/", opts.name_exp, "results")
|
24 |
+
|
25 |
+
test_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size, 'test')
|
26 |
+
if torch.cuda.is_available():
|
27 |
+
device = torch.device("cuda")
|
28 |
+
else:
|
29 |
+
device = torch.device("cpu")
|
30 |
+
if opts.streamlit:
|
31 |
+
st.write("Loading Model Weight...")
|
32 |
+
model_main = ModelMain(opts)
|
33 |
+
path_ckpt = os.path.join(f"{opts.model_path}")
|
34 |
+
model_main.load_state_dict(torch.load(path_ckpt)['model'])
|
35 |
+
model_main.to(device)
|
36 |
+
model_main.eval()
|
37 |
+
with torch.no_grad():
|
38 |
+
|
39 |
+
for test_idx, test_data in enumerate(test_loader):
|
40 |
+
for key in test_data: test_data[key] = test_data[key].to(device)
|
41 |
+
|
42 |
+
print("testing font %04d ..."%test_idx)
|
43 |
+
|
44 |
+
dir_save = os.path.join(dir_res, "%04d"%test_idx)
|
45 |
+
if not os.path.exists(dir_save):
|
46 |
+
os.mkdir(dir_save)
|
47 |
+
os.mkdir(os.path.join(dir_save, "imgs"))
|
48 |
+
os.mkdir(os.path.join(dir_save, "svgs_single"))
|
49 |
+
os.mkdir(os.path.join(dir_save, "svgs_merge"))
|
50 |
+
svg_merge_dir = os.path.join(dir_save, "svgs_merge")
|
51 |
+
|
52 |
+
iou_max = np.zeros(opts.char_num)
|
53 |
+
idx_best_sample = np.zeros(opts.char_num)
|
54 |
+
|
55 |
+
# syn_svg_merge_f = open(os.path.join(svg_merge_dir, f"{opts.name_ckpt}_syn_merge_{test_idx}_rand_{sample_idx}.html"), 'w')
|
56 |
+
syn_svg_merge_f = open(os.path.join(svg_merge_dir, f"{opts.name_ckpt}_syn_merge_{test_idx}.html"), 'w')
|
57 |
+
|
58 |
+
for sample_idx in tqdm(range(opts.n_samples)):
|
59 |
+
|
60 |
+
ret_dict_test, loss_dict_test = model_main(test_data, mode='test')
|
61 |
+
|
62 |
+
svg_sampled = ret_dict_test['svg']['sampled_1']
|
63 |
+
sampled_svg_2 = ret_dict_test['svg']['sampled_2']
|
64 |
+
|
65 |
+
img_trg = ret_dict_test['img']['trg']
|
66 |
+
img_output = ret_dict_test['img']['out']
|
67 |
+
trg_seq_gt = ret_dict_test['svg']['trg']
|
68 |
+
|
69 |
+
img_sample_merge = torch.cat((img_trg.data, img_output.data), -2)
|
70 |
+
save_file_merge = os.path.join(dir_save, "imgs", f"merge_{opts.img_size}.png")
|
71 |
+
save_image(img_sample_merge, save_file_merge, nrow=8, normalize=True)
|
72 |
+
if opts.streamlit:
|
73 |
+
st.progress((sample_idx+1)/opts.n_samples, f"Generating Font Sample {sample_idx+1} Please wait...")
|
74 |
+
im = Image.open(save_file_merge)
|
75 |
+
st.image(im, caption='img_sample_merge')
|
76 |
+
|
77 |
+
for char_idx in range(opts.char_num):
|
78 |
+
img_gt = (1.0 - img_trg[char_idx,...]).data
|
79 |
+
save_file_gt = os.path.join(dir_save,"imgs", f"{char_idx:02d}_gt.png")
|
80 |
+
save_image(img_gt, save_file_gt, normalize=True)
|
81 |
+
|
82 |
+
img_sample = (1.0 - img_output[char_idx,...]).data
|
83 |
+
save_file = os.path.join(dir_save,"imgs", f"{char_idx:02d}_{opts.img_size}.png")
|
84 |
+
save_image(img_sample, save_file, normalize=True)
|
85 |
+
|
86 |
+
# write results w/o parallel refinement
|
87 |
+
svg_dec_out = svg_sampled.clone().detach()
|
88 |
+
for i, one_seq in enumerate(svg_dec_out):
|
89 |
+
syn_svg_outfile = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{sample_idx}_wo_refine.svg")
|
90 |
+
|
91 |
+
syn_svg_f_ = open(syn_svg_outfile, 'w')
|
92 |
+
try:
|
93 |
+
svg = render(one_seq.cpu().numpy())
|
94 |
+
syn_svg_f_.write(svg)
|
95 |
+
# syn_svg_merge_f.write(svg)
|
96 |
+
if i > 0 and i % 13 == 12:
|
97 |
+
syn_svg_f_.write('<br>')
|
98 |
+
# syn_svg_merge_f.write('<br>')
|
99 |
+
|
100 |
+
except:
|
101 |
+
continue
|
102 |
+
syn_svg_f_.close()
|
103 |
+
|
104 |
+
# write results w/ parallel refinement
|
105 |
+
svg_dec_out = sampled_svg_2.clone().detach()
|
106 |
+
for i, one_seq in enumerate(svg_dec_out):
|
107 |
+
syn_svg_outfile = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{sample_idx}_refined.svg")
|
108 |
+
|
109 |
+
syn_svg_f = open(syn_svg_outfile, 'w')
|
110 |
+
try:
|
111 |
+
svg = render(one_seq.cpu().numpy())
|
112 |
+
syn_svg_f.write(svg)
|
113 |
+
#syn_svg_merge_f.write(svg)
|
114 |
+
|
115 |
+
#if i > 0 and i % 13 == 12:
|
116 |
+
# syn_svg_merge_f.write('<br>')
|
117 |
+
except:
|
118 |
+
continue
|
119 |
+
syn_svg_f.close()
|
120 |
+
syn_img_outfile = syn_svg_outfile.replace('.svg', '.png')
|
121 |
+
svg2img(syn_svg_outfile, syn_img_outfile, img_size=opts.img_size)
|
122 |
+
iou_tmp, l1_tmp = cal_iou(syn_img_outfile, os.path.join(dir_save, "imgs", f"{i:02d}_{opts.img_size}.png"))
|
123 |
+
iou_tmp = iou_tmp
|
124 |
+
if iou_tmp > iou_max[i]:
|
125 |
+
iou_max[i] = iou_tmp
|
126 |
+
idx_best_sample[i] = sample_idx
|
127 |
+
|
128 |
+
for i in range(opts.char_num):
|
129 |
+
# print(idx_best_sample[i])
|
130 |
+
syn_svg_outfile_best = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{int(idx_best_sample[i])}_refined.svg")
|
131 |
+
syn_svg_merge_f.write(open(syn_svg_outfile_best, 'r').read())
|
132 |
+
if i > 0 and i % 13 == 12:
|
133 |
+
syn_svg_merge_f.write('<br>')
|
134 |
+
|
135 |
+
svg_target = trg_seq_gt.clone().detach()
|
136 |
+
tgt_commands_onehot = F.one_hot(svg_target[:, :, :1].long(), 4).squeeze()
|
137 |
+
tgt_args_denum = denumericalize(svg_target[:, :, 1:])
|
138 |
+
svg_target = torch.cat([tgt_commands_onehot, tgt_args_denum], dim=-1)
|
139 |
+
|
140 |
+
for i, one_gt_seq in enumerate(svg_target):
|
141 |
+
# gt_svg_outfile = os.path.join(os.path.join(dir_save, "svgs_single"), f"gt_{i:02d}.svg")
|
142 |
+
# gt_svg_f = open(gt_svg_outfile, 'w')
|
143 |
+
gt_svg = render(one_gt_seq.cpu().numpy())
|
144 |
+
# gt_svg_f.write(gt_svg)
|
145 |
+
syn_svg_merge_f.write(gt_svg)
|
146 |
+
# gt_svg_f.close()
|
147 |
+
if i > 0 and i % 13 == 12:
|
148 |
+
syn_svg_merge_f.write('<br>')
|
149 |
+
|
150 |
+
syn_svg_merge_f.close()
|
151 |
+
|
152 |
+
return im
|
153 |
+
|
154 |
+
def main():
|
155 |
+
|
156 |
+
opts = get_parser_main_model().parse_args()
|
157 |
+
opts.name_exp = opts.name_exp + '_' + opts.model_name
|
158 |
+
experiment_dir = os.path.join(f"{opts.exp_path}","experiments", opts.name_exp)
|
159 |
+
print(f"Testing on experiment {opts.name_exp}...")
|
160 |
+
# Dump options
|
161 |
+
test_main_model(opts)
|
162 |
+
|
163 |
+
if __name__ == "__main__":
|
164 |
+
main()
|
train.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import numpy as np
|
4 |
+
import shutil
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch.optim import Adam, AdamW
|
9 |
+
from torchvision.utils import save_image
|
10 |
+
import wandb
|
11 |
+
from dataloader import get_loader
|
12 |
+
from models import util_funcs
|
13 |
+
from models.model_main import ModelMain
|
14 |
+
from options import get_parser_main_model
|
15 |
+
from data_utils.svg_utils import render
|
16 |
+
from time import time
|
17 |
+
|
18 |
+
def setup_seed(seed):
|
19 |
+
torch.manual_seed(seed)
|
20 |
+
torch.cuda.manual_seed_all(seed)
|
21 |
+
np.random.seed(seed)
|
22 |
+
random.seed(seed)
|
23 |
+
torch.backends.cudnn.deterministic = True
|
24 |
+
|
25 |
+
def train_main_model(opts):
|
26 |
+
setup_seed(opts.seed)
|
27 |
+
dir_exp = os.path.join(f"{opts.exp_path}", "experiments", opts.name_exp)
|
28 |
+
dir_sample = os.path.join(dir_exp, "samples")
|
29 |
+
dir_ckpt = os.path.join(dir_exp, "checkpoints")
|
30 |
+
dir_log = os.path.join(dir_exp, "logs")
|
31 |
+
logfile_train = open(os.path.join(dir_log, "train_loss_log.txt"), 'w')
|
32 |
+
logfile_val = open(os.path.join(dir_log, "val_loss_log.txt"), 'w')
|
33 |
+
|
34 |
+
train_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size, opts.mode)
|
35 |
+
val_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size_val, 'val')
|
36 |
+
|
37 |
+
run = wandb.init(project=opts.wandb_project_name, config=opts) # initialize wandb project
|
38 |
+
text_table = wandb.Table(columns=["epoch", "loss", "ref"])
|
39 |
+
|
40 |
+
model_main = ModelMain(opts)
|
41 |
+
if torch.cuda.is_available() and opts.multi_gpu:
|
42 |
+
model_main = torch.nn.DataParallel(model_main)
|
43 |
+
|
44 |
+
if opts.continue_training:
|
45 |
+
model_main.load_state_dict(torch.load(opts.continue_ckpt)['model'])
|
46 |
+
|
47 |
+
model_main.cuda()
|
48 |
+
|
49 |
+
parameters_all = [{"params": model_main.img_encoder.parameters()}, {"params": model_main.img_decoder.parameters()},
|
50 |
+
{"params": model_main.modality_fusion.parameters()}, {"params": model_main.transformer_main.parameters()},
|
51 |
+
{"params": model_main.transformer_seqdec.parameters()}]
|
52 |
+
|
53 |
+
optimizer = AdamW(parameters_all, lr=opts.lr, betas=(opts.beta1, opts.beta2), eps=opts.eps, weight_decay=opts.weight_decay)
|
54 |
+
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.997)
|
55 |
+
|
56 |
+
for epoch in range(opts.init_epoch, opts.n_epochs):
|
57 |
+
t0 = time()
|
58 |
+
for idx, data in enumerate(train_loader):
|
59 |
+
for key in data: data[key] = data[key].cuda()
|
60 |
+
ret_dict, loss_dict = model_main(data)
|
61 |
+
|
62 |
+
loss = opts.loss_w_l1 * loss_dict['img']['l1'] + opts.loss_w_pt_c * loss_dict['img']['vggpt'] + opts.kl_beta * loss_dict['kl'] \
|
63 |
+
+ loss_dict['svg']['total'] + loss_dict['svg_para']['total']
|
64 |
+
|
65 |
+
# perform optimization
|
66 |
+
optimizer.zero_grad()
|
67 |
+
loss.backward()
|
68 |
+
optimizer.step()
|
69 |
+
batches_done = epoch * len(train_loader) + idx + 1
|
70 |
+
message = (
|
71 |
+
f"Time: {'{} seconds'.format(time() - t0)}, "
|
72 |
+
f"Epoch: {epoch}/{opts.n_epochs}, Batch: {idx}/{len(train_loader)}, "
|
73 |
+
f"Loss: {loss.item():.6f}, "
|
74 |
+
f"img_l1_loss: {opts.loss_w_l1 * loss_dict['img']['l1'].item():.6f}, "
|
75 |
+
f"img_pt_c_loss: {opts.loss_w_pt_c * loss_dict['img']['vggpt']:.6f}, "
|
76 |
+
f"svg_total_loss: {loss_dict['svg']['total'].item():.6f}, "
|
77 |
+
f"svg_cmd_loss: {opts.loss_w_cmd * loss_dict['svg']['cmd'].item():.6f}, "
|
78 |
+
f"svg_args_loss: {opts.loss_w_args * loss_dict['svg']['args'].item():.6f}, "
|
79 |
+
f"svg_smooth_loss: {opts.loss_w_smt * loss_dict['svg']['smt'].item():.6f}, "
|
80 |
+
f"svg_aux_loss: {opts.loss_w_aux * loss_dict['svg']['aux'].item():.6f}, "
|
81 |
+
f"lr: {optimizer.param_groups[0]['lr']:.6f}, "
|
82 |
+
f"Step: {batches_done}"
|
83 |
+
)
|
84 |
+
if batches_done % opts.freq_log == 0:
|
85 |
+
logfile_train.write(message + '\n')
|
86 |
+
print(message)
|
87 |
+
|
88 |
+
if opts.wandb:
|
89 |
+
# print("Running With Wandb")
|
90 |
+
# Define the items for image and SVG losses
|
91 |
+
loss_img_items = ['l1', 'vggpt']
|
92 |
+
loss_svg_items = ['total', 'cmd', 'args', 'aux', 'smt']
|
93 |
+
|
94 |
+
# Log image loss items
|
95 |
+
for item in loss_img_items:
|
96 |
+
wandb.log({f'Loss/img_{item}': loss_dict['img'][item].item()}, step=batches_done)
|
97 |
+
|
98 |
+
# Log SVG loss items
|
99 |
+
for item in loss_svg_items:
|
100 |
+
wandb.log({f'Loss/svg_{item}': loss_dict['svg'][item].item()}, step=batches_done)
|
101 |
+
wandb.log({f'Loss/svg_para_{item}': loss_dict['svg_para'][item].item()}, step=batches_done)
|
102 |
+
|
103 |
+
# Log KL loss
|
104 |
+
wandb.log({'Loss/img_kl_loss': opts.kl_beta * loss_dict['kl'].item()}, step=batches_done)
|
105 |
+
|
106 |
+
wandb.log({
|
107 |
+
'Images/trg_img': wandb.Image(ret_dict['img']['trg'][0], caption="Target"),
|
108 |
+
'Images/img_output': wandb.Image(ret_dict['img']['out'][0], caption="Output")
|
109 |
+
}, step=batches_done)
|
110 |
+
|
111 |
+
text_table.add_data(epoch, loss, str(ret_dict['img']['ref'][0]))
|
112 |
+
wandb.log({"training_samples" : text_table})
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
if opts.freq_sample > 0 and batches_done % opts.freq_sample == 0:
|
117 |
+
|
118 |
+
img_sample = torch.cat((ret_dict['img']['trg'].data, ret_dict['img']['out'].data), -2)
|
119 |
+
save_file = os.path.join(dir_sample, f"train_epoch_{epoch}_batch_{batches_done}.png")
|
120 |
+
save_image(img_sample, save_file, nrow=8, normalize=True)
|
121 |
+
|
122 |
+
if opts.freq_val > 0 and batches_done % opts.freq_val == 0:
|
123 |
+
|
124 |
+
with torch.no_grad():
|
125 |
+
model_main.eval()
|
126 |
+
loss_val = {'img':{'l1':0.0, 'vggpt':0.0}, 'svg':{'total':0.0, 'cmd':0.0, 'args':0.0, 'aux':0.0},
|
127 |
+
'svg_para':{'total':0.0, 'cmd':0.0, 'args':0.0, 'aux':0.0}}
|
128 |
+
|
129 |
+
for val_idx, val_data in enumerate(val_loader):
|
130 |
+
for key in val_data: val_data[key] = val_data[key].cuda()
|
131 |
+
ret_dict_val, loss_dict_val = model_main(val_data, mode='val')
|
132 |
+
for loss_cat in ['img', 'svg']:
|
133 |
+
for key, _ in loss_val[loss_cat].items():
|
134 |
+
loss_val[loss_cat][key] += loss_dict_val[loss_cat][key]
|
135 |
+
|
136 |
+
for loss_cat in ['img', 'svg']:
|
137 |
+
for key, _ in loss_val[loss_cat].items():
|
138 |
+
loss_val[loss_cat][key] /= len(val_loader)
|
139 |
+
|
140 |
+
if opts.wandb:
|
141 |
+
for loss_cat in ['img', 'svg']:
|
142 |
+
# Iterate over keys and values in the loss dictionary
|
143 |
+
for key, value in loss_val[loss_cat].items():
|
144 |
+
# Log loss value to WandB
|
145 |
+
wandb.log({f'VAL/loss_{loss_cat}_{key}': value})
|
146 |
+
|
147 |
+
val_msg = (
|
148 |
+
f"Epoch: {epoch}/{opts.n_epochs}, Batch: {idx}/{len(train_loader)}, "
|
149 |
+
f"Val loss img l1: {loss_val['img']['l1']: .6f}, "
|
150 |
+
f"Val loss img pt: {loss_val['img']['vggpt']: .6f}, "
|
151 |
+
f"Val loss total: {loss_val['svg']['total']: .6f}, "
|
152 |
+
f"Val loss cmd: {loss_val['svg']['cmd']: .6f}, "
|
153 |
+
f"Val loss args: {loss_val['svg']['args']: .6f}, "
|
154 |
+
)
|
155 |
+
|
156 |
+
logfile_val.write(val_msg + "\n")
|
157 |
+
print(val_msg)
|
158 |
+
|
159 |
+
|
160 |
+
scheduler.step()
|
161 |
+
|
162 |
+
if epoch % opts.freq_ckpt == 0 and epoch >= opts.threshold_ckpt:
|
163 |
+
if opts.multi_gpu:
|
164 |
+
print(f"Saved {dir_ckpt}/{epoch}_{batches_done}.ckpt")
|
165 |
+
torch.save({'model':model_main.module.state_dict(), 'opt':optimizer.state_dict(), 'n_epoch':epoch, 'n_iter':batches_done}, f'{dir_ckpt}/{epoch}_{batches_done}.ckpt')
|
166 |
+
else:
|
167 |
+
print(f"Saved {dir_ckpt}/{epoch}_{batches_done}.ckpt")
|
168 |
+
torch.save({'model':model_main.state_dict(), 'opt':optimizer.state_dict(), 'n_epoch':epoch, 'n_iter':batches_done}, f'{dir_ckpt}/{epoch}_{batches_done}.ckpt')
|
169 |
+
if opts.wandb:
|
170 |
+
artifact = wandb.Artifact('model_main_checkpoints', type='model')
|
171 |
+
artifact.add_file(f'{dir_ckpt}/{epoch}_{batches_done}.ckpt')
|
172 |
+
run.log_artifact(artifact)
|
173 |
+
|
174 |
+
logfile_train.close()
|
175 |
+
logfile_val.close()
|
176 |
+
|
177 |
+
def backup_code(name_exp, exp_path):
|
178 |
+
os.makedirs(os.path.join(exp_path,'experiments', name_exp, 'code'), exist_ok=True)
|
179 |
+
shutil.copy('models/transformers.py', os.path.join(exp_path,'experiments', name_exp, 'code', 'transformers.py') )
|
180 |
+
shutil.copy('models/model_main.py', os.path.join(exp_path,'experiments', name_exp, 'code', 'model_main.py'))
|
181 |
+
shutil.copy('models/image_encoder.py', os.path.join(exp_path,'experiments', name_exp, 'code', 'image_encoder.py'))
|
182 |
+
shutil.copy('models/image_decoder.py', os.path.join(exp_path,'experiments', name_exp, 'code', 'image_decoder.py'))
|
183 |
+
shutil.copy('./train.py', os.path.join(exp_path,'experiments', name_exp, 'code', 'train.py'))
|
184 |
+
shutil.copy('./options.py', os.path.join(exp_path,'experiments', name_exp, 'code', 'options.py'))
|
185 |
+
|
186 |
+
def train(opts):
|
187 |
+
if opts.model_name == 'main_model':
|
188 |
+
train_main_model(opts)
|
189 |
+
elif opts.model_name == 'others':
|
190 |
+
train_others(opts)
|
191 |
+
else:
|
192 |
+
raise NotImplementedError
|
193 |
+
|
194 |
+
def main():
|
195 |
+
|
196 |
+
opts = get_parser_main_model().parse_args()
|
197 |
+
opts.name_exp = opts.name_exp + '_' + opts.model_name
|
198 |
+
os.makedirs(f"{opts.exp_path}/experiments", exist_ok=True)
|
199 |
+
debug = True
|
200 |
+
# Create directories
|
201 |
+
experiment_dir = os.path.join(f"{opts.exp_path}","experiments", opts.name_exp)
|
202 |
+
backup_code(opts.name_exp, opts.exp_path)
|
203 |
+
os.makedirs(experiment_dir, exist_ok=debug) # False to prevent multiple train run by mistake
|
204 |
+
os.makedirs(os.path.join(experiment_dir, "samples"), exist_ok=True)
|
205 |
+
os.makedirs(os.path.join(experiment_dir, "checkpoints"), exist_ok=True)
|
206 |
+
os.makedirs(os.path.join(experiment_dir, "results"), exist_ok=True)
|
207 |
+
os.makedirs(os.path.join(experiment_dir, "logs"), exist_ok=True)
|
208 |
+
print(f"Training on experiment {opts.name_exp}...")
|
209 |
+
# Dump options
|
210 |
+
with open(os.path.join(experiment_dir, "opts.txt"), "w") as f:
|
211 |
+
for key, value in vars(opts).items():
|
212 |
+
f.write(str(key) + ": " + str(value) + "\n")
|
213 |
+
train(opts)
|
214 |
+
|
215 |
+
if __name__ == "__main__":
|
216 |
+
main()
|