microhum commited on
Commit
667ab99
·
1 Parent(s): 34f1982

add dockerfile

Browse files
.gitattributes CHANGED
@@ -1,36 +1,38 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- *.gif filter=lfs diff=lfs merge=lfs -text
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.gif filter=lfs diff=lfs merge=lfs -text
37
+ *ckpt filter=lfs diff=lfs merge=lfs -text
38
+ inference_model filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -1,91 +1,91 @@
1
- data/
2
- experiments/
3
- inference_few_shot.py
4
- flagged
5
- inference
6
- Font_dataset
7
- venv
8
-
9
- # Byte-compiled / optimized / DLL files
10
- __pycache__/
11
- *.py[cod]
12
- *$py.class
13
-
14
- # C extensions
15
- *.so
16
-
17
- # Distribution / packaging
18
- .Python
19
- build/
20
- develop-eggs/
21
- dist/
22
- downloads/
23
- eggs/
24
- .eggs/
25
- lib/
26
- lib64/
27
- parts/
28
- sdist/
29
- var/
30
- wheels/
31
- share/python-wheels/
32
- *.egg-info/
33
- .installed.cfg
34
- *.egg
35
- MANIFEST
36
-
37
- # PyInstaller
38
- # Usually these files are written by a python script from a template
39
- # before PyInstaller builds the exe, so as to inject date/other infos into it.
40
- *.manifest
41
- *.spec
42
-
43
- # Installer logs
44
- pip-log.txt
45
- pip-delete-this-directory.txt
46
-
47
- # Unit test / coverage reports
48
- htmlcov/
49
- .tox/
50
- .nox/
51
- .coverage
52
- .coverage.*
53
- .cache
54
- nosetests.xml
55
- coverage.xml
56
- *.cover
57
- *.py,cover
58
- .hypothesis/
59
- .pytest_cache/
60
- cover/
61
-
62
- # Translations
63
- *.mo
64
- *.pot
65
-
66
- # Django stuff:
67
- *.log
68
- local_settings.py
69
- db.sqlite3
70
- db.sqlite3-journal
71
-
72
- # Flask stuff:
73
- instance/
74
- .webassets-cache
75
-
76
- # Scrapy stuff:
77
- .scrapy
78
-
79
- # Sphinx documentation
80
- docs/_build/
81
-
82
- # PyBuilder
83
- .pybuilder/
84
- target/
85
-
86
- # Jupyter Notebook
87
- .ipynb_checkpoints
88
-
89
- # IPython
90
- profile_default/
91
- ipython_config.py
 
1
+ data/
2
+ experiments/
3
+ inference_few_shot.py
4
+ flagged
5
+ inference
6
+ Font_dataset
7
+ venv
8
+
9
+ # Byte-compiled / optimized / DLL files
10
+ __pycache__/
11
+ *.py[cod]
12
+ *$py.class
13
+
14
+ # C extensions
15
+ *.so
16
+
17
+ # Distribution / packaging
18
+ .Python
19
+ build/
20
+ develop-eggs/
21
+ dist/
22
+ downloads/
23
+ eggs/
24
+ .eggs/
25
+ lib/
26
+ lib64/
27
+ parts/
28
+ sdist/
29
+ var/
30
+ wheels/
31
+ share/python-wheels/
32
+ *.egg-info/
33
+ .installed.cfg
34
+ *.egg
35
+ MANIFEST
36
+
37
+ # PyInstaller
38
+ # Usually these files are written by a python script from a template
39
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
40
+ *.manifest
41
+ *.spec
42
+
43
+ # Installer logs
44
+ pip-log.txt
45
+ pip-delete-this-directory.txt
46
+
47
+ # Unit test / coverage reports
48
+ htmlcov/
49
+ .tox/
50
+ .nox/
51
+ .coverage
52
+ .coverage.*
53
+ .cache
54
+ nosetests.xml
55
+ coverage.xml
56
+ *.cover
57
+ *.py,cover
58
+ .hypothesis/
59
+ .pytest_cache/
60
+ cover/
61
+
62
+ # Translations
63
+ *.mo
64
+ *.pot
65
+
66
+ # Django stuff:
67
+ *.log
68
+ local_settings.py
69
+ db.sqlite3
70
+ db.sqlite3-journal
71
+
72
+ # Flask stuff:
73
+ instance/
74
+ .webassets-cache
75
+
76
+ # Scrapy stuff:
77
+ .scrapy
78
+
79
+ # Sphinx documentation
80
+ docs/_build/
81
+
82
+ # PyBuilder
83
+ .pybuilder/
84
+ target/
85
+
86
+ # Jupyter Notebook
87
+ .ipynb_checkpoints
88
+
89
+ # IPython
90
+ profile_default/
91
+ ipython_config.py
DockerFile ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10
2
+
3
+ WORKDIR /code
4
+
5
+ COPY ./requirements.txt /code/requirements.txt
6
+
7
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
8
+
9
+ RUN xargs apt-get install -y <packages.txt
10
+
11
+ COPY /home/usr/lib/python3/dist-packages/fontforge.cpython-310-x86_64-linux-gnu.so /home/usr/lib/python3.10/dist-packages/
12
+
13
+ COPY . .
14
+
15
+ CMD ["streamlit", "app.py"]
LICENSE CHANGED
@@ -1,21 +1,21 @@
1
- MIT License
2
-
3
- Copyright (c) 2023 Yizhi Wang
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Yizhi Wang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,10 +1,10 @@
1
- ---
2
- title: ThaiVecFont
3
- emoji: 🚀
4
- colorFrom: red
5
- colorTo: yellow
6
- sdk: streamlit
7
- sdk_version: 1.25.0
8
- app_file: app.py
9
- pinned: false
10
  ---
 
1
+ ---
2
+ title: ThaiVecFont
3
+ emoji: 🚀
4
+ colorFrom: red
5
+ colorTo: blue
6
+ sdk: docker
7
+ app_port: 8501
8
+ app_file: app.py
9
+ pinned: false
10
  ---
app.py CHANGED
@@ -1,113 +1,113 @@
1
- from typing import Optional
2
- import streamlit as st
3
- from generate import ttf_to_image
4
- from PIL import Image
5
- import os
6
-
7
- LOADED_TTF_KEY = "loaded_ttf"
8
- SET_IMG_KEY = "set_img"
9
- OUTPUT_IMG_KEY = "output_img"
10
-
11
- def get_ttf(key: str) -> Optional[any]:
12
- if key in st.session_state:
13
- return st.session_state[key]
14
- return None
15
-
16
- def get_img(key: str) -> Optional[Image.Image]:
17
- if key in st.session_state:
18
- return st.session_state[key]
19
- return None
20
-
21
- def set_img(key: str, img: Image.Image):
22
- st.session_state[key] = img
23
-
24
- def ttf_uploader(prefix):
25
- file = st.file_uploader("TTF, OTF", ["ttf", "otf"], key=f"{prefix}-uploader")
26
- if file:
27
- return file
28
-
29
- return get_ttf(LOADED_TTF_KEY)
30
-
31
- def generate_button(prefix, file_input, version, **kwargs):
32
-
33
- col1, col2 = st.columns(2)
34
- with col1:
35
- n_samples = st.slider(
36
- "Number of inference sample",
37
- min_value=1,
38
- max_value=200,
39
- value=20,
40
- key=f"{prefix}-inference-sample",
41
- )
42
- with col2:
43
- ref_char_ids = st.text_area(
44
- "ref_char_ids",
45
- value="1,2,3,4,5,6,7,8",
46
- key=f"{prefix}-ref_char_ids",
47
- )
48
- enable_attention_slicing = st.checkbox(
49
- "Enable attention slicing (enables higher resolutions but is slower)",
50
- key=f"{prefix}-attention-slicing",
51
- )
52
- enable_cpu_offload = st.checkbox(
53
- "Enable CPU offload (if you run out of memory, e.g. for XL model)",
54
- key=f"{prefix}-cpu-offload",
55
- value=False,
56
- )
57
-
58
- if st.button("Generate image", key=f"{prefix}-btn"):
59
- with st.spinner("⏳ Generating image..."):
60
- image = ttf_to_image(file_input, n_samples, ref_char_ids, version)
61
- set_img(OUTPUT_IMG_KEY, image.copy())
62
- st.image(image)
63
-
64
- test_font = st.text_area(
65
- "test font",
66
- value="กขคง",
67
- key=f"{prefix}-prompt",
68
- )
69
-
70
- def generate_tab():
71
- prefix = "ttf2img"
72
- col1, col2 = st.columns(2)
73
-
74
- with col1:
75
- sample_choose = st.selectbox(
76
- "Choose Sample", ["Custom"] + [i for i in os.listdir("font_sample/")], key=f"{prefix}-sample_choose"
77
- )
78
- if sample_choose == "Custom":
79
- uploaded_file = ttf_uploader(prefix)
80
- if uploaded_file:
81
- st.write("filename:", uploaded_file.name)
82
- uploaded_file = uploaded_file.getbuffer() # Send file as Buffer
83
-
84
- else:
85
- st.write("filename:", sample_choose)
86
- uploaded_file = os.path.join("font_sample", sample_choose)
87
-
88
- with col2:
89
- if uploaded_file:
90
- version = st.selectbox(
91
- "Model version", ["TH2TH", "ENG2TH"], key=f"{prefix}-version"
92
- )
93
- generate_button(
94
- prefix, file_input=uploaded_file, version=version
95
- )
96
-
97
- def main():
98
- st.set_page_config(layout="wide")
99
- st.title("ThaiVecFont Playground")
100
-
101
- generate_tab()
102
-
103
- with st.sidebar:
104
- st.header("Latest Output Image")
105
- output_image = get_img(OUTPUT_IMG_KEY)
106
- if output_image:
107
- st.image(output_image)
108
- else:
109
- st.markdown("No output generated yet")
110
-
111
-
112
- if __name__ == "__main__":
113
  main()
 
1
+ from typing import Optional
2
+ import streamlit as st
3
+ from generate import ttf_to_image
4
+ from PIL import Image
5
+ import os
6
+
7
+ LOADED_TTF_KEY = "loaded_ttf"
8
+ SET_IMG_KEY = "set_img"
9
+ OUTPUT_IMG_KEY = "output_img"
10
+
11
+ def get_ttf(key: str) -> Optional[any]:
12
+ if key in st.session_state:
13
+ return st.session_state[key]
14
+ return None
15
+
16
+ def get_img(key: str) -> Optional[Image.Image]:
17
+ if key in st.session_state:
18
+ return st.session_state[key]
19
+ return None
20
+
21
+ def set_img(key: str, img: Image.Image):
22
+ st.session_state[key] = img
23
+
24
+ def ttf_uploader(prefix):
25
+ file = st.file_uploader("TTF, OTF", ["ttf", "otf"], key=f"{prefix}-uploader")
26
+ if file:
27
+ return file
28
+
29
+ return get_ttf(LOADED_TTF_KEY)
30
+
31
+ def generate_button(prefix, file_input, version, **kwargs):
32
+
33
+ col1, col2 = st.columns(2)
34
+ with col1:
35
+ n_samples = st.slider(
36
+ "Number of inference sample",
37
+ min_value=1,
38
+ max_value=200,
39
+ value=20,
40
+ key=f"{prefix}-inference-sample",
41
+ )
42
+ with col2:
43
+ ref_char_ids = st.text_area(
44
+ "ref_char_ids",
45
+ value="1,2,3,4,5,6,7,8",
46
+ key=f"{prefix}-ref_char_ids",
47
+ )
48
+ enable_attention_slicing = st.checkbox(
49
+ "Enable attention slicing (enables higher resolutions but is slower)",
50
+ key=f"{prefix}-attention-slicing",
51
+ )
52
+ enable_cpu_offload = st.checkbox(
53
+ "Enable CPU offload (if you run out of memory, e.g. for XL model)",
54
+ key=f"{prefix}-cpu-offload",
55
+ value=False,
56
+ )
57
+
58
+ if st.button("Generate image", key=f"{prefix}-btn"):
59
+ with st.spinner("⏳ Generating image..."):
60
+ image = ttf_to_image(file_input, n_samples, ref_char_ids, version)
61
+ set_img(OUTPUT_IMG_KEY, image.copy())
62
+ st.image(image)
63
+
64
+ test_font = st.text_area(
65
+ "test font",
66
+ value="กขคง",
67
+ key=f"{prefix}-prompt",
68
+ )
69
+
70
+ def generate_tab():
71
+ prefix = "ttf2img"
72
+ col1, col2 = st.columns(2)
73
+
74
+ with col1:
75
+ sample_choose = st.selectbox(
76
+ "Choose Sample", ["Custom"] + [i for i in os.listdir("font_sample/")], key=f"{prefix}-sample_choose"
77
+ )
78
+ if sample_choose == "Custom":
79
+ uploaded_file = ttf_uploader(prefix)
80
+ if uploaded_file:
81
+ st.write("filename:", uploaded_file.name)
82
+ uploaded_file = uploaded_file.getbuffer() # Send file as Buffer
83
+
84
+ else:
85
+ st.write("filename:", sample_choose)
86
+ uploaded_file = os.path.join("font_sample", sample_choose)
87
+
88
+ with col2:
89
+ if uploaded_file:
90
+ version = st.selectbox(
91
+ "Model version", ["TH2TH", "ENG2TH"], key=f"{prefix}-version"
92
+ )
93
+ generate_button(
94
+ prefix, file_input=uploaded_file, version=version
95
+ )
96
+
97
+ def main():
98
+ st.set_page_config(layout="wide")
99
+ st.title("ThaiVecFont Playground")
100
+
101
+ generate_tab()
102
+
103
+ with st.sidebar:
104
+ st.header("Latest Output Image")
105
+ output_image = get_img(OUTPUT_IMG_KEY)
106
+ if output_image:
107
+ st.image(output_image)
108
+ else:
109
+ st.markdown("No output generated yet")
110
+
111
+
112
+ if __name__ == "__main__":
113
  main()
data_utils/augment.py CHANGED
@@ -1,112 +1,112 @@
1
- import argparse
2
- import multiprocessing as mp
3
- import os
4
- import numpy as np
5
- import math
6
- import cairosvg
7
- import shutil
8
- from data_utils.svg_utils import clockwise, render
9
- from common_utils import affine_shear, affine_rotate, affine_scale, trans2_white_bg
10
-
11
- def render_svg(svg_str, font_dir, char_idx, aug_idx, img_size):
12
- svg_html = render(svg_str)
13
- svg_path = open(f'{font_dir}/aug_svgs/{str(char_idx)}.svg', 'w')
14
- svg_path.write(svg_html)
15
- svg_path.close()
16
- cairosvg.svg2png(url=f'{font_dir}/aug_svgs/{str(char_idx)}.svg',
17
- write_to=f'{font_dir}/aug_imgs/{str(char_idx)}_{aug_idx}.png', output_width=img_size, output_height=img_size)
18
- img_arr = trans2_white_bg(f'{font_dir}/aug_imgs/{str(char_idx)}_{aug_idx}.png')
19
- return img_arr
20
-
21
- def aug_rules(char_seq, aug_idx):
22
- if aug_idx == 0:
23
- return clockwise(affine_shear(char_seq, dx=0.2))['sequence']
24
- elif aug_idx == 1:
25
- return clockwise(affine_shear(char_seq, dy=-0.1))['sequence']
26
- elif aug_idx == 2:
27
- return clockwise(affine_scale(char_seq, 0.8))['sequence']
28
- elif aug_idx == 3:
29
- return clockwise(affine_rotate(char_seq, theta=5))['sequence']
30
- else:
31
- return clockwise(affine_rotate(char_seq, theta=-5))['sequence']
32
-
33
- def copy_others(dir_src, dir_tgt):
34
- for item in ['class.npy', 'font_id.npy', 'seq_len.npy']:
35
- shutil.copy(f'{dir_src}/{item}', f'{dir_tgt}/{item}')
36
-
37
- def apply_aug(opts):
38
- """
39
- applying data augmentation for Chinese fonts
40
- """
41
- data_path = os.path.join(opts.output_path, opts.language, opts.split)
42
- font_dirs_ = os.listdir(data_path)
43
- font_dirs = []
44
- for idx in range(len(font_dirs_)):
45
- if '_' not in font_dirs_[idx].split('/')[-1]:
46
- font_dirs.append(font_dirs_[idx])
47
- font_dirs.sort()
48
- num_fonts = len(font_dirs)
49
- print(f"Number {opts.split} fonts before processing", num_fonts)
50
- num_processes = mp.cpu_count() - 2
51
- fonts_per_process = num_fonts // num_processes + 1
52
-
53
- def process(process_id):
54
- for i in range(process_id * fonts_per_process, (process_id + 1) * fonts_per_process):
55
- if i >= num_fonts:
56
- break
57
- font_dir = os.path.join(data_path, font_dirs[i])
58
- font_seq = np.load(os.path.join(font_dir, 'sequence.npy')).reshape(opts.n_chars, opts.max_len, -1)
59
-
60
- ret_seq_list = []
61
- ret_img_list = []
62
- for k in range(opts.n_aug):
63
- os.makedirs(font_dir + '_' + str(k), exist_ok=True)
64
- ret_seq_list.append([])
65
- ret_img_list.append([])
66
-
67
- os.makedirs(f'{font_dir}/aug_svgs', exist_ok=True)
68
- os.makedirs(f'{font_dir}/aug_imgs', exist_ok=True)
69
-
70
- for j in range(opts.n_chars):
71
- char_seq = font_seq[j] # default as [71, 12]
72
- for k in range(opts.n_aug):
73
- char_seq_aug = aug_rules(char_seq, k)
74
- ret_seq_list[k].append(char_seq_aug)
75
- img_arr = render_svg(char_seq_aug, font_dir, j, aug_idx=k, img_size=opts.img_size)
76
- ret_img_list[k].append(img_arr)
77
-
78
- for k in range(opts.n_aug):
79
- ret_seq_list[k] = np.array(ret_seq_list[k]).reshape(opts.n_chars, opts.max_len * 10)
80
- ret_img_list[k] = np.array(ret_img_list[k]).reshape(opts.n_chars, opts.img_size, opts.img_size)
81
- np.save(os.path.join(font_dir + '_' + str(k), f'sequence.npy'), ret_seq_list[k])
82
- np.save(os.path.join(font_dir + '_' + str(k), f'rendered_{opts.img_size}.npy'), ret_img_list[k])
83
- copy_others(font_dir, font_dir + '_' + str(k))
84
-
85
- processes = [mp.Process(target=process, args=[pid]) for pid in range(num_processes)]
86
-
87
- for p in processes:
88
- p.start()
89
- for p in processes:
90
- p.join()
91
-
92
-
93
- def main():
94
- parser = argparse.ArgumentParser(description="relax representation")
95
- parser.add_argument("--language", type=str, default='eng', choices=['eng', 'chn', 'tha'])
96
- parser.add_argument("--output_path", type=str, default='../data/vecfont_dataset_/', help="Path to write the database to")
97
- parser.add_argument('--max_len', type=int, default=71, help="by default, 51 for english and 71 for chinese")
98
- parser.add_argument('--n_aug', type=int, default=5, help="for each font, augment it for n_aug times")
99
- parser.add_argument('--n_chars', type=int, default=52)
100
- parser.add_argument('--img_size', type=int, default=64, help="the height and width of glyph images")
101
- parser.add_argument("--split", type=str, default='train')
102
- parser.add_argument('--debug', type=bool, default=True)
103
- opts = parser.parse_args()
104
- apply_aug(opts)
105
-
106
- if __name__ == "__main__":
107
- main()
108
-
109
-
110
-
111
-
112
 
 
1
+ import argparse
2
+ import multiprocessing as mp
3
+ import os
4
+ import numpy as np
5
+ import math
6
+ import cairosvg
7
+ import shutil
8
+ from data_utils.svg_utils import clockwise, render
9
+ from common_utils import affine_shear, affine_rotate, affine_scale, trans2_white_bg
10
+
11
+ def render_svg(svg_str, font_dir, char_idx, aug_idx, img_size):
12
+ svg_html = render(svg_str)
13
+ svg_path = open(f'{font_dir}/aug_svgs/{str(char_idx)}.svg', 'w')
14
+ svg_path.write(svg_html)
15
+ svg_path.close()
16
+ cairosvg.svg2png(url=f'{font_dir}/aug_svgs/{str(char_idx)}.svg',
17
+ write_to=f'{font_dir}/aug_imgs/{str(char_idx)}_{aug_idx}.png', output_width=img_size, output_height=img_size)
18
+ img_arr = trans2_white_bg(f'{font_dir}/aug_imgs/{str(char_idx)}_{aug_idx}.png')
19
+ return img_arr
20
+
21
+ def aug_rules(char_seq, aug_idx):
22
+ if aug_idx == 0:
23
+ return clockwise(affine_shear(char_seq, dx=0.2))['sequence']
24
+ elif aug_idx == 1:
25
+ return clockwise(affine_shear(char_seq, dy=-0.1))['sequence']
26
+ elif aug_idx == 2:
27
+ return clockwise(affine_scale(char_seq, 0.8))['sequence']
28
+ elif aug_idx == 3:
29
+ return clockwise(affine_rotate(char_seq, theta=5))['sequence']
30
+ else:
31
+ return clockwise(affine_rotate(char_seq, theta=-5))['sequence']
32
+
33
+ def copy_others(dir_src, dir_tgt):
34
+ for item in ['class.npy', 'font_id.npy', 'seq_len.npy']:
35
+ shutil.copy(f'{dir_src}/{item}', f'{dir_tgt}/{item}')
36
+
37
+ def apply_aug(opts):
38
+ """
39
+ applying data augmentation for Chinese fonts
40
+ """
41
+ data_path = os.path.join(opts.output_path, opts.language, opts.split)
42
+ font_dirs_ = os.listdir(data_path)
43
+ font_dirs = []
44
+ for idx in range(len(font_dirs_)):
45
+ if '_' not in font_dirs_[idx].split('/')[-1]:
46
+ font_dirs.append(font_dirs_[idx])
47
+ font_dirs.sort()
48
+ num_fonts = len(font_dirs)
49
+ print(f"Number {opts.split} fonts before processing", num_fonts)
50
+ num_processes = mp.cpu_count() - 2
51
+ fonts_per_process = num_fonts // num_processes + 1
52
+
53
+ def process(process_id):
54
+ for i in range(process_id * fonts_per_process, (process_id + 1) * fonts_per_process):
55
+ if i >= num_fonts:
56
+ break
57
+ font_dir = os.path.join(data_path, font_dirs[i])
58
+ font_seq = np.load(os.path.join(font_dir, 'sequence.npy')).reshape(opts.n_chars, opts.max_len, -1)
59
+
60
+ ret_seq_list = []
61
+ ret_img_list = []
62
+ for k in range(opts.n_aug):
63
+ os.makedirs(font_dir + '_' + str(k), exist_ok=True)
64
+ ret_seq_list.append([])
65
+ ret_img_list.append([])
66
+
67
+ os.makedirs(f'{font_dir}/aug_svgs', exist_ok=True)
68
+ os.makedirs(f'{font_dir}/aug_imgs', exist_ok=True)
69
+
70
+ for j in range(opts.n_chars):
71
+ char_seq = font_seq[j] # default as [71, 12]
72
+ for k in range(opts.n_aug):
73
+ char_seq_aug = aug_rules(char_seq, k)
74
+ ret_seq_list[k].append(char_seq_aug)
75
+ img_arr = render_svg(char_seq_aug, font_dir, j, aug_idx=k, img_size=opts.img_size)
76
+ ret_img_list[k].append(img_arr)
77
+
78
+ for k in range(opts.n_aug):
79
+ ret_seq_list[k] = np.array(ret_seq_list[k]).reshape(opts.n_chars, opts.max_len * 10)
80
+ ret_img_list[k] = np.array(ret_img_list[k]).reshape(opts.n_chars, opts.img_size, opts.img_size)
81
+ np.save(os.path.join(font_dir + '_' + str(k), f'sequence.npy'), ret_seq_list[k])
82
+ np.save(os.path.join(font_dir + '_' + str(k), f'rendered_{opts.img_size}.npy'), ret_img_list[k])
83
+ copy_others(font_dir, font_dir + '_' + str(k))
84
+
85
+ processes = [mp.Process(target=process, args=[pid]) for pid in range(num_processes)]
86
+
87
+ for p in processes:
88
+ p.start()
89
+ for p in processes:
90
+ p.join()
91
+
92
+
93
+ def main():
94
+ parser = argparse.ArgumentParser(description="relax representation")
95
+ parser.add_argument("--language", type=str, default='eng', choices=['eng', 'chn', 'tha'])
96
+ parser.add_argument("--output_path", type=str, default='../data/vecfont_dataset_/', help="Path to write the database to")
97
+ parser.add_argument('--max_len', type=int, default=71, help="by default, 51 for english and 71 for chinese")
98
+ parser.add_argument('--n_aug', type=int, default=5, help="for each font, augment it for n_aug times")
99
+ parser.add_argument('--n_chars', type=int, default=52)
100
+ parser.add_argument('--img_size', type=int, default=64, help="the height and width of glyph images")
101
+ parser.add_argument("--split", type=str, default='train')
102
+ parser.add_argument('--debug', type=bool, default=True)
103
+ opts = parser.parse_args()
104
+ apply_aug(opts)
105
+
106
+ if __name__ == "__main__":
107
+ main()
108
+
109
+
110
+
111
+
112
 
data_utils/common_utils.py CHANGED
@@ -1,75 +1,75 @@
1
- import math
2
- import numpy as np
3
- from PIL import Image
4
-
5
- def trans2_white_bg(img_path):
6
- img = Image.open(img_path)
7
- img_arr = 255 - np.array(img)[:, :, 3]
8
- img_ = Image.fromarray(img_arr)
9
- img_.save(img_path)
10
- return img_arr
11
-
12
- def affine_shear(seq, dx=-0.3, dy=0.0):
13
- mask = ~(seq == 0)
14
- seq_12 = seq.copy()
15
- seq_12[:,4] -= 12.0
16
- seq_12[:,5] = -seq_12[:,5] + 12
17
- seq_12[:,6] -= 12.0
18
- seq_12[:,7] = -seq_12[:,7] + 12
19
- seq_12[:,8] -= 12.0
20
- seq_12[:,9] = -seq_12[:,9] + 12
21
-
22
- seq_args = seq_12[:,4:]
23
- seq_args = np.concatenate([seq_args[:, :2], seq_args[:, 2:4], seq_args[:, 4:6]], 0).transpose(1,0)
24
- affine_matrix=np.array([[1, dx],
25
- [dy, 1]])
26
- rotated_args = np.dot(affine_matrix,seq_args)
27
- rotated_args = rotated_args.transpose(1,0)
28
- new_args = np.concatenate([rotated_args[:seq.shape[0]], rotated_args[seq.shape[0]:seq.shape[0]*2], rotated_args[seq.shape[0]*2:]],-1)
29
- new_args[:,0] += 12.0
30
- new_args[:,1] = -(new_args[:,1] - 12)
31
- new_args[:,2] += 12.0
32
- new_args[:,3] = -(new_args[:,3] - 12)
33
- new_args[:,4] += 12.0
34
- new_args[:,5] = -(new_args[:,5] - 12)
35
- new_seq = np.concatenate([seq[:, :4], new_args],1)
36
- new_seq = new_seq * mask
37
- return new_seq
38
-
39
- def affine_scale(seq, scale=0.8):
40
- mask = ~(seq==0)
41
- seq_args = seq[:, 4:] - 12.0
42
- seq_args *= scale
43
- seq_args = seq_args + 12.0
44
- new_seq = np.concatenate([seq[:, :4], seq_args], 1)
45
- new_seq = new_seq * mask
46
- return new_seq
47
-
48
- def affine_rotate(seq,theta=-5):
49
- mask = ~(seq==0)
50
- seq_12 = seq.copy()
51
- seq_12[:,4] -=12.0
52
- seq_12[:,5] = -seq_12[:,5] + 12
53
- seq_12[:,6] -=12.0
54
- seq_12[:,7] = -seq_12[:,7] + 12
55
- seq_12[:,8] -=12.0
56
- seq_12[:,9] = -seq_12[:,9] + 12
57
-
58
- seq_args =seq_12[:, 4:] # default as [71,6]
59
- seq_args = np.concatenate([seq_args[:,:2],seq_args[:,2:4],seq_args[:,4:6]],0).transpose(1,0)# note 2,213
60
- theta = math.radians(theta)
61
- affine_matrix=np.array([[np.cos(theta),-np.sin(theta)], [np.sin(theta), np.cos(theta)]])# note 2,2
62
- rotated_args = np.dot(affine_matrix,seq_args)# note 2,213
63
- rotated_args = rotated_args.transpose(1,0)# note 213,2
64
- new_args = np.concatenate([rotated_args[:seq.shape[0]],rotated_args[seq.shape[0]:seq.shape[0]*2],rotated_args[seq.shape[0]*2:]],-1)# note 2,213
65
- new_args[:,0] +=12.0
66
- new_args[:,1] = -(new_args[:,1]-12)
67
- new_args[:,2] +=12.0
68
- new_args[:,3] = -(new_args[:,3]-12)
69
- new_args[:,4] +=12.0
70
- new_args[:,5] = -(new_args[:,5]-12)
71
-
72
- new_seq = np.concatenate([seq[:,:4],new_args],1)
73
- new_seq =new_seq *mask
74
- return new_seq
75
-
 
1
+ import math
2
+ import numpy as np
3
+ from PIL import Image
4
+
5
+ def trans2_white_bg(img_path):
6
+ img = Image.open(img_path)
7
+ img_arr = 255 - np.array(img)[:, :, 3]
8
+ img_ = Image.fromarray(img_arr)
9
+ img_.save(img_path)
10
+ return img_arr
11
+
12
+ def affine_shear(seq, dx=-0.3, dy=0.0):
13
+ mask = ~(seq == 0)
14
+ seq_12 = seq.copy()
15
+ seq_12[:,4] -= 12.0
16
+ seq_12[:,5] = -seq_12[:,5] + 12
17
+ seq_12[:,6] -= 12.0
18
+ seq_12[:,7] = -seq_12[:,7] + 12
19
+ seq_12[:,8] -= 12.0
20
+ seq_12[:,9] = -seq_12[:,9] + 12
21
+
22
+ seq_args = seq_12[:,4:]
23
+ seq_args = np.concatenate([seq_args[:, :2], seq_args[:, 2:4], seq_args[:, 4:6]], 0).transpose(1,0)
24
+ affine_matrix=np.array([[1, dx],
25
+ [dy, 1]])
26
+ rotated_args = np.dot(affine_matrix,seq_args)
27
+ rotated_args = rotated_args.transpose(1,0)
28
+ new_args = np.concatenate([rotated_args[:seq.shape[0]], rotated_args[seq.shape[0]:seq.shape[0]*2], rotated_args[seq.shape[0]*2:]],-1)
29
+ new_args[:,0] += 12.0
30
+ new_args[:,1] = -(new_args[:,1] - 12)
31
+ new_args[:,2] += 12.0
32
+ new_args[:,3] = -(new_args[:,3] - 12)
33
+ new_args[:,4] += 12.0
34
+ new_args[:,5] = -(new_args[:,5] - 12)
35
+ new_seq = np.concatenate([seq[:, :4], new_args],1)
36
+ new_seq = new_seq * mask
37
+ return new_seq
38
+
39
+ def affine_scale(seq, scale=0.8):
40
+ mask = ~(seq==0)
41
+ seq_args = seq[:, 4:] - 12.0
42
+ seq_args *= scale
43
+ seq_args = seq_args + 12.0
44
+ new_seq = np.concatenate([seq[:, :4], seq_args], 1)
45
+ new_seq = new_seq * mask
46
+ return new_seq
47
+
48
+ def affine_rotate(seq,theta=-5):
49
+ mask = ~(seq==0)
50
+ seq_12 = seq.copy()
51
+ seq_12[:,4] -=12.0
52
+ seq_12[:,5] = -seq_12[:,5] + 12
53
+ seq_12[:,6] -=12.0
54
+ seq_12[:,7] = -seq_12[:,7] + 12
55
+ seq_12[:,8] -=12.0
56
+ seq_12[:,9] = -seq_12[:,9] + 12
57
+
58
+ seq_args =seq_12[:, 4:] # default as [71,6]
59
+ seq_args = np.concatenate([seq_args[:,:2],seq_args[:,2:4],seq_args[:,4:6]],0).transpose(1,0)# note 2,213
60
+ theta = math.radians(theta)
61
+ affine_matrix=np.array([[np.cos(theta),-np.sin(theta)], [np.sin(theta), np.cos(theta)]])# note 2,2
62
+ rotated_args = np.dot(affine_matrix,seq_args)# note 2,213
63
+ rotated_args = rotated_args.transpose(1,0)# note 213,2
64
+ new_args = np.concatenate([rotated_args[:seq.shape[0]],rotated_args[seq.shape[0]:seq.shape[0]*2],rotated_args[seq.shape[0]*2:]],-1)# note 2,213
65
+ new_args[:,0] +=12.0
66
+ new_args[:,1] = -(new_args[:,1]-12)
67
+ new_args[:,2] +=12.0
68
+ new_args[:,3] = -(new_args[:,3]-12)
69
+ new_args[:,4] +=12.0
70
+ new_args[:,5] = -(new_args[:,5]-12)
71
+
72
+ new_seq = np.concatenate([seq[:,:4],new_args],1)
73
+ new_seq =new_seq *mask
74
+ return new_seq
75
+
data_utils/convert_ttf_to_sfd.py CHANGED
@@ -1,103 +1,103 @@
1
- import fontforge # noqa
2
- import os
3
- import sys
4
- from tqdm import tqdm
5
- import multiprocessing as mp
6
- import argparse
7
-
8
-
9
- def convert_mp(opts):
10
- """Useing multiprocessing to convert all fonts to sfd files"""
11
-
12
- charset_th = open(f"{opts.data_path}/char_set/{opts.language}.txt", 'r').read()
13
- charset = charset_th
14
- if opts.ref_nshot == 52:
15
- charset_eng = open(f"{opts.data_path}/char_set/eng.txt", 'r').read()
16
- charset = charset_th + charset_eng
17
- charset_lenw = len(str(len(charset)))
18
- fonts_file_path = os.path.join(opts.ttf_path, opts.language) # opts.ttf_path,opts.language,
19
- sfd_path = os.path.join(opts.sfd_path, opts.language)
20
- print(os.path.join(fonts_file_path, opts.split))
21
- for root, dirs, files in os.walk(os.path.join(fonts_file_path, opts.split)):
22
- ttf_fnames = files
23
- print(ttf_fnames)
24
-
25
- font_num = len(ttf_fnames)
26
- process_num = mp.cpu_count() - 1
27
- font_num_per_process = font_num // process_num + 1
28
-
29
- def process(process_id, font_num_p_process):
30
- for i in tqdm(range(process_id * font_num_p_process, (process_id + 1) * font_num_p_process)):
31
- if i >= font_num:
32
- break
33
-
34
- font_id = ttf_fnames[i].split('.')[0]
35
- split = opts.split
36
- font_name = ttf_fnames[i]
37
-
38
- font_file_path = os.path.join(fonts_file_path, split, font_name)
39
- try:
40
- cur_font = fontforge.open(font_file_path)
41
- except Exception as e:
42
- print('Cannot open ', font_name)
43
- print(e)
44
- continue
45
-
46
- target_dir = os.path.join(sfd_path, split, "{}".format(font_id))
47
- if not os.path.exists(target_dir):
48
- os.makedirs(target_dir)
49
-
50
- for char_id, char in enumerate(charset):
51
- try:
52
- char_description = open(os.path.join(target_dir, '{}_{num:0{width}}.txt'.format(font_id, num=char_id, width=charset_lenw)), 'w')
53
- if char in charset_th:
54
- char = 'uni' + char.encode("unicode_escape")[2:].decode("utf-8")
55
-
56
- cur_font.selection.select(char)
57
- cur_font.copy()
58
-
59
- new_font_for_char = fontforge.font()
60
- # new_font_for_char.ascent = 750
61
- # new_font_for_char.descent = 250
62
- # new_font_for_char.em = new_font_for_char.ascent + new_font_for_char.descent
63
- char = 'A'
64
-
65
- new_font_for_char.selection.select(char)
66
- new_font_for_char.paste()
67
- new_font_for_char.fontname = "{}_".format(font_id) + font_name
68
-
69
- new_font_for_char.save(os.path.join(target_dir, '{}_{num:0{width}}.sfd'.format(font_id, num=char_id, width=charset_lenw)))
70
-
71
- char_description.write(str(ord(char)) + '\n')
72
- char_description.write(str(new_font_for_char[char].width) + '\n')
73
- char_description.write(str(new_font_for_char[char].vwidth) + '\n')
74
- char_description.write('{num:0{width}}'.format(num=char_id, width=charset_lenw) + '\n')
75
- char_description.write('{}'.format(font_id))
76
- # print('{}_{num:0{width}}.sfd'.format(font_id, num=char_id, width=charset_lenw))
77
- char_description.close()
78
- except Exception as e:
79
- print("Found Error:", font_id, font_name ,char_id, char)
80
- print(e)
81
-
82
- cur_font.close()
83
-
84
- processes = [mp.Process(target=process, args=(pid, font_num_per_process)) for pid in range(process_num)]
85
-
86
- for p in processes:
87
- p.start()
88
- for p in processes:
89
- p.join()
90
-
91
-
92
- def main():
93
- parser = argparse.ArgumentParser(description="Convert ttf fonts to sfd fonts")
94
- parser.add_argument("--language", type=str, default='eng', choices=['eng', 'chn', 'tha'])
95
- parser.add_argument("--data_path", type=str, default='./Font_Dataset', help="Path to Dataset")
96
- parser.add_argument("--ttf_path", type=str, default='../data/font_ttfs')
97
- parser.add_argument('--sfd_path', type=str, default='../data/font_sfds')
98
- parser.add_argument('--split', type=str, default='train')
99
- opts = parser.parse_args()
100
- convert_mp(opts)
101
-
102
- if __name__ == "__main__":
103
  main()
 
1
+ import fontforge # noqa
2
+ import os
3
+ import sys
4
+ from tqdm import tqdm
5
+ import multiprocessing as mp
6
+ import argparse
7
+
8
+
9
+ def convert_mp(opts):
10
+ """Useing multiprocessing to convert all fonts to sfd files"""
11
+
12
+ charset_th = open(f"{opts.data_path}/char_set/{opts.language}.txt", 'r').read()
13
+ charset = charset_th
14
+ if opts.ref_nshot == 52:
15
+ charset_eng = open(f"{opts.data_path}/char_set/eng.txt", 'r').read()
16
+ charset = charset_th + charset_eng
17
+ charset_lenw = len(str(len(charset)))
18
+ fonts_file_path = os.path.join(opts.ttf_path, opts.language) # opts.ttf_path,opts.language,
19
+ sfd_path = os.path.join(opts.sfd_path, opts.language)
20
+ print(os.path.join(fonts_file_path, opts.split))
21
+ for root, dirs, files in os.walk(os.path.join(fonts_file_path, opts.split)):
22
+ ttf_fnames = files
23
+ print(ttf_fnames)
24
+
25
+ font_num = len(ttf_fnames)
26
+ process_num = mp.cpu_count() - 1
27
+ font_num_per_process = font_num // process_num + 1
28
+
29
+ def process(process_id, font_num_p_process):
30
+ for i in tqdm(range(process_id * font_num_p_process, (process_id + 1) * font_num_p_process)):
31
+ if i >= font_num:
32
+ break
33
+
34
+ font_id = ttf_fnames[i].split('.')[0]
35
+ split = opts.split
36
+ font_name = ttf_fnames[i]
37
+
38
+ font_file_path = os.path.join(fonts_file_path, split, font_name)
39
+ try:
40
+ cur_font = fontforge.open(font_file_path)
41
+ except Exception as e:
42
+ print('Cannot open ', font_name)
43
+ print(e)
44
+ continue
45
+
46
+ target_dir = os.path.join(sfd_path, split, "{}".format(font_id))
47
+ if not os.path.exists(target_dir):
48
+ os.makedirs(target_dir)
49
+
50
+ for char_id, char in enumerate(charset):
51
+ try:
52
+ char_description = open(os.path.join(target_dir, '{}_{num:0{width}}.txt'.format(font_id, num=char_id, width=charset_lenw)), 'w')
53
+ if char in charset_th:
54
+ char = 'uni' + char.encode("unicode_escape")[2:].decode("utf-8")
55
+
56
+ cur_font.selection.select(char)
57
+ cur_font.copy()
58
+
59
+ new_font_for_char = fontforge.font()
60
+ # new_font_for_char.ascent = 750
61
+ # new_font_for_char.descent = 250
62
+ # new_font_for_char.em = new_font_for_char.ascent + new_font_for_char.descent
63
+ char = 'A'
64
+
65
+ new_font_for_char.selection.select(char)
66
+ new_font_for_char.paste()
67
+ new_font_for_char.fontname = "{}_".format(font_id) + font_name
68
+
69
+ new_font_for_char.save(os.path.join(target_dir, '{}_{num:0{width}}.sfd'.format(font_id, num=char_id, width=charset_lenw)))
70
+
71
+ char_description.write(str(ord(char)) + '\n')
72
+ char_description.write(str(new_font_for_char[char].width) + '\n')
73
+ char_description.write(str(new_font_for_char[char].vwidth) + '\n')
74
+ char_description.write('{num:0{width}}'.format(num=char_id, width=charset_lenw) + '\n')
75
+ char_description.write('{}'.format(font_id))
76
+ # print('{}_{num:0{width}}.sfd'.format(font_id, num=char_id, width=charset_lenw))
77
+ char_description.close()
78
+ except Exception as e:
79
+ print("Found Error:", font_id, font_name ,char_id, char)
80
+ print(e)
81
+
82
+ cur_font.close()
83
+
84
+ processes = [mp.Process(target=process, args=(pid, font_num_per_process)) for pid in range(process_num)]
85
+
86
+ for p in processes:
87
+ p.start()
88
+ for p in processes:
89
+ p.join()
90
+
91
+
92
+ def main():
93
+ parser = argparse.ArgumentParser(description="Convert ttf fonts to sfd fonts")
94
+ parser.add_argument("--language", type=str, default='eng', choices=['eng', 'chn', 'tha'])
95
+ parser.add_argument("--data_path", type=str, default='./Font_Dataset', help="Path to Dataset")
96
+ parser.add_argument("--ttf_path", type=str, default='../data/font_ttfs')
97
+ parser.add_argument('--sfd_path', type=str, default='../data/font_sfds')
98
+ parser.add_argument('--split', type=str, default='train')
99
+ opts = parser.parse_args()
100
+ convert_mp(opts)
101
+
102
+ if __name__ == "__main__":
103
  main()
data_utils/relax_rep.py CHANGED
@@ -1,136 +1,136 @@
1
- import argparse
2
- import multiprocessing as mp
3
- import os
4
- import numpy as np
5
- from dataclasses import dataclass
6
- from tqdm import tqdm
7
-
8
- def numericalize(cmd, n=64):
9
- """NOTE: shall only be called after normalization"""
10
- cmd = ((cmd) / 30 * n).round().clip(min=0, max=n-1).astype(int)
11
- return cmd
12
-
13
- def denumericalize(cmd, n=64):
14
- cmd = cmd / n * 30
15
- return cmd
16
-
17
- def cal_aux_bezier_pts(font_seq, opts):
18
- """
19
- calculate aux pts along bezier curves
20
- """
21
- pts_aux_all = []
22
-
23
- for j in range(opts.char_num):
24
- char_seq = font_seq[j] # shape: opts.max_len ,12
25
- pts_aux_char = []
26
- for k in range(opts.max_seq_len):
27
- stroke_seq = char_seq[k]
28
- stroke_cmd = np.argmax(stroke_seq[:4], -1)
29
- stroke_seq[4:] = denumericalize(numericalize(stroke_seq[4:]))
30
- p0, p1, p2, p3 = stroke_seq[4:6], stroke_seq[6:8], stroke_seq[8:10], stroke_seq[10:12]
31
- pts_aux_stroke = []
32
- if stroke_cmd == 0:
33
- for t in range(6):
34
- pts_aux_stroke.append(0)
35
- elif stroke_cmd == 1: # move
36
- for t in [0.25, 0.5, 0.75]:
37
- coord_t = p0 + t*(p3-p0)
38
- pts_aux_stroke.append(coord_t[0])
39
- pts_aux_stroke.append(coord_t[1])
40
- elif stroke_cmd == 2: # line
41
- for t in [0.25, 0.5, 0.75]:
42
- coord_t = p0 + t*(p3-p0)
43
- pts_aux_stroke.append(coord_t[0])
44
- pts_aux_stroke.append(coord_t[1])
45
- elif stroke_cmd == 3: # curve
46
- for t in [0.25, 0.5, 0.75]:
47
- coord_t = (1-t)*(1-t)*(1-t)*p0 + 3*t*(1-t)*(1-t)*p1 + 3*t*t*(1-t)*p2 + t*t*t*p3
48
- pts_aux_stroke.append(coord_t[0])
49
- pts_aux_stroke.append(coord_t[1])
50
-
51
- pts_aux_stroke = np.array(pts_aux_stroke)
52
- pts_aux_char.append(pts_aux_stroke)
53
-
54
- pts_aux_char = np.array(pts_aux_char)
55
- pts_aux_all.append(pts_aux_char)
56
-
57
- pts_aux_all = np.array(pts_aux_all)
58
-
59
- return pts_aux_all
60
-
61
-
62
- def relax_rep(opts):
63
- """
64
- relaxing the sequence representation, details are shown in paper
65
- """
66
- data_path = os.path.join(opts.output_path, opts.language, opts.split)
67
- font_dirs = os.listdir(data_path)
68
- font_dirs.sort()
69
- num_fonts = len(font_dirs)
70
- print(f"Number {opts.split} fonts before processing", num_fonts)
71
- num_processes = mp.cpu_count() - 1
72
- # num_processes = 1
73
- fonts_per_process = num_fonts // num_processes + 1
74
-
75
- def process(process_id):
76
-
77
- for i in tqdm(range(process_id * fonts_per_process, (process_id + 1) * fonts_per_process)):
78
- if i >= num_fonts:
79
- break
80
-
81
- font_dir = os.path.join(data_path, font_dirs[i])
82
- font_seq = np.load(os.path.join(font_dir, 'sequence.npy')).reshape(opts.char_num, opts.max_seq_len, -1)
83
- font_len = np.load(os.path.join(font_dir, 'seq_len.npy')).reshape(-1)
84
- cmd = font_seq[:, :, :4]
85
- args = font_seq[:, :, 4:]
86
-
87
- ret = []
88
- for j in range(opts.char_num):
89
-
90
- char_cmds = cmd[j]
91
- char_args = args[j]
92
- char_len = font_len[j]
93
- new_args = []
94
- for k in range(char_len):
95
- cur_cls = np.argmax(char_cmds[k], -1)
96
- cur_arg = char_args[k]
97
- if k - 1 > -1:
98
- pre_arg = char_args[k - 1]
99
- if cur_cls == 1: # when k == 0, cur_cls == 1
100
- cur_arg = np.concatenate((np.array([cur_arg[-2], cur_arg[-1]]), cur_arg), -1)
101
- else:
102
- cur_arg = np.concatenate((np.array([pre_arg[-2], pre_arg[-1]]), cur_arg), -1)
103
- new_args.append(cur_arg)
104
-
105
- while(len(new_args)) < opts.max_seq_len:
106
- new_args.append(np.array([0, 0, 0, 0, 0, 0, 0, 0]))
107
-
108
- new_args = np.array(new_args)
109
- new_seq = np.concatenate((char_cmds, new_args),-1)
110
- ret.append(new_seq)
111
- ret = np.array(ret)
112
- # write relaxed version of sequence.npy
113
- np.save(os.path.join(font_dir, 'sequence_relaxed.npy'), ret.reshape(opts.char_num, -1))
114
-
115
- pts_aux = cal_aux_bezier_pts(ret, opts)
116
- np.save(os.path.join(font_dir, 'pts_aux.npy'), pts_aux)
117
-
118
- processes = [mp.Process(target=process, args=[pid]) for pid in range(num_processes)]
119
-
120
- for p in processes:
121
- p.start()
122
- for p in processes:
123
- p.join()
124
-
125
-
126
- def main():
127
- parser = argparse.ArgumentParser(description="relax representation")
128
- parser.add_argument("--language", type=str, default='eng', choices=['eng', 'chn', 'tha'])
129
- parser.add_argument("--data_path", type=str, default='./Font_Dataset', help="Path to Dataset")
130
- parser.add_argument("--output_path", type=str, default='../data/vecfont_dataset_/', help="Path to write the database to")
131
- parser.add_argument("--split", type=str, default='train')
132
- opts = parser.parse_args()
133
- relax_rep(opts)
134
-
135
- if __name__ == "__main__":
136
  main()
 
1
+ import argparse
2
+ import multiprocessing as mp
3
+ import os
4
+ import numpy as np
5
+ from dataclasses import dataclass
6
+ from tqdm import tqdm
7
+
8
+ def numericalize(cmd, n=64):
9
+ """NOTE: shall only be called after normalization"""
10
+ cmd = ((cmd) / 30 * n).round().clip(min=0, max=n-1).astype(int)
11
+ return cmd
12
+
13
+ def denumericalize(cmd, n=64):
14
+ cmd = cmd / n * 30
15
+ return cmd
16
+
17
+ def cal_aux_bezier_pts(font_seq, opts):
18
+ """
19
+ calculate aux pts along bezier curves
20
+ """
21
+ pts_aux_all = []
22
+
23
+ for j in range(opts.char_num):
24
+ char_seq = font_seq[j] # shape: opts.max_len ,12
25
+ pts_aux_char = []
26
+ for k in range(opts.max_seq_len):
27
+ stroke_seq = char_seq[k]
28
+ stroke_cmd = np.argmax(stroke_seq[:4], -1)
29
+ stroke_seq[4:] = denumericalize(numericalize(stroke_seq[4:]))
30
+ p0, p1, p2, p3 = stroke_seq[4:6], stroke_seq[6:8], stroke_seq[8:10], stroke_seq[10:12]
31
+ pts_aux_stroke = []
32
+ if stroke_cmd == 0:
33
+ for t in range(6):
34
+ pts_aux_stroke.append(0)
35
+ elif stroke_cmd == 1: # move
36
+ for t in [0.25, 0.5, 0.75]:
37
+ coord_t = p0 + t*(p3-p0)
38
+ pts_aux_stroke.append(coord_t[0])
39
+ pts_aux_stroke.append(coord_t[1])
40
+ elif stroke_cmd == 2: # line
41
+ for t in [0.25, 0.5, 0.75]:
42
+ coord_t = p0 + t*(p3-p0)
43
+ pts_aux_stroke.append(coord_t[0])
44
+ pts_aux_stroke.append(coord_t[1])
45
+ elif stroke_cmd == 3: # curve
46
+ for t in [0.25, 0.5, 0.75]:
47
+ coord_t = (1-t)*(1-t)*(1-t)*p0 + 3*t*(1-t)*(1-t)*p1 + 3*t*t*(1-t)*p2 + t*t*t*p3
48
+ pts_aux_stroke.append(coord_t[0])
49
+ pts_aux_stroke.append(coord_t[1])
50
+
51
+ pts_aux_stroke = np.array(pts_aux_stroke)
52
+ pts_aux_char.append(pts_aux_stroke)
53
+
54
+ pts_aux_char = np.array(pts_aux_char)
55
+ pts_aux_all.append(pts_aux_char)
56
+
57
+ pts_aux_all = np.array(pts_aux_all)
58
+
59
+ return pts_aux_all
60
+
61
+
62
+ def relax_rep(opts):
63
+ """
64
+ relaxing the sequence representation, details are shown in paper
65
+ """
66
+ data_path = os.path.join(opts.output_path, opts.language, opts.split)
67
+ font_dirs = os.listdir(data_path)
68
+ font_dirs.sort()
69
+ num_fonts = len(font_dirs)
70
+ print(f"Number {opts.split} fonts before processing", num_fonts)
71
+ num_processes = mp.cpu_count() - 1
72
+ # num_processes = 1
73
+ fonts_per_process = num_fonts // num_processes + 1
74
+
75
+ def process(process_id):
76
+
77
+ for i in tqdm(range(process_id * fonts_per_process, (process_id + 1) * fonts_per_process)):
78
+ if i >= num_fonts:
79
+ break
80
+
81
+ font_dir = os.path.join(data_path, font_dirs[i])
82
+ font_seq = np.load(os.path.join(font_dir, 'sequence.npy')).reshape(opts.char_num, opts.max_seq_len, -1)
83
+ font_len = np.load(os.path.join(font_dir, 'seq_len.npy')).reshape(-1)
84
+ cmd = font_seq[:, :, :4]
85
+ args = font_seq[:, :, 4:]
86
+
87
+ ret = []
88
+ for j in range(opts.char_num):
89
+
90
+ char_cmds = cmd[j]
91
+ char_args = args[j]
92
+ char_len = font_len[j]
93
+ new_args = []
94
+ for k in range(char_len):
95
+ cur_cls = np.argmax(char_cmds[k], -1)
96
+ cur_arg = char_args[k]
97
+ if k - 1 > -1:
98
+ pre_arg = char_args[k - 1]
99
+ if cur_cls == 1: # when k == 0, cur_cls == 1
100
+ cur_arg = np.concatenate((np.array([cur_arg[-2], cur_arg[-1]]), cur_arg), -1)
101
+ else:
102
+ cur_arg = np.concatenate((np.array([pre_arg[-2], pre_arg[-1]]), cur_arg), -1)
103
+ new_args.append(cur_arg)
104
+
105
+ while(len(new_args)) < opts.max_seq_len:
106
+ new_args.append(np.array([0, 0, 0, 0, 0, 0, 0, 0]))
107
+
108
+ new_args = np.array(new_args)
109
+ new_seq = np.concatenate((char_cmds, new_args),-1)
110
+ ret.append(new_seq)
111
+ ret = np.array(ret)
112
+ # write relaxed version of sequence.npy
113
+ np.save(os.path.join(font_dir, 'sequence_relaxed.npy'), ret.reshape(opts.char_num, -1))
114
+
115
+ pts_aux = cal_aux_bezier_pts(ret, opts)
116
+ np.save(os.path.join(font_dir, 'pts_aux.npy'), pts_aux)
117
+
118
+ processes = [mp.Process(target=process, args=[pid]) for pid in range(num_processes)]
119
+
120
+ for p in processes:
121
+ p.start()
122
+ for p in processes:
123
+ p.join()
124
+
125
+
126
+ def main():
127
+ parser = argparse.ArgumentParser(description="relax representation")
128
+ parser.add_argument("--language", type=str, default='eng', choices=['eng', 'chn', 'tha'])
129
+ parser.add_argument("--data_path", type=str, default='./Font_Dataset', help="Path to Dataset")
130
+ parser.add_argument("--output_path", type=str, default='../data/vecfont_dataset_/', help="Path to write the database to")
131
+ parser.add_argument("--split", type=str, default='train')
132
+ opts = parser.parse_args()
133
+ relax_rep(opts)
134
+
135
+ if __name__ == "__main__":
136
  main()
data_utils/svg_utils.py CHANGED
@@ -1,1083 +1,1083 @@
1
- # Copyright 2020 The Magenta Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- # Lint as: python3
16
- """Defines the Material Design Icons Problem."""
17
- import io
18
- import numpy as np
19
- import re
20
-
21
- from PIL import Image
22
- from itertools import zip_longest
23
- from skimage import draw
24
-
25
-
26
- SVG_PREFIX_BIG = ('<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="'
27
- 'http://www.w3.org/1999/xlink" width="256px" height="256px"'
28
- ' style="-ms-transform: rotate(360deg); -webkit-transform:'
29
- ' rotate(360deg); transform: rotate(360deg);" '
30
- 'preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 30">')
31
- PATH_PREFIX_1 = '<path d="'
32
- PATH_POSFIX_1 = '" fill="currentColor"/>'
33
- SVG_POSFIX = '</svg>'
34
-
35
- NUM_ARGS = {'v': 1, 'V': 1, 'h': 1, 'H': 1, 'a': 7, 'A': 7, 'l': 2, 'L': 2,
36
- 't': 2, 'T': 2, 'c': 6, 'C': 6, 'm': 2, 'M': 2, 's': 4, 'S': 4,
37
- 'q': 4, 'Q': 4, 'z': 0}
38
- # in order of arg complexity, with absolutes clustered
39
- # recall we don't handle all commands (see docstring)
40
- CMDS_LIST = 'zHVMLTSQCAhvmltsqca' # was zhvmltsqcaHVMLTSQCA
41
- CMD_MAPPING = {cmd: i for i, cmd in enumerate(CMDS_LIST)}
42
-
43
- FEATURE_DIM = 10
44
-
45
- MAX_SEQ_LEN = 120
46
-
47
- # Manually Change Max Sequence
48
- def change_max_seq_len(param):
49
- global MAX_SEQ_LEN
50
- MAX_SEQ_LEN = param
51
- return MAX_SEQ_LEN
52
-
53
- ############################### GENERAL UTILS #################################
54
- def grouper(iterable, batch_size, fill_value=None):
55
- """Helper method for returning batches of size batch_size of a dataset."""
56
- # grouper('ABCDEF', 3) -> 'ABC', 'DEF'
57
- args = [iter(iterable)] * batch_size
58
- return zip_longest(*args, fillvalue=fill_value)
59
-
60
-
61
- def _map_uni_to_alphanum(uni):
62
- """Maps [0-9 A-Z a-z] to numbers 0-62."""
63
- if 48 <= uni <= 57:
64
- return uni - 48
65
- elif 65 <= uni <= 90:
66
- return uni - 65 + 10
67
- return uni - 97 + 36
68
-
69
-
70
- def _map_uni_to_alpha(uni):
71
- """Maps [A-Z a-z] to numbers 0-52."""
72
- if 65 <= uni <= 90:
73
- return uni - 65
74
- return uni - 97 + 26
75
-
76
-
77
- ############# UTILS FOR CONVERTING SFD/SPLINESETS TO SVG PATHS ################
78
- def _get_spline(sfd):
79
- if 'SplineSet' not in sfd:
80
- return ''
81
- pro = sfd[sfd.index('SplineSet') + 10:] # 10 is the 'SplineSet'
82
- pro = pro[:pro.index('EndSplineSet')]
83
- return pro
84
-
85
-
86
- def _spline_to_path_list(spline, height, replace_with_prev=False):
87
- """Converts SplineSet to a list of tokenized commands in svg path."""
88
- path = []
89
- prev_xy = []
90
- for line in spline.splitlines():
91
- if not line:
92
- continue
93
- tokens = line.split(' ')
94
- cmd = tokens[-2]
95
- if cmd not in 'cml':
96
- # COMMAND NOT RECOGNIZED.
97
- return []
98
- # assert cmd in 'cml', 'Command not recognized: {}'.format(cmd)
99
- args = tokens[:-2]
100
- args = [float(x) for x in args if x]
101
-
102
- if replace_with_prev and cmd in 'c':
103
- args[:2] = prev_xy
104
- prev_xy = args[-2:]
105
-
106
- new_y_args = []
107
- for i, a in enumerate(args):
108
- if i % 2 == 1:
109
- new_y_args.append((height - a))
110
- else:
111
- new_y_args.append((a))
112
-
113
- path.append([cmd.upper()] + new_y_args)
114
- return path
115
-
116
-
117
- def _sfd_to_path_list(single, replace_with_prev=False):
118
- """Converts the given SFD glyph into a path."""
119
- return _spline_to_path_list(_get_spline(single['sfd']), single['vwidth'], replace_with_prev)
120
-
121
-
122
- #################### UTILS FOR PROCESSING TOKENIZED PATHS #####################
123
- def _add_missing_cmds(path, remove_zs=False):
124
- """Adds missing cmd tags to the commands in the svg."""
125
- # For instance, the command 'a' takes 7 arguments, but some SVGs declare:
126
- # a 1 2 3 4 5 6 7 8 9 10 11 12 13 14
127
- # Which is 14 arguments. This function converts the above to the equivalent:
128
- # a 1 2 3 4 5 6 7 a 8 9 10 11 12 13 14
129
- #
130
- # Note: if remove_zs is True, this also removes any occurences of z commands.
131
- new_path = []
132
- for cmd in path:
133
- if not remove_zs or cmd[0] not in 'Zz':
134
- for new_cmd in add_missing_cmd(cmd):
135
- new_path.append(new_cmd)
136
- return new_path
137
-
138
-
139
- def add_missing_cmd(command_list):
140
- """Adds missing cmd tags to the given command list."""
141
- # E.g.: given:
142
- # ['a', '0', '0', '0', '0', '0', '0', '0',
143
- # '0', '0', '0', '0', '0', '0', '0']
144
- # Converts to:
145
- # [['a', '0', '0', '0', '0', '0', '0', '0'],
146
- # ['a', '0', '0', '0', '0', '0', '0', '0']]
147
- # And returns a string that joins these elements with spaces.
148
- cmd_tag = command_list[0]
149
- args = command_list[1:]
150
-
151
- final_cmds = []
152
- for arg_batch in grouper(args, NUM_ARGS[cmd_tag]):
153
- final_cmds.append([cmd_tag] + list(arg_batch))
154
-
155
- if not final_cmds:
156
- # command has no args (e.g.: 'z')
157
- final_cmds = [[cmd_tag]]
158
-
159
- return final_cmds
160
-
161
-
162
- def _normalize_args(arglist, norm, add=None, flip=False):
163
- """Normalize the given args with the given norm value."""
164
- new_arglist = []
165
- for i, arg in enumerate(arglist):
166
- new_arg = float(arg)
167
-
168
- if add is not None:
169
- add_to_x, add_to_y = add
170
-
171
- # This argument is an x-coordinate if even, y-coordinate if odd
172
- # except when flip == True
173
- if i % 2 == 0:
174
- new_arg += add_to_y if flip else add_to_x
175
- else:
176
- new_arg += add_to_x if flip else add_to_y
177
-
178
- new_arglist.append(str(24 * new_arg / norm))
179
- return new_arglist
180
-
181
-
182
- def _normalize_based_on_viewbox(path, viewbox):
183
- """Normalizes all args in a path to a standard 24x24 viewbox."""
184
- # Each SVG lives in a 2D plane. The viewbox determines the region of that
185
- # plane that gets rendered. For instance, some designers may work with a
186
- # viewbox that's 24x24, others with one that's 100x100, etc.
187
-
188
- # Suppose I design the the letter "h" in the Arial style using a 100x100
189
- # viewbox (let's call it icon A). Let's suppose the icon has height 75. Then,
190
- # I design the same character using a 20x20 viewbox (call this icon B), with
191
- # height 15 (=75% of 20). This means that, when rendered, both icons with look
192
- # exactly the same, but the scale of the commands each icon is using is
193
- # different. For instance, if icon A has a command like "lineTo 100 100", the
194
- # equivalent command in icon B will be "lineTo 20 20".
195
-
196
- # In order to avoid this problem and bring all real values to the same scale,
197
- # I scale all icons' commands to use a 24x24 viewbox. This function does this:
198
- # it converts a path that exists in the given viewbox into a standard 24x24
199
- # viewbox.
200
- viewbox = viewbox.split(' ')
201
- norm = max(int(viewbox[-1]), int(viewbox[-2]))
202
-
203
- if int(viewbox[-1]) > int(viewbox[-2]):
204
- add_to_y = 0
205
- add_to_x = abs(int(viewbox[-1]) - int(viewbox[-2])) / 2
206
- else:
207
- add_to_y = abs(int(viewbox[-1]) - int(viewbox[-2])) / 2
208
- add_to_x = 0
209
-
210
- new_path = []
211
- for command in path:
212
- if command[0] == 'a':
213
- new_path.append([command[0]] + _normalize_args(command[1:3], norm)
214
- + command[3:6] + _normalize_args(command[6:], norm))
215
- elif command[0] == 'A':
216
- new_path.append([command[0]] + _normalize_args(command[1:3], norm)
217
- + command[3:6] + _normalize_args(command[6:], norm, add=(add_to_x, add_to_y)))
218
- elif command[0] == 'V':
219
- new_path.append([command[0]] + _normalize_args(command[1:], norm, add=(add_to_x, add_to_y), flip=True))
220
- elif command[0] == command[0].upper():
221
- new_path.append([command[0]] + _normalize_args(command[1:], norm, add=(add_to_x, add_to_y)))
222
- elif command[0] in 'zZ':
223
- new_path.append([command[0]])
224
- else:
225
- new_path.append([command[0]] + _normalize_args(command[1:], norm))
226
-
227
- return new_path
228
-
229
-
230
- def _convert_args(args, curr_pos, cmd):
231
- """Converts given args to relative values."""
232
- # NOTE: glyphs only use a very small subset of commands (L, C, M, and Z -- I
233
- # believe). So I'm not handling A and H for now.
234
- if cmd in 'AH':
235
- raise NotImplementedError('These commands have >6 args (not supported).')
236
-
237
- new_args = []
238
- for i, arg in enumerate(args):
239
- x_or_y = i % 2
240
- if cmd == 'H':
241
- x_or_y = (i + 1) % 2
242
- new_args.append(str(float(arg) - curr_pos[x_or_y]))
243
-
244
- return new_args
245
-
246
-
247
- def _update_curr_pos(curr_pos, cmd, start_of_path):
248
- """Calculate the position of the pen after cmd is applied."""
249
- if cmd[0] in 'ml':
250
- curr_pos = [curr_pos[0] + float(cmd[1]), curr_pos[1] + float(cmd[2])]
251
- if cmd[0] == 'm':
252
- start_of_path = curr_pos
253
- elif cmd[0] in 'z':
254
- curr_pos = start_of_path
255
- elif cmd[0] in 'h':
256
- curr_pos = [curr_pos[0] + float(cmd[1]), curr_pos[1]]
257
- elif cmd[0] in 'v':
258
- curr_pos = [curr_pos[0], curr_pos[1] + float(cmd[1])]
259
- elif cmd[0] in 'ctsqa':
260
- curr_pos = [curr_pos[0] + float(cmd[-2]), curr_pos[1] + float(cmd[-1])]
261
-
262
- return curr_pos, start_of_path
263
-
264
-
265
- def _make_relative(cmds):
266
- """Convert commands in a path to relative positioning."""
267
- curr_pos = (0.0, 0.0)
268
- start_of_path = (0.0, 0.0)
269
- new_cmds = []
270
- for cmd in cmds:
271
- if cmd[0].lower() == cmd[0]:
272
- new_cmd = cmd
273
- elif cmd[0].lower() == 'z':
274
- new_cmd = [cmd[0].lower()]
275
- else:
276
- new_cmd = [cmd[0].lower()] + _convert_args(cmd[1:], curr_pos, cmd=cmd[0])
277
- new_cmds.append(new_cmd)
278
- curr_pos, start_of_path = _update_curr_pos(curr_pos, new_cmd, start_of_path)
279
- return new_cmds
280
-
281
-
282
- def _is_to_left_of(pt1, pt2):
283
- pt1_norm = (pt1[0]**2 + pt1[1]**2)
284
- pt2_norm = (pt2[0]**2 + pt2[1]**2)
285
- return pt1[1] < pt2[1] or (pt1_norm == pt2_norm and pt1[0] < pt2[0])
286
-
287
-
288
- def _get_leftmost_point(path):
289
- """Returns the leftmost, topmost point of the path."""
290
- leftmost = (float('inf'), float('inf'))
291
- idx = -1
292
-
293
- for i, cmd in enumerate(path):
294
- if len(cmd) > 1:
295
- endpoint = cmd[-2:]
296
- if _is_to_left_of(endpoint, leftmost):
297
- leftmost = endpoint
298
- idx = i
299
-
300
- return leftmost, idx
301
-
302
-
303
- def _separate_substructures(path):
304
- """Returns a list of subpaths, each representing substructures the glyph."""
305
- substructures = []
306
- curr = []
307
- for cmd in path:
308
- if cmd[0] in 'mM' and curr:
309
- substructures.append(curr)
310
- curr = []
311
- curr.append(cmd)
312
- if curr:
313
- substructures.append(curr)
314
- return substructures
315
-
316
-
317
- def _is_clockwise(subpath):
318
- """Returns whether the given subpath is clockwise-oriented."""
319
- pts = [cmd[-2:] for cmd in subpath]
320
- det = 0
321
- for i in range(len(pts) - 1):
322
- det += np.linalg.det(pts[i:i + 2])
323
- return det > 0
324
-
325
-
326
- def _make_clockwise(subpath):
327
- """Inverts the cardinality of the given subpath."""
328
- new_path = [subpath[0]]
329
- other_cmds = list(reversed(subpath[1:]))
330
- for i, cmd in enumerate(other_cmds):
331
- if i + 1 == len(other_cmds):
332
- where_we_were = subpath[0][-2:]
333
- else:
334
- where_we_were = other_cmds[i + 1][-2:]
335
-
336
- if len(cmd) > 3:
337
- new_cmd = [cmd[0], cmd[3], cmd[4], cmd[1], cmd[2],
338
- where_we_were[0], where_we_were[1]]
339
- else:
340
- new_cmd = [cmd[0], where_we_were[0], where_we_were[1]]
341
-
342
- new_path.append(new_cmd)
343
- return new_path
344
-
345
-
346
- def _canonicalize(path):
347
- """Makes all paths start at top left, and go clockwise first."""
348
- # convert args to floats
349
- path = [[x[0]] + list(map(float, x[1:])) for x in path]
350
-
351
- # _canonicalize each subpath separately
352
- new_substructures = []
353
- for subpath in _separate_substructures(path):
354
- leftmost_point, leftmost_idx = _get_leftmost_point(subpath)
355
- reordered = ([['M', leftmost_point[0], leftmost_point[1]]] + subpath[leftmost_idx + 1:] + subpath[1:leftmost_idx + 1])
356
- new_substructures.append((reordered, leftmost_point))
357
-
358
- new_path = []
359
- first_substructure_done = False
360
- should_flip_cardinality = False
361
- for sp, _ in sorted(new_substructures, key=lambda x: (x[1][1], x[1][0])):
362
- if not first_substructure_done:
363
- # we're looking at the first substructure now, we can determine whether we
364
- # will flip the cardniality of the whole icon or not
365
- should_flip_cardinality = not _is_clockwise(sp)
366
- first_substructure_done = True
367
-
368
- if should_flip_cardinality:
369
- sp = _make_clockwise(sp)
370
-
371
- new_path.extend(sp)
372
-
373
- # convert args to strs
374
- path = [[x[0]] + list(map(str, x[1:])) for x in new_path]
375
- return path
376
-
377
-
378
- # ######### UTILS FOR CONVERTING TOKENIZED PATHS TO VECTORS ###########
379
- def _path_to_vector(path, categorical=False):
380
- """Converts path's commands to a series of vectors."""
381
- # Notes:
382
- # - The SimpleSVG dataset does not have any 't', 'q', 'Z', 'T', or 'Q'.
383
- # Thus, we don't handle those here.
384
- # - We also removed all 'z's.
385
- # - The x-axis-rotation argument to a commands is always 0 in this
386
- # dataset, so we ignore it
387
-
388
- # Many commands have args that correspond to args in other commands.
389
- # v __,__ _______________ ______________,_________ __,__ __,__ _,y
390
- # h __,__ _______________ ______________,_________ __,__ __,__ x,_
391
- # z __,__ _______________ ______________,_________ __,__ __,__ _,_
392
- # a rx,ry x-axis-rotation large-arc-flag,sweepflag __,__ __,__ x,y
393
- # l __,__ _______________ ______________,_________ __,__ __,__ x,y
394
- # c __,__ _______________ ______________,_________ x1,y1 x2,y2 x,y
395
- # m __,__ _______________ ______________,_________ __,__ __,__ x,y
396
- # s __,__ _______________ ______________,_________ __,__ x2,y2 x,y
397
-
398
- # So each command will be converted to a vector where the dimension is the
399
- # minimal number of arguments to all commands:
400
- # [rx, ry, large-arc-flag, sweepflag, x1, y1, x2, y2, x, y]
401
- # If a command does not output a certain arg, it is set to 0.
402
- # "l 5,5" becomes [0, 0, 0, 0, 0, 0, 0, 0, 5, 5]
403
-
404
- # Also note, as of now we also output an extra dimension at index 0, which
405
- # indicates which command is being outputted (integer).
406
- new_path = []
407
- for cmd in path:
408
- new_path.append(_cmd_to_vector(cmd, categorical=categorical))
409
- return new_path
410
-
411
-
412
- def _cmd_to_vector(cmd_list, categorical=False):
413
- """Converts the given command (given as a list) into a vector."""
414
- # For description of how this conversion happens, see
415
- # _path_to_vector docstring.
416
- cmd = cmd_list[0]
417
- args = cmd_list[1:]
418
-
419
- if not categorical:
420
- # integer, for MSE
421
- command = [float(CMD_MAPPING[cmd])]
422
- else:
423
- # one hot + 1 dim for EOS.
424
- command = [0.0] * (len(CMDS_LIST) + 1)
425
- command[CMD_MAPPING[cmd] + 1] = 1.0
426
-
427
- arguments = [0.0] * 10
428
- if cmd in 'hH':
429
- arguments[8] = float(args[0]) # x
430
- elif cmd in 'vV':
431
- arguments[9] = float(args[0]) # y
432
- elif cmd in 'mMlLtT':
433
- arguments[8] = float(args[0]) # x
434
- arguments[9] = float(args[1]) # y
435
- elif cmd in 'sSqQ':
436
- arguments[6] = float(args[0]) # x2
437
- arguments[7] = float(args[1]) # y2
438
- arguments[8] = float(args[2]) # x
439
- arguments[9] = float(args[3]) # y
440
- elif cmd in 'cC':
441
- arguments[4] = float(args[0]) # x1
442
- arguments[5] = float(args[1]) # y1
443
- arguments[6] = float(args[2]) # x2
444
- arguments[7] = float(args[3]) # y2
445
- arguments[8] = float(args[4]) # x
446
- arguments[9] = float(args[5]) # y
447
- elif cmd in 'aA':
448
- arguments[0] = float(args[0]) # rx
449
- arguments[1] = float(args[1]) # ry
450
- # we skip x-axis-rotation
451
- arguments[2] = float(args[3]) # large-arc-flag
452
- arguments[3] = float(args[4]) # sweep-flag
453
- # a does not have x1, y1, x2, y2 args
454
- arguments[8] = float(args[5]) # x
455
- arguments[9] = float(args[6]) # y
456
-
457
- return command + arguments
458
-
459
-
460
- ################## UTILS FOR RENDERING PATH INTO IMAGE #################
461
- def _cubicbezier(x0, y0, x1, y1, x2, y2, x3, y3, n=40):
462
- """Return n points along cubiz bezier with given control points."""
463
- # from http://rosettacode.org/wiki/Bitmap/B%C3%A9zier_curves/Cubic
464
- pts = []
465
- for i in range(n + 1):
466
- t = float(i) / float(n)
467
- a = (1. - t)**3
468
- b = 3. * t * (1. - t)**2
469
- c = 3.0 * t**2 * (1.0 - t)
470
- d = t**3
471
-
472
- x = float(a * x0 + b * x1 + c * x2 + d * x3)
473
- y = float(a * y0 + b * y1 + c * y2 + d * y3)
474
- pts.append((x, y))
475
- return list(zip(*pts))
476
-
477
-
478
- def _update_pos(curr_pos, end_pos, absolute):
479
- if absolute:
480
- return end_pos
481
- return curr_pos[0] + end_pos[0], curr_pos[1] + end_pos[1]
482
-
483
-
484
- def constant_color(*unused_args):
485
- return np.array([255, 255, 255])
486
-
487
-
488
- def _render_cubic(canvas, curr_pos, c_args, absolute, color):
489
- """Renders a cubic bezier curve in the given canvas."""
490
- if not absolute:
491
- c_args[0] += curr_pos[0]
492
- c_args[1] += curr_pos[1]
493
- c_args[2] += curr_pos[0]
494
- c_args[3] += curr_pos[1]
495
- c_args[4] += curr_pos[0]
496
- c_args[5] += curr_pos[1]
497
- x, y = _cubicbezier(curr_pos[0], curr_pos[1],
498
- c_args[0], c_args[1],
499
- c_args[2], c_args[3],
500
- c_args[4], c_args[5])
501
- max_possible = len(canvas)
502
- x = [int(round(x_)) for x_ in x]
503
- y = [int(round(y_)) for y_ in y]
504
-
505
- def within_range(x):
506
- return 0 <= x < max_possible
507
-
508
- filtered = [(x_, y_) for x_, y_ in zip(x, y)
509
- if within_range(x_) and within_range(y_)]
510
- if not filtered:
511
- return
512
- x, y = list(zip(*filtered))
513
- canvas[y, x, :] = color
514
-
515
-
516
- def _render_line(canvas, curr_pos, l_args, absolute, color):
517
- """Renders a line in the given canvas."""
518
- end_point = l_args
519
- if not absolute:
520
- end_point[0] += curr_pos[0]
521
- end_point[1] += curr_pos[1]
522
- rr, cc, val = draw.line_aa(int(curr_pos[0]), int(curr_pos[1]),
523
- int(end_point[0]), int(end_point[1]))
524
-
525
- max_possible = len(canvas)
526
-
527
- def within_range(x):
528
- return 0 <= x < max_possible
529
-
530
- filtered = [(x, y, v) for x, y, v in zip(rr, cc, val)
531
- if within_range(x) and within_range(y)]
532
- if not filtered:
533
- return
534
- rr, cc, val = list(zip(*filtered))
535
- val = [(v * color) for v in val]
536
- canvas[cc, rr, :] = val
537
-
538
-
539
- def _per_step_render(path, absolute=False, color=constant_color):
540
- """Render the icon's edges, given its path."""
541
- def to_canvas_size(l):
542
- return [float(f) * (64. / 24.) for f in l]
543
-
544
- canvas = np.zeros((64, 64, 3))
545
- curr_pos = (0.0, 0.0)
546
- for i, cmd in enumerate(path):
547
- if not cmd:
548
- continue
549
- if cmd[0] in 'mM':
550
- curr_pos = _update_pos(curr_pos, to_canvas_size(cmd[-2:]), absolute)
551
- elif cmd[0] in 'cC':
552
- _render_cubic(canvas, curr_pos, to_canvas_size(cmd[1:]), absolute, color(i, 55))
553
- curr_pos = _update_pos(curr_pos, to_canvas_size(cmd[-2:]), absolute)
554
- elif cmd[0] in 'lL':
555
- _render_line(canvas, curr_pos, to_canvas_size(cmd[1:]), absolute, color(i, 55))
556
- curr_pos = _update_pos(curr_pos, to_canvas_size(cmd[1:]), absolute)
557
-
558
- return canvas
559
-
560
-
561
- def _zoom_out(path_list, add_baseline=0., per=22):
562
- """Makes glyph slightly smaller in viewbox, makes some descenders visible."""
563
- # assumes tensor is already unnormalized, and in long form
564
- new_path = []
565
- for command in path_list:
566
- args = []
567
- is_even = False
568
- for arg in command[1:]:
569
- if is_even:
570
- args.append(str(float(arg) - ((24. - per) / 24.) * 64. / 4.))
571
- is_even = False
572
- else:
573
- args.append(str(float(arg) - add_baseline))
574
- is_even = True
575
- new_path.append([command[0]] + args)
576
- return new_path
577
-
578
-
579
- ##################### UTILS FOR PROCESSING VECTORS ################
580
- def _append_eos(sample, categorical, feature_dim):
581
- if not categorical:
582
- eos = -1 * np.ones(feature_dim)
583
- else:
584
- eos = np.zeros(feature_dim)
585
- eos[0] = 1.0
586
- sample.append(eos)
587
- return sample
588
-
589
-
590
- def _make_simple_cmds_long(out):
591
- """Converts svg decoder output to format required by some render functions."""
592
- # out has 10 dims
593
- # the first 4 are respectively dims 0, 4, 5, 9 of the full 20-dim onehot vec
594
- # the latter 6 are the 6 last dims of the 10-dim arg vec
595
- shape_minus_dim = list(np.shape(out))[:-1]
596
- return np.concatenate([out[..., :1],
597
- np.zeros(shape_minus_dim + [3]),
598
- out[..., 1:3],
599
- np.zeros(shape_minus_dim + [3]),
600
- out[..., 3:4],
601
- np.zeros(shape_minus_dim + [14]),
602
- out[..., 4:]], -1)
603
-
604
-
605
- ################# UTILS FOR CONVERTING VECTORS TO SVGS ########################
606
- def _vector_to_svg(vectors, stop_at_eos=False, categorical=False):
607
- """Tranforms a given vector to an svg string."""
608
- new_path = []
609
- for vector in vectors:
610
- if stop_at_eos:
611
- if categorical:
612
- try:
613
- is_eos = np.argmax(vector[:len(CMDS_LIST) + 1]) == 0
614
- except Exception:
615
- raise Exception(vector)
616
- else:
617
- is_eos = vector[0] < -0.5
618
-
619
- if is_eos:
620
- break
621
- new_path.append(' '.join(_vector_to_cmd(vector, categorical=categorical)))
622
- new_path = ' '.join(new_path)
623
- return SVG_PREFIX_BIG + PATH_PREFIX_1 + new_path + PATH_POSFIX_1 + SVG_POSFIX
624
-
625
-
626
- def _vector_to_cmd(vector, categorical=False, return_floats=False):
627
- """Does the inverse transformation as _cmd_to_vector()."""
628
- cast_fn = float if return_floats else str
629
- if categorical:
630
- command = vector[:len(CMDS_LIST) + 1],
631
- arguments = vector[len(CMDS_LIST) + 1:]
632
- cmd_idx = np.argmax(command) - 1
633
- else:
634
- command, arguments = vector[:1], vector[1:]
635
- cmd_idx = int(round(command[0]))
636
-
637
- if cmd_idx < -0.5:
638
- # EOS
639
- return []
640
- if cmd_idx >= len(CMDS_LIST):
641
- cmd_idx = len(CMDS_LIST) - 1
642
-
643
- cmd = CMDS_LIST[cmd_idx]
644
- cmd = cmd.upper()
645
- cmd_list = [cmd]
646
-
647
- if cmd in 'hH':
648
- cmd_list.append(cast_fn(arguments[8])) # x
649
- elif cmd in 'vV':
650
- cmd_list.append(cast_fn(arguments[9])) # y
651
- elif cmd in 'mMlLtT':
652
- cmd_list.append(cast_fn(arguments[8])) # x
653
- cmd_list.append(cast_fn(arguments[9])) # y
654
- elif cmd in 'sSqQ':
655
- cmd_list.append(cast_fn(arguments[6])) # x2
656
- cmd_list.append(cast_fn(arguments[7])) # y2
657
- cmd_list.append(cast_fn(arguments[8])) # x
658
- cmd_list.append(cast_fn(arguments[9])) # y
659
- elif cmd in 'cC':
660
- cmd_list.append(cast_fn(arguments[4])) # x1
661
- cmd_list.append(cast_fn(arguments[5])) # y1
662
- cmd_list.append(cast_fn(arguments[6])) # x2
663
- cmd_list.append(cast_fn(arguments[7])) # y2
664
- cmd_list.append(cast_fn(arguments[8])) # x
665
- cmd_list.append(cast_fn(arguments[9])) # y
666
- elif cmd in 'aA':
667
- cmd_list.append(cast_fn(arguments[0])) # rx
668
- cmd_list.append(cast_fn(arguments[1])) # ry
669
- # x-axis-rotation is always 0
670
- cmd_list.append(cast_fn('0'))
671
- # the following two flags are binary.
672
- cmd_list.append(cast_fn(1 if arguments[2] > 0.5 else 0)) # large-arc-flag
673
- cmd_list.append(cast_fn(1 if arguments[3] > 0.5 else 0)) # sweep-flag
674
- cmd_list.append(cast_fn(arguments[8])) # x
675
- cmd_list.append(cast_fn(arguments[9])) # y
676
-
677
- return cmd_list
678
-
679
-
680
- ############## UTILS FOR CONVERTING SVGS/VECTORS TO IMAGES ###################
681
-
682
- # From Infer notebook
683
- start = ("""<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www."""
684
- """w3.org/1999/xlink" width="256px" height="256px" style="-ms-trans"""
685
- """form: rotate(360deg); -webkit-transform: rotate(360deg); transfo"""
686
- """rm: rotate(360deg);" preserveAspectRatio="xMidYMid meet" viewBox"""
687
- """="0 0 24 30"><path d=\"""")
688
- end = """\" fill="currentColor"/></svg>"""
689
-
690
- COMMAND_RX = re.compile("([MmLlHhVvCcSsQqTtAaZz])")
691
- FLOAT_RX = re.compile("[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?") # noqa
692
-
693
-
694
- def svg_html_to_path_string(svg):
695
- return svg.replace(start, '').replace(end, '')
696
-
697
-
698
- def _tokenize(pathdef):
699
- """Returns each svg token from path list."""
700
- # e.g.: 'm0.1-.5c0,6' -> m', '0.1, '-.5', 'c', '0', '6'
701
- for x in COMMAND_RX.split(pathdef):
702
- if x != '' and x in 'MmLlHhVvCcSsQqTtAaZz':
703
- yield x
704
- for token in FLOAT_RX.findall(x):
705
- yield token
706
-
707
-
708
- def path_string_to_tokenized_commands(path):
709
- """Tokenizes the given path string.
710
-
711
- E.g.:
712
- Given M 0.5 0.5 l 0.25 0.25 z
713
- Returns [['M', '0.5', '0.5'], ['l', '0.25', '0.25'], ['z']]
714
- """
715
- new_path = []
716
- current_cmd = []
717
- for token in _tokenize(path):
718
- if len(current_cmd) > 0:
719
- if token in 'MmLlHhVvCcSsQqTtAaZz':
720
- # cmd ended, convert to vector and add to new_path
721
- new_path.append(current_cmd)
722
- current_cmd = [token]
723
- else:
724
- # add arg to command
725
- current_cmd.append(token)
726
- else:
727
- # add to start new cmd
728
- current_cmd.append(token)
729
-
730
- if current_cmd:
731
- # process command still unprocessed
732
- new_path.append(current_cmd)
733
-
734
- return new_path
735
-
736
-
737
- def separate_substructures(tokenized_commands):
738
- """Returns a list of SVG substructures."""
739
- # every moveTo command starts a new substructure
740
- # an SVG substructure is a subpath that closes on itself
741
- # such as the outter and the inner edge of the character `o`
742
- substructures = []
743
- curr = []
744
- for cmd in tokenized_commands:
745
- if cmd[0] in 'mM' and len(curr) > 0:
746
- substructures.append(curr)
747
- curr = []
748
- curr.append(cmd)
749
- if len(curr) > 0:
750
- substructures.append(curr)
751
- return substructures
752
-
753
-
754
- def postprocess(svg, dist_thresh=2., skip=False):
755
- path = svg_html_to_path_string(svg)
756
- svg_template = svg.replace(path, '{}')
757
- tokenized_commands = path_string_to_tokenized_commands(path)
758
-
759
- def dist(a, b):
760
- return np.sqrt((float(a[0]) - float(b[0]))**2 + (float(a[1]) - float(b[1]))**2)
761
-
762
- def are_close_together(a, b, t):
763
- return dist(a, b) < t
764
-
765
- # first, go through each start/end point and merge if they're close enough
766
- # together (that is, make end point the same as the start point).
767
- # TODO: there are better ways of doing this, in a way that propagates error
768
- # back (so if total error is 0.2, go through all N commands in this
769
- # substructure and fix each by 0.2/N (unless they have 0 vertical change))
770
- substructures = separate_substructures(tokenized_commands)
771
-
772
- previous_substructure_endpoint = (0., 0.,)
773
- for substructure in substructures:
774
- # first, if the last substructure's endpoint was updated, we must update
775
- # the start point of this one to reflect the opposite update
776
- substructure[0][-2] = str(float(substructure[0][-2]) -
777
- previous_substructure_endpoint[0])
778
- substructure[0][-1] = str(float(substructure[0][-1]) -
779
- previous_substructure_endpoint[1])
780
-
781
- start = list(map(float, substructure[0][-2:]))
782
- curr_pos = (0., 0.)
783
- for cmd in substructure:
784
- curr_pos, _ = _update_curr_pos(curr_pos, cmd, (0., 0.))
785
- if are_close_together(start, curr_pos, dist_thresh):
786
- new_point = np.array(start)
787
- previous_substructure_endpoint = ((new_point[0] - curr_pos[0]),
788
- (new_point[1] - curr_pos[1]))
789
- substructure[-1][-2] = str(float(substructure[-1][-2]) +
790
- (new_point[0] - curr_pos[0]))
791
- substructure[-1][-1] = str(float(substructure[-1][-1]) +
792
- (new_point[1] - curr_pos[1]))
793
- if substructure[-1][0] in 'cC':
794
- substructure[-1][-4] = str(float(substructure[-1][-4]) +
795
- (new_point[0] - curr_pos[0]))
796
- substructure[-1][-3] = str(float(substructure[-1][-3]) +
797
- (new_point[1] - curr_pos[1]))
798
-
799
- if skip:
800
- return svg_template.format(' '.join([' '.join(' '.join(cmd) for cmd in s)
801
- for s in substructures]))
802
-
803
- def cosa(x, y):
804
- return (x[0] * y[0] + x[1] * y[1]) / ((np.sqrt(x[0]**2 + x[1]**2) * np.sqrt(y[0]**2 + y[1]**2)))
805
-
806
- def rotate(a, x, y):
807
- return (x * np.cos(a) - y * np.sin(a), y * np.cos(a) + x * np.sin(a))
808
- # second, gotta find adjacent bezier curves and, if their control points
809
- # are well enough aligned, fully align them
810
- for substructure in substructures:
811
- curr_pos = (0., 0.)
812
- new_curr_pos, _ = _update_curr_pos((0., 0.,), substructure[0], (0., 0.))
813
-
814
- for cmd_idx in range(1, len(substructure)):
815
- prev_cmd = substructure[cmd_idx-1]
816
- cmd = substructure[cmd_idx]
817
-
818
- new_new_curr_pos, _ = _update_curr_pos(
819
- new_curr_pos, cmd, (0., 0.))
820
-
821
- if cmd[0] == 'c':
822
- if prev_cmd[0] == 'c':
823
- # check the vectors and update if needed
824
- # previous control pt wrt new curr point
825
- prev_ctr_point = (curr_pos[0] + float(prev_cmd[3]) - new_curr_pos[0],
826
- curr_pos[1] + float(prev_cmd[4]) - new_curr_pos[1])
827
- ctr_point = (float(cmd[1]), float(cmd[2]))
828
-
829
- if -1. < cosa(prev_ctr_point, ctr_point) < -0.95:
830
- # calculate exact angle between the two vectors
831
- angle_diff = (np.pi - np.arccos(cosa(prev_ctr_point, ctr_point)))/2
832
-
833
- # rotate each vector by angle/2 in the correct direction for each.
834
- sign = np.sign(np.cross(prev_ctr_point, ctr_point))
835
- new_ctr_point = rotate(sign * angle_diff, *ctr_point)
836
- new_prev_ctr_point = rotate(-sign * angle_diff, *prev_ctr_point)
837
-
838
- # override the previous control points
839
- # (which has to be wrt previous curr position)
840
- substructure[cmd_idx-1][3] = str(new_prev_ctr_point[0] -
841
- curr_pos[0] + new_curr_pos[0])
842
- substructure[cmd_idx-1][4] = str(new_prev_ctr_point[1] -
843
- curr_pos[1] + new_curr_pos[1])
844
- substructure[cmd_idx][1] = str(new_ctr_point[0])
845
- substructure[cmd_idx][2] = str(new_ctr_point[1])
846
-
847
- curr_pos = new_curr_pos
848
- new_curr_pos = new_new_curr_pos
849
-
850
- return svg_template.format(' '.join([' '.join(' '.join(cmd) for cmd in s)
851
- for s in substructures]))
852
-
853
-
854
- # def get_means_stdevs(data_dir):
855
- # """Returns the means and stdev saved in data_dir."""
856
- # if data_dir not in means_stdevs:
857
- # with tf.gfile.Open(os.path.join(data_dir, 'mean.npz'), 'r') as f:
858
- # mean_npz = np.load(f)
859
- # with tf.gfile.Open(os.path.join(data_dir, 'stdev.npz'), 'r') as f:
860
- # stdev_npz = np.load(f)
861
- # means_stdevs[data_dir] = (mean_npz, stdev_npz)
862
- # return means_stdevs[data_dir]
863
-
864
-
865
- def render(tensor, data_dir=None):
866
- """Converts SVG decoder output into HTML svg."""
867
- # undo normalization
868
- # mean_npz, stdev_npz = get_means_stdevs(data_dir)
869
- # tensor = (tensor * stdev_npz) + mean_npz
870
-
871
- # convert to html
872
- tensor = _make_simple_cmds_long(tensor)
873
- # vector = np.squeeze(np.squeeze(tensor, 0), 2)
874
- html = _vector_to_svg(tensor, stop_at_eos=True, categorical=True)
875
-
876
- # some aesthetic postprocessing
877
- html = postprocess(html)
878
- html = html.replace('256px', '50px')
879
-
880
- return html
881
-
882
- ###############
883
-
884
-
885
- def convert_to_svg(decoder_output, categorical=False):
886
- converted = []
887
- for example in decoder_output:
888
- converted.append(_vector_to_svg(example, True, categorical=categorical))
889
- return np.array(converted)
890
-
891
-
892
- def create_image_conversion_fn(max_outputs, categorical=False):
893
- """Binds the number of outputs to the image conversion fn (to svg or png)."""
894
- def convert_to_svg(decoder_output):
895
- converted = []
896
- for example in decoder_output:
897
- if len(converted) == max_outputs:
898
- break
899
- converted.append(_vector_to_svg(example, True, categorical=categorical))
900
- return np.array(converted)
901
-
902
- return convert_to_svg
903
-
904
-
905
- ################### UTILS FOR CREATING TF SUMMARIES ##########################
906
- def _make_encoded_image(img_tensor):
907
- pil_img = Image.fromarray(np.squeeze(img_tensor * 255).astype(np.uint8), mode='L')
908
- buff = io.BytesIO()
909
- pil_img.save(buff, format='png')
910
- encoded_image = buff.getvalue()
911
- return encoded_image
912
-
913
-
914
- ################### CHECK GLYPH/PATH VALID ##############################################
915
- def is_valid_glyph(g):
916
- is_09 = 48 <= g['uni'] <= 57
917
- is_capital_az = 65 <= g['uni'] <= 90
918
- is_az = 97 <= g['uni'] <= 122
919
- is_valid_dims = g['width'] != 0 and g['vwidth'] != 0
920
- return (is_09 or is_capital_az or is_az) and is_valid_dims
921
-
922
-
923
- def is_valid_path(pathunibfp):
924
- return pathunibfp[0] and len(pathunibfp[0]) <= MAX_SEQ_LEN
925
-
926
-
927
- ################### DATASET PROCESSING #######################################
928
- def convert_to_path(g):
929
- """Converts SplineSet in SFD font to str path."""
930
- path = _sfd_to_path_list(g)
931
- path = _add_missing_cmds(path, remove_zs=False)
932
- path = _normalize_based_on_viewbox(path, '0 0 {} {}'.format(g['width'], g['vwidth']))
933
- return path, g['uni'], g['binary_fp']
934
-
935
-
936
- def create_example(pathunibfp):
937
- """Bulk of dataset processing. Converts str path to np array"""
938
- path, uni, binary_fp = pathunibfp
939
- final = {}
940
-
941
- # zoom out
942
- path = _zoom_out(path)
943
- # make clockwise
944
- path = _canonicalize(path)
945
-
946
- # render path for training
947
- final['rendered'] = _per_step_render(path, absolute=True)
948
-
949
- # make path relative
950
- # path = _make_relative(path)
951
- # convert to vector
952
- vector = _path_to_vector(path, categorical=True)
953
- # make simple vector
954
- vector = np.array(vector)
955
- vector = np.concatenate([np.take(vector, [0, 4, 5, 9], axis=-1), vector[..., -6:]], axis=-1)
956
-
957
- # count some stats
958
- final['seq_len'] = np.shape(vector)[0]
959
- # final['class'] = int(_map_uni_to_alphanum(uni))
960
- final['class'] = int(_map_uni_to_alpha(uni)) # be advised that the class is useless bcz it is all 0
961
- final['binary_fp'] = str(binary_fp)
962
-
963
- # append eos
964
- vector = _append_eos(vector.tolist(), True, 10)
965
-
966
- # pad path to MAX_SEQ_LEN + 1 (with eos)
967
- final['sequence'] = np.concatenate((vector, np.zeros(((MAX_SEQ_LEN - final['seq_len']), 10))), 0)
968
-
969
- # make pure list:
970
- # use last channel only
971
- final['rendered'] = np.reshape(final['rendered'][..., 0], [64 * 64]).astype(np.float32).tolist()
972
- final['sequence'] = np.reshape(final['sequence'], [(MAX_SEQ_LEN + 1) * 10]).astype(np.float32).tolist()
973
- final['class'] = np.reshape(final['class'], [1]).astype(np.int64).tolist()
974
- final['seq_len'] = np.reshape(final['seq_len'], [1]).astype(np.int64).tolist()
975
- return final
976
-
977
-
978
- def mean_to_example(mean_stdev):
979
- """Converts the found mean and stdev to example."""
980
- # mean_stdev is a dict
981
- mean_stdev['mean'] = np.reshape(mean_stdev['mean'], [10]).astype(np.float32).tolist()
982
- mean_stdev['variance'] = np.reshape(mean_stdev['variance'], [10]).astype(np.float32).tolist()
983
- mean_stdev['stddev'] = np.reshape(mean_stdev['stddev'], [10]).astype(np.float32).tolist()
984
- mean_stdev['count'] = np.reshape(mean_stdev['count'], [1]).astype(np.int64).tolist()
985
- return mean_stdev
986
-
987
-
988
- def convert_simple_vector_to_path(seq):
989
- path=[]
990
- for i in range(seq.shape[0]):
991
- path_i=[]
992
- cmd = np.argmax(seq[i][:4])
993
- p0 = seq[i][4:6]
994
- p1 = seq[i][6:8]
995
- p2 = seq[i][8:10]
996
- if cmd == 0:
997
- break
998
- elif cmd == 1:
999
- path_i.append('M')
1000
- path_i.append(str(p2[0]))
1001
- path_i.append(str(p2[1]))
1002
- elif cmd == 2:
1003
- path_i.append('L')
1004
- path_i.append(str(p2[0]))
1005
- path_i.append(str(p2[1]))
1006
- elif cmd == 3:
1007
- path_i.append('C')
1008
- path_i.append(str(p0[0]))
1009
- path_i.append(str(p0[1]))
1010
- path_i.append(str(p1[0]))
1011
- path_i.append(str(p1[1]))
1012
- path_i.append(str(p2[0]))
1013
- path_i.append(str(p2[1]))
1014
- else:
1015
- print("wrong!!! to path")
1016
- path.append(path_i)
1017
- return path
1018
-
1019
- def clockwise(seq):
1020
- path = convert_simple_vector_to_path(seq)
1021
- path = _canonicalize(path)
1022
- ret = {}
1023
- vector = _path_to_vector(path, categorical=True)
1024
- vector = np.array(vector)
1025
- vector = np.concatenate([np.take(vector, [0, 4, 5, 9], axis=-1), vector[..., -6:]], axis=-1)
1026
- ret['seq_len'] = np.shape(vector)[0]
1027
- vector = _append_eos(vector.tolist(), True, 10)
1028
- ret['sequence'] = np.concatenate((vector, np.zeros(((MAX_SEQ_LEN - ret['seq_len']), 10))), 0)
1029
- return ret
1030
-
1031
- ################### CHECK VALID ##############################################
1032
- class MeanStddev:
1033
- """Accumulator to compute the mean/stdev of svg commands."""
1034
-
1035
- def create_accumulator(self):
1036
- curr_sum = np.zeros([10])
1037
- sum_sq = np.zeros([10])
1038
- return (curr_sum, sum_sq, 0) # x, x^2, count
1039
-
1040
- def add_input(self, sum_count, new_input):
1041
- (curr_sum, sum_sq, count) = sum_count
1042
- # new_input is a dict with keys = ['seq_len', 'sequence']
1043
- new_seq_len = new_input['seq_len'][0] # Line #754 'seq_len' is a list of one int
1044
- assert isinstance(new_seq_len, int), print(type(new_seq_len))
1045
-
1046
- # remove padding and eos from sequence
1047
- assert isinstance(new_input['sequence'], list), print(type(new_input['sequence']))
1048
- new_input_np = np.reshape(np.array(new_input['sequence']), [-1, 10])
1049
- assert isinstance(new_input_np, np.ndarray), print(type())
1050
- assert new_input_np.shape[0] >= new_seq_len
1051
- new_input_np = new_input_np[:new_seq_len, :]
1052
-
1053
- # accumulate new_sum and new_sum_sq
1054
- new_sum = np.sum([curr_sum, np.sum(new_input_np, axis=0)], axis=0)
1055
- new_sum_sq = np.sum([sum_sq, np.sum(np.power(new_input_np, 2), axis=0)],
1056
- axis=0)
1057
- return new_sum, new_sum_sq, count + new_seq_len
1058
-
1059
- def merge_accumulators(self, accumulators):
1060
- curr_sums, sum_sqs, counts = list(zip(*accumulators))
1061
- return np.sum(curr_sums, axis=0), np.sum(sum_sqs, axis=0), np.sum(counts)
1062
-
1063
- def extract_output(self, sum_count):
1064
- (curr_sum, curr_sum_sq, count) = sum_count
1065
- if count:
1066
- mean = np.divide(curr_sum, count)
1067
- variance = np.divide(curr_sum_sq, count) - np.power(mean, 2)
1068
- # -ve value could happen due to rounding
1069
- variance = np.max([variance, np.zeros(np.shape(variance))], axis=0)
1070
- stddev = np.sqrt(variance)
1071
- return {
1072
- 'mean': mean,
1073
- 'variance': variance,
1074
- 'stddev': stddev,
1075
- 'count': count
1076
- }
1077
- else:
1078
- return {
1079
- 'mean': float('NaN'),
1080
- 'variance': float('NaN'),
1081
- 'stddev': float('NaN'),
1082
- 'count': 0
1083
  }
 
1
+ # Copyright 2020 The Magenta Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Lint as: python3
16
+ """Defines the Material Design Icons Problem."""
17
+ import io
18
+ import numpy as np
19
+ import re
20
+
21
+ from PIL import Image
22
+ from itertools import zip_longest
23
+ from skimage import draw
24
+
25
+
26
+ SVG_PREFIX_BIG = ('<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="'
27
+ 'http://www.w3.org/1999/xlink" width="256px" height="256px"'
28
+ ' style="-ms-transform: rotate(360deg); -webkit-transform:'
29
+ ' rotate(360deg); transform: rotate(360deg);" '
30
+ 'preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 30">')
31
+ PATH_PREFIX_1 = '<path d="'
32
+ PATH_POSFIX_1 = '" fill="currentColor"/>'
33
+ SVG_POSFIX = '</svg>'
34
+
35
+ NUM_ARGS = {'v': 1, 'V': 1, 'h': 1, 'H': 1, 'a': 7, 'A': 7, 'l': 2, 'L': 2,
36
+ 't': 2, 'T': 2, 'c': 6, 'C': 6, 'm': 2, 'M': 2, 's': 4, 'S': 4,
37
+ 'q': 4, 'Q': 4, 'z': 0}
38
+ # in order of arg complexity, with absolutes clustered
39
+ # recall we don't handle all commands (see docstring)
40
+ CMDS_LIST = 'zHVMLTSQCAhvmltsqca' # was zhvmltsqcaHVMLTSQCA
41
+ CMD_MAPPING = {cmd: i for i, cmd in enumerate(CMDS_LIST)}
42
+
43
+ FEATURE_DIM = 10
44
+
45
+ MAX_SEQ_LEN = 120
46
+
47
+ # Manually Change Max Sequence
48
+ def change_max_seq_len(param):
49
+ global MAX_SEQ_LEN
50
+ MAX_SEQ_LEN = param
51
+ return MAX_SEQ_LEN
52
+
53
+ ############################### GENERAL UTILS #################################
54
+ def grouper(iterable, batch_size, fill_value=None):
55
+ """Helper method for returning batches of size batch_size of a dataset."""
56
+ # grouper('ABCDEF', 3) -> 'ABC', 'DEF'
57
+ args = [iter(iterable)] * batch_size
58
+ return zip_longest(*args, fillvalue=fill_value)
59
+
60
+
61
+ def _map_uni_to_alphanum(uni):
62
+ """Maps [0-9 A-Z a-z] to numbers 0-62."""
63
+ if 48 <= uni <= 57:
64
+ return uni - 48
65
+ elif 65 <= uni <= 90:
66
+ return uni - 65 + 10
67
+ return uni - 97 + 36
68
+
69
+
70
+ def _map_uni_to_alpha(uni):
71
+ """Maps [A-Z a-z] to numbers 0-52."""
72
+ if 65 <= uni <= 90:
73
+ return uni - 65
74
+ return uni - 97 + 26
75
+
76
+
77
+ ############# UTILS FOR CONVERTING SFD/SPLINESETS TO SVG PATHS ################
78
+ def _get_spline(sfd):
79
+ if 'SplineSet' not in sfd:
80
+ return ''
81
+ pro = sfd[sfd.index('SplineSet') + 10:] # 10 is the 'SplineSet'
82
+ pro = pro[:pro.index('EndSplineSet')]
83
+ return pro
84
+
85
+
86
+ def _spline_to_path_list(spline, height, replace_with_prev=False):
87
+ """Converts SplineSet to a list of tokenized commands in svg path."""
88
+ path = []
89
+ prev_xy = []
90
+ for line in spline.splitlines():
91
+ if not line:
92
+ continue
93
+ tokens = line.split(' ')
94
+ cmd = tokens[-2]
95
+ if cmd not in 'cml':
96
+ # COMMAND NOT RECOGNIZED.
97
+ return []
98
+ # assert cmd in 'cml', 'Command not recognized: {}'.format(cmd)
99
+ args = tokens[:-2]
100
+ args = [float(x) for x in args if x]
101
+
102
+ if replace_with_prev and cmd in 'c':
103
+ args[:2] = prev_xy
104
+ prev_xy = args[-2:]
105
+
106
+ new_y_args = []
107
+ for i, a in enumerate(args):
108
+ if i % 2 == 1:
109
+ new_y_args.append((height - a))
110
+ else:
111
+ new_y_args.append((a))
112
+
113
+ path.append([cmd.upper()] + new_y_args)
114
+ return path
115
+
116
+
117
+ def _sfd_to_path_list(single, replace_with_prev=False):
118
+ """Converts the given SFD glyph into a path."""
119
+ return _spline_to_path_list(_get_spline(single['sfd']), single['vwidth'], replace_with_prev)
120
+
121
+
122
+ #################### UTILS FOR PROCESSING TOKENIZED PATHS #####################
123
+ def _add_missing_cmds(path, remove_zs=False):
124
+ """Adds missing cmd tags to the commands in the svg."""
125
+ # For instance, the command 'a' takes 7 arguments, but some SVGs declare:
126
+ # a 1 2 3 4 5 6 7 8 9 10 11 12 13 14
127
+ # Which is 14 arguments. This function converts the above to the equivalent:
128
+ # a 1 2 3 4 5 6 7 a 8 9 10 11 12 13 14
129
+ #
130
+ # Note: if remove_zs is True, this also removes any occurences of z commands.
131
+ new_path = []
132
+ for cmd in path:
133
+ if not remove_zs or cmd[0] not in 'Zz':
134
+ for new_cmd in add_missing_cmd(cmd):
135
+ new_path.append(new_cmd)
136
+ return new_path
137
+
138
+
139
+ def add_missing_cmd(command_list):
140
+ """Adds missing cmd tags to the given command list."""
141
+ # E.g.: given:
142
+ # ['a', '0', '0', '0', '0', '0', '0', '0',
143
+ # '0', '0', '0', '0', '0', '0', '0']
144
+ # Converts to:
145
+ # [['a', '0', '0', '0', '0', '0', '0', '0'],
146
+ # ['a', '0', '0', '0', '0', '0', '0', '0']]
147
+ # And returns a string that joins these elements with spaces.
148
+ cmd_tag = command_list[0]
149
+ args = command_list[1:]
150
+
151
+ final_cmds = []
152
+ for arg_batch in grouper(args, NUM_ARGS[cmd_tag]):
153
+ final_cmds.append([cmd_tag] + list(arg_batch))
154
+
155
+ if not final_cmds:
156
+ # command has no args (e.g.: 'z')
157
+ final_cmds = [[cmd_tag]]
158
+
159
+ return final_cmds
160
+
161
+
162
+ def _normalize_args(arglist, norm, add=None, flip=False):
163
+ """Normalize the given args with the given norm value."""
164
+ new_arglist = []
165
+ for i, arg in enumerate(arglist):
166
+ new_arg = float(arg)
167
+
168
+ if add is not None:
169
+ add_to_x, add_to_y = add
170
+
171
+ # This argument is an x-coordinate if even, y-coordinate if odd
172
+ # except when flip == True
173
+ if i % 2 == 0:
174
+ new_arg += add_to_y if flip else add_to_x
175
+ else:
176
+ new_arg += add_to_x if flip else add_to_y
177
+
178
+ new_arglist.append(str(24 * new_arg / norm))
179
+ return new_arglist
180
+
181
+
182
+ def _normalize_based_on_viewbox(path, viewbox):
183
+ """Normalizes all args in a path to a standard 24x24 viewbox."""
184
+ # Each SVG lives in a 2D plane. The viewbox determines the region of that
185
+ # plane that gets rendered. For instance, some designers may work with a
186
+ # viewbox that's 24x24, others with one that's 100x100, etc.
187
+
188
+ # Suppose I design the the letter "h" in the Arial style using a 100x100
189
+ # viewbox (let's call it icon A). Let's suppose the icon has height 75. Then,
190
+ # I design the same character using a 20x20 viewbox (call this icon B), with
191
+ # height 15 (=75% of 20). This means that, when rendered, both icons with look
192
+ # exactly the same, but the scale of the commands each icon is using is
193
+ # different. For instance, if icon A has a command like "lineTo 100 100", the
194
+ # equivalent command in icon B will be "lineTo 20 20".
195
+
196
+ # In order to avoid this problem and bring all real values to the same scale,
197
+ # I scale all icons' commands to use a 24x24 viewbox. This function does this:
198
+ # it converts a path that exists in the given viewbox into a standard 24x24
199
+ # viewbox.
200
+ viewbox = viewbox.split(' ')
201
+ norm = max(int(viewbox[-1]), int(viewbox[-2]))
202
+
203
+ if int(viewbox[-1]) > int(viewbox[-2]):
204
+ add_to_y = 0
205
+ add_to_x = abs(int(viewbox[-1]) - int(viewbox[-2])) / 2
206
+ else:
207
+ add_to_y = abs(int(viewbox[-1]) - int(viewbox[-2])) / 2
208
+ add_to_x = 0
209
+
210
+ new_path = []
211
+ for command in path:
212
+ if command[0] == 'a':
213
+ new_path.append([command[0]] + _normalize_args(command[1:3], norm)
214
+ + command[3:6] + _normalize_args(command[6:], norm))
215
+ elif command[0] == 'A':
216
+ new_path.append([command[0]] + _normalize_args(command[1:3], norm)
217
+ + command[3:6] + _normalize_args(command[6:], norm, add=(add_to_x, add_to_y)))
218
+ elif command[0] == 'V':
219
+ new_path.append([command[0]] + _normalize_args(command[1:], norm, add=(add_to_x, add_to_y), flip=True))
220
+ elif command[0] == command[0].upper():
221
+ new_path.append([command[0]] + _normalize_args(command[1:], norm, add=(add_to_x, add_to_y)))
222
+ elif command[0] in 'zZ':
223
+ new_path.append([command[0]])
224
+ else:
225
+ new_path.append([command[0]] + _normalize_args(command[1:], norm))
226
+
227
+ return new_path
228
+
229
+
230
+ def _convert_args(args, curr_pos, cmd):
231
+ """Converts given args to relative values."""
232
+ # NOTE: glyphs only use a very small subset of commands (L, C, M, and Z -- I
233
+ # believe). So I'm not handling A and H for now.
234
+ if cmd in 'AH':
235
+ raise NotImplementedError('These commands have >6 args (not supported).')
236
+
237
+ new_args = []
238
+ for i, arg in enumerate(args):
239
+ x_or_y = i % 2
240
+ if cmd == 'H':
241
+ x_or_y = (i + 1) % 2
242
+ new_args.append(str(float(arg) - curr_pos[x_or_y]))
243
+
244
+ return new_args
245
+
246
+
247
+ def _update_curr_pos(curr_pos, cmd, start_of_path):
248
+ """Calculate the position of the pen after cmd is applied."""
249
+ if cmd[0] in 'ml':
250
+ curr_pos = [curr_pos[0] + float(cmd[1]), curr_pos[1] + float(cmd[2])]
251
+ if cmd[0] == 'm':
252
+ start_of_path = curr_pos
253
+ elif cmd[0] in 'z':
254
+ curr_pos = start_of_path
255
+ elif cmd[0] in 'h':
256
+ curr_pos = [curr_pos[0] + float(cmd[1]), curr_pos[1]]
257
+ elif cmd[0] in 'v':
258
+ curr_pos = [curr_pos[0], curr_pos[1] + float(cmd[1])]
259
+ elif cmd[0] in 'ctsqa':
260
+ curr_pos = [curr_pos[0] + float(cmd[-2]), curr_pos[1] + float(cmd[-1])]
261
+
262
+ return curr_pos, start_of_path
263
+
264
+
265
+ def _make_relative(cmds):
266
+ """Convert commands in a path to relative positioning."""
267
+ curr_pos = (0.0, 0.0)
268
+ start_of_path = (0.0, 0.0)
269
+ new_cmds = []
270
+ for cmd in cmds:
271
+ if cmd[0].lower() == cmd[0]:
272
+ new_cmd = cmd
273
+ elif cmd[0].lower() == 'z':
274
+ new_cmd = [cmd[0].lower()]
275
+ else:
276
+ new_cmd = [cmd[0].lower()] + _convert_args(cmd[1:], curr_pos, cmd=cmd[0])
277
+ new_cmds.append(new_cmd)
278
+ curr_pos, start_of_path = _update_curr_pos(curr_pos, new_cmd, start_of_path)
279
+ return new_cmds
280
+
281
+
282
+ def _is_to_left_of(pt1, pt2):
283
+ pt1_norm = (pt1[0]**2 + pt1[1]**2)
284
+ pt2_norm = (pt2[0]**2 + pt2[1]**2)
285
+ return pt1[1] < pt2[1] or (pt1_norm == pt2_norm and pt1[0] < pt2[0])
286
+
287
+
288
+ def _get_leftmost_point(path):
289
+ """Returns the leftmost, topmost point of the path."""
290
+ leftmost = (float('inf'), float('inf'))
291
+ idx = -1
292
+
293
+ for i, cmd in enumerate(path):
294
+ if len(cmd) > 1:
295
+ endpoint = cmd[-2:]
296
+ if _is_to_left_of(endpoint, leftmost):
297
+ leftmost = endpoint
298
+ idx = i
299
+
300
+ return leftmost, idx
301
+
302
+
303
+ def _separate_substructures(path):
304
+ """Returns a list of subpaths, each representing substructures the glyph."""
305
+ substructures = []
306
+ curr = []
307
+ for cmd in path:
308
+ if cmd[0] in 'mM' and curr:
309
+ substructures.append(curr)
310
+ curr = []
311
+ curr.append(cmd)
312
+ if curr:
313
+ substructures.append(curr)
314
+ return substructures
315
+
316
+
317
+ def _is_clockwise(subpath):
318
+ """Returns whether the given subpath is clockwise-oriented."""
319
+ pts = [cmd[-2:] for cmd in subpath]
320
+ det = 0
321
+ for i in range(len(pts) - 1):
322
+ det += np.linalg.det(pts[i:i + 2])
323
+ return det > 0
324
+
325
+
326
+ def _make_clockwise(subpath):
327
+ """Inverts the cardinality of the given subpath."""
328
+ new_path = [subpath[0]]
329
+ other_cmds = list(reversed(subpath[1:]))
330
+ for i, cmd in enumerate(other_cmds):
331
+ if i + 1 == len(other_cmds):
332
+ where_we_were = subpath[0][-2:]
333
+ else:
334
+ where_we_were = other_cmds[i + 1][-2:]
335
+
336
+ if len(cmd) > 3:
337
+ new_cmd = [cmd[0], cmd[3], cmd[4], cmd[1], cmd[2],
338
+ where_we_were[0], where_we_were[1]]
339
+ else:
340
+ new_cmd = [cmd[0], where_we_were[0], where_we_were[1]]
341
+
342
+ new_path.append(new_cmd)
343
+ return new_path
344
+
345
+
346
+ def _canonicalize(path):
347
+ """Makes all paths start at top left, and go clockwise first."""
348
+ # convert args to floats
349
+ path = [[x[0]] + list(map(float, x[1:])) for x in path]
350
+
351
+ # _canonicalize each subpath separately
352
+ new_substructures = []
353
+ for subpath in _separate_substructures(path):
354
+ leftmost_point, leftmost_idx = _get_leftmost_point(subpath)
355
+ reordered = ([['M', leftmost_point[0], leftmost_point[1]]] + subpath[leftmost_idx + 1:] + subpath[1:leftmost_idx + 1])
356
+ new_substructures.append((reordered, leftmost_point))
357
+
358
+ new_path = []
359
+ first_substructure_done = False
360
+ should_flip_cardinality = False
361
+ for sp, _ in sorted(new_substructures, key=lambda x: (x[1][1], x[1][0])):
362
+ if not first_substructure_done:
363
+ # we're looking at the first substructure now, we can determine whether we
364
+ # will flip the cardniality of the whole icon or not
365
+ should_flip_cardinality = not _is_clockwise(sp)
366
+ first_substructure_done = True
367
+
368
+ if should_flip_cardinality:
369
+ sp = _make_clockwise(sp)
370
+
371
+ new_path.extend(sp)
372
+
373
+ # convert args to strs
374
+ path = [[x[0]] + list(map(str, x[1:])) for x in new_path]
375
+ return path
376
+
377
+
378
+ # ######### UTILS FOR CONVERTING TOKENIZED PATHS TO VECTORS ###########
379
+ def _path_to_vector(path, categorical=False):
380
+ """Converts path's commands to a series of vectors."""
381
+ # Notes:
382
+ # - The SimpleSVG dataset does not have any 't', 'q', 'Z', 'T', or 'Q'.
383
+ # Thus, we don't handle those here.
384
+ # - We also removed all 'z's.
385
+ # - The x-axis-rotation argument to a commands is always 0 in this
386
+ # dataset, so we ignore it
387
+
388
+ # Many commands have args that correspond to args in other commands.
389
+ # v __,__ _______________ ______________,_________ __,__ __,__ _,y
390
+ # h __,__ _______________ ______________,_________ __,__ __,__ x,_
391
+ # z __,__ _______________ ______________,_________ __,__ __,__ _,_
392
+ # a rx,ry x-axis-rotation large-arc-flag,sweepflag __,__ __,__ x,y
393
+ # l __,__ _______________ ______________,_________ __,__ __,__ x,y
394
+ # c __,__ _______________ ______________,_________ x1,y1 x2,y2 x,y
395
+ # m __,__ _______________ ______________,_________ __,__ __,__ x,y
396
+ # s __,__ _______________ ______________,_________ __,__ x2,y2 x,y
397
+
398
+ # So each command will be converted to a vector where the dimension is the
399
+ # minimal number of arguments to all commands:
400
+ # [rx, ry, large-arc-flag, sweepflag, x1, y1, x2, y2, x, y]
401
+ # If a command does not output a certain arg, it is set to 0.
402
+ # "l 5,5" becomes [0, 0, 0, 0, 0, 0, 0, 0, 5, 5]
403
+
404
+ # Also note, as of now we also output an extra dimension at index 0, which
405
+ # indicates which command is being outputted (integer).
406
+ new_path = []
407
+ for cmd in path:
408
+ new_path.append(_cmd_to_vector(cmd, categorical=categorical))
409
+ return new_path
410
+
411
+
412
+ def _cmd_to_vector(cmd_list, categorical=False):
413
+ """Converts the given command (given as a list) into a vector."""
414
+ # For description of how this conversion happens, see
415
+ # _path_to_vector docstring.
416
+ cmd = cmd_list[0]
417
+ args = cmd_list[1:]
418
+
419
+ if not categorical:
420
+ # integer, for MSE
421
+ command = [float(CMD_MAPPING[cmd])]
422
+ else:
423
+ # one hot + 1 dim for EOS.
424
+ command = [0.0] * (len(CMDS_LIST) + 1)
425
+ command[CMD_MAPPING[cmd] + 1] = 1.0
426
+
427
+ arguments = [0.0] * 10
428
+ if cmd in 'hH':
429
+ arguments[8] = float(args[0]) # x
430
+ elif cmd in 'vV':
431
+ arguments[9] = float(args[0]) # y
432
+ elif cmd in 'mMlLtT':
433
+ arguments[8] = float(args[0]) # x
434
+ arguments[9] = float(args[1]) # y
435
+ elif cmd in 'sSqQ':
436
+ arguments[6] = float(args[0]) # x2
437
+ arguments[7] = float(args[1]) # y2
438
+ arguments[8] = float(args[2]) # x
439
+ arguments[9] = float(args[3]) # y
440
+ elif cmd in 'cC':
441
+ arguments[4] = float(args[0]) # x1
442
+ arguments[5] = float(args[1]) # y1
443
+ arguments[6] = float(args[2]) # x2
444
+ arguments[7] = float(args[3]) # y2
445
+ arguments[8] = float(args[4]) # x
446
+ arguments[9] = float(args[5]) # y
447
+ elif cmd in 'aA':
448
+ arguments[0] = float(args[0]) # rx
449
+ arguments[1] = float(args[1]) # ry
450
+ # we skip x-axis-rotation
451
+ arguments[2] = float(args[3]) # large-arc-flag
452
+ arguments[3] = float(args[4]) # sweep-flag
453
+ # a does not have x1, y1, x2, y2 args
454
+ arguments[8] = float(args[5]) # x
455
+ arguments[9] = float(args[6]) # y
456
+
457
+ return command + arguments
458
+
459
+
460
+ ################## UTILS FOR RENDERING PATH INTO IMAGE #################
461
+ def _cubicbezier(x0, y0, x1, y1, x2, y2, x3, y3, n=40):
462
+ """Return n points along cubiz bezier with given control points."""
463
+ # from http://rosettacode.org/wiki/Bitmap/B%C3%A9zier_curves/Cubic
464
+ pts = []
465
+ for i in range(n + 1):
466
+ t = float(i) / float(n)
467
+ a = (1. - t)**3
468
+ b = 3. * t * (1. - t)**2
469
+ c = 3.0 * t**2 * (1.0 - t)
470
+ d = t**3
471
+
472
+ x = float(a * x0 + b * x1 + c * x2 + d * x3)
473
+ y = float(a * y0 + b * y1 + c * y2 + d * y3)
474
+ pts.append((x, y))
475
+ return list(zip(*pts))
476
+
477
+
478
+ def _update_pos(curr_pos, end_pos, absolute):
479
+ if absolute:
480
+ return end_pos
481
+ return curr_pos[0] + end_pos[0], curr_pos[1] + end_pos[1]
482
+
483
+
484
+ def constant_color(*unused_args):
485
+ return np.array([255, 255, 255])
486
+
487
+
488
+ def _render_cubic(canvas, curr_pos, c_args, absolute, color):
489
+ """Renders a cubic bezier curve in the given canvas."""
490
+ if not absolute:
491
+ c_args[0] += curr_pos[0]
492
+ c_args[1] += curr_pos[1]
493
+ c_args[2] += curr_pos[0]
494
+ c_args[3] += curr_pos[1]
495
+ c_args[4] += curr_pos[0]
496
+ c_args[5] += curr_pos[1]
497
+ x, y = _cubicbezier(curr_pos[0], curr_pos[1],
498
+ c_args[0], c_args[1],
499
+ c_args[2], c_args[3],
500
+ c_args[4], c_args[5])
501
+ max_possible = len(canvas)
502
+ x = [int(round(x_)) for x_ in x]
503
+ y = [int(round(y_)) for y_ in y]
504
+
505
+ def within_range(x):
506
+ return 0 <= x < max_possible
507
+
508
+ filtered = [(x_, y_) for x_, y_ in zip(x, y)
509
+ if within_range(x_) and within_range(y_)]
510
+ if not filtered:
511
+ return
512
+ x, y = list(zip(*filtered))
513
+ canvas[y, x, :] = color
514
+
515
+
516
+ def _render_line(canvas, curr_pos, l_args, absolute, color):
517
+ """Renders a line in the given canvas."""
518
+ end_point = l_args
519
+ if not absolute:
520
+ end_point[0] += curr_pos[0]
521
+ end_point[1] += curr_pos[1]
522
+ rr, cc, val = draw.line_aa(int(curr_pos[0]), int(curr_pos[1]),
523
+ int(end_point[0]), int(end_point[1]))
524
+
525
+ max_possible = len(canvas)
526
+
527
+ def within_range(x):
528
+ return 0 <= x < max_possible
529
+
530
+ filtered = [(x, y, v) for x, y, v in zip(rr, cc, val)
531
+ if within_range(x) and within_range(y)]
532
+ if not filtered:
533
+ return
534
+ rr, cc, val = list(zip(*filtered))
535
+ val = [(v * color) for v in val]
536
+ canvas[cc, rr, :] = val
537
+
538
+
539
+ def _per_step_render(path, absolute=False, color=constant_color):
540
+ """Render the icon's edges, given its path."""
541
+ def to_canvas_size(l):
542
+ return [float(f) * (64. / 24.) for f in l]
543
+
544
+ canvas = np.zeros((64, 64, 3))
545
+ curr_pos = (0.0, 0.0)
546
+ for i, cmd in enumerate(path):
547
+ if not cmd:
548
+ continue
549
+ if cmd[0] in 'mM':
550
+ curr_pos = _update_pos(curr_pos, to_canvas_size(cmd[-2:]), absolute)
551
+ elif cmd[0] in 'cC':
552
+ _render_cubic(canvas, curr_pos, to_canvas_size(cmd[1:]), absolute, color(i, 55))
553
+ curr_pos = _update_pos(curr_pos, to_canvas_size(cmd[-2:]), absolute)
554
+ elif cmd[0] in 'lL':
555
+ _render_line(canvas, curr_pos, to_canvas_size(cmd[1:]), absolute, color(i, 55))
556
+ curr_pos = _update_pos(curr_pos, to_canvas_size(cmd[1:]), absolute)
557
+
558
+ return canvas
559
+
560
+
561
+ def _zoom_out(path_list, add_baseline=0., per=22):
562
+ """Makes glyph slightly smaller in viewbox, makes some descenders visible."""
563
+ # assumes tensor is already unnormalized, and in long form
564
+ new_path = []
565
+ for command in path_list:
566
+ args = []
567
+ is_even = False
568
+ for arg in command[1:]:
569
+ if is_even:
570
+ args.append(str(float(arg) - ((24. - per) / 24.) * 64. / 4.))
571
+ is_even = False
572
+ else:
573
+ args.append(str(float(arg) - add_baseline))
574
+ is_even = True
575
+ new_path.append([command[0]] + args)
576
+ return new_path
577
+
578
+
579
+ ##################### UTILS FOR PROCESSING VECTORS ################
580
+ def _append_eos(sample, categorical, feature_dim):
581
+ if not categorical:
582
+ eos = -1 * np.ones(feature_dim)
583
+ else:
584
+ eos = np.zeros(feature_dim)
585
+ eos[0] = 1.0
586
+ sample.append(eos)
587
+ return sample
588
+
589
+
590
+ def _make_simple_cmds_long(out):
591
+ """Converts svg decoder output to format required by some render functions."""
592
+ # out has 10 dims
593
+ # the first 4 are respectively dims 0, 4, 5, 9 of the full 20-dim onehot vec
594
+ # the latter 6 are the 6 last dims of the 10-dim arg vec
595
+ shape_minus_dim = list(np.shape(out))[:-1]
596
+ return np.concatenate([out[..., :1],
597
+ np.zeros(shape_minus_dim + [3]),
598
+ out[..., 1:3],
599
+ np.zeros(shape_minus_dim + [3]),
600
+ out[..., 3:4],
601
+ np.zeros(shape_minus_dim + [14]),
602
+ out[..., 4:]], -1)
603
+
604
+
605
+ ################# UTILS FOR CONVERTING VECTORS TO SVGS ########################
606
+ def _vector_to_svg(vectors, stop_at_eos=False, categorical=False):
607
+ """Tranforms a given vector to an svg string."""
608
+ new_path = []
609
+ for vector in vectors:
610
+ if stop_at_eos:
611
+ if categorical:
612
+ try:
613
+ is_eos = np.argmax(vector[:len(CMDS_LIST) + 1]) == 0
614
+ except Exception:
615
+ raise Exception(vector)
616
+ else:
617
+ is_eos = vector[0] < -0.5
618
+
619
+ if is_eos:
620
+ break
621
+ new_path.append(' '.join(_vector_to_cmd(vector, categorical=categorical)))
622
+ new_path = ' '.join(new_path)
623
+ return SVG_PREFIX_BIG + PATH_PREFIX_1 + new_path + PATH_POSFIX_1 + SVG_POSFIX
624
+
625
+
626
+ def _vector_to_cmd(vector, categorical=False, return_floats=False):
627
+ """Does the inverse transformation as _cmd_to_vector()."""
628
+ cast_fn = float if return_floats else str
629
+ if categorical:
630
+ command = vector[:len(CMDS_LIST) + 1],
631
+ arguments = vector[len(CMDS_LIST) + 1:]
632
+ cmd_idx = np.argmax(command) - 1
633
+ else:
634
+ command, arguments = vector[:1], vector[1:]
635
+ cmd_idx = int(round(command[0]))
636
+
637
+ if cmd_idx < -0.5:
638
+ # EOS
639
+ return []
640
+ if cmd_idx >= len(CMDS_LIST):
641
+ cmd_idx = len(CMDS_LIST) - 1
642
+
643
+ cmd = CMDS_LIST[cmd_idx]
644
+ cmd = cmd.upper()
645
+ cmd_list = [cmd]
646
+
647
+ if cmd in 'hH':
648
+ cmd_list.append(cast_fn(arguments[8])) # x
649
+ elif cmd in 'vV':
650
+ cmd_list.append(cast_fn(arguments[9])) # y
651
+ elif cmd in 'mMlLtT':
652
+ cmd_list.append(cast_fn(arguments[8])) # x
653
+ cmd_list.append(cast_fn(arguments[9])) # y
654
+ elif cmd in 'sSqQ':
655
+ cmd_list.append(cast_fn(arguments[6])) # x2
656
+ cmd_list.append(cast_fn(arguments[7])) # y2
657
+ cmd_list.append(cast_fn(arguments[8])) # x
658
+ cmd_list.append(cast_fn(arguments[9])) # y
659
+ elif cmd in 'cC':
660
+ cmd_list.append(cast_fn(arguments[4])) # x1
661
+ cmd_list.append(cast_fn(arguments[5])) # y1
662
+ cmd_list.append(cast_fn(arguments[6])) # x2
663
+ cmd_list.append(cast_fn(arguments[7])) # y2
664
+ cmd_list.append(cast_fn(arguments[8])) # x
665
+ cmd_list.append(cast_fn(arguments[9])) # y
666
+ elif cmd in 'aA':
667
+ cmd_list.append(cast_fn(arguments[0])) # rx
668
+ cmd_list.append(cast_fn(arguments[1])) # ry
669
+ # x-axis-rotation is always 0
670
+ cmd_list.append(cast_fn('0'))
671
+ # the following two flags are binary.
672
+ cmd_list.append(cast_fn(1 if arguments[2] > 0.5 else 0)) # large-arc-flag
673
+ cmd_list.append(cast_fn(1 if arguments[3] > 0.5 else 0)) # sweep-flag
674
+ cmd_list.append(cast_fn(arguments[8])) # x
675
+ cmd_list.append(cast_fn(arguments[9])) # y
676
+
677
+ return cmd_list
678
+
679
+
680
+ ############## UTILS FOR CONVERTING SVGS/VECTORS TO IMAGES ###################
681
+
682
+ # From Infer notebook
683
+ start = ("""<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www."""
684
+ """w3.org/1999/xlink" width="256px" height="256px" style="-ms-trans"""
685
+ """form: rotate(360deg); -webkit-transform: rotate(360deg); transfo"""
686
+ """rm: rotate(360deg);" preserveAspectRatio="xMidYMid meet" viewBox"""
687
+ """="0 0 24 30"><path d=\"""")
688
+ end = """\" fill="currentColor"/></svg>"""
689
+
690
+ COMMAND_RX = re.compile("([MmLlHhVvCcSsQqTtAaZz])")
691
+ FLOAT_RX = re.compile("[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?") # noqa
692
+
693
+
694
+ def svg_html_to_path_string(svg):
695
+ return svg.replace(start, '').replace(end, '')
696
+
697
+
698
+ def _tokenize(pathdef):
699
+ """Returns each svg token from path list."""
700
+ # e.g.: 'm0.1-.5c0,6' -> m', '0.1, '-.5', 'c', '0', '6'
701
+ for x in COMMAND_RX.split(pathdef):
702
+ if x != '' and x in 'MmLlHhVvCcSsQqTtAaZz':
703
+ yield x
704
+ for token in FLOAT_RX.findall(x):
705
+ yield token
706
+
707
+
708
+ def path_string_to_tokenized_commands(path):
709
+ """Tokenizes the given path string.
710
+
711
+ E.g.:
712
+ Given M 0.5 0.5 l 0.25 0.25 z
713
+ Returns [['M', '0.5', '0.5'], ['l', '0.25', '0.25'], ['z']]
714
+ """
715
+ new_path = []
716
+ current_cmd = []
717
+ for token in _tokenize(path):
718
+ if len(current_cmd) > 0:
719
+ if token in 'MmLlHhVvCcSsQqTtAaZz':
720
+ # cmd ended, convert to vector and add to new_path
721
+ new_path.append(current_cmd)
722
+ current_cmd = [token]
723
+ else:
724
+ # add arg to command
725
+ current_cmd.append(token)
726
+ else:
727
+ # add to start new cmd
728
+ current_cmd.append(token)
729
+
730
+ if current_cmd:
731
+ # process command still unprocessed
732
+ new_path.append(current_cmd)
733
+
734
+ return new_path
735
+
736
+
737
+ def separate_substructures(tokenized_commands):
738
+ """Returns a list of SVG substructures."""
739
+ # every moveTo command starts a new substructure
740
+ # an SVG substructure is a subpath that closes on itself
741
+ # such as the outter and the inner edge of the character `o`
742
+ substructures = []
743
+ curr = []
744
+ for cmd in tokenized_commands:
745
+ if cmd[0] in 'mM' and len(curr) > 0:
746
+ substructures.append(curr)
747
+ curr = []
748
+ curr.append(cmd)
749
+ if len(curr) > 0:
750
+ substructures.append(curr)
751
+ return substructures
752
+
753
+
754
+ def postprocess(svg, dist_thresh=2., skip=False):
755
+ path = svg_html_to_path_string(svg)
756
+ svg_template = svg.replace(path, '{}')
757
+ tokenized_commands = path_string_to_tokenized_commands(path)
758
+
759
+ def dist(a, b):
760
+ return np.sqrt((float(a[0]) - float(b[0]))**2 + (float(a[1]) - float(b[1]))**2)
761
+
762
+ def are_close_together(a, b, t):
763
+ return dist(a, b) < t
764
+
765
+ # first, go through each start/end point and merge if they're close enough
766
+ # together (that is, make end point the same as the start point).
767
+ # TODO: there are better ways of doing this, in a way that propagates error
768
+ # back (so if total error is 0.2, go through all N commands in this
769
+ # substructure and fix each by 0.2/N (unless they have 0 vertical change))
770
+ substructures = separate_substructures(tokenized_commands)
771
+
772
+ previous_substructure_endpoint = (0., 0.,)
773
+ for substructure in substructures:
774
+ # first, if the last substructure's endpoint was updated, we must update
775
+ # the start point of this one to reflect the opposite update
776
+ substructure[0][-2] = str(float(substructure[0][-2]) -
777
+ previous_substructure_endpoint[0])
778
+ substructure[0][-1] = str(float(substructure[0][-1]) -
779
+ previous_substructure_endpoint[1])
780
+
781
+ start = list(map(float, substructure[0][-2:]))
782
+ curr_pos = (0., 0.)
783
+ for cmd in substructure:
784
+ curr_pos, _ = _update_curr_pos(curr_pos, cmd, (0., 0.))
785
+ if are_close_together(start, curr_pos, dist_thresh):
786
+ new_point = np.array(start)
787
+ previous_substructure_endpoint = ((new_point[0] - curr_pos[0]),
788
+ (new_point[1] - curr_pos[1]))
789
+ substructure[-1][-2] = str(float(substructure[-1][-2]) +
790
+ (new_point[0] - curr_pos[0]))
791
+ substructure[-1][-1] = str(float(substructure[-1][-1]) +
792
+ (new_point[1] - curr_pos[1]))
793
+ if substructure[-1][0] in 'cC':
794
+ substructure[-1][-4] = str(float(substructure[-1][-4]) +
795
+ (new_point[0] - curr_pos[0]))
796
+ substructure[-1][-3] = str(float(substructure[-1][-3]) +
797
+ (new_point[1] - curr_pos[1]))
798
+
799
+ if skip:
800
+ return svg_template.format(' '.join([' '.join(' '.join(cmd) for cmd in s)
801
+ for s in substructures]))
802
+
803
+ def cosa(x, y):
804
+ return (x[0] * y[0] + x[1] * y[1]) / ((np.sqrt(x[0]**2 + x[1]**2) * np.sqrt(y[0]**2 + y[1]**2)))
805
+
806
+ def rotate(a, x, y):
807
+ return (x * np.cos(a) - y * np.sin(a), y * np.cos(a) + x * np.sin(a))
808
+ # second, gotta find adjacent bezier curves and, if their control points
809
+ # are well enough aligned, fully align them
810
+ for substructure in substructures:
811
+ curr_pos = (0., 0.)
812
+ new_curr_pos, _ = _update_curr_pos((0., 0.,), substructure[0], (0., 0.))
813
+
814
+ for cmd_idx in range(1, len(substructure)):
815
+ prev_cmd = substructure[cmd_idx-1]
816
+ cmd = substructure[cmd_idx]
817
+
818
+ new_new_curr_pos, _ = _update_curr_pos(
819
+ new_curr_pos, cmd, (0., 0.))
820
+
821
+ if cmd[0] == 'c':
822
+ if prev_cmd[0] == 'c':
823
+ # check the vectors and update if needed
824
+ # previous control pt wrt new curr point
825
+ prev_ctr_point = (curr_pos[0] + float(prev_cmd[3]) - new_curr_pos[0],
826
+ curr_pos[1] + float(prev_cmd[4]) - new_curr_pos[1])
827
+ ctr_point = (float(cmd[1]), float(cmd[2]))
828
+
829
+ if -1. < cosa(prev_ctr_point, ctr_point) < -0.95:
830
+ # calculate exact angle between the two vectors
831
+ angle_diff = (np.pi - np.arccos(cosa(prev_ctr_point, ctr_point)))/2
832
+
833
+ # rotate each vector by angle/2 in the correct direction for each.
834
+ sign = np.sign(np.cross(prev_ctr_point, ctr_point))
835
+ new_ctr_point = rotate(sign * angle_diff, *ctr_point)
836
+ new_prev_ctr_point = rotate(-sign * angle_diff, *prev_ctr_point)
837
+
838
+ # override the previous control points
839
+ # (which has to be wrt previous curr position)
840
+ substructure[cmd_idx-1][3] = str(new_prev_ctr_point[0] -
841
+ curr_pos[0] + new_curr_pos[0])
842
+ substructure[cmd_idx-1][4] = str(new_prev_ctr_point[1] -
843
+ curr_pos[1] + new_curr_pos[1])
844
+ substructure[cmd_idx][1] = str(new_ctr_point[0])
845
+ substructure[cmd_idx][2] = str(new_ctr_point[1])
846
+
847
+ curr_pos = new_curr_pos
848
+ new_curr_pos = new_new_curr_pos
849
+
850
+ return svg_template.format(' '.join([' '.join(' '.join(cmd) for cmd in s)
851
+ for s in substructures]))
852
+
853
+
854
+ # def get_means_stdevs(data_dir):
855
+ # """Returns the means and stdev saved in data_dir."""
856
+ # if data_dir not in means_stdevs:
857
+ # with tf.gfile.Open(os.path.join(data_dir, 'mean.npz'), 'r') as f:
858
+ # mean_npz = np.load(f)
859
+ # with tf.gfile.Open(os.path.join(data_dir, 'stdev.npz'), 'r') as f:
860
+ # stdev_npz = np.load(f)
861
+ # means_stdevs[data_dir] = (mean_npz, stdev_npz)
862
+ # return means_stdevs[data_dir]
863
+
864
+
865
+ def render(tensor, data_dir=None):
866
+ """Converts SVG decoder output into HTML svg."""
867
+ # undo normalization
868
+ # mean_npz, stdev_npz = get_means_stdevs(data_dir)
869
+ # tensor = (tensor * stdev_npz) + mean_npz
870
+
871
+ # convert to html
872
+ tensor = _make_simple_cmds_long(tensor)
873
+ # vector = np.squeeze(np.squeeze(tensor, 0), 2)
874
+ html = _vector_to_svg(tensor, stop_at_eos=True, categorical=True)
875
+
876
+ # some aesthetic postprocessing
877
+ html = postprocess(html)
878
+ html = html.replace('256px', '50px')
879
+
880
+ return html
881
+
882
+ ###############
883
+
884
+
885
+ def convert_to_svg(decoder_output, categorical=False):
886
+ converted = []
887
+ for example in decoder_output:
888
+ converted.append(_vector_to_svg(example, True, categorical=categorical))
889
+ return np.array(converted)
890
+
891
+
892
+ def create_image_conversion_fn(max_outputs, categorical=False):
893
+ """Binds the number of outputs to the image conversion fn (to svg or png)."""
894
+ def convert_to_svg(decoder_output):
895
+ converted = []
896
+ for example in decoder_output:
897
+ if len(converted) == max_outputs:
898
+ break
899
+ converted.append(_vector_to_svg(example, True, categorical=categorical))
900
+ return np.array(converted)
901
+
902
+ return convert_to_svg
903
+
904
+
905
+ ################### UTILS FOR CREATING TF SUMMARIES ##########################
906
+ def _make_encoded_image(img_tensor):
907
+ pil_img = Image.fromarray(np.squeeze(img_tensor * 255).astype(np.uint8), mode='L')
908
+ buff = io.BytesIO()
909
+ pil_img.save(buff, format='png')
910
+ encoded_image = buff.getvalue()
911
+ return encoded_image
912
+
913
+
914
+ ################### CHECK GLYPH/PATH VALID ##############################################
915
+ def is_valid_glyph(g):
916
+ is_09 = 48 <= g['uni'] <= 57
917
+ is_capital_az = 65 <= g['uni'] <= 90
918
+ is_az = 97 <= g['uni'] <= 122
919
+ is_valid_dims = g['width'] != 0 and g['vwidth'] != 0
920
+ return (is_09 or is_capital_az or is_az) and is_valid_dims
921
+
922
+
923
+ def is_valid_path(pathunibfp):
924
+ return pathunibfp[0] and len(pathunibfp[0]) <= MAX_SEQ_LEN
925
+
926
+
927
+ ################### DATASET PROCESSING #######################################
928
+ def convert_to_path(g):
929
+ """Converts SplineSet in SFD font to str path."""
930
+ path = _sfd_to_path_list(g)
931
+ path = _add_missing_cmds(path, remove_zs=False)
932
+ path = _normalize_based_on_viewbox(path, '0 0 {} {}'.format(g['width'], g['vwidth']))
933
+ return path, g['uni'], g['binary_fp']
934
+
935
+
936
+ def create_example(pathunibfp):
937
+ """Bulk of dataset processing. Converts str path to np array"""
938
+ path, uni, binary_fp = pathunibfp
939
+ final = {}
940
+
941
+ # zoom out
942
+ path = _zoom_out(path)
943
+ # make clockwise
944
+ path = _canonicalize(path)
945
+
946
+ # render path for training
947
+ final['rendered'] = _per_step_render(path, absolute=True)
948
+
949
+ # make path relative
950
+ # path = _make_relative(path)
951
+ # convert to vector
952
+ vector = _path_to_vector(path, categorical=True)
953
+ # make simple vector
954
+ vector = np.array(vector)
955
+ vector = np.concatenate([np.take(vector, [0, 4, 5, 9], axis=-1), vector[..., -6:]], axis=-1)
956
+
957
+ # count some stats
958
+ final['seq_len'] = np.shape(vector)[0]
959
+ # final['class'] = int(_map_uni_to_alphanum(uni))
960
+ final['class'] = int(_map_uni_to_alpha(uni)) # be advised that the class is useless bcz it is all 0
961
+ final['binary_fp'] = str(binary_fp)
962
+
963
+ # append eos
964
+ vector = _append_eos(vector.tolist(), True, 10)
965
+
966
+ # pad path to MAX_SEQ_LEN + 1 (with eos)
967
+ final['sequence'] = np.concatenate((vector, np.zeros(((MAX_SEQ_LEN - final['seq_len']), 10))), 0)
968
+
969
+ # make pure list:
970
+ # use last channel only
971
+ final['rendered'] = np.reshape(final['rendered'][..., 0], [64 * 64]).astype(np.float32).tolist()
972
+ final['sequence'] = np.reshape(final['sequence'], [(MAX_SEQ_LEN + 1) * 10]).astype(np.float32).tolist()
973
+ final['class'] = np.reshape(final['class'], [1]).astype(np.int64).tolist()
974
+ final['seq_len'] = np.reshape(final['seq_len'], [1]).astype(np.int64).tolist()
975
+ return final
976
+
977
+
978
+ def mean_to_example(mean_stdev):
979
+ """Converts the found mean and stdev to example."""
980
+ # mean_stdev is a dict
981
+ mean_stdev['mean'] = np.reshape(mean_stdev['mean'], [10]).astype(np.float32).tolist()
982
+ mean_stdev['variance'] = np.reshape(mean_stdev['variance'], [10]).astype(np.float32).tolist()
983
+ mean_stdev['stddev'] = np.reshape(mean_stdev['stddev'], [10]).astype(np.float32).tolist()
984
+ mean_stdev['count'] = np.reshape(mean_stdev['count'], [1]).astype(np.int64).tolist()
985
+ return mean_stdev
986
+
987
+
988
+ def convert_simple_vector_to_path(seq):
989
+ path=[]
990
+ for i in range(seq.shape[0]):
991
+ path_i=[]
992
+ cmd = np.argmax(seq[i][:4])
993
+ p0 = seq[i][4:6]
994
+ p1 = seq[i][6:8]
995
+ p2 = seq[i][8:10]
996
+ if cmd == 0:
997
+ break
998
+ elif cmd == 1:
999
+ path_i.append('M')
1000
+ path_i.append(str(p2[0]))
1001
+ path_i.append(str(p2[1]))
1002
+ elif cmd == 2:
1003
+ path_i.append('L')
1004
+ path_i.append(str(p2[0]))
1005
+ path_i.append(str(p2[1]))
1006
+ elif cmd == 3:
1007
+ path_i.append('C')
1008
+ path_i.append(str(p0[0]))
1009
+ path_i.append(str(p0[1]))
1010
+ path_i.append(str(p1[0]))
1011
+ path_i.append(str(p1[1]))
1012
+ path_i.append(str(p2[0]))
1013
+ path_i.append(str(p2[1]))
1014
+ else:
1015
+ print("wrong!!! to path")
1016
+ path.append(path_i)
1017
+ return path
1018
+
1019
+ def clockwise(seq):
1020
+ path = convert_simple_vector_to_path(seq)
1021
+ path = _canonicalize(path)
1022
+ ret = {}
1023
+ vector = _path_to_vector(path, categorical=True)
1024
+ vector = np.array(vector)
1025
+ vector = np.concatenate([np.take(vector, [0, 4, 5, 9], axis=-1), vector[..., -6:]], axis=-1)
1026
+ ret['seq_len'] = np.shape(vector)[0]
1027
+ vector = _append_eos(vector.tolist(), True, 10)
1028
+ ret['sequence'] = np.concatenate((vector, np.zeros(((MAX_SEQ_LEN - ret['seq_len']), 10))), 0)
1029
+ return ret
1030
+
1031
+ ################### CHECK VALID ##############################################
1032
+ class MeanStddev:
1033
+ """Accumulator to compute the mean/stdev of svg commands."""
1034
+
1035
+ def create_accumulator(self):
1036
+ curr_sum = np.zeros([10])
1037
+ sum_sq = np.zeros([10])
1038
+ return (curr_sum, sum_sq, 0) # x, x^2, count
1039
+
1040
+ def add_input(self, sum_count, new_input):
1041
+ (curr_sum, sum_sq, count) = sum_count
1042
+ # new_input is a dict with keys = ['seq_len', 'sequence']
1043
+ new_seq_len = new_input['seq_len'][0] # Line #754 'seq_len' is a list of one int
1044
+ assert isinstance(new_seq_len, int), print(type(new_seq_len))
1045
+
1046
+ # remove padding and eos from sequence
1047
+ assert isinstance(new_input['sequence'], list), print(type(new_input['sequence']))
1048
+ new_input_np = np.reshape(np.array(new_input['sequence']), [-1, 10])
1049
+ assert isinstance(new_input_np, np.ndarray), print(type())
1050
+ assert new_input_np.shape[0] >= new_seq_len
1051
+ new_input_np = new_input_np[:new_seq_len, :]
1052
+
1053
+ # accumulate new_sum and new_sum_sq
1054
+ new_sum = np.sum([curr_sum, np.sum(new_input_np, axis=0)], axis=0)
1055
+ new_sum_sq = np.sum([sum_sq, np.sum(np.power(new_input_np, 2), axis=0)],
1056
+ axis=0)
1057
+ return new_sum, new_sum_sq, count + new_seq_len
1058
+
1059
+ def merge_accumulators(self, accumulators):
1060
+ curr_sums, sum_sqs, counts = list(zip(*accumulators))
1061
+ return np.sum(curr_sums, axis=0), np.sum(sum_sqs, axis=0), np.sum(counts)
1062
+
1063
+ def extract_output(self, sum_count):
1064
+ (curr_sum, curr_sum_sq, count) = sum_count
1065
+ if count:
1066
+ mean = np.divide(curr_sum, count)
1067
+ variance = np.divide(curr_sum_sq, count) - np.power(mean, 2)
1068
+ # -ve value could happen due to rounding
1069
+ variance = np.max([variance, np.zeros(np.shape(variance))], axis=0)
1070
+ stddev = np.sqrt(variance)
1071
+ return {
1072
+ 'mean': mean,
1073
+ 'variance': variance,
1074
+ 'stddev': stddev,
1075
+ 'count': count
1076
+ }
1077
+ else:
1078
+ return {
1079
+ 'mean': float('NaN'),
1080
+ 'variance': float('NaN'),
1081
+ 'stddev': float('NaN'),
1082
+ 'count': 0
1083
  }
data_utils/svg_utils_backup.py CHANGED
@@ -1,1174 +1,1174 @@
1
-
2
- # Copyright 2020 The Magenta Authors.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- import pdb
16
- # Lint as: python3
17
- """Defines the Material Design Icons Problem."""
18
- import io
19
- import numpy as np
20
- import re
21
-
22
- from PIL import Image
23
- from itertools import zip_longest
24
- from skimage import draw
25
- import sys
26
-
27
- SVG_PREFIX_BIG = ('<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="'
28
- 'http://www.w3.org/1999/xlink" width="256px" height="256px"'
29
- ' style="-ms-transform: rotate(360deg); -webkit-transform:'
30
- ' rotate(360deg); transform: rotate(360deg);" '
31
- 'preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24">')
32
- PATH_PREFIX_1 = '<path d="'
33
- PATH_POSFIX_1 = '" fill="currentColor"/>'
34
- SVG_POSFIX = '</svg>'
35
-
36
- NUM_ARGS = {'v': 1, 'V': 1, 'h': 1, 'H': 1, 'a': 7, 'A': 7, 'l': 2, 'L': 2,
37
- 't': 2, 'T': 2, 'c': 6, 'C': 6, 'm': 2, 'M': 2, 's': 4, 'S': 4,
38
- 'q': 4, 'Q': 4, 'z': 0}
39
- # in order of arg complexity, with absolutes clustered
40
- # recall we don't handle all commands (see docstring)
41
-
42
- #note args:
43
- # v, h: vertical horizental lines
44
- # a: elliptical Arc 椭圆
45
- # l: lineto
46
- # t: smooth quadratic Bézier curveto 2次贝塞尔曲线
47
- # c: curveto
48
- # m: moveto
49
- # s: smooth curveto
50
- # Q: quadratic Bézier curve 2次贝塞尔曲线
51
- # z: closepath
52
- #CMDS_LIST = 'zhvmltsqcaHVMLTSQCA'
53
- CMDS_LIST = 'zHVMLTSQCAhvmltsqca'
54
- CMD_MAPPING = {cmd: i for i, cmd in enumerate(CMDS_LIST)}
55
-
56
- FEATURE_DIM = 10
57
-
58
-
59
- ############################### GENERAL UTILS #################################
60
- def grouper(iterable, batch_size, fill_value=None):
61
- """Helper method for returning batches of size batch_size of a dataset."""
62
- # grouper('ABCDEF', 3) -> 'ABC', 'DEF'
63
- args = [iter(iterable)] * batch_size
64
- return zip_longest(*args, fillvalue=fill_value)
65
-
66
-
67
- def _map_uni_to_alphanum(uni):
68
- """Maps [0-9 A-Z a-z] to numbers 0-62."""
69
- if 48 <= uni <= 57:
70
- return uni - 48
71
- elif 65 <= uni <= 90:
72
- return uni - 65 + 10
73
- return uni - 97 + 36
74
-
75
-
76
- def _map_uni_to_alpha(uni):
77
- """Maps [A-Z a-z] to numbers 0-52."""
78
- if 65 <= uni <= 90:
79
- return uni - 65
80
- return uni - 97 + 26
81
-
82
-
83
- ############# UTILS FOR CONVERTING SFD/SPLINESETS TO SVG PATHS ################
84
- def _get_spline(sfd):
85
- if 'SplineSet' not in sfd:
86
- return ''
87
- pro = sfd[sfd.index('SplineSet') + 10:] # 10 is the 'SplineSet'
88
- pro = pro[:pro.index('EndSplineSet')]
89
- return pro
90
-
91
-
92
- def _spline_to_path_list(spline, height, replace_with_prev=False):
93
- """Converts SplineSet to a list of tokenized commands in svg path."""
94
- path = []
95
- prev_xy = []
96
- for line in spline.splitlines():
97
- if not line:
98
- continue
99
- tokens = line.split(' ')
100
- cmd = tokens[-2]
101
- if cmd not in 'cml':
102
- # COMMAND NOT RECOGNIZED.
103
- return []
104
- # assert cmd in 'cml', 'Command not recognized: {}'.format(cmd)
105
- args = tokens[:-2]
106
- args = [float(x) for x in args if x]
107
-
108
- if replace_with_prev and cmd in 'c':
109
- args[:2] = prev_xy
110
- prev_xy = args[-2:]
111
-
112
- new_y_args = []
113
- for i, a in enumerate(args):
114
- if i % 2 == 1:
115
- new_y_args.append((height - a))
116
- else:
117
- new_y_args.append((a))
118
-
119
- path.append([cmd.upper()] + new_y_args)
120
- return path
121
-
122
-
123
- def _sfd_to_path_list(single, replace_with_prev=False):
124
- """Converts the given SFD glyph into a path."""
125
- return _spline_to_path_list(_get_spline(single['sfd']), single['vwidth'], replace_with_prev)
126
-
127
-
128
- #################### UTILS FOR PROCESSING TOKENIZED PATHS #####################
129
- def _add_missing_cmds(path, remove_zs=False):
130
- """Adds missing cmd tags to the commands in the svg."""
131
- # For instance, the command 'a' takes 7 arguments, but some SVGs declare:
132
- # a 1 2 3 4 5 6 7 8 9 10 11 12 13 14
133
- # Which is 14 arguments. This function converts the above to the equivalent:
134
- # a 1 2 3 4 5 6 7 a 8 9 10 11 12 13 14
135
- #
136
- # Note: if remove_zs is True, this also removes any occurences of z commands.
137
- new_path = []
138
- for cmd in path:
139
- if not remove_zs or cmd[0] not in 'Zz':
140
- for new_cmd in add_missing_cmd(cmd):
141
- new_path.append(new_cmd)
142
- return new_path
143
-
144
-
145
- def add_missing_cmd(command_list):
146
- """Adds missing cmd tags to the given command list."""
147
- # E.g.: given:
148
- # ['a', '0', '0', '0', '0', '0', '0', '0',
149
- # '0', '0', '0', '0', '0', '0', '0']
150
- # Converts to:
151
- # [['a', '0', '0', '0', '0', '0', '0', '0'],
152
- # ['a', '0', '0', '0', '0', '0', '0', '0']]
153
- # And returns a string that joins these elements with spaces.
154
- cmd_tag = command_list[0]
155
- args = command_list[1:]
156
-
157
- final_cmds = []
158
- for arg_batch in grouper(args, NUM_ARGS[cmd_tag]):
159
- final_cmds.append([cmd_tag] + list(arg_batch))
160
-
161
- if not final_cmds:
162
- # command has no args (e.g.: 'z')
163
- final_cmds = [[cmd_tag]]
164
-
165
- return final_cmds
166
-
167
-
168
- def _normalize_args(arglist, norm, add=None, flip=False):
169
- """Normalize the given args with the given norm value."""
170
- new_arglist = []
171
- for i, arg in enumerate(arglist):
172
- new_arg = float(arg)
173
-
174
- if add is not None:
175
- add_to_x, add_to_y = add
176
-
177
- # This argument is an x-coordinate if even, y-coordinate if odd
178
- # except when flip == True
179
- if i % 2 == 0:
180
- new_arg += add_to_y if flip else add_to_x
181
- else:
182
- new_arg += add_to_x if flip else add_to_y
183
-
184
- new_arglist.append(str(24 * new_arg / norm))
185
- return new_arglist
186
-
187
-
188
- def _normalize_based_on_viewbox(path, viewbox):
189
- """Normalizes all args in a path to a standard 24x24 viewbox."""
190
- # Each SVG lives in a 2D plane. The viewbox determines the region of that
191
- # plane that gets rendered. For instance, some designers may work with a
192
- # viewbox that's 24x24, others with one that's 100x100, etc.
193
-
194
- # Suppose I design the the letter "h" in the Arial style using a 100x100
195
- # viewbox (let's call it icon A). Let's suppose the icon has height 75. Then,
196
- # I design the same character using a 20x20 viewbox (call this icon B), with
197
- # height 15 (=75% of 20). This means that, when rendered, both icons with look
198
- # exactly the same, but the scale of the commands each icon is using is
199
- # different. For instance, if icon A has a command like "lineTo 100 100", the
200
- # equivalent command in icon B will be "lineTo 20 20".
201
-
202
- # In order to avoid this problem and bring all real values to the same scale,
203
- # I scale all icons' commands to use a 24x24 viewbox. This function does this:
204
- # it converts a path that exists in the given viewbox into a standard 24x24
205
- # viewbox.
206
- viewbox = viewbox.split(' ')
207
- norm = max(int(viewbox[-1]), int(viewbox[-2]))
208
-
209
- if int(viewbox[-1]) > int(viewbox[-2]):
210
- add_to_y = 0
211
- add_to_x = abs(int(viewbox[-1]) - int(viewbox[-2])) / 2
212
- else:
213
- add_to_y = abs(int(viewbox[-1]) - int(viewbox[-2])) / 2
214
- add_to_x = 0
215
-
216
- new_path = []
217
- for command in path:
218
- if command[0] == 'a':
219
- new_path.append([command[0]] + _normalize_args(command[1:3], norm)
220
- + command[3:6] + _normalize_args(command[6:], norm))
221
- elif command[0] == 'A':
222
- new_path.append([command[0]] + _normalize_args(command[1:3], norm)
223
- + command[3:6] + _normalize_args(command[6:], norm, add=(add_to_x, add_to_y)))
224
- elif command[0] == 'V':
225
- new_path.append([command[0]] + _normalize_args(command[1:], norm, add=(add_to_x, add_to_y), flip=True))
226
- elif command[0] == command[0].upper():
227
- new_path.append([command[0]] + _normalize_args(command[1:], norm, add=(add_to_x, add_to_y)))
228
- elif command[0] in 'zZ':
229
- new_path.append([command[0]])
230
- else:
231
- new_path.append([command[0]] + _normalize_args(command[1:], norm))
232
-
233
- return new_path
234
-
235
-
236
- def _convert_args(args, curr_pos, cmd):
237
- """Converts given args to relative values."""
238
- # NOTE: glyphs only use a very small subset of commands (L, C, M, and Z -- I
239
- # believe). So I'm not handling A and H for now.
240
- if cmd in 'AH':
241
- raise NotImplementedError('These commands have >6 args (not supported).')
242
-
243
- new_args = []
244
- for i, arg in enumerate(args):
245
- x_or_y = i % 2
246
- if cmd == 'H':
247
- x_or_y = (i + 1) % 2
248
- new_args.append(str(float(arg) - curr_pos[x_or_y]))
249
-
250
- return new_args
251
-
252
-
253
- def _update_curr_pos(curr_pos, cmd, start_of_path):
254
- """Calculate the position of the pen after cmd is applied."""
255
- if cmd[0] in 'ml':
256
- curr_pos = [curr_pos[0] + float(cmd[1]), curr_pos[1] + float(cmd[2])]
257
- if cmd[0] == 'm':
258
- start_of_path = curr_pos
259
- elif cmd[0] in 'z':
260
- curr_pos = start_of_path
261
- elif cmd[0] in 'h':
262
- curr_pos = [curr_pos[0] + float(cmd[1]), curr_pos[1]]
263
- elif cmd[0] in 'v':
264
- curr_pos = [curr_pos[0], curr_pos[1] + float(cmd[1])]
265
- elif cmd[0] in 'ctsqa':
266
- curr_pos = [curr_pos[0] + float(cmd[-2]), curr_pos[1] + float(cmd[-1])]
267
-
268
- return curr_pos, start_of_path
269
-
270
-
271
- def _make_relative(cmds):
272
- """Convert commands in a path to relative positioning."""
273
- curr_pos = (0.0, 0.0)
274
- start_of_path = (0.0, 0.0)
275
- new_cmds = []
276
- for cmd in cmds:
277
- if cmd[0].lower() == cmd[0]:
278
- new_cmd = cmd
279
- elif cmd[0].lower() == 'z':
280
- new_cmd = [cmd[0].lower()]
281
- else:
282
- new_cmd = [cmd[0].lower()] + _convert_args(cmd[1:], curr_pos, cmd=cmd[0])
283
- new_cmds.append(new_cmd)
284
- curr_pos, start_of_path = _update_curr_pos(curr_pos, new_cmd, start_of_path)
285
- return new_cmds
286
-
287
-
288
- def _is_to_left_of(pt1, pt2):
289
- pt1_norm = (pt1[0]**2 + pt1[1]**2)
290
- pt2_norm = (pt2[0]**2 + pt2[1]**2)
291
- return pt1[1] < pt2[1] or (pt1_norm == pt2_norm and pt1[0] < pt2[0])
292
-
293
-
294
- def _get_leftmost_point(path):
295
- """Returns the leftmost, topmost point of the path."""
296
- leftmost = (float('inf'), float('inf'))
297
- idx = -1
298
-
299
- for i, cmd in enumerate(path):
300
- if len(cmd) > 1:
301
- endpoint = cmd[-2:]
302
- if _is_to_left_of(endpoint, leftmost):
303
- leftmost = endpoint
304
- idx = i
305
-
306
- return leftmost, idx
307
-
308
-
309
- def _separate_substructures(path):
310
- """Returns a list of subpaths, each representing substructures the glyph."""
311
- substructures = []
312
- curr = []
313
- for cmd in path:
314
- if cmd[0] in 'mM' and curr:
315
- substructures.append(curr)
316
- curr = []
317
- curr.append(cmd)
318
- if curr:
319
- substructures.append(curr)
320
- return substructures
321
-
322
-
323
- def _is_clockwise(subpath):
324
- """Returns whether the given subpath is clockwise-oriented."""
325
- pts = [cmd[-2:] for cmd in subpath]
326
- det = 0
327
- for i in range(len(pts) - 1):
328
- det += np.linalg.det(pts[i:i + 2])
329
- return det > 0
330
-
331
-
332
- def _make_clockwise(subpath):
333
- """Inverts the cardinality of the given subpath."""
334
- new_path = [subpath[0]]
335
- other_cmds = list(reversed(subpath[1:]))
336
- for i, cmd in enumerate(other_cmds):
337
- if i + 1 == len(other_cmds):
338
- where_we_were = subpath[0][-2:]
339
- else:
340
- where_we_were = other_cmds[i + 1][-2:]
341
-
342
- if len(cmd) > 3:
343
- new_cmd = [cmd[0], cmd[3], cmd[4], cmd[1], cmd[2],
344
- where_we_were[0], where_we_were[1]]
345
- else:
346
- new_cmd = [cmd[0], where_we_were[0], where_we_were[1]]
347
-
348
- new_path.append(new_cmd)
349
- return new_path
350
-
351
-
352
- def _canonicalize(path):
353
- """Makes all paths start at top left, and go clockwise first."""
354
- # convert args to floats
355
- #print(len(path),path)
356
-
357
- path = [[x[0]] + list(map(float, x[1:])) for x in path]
358
- # print(len(path),path)
359
-
360
- # _canonicalize each subpath separately
361
- #pdb.set_trace()
362
-
363
- new_substructures = []
364
- for subpath in _separate_substructures(path):
365
- # print(subpath,"\n")
366
- leftmost_point, leftmost_idx = _get_leftmost_point(subpath)
367
- reordered = ([['M', leftmost_point[0], leftmost_point[1]]] + subpath[leftmost_idx + 1:] + subpath[1:leftmost_idx + 1])
368
- new_substructures.append((reordered, leftmost_point))
369
-
370
- # sys.exit()
371
- new_path = []
372
- first_substructure_done = False
373
- should_flip_cardinality = False
374
- for sp, _ in sorted(new_substructures, key=lambda x: (x[1][1], x[1][0])):
375
- if not first_substructure_done:
376
- # we're looking at the first substructure now, we can determine whether we
377
- # will flip the cardniality of the whole icon or not
378
- should_flip_cardinality = not _is_clockwise(sp)
379
- first_substructure_done = True
380
-
381
- if should_flip_cardinality:
382
- sp = _make_clockwise(sp)
383
-
384
- new_path.extend(sp)
385
-
386
- # convert args to strs
387
- path = [[x[0]] + list(map(str, x[1:])) for x in new_path]
388
- return path
389
-
390
-
391
- # ######### UTILS FOR CONVERTING TOKENIZED PATHS TO VECTORS ###########
392
- def _path_to_vector(path, categorical=False):
393
- """Converts path's commands to a series of vectors."""
394
- # Notes:
395
- # - The SimpleSVG dataset does not have any 't', 'q', 'Z', 'T', or 'Q'.
396
- # Thus, we don't handle those here.
397
- # - We also removed all 'z's.
398
- # - The x-axis-rotation argument to a commands is always 0 in this
399
- # dataset, so we ignore it
400
-
401
- # Many commands have args that correspond to args in other commands.
402
- # v __,__ _______________ ______________,_________ __,__ __,__ _,y
403
- # h __,__ _______________ ______________,_________ __,__ __,__ x,_
404
- # z __,__ _______________ ______________,_________ __,__ __,__ _,_
405
- # a rx,ry x-axis-rotation large-arc-flag,sweepflag __,__ __,__ x,y
406
- # l __,__ _______________ ______________,_________ __,__ __,__ x,y
407
- # c __,__ _______________ ______________,_________ x1,y1 x2,y2 x,y
408
- # m __,__ _______________ ______________,_________ __,__ __,__ x,y
409
- # s __,__ _______________ ______________,_________ __,__ x2,y2 x,y
410
-
411
- # So each command will be converted to a vector where the dimension is the
412
- # minimal number of arguments to all commands:
413
- # [rx, ry, large-arc-flag, sweepflag, x1, y1, x2, y2, x, y]
414
- # If a command does not output a certain arg, it is set to 0.
415
- # "l 5,5" becomes [0, 0, 0, 0, 0, 0, 0, 0, 5, 5]
416
-
417
- # Also note, as of now we also output an extra dimension at index 0, which
418
- # indicates which command is being outputted (integer).
419
- new_path = []
420
- for cmd in path:
421
- new_path.append(_cmd_to_vector(cmd, categorical=categorical))
422
- return new_path
423
-
424
-
425
- def _cmd_to_vector(cmd_list, categorical=False):
426
- """Converts the given command (given as a list) into a vector.
427
- UM_ARGS = {'v': 1, 'V': 1, 'h': 1, 'H': 1, 'a': 7, 'A': 7, 'l': 2, 'L': 2,
428
- 't': 2, 'T': 2, 'c': 6, 'C': 6, 'm': 2, 'M': 2, 's': 4, 'S': 4,
429
- 'q': 4, 'Q': 4, 'z': 0}
430
-
431
- CMDS_LIST = 'zhvmltsqcaHVMLTSQCA'
432
- CMD_MAPPING = {cmd: i for i, cmd in enumerate(CMDS_LIST)}
433
- """
434
- # For description of how this conversion happens, see
435
- # _path_to_vector docstring.
436
- cmd = cmd_list[0]
437
- args = cmd_list[1:]
438
-
439
- if not categorical:
440
- # integer, for MSE
441
- command = [float(CMD_MAPPING[cmd])]
442
- else:
443
- # one hot + 1 dim for EOS.
444
- command = [0.0] * (len(CMDS_LIST) + 1) # 大概有19个commands?
445
- command[CMD_MAPPING[cmd] + 1] = 1.0
446
-
447
- arguments = [0.0] * 10
448
- if cmd in 'hH':
449
- arguments[8] = float(args[0]) # x
450
- elif cmd in 'vV':
451
- arguments[9] = float(args[0]) # y
452
- elif cmd in 'mMlLtT':
453
- arguments[8] = float(args[0]) # x
454
- arguments[9] = float(args[1]) # y
455
- elif cmd in 'sSqQ':
456
- arguments[6] = float(args[0]) # x2
457
- arguments[7] = float(args[1]) # y2
458
- arguments[8] = float(args[2]) # x
459
- arguments[9] = float(args[3]) # y
460
- elif cmd in 'cC':
461
- arguments[4] = float(args[0]) # x1
462
- arguments[5] = float(args[1]) # y1
463
- arguments[6] = float(args[2]) # x2
464
- arguments[7] = float(args[3]) # y2
465
- arguments[8] = float(args[4]) # x
466
- arguments[9] = float(args[5]) # y
467
- elif cmd in 'aA':
468
- arguments[0] = float(args[0]) # rx
469
- arguments[1] = float(args[1]) # ry
470
- # we skip x-axis-rotation
471
- arguments[2] = float(args[3]) # large-arc-flag
472
- arguments[3] = float(args[4]) # sweep-flag
473
- # a does not have x1, y1, x2, y2 args
474
- arguments[8] = float(args[5]) # x
475
- arguments[9] = float(args[6]) # y
476
-
477
- return command + arguments
478
-
479
-
480
- ################## UTILS FOR RENDERING PATH INTO IMAGE #################
481
- def _cubicbezier(x0, y0, x1, y1, x2, y2, x3, y3, n=40):
482
- """Return n points along cubiz bezier with given control points."""
483
- # from http://rosettacode.org/wiki/Bitmap/B%C3%A9zier_curves/Cubic
484
- pts = []
485
- for i in range(n + 1):
486
- t = float(i) / float(n)
487
- a = (1. - t)**3
488
- b = 3. * t * (1. - t)**2
489
- c = 3.0 * t**2 * (1.0 - t)
490
- d = t**3
491
-
492
- x = float(a * x0 + b * x1 + c * x2 + d * x3)
493
- y = float(a * y0 + b * y1 + c * y2 + d * y3)
494
- pts.append((x, y))
495
- return list(zip(*pts))
496
-
497
-
498
- def _update_pos(curr_pos, end_pos, absolute):
499
- if absolute:
500
- return end_pos
501
- return curr_pos[0] + end_pos[0], curr_pos[1] + end_pos[1]
502
-
503
-
504
- def constant_color(*unused_args):
505
- return np.array([255, 255, 255])
506
-
507
-
508
- def _render_cubic(canvas, curr_pos, c_args, absolute, color):
509
- """Renders a cubic bezier curve in the given canvas."""
510
- if not absolute:
511
- c_args[0] += curr_pos[0]
512
- c_args[1] += curr_pos[1]
513
- c_args[2] += curr_pos[0]
514
- c_args[3] += curr_pos[1]
515
- c_args[4] += curr_pos[0]
516
- c_args[5] += curr_pos[1]
517
- x, y = _cubicbezier(curr_pos[0], curr_pos[1],
518
- c_args[0], c_args[1],
519
- c_args[2], c_args[3],
520
- c_args[4], c_args[5])
521
- max_possible = len(canvas)
522
- x = [int(round(x_)) for x_ in x]
523
- y = [int(round(y_)) for y_ in y]
524
-
525
- def within_range(x):
526
- return 0 <= x < max_possible
527
-
528
- filtered = [(x_, y_) for x_, y_ in zip(x, y)
529
- if within_range(x_) and within_range(y_)]
530
- if not filtered:
531
- return
532
- x, y = list(zip(*filtered))
533
- canvas[y, x, :] = color
534
-
535
-
536
- def _render_line(canvas, curr_pos, l_args, absolute, color):
537
- """Renders a line in the given canvas."""
538
- end_point = l_args
539
- if not absolute:
540
- end_point[0] += curr_pos[0]
541
- end_point[1] += curr_pos[1]
542
- rr, cc, val = draw.line_aa(int(curr_pos[0]), int(curr_pos[1]),
543
- int(end_point[0]), int(end_point[1]))
544
-
545
- max_possible = len(canvas)
546
-
547
- def within_range(x):
548
- return 0 <= x < max_possible
549
-
550
- filtered = [(x, y, v) for x, y, v in zip(rr, cc, val)
551
- if within_range(x) and within_range(y)]
552
- if not filtered:
553
- return
554
- rr, cc, val = list(zip(*filtered))
555
- val = [(v * color) for v in val]
556
- canvas[cc, rr, :] = val
557
-
558
-
559
- def _per_step_render(path, absolute=False, color=constant_color):
560
- """Render the icon's edges, given its path."""
561
- def to_canvas_size(l):
562
- return [float(f) * (64. / 24.) for f in l]
563
-
564
- canvas = np.zeros((64, 64, 3))
565
- curr_pos = (0.0, 0.0)
566
- for i, cmd in enumerate(path):
567
- if not cmd:
568
- continue
569
- if cmd[0] in 'mM':
570
- curr_pos = _update_pos(curr_pos, to_canvas_size(cmd[-2:]), absolute)
571
- elif cmd[0] in 'cC':
572
- _render_cubic(canvas, curr_pos, to_canvas_size(cmd[1:]), absolute, color(i, 55))
573
- curr_pos = _update_pos(curr_pos, to_canvas_size(cmd[-2:]), absolute)
574
- elif cmd[0] in 'lL':
575
- _render_line(canvas, curr_pos, to_canvas_size(cmd[1:]), absolute, color(i, 55))
576
- curr_pos = _update_pos(curr_pos, to_canvas_size(cmd[1:]), absolute)
577
-
578
- return canvas
579
-
580
-
581
- def _zoom_out(path_list, add_baseline=0., per=22):
582
- """Makes glyph slightly smaller in viewbox, makes some descenders visible."""
583
- # assumes tensor is already unnormalized, and in long form
584
- new_path = []
585
- for command in path_list:
586
- args = []
587
- is_even = False
588
- for arg in command[1:]:
589
- if is_even:
590
- args.append(str(float(arg) - ((24. - per) / 24.) * 64. / 4.))
591
- is_even = False
592
- else:
593
- args.append(str(float(arg) - add_baseline))
594
- is_even = True
595
- new_path.append([command[0]] + args)
596
- return new_path
597
-
598
-
599
- ##################### UTILS FOR PROCESSING VECTORS ################
600
- def _append_eos(sample, categorical, feature_dim):
601
- if not categorical:
602
- eos = -1 * np.ones(feature_dim)
603
- else:
604
- eos = np.zeros(feature_dim)
605
- eos[0] = 1.0
606
- sample.append(eos)
607
- return sample
608
-
609
-
610
- def _make_simple_cmds_long(out):
611
- """Converts svg decoder output to format required by some render functions."""
612
- # out has 10 dims
613
- # the first 4 are respectively dims 0, 4, 5, 9 of the full 20-dim onehot vec
614
- # the latter 6 are the 6 last dims of the 10-dim arg vec
615
- shape_minus_dim = list(np.shape(out))[:-1]
616
- # print("make? ",shape_minus_dim ) # [51]
617
-
618
- return np.concatenate([out[..., :1], # [51,1] 51个steps的第1维特征
619
- np.zeros(shape_minus_dim + [3]),# [51,3]
620
- out[..., 1:3], #[51,2]
621
- np.zeros(shape_minus_dim + [3]),# [51,3]
622
- out[..., 3:4],# [51,1]
623
- np.zeros(shape_minus_dim + [14]),# [51,14]
624
- out[..., 4:]], -1)# [51,6] # 最后的6个绘制参数
625
-
626
- def render(tensor, data_dir=None):
627
- """Converts SVG decoder output into HTML svg."""
628
- # undo normalization
629
- # mean_npz, stdev_npz = get_means_stdevs(data_dir)
630
- # tensor = (tensor * stdev_npz) + mean_npz
631
-
632
- # convert to html
633
- #print("before",tensor.shape)# 51, 10)
634
- tensor = _make_simple_cmds_long(tensor)
635
- # print("after",tensor.shape)#(51, 30)
636
- # vector = np.squeeze(np.squeeze(tensor, 0), 2)
637
- # print("1",tensor[0,:5])# (51, 30)
638
- html = _vector_to_svg(tensor, stop_at_eos=True, categorical=True)
639
- # print(html.shape)
640
- # some aesthetic postprocessing
641
- html = postprocess(html)
642
- html = html.replace('256px', '50px')
643
-
644
- return html
645
-
646
- ################# UTILS FOR CONVERTING VECTORS TO SVGS ########################
647
- #note: transform the decoded trg_seq into the common svg format.把decode出来的seq转成html的svg,命令有前后关系,也都是相对位置。
648
- def _vector_to_svg(vectors, stop_at_eos=False, categorical=False):
649
- """Tranforms a given vector to an svg string.
650
-
651
- """
652
- new_path = []
653
- for vector in vectors:
654
- if stop_at_eos:
655
- if categorical:
656
- try:
657
- is_eos = np.argmax(vector[:len(CMDS_LIST) + 1]) == 0
658
- except Exception:
659
- raise Exception(vector)
660
- else:
661
- is_eos = vector[0] < -0.5
662
-
663
- if is_eos:
664
- break
665
- new_path.append(' '.join(_vector_to_cmd(vector, categorical=categorical))) #
666
- new_path = ' '.join(new_path) # 加入new_path,每个path都以空格分隔
667
- return SVG_PREFIX_BIG + PATH_PREFIX_1 + new_path + PATH_POSFIX_1 + SVG_POSFIX
668
-
669
- def _vector_to_path(vectors):
670
- new_path = []
671
- for vector in vectors:
672
- #print(vector,"???")
673
- new_path.append(_vector_to_cmd(vector,categorical=True)) #
674
- #print(_vector_to_cmd(vector),"hhh")
675
- # new_path = ' '.join(new_path) # 加入new_path,每个path都以空格分隔
676
- return new_path
677
-
678
- def _vector_to_cmd(vector, categorical=False, return_floats=False):
679
- """Does the inverse transformation as _cmd_to_vector().
680
- UM_ARGS = {'v': 1, 'V': 1, 'h': 1, 'H': 1, 'a': 7, 'A': 7, 'l': 2, 'L': 2,
681
- 't': 2, 'T': 2, 'c': 6, 'C': 6, 'm': 2, 'M': 2, 's': 4, 'S': 4,
682
- 'q': 4, 'Q': 4, 'z': 0}
683
-
684
- CMDS_LIST = 'zhvmltsqcaHVMLTSQCA'
685
- CMD_MAPPING = {cmd: i for i, cmd in enumerate(CMDS_LIST)}
686
-
687
- """
688
- cast_fn = float if return_floats else str
689
- if categorical:
690
- # print(vector.shape,vector)# 30
691
- #print("??",len(CMDS_LIST)) # 19
692
- command = vector[:len(CMDS_LIST) + 1],# 前20维
693
- arguments = vector[len(CMDS_LIST) + 1:]# 后10维
694
- cmd_idx = np.argmax(command) - 1 # 看当前绘制命令属于哪一类
695
-
696
- else:
697
-
698
- command, arguments = vector[:1], vector[1:]
699
- cmd_idx = int(round(command[0]))
700
-
701
- if cmd_idx < -0.5:
702
- # EOS
703
- return []
704
- if cmd_idx >= len(CMDS_LIST):
705
- cmd_idx = len(CMDS_LIST) - 1
706
-
707
- cmd = CMDS_LIST[cmd_idx]
708
- cmd = cmd.upper()
709
- cmd_list = [cmd]
710
-
711
- if cmd in 'hH': # 如果是画线,而且是x轴
712
- cmd_list.append(cast_fn(arguments[8])) # x
713
- elif cmd in 'vV': # 如果是画线,而且是y轴
714
- cmd_list.append(cast_fn(arguments[9])) # y
715
- elif cmd in 'mMlLtT':
716
- cmd_list.append(cast_fn(arguments[8])) # x
717
- cmd_list.append(cast_fn(arguments[9])) # y
718
- elif cmd in 'sSqQ':
719
- cmd_list.append(cast_fn(arguments[6])) # x2
720
- cmd_list.append(cast_fn(arguments[7])) # y2
721
- cmd_list.append(cast_fn(arguments[8])) # x
722
- cmd_list.append(cast_fn(arguments[9])) # y
723
- elif cmd in 'cC':
724
- cmd_list.append(cast_fn(arguments[4])) # x1
725
- cmd_list.append(cast_fn(arguments[5])) # y1
726
- cmd_list.append(cast_fn(arguments[6])) # x2
727
- cmd_list.append(cast_fn(arguments[7])) # y2
728
- cmd_list.append(cast_fn(arguments[8])) # x
729
- cmd_list.append(cast_fn(arguments[9])) # y
730
- elif cmd in 'aA':
731
- cmd_list.append(cast_fn(arguments[0])) # rx
732
- cmd_list.append(cast_fn(arguments[1])) # ry
733
- # x-axis-rotation is always 0
734
- cmd_list.append(cast_fn('0'))
735
- # the following two flags are binary.
736
- cmd_list.append(cast_fn(1 if arguments[2] > 0.5 else 0)) # large-arc-flag
737
- cmd_list.append(cast_fn(1 if arguments[3] > 0.5 else 0)) # sweep-flag
738
- cmd_list.append(cast_fn(arguments[8])) # x
739
- cmd_list.append(cast_fn(arguments[9])) # y
740
-
741
- return cmd_list
742
-
743
-
744
- ############## UTILS FOR CONVERTING SVGS/VECTORS TO IMAGES ###################
745
-
746
- # From Infer notebook
747
- start = ("""<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www."""
748
- """w3.org/1999/xlink" width="256px" height="256px" style="-ms-trans"""
749
- """form: rotate(360deg); -webkit-transform: rotate(360deg); transfo"""
750
- """rm: rotate(360deg);" preserveAspectRatio="xMidYMid meet" viewBox"""
751
- """="0 0 24 24"><path d=\"""")
752
- end = """\" fill="currentColor"/></svg>"""
753
-
754
- COMMAND_RX = re.compile("([MmLlHhVvCcSsQqTtAaZz])")
755
- FLOAT_RX = re.compile("[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?") # noqa
756
-
757
-
758
- def svg_html_to_path_string(svg):
759
- return svg.replace(start, '').replace(end, '')
760
-
761
-
762
- def _tokenize(pathdef):
763
- """Returns each svg token from path list."""
764
- # e.g.: 'm0.1-.5c0,6' -> m', '0.1, '-.5', 'c', '0', '6'
765
- for x in COMMAND_RX.split(pathdef):
766
- if x != '' and x in 'MmLlHhVvCcSsQqTtAaZz':
767
- yield x
768
- for token in FLOAT_RX.findall(x):
769
- yield token
770
-
771
-
772
- def path_string_to_tokenized_commands(path):
773
- """Tokenizes the given path string.
774
-
775
- E.g.:
776
- Given M 0.5 0.5 l 0.25 0.25 z
777
- Returns [['M', '0.5', '0.5'], ['l', '0.25', '0.25'], ['z']]
778
- """
779
- new_path = []
780
- current_cmd = []
781
- for token in _tokenize(path):
782
- if len(current_cmd) > 0:
783
- if token in 'MmLlHhVvCcSsQqTtAaZz':
784
- # cmd ended, convert to vector and add to new_path
785
- new_path.append(current_cmd)
786
- current_cmd = [token]
787
- else:
788
- # add arg to command
789
- current_cmd.append(token)
790
- else:
791
- # add to start new cmd
792
- current_cmd.append(token)
793
-
794
- if current_cmd:
795
- # process command still unprocessed
796
- new_path.append(current_cmd)
797
-
798
- return new_path
799
-
800
-
801
- def separate_substructures(tokenized_commands):
802
- """Returns a list of SVG substructures."""
803
- # every moveTo command starts a new substructure
804
- # an SVG substructure is a subpath that closes on itself
805
- # such as the outter and the inner edge of the character `o`
806
- substructures = []
807
- curr = []
808
- for cmd in tokenized_commands:
809
- if cmd[0] in 'mM' and len(curr) > 0:
810
- substructures.append(curr)
811
- curr = []
812
- curr.append(cmd)
813
- if len(curr) > 0:
814
- substructures.append(curr)
815
- return substructures
816
-
817
-
818
- def postprocess(svg, dist_thresh=2., skip=False):
819
- path = svg_html_to_path_string(svg)
820
- #print(svg)
821
- svg_template = svg.replace(path, '{}')
822
- tokenized_commands = path_string_to_tokenized_commands(path)
823
-
824
- def dist(a, b):
825
- return np.sqrt((float(a[0]) - float(b[0]))**2 + (float(a[1]) - float(b[1]))**2)
826
-
827
- def are_close_together(a, b, t):
828
- return dist(a, b) < t
829
-
830
- # first, go through each start/end point and merge if they're close enough
831
- # together (that is, make end point the same as the start point).
832
- # TODO: there are better ways of doing this, in a way that propagates errors.
833
- # back (so if total error is 0.2, go through all N commands in this
834
- # substructure and fix each by 0.2/N (unless they have 0 vertical change))
835
- # NOTE: this is the same.
836
- substructures = separate_substructures(tokenized_commands)
837
- # print(len(substructures))# 7578
838
-
839
- previous_substructure_endpoint = (0., 0.,)
840
- for substructure in substructures:
841
- # first, if the last substructure's endpoint was updated, we must update
842
- # the start point of this one to reflect the opposite update
843
- substructure[0][-2] = str(float(substructure[0][-2]) -
844
- previous_substructure_endpoint[0])
845
- substructure[0][-1] = str(float(substructure[0][-1]) -
846
- previous_substructure_endpoint[1])
847
-
848
- start = list(map(float, substructure[0][-2:]))
849
- curr_pos = (0., 0.)
850
- for cmd in substructure:
851
- curr_pos, _ = _update_curr_pos(curr_pos, cmd, (0., 0.))
852
- if are_close_together(start, curr_pos, dist_thresh):
853
- new_point = np.array(start)
854
- previous_substructure_endpoint = ((new_point[0] - curr_pos[0]),
855
- (new_point[1] - curr_pos[1]))
856
- substructure[-1][-2] = str(float(substructure[-1][-2]) +
857
- (new_point[0] - curr_pos[0]))
858
- substructure[-1][-1] = str(float(substructure[-1][-1]) +
859
- (new_point[1] - curr_pos[1]))
860
- if substructure[-1][0] in 'cC':
861
- substructure[-1][-4] = str(float(substructure[-1][-4]) +
862
- (new_point[0] - curr_pos[0]))
863
- substructure[-1][-3] = str(float(substructure[-1][-3]) +
864
- (new_point[1] - curr_pos[1]))
865
-
866
- if skip:
867
- return svg_template.format(' '.join([' '.join(' '.join(cmd) for cmd in s)
868
- for s in substructures]))
869
-
870
- def cosa(x, y):
871
- return (x[0] * y[0] + x[1] * y[1]) / ((np.sqrt(x[0]**2 + x[1]**2) * np.sqrt(y[0]**2 + y[1]**2)))
872
-
873
- def rotate(a, x, y):
874
- return (x * np.cos(a) - y * np.sin(a), y * np.cos(a) + x * np.sin(a))
875
- # second, gotta find adjacent bezier curves and, if their control points
876
- # are well enough aligned, fully align them
877
- for substructure in substructures:
878
- curr_pos = (0., 0.)
879
- new_curr_pos, _ = _update_curr_pos((0., 0.,), substructure[0], (0., 0.))
880
-
881
- for cmd_idx in range(1, len(substructure)):
882
- prev_cmd = substructure[cmd_idx-1]
883
- cmd = substructure[cmd_idx]
884
-
885
- new_new_curr_pos, _ = _update_curr_pos(new_curr_pos, cmd, (0., 0.))
886
-
887
- if cmd[0] == 'c':
888
- if prev_cmd[0] == 'c':
889
- # check the vectors and update if needed
890
- # previous control pt wrt new curr point
891
- prev_ctr_point = (curr_pos[0] + float(prev_cmd[3]) - new_curr_pos[0],
892
- curr_pos[1] + float(prev_cmd[4]) - new_curr_pos[1])
893
- ctr_point = (float(cmd[1]), float(cmd[2]))
894
-
895
- if -1. < cosa(prev_ctr_point, ctr_point) < -0.95:
896
- # calculate exact angle between the two vectors
897
- angle_diff = (np.pi - np.arccos(cosa(prev_ctr_point, ctr_point)))/2
898
-
899
- # rotate each vector by angle/2 in the correct direction for each.
900
- sign = np.sign(np.cross(prev_ctr_point, ctr_point))
901
- new_ctr_point = rotate(sign * angle_diff, *ctr_point)
902
- new_prev_ctr_point = rotate(-sign * angle_diff, *prev_ctr_point)
903
-
904
- # override the previous control points
905
- # (which has to be wrt previous curr position)
906
- substructure[cmd_idx-1][3] = str(new_prev_ctr_point[0] -
907
- curr_pos[0] + new_curr_pos[0])
908
- substructure[cmd_idx-1][4] = str(new_prev_ctr_point[1] -
909
- curr_pos[1] + new_curr_pos[1])
910
- substructure[cmd_idx][1] = str(new_ctr_point[0])
911
- substructure[cmd_idx][2] = str(new_ctr_point[1])
912
-
913
- curr_pos = new_curr_pos
914
- new_curr_pos = new_new_curr_pos
915
-
916
- # print('0',substructures)
917
- return svg_template.format(' '.join([' '.join(' '.join(cmd) for cmd in s)
918
- for s in substructures]))
919
-
920
-
921
-
922
-
923
-
924
- # def get_means_stdevs(data_dir):
925
- # """Returns the means and stdev saved in data_dir."""
926
- # if data_dir not in means_stdevs:
927
- # with tf.gfile.Open(os.path.join(data_dir, 'mean.npz'), 'r') as f:
928
- # mean_npz = np.load(f)
929
- # with tf.gfile.Open(os.path.join(data_dir, 'stdev.npz'), 'r') as f:
930
- # stdev_npz = np.load(f)
931
- # means_stdevs[data_dir] = (mean_npz, stdev_npz)
932
- # return means_stdevs[data_dir]
933
-
934
-
935
-
936
-
937
- ###############
938
-
939
-
940
- def convert_to_svg(decoder_output, categorical=False):
941
- converted = []
942
- for example in decoder_output:
943
- converted.append(_vector_to_svg(example, True, categorical=categorical))
944
- return np.array(converted)
945
-
946
-
947
- def create_image_conversion_fn(max_outputs, categorical=False):
948
- """Binds the number of outputs to the image conversion fn (to svg or png)."""
949
- def convert_to_svg(decoder_output):
950
- converted = []
951
- for example in decoder_output:
952
- if len(converted) == max_outputs:
953
- break
954
- converted.append(_vector_to_svg(example, True, categorical=categorical))
955
- return np.array(converted)
956
-
957
- return convert_to_svg
958
-
959
-
960
- ################### UTILS FOR CREATING TF SUMMARIES ##########################
961
- def _make_encoded_image(img_tensor):
962
- pil_img = Image.fromarray(np.squeeze(img_tensor * 255).astype(np.uint8), mode='L')
963
- buff = io.BytesIO()
964
- pil_img.save(buff, format='png')
965
- encoded_image = buff.getvalue()
966
- return encoded_image
967
-
968
-
969
- ################### CHECK GLYPH/PATH VALID ##############################################
970
- def is_valid_glyph(g):
971
- is_09 = 48 <= g['uni'] <= 57
972
- is_capital_az = 65 <= g['uni'] <= 90
973
- is_az = 97 <= g['uni'] <= 122
974
- is_valid_dims = g['width'] != 0 and g['vwidth'] != 0
975
- return (is_09 or is_capital_az or is_az) and is_valid_dims
976
-
977
-
978
- def is_valid_path(pathunibfp):
979
- # print(len(pathunibfp[0]))
980
- if len(pathunibfp[0])>70:
981
- print("!!!more than 400",len(pathunibfp[0]))
982
- # sys.exit()
983
- return pathunibfp[0] and len(pathunibfp[0]) <= 70,len(pathunibfp[0])
984
-
985
-
986
- ################### DATASET PROCESSING #######################################
987
- def convert_to_path(g):
988
- """Converts SplineSet in SFD font to str path."""
989
- path = _sfd_to_path_list(g)
990
- path = _add_missing_cmds(path, remove_zs=False)
991
- path = _normalize_based_on_viewbox(path, '0 0 {} {}'.format(g['width'], g['vwidth']))
992
- return path, g['uni'], g['binary_fp']
993
- def convert_simple_vector_to_path(seq):
994
- path=[]
995
- for i in range(seq.shape[0]):
996
- # seq_i = seq[i]
997
- path_i=[]
998
- cmd = np.argmax(seq[i][:4])
999
- # args = seq[i][4:]
1000
- p0 = seq[i][4:6]
1001
- p1 = seq[i][6:8]
1002
- p2 = seq[i][8:10]
1003
- if cmd == 0:
1004
- break
1005
- elif cmd==1:
1006
- path_i.append('M')
1007
- path_i.append(str(p2[0]))
1008
- path_i.append(str(p2[1]))
1009
- elif cmd==2:
1010
- path_i.append('L')
1011
- path_i.append(str(p2[0]))
1012
- path_i.append(str(p2[1]))
1013
- elif cmd==3:
1014
- path_i.append('C')
1015
- path_i.append(str(p0[0]))
1016
- path_i.append(str(p0[1]))
1017
- path_i.append(str(p1[0]))
1018
- path_i.append(str(p1[1]))
1019
- path_i.append(str(p2[0]))
1020
- path_i.append(str(p2[1]))
1021
- else:
1022
- print("wrong!!! to path")
1023
- sys.exit()
1024
- path.append(path_i)
1025
- return path
1026
- # print("jjj")
1027
- def clockwise(seq):
1028
- #pdb.set_trace()
1029
- path=convert_simple_vector_to_path(seq)
1030
- path = _canonicalize(path)
1031
- final = {}
1032
- final['rendered'] = _per_step_render(path, absolute=True)
1033
- vector = _path_to_vector(path, categorical=True)
1034
- vector = np.array(vector)
1035
- # print(vector.shape,vector[:,9])# note vector: 12,30
1036
-
1037
- vector = np.concatenate([np.take(vector, [0, 4, 5, 9], axis=-1), vector[..., -6:]], axis=-1)
1038
- final['seq_len'] = np.shape(vector)[0]
1039
- vector = _append_eos(vector.tolist(), True, 10)
1040
- final['sequence'] = np.concatenate((vector, np.zeros(((70 - final['seq_len']), 10))), 0)
1041
- final['rendered'] = np.reshape(final['rendered'][..., 0], [64 * 64]).astype(np.float32).tolist()
1042
- return final
1043
-
1044
-
1045
- def create_example(pathunibfp):
1046
- """Bulk of dataset processing. Converts str path to np array"""
1047
- path, uni, binary_fp = pathunibfp
1048
- final = {}
1049
-
1050
- # zoom out
1051
- path = _zoom_out(path)
1052
- # make clockwise
1053
- path = _canonicalize(path)
1054
-
1055
- # render path for training
1056
- final['rendered'] = _per_step_render(path, absolute=True)
1057
-
1058
- # make path relative
1059
- #path = _make_relative(path) # note 不rela 直接是绝对的
1060
- # convert to vector
1061
- vector = _path_to_vector(path, categorical=True)
1062
-
1063
-
1064
-
1065
- # path2 = _vector_to_path(vector)# note vector转成path
1066
- #print(path2)
1067
- #print(path==path2)
1068
-
1069
- vector = np.array(vector)
1070
- # print(vector.shape,vector[:,9])# note vector: 12,30
1071
- vector = np.concatenate([np.take(vector, [0, 4, 5, 9], axis=-1), vector[..., -6:]], axis=-1)
1072
-
1073
-
1074
-
1075
- #path2 = _vector_to_path(vector)
1076
- #print(path,"\nhhh",path2)
1077
-
1078
-
1079
- # print("hhh",vector)
1080
- # print(render(vector))
1081
- # sys.exit()
1082
- # count some stats
1083
- final['seq_len'] = np.shape(vector)[0]
1084
- # final['class'] = int(_map_uni_to_alphanum(uni))
1085
- final['class'] = int(_map_uni_to_alpha(uni))
1086
- final['binary_fp'] = str(binary_fp)
1087
-
1088
- # append eos
1089
- vector = _append_eos(vector.tolist(), True, 10)
1090
-
1091
- # pad path to 51 (with eos)
1092
- # pdb.set_trace()
1093
- # if final['seq_len']>50:
1094
-
1095
- # print( final['seq_len'])
1096
- final['sequence'] = np.concatenate((vector, np.zeros(((70 - final['seq_len']), 10))), 0)
1097
- #seq = final['sequence']
1098
-
1099
- # new_path = convert_simple_vector_to_path(seq)
1100
- # print(new_path,path2==new_path)
1101
-
1102
- # sys.exit()
1103
- # make pure list:
1104
- # use last channel only
1105
- final['rendered'] = np.reshape(final['rendered'][..., 0], [64 * 64]).astype(np.float32).tolist()
1106
- final['sequence'] = np.reshape(final['sequence'], [71 * 10]).astype(np.float32).tolist()
1107
- final['class'] = np.reshape(final['class'], [1]).astype(np.int64).tolist()
1108
- final['seq_len'] = np.reshape(final['seq_len'], [1]).astype(np.int64).tolist()
1109
- return final
1110
-
1111
-
1112
- def mean_to_example(mean_stdev):
1113
- """Converts the found mean and stdev to example."""
1114
- # mean_stdev is a dict
1115
- mean_stdev['mean'] = np.reshape(mean_stdev['mean'], [10]).astype(np.float32).tolist()
1116
- mean_stdev['variance'] = np.reshape(mean_stdev['variance'], [10]).astype(np.float32).tolist()
1117
- mean_stdev['stddev'] = np.reshape(mean_stdev['stddev'], [10]).astype(np.float32).tolist()
1118
- mean_stdev['count'] = np.reshape(mean_stdev['count'], [1]).astype(np.int64).tolist()
1119
- return mean_stdev
1120
-
1121
-
1122
- ################### CHECK VALID ##############################################
1123
- class MeanStddev:
1124
- """Accumulator to compute the mean/stdev of svg commands."""
1125
-
1126
- def create_accumulator(self):
1127
- curr_sum = np.zeros([10])
1128
- sum_sq = np.zeros([10])
1129
- return (curr_sum, sum_sq, 0) # x, x^2, count
1130
-
1131
- def add_input(self, sum_count, new_input):
1132
- (curr_sum, sum_sq, count) = sum_count
1133
- # new_input is a dict with keys = ['seq_len', 'sequence']
1134
- new_seq_len = new_input['seq_len'][0] # Line #754 'seq_len' is a list of one int
1135
- assert isinstance(new_seq_len, int), print(type(new_seq_len))
1136
-
1137
- # remove padding and eos from sequence
1138
- assert isinstance(new_input['sequence'], list), print(type(new_input['sequence']))
1139
- new_input_np = np.reshape(np.array(new_input['sequence']), [-1, 10])
1140
- assert isinstance(new_input_np, np.ndarray), print(type())
1141
- assert new_input_np.shape[0] >= new_seq_len
1142
- new_input_np = new_input_np[:new_seq_len, :]
1143
-
1144
- # accumulate new_sum and new_sum_sq
1145
- new_sum = np.sum([curr_sum, np.sum(new_input_np, axis=0)], axis=0)
1146
- new_sum_sq = np.sum([sum_sq, np.sum(np.power(new_input_np, 2), axis=0)],
1147
- axis=0)
1148
- return new_sum, new_sum_sq, count + new_seq_len
1149
-
1150
- def merge_accumulators(self, accumulators):
1151
- curr_sums, sum_sqs, counts = list(zip(*accumulators))
1152
- return np.sum(curr_sums, axis=0), np.sum(sum_sqs, axis=0), np.sum(counts)
1153
-
1154
- def extract_output(self, sum_count):
1155
- (curr_sum, curr_sum_sq, count) = sum_count
1156
- if count:
1157
- mean = np.divide(curr_sum, count)
1158
- variance = np.divide(curr_sum_sq, count) - np.power(mean, 2)
1159
- # -ve value could happen due to rounding
1160
- variance = np.max([variance, np.zeros(np.shape(variance))], axis=0)
1161
- stddev = np.sqrt(variance)
1162
- return {
1163
- 'mean': mean,
1164
- 'variance': variance,
1165
- 'stddev': stddev,
1166
- 'count': count
1167
- }
1168
- else:
1169
- return {
1170
- 'mean': float('NaN'),
1171
- 'variance': float('NaN'),
1172
- 'stddev': float('NaN'),
1173
- 'count': 0
1174
- }
 
1
+
2
+ # Copyright 2020 The Magenta Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import pdb
16
+ # Lint as: python3
17
+ """Defines the Material Design Icons Problem."""
18
+ import io
19
+ import numpy as np
20
+ import re
21
+
22
+ from PIL import Image
23
+ from itertools import zip_longest
24
+ from skimage import draw
25
+ import sys
26
+
27
+ SVG_PREFIX_BIG = ('<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="'
28
+ 'http://www.w3.org/1999/xlink" width="256px" height="256px"'
29
+ ' style="-ms-transform: rotate(360deg); -webkit-transform:'
30
+ ' rotate(360deg); transform: rotate(360deg);" '
31
+ 'preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24">')
32
+ PATH_PREFIX_1 = '<path d="'
33
+ PATH_POSFIX_1 = '" fill="currentColor"/>'
34
+ SVG_POSFIX = '</svg>'
35
+
36
+ NUM_ARGS = {'v': 1, 'V': 1, 'h': 1, 'H': 1, 'a': 7, 'A': 7, 'l': 2, 'L': 2,
37
+ 't': 2, 'T': 2, 'c': 6, 'C': 6, 'm': 2, 'M': 2, 's': 4, 'S': 4,
38
+ 'q': 4, 'Q': 4, 'z': 0}
39
+ # in order of arg complexity, with absolutes clustered
40
+ # recall we don't handle all commands (see docstring)
41
+
42
+ #note args:
43
+ # v, h: vertical horizental lines
44
+ # a: elliptical Arc 椭圆
45
+ # l: lineto
46
+ # t: smooth quadratic Bézier curveto 2次贝塞尔曲线
47
+ # c: curveto
48
+ # m: moveto
49
+ # s: smooth curveto
50
+ # Q: quadratic Bézier curve 2次贝塞尔曲线
51
+ # z: closepath
52
+ #CMDS_LIST = 'zhvmltsqcaHVMLTSQCA'
53
+ CMDS_LIST = 'zHVMLTSQCAhvmltsqca'
54
+ CMD_MAPPING = {cmd: i for i, cmd in enumerate(CMDS_LIST)}
55
+
56
+ FEATURE_DIM = 10
57
+
58
+
59
+ ############################### GENERAL UTILS #################################
60
+ def grouper(iterable, batch_size, fill_value=None):
61
+ """Helper method for returning batches of size batch_size of a dataset."""
62
+ # grouper('ABCDEF', 3) -> 'ABC', 'DEF'
63
+ args = [iter(iterable)] * batch_size
64
+ return zip_longest(*args, fillvalue=fill_value)
65
+
66
+
67
+ def _map_uni_to_alphanum(uni):
68
+ """Maps [0-9 A-Z a-z] to numbers 0-62."""
69
+ if 48 <= uni <= 57:
70
+ return uni - 48
71
+ elif 65 <= uni <= 90:
72
+ return uni - 65 + 10
73
+ return uni - 97 + 36
74
+
75
+
76
+ def _map_uni_to_alpha(uni):
77
+ """Maps [A-Z a-z] to numbers 0-52."""
78
+ if 65 <= uni <= 90:
79
+ return uni - 65
80
+ return uni - 97 + 26
81
+
82
+
83
+ ############# UTILS FOR CONVERTING SFD/SPLINESETS TO SVG PATHS ################
84
+ def _get_spline(sfd):
85
+ if 'SplineSet' not in sfd:
86
+ return ''
87
+ pro = sfd[sfd.index('SplineSet') + 10:] # 10 is the 'SplineSet'
88
+ pro = pro[:pro.index('EndSplineSet')]
89
+ return pro
90
+
91
+
92
+ def _spline_to_path_list(spline, height, replace_with_prev=False):
93
+ """Converts SplineSet to a list of tokenized commands in svg path."""
94
+ path = []
95
+ prev_xy = []
96
+ for line in spline.splitlines():
97
+ if not line:
98
+ continue
99
+ tokens = line.split(' ')
100
+ cmd = tokens[-2]
101
+ if cmd not in 'cml':
102
+ # COMMAND NOT RECOGNIZED.
103
+ return []
104
+ # assert cmd in 'cml', 'Command not recognized: {}'.format(cmd)
105
+ args = tokens[:-2]
106
+ args = [float(x) for x in args if x]
107
+
108
+ if replace_with_prev and cmd in 'c':
109
+ args[:2] = prev_xy
110
+ prev_xy = args[-2:]
111
+
112
+ new_y_args = []
113
+ for i, a in enumerate(args):
114
+ if i % 2 == 1:
115
+ new_y_args.append((height - a))
116
+ else:
117
+ new_y_args.append((a))
118
+
119
+ path.append([cmd.upper()] + new_y_args)
120
+ return path
121
+
122
+
123
+ def _sfd_to_path_list(single, replace_with_prev=False):
124
+ """Converts the given SFD glyph into a path."""
125
+ return _spline_to_path_list(_get_spline(single['sfd']), single['vwidth'], replace_with_prev)
126
+
127
+
128
+ #################### UTILS FOR PROCESSING TOKENIZED PATHS #####################
129
+ def _add_missing_cmds(path, remove_zs=False):
130
+ """Adds missing cmd tags to the commands in the svg."""
131
+ # For instance, the command 'a' takes 7 arguments, but some SVGs declare:
132
+ # a 1 2 3 4 5 6 7 8 9 10 11 12 13 14
133
+ # Which is 14 arguments. This function converts the above to the equivalent:
134
+ # a 1 2 3 4 5 6 7 a 8 9 10 11 12 13 14
135
+ #
136
+ # Note: if remove_zs is True, this also removes any occurences of z commands.
137
+ new_path = []
138
+ for cmd in path:
139
+ if not remove_zs or cmd[0] not in 'Zz':
140
+ for new_cmd in add_missing_cmd(cmd):
141
+ new_path.append(new_cmd)
142
+ return new_path
143
+
144
+
145
+ def add_missing_cmd(command_list):
146
+ """Adds missing cmd tags to the given command list."""
147
+ # E.g.: given:
148
+ # ['a', '0', '0', '0', '0', '0', '0', '0',
149
+ # '0', '0', '0', '0', '0', '0', '0']
150
+ # Converts to:
151
+ # [['a', '0', '0', '0', '0', '0', '0', '0'],
152
+ # ['a', '0', '0', '0', '0', '0', '0', '0']]
153
+ # And returns a string that joins these elements with spaces.
154
+ cmd_tag = command_list[0]
155
+ args = command_list[1:]
156
+
157
+ final_cmds = []
158
+ for arg_batch in grouper(args, NUM_ARGS[cmd_tag]):
159
+ final_cmds.append([cmd_tag] + list(arg_batch))
160
+
161
+ if not final_cmds:
162
+ # command has no args (e.g.: 'z')
163
+ final_cmds = [[cmd_tag]]
164
+
165
+ return final_cmds
166
+
167
+
168
+ def _normalize_args(arglist, norm, add=None, flip=False):
169
+ """Normalize the given args with the given norm value."""
170
+ new_arglist = []
171
+ for i, arg in enumerate(arglist):
172
+ new_arg = float(arg)
173
+
174
+ if add is not None:
175
+ add_to_x, add_to_y = add
176
+
177
+ # This argument is an x-coordinate if even, y-coordinate if odd
178
+ # except when flip == True
179
+ if i % 2 == 0:
180
+ new_arg += add_to_y if flip else add_to_x
181
+ else:
182
+ new_arg += add_to_x if flip else add_to_y
183
+
184
+ new_arglist.append(str(24 * new_arg / norm))
185
+ return new_arglist
186
+
187
+
188
+ def _normalize_based_on_viewbox(path, viewbox):
189
+ """Normalizes all args in a path to a standard 24x24 viewbox."""
190
+ # Each SVG lives in a 2D plane. The viewbox determines the region of that
191
+ # plane that gets rendered. For instance, some designers may work with a
192
+ # viewbox that's 24x24, others with one that's 100x100, etc.
193
+
194
+ # Suppose I design the the letter "h" in the Arial style using a 100x100
195
+ # viewbox (let's call it icon A). Let's suppose the icon has height 75. Then,
196
+ # I design the same character using a 20x20 viewbox (call this icon B), with
197
+ # height 15 (=75% of 20). This means that, when rendered, both icons with look
198
+ # exactly the same, but the scale of the commands each icon is using is
199
+ # different. For instance, if icon A has a command like "lineTo 100 100", the
200
+ # equivalent command in icon B will be "lineTo 20 20".
201
+
202
+ # In order to avoid this problem and bring all real values to the same scale,
203
+ # I scale all icons' commands to use a 24x24 viewbox. This function does this:
204
+ # it converts a path that exists in the given viewbox into a standard 24x24
205
+ # viewbox.
206
+ viewbox = viewbox.split(' ')
207
+ norm = max(int(viewbox[-1]), int(viewbox[-2]))
208
+
209
+ if int(viewbox[-1]) > int(viewbox[-2]):
210
+ add_to_y = 0
211
+ add_to_x = abs(int(viewbox[-1]) - int(viewbox[-2])) / 2
212
+ else:
213
+ add_to_y = abs(int(viewbox[-1]) - int(viewbox[-2])) / 2
214
+ add_to_x = 0
215
+
216
+ new_path = []
217
+ for command in path:
218
+ if command[0] == 'a':
219
+ new_path.append([command[0]] + _normalize_args(command[1:3], norm)
220
+ + command[3:6] + _normalize_args(command[6:], norm))
221
+ elif command[0] == 'A':
222
+ new_path.append([command[0]] + _normalize_args(command[1:3], norm)
223
+ + command[3:6] + _normalize_args(command[6:], norm, add=(add_to_x, add_to_y)))
224
+ elif command[0] == 'V':
225
+ new_path.append([command[0]] + _normalize_args(command[1:], norm, add=(add_to_x, add_to_y), flip=True))
226
+ elif command[0] == command[0].upper():
227
+ new_path.append([command[0]] + _normalize_args(command[1:], norm, add=(add_to_x, add_to_y)))
228
+ elif command[0] in 'zZ':
229
+ new_path.append([command[0]])
230
+ else:
231
+ new_path.append([command[0]] + _normalize_args(command[1:], norm))
232
+
233
+ return new_path
234
+
235
+
236
+ def _convert_args(args, curr_pos, cmd):
237
+ """Converts given args to relative values."""
238
+ # NOTE: glyphs only use a very small subset of commands (L, C, M, and Z -- I
239
+ # believe). So I'm not handling A and H for now.
240
+ if cmd in 'AH':
241
+ raise NotImplementedError('These commands have >6 args (not supported).')
242
+
243
+ new_args = []
244
+ for i, arg in enumerate(args):
245
+ x_or_y = i % 2
246
+ if cmd == 'H':
247
+ x_or_y = (i + 1) % 2
248
+ new_args.append(str(float(arg) - curr_pos[x_or_y]))
249
+
250
+ return new_args
251
+
252
+
253
+ def _update_curr_pos(curr_pos, cmd, start_of_path):
254
+ """Calculate the position of the pen after cmd is applied."""
255
+ if cmd[0] in 'ml':
256
+ curr_pos = [curr_pos[0] + float(cmd[1]), curr_pos[1] + float(cmd[2])]
257
+ if cmd[0] == 'm':
258
+ start_of_path = curr_pos
259
+ elif cmd[0] in 'z':
260
+ curr_pos = start_of_path
261
+ elif cmd[0] in 'h':
262
+ curr_pos = [curr_pos[0] + float(cmd[1]), curr_pos[1]]
263
+ elif cmd[0] in 'v':
264
+ curr_pos = [curr_pos[0], curr_pos[1] + float(cmd[1])]
265
+ elif cmd[0] in 'ctsqa':
266
+ curr_pos = [curr_pos[0] + float(cmd[-2]), curr_pos[1] + float(cmd[-1])]
267
+
268
+ return curr_pos, start_of_path
269
+
270
+
271
+ def _make_relative(cmds):
272
+ """Convert commands in a path to relative positioning."""
273
+ curr_pos = (0.0, 0.0)
274
+ start_of_path = (0.0, 0.0)
275
+ new_cmds = []
276
+ for cmd in cmds:
277
+ if cmd[0].lower() == cmd[0]:
278
+ new_cmd = cmd
279
+ elif cmd[0].lower() == 'z':
280
+ new_cmd = [cmd[0].lower()]
281
+ else:
282
+ new_cmd = [cmd[0].lower()] + _convert_args(cmd[1:], curr_pos, cmd=cmd[0])
283
+ new_cmds.append(new_cmd)
284
+ curr_pos, start_of_path = _update_curr_pos(curr_pos, new_cmd, start_of_path)
285
+ return new_cmds
286
+
287
+
288
+ def _is_to_left_of(pt1, pt2):
289
+ pt1_norm = (pt1[0]**2 + pt1[1]**2)
290
+ pt2_norm = (pt2[0]**2 + pt2[1]**2)
291
+ return pt1[1] < pt2[1] or (pt1_norm == pt2_norm and pt1[0] < pt2[0])
292
+
293
+
294
+ def _get_leftmost_point(path):
295
+ """Returns the leftmost, topmost point of the path."""
296
+ leftmost = (float('inf'), float('inf'))
297
+ idx = -1
298
+
299
+ for i, cmd in enumerate(path):
300
+ if len(cmd) > 1:
301
+ endpoint = cmd[-2:]
302
+ if _is_to_left_of(endpoint, leftmost):
303
+ leftmost = endpoint
304
+ idx = i
305
+
306
+ return leftmost, idx
307
+
308
+
309
+ def _separate_substructures(path):
310
+ """Returns a list of subpaths, each representing substructures the glyph."""
311
+ substructures = []
312
+ curr = []
313
+ for cmd in path:
314
+ if cmd[0] in 'mM' and curr:
315
+ substructures.append(curr)
316
+ curr = []
317
+ curr.append(cmd)
318
+ if curr:
319
+ substructures.append(curr)
320
+ return substructures
321
+
322
+
323
+ def _is_clockwise(subpath):
324
+ """Returns whether the given subpath is clockwise-oriented."""
325
+ pts = [cmd[-2:] for cmd in subpath]
326
+ det = 0
327
+ for i in range(len(pts) - 1):
328
+ det += np.linalg.det(pts[i:i + 2])
329
+ return det > 0
330
+
331
+
332
+ def _make_clockwise(subpath):
333
+ """Inverts the cardinality of the given subpath."""
334
+ new_path = [subpath[0]]
335
+ other_cmds = list(reversed(subpath[1:]))
336
+ for i, cmd in enumerate(other_cmds):
337
+ if i + 1 == len(other_cmds):
338
+ where_we_were = subpath[0][-2:]
339
+ else:
340
+ where_we_were = other_cmds[i + 1][-2:]
341
+
342
+ if len(cmd) > 3:
343
+ new_cmd = [cmd[0], cmd[3], cmd[4], cmd[1], cmd[2],
344
+ where_we_were[0], where_we_were[1]]
345
+ else:
346
+ new_cmd = [cmd[0], where_we_were[0], where_we_were[1]]
347
+
348
+ new_path.append(new_cmd)
349
+ return new_path
350
+
351
+
352
+ def _canonicalize(path):
353
+ """Makes all paths start at top left, and go clockwise first."""
354
+ # convert args to floats
355
+ #print(len(path),path)
356
+
357
+ path = [[x[0]] + list(map(float, x[1:])) for x in path]
358
+ # print(len(path),path)
359
+
360
+ # _canonicalize each subpath separately
361
+ #pdb.set_trace()
362
+
363
+ new_substructures = []
364
+ for subpath in _separate_substructures(path):
365
+ # print(subpath,"\n")
366
+ leftmost_point, leftmost_idx = _get_leftmost_point(subpath)
367
+ reordered = ([['M', leftmost_point[0], leftmost_point[1]]] + subpath[leftmost_idx + 1:] + subpath[1:leftmost_idx + 1])
368
+ new_substructures.append((reordered, leftmost_point))
369
+
370
+ # sys.exit()
371
+ new_path = []
372
+ first_substructure_done = False
373
+ should_flip_cardinality = False
374
+ for sp, _ in sorted(new_substructures, key=lambda x: (x[1][1], x[1][0])):
375
+ if not first_substructure_done:
376
+ # we're looking at the first substructure now, we can determine whether we
377
+ # will flip the cardniality of the whole icon or not
378
+ should_flip_cardinality = not _is_clockwise(sp)
379
+ first_substructure_done = True
380
+
381
+ if should_flip_cardinality:
382
+ sp = _make_clockwise(sp)
383
+
384
+ new_path.extend(sp)
385
+
386
+ # convert args to strs
387
+ path = [[x[0]] + list(map(str, x[1:])) for x in new_path]
388
+ return path
389
+
390
+
391
+ # ######### UTILS FOR CONVERTING TOKENIZED PATHS TO VECTORS ###########
392
+ def _path_to_vector(path, categorical=False):
393
+ """Converts path's commands to a series of vectors."""
394
+ # Notes:
395
+ # - The SimpleSVG dataset does not have any 't', 'q', 'Z', 'T', or 'Q'.
396
+ # Thus, we don't handle those here.
397
+ # - We also removed all 'z's.
398
+ # - The x-axis-rotation argument to a commands is always 0 in this
399
+ # dataset, so we ignore it
400
+
401
+ # Many commands have args that correspond to args in other commands.
402
+ # v __,__ _______________ ______________,_________ __,__ __,__ _,y
403
+ # h __,__ _______________ ______________,_________ __,__ __,__ x,_
404
+ # z __,__ _______________ ______________,_________ __,__ __,__ _,_
405
+ # a rx,ry x-axis-rotation large-arc-flag,sweepflag __,__ __,__ x,y
406
+ # l __,__ _______________ ______________,_________ __,__ __,__ x,y
407
+ # c __,__ _______________ ______________,_________ x1,y1 x2,y2 x,y
408
+ # m __,__ _______________ ______________,_________ __,__ __,__ x,y
409
+ # s __,__ _______________ ______________,_________ __,__ x2,y2 x,y
410
+
411
+ # So each command will be converted to a vector where the dimension is the
412
+ # minimal number of arguments to all commands:
413
+ # [rx, ry, large-arc-flag, sweepflag, x1, y1, x2, y2, x, y]
414
+ # If a command does not output a certain arg, it is set to 0.
415
+ # "l 5,5" becomes [0, 0, 0, 0, 0, 0, 0, 0, 5, 5]
416
+
417
+ # Also note, as of now we also output an extra dimension at index 0, which
418
+ # indicates which command is being outputted (integer).
419
+ new_path = []
420
+ for cmd in path:
421
+ new_path.append(_cmd_to_vector(cmd, categorical=categorical))
422
+ return new_path
423
+
424
+
425
+ def _cmd_to_vector(cmd_list, categorical=False):
426
+ """Converts the given command (given as a list) into a vector.
427
+ UM_ARGS = {'v': 1, 'V': 1, 'h': 1, 'H': 1, 'a': 7, 'A': 7, 'l': 2, 'L': 2,
428
+ 't': 2, 'T': 2, 'c': 6, 'C': 6, 'm': 2, 'M': 2, 's': 4, 'S': 4,
429
+ 'q': 4, 'Q': 4, 'z': 0}
430
+
431
+ CMDS_LIST = 'zhvmltsqcaHVMLTSQCA'
432
+ CMD_MAPPING = {cmd: i for i, cmd in enumerate(CMDS_LIST)}
433
+ """
434
+ # For description of how this conversion happens, see
435
+ # _path_to_vector docstring.
436
+ cmd = cmd_list[0]
437
+ args = cmd_list[1:]
438
+
439
+ if not categorical:
440
+ # integer, for MSE
441
+ command = [float(CMD_MAPPING[cmd])]
442
+ else:
443
+ # one hot + 1 dim for EOS.
444
+ command = [0.0] * (len(CMDS_LIST) + 1) # 大概有19个commands?
445
+ command[CMD_MAPPING[cmd] + 1] = 1.0
446
+
447
+ arguments = [0.0] * 10
448
+ if cmd in 'hH':
449
+ arguments[8] = float(args[0]) # x
450
+ elif cmd in 'vV':
451
+ arguments[9] = float(args[0]) # y
452
+ elif cmd in 'mMlLtT':
453
+ arguments[8] = float(args[0]) # x
454
+ arguments[9] = float(args[1]) # y
455
+ elif cmd in 'sSqQ':
456
+ arguments[6] = float(args[0]) # x2
457
+ arguments[7] = float(args[1]) # y2
458
+ arguments[8] = float(args[2]) # x
459
+ arguments[9] = float(args[3]) # y
460
+ elif cmd in 'cC':
461
+ arguments[4] = float(args[0]) # x1
462
+ arguments[5] = float(args[1]) # y1
463
+ arguments[6] = float(args[2]) # x2
464
+ arguments[7] = float(args[3]) # y2
465
+ arguments[8] = float(args[4]) # x
466
+ arguments[9] = float(args[5]) # y
467
+ elif cmd in 'aA':
468
+ arguments[0] = float(args[0]) # rx
469
+ arguments[1] = float(args[1]) # ry
470
+ # we skip x-axis-rotation
471
+ arguments[2] = float(args[3]) # large-arc-flag
472
+ arguments[3] = float(args[4]) # sweep-flag
473
+ # a does not have x1, y1, x2, y2 args
474
+ arguments[8] = float(args[5]) # x
475
+ arguments[9] = float(args[6]) # y
476
+
477
+ return command + arguments
478
+
479
+
480
+ ################## UTILS FOR RENDERING PATH INTO IMAGE #################
481
+ def _cubicbezier(x0, y0, x1, y1, x2, y2, x3, y3, n=40):
482
+ """Return n points along cubiz bezier with given control points."""
483
+ # from http://rosettacode.org/wiki/Bitmap/B%C3%A9zier_curves/Cubic
484
+ pts = []
485
+ for i in range(n + 1):
486
+ t = float(i) / float(n)
487
+ a = (1. - t)**3
488
+ b = 3. * t * (1. - t)**2
489
+ c = 3.0 * t**2 * (1.0 - t)
490
+ d = t**3
491
+
492
+ x = float(a * x0 + b * x1 + c * x2 + d * x3)
493
+ y = float(a * y0 + b * y1 + c * y2 + d * y3)
494
+ pts.append((x, y))
495
+ return list(zip(*pts))
496
+
497
+
498
+ def _update_pos(curr_pos, end_pos, absolute):
499
+ if absolute:
500
+ return end_pos
501
+ return curr_pos[0] + end_pos[0], curr_pos[1] + end_pos[1]
502
+
503
+
504
+ def constant_color(*unused_args):
505
+ return np.array([255, 255, 255])
506
+
507
+
508
+ def _render_cubic(canvas, curr_pos, c_args, absolute, color):
509
+ """Renders a cubic bezier curve in the given canvas."""
510
+ if not absolute:
511
+ c_args[0] += curr_pos[0]
512
+ c_args[1] += curr_pos[1]
513
+ c_args[2] += curr_pos[0]
514
+ c_args[3] += curr_pos[1]
515
+ c_args[4] += curr_pos[0]
516
+ c_args[5] += curr_pos[1]
517
+ x, y = _cubicbezier(curr_pos[0], curr_pos[1],
518
+ c_args[0], c_args[1],
519
+ c_args[2], c_args[3],
520
+ c_args[4], c_args[5])
521
+ max_possible = len(canvas)
522
+ x = [int(round(x_)) for x_ in x]
523
+ y = [int(round(y_)) for y_ in y]
524
+
525
+ def within_range(x):
526
+ return 0 <= x < max_possible
527
+
528
+ filtered = [(x_, y_) for x_, y_ in zip(x, y)
529
+ if within_range(x_) and within_range(y_)]
530
+ if not filtered:
531
+ return
532
+ x, y = list(zip(*filtered))
533
+ canvas[y, x, :] = color
534
+
535
+
536
+ def _render_line(canvas, curr_pos, l_args, absolute, color):
537
+ """Renders a line in the given canvas."""
538
+ end_point = l_args
539
+ if not absolute:
540
+ end_point[0] += curr_pos[0]
541
+ end_point[1] += curr_pos[1]
542
+ rr, cc, val = draw.line_aa(int(curr_pos[0]), int(curr_pos[1]),
543
+ int(end_point[0]), int(end_point[1]))
544
+
545
+ max_possible = len(canvas)
546
+
547
+ def within_range(x):
548
+ return 0 <= x < max_possible
549
+
550
+ filtered = [(x, y, v) for x, y, v in zip(rr, cc, val)
551
+ if within_range(x) and within_range(y)]
552
+ if not filtered:
553
+ return
554
+ rr, cc, val = list(zip(*filtered))
555
+ val = [(v * color) for v in val]
556
+ canvas[cc, rr, :] = val
557
+
558
+
559
+ def _per_step_render(path, absolute=False, color=constant_color):
560
+ """Render the icon's edges, given its path."""
561
+ def to_canvas_size(l):
562
+ return [float(f) * (64. / 24.) for f in l]
563
+
564
+ canvas = np.zeros((64, 64, 3))
565
+ curr_pos = (0.0, 0.0)
566
+ for i, cmd in enumerate(path):
567
+ if not cmd:
568
+ continue
569
+ if cmd[0] in 'mM':
570
+ curr_pos = _update_pos(curr_pos, to_canvas_size(cmd[-2:]), absolute)
571
+ elif cmd[0] in 'cC':
572
+ _render_cubic(canvas, curr_pos, to_canvas_size(cmd[1:]), absolute, color(i, 55))
573
+ curr_pos = _update_pos(curr_pos, to_canvas_size(cmd[-2:]), absolute)
574
+ elif cmd[0] in 'lL':
575
+ _render_line(canvas, curr_pos, to_canvas_size(cmd[1:]), absolute, color(i, 55))
576
+ curr_pos = _update_pos(curr_pos, to_canvas_size(cmd[1:]), absolute)
577
+
578
+ return canvas
579
+
580
+
581
+ def _zoom_out(path_list, add_baseline=0., per=22):
582
+ """Makes glyph slightly smaller in viewbox, makes some descenders visible."""
583
+ # assumes tensor is already unnormalized, and in long form
584
+ new_path = []
585
+ for command in path_list:
586
+ args = []
587
+ is_even = False
588
+ for arg in command[1:]:
589
+ if is_even:
590
+ args.append(str(float(arg) - ((24. - per) / 24.) * 64. / 4.))
591
+ is_even = False
592
+ else:
593
+ args.append(str(float(arg) - add_baseline))
594
+ is_even = True
595
+ new_path.append([command[0]] + args)
596
+ return new_path
597
+
598
+
599
+ ##################### UTILS FOR PROCESSING VECTORS ################
600
+ def _append_eos(sample, categorical, feature_dim):
601
+ if not categorical:
602
+ eos = -1 * np.ones(feature_dim)
603
+ else:
604
+ eos = np.zeros(feature_dim)
605
+ eos[0] = 1.0
606
+ sample.append(eos)
607
+ return sample
608
+
609
+
610
+ def _make_simple_cmds_long(out):
611
+ """Converts svg decoder output to format required by some render functions."""
612
+ # out has 10 dims
613
+ # the first 4 are respectively dims 0, 4, 5, 9 of the full 20-dim onehot vec
614
+ # the latter 6 are the 6 last dims of the 10-dim arg vec
615
+ shape_minus_dim = list(np.shape(out))[:-1]
616
+ # print("make? ",shape_minus_dim ) # [51]
617
+
618
+ return np.concatenate([out[..., :1], # [51,1] 51个steps的第1维特征
619
+ np.zeros(shape_minus_dim + [3]),# [51,3]
620
+ out[..., 1:3], #[51,2]
621
+ np.zeros(shape_minus_dim + [3]),# [51,3]
622
+ out[..., 3:4],# [51,1]
623
+ np.zeros(shape_minus_dim + [14]),# [51,14]
624
+ out[..., 4:]], -1)# [51,6] # 最后的6个绘制参数
625
+
626
+ def render(tensor, data_dir=None):
627
+ """Converts SVG decoder output into HTML svg."""
628
+ # undo normalization
629
+ # mean_npz, stdev_npz = get_means_stdevs(data_dir)
630
+ # tensor = (tensor * stdev_npz) + mean_npz
631
+
632
+ # convert to html
633
+ #print("before",tensor.shape)# 51, 10)
634
+ tensor = _make_simple_cmds_long(tensor)
635
+ # print("after",tensor.shape)#(51, 30)
636
+ # vector = np.squeeze(np.squeeze(tensor, 0), 2)
637
+ # print("1",tensor[0,:5])# (51, 30)
638
+ html = _vector_to_svg(tensor, stop_at_eos=True, categorical=True)
639
+ # print(html.shape)
640
+ # some aesthetic postprocessing
641
+ html = postprocess(html)
642
+ html = html.replace('256px', '50px')
643
+
644
+ return html
645
+
646
+ ################# UTILS FOR CONVERTING VECTORS TO SVGS ########################
647
+ #note: transform the decoded trg_seq into the common svg format.把decode出来的seq转成html的svg,命令有前后关系,也都是相对位置。
648
+ def _vector_to_svg(vectors, stop_at_eos=False, categorical=False):
649
+ """Tranforms a given vector to an svg string.
650
+
651
+ """
652
+ new_path = []
653
+ for vector in vectors:
654
+ if stop_at_eos:
655
+ if categorical:
656
+ try:
657
+ is_eos = np.argmax(vector[:len(CMDS_LIST) + 1]) == 0
658
+ except Exception:
659
+ raise Exception(vector)
660
+ else:
661
+ is_eos = vector[0] < -0.5
662
+
663
+ if is_eos:
664
+ break
665
+ new_path.append(' '.join(_vector_to_cmd(vector, categorical=categorical))) #
666
+ new_path = ' '.join(new_path) # 加入new_path,每个path都以空格分隔
667
+ return SVG_PREFIX_BIG + PATH_PREFIX_1 + new_path + PATH_POSFIX_1 + SVG_POSFIX
668
+
669
+ def _vector_to_path(vectors):
670
+ new_path = []
671
+ for vector in vectors:
672
+ #print(vector,"???")
673
+ new_path.append(_vector_to_cmd(vector,categorical=True)) #
674
+ #print(_vector_to_cmd(vector),"hhh")
675
+ # new_path = ' '.join(new_path) # 加入new_path,每个path都以空格分隔
676
+ return new_path
677
+
678
+ def _vector_to_cmd(vector, categorical=False, return_floats=False):
679
+ """Does the inverse transformation as _cmd_to_vector().
680
+ UM_ARGS = {'v': 1, 'V': 1, 'h': 1, 'H': 1, 'a': 7, 'A': 7, 'l': 2, 'L': 2,
681
+ 't': 2, 'T': 2, 'c': 6, 'C': 6, 'm': 2, 'M': 2, 's': 4, 'S': 4,
682
+ 'q': 4, 'Q': 4, 'z': 0}
683
+
684
+ CMDS_LIST = 'zhvmltsqcaHVMLTSQCA'
685
+ CMD_MAPPING = {cmd: i for i, cmd in enumerate(CMDS_LIST)}
686
+
687
+ """
688
+ cast_fn = float if return_floats else str
689
+ if categorical:
690
+ # print(vector.shape,vector)# 30
691
+ #print("??",len(CMDS_LIST)) # 19
692
+ command = vector[:len(CMDS_LIST) + 1],# 前20维
693
+ arguments = vector[len(CMDS_LIST) + 1:]# 后10维
694
+ cmd_idx = np.argmax(command) - 1 # 看当前绘制命令属于哪一类
695
+
696
+ else:
697
+
698
+ command, arguments = vector[:1], vector[1:]
699
+ cmd_idx = int(round(command[0]))
700
+
701
+ if cmd_idx < -0.5:
702
+ # EOS
703
+ return []
704
+ if cmd_idx >= len(CMDS_LIST):
705
+ cmd_idx = len(CMDS_LIST) - 1
706
+
707
+ cmd = CMDS_LIST[cmd_idx]
708
+ cmd = cmd.upper()
709
+ cmd_list = [cmd]
710
+
711
+ if cmd in 'hH': # 如果是画线,而且是x轴
712
+ cmd_list.append(cast_fn(arguments[8])) # x
713
+ elif cmd in 'vV': # 如果是画线,而且是y轴
714
+ cmd_list.append(cast_fn(arguments[9])) # y
715
+ elif cmd in 'mMlLtT':
716
+ cmd_list.append(cast_fn(arguments[8])) # x
717
+ cmd_list.append(cast_fn(arguments[9])) # y
718
+ elif cmd in 'sSqQ':
719
+ cmd_list.append(cast_fn(arguments[6])) # x2
720
+ cmd_list.append(cast_fn(arguments[7])) # y2
721
+ cmd_list.append(cast_fn(arguments[8])) # x
722
+ cmd_list.append(cast_fn(arguments[9])) # y
723
+ elif cmd in 'cC':
724
+ cmd_list.append(cast_fn(arguments[4])) # x1
725
+ cmd_list.append(cast_fn(arguments[5])) # y1
726
+ cmd_list.append(cast_fn(arguments[6])) # x2
727
+ cmd_list.append(cast_fn(arguments[7])) # y2
728
+ cmd_list.append(cast_fn(arguments[8])) # x
729
+ cmd_list.append(cast_fn(arguments[9])) # y
730
+ elif cmd in 'aA':
731
+ cmd_list.append(cast_fn(arguments[0])) # rx
732
+ cmd_list.append(cast_fn(arguments[1])) # ry
733
+ # x-axis-rotation is always 0
734
+ cmd_list.append(cast_fn('0'))
735
+ # the following two flags are binary.
736
+ cmd_list.append(cast_fn(1 if arguments[2] > 0.5 else 0)) # large-arc-flag
737
+ cmd_list.append(cast_fn(1 if arguments[3] > 0.5 else 0)) # sweep-flag
738
+ cmd_list.append(cast_fn(arguments[8])) # x
739
+ cmd_list.append(cast_fn(arguments[9])) # y
740
+
741
+ return cmd_list
742
+
743
+
744
+ ############## UTILS FOR CONVERTING SVGS/VECTORS TO IMAGES ###################
745
+
746
+ # From Infer notebook
747
+ start = ("""<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www."""
748
+ """w3.org/1999/xlink" width="256px" height="256px" style="-ms-trans"""
749
+ """form: rotate(360deg); -webkit-transform: rotate(360deg); transfo"""
750
+ """rm: rotate(360deg);" preserveAspectRatio="xMidYMid meet" viewBox"""
751
+ """="0 0 24 24"><path d=\"""")
752
+ end = """\" fill="currentColor"/></svg>"""
753
+
754
+ COMMAND_RX = re.compile("([MmLlHhVvCcSsQqTtAaZz])")
755
+ FLOAT_RX = re.compile("[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?") # noqa
756
+
757
+
758
+ def svg_html_to_path_string(svg):
759
+ return svg.replace(start, '').replace(end, '')
760
+
761
+
762
+ def _tokenize(pathdef):
763
+ """Returns each svg token from path list."""
764
+ # e.g.: 'm0.1-.5c0,6' -> m', '0.1, '-.5', 'c', '0', '6'
765
+ for x in COMMAND_RX.split(pathdef):
766
+ if x != '' and x in 'MmLlHhVvCcSsQqTtAaZz':
767
+ yield x
768
+ for token in FLOAT_RX.findall(x):
769
+ yield token
770
+
771
+
772
+ def path_string_to_tokenized_commands(path):
773
+ """Tokenizes the given path string.
774
+
775
+ E.g.:
776
+ Given M 0.5 0.5 l 0.25 0.25 z
777
+ Returns [['M', '0.5', '0.5'], ['l', '0.25', '0.25'], ['z']]
778
+ """
779
+ new_path = []
780
+ current_cmd = []
781
+ for token in _tokenize(path):
782
+ if len(current_cmd) > 0:
783
+ if token in 'MmLlHhVvCcSsQqTtAaZz':
784
+ # cmd ended, convert to vector and add to new_path
785
+ new_path.append(current_cmd)
786
+ current_cmd = [token]
787
+ else:
788
+ # add arg to command
789
+ current_cmd.append(token)
790
+ else:
791
+ # add to start new cmd
792
+ current_cmd.append(token)
793
+
794
+ if current_cmd:
795
+ # process command still unprocessed
796
+ new_path.append(current_cmd)
797
+
798
+ return new_path
799
+
800
+
801
+ def separate_substructures(tokenized_commands):
802
+ """Returns a list of SVG substructures."""
803
+ # every moveTo command starts a new substructure
804
+ # an SVG substructure is a subpath that closes on itself
805
+ # such as the outter and the inner edge of the character `o`
806
+ substructures = []
807
+ curr = []
808
+ for cmd in tokenized_commands:
809
+ if cmd[0] in 'mM' and len(curr) > 0:
810
+ substructures.append(curr)
811
+ curr = []
812
+ curr.append(cmd)
813
+ if len(curr) > 0:
814
+ substructures.append(curr)
815
+ return substructures
816
+
817
+
818
+ def postprocess(svg, dist_thresh=2., skip=False):
819
+ path = svg_html_to_path_string(svg)
820
+ #print(svg)
821
+ svg_template = svg.replace(path, '{}')
822
+ tokenized_commands = path_string_to_tokenized_commands(path)
823
+
824
+ def dist(a, b):
825
+ return np.sqrt((float(a[0]) - float(b[0]))**2 + (float(a[1]) - float(b[1]))**2)
826
+
827
+ def are_close_together(a, b, t):
828
+ return dist(a, b) < t
829
+
830
+ # first, go through each start/end point and merge if they're close enough
831
+ # together (that is, make end point the same as the start point).
832
+ # TODO: there are better ways of doing this, in a way that propagates errors.
833
+ # back (so if total error is 0.2, go through all N commands in this
834
+ # substructure and fix each by 0.2/N (unless they have 0 vertical change))
835
+ # NOTE: this is the same.
836
+ substructures = separate_substructures(tokenized_commands)
837
+ # print(len(substructures))# 7578
838
+
839
+ previous_substructure_endpoint = (0., 0.,)
840
+ for substructure in substructures:
841
+ # first, if the last substructure's endpoint was updated, we must update
842
+ # the start point of this one to reflect the opposite update
843
+ substructure[0][-2] = str(float(substructure[0][-2]) -
844
+ previous_substructure_endpoint[0])
845
+ substructure[0][-1] = str(float(substructure[0][-1]) -
846
+ previous_substructure_endpoint[1])
847
+
848
+ start = list(map(float, substructure[0][-2:]))
849
+ curr_pos = (0., 0.)
850
+ for cmd in substructure:
851
+ curr_pos, _ = _update_curr_pos(curr_pos, cmd, (0., 0.))
852
+ if are_close_together(start, curr_pos, dist_thresh):
853
+ new_point = np.array(start)
854
+ previous_substructure_endpoint = ((new_point[0] - curr_pos[0]),
855
+ (new_point[1] - curr_pos[1]))
856
+ substructure[-1][-2] = str(float(substructure[-1][-2]) +
857
+ (new_point[0] - curr_pos[0]))
858
+ substructure[-1][-1] = str(float(substructure[-1][-1]) +
859
+ (new_point[1] - curr_pos[1]))
860
+ if substructure[-1][0] in 'cC':
861
+ substructure[-1][-4] = str(float(substructure[-1][-4]) +
862
+ (new_point[0] - curr_pos[0]))
863
+ substructure[-1][-3] = str(float(substructure[-1][-3]) +
864
+ (new_point[1] - curr_pos[1]))
865
+
866
+ if skip:
867
+ return svg_template.format(' '.join([' '.join(' '.join(cmd) for cmd in s)
868
+ for s in substructures]))
869
+
870
+ def cosa(x, y):
871
+ return (x[0] * y[0] + x[1] * y[1]) / ((np.sqrt(x[0]**2 + x[1]**2) * np.sqrt(y[0]**2 + y[1]**2)))
872
+
873
+ def rotate(a, x, y):
874
+ return (x * np.cos(a) - y * np.sin(a), y * np.cos(a) + x * np.sin(a))
875
+ # second, gotta find adjacent bezier curves and, if their control points
876
+ # are well enough aligned, fully align them
877
+ for substructure in substructures:
878
+ curr_pos = (0., 0.)
879
+ new_curr_pos, _ = _update_curr_pos((0., 0.,), substructure[0], (0., 0.))
880
+
881
+ for cmd_idx in range(1, len(substructure)):
882
+ prev_cmd = substructure[cmd_idx-1]
883
+ cmd = substructure[cmd_idx]
884
+
885
+ new_new_curr_pos, _ = _update_curr_pos(new_curr_pos, cmd, (0., 0.))
886
+
887
+ if cmd[0] == 'c':
888
+ if prev_cmd[0] == 'c':
889
+ # check the vectors and update if needed
890
+ # previous control pt wrt new curr point
891
+ prev_ctr_point = (curr_pos[0] + float(prev_cmd[3]) - new_curr_pos[0],
892
+ curr_pos[1] + float(prev_cmd[4]) - new_curr_pos[1])
893
+ ctr_point = (float(cmd[1]), float(cmd[2]))
894
+
895
+ if -1. < cosa(prev_ctr_point, ctr_point) < -0.95:
896
+ # calculate exact angle between the two vectors
897
+ angle_diff = (np.pi - np.arccos(cosa(prev_ctr_point, ctr_point)))/2
898
+
899
+ # rotate each vector by angle/2 in the correct direction for each.
900
+ sign = np.sign(np.cross(prev_ctr_point, ctr_point))
901
+ new_ctr_point = rotate(sign * angle_diff, *ctr_point)
902
+ new_prev_ctr_point = rotate(-sign * angle_diff, *prev_ctr_point)
903
+
904
+ # override the previous control points
905
+ # (which has to be wrt previous curr position)
906
+ substructure[cmd_idx-1][3] = str(new_prev_ctr_point[0] -
907
+ curr_pos[0] + new_curr_pos[0])
908
+ substructure[cmd_idx-1][4] = str(new_prev_ctr_point[1] -
909
+ curr_pos[1] + new_curr_pos[1])
910
+ substructure[cmd_idx][1] = str(new_ctr_point[0])
911
+ substructure[cmd_idx][2] = str(new_ctr_point[1])
912
+
913
+ curr_pos = new_curr_pos
914
+ new_curr_pos = new_new_curr_pos
915
+
916
+ # print('0',substructures)
917
+ return svg_template.format(' '.join([' '.join(' '.join(cmd) for cmd in s)
918
+ for s in substructures]))
919
+
920
+
921
+
922
+
923
+
924
+ # def get_means_stdevs(data_dir):
925
+ # """Returns the means and stdev saved in data_dir."""
926
+ # if data_dir not in means_stdevs:
927
+ # with tf.gfile.Open(os.path.join(data_dir, 'mean.npz'), 'r') as f:
928
+ # mean_npz = np.load(f)
929
+ # with tf.gfile.Open(os.path.join(data_dir, 'stdev.npz'), 'r') as f:
930
+ # stdev_npz = np.load(f)
931
+ # means_stdevs[data_dir] = (mean_npz, stdev_npz)
932
+ # return means_stdevs[data_dir]
933
+
934
+
935
+
936
+
937
+ ###############
938
+
939
+
940
+ def convert_to_svg(decoder_output, categorical=False):
941
+ converted = []
942
+ for example in decoder_output:
943
+ converted.append(_vector_to_svg(example, True, categorical=categorical))
944
+ return np.array(converted)
945
+
946
+
947
+ def create_image_conversion_fn(max_outputs, categorical=False):
948
+ """Binds the number of outputs to the image conversion fn (to svg or png)."""
949
+ def convert_to_svg(decoder_output):
950
+ converted = []
951
+ for example in decoder_output:
952
+ if len(converted) == max_outputs:
953
+ break
954
+ converted.append(_vector_to_svg(example, True, categorical=categorical))
955
+ return np.array(converted)
956
+
957
+ return convert_to_svg
958
+
959
+
960
+ ################### UTILS FOR CREATING TF SUMMARIES ##########################
961
+ def _make_encoded_image(img_tensor):
962
+ pil_img = Image.fromarray(np.squeeze(img_tensor * 255).astype(np.uint8), mode='L')
963
+ buff = io.BytesIO()
964
+ pil_img.save(buff, format='png')
965
+ encoded_image = buff.getvalue()
966
+ return encoded_image
967
+
968
+
969
+ ################### CHECK GLYPH/PATH VALID ##############################################
970
+ def is_valid_glyph(g):
971
+ is_09 = 48 <= g['uni'] <= 57
972
+ is_capital_az = 65 <= g['uni'] <= 90
973
+ is_az = 97 <= g['uni'] <= 122
974
+ is_valid_dims = g['width'] != 0 and g['vwidth'] != 0
975
+ return (is_09 or is_capital_az or is_az) and is_valid_dims
976
+
977
+
978
+ def is_valid_path(pathunibfp):
979
+ # print(len(pathunibfp[0]))
980
+ if len(pathunibfp[0])>70:
981
+ print("!!!more than 400",len(pathunibfp[0]))
982
+ # sys.exit()
983
+ return pathunibfp[0] and len(pathunibfp[0]) <= 70,len(pathunibfp[0])
984
+
985
+
986
+ ################### DATASET PROCESSING #######################################
987
+ def convert_to_path(g):
988
+ """Converts SplineSet in SFD font to str path."""
989
+ path = _sfd_to_path_list(g)
990
+ path = _add_missing_cmds(path, remove_zs=False)
991
+ path = _normalize_based_on_viewbox(path, '0 0 {} {}'.format(g['width'], g['vwidth']))
992
+ return path, g['uni'], g['binary_fp']
993
+ def convert_simple_vector_to_path(seq):
994
+ path=[]
995
+ for i in range(seq.shape[0]):
996
+ # seq_i = seq[i]
997
+ path_i=[]
998
+ cmd = np.argmax(seq[i][:4])
999
+ # args = seq[i][4:]
1000
+ p0 = seq[i][4:6]
1001
+ p1 = seq[i][6:8]
1002
+ p2 = seq[i][8:10]
1003
+ if cmd == 0:
1004
+ break
1005
+ elif cmd==1:
1006
+ path_i.append('M')
1007
+ path_i.append(str(p2[0]))
1008
+ path_i.append(str(p2[1]))
1009
+ elif cmd==2:
1010
+ path_i.append('L')
1011
+ path_i.append(str(p2[0]))
1012
+ path_i.append(str(p2[1]))
1013
+ elif cmd==3:
1014
+ path_i.append('C')
1015
+ path_i.append(str(p0[0]))
1016
+ path_i.append(str(p0[1]))
1017
+ path_i.append(str(p1[0]))
1018
+ path_i.append(str(p1[1]))
1019
+ path_i.append(str(p2[0]))
1020
+ path_i.append(str(p2[1]))
1021
+ else:
1022
+ print("wrong!!! to path")
1023
+ sys.exit()
1024
+ path.append(path_i)
1025
+ return path
1026
+ # print("jjj")
1027
+ def clockwise(seq):
1028
+ #pdb.set_trace()
1029
+ path=convert_simple_vector_to_path(seq)
1030
+ path = _canonicalize(path)
1031
+ final = {}
1032
+ final['rendered'] = _per_step_render(path, absolute=True)
1033
+ vector = _path_to_vector(path, categorical=True)
1034
+ vector = np.array(vector)
1035
+ # print(vector.shape,vector[:,9])# note vector: 12,30
1036
+
1037
+ vector = np.concatenate([np.take(vector, [0, 4, 5, 9], axis=-1), vector[..., -6:]], axis=-1)
1038
+ final['seq_len'] = np.shape(vector)[0]
1039
+ vector = _append_eos(vector.tolist(), True, 10)
1040
+ final['sequence'] = np.concatenate((vector, np.zeros(((70 - final['seq_len']), 10))), 0)
1041
+ final['rendered'] = np.reshape(final['rendered'][..., 0], [64 * 64]).astype(np.float32).tolist()
1042
+ return final
1043
+
1044
+
1045
+ def create_example(pathunibfp):
1046
+ """Bulk of dataset processing. Converts str path to np array"""
1047
+ path, uni, binary_fp = pathunibfp
1048
+ final = {}
1049
+
1050
+ # zoom out
1051
+ path = _zoom_out(path)
1052
+ # make clockwise
1053
+ path = _canonicalize(path)
1054
+
1055
+ # render path for training
1056
+ final['rendered'] = _per_step_render(path, absolute=True)
1057
+
1058
+ # make path relative
1059
+ #path = _make_relative(path) # note 不rela 直接是绝对的
1060
+ # convert to vector
1061
+ vector = _path_to_vector(path, categorical=True)
1062
+
1063
+
1064
+
1065
+ # path2 = _vector_to_path(vector)# note vector转成path
1066
+ #print(path2)
1067
+ #print(path==path2)
1068
+
1069
+ vector = np.array(vector)
1070
+ # print(vector.shape,vector[:,9])# note vector: 12,30
1071
+ vector = np.concatenate([np.take(vector, [0, 4, 5, 9], axis=-1), vector[..., -6:]], axis=-1)
1072
+
1073
+
1074
+
1075
+ #path2 = _vector_to_path(vector)
1076
+ #print(path,"\nhhh",path2)
1077
+
1078
+
1079
+ # print("hhh",vector)
1080
+ # print(render(vector))
1081
+ # sys.exit()
1082
+ # count some stats
1083
+ final['seq_len'] = np.shape(vector)[0]
1084
+ # final['class'] = int(_map_uni_to_alphanum(uni))
1085
+ final['class'] = int(_map_uni_to_alpha(uni))
1086
+ final['binary_fp'] = str(binary_fp)
1087
+
1088
+ # append eos
1089
+ vector = _append_eos(vector.tolist(), True, 10)
1090
+
1091
+ # pad path to 51 (with eos)
1092
+ # pdb.set_trace()
1093
+ # if final['seq_len']>50:
1094
+
1095
+ # print( final['seq_len'])
1096
+ final['sequence'] = np.concatenate((vector, np.zeros(((70 - final['seq_len']), 10))), 0)
1097
+ #seq = final['sequence']
1098
+
1099
+ # new_path = convert_simple_vector_to_path(seq)
1100
+ # print(new_path,path2==new_path)
1101
+
1102
+ # sys.exit()
1103
+ # make pure list:
1104
+ # use last channel only
1105
+ final['rendered'] = np.reshape(final['rendered'][..., 0], [64 * 64]).astype(np.float32).tolist()
1106
+ final['sequence'] = np.reshape(final['sequence'], [71 * 10]).astype(np.float32).tolist()
1107
+ final['class'] = np.reshape(final['class'], [1]).astype(np.int64).tolist()
1108
+ final['seq_len'] = np.reshape(final['seq_len'], [1]).astype(np.int64).tolist()
1109
+ return final
1110
+
1111
+
1112
+ def mean_to_example(mean_stdev):
1113
+ """Converts the found mean and stdev to example."""
1114
+ # mean_stdev is a dict
1115
+ mean_stdev['mean'] = np.reshape(mean_stdev['mean'], [10]).astype(np.float32).tolist()
1116
+ mean_stdev['variance'] = np.reshape(mean_stdev['variance'], [10]).astype(np.float32).tolist()
1117
+ mean_stdev['stddev'] = np.reshape(mean_stdev['stddev'], [10]).astype(np.float32).tolist()
1118
+ mean_stdev['count'] = np.reshape(mean_stdev['count'], [1]).astype(np.int64).tolist()
1119
+ return mean_stdev
1120
+
1121
+
1122
+ ################### CHECK VALID ##############################################
1123
+ class MeanStddev:
1124
+ """Accumulator to compute the mean/stdev of svg commands."""
1125
+
1126
+ def create_accumulator(self):
1127
+ curr_sum = np.zeros([10])
1128
+ sum_sq = np.zeros([10])
1129
+ return (curr_sum, sum_sq, 0) # x, x^2, count
1130
+
1131
+ def add_input(self, sum_count, new_input):
1132
+ (curr_sum, sum_sq, count) = sum_count
1133
+ # new_input is a dict with keys = ['seq_len', 'sequence']
1134
+ new_seq_len = new_input['seq_len'][0] # Line #754 'seq_len' is a list of one int
1135
+ assert isinstance(new_seq_len, int), print(type(new_seq_len))
1136
+
1137
+ # remove padding and eos from sequence
1138
+ assert isinstance(new_input['sequence'], list), print(type(new_input['sequence']))
1139
+ new_input_np = np.reshape(np.array(new_input['sequence']), [-1, 10])
1140
+ assert isinstance(new_input_np, np.ndarray), print(type())
1141
+ assert new_input_np.shape[0] >= new_seq_len
1142
+ new_input_np = new_input_np[:new_seq_len, :]
1143
+
1144
+ # accumulate new_sum and new_sum_sq
1145
+ new_sum = np.sum([curr_sum, np.sum(new_input_np, axis=0)], axis=0)
1146
+ new_sum_sq = np.sum([sum_sq, np.sum(np.power(new_input_np, 2), axis=0)],
1147
+ axis=0)
1148
+ return new_sum, new_sum_sq, count + new_seq_len
1149
+
1150
+ def merge_accumulators(self, accumulators):
1151
+ curr_sums, sum_sqs, counts = list(zip(*accumulators))
1152
+ return np.sum(curr_sums, axis=0), np.sum(sum_sqs, axis=0), np.sum(counts)
1153
+
1154
+ def extract_output(self, sum_count):
1155
+ (curr_sum, curr_sum_sq, count) = sum_count
1156
+ if count:
1157
+ mean = np.divide(curr_sum, count)
1158
+ variance = np.divide(curr_sum_sq, count) - np.power(mean, 2)
1159
+ # -ve value could happen due to rounding
1160
+ variance = np.max([variance, np.zeros(np.shape(variance))], axis=0)
1161
+ stddev = np.sqrt(variance)
1162
+ return {
1163
+ 'mean': mean,
1164
+ 'variance': variance,
1165
+ 'stddev': stddev,
1166
+ 'count': count
1167
+ }
1168
+ else:
1169
+ return {
1170
+ 'mean': float('NaN'),
1171
+ 'variance': float('NaN'),
1172
+ 'stddev': float('NaN'),
1173
+ 'count': 0
1174
+ }
data_utils/write_data_to_dirs.py CHANGED
@@ -1,231 +1,231 @@
1
- import argparse
2
- import multiprocessing as mp
3
- import os
4
- import pickle
5
- import numpy as np
6
- from data_utils import svg_utils
7
- from tqdm import tqdm
8
-
9
- def exist_empty_imgs(imgs_array, num_chars):
10
- for char_id in range(num_chars):
11
- print(np.max(imgs_array[char_id]))
12
- input()
13
- if np.max(imgs_array[char_id]) == 0:
14
- return True
15
- return False
16
-
17
- def create_db(opts, output_path, log_path):
18
- charset = open(f"{opts.data_path}/char_set/{opts.language}.txt", 'r').read()
19
- print("Process sfd to npy files in dirs....")
20
- sdf_path = os.path.join(opts.sfd_path, opts.language, opts.split)
21
- all_font_ids = sorted(os.listdir(sdf_path))
22
- num_fonts = len(all_font_ids)
23
- num_fonts_w = len(str(num_fonts))
24
- print(f"Number {opts.split} fonts before processing", num_fonts)
25
- num_processes = mp.cpu_count() - 1
26
- fonts_per_process = num_fonts // num_processes + 1
27
- num_chars = len(charset)
28
- num_chars_w = len(str(num_chars))
29
-
30
-
31
- # import ipdb; ipdb.set_trace()
32
-
33
- def process(process_id):
34
- valid_chars = []
35
- invalid_path = []
36
- invalid_glypts = []
37
-
38
- cur_process_log_file = open(os.path.join(log_path, f'log_{opts.split}_{process_id}.txt'), 'w')
39
- for i in tqdm(range(process_id * fonts_per_process, (process_id + 1) * fonts_per_process)):
40
- if i >= num_fonts:
41
- break
42
-
43
- font_id = all_font_ids[i]
44
- cur_font_sfd_dir = os.path.join(sdf_path, font_id)
45
- cur_font_glyphs = []
46
-
47
- if not os.path.exists(os.path.join(cur_font_sfd_dir, 'imgs_' + str(opts.img_size) + '.npy')):
48
- continue
49
-
50
- # a whole font as an entry
51
- for char_id in range(num_chars):
52
- # print('char_id :',char_id)
53
- if not os.path.exists(os.path.join(cur_font_sfd_dir, '{}_{num:0{width}}.sfd'.format(font_id, num=char_id, width=num_chars_w))):
54
- break
55
-
56
- char_desp_f = open(os.path.join(cur_font_sfd_dir, '{}_{num:0{width}}.txt'.format(font_id, num=char_id, width=num_chars_w)), 'r')
57
- char_desp = char_desp_f.readlines()
58
- sfd_f = open(os.path.join(cur_font_sfd_dir, '{}_{num:0{width}}.sfd'.format(font_id, num=char_id, width=num_chars_w)), 'r')
59
- sfd = sfd_f.read()
60
-
61
- uni = int(char_desp[0].strip())
62
- width = int(char_desp[1].strip())
63
- vwidth = int(char_desp[2].strip())
64
- char_idx = char_desp[3].strip()
65
- font_idx = char_desp[4].strip()
66
-
67
- cur_glyph = {}
68
- cur_glyph['uni'] = uni
69
- cur_glyph['width'] = width
70
- cur_glyph['vwidth'] = vwidth
71
- cur_glyph['sfd'] = sfd
72
- cur_glyph['id'] = char_idx
73
- cur_glyph['binary_fp'] = font_idx
74
-
75
- if not svg_utils.is_valid_glyph(cur_glyph):
76
- msg = f"font {font_idx}, char {char_idx} is not a valid glyph\n"
77
- invalid_path.glypts([font_idx, int(char_idx), charset[int(char_idx)]])
78
- cur_process_log_file.write(msg)
79
- char_desp_f.close()
80
- sfd_f.close()
81
- # use the font whose all glyphs are valid
82
- break
83
- pathunibfp = svg_utils.convert_to_path(cur_glyph)
84
-
85
- if not svg_utils.is_valid_path(pathunibfp):
86
- msg = f"font {font_idx}, char {char_idx}'s sfd is not a valid path\n"
87
- invalid_path.append([font_idx, int(char_idx), charset[int(char_idx)]])
88
- cur_process_log_file.write(msg)
89
- char_desp_f.close()
90
- sfd_f.close()
91
- break
92
- valid_chars.append([font_idx, int(char_idx), charset[int(char_idx)]])
93
- example = svg_utils.create_example(pathunibfp)
94
-
95
- cur_font_glyphs.append(example)
96
- char_desp_f.close()
97
- sfd_f.close()
98
-
99
- if len(cur_font_glyphs) == num_chars:
100
- # use the font whose all glyphs are valid
101
- # merge the whole font
102
-
103
- rendered = np.load(os.path.join(cur_font_sfd_dir, 'imgs_' + str(opts.img_size) + '.npy'))
104
-
105
- if (rendered[0] == rendered[1]).all() == True:
106
- continue
107
-
108
- sequence = []
109
- seq_len = []
110
- binaryfp = []
111
- char_class = []
112
- for char_id in range(num_chars):
113
- example = cur_font_glyphs[char_id]
114
- sequence.append(example['sequence'])
115
- seq_len.append(example['seq_len'])
116
- char_class.append(example['class'])
117
- binaryfp = example['binary_fp']
118
- if not os.path.exists(os.path.join(output_path, '{num:0{width}}'.format(num=i, width=num_fonts_w))):
119
- os.mkdir(os.path.join(output_path, '{num:0{width}}'.format(num=i, width=num_fonts_w)))
120
-
121
- np.save(os.path.join(output_path, '{num:0{width}}'.format(num=i, width=num_fonts_w), 'sequence.npy'), np.array(sequence))
122
- np.save(os.path.join(output_path, '{num:0{width}}'.format(num=i, width=num_fonts_w), 'seq_len.npy'), np.array(seq_len))
123
- np.save(os.path.join(output_path, '{num:0{width}}'.format(num=i, width=num_fonts_w), 'class.npy'), np.array(char_class))
124
- np.save(os.path.join(output_path, '{num:0{width}}'.format(num=i, width=num_fonts_w), 'font_id.npy'), np.array(binaryfp))
125
- np.save(os.path.join(output_path, '{num:0{width}}'.format(num=i, width=num_fonts_w), 'rendered_' + str(opts.img_size) + '.npy'), rendered)
126
-
127
- print("valid_chars", len(valid_chars))
128
- print("invalid_path:", invalid_path)
129
- print("invalid_glypts:",invalid_glypts)
130
-
131
- processes = [mp.Process(target=process, args=[pid]) for pid in range(num_processes)]
132
-
133
- for p in processes:
134
- p.start()
135
- for p in processes:
136
- p.join()
137
-
138
- print("Finished processing all sfd files, logs (invalid glyphs and paths) are saved to", log_path)
139
-
140
-
141
- def cal_mean_stddev(opts, output_path):
142
- print("Calculating all glyphs' mean stddev ....")
143
- charset = open(f"{opts.data_path}/char_set/{opts.language}.txt", 'r').read()
144
- font_paths = []
145
- for root, dirs, files in os.walk(output_path):
146
- for dir_name in dirs:
147
- font_paths.append(os.path.join(output_path, dir_name))
148
-
149
- font_paths.sort()
150
- num_fonts = len(font_paths)
151
- num_processes = mp.cpu_count() - 1
152
- fonts_per_process = num_fonts // num_processes + 1
153
- num_chars = len(charset)
154
- manager = mp.Manager()
155
- return_dict = manager.dict()
156
- main_stddev_accum = svg_utils.MeanStddev()
157
- print(main_stddev_accum)
158
-
159
- def process(process_id, return_dict):
160
- mean_stddev_accum = svg_utils.MeanStddev()
161
- cur_sum_count = mean_stddev_accum.create_accumulator()
162
- for i in range(process_id * fonts_per_process, (process_id + 1) * fonts_per_process):
163
- if i >= num_fonts:
164
- break
165
- cur_font_path = font_paths[i]
166
- for charid in range(num_chars):
167
- cur_font_char = {}
168
- cur_font_char['seq_len'] = np.load(os.path.join(cur_font_path, 'seq_len.npy')).tolist()[charid]
169
- cur_font_char['sequence'] = np.load(os.path.join(cur_font_path, 'sequence.npy')).tolist()[charid]
170
- # print(cur_font_char)
171
- cur_sum_count = mean_stddev_accum.add_input(cur_sum_count, cur_font_char)
172
- return_dict[process_id] = cur_sum_count
173
- processes = [mp.Process(target=process, args=[pid, return_dict]) for pid in range(num_processes)]
174
-
175
- for p in processes:
176
- p.start()
177
- for p in processes:
178
- p.join()
179
-
180
- merged_sum_count = main_stddev_accum.merge_accumulators(return_dict.values())
181
- output = main_stddev_accum.extract_output(merged_sum_count)
182
- print('output :', output)
183
- mean = output['mean']
184
- stdev = output['stddev']
185
- print('mean :', mean)
186
- mean = np.concatenate((np.zeros([4]), mean[4:]), axis=0)
187
- stdev = np.concatenate((np.ones([4]), stdev[4:]), axis=0)
188
- # finally, save the mean and stddev files
189
- output_path_ = os.path.join(opts.output_path, opts.language)
190
- np.save(os.path.join(output_path_, 'mean'), mean)
191
- np.save(os.path.join(output_path_, 'stdev'), stdev)
192
-
193
- # rename npy to npz, don't mind about it, just some legacy issue
194
- os.rename(os.path.join(output_path_, 'mean.npy'), os.path.join(output_path_, 'mean.npz'))
195
- os.rename(os.path.join(output_path_, 'stdev.npy'), os.path.join(output_path_, 'stdev.npz'))
196
-
197
-
198
- def main():
199
- parser = argparse.ArgumentParser(description="LMDB creation")
200
- parser.add_argument("--language", type=str, default='eng', choices=['eng', 'chn', 'tha'])
201
- parser.add_argument("--data_path", type=str, default='./Font_Dataset', help="Path to Dataset")
202
- parser.add_argument("--ttf_path", type=str, default='../data/font_ttfs')
203
- parser.add_argument('--sfd_path', type=str, default='../data/font_sfds')
204
- parser.add_argument("--output_path", type=str, default='../data/vecfont_dataset_/', help="Path to write the database to")
205
- parser.add_argument('--img_size', type=int, default=64, help="the height and width of glyph images")
206
- parser.add_argument("--split", type=str, default='train')
207
- # parser.add_argument("--log_dir", type=str, default='../data/font_sfds/log/')
208
- parser.add_argument("--phase", type=int, default=0, choices=[0, 1, 2],
209
- help="0 all, 1 create db, 2 cal stddev")
210
-
211
- opts = parser.parse_args()
212
- assert os.path.exists(opts.sfd_path), "specified sfd glyphs path does not exist"
213
-
214
- output_path = os.path.join(opts.output_path, opts.language, opts.split)
215
- log_path = os.path.join(opts.sfd_path, opts.language, 'log')
216
-
217
- if not os.path.exists(output_path):
218
- os.makedirs(output_path)
219
-
220
- if not os.path.exists(log_path):
221
- os.makedirs(log_path)
222
-
223
- if opts.phase <= 1:
224
- create_db(opts, output_path, log_path)
225
-
226
- if opts.phase <= 2 and opts.split == 'train':
227
- cal_mean_stddev(opts, output_path)
228
-
229
-
230
- if __name__ == "__main__":
231
  main()
 
1
+ import argparse
2
+ import multiprocessing as mp
3
+ import os
4
+ import pickle
5
+ import numpy as np
6
+ from data_utils import svg_utils
7
+ from tqdm import tqdm
8
+
9
+ def exist_empty_imgs(imgs_array, num_chars):
10
+ for char_id in range(num_chars):
11
+ print(np.max(imgs_array[char_id]))
12
+ input()
13
+ if np.max(imgs_array[char_id]) == 0:
14
+ return True
15
+ return False
16
+
17
+ def create_db(opts, output_path, log_path):
18
+ charset = open(f"{opts.data_path}/char_set/{opts.language}.txt", 'r').read()
19
+ print("Process sfd to npy files in dirs....")
20
+ sdf_path = os.path.join(opts.sfd_path, opts.language, opts.split)
21
+ all_font_ids = sorted(os.listdir(sdf_path))
22
+ num_fonts = len(all_font_ids)
23
+ num_fonts_w = len(str(num_fonts))
24
+ print(f"Number {opts.split} fonts before processing", num_fonts)
25
+ num_processes = mp.cpu_count() - 1
26
+ fonts_per_process = num_fonts // num_processes + 1
27
+ num_chars = len(charset)
28
+ num_chars_w = len(str(num_chars))
29
+
30
+
31
+ # import ipdb; ipdb.set_trace()
32
+
33
+ def process(process_id):
34
+ valid_chars = []
35
+ invalid_path = []
36
+ invalid_glypts = []
37
+
38
+ cur_process_log_file = open(os.path.join(log_path, f'log_{opts.split}_{process_id}.txt'), 'w')
39
+ for i in tqdm(range(process_id * fonts_per_process, (process_id + 1) * fonts_per_process)):
40
+ if i >= num_fonts:
41
+ break
42
+
43
+ font_id = all_font_ids[i]
44
+ cur_font_sfd_dir = os.path.join(sdf_path, font_id)
45
+ cur_font_glyphs = []
46
+
47
+ if not os.path.exists(os.path.join(cur_font_sfd_dir, 'imgs_' + str(opts.img_size) + '.npy')):
48
+ continue
49
+
50
+ # a whole font as an entry
51
+ for char_id in range(num_chars):
52
+ # print('char_id :',char_id)
53
+ if not os.path.exists(os.path.join(cur_font_sfd_dir, '{}_{num:0{width}}.sfd'.format(font_id, num=char_id, width=num_chars_w))):
54
+ break
55
+
56
+ char_desp_f = open(os.path.join(cur_font_sfd_dir, '{}_{num:0{width}}.txt'.format(font_id, num=char_id, width=num_chars_w)), 'r')
57
+ char_desp = char_desp_f.readlines()
58
+ sfd_f = open(os.path.join(cur_font_sfd_dir, '{}_{num:0{width}}.sfd'.format(font_id, num=char_id, width=num_chars_w)), 'r')
59
+ sfd = sfd_f.read()
60
+
61
+ uni = int(char_desp[0].strip())
62
+ width = int(char_desp[1].strip())
63
+ vwidth = int(char_desp[2].strip())
64
+ char_idx = char_desp[3].strip()
65
+ font_idx = char_desp[4].strip()
66
+
67
+ cur_glyph = {}
68
+ cur_glyph['uni'] = uni
69
+ cur_glyph['width'] = width
70
+ cur_glyph['vwidth'] = vwidth
71
+ cur_glyph['sfd'] = sfd
72
+ cur_glyph['id'] = char_idx
73
+ cur_glyph['binary_fp'] = font_idx
74
+
75
+ if not svg_utils.is_valid_glyph(cur_glyph):
76
+ msg = f"font {font_idx}, char {char_idx} is not a valid glyph\n"
77
+ invalid_path.glypts([font_idx, int(char_idx), charset[int(char_idx)]])
78
+ cur_process_log_file.write(msg)
79
+ char_desp_f.close()
80
+ sfd_f.close()
81
+ # use the font whose all glyphs are valid
82
+ break
83
+ pathunibfp = svg_utils.convert_to_path(cur_glyph)
84
+
85
+ if not svg_utils.is_valid_path(pathunibfp):
86
+ msg = f"font {font_idx}, char {char_idx}'s sfd is not a valid path\n"
87
+ invalid_path.append([font_idx, int(char_idx), charset[int(char_idx)]])
88
+ cur_process_log_file.write(msg)
89
+ char_desp_f.close()
90
+ sfd_f.close()
91
+ break
92
+ valid_chars.append([font_idx, int(char_idx), charset[int(char_idx)]])
93
+ example = svg_utils.create_example(pathunibfp)
94
+
95
+ cur_font_glyphs.append(example)
96
+ char_desp_f.close()
97
+ sfd_f.close()
98
+
99
+ if len(cur_font_glyphs) == num_chars:
100
+ # use the font whose all glyphs are valid
101
+ # merge the whole font
102
+
103
+ rendered = np.load(os.path.join(cur_font_sfd_dir, 'imgs_' + str(opts.img_size) + '.npy'))
104
+
105
+ if (rendered[0] == rendered[1]).all() == True:
106
+ continue
107
+
108
+ sequence = []
109
+ seq_len = []
110
+ binaryfp = []
111
+ char_class = []
112
+ for char_id in range(num_chars):
113
+ example = cur_font_glyphs[char_id]
114
+ sequence.append(example['sequence'])
115
+ seq_len.append(example['seq_len'])
116
+ char_class.append(example['class'])
117
+ binaryfp = example['binary_fp']
118
+ if not os.path.exists(os.path.join(output_path, '{num:0{width}}'.format(num=i, width=num_fonts_w))):
119
+ os.mkdir(os.path.join(output_path, '{num:0{width}}'.format(num=i, width=num_fonts_w)))
120
+
121
+ np.save(os.path.join(output_path, '{num:0{width}}'.format(num=i, width=num_fonts_w), 'sequence.npy'), np.array(sequence))
122
+ np.save(os.path.join(output_path, '{num:0{width}}'.format(num=i, width=num_fonts_w), 'seq_len.npy'), np.array(seq_len))
123
+ np.save(os.path.join(output_path, '{num:0{width}}'.format(num=i, width=num_fonts_w), 'class.npy'), np.array(char_class))
124
+ np.save(os.path.join(output_path, '{num:0{width}}'.format(num=i, width=num_fonts_w), 'font_id.npy'), np.array(binaryfp))
125
+ np.save(os.path.join(output_path, '{num:0{width}}'.format(num=i, width=num_fonts_w), 'rendered_' + str(opts.img_size) + '.npy'), rendered)
126
+
127
+ print("valid_chars", len(valid_chars))
128
+ print("invalid_path:", invalid_path)
129
+ print("invalid_glypts:",invalid_glypts)
130
+
131
+ processes = [mp.Process(target=process, args=[pid]) for pid in range(num_processes)]
132
+
133
+ for p in processes:
134
+ p.start()
135
+ for p in processes:
136
+ p.join()
137
+
138
+ print("Finished processing all sfd files, logs (invalid glyphs and paths) are saved to", log_path)
139
+
140
+
141
+ def cal_mean_stddev(opts, output_path):
142
+ print("Calculating all glyphs' mean stddev ....")
143
+ charset = open(f"{opts.data_path}/char_set/{opts.language}.txt", 'r').read()
144
+ font_paths = []
145
+ for root, dirs, files in os.walk(output_path):
146
+ for dir_name in dirs:
147
+ font_paths.append(os.path.join(output_path, dir_name))
148
+
149
+ font_paths.sort()
150
+ num_fonts = len(font_paths)
151
+ num_processes = mp.cpu_count() - 1
152
+ fonts_per_process = num_fonts // num_processes + 1
153
+ num_chars = len(charset)
154
+ manager = mp.Manager()
155
+ return_dict = manager.dict()
156
+ main_stddev_accum = svg_utils.MeanStddev()
157
+ print(main_stddev_accum)
158
+
159
+ def process(process_id, return_dict):
160
+ mean_stddev_accum = svg_utils.MeanStddev()
161
+ cur_sum_count = mean_stddev_accum.create_accumulator()
162
+ for i in range(process_id * fonts_per_process, (process_id + 1) * fonts_per_process):
163
+ if i >= num_fonts:
164
+ break
165
+ cur_font_path = font_paths[i]
166
+ for charid in range(num_chars):
167
+ cur_font_char = {}
168
+ cur_font_char['seq_len'] = np.load(os.path.join(cur_font_path, 'seq_len.npy')).tolist()[charid]
169
+ cur_font_char['sequence'] = np.load(os.path.join(cur_font_path, 'sequence.npy')).tolist()[charid]
170
+ # print(cur_font_char)
171
+ cur_sum_count = mean_stddev_accum.add_input(cur_sum_count, cur_font_char)
172
+ return_dict[process_id] = cur_sum_count
173
+ processes = [mp.Process(target=process, args=[pid, return_dict]) for pid in range(num_processes)]
174
+
175
+ for p in processes:
176
+ p.start()
177
+ for p in processes:
178
+ p.join()
179
+
180
+ merged_sum_count = main_stddev_accum.merge_accumulators(return_dict.values())
181
+ output = main_stddev_accum.extract_output(merged_sum_count)
182
+ print('output :', output)
183
+ mean = output['mean']
184
+ stdev = output['stddev']
185
+ print('mean :', mean)
186
+ mean = np.concatenate((np.zeros([4]), mean[4:]), axis=0)
187
+ stdev = np.concatenate((np.ones([4]), stdev[4:]), axis=0)
188
+ # finally, save the mean and stddev files
189
+ output_path_ = os.path.join(opts.output_path, opts.language)
190
+ np.save(os.path.join(output_path_, 'mean'), mean)
191
+ np.save(os.path.join(output_path_, 'stdev'), stdev)
192
+
193
+ # rename npy to npz, don't mind about it, just some legacy issue
194
+ os.rename(os.path.join(output_path_, 'mean.npy'), os.path.join(output_path_, 'mean.npz'))
195
+ os.rename(os.path.join(output_path_, 'stdev.npy'), os.path.join(output_path_, 'stdev.npz'))
196
+
197
+
198
+ def main():
199
+ parser = argparse.ArgumentParser(description="LMDB creation")
200
+ parser.add_argument("--language", type=str, default='eng', choices=['eng', 'chn', 'tha'])
201
+ parser.add_argument("--data_path", type=str, default='./Font_Dataset', help="Path to Dataset")
202
+ parser.add_argument("--ttf_path", type=str, default='../data/font_ttfs')
203
+ parser.add_argument('--sfd_path', type=str, default='../data/font_sfds')
204
+ parser.add_argument("--output_path", type=str, default='../data/vecfont_dataset_/', help="Path to write the database to")
205
+ parser.add_argument('--img_size', type=int, default=64, help="the height and width of glyph images")
206
+ parser.add_argument("--split", type=str, default='train')
207
+ # parser.add_argument("--log_dir", type=str, default='../data/font_sfds/log/')
208
+ parser.add_argument("--phase", type=int, default=0, choices=[0, 1, 2],
209
+ help="0 all, 1 create db, 2 cal stddev")
210
+
211
+ opts = parser.parse_args()
212
+ assert os.path.exists(opts.sfd_path), "specified sfd glyphs path does not exist"
213
+
214
+ output_path = os.path.join(opts.output_path, opts.language, opts.split)
215
+ log_path = os.path.join(opts.sfd_path, opts.language, 'log')
216
+
217
+ if not os.path.exists(output_path):
218
+ os.makedirs(output_path)
219
+
220
+ if not os.path.exists(log_path):
221
+ os.makedirs(log_path)
222
+
223
+ if opts.phase <= 1:
224
+ create_db(opts, output_path, log_path)
225
+
226
+ if opts.phase <= 2 and opts.split == 'train':
227
+ cal_mean_stddev(opts, output_path)
228
+
229
+
230
+ if __name__ == "__main__":
231
  main()
data_utils/write_glyph_imgs.py CHANGED
@@ -1,180 +1,180 @@
1
- from PIL import Image
2
- from PIL import Image
3
- from PIL import ImageDraw
4
- from PIL import ImageFont
5
- import argparse
6
- import numpy as np
7
- import os
8
- import multiprocessing as mp
9
- from tqdm import tqdm
10
- char_error = 0
11
-
12
- def get_bbox(img):
13
- img = 255 - np.array(img)
14
- sum_x = np.sum(img, axis=0)
15
- sum_y = np.sum(img, axis=1)
16
- range_x = np.where(sum_x > 0)
17
- width = range_x[0][-1] - range_x[0][0]
18
- range_y = np.where(sum_y > 0)
19
- height = range_y[0][-1] - range_y[0][0]
20
- return width, height
21
-
22
- def write_glyph_imgs_mp(opts):
23
- """Useing multiprocessing to render glyph images"""
24
- charset = open(f"{opts.data_path}/char_set/{opts.language}.txt", 'r').read()
25
- fonts_file_path = os.path.join(opts.ttf_path, opts.language)
26
- sfd_path = os.path.join(opts.sfd_path, opts.language)
27
- for root, dirs, files in os.walk(os.path.join(fonts_file_path, opts.split)):
28
- ttf_names = files
29
- # ttf_names = ['08343.aspx_id=299524532']
30
- ttf_names.sort()
31
- font_num = len(ttf_names)
32
- charset_lenw = len(str(len(charset)))
33
- process_nums = mp.cpu_count() - 1
34
- font_num_per_process = font_num // process_nums + 1
35
-
36
- def process(process_id, font_num_p_process):
37
- for i in tqdm(range(process_id * font_num_p_process, (process_id + 1) * font_num_p_process)):
38
- if i >= font_num:
39
- break
40
-
41
- fontname = ttf_names[i].split('.')[0]
42
- # print(fontname)
43
-
44
- if not os.path.exists(os.path.join(sfd_path, opts.split, fontname)):
45
- continue
46
-
47
- ttf_file_path = os.path.join(fonts_file_path, opts.split, ttf_names[i])
48
-
49
- try:
50
- font = ImageFont.truetype(ttf_file_path, int(opts.img_size*opts.FONT_SIZE), encoding="unic")
51
- except:
52
- print('cant open ' + fontname)
53
- continue
54
-
55
- fontimgs_array = np.zeros((len(charset), opts.img_size, opts.img_size), np.uint8)
56
- fontimgs_array[:, :, :] = 255
57
-
58
- flag_success = True
59
-
60
- for charid in range(len(charset)):
61
- # read the meta file
62
- txt_fpath = os.path.join(sfd_path, opts.split, fontname, fontname + '_' + '{num:0{width}}'.format(num=charid, width=charset_lenw) + '.txt')
63
- try:
64
- txt_lines = open(txt_fpath,'r').read().split('\n')
65
- except:
66
- print('cannot read text file')
67
- flag_success = False
68
- break
69
- if len(txt_lines) < 5:
70
- flag_success = False
71
- break # should be empty file
72
- # the offsets are calculated according to the rules in data_utils/svg_utils.py
73
- vbox_w = float(txt_lines[1])
74
- vbox_h = float(txt_lines[2])
75
- norm = max(int(vbox_w), int(vbox_h))
76
-
77
- if int(vbox_h) > int(vbox_w):
78
- add_to_y = 0
79
- add_to_x = abs(int(vbox_h) - int(vbox_w)) / 2
80
- add_to_x = add_to_x * (float(opts.img_size) / norm)
81
- else:
82
- add_to_y = abs(int(vbox_h) - int(vbox_w)) / 2
83
- add_to_y = add_to_y * (float(opts.img_size) / norm)
84
- add_to_x = 0
85
-
86
- char = charset[charid]
87
-
88
- array = np.ndarray((opts.img_size, opts.img_size), np.uint8)
89
- array[:, :] = 255
90
- image = Image.fromarray(array)
91
- draw = ImageDraw.Draw(image)
92
- try:
93
- font_width, font_height = font.getsize(char)
94
- except Exception as e:
95
- print('cant calculate height and width ' + "%04d"%i + '_' + '{num:0{width}}'.format(num=charid, width=charset_lenw))
96
- flag_success = False
97
- break
98
-
99
- try:
100
- ascent, descent = font.getmetrics()
101
- except:
102
- print('cannot get ascent, descent')
103
- flag_success = False
104
- break
105
-
106
- draw_pos_x = add_to_x
107
- #if opts.language == 'eng':
108
- thai_characters_long = ["ญ","ฎ","ฏ","ฐ"]
109
-
110
- if char in thai_characters_long:
111
- draw_pos_y = add_to_y + opts.img_size - ascent - descent - int((opts.img_size / 24.0) * (10.0 / 3.0))
112
- else:
113
- draw_pos_y = add_to_y + opts.img_size - ascent - int((opts.img_size / 24.0) * (10.0 / 3.0))
114
- #else:
115
- # draw_pos_y = add_to_y + opts.img_size - ascent - int((opts.img_size / 24.0) * (10.0 / 3.0))
116
-
117
- draw.text((draw_pos_x, draw_pos_y), char, (0), font=font)
118
-
119
- if opts.debug:
120
- image.save(os.path.join(sfd_path, opts.split, fontname, str(charid) + '_' + str(opts.img_size) + '.png'))
121
-
122
- try:
123
- char_w, char_h = get_bbox(image)
124
- # print(charid, char_w, char_h)
125
- except Exception as e:
126
- print("cannot get bbox")
127
- print(e)
128
- flag_success = False
129
- break
130
-
131
- # Detect large font
132
- problem = []
133
- if font_width > 59:
134
- problem.append("width")
135
-
136
- if font_height > 93:
137
- problem.append("height")
138
-
139
- if problem:
140
- print(problem,fontname, charid, font_width, font_height, char_w, char_h)
141
- flag_success = False
142
- break
143
-
144
- # Detect Small Font
145
- if (char_w < opts.img_size * 0.15) and (char_h < opts.img_size * 0.15):
146
- flag_success = False
147
- break
148
-
149
- fontimgs_array[charid] = np.array(image)
150
-
151
- if flag_success:
152
- np.save(os.path.join(sfd_path, opts.split, fontname, 'imgs_' + str(opts.img_size) + '.npy'), fontimgs_array)
153
- else:
154
- global char_error # Count char flag not success
155
- char_error += 1
156
- print("flag on", fontname, charid, 'imgs_' + str(opts.img_size) + '.npy', " Not Succeed")
157
-
158
- processes = [mp.Process(target=process, args=(pid, font_num_per_process)) for pid in range(process_nums)]
159
-
160
- for p in processes:
161
- p.start()
162
- for p in processes:
163
- p.join()
164
-
165
- def main():
166
- parser = argparse.ArgumentParser(description="Write glyph images")
167
- parser.add_argument("--language", type=str, default='eng', choices=['eng', 'chn', 'tha'])
168
- parser.add_argument("--data_path", type=str, default='./Font_Dataset', help="Path to Dataset")
169
- parser.add_argument("--ttf_path", type=str, default='../data/font_ttfs')
170
- parser.add_argument('--sfd_path', type=str, default='../data/font_sfds')
171
- parser.add_argument('--img_size', type=int, default=64)
172
- parser.add_argument('--split', type=str, default='train')
173
- parser.add_argument('--FONT_SIZE', type=float, default=1)
174
- parser.add_argument('--debug', type=bool, default=False)
175
- opts = parser.parse_args()
176
- write_glyph_imgs_mp(opts)
177
-
178
-
179
- if __name__ == "__main__":
180
- main()
 
1
+ from PIL import Image
2
+ from PIL import Image
3
+ from PIL import ImageDraw
4
+ from PIL import ImageFont
5
+ import argparse
6
+ import numpy as np
7
+ import os
8
+ import multiprocessing as mp
9
+ from tqdm import tqdm
10
+ char_error = 0
11
+
12
+ def get_bbox(img):
13
+ img = 255 - np.array(img)
14
+ sum_x = np.sum(img, axis=0)
15
+ sum_y = np.sum(img, axis=1)
16
+ range_x = np.where(sum_x > 0)
17
+ width = range_x[0][-1] - range_x[0][0]
18
+ range_y = np.where(sum_y > 0)
19
+ height = range_y[0][-1] - range_y[0][0]
20
+ return width, height
21
+
22
+ def write_glyph_imgs_mp(opts):
23
+ """Useing multiprocessing to render glyph images"""
24
+ charset = open(f"{opts.data_path}/char_set/{opts.language}.txt", 'r').read()
25
+ fonts_file_path = os.path.join(opts.ttf_path, opts.language)
26
+ sfd_path = os.path.join(opts.sfd_path, opts.language)
27
+ for root, dirs, files in os.walk(os.path.join(fonts_file_path, opts.split)):
28
+ ttf_names = files
29
+ # ttf_names = ['08343.aspx_id=299524532']
30
+ ttf_names.sort()
31
+ font_num = len(ttf_names)
32
+ charset_lenw = len(str(len(charset)))
33
+ process_nums = mp.cpu_count() - 1
34
+ font_num_per_process = font_num // process_nums + 1
35
+
36
+ def process(process_id, font_num_p_process):
37
+ for i in tqdm(range(process_id * font_num_p_process, (process_id + 1) * font_num_p_process)):
38
+ if i >= font_num:
39
+ break
40
+
41
+ fontname = ttf_names[i].split('.')[0]
42
+ # print(fontname)
43
+
44
+ if not os.path.exists(os.path.join(sfd_path, opts.split, fontname)):
45
+ continue
46
+
47
+ ttf_file_path = os.path.join(fonts_file_path, opts.split, ttf_names[i])
48
+
49
+ try:
50
+ font = ImageFont.truetype(ttf_file_path, int(opts.img_size*opts.FONT_SIZE), encoding="unic")
51
+ except:
52
+ print('cant open ' + fontname)
53
+ continue
54
+
55
+ fontimgs_array = np.zeros((len(charset), opts.img_size, opts.img_size), np.uint8)
56
+ fontimgs_array[:, :, :] = 255
57
+
58
+ flag_success = True
59
+
60
+ for charid in range(len(charset)):
61
+ # read the meta file
62
+ txt_fpath = os.path.join(sfd_path, opts.split, fontname, fontname + '_' + '{num:0{width}}'.format(num=charid, width=charset_lenw) + '.txt')
63
+ try:
64
+ txt_lines = open(txt_fpath,'r').read().split('\n')
65
+ except:
66
+ print('cannot read text file')
67
+ flag_success = False
68
+ break
69
+ if len(txt_lines) < 5:
70
+ flag_success = False
71
+ break # should be empty file
72
+ # the offsets are calculated according to the rules in data_utils/svg_utils.py
73
+ vbox_w = float(txt_lines[1])
74
+ vbox_h = float(txt_lines[2])
75
+ norm = max(int(vbox_w), int(vbox_h))
76
+
77
+ if int(vbox_h) > int(vbox_w):
78
+ add_to_y = 0
79
+ add_to_x = abs(int(vbox_h) - int(vbox_w)) / 2
80
+ add_to_x = add_to_x * (float(opts.img_size) / norm)
81
+ else:
82
+ add_to_y = abs(int(vbox_h) - int(vbox_w)) / 2
83
+ add_to_y = add_to_y * (float(opts.img_size) / norm)
84
+ add_to_x = 0
85
+
86
+ char = charset[charid]
87
+
88
+ array = np.ndarray((opts.img_size, opts.img_size), np.uint8)
89
+ array[:, :] = 255
90
+ image = Image.fromarray(array)
91
+ draw = ImageDraw.Draw(image)
92
+ try:
93
+ font_width, font_height = font.getsize(char)
94
+ except Exception as e:
95
+ print('cant calculate height and width ' + "%04d"%i + '_' + '{num:0{width}}'.format(num=charid, width=charset_lenw))
96
+ flag_success = False
97
+ break
98
+
99
+ try:
100
+ ascent, descent = font.getmetrics()
101
+ except:
102
+ print('cannot get ascent, descent')
103
+ flag_success = False
104
+ break
105
+
106
+ draw_pos_x = add_to_x
107
+ #if opts.language == 'eng':
108
+ thai_characters_long = ["ญ","ฎ","ฏ","ฐ"]
109
+
110
+ if char in thai_characters_long:
111
+ draw_pos_y = add_to_y + opts.img_size - ascent - descent - int((opts.img_size / 24.0) * (10.0 / 3.0))
112
+ else:
113
+ draw_pos_y = add_to_y + opts.img_size - ascent - int((opts.img_size / 24.0) * (10.0 / 3.0))
114
+ #else:
115
+ # draw_pos_y = add_to_y + opts.img_size - ascent - int((opts.img_size / 24.0) * (10.0 / 3.0))
116
+
117
+ draw.text((draw_pos_x, draw_pos_y), char, (0), font=font)
118
+
119
+ if opts.debug:
120
+ image.save(os.path.join(sfd_path, opts.split, fontname, str(charid) + '_' + str(opts.img_size) + '.png'))
121
+
122
+ try:
123
+ char_w, char_h = get_bbox(image)
124
+ # print(charid, char_w, char_h)
125
+ except Exception as e:
126
+ print("cannot get bbox")
127
+ print(e)
128
+ flag_success = False
129
+ break
130
+
131
+ # Detect large font
132
+ problem = []
133
+ if font_width > 59:
134
+ problem.append("width")
135
+
136
+ if font_height > 93:
137
+ problem.append("height")
138
+
139
+ if problem:
140
+ print(problem,fontname, charid, font_width, font_height, char_w, char_h)
141
+ flag_success = False
142
+ break
143
+
144
+ # Detect Small Font
145
+ if (char_w < opts.img_size * 0.15) and (char_h < opts.img_size * 0.15):
146
+ flag_success = False
147
+ break
148
+
149
+ fontimgs_array[charid] = np.array(image)
150
+
151
+ if flag_success:
152
+ np.save(os.path.join(sfd_path, opts.split, fontname, 'imgs_' + str(opts.img_size) + '.npy'), fontimgs_array)
153
+ else:
154
+ global char_error # Count char flag not success
155
+ char_error += 1
156
+ print("flag on", fontname, charid, 'imgs_' + str(opts.img_size) + '.npy', " Not Succeed")
157
+
158
+ processes = [mp.Process(target=process, args=(pid, font_num_per_process)) for pid in range(process_nums)]
159
+
160
+ for p in processes:
161
+ p.start()
162
+ for p in processes:
163
+ p.join()
164
+
165
+ def main():
166
+ parser = argparse.ArgumentParser(description="Write glyph images")
167
+ parser.add_argument("--language", type=str, default='eng', choices=['eng', 'chn', 'tha'])
168
+ parser.add_argument("--data_path", type=str, default='./Font_Dataset', help="Path to Dataset")
169
+ parser.add_argument("--ttf_path", type=str, default='../data/font_ttfs')
170
+ parser.add_argument('--sfd_path', type=str, default='../data/font_sfds')
171
+ parser.add_argument('--img_size', type=int, default=64)
172
+ parser.add_argument('--split', type=str, default='train')
173
+ parser.add_argument('--FONT_SIZE', type=float, default=1)
174
+ parser.add_argument('--debug', type=bool, default=False)
175
+ opts = parser.parse_args()
176
+ write_glyph_imgs_mp(opts)
177
+
178
+
179
+ if __name__ == "__main__":
180
+ main()
dataloader.py CHANGED
@@ -1,67 +1,67 @@
1
- # data loader for training main model
2
- import os
3
- import pickle
4
- import torch
5
- import torch.utils.data as data
6
- import torchvision.transforms as T
7
- import sys
8
- import numpy as np
9
- torch.multiprocessing.set_sharing_strategy('file_system')
10
-
11
-
12
- class SVGDataset(data.Dataset):
13
- def __init__(self, root_path, img_size=128, lang='eng', char_num=52, max_seq_len=51, dim_seq=10, transform=None, mode='train'):
14
- super().__init__()
15
- self.mode = mode
16
- self.img_size = img_size
17
- self.char_num = char_num
18
- self.max_seq_len = max_seq_len
19
- self.dim_seq = dim_seq
20
- self.trans = transform
21
- self.font_paths = []
22
- self.dir_path = os.path.join(root_path, lang, self.mode)
23
- for root, dirs, files in os.walk(self.dir_path):
24
- depth = root.count('/') - self.dir_path.count('/')
25
- if depth == 0:
26
- for dir_name in dirs:
27
- self.font_paths.append(os.path.join(self.dir_path, dir_name))
28
- self.font_paths.sort()
29
- print(f"Finished loading {mode} paths, number: {str(len(self.font_paths))}")
30
-
31
- def __getitem__(self, index):
32
- item = {}
33
- font_path = self.font_paths[index]
34
- item = {}
35
- item['class'] = torch.LongTensor(np.load(os.path.join(font_path, 'class.npy')))
36
- item['seq_len'] = torch.LongTensor(np.load(os.path.join(font_path, 'seq_len.npy')))
37
- item['sequence'] = torch.FloatTensor(np.load(os.path.join(font_path, 'sequence_relaxed.npy'))).view(self.char_num, self.max_seq_len, self.dim_seq)
38
- item['pts_aux'] = torch.FloatTensor(np.load(os.path.join(font_path, 'pts_aux.npy')))
39
- item['rendered'] = torch.FloatTensor(np.load(os.path.join(font_path, 'rendered_' + str(self.img_size) + '.npy'))).view(self.char_num, self.img_size, self.img_size) / 255.
40
- item['rendered'] = self.trans(item['rendered'])
41
- item['font_id'] = torch.FloatTensor(np.load(os.path.join(font_path, 'font_id.npy')).astype(np.float32))
42
- return item
43
-
44
- def __len__(self):
45
- return len(self.font_paths)
46
-
47
-
48
- def get_loader(root_path, img_size, lang, char_num, max_seq_len, dim_seq, batch_size, mode='train'):
49
- SetRange = T.Lambda(lambda X: 1. - X ) # convert [0, 1] -> [0, 1]
50
- transform = T.Compose([SetRange])
51
- dataset = SVGDataset(root_path, img_size, lang, char_num, max_seq_len, dim_seq, transform, mode)
52
- dataloader = data.DataLoader(dataset, batch_size, shuffle=(mode == 'train'), num_workers=batch_size)
53
- return dataloader
54
-
55
- if __name__ == '__main__':
56
- root_path = 'data/new_data'
57
- max_seq_len = 51
58
- dim_seq = 10
59
- batch_size = 1
60
- char_num = 52
61
-
62
- loader = get_loader(root_path, char_num, max_seq_len, dim_seq, batch_size, 'train')
63
- fout = open('train_id_record_old.txt','w')
64
- for idx, batch in enumerate(loader):
65
- binary_fp = batch['font_id'].numpy()[0][0]
66
- fout.write("%05d"%int(binary_fp) + '\n')
67
-
 
1
+ # data loader for training main model
2
+ import os
3
+ import pickle
4
+ import torch
5
+ import torch.utils.data as data
6
+ import torchvision.transforms as T
7
+ import sys
8
+ import numpy as np
9
+ torch.multiprocessing.set_sharing_strategy('file_system')
10
+
11
+
12
+ class SVGDataset(data.Dataset):
13
+ def __init__(self, root_path, img_size=128, lang='eng', char_num=52, max_seq_len=51, dim_seq=10, transform=None, mode='train'):
14
+ super().__init__()
15
+ self.mode = mode
16
+ self.img_size = img_size
17
+ self.char_num = char_num
18
+ self.max_seq_len = max_seq_len
19
+ self.dim_seq = dim_seq
20
+ self.trans = transform
21
+ self.font_paths = []
22
+ self.dir_path = os.path.join(root_path, lang, self.mode)
23
+ for root, dirs, files in os.walk(self.dir_path):
24
+ depth = root.count('/') - self.dir_path.count('/')
25
+ if depth == 0:
26
+ for dir_name in dirs:
27
+ self.font_paths.append(os.path.join(self.dir_path, dir_name))
28
+ self.font_paths.sort()
29
+ print(f"Finished loading {mode} paths, number: {str(len(self.font_paths))}")
30
+
31
+ def __getitem__(self, index):
32
+ item = {}
33
+ font_path = self.font_paths[index]
34
+ item = {}
35
+ item['class'] = torch.LongTensor(np.load(os.path.join(font_path, 'class.npy')))
36
+ item['seq_len'] = torch.LongTensor(np.load(os.path.join(font_path, 'seq_len.npy')))
37
+ item['sequence'] = torch.FloatTensor(np.load(os.path.join(font_path, 'sequence_relaxed.npy'))).view(self.char_num, self.max_seq_len, self.dim_seq)
38
+ item['pts_aux'] = torch.FloatTensor(np.load(os.path.join(font_path, 'pts_aux.npy')))
39
+ item['rendered'] = torch.FloatTensor(np.load(os.path.join(font_path, 'rendered_' + str(self.img_size) + '.npy'))).view(self.char_num, self.img_size, self.img_size) / 255.
40
+ item['rendered'] = self.trans(item['rendered'])
41
+ item['font_id'] = torch.FloatTensor(np.load(os.path.join(font_path, 'font_id.npy')).astype(np.float32))
42
+ return item
43
+
44
+ def __len__(self):
45
+ return len(self.font_paths)
46
+
47
+
48
+ def get_loader(root_path, img_size, lang, char_num, max_seq_len, dim_seq, batch_size, mode='train'):
49
+ SetRange = T.Lambda(lambda X: 1. - X ) # convert [0, 1] -> [0, 1]
50
+ transform = T.Compose([SetRange])
51
+ dataset = SVGDataset(root_path, img_size, lang, char_num, max_seq_len, dim_seq, transform, mode)
52
+ dataloader = data.DataLoader(dataset, batch_size, shuffle=(mode == 'train'), num_workers=batch_size)
53
+ return dataloader
54
+
55
+ if __name__ == '__main__':
56
+ root_path = 'data/new_data'
57
+ max_seq_len = 51
58
+ dim_seq = 10
59
+ batch_size = 1
60
+ char_num = 52
61
+
62
+ loader = get_loader(root_path, char_num, max_seq_len, dim_seq, batch_size, 'train')
63
+ fout = open('train_id_record_old.txt','w')
64
+ for idx, batch in enumerate(loader):
65
+ binary_fp = batch['font_id'].numpy()[0][0]
66
+ fout.write("%05d"%int(binary_fp) + '\n')
67
+
generate.py CHANGED
@@ -1,143 +1,143 @@
1
- import fontTools
2
- import os
3
- import shutil
4
- import typing
5
- import PIL
6
- from PIL import Image, ImageDraw, ImageFont
7
- from data_utils.convert_ttf_to_sfd import convert_mp
8
- from data_utils.write_glyph_imgs import write_glyph_imgs_mp
9
- from data_utils.write_data_to_dirs import create_db
10
- from data_utils.relax_rep import relax_rep
11
- from test_few_shot import test_main_model
12
- from options import get_parser_main_model
13
-
14
- opts = get_parser_main_model().parse_args()
15
-
16
- # Config on opts
17
- # Inference opts
18
- opts.mode = "test"
19
- opts.language = "tha"
20
- opts.char_num = 44
21
- opts.ref_nshot = 8
22
- opts.batch_size = 1 # inference rule
23
- opts.img_size = 64
24
- opts.max_seq_len = 121
25
- opts.name_ckpt = ""
26
- opts.model_path = "./inference_model/950_49452.ckpt"
27
- opts.ref_char_ids = "0,1,2,3,4,5,6,7"
28
- opts.dir_res = "./inference"
29
- opts.data_root = "./inference/vecfont_dataset/"
30
-
31
- # Data preprocessing opts
32
- opts.data_path = './inference'
33
- opts.sfd_path = f'{opts.data_path}/font_sfds'
34
- opts.ttf_path = f'{opts.data_path}/font_ttfs'
35
- opts.split = "test"
36
- opts.debug = True # Save Image On write_glyph_imgs_mp
37
- opts.output_path = f'{opts.data_path}/vecfont_dataset/'
38
- opts.phase = 0
39
- opts.FONT_SIZE = 1
40
-
41
- opts.streamlit = True
42
-
43
- # Glypts ID :
44
- # [(0, 'A'), (1, 'B'), (2, 'C'), (3, 'D'), (4, 'E')]
45
- # [(5, 'F'), (6, 'G'), (7, 'H'), (8, 'I'), (9, 'J')]
46
- # [(10, 'K'), (11, 'L'), (12, 'M'), (13, 'N'), (14, 'O')]
47
- # [(15, 'P'), (16, 'Q'), (17, 'R'), (18, 'S'), (19, 'T')]
48
- # [(20, 'U'), (21, 'V'), (22, 'W'), (23, 'X'), (24, 'Y')]
49
- # [(25, 'Z'), (26, 'a'), (27, 'b'), (28, 'c'), (29, 'd')]
50
- # [(30, 'e'), (31, 'f'), (32, 'g'), (33, 'h'), (34, 'i')]
51
- # [(35, 'j'), (36, 'k'), (37, 'l'), (38, 'm'), (39, 'n')]
52
- # [(40, 'o'), (41, 'p'), (42, 'q'), (43, 'r'), (44, 's')]
53
- # [(45, 't'), (46, 'u'), (47, 'v'), (48, 'w'), (49, 'x')]
54
- # [(50, 'y'), (51, 'z'), (52, 'ก'), (53, 'ข'), (54, 'ฃ')]
55
- # [(55, 'ค'), (56, 'ฅ'), (57, 'ฆ'), (58, 'ง'), (59, 'จ')]
56
- # [(60, 'ฉ'), (61, 'ช'), (62, 'ซ'), (63, 'ฌ'), (64, 'ญ')]
57
- # [(65, 'ฎ'), (66, 'ฏ'), (67, 'ฐ'), (68, 'ฑ'), (69, 'ฒ')]
58
- # [(70, 'ณ'), (71, 'ด'), (72, 'ต'), (73, 'ถ'), (74, 'ท')]
59
- # [(75, 'ธ'), (76, 'น'), (77, 'บ'), (78, 'ป'), (79, 'ผ')]
60
- # [(80, 'ฝ'), (81, 'พ'), (82, 'ฟ'), (83, 'ภ'), (84, 'ม')]
61
- # [(85, 'ย'), (86, 'ร'), (87, 'ล'), (88, 'ว'), (89, 'ศ')]
62
- # [(90, 'ษ'), (91, 'ส'), (92, 'ห'), (93, 'ฬ'), (94, 'อ')]
63
- # [(95, 'ฮ')]
64
-
65
- import string
66
- import pythainlp
67
-
68
- thai_digits = [*pythainlp.thai_digits]
69
- thai_characters = [*pythainlp.thai_consonants]
70
- eng_characters = [*string.ascii_letters]
71
- thai_floating = [*pythainlp.thai_vowels]
72
-
73
- directories = [
74
- "inference",
75
- "inference/char_set",
76
- "inference/font_sfds",
77
- "inference/font_ttfs",
78
- "inference/vecfont_dataset",
79
- "inference/font_ttfs/tha/test",
80
- ]
81
-
82
-
83
- # Data Preprocessing
84
- def preprocessing(ttf_file) -> str:
85
- shutil.rmtree("inference")
86
- for directory in directories:
87
- os.makedirs(directory, exist_ok=True)
88
-
89
- # Save File / Copy File
90
- if isinstance(ttf_file, memoryview):
91
- with open(f"{opts.data_path}/font_ttfs/tha/test/0000.ttf", 'wb') as f:
92
- f.write(ttf_file)
93
- elif isinstance(ttf_file, str):
94
- shutil.copy(ttf_file, f"{opts.data_path}/font_ttfs/tha/test/0000.ttf")
95
-
96
- glypts = sorted(set(thai_characters))
97
- print("Glypts:",len(glypts))
98
- print("".join(glypts))
99
- f = open("inference/char_set/tha.txt", "w")
100
- f.write("".join(glypts))
101
- f.close()
102
-
103
- # Preprocess Pipeline
104
- convert_mp(opts)
105
- write_glyph_imgs_mp(opts)
106
- output_path = os.path.join(opts.output_path, opts.language, opts.split)
107
- log_path = os.path.join(opts.sfd_path, opts.language, 'log')
108
- if not os.path.exists(output_path):
109
- os.makedirs(output_path)
110
- if not os.path.exists(log_path):
111
- os.makedirs(log_path)
112
- create_db(opts, output_path, log_path)
113
- relax_rep(opts)
114
-
115
- print("Finished making a data", ttf_file)
116
- print("Saved at", output_path)
117
- return output_path
118
-
119
- def inference_model(n_samples, ref_char_ids, version):
120
- opts.n_samples = n_samples
121
- opts.ref_char_ids = ref_char_ids
122
-
123
- # Select Model
124
- if version == "TH2TH":
125
- opts.model_path = "./inference_model/950_49452.ckpt"
126
- elif version == "ENG2TH":
127
- opts.model_path = "./inference_model/950_49452.ckpt"
128
- else:
129
- raise NotImplementedError
130
-
131
- return test_main_model(opts)
132
-
133
- def ttf_to_image(ttf_file, n_samples=10, ref_char_ids="1,2,3,4,5,6,7,8", version="TH2TH"):
134
- preprocessing(ttf_file) # Make Data
135
- merge_svg_img = inference_model(n_samples, ref_char_ids, version) # Inference
136
- return merge_svg_img
137
-
138
- def main():
139
- print(opts.mode)
140
- ttf_to_image("font_sample/SaoChingcha-Regular.otf")
141
-
142
- if __name__ == "__main__":
143
- main()
 
1
+ import fontTools
2
+ import os
3
+ import shutil
4
+ import typing
5
+ import PIL
6
+ from PIL import Image, ImageDraw, ImageFont
7
+ from data_utils.convert_ttf_to_sfd import convert_mp
8
+ from data_utils.write_glyph_imgs import write_glyph_imgs_mp
9
+ from data_utils.write_data_to_dirs import create_db
10
+ from data_utils.relax_rep import relax_rep
11
+ from test_few_shot import test_main_model
12
+ from options import get_parser_main_model
13
+
14
+ opts = get_parser_main_model().parse_args()
15
+
16
+ # Config on opts
17
+ # Inference opts
18
+ opts.mode = "test"
19
+ opts.language = "tha"
20
+ opts.char_num = 44
21
+ opts.ref_nshot = 8
22
+ opts.batch_size = 1 # inference rule
23
+ opts.img_size = 64
24
+ opts.max_seq_len = 121
25
+ opts.name_ckpt = ""
26
+ opts.model_path = "./inference_model/950_49452.ckpt"
27
+ opts.ref_char_ids = "0,1,2,3,4,5,6,7"
28
+ opts.dir_res = "./inference"
29
+ opts.data_root = "./inference/vecfont_dataset/"
30
+
31
+ # Data preprocessing opts
32
+ opts.data_path = './inference'
33
+ opts.sfd_path = f'{opts.data_path}/font_sfds'
34
+ opts.ttf_path = f'{opts.data_path}/font_ttfs'
35
+ opts.split = "test"
36
+ opts.debug = True # Save Image On write_glyph_imgs_mp
37
+ opts.output_path = f'{opts.data_path}/vecfont_dataset/'
38
+ opts.phase = 0
39
+ opts.FONT_SIZE = 1
40
+
41
+ opts.streamlit = True
42
+
43
+ # Glypts ID :
44
+ # [(0, 'A'), (1, 'B'), (2, 'C'), (3, 'D'), (4, 'E')]
45
+ # [(5, 'F'), (6, 'G'), (7, 'H'), (8, 'I'), (9, 'J')]
46
+ # [(10, 'K'), (11, 'L'), (12, 'M'), (13, 'N'), (14, 'O')]
47
+ # [(15, 'P'), (16, 'Q'), (17, 'R'), (18, 'S'), (19, 'T')]
48
+ # [(20, 'U'), (21, 'V'), (22, 'W'), (23, 'X'), (24, 'Y')]
49
+ # [(25, 'Z'), (26, 'a'), (27, 'b'), (28, 'c'), (29, 'd')]
50
+ # [(30, 'e'), (31, 'f'), (32, 'g'), (33, 'h'), (34, 'i')]
51
+ # [(35, 'j'), (36, 'k'), (37, 'l'), (38, 'm'), (39, 'n')]
52
+ # [(40, 'o'), (41, 'p'), (42, 'q'), (43, 'r'), (44, 's')]
53
+ # [(45, 't'), (46, 'u'), (47, 'v'), (48, 'w'), (49, 'x')]
54
+ # [(50, 'y'), (51, 'z'), (52, 'ก'), (53, 'ข'), (54, 'ฃ')]
55
+ # [(55, 'ค'), (56, 'ฅ'), (57, 'ฆ'), (58, 'ง'), (59, 'จ')]
56
+ # [(60, 'ฉ'), (61, 'ช'), (62, 'ซ'), (63, 'ฌ'), (64, 'ญ')]
57
+ # [(65, 'ฎ'), (66, 'ฏ'), (67, 'ฐ'), (68, 'ฑ'), (69, 'ฒ')]
58
+ # [(70, 'ณ'), (71, 'ด'), (72, 'ต'), (73, 'ถ'), (74, 'ท')]
59
+ # [(75, 'ธ'), (76, 'น'), (77, 'บ'), (78, 'ป'), (79, 'ผ')]
60
+ # [(80, 'ฝ'), (81, 'พ'), (82, 'ฟ'), (83, 'ภ'), (84, 'ม')]
61
+ # [(85, 'ย'), (86, 'ร'), (87, 'ล'), (88, 'ว'), (89, 'ศ')]
62
+ # [(90, 'ษ'), (91, 'ส'), (92, 'ห'), (93, 'ฬ'), (94, 'อ')]
63
+ # [(95, 'ฮ')]
64
+
65
+ import string
66
+ import pythainlp
67
+
68
+ thai_digits = [*pythainlp.thai_digits]
69
+ thai_characters = [*pythainlp.thai_consonants]
70
+ eng_characters = [*string.ascii_letters]
71
+ thai_floating = [*pythainlp.thai_vowels]
72
+
73
+ directories = [
74
+ "inference",
75
+ "inference/char_set",
76
+ "inference/font_sfds",
77
+ "inference/font_ttfs",
78
+ "inference/vecfont_dataset",
79
+ "inference/font_ttfs/tha/test",
80
+ ]
81
+
82
+
83
+ # Data Preprocessing
84
+ def preprocessing(ttf_file) -> str:
85
+ shutil.rmtree("inference")
86
+ for directory in directories:
87
+ os.makedirs(directory, exist_ok=True)
88
+
89
+ # Save File / Copy File
90
+ if isinstance(ttf_file, memoryview):
91
+ with open(f"{opts.data_path}/font_ttfs/tha/test/0000.ttf", 'wb') as f:
92
+ f.write(ttf_file)
93
+ elif isinstance(ttf_file, str):
94
+ shutil.copy(ttf_file, f"{opts.data_path}/font_ttfs/tha/test/0000.ttf")
95
+
96
+ glypts = sorted(set(thai_characters))
97
+ print("Glypts:",len(glypts))
98
+ print("".join(glypts))
99
+ f = open("inference/char_set/tha.txt", "w")
100
+ f.write("".join(glypts))
101
+ f.close()
102
+
103
+ # Preprocess Pipeline
104
+ convert_mp(opts)
105
+ write_glyph_imgs_mp(opts)
106
+ output_path = os.path.join(opts.output_path, opts.language, opts.split)
107
+ log_path = os.path.join(opts.sfd_path, opts.language, 'log')
108
+ if not os.path.exists(output_path):
109
+ os.makedirs(output_path)
110
+ if not os.path.exists(log_path):
111
+ os.makedirs(log_path)
112
+ create_db(opts, output_path, log_path)
113
+ relax_rep(opts)
114
+
115
+ print("Finished making a data", ttf_file)
116
+ print("Saved at", output_path)
117
+ return output_path
118
+
119
+ def inference_model(n_samples, ref_char_ids, version):
120
+ opts.n_samples = n_samples
121
+ opts.ref_char_ids = ref_char_ids
122
+
123
+ # Select Model
124
+ if version == "TH2TH":
125
+ opts.model_path = "./inference_model/950_49452.ckpt"
126
+ elif version == "ENG2TH":
127
+ opts.model_path = "./inference_model/950_49452.ckpt"
128
+ else:
129
+ raise NotImplementedError
130
+
131
+ return test_main_model(opts)
132
+
133
+ def ttf_to_image(ttf_file, n_samples=10, ref_char_ids="1,2,3,4,5,6,7,8", version="TH2TH"):
134
+ preprocessing(ttf_file) # Make Data
135
+ merge_svg_img = inference_model(n_samples, ref_char_ids, version) # Inference
136
+ return merge_svg_img
137
+
138
+ def main():
139
+ print(opts.mode)
140
+ ttf_to_image("font_sample/SaoChingcha-Regular.otf")
141
+
142
+ if __name__ == "__main__":
143
+ main()
models/image_decoder.py CHANGED
@@ -1,48 +1,48 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import math
5
-
6
- class ImageDecoder(nn.Module):
7
- def __init__(self, img_size, input_nc, output_nc, ngf=16, norm_layer=nn.LayerNorm):
8
-
9
- super(ImageDecoder, self).__init__()
10
- n_upsampling = int(math.log(img_size, 2))
11
- ks_list = [3] * (n_upsampling // 3) + [5] * (n_upsampling - n_upsampling // 3)
12
- stride_list = [2] * n_upsampling
13
- decoder = []
14
-
15
- chn_mult = []
16
- for i in range(n_upsampling):
17
- chn_mult.append(2 ** (n_upsampling - i - 1))
18
-
19
- decoder += [nn.ConvTranspose2d(input_nc, chn_mult[0] * ngf,
20
- kernel_size=ks_list[0], stride=stride_list[0],
21
- padding=ks_list[0] // 2, output_padding=stride_list[0]-1),
22
- norm_layer([chn_mult[0] * ngf, 2, 2]),
23
- nn.ReLU(True)]
24
-
25
- for i in range(1, n_upsampling): # add upsampling layers
26
- chn_prev = chn_mult[i - 1] * ngf
27
- chn_next = chn_mult[i] * ngf
28
- decoder += [nn.ConvTranspose2d(chn_prev, chn_next, kernel_size=ks_list[i], stride=stride_list[i], padding=ks_list[i] // 2, output_padding=stride_list[i]-1),
29
- norm_layer([chn_next, 2 ** (i+1) , 2 ** (i+1)]),
30
- nn.ReLU(True)]
31
-
32
- decoder += [nn.Conv2d(chn_mult[-1] * ngf, output_nc, kernel_size=7, padding=7 // 2)]
33
- decoder += [nn.Sigmoid()]
34
- self.decode = nn.Sequential(*decoder)
35
-
36
- def forward(self, latent_feat, trg_char, trg_img=None):
37
- """Standard forward"""
38
- dec_input = torch.cat((latent_feat, trg_char),-1)
39
- dec_input = dec_input.view(dec_input.size(0), dec_input.size(1), 1, 1)
40
- dec_out = self.decode(dec_input)
41
- output = {}
42
- output['gen_imgs'] = dec_out
43
- if trg_img is not None:
44
- output['img_l1loss'] = F.l1_loss(dec_out, trg_img)
45
-
46
- return output
47
-
48
-
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+ class ImageDecoder(nn.Module):
7
+ def __init__(self, img_size, input_nc, output_nc, ngf=16, norm_layer=nn.LayerNorm):
8
+
9
+ super(ImageDecoder, self).__init__()
10
+ n_upsampling = int(math.log(img_size, 2))
11
+ ks_list = [3] * (n_upsampling // 3) + [5] * (n_upsampling - n_upsampling // 3)
12
+ stride_list = [2] * n_upsampling
13
+ decoder = []
14
+
15
+ chn_mult = []
16
+ for i in range(n_upsampling):
17
+ chn_mult.append(2 ** (n_upsampling - i - 1))
18
+
19
+ decoder += [nn.ConvTranspose2d(input_nc, chn_mult[0] * ngf,
20
+ kernel_size=ks_list[0], stride=stride_list[0],
21
+ padding=ks_list[0] // 2, output_padding=stride_list[0]-1),
22
+ norm_layer([chn_mult[0] * ngf, 2, 2]),
23
+ nn.ReLU(True)]
24
+
25
+ for i in range(1, n_upsampling): # add upsampling layers
26
+ chn_prev = chn_mult[i - 1] * ngf
27
+ chn_next = chn_mult[i] * ngf
28
+ decoder += [nn.ConvTranspose2d(chn_prev, chn_next, kernel_size=ks_list[i], stride=stride_list[i], padding=ks_list[i] // 2, output_padding=stride_list[i]-1),
29
+ norm_layer([chn_next, 2 ** (i+1) , 2 ** (i+1)]),
30
+ nn.ReLU(True)]
31
+
32
+ decoder += [nn.Conv2d(chn_mult[-1] * ngf, output_nc, kernel_size=7, padding=7 // 2)]
33
+ decoder += [nn.Sigmoid()]
34
+ self.decode = nn.Sequential(*decoder)
35
+
36
+ def forward(self, latent_feat, trg_char, trg_img=None):
37
+ """Standard forward"""
38
+ dec_input = torch.cat((latent_feat, trg_char),-1)
39
+ dec_input = dec_input.view(dec_input.size(0), dec_input.size(1), 1, 1)
40
+ dec_out = self.decode(dec_input)
41
+ output = {}
42
+ output['gen_imgs'] = dec_out
43
+ if trg_img is not None:
44
+ output['img_l1loss'] = F.l1_loss(dec_out, trg_img)
45
+
46
+ return output
47
+
48
+
models/image_encoder.py CHANGED
@@ -1,42 +1,42 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import math
5
-
6
- class ImageEncoder(nn.Module):
7
-
8
- def __init__(self, img_size, input_nc, ngf=16, norm_layer=nn.LayerNorm):
9
-
10
- super(ImageEncoder, self).__init__()
11
- n_downsampling = int(math.log(img_size, 2))
12
- ks_list = [5] * (n_downsampling - n_downsampling // 3) + [3] * (n_downsampling // 3)
13
- stride_list = [2] * n_downsampling
14
-
15
- chn_mult = []
16
- for i in range(n_downsampling):
17
- chn_mult.append(2 ** (i + 1))
18
-
19
- encoder = [nn.Conv2d(input_nc, ngf, kernel_size=7, padding=7 // 2, bias=True, padding_mode='replicate'),
20
- norm_layer([ngf, 2 ** n_downsampling, 2 ** n_downsampling]),
21
- nn.ReLU(True)]
22
- for i in range(n_downsampling): # add downsampling layers
23
- if i == 0:
24
- chn_prev = ngf
25
- else:
26
- chn_prev = ngf * chn_mult[i - 1]
27
- chn_next = ngf * chn_mult[i]
28
-
29
- encoder += [nn.Conv2d(chn_prev, chn_next, kernel_size=ks_list[i], stride=stride_list[i], padding=ks_list[i] // 2, padding_mode='replicate'),
30
- norm_layer([chn_next, 2 ** (n_downsampling - 1 - i), 2 ** (n_downsampling - 1 - i)]),
31
- nn.ReLU(True)]
32
-
33
- self.encode = nn.Sequential(*encoder)
34
- self.flatten = nn.Flatten()
35
-
36
- def forward(self, input):
37
- """Standard forward"""
38
- ret = self.encode(input)
39
- img_feat = self.flatten(ret)
40
- output = {}
41
- output['img_feat'] = img_feat
42
- return output
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+ class ImageEncoder(nn.Module):
7
+
8
+ def __init__(self, img_size, input_nc, ngf=16, norm_layer=nn.LayerNorm):
9
+
10
+ super(ImageEncoder, self).__init__()
11
+ n_downsampling = int(math.log(img_size, 2))
12
+ ks_list = [5] * (n_downsampling - n_downsampling // 3) + [3] * (n_downsampling // 3)
13
+ stride_list = [2] * n_downsampling
14
+
15
+ chn_mult = []
16
+ for i in range(n_downsampling):
17
+ chn_mult.append(2 ** (i + 1))
18
+
19
+ encoder = [nn.Conv2d(input_nc, ngf, kernel_size=7, padding=7 // 2, bias=True, padding_mode='replicate'),
20
+ norm_layer([ngf, 2 ** n_downsampling, 2 ** n_downsampling]),
21
+ nn.ReLU(True)]
22
+ for i in range(n_downsampling): # add downsampling layers
23
+ if i == 0:
24
+ chn_prev = ngf
25
+ else:
26
+ chn_prev = ngf * chn_mult[i - 1]
27
+ chn_next = ngf * chn_mult[i]
28
+
29
+ encoder += [nn.Conv2d(chn_prev, chn_next, kernel_size=ks_list[i], stride=stride_list[i], padding=ks_list[i] // 2, padding_mode='replicate'),
30
+ norm_layer([chn_next, 2 ** (n_downsampling - 1 - i), 2 ** (n_downsampling - 1 - i)]),
31
+ nn.ReLU(True)]
32
+
33
+ self.encode = nn.Sequential(*encoder)
34
+ self.flatten = nn.Flatten()
35
+
36
+ def forward(self, input):
37
+ """Standard forward"""
38
+ ret = self.encode(input)
39
+ img_feat = self.flatten(ret)
40
+ output = {}
41
+ output['img_feat'] = img_feat
42
+ return output
models/modality_fusion.py CHANGED
@@ -1,64 +1,64 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import math
5
- from options import get_parser_main_model
6
- opts = get_parser_main_model().parse_args()
7
-
8
- def init_weights(m):
9
- for name, param in m.named_parameters():
10
- nn.init.uniform_(param.data, -0.08, 0.08)
11
-
12
-
13
- class ModalityFusion(nn.Module):
14
- def __init__(self, img_size=64, ref_nshot=4, bottleneck_bits=512, ngf=32, seq_latent_dim=512, mode='train'):
15
- super().__init__()
16
- self.mode = mode
17
- self.bottleneck_bits = bottleneck_bits
18
- self.ref_nshot = ref_nshot
19
- self.mode = mode
20
- self.fc_merge = nn.Linear(seq_latent_dim * opts.ref_nshot, 512)
21
- n_downsampling = int(math.log(img_size, 2))
22
- mult_max = 2 ** (n_downsampling)
23
- self.fc_fusion = nn.Linear(ngf * mult_max + seq_latent_dim, opts.bottleneck_bits * 2, bias=True) # the max multiplier for img feat channels is
24
-
25
- def forward(self, seq_feat, img_feat, ref_pad_mask=None):
26
-
27
-
28
- cls_one_pad = torch.ones((1,1,1)).to(seq_feat.device).repeat(seq_feat.size(0),1,1)
29
- ref_pad_mask = torch.cat([cls_one_pad,ref_pad_mask],dim=-1)
30
-
31
- seq_feat = seq_feat * (ref_pad_mask.transpose(1, 2))
32
- seq_feat_ = seq_feat.view(seq_feat.size(0) // self.ref_nshot, self.ref_nshot,seq_feat.size(-2) , seq_feat.size(-1))
33
- seq_feat_ = seq_feat_.transpose(1, 2)
34
- seq_feat_ = seq_feat_.contiguous().view(seq_feat_.size(0), seq_feat_.size(1), seq_feat_.size(2) * seq_feat_.size(3))
35
- seq_feat_ = self.fc_merge(seq_feat_)
36
- seq_feat_cls = seq_feat_[:, 0]
37
-
38
- feat_cat = torch.cat((img_feat, seq_feat_cls),-1)
39
- dist_param = self.fc_fusion(feat_cat)
40
-
41
- output = {}
42
- mu = dist_param[..., :self.bottleneck_bits]
43
- log_sigma = dist_param[..., self.bottleneck_bits:]
44
-
45
- if self.mode == 'train':
46
- # calculate the kl loss and reparamerize latent code
47
- epsilon = torch.randn(*mu.size(), device=mu.device)
48
- z = mu + torch.exp(log_sigma / 2) * epsilon
49
- kl = 0.5 * torch.mean(torch.exp(log_sigma) + torch.square(mu) - 1. - log_sigma)
50
- output['latent'] = z
51
- output['kl_loss'] = kl
52
- seq_feat_[:, 0] = z
53
- latent_feat_seq = seq_feat_
54
-
55
- else:
56
- output['latent'] = mu
57
- output['kl_loss'] = 0.0
58
- seq_feat_[:, 0] = mu
59
- latent_feat_seq = seq_feat_
60
-
61
-
62
- return output, latent_feat_seq
63
-
64
-
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from options import get_parser_main_model
6
+ opts = get_parser_main_model().parse_args()
7
+
8
+ def init_weights(m):
9
+ for name, param in m.named_parameters():
10
+ nn.init.uniform_(param.data, -0.08, 0.08)
11
+
12
+
13
+ class ModalityFusion(nn.Module):
14
+ def __init__(self, img_size=64, ref_nshot=4, bottleneck_bits=512, ngf=32, seq_latent_dim=512, mode='train'):
15
+ super().__init__()
16
+ self.mode = mode
17
+ self.bottleneck_bits = bottleneck_bits
18
+ self.ref_nshot = ref_nshot
19
+ self.mode = mode
20
+ self.fc_merge = nn.Linear(seq_latent_dim * opts.ref_nshot, 512)
21
+ n_downsampling = int(math.log(img_size, 2))
22
+ mult_max = 2 ** (n_downsampling)
23
+ self.fc_fusion = nn.Linear(ngf * mult_max + seq_latent_dim, opts.bottleneck_bits * 2, bias=True) # the max multiplier for img feat channels is
24
+
25
+ def forward(self, seq_feat, img_feat, ref_pad_mask=None):
26
+
27
+
28
+ cls_one_pad = torch.ones((1,1,1)).to(seq_feat.device).repeat(seq_feat.size(0),1,1)
29
+ ref_pad_mask = torch.cat([cls_one_pad,ref_pad_mask],dim=-1)
30
+
31
+ seq_feat = seq_feat * (ref_pad_mask.transpose(1, 2))
32
+ seq_feat_ = seq_feat.view(seq_feat.size(0) // self.ref_nshot, self.ref_nshot,seq_feat.size(-2) , seq_feat.size(-1))
33
+ seq_feat_ = seq_feat_.transpose(1, 2)
34
+ seq_feat_ = seq_feat_.contiguous().view(seq_feat_.size(0), seq_feat_.size(1), seq_feat_.size(2) * seq_feat_.size(3))
35
+ seq_feat_ = self.fc_merge(seq_feat_)
36
+ seq_feat_cls = seq_feat_[:, 0]
37
+
38
+ feat_cat = torch.cat((img_feat, seq_feat_cls),-1)
39
+ dist_param = self.fc_fusion(feat_cat)
40
+
41
+ output = {}
42
+ mu = dist_param[..., :self.bottleneck_bits]
43
+ log_sigma = dist_param[..., self.bottleneck_bits:]
44
+
45
+ if self.mode == 'train':
46
+ # calculate the kl loss and reparamerize latent code
47
+ epsilon = torch.randn(*mu.size(), device=mu.device)
48
+ z = mu + torch.exp(log_sigma / 2) * epsilon
49
+ kl = 0.5 * torch.mean(torch.exp(log_sigma) + torch.square(mu) - 1. - log_sigma)
50
+ output['latent'] = z
51
+ output['kl_loss'] = kl
52
+ seq_feat_[:, 0] = z
53
+ latent_feat_seq = seq_feat_
54
+
55
+ else:
56
+ output['latent'] = mu
57
+ output['kl_loss'] = 0.0
58
+ seq_feat_[:, 0] = mu
59
+ latent_feat_seq = seq_feat_
60
+
61
+
62
+ return output, latent_feat_seq
63
+
64
+
models/model_main.py CHANGED
@@ -1,212 +1,212 @@
1
- from models.image_encoder import ImageEncoder
2
- from models.image_decoder import ImageDecoder
3
- from models.modality_fusion import ModalityFusion
4
- from models.vgg_perceptual_loss import VGGPerceptualLoss
5
- from models.transformers import *
6
- from torch.autograd import Variable
7
-
8
- class ModelMain(nn.Module):
9
-
10
- def __init__(self, opts, mode='train'):
11
- super().__init__()
12
- self.opts = opts
13
- self.img_encoder = ImageEncoder(img_size=opts.img_size, input_nc=opts.ref_nshot, ngf=opts.ngf, norm_layer=nn.LayerNorm)
14
- self.img_decoder = ImageDecoder(img_size=opts.img_size, input_nc=opts.bottleneck_bits + opts.char_num, output_nc=1, ngf=opts.ngf, norm_layer=nn.LayerNorm)
15
- self.vggptlossfunc = VGGPerceptualLoss()
16
- self.modality_fusion = ModalityFusion(img_size=opts.img_size, ref_nshot=opts.ref_nshot, bottleneck_bits=opts.bottleneck_bits, ngf=opts.ngf, mode=opts.mode)
17
- self.transformer_main = Transformer(
18
- input_channels = 1,
19
- input_axis = 2, # number of axis for input data (2 for images, 3 for video)
20
- num_freq_bands = 6, # number of freq bands, with original value (2 * K + 1)
21
- max_freq = 10., # maximum frequency, hyperparameter depending on how fine the data is
22
- depth = 6, # depth of net. The shape of the final attention mechanism will be:
23
- # depth * (cross attention -> self_per_cross_attn * self attention)
24
- num_latents = 256, # number of latents, or induced set points, or centroids. different papers giving it different names
25
- latent_dim = opts.dim_seq_latent, # latent dimension
26
- cross_heads = 1, # number of heads for cross attention. paper said 1
27
- latent_heads = 8, # number of heads for latent self attention, 8
28
- cross_dim_head = 64, # number of dimensions per cross attention head
29
- latent_dim_head = 64, # number of dimensions per latent self attention head
30
- num_classes = 1000, # output number of classes
31
- attn_dropout = 0.,
32
- ff_dropout = 0.,
33
- weight_tie_layers = False, # whether to weight tie layers (optional, as indicated in the diagram)
34
- fourier_encode_data = True, # whether to auto-fourier encode the data, using the input_axis given. defaults to True, but can be turned off if you are fourier encoding the data yourself
35
- self_per_cross_attn = 2 # number of self attention blocks per cross attention
36
- )
37
-
38
- self.transformer_seqdec = Transformer_decoder()
39
-
40
-
41
- def forward(self, data, mode='train'):
42
-
43
- imgs, seqs, scalars = self.fetch_data(data, mode)
44
- ref_img, trg_img = imgs
45
- ref_seq, ref_seq_cat, ref_pad_mask, trg_seq, trg_seq_gt, trg_seq_shifted, trg_pts_aux = seqs
46
- trg_char_onehot, trg_cls, trg_seqlen = scalars
47
-
48
- # image encoding
49
- img_encoder_out = self.img_encoder(ref_img)
50
- img_feat = img_encoder_out['img_feat'] # bs, ngf * (2 ** 6)
51
-
52
- # seq encoding
53
- ref_img_ = ref_img.view(ref_img.size(0) * ref_img.size(1), ref_img.size(2), ref_img.size(3)).unsqueeze(-1) # [max_seq_len, n_bs * n_ref, 9]
54
- seq_feat, _ = self.transformer_main(ref_img_, ref_seq_cat, mask=ref_pad_mask) # [n_bs * n_ref, max_seq_len + 1, 9]
55
-
56
- # modality funsion
57
- mf_output, latent_feat_seq = self.modality_fusion(seq_feat, img_feat, ref_pad_mask=ref_pad_mask)
58
- latent_feat_seq = self.transformer_main.att_residual(latent_feat_seq) # [n_bs, max_seq_len + 1, bottleneck_bits]
59
- z = mf_output['latent']
60
- kl_loss = mf_output['kl_loss']
61
-
62
- # image decoding
63
- img_decoder_out = self.img_decoder(z, trg_char_onehot, trg_img)
64
-
65
- ret_dict = {}
66
- loss_dict = {}
67
-
68
- ret_dict['img'] = {}
69
- ret_dict['img']['out'] = img_decoder_out['gen_imgs']
70
- ret_dict['img']['ref'] = ref_img
71
- ret_dict['img']['trg'] = trg_img
72
-
73
- if mode in {'train', 'val'}:
74
- # seq decoding (training or val mode)
75
- tgt_mask = Variable(subsequent_mask(self.opts.max_seq_len).type_as(ref_pad_mask.data)).unsqueeze(0).expand(z.size(0), -1, -1, -1).cuda().float()
76
- command_logits, args_logits, attn = self.transformer_seqdec(x=trg_seq_shifted, memory=latent_feat_seq, trg_char=trg_cls, tgt_mask=tgt_mask)
77
- command_logits_2, args_logits_2 = self.transformer_seqdec.parallel_decoder(command_logits, args_logits, memory=latent_feat_seq.detach(), trg_char=trg_cls)
78
-
79
- total_loss = self.transformer_main.loss(command_logits, args_logits,trg_seq, trg_seqlen, trg_pts_aux)
80
- total_loss_parallel = self.transformer_main.loss(command_logits_2, args_logits_2, trg_seq, trg_seqlen, trg_pts_aux)
81
- vggpt_loss = self.vggptlossfunc(img_decoder_out['gen_imgs'], trg_img)
82
- # loss and output
83
- loss_svg_items = ['total', 'cmd', 'args', 'smt', 'aux']
84
- # for image
85
- loss_dict['img'] = {}
86
- loss_dict['img']['l1'] = img_decoder_out['img_l1loss']
87
- loss_dict['img']['vggpt'] = vggpt_loss['pt_c_loss']
88
- # for latent
89
- loss_dict['kl'] = kl_loss
90
- # for svg
91
- loss_dict['svg'] = {}
92
- loss_dict['svg_para'] = {}
93
- for item in loss_svg_items:
94
- loss_dict['svg'][item] = total_loss[f'loss_{item}']
95
- loss_dict['svg_para'][item] = total_loss_parallel[f'loss_{item}']
96
-
97
- else: # testing (inference)
98
-
99
- trg_len = trg_seq_shifted.size(0)
100
- sampled_svg = torch.zeros(1, trg_seq.size(1), self.opts.dim_seq_short).cuda()
101
-
102
- for t in range(0, trg_len):
103
- tgt_mask = Variable(subsequent_mask(sampled_svg.size(0)).type_as(ref_seq_cat.data)).unsqueeze(0).expand(sampled_svg.size(1), -1, -1, -1).cuda().float()
104
- command_logits, args_logits, attn = self.transformer_seqdec(x=sampled_svg, memory=latent_feat_seq, trg_char=trg_cls, tgt_mask=tgt_mask)
105
- prob_comand = F.softmax(command_logits[:, -1, :], -1)
106
- prob_args = F.softmax(args_logits[:, -1, :], -1)
107
- next_command = torch.argmax(prob_comand, -1).unsqueeze(-1)
108
- next_args = torch.argmax(prob_args, -1)
109
- predict_tmp = torch.cat((next_command, next_args),-1).unsqueeze(1).transpose(0,1)
110
- sampled_svg = torch.cat((sampled_svg, predict_tmp), dim=0)
111
-
112
- sampled_svg = sampled_svg[1:]
113
- cmd2 = sampled_svg[:,:,0].unsqueeze(-1)
114
- arg2 = sampled_svg[:,:,1:]
115
- command_logits_2, args_logits_2 = self.transformer_seqdec.parallel_decoder(cmd_logits=cmd2, args_logits=arg2, memory=latent_feat_seq, trg_char=trg_cls)
116
- prob_comand = F.softmax(command_logits_2,-1)
117
- prob_args = F.softmax(args_logits_2,-1)
118
- update_command = torch.argmax(prob_comand,-1).unsqueeze(-1)
119
- update_args = torch.argmax(prob_args,-1)
120
-
121
- sampled_svg_parralel = torch.cat((update_command, update_args),-1).transpose(0,1)
122
-
123
- commands1 = F.one_hot(sampled_svg[:,:,:1].long(), 4).squeeze().transpose(0, 1)
124
- args1 = denumericalize(sampled_svg[:,:,1:]).transpose(0,1)
125
- sampled_svg_1 = torch.cat([commands1.cpu().detach(),args1[:, :, 2:].cpu().detach()],dim =-1)
126
-
127
-
128
- commands2 = F.one_hot(sampled_svg_parralel[:, :, :1].long(), 4).squeeze().transpose(0, 1)
129
- args2 = denumericalize(sampled_svg_parralel[:, :, 1:]).transpose(0,1)
130
- sampled_svg_2 = torch.cat([commands2.cpu().detach(),args2[:, :, 2:].cpu().detach()], dim =-1)
131
-
132
- ret_dict['svg'] = {}
133
- ret_dict['svg']['sampled_1'] = sampled_svg_1
134
- ret_dict['svg']['sampled_2'] = sampled_svg_2
135
- ret_dict['svg']['trg'] = trg_seq_gt
136
-
137
- return ret_dict, loss_dict
138
-
139
- def fetch_data(self, data, mode):
140
-
141
- input_image = data['rendered'] # [bs, opts.char_num, opts.img_size, opts.img_size]
142
- input_sequence = data['sequence'] # [bs, opts.char_num, opts.max_seq_len]
143
- input_seqlen = data['seq_len']
144
- input_seqlen = input_seqlen + 1
145
- input_pts_aux = data['pts_aux']
146
- arg_quant = numericalize(input_sequence[:, :, :, 4:])
147
- cmd_cls = torch.argmax(input_sequence[:, :, :, :4], dim=-1).unsqueeze(-1)
148
- input_sequence = torch.cat([cmd_cls, arg_quant], dim=-1) # 1 + 8 = 9 dimension
149
-
150
- # choose reference classes and target classes
151
-
152
-
153
- if mode == 'train':
154
- ref_cls = torch.randint(0, self.opts.char_num, (input_image.size(0), self.opts.ref_nshot)).cuda()
155
- if opts.ref_nshot == 52: # For ENG to TH
156
- ref_cls_upper = torch.randint(0, 26, (input_image.size(0), self.opts.ref_nshot // 2)).cuda()
157
- ref_cls_lower = torch.randint(26, 52, (input_image.size(0), self.opts.ref_nshot // 2)).cuda()
158
- ref_cls = torch.cat((ref_cls_upper, ref_cls_lower), -1)
159
- elif mode == 'val':
160
- ref_cls = torch.arange(0, self.opts.ref_nshot, 1).cuda().unsqueeze(0).expand(input_image.size(0), -1)
161
- else:
162
- ref_ids = self.opts.ref_char_ids.split(',')
163
- ref_ids = list(map(int, ref_ids))
164
- assert len(ref_ids) == self.opts.ref_nshot
165
- ref_cls = torch.tensor(ref_ids).cuda().unsqueeze(0).expand(self.opts.char_num, -1)
166
-
167
-
168
-
169
- if mode in {'train', 'val'}:
170
- trg_cls = torch.randint(0, self.opts.char_num, (input_image.size(0), 1)).cuda()
171
- if opts.ref_nshot == 52:
172
- trg_cls = torch.randint(52, opts.char_num, (input_image.size(0), 1)).cuda()
173
- else:
174
- trg_cls = torch.arange(0, self.opts.char_num).cuda()
175
- if opts.ref_nshot == 52:
176
- trg_cls = torch.randint(52, opts.char_num, (input_image.size(0), 1)).cuda()
177
- trg_cls = trg_cls.view(self.opts.char_num, 1)
178
- input_image = input_image.expand(self.opts.char_num, -1, -1, -1)
179
- input_sequence = input_sequence.expand(self.opts.char_num, -1, -1, -1)
180
- input_pts_aux = input_pts_aux.expand(self.opts.char_num, -1, -1, -1)
181
- input_seqlen = input_seqlen.expand(self.opts.char_num, -1, -1)
182
-
183
- ref_img = util_funcs.select_imgs(input_image, ref_cls, self.opts)
184
- # select a target glyph image
185
- trg_img = util_funcs.select_imgs(input_image, trg_cls, self.opts)
186
- # randomly select ref vector glyphs
187
- ref_seq = util_funcs.select_seqs(input_sequence, ref_cls, self.opts, self.opts.dim_seq_short) # [opts.batch_size, opts.ref_nshot, opts.max_seq_len, opts.dim_seq_nmr]
188
- # randomly select a target vector glyph
189
- trg_seq = util_funcs.select_seqs(input_sequence, trg_cls, self.opts, self.opts.dim_seq_short)
190
- trg_seq = trg_seq.squeeze(1)
191
- trg_pts_aux = util_funcs.select_seqs(input_pts_aux, trg_cls, self.opts, opts.n_aux_pts)
192
- trg_pts_aux = trg_pts_aux.squeeze(1)
193
- # the one-hot target char class
194
- trg_char_onehot = util_funcs.trgcls_to_onehot(trg_cls, self.opts)
195
- # shift target sequence
196
- trg_seq_gt = trg_seq.clone().detach()
197
- trg_seq_gt = torch.cat((trg_seq_gt[:, :, :1], trg_seq_gt[:, :, 3:]), -1)
198
- trg_seq = trg_seq.transpose(0, 1)
199
- trg_seq_shifted = util_funcs.shift_right(trg_seq)
200
-
201
- ref_seq_cat = ref_seq.view(ref_seq.size(0) * ref_seq.size(1), ref_seq.size(2), ref_seq.size(3))
202
- ref_seq_cat = ref_seq_cat.transpose(0,1)
203
- ref_seqlen = util_funcs.select_seqlens(input_seqlen, ref_cls, self.opts)
204
- ref_seqlen_cat = ref_seqlen.view(ref_seqlen.size(0) * ref_seqlen.size(1), ref_seqlen.size(2))
205
- ref_pad_mask = torch.zeros(ref_seqlen_cat.size(0), self.opts.max_seq_len) # value = 1 means pos to be masked
206
- for i in range(ref_seqlen_cat.size(0)):
207
- ref_pad_mask[i,:ref_seqlen_cat[i]] = 1
208
- ref_pad_mask = ref_pad_mask.cuda().float().unsqueeze(1)
209
- trg_seqlen = util_funcs.select_seqlens(input_seqlen, trg_cls, self.opts)
210
- trg_seqlen = trg_seqlen.squeeze()
211
-
212
  return [ref_img, trg_img], [ref_seq, ref_seq_cat, ref_pad_mask, trg_seq, trg_seq_gt, trg_seq_shifted, trg_pts_aux], [trg_char_onehot, trg_cls, trg_seqlen]
 
1
+ from models.image_encoder import ImageEncoder
2
+ from models.image_decoder import ImageDecoder
3
+ from models.modality_fusion import ModalityFusion
4
+ from models.vgg_perceptual_loss import VGGPerceptualLoss
5
+ from models.transformers import *
6
+ from torch.autograd import Variable
7
+
8
+ class ModelMain(nn.Module):
9
+
10
+ def __init__(self, opts, mode='train'):
11
+ super().__init__()
12
+ self.opts = opts
13
+ self.img_encoder = ImageEncoder(img_size=opts.img_size, input_nc=opts.ref_nshot, ngf=opts.ngf, norm_layer=nn.LayerNorm)
14
+ self.img_decoder = ImageDecoder(img_size=opts.img_size, input_nc=opts.bottleneck_bits + opts.char_num, output_nc=1, ngf=opts.ngf, norm_layer=nn.LayerNorm)
15
+ self.vggptlossfunc = VGGPerceptualLoss()
16
+ self.modality_fusion = ModalityFusion(img_size=opts.img_size, ref_nshot=opts.ref_nshot, bottleneck_bits=opts.bottleneck_bits, ngf=opts.ngf, mode=opts.mode)
17
+ self.transformer_main = Transformer(
18
+ input_channels = 1,
19
+ input_axis = 2, # number of axis for input data (2 for images, 3 for video)
20
+ num_freq_bands = 6, # number of freq bands, with original value (2 * K + 1)
21
+ max_freq = 10., # maximum frequency, hyperparameter depending on how fine the data is
22
+ depth = 6, # depth of net. The shape of the final attention mechanism will be:
23
+ # depth * (cross attention -> self_per_cross_attn * self attention)
24
+ num_latents = 256, # number of latents, or induced set points, or centroids. different papers giving it different names
25
+ latent_dim = opts.dim_seq_latent, # latent dimension
26
+ cross_heads = 1, # number of heads for cross attention. paper said 1
27
+ latent_heads = 8, # number of heads for latent self attention, 8
28
+ cross_dim_head = 64, # number of dimensions per cross attention head
29
+ latent_dim_head = 64, # number of dimensions per latent self attention head
30
+ num_classes = 1000, # output number of classes
31
+ attn_dropout = 0.,
32
+ ff_dropout = 0.,
33
+ weight_tie_layers = False, # whether to weight tie layers (optional, as indicated in the diagram)
34
+ fourier_encode_data = True, # whether to auto-fourier encode the data, using the input_axis given. defaults to True, but can be turned off if you are fourier encoding the data yourself
35
+ self_per_cross_attn = 2 # number of self attention blocks per cross attention
36
+ )
37
+
38
+ self.transformer_seqdec = Transformer_decoder()
39
+
40
+
41
+ def forward(self, data, mode='train'):
42
+
43
+ imgs, seqs, scalars = self.fetch_data(data, mode)
44
+ ref_img, trg_img = imgs
45
+ ref_seq, ref_seq_cat, ref_pad_mask, trg_seq, trg_seq_gt, trg_seq_shifted, trg_pts_aux = seqs
46
+ trg_char_onehot, trg_cls, trg_seqlen = scalars
47
+
48
+ # image encoding
49
+ img_encoder_out = self.img_encoder(ref_img)
50
+ img_feat = img_encoder_out['img_feat'] # bs, ngf * (2 ** 6)
51
+
52
+ # seq encoding
53
+ ref_img_ = ref_img.view(ref_img.size(0) * ref_img.size(1), ref_img.size(2), ref_img.size(3)).unsqueeze(-1) # [max_seq_len, n_bs * n_ref, 9]
54
+ seq_feat, _ = self.transformer_main(ref_img_, ref_seq_cat, mask=ref_pad_mask) # [n_bs * n_ref, max_seq_len + 1, 9]
55
+
56
+ # modality funsion
57
+ mf_output, latent_feat_seq = self.modality_fusion(seq_feat, img_feat, ref_pad_mask=ref_pad_mask)
58
+ latent_feat_seq = self.transformer_main.att_residual(latent_feat_seq) # [n_bs, max_seq_len + 1, bottleneck_bits]
59
+ z = mf_output['latent']
60
+ kl_loss = mf_output['kl_loss']
61
+
62
+ # image decoding
63
+ img_decoder_out = self.img_decoder(z, trg_char_onehot, trg_img)
64
+
65
+ ret_dict = {}
66
+ loss_dict = {}
67
+
68
+ ret_dict['img'] = {}
69
+ ret_dict['img']['out'] = img_decoder_out['gen_imgs']
70
+ ret_dict['img']['ref'] = ref_img
71
+ ret_dict['img']['trg'] = trg_img
72
+
73
+ if mode in {'train', 'val'}:
74
+ # seq decoding (training or val mode)
75
+ tgt_mask = Variable(subsequent_mask(self.opts.max_seq_len).type_as(ref_pad_mask.data)).unsqueeze(0).expand(z.size(0), -1, -1, -1).cuda().float()
76
+ command_logits, args_logits, attn = self.transformer_seqdec(x=trg_seq_shifted, memory=latent_feat_seq, trg_char=trg_cls, tgt_mask=tgt_mask)
77
+ command_logits_2, args_logits_2 = self.transformer_seqdec.parallel_decoder(command_logits, args_logits, memory=latent_feat_seq.detach(), trg_char=trg_cls)
78
+
79
+ total_loss = self.transformer_main.loss(command_logits, args_logits,trg_seq, trg_seqlen, trg_pts_aux)
80
+ total_loss_parallel = self.transformer_main.loss(command_logits_2, args_logits_2, trg_seq, trg_seqlen, trg_pts_aux)
81
+ vggpt_loss = self.vggptlossfunc(img_decoder_out['gen_imgs'], trg_img)
82
+ # loss and output
83
+ loss_svg_items = ['total', 'cmd', 'args', 'smt', 'aux']
84
+ # for image
85
+ loss_dict['img'] = {}
86
+ loss_dict['img']['l1'] = img_decoder_out['img_l1loss']
87
+ loss_dict['img']['vggpt'] = vggpt_loss['pt_c_loss']
88
+ # for latent
89
+ loss_dict['kl'] = kl_loss
90
+ # for svg
91
+ loss_dict['svg'] = {}
92
+ loss_dict['svg_para'] = {}
93
+ for item in loss_svg_items:
94
+ loss_dict['svg'][item] = total_loss[f'loss_{item}']
95
+ loss_dict['svg_para'][item] = total_loss_parallel[f'loss_{item}']
96
+
97
+ else: # testing (inference)
98
+
99
+ trg_len = trg_seq_shifted.size(0)
100
+ sampled_svg = torch.zeros(1, trg_seq.size(1), self.opts.dim_seq_short).cuda()
101
+
102
+ for t in range(0, trg_len):
103
+ tgt_mask = Variable(subsequent_mask(sampled_svg.size(0)).type_as(ref_seq_cat.data)).unsqueeze(0).expand(sampled_svg.size(1), -1, -1, -1).cuda().float()
104
+ command_logits, args_logits, attn = self.transformer_seqdec(x=sampled_svg, memory=latent_feat_seq, trg_char=trg_cls, tgt_mask=tgt_mask)
105
+ prob_comand = F.softmax(command_logits[:, -1, :], -1)
106
+ prob_args = F.softmax(args_logits[:, -1, :], -1)
107
+ next_command = torch.argmax(prob_comand, -1).unsqueeze(-1)
108
+ next_args = torch.argmax(prob_args, -1)
109
+ predict_tmp = torch.cat((next_command, next_args),-1).unsqueeze(1).transpose(0,1)
110
+ sampled_svg = torch.cat((sampled_svg, predict_tmp), dim=0)
111
+
112
+ sampled_svg = sampled_svg[1:]
113
+ cmd2 = sampled_svg[:,:,0].unsqueeze(-1)
114
+ arg2 = sampled_svg[:,:,1:]
115
+ command_logits_2, args_logits_2 = self.transformer_seqdec.parallel_decoder(cmd_logits=cmd2, args_logits=arg2, memory=latent_feat_seq, trg_char=trg_cls)
116
+ prob_comand = F.softmax(command_logits_2,-1)
117
+ prob_args = F.softmax(args_logits_2,-1)
118
+ update_command = torch.argmax(prob_comand,-1).unsqueeze(-1)
119
+ update_args = torch.argmax(prob_args,-1)
120
+
121
+ sampled_svg_parralel = torch.cat((update_command, update_args),-1).transpose(0,1)
122
+
123
+ commands1 = F.one_hot(sampled_svg[:,:,:1].long(), 4).squeeze().transpose(0, 1)
124
+ args1 = denumericalize(sampled_svg[:,:,1:]).transpose(0,1)
125
+ sampled_svg_1 = torch.cat([commands1.cpu().detach(),args1[:, :, 2:].cpu().detach()],dim =-1)
126
+
127
+
128
+ commands2 = F.one_hot(sampled_svg_parralel[:, :, :1].long(), 4).squeeze().transpose(0, 1)
129
+ args2 = denumericalize(sampled_svg_parralel[:, :, 1:]).transpose(0,1)
130
+ sampled_svg_2 = torch.cat([commands2.cpu().detach(),args2[:, :, 2:].cpu().detach()], dim =-1)
131
+
132
+ ret_dict['svg'] = {}
133
+ ret_dict['svg']['sampled_1'] = sampled_svg_1
134
+ ret_dict['svg']['sampled_2'] = sampled_svg_2
135
+ ret_dict['svg']['trg'] = trg_seq_gt
136
+
137
+ return ret_dict, loss_dict
138
+
139
+ def fetch_data(self, data, mode):
140
+
141
+ input_image = data['rendered'] # [bs, opts.char_num, opts.img_size, opts.img_size]
142
+ input_sequence = data['sequence'] # [bs, opts.char_num, opts.max_seq_len]
143
+ input_seqlen = data['seq_len']
144
+ input_seqlen = input_seqlen + 1
145
+ input_pts_aux = data['pts_aux']
146
+ arg_quant = numericalize(input_sequence[:, :, :, 4:])
147
+ cmd_cls = torch.argmax(input_sequence[:, :, :, :4], dim=-1).unsqueeze(-1)
148
+ input_sequence = torch.cat([cmd_cls, arg_quant], dim=-1) # 1 + 8 = 9 dimension
149
+
150
+ # choose reference classes and target classes
151
+
152
+
153
+ if mode == 'train':
154
+ ref_cls = torch.randint(0, self.opts.char_num, (input_image.size(0), self.opts.ref_nshot)).cuda()
155
+ if opts.ref_nshot == 52: # For ENG to TH
156
+ ref_cls_upper = torch.randint(0, 26, (input_image.size(0), self.opts.ref_nshot // 2)).cuda()
157
+ ref_cls_lower = torch.randint(26, 52, (input_image.size(0), self.opts.ref_nshot // 2)).cuda()
158
+ ref_cls = torch.cat((ref_cls_upper, ref_cls_lower), -1)
159
+ elif mode == 'val':
160
+ ref_cls = torch.arange(0, self.opts.ref_nshot, 1).cuda().unsqueeze(0).expand(input_image.size(0), -1)
161
+ else:
162
+ ref_ids = self.opts.ref_char_ids.split(',')
163
+ ref_ids = list(map(int, ref_ids))
164
+ assert len(ref_ids) == self.opts.ref_nshot
165
+ ref_cls = torch.tensor(ref_ids).cuda().unsqueeze(0).expand(self.opts.char_num, -1)
166
+
167
+
168
+
169
+ if mode in {'train', 'val'}:
170
+ trg_cls = torch.randint(0, self.opts.char_num, (input_image.size(0), 1)).cuda()
171
+ if opts.ref_nshot == 52:
172
+ trg_cls = torch.randint(52, opts.char_num, (input_image.size(0), 1)).cuda()
173
+ else:
174
+ trg_cls = torch.arange(0, self.opts.char_num).cuda()
175
+ if opts.ref_nshot == 52:
176
+ trg_cls = torch.randint(52, opts.char_num, (input_image.size(0), 1)).cuda()
177
+ trg_cls = trg_cls.view(self.opts.char_num, 1)
178
+ input_image = input_image.expand(self.opts.char_num, -1, -1, -1)
179
+ input_sequence = input_sequence.expand(self.opts.char_num, -1, -1, -1)
180
+ input_pts_aux = input_pts_aux.expand(self.opts.char_num, -1, -1, -1)
181
+ input_seqlen = input_seqlen.expand(self.opts.char_num, -1, -1)
182
+
183
+ ref_img = util_funcs.select_imgs(input_image, ref_cls, self.opts)
184
+ # select a target glyph image
185
+ trg_img = util_funcs.select_imgs(input_image, trg_cls, self.opts)
186
+ # randomly select ref vector glyphs
187
+ ref_seq = util_funcs.select_seqs(input_sequence, ref_cls, self.opts, self.opts.dim_seq_short) # [opts.batch_size, opts.ref_nshot, opts.max_seq_len, opts.dim_seq_nmr]
188
+ # randomly select a target vector glyph
189
+ trg_seq = util_funcs.select_seqs(input_sequence, trg_cls, self.opts, self.opts.dim_seq_short)
190
+ trg_seq = trg_seq.squeeze(1)
191
+ trg_pts_aux = util_funcs.select_seqs(input_pts_aux, trg_cls, self.opts, opts.n_aux_pts)
192
+ trg_pts_aux = trg_pts_aux.squeeze(1)
193
+ # the one-hot target char class
194
+ trg_char_onehot = util_funcs.trgcls_to_onehot(trg_cls, self.opts)
195
+ # shift target sequence
196
+ trg_seq_gt = trg_seq.clone().detach()
197
+ trg_seq_gt = torch.cat((trg_seq_gt[:, :, :1], trg_seq_gt[:, :, 3:]), -1)
198
+ trg_seq = trg_seq.transpose(0, 1)
199
+ trg_seq_shifted = util_funcs.shift_right(trg_seq)
200
+
201
+ ref_seq_cat = ref_seq.view(ref_seq.size(0) * ref_seq.size(1), ref_seq.size(2), ref_seq.size(3))
202
+ ref_seq_cat = ref_seq_cat.transpose(0,1)
203
+ ref_seqlen = util_funcs.select_seqlens(input_seqlen, ref_cls, self.opts)
204
+ ref_seqlen_cat = ref_seqlen.view(ref_seqlen.size(0) * ref_seqlen.size(1), ref_seqlen.size(2))
205
+ ref_pad_mask = torch.zeros(ref_seqlen_cat.size(0), self.opts.max_seq_len) # value = 1 means pos to be masked
206
+ for i in range(ref_seqlen_cat.size(0)):
207
+ ref_pad_mask[i,:ref_seqlen_cat[i]] = 1
208
+ ref_pad_mask = ref_pad_mask.cuda().float().unsqueeze(1)
209
+ trg_seqlen = util_funcs.select_seqlens(input_seqlen, trg_cls, self.opts)
210
+ trg_seqlen = trg_seqlen.squeeze()
211
+
212
  return [ref_img, trg_img], [ref_seq, ref_seq_cat, ref_pad_mask, trg_seq, trg_seq_gt, trg_seq_shifted, trg_pts_aux], [trg_char_onehot, trg_cls, trg_seqlen]
models/pos_enc.py CHANGED
@@ -1,21 +1,21 @@
1
-
2
- class PositionalEncoding(nn.Module):
3
- def __init__(self, d_model, dropout=0.1, max_len=5000):
4
- super(PositionalEncoding, self).__init__()
5
- self.dropout = nn.Dropout(p=dropout)
6
- pe = torch.zeros(max_len, d_model) # [max_len, d_model]
7
- position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # [max_len, 1]
8
- div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
9
- # [d_model/2]
10
- pe[:, 0::2] = torch.sin(position * div_term) # [max_len, d_model/2]
11
- pe[:, 1::2] = torch.cos(position * div_term)
12
- pe = pe.unsqueeze(0).transpose(0, 1) # 1,51,512 --> [51, 1, d_model]
13
- self.register_buffer('pe', pe)
14
-
15
- def forward(self, x):
16
- """
17
- :param x: [x_len, batch_size, emb_size]
18
- :return: [x_len, batch_size, emb_size]
19
- """
20
- x = x + self.pe[:x.size(0), :] # [x_len, batch_size, d_model]
21
  return self.dropout(x)
 
1
+
2
+ class PositionalEncoding(nn.Module):
3
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
4
+ super(PositionalEncoding, self).__init__()
5
+ self.dropout = nn.Dropout(p=dropout)
6
+ pe = torch.zeros(max_len, d_model) # [max_len, d_model]
7
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # [max_len, 1]
8
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
9
+ # [d_model/2]
10
+ pe[:, 0::2] = torch.sin(position * div_term) # [max_len, d_model/2]
11
+ pe[:, 1::2] = torch.cos(position * div_term)
12
+ pe = pe.unsqueeze(0).transpose(0, 1) # 1,51,512 --> [51, 1, d_model]
13
+ self.register_buffer('pe', pe)
14
+
15
+ def forward(self, x):
16
+ """
17
+ :param x: [x_len, batch_size, emb_size]
18
+ :return: [x_len, batch_size, emb_size]
19
+ """
20
+ x = x + self.pe[:x.size(0), :] # [x_len, batch_size, d_model]
21
  return self.dropout(x)
models/transformers.py CHANGED
@@ -1,711 +1,711 @@
1
- from math import pi, log
2
- from functools import wraps
3
- from multiprocessing import context
4
- from textwrap import indent
5
- import models.util_funcs as util_funcs
6
- import math, copy
7
- import numpy as np
8
- import torch
9
- from torch import nn, einsum
10
- import torch.nn.functional as F
11
- from einops import rearrange, repeat
12
- from einops.layers.torch import Reduce
13
- import pdb
14
- from einops.layers.torch import Rearrange
15
- from options import get_parser_main_model
16
- opts = get_parser_main_model().parse_args()
17
-
18
- class PositionalEncoding(nn.Module):
19
- def __init__(self, d_model, dropout=0.1, max_len=5000):
20
- super(PositionalEncoding, self).__init__()
21
- self.dropout = nn.Dropout(p=dropout)
22
- pe = torch.zeros(max_len, d_model)
23
- position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
24
- div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
25
- pe[:, 0::2] = torch.sin(position * div_term)
26
- pe[:, 1::2] = torch.cos(position * div_term)
27
- pe = pe.unsqueeze(0).transpose(0, 1)
28
- self.register_buffer('pe', pe)
29
-
30
- def forward(self, x):
31
- """
32
- :param x: [x_len, batch_size, emb_size]
33
- :return: [x_len, batch_size, emb_size]
34
- """
35
- x = x + self.pe[:x.size(0), :].to(x.device)
36
- return self.dropout(x)
37
-
38
- def exists(val):
39
- return val is not None
40
-
41
- def default(val, d):
42
- return val if exists(val) else d
43
-
44
- def cache_fn(f):
45
- cache = dict()
46
- @wraps(f)
47
- def cached_fn(*args, _cache = True, key = None, **kwargs):
48
- if not _cache:
49
- return f(*args, **kwargs)
50
- nonlocal cache
51
- if key in cache:
52
- return cache[key]
53
- result = f(*args, **kwargs)
54
- cache[key] = result
55
- return result
56
- return cached_fn
57
-
58
- def fourier_encode(x, max_freq, num_bands = 4):
59
- '''
60
- x: ([64, 64, 2, 1]) is between [-1,1]
61
- max_feq is 10
62
- num_bands is 6
63
- '''
64
-
65
- x = x.unsqueeze(-1)
66
- device, dtype, orig_x = x.device, x.dtype, x
67
-
68
- scales = torch.linspace(1., max_freq / 2, num_bands, device = device, dtype = dtype) # tensor([1.0000, 1.8000, 2.6000, 3.4000, 4.2000, 5.0000]
69
- scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)] # r([[[[1.0000, 1.8000, 2.6000, 3.4000, 4.2000, 5.0000]]]],
70
-
71
- x = x * scales * pi
72
- x = torch.cat([x.sin(), x.cos()], dim = -1)
73
-
74
- x = torch.cat((x, orig_x), dim = -1)
75
- return x
76
-
77
- class PreNorm(nn.Module):
78
- def __init__(self, dim, fn, context_dim = None):
79
- super().__init__()
80
- self.fn = fn
81
- self.norm = nn.LayerNorm(dim)
82
- self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None
83
-
84
- def forward(self, x, **kwargs):
85
- x = self.norm(x)
86
-
87
- if exists(self.norm_context):
88
- context = kwargs['context']
89
- normed_context = self.norm_context(context)
90
- kwargs.update(context = normed_context)
91
-
92
- return self.fn(x, **kwargs)
93
-
94
- class GEGLU(nn.Module):
95
- def forward(self, x):
96
- x, gates = x.chunk(2, dim = -1)
97
- return x * F.gelu(gates)
98
-
99
- class FeedForward(nn.Module):
100
- def __init__(self, dim, mult = 4, dropout = 0.):
101
- super().__init__()
102
- self.net = nn.Sequential(
103
- nn.Linear(dim, dim * mult * 2),
104
- GEGLU(),
105
- nn.Linear(dim * mult, dim),
106
- nn.Dropout(dropout)
107
- )
108
-
109
- def forward(self, x):
110
- return self.net(x)
111
-
112
- class Attention(nn.Module):
113
- def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64, dropout = 0.,cls_conv_dim=None):
114
- super().__init__()
115
- inner_dim = dim_head * heads
116
- context_dim = default(context_dim, query_dim)
117
-
118
- self.scale = dim_head ** -0.5
119
- self.heads = heads
120
-
121
- self.to_q = nn.Linear(query_dim, inner_dim, bias = False)
122
- self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False) # 27 to 5012*2 = 1024
123
-
124
- self.dropout = nn.Dropout(dropout)
125
- self.to_out = nn.Linear(inner_dim, query_dim)
126
- #self.cls_dim_adjust = nn.Linear(context_dim,cls_conv_dim)
127
-
128
- def forward(self, x, context = None, mask = None, ref_cls_onehot=None):
129
-
130
- h = self.heads
131
- q = self.to_q(x)
132
- context = default(context, x)
133
- k, v = self.to_kv(context).chunk(2, dim = -1)
134
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
135
- sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
136
-
137
- if exists(mask):
138
- mask = repeat(mask, 'b j k -> (b h) k j', h = h)
139
- sim.masked_fill(mask == 0, -1e9)
140
-
141
- # attention, what we cannot get enough of
142
- attn = sim.softmax(dim = -1)
143
- attn = self.dropout(attn)
144
- out = einsum('b i j, b j d -> b i d', attn, v)
145
- out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
146
- return self.to_out(out), attn
147
-
148
-
149
- class SVGEmbedding(nn.Module):
150
- def __init__(self):
151
- super().__init__()
152
- self.command_embed = nn.Embedding(4, 512)
153
- self.arg_embed = nn.Embedding(128, 128,padding_idx=0)
154
- self.embed_fcn = nn.Linear(128 * 8, 512)
155
- self.pos_encoding = PositionalEncoding(d_model=opts.hidden_size, max_len=opts.max_seq_len + 1)
156
- self._init_embeddings()
157
-
158
- def _init_embeddings(self):
159
- nn.init.kaiming_normal_(self.command_embed.weight, mode="fan_in")
160
- nn.init.kaiming_normal_(self.arg_embed.weight, mode="fan_in")
161
- nn.init.kaiming_normal_(self.embed_fcn.weight, mode="fan_in")
162
-
163
-
164
- def forward(self, commands, args, groups=None):
165
-
166
- S, GN,_ = commands.shape
167
- src = self.command_embed(commands.long()).squeeze() + \
168
- self.embed_fcn(self.arg_embed((args).long()).view(S, GN, -1)) # shift due to -1 PAD_VAL
169
-
170
- src = self.pos_encoding(src)
171
-
172
- return src
173
- class PositionwiseFeedForward(nn.Module):
174
- "Implements FFN equation."
175
-
176
- def __init__(self, d_model, d_ff, dropout):
177
- super(PositionwiseFeedForward, self).__init__()
178
- self.w_1 = nn.Linear(d_model, d_ff)
179
- self.w_2 = nn.Linear(d_ff, d_model)
180
- self.dropout = nn.Dropout(dropout)
181
-
182
- def forward(self, x):
183
- return self.w_2(F.relu(self.dropout(self.w_1(x))))
184
-
185
- class Transformer_decoder(nn.Module):
186
- def __init__(self):
187
- super().__init__()
188
- self.SVG_embedding = SVGEmbedding()
189
- self.command_fcn = nn.Linear(512, 4)
190
- self.args_fcn = nn.Linear(512, 8 * 128)
191
- c = copy.deepcopy
192
- attn = MultiHeadedAttention(h=8, d_model=512, dropout=0.0)
193
- ff = PositionwiseFeedForward(d_model=512, d_ff=1024, dropout=0.0)
194
- self.decoder_layers = clones(DecoderLayer(512, c(attn), c(attn),c(ff), dropout=0.0), 6)
195
- self.decoder_norm = nn.LayerNorm(512)
196
- self.decoder_layers_parallel = clones(DecoderLayer(512, c(attn), c(attn), c(ff), dropout=0.0), 1)
197
- self.decoder_norm_parallel = nn.LayerNorm(512)
198
- self.cls_embedding = nn.Embedding(52,512)
199
- self.cls_token = nn.Parameter(torch.zeros(1, 1, 512))
200
-
201
- def forward(self, x, memory, trg_char, src_mask=None, tgt_mask=None):
202
-
203
- memory = memory.unsqueeze(1)
204
- commands = x[:, :, :1]
205
- args = x[:, :, 1:]
206
- x = self.SVG_embedding(commands, args).transpose(0,1)
207
- trg_char = trg_char.long()
208
- trg_char = self.cls_embedding(trg_char)
209
- x[:, 0:1, :] = trg_char
210
- tgt_mask = tgt_mask.squeeze()
211
- for layer in self.decoder_layers:
212
- x,attn = layer(x, memory, src_mask, tgt_mask)
213
- out = self.decoder_norm(x)
214
- N, S, _ = out.shape
215
- cmd_logits = self.command_fcn(out)
216
- args_logits = self.args_fcn(out) # shape: bs, max_len, 8, 256
217
- args_logits = args_logits.reshape(N, S, 8, 128)
218
- return cmd_logits,args_logits,attn
219
-
220
- def parallel_decoder(self, cmd_logits, args_logits, memory, trg_char):
221
-
222
- memory = memory.unsqueeze(1)
223
- cmd_args_mask = torch.Tensor([[0, 0, 0., 0., 0., 0., 0., 0.],
224
- [1, 1, 0., 0., 0., 0., 1., 1.],
225
- [1, 1, 0., 0., 0., 0., 1., 1.],
226
- [1, 1, 1., 1., 1., 1., 1., 1.]]).to(cmd_logits.device)
227
- if opts.mode == 'train':
228
- cmd2 = torch.argmax(cmd_logits, -1).unsqueeze(-1).transpose(0, 1)
229
- arg2 = torch.argmax(args_logits, -1).transpose(0, 1)
230
-
231
- cmd2paddingmask = _get_key_padding_mask(cmd2).transpose(0,1).unsqueeze(-1).to(cmd2.device)
232
- cmd2 = cmd2 * cmd2paddingmask
233
- args_mask = torch.matmul(F.one_hot(cmd2.long(),4).float(), cmd_args_mask).transpose(-1,-2).squeeze(-1)
234
- arg2 = arg2 * args_mask
235
-
236
- x = self.SVG_embedding(cmd2, arg2).transpose(0, 1)
237
- else:
238
- cmd2 = cmd_logits
239
- arg2 = args_logits
240
-
241
- cmd2paddingmask = _get_key_padding_mask(cmd2).transpose(0, 1).unsqueeze(-1).to(cmd2.device)
242
- cmd2 = cmd2 * cmd2paddingmask
243
- args_mask = torch.matmul(F.one_hot(cmd2.long(),4).float(), cmd_args_mask).transpose(-1, -2).squeeze(-1)
244
- arg2 = arg2 * args_mask
245
-
246
- x = self.SVG_embedding(cmd2, arg2).transpose(0,1)
247
-
248
- S = x.size(1)
249
- B = x.size(0)
250
- tgt_mask = torch.ones(S,S).to(x.device).unsqueeze(0).repeat(B, 1, 1)
251
- cmd2paddingmask = cmd2paddingmask.transpose(0, 1).transpose(-1, -2)
252
- tgt_mask = tgt_mask * cmd2paddingmask
253
-
254
- trg_char = trg_char.long()
255
- trg_char = self.cls_embedding(trg_char)
256
-
257
- x = torch.cat([trg_char, x],1)
258
- x[:, 0:1, :] = trg_char
259
- x = x[:,:opts.max_seq_len,:]
260
- tgt_mask = tgt_mask #*tri
261
- for layer in self.decoder_layers_parallel:
262
- x, attn = layer(x, memory, src_mask=None, tgt_mask=tgt_mask)
263
- out = self.decoder_norm_parallel(x)
264
-
265
- N, S, _ = out.shape
266
- cmd_logits = self.command_fcn(out)
267
- args_logits = self.args_fcn(out)
268
- args_logits = args_logits.reshape(N, S, 8, 128)
269
-
270
- return cmd_logits, args_logits
271
-
272
-
273
- def _get_key_padding_mask(commands, seq_dim=0):
274
- """
275
- Args:
276
- commands: Shape [S, ...]
277
- """
278
- lens =[]
279
- with torch.no_grad():
280
- key_padding_mask = (commands == 0).cumsum(dim=seq_dim) > 0
281
- commands=commands.transpose(0,1).squeeze(-1) #bs, opts.max_seq_len
282
- for i in range(commands.size(0)):
283
- try:
284
- seqi = commands[i]#blue opts.max_seq_len
285
- index = torch.where(seqi==0)[0][0]
286
-
287
- except:
288
- index=opts.max_seq_len
289
-
290
- lens.append(index)
291
- lens = torch.tensor(lens)+1#blue b
292
- seqlen_mask = util_funcs.sequence_mask(lens, opts.max_seq_len)#blue b,opts.max_seq_len
293
- return seqlen_mask
294
-
295
- class Transformer(nn.Module):
296
- def __init__(
297
- self,
298
- *,
299
- num_freq_bands,
300
- depth,
301
- max_freq,
302
- input_channels = 1,
303
- input_axis = 2,
304
- num_latents = 512,
305
- latent_dim = 512,
306
- cross_heads = 1,
307
- latent_heads = 8,
308
- cross_dim_head = 64,
309
- latent_dim_head = 64,
310
- num_classes = 1000,
311
- attn_dropout = 0.,
312
- ff_dropout = 0.,
313
- weight_tie_layers = False,
314
- fourier_encode_data = True,
315
- self_per_cross_attn = 2,
316
- final_classifier_head = True
317
- ):
318
- """The shape of the final attention mechanism will be:
319
- depth * (cross attention -> self_per_cross_attn * self attention)
320
-
321
- Args:
322
- num_freq_bands: Number of freq bands, with original value (2 * K + 1)
323
- depth: Depth of net.
324
- max_freq: Maximum frequency, hyperparameter depending on how
325
- fine the data is.
326
- freq_base: Base for the frequency
327
- input_channels: Number of channels for each token of the input.
328
- input_axis: Number of axes for input data (2 for images, 3 for video)
329
- num_latents: Number of latents, or induced set points, or centroids.
330
- Different papers giving it different names.
331
- latent_dim: Latent dimension.
332
- cross_heads: Number of heads for cross attention. Paper said 1.
333
- latent_heads: Number of heads for latent self attention, 8.
334
- cross_dim_head: Number of dimensions per cross attention head.
335
- latent_dim_head: Number of dimensions per latent self attention head.
336
- num_classes: Output number of classes.
337
- attn_dropout: Attention dropout
338
- ff_dropout: Feedforward dropout
339
- weight_tie_layers: Whether to weight tie layers (optional).
340
- fourier_encode_data: Whether to auto-fourier encode the data, using
341
- the input_axis given. defaults to True, but can be turned off
342
- if you are fourier encoding the data yourself.
343
- self_per_cross_attn: Number of self attention blocks per cross attn.
344
- final_classifier_head: mean pool and project embeddings to number of classes (num_classes) at the end
345
- """
346
- super().__init__()
347
- self.input_axis = input_axis
348
- self.max_freq = max_freq
349
- self.num_freq_bands = num_freq_bands
350
-
351
- self.fourier_encode_data = fourier_encode_data
352
- fourier_channels = (input_axis * ((num_freq_bands * 2) + 1)) if fourier_encode_data else 0 # 26
353
- input_dim = fourier_channels + input_channels
354
-
355
- self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
356
-
357
- get_cross_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, input_dim, heads=cross_heads, dim_head=cross_dim_head, dropout=attn_dropout), context_dim=input_dim)
358
- get_cross_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout=ff_dropout))
359
- get_latent_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, heads=latent_heads, dim_head=latent_dim_head, dropout=attn_dropout))
360
- get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout=ff_dropout))
361
-
362
- get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map(cache_fn, (get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff))
363
-
364
-
365
- #self_per_cross_attn=1
366
- self.layers = nn.ModuleList([])
367
- for i in range(depth):
368
- should_cache = i > 0 and weight_tie_layers
369
- cache_args = {'_cache': should_cache}
370
-
371
- self_attns = nn.ModuleList([])
372
-
373
- for block_ind in range(self_per_cross_attn): #BUG 之前是2 self_per_cross_attn
374
- self_attns.append(nn.ModuleList([
375
- get_latent_attn(**cache_args, key = block_ind),
376
- get_latent_ff(**cache_args, key = block_ind)
377
- ]))
378
-
379
- self.layers.append(nn.ModuleList([
380
- get_cross_attn(**cache_args),
381
- get_cross_ff(**cache_args),
382
- self_attns
383
- ]))
384
-
385
-
386
- get_cross_attn2 = lambda: PreNorm(latent_dim, Attention(latent_dim, input_dim, heads = cross_heads, dim_head = cross_dim_head, dropout = attn_dropout), context_dim = input_dim)
387
- get_cross_ff2 = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout))
388
- get_latent_attn2 = lambda: PreNorm(latent_dim, Attention(latent_dim, heads = latent_heads, dim_head = latent_dim_head, dropout = attn_dropout))
389
- get_latent_ff2 = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout))
390
-
391
- get_cross_attn2, get_cross_ff2, get_latent_attn2, get_latent_ff2 = map(cache_fn, (get_cross_attn2, get_cross_ff2, get_latent_attn2, get_latent_ff2))
392
-
393
- self.layers_cnnsvg = nn.ModuleList([])
394
- for i in range(1):
395
- should_cache = i > 0 and weight_tie_layers
396
- cache_args = {'_cache': should_cache}
397
-
398
- self_attns2 = nn.ModuleList([])
399
-
400
- for block_ind in range(self_per_cross_attn):
401
- self_attns2.append(nn.ModuleList([
402
- get_latent_attn2(**cache_args, key = block_ind),
403
- get_latent_ff2(**cache_args, key = block_ind)
404
- ]))
405
-
406
- self.layers_cnnsvg.append(nn.ModuleList([
407
- get_cross_attn2(**cache_args),
408
- get_cross_ff2(**cache_args),
409
- self_attns2
410
- ]))
411
-
412
- self.to_logits = nn.Sequential(
413
- Reduce('b n d -> b d', 'mean'),
414
- nn.LayerNorm(latent_dim),
415
- nn.Linear(latent_dim, num_classes)
416
- ) if final_classifier_head else nn.Identity()
417
- self.pre_lstm_fc = nn.Linear(10,opts.hidden_size)
418
- self.posr = PositionalEncoding(d_model=opts.hidden_size,max_len=opts.max_seq_len)
419
-
420
- patch_height = 2
421
- patch_width = 2
422
- patch_dim = 1 * patch_height * patch_width
423
- self.to_patch_embedding = nn.Sequential(
424
- Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
425
- nn.Linear(patch_dim, 16),
426
- )
427
-
428
- self.SVG_embedding = SVGEmbedding()
429
- self.cls_token = nn.Parameter(torch.zeros(1, 1, 512))
430
-
431
- def forward(self, data, seq, ref_cls_onehot=None, mask=None, return_embeddings=True):
432
-
433
- b, *axis, _, device, dtype = *data.shape, data.device, data.dtype
434
- assert len(axis) == self.input_axis, 'input data must have the right number of axis' # img is 2
435
- x = seq
436
- commands=x[:, :, :1]
437
- args=x[:, :, 1:]
438
- x = self.SVG_embedding(commands, args).transpose(0,1)
439
- cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = x.size(0))
440
- x = torch.cat([cls_tokens,x],dim = 1)
441
- cls_one_pad = torch.ones((1,1,1)).to(x.device).repeat(x.size(0),1,1)
442
- mask = torch.cat([cls_one_pad,mask],dim=-1)
443
- self_atten = []
444
- for cross_attn, cross_ff, self_attns in self.layers:
445
- for self_attn, self_ff in self_attns:
446
- x_,atten = self_attn(x,mask=mask)
447
- x = x_ + x
448
- self_atten.append(atten)
449
- x = self_ff(x) + x
450
- x = x + torch.randn_like(x) # add a perturbation
451
- return x, self_atten
452
-
453
- def att_residual(self, x, mask=None):
454
-
455
- for cross_attn, cross_ff, self_attns in self.layers_cnnsvg:
456
- for self_attn, self_ff in self_attns:
457
- x_, atten = self_attn(x)
458
- x = x_ + x
459
- x = self_ff(x) + x
460
- return x
461
-
462
-
463
-
464
- def loss(self, cmd_logits, args_logits, trg_seq, trg_seqlen, trg_pts_aux):
465
- '''
466
- Inputs:
467
- cmd_logits: [b, 51, 4]
468
- args_logits: [b, 51, 6]
469
- '''
470
- cmd_args_mask = torch.Tensor([[0, 0, 0., 0., 0., 0., 0., 0.],
471
- [1, 1, 0., 0., 0., 0., 1., 1.],
472
- [1, 1, 0., 0., 0., 0., 1., 1.],
473
- [1, 1, 1., 1., 1., 1., 1., 1.]]).to(cmd_logits.device)
474
-
475
- tgt_commands = trg_seq[:,:,:1].transpose(0,1)
476
- tgt_args = trg_seq[:,:,1:].transpose(0,1)
477
-
478
- seqlen_mask = util_funcs.sequence_mask(trg_seqlen, opts.max_seq_len).unsqueeze(-1)
479
- seqlen_mask2 = seqlen_mask.repeat(1,1,4)# NOTE b,501,4
480
- seqlen_mask4 = seqlen_mask.repeat(1,1,8)
481
- seqlen_mask3 = seqlen_mask.unsqueeze(-1).repeat(1,1,8,128)
482
-
483
-
484
- tgt_commands_onehot = F.one_hot(tgt_commands, 4)
485
- tgt_args_onehot = F.one_hot(tgt_args, 128)
486
-
487
- args_mask = torch.matmul(tgt_commands_onehot.float(),cmd_args_mask).squeeze()
488
-
489
-
490
- loss_cmd = torch.sum(- tgt_commands_onehot.squeeze() * F.log_softmax(cmd_logits, -1), -1)
491
- loss_cmd = torch.mul(loss_cmd, seqlen_mask.squeeze())
492
- loss_cmd = torch.mean(torch.sum(loss_cmd/trg_seqlen.unsqueeze(-1),-1))
493
-
494
- loss_args = (torch.sum(-tgt_args_onehot*F.log_softmax(args_logits,-1),-1)*seqlen_mask4*args_mask)
495
-
496
- loss_args = torch.mean(loss_args,dim=-1,keepdim=False)
497
- loss_args = torch.mean(torch.sum(loss_args/trg_seqlen.unsqueeze(-1),-1))
498
-
499
- SE_mask = torch.Tensor([[1, 1],
500
- [0, 0],
501
- [1, 1],
502
- [1, 1]]).to(cmd_logits.device)
503
-
504
- SE_args_mask = torch.matmul(tgt_commands_onehot.float(),SE_mask).squeeze().unsqueeze(-1)
505
-
506
-
507
- args_prob = F.softmax(args_logits, -1)
508
- args_end = args_prob[:,:,6:]
509
- args_end_shifted = torch.cat((torch.zeros(args_end.size(0),1,args_end.size(2),args_end.size(3)).to(args_end.device),args_end),1)
510
- args_end_shifted = args_end_shifted[:,:opts.max_seq_len,:,:]
511
- args_end_shifted = args_end_shifted*SE_args_mask + args_end*(1-SE_args_mask)
512
-
513
- args_start = args_prob[:,:,:2]
514
-
515
- seqlen_mask5 = util_funcs.sequence_mask(trg_seqlen-1, opts.max_seq_len).unsqueeze(-1)
516
- seqlen_mask5 = seqlen_mask5.repeat(1,1,2)
517
-
518
- smooth_constrained = torch.sum(torch.pow((args_end_shifted - args_start), 2), -1) * seqlen_mask5
519
- smooth_constrained = torch.mean(smooth_constrained, dim=-1, keepdim=False)
520
- smooth_constrained = torch.mean(torch.sum(smooth_constrained / (trg_seqlen - 1).unsqueeze(-1), -1))
521
-
522
- args_prob2 = F.softmax(args_logits / 0.1, -1)
523
-
524
- c = torch.argmax(args_prob2,-1).unsqueeze(-1).float() - args_prob2.detach()
525
- p_argmax = args_prob2 + c
526
- p_argmax = torch.mean(p_argmax,-1)
527
- control_pts = denumericalize(p_argmax)
528
-
529
- p0 = control_pts[:,:,:2]
530
- p1 = control_pts[:,:,2:4]
531
- p2 = control_pts[:,:,4:6]
532
- p3 = control_pts[:,:,6:8]
533
-
534
- line_mask = (tgt_commands==2).float() + (tgt_commands==1).float()
535
- curve_mask = (tgt_commands==3).float()
536
-
537
- t=0.25
538
- aux_pts_line = p0 + t*(p3-p0)
539
- for t in [0.5,0.75]:
540
- coord_t = p0 + t*(p3-p0)
541
- aux_pts_line = torch.cat((aux_pts_line,coord_t),-1)
542
- aux_pts_line = aux_pts_line*line_mask
543
-
544
- t=0.25
545
- aux_pts_curve = (1-t)*(1-t)*(1-t)*p0 + 3*t*(1-t)*(1-t)*p1 + 3*t*t*(1-t)*p2 + t*t*t*p3
546
- for t in [0.5, 0.75]:
547
- coord_t = (1-t)*(1-t)*(1-t)*p0 + 3*t*(1-t)*(1-t)*p1 + 3*t*t*(1-t)*p2 + t*t*t*p3
548
- aux_pts_curve = torch.cat((aux_pts_curve,coord_t),-1)
549
- aux_pts_curve = aux_pts_curve * curve_mask
550
-
551
-
552
- aux_pts_predict = aux_pts_curve + aux_pts_line
553
- seqlen_mask_aux = util_funcs.sequence_mask(trg_seqlen - 1, opts.max_seq_len).unsqueeze(-1)
554
- aux_pts_loss = torch.pow((aux_pts_predict - trg_pts_aux), 2) * seqlen_mask_aux
555
-
556
- loss_aux = torch.mean(aux_pts_loss, dim=-1, keepdim=False)
557
- loss_aux = torch.mean(torch.sum(loss_aux / trg_seqlen.unsqueeze(-1), -1))
558
-
559
-
560
- loss = opts.loss_w_cmd * loss_cmd + opts.loss_w_args * loss_args + opts.loss_w_aux * loss_aux + opts.loss_w_smt * smooth_constrained
561
-
562
- svg_losses = {}
563
- svg_losses['loss_total'] = loss
564
- svg_losses["loss_cmd"] = loss_cmd
565
- svg_losses["loss_args"] = loss_args
566
- svg_losses["loss_smt"] = smooth_constrained
567
- svg_losses["loss_aux"] = loss_aux
568
-
569
- return svg_losses
570
-
571
- class DecoderLayer(nn.Module):
572
- "Decoder is made of self-attn, src-attn, and feed forward (defined below)"
573
-
574
- def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
575
- super(DecoderLayer, self).__init__()
576
- self.size = size
577
- self.self_attn = self_attn
578
- self.src_attn = src_attn
579
- self.feed_forward = feed_forward
580
- self.sublayer = clones(SublayerConnection(size, dropout), 3)
581
-
582
- def forward(self, x, memory, src_mask, tgt_mask):
583
- "Follow Figure 1 (right) for connections."
584
- m = memory
585
- x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
586
- x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
587
- attn = self.self_attn.attn
588
- return self.sublayer[2](x, self.feed_forward),attn
589
-
590
- def subsequent_mask(size):
591
- "Mask out subsequent positions."
592
- attn_shape = (1, size, size)
593
- subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
594
- return torch.from_numpy(subsequent_mask) == 0
595
-
596
- def numericalize(cmd, n=128):
597
- """NOTE: shall only be called after normalization"""
598
- # assert np.max(cmd.origin) <= 1.0 and np.min(cmd.origin) >= -1.0
599
- cmd = (cmd / 30 * n).round().clip(min=0, max=n-1).int()
600
- return cmd
601
-
602
- def denumericalize(cmd, n=128):
603
- cmd = cmd / n * 30
604
- return cmd
605
-
606
- def attention(query, key, value, mask=None, trg_tri_mask=None,dropout=None, posr=None):
607
- "Compute 'Scaled Dot Product Attention'"
608
- d_k = query.size(-1)
609
- scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
610
-
611
- if posr is not None:
612
- posr = posr.unsqueeze(1)
613
- scores = scores + posr
614
-
615
- if mask is not None:
616
- try:
617
- scores = scores.masked_fill(mask == 0, -1e9) # note mask: b,1,501,501 scores: b, head, 501,501
618
- except Exception as e:
619
- print("Shape: ",scores.shape)
620
- print("Error: ",e)
621
- import pdb; pdb.set_trace()
622
-
623
- if trg_tri_mask is not None:
624
- scores = scores.masked_fill(trg_tri_mask == 0, -1e9)
625
-
626
- p_attn = F.softmax(scores, dim=-1)
627
-
628
- if dropout is not None:
629
- p_attn = dropout(p_attn)
630
-
631
- return torch.matmul(p_attn, value), p_attn
632
-
633
-
634
- class MultiHeadedAttention(nn.Module):
635
- def __init__(self, h, d_model, dropout):
636
- "Take in model size and number of heads."
637
- super(MultiHeadedAttention, self).__init__()
638
- assert d_model % h == 0
639
- # We assume d_v always equals d_k
640
- self.d_k = d_model // h #32
641
- self.h = h #8
642
- self.linears = clones(nn.Linear(d_model, d_model), 4)
643
- self.attn = None
644
- self.dropout = nn.Dropout(p=dropout)
645
-
646
- def forward(self, query, key, value, mask=None,trg_tri_mask=None, posr=None):
647
- "Implements Figure 2"
648
-
649
- if mask is not None:
650
- # Same mask applied to all h heads.
651
- mask = mask.unsqueeze(1)
652
- nbatches = query.size(0) #16
653
-
654
- query, key, value = \
655
- [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
656
- for l, x in zip(self.linears, (query, key, value))]
657
-
658
- x, self.attn = attention(query, key, value, mask=mask,trg_tri_mask=trg_tri_mask,
659
- dropout=self.dropout, posr=posr)
660
-
661
- x = x.transpose(1, 2).contiguous() \
662
- .view(nbatches, -1, self.h * self.d_k)
663
-
664
- return self.linears[-1](x)
665
-
666
- def clones(module, N):
667
- "Produce N identical layers."
668
- return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
669
-
670
- class SublayerConnection(nn.Module):
671
- """
672
- A residual connection followed by a layer norm.
673
- Note for code simplicity the norm is first as opposed to last.
674
- """
675
-
676
- def __init__(self, size, dropout):
677
- super(SublayerConnection, self).__init__()
678
- self.norm = nn.LayerNorm(size)
679
- self.dropout = nn.Dropout(dropout)
680
-
681
- def forward(self, x, sublayer):
682
- "Apply residual connection to any sublayer with the same size."
683
- x_norm=self.norm(x)
684
- return x + self.dropout(sublayer(x_norm))#+ self.augs(x_norm)
685
-
686
-
687
- if __name__ == '__main__':
688
- model = Transformer(
689
- input_channels = 1, # number of channels for each token of the input
690
- input_axis = 2, # number of axis for input data (2 for images, 3 for video)
691
- num_freq_bands = 6, # number of freq bands, with original value (2 * K + 1)
692
- max_freq = 10., # maximum frequency, hyperparameter depending on how fine the data is
693
- depth = 6, # depth of net. The shape of the final attention mechanism will be:
694
- # depth * (cross attention -> self_per_cross_attn * self attention)
695
- num_latents = 256, # number of latents, or induced set points, or centroids. different papers giving it different names
696
- latent_dim = 512, # latent dimension
697
- cross_heads = 1, # number of heads for cross attention. paper said 1
698
- latent_heads = 8, # number of heads for latent self attention, 8
699
- cross_dim_head = 64, # number of dimensions per cross attention head
700
- latent_dim_head = 64, # number of dimensions per latent self attention head
701
- num_classes = 1000, # output number of classes
702
- attn_dropout = 0.,
703
- ff_dropout = 0.,
704
- weight_tie_layers = False, # whether to weight tie layers (optional, as indicated in the diagram)
705
- fourier_encode_data = True, # whether to auto-fourier encode the data, using the input_axis given. defaults to True, but can be turned off if you are fourier encoding the data yourself
706
- self_per_cross_attn = 2 # number of self attention blocks per cross attention
707
- )
708
-
709
- img = torch.randn(1, 224, 224, 3) # 1 imagenet image, pixelized
710
-
711
  model(img) # (1, 1000)
 
1
+ from math import pi, log
2
+ from functools import wraps
3
+ from multiprocessing import context
4
+ from textwrap import indent
5
+ import models.util_funcs as util_funcs
6
+ import math, copy
7
+ import numpy as np
8
+ import torch
9
+ from torch import nn, einsum
10
+ import torch.nn.functional as F
11
+ from einops import rearrange, repeat
12
+ from einops.layers.torch import Reduce
13
+ import pdb
14
+ from einops.layers.torch import Rearrange
15
+ from options import get_parser_main_model
16
+ opts = get_parser_main_model().parse_args()
17
+
18
+ class PositionalEncoding(nn.Module):
19
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
20
+ super(PositionalEncoding, self).__init__()
21
+ self.dropout = nn.Dropout(p=dropout)
22
+ pe = torch.zeros(max_len, d_model)
23
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
24
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
25
+ pe[:, 0::2] = torch.sin(position * div_term)
26
+ pe[:, 1::2] = torch.cos(position * div_term)
27
+ pe = pe.unsqueeze(0).transpose(0, 1)
28
+ self.register_buffer('pe', pe)
29
+
30
+ def forward(self, x):
31
+ """
32
+ :param x: [x_len, batch_size, emb_size]
33
+ :return: [x_len, batch_size, emb_size]
34
+ """
35
+ x = x + self.pe[:x.size(0), :].to(x.device)
36
+ return self.dropout(x)
37
+
38
+ def exists(val):
39
+ return val is not None
40
+
41
+ def default(val, d):
42
+ return val if exists(val) else d
43
+
44
+ def cache_fn(f):
45
+ cache = dict()
46
+ @wraps(f)
47
+ def cached_fn(*args, _cache = True, key = None, **kwargs):
48
+ if not _cache:
49
+ return f(*args, **kwargs)
50
+ nonlocal cache
51
+ if key in cache:
52
+ return cache[key]
53
+ result = f(*args, **kwargs)
54
+ cache[key] = result
55
+ return result
56
+ return cached_fn
57
+
58
+ def fourier_encode(x, max_freq, num_bands = 4):
59
+ '''
60
+ x: ([64, 64, 2, 1]) is between [-1,1]
61
+ max_feq is 10
62
+ num_bands is 6
63
+ '''
64
+
65
+ x = x.unsqueeze(-1)
66
+ device, dtype, orig_x = x.device, x.dtype, x
67
+
68
+ scales = torch.linspace(1., max_freq / 2, num_bands, device = device, dtype = dtype) # tensor([1.0000, 1.8000, 2.6000, 3.4000, 4.2000, 5.0000]
69
+ scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)] # r([[[[1.0000, 1.8000, 2.6000, 3.4000, 4.2000, 5.0000]]]],
70
+
71
+ x = x * scales * pi
72
+ x = torch.cat([x.sin(), x.cos()], dim = -1)
73
+
74
+ x = torch.cat((x, orig_x), dim = -1)
75
+ return x
76
+
77
+ class PreNorm(nn.Module):
78
+ def __init__(self, dim, fn, context_dim = None):
79
+ super().__init__()
80
+ self.fn = fn
81
+ self.norm = nn.LayerNorm(dim)
82
+ self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None
83
+
84
+ def forward(self, x, **kwargs):
85
+ x = self.norm(x)
86
+
87
+ if exists(self.norm_context):
88
+ context = kwargs['context']
89
+ normed_context = self.norm_context(context)
90
+ kwargs.update(context = normed_context)
91
+
92
+ return self.fn(x, **kwargs)
93
+
94
+ class GEGLU(nn.Module):
95
+ def forward(self, x):
96
+ x, gates = x.chunk(2, dim = -1)
97
+ return x * F.gelu(gates)
98
+
99
+ class FeedForward(nn.Module):
100
+ def __init__(self, dim, mult = 4, dropout = 0.):
101
+ super().__init__()
102
+ self.net = nn.Sequential(
103
+ nn.Linear(dim, dim * mult * 2),
104
+ GEGLU(),
105
+ nn.Linear(dim * mult, dim),
106
+ nn.Dropout(dropout)
107
+ )
108
+
109
+ def forward(self, x):
110
+ return self.net(x)
111
+
112
+ class Attention(nn.Module):
113
+ def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64, dropout = 0.,cls_conv_dim=None):
114
+ super().__init__()
115
+ inner_dim = dim_head * heads
116
+ context_dim = default(context_dim, query_dim)
117
+
118
+ self.scale = dim_head ** -0.5
119
+ self.heads = heads
120
+
121
+ self.to_q = nn.Linear(query_dim, inner_dim, bias = False)
122
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False) # 27 to 5012*2 = 1024
123
+
124
+ self.dropout = nn.Dropout(dropout)
125
+ self.to_out = nn.Linear(inner_dim, query_dim)
126
+ #self.cls_dim_adjust = nn.Linear(context_dim,cls_conv_dim)
127
+
128
+ def forward(self, x, context = None, mask = None, ref_cls_onehot=None):
129
+
130
+ h = self.heads
131
+ q = self.to_q(x)
132
+ context = default(context, x)
133
+ k, v = self.to_kv(context).chunk(2, dim = -1)
134
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
135
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
136
+
137
+ if exists(mask):
138
+ mask = repeat(mask, 'b j k -> (b h) k j', h = h)
139
+ sim.masked_fill(mask == 0, -1e9)
140
+
141
+ # attention, what we cannot get enough of
142
+ attn = sim.softmax(dim = -1)
143
+ attn = self.dropout(attn)
144
+ out = einsum('b i j, b j d -> b i d', attn, v)
145
+ out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
146
+ return self.to_out(out), attn
147
+
148
+
149
+ class SVGEmbedding(nn.Module):
150
+ def __init__(self):
151
+ super().__init__()
152
+ self.command_embed = nn.Embedding(4, 512)
153
+ self.arg_embed = nn.Embedding(128, 128,padding_idx=0)
154
+ self.embed_fcn = nn.Linear(128 * 8, 512)
155
+ self.pos_encoding = PositionalEncoding(d_model=opts.hidden_size, max_len=opts.max_seq_len + 1)
156
+ self._init_embeddings()
157
+
158
+ def _init_embeddings(self):
159
+ nn.init.kaiming_normal_(self.command_embed.weight, mode="fan_in")
160
+ nn.init.kaiming_normal_(self.arg_embed.weight, mode="fan_in")
161
+ nn.init.kaiming_normal_(self.embed_fcn.weight, mode="fan_in")
162
+
163
+
164
+ def forward(self, commands, args, groups=None):
165
+
166
+ S, GN,_ = commands.shape
167
+ src = self.command_embed(commands.long()).squeeze() + \
168
+ self.embed_fcn(self.arg_embed((args).long()).view(S, GN, -1)) # shift due to -1 PAD_VAL
169
+
170
+ src = self.pos_encoding(src)
171
+
172
+ return src
173
+ class PositionwiseFeedForward(nn.Module):
174
+ "Implements FFN equation."
175
+
176
+ def __init__(self, d_model, d_ff, dropout):
177
+ super(PositionwiseFeedForward, self).__init__()
178
+ self.w_1 = nn.Linear(d_model, d_ff)
179
+ self.w_2 = nn.Linear(d_ff, d_model)
180
+ self.dropout = nn.Dropout(dropout)
181
+
182
+ def forward(self, x):
183
+ return self.w_2(F.relu(self.dropout(self.w_1(x))))
184
+
185
+ class Transformer_decoder(nn.Module):
186
+ def __init__(self):
187
+ super().__init__()
188
+ self.SVG_embedding = SVGEmbedding()
189
+ self.command_fcn = nn.Linear(512, 4)
190
+ self.args_fcn = nn.Linear(512, 8 * 128)
191
+ c = copy.deepcopy
192
+ attn = MultiHeadedAttention(h=8, d_model=512, dropout=0.0)
193
+ ff = PositionwiseFeedForward(d_model=512, d_ff=1024, dropout=0.0)
194
+ self.decoder_layers = clones(DecoderLayer(512, c(attn), c(attn),c(ff), dropout=0.0), 6)
195
+ self.decoder_norm = nn.LayerNorm(512)
196
+ self.decoder_layers_parallel = clones(DecoderLayer(512, c(attn), c(attn), c(ff), dropout=0.0), 1)
197
+ self.decoder_norm_parallel = nn.LayerNorm(512)
198
+ self.cls_embedding = nn.Embedding(52,512)
199
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, 512))
200
+
201
+ def forward(self, x, memory, trg_char, src_mask=None, tgt_mask=None):
202
+
203
+ memory = memory.unsqueeze(1)
204
+ commands = x[:, :, :1]
205
+ args = x[:, :, 1:]
206
+ x = self.SVG_embedding(commands, args).transpose(0,1)
207
+ trg_char = trg_char.long()
208
+ trg_char = self.cls_embedding(trg_char)
209
+ x[:, 0:1, :] = trg_char
210
+ tgt_mask = tgt_mask.squeeze()
211
+ for layer in self.decoder_layers:
212
+ x,attn = layer(x, memory, src_mask, tgt_mask)
213
+ out = self.decoder_norm(x)
214
+ N, S, _ = out.shape
215
+ cmd_logits = self.command_fcn(out)
216
+ args_logits = self.args_fcn(out) # shape: bs, max_len, 8, 256
217
+ args_logits = args_logits.reshape(N, S, 8, 128)
218
+ return cmd_logits,args_logits,attn
219
+
220
+ def parallel_decoder(self, cmd_logits, args_logits, memory, trg_char):
221
+
222
+ memory = memory.unsqueeze(1)
223
+ cmd_args_mask = torch.Tensor([[0, 0, 0., 0., 0., 0., 0., 0.],
224
+ [1, 1, 0., 0., 0., 0., 1., 1.],
225
+ [1, 1, 0., 0., 0., 0., 1., 1.],
226
+ [1, 1, 1., 1., 1., 1., 1., 1.]]).to(cmd_logits.device)
227
+ if opts.mode == 'train':
228
+ cmd2 = torch.argmax(cmd_logits, -1).unsqueeze(-1).transpose(0, 1)
229
+ arg2 = torch.argmax(args_logits, -1).transpose(0, 1)
230
+
231
+ cmd2paddingmask = _get_key_padding_mask(cmd2).transpose(0,1).unsqueeze(-1).to(cmd2.device)
232
+ cmd2 = cmd2 * cmd2paddingmask
233
+ args_mask = torch.matmul(F.one_hot(cmd2.long(),4).float(), cmd_args_mask).transpose(-1,-2).squeeze(-1)
234
+ arg2 = arg2 * args_mask
235
+
236
+ x = self.SVG_embedding(cmd2, arg2).transpose(0, 1)
237
+ else:
238
+ cmd2 = cmd_logits
239
+ arg2 = args_logits
240
+
241
+ cmd2paddingmask = _get_key_padding_mask(cmd2).transpose(0, 1).unsqueeze(-1).to(cmd2.device)
242
+ cmd2 = cmd2 * cmd2paddingmask
243
+ args_mask = torch.matmul(F.one_hot(cmd2.long(),4).float(), cmd_args_mask).transpose(-1, -2).squeeze(-1)
244
+ arg2 = arg2 * args_mask
245
+
246
+ x = self.SVG_embedding(cmd2, arg2).transpose(0,1)
247
+
248
+ S = x.size(1)
249
+ B = x.size(0)
250
+ tgt_mask = torch.ones(S,S).to(x.device).unsqueeze(0).repeat(B, 1, 1)
251
+ cmd2paddingmask = cmd2paddingmask.transpose(0, 1).transpose(-1, -2)
252
+ tgt_mask = tgt_mask * cmd2paddingmask
253
+
254
+ trg_char = trg_char.long()
255
+ trg_char = self.cls_embedding(trg_char)
256
+
257
+ x = torch.cat([trg_char, x],1)
258
+ x[:, 0:1, :] = trg_char
259
+ x = x[:,:opts.max_seq_len,:]
260
+ tgt_mask = tgt_mask #*tri
261
+ for layer in self.decoder_layers_parallel:
262
+ x, attn = layer(x, memory, src_mask=None, tgt_mask=tgt_mask)
263
+ out = self.decoder_norm_parallel(x)
264
+
265
+ N, S, _ = out.shape
266
+ cmd_logits = self.command_fcn(out)
267
+ args_logits = self.args_fcn(out)
268
+ args_logits = args_logits.reshape(N, S, 8, 128)
269
+
270
+ return cmd_logits, args_logits
271
+
272
+
273
+ def _get_key_padding_mask(commands, seq_dim=0):
274
+ """
275
+ Args:
276
+ commands: Shape [S, ...]
277
+ """
278
+ lens =[]
279
+ with torch.no_grad():
280
+ key_padding_mask = (commands == 0).cumsum(dim=seq_dim) > 0
281
+ commands=commands.transpose(0,1).squeeze(-1) #bs, opts.max_seq_len
282
+ for i in range(commands.size(0)):
283
+ try:
284
+ seqi = commands[i]#blue opts.max_seq_len
285
+ index = torch.where(seqi==0)[0][0]
286
+
287
+ except:
288
+ index=opts.max_seq_len
289
+
290
+ lens.append(index)
291
+ lens = torch.tensor(lens)+1#blue b
292
+ seqlen_mask = util_funcs.sequence_mask(lens, opts.max_seq_len)#blue b,opts.max_seq_len
293
+ return seqlen_mask
294
+
295
+ class Transformer(nn.Module):
296
+ def __init__(
297
+ self,
298
+ *,
299
+ num_freq_bands,
300
+ depth,
301
+ max_freq,
302
+ input_channels = 1,
303
+ input_axis = 2,
304
+ num_latents = 512,
305
+ latent_dim = 512,
306
+ cross_heads = 1,
307
+ latent_heads = 8,
308
+ cross_dim_head = 64,
309
+ latent_dim_head = 64,
310
+ num_classes = 1000,
311
+ attn_dropout = 0.,
312
+ ff_dropout = 0.,
313
+ weight_tie_layers = False,
314
+ fourier_encode_data = True,
315
+ self_per_cross_attn = 2,
316
+ final_classifier_head = True
317
+ ):
318
+ """The shape of the final attention mechanism will be:
319
+ depth * (cross attention -> self_per_cross_attn * self attention)
320
+
321
+ Args:
322
+ num_freq_bands: Number of freq bands, with original value (2 * K + 1)
323
+ depth: Depth of net.
324
+ max_freq: Maximum frequency, hyperparameter depending on how
325
+ fine the data is.
326
+ freq_base: Base for the frequency
327
+ input_channels: Number of channels for each token of the input.
328
+ input_axis: Number of axes for input data (2 for images, 3 for video)
329
+ num_latents: Number of latents, or induced set points, or centroids.
330
+ Different papers giving it different names.
331
+ latent_dim: Latent dimension.
332
+ cross_heads: Number of heads for cross attention. Paper said 1.
333
+ latent_heads: Number of heads for latent self attention, 8.
334
+ cross_dim_head: Number of dimensions per cross attention head.
335
+ latent_dim_head: Number of dimensions per latent self attention head.
336
+ num_classes: Output number of classes.
337
+ attn_dropout: Attention dropout
338
+ ff_dropout: Feedforward dropout
339
+ weight_tie_layers: Whether to weight tie layers (optional).
340
+ fourier_encode_data: Whether to auto-fourier encode the data, using
341
+ the input_axis given. defaults to True, but can be turned off
342
+ if you are fourier encoding the data yourself.
343
+ self_per_cross_attn: Number of self attention blocks per cross attn.
344
+ final_classifier_head: mean pool and project embeddings to number of classes (num_classes) at the end
345
+ """
346
+ super().__init__()
347
+ self.input_axis = input_axis
348
+ self.max_freq = max_freq
349
+ self.num_freq_bands = num_freq_bands
350
+
351
+ self.fourier_encode_data = fourier_encode_data
352
+ fourier_channels = (input_axis * ((num_freq_bands * 2) + 1)) if fourier_encode_data else 0 # 26
353
+ input_dim = fourier_channels + input_channels
354
+
355
+ self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
356
+
357
+ get_cross_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, input_dim, heads=cross_heads, dim_head=cross_dim_head, dropout=attn_dropout), context_dim=input_dim)
358
+ get_cross_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout=ff_dropout))
359
+ get_latent_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, heads=latent_heads, dim_head=latent_dim_head, dropout=attn_dropout))
360
+ get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout=ff_dropout))
361
+
362
+ get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map(cache_fn, (get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff))
363
+
364
+
365
+ #self_per_cross_attn=1
366
+ self.layers = nn.ModuleList([])
367
+ for i in range(depth):
368
+ should_cache = i > 0 and weight_tie_layers
369
+ cache_args = {'_cache': should_cache}
370
+
371
+ self_attns = nn.ModuleList([])
372
+
373
+ for block_ind in range(self_per_cross_attn): #BUG 之前是2 self_per_cross_attn
374
+ self_attns.append(nn.ModuleList([
375
+ get_latent_attn(**cache_args, key = block_ind),
376
+ get_latent_ff(**cache_args, key = block_ind)
377
+ ]))
378
+
379
+ self.layers.append(nn.ModuleList([
380
+ get_cross_attn(**cache_args),
381
+ get_cross_ff(**cache_args),
382
+ self_attns
383
+ ]))
384
+
385
+
386
+ get_cross_attn2 = lambda: PreNorm(latent_dim, Attention(latent_dim, input_dim, heads = cross_heads, dim_head = cross_dim_head, dropout = attn_dropout), context_dim = input_dim)
387
+ get_cross_ff2 = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout))
388
+ get_latent_attn2 = lambda: PreNorm(latent_dim, Attention(latent_dim, heads = latent_heads, dim_head = latent_dim_head, dropout = attn_dropout))
389
+ get_latent_ff2 = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout))
390
+
391
+ get_cross_attn2, get_cross_ff2, get_latent_attn2, get_latent_ff2 = map(cache_fn, (get_cross_attn2, get_cross_ff2, get_latent_attn2, get_latent_ff2))
392
+
393
+ self.layers_cnnsvg = nn.ModuleList([])
394
+ for i in range(1):
395
+ should_cache = i > 0 and weight_tie_layers
396
+ cache_args = {'_cache': should_cache}
397
+
398
+ self_attns2 = nn.ModuleList([])
399
+
400
+ for block_ind in range(self_per_cross_attn):
401
+ self_attns2.append(nn.ModuleList([
402
+ get_latent_attn2(**cache_args, key = block_ind),
403
+ get_latent_ff2(**cache_args, key = block_ind)
404
+ ]))
405
+
406
+ self.layers_cnnsvg.append(nn.ModuleList([
407
+ get_cross_attn2(**cache_args),
408
+ get_cross_ff2(**cache_args),
409
+ self_attns2
410
+ ]))
411
+
412
+ self.to_logits = nn.Sequential(
413
+ Reduce('b n d -> b d', 'mean'),
414
+ nn.LayerNorm(latent_dim),
415
+ nn.Linear(latent_dim, num_classes)
416
+ ) if final_classifier_head else nn.Identity()
417
+ self.pre_lstm_fc = nn.Linear(10,opts.hidden_size)
418
+ self.posr = PositionalEncoding(d_model=opts.hidden_size,max_len=opts.max_seq_len)
419
+
420
+ patch_height = 2
421
+ patch_width = 2
422
+ patch_dim = 1 * patch_height * patch_width
423
+ self.to_patch_embedding = nn.Sequential(
424
+ Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
425
+ nn.Linear(patch_dim, 16),
426
+ )
427
+
428
+ self.SVG_embedding = SVGEmbedding()
429
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, 512))
430
+
431
+ def forward(self, data, seq, ref_cls_onehot=None, mask=None, return_embeddings=True):
432
+
433
+ b, *axis, _, device, dtype = *data.shape, data.device, data.dtype
434
+ assert len(axis) == self.input_axis, 'input data must have the right number of axis' # img is 2
435
+ x = seq
436
+ commands=x[:, :, :1]
437
+ args=x[:, :, 1:]
438
+ x = self.SVG_embedding(commands, args).transpose(0,1)
439
+ cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = x.size(0))
440
+ x = torch.cat([cls_tokens,x],dim = 1)
441
+ cls_one_pad = torch.ones((1,1,1)).to(x.device).repeat(x.size(0),1,1)
442
+ mask = torch.cat([cls_one_pad,mask],dim=-1)
443
+ self_atten = []
444
+ for cross_attn, cross_ff, self_attns in self.layers:
445
+ for self_attn, self_ff in self_attns:
446
+ x_,atten = self_attn(x,mask=mask)
447
+ x = x_ + x
448
+ self_atten.append(atten)
449
+ x = self_ff(x) + x
450
+ x = x + torch.randn_like(x) # add a perturbation
451
+ return x, self_atten
452
+
453
+ def att_residual(self, x, mask=None):
454
+
455
+ for cross_attn, cross_ff, self_attns in self.layers_cnnsvg:
456
+ for self_attn, self_ff in self_attns:
457
+ x_, atten = self_attn(x)
458
+ x = x_ + x
459
+ x = self_ff(x) + x
460
+ return x
461
+
462
+
463
+
464
+ def loss(self, cmd_logits, args_logits, trg_seq, trg_seqlen, trg_pts_aux):
465
+ '''
466
+ Inputs:
467
+ cmd_logits: [b, 51, 4]
468
+ args_logits: [b, 51, 6]
469
+ '''
470
+ cmd_args_mask = torch.Tensor([[0, 0, 0., 0., 0., 0., 0., 0.],
471
+ [1, 1, 0., 0., 0., 0., 1., 1.],
472
+ [1, 1, 0., 0., 0., 0., 1., 1.],
473
+ [1, 1, 1., 1., 1., 1., 1., 1.]]).to(cmd_logits.device)
474
+
475
+ tgt_commands = trg_seq[:,:,:1].transpose(0,1)
476
+ tgt_args = trg_seq[:,:,1:].transpose(0,1)
477
+
478
+ seqlen_mask = util_funcs.sequence_mask(trg_seqlen, opts.max_seq_len).unsqueeze(-1)
479
+ seqlen_mask2 = seqlen_mask.repeat(1,1,4)# NOTE b,501,4
480
+ seqlen_mask4 = seqlen_mask.repeat(1,1,8)
481
+ seqlen_mask3 = seqlen_mask.unsqueeze(-1).repeat(1,1,8,128)
482
+
483
+
484
+ tgt_commands_onehot = F.one_hot(tgt_commands, 4)
485
+ tgt_args_onehot = F.one_hot(tgt_args, 128)
486
+
487
+ args_mask = torch.matmul(tgt_commands_onehot.float(),cmd_args_mask).squeeze()
488
+
489
+
490
+ loss_cmd = torch.sum(- tgt_commands_onehot.squeeze() * F.log_softmax(cmd_logits, -1), -1)
491
+ loss_cmd = torch.mul(loss_cmd, seqlen_mask.squeeze())
492
+ loss_cmd = torch.mean(torch.sum(loss_cmd/trg_seqlen.unsqueeze(-1),-1))
493
+
494
+ loss_args = (torch.sum(-tgt_args_onehot*F.log_softmax(args_logits,-1),-1)*seqlen_mask4*args_mask)
495
+
496
+ loss_args = torch.mean(loss_args,dim=-1,keepdim=False)
497
+ loss_args = torch.mean(torch.sum(loss_args/trg_seqlen.unsqueeze(-1),-1))
498
+
499
+ SE_mask = torch.Tensor([[1, 1],
500
+ [0, 0],
501
+ [1, 1],
502
+ [1, 1]]).to(cmd_logits.device)
503
+
504
+ SE_args_mask = torch.matmul(tgt_commands_onehot.float(),SE_mask).squeeze().unsqueeze(-1)
505
+
506
+
507
+ args_prob = F.softmax(args_logits, -1)
508
+ args_end = args_prob[:,:,6:]
509
+ args_end_shifted = torch.cat((torch.zeros(args_end.size(0),1,args_end.size(2),args_end.size(3)).to(args_end.device),args_end),1)
510
+ args_end_shifted = args_end_shifted[:,:opts.max_seq_len,:,:]
511
+ args_end_shifted = args_end_shifted*SE_args_mask + args_end*(1-SE_args_mask)
512
+
513
+ args_start = args_prob[:,:,:2]
514
+
515
+ seqlen_mask5 = util_funcs.sequence_mask(trg_seqlen-1, opts.max_seq_len).unsqueeze(-1)
516
+ seqlen_mask5 = seqlen_mask5.repeat(1,1,2)
517
+
518
+ smooth_constrained = torch.sum(torch.pow((args_end_shifted - args_start), 2), -1) * seqlen_mask5
519
+ smooth_constrained = torch.mean(smooth_constrained, dim=-1, keepdim=False)
520
+ smooth_constrained = torch.mean(torch.sum(smooth_constrained / (trg_seqlen - 1).unsqueeze(-1), -1))
521
+
522
+ args_prob2 = F.softmax(args_logits / 0.1, -1)
523
+
524
+ c = torch.argmax(args_prob2,-1).unsqueeze(-1).float() - args_prob2.detach()
525
+ p_argmax = args_prob2 + c
526
+ p_argmax = torch.mean(p_argmax,-1)
527
+ control_pts = denumericalize(p_argmax)
528
+
529
+ p0 = control_pts[:,:,:2]
530
+ p1 = control_pts[:,:,2:4]
531
+ p2 = control_pts[:,:,4:6]
532
+ p3 = control_pts[:,:,6:8]
533
+
534
+ line_mask = (tgt_commands==2).float() + (tgt_commands==1).float()
535
+ curve_mask = (tgt_commands==3).float()
536
+
537
+ t=0.25
538
+ aux_pts_line = p0 + t*(p3-p0)
539
+ for t in [0.5,0.75]:
540
+ coord_t = p0 + t*(p3-p0)
541
+ aux_pts_line = torch.cat((aux_pts_line,coord_t),-1)
542
+ aux_pts_line = aux_pts_line*line_mask
543
+
544
+ t=0.25
545
+ aux_pts_curve = (1-t)*(1-t)*(1-t)*p0 + 3*t*(1-t)*(1-t)*p1 + 3*t*t*(1-t)*p2 + t*t*t*p3
546
+ for t in [0.5, 0.75]:
547
+ coord_t = (1-t)*(1-t)*(1-t)*p0 + 3*t*(1-t)*(1-t)*p1 + 3*t*t*(1-t)*p2 + t*t*t*p3
548
+ aux_pts_curve = torch.cat((aux_pts_curve,coord_t),-1)
549
+ aux_pts_curve = aux_pts_curve * curve_mask
550
+
551
+
552
+ aux_pts_predict = aux_pts_curve + aux_pts_line
553
+ seqlen_mask_aux = util_funcs.sequence_mask(trg_seqlen - 1, opts.max_seq_len).unsqueeze(-1)
554
+ aux_pts_loss = torch.pow((aux_pts_predict - trg_pts_aux), 2) * seqlen_mask_aux
555
+
556
+ loss_aux = torch.mean(aux_pts_loss, dim=-1, keepdim=False)
557
+ loss_aux = torch.mean(torch.sum(loss_aux / trg_seqlen.unsqueeze(-1), -1))
558
+
559
+
560
+ loss = opts.loss_w_cmd * loss_cmd + opts.loss_w_args * loss_args + opts.loss_w_aux * loss_aux + opts.loss_w_smt * smooth_constrained
561
+
562
+ svg_losses = {}
563
+ svg_losses['loss_total'] = loss
564
+ svg_losses["loss_cmd"] = loss_cmd
565
+ svg_losses["loss_args"] = loss_args
566
+ svg_losses["loss_smt"] = smooth_constrained
567
+ svg_losses["loss_aux"] = loss_aux
568
+
569
+ return svg_losses
570
+
571
+ class DecoderLayer(nn.Module):
572
+ "Decoder is made of self-attn, src-attn, and feed forward (defined below)"
573
+
574
+ def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
575
+ super(DecoderLayer, self).__init__()
576
+ self.size = size
577
+ self.self_attn = self_attn
578
+ self.src_attn = src_attn
579
+ self.feed_forward = feed_forward
580
+ self.sublayer = clones(SublayerConnection(size, dropout), 3)
581
+
582
+ def forward(self, x, memory, src_mask, tgt_mask):
583
+ "Follow Figure 1 (right) for connections."
584
+ m = memory
585
+ x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
586
+ x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
587
+ attn = self.self_attn.attn
588
+ return self.sublayer[2](x, self.feed_forward),attn
589
+
590
+ def subsequent_mask(size):
591
+ "Mask out subsequent positions."
592
+ attn_shape = (1, size, size)
593
+ subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
594
+ return torch.from_numpy(subsequent_mask) == 0
595
+
596
+ def numericalize(cmd, n=128):
597
+ """NOTE: shall only be called after normalization"""
598
+ # assert np.max(cmd.origin) <= 1.0 and np.min(cmd.origin) >= -1.0
599
+ cmd = (cmd / 30 * n).round().clip(min=0, max=n-1).int()
600
+ return cmd
601
+
602
+ def denumericalize(cmd, n=128):
603
+ cmd = cmd / n * 30
604
+ return cmd
605
+
606
+ def attention(query, key, value, mask=None, trg_tri_mask=None,dropout=None, posr=None):
607
+ "Compute 'Scaled Dot Product Attention'"
608
+ d_k = query.size(-1)
609
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
610
+
611
+ if posr is not None:
612
+ posr = posr.unsqueeze(1)
613
+ scores = scores + posr
614
+
615
+ if mask is not None:
616
+ try:
617
+ scores = scores.masked_fill(mask == 0, -1e9) # note mask: b,1,501,501 scores: b, head, 501,501
618
+ except Exception as e:
619
+ print("Shape: ",scores.shape)
620
+ print("Error: ",e)
621
+ import pdb; pdb.set_trace()
622
+
623
+ if trg_tri_mask is not None:
624
+ scores = scores.masked_fill(trg_tri_mask == 0, -1e9)
625
+
626
+ p_attn = F.softmax(scores, dim=-1)
627
+
628
+ if dropout is not None:
629
+ p_attn = dropout(p_attn)
630
+
631
+ return torch.matmul(p_attn, value), p_attn
632
+
633
+
634
+ class MultiHeadedAttention(nn.Module):
635
+ def __init__(self, h, d_model, dropout):
636
+ "Take in model size and number of heads."
637
+ super(MultiHeadedAttention, self).__init__()
638
+ assert d_model % h == 0
639
+ # We assume d_v always equals d_k
640
+ self.d_k = d_model // h #32
641
+ self.h = h #8
642
+ self.linears = clones(nn.Linear(d_model, d_model), 4)
643
+ self.attn = None
644
+ self.dropout = nn.Dropout(p=dropout)
645
+
646
+ def forward(self, query, key, value, mask=None,trg_tri_mask=None, posr=None):
647
+ "Implements Figure 2"
648
+
649
+ if mask is not None:
650
+ # Same mask applied to all h heads.
651
+ mask = mask.unsqueeze(1)
652
+ nbatches = query.size(0) #16
653
+
654
+ query, key, value = \
655
+ [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
656
+ for l, x in zip(self.linears, (query, key, value))]
657
+
658
+ x, self.attn = attention(query, key, value, mask=mask,trg_tri_mask=trg_tri_mask,
659
+ dropout=self.dropout, posr=posr)
660
+
661
+ x = x.transpose(1, 2).contiguous() \
662
+ .view(nbatches, -1, self.h * self.d_k)
663
+
664
+ return self.linears[-1](x)
665
+
666
+ def clones(module, N):
667
+ "Produce N identical layers."
668
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
669
+
670
+ class SublayerConnection(nn.Module):
671
+ """
672
+ A residual connection followed by a layer norm.
673
+ Note for code simplicity the norm is first as opposed to last.
674
+ """
675
+
676
+ def __init__(self, size, dropout):
677
+ super(SublayerConnection, self).__init__()
678
+ self.norm = nn.LayerNorm(size)
679
+ self.dropout = nn.Dropout(dropout)
680
+
681
+ def forward(self, x, sublayer):
682
+ "Apply residual connection to any sublayer with the same size."
683
+ x_norm=self.norm(x)
684
+ return x + self.dropout(sublayer(x_norm))#+ self.augs(x_norm)
685
+
686
+
687
+ if __name__ == '__main__':
688
+ model = Transformer(
689
+ input_channels = 1, # number of channels for each token of the input
690
+ input_axis = 2, # number of axis for input data (2 for images, 3 for video)
691
+ num_freq_bands = 6, # number of freq bands, with original value (2 * K + 1)
692
+ max_freq = 10., # maximum frequency, hyperparameter depending on how fine the data is
693
+ depth = 6, # depth of net. The shape of the final attention mechanism will be:
694
+ # depth * (cross attention -> self_per_cross_attn * self attention)
695
+ num_latents = 256, # number of latents, or induced set points, or centroids. different papers giving it different names
696
+ latent_dim = 512, # latent dimension
697
+ cross_heads = 1, # number of heads for cross attention. paper said 1
698
+ latent_heads = 8, # number of heads for latent self attention, 8
699
+ cross_dim_head = 64, # number of dimensions per cross attention head
700
+ latent_dim_head = 64, # number of dimensions per latent self attention head
701
+ num_classes = 1000, # output number of classes
702
+ attn_dropout = 0.,
703
+ ff_dropout = 0.,
704
+ weight_tie_layers = False, # whether to weight tie layers (optional, as indicated in the diagram)
705
+ fourier_encode_data = True, # whether to auto-fourier encode the data, using the input_axis given. defaults to True, but can be turned off if you are fourier encoding the data yourself
706
+ self_per_cross_attn = 2 # number of self attention blocks per cross attention
707
+ )
708
+
709
+ img = torch.randn(1, 224, 224, 3) # 1 imagenet image, pixelized
710
+
711
  model(img) # (1, 1000)
models/util_funcs.py CHANGED
@@ -1,96 +1,96 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import cairosvg
4
- from data_utils.common_utils import trans2_white_bg
5
- from PIL import Image
6
- import numpy as np
7
-
8
- def select_imgs(images_of_onefont, selected_cls, opts):
9
- # given selected char classes, return selected imgs
10
- # images_of_onefont: [bs, 52, opts.img_size, opts.img_size]
11
- # selected_cls: [bs, nshot]
12
- nums = selected_cls.size(1)
13
- selected_cls_ = selected_cls.unsqueeze(2)
14
- selected_cls_ = selected_cls_.unsqueeze(3)
15
- selected_cls_ = selected_cls_.expand(images_of_onefont.size(0), nums, opts.img_size, opts.img_size)
16
- selected_img = torch.gather(images_of_onefont, 1, selected_cls_)
17
- return selected_img
18
-
19
- def select_seqs(seqs_of_onefont, selected_cls, opts, seq_dim):
20
-
21
- nums = selected_cls.size(1)
22
- selected_cls_ = selected_cls.unsqueeze(2)
23
- selected_cls_ = selected_cls_.unsqueeze(3)
24
- selected_cls_ = selected_cls_.expand(seqs_of_onefont.size(0), nums, opts.max_seq_len, seq_dim)
25
- selected_seqs = torch.gather(seqs_of_onefont, 1, selected_cls_)
26
- return selected_seqs
27
-
28
- def select_seqlens(seqlens_of_onefont, selected_cls, opts):
29
-
30
- nums = selected_cls.size(1)
31
- selected_cls_ = selected_cls.unsqueeze(2)
32
- selected_cls_ = selected_cls_.expand(seqlens_of_onefont.size(0), nums, 1) # 64, nums, 1
33
- selected_seqlens = torch.gather(seqlens_of_onefont, 1, selected_cls_)
34
- return selected_seqlens
35
-
36
- def trgcls_to_onehot(trg_cls, opts):
37
- trg_char = F.one_hot(trg_cls, num_classes=opts.char_num).squeeze(dim=1)
38
- return trg_char
39
-
40
-
41
- def shift_right(x, pad_value=None):
42
- if pad_value is None:
43
- shifted = F.pad(x, (0, 0, 0, 0, 1, 0))[:-1, :, :]
44
- else:
45
- shifted = torch.cat([pad_value, x], axis=0)[:-1, :, :]
46
- return shifted
47
-
48
-
49
- def length_form_embedding(emb):
50
- """Compute the length of each sequence in the batch
51
- Args:
52
- emb: [seq_len, batch, depth]
53
- Returns:
54
- a 0/1 tensor: [batch]
55
- """
56
- absed = torch.abs(emb)
57
- sum_last = torch.sum(absed, dim=2, keepdim=True)
58
- mask = sum_last != 0
59
- sum_except_batch = torch.sum(mask, dim=(0, 2), dtype=torch.long)
60
- return sum_except_batch
61
-
62
-
63
- def lognormal(y, mean, logstd, logsqrttwopi):
64
- y_mean = y - mean # NOTE y:[b*51*6, 1] mean: [b*51*6, 50]
65
- logstd_exp = logstd.exp() # NOTE [b*51*6, 50]
66
- y_mean_divide_exp = y_mean / logstd_exp
67
- return -0.5 * (y_mean_divide_exp) ** 2 - logstd - logsqrttwopi
68
-
69
- def sequence_mask(lengths, max_len=None):
70
- batch_size=lengths.numel()
71
- max_len=max_len or lengths.max()
72
- return (torch.arange(0, max_len, device=lengths.device)
73
- .type_as(lengths)
74
- .unsqueeze(0).expand(batch_size,max_len)
75
- .lt(lengths.unsqueeze(1)))
76
-
77
- def svg2img(path_svg, path_img, img_size):
78
- cairosvg.svg2png(url=path_svg, write_to=path_img, output_width=img_size, output_height=img_size)
79
- img_arr = trans2_white_bg(path_img)
80
- return img_arr
81
-
82
- def cal_img_l1_dist(path_img1, path_img2):
83
- img1 = np.array(Image.open(path_img1))
84
- img2 = np.array(Image.open(path_img2))
85
- dist = np.mean(np.abs(img1 - img2[:, :, 0]))
86
- return dist
87
-
88
- def cal_iou(path_img1, path_img2):
89
-
90
- img1 = np.array(Image.open(path_img1))
91
- img2 = np.array(Image.open(path_img2))[:, :, 0]
92
- mask_img1 = img1 < (255 * 3 / 4)
93
- mask_img2 = img2 < (255 * 3 / 4)
94
- iou = np.sum(mask_img1 * mask_img2) / (np.sum(mask_img1 + mask_img2))
95
- l1_dist = np.mean(np.abs(mask_img1.astype(float) - mask_img2.astype(float)))
96
  return iou, l1_dist
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import cairosvg
4
+ from data_utils.common_utils import trans2_white_bg
5
+ from PIL import Image
6
+ import numpy as np
7
+
8
+ def select_imgs(images_of_onefont, selected_cls, opts):
9
+ # given selected char classes, return selected imgs
10
+ # images_of_onefont: [bs, 52, opts.img_size, opts.img_size]
11
+ # selected_cls: [bs, nshot]
12
+ nums = selected_cls.size(1)
13
+ selected_cls_ = selected_cls.unsqueeze(2)
14
+ selected_cls_ = selected_cls_.unsqueeze(3)
15
+ selected_cls_ = selected_cls_.expand(images_of_onefont.size(0), nums, opts.img_size, opts.img_size)
16
+ selected_img = torch.gather(images_of_onefont, 1, selected_cls_)
17
+ return selected_img
18
+
19
+ def select_seqs(seqs_of_onefont, selected_cls, opts, seq_dim):
20
+
21
+ nums = selected_cls.size(1)
22
+ selected_cls_ = selected_cls.unsqueeze(2)
23
+ selected_cls_ = selected_cls_.unsqueeze(3)
24
+ selected_cls_ = selected_cls_.expand(seqs_of_onefont.size(0), nums, opts.max_seq_len, seq_dim)
25
+ selected_seqs = torch.gather(seqs_of_onefont, 1, selected_cls_)
26
+ return selected_seqs
27
+
28
+ def select_seqlens(seqlens_of_onefont, selected_cls, opts):
29
+
30
+ nums = selected_cls.size(1)
31
+ selected_cls_ = selected_cls.unsqueeze(2)
32
+ selected_cls_ = selected_cls_.expand(seqlens_of_onefont.size(0), nums, 1) # 64, nums, 1
33
+ selected_seqlens = torch.gather(seqlens_of_onefont, 1, selected_cls_)
34
+ return selected_seqlens
35
+
36
+ def trgcls_to_onehot(trg_cls, opts):
37
+ trg_char = F.one_hot(trg_cls, num_classes=opts.char_num).squeeze(dim=1)
38
+ return trg_char
39
+
40
+
41
+ def shift_right(x, pad_value=None):
42
+ if pad_value is None:
43
+ shifted = F.pad(x, (0, 0, 0, 0, 1, 0))[:-1, :, :]
44
+ else:
45
+ shifted = torch.cat([pad_value, x], axis=0)[:-1, :, :]
46
+ return shifted
47
+
48
+
49
+ def length_form_embedding(emb):
50
+ """Compute the length of each sequence in the batch
51
+ Args:
52
+ emb: [seq_len, batch, depth]
53
+ Returns:
54
+ a 0/1 tensor: [batch]
55
+ """
56
+ absed = torch.abs(emb)
57
+ sum_last = torch.sum(absed, dim=2, keepdim=True)
58
+ mask = sum_last != 0
59
+ sum_except_batch = torch.sum(mask, dim=(0, 2), dtype=torch.long)
60
+ return sum_except_batch
61
+
62
+
63
+ def lognormal(y, mean, logstd, logsqrttwopi):
64
+ y_mean = y - mean # NOTE y:[b*51*6, 1] mean: [b*51*6, 50]
65
+ logstd_exp = logstd.exp() # NOTE [b*51*6, 50]
66
+ y_mean_divide_exp = y_mean / logstd_exp
67
+ return -0.5 * (y_mean_divide_exp) ** 2 - logstd - logsqrttwopi
68
+
69
+ def sequence_mask(lengths, max_len=None):
70
+ batch_size=lengths.numel()
71
+ max_len=max_len or lengths.max()
72
+ return (torch.arange(0, max_len, device=lengths.device)
73
+ .type_as(lengths)
74
+ .unsqueeze(0).expand(batch_size,max_len)
75
+ .lt(lengths.unsqueeze(1)))
76
+
77
+ def svg2img(path_svg, path_img, img_size):
78
+ cairosvg.svg2png(url=path_svg, write_to=path_img, output_width=img_size, output_height=img_size)
79
+ img_arr = trans2_white_bg(path_img)
80
+ return img_arr
81
+
82
+ def cal_img_l1_dist(path_img1, path_img2):
83
+ img1 = np.array(Image.open(path_img1))
84
+ img2 = np.array(Image.open(path_img2))
85
+ dist = np.mean(np.abs(img1 - img2[:, :, 0]))
86
+ return dist
87
+
88
+ def cal_iou(path_img1, path_img2):
89
+
90
+ img1 = np.array(Image.open(path_img1))
91
+ img2 = np.array(Image.open(path_img2))[:, :, 0]
92
+ mask_img1 = img1 < (255 * 3 / 4)
93
+ mask_img2 = img2 < (255 * 3 / 4)
94
+ iou = np.sum(mask_img1 * mask_img2) / (np.sum(mask_img1 + mask_img2))
95
+ l1_dist = np.mean(np.abs(mask_img1.astype(float) - mask_img2.astype(float)))
96
  return iou, l1_dist
models/vgg_perceptual_loss.py CHANGED
@@ -1,69 +1,69 @@
1
- import torch
2
- import torchvision
3
-
4
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5
-
6
- class VGG19Feats(torch.nn.Module):
7
- def __init__(self, requires_grad=False):
8
- super(VGG19Feats, self).__init__()
9
- vgg = torchvision.models.vgg19(pretrained=True).to(device) #.cuda()
10
- # vgg.eval()
11
- vgg_pretrained_features = vgg.features.eval()
12
- self.requires_grad = requires_grad
13
- self.slice1 = torch.nn.Sequential()
14
- self.slice2 = torch.nn.Sequential()
15
- self.slice3 = torch.nn.Sequential()
16
- self.slice4 = torch.nn.Sequential()
17
- self.slice5 = torch.nn.Sequential()
18
- for x in range(3):
19
- self.slice1.add_module(str(x), vgg_pretrained_features[x])
20
- for x in range(3, 8):
21
- self.slice2.add_module(str(x), vgg_pretrained_features[x])
22
- for x in range(8, 13):
23
- self.slice3.add_module(str(x), vgg_pretrained_features[x])
24
- for x in range(13, 22):
25
- self.slice4.add_module(str(x), vgg_pretrained_features[x])
26
- for x in range(22, 31):
27
- self.slice5.add_module(str(x), vgg_pretrained_features[x])
28
- if not self.requires_grad:
29
- for param in self.parameters():
30
- param.requires_grad = False
31
-
32
- def forward(self, img):
33
- conv1_2 = self.slice1(img)
34
- conv2_2 = self.slice2(conv1_2)
35
- conv3_2 = self.slice3(conv2_2)
36
- conv4_2 = self.slice4(conv3_2)
37
- conv5_2 = self.slice5(conv4_2)
38
- out = [conv1_2, conv2_2, conv3_2, conv4_2, conv5_2]
39
- return out
40
-
41
-
42
- class VGGPerceptualLoss(torch.nn.Module):
43
- def __init__(self):
44
- super(VGGPerceptualLoss, self).__init__()
45
- self.vgg = VGG19Feats().to(device)
46
- self.criterion = torch.nn.functional.l1_loss
47
- self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
48
- self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
49
- self.weights = [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 1.0*10/1.5]
50
-
51
- def forward(self, input_img, target_img):
52
-
53
- if input_img.shape[1] != 3:
54
- input_img = input_img.repeat(1, 3, 1, 1)
55
- target_img = target_img.repeat(1, 3, 1, 1)
56
- input_img = (input_img - self.mean) / self.std
57
- target_img = (target_img - self.mean) / self.std
58
-
59
- x_vgg, y_vgg = self.vgg(input_img), self.vgg(target_img)
60
-
61
- loss = {}
62
- loss['pt_c_loss'] = self.weights[0] * self.criterion(x_vgg[0], y_vgg[0])+\
63
- self.weights[1] * self.criterion(x_vgg[1], y_vgg[1])+\
64
- self.weights[2] * self.criterion(x_vgg[2], y_vgg[2])+\
65
- self.weights[3] * self.criterion(x_vgg[3], y_vgg[3])+\
66
- self.weights[4] * self.criterion(x_vgg[4], y_vgg[4])
67
- loss['pt_s_loss'] = 0.0
68
-
69
  return loss
 
1
+ import torch
2
+ import torchvision
3
+
4
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5
+
6
+ class VGG19Feats(torch.nn.Module):
7
+ def __init__(self, requires_grad=False):
8
+ super(VGG19Feats, self).__init__()
9
+ vgg = torchvision.models.vgg19(pretrained=True).to(device) #.cuda()
10
+ # vgg.eval()
11
+ vgg_pretrained_features = vgg.features.eval()
12
+ self.requires_grad = requires_grad
13
+ self.slice1 = torch.nn.Sequential()
14
+ self.slice2 = torch.nn.Sequential()
15
+ self.slice3 = torch.nn.Sequential()
16
+ self.slice4 = torch.nn.Sequential()
17
+ self.slice5 = torch.nn.Sequential()
18
+ for x in range(3):
19
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
20
+ for x in range(3, 8):
21
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
22
+ for x in range(8, 13):
23
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
24
+ for x in range(13, 22):
25
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
26
+ for x in range(22, 31):
27
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
28
+ if not self.requires_grad:
29
+ for param in self.parameters():
30
+ param.requires_grad = False
31
+
32
+ def forward(self, img):
33
+ conv1_2 = self.slice1(img)
34
+ conv2_2 = self.slice2(conv1_2)
35
+ conv3_2 = self.slice3(conv2_2)
36
+ conv4_2 = self.slice4(conv3_2)
37
+ conv5_2 = self.slice5(conv4_2)
38
+ out = [conv1_2, conv2_2, conv3_2, conv4_2, conv5_2]
39
+ return out
40
+
41
+
42
+ class VGGPerceptualLoss(torch.nn.Module):
43
+ def __init__(self):
44
+ super(VGGPerceptualLoss, self).__init__()
45
+ self.vgg = VGG19Feats().to(device)
46
+ self.criterion = torch.nn.functional.l1_loss
47
+ self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
48
+ self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
49
+ self.weights = [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 1.0*10/1.5]
50
+
51
+ def forward(self, input_img, target_img):
52
+
53
+ if input_img.shape[1] != 3:
54
+ input_img = input_img.repeat(1, 3, 1, 1)
55
+ target_img = target_img.repeat(1, 3, 1, 1)
56
+ input_img = (input_img - self.mean) / self.std
57
+ target_img = (target_img - self.mean) / self.std
58
+
59
+ x_vgg, y_vgg = self.vgg(input_img), self.vgg(target_img)
60
+
61
+ loss = {}
62
+ loss['pt_c_loss'] = self.weights[0] * self.criterion(x_vgg[0], y_vgg[0])+\
63
+ self.weights[1] * self.criterion(x_vgg[1], y_vgg[1])+\
64
+ self.weights[2] * self.criterion(x_vgg[2], y_vgg[2])+\
65
+ self.weights[3] * self.criterion(x_vgg[3], y_vgg[3])+\
66
+ self.weights[4] * self.criterion(x_vgg[4], y_vgg[4])
67
+ loss['pt_s_loss'] = 0.0
68
+
69
  return loss
options.py CHANGED
@@ -1,67 +1,67 @@
1
- import argparse
2
-
3
- def get_parser_main_model():
4
- parser = argparse.ArgumentParser()
5
- # basic parameters training related
6
- parser.add_argument('--model_name', type=str, default='main_model', choices=['main_model', 'neural_raster'], help='current model_name')
7
- parser.add_argument("--language", type=str, default='tha', choices=['eng', 'chn', 'tha'])
8
- parser.add_argument('--bottleneck_bits', type=int, default=512, help='latent code number of bottleneck bits')
9
- parser.add_argument('--char_num', type=int, default=44, help='number of glyphs, original is 44 (Thai)')
10
- parser.add_argument('--seed', type=int, default=3712)
11
- parser.add_argument('--ref_nshot', type=int, default=8, help='reference number')
12
- parser.add_argument('--batch_size', type=int, default=64, help='batch size')
13
- parser.add_argument('--batch_size_val', type=int, default=8, help='batch size when do validation')
14
- parser.add_argument('--img_size', type=int, default=64, help='image size')
15
- parser.add_argument('--max_seq_len', type=int, default=121, help='maximum length of sequence')
16
- parser.add_argument('--dim_seq', type=int, default=12, help='the dim of each stroke in a sequence, 4 + 8, 4 is cmd, and 8 is args')
17
- parser.add_argument('--dim_seq_short', type=int, default=9, help='the short dim of each stroke in a sequence, 1 + 8, 1 is cmd class num, and 8 is args')
18
- parser.add_argument('--hidden_size', type=int, default=512, help='hidden_size')
19
- parser.add_argument('--dim_seq_latent', type=int, default=512, help='sequence encoder latent dim')
20
- parser.add_argument('--ngf', type=int, default=16, help='the basic num of channel in image encoder and decoder')
21
- parser.add_argument('--n_aux_pts', type=int, default=6, help='the number of aux pts in bezier curves for additional supervison')
22
- # experiment related
23
-
24
- parser.add_argument('--random_index', type=str, default='00')
25
- parser.add_argument('--name_ckpt', type=str, default='600_192921.ckpt')
26
- parser.add_argument('--model_path', type=str, default='.')
27
- parser.add_argument('--n_epochs', type=int, default=800, help='number of epochs')
28
- parser.add_argument('--n_samples', type=int, default=20, help='the number of samples for each glyph when testing')
29
- parser.add_argument('--lr', type=float, default=0.0002, help='learning rate')
30
- parser.add_argument('--ref_char_ids', type=str, default='0,1,26,27', help='default is A, B, a, b')
31
-
32
- parser.add_argument('--mode', type=str, default='test', choices=['train', 'val', 'test'])
33
- parser.add_argument('--multi_gpu', type=bool, default=False)
34
- parser.add_argument('--name_exp', type=str, default='dvf')
35
-
36
- # continue training'
37
- parser.add_argument('--continue_training', type=bool, default=False, help='whether continue training from old checkpoint')
38
- parser.add_argument('--continue_ckpt', type=str, default='.', help='checkpoint model for continue training')
39
- parser.add_argument('--init_epoch', type=int, default=0, help='init epoch')
40
-
41
- # Manually Add
42
- parser.add_argument('--exp_path', type=str, default='.')
43
- parser.add_argument('--dir_res', type=str, default=None)
44
- parser.add_argument('--data_root', type=str, default='./data/vecfont_dataset/')
45
- parser.add_argument('--freq_ckpt', type=int, default=50, help='save checkpoint frequency of epoch')
46
- parser.add_argument('--threshold_ckpt', type=int, default=0, help='save checkpoint only when more than threshold epoch')
47
-
48
- parser.add_argument('--freq_sample', type=int, default=500, help='sample train output of steps')
49
- parser.add_argument('--freq_log', type=int, default=50, help='freq of showing logs')
50
- parser.add_argument('--freq_val', type=int, default=500, help='sample validate output of steps')
51
- parser.add_argument('--beta1', type=float, default=0.9, help='beta1 of Adam optimizer')
52
- parser.add_argument('--beta2', type=float, default=0.999, help='beta2 of Adam optimizer')
53
- parser.add_argument('--eps', type=float, default=1e-8, help='Adam epsilon')
54
- parser.add_argument('--weight_decay', type=float, default=0.0, help='weight decay')
55
- parser.add_argument('--wandb', type=bool, default=True, help='whether use wandb to visulize loss')
56
- parser.add_argument('--wandb_project_name', type=str, default="DeepVecFontV2", help='wandb project name')
57
-
58
- # loss weight
59
- parser.add_argument('--kl_beta', type=float, default=0.01, help='latent code kl loss beta')
60
- parser.add_argument('--loss_w_pt_c', type=float, default=0.001 * 10, help='the weight of perceptual content loss')
61
- parser.add_argument('--loss_w_l1', type=float, default=1.0 * 10, help='the weight of image reconstruction l1 loss')
62
- parser.add_argument('--loss_w_cmd', type=float, default=1.0, help='the weight of cmd loss')
63
- parser.add_argument('--loss_w_args', type=float, default=1.0, help='the weight of args loss')
64
- parser.add_argument('--loss_w_aux', type=float, default=0.01, help='the weight of pts aux loss')
65
- parser.add_argument('--loss_w_smt', type=float, default=10., help='the weight of smooth loss')
66
-
67
- return parser
 
1
+ import argparse
2
+
3
+ def get_parser_main_model():
4
+ parser = argparse.ArgumentParser()
5
+ # basic parameters training related
6
+ parser.add_argument('--model_name', type=str, default='main_model', choices=['main_model', 'neural_raster'], help='current model_name')
7
+ parser.add_argument("--language", type=str, default='tha', choices=['eng', 'chn', 'tha'])
8
+ parser.add_argument('--bottleneck_bits', type=int, default=512, help='latent code number of bottleneck bits')
9
+ parser.add_argument('--char_num', type=int, default=44, help='number of glyphs, original is 44 (Thai)')
10
+ parser.add_argument('--seed', type=int, default=3712)
11
+ parser.add_argument('--ref_nshot', type=int, default=8, help='reference number')
12
+ parser.add_argument('--batch_size', type=int, default=64, help='batch size')
13
+ parser.add_argument('--batch_size_val', type=int, default=8, help='batch size when do validation')
14
+ parser.add_argument('--img_size', type=int, default=64, help='image size')
15
+ parser.add_argument('--max_seq_len', type=int, default=121, help='maximum length of sequence')
16
+ parser.add_argument('--dim_seq', type=int, default=12, help='the dim of each stroke in a sequence, 4 + 8, 4 is cmd, and 8 is args')
17
+ parser.add_argument('--dim_seq_short', type=int, default=9, help='the short dim of each stroke in a sequence, 1 + 8, 1 is cmd class num, and 8 is args')
18
+ parser.add_argument('--hidden_size', type=int, default=512, help='hidden_size')
19
+ parser.add_argument('--dim_seq_latent', type=int, default=512, help='sequence encoder latent dim')
20
+ parser.add_argument('--ngf', type=int, default=16, help='the basic num of channel in image encoder and decoder')
21
+ parser.add_argument('--n_aux_pts', type=int, default=6, help='the number of aux pts in bezier curves for additional supervison')
22
+ # experiment related
23
+
24
+ parser.add_argument('--random_index', type=str, default='00')
25
+ parser.add_argument('--name_ckpt', type=str, default='600_192921.ckpt')
26
+ parser.add_argument('--model_path', type=str, default='.')
27
+ parser.add_argument('--n_epochs', type=int, default=800, help='number of epochs')
28
+ parser.add_argument('--n_samples', type=int, default=20, help='the number of samples for each glyph when testing')
29
+ parser.add_argument('--lr', type=float, default=0.0002, help='learning rate')
30
+ parser.add_argument('--ref_char_ids', type=str, default='0,1,26,27', help='default is A, B, a, b')
31
+
32
+ parser.add_argument('--mode', type=str, default='test', choices=['train', 'val', 'test'])
33
+ parser.add_argument('--multi_gpu', type=bool, default=False)
34
+ parser.add_argument('--name_exp', type=str, default='dvf')
35
+
36
+ # continue training'
37
+ parser.add_argument('--continue_training', type=bool, default=False, help='whether continue training from old checkpoint')
38
+ parser.add_argument('--continue_ckpt', type=str, default='.', help='checkpoint model for continue training')
39
+ parser.add_argument('--init_epoch', type=int, default=0, help='init epoch')
40
+
41
+ # Manually Add
42
+ parser.add_argument('--exp_path', type=str, default='.')
43
+ parser.add_argument('--dir_res', type=str, default=None)
44
+ parser.add_argument('--data_root', type=str, default='./data/vecfont_dataset/')
45
+ parser.add_argument('--freq_ckpt', type=int, default=50, help='save checkpoint frequency of epoch')
46
+ parser.add_argument('--threshold_ckpt', type=int, default=0, help='save checkpoint only when more than threshold epoch')
47
+
48
+ parser.add_argument('--freq_sample', type=int, default=500, help='sample train output of steps')
49
+ parser.add_argument('--freq_log', type=int, default=50, help='freq of showing logs')
50
+ parser.add_argument('--freq_val', type=int, default=500, help='sample validate output of steps')
51
+ parser.add_argument('--beta1', type=float, default=0.9, help='beta1 of Adam optimizer')
52
+ parser.add_argument('--beta2', type=float, default=0.999, help='beta2 of Adam optimizer')
53
+ parser.add_argument('--eps', type=float, default=1e-8, help='Adam epsilon')
54
+ parser.add_argument('--weight_decay', type=float, default=0.0, help='weight decay')
55
+ parser.add_argument('--wandb', type=bool, default=True, help='whether use wandb to visulize loss')
56
+ parser.add_argument('--wandb_project_name', type=str, default="DeepVecFontV2", help='wandb project name')
57
+
58
+ # loss weight
59
+ parser.add_argument('--kl_beta', type=float, default=0.01, help='latent code kl loss beta')
60
+ parser.add_argument('--loss_w_pt_c', type=float, default=0.001 * 10, help='the weight of perceptual content loss')
61
+ parser.add_argument('--loss_w_l1', type=float, default=1.0 * 10, help='the weight of image reconstruction l1 loss')
62
+ parser.add_argument('--loss_w_cmd', type=float, default=1.0, help='the weight of cmd loss')
63
+ parser.add_argument('--loss_w_args', type=float, default=1.0, help='the weight of args loss')
64
+ parser.add_argument('--loss_w_aux', type=float, default=0.01, help='the weight of pts aux loss')
65
+ parser.add_argument('--loss_w_smt', type=float, default=10., help='the weight of smooth loss')
66
+
67
+ return parser
requirements.txt CHANGED
@@ -1,120 +1,120 @@
1
- aiofiles==23.2.1
2
- altair==5.3.0
3
- annotated-types==0.7.0
4
- anyio==4.4.0
5
- attrs==23.2.0
6
- blinker==1.8.2
7
- build==1.2.1
8
- cachetools==5.3.3
9
- cairocffi==1.7.0
10
- CairoSVG==2.7.1
11
- certifi==2024.2.2
12
- cffi==1.16.0
13
- charset-normalizer==3.3.2
14
- click==8.1.7
15
- contourpy==1.2.1
16
- cssselect2==0.7.0
17
- cycler==0.12.1
18
- defusedxml==0.7.1
19
- dnspython==2.6.1
20
- docker-pycreds==0.4.0
21
- einops==0.8.0
22
- email_validator==2.1.1
23
- exceptiongroup==1.2.1
24
- fastapi==0.111.0
25
- fastapi-cli==0.0.4
26
- ffmpy==0.3.2
27
- filelock==3.14.0
28
- fonttools==4.52.4
29
- fsspec==2024.5.0
30
- gitdb==4.0.11
31
- GitPython==3.1.43
32
- Glances==4.0.7
33
- gradio==4.31.5
34
- gradio_client==0.16.4
35
- h11==0.14.0
36
- httpcore==1.0.5
37
- httptools==0.6.1
38
- httpx==0.27.0
39
- huggingface-hub==0.23.2
40
- idna==3.7
41
- imageio==2.34.1
42
- importlib_resources==6.4.0
43
- Jinja2==3.1.4
44
- jsonschema==4.22.0
45
- jsonschema-specifications==2023.12.1
46
- kiwisolver==1.4.5
47
- lazy_loader==0.4
48
- markdown-it-py==3.0.0
49
- MarkupSafe==2.1.5
50
- matplotlib==3.9.0
51
- mdurl==0.1.2
52
- networkx==3.3
53
- numpy==1.26.4
54
- orjson==3.10.3
55
- packaging==24.0
56
- pandas==2.2.2
57
- Pillow==9.5.0
58
- pip-tools==7.4.1
59
- platformdirs==4.2.2
60
- protobuf==4.25.3
61
- psutil==5.9.8
62
- py3nvml==0.2.7
63
- pyarrow==16.1.0
64
- pycparser==2.22
65
- pydantic==2.7.2
66
- pydantic_core==2.18.3
67
- pydeck==0.9.1
68
- pydub==0.25.1
69
- Pygments==2.18.0
70
- pyparsing==3.1.2
71
- pyproject_hooks==1.1.0
72
- pythainlp==5.0.3
73
- python-dateutil==2.9.0.post0
74
- python-dotenv==1.0.1
75
- python-multipart==0.0.9
76
- pytz==2024.1
77
- PyYAML==6.0.1
78
- referencing==0.35.1
79
- requests==2.32.2
80
- rich==13.7.1
81
- rpds-py==0.18.1
82
- ruff==0.4.6
83
- safetensors==0.4.3
84
- scikit-image==0.23.2
85
- scipy==1.13.1
86
- semantic-version==2.10.0
87
- sentry-sdk==2.3.1
88
- setproctitle==1.3.3
89
- shellingham==1.5.4
90
- six==1.16.0
91
- smmap==5.0.1
92
- sniffio==1.3.1
93
- starlette==0.37.2
94
- streamlit==1.35.0
95
- tenacity==8.3.0
96
- tifffile==2024.5.22
97
- timm==1.0.3
98
- tinycss2==1.3.0
99
- toml==0.10.2
100
- tomli==2.0.1
101
- tomlkit==0.12.0
102
- toolz==0.12.1
103
- torch==1.13.1
104
- torchaudio==0.13.1
105
- torchvision==0.14.1
106
- tornado==6.4
107
- tqdm==4.66.4
108
- typer==0.12.3
109
- typing_extensions==4.12.0
110
- tzdata==2024.1
111
- ujson==5.10.0
112
- urllib3==2.2.1
113
- uvicorn==0.30.0
114
- uvloop==0.19.0
115
- wandb==0.17.0
116
- watchdog==4.0.1
117
- watchfiles==0.22.0
118
- webencodings==0.5.1
119
- websockets==11.0.3
120
- xmltodict==0.13.0
 
1
+ aiofiles==23.2.1
2
+ altair==5.3.0
3
+ annotated-types==0.7.0
4
+ anyio==4.4.0
5
+ attrs==23.2.0
6
+ blinker==1.8.2
7
+ build==1.2.1
8
+ cachetools==5.3.3
9
+ cairocffi==1.7.0
10
+ CairoSVG==2.7.1
11
+ certifi==2024.2.2
12
+ cffi==1.16.0
13
+ charset-normalizer==3.3.2
14
+ click==8.1.7
15
+ contourpy==1.2.1
16
+ cssselect2==0.7.0
17
+ cycler==0.12.1
18
+ defusedxml==0.7.1
19
+ dnspython==2.6.1
20
+ docker-pycreds==0.4.0
21
+ einops==0.8.0
22
+ email_validator==2.1.1
23
+ exceptiongroup==1.2.1
24
+ fastapi==0.111.0
25
+ fastapi-cli==0.0.4
26
+ ffmpy==0.3.2
27
+ filelock==3.14.0
28
+ fonttools==4.52.4
29
+ fsspec==2024.5.0
30
+ gitdb==4.0.11
31
+ GitPython==3.1.43
32
+ Glances==4.0.7
33
+ gradio==4.31.5
34
+ gradio_client==0.16.4
35
+ h11==0.14.0
36
+ httpcore==1.0.5
37
+ httptools==0.6.1
38
+ httpx==0.27.0
39
+ huggingface-hub==0.23.2
40
+ idna==3.7
41
+ imageio==2.34.1
42
+ importlib_resources==6.4.0
43
+ Jinja2==3.1.4
44
+ jsonschema==4.22.0
45
+ jsonschema-specifications==2023.12.1
46
+ kiwisolver==1.4.5
47
+ lazy_loader==0.4
48
+ markdown-it-py==3.0.0
49
+ MarkupSafe==2.1.5
50
+ matplotlib==3.9.0
51
+ mdurl==0.1.2
52
+ networkx==3.3
53
+ numpy==1.26.4
54
+ orjson==3.10.3
55
+ packaging==24.0
56
+ pandas==2.2.2
57
+ Pillow==9.5.0
58
+ pip-tools==7.4.1
59
+ platformdirs==4.2.2
60
+ protobuf==4.25.3
61
+ psutil==5.9.8
62
+ py3nvml==0.2.7
63
+ pyarrow==16.1.0
64
+ pycparser==2.22
65
+ pydantic==2.7.2
66
+ pydantic_core==2.18.3
67
+ pydeck==0.9.1
68
+ pydub==0.25.1
69
+ Pygments==2.18.0
70
+ pyparsing==3.1.2
71
+ pyproject_hooks==1.1.0
72
+ pythainlp==5.0.3
73
+ python-dateutil==2.9.0.post0
74
+ python-dotenv==1.0.1
75
+ python-multipart==0.0.9
76
+ pytz==2024.1
77
+ PyYAML==6.0.1
78
+ referencing==0.35.1
79
+ requests==2.32.2
80
+ rich==13.7.1
81
+ rpds-py==0.18.1
82
+ ruff==0.4.6
83
+ safetensors==0.4.3
84
+ scikit-image==0.23.2
85
+ scipy==1.13.1
86
+ semantic-version==2.10.0
87
+ sentry-sdk==2.3.1
88
+ setproctitle==1.3.3
89
+ shellingham==1.5.4
90
+ six==1.16.0
91
+ smmap==5.0.1
92
+ sniffio==1.3.1
93
+ starlette==0.37.2
94
+ streamlit==1.35.0
95
+ tenacity==8.3.0
96
+ tifffile==2024.5.22
97
+ timm==1.0.3
98
+ tinycss2==1.3.0
99
+ toml==0.10.2
100
+ tomli==2.0.1
101
+ tomlkit==0.12.0
102
+ toolz==0.12.1
103
+ torch==1.13.1
104
+ torchaudio==0.13.1
105
+ torchvision==0.14.1
106
+ tornado==6.4
107
+ tqdm==4.66.4
108
+ typer==0.12.3
109
+ typing_extensions==4.12.0
110
+ tzdata==2024.1
111
+ ujson==5.10.0
112
+ urllib3==2.2.1
113
+ uvicorn==0.30.0
114
+ uvloop==0.19.0
115
+ wandb==0.17.0
116
+ watchdog==4.0.1
117
+ watchfiles==0.22.0
118
+ webencodings==0.5.1
119
+ websockets==11.0.3
120
+ xmltodict==0.13.0
test.py CHANGED
@@ -1,60 +1,60 @@
1
- import os
2
- import numpy as np
3
- import torch
4
- import torch.nn.functional as F
5
- from torchvision.utils import save_image
6
- from dataloader import get_loader
7
- from models.model_main import ModelMain
8
- from models.transformers import denumericalize
9
- from options import get_parser_main_model
10
- from data_utils.svg_utils import render
11
- from models.util_funcs import svg2img, cal_iou
12
-
13
- # Testing (Only accuracy)
14
-
15
- def test_main_model(opts):
16
- test_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size, 'test')
17
-
18
- model_main = ModelMain(opts)
19
- path_ckpt = os.path.join(f"{opts.model_path}")
20
- model_main.load_state_dict(torch.load(path_ckpt)['model'])
21
- model_main.cuda()
22
- model_main.eval() # Testing mode
23
-
24
- with torch.no_grad():
25
- loss_val = {'img':{'l1':0.0, 'vggpt':0.0}, 'svg':{'total':0.0, 'cmd':0.0, 'args':0.0, 'aux':0.0},
26
- 'svg_para':{'total':0.0, 'cmd':0.0, 'args':0.0, 'aux':0.0}}
27
-
28
- for val_idx, val_data in enumerate(test_loader):
29
- for key in val_data: val_data[key] = val_data[key].cuda()
30
- ret_dict_val, loss_dict_val = model_main(val_data, mode='val')
31
- for loss_cat in ['img', 'svg']:
32
- for key, _ in loss_val[loss_cat].items():
33
- loss_val[loss_cat][key] += loss_dict_val[loss_cat][key]
34
-
35
- for loss_cat in ['img', 'svg']:
36
- for key, _ in loss_val[loss_cat].items():
37
- loss_val[loss_cat][key] /= len(test_loader)
38
-
39
- val_msg = (
40
- f"Val loss img l1: {loss_val['img']['l1']: .6f}, "
41
- f"Val loss img pt: {loss_val['img']['vggpt']: .6f}, "
42
- f"Val loss total: {loss_val['svg']['total']: .6f}, "
43
- f"Val loss cmd: {loss_val['svg']['cmd']: .6f}, "
44
- f"Val loss args: {loss_val['svg']['args']: .6f}, "
45
- )
46
-
47
- print(val_msg)
48
- print(f"l1: {loss_val['img']['l1']: .6f}, pt: {loss_val['img']['vggpt']: .6f}")
49
-
50
- def main():
51
-
52
- opts = get_parser_main_model().parse_args()
53
- opts.name_exp = opts.name_exp + '_' + opts.model_name
54
- experiment_dir = os.path.join(f"{opts.exp_path}","experiments", opts.name_exp)
55
- print(f"Testing on experiment {opts.name_exp}...")
56
- # Dump options
57
- test_main_model(opts)
58
-
59
- if __name__ == "__main__":
60
  main()
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torchvision.utils import save_image
6
+ from dataloader import get_loader
7
+ from models.model_main import ModelMain
8
+ from models.transformers import denumericalize
9
+ from options import get_parser_main_model
10
+ from data_utils.svg_utils import render
11
+ from models.util_funcs import svg2img, cal_iou
12
+
13
+ # Testing (Only accuracy)
14
+
15
+ def test_main_model(opts):
16
+ test_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size, 'test')
17
+
18
+ model_main = ModelMain(opts)
19
+ path_ckpt = os.path.join(f"{opts.model_path}")
20
+ model_main.load_state_dict(torch.load(path_ckpt)['model'])
21
+ model_main.cuda()
22
+ model_main.eval() # Testing mode
23
+
24
+ with torch.no_grad():
25
+ loss_val = {'img':{'l1':0.0, 'vggpt':0.0}, 'svg':{'total':0.0, 'cmd':0.0, 'args':0.0, 'aux':0.0},
26
+ 'svg_para':{'total':0.0, 'cmd':0.0, 'args':0.0, 'aux':0.0}}
27
+
28
+ for val_idx, val_data in enumerate(test_loader):
29
+ for key in val_data: val_data[key] = val_data[key].cuda()
30
+ ret_dict_val, loss_dict_val = model_main(val_data, mode='val')
31
+ for loss_cat in ['img', 'svg']:
32
+ for key, _ in loss_val[loss_cat].items():
33
+ loss_val[loss_cat][key] += loss_dict_val[loss_cat][key]
34
+
35
+ for loss_cat in ['img', 'svg']:
36
+ for key, _ in loss_val[loss_cat].items():
37
+ loss_val[loss_cat][key] /= len(test_loader)
38
+
39
+ val_msg = (
40
+ f"Val loss img l1: {loss_val['img']['l1']: .6f}, "
41
+ f"Val loss img pt: {loss_val['img']['vggpt']: .6f}, "
42
+ f"Val loss total: {loss_val['svg']['total']: .6f}, "
43
+ f"Val loss cmd: {loss_val['svg']['cmd']: .6f}, "
44
+ f"Val loss args: {loss_val['svg']['args']: .6f}, "
45
+ )
46
+
47
+ print(val_msg)
48
+ print(f"l1: {loss_val['img']['l1']: .6f}, pt: {loss_val['img']['vggpt']: .6f}")
49
+
50
+ def main():
51
+
52
+ opts = get_parser_main_model().parse_args()
53
+ opts.name_exp = opts.name_exp + '_' + opts.model_name
54
+ experiment_dir = os.path.join(f"{opts.exp_path}","experiments", opts.name_exp)
55
+ print(f"Testing on experiment {opts.name_exp}...")
56
+ # Dump options
57
+ test_main_model(opts)
58
+
59
+ if __name__ == "__main__":
60
  main()
test_few_shot.py CHANGED
@@ -1,164 +1,164 @@
1
- import os
2
- import numpy as np
3
- import torch
4
- import torch.nn.functional as F
5
- from torchvision.utils import save_image
6
- from dataloader import get_loader
7
- from models.model_main import ModelMain
8
- from models.transformers import denumericalize
9
- from options import get_parser_main_model
10
- from data_utils.svg_utils import render
11
- from models.util_funcs import svg2img, cal_iou
12
- from tqdm import tqdm
13
- from PIL import Image
14
-
15
- def test_main_model(opts):
16
- if opts.streamlit:
17
- import streamlit as st
18
-
19
- if opts.dir_res:
20
- os.mkdir(os.path.join(opts.dir_res, "results"))
21
- dir_res = os.path.join(opts.dir_res, "results")
22
- else:
23
- dir_res = os.path.join(f"{opts.exp_path}", "experiments/", opts.name_exp, "results")
24
-
25
- test_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size, 'test')
26
- if torch.cuda.is_available():
27
- device = torch.device("cuda")
28
- else:
29
- device = torch.device("cpu")
30
- if opts.streamlit:
31
- st.write("Loading Model Weight...")
32
- model_main = ModelMain(opts)
33
- path_ckpt = os.path.join(f"{opts.model_path}")
34
- model_main.load_state_dict(torch.load(path_ckpt)['model'])
35
- model_main.to(device)
36
- model_main.eval()
37
- with torch.no_grad():
38
-
39
- for test_idx, test_data in enumerate(test_loader):
40
- for key in test_data: test_data[key] = test_data[key].to(device)
41
-
42
- print("testing font %04d ..."%test_idx)
43
-
44
- dir_save = os.path.join(dir_res, "%04d"%test_idx)
45
- if not os.path.exists(dir_save):
46
- os.mkdir(dir_save)
47
- os.mkdir(os.path.join(dir_save, "imgs"))
48
- os.mkdir(os.path.join(dir_save, "svgs_single"))
49
- os.mkdir(os.path.join(dir_save, "svgs_merge"))
50
- svg_merge_dir = os.path.join(dir_save, "svgs_merge")
51
-
52
- iou_max = np.zeros(opts.char_num)
53
- idx_best_sample = np.zeros(opts.char_num)
54
-
55
- # syn_svg_merge_f = open(os.path.join(svg_merge_dir, f"{opts.name_ckpt}_syn_merge_{test_idx}_rand_{sample_idx}.html"), 'w')
56
- syn_svg_merge_f = open(os.path.join(svg_merge_dir, f"{opts.name_ckpt}_syn_merge_{test_idx}.html"), 'w')
57
-
58
- for sample_idx in tqdm(range(opts.n_samples)):
59
-
60
- ret_dict_test, loss_dict_test = model_main(test_data, mode='test')
61
-
62
- svg_sampled = ret_dict_test['svg']['sampled_1']
63
- sampled_svg_2 = ret_dict_test['svg']['sampled_2']
64
-
65
- img_trg = ret_dict_test['img']['trg']
66
- img_output = ret_dict_test['img']['out']
67
- trg_seq_gt = ret_dict_test['svg']['trg']
68
-
69
- img_sample_merge = torch.cat((img_trg.data, img_output.data), -2)
70
- save_file_merge = os.path.join(dir_save, "imgs", f"merge_{opts.img_size}.png")
71
- save_image(img_sample_merge, save_file_merge, nrow=8, normalize=True)
72
- if opts.streamlit:
73
- st.progress((sample_idx+1)/opts.n_samples, f"Generating Font Sample {sample_idx+1} Please wait...")
74
- im = Image.open(save_file_merge)
75
- st.image(im, caption='img_sample_merge')
76
-
77
- for char_idx in range(opts.char_num):
78
- img_gt = (1.0 - img_trg[char_idx,...]).data
79
- save_file_gt = os.path.join(dir_save,"imgs", f"{char_idx:02d}_gt.png")
80
- save_image(img_gt, save_file_gt, normalize=True)
81
-
82
- img_sample = (1.0 - img_output[char_idx,...]).data
83
- save_file = os.path.join(dir_save,"imgs", f"{char_idx:02d}_{opts.img_size}.png")
84
- save_image(img_sample, save_file, normalize=True)
85
-
86
- # write results w/o parallel refinement
87
- svg_dec_out = svg_sampled.clone().detach()
88
- for i, one_seq in enumerate(svg_dec_out):
89
- syn_svg_outfile = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{sample_idx}_wo_refine.svg")
90
-
91
- syn_svg_f_ = open(syn_svg_outfile, 'w')
92
- try:
93
- svg = render(one_seq.cpu().numpy())
94
- syn_svg_f_.write(svg)
95
- # syn_svg_merge_f.write(svg)
96
- if i > 0 and i % 13 == 12:
97
- syn_svg_f_.write('<br>')
98
- # syn_svg_merge_f.write('<br>')
99
-
100
- except:
101
- continue
102
- syn_svg_f_.close()
103
-
104
- # write results w/ parallel refinement
105
- svg_dec_out = sampled_svg_2.clone().detach()
106
- for i, one_seq in enumerate(svg_dec_out):
107
- syn_svg_outfile = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{sample_idx}_refined.svg")
108
-
109
- syn_svg_f = open(syn_svg_outfile, 'w')
110
- try:
111
- svg = render(one_seq.cpu().numpy())
112
- syn_svg_f.write(svg)
113
- #syn_svg_merge_f.write(svg)
114
-
115
- #if i > 0 and i % 13 == 12:
116
- # syn_svg_merge_f.write('<br>')
117
- except:
118
- continue
119
- syn_svg_f.close()
120
- syn_img_outfile = syn_svg_outfile.replace('.svg', '.png')
121
- svg2img(syn_svg_outfile, syn_img_outfile, img_size=opts.img_size)
122
- iou_tmp, l1_tmp = cal_iou(syn_img_outfile, os.path.join(dir_save, "imgs", f"{i:02d}_{opts.img_size}.png"))
123
- iou_tmp = iou_tmp
124
- if iou_tmp > iou_max[i]:
125
- iou_max[i] = iou_tmp
126
- idx_best_sample[i] = sample_idx
127
-
128
- for i in range(opts.char_num):
129
- # print(idx_best_sample[i])
130
- syn_svg_outfile_best = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{int(idx_best_sample[i])}_refined.svg")
131
- syn_svg_merge_f.write(open(syn_svg_outfile_best, 'r').read())
132
- if i > 0 and i % 13 == 12:
133
- syn_svg_merge_f.write('<br>')
134
-
135
- svg_target = trg_seq_gt.clone().detach()
136
- tgt_commands_onehot = F.one_hot(svg_target[:, :, :1].long(), 4).squeeze()
137
- tgt_args_denum = denumericalize(svg_target[:, :, 1:])
138
- svg_target = torch.cat([tgt_commands_onehot, tgt_args_denum], dim=-1)
139
-
140
- for i, one_gt_seq in enumerate(svg_target):
141
- # gt_svg_outfile = os.path.join(os.path.join(dir_save, "svgs_single"), f"gt_{i:02d}.svg")
142
- # gt_svg_f = open(gt_svg_outfile, 'w')
143
- gt_svg = render(one_gt_seq.cpu().numpy())
144
- # gt_svg_f.write(gt_svg)
145
- syn_svg_merge_f.write(gt_svg)
146
- # gt_svg_f.close()
147
- if i > 0 and i % 13 == 12:
148
- syn_svg_merge_f.write('<br>')
149
-
150
- syn_svg_merge_f.close()
151
-
152
- return im
153
-
154
- def main():
155
-
156
- opts = get_parser_main_model().parse_args()
157
- opts.name_exp = opts.name_exp + '_' + opts.model_name
158
- experiment_dir = os.path.join(f"{opts.exp_path}","experiments", opts.name_exp)
159
- print(f"Testing on experiment {opts.name_exp}...")
160
- # Dump options
161
- test_main_model(opts)
162
-
163
- if __name__ == "__main__":
164
  main()
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torchvision.utils import save_image
6
+ from dataloader import get_loader
7
+ from models.model_main import ModelMain
8
+ from models.transformers import denumericalize
9
+ from options import get_parser_main_model
10
+ from data_utils.svg_utils import render
11
+ from models.util_funcs import svg2img, cal_iou
12
+ from tqdm import tqdm
13
+ from PIL import Image
14
+
15
+ def test_main_model(opts):
16
+ if opts.streamlit:
17
+ import streamlit as st
18
+
19
+ if opts.dir_res:
20
+ os.mkdir(os.path.join(opts.dir_res, "results"))
21
+ dir_res = os.path.join(opts.dir_res, "results")
22
+ else:
23
+ dir_res = os.path.join(f"{opts.exp_path}", "experiments/", opts.name_exp, "results")
24
+
25
+ test_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size, 'test')
26
+ if torch.cuda.is_available():
27
+ device = torch.device("cuda")
28
+ else:
29
+ device = torch.device("cpu")
30
+ if opts.streamlit:
31
+ st.write("Loading Model Weight...")
32
+ model_main = ModelMain(opts)
33
+ path_ckpt = os.path.join(f"{opts.model_path}")
34
+ model_main.load_state_dict(torch.load(path_ckpt)['model'])
35
+ model_main.to(device)
36
+ model_main.eval()
37
+ with torch.no_grad():
38
+
39
+ for test_idx, test_data in enumerate(test_loader):
40
+ for key in test_data: test_data[key] = test_data[key].to(device)
41
+
42
+ print("testing font %04d ..."%test_idx)
43
+
44
+ dir_save = os.path.join(dir_res, "%04d"%test_idx)
45
+ if not os.path.exists(dir_save):
46
+ os.mkdir(dir_save)
47
+ os.mkdir(os.path.join(dir_save, "imgs"))
48
+ os.mkdir(os.path.join(dir_save, "svgs_single"))
49
+ os.mkdir(os.path.join(dir_save, "svgs_merge"))
50
+ svg_merge_dir = os.path.join(dir_save, "svgs_merge")
51
+
52
+ iou_max = np.zeros(opts.char_num)
53
+ idx_best_sample = np.zeros(opts.char_num)
54
+
55
+ # syn_svg_merge_f = open(os.path.join(svg_merge_dir, f"{opts.name_ckpt}_syn_merge_{test_idx}_rand_{sample_idx}.html"), 'w')
56
+ syn_svg_merge_f = open(os.path.join(svg_merge_dir, f"{opts.name_ckpt}_syn_merge_{test_idx}.html"), 'w')
57
+
58
+ for sample_idx in tqdm(range(opts.n_samples)):
59
+
60
+ ret_dict_test, loss_dict_test = model_main(test_data, mode='test')
61
+
62
+ svg_sampled = ret_dict_test['svg']['sampled_1']
63
+ sampled_svg_2 = ret_dict_test['svg']['sampled_2']
64
+
65
+ img_trg = ret_dict_test['img']['trg']
66
+ img_output = ret_dict_test['img']['out']
67
+ trg_seq_gt = ret_dict_test['svg']['trg']
68
+
69
+ img_sample_merge = torch.cat((img_trg.data, img_output.data), -2)
70
+ save_file_merge = os.path.join(dir_save, "imgs", f"merge_{opts.img_size}.png")
71
+ save_image(img_sample_merge, save_file_merge, nrow=8, normalize=True)
72
+ if opts.streamlit:
73
+ st.progress((sample_idx+1)/opts.n_samples, f"Generating Font Sample {sample_idx+1} Please wait...")
74
+ im = Image.open(save_file_merge)
75
+ st.image(im, caption='img_sample_merge')
76
+
77
+ for char_idx in range(opts.char_num):
78
+ img_gt = (1.0 - img_trg[char_idx,...]).data
79
+ save_file_gt = os.path.join(dir_save,"imgs", f"{char_idx:02d}_gt.png")
80
+ save_image(img_gt, save_file_gt, normalize=True)
81
+
82
+ img_sample = (1.0 - img_output[char_idx,...]).data
83
+ save_file = os.path.join(dir_save,"imgs", f"{char_idx:02d}_{opts.img_size}.png")
84
+ save_image(img_sample, save_file, normalize=True)
85
+
86
+ # write results w/o parallel refinement
87
+ svg_dec_out = svg_sampled.clone().detach()
88
+ for i, one_seq in enumerate(svg_dec_out):
89
+ syn_svg_outfile = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{sample_idx}_wo_refine.svg")
90
+
91
+ syn_svg_f_ = open(syn_svg_outfile, 'w')
92
+ try:
93
+ svg = render(one_seq.cpu().numpy())
94
+ syn_svg_f_.write(svg)
95
+ # syn_svg_merge_f.write(svg)
96
+ if i > 0 and i % 13 == 12:
97
+ syn_svg_f_.write('<br>')
98
+ # syn_svg_merge_f.write('<br>')
99
+
100
+ except:
101
+ continue
102
+ syn_svg_f_.close()
103
+
104
+ # write results w/ parallel refinement
105
+ svg_dec_out = sampled_svg_2.clone().detach()
106
+ for i, one_seq in enumerate(svg_dec_out):
107
+ syn_svg_outfile = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{sample_idx}_refined.svg")
108
+
109
+ syn_svg_f = open(syn_svg_outfile, 'w')
110
+ try:
111
+ svg = render(one_seq.cpu().numpy())
112
+ syn_svg_f.write(svg)
113
+ #syn_svg_merge_f.write(svg)
114
+
115
+ #if i > 0 and i % 13 == 12:
116
+ # syn_svg_merge_f.write('<br>')
117
+ except:
118
+ continue
119
+ syn_svg_f.close()
120
+ syn_img_outfile = syn_svg_outfile.replace('.svg', '.png')
121
+ svg2img(syn_svg_outfile, syn_img_outfile, img_size=opts.img_size)
122
+ iou_tmp, l1_tmp = cal_iou(syn_img_outfile, os.path.join(dir_save, "imgs", f"{i:02d}_{opts.img_size}.png"))
123
+ iou_tmp = iou_tmp
124
+ if iou_tmp > iou_max[i]:
125
+ iou_max[i] = iou_tmp
126
+ idx_best_sample[i] = sample_idx
127
+
128
+ for i in range(opts.char_num):
129
+ # print(idx_best_sample[i])
130
+ syn_svg_outfile_best = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{int(idx_best_sample[i])}_refined.svg")
131
+ syn_svg_merge_f.write(open(syn_svg_outfile_best, 'r').read())
132
+ if i > 0 and i % 13 == 12:
133
+ syn_svg_merge_f.write('<br>')
134
+
135
+ svg_target = trg_seq_gt.clone().detach()
136
+ tgt_commands_onehot = F.one_hot(svg_target[:, :, :1].long(), 4).squeeze()
137
+ tgt_args_denum = denumericalize(svg_target[:, :, 1:])
138
+ svg_target = torch.cat([tgt_commands_onehot, tgt_args_denum], dim=-1)
139
+
140
+ for i, one_gt_seq in enumerate(svg_target):
141
+ # gt_svg_outfile = os.path.join(os.path.join(dir_save, "svgs_single"), f"gt_{i:02d}.svg")
142
+ # gt_svg_f = open(gt_svg_outfile, 'w')
143
+ gt_svg = render(one_gt_seq.cpu().numpy())
144
+ # gt_svg_f.write(gt_svg)
145
+ syn_svg_merge_f.write(gt_svg)
146
+ # gt_svg_f.close()
147
+ if i > 0 and i % 13 == 12:
148
+ syn_svg_merge_f.write('<br>')
149
+
150
+ syn_svg_merge_f.close()
151
+
152
+ return im
153
+
154
+ def main():
155
+
156
+ opts = get_parser_main_model().parse_args()
157
+ opts.name_exp = opts.name_exp + '_' + opts.model_name
158
+ experiment_dir = os.path.join(f"{opts.exp_path}","experiments", opts.name_exp)
159
+ print(f"Testing on experiment {opts.name_exp}...")
160
+ # Dump options
161
+ test_main_model(opts)
162
+
163
+ if __name__ == "__main__":
164
  main()
train.py CHANGED
@@ -1,216 +1,216 @@
1
- import os
2
- import random
3
- import numpy as np
4
- import shutil
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- from torch.optim import Adam, AdamW
9
- from torchvision.utils import save_image
10
- import wandb
11
- from dataloader import get_loader
12
- from models import util_funcs
13
- from models.model_main import ModelMain
14
- from options import get_parser_main_model
15
- from data_utils.svg_utils import render
16
- from time import time
17
-
18
- def setup_seed(seed):
19
- torch.manual_seed(seed)
20
- torch.cuda.manual_seed_all(seed)
21
- np.random.seed(seed)
22
- random.seed(seed)
23
- torch.backends.cudnn.deterministic = True
24
-
25
- def train_main_model(opts):
26
- setup_seed(opts.seed)
27
- dir_exp = os.path.join(f"{opts.exp_path}", "experiments", opts.name_exp)
28
- dir_sample = os.path.join(dir_exp, "samples")
29
- dir_ckpt = os.path.join(dir_exp, "checkpoints")
30
- dir_log = os.path.join(dir_exp, "logs")
31
- logfile_train = open(os.path.join(dir_log, "train_loss_log.txt"), 'w')
32
- logfile_val = open(os.path.join(dir_log, "val_loss_log.txt"), 'w')
33
-
34
- train_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size, opts.mode)
35
- val_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size_val, 'val')
36
-
37
- run = wandb.init(project=opts.wandb_project_name, config=opts) # initialize wandb project
38
- text_table = wandb.Table(columns=["epoch", "loss", "ref"])
39
-
40
- model_main = ModelMain(opts)
41
- if torch.cuda.is_available() and opts.multi_gpu:
42
- model_main = torch.nn.DataParallel(model_main)
43
-
44
- if opts.continue_training:
45
- model_main.load_state_dict(torch.load(opts.continue_ckpt)['model'])
46
-
47
- model_main.cuda()
48
-
49
- parameters_all = [{"params": model_main.img_encoder.parameters()}, {"params": model_main.img_decoder.parameters()},
50
- {"params": model_main.modality_fusion.parameters()}, {"params": model_main.transformer_main.parameters()},
51
- {"params": model_main.transformer_seqdec.parameters()}]
52
-
53
- optimizer = AdamW(parameters_all, lr=opts.lr, betas=(opts.beta1, opts.beta2), eps=opts.eps, weight_decay=opts.weight_decay)
54
- scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.997)
55
-
56
- for epoch in range(opts.init_epoch, opts.n_epochs):
57
- t0 = time()
58
- for idx, data in enumerate(train_loader):
59
- for key in data: data[key] = data[key].cuda()
60
- ret_dict, loss_dict = model_main(data)
61
-
62
- loss = opts.loss_w_l1 * loss_dict['img']['l1'] + opts.loss_w_pt_c * loss_dict['img']['vggpt'] + opts.kl_beta * loss_dict['kl'] \
63
- + loss_dict['svg']['total'] + loss_dict['svg_para']['total']
64
-
65
- # perform optimization
66
- optimizer.zero_grad()
67
- loss.backward()
68
- optimizer.step()
69
- batches_done = epoch * len(train_loader) + idx + 1
70
- message = (
71
- f"Time: {'{} seconds'.format(time() - t0)}, "
72
- f"Epoch: {epoch}/{opts.n_epochs}, Batch: {idx}/{len(train_loader)}, "
73
- f"Loss: {loss.item():.6f}, "
74
- f"img_l1_loss: {opts.loss_w_l1 * loss_dict['img']['l1'].item():.6f}, "
75
- f"img_pt_c_loss: {opts.loss_w_pt_c * loss_dict['img']['vggpt']:.6f}, "
76
- f"svg_total_loss: {loss_dict['svg']['total'].item():.6f}, "
77
- f"svg_cmd_loss: {opts.loss_w_cmd * loss_dict['svg']['cmd'].item():.6f}, "
78
- f"svg_args_loss: {opts.loss_w_args * loss_dict['svg']['args'].item():.6f}, "
79
- f"svg_smooth_loss: {opts.loss_w_smt * loss_dict['svg']['smt'].item():.6f}, "
80
- f"svg_aux_loss: {opts.loss_w_aux * loss_dict['svg']['aux'].item():.6f}, "
81
- f"lr: {optimizer.param_groups[0]['lr']:.6f}, "
82
- f"Step: {batches_done}"
83
- )
84
- if batches_done % opts.freq_log == 0:
85
- logfile_train.write(message + '\n')
86
- print(message)
87
-
88
- if opts.wandb:
89
- # print("Running With Wandb")
90
- # Define the items for image and SVG losses
91
- loss_img_items = ['l1', 'vggpt']
92
- loss_svg_items = ['total', 'cmd', 'args', 'aux', 'smt']
93
-
94
- # Log image loss items
95
- for item in loss_img_items:
96
- wandb.log({f'Loss/img_{item}': loss_dict['img'][item].item()}, step=batches_done)
97
-
98
- # Log SVG loss items
99
- for item in loss_svg_items:
100
- wandb.log({f'Loss/svg_{item}': loss_dict['svg'][item].item()}, step=batches_done)
101
- wandb.log({f'Loss/svg_para_{item}': loss_dict['svg_para'][item].item()}, step=batches_done)
102
-
103
- # Log KL loss
104
- wandb.log({'Loss/img_kl_loss': opts.kl_beta * loss_dict['kl'].item()}, step=batches_done)
105
-
106
- wandb.log({
107
- 'Images/trg_img': wandb.Image(ret_dict['img']['trg'][0], caption="Target"),
108
- 'Images/img_output': wandb.Image(ret_dict['img']['out'][0], caption="Output")
109
- }, step=batches_done)
110
-
111
- text_table.add_data(epoch, loss, str(ret_dict['img']['ref'][0]))
112
- wandb.log({"training_samples" : text_table})
113
-
114
-
115
-
116
- if opts.freq_sample > 0 and batches_done % opts.freq_sample == 0:
117
-
118
- img_sample = torch.cat((ret_dict['img']['trg'].data, ret_dict['img']['out'].data), -2)
119
- save_file = os.path.join(dir_sample, f"train_epoch_{epoch}_batch_{batches_done}.png")
120
- save_image(img_sample, save_file, nrow=8, normalize=True)
121
-
122
- if opts.freq_val > 0 and batches_done % opts.freq_val == 0:
123
-
124
- with torch.no_grad():
125
- model_main.eval()
126
- loss_val = {'img':{'l1':0.0, 'vggpt':0.0}, 'svg':{'total':0.0, 'cmd':0.0, 'args':0.0, 'aux':0.0},
127
- 'svg_para':{'total':0.0, 'cmd':0.0, 'args':0.0, 'aux':0.0}}
128
-
129
- for val_idx, val_data in enumerate(val_loader):
130
- for key in val_data: val_data[key] = val_data[key].cuda()
131
- ret_dict_val, loss_dict_val = model_main(val_data, mode='val')
132
- for loss_cat in ['img', 'svg']:
133
- for key, _ in loss_val[loss_cat].items():
134
- loss_val[loss_cat][key] += loss_dict_val[loss_cat][key]
135
-
136
- for loss_cat in ['img', 'svg']:
137
- for key, _ in loss_val[loss_cat].items():
138
- loss_val[loss_cat][key] /= len(val_loader)
139
-
140
- if opts.wandb:
141
- for loss_cat in ['img', 'svg']:
142
- # Iterate over keys and values in the loss dictionary
143
- for key, value in loss_val[loss_cat].items():
144
- # Log loss value to WandB
145
- wandb.log({f'VAL/loss_{loss_cat}_{key}': value})
146
-
147
- val_msg = (
148
- f"Epoch: {epoch}/{opts.n_epochs}, Batch: {idx}/{len(train_loader)}, "
149
- f"Val loss img l1: {loss_val['img']['l1']: .6f}, "
150
- f"Val loss img pt: {loss_val['img']['vggpt']: .6f}, "
151
- f"Val loss total: {loss_val['svg']['total']: .6f}, "
152
- f"Val loss cmd: {loss_val['svg']['cmd']: .6f}, "
153
- f"Val loss args: {loss_val['svg']['args']: .6f}, "
154
- )
155
-
156
- logfile_val.write(val_msg + "\n")
157
- print(val_msg)
158
-
159
-
160
- scheduler.step()
161
-
162
- if epoch % opts.freq_ckpt == 0 and epoch >= opts.threshold_ckpt:
163
- if opts.multi_gpu:
164
- print(f"Saved {dir_ckpt}/{epoch}_{batches_done}.ckpt")
165
- torch.save({'model':model_main.module.state_dict(), 'opt':optimizer.state_dict(), 'n_epoch':epoch, 'n_iter':batches_done}, f'{dir_ckpt}/{epoch}_{batches_done}.ckpt')
166
- else:
167
- print(f"Saved {dir_ckpt}/{epoch}_{batches_done}.ckpt")
168
- torch.save({'model':model_main.state_dict(), 'opt':optimizer.state_dict(), 'n_epoch':epoch, 'n_iter':batches_done}, f'{dir_ckpt}/{epoch}_{batches_done}.ckpt')
169
- if opts.wandb:
170
- artifact = wandb.Artifact('model_main_checkpoints', type='model')
171
- artifact.add_file(f'{dir_ckpt}/{epoch}_{batches_done}.ckpt')
172
- run.log_artifact(artifact)
173
-
174
- logfile_train.close()
175
- logfile_val.close()
176
-
177
- def backup_code(name_exp, exp_path):
178
- os.makedirs(os.path.join(exp_path,'experiments', name_exp, 'code'), exist_ok=True)
179
- shutil.copy('models/transformers.py', os.path.join(exp_path,'experiments', name_exp, 'code', 'transformers.py') )
180
- shutil.copy('models/model_main.py', os.path.join(exp_path,'experiments', name_exp, 'code', 'model_main.py'))
181
- shutil.copy('models/image_encoder.py', os.path.join(exp_path,'experiments', name_exp, 'code', 'image_encoder.py'))
182
- shutil.copy('models/image_decoder.py', os.path.join(exp_path,'experiments', name_exp, 'code', 'image_decoder.py'))
183
- shutil.copy('./train.py', os.path.join(exp_path,'experiments', name_exp, 'code', 'train.py'))
184
- shutil.copy('./options.py', os.path.join(exp_path,'experiments', name_exp, 'code', 'options.py'))
185
-
186
- def train(opts):
187
- if opts.model_name == 'main_model':
188
- train_main_model(opts)
189
- elif opts.model_name == 'others':
190
- train_others(opts)
191
- else:
192
- raise NotImplementedError
193
-
194
- def main():
195
-
196
- opts = get_parser_main_model().parse_args()
197
- opts.name_exp = opts.name_exp + '_' + opts.model_name
198
- os.makedirs(f"{opts.exp_path}/experiments", exist_ok=True)
199
- debug = True
200
- # Create directories
201
- experiment_dir = os.path.join(f"{opts.exp_path}","experiments", opts.name_exp)
202
- backup_code(opts.name_exp, opts.exp_path)
203
- os.makedirs(experiment_dir, exist_ok=debug) # False to prevent multiple train run by mistake
204
- os.makedirs(os.path.join(experiment_dir, "samples"), exist_ok=True)
205
- os.makedirs(os.path.join(experiment_dir, "checkpoints"), exist_ok=True)
206
- os.makedirs(os.path.join(experiment_dir, "results"), exist_ok=True)
207
- os.makedirs(os.path.join(experiment_dir, "logs"), exist_ok=True)
208
- print(f"Training on experiment {opts.name_exp}...")
209
- # Dump options
210
- with open(os.path.join(experiment_dir, "opts.txt"), "w") as f:
211
- for key, value in vars(opts).items():
212
- f.write(str(key) + ": " + str(value) + "\n")
213
- train(opts)
214
-
215
- if __name__ == "__main__":
216
- main()
 
1
+ import os
2
+ import random
3
+ import numpy as np
4
+ import shutil
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.optim import Adam, AdamW
9
+ from torchvision.utils import save_image
10
+ import wandb
11
+ from dataloader import get_loader
12
+ from models import util_funcs
13
+ from models.model_main import ModelMain
14
+ from options import get_parser_main_model
15
+ from data_utils.svg_utils import render
16
+ from time import time
17
+
18
+ def setup_seed(seed):
19
+ torch.manual_seed(seed)
20
+ torch.cuda.manual_seed_all(seed)
21
+ np.random.seed(seed)
22
+ random.seed(seed)
23
+ torch.backends.cudnn.deterministic = True
24
+
25
+ def train_main_model(opts):
26
+ setup_seed(opts.seed)
27
+ dir_exp = os.path.join(f"{opts.exp_path}", "experiments", opts.name_exp)
28
+ dir_sample = os.path.join(dir_exp, "samples")
29
+ dir_ckpt = os.path.join(dir_exp, "checkpoints")
30
+ dir_log = os.path.join(dir_exp, "logs")
31
+ logfile_train = open(os.path.join(dir_log, "train_loss_log.txt"), 'w')
32
+ logfile_val = open(os.path.join(dir_log, "val_loss_log.txt"), 'w')
33
+
34
+ train_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size, opts.mode)
35
+ val_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size_val, 'val')
36
+
37
+ run = wandb.init(project=opts.wandb_project_name, config=opts) # initialize wandb project
38
+ text_table = wandb.Table(columns=["epoch", "loss", "ref"])
39
+
40
+ model_main = ModelMain(opts)
41
+ if torch.cuda.is_available() and opts.multi_gpu:
42
+ model_main = torch.nn.DataParallel(model_main)
43
+
44
+ if opts.continue_training:
45
+ model_main.load_state_dict(torch.load(opts.continue_ckpt)['model'])
46
+
47
+ model_main.cuda()
48
+
49
+ parameters_all = [{"params": model_main.img_encoder.parameters()}, {"params": model_main.img_decoder.parameters()},
50
+ {"params": model_main.modality_fusion.parameters()}, {"params": model_main.transformer_main.parameters()},
51
+ {"params": model_main.transformer_seqdec.parameters()}]
52
+
53
+ optimizer = AdamW(parameters_all, lr=opts.lr, betas=(opts.beta1, opts.beta2), eps=opts.eps, weight_decay=opts.weight_decay)
54
+ scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.997)
55
+
56
+ for epoch in range(opts.init_epoch, opts.n_epochs):
57
+ t0 = time()
58
+ for idx, data in enumerate(train_loader):
59
+ for key in data: data[key] = data[key].cuda()
60
+ ret_dict, loss_dict = model_main(data)
61
+
62
+ loss = opts.loss_w_l1 * loss_dict['img']['l1'] + opts.loss_w_pt_c * loss_dict['img']['vggpt'] + opts.kl_beta * loss_dict['kl'] \
63
+ + loss_dict['svg']['total'] + loss_dict['svg_para']['total']
64
+
65
+ # perform optimization
66
+ optimizer.zero_grad()
67
+ loss.backward()
68
+ optimizer.step()
69
+ batches_done = epoch * len(train_loader) + idx + 1
70
+ message = (
71
+ f"Time: {'{} seconds'.format(time() - t0)}, "
72
+ f"Epoch: {epoch}/{opts.n_epochs}, Batch: {idx}/{len(train_loader)}, "
73
+ f"Loss: {loss.item():.6f}, "
74
+ f"img_l1_loss: {opts.loss_w_l1 * loss_dict['img']['l1'].item():.6f}, "
75
+ f"img_pt_c_loss: {opts.loss_w_pt_c * loss_dict['img']['vggpt']:.6f}, "
76
+ f"svg_total_loss: {loss_dict['svg']['total'].item():.6f}, "
77
+ f"svg_cmd_loss: {opts.loss_w_cmd * loss_dict['svg']['cmd'].item():.6f}, "
78
+ f"svg_args_loss: {opts.loss_w_args * loss_dict['svg']['args'].item():.6f}, "
79
+ f"svg_smooth_loss: {opts.loss_w_smt * loss_dict['svg']['smt'].item():.6f}, "
80
+ f"svg_aux_loss: {opts.loss_w_aux * loss_dict['svg']['aux'].item():.6f}, "
81
+ f"lr: {optimizer.param_groups[0]['lr']:.6f}, "
82
+ f"Step: {batches_done}"
83
+ )
84
+ if batches_done % opts.freq_log == 0:
85
+ logfile_train.write(message + '\n')
86
+ print(message)
87
+
88
+ if opts.wandb:
89
+ # print("Running With Wandb")
90
+ # Define the items for image and SVG losses
91
+ loss_img_items = ['l1', 'vggpt']
92
+ loss_svg_items = ['total', 'cmd', 'args', 'aux', 'smt']
93
+
94
+ # Log image loss items
95
+ for item in loss_img_items:
96
+ wandb.log({f'Loss/img_{item}': loss_dict['img'][item].item()}, step=batches_done)
97
+
98
+ # Log SVG loss items
99
+ for item in loss_svg_items:
100
+ wandb.log({f'Loss/svg_{item}': loss_dict['svg'][item].item()}, step=batches_done)
101
+ wandb.log({f'Loss/svg_para_{item}': loss_dict['svg_para'][item].item()}, step=batches_done)
102
+
103
+ # Log KL loss
104
+ wandb.log({'Loss/img_kl_loss': opts.kl_beta * loss_dict['kl'].item()}, step=batches_done)
105
+
106
+ wandb.log({
107
+ 'Images/trg_img': wandb.Image(ret_dict['img']['trg'][0], caption="Target"),
108
+ 'Images/img_output': wandb.Image(ret_dict['img']['out'][0], caption="Output")
109
+ }, step=batches_done)
110
+
111
+ text_table.add_data(epoch, loss, str(ret_dict['img']['ref'][0]))
112
+ wandb.log({"training_samples" : text_table})
113
+
114
+
115
+
116
+ if opts.freq_sample > 0 and batches_done % opts.freq_sample == 0:
117
+
118
+ img_sample = torch.cat((ret_dict['img']['trg'].data, ret_dict['img']['out'].data), -2)
119
+ save_file = os.path.join(dir_sample, f"train_epoch_{epoch}_batch_{batches_done}.png")
120
+ save_image(img_sample, save_file, nrow=8, normalize=True)
121
+
122
+ if opts.freq_val > 0 and batches_done % opts.freq_val == 0:
123
+
124
+ with torch.no_grad():
125
+ model_main.eval()
126
+ loss_val = {'img':{'l1':0.0, 'vggpt':0.0}, 'svg':{'total':0.0, 'cmd':0.0, 'args':0.0, 'aux':0.0},
127
+ 'svg_para':{'total':0.0, 'cmd':0.0, 'args':0.0, 'aux':0.0}}
128
+
129
+ for val_idx, val_data in enumerate(val_loader):
130
+ for key in val_data: val_data[key] = val_data[key].cuda()
131
+ ret_dict_val, loss_dict_val = model_main(val_data, mode='val')
132
+ for loss_cat in ['img', 'svg']:
133
+ for key, _ in loss_val[loss_cat].items():
134
+ loss_val[loss_cat][key] += loss_dict_val[loss_cat][key]
135
+
136
+ for loss_cat in ['img', 'svg']:
137
+ for key, _ in loss_val[loss_cat].items():
138
+ loss_val[loss_cat][key] /= len(val_loader)
139
+
140
+ if opts.wandb:
141
+ for loss_cat in ['img', 'svg']:
142
+ # Iterate over keys and values in the loss dictionary
143
+ for key, value in loss_val[loss_cat].items():
144
+ # Log loss value to WandB
145
+ wandb.log({f'VAL/loss_{loss_cat}_{key}': value})
146
+
147
+ val_msg = (
148
+ f"Epoch: {epoch}/{opts.n_epochs}, Batch: {idx}/{len(train_loader)}, "
149
+ f"Val loss img l1: {loss_val['img']['l1']: .6f}, "
150
+ f"Val loss img pt: {loss_val['img']['vggpt']: .6f}, "
151
+ f"Val loss total: {loss_val['svg']['total']: .6f}, "
152
+ f"Val loss cmd: {loss_val['svg']['cmd']: .6f}, "
153
+ f"Val loss args: {loss_val['svg']['args']: .6f}, "
154
+ )
155
+
156
+ logfile_val.write(val_msg + "\n")
157
+ print(val_msg)
158
+
159
+
160
+ scheduler.step()
161
+
162
+ if epoch % opts.freq_ckpt == 0 and epoch >= opts.threshold_ckpt:
163
+ if opts.multi_gpu:
164
+ print(f"Saved {dir_ckpt}/{epoch}_{batches_done}.ckpt")
165
+ torch.save({'model':model_main.module.state_dict(), 'opt':optimizer.state_dict(), 'n_epoch':epoch, 'n_iter':batches_done}, f'{dir_ckpt}/{epoch}_{batches_done}.ckpt')
166
+ else:
167
+ print(f"Saved {dir_ckpt}/{epoch}_{batches_done}.ckpt")
168
+ torch.save({'model':model_main.state_dict(), 'opt':optimizer.state_dict(), 'n_epoch':epoch, 'n_iter':batches_done}, f'{dir_ckpt}/{epoch}_{batches_done}.ckpt')
169
+ if opts.wandb:
170
+ artifact = wandb.Artifact('model_main_checkpoints', type='model')
171
+ artifact.add_file(f'{dir_ckpt}/{epoch}_{batches_done}.ckpt')
172
+ run.log_artifact(artifact)
173
+
174
+ logfile_train.close()
175
+ logfile_val.close()
176
+
177
+ def backup_code(name_exp, exp_path):
178
+ os.makedirs(os.path.join(exp_path,'experiments', name_exp, 'code'), exist_ok=True)
179
+ shutil.copy('models/transformers.py', os.path.join(exp_path,'experiments', name_exp, 'code', 'transformers.py') )
180
+ shutil.copy('models/model_main.py', os.path.join(exp_path,'experiments', name_exp, 'code', 'model_main.py'))
181
+ shutil.copy('models/image_encoder.py', os.path.join(exp_path,'experiments', name_exp, 'code', 'image_encoder.py'))
182
+ shutil.copy('models/image_decoder.py', os.path.join(exp_path,'experiments', name_exp, 'code', 'image_decoder.py'))
183
+ shutil.copy('./train.py', os.path.join(exp_path,'experiments', name_exp, 'code', 'train.py'))
184
+ shutil.copy('./options.py', os.path.join(exp_path,'experiments', name_exp, 'code', 'options.py'))
185
+
186
+ def train(opts):
187
+ if opts.model_name == 'main_model':
188
+ train_main_model(opts)
189
+ elif opts.model_name == 'others':
190
+ train_others(opts)
191
+ else:
192
+ raise NotImplementedError
193
+
194
+ def main():
195
+
196
+ opts = get_parser_main_model().parse_args()
197
+ opts.name_exp = opts.name_exp + '_' + opts.model_name
198
+ os.makedirs(f"{opts.exp_path}/experiments", exist_ok=True)
199
+ debug = True
200
+ # Create directories
201
+ experiment_dir = os.path.join(f"{opts.exp_path}","experiments", opts.name_exp)
202
+ backup_code(opts.name_exp, opts.exp_path)
203
+ os.makedirs(experiment_dir, exist_ok=debug) # False to prevent multiple train run by mistake
204
+ os.makedirs(os.path.join(experiment_dir, "samples"), exist_ok=True)
205
+ os.makedirs(os.path.join(experiment_dir, "checkpoints"), exist_ok=True)
206
+ os.makedirs(os.path.join(experiment_dir, "results"), exist_ok=True)
207
+ os.makedirs(os.path.join(experiment_dir, "logs"), exist_ok=True)
208
+ print(f"Training on experiment {opts.name_exp}...")
209
+ # Dump options
210
+ with open(os.path.join(experiment_dir, "opts.txt"), "w") as f:
211
+ for key, value in vars(opts).items():
212
+ f.write(str(key) + ": " + str(value) + "\n")
213
+ train(opts)
214
+
215
+ if __name__ == "__main__":
216
+ main()