Spaces:
Runtime error
Runtime error
initial commit
Browse files- .gitignore +24 -0
- app.py +121 -0
- checkpoint/.gitkeep +0 -0
- config/datasets/googlefont.yaml +37 -0
- config/lightning.yaml +16 -0
- config/logging.yaml +25 -0
- config/models/google-font.yaml +43 -0
- config/setting-google-font.yaml +5 -0
- data/.gitkeep +0 -0
- datasets/__init__.py +2 -0
- datasets/googlefont.py +286 -0
- docs/ml-font-style-transfer.md +9 -0
- font_list.json +427 -0
- font_list_noto_sans.json +41 -0
- inference.py +163 -0
- lightning.py +313 -0
- models/__init__.py +2 -0
- models/decoder.py +45 -0
- models/discriminator.py +54 -0
- models/encoder.py +65 -0
- models/generator.py +22 -0
- models/loss.py +16 -0
- models/module.py +159 -0
- pretrained/.gitkeep +0 -0
- requirements.txt +6 -0
- trainer.py +88 -0
- trainer.sh +1 -0
- utils/__init__.py +3 -0
- utils/logger.py +39 -0
- utils/tb.py +25 -0
- utils/util.py +28 -0
.gitignore
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# block temp directory
|
2 |
+
.idea/
|
3 |
+
__pycache__/
|
4 |
+
.ipynb_checkpoints/
|
5 |
+
.vscode/
|
6 |
+
.temp/
|
7 |
+
lightning_logs/
|
8 |
+
|
9 |
+
# block extension
|
10 |
+
*.pkl
|
11 |
+
*.png
|
12 |
+
*.pth
|
13 |
+
*.json
|
14 |
+
*.ckpt
|
15 |
+
|
16 |
+
# block logging directory
|
17 |
+
logs/
|
18 |
+
wandb/
|
19 |
+
|
20 |
+
# custom
|
21 |
+
font-image
|
22 |
+
|
23 |
+
!font_list.json
|
24 |
+
!font_list_noto_sans.json
|
app.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Optional, Union, Tuple, List
|
5 |
+
import subprocess
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
from PIL import Image
|
9 |
+
from omegaconf import OmegaConf, DictConfig
|
10 |
+
|
11 |
+
from inference import InferenceServicer
|
12 |
+
|
13 |
+
PATH_DOCS = os.getenv("PATH_DOCS", default="docs/ml-font-style-transfer.md")
|
14 |
+
MODEL_CONFIG = os.getenv("MODEL_CONFIG", default="config/models/google-font.yaml")
|
15 |
+
|
16 |
+
MODEL_CHECKPOINT_PATH = os.getenv("MODEL_CHECKPOINT_PATH", default=None)
|
17 |
+
NOTO_SANS_ZIP_PATH = os.getenv("NOTO_SANS_ZIP_PATH", default=None)
|
18 |
+
|
19 |
+
LOCAL_CHECKPOINT_PATH = "checkpoint/checkpoint.ckpt"
|
20 |
+
LOCAL_NOTO_ZIP_PATH = "data/NotoSans.zip"
|
21 |
+
|
22 |
+
if MODEL_CHECKPOINT_PATH is not None:
|
23 |
+
subprocess.call(f"wget --no-check-certificate -O {LOCAL_CHECKPOINT_PATH} {MODEL_CHECKPOINT_PATH}", shell=True)
|
24 |
+
if NOTO_SANS_ZIP_PATH is not None:
|
25 |
+
subprocess.call(f"wget --no-check-certificate -O {LOCAL_NOTO_ZIP_PATH} {NOTO_SANS_ZIP_PATH}", shell=True)
|
26 |
+
subprocess.call(f"unzip data/NotoSans.zip -d {str(Path(LOCAL_NOTO_ZIP_PATH).parent)}", shell=True)
|
27 |
+
|
28 |
+
assert Path("checkpoint/checkpoint.ckpt").exists()
|
29 |
+
assert Path("data/NotoSans").exists()
|
30 |
+
|
31 |
+
EXAMPLE_FONTS = sorted([
|
32 |
+
"example_fonts/BalooDa2-Bold.ttf",
|
33 |
+
"example_fonts/BalooDa2-Regular.ttf",
|
34 |
+
"example_fonts/Lalezar-Regular.ttf",
|
35 |
+
"example_fonts/MaShanZheng-Regular.ttf",
|
36 |
+
])
|
37 |
+
|
38 |
+
def parse_args():
|
39 |
+
|
40 |
+
parser = argparse.ArgumentParser(description="Augmentation simulator for NetsPresso Trainer")
|
41 |
+
|
42 |
+
# -------- User arguments ----------------------------------------
|
43 |
+
|
44 |
+
parser.add_argument(
|
45 |
+
'--docs', type=Path, default=PATH_DOCS,
|
46 |
+
help="Docs string file")
|
47 |
+
|
48 |
+
parser.add_argument(
|
49 |
+
'--config', type=Path, default=MODEL_CONFIG,
|
50 |
+
help="Config for model")
|
51 |
+
|
52 |
+
parser.add_argument(
|
53 |
+
'--local', action='store_true',
|
54 |
+
help="Whether to run in local environment or not")
|
55 |
+
|
56 |
+
parser.add_argument(
|
57 |
+
'--port', type=int, default=50003,
|
58 |
+
help="Service port (only applicable when running on local server)")
|
59 |
+
|
60 |
+
args, _ = parser.parse_known_args()
|
61 |
+
|
62 |
+
return args
|
63 |
+
|
64 |
+
class InferenceServiceResolver(InferenceServicer):
|
65 |
+
def __init__(self, hp, checkpoint_path, content_image_dir, imsize=64, gpu_id='0') -> None:
|
66 |
+
super().__init__(hp, checkpoint_path, content_image_dir, imsize, gpu_id)
|
67 |
+
|
68 |
+
def generate(self, content_char: str, style_font: Union[str, Path]) -> List[Image.Image]:
|
69 |
+
try:
|
70 |
+
content_image, style_images, result = self.inference(content_char=content_char, style_font=style_font)
|
71 |
+
return [content_image, *style_images, result]
|
72 |
+
except Exception as e:
|
73 |
+
raise gr.Error(str(e))
|
74 |
+
|
75 |
+
def launch_gradio(docs_path: Path, hp: DictConfig, checkpoint_path: Path, content_image_dir: Path, is_local: bool, port: Optional[int] = None):
|
76 |
+
|
77 |
+
servicer = InferenceServiceResolver(hp, checkpoint_path, content_image_dir, gpu_id=None)
|
78 |
+
with gr.Blocks(title="Multilingual Font Style Transfer (training with Google Fonts)") as demo:
|
79 |
+
gr.Markdown(docs_path.read_text())
|
80 |
+
with gr.Row(equal_height=True):
|
81 |
+
character_input = gr.Textbox(max_lines=1, value="7", info="Only single character is acceptable (e.g. '간', '7', or 'ជ')")
|
82 |
+
style_font = gr.Dropdown(label="Select example font: ", choices=EXAMPLE_FONTS, value=EXAMPLE_FONTS[0])
|
83 |
+
run_button = gr.Button(value="Generate", variant='primary')
|
84 |
+
|
85 |
+
with gr.Row(equal_height=True):
|
86 |
+
with gr.Column(scale=1):
|
87 |
+
with gr.Group():
|
88 |
+
gr.Markdown(f"<center><h3>Content character</h3></center>")
|
89 |
+
content_char = gr.Image(label="Content character", show_label=False)
|
90 |
+
with gr.Column(scale=5):
|
91 |
+
with gr.Group():
|
92 |
+
gr.Markdown(f"<center><h3>Style font images</h3></center>")
|
93 |
+
with gr.Row(equal_height=True):
|
94 |
+
style_char_1 = gr.Image(label="Style #1", show_label=False)
|
95 |
+
style_char_2 = gr.Image(label="Style #2", show_label=False)
|
96 |
+
style_char_3 = gr.Image(label="Style #3", show_label=False)
|
97 |
+
style_char_4 = gr.Image(label="Style #4", show_label=False)
|
98 |
+
style_char_5 = gr.Image(label="Style #5", show_label=False)
|
99 |
+
with gr.Column(scale=1):
|
100 |
+
with gr.Group():
|
101 |
+
gr.Markdown(f"<center><h3>Generated font image</h3></center>")
|
102 |
+
generated_font = gr.Image(label="Generated font image", show_label=False)
|
103 |
+
|
104 |
+
outputs = [content_char, style_char_1, style_char_2, style_char_3, style_char_4, style_char_5, generated_font]
|
105 |
+
run_inputs = [character_input, style_font]
|
106 |
+
run_button.click(servicer.generate, inputs=run_inputs, outputs=outputs)
|
107 |
+
|
108 |
+
if is_local:
|
109 |
+
demo.launch(server_name="0.0.0.0", server_port=port)
|
110 |
+
else:
|
111 |
+
demo.launch()
|
112 |
+
|
113 |
+
|
114 |
+
if __name__ == "__main__":
|
115 |
+
args = parse_args()
|
116 |
+
|
117 |
+
hp = OmegaConf.load(args.config)
|
118 |
+
checkpoint_path = Path(LOCAL_CHECKPOINT_PATH)
|
119 |
+
content_image_dir = Path(LOCAL_NOTO_ZIP_PATH).with_suffix("")
|
120 |
+
|
121 |
+
launch_gradio(args.docs, hp, checkpoint_path, content_image_dir, args.local, args.port)
|
checkpoint/.gitkeep
ADDED
File without changes
|
config/datasets/googlefont.yaml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets:
|
2 |
+
type: GoogleFontDataset
|
3 |
+
train:
|
4 |
+
split: auto
|
5 |
+
font_dir: &font_dir ../DATA/fonts-image-20230929
|
6 |
+
imsize: 64
|
7 |
+
reference_imgs:
|
8 |
+
replace: False
|
9 |
+
char: &reference_char 1
|
10 |
+
style: &reference_style 5
|
11 |
+
|
12 |
+
squeeze_gray: &squeeze_gray True
|
13 |
+
transform:
|
14 |
+
# TODO
|
15 |
+
|
16 |
+
# loader configs
|
17 |
+
shuffle: True
|
18 |
+
batch_size: 64
|
19 |
+
num_workers: 12
|
20 |
+
|
21 |
+
eval:
|
22 |
+
split: auto
|
23 |
+
font_dir: *font_dir
|
24 |
+
imsize: 64
|
25 |
+
reference_imgs:
|
26 |
+
replace: False
|
27 |
+
char: *reference_char
|
28 |
+
style: *reference_style
|
29 |
+
|
30 |
+
squeeze_gray: *squeeze_gray
|
31 |
+
transform:
|
32 |
+
# TODO
|
33 |
+
|
34 |
+
# loader configs
|
35 |
+
shuffle: True
|
36 |
+
batch_size: 1
|
37 |
+
num_workers: 4
|
config/lightning.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pl_config:
|
2 |
+
checkpoint:
|
3 |
+
callback:
|
4 |
+
save_top_k: -1
|
5 |
+
verbose: True
|
6 |
+
every_n_epochs: 5 #epochs
|
7 |
+
|
8 |
+
trainer:
|
9 |
+
gradient_clip_val: 0
|
10 |
+
max_epochs: 2000
|
11 |
+
num_sanity_val_steps: 1
|
12 |
+
fast_dev_run: False
|
13 |
+
check_val_every_n_epoch: 5
|
14 |
+
# distributed_backend: 'ddp'
|
15 |
+
accelerator: 'cuda'
|
16 |
+
benchmark: True
|
config/logging.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
logging:
|
2 |
+
dry_run: False
|
3 |
+
device: cuda
|
4 |
+
log_dir: /ssd1/hksong/LOG/font
|
5 |
+
seed: ftgan-patch-full
|
6 |
+
freq:
|
7 |
+
train: 100 # step
|
8 |
+
|
9 |
+
nepochs_decay: 100
|
10 |
+
|
11 |
+
gan_loss: lsgan
|
12 |
+
lambda_L1: 100
|
13 |
+
lambda_classifier: ~
|
14 |
+
|
15 |
+
trainer: base
|
16 |
+
|
17 |
+
savefiles: [
|
18 |
+
'*.py',
|
19 |
+
'data/*.*',
|
20 |
+
'datasets/*.*',
|
21 |
+
'models/*.*',
|
22 |
+
'configs/*.*',
|
23 |
+
'utils/*.*',
|
24 |
+
'trainer/*.*',
|
25 |
+
]
|
config/models/google-font.yaml
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
models:
|
2 |
+
G:
|
3 |
+
encoder:
|
4 |
+
content:
|
5 |
+
type: ContentVanillaEncoder
|
6 |
+
depth: 2
|
7 |
+
style:
|
8 |
+
type: StyleVanillaEncoder
|
9 |
+
depth: 2
|
10 |
+
decoder:
|
11 |
+
type: VanillaDecoder
|
12 |
+
residual_blocks: 6
|
13 |
+
depth: 2
|
14 |
+
|
15 |
+
optim:
|
16 |
+
class: torch.optim.Adam
|
17 |
+
betas: [ 0.5, 0.999 ]
|
18 |
+
lr: 0.0002
|
19 |
+
lr_policy: step
|
20 |
+
lr_decay_iters: 1000
|
21 |
+
|
22 |
+
init_type: normal
|
23 |
+
init_gain: 0.02
|
24 |
+
|
25 |
+
D_content:
|
26 |
+
in_channels: 2 # char + 1
|
27 |
+
class: models.discriminator.PatchGANDiscriminator
|
28 |
+
optim:
|
29 |
+
class: torch.optim.Adam
|
30 |
+
betas: [ 0.5, 0.999 ]
|
31 |
+
lr: 2e-4
|
32 |
+
lr_policy: step
|
33 |
+
lr_decay_iters: 1000
|
34 |
+
|
35 |
+
D_style:
|
36 |
+
in_channels: 6 # style + 1
|
37 |
+
class: models.discriminator.PatchGANDiscriminator
|
38 |
+
optim:
|
39 |
+
class: torch.optim.Adam
|
40 |
+
betas: [ 0.5, 0.999 ]
|
41 |
+
lr: 2e-4
|
42 |
+
lr_policy: step
|
43 |
+
lr_decay_iters: 1000
|
config/setting-google-font.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
config:
|
2 |
+
dataset: 'config/datasets/googlefont.yaml'
|
3 |
+
model: 'config/models/google-font.yaml'
|
4 |
+
logging: 'config/logging.yaml'
|
5 |
+
lightning: 'config/lightning.yaml'
|
data/.gitkeep
ADDED
File without changes
|
datasets/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .googlefont import GoogleFontDataset
|
2 |
+
from .ftgan import FTGANDataset
|
datasets/googlefont.py
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pickle
|
3 |
+
import random
|
4 |
+
import string
|
5 |
+
import json
|
6 |
+
import logging
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
from omegaconf import OmegaConf
|
10 |
+
import numpy as np
|
11 |
+
import PIL.Image as Image
|
12 |
+
import torch
|
13 |
+
from torch.utils.data import Dataset
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
REPEATE_NUM = 10000
|
17 |
+
|
18 |
+
WHITE = 255
|
19 |
+
|
20 |
+
MAX_TRIAL = 10
|
21 |
+
|
22 |
+
_upper_case = set(map(lambda s: f"{ord(s):04X}", string.ascii_uppercase))
|
23 |
+
_digits = set(map(lambda s: f"{ord(s):04X}", string.digits))
|
24 |
+
english_set = list(_upper_case.union(_digits))
|
25 |
+
|
26 |
+
NOTO_FONT_DIRNAME = "Noto"
|
27 |
+
|
28 |
+
|
29 |
+
class GoogleFontDataset(Dataset):
|
30 |
+
def __init__(self, args, mode='train',
|
31 |
+
metadata_path="./lang_set.json"):
|
32 |
+
super(GoogleFontDataset, self).__init__()
|
33 |
+
self.args = args
|
34 |
+
self.font_dir = Path(args.font_dir)
|
35 |
+
self.mode = mode
|
36 |
+
self.lang_list = sorted([x.stem for x in self.font_dir.iterdir() if x.is_dir()])
|
37 |
+
self.min_tight_bound = 10000
|
38 |
+
self.min_font_name = None
|
39 |
+
|
40 |
+
if self.mode == 'train':
|
41 |
+
self.lang_list = self.lang_list[:-2]
|
42 |
+
else:
|
43 |
+
self.lang_list = self.lang_list[-2:]
|
44 |
+
with open(metadata_path, "r") as json_f:
|
45 |
+
self.data = json.load(json_f)
|
46 |
+
|
47 |
+
self.num_lang = None
|
48 |
+
self.num_font = None
|
49 |
+
self.num_char = None
|
50 |
+
self.content_meta, self.style_meta, self.num_lang, self.num_font, self.num_char = self.get_meta()
|
51 |
+
logging.info(f"min_tight_bound: {self.min_tight_bound}") # 20
|
52 |
+
|
53 |
+
@staticmethod
|
54 |
+
def center_align(bg_img, item_img, fit=False):
|
55 |
+
bg_img = bg_img.copy()
|
56 |
+
item_img = item_img.copy()
|
57 |
+
item_w, item_h = item_img.size
|
58 |
+
W, H = bg_img.size
|
59 |
+
if fit:
|
60 |
+
item_ratio = item_w / item_h
|
61 |
+
bg_ratio = W / H
|
62 |
+
|
63 |
+
if bg_ratio > item_ratio:
|
64 |
+
# height fitting
|
65 |
+
resize_ratio = H / item_h
|
66 |
+
else:
|
67 |
+
# width fitting
|
68 |
+
resize_ratio = W / item_w
|
69 |
+
item_img = item_img.resize((int(item_w * resize_ratio), int(item_h * resize_ratio)))
|
70 |
+
item_w, item_h = item_img.size
|
71 |
+
|
72 |
+
bg_img.paste(item_img, ((W - item_w) // 2, (H - item_h) // 2))
|
73 |
+
return bg_img
|
74 |
+
|
75 |
+
def _get_content_image(self, png_path):
|
76 |
+
im = Image.open(png_path)
|
77 |
+
bg_img = Image.new('RGB', (self.args.imsize, self.args.imsize), color='white')
|
78 |
+
blend_img = self.center_align(bg_img, im, fit=True)
|
79 |
+
return blend_img
|
80 |
+
|
81 |
+
def _get_style_image(self, png_path):
|
82 |
+
im = Image.open(png_path)
|
83 |
+
w, h = im.size
|
84 |
+
|
85 |
+
# tight_bound_check & update
|
86 |
+
tight_bound = self.get_tight_bound_size(np.array(im))
|
87 |
+
if self.min_tight_bound > tight_bound:
|
88 |
+
self.min_tight_bound = tight_bound
|
89 |
+
self.min_font_name = png_path
|
90 |
+
logging.debug(f"min_tight_bound: {self.min_tight_bound}, min_font_name: {self.min_font_name}")
|
91 |
+
|
92 |
+
bg_img = Image.new('RGB', (max([w, h, self.args.imsize]), max([w, h, self.args.imsize])), color='white')
|
93 |
+
blend_img = self.center_align(bg_img, im)
|
94 |
+
return blend_img
|
95 |
+
|
96 |
+
def get_meta(self):
|
97 |
+
content_meta = dict()
|
98 |
+
style_meta = dict()
|
99 |
+
|
100 |
+
num_lang = 0
|
101 |
+
num_font = 0
|
102 |
+
num_char = 0
|
103 |
+
for lang_dir in tqdm(self.lang_list, total=len(self.lang_list)):
|
104 |
+
font_list = sorted([x for x in (self.font_dir / lang_dir).iterdir() if x.is_dir()])
|
105 |
+
|
106 |
+
font_content_dict = dict()
|
107 |
+
font_style_dict = dict()
|
108 |
+
|
109 |
+
for font_dir in font_list:
|
110 |
+
image_content_dict = dict()
|
111 |
+
image_style_dict = dict()
|
112 |
+
|
113 |
+
png_list = [x for x in font_dir.glob("*.png")]
|
114 |
+
|
115 |
+
for png_path in png_list:
|
116 |
+
|
117 |
+
# image_content_dict[png_path.stem] = self._get_content_image(png_path)
|
118 |
+
# image_style_dict[png_path.stem] = self._get_style_image(png_path)
|
119 |
+
image_content_dict[png_path.stem] = png_path
|
120 |
+
image_style_dict[png_path.stem] = png_path
|
121 |
+
num_char += 1
|
122 |
+
|
123 |
+
font_content_dict[font_dir.stem] = image_content_dict
|
124 |
+
font_style_dict[font_dir.stem] = image_style_dict
|
125 |
+
num_font += 1
|
126 |
+
|
127 |
+
content_meta[lang_dir] = font_content_dict
|
128 |
+
style_meta[lang_dir] = font_style_dict
|
129 |
+
num_lang += 1
|
130 |
+
|
131 |
+
return content_meta, style_meta, num_lang, num_font, num_char
|
132 |
+
|
133 |
+
@staticmethod
|
134 |
+
def get_tight_bound_size(img):
|
135 |
+
contents_cell = np.where(img < WHITE)
|
136 |
+
|
137 |
+
if len(contents_cell[0]) == 0:
|
138 |
+
return 0
|
139 |
+
|
140 |
+
size = {
|
141 |
+
'xmin': np.min(contents_cell[1]),
|
142 |
+
'ymin': np.min(contents_cell[0]),
|
143 |
+
'xmax': np.max(contents_cell[1]) + 1,
|
144 |
+
'ymax': np.max(contents_cell[0]) + 1,
|
145 |
+
}
|
146 |
+
return max(size['xmax'] - size['xmin'], size['ymax'] - size['ymin'])
|
147 |
+
|
148 |
+
def get_patch_from_style_image(self, image, patch_per_image=1):
|
149 |
+
w, h = image.size
|
150 |
+
image_list = []
|
151 |
+
relative_patch_size = int(self.args.imsize * 2)
|
152 |
+
for _ in range(patch_per_image):
|
153 |
+
offset = w - relative_patch_size
|
154 |
+
if offset < relative_patch_size // 2:
|
155 |
+
# if image is too small, just resize
|
156 |
+
crop_candidate = np.array(image.resize((self.args.imsize, self.args.imsize)))
|
157 |
+
else:
|
158 |
+
# if image is sufficent to be cropped, randomly crop
|
159 |
+
x = np.random.randint(0, offset)
|
160 |
+
y = np.random.randint(0, offset)
|
161 |
+
crop_candidate = image.crop((x, y, x + relative_patch_size, y + relative_patch_size))
|
162 |
+
|
163 |
+
_trial = 0
|
164 |
+
while self.get_tight_bound_size(np.array(crop_candidate)) < relative_patch_size // 16 and _trial < MAX_TRIAL:
|
165 |
+
x = np.random.randint(0, offset)
|
166 |
+
y = np.random.randint(0, offset)
|
167 |
+
crop_candidate = image.crop((x, y, x + relative_patch_size, y + relative_patch_size))
|
168 |
+
_trial += 1
|
169 |
+
|
170 |
+
crop_candidate = np.array(crop_candidate.resize((self.args.imsize, self.args.imsize)))
|
171 |
+
image_list.append(crop_candidate)
|
172 |
+
return image_list
|
173 |
+
|
174 |
+
def get_pairs(self, content_english=False, style_english=False):
|
175 |
+
lang_content = random.choice(self.lang_list)
|
176 |
+
|
177 |
+
content_unicode_list = english_set if content_english else self.data[lang_content]
|
178 |
+
style_unicode_list = english_set if style_english else self.data[lang_content]
|
179 |
+
|
180 |
+
if content_english == style_english:
|
181 |
+
# content_unicode_list == style_unicode_list
|
182 |
+
chars = random.sample(content_unicode_list,
|
183 |
+
k=self.args.reference_imgs.style + 1)
|
184 |
+
content_char = chars[-1]
|
185 |
+
style_chars = chars[:self.args.reference_imgs.style]
|
186 |
+
else:
|
187 |
+
content_char = random.choice(content_unicode_list)
|
188 |
+
style_chars = random.sample(style_unicode_list, k=self.args.reference_imgs.style)
|
189 |
+
|
190 |
+
# fonts = random.sample(self.content_meta[lang_content].keys(),
|
191 |
+
# k=self.args.reference_imgs.char + 1)
|
192 |
+
# content_fonts = fonts[:self.args.reference_imgs.char]
|
193 |
+
# style_font = fonts[-1]
|
194 |
+
|
195 |
+
style_font_list = list(self.content_meta[lang_content].keys())
|
196 |
+
style_font_list.remove(NOTO_FONT_DIRNAME)
|
197 |
+
style_font = random.choice(style_font_list)
|
198 |
+
content_fonts = [NOTO_FONT_DIRNAME]
|
199 |
+
|
200 |
+
content_fonts_image = [self.content_meta[lang_content][x][content_char] for x in content_fonts]
|
201 |
+
style_chars_image = [self.content_meta[lang_content][style_font][x] for x in style_chars]
|
202 |
+
|
203 |
+
# style_chars_image = [self.content_meta[lang_content][style_font][x] for x in style_chars]
|
204 |
+
|
205 |
+
# style_chars_cropped = []
|
206 |
+
# for style_char_image in style_chars_image:
|
207 |
+
# style_chars_cropped.extend(self.get_patch_from_style_image(style_char_image,
|
208 |
+
# patch_per_image=self.args.reference_imgs.style // self.args.reference_imgs.char))
|
209 |
+
|
210 |
+
target_image = self.content_meta[lang_content][style_font][content_char]
|
211 |
+
|
212 |
+
content_fonts_image = [self._get_content_image(image_path) for image_path in content_fonts_image]
|
213 |
+
style_chars_image = [self._get_content_image(image_path) for image_path in style_chars_image]
|
214 |
+
target_image = self._get_content_image(target_image)
|
215 |
+
|
216 |
+
return content_char, content_fonts, content_fonts_image, style_font, style_chars, style_chars_image, target_image
|
217 |
+
|
218 |
+
def __getitem__(self, idx):
|
219 |
+
"""GoogleFontDataset의 __getitem__
|
220 |
+
|
221 |
+
Args:
|
222 |
+
idx (int): torch dataset index
|
223 |
+
|
224 |
+
Returns:
|
225 |
+
dict: return dict with following keys
|
226 |
+
|
227 |
+
gt_images: target_image,
|
228 |
+
content_images: same_chars_image,
|
229 |
+
style_images: same_fonts_image,
|
230 |
+
style_idx: font_idx,
|
231 |
+
char_idx: char_idx,
|
232 |
+
content_image_idxs: same_chars,
|
233 |
+
style_image_idxs: same_fonts,
|
234 |
+
image_paths: ''
|
235 |
+
"""
|
236 |
+
use_eng_content, use_eng_style = random.choice([(True, False), (False, True), (False, False)])
|
237 |
+
|
238 |
+
if self.mode != 'train':
|
239 |
+
use_eng_content = False
|
240 |
+
use_eng_style = True
|
241 |
+
|
242 |
+
content_char, content_fonts, content_fonts_image, style_font, style_chars, style_chars_image, target_image = \
|
243 |
+
self.get_pairs(content_english=use_eng_content, style_english=use_eng_style)
|
244 |
+
|
245 |
+
content_fonts_image = np.array([np.mean(np.array(x), axis=-1) / WHITE
|
246 |
+
for x in content_fonts_image], dtype=np.float32)
|
247 |
+
style_chars_image = np.array([np.mean(np.array(x), axis=-1) / WHITE
|
248 |
+
for x in style_chars_image], dtype=np.float32)
|
249 |
+
target_image = np.mean(np.array(target_image, dtype=np.float32), axis=-1)[np.newaxis, ...] / WHITE
|
250 |
+
|
251 |
+
dict_return = {
|
252 |
+
# data for training
|
253 |
+
'gt_images': target_image,
|
254 |
+
'content_images': content_fonts_image,
|
255 |
+
'style_images': style_chars_image, # TODO: crop style image with fixed size
|
256 |
+
# data for logging
|
257 |
+
'style_idx': style_font,
|
258 |
+
'char_idx': content_char,
|
259 |
+
'content_image_idxs': content_fonts,
|
260 |
+
'style_image_idxs': style_chars,
|
261 |
+
'image_paths': '',
|
262 |
+
}
|
263 |
+
return dict_return
|
264 |
+
|
265 |
+
def __len__(self):
|
266 |
+
return len(self.lang_list) * REPEATE_NUM
|
267 |
+
|
268 |
+
|
269 |
+
if __name__ == '__main__':
|
270 |
+
hp = OmegaConf.load('config/datasets/googlefont.yaml').datasets.train
|
271 |
+
metadata_path = "./lang_set.json"
|
272 |
+
FONT_DIR = "/data2/hksong/DATA/fonts-image"
|
273 |
+
|
274 |
+
_dataset = GoogleFontDataset(hp, metadata_path=metadata_path, font_dir=FONT_DIR)
|
275 |
+
TEST_ITER_NUM = 4
|
276 |
+
for i in range(TEST_ITER_NUM):
|
277 |
+
data = _dataset[i]
|
278 |
+
print(data.keys())
|
279 |
+
print(data['gt_image'].size,
|
280 |
+
data['content_images'][0].size,
|
281 |
+
data['style_images'][0].size,
|
282 |
+
data['lang'],
|
283 |
+
data['style_idx'],
|
284 |
+
data['char_idx'],
|
285 |
+
data['content_image_idxs'],
|
286 |
+
data['style_image_idxs'])
|
docs/ml-font-style-transfer.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<center><h1>Multilingual Font Style Transfer</h1></center>
|
2 |
+
|
3 |
+
- Compostion-free font style transfer across 13 different languages
|
4 |
+
- Trained with [Google Fonts](https://github.com/google/fonts) (ofl fonts and Nota Sans)
|
5 |
+
|
6 |
+
This is personal concept proofing demo, so it does not guarantee that the quality of output.
|
7 |
+
I hope that in someday there will be an established model for the better mulitlingual society.
|
8 |
+
|
9 |
+
I only used personal RTX 30 series GPU(s) for training the model. The model is heavily inspired from a model from the previous study, [FTransGAN](https://github.com/ligoudaner377/font_translator_gan) (Li et al.).
|
font_list.json
ADDED
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"arabic": [
|
3 |
+
"ofl/baloobhaijaan2/BalooBhaijaan2[wght].ttf",
|
4 |
+
"ofl/bonanova/BonaNova-Italic.ttf",
|
5 |
+
"ofl/bonanova/BonaNova-Regular.ttf",
|
6 |
+
"ofl/cairo/Cairo[wght].ttf",
|
7 |
+
"ofl/changa/Changa[wght].ttf",
|
8 |
+
"ofl/elmessiri/ElMessiri[wght].ttf",
|
9 |
+
"ofl/handjet/Handjet[EGRD,ESHP,wght].ttf",
|
10 |
+
"ofl/harmattan/Harmattan-Regular.ttf",
|
11 |
+
"ofl/ibmplexsansarabic/IBMPlexSansArabic-Regular.ttf",
|
12 |
+
"ofl/katibeh/Katibeh-Regular.ttf",
|
13 |
+
"ofl/kufam/Kufam-Italic[wght].ttf",
|
14 |
+
"ofl/kufam/Kufam[wght].ttf",
|
15 |
+
"ofl/lalezar/Lalezar-Regular.ttf",
|
16 |
+
"ofl/lemonada/Lemonada[wght].ttf",
|
17 |
+
"ofl/lemonadavfbeta/LemonadaVFBeta.ttf",
|
18 |
+
"ofl/markazitext/MarkaziText[wght].ttf",
|
19 |
+
"ofl/mirza/Mirza-Regular.ttf",
|
20 |
+
"ofl/qahiri/Qahiri-Regular.ttf",
|
21 |
+
"ofl/rakkas/Rakkas-Regular.ttf",
|
22 |
+
"ofl/readexpro/ReadexPro[wght].ttf",
|
23 |
+
"ofl/reemkufi/ReemKufi[wght].ttf",
|
24 |
+
"ofl/scheherazadenew/ScheherazadeNew-Regular.ttf",
|
25 |
+
"ofl/scheherazade/Scheherazade-Regular.ttf",
|
26 |
+
"ofl/tajawal/Tajawal-Regular.ttf",
|
27 |
+
"ofl/vibes/Vibes-Regular.ttf"
|
28 |
+
],
|
29 |
+
"bengali": [
|
30 |
+
"ofl/atma/Atma-Regular.ttf",
|
31 |
+
"ofl/balooda2/BalooDa2-Regular.ttf",
|
32 |
+
"ofl/galada/Galada-Regular.ttf",
|
33 |
+
"ofl/hindsiliguri/HindSiliguri-Regular.ttf",
|
34 |
+
"ofl/mina/Mina-Regular.ttf"
|
35 |
+
],
|
36 |
+
"gujarati": [
|
37 |
+
"ofl/baloobhai2/BalooBhai2[wght].ttf",
|
38 |
+
"ofl/farsan/Farsan-Regular.ttf",
|
39 |
+
"ofl/hindvadodara/HindVadodara-Regular.ttf",
|
40 |
+
"ofl/mogra/Mogra-Regular.ttf",
|
41 |
+
"ofl/muktavaani/MuktaVaani-Regular.ttf",
|
42 |
+
"ofl/rasa/Rasa[wght].ttf",
|
43 |
+
"ofl/shrikhand/Shrikhand-Regular.ttf"
|
44 |
+
],
|
45 |
+
"hebrew": [
|
46 |
+
"ofl/adobeblank/AdobeBlank-Regular.ttf",
|
47 |
+
"ofl/alef/Alef-Regular.ttf",
|
48 |
+
"ofl/amaticsc/AmaticSC-Regular.ttf",
|
49 |
+
"ofl/bellefair/Bellefair-Regular.ttf",
|
50 |
+
"ofl/bonanova/BonaNova-Italic.ttf",
|
51 |
+
"ofl/bonanova/BonaNova-Regular.ttf",
|
52 |
+
"ofl/cardo/Cardo-Italic.ttf",
|
53 |
+
"ofl/cardo/Cardo-Regular.ttf",
|
54 |
+
"ofl/davidlibre/DavidLibre-Regular.ttf",
|
55 |
+
"ofl/frankruhllibre/FrankRuhlLibre-Regular.ttf",
|
56 |
+
"ofl/handjet/Handjet[EGRD,ESHP,wght].ttf",
|
57 |
+
"ofl/heebo/Heebo[wght].ttf",
|
58 |
+
"ofl/ibmplexsanshebrew/IBMPlexSansHebrew-Regular.ttf",
|
59 |
+
"ofl/karantina/Karantina-Regular.ttf",
|
60 |
+
"ofl/miriamlibre/MiriamLibre-Regular.ttf",
|
61 |
+
"ofl/mplus1p/Mplus1p-Regular.ttf",
|
62 |
+
"ofl/roundedmplus1c/RoundedMplus1c-Regular.ttf",
|
63 |
+
"ofl/rubikbeastly/RubikBeastly-Regular.ttf",
|
64 |
+
"ofl/rubikmonoone/RubikMonoOne-Regular.ttf",
|
65 |
+
"ofl/rubikone/RubikOne-Regular.ttf",
|
66 |
+
"ofl/rubik/Rubik-Italic[wght].ttf",
|
67 |
+
"ofl/secularone/SecularOne-Regular.ttf",
|
68 |
+
"ofl/suezone/SuezOne-Regular.ttf"
|
69 |
+
],
|
70 |
+
"japanese": [
|
71 |
+
"ofl/delagothicone/DelaGothicOne-Regular.ttf",
|
72 |
+
"ofl/dotgothic16/DotGothic16-Regular.ttf",
|
73 |
+
"ofl/hachimarupop/HachiMaruPop-Regular.ttf",
|
74 |
+
"ofl/jejugothic/JejuGothic-Regular.ttf",
|
75 |
+
"ofl/jejuhallasan/JejuHallasan-Regular.ttf",
|
76 |
+
"ofl/jejumyeongjo/JejuMyeongjo-Regular.ttf",
|
77 |
+
"ofl/kaiseidecol/KaiseiDecol-Regular.ttf",
|
78 |
+
"ofl/kaiseiharunoumi/KaiseiHarunoUmi-Regular.ttf",
|
79 |
+
"ofl/kaiseiopti/KaiseiOpti-Regular.ttf",
|
80 |
+
"ofl/kaiseitokumin/KaiseiTokumin-Regular.ttf"
|
81 |
+
],
|
82 |
+
"khmer": [
|
83 |
+
"ofl/angkor/Angkor-Regular.ttf",
|
84 |
+
"ofl/battambang/Battambang-Regular.ttf",
|
85 |
+
"ofl/bayon/Bayon-Regular.ttf",
|
86 |
+
"ofl/bokor/Bokor-Regular.ttf",
|
87 |
+
"ofl/dangrek/Dangrek-Regular.ttf",
|
88 |
+
"ofl/fasthand/Fasthand-Regular.ttf",
|
89 |
+
"ofl/freehand/Freehand-Regular.ttf",
|
90 |
+
"ofl/hanuman/Hanuman-Regular.ttf",
|
91 |
+
"ofl/kohsantepheap/KohSantepheap-Regular.ttf",
|
92 |
+
"ofl/koulen/Koulen-Regular.ttf",
|
93 |
+
"ofl/metal/Metal-Regular.ttf",
|
94 |
+
"ofl/moul/Moul-Regular.ttf",
|
95 |
+
"ofl/moulpali/Moulpali-Regular.ttf",
|
96 |
+
"ofl/nokora/Nokora-Regular.ttf",
|
97 |
+
"ofl/odormeanchey/OdorMeanChey-Regular.ttf",
|
98 |
+
"ofl/preahvihear/Preahvihear-Regular.ttf",
|
99 |
+
"ofl/suwannaphum/Suwannaphum-Regular.ttf",
|
100 |
+
"ofl/taprom/Taprom-Regular.ttf"
|
101 |
+
],
|
102 |
+
"korean": [
|
103 |
+
"ofl/blackandwhitepicture/BlackAndWhitePicture-Regular.ttf",
|
104 |
+
"ofl/dongle/Dongle-Regular.ttf",
|
105 |
+
"ofl/gamjaflower/GamjaFlower-Regular.ttf",
|
106 |
+
"ofl/gothica1/GothicA1-Regular.ttf",
|
107 |
+
"ofl/gowunbatang/GowunBatang-Regular.ttf",
|
108 |
+
"ofl/gowundodum/GowunDodum-Regular.ttf",
|
109 |
+
"ofl/himelody/HiMelody-Regular.ttf",
|
110 |
+
"ofl/poorstory/PoorStory-Regular.ttf"
|
111 |
+
],
|
112 |
+
"malayalam": [
|
113 |
+
"ofl/baloochettan2/BalooChettan2-Regular.ttf",
|
114 |
+
"ofl/chilanka/Chilanka-Regular.ttf",
|
115 |
+
"ofl/gayathri/Gayathri-Regular.ttf",
|
116 |
+
"ofl/hindkochi/HindKochi-Regular.ttf",
|
117 |
+
"ofl/manjari/Manjari-Regular.ttf"
|
118 |
+
],
|
119 |
+
"cyrillic": [
|
120 |
+
"ofl/adobeblank/AdobeBlank-Regular.ttf",
|
121 |
+
"ofl/alegreya/Alegreya-Italic[wght].ttf",
|
122 |
+
"ofl/alegreya/Alegreya[wght].ttf",
|
123 |
+
"ofl/alegreyasans/AlegreyaSans-Italic.ttf",
|
124 |
+
"ofl/alegreyasans/AlegreyaSans-Regular.ttf",
|
125 |
+
"ofl/alegreyasanssc/AlegreyaSansSC-Italic.ttf",
|
126 |
+
"ofl/alegreyasanssc/AlegreyaSansSC-Regular.ttf",
|
127 |
+
"ofl/alegreyasc/AlegreyaSC-Italic.ttf",
|
128 |
+
"ofl/alegreyasc/AlegreyaSC-Regular.ttf",
|
129 |
+
"ofl/alice/Alice-Regular.ttf",
|
130 |
+
"ofl/alumnisans/AlumniSans-Italic[wght].ttf",
|
131 |
+
"ofl/amaticsc/AmaticSC-Regular.ttf",
|
132 |
+
"ofl/andika/Andika-Regular.ttf",
|
133 |
+
"ofl/anonymouspro/AnonymousPro-Italic.ttf",
|
134 |
+
"ofl/anonymouspro/AnonymousPro-Regular.ttf",
|
135 |
+
"ofl/arsenal/Arsenal-Italic.ttf",
|
136 |
+
"ofl/arsenal/Arsenal-Regular.ttf",
|
137 |
+
"ofl/badscript/BadScript-Regular.ttf",
|
138 |
+
"ofl/balsamiqsans/BalsamiqSans-Italic.ttf",
|
139 |
+
"ofl/balsamiqsans/BalsamiqSans-Regular.ttf",
|
140 |
+
"ofl/bellota/Bellota-Italic.ttf",
|
141 |
+
"ofl/bellota/Bellota-Regular.ttf",
|
142 |
+
"ofl/bellotatext/BellotaText-Italic.ttf",
|
143 |
+
"ofl/bellotatext/BellotaText-Regular.ttf",
|
144 |
+
"ofl/bitter/Bitter-Italic[wght].ttf",
|
145 |
+
"ofl/bonanova/BonaNova-Italic.ttf",
|
146 |
+
"ofl/bonanova/BonaNova-Regular.ttf",
|
147 |
+
"ofl/brygada1918/Brygada1918-Italic[wght].ttf",
|
148 |
+
"ofl/brygada1918/Brygada1918[wght].ttf",
|
149 |
+
"ofl/caveat/Caveat[wght].ttf",
|
150 |
+
"ofl/comfortaa/Comfortaa[wght].ttf",
|
151 |
+
"ofl/comforterbrush/ComforterBrush-Regular.ttf",
|
152 |
+
"ofl/comforter/Comforter-Regular.ttf",
|
153 |
+
"ofl/cormorant/Cormorant-Italic.ttf",
|
154 |
+
"ofl/cormorant/Cormorant-Regular.ttf",
|
155 |
+
"ofl/cormorantgaramond/CormorantGaramond-Italic.ttf",
|
156 |
+
"ofl/cormorantgaramond/CormorantGaramond-Regular.ttf",
|
157 |
+
"ofl/cormorantinfant/CormorantInfant-Italic.ttf",
|
158 |
+
"ofl/cormorantinfant/CormorantInfant-Regular.ttf",
|
159 |
+
"ofl/cormorantsc/CormorantSC-Regular.ttf",
|
160 |
+
"ofl/cormorantunicase/CormorantUnicase-Regular.ttf",
|
161 |
+
"ofl/crimsontext/CrimsonText-Regular.ttf",
|
162 |
+
"ofl/cuprum/Cuprum-Italic[wght].ttf",
|
163 |
+
"ofl/cuprum/Cuprum[wght].ttf",
|
164 |
+
"ofl/daysone/DaysOne-Regular.ttf",
|
165 |
+
"ofl/delagothicone/DelaGothicOne-Regular.ttf",
|
166 |
+
"ofl/didactgothic/DidactGothic-Regular.ttf",
|
167 |
+
"ofl/dotgothic16/DotGothic16-Regular.ttf",
|
168 |
+
"ofl/ebgaramond/EBGaramond-Italic[wght].ttf",
|
169 |
+
"ofl/ebgaramond/EBGaramond[wght].ttf",
|
170 |
+
"ofl/elmessiri/ElMessiri[wght].ttf",
|
171 |
+
"ofl/exo2/Exo2-Italic[wght].ttf",
|
172 |
+
"ofl/exo2/Exo2[wght].ttf",
|
173 |
+
"ofl/firasanscondensed/FiraSansCondensed-Italic.ttf",
|
174 |
+
"ofl/firasanscondensed/FiraSansCondensed-Regular.ttf",
|
175 |
+
"ofl/firasansextracondensed/FiraSansExtraCondensed-Italic.ttf",
|
176 |
+
"ofl/firasansextracondensed/FiraSansExtraCondensed-Regular.ttf",
|
177 |
+
"ofl/firasans/FiraSans-Italic.ttf",
|
178 |
+
"ofl/firasans/FiraSans-Regular.ttf",
|
179 |
+
"ofl/flowblock/FlowBlock-Regular.ttf",
|
180 |
+
"ofl/flowcircular/FlowCircular-Regular.ttf",
|
181 |
+
"ofl/flowrounded/FlowRounded-Regular.ttf",
|
182 |
+
"ofl/forum/Forum-Regular.ttf",
|
183 |
+
"ofl/gabriela/Gabriela-Regular.ttf",
|
184 |
+
"ofl/gothica1/GothicA1-Regular.ttf",
|
185 |
+
"ofl/hachimarupop/HachiMaruPop-Regular.ttf",
|
186 |
+
"ofl/handjet/Handjet[EGRD,ESHP,wght].ttf",
|
187 |
+
"ofl/hinamincho/HinaMincho-Regular.ttf",
|
188 |
+
"ofl/ibmplexmono/IBMPlexMono-Italic.ttf",
|
189 |
+
"ofl/ibmplexmono/IBMPlexMono-Regular.ttf",
|
190 |
+
"ofl/ibmplexsans/IBMPlexSans-Italic.ttf",
|
191 |
+
"ofl/ibmplexsans/IBMPlexSans-Regular.ttf",
|
192 |
+
"ofl/ibmplexserif/IBMPlexSerif-Italic.ttf",
|
193 |
+
"ofl/ibmplexserif/IBMPlexSerif-Regular.ttf",
|
194 |
+
"ofl/inter/Inter[slnt,wght].ttf",
|
195 |
+
"ofl/istokweb/IstokWeb-Italic.ttf",
|
196 |
+
"ofl/istokweb/IstokWeb-Regular.ttf",
|
197 |
+
"ofl/jejugothic/JejuGothic-Regular.ttf",
|
198 |
+
"ofl/jejuhallasan/JejuHallasan-Regular.ttf",
|
199 |
+
"ofl/jejumyeongjo/JejuMyeongjo-Regular.ttf",
|
200 |
+
"ofl/jetbrainsmono/JetBrainsMono-Italic[wght].ttf",
|
201 |
+
"ofl/jetbrainsmono/JetBrainsMono[wght].ttf",
|
202 |
+
"ofl/jost/Jost-Italic[wght].ttf",
|
203 |
+
"ofl/jost/Jost[wght].ttf",
|
204 |
+
"ofl/kaiseidecol/KaiseiDecol-Regular.ttf",
|
205 |
+
"ofl/kaiseiharunoumi/KaiseiHarunoUmi-Regular.ttf",
|
206 |
+
"ofl/kaiseiopti/KaiseiOpti-Regular.ttf",
|
207 |
+
"ofl/kaiseitokumin/KaiseiTokumin-Regular.ttf",
|
208 |
+
"ofl/kellyslab/KellySlab-Regular.ttf",
|
209 |
+
"ofl/kiwimaru/KiwiMaru-Regular.ttf",
|
210 |
+
"ofl/kleeone/KleeOne-Regular.ttf",
|
211 |
+
"ofl/kopubbatang/KoPubBatang-Regular.ttf",
|
212 |
+
"ofl/kurale/Kurale-Regular.ttf",
|
213 |
+
"ofl/lato/Lato-Italic.ttf",
|
214 |
+
"ofl/lato/Lato-Regular.ttf",
|
215 |
+
"ofl/ledger/Ledger-Regular.ttf",
|
216 |
+
"ofl/literata/Literata-Italic[opsz,wght].ttf",
|
217 |
+
"ofl/literata/Literata[opsz,wght].ttf",
|
218 |
+
"ofl/lobster/Lobster-Regular.ttf",
|
219 |
+
"ofl/lora/Lora-Italic[wght].ttf",
|
220 |
+
"ofl/lora/Lora[wght].ttf",
|
221 |
+
"ofl/marckscript/MarckScript-Regular.ttf",
|
222 |
+
"ofl/marmelad/Marmelad-Regular.ttf",
|
223 |
+
"ofl/merriweather/Merriweather-Italic.ttf",
|
224 |
+
"ofl/merriweather/Merriweather-Regular.ttf",
|
225 |
+
"ofl/montserratalternates/MontserratAlternates-Italic.ttf",
|
226 |
+
"ofl/montserratalternates/MontserratAlternates-Regular.ttf",
|
227 |
+
"ofl/montserrat/Montserrat-Italic.ttf",
|
228 |
+
"ofl/montserrat/Montserrat-Regular.ttf",
|
229 |
+
"ofl/mplus1p/Mplus1p-Regular.ttf",
|
230 |
+
"ofl/mulish/Mulish-Italic[wght].ttf",
|
231 |
+
"ofl/nanumgothiccoding/NanumGothicCoding-Regular.ttf",
|
232 |
+
"ofl/neucha/Neucha.ttf",
|
233 |
+
"ofl/newscycle/NewsCycle-Regular.ttf",
|
234 |
+
"ofl/nobile/Nobile-Italic.ttf",
|
235 |
+
"ofl/nobile/Nobile-Regular.ttf",
|
236 |
+
"ofl/nunito/Nunito-Italic[wght].ttf",
|
237 |
+
"ofl/nunitosans/NunitoSans-Italic.ttf",
|
238 |
+
"ofl/nunitosans/NunitoSans-Regular.ttf",
|
239 |
+
"ofl/oi/Oi-Regular.ttf",
|
240 |
+
"ofl/oranienbaum/Oranienbaum-Regular.ttf",
|
241 |
+
"ofl/orelegaone/OrelegaOne-Regular.ttf",
|
242 |
+
"ofl/oswald/Oswald[wght].ttf",
|
243 |
+
"ofl/overpass/Overpass-Italic[wght].ttf",
|
244 |
+
"ofl/overpass/Overpass[wght].ttf",
|
245 |
+
"ofl/pacifico/Pacifico-Regular.ttf",
|
246 |
+
"ofl/pangolin/Pangolin-Regular.ttf",
|
247 |
+
"ofl/pattaya/Pattaya-Regular.ttf",
|
248 |
+
"ofl/philosopher/Philosopher-Italic.ttf",
|
249 |
+
"ofl/philosopher/Philosopher-Regular.ttf",
|
250 |
+
"ofl/piazzolla/Piazzolla-Italic[opsz,wght].ttf",
|
251 |
+
"ofl/playfairdisplay/PlayfairDisplay-Italic[wght].ttf",
|
252 |
+
"ofl/playfairdisplay/PlayfairDisplay[wght].ttf",
|
253 |
+
"ofl/playfairdisplaysc/PlayfairDisplaySC-Italic.ttf",
|
254 |
+
"ofl/playfairdisplaysc/PlayfairDisplaySC-Regular.ttf",
|
255 |
+
"ofl/play/Play-Regular.ttf",
|
256 |
+
"ofl/podkova/Podkova[wght].ttf",
|
257 |
+
"ofl/podkovavfbeta/PodkovaVFBeta.ttf",
|
258 |
+
"ofl/poiretone/PoiretOne-Regular.ttf",
|
259 |
+
"ofl/prata/Prata-Regular.ttf",
|
260 |
+
"ofl/pressstart2p/PressStart2P-Regular.ttf",
|
261 |
+
"ofl/prostoone/ProstoOne-Regular.ttf",
|
262 |
+
"ofl/pushster/Pushster-Regular.ttf",
|
263 |
+
"ofl/raleway/Raleway-Italic[wght].ttf",
|
264 |
+
"ofl/rampartone/RampartOne-Regular.ttf",
|
265 |
+
"ofl/reggaeone/ReggaeOne-Regular.ttf",
|
266 |
+
"ofl/robotoflex/RobotoFlex[GRAD,XOPQ,XTRA,YOPQ,YTAS,YTDE,YTFI,YTLC,YTUC,opsz,slnt,wdth,wght].ttf",
|
267 |
+
"ofl/rocknrollone/RocknRollOne-Regular.ttf",
|
268 |
+
"ofl/roundedmplus1c/RoundedMplus1c-Regular.ttf",
|
269 |
+
"ofl/rubikbeastly/RubikBeastly-Regular.ttf",
|
270 |
+
"ofl/rubikmonoone/RubikMonoOne-Regular.ttf",
|
271 |
+
"ofl/rubikone/RubikOne-Regular.ttf",
|
272 |
+
"ofl/rubik/Rubik-Italic[wght].ttf",
|
273 |
+
"ofl/ruda/Ruda[wght].ttf",
|
274 |
+
"ofl/ruslandisplay/RuslanDisplay.ttf",
|
275 |
+
"ofl/russoone/RussoOne-Regular.ttf",
|
276 |
+
"ofl/sawarabigothic/SawarabiGothic-Regular.ttf",
|
277 |
+
"ofl/scada/Scada-Italic.ttf",
|
278 |
+
"ofl/scada/Scada-Regular.ttf",
|
279 |
+
"ofl/seoulhangangcondensed/SeoulHangangCondensed-BoldL.ttf",
|
280 |
+
"ofl/seoulhangangcondensed/SeoulHangangCondensed-Bold.ttf",
|
281 |
+
"ofl/seoulhangangcondensed/SeoulHangangCondensed-ExtraBold.ttf",
|
282 |
+
"ofl/seoulhangangcondensed/SeoulHangangCondensed-Medium.ttf",
|
283 |
+
"ofl/seoulhangang/SeoulHangang-Bold.ttf",
|
284 |
+
"ofl/seoulhangang/SeoulHangang-ExtraBold.ttf",
|
285 |
+
"ofl/seoulhangang/SeoulHangang-Medium.ttf",
|
286 |
+
"ofl/seoulnamsancondensed/SeoulNamsanCondensed-Black.ttf",
|
287 |
+
"ofl/seoulnamsancondensed/SeoulNamsanCondensed-Bold.ttf",
|
288 |
+
"ofl/seoulnamsancondensed/SeoulNamsanCondensed-ExtraBold.ttf",
|
289 |
+
"ofl/seoulnamsancondensed/SeoulNamsanCondensed-Medium.ttf",
|
290 |
+
"ofl/seoulnamsan/SeoulNamsan-Bold.ttf",
|
291 |
+
"ofl/seoulnamsan/SeoulNamsan-ExtraBold.ttf",
|
292 |
+
"ofl/seoulnamsan/SeoulNamsan-Medium.ttf",
|
293 |
+
"ofl/seoulnamsanvertical/SeoulNamsanVertical-Regular.ttf",
|
294 |
+
"ofl/seymourone/SeymourOne-Regular.ttf",
|
295 |
+
"ofl/sofiasans/SofiaSans-Italic[wdth,wght].ttf",
|
296 |
+
"ofl/sofiasans/SofiaSans[wdth,wght].ttf",
|
297 |
+
"ofl/sourcesans3/SourceSans3-Italic[wght].ttf",
|
298 |
+
"ofl/sourcesans3/SourceSans3[wght].ttf",
|
299 |
+
"ofl/sourcesanspro/SourceSansPro-Regular.ttf",
|
300 |
+
"ofl/sourceserifpro/SourceSerifPro-Italic.ttf",
|
301 |
+
"ofl/sourceserifpro/SourceSerifPro-Regular.ttf",
|
302 |
+
"ofl/spectralsc/SpectralSC-Italic.ttf",
|
303 |
+
"ofl/spectralsc/SpectralSC-Regular.ttf",
|
304 |
+
"ofl/spectral/Spectral-Italic.ttf",
|
305 |
+
"ofl/spectral/Spectral-Regular.ttf",
|
306 |
+
"ofl/stalinistone/StalinistOne-Regular.ttf",
|
307 |
+
"ofl/stick/Stick-Regular.ttf",
|
308 |
+
"ofl/stixtwomath/STIXTwoMath-Regular.ttf",
|
309 |
+
"ofl/stixtwotext/STIXTwoText-Italic[wght].ttf",
|
310 |
+
"ofl/stixtwotext/STIXTwoText[wght].ttf",
|
311 |
+
"ofl/strong/Strong-Regular.ttf",
|
312 |
+
"ofl/tenorsans/TenorSans-Regular.ttf",
|
313 |
+
"ofl/trainone/TrainOne-Regular.ttf",
|
314 |
+
"ofl/tuffy/Tuffy-Italic.ttf",
|
315 |
+
"ofl/tuffy/Tuffy-Regular.ttf",
|
316 |
+
"ofl/underdog/Underdog-Regular.ttf",
|
317 |
+
"ofl/viaodalibre/ViaodaLibre-Regular.ttf",
|
318 |
+
"ofl/vollkornsc/VollkornSC-Regular.ttf",
|
319 |
+
"ofl/vollkorn/Vollkorn-Italic[wght].ttf",
|
320 |
+
"ofl/vollkorn/Vollkorn[wght].ttf",
|
321 |
+
"ofl/yesevaone/YesevaOne-Regular.ttf",
|
322 |
+
"ofl/yomogi/Yomogi-Regular.ttf",
|
323 |
+
"ofl/yujiboku/YujiBoku-Regular.ttf",
|
324 |
+
"ofl/yujimai/YujiMai-Regular.ttf",
|
325 |
+
"ofl/yujisyuku/YujiSyuku-Regular.ttf",
|
326 |
+
"ofl/zenantiquesoft/ZenAntiqueSoft-Regular.ttf",
|
327 |
+
"ofl/zenantique/ZenAntique-Regular.ttf",
|
328 |
+
"ofl/zenkakugothicantique/ZenKakuGothicAntique-Regular.ttf",
|
329 |
+
"ofl/zenkakugothicnew/ZenKakuGothicNew-Regular.ttf",
|
330 |
+
"ofl/zenkurenaido/ZenKurenaido-Regular.ttf",
|
331 |
+
"ofl/zenmarugothic/ZenMaruGothic-Regular.ttf",
|
332 |
+
"ofl/zenoldmincho/ZenOldMincho-Regular.ttf"
|
333 |
+
],
|
334 |
+
"tamil": [
|
335 |
+
"ofl/arimamadurai/ArimaMadurai-Regular.ttf",
|
336 |
+
"ofl/baloothambi2/BalooThambi2-Regular.ttf",
|
337 |
+
"ofl/coiny/Coiny-Regular.ttf",
|
338 |
+
"ofl/hindmadurai/HindMadurai-Regular.ttf",
|
339 |
+
"ofl/kavivanar/Kavivanar-Regular.ttf",
|
340 |
+
"ofl/meerainimai/MeeraInimai-Regular.ttf",
|
341 |
+
"ofl/muktamalar/MuktaMalar-Regular.ttf",
|
342 |
+
"ofl/oi/Oi-Regular.ttf",
|
343 |
+
"ofl/pavanam/Pavanam-Regular.ttf",
|
344 |
+
"ofl/postnobillsjaffna/PostNoBillsJaffna-Regular.ttf"
|
345 |
+
],
|
346 |
+
"telugu": [
|
347 |
+
"ofl/akayatelivigala/AkayaTelivigala-Regular.ttf",
|
348 |
+
"ofl/balootammudu2/BalooTammudu2[wght].ttf",
|
349 |
+
"ofl/chathura/Chathura-Regular.ttf",
|
350 |
+
"ofl/dhurjati/Dhurjati-Regular.ttf",
|
351 |
+
"ofl/gidugu/Gidugu-Regular.ttf",
|
352 |
+
"ofl/gurajada/Gurajada-Regular.ttf",
|
353 |
+
"ofl/hindguntur/HindGuntur-Regular.ttf",
|
354 |
+
"ofl/lakkireddy/LakkiReddy-Regular.ttf",
|
355 |
+
"ofl/mallanna/Mallanna-Regular.ttf",
|
356 |
+
"ofl/mandali/Mandali-Regular.ttf",
|
357 |
+
"ofl/nats/NATS-Regular.ttf",
|
358 |
+
"ofl/ntr/NTR-Regular.ttf",
|
359 |
+
"ofl/peddana/Peddana-Regular.ttf",
|
360 |
+
"ofl/ramabhadra/Ramabhadra-Regular.ttf",
|
361 |
+
"ofl/ramaraja/Ramaraja-Regular.ttf",
|
362 |
+
"ofl/raviprakash/RaviPrakash-Regular.ttf",
|
363 |
+
"ofl/sreekrushnadevaraya/SreeKrushnadevaraya-Regular.ttf",
|
364 |
+
"ofl/suranna/Suranna-Regular.ttf",
|
365 |
+
"ofl/suravaram/Suravaram-Regular.ttf",
|
366 |
+
"ofl/tenaliramakrishna/TenaliRamakrishna-Regular.ttf",
|
367 |
+
"ofl/timmana/Timmana-Regular.ttf"
|
368 |
+
],
|
369 |
+
"thai": [
|
370 |
+
"ofl/athiti/Athiti-Regular.ttf",
|
371 |
+
"ofl/baijamjuree/BaiJamjuree-Italic.ttf",
|
372 |
+
"ofl/baijamjuree/BaiJamjuree-Regular.ttf",
|
373 |
+
"ofl/chakrapetch/ChakraPetch-Italic.ttf",
|
374 |
+
"ofl/chakrapetch/ChakraPetch-Regular.ttf",
|
375 |
+
"ofl/charm/Charm-Regular.ttf",
|
376 |
+
"ofl/charmonman/Charmonman-Regular.ttf",
|
377 |
+
"ofl/chonburi/Chonburi-Regular.ttf",
|
378 |
+
"ofl/fahkwang/Fahkwang-Italic.ttf",
|
379 |
+
"ofl/fahkwang/Fahkwang-Regular.ttf",
|
380 |
+
"ofl/ibmplexsansthai/IBMPlexSansThai-Regular.ttf",
|
381 |
+
"ofl/ibmplexsansthailooped/IBMPlexSansThaiLooped-Regular.ttf",
|
382 |
+
"ofl/itim/Itim-Regular.ttf",
|
383 |
+
"ofl/k2d/K2D-Italic.ttf",
|
384 |
+
"ofl/k2d/K2D-Regular.ttf",
|
385 |
+
"ofl/kanit/Kanit-Italic.ttf",
|
386 |
+
"ofl/kanit/Kanit-Regular.ttf",
|
387 |
+
"ofl/kodchasan/Kodchasan-Italic.ttf",
|
388 |
+
"ofl/kodchasan/Kodchasan-Regular.ttf",
|
389 |
+
"ofl/koho/KoHo-Italic.ttf",
|
390 |
+
"ofl/koho/KoHo-Regular.ttf",
|
391 |
+
"ofl/krub/Krub-Italic.ttf",
|
392 |
+
"ofl/krub/Krub-Regular.ttf",
|
393 |
+
"ofl/maitree/Maitree-Regular.ttf",
|
394 |
+
"ofl/mali/Mali-Italic.ttf",
|
395 |
+
"ofl/mali/Mali-Regular.ttf",
|
396 |
+
"ofl/mitr/Mitr-Regular.ttf",
|
397 |
+
"ofl/niramit/Niramit-Italic.ttf",
|
398 |
+
"ofl/niramit/Niramit-Regular.ttf",
|
399 |
+
"ofl/pattaya/Pattaya-Regular.ttf",
|
400 |
+
"ofl/pridi/Pridi-Regular.ttf",
|
401 |
+
"ofl/prompt/Prompt-Italic.ttf",
|
402 |
+
"ofl/prompt/Prompt-Regular.ttf",
|
403 |
+
"ofl/sarabun/Sarabun-Italic.ttf",
|
404 |
+
"ofl/sarabun/Sarabun-Regular.ttf",
|
405 |
+
"ofl/sriracha/Sriracha-Regular.ttf",
|
406 |
+
"ofl/srisakdi/Srisakdi-Regular.ttf",
|
407 |
+
"ofl/taviraj/Taviraj-Italic.ttf",
|
408 |
+
"ofl/taviraj/Taviraj-Regular.ttf",
|
409 |
+
"ofl/thasadith/Thasadith-Italic.ttf",
|
410 |
+
"ofl/thasadith/Thasadith-Regular.ttf",
|
411 |
+
"ofl/trirong/Trirong-Italic.ttf",
|
412 |
+
"ofl/trirong/Trirong-Regular.ttf"
|
413 |
+
],
|
414 |
+
"chinese": [
|
415 |
+
"ofl/liujianmaocao/LiuJianMaoCao-Regular.ttf",
|
416 |
+
"ofl/longcang/LongCang-Regular.ttf",
|
417 |
+
"ofl/mashanzheng/MaShanZheng-Regular.ttf",
|
418 |
+
"ofl/mochiypopone/MochiyPopOne-Regular.ttf",
|
419 |
+
"ofl/mochiypoppone/MochiyPopPOne-Regular.ttf",
|
420 |
+
"ofl/mplus1code/MPLUS1Code[wght].ttf",
|
421 |
+
"ofl/mplus1p/Mplus1p-Regular.ttf",
|
422 |
+
"ofl/newtegomin/NewTegomin-Regular.ttf",
|
423 |
+
"ofl/pottaone/PottaOne-Regular.ttf",
|
424 |
+
"ofl/rampartone/RampartOne-Regular.ttf",
|
425 |
+
"ofl/reggaeone/ReggaeOne-Regular.ttf"
|
426 |
+
]
|
427 |
+
}
|
font_list_noto_sans.json
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"arabic": [
|
3 |
+
"notosans/notosansarabic/NotoSansArabic[wdth,wght].ttf"
|
4 |
+
],
|
5 |
+
"bengali": [
|
6 |
+
"notosans/notosansbengali/NotoSansBengali[wdth,wght].ttf"
|
7 |
+
],
|
8 |
+
"gujarati": [
|
9 |
+
"notosans/notosansgujarati/NotoSansGujarati-Regular.ttf"
|
10 |
+
],
|
11 |
+
"hebrew": [
|
12 |
+
"notosans/notosanshebrew/NotoSansHebrew[wdth,wght].ttf"
|
13 |
+
],
|
14 |
+
"japanese": [
|
15 |
+
"notosans/notosansjp/NotoSansJP-Regular.otf"
|
16 |
+
],
|
17 |
+
"khmer": [
|
18 |
+
"notosans/notosanskhmer/NotoSansKhmer[wdth,wght].ttf"
|
19 |
+
],
|
20 |
+
"korean": [
|
21 |
+
"notosans/notosanskr/NotoSansKR-Regular.otf"
|
22 |
+
],
|
23 |
+
"malayalam": [
|
24 |
+
"notosans/notosansmalayalam/NotoSansMalayalam[wdth,wght].ttf"
|
25 |
+
],
|
26 |
+
"cyrillic": [
|
27 |
+
"notosans/notosans/NotoSans-Regular.ttf"
|
28 |
+
],
|
29 |
+
"tamil": [
|
30 |
+
"notosans/notosanstamil/NotoSansTamil[wdth,wght].ttf"
|
31 |
+
],
|
32 |
+
"telugu": [
|
33 |
+
"notosans/notosanstelugu/NotoSansTelugu[wdth,wght].ttf"
|
34 |
+
],
|
35 |
+
"thai": [
|
36 |
+
"notosans/notosansthai/NotoSansThai[wdth,wght].ttf"
|
37 |
+
],
|
38 |
+
"chinese": [
|
39 |
+
"notosans/notosanssc/NotoSansSC-Regular.otf"
|
40 |
+
]
|
41 |
+
}
|
inference.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Dict, List, Union, Tuple
|
3 |
+
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
from PIL import Image, ImageDraw, ImageFont
|
9 |
+
|
10 |
+
import models
|
11 |
+
|
12 |
+
GENERATOR_PREFIX = "networks.g."
|
13 |
+
WHITE = 255
|
14 |
+
EXAMPLE_CHARACTERS = ['A', 'B', 'C', 'D', 'E']
|
15 |
+
|
16 |
+
class InferenceServicer:
|
17 |
+
def __init__(self, hp, checkpoint_path, content_image_dir, imsize=64, gpu_id='0') -> None:
|
18 |
+
self.hp = hp
|
19 |
+
self.imsize = imsize
|
20 |
+
|
21 |
+
if gpu_id is None:
|
22 |
+
self.device = torch.device(f'cuda:0') if torch.cuda.is_available() else 'cpu'
|
23 |
+
else:
|
24 |
+
self.device = torch.device(f'cuda:{gpu_id}')
|
25 |
+
|
26 |
+
model_config = self.hp.models.G
|
27 |
+
self.model: nn.Module = models.Generator(model_config)
|
28 |
+
|
29 |
+
# Load Generator model weight
|
30 |
+
model_state_dict_pl = torch.load(checkpoint_path, map_location='cpu')
|
31 |
+
generator_state_dict = self.convert_generator_state_dict(model_state_dict_pl)
|
32 |
+
self.model.load_state_dict(generator_state_dict)
|
33 |
+
self.model.to(device=self.device)
|
34 |
+
self.model.eval()
|
35 |
+
|
36 |
+
# Setting Content font files
|
37 |
+
self.content_character_dict = self.load_content_character_dict(Path(content_image_dir))
|
38 |
+
|
39 |
+
@staticmethod
|
40 |
+
def convert_generator_state_dict(model_state_dict_pl):
|
41 |
+
generator_prefix = GENERATOR_PREFIX
|
42 |
+
generator_state_dict = {}
|
43 |
+
for module_name, module_state in model_state_dict_pl['state_dict'].items():
|
44 |
+
if module_name.startswith(generator_prefix):
|
45 |
+
generator_state_dict[module_name[len(generator_prefix):]] = module_state
|
46 |
+
|
47 |
+
return generator_state_dict
|
48 |
+
|
49 |
+
@staticmethod
|
50 |
+
def load_content_character_dict(content_image_dir: Path) -> Dict[str, Path]:
|
51 |
+
content_character_dict = {}
|
52 |
+
for filepath in content_image_dir.glob("**/*.png"):
|
53 |
+
content_character_dict[filepath.stem] = filepath
|
54 |
+
return content_character_dict
|
55 |
+
|
56 |
+
@staticmethod
|
57 |
+
def center_align(bg_img: Image.Image, item_img: Image.Image, fit=False) -> Image.Image:
|
58 |
+
bg_img = bg_img.copy()
|
59 |
+
item_img = item_img.copy()
|
60 |
+
item_w, item_h = item_img.size
|
61 |
+
W, H = bg_img.size
|
62 |
+
if fit:
|
63 |
+
item_ratio = item_w / item_h
|
64 |
+
bg_ratio = W / H
|
65 |
+
|
66 |
+
if bg_ratio > item_ratio:
|
67 |
+
# height fitting
|
68 |
+
resize_ratio = H / item_h
|
69 |
+
else:
|
70 |
+
# width fitting
|
71 |
+
resize_ratio = W / item_w
|
72 |
+
item_img = item_img.resize((int(item_w * resize_ratio), int(item_h * resize_ratio)))
|
73 |
+
item_w, item_h = item_img.size
|
74 |
+
|
75 |
+
bg_img.paste(item_img, ((W - item_w) // 2, (H - item_h) // 2))
|
76 |
+
return bg_img
|
77 |
+
|
78 |
+
def set_image(self, image: Union[Path, Image.Image]) -> Image.Image:
|
79 |
+
if isinstance(image, (str, Path)):
|
80 |
+
image = Image.open(image)
|
81 |
+
assert isinstance(image, Image.Image)
|
82 |
+
|
83 |
+
bg_img = Image.new('RGB', (self.imsize, self.imsize), color='white')
|
84 |
+
blend_img = self.center_align(bg_img, image, fit=True)
|
85 |
+
return blend_img
|
86 |
+
|
87 |
+
@staticmethod
|
88 |
+
def pil_image_to_array(blend_img: Image.Image) -> np.ndarray:
|
89 |
+
normalized_array = np.mean(np.array(blend_img, dtype=np.float32), axis=-1) / WHITE # L-only image normalized to [0, 1]
|
90 |
+
return normalized_array
|
91 |
+
|
92 |
+
def get_images_from_fontfile(self, font_file_path: Path, imgmode: str = 'RGB', position: tuple = (0, 0), font_size: int = 128, padding: int = 100) -> List[Image.Image]:
|
93 |
+
|
94 |
+
imagefont = ImageFont.truetype(str(font_file_path), size=font_size)
|
95 |
+
example_characters = EXAMPLE_CHARACTERS
|
96 |
+
|
97 |
+
font_images: List[Image.Image] = []
|
98 |
+
|
99 |
+
for character in example_characters:
|
100 |
+
x, y, _, _ = imagefont.getbbox(character)
|
101 |
+
img = Image.new(imgmode, (x + padding, y + padding), color='white')
|
102 |
+
draw = ImageDraw.Draw(img)
|
103 |
+
|
104 |
+
# bbox = draw.textbbox((0,0), character, font=imagefont)
|
105 |
+
# w = bbox[2] - bbox[0]
|
106 |
+
# h = bbox[3] - bbox[1]
|
107 |
+
|
108 |
+
w, h = draw.textsize(character, font=imagefont)
|
109 |
+
|
110 |
+
img = Image.new(imgmode, (w + padding, h + padding), color='white')
|
111 |
+
draw = ImageDraw.Draw(img)
|
112 |
+
draw.text(position, text=character, font=imagefont, fill='black')
|
113 |
+
img = img.convert(imgmode)
|
114 |
+
font_images.append(img)
|
115 |
+
|
116 |
+
return font_images
|
117 |
+
|
118 |
+
@staticmethod
|
119 |
+
def get_hex_from_char(char: str) -> str:
|
120 |
+
assert len(char) == 1
|
121 |
+
return f"{ord(char):04X}".upper() # 4-digit hex string
|
122 |
+
|
123 |
+
@torch.no_grad()
|
124 |
+
def inference(self, content_char: str, style_font: Union[str, Path]) -> Tuple[Image.Image, List[Image.Image], Image.Image]:
|
125 |
+
assert len(content_char) > 0
|
126 |
+
content_char = content_char[:1] # only get the first character if the length > 1
|
127 |
+
char_hex = self.get_hex_from_char(content_char)
|
128 |
+
|
129 |
+
if char_hex not in self.content_character_dict:
|
130 |
+
raise ValueError(f"The character {content_char} (hex: {char_hex}) is not supported in this model!")
|
131 |
+
|
132 |
+
content_image = self.set_image(self.content_character_dict[char_hex])
|
133 |
+
style_images: List[Image.Image] = self.get_images_from_fontfile(Path(style_font))
|
134 |
+
style_images: List[Image.Image] = [self.set_image(image) for image in style_images]
|
135 |
+
|
136 |
+
content_image_array = self.pil_image_to_array(content_image)[np.newaxis, np.newaxis, ...] # 1 x C(=1) x H x W
|
137 |
+
style_images_array: np.ndarray = np.array([self.pil_image_to_array(image) for image in style_images])[np.newaxis, ...] # 1 x C(=5, # shots) x H x W, k-shots goes to batch
|
138 |
+
|
139 |
+
content_input_tensor = torch.from_numpy(content_image_array).to(self.device)
|
140 |
+
style_input_tensor = torch.from_numpy(style_images_array).to(self.device)
|
141 |
+
|
142 |
+
generated_images: torch.Tensor = self.model((content_input_tensor, style_input_tensor))
|
143 |
+
generated_images = torch.clip(generated_images, 0, 1)
|
144 |
+
assert generated_images.size(0) == 1
|
145 |
+
|
146 |
+
generated_image_numpy = (generated_images[0].cpu().numpy() * 255).astype(np.uint8)[0, ...] # H x W
|
147 |
+
return content_image, style_images, Image.fromarray(generated_image_numpy, mode='L')
|
148 |
+
|
149 |
+
|
150 |
+
if __name__ == '__main__':
|
151 |
+
hp = OmegaConf.load("config/models/google-font.yaml")
|
152 |
+
checkpoint_path = "epoch=199-step=257400.ckpt"
|
153 |
+
content_image_dir = "../DATA/NotoSans"
|
154 |
+
|
155 |
+
servicer = InferenceServicer(hp, checkpoint_path, content_image_dir)
|
156 |
+
|
157 |
+
style_font = "example_fonts/MaShanZheng-Regular.ttf"
|
158 |
+
content_image, style_images, result = servicer.inference("7", style_font)
|
159 |
+
|
160 |
+
content_image.save("result_content.png")
|
161 |
+
for idx, style_image in enumerate(style_images):
|
162 |
+
style_image.save(f"result_style_{idx:02d}.png")
|
163 |
+
result.save("result_generated.png")
|
lightning.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
import importlib
|
8 |
+
import PIL.Image as Image
|
9 |
+
|
10 |
+
import models
|
11 |
+
import datasets
|
12 |
+
from evaluator.ssim import SSIM, MSSSIM
|
13 |
+
import lpips
|
14 |
+
from models.loss import GANHingeLoss
|
15 |
+
from utils import set_logger, magic_image_handler
|
16 |
+
|
17 |
+
NUM_TEST_SAVE_IMAGE = 10
|
18 |
+
|
19 |
+
|
20 |
+
class FontLightningModule(pl.LightningModule):
|
21 |
+
def __init__(self, args):
|
22 |
+
super().__init__()
|
23 |
+
self.args = args
|
24 |
+
|
25 |
+
self.losses = {}
|
26 |
+
self.metrics = {}
|
27 |
+
self.networks = nn.ModuleDict(self.build_models())
|
28 |
+
self.module_keys = list(self.networks.keys())
|
29 |
+
|
30 |
+
self.losses = self.build_losses()
|
31 |
+
self.metrics = self.build_metrics()
|
32 |
+
|
33 |
+
self.opt_tag = {key: None for key in self.networks.keys()}
|
34 |
+
self.sched_tag = {key: None for key in self.networks.keys()}
|
35 |
+
self.sched_use = False
|
36 |
+
# self.automatic_optimization = False
|
37 |
+
|
38 |
+
self.train_d_content = True
|
39 |
+
self.train_d_style = True
|
40 |
+
|
41 |
+
def build_models(self):
|
42 |
+
networks = {}
|
43 |
+
for key, hp_model in self.args.models.items():
|
44 |
+
key_ = key.lower()
|
45 |
+
if 'g' == key_[0]:
|
46 |
+
model_ = models.Generator(hp_model)
|
47 |
+
elif 'd' == key_[0]:
|
48 |
+
model_ = models.PatchGANDiscriminator(hp_model) # TODO: add option for selecting discriminator
|
49 |
+
else:
|
50 |
+
raise ValueError(f"No key such as {key}")
|
51 |
+
|
52 |
+
networks[key.lower()] = model_
|
53 |
+
return networks
|
54 |
+
|
55 |
+
def build_losses(self):
|
56 |
+
losses_dict = {}
|
57 |
+
losses_dict['L1'] = torch.nn.L1Loss()
|
58 |
+
|
59 |
+
if 'd_content' in self.module_keys:
|
60 |
+
losses_dict['GANLoss_content'] = GANHingeLoss()
|
61 |
+
if 'd_style' in self.module_keys:
|
62 |
+
losses_dict['GANLoss_style'] = GANHingeLoss()
|
63 |
+
|
64 |
+
return losses_dict
|
65 |
+
|
66 |
+
def build_metrics(self):
|
67 |
+
metrics_dict = nn.ModuleDict()
|
68 |
+
metrics_dict['ssim'] = SSIM(val_range=1) # img value is in [0, 1]
|
69 |
+
metrics_dict['msssim'] = MSSSIM(weights=[0.45, 0.3, 0.25], val_range=1) # since imsize=64, len(weight)<=3
|
70 |
+
metrics_dict['lpips'] = lpips.LPIPS(net='vgg')
|
71 |
+
return metrics_dict
|
72 |
+
|
73 |
+
def configure_optimizers(self):
|
74 |
+
optims = {}
|
75 |
+
for key, args_model in self.args.models.items():
|
76 |
+
key = key.lower()
|
77 |
+
if args_model['optim'] is not None:
|
78 |
+
args_optim = args_model['optim']
|
79 |
+
module, cls = args_optim['class'].rsplit(".", 1)
|
80 |
+
O = getattr(importlib.import_module(module, package=None), cls)
|
81 |
+
o = O([p for p in self.networks[key].parameters() if p.requires_grad],
|
82 |
+
lr=args_optim.lr, betas=args_optim.betas)
|
83 |
+
|
84 |
+
optims[key] = o
|
85 |
+
|
86 |
+
optim_module_keys = optims.keys()
|
87 |
+
|
88 |
+
count = 0
|
89 |
+
optim_list = []
|
90 |
+
|
91 |
+
for _key in self.module_keys:
|
92 |
+
if _key in optim_module_keys:
|
93 |
+
optim_list.append(optims[_key])
|
94 |
+
self.opt_tag[_key] = count
|
95 |
+
count += 1
|
96 |
+
|
97 |
+
return optim_list
|
98 |
+
|
99 |
+
def forward(self, content_images, style_images):
|
100 |
+
return self.networks['g']((content_images, style_images))
|
101 |
+
|
102 |
+
def common_forward(self, batch, batch_idx):
|
103 |
+
loss = {}
|
104 |
+
logs = {}
|
105 |
+
|
106 |
+
content_images = batch['content_images']
|
107 |
+
style_images = batch['style_images']
|
108 |
+
gt_images = batch['gt_images']
|
109 |
+
image_paths = batch['image_paths']
|
110 |
+
char_idx = batch['char_idx']
|
111 |
+
|
112 |
+
generated_images = self(content_images, style_images)
|
113 |
+
|
114 |
+
# l1 loss
|
115 |
+
loss['g_L1'] = self.losses['L1'](generated_images, gt_images)
|
116 |
+
loss['g_backward'] = loss['g_L1'] * self.args.logging.lambda_L1
|
117 |
+
|
118 |
+
# loss for training generator
|
119 |
+
if 'd_content' in self.module_keys:
|
120 |
+
loss = self.d_content_loss_for_G(content_images, generated_images, loss)
|
121 |
+
|
122 |
+
if 'd_style' in self.networks.keys():
|
123 |
+
loss = self.d_style_loss_for_G(style_images, generated_images, loss)
|
124 |
+
|
125 |
+
# loss for training discriminator
|
126 |
+
generated_images = generated_images.detach()
|
127 |
+
|
128 |
+
if 'd_content' in self.module_keys:
|
129 |
+
if self.train_d_content:
|
130 |
+
loss = self.d_content_loss_for_D(content_images, generated_images, gt_images, loss)
|
131 |
+
|
132 |
+
if 'd_style' in self.module_keys:
|
133 |
+
if self.train_d_style:
|
134 |
+
loss = self.d_style_loss_for_D(style_images, generated_images, gt_images, loss)
|
135 |
+
|
136 |
+
logs['content_images'] = content_images
|
137 |
+
logs['style_images'] = style_images
|
138 |
+
logs['gt_images'] = gt_images
|
139 |
+
logs['generated_images'] = generated_images
|
140 |
+
|
141 |
+
return loss, logs
|
142 |
+
|
143 |
+
@property
|
144 |
+
def automatic_optimization(self):
|
145 |
+
return False
|
146 |
+
|
147 |
+
def training_step(self, batch, batch_idx):
|
148 |
+
metrics = {}
|
149 |
+
# forward
|
150 |
+
loss, logs = self.common_forward(batch, batch_idx)
|
151 |
+
|
152 |
+
if self.global_step % self.args.logging.freq['train'] == 0:
|
153 |
+
with torch.no_grad():
|
154 |
+
metrics.update(self.calc_metrics(logs['gt_images'], logs['generated_images']))
|
155 |
+
|
156 |
+
# backward
|
157 |
+
opts = self.optimizers()
|
158 |
+
|
159 |
+
opts[self.opt_tag['g']].zero_grad()
|
160 |
+
self.manual_backward(loss['g_backward'])
|
161 |
+
|
162 |
+
if 'd_content' in self.module_keys:
|
163 |
+
if self.train_d_content:
|
164 |
+
opts[self.opt_tag['d_content']].zero_grad()
|
165 |
+
self.manual_backward(loss['dcontent_backward'])
|
166 |
+
|
167 |
+
if 'd_style' in self.module_keys:
|
168 |
+
if self.train_d_style:
|
169 |
+
opts[self.opt_tag['d_style']].zero_grad()
|
170 |
+
self.manual_backward(loss['dstyle_backward'])
|
171 |
+
|
172 |
+
opts[self.opt_tag['g']].step()
|
173 |
+
|
174 |
+
if 'd_content' in self.module_keys:
|
175 |
+
if self.train_d_content:
|
176 |
+
opts[self.opt_tag['d_content']].step()
|
177 |
+
|
178 |
+
if 'd_style' in self.module_keys:
|
179 |
+
if self.train_d_style:
|
180 |
+
opts[self.opt_tag['d_style']].step()
|
181 |
+
|
182 |
+
if self.global_step % self.args.logging.freq['train'] == 0:
|
183 |
+
self.custom_log(loss, metrics, logs, mode='train')
|
184 |
+
|
185 |
+
def validation_step(self, batch, batch_idx):
|
186 |
+
metrics = {}
|
187 |
+
loss, logs = self.common_forward(batch, batch_idx)
|
188 |
+
self.custom_log(loss, metrics, logs, mode='eval')
|
189 |
+
|
190 |
+
def test_step(self, batch, batch_idx):
|
191 |
+
metrics = {}
|
192 |
+
loss, logs = self.common_forward(batch, batch_idx)
|
193 |
+
metrics.update(self.calc_metrics(logs['gt_images'], logs['generated_images']))
|
194 |
+
|
195 |
+
if batch_idx < NUM_TEST_SAVE_IMAGE:
|
196 |
+
for key, value in logs.items():
|
197 |
+
if 'image' in key:
|
198 |
+
sample_images = (magic_image_handler(value) * 255)[..., 0].astype(np.uint8)
|
199 |
+
Image.fromarray(sample_images).save(f"{batch_idx:02d}_{key}.png")
|
200 |
+
|
201 |
+
return loss, logs, metrics
|
202 |
+
|
203 |
+
def test_epoch_end(self, test_step_outputs):
|
204 |
+
# do something with the outputs of all test batches
|
205 |
+
# all_test_preds = test_step_outputs.metrics
|
206 |
+
ssim_list = []
|
207 |
+
msssim_list = []
|
208 |
+
|
209 |
+
for _, test_output in enumerate(test_step_outputs):
|
210 |
+
|
211 |
+
ssim_list.append(test_output[2]['SSIM'].cpu().numpy())
|
212 |
+
msssim_list.append(test_output[2]['MSSSIM'].cpu().numpy())
|
213 |
+
|
214 |
+
print(f"SSIM: {np.mean(ssim_list)}")
|
215 |
+
print(f"MSSSIM: {np.mean(msssim_list)}")
|
216 |
+
|
217 |
+
def common_dataloader(self, mode='train', batch_size=None):
|
218 |
+
dataset_cls = getattr(datasets, self.args.datasets.type)
|
219 |
+
dataset_config = getattr(self.args.datasets, mode)
|
220 |
+
dataset = dataset_cls(dataset_config, mode=mode)
|
221 |
+
_batch_size = batch_size if batch_size is not None else dataset_config.batch_size
|
222 |
+
dataloader = DataLoader(dataset,
|
223 |
+
shuffle=dataset_config.shuffle,
|
224 |
+
batch_size=_batch_size,
|
225 |
+
num_workers=dataset_config.num_workers,
|
226 |
+
drop_last=True)
|
227 |
+
|
228 |
+
return dataloader
|
229 |
+
|
230 |
+
def train_dataloader(self):
|
231 |
+
return self.common_dataloader(mode='train')
|
232 |
+
|
233 |
+
def val_dataloader(self):
|
234 |
+
return self.common_dataloader(mode='eval')
|
235 |
+
|
236 |
+
def test_dataloader(self):
|
237 |
+
return self.common_dataloader(mode='eval')
|
238 |
+
|
239 |
+
def calc_metrics(self, gt_images, generated_images):
|
240 |
+
"""
|
241 |
+
|
242 |
+
:param gt_images:
|
243 |
+
:param generated_images:
|
244 |
+
:return:
|
245 |
+
"""
|
246 |
+
metrics = {}
|
247 |
+
_gt = torch.clamp(gt_images.clone(), 0, 1)
|
248 |
+
_gen = torch.clamp(generated_images.clone(), 0, 1)
|
249 |
+
metrics['SSIM'] = self.metrics['ssim'](_gt, _gen)
|
250 |
+
msssim_value = self.metrics['msssim'](_gt, _gen)
|
251 |
+
metrics['MSSSIM'] = msssim_value if not torch.isnan(msssim_value) else torch.tensor(0.).type_as(_gt)
|
252 |
+
metrics['LPIPS'] = self.metrics['lpips'](_gt * 2 - 1, _gen * 2 - 1).squeeze().mean()
|
253 |
+
return metrics
|
254 |
+
|
255 |
+
# region step
|
256 |
+
def d_content_loss_for_G(self, content_images, generated_images, loss):
|
257 |
+
pred_generated = self.networks['d_content'](torch.cat([content_images, generated_images], dim=1))
|
258 |
+
loss['g_gan_content'] = self.losses['GANLoss_content'](pred_generated, True, for_discriminator=False)
|
259 |
+
|
260 |
+
loss['g_backward'] += loss['g_gan_content']
|
261 |
+
return loss
|
262 |
+
|
263 |
+
def d_content_loss_for_D(self, content_images, generated_images, gt_images, loss):
|
264 |
+
# D
|
265 |
+
if 'd_content' in self.module_keys:
|
266 |
+
if self.train_d_content:
|
267 |
+
pred_gt_images = self.networks['d_content'](torch.cat([content_images, gt_images], dim=1))
|
268 |
+
pred_generated_images = self.networks['d_content'](torch.cat([content_images, generated_images], dim=1))
|
269 |
+
|
270 |
+
loss['dcontent_gt'] = self.losses['GANLoss_content'](pred_gt_images, True, for_discriminator=True)
|
271 |
+
loss['dcontent_gen'] = self.losses['GANLoss_content'](pred_generated_images, False, for_discriminator=True)
|
272 |
+
loss['dcontent_backward'] = (loss['dcontent_gt'] + loss['dcontent_gen'])
|
273 |
+
|
274 |
+
return loss
|
275 |
+
|
276 |
+
def d_style_loss_for_G(self, style_images, generated_images, loss):
|
277 |
+
pred_generated = self.networks['d_style'](torch.cat([style_images, generated_images], dim=1))
|
278 |
+
loss['g_gan_style'] = self.losses['GANLoss_style'](pred_generated, True, for_discriminator=False)
|
279 |
+
|
280 |
+
assert self.train_d_style
|
281 |
+
loss['g_backward'] += loss['g_gan_style']
|
282 |
+
return loss
|
283 |
+
|
284 |
+
def d_style_loss_for_D(self, style_images, generated_images, gt_images, loss):
|
285 |
+
pred_gt_images = self.networks['d_style'](torch.cat([style_images, gt_images], dim=1))
|
286 |
+
pred_generated_images = self.networks['d_style'](torch.cat([style_images, generated_images], dim=1))
|
287 |
+
|
288 |
+
loss['dstyle_gt'] = self.losses['GANLoss_style'](pred_gt_images, True, for_discriminator=True)
|
289 |
+
loss['dstyle_gen'] = self.losses['GANLoss_style'](pred_generated_images, False, for_discriminator=True)
|
290 |
+
loss['dstyle_backward'] = (loss['dstyle_gt'] + loss['dstyle_gen'])
|
291 |
+
|
292 |
+
return loss
|
293 |
+
|
294 |
+
def custom_log(self, loss, metrics, logs, mode):
|
295 |
+
# logging values with tensorboard
|
296 |
+
for loss_full_key, value in loss.items():
|
297 |
+
model_type, loss_type = loss_full_key.split('_')[0], "_".join(loss_full_key.split('_')[1:])
|
298 |
+
self.log(f'{model_type}/{mode}_{loss_type}', value)
|
299 |
+
|
300 |
+
for metric_full_key, value in metrics.items():
|
301 |
+
model_type, metric_type = metric_full_key.split('_')[0], "_".join(metric_full_key.split('_')[1:])
|
302 |
+
self.log(f'{model_type}/{mode}_{metric_type}', value)
|
303 |
+
|
304 |
+
# logging images, params, etc.
|
305 |
+
tensorboard = self.logger.experiment
|
306 |
+
for key, value in logs.items():
|
307 |
+
if 'image' in key:
|
308 |
+
sample_images = magic_image_handler(value)
|
309 |
+
tensorboard.add_image(f"{mode}/" + key, sample_images, self.global_step, dataformats='HWC')
|
310 |
+
elif 'param' in key:
|
311 |
+
tensorboard.add_histogram(f"{mode}" + key, value, self.global_step)
|
312 |
+
else:
|
313 |
+
raise RuntimeError(f"Only logging with one of keywords: image, param | current input: {key}")
|
models/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .generator import *
|
2 |
+
from .discriminator import *
|
models/decoder.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from models.module import ResidualBlocks
|
4 |
+
|
5 |
+
_DECODER_CHANNEL_DEFAULT = 512
|
6 |
+
|
7 |
+
|
8 |
+
class Decoder(nn.Module):
|
9 |
+
def __init__(self, hp, in_channels=_DECODER_CHANNEL_DEFAULT, out_channels=1):
|
10 |
+
super().__init__()
|
11 |
+
self.module = nn.ModuleList()
|
12 |
+
|
13 |
+
def forward(self, x):
|
14 |
+
for block in self.module:
|
15 |
+
x = block(x)
|
16 |
+
return x
|
17 |
+
|
18 |
+
|
19 |
+
class VanillaDecoder(Decoder):
|
20 |
+
def __init__(self, hp, in_channels, out_channels):
|
21 |
+
super().__init__(hp, in_channels, out_channels)
|
22 |
+
self.depth = hp.decoder.depth
|
23 |
+
self.blocks = hp.decoder.residual_blocks
|
24 |
+
|
25 |
+
self.module = nn.ModuleList()
|
26 |
+
if self.blocks > 0:
|
27 |
+
self.module.append(ResidualBlocks(in_channels, n_blocks=self.blocks))
|
28 |
+
|
29 |
+
for layer_idx in range(1, self.depth + 1): # add upsampling layers
|
30 |
+
self.module.append(nn.Sequential(
|
31 |
+
nn.ConvTranspose2d(in_channels // (2 ** (layer_idx - 1)),
|
32 |
+
in_channels // (2 ** layer_idx),
|
33 |
+
kernel_size=3, stride=2,
|
34 |
+
padding=1, output_padding=1,
|
35 |
+
bias=False),
|
36 |
+
nn.BatchNorm2d(in_channels // (2 ** layer_idx)),
|
37 |
+
nn.ReLU(True)
|
38 |
+
))
|
39 |
+
|
40 |
+
final = nn.Sequential(
|
41 |
+
nn.Conv2d(in_channels // (2 ** self.depth), out_channels, kernel_size=7, padding=3, padding_mode='reflect'),
|
42 |
+
nn.Tanh()
|
43 |
+
)
|
44 |
+
|
45 |
+
self.module.append(final)
|
models/discriminator.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
import omegaconf
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
# FIXME
|
9 |
+
|
10 |
+
|
11 |
+
class PatchGANDiscriminator(nn.Module):
|
12 |
+
"""Defines a PatchGAN discriminator"""
|
13 |
+
|
14 |
+
def __init__(self, hp, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
|
15 |
+
"""Construct a PatchGAN discriminator
|
16 |
+
|
17 |
+
Parameters:
|
18 |
+
ndf (int) -- the number of filters in the last conv layer
|
19 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
20 |
+
norm_layer -- normalization layer
|
21 |
+
"""
|
22 |
+
super().__init__()
|
23 |
+
self.hp = hp
|
24 |
+
in_channels = hp.in_channels
|
25 |
+
|
26 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
27 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
28 |
+
else:
|
29 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
30 |
+
kw = 4
|
31 |
+
padw = 1
|
32 |
+
sequence = [nn.Conv2d(in_channels, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
33 |
+
nf_mult = 1
|
34 |
+
nf_mult_prev = 1
|
35 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
36 |
+
nf_mult_prev = nf_mult
|
37 |
+
nf_mult = min(2 ** n, 8)
|
38 |
+
sequence += [
|
39 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
40 |
+
norm_layer(ndf * nf_mult),
|
41 |
+
nn.LeakyReLU(0.2, True)
|
42 |
+
]
|
43 |
+
nf_mult_prev = nf_mult
|
44 |
+
nf_mult = min(2 ** n_layers, 8)
|
45 |
+
sequence += [
|
46 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
47 |
+
norm_layer(ndf * nf_mult),
|
48 |
+
nn.LeakyReLU(0.2, True)
|
49 |
+
]
|
50 |
+
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
|
51 |
+
self.model = nn.Sequential(*sequence)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
return self.model(x)
|
models/encoder.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from models.module import Conv2d, StyleAttentionBlock
|
4 |
+
|
5 |
+
_ENCODER_CHANNEL_DEFAULT = 256
|
6 |
+
|
7 |
+
|
8 |
+
class Encoder(nn.Module):
|
9 |
+
def __init__(self, hp, in_channels=1, out_channels=_ENCODER_CHANNEL_DEFAULT):
|
10 |
+
super().__init__()
|
11 |
+
self.hp = hp
|
12 |
+
self.module = nn.ModuleList()
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
for block in self.module:
|
16 |
+
x = block(x)
|
17 |
+
return x
|
18 |
+
|
19 |
+
|
20 |
+
class ContentVanillaEncoder(Encoder):
|
21 |
+
def __init__(self, hp, in_channels, out_channels):
|
22 |
+
super().__init__(hp, in_channels, out_channels)
|
23 |
+
self.depth = hp.encoder.content.depth
|
24 |
+
assert out_channels // (2 ** self.depth) >= in_channels * 2, "Output channel should be increased"
|
25 |
+
|
26 |
+
self.module = nn.ModuleList()
|
27 |
+
self.module.append(
|
28 |
+
Conv2d(in_channels, out_channels // (2 ** self.depth),
|
29 |
+
kernel_size=7, padding=3, padding_mode='reflect', bias=False)
|
30 |
+
)
|
31 |
+
|
32 |
+
for layer_idx in range(1, self.depth + 1): # downsample
|
33 |
+
self.module.append(
|
34 |
+
Conv2d(out_channels // (2 ** (self.depth - layer_idx + 1)),
|
35 |
+
out_channels // (2 ** (self.depth - layer_idx)),
|
36 |
+
kernel_size=3, stride=2, padding=1, bias=False)
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
class StyleVanillaEncoder(Encoder):
|
41 |
+
def __init__(self, hp, in_channels, out_channels):
|
42 |
+
super().__init__(hp, in_channels, out_channels)
|
43 |
+
self.depth = hp.encoder.style.depth
|
44 |
+
assert out_channels // (2 ** self.depth) >= in_channels * 2, "Output channel should be increased"
|
45 |
+
|
46 |
+
encoder_module = []
|
47 |
+
encoder_module.append(
|
48 |
+
Conv2d(in_channels, out_channels // (2 ** self.depth),
|
49 |
+
kernel_size=7, padding=3, padding_mode='reflect', bias=False)
|
50 |
+
)
|
51 |
+
|
52 |
+
for layer_idx in range(1, self.depth + 1): # downsample
|
53 |
+
encoder_module.append(
|
54 |
+
Conv2d(out_channels // (2 ** (self.depth - layer_idx + 1)),
|
55 |
+
out_channels // (2 ** (self.depth - layer_idx)),
|
56 |
+
kernel_size=3, stride=2, padding=1, bias=False)
|
57 |
+
)
|
58 |
+
self.add_module("encoder_module", nn.Sequential(*encoder_module))
|
59 |
+
self.add_module("attention_module", StyleAttentionBlock(out_channels))
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
B, K, H, W = x.size()
|
63 |
+
out = self.encoder_module(x.view(-1, 1, H, W))
|
64 |
+
out = self.attention_module(out, B, K)
|
65 |
+
return out
|
models/generator.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from . import encoder, decoder
|
4 |
+
|
5 |
+
class Generator(nn.Module):
|
6 |
+
def __init__(self, hp, in_channels=1):
|
7 |
+
super().__init__()
|
8 |
+
self.hp = hp
|
9 |
+
_ngf = 64
|
10 |
+
hidden_dim = _ngf * 4
|
11 |
+
self.content_encoder = getattr(encoder, self.hp.encoder.content.type)(self.hp, in_channels, hidden_dim)
|
12 |
+
self.style_encoder = getattr(encoder, self.hp.encoder.style.type)(self.hp, in_channels, hidden_dim)
|
13 |
+
self.decoder = getattr(decoder, self.hp.decoder.type)(self.hp, hidden_dim * 2, in_channels)
|
14 |
+
|
15 |
+
def forward(self, images):
|
16 |
+
content_images, style_images = images
|
17 |
+
content_feature = self.content_encoder(content_images)
|
18 |
+
style_images = style_images * 2 - 1 # pixel value range -1 to 1
|
19 |
+
style_feature = self.style_encoder(style_images) # K-shot as batch
|
20 |
+
_, _, H, W = content_feature.size()
|
21 |
+
out = self.decoder(torch.cat([content_feature, style_feature.expand(-1, -1, H, W)], dim=1))
|
22 |
+
return out
|
models/loss.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class GANHingeLoss(nn.Module):
|
5 |
+
def __init__(self):
|
6 |
+
super(GANHingeLoss, self).__init__()
|
7 |
+
self.relu = nn.ReLU()
|
8 |
+
|
9 |
+
def __call__(self, pred, is_real, for_discriminator):
|
10 |
+
if for_discriminator:
|
11 |
+
if is_real:
|
12 |
+
return self.relu(1 - pred).mean()
|
13 |
+
return self.relu(1 + pred).mean()
|
14 |
+
|
15 |
+
assert is_real, "The generator's hinge loss must be aiming for real"
|
16 |
+
return -1.0 * pred.mean()
|
models/module.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class Conv2d(nn.Module):
|
6 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
7 |
+
padding=0, padding_mode='zeros', bias=True, residual=False):
|
8 |
+
super(Conv2d, self).__init__()
|
9 |
+
self.conv_block = nn.Sequential(
|
10 |
+
nn.Conv2d(in_channels, out_channels, kernel_size, stride,
|
11 |
+
padding, padding_mode=padding_mode, bias=bias),
|
12 |
+
nn.BatchNorm2d(out_channels)
|
13 |
+
)
|
14 |
+
self.residual = residual
|
15 |
+
self.act = nn.ReLU()
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
out = self.conv_block(x)
|
19 |
+
if self.residual:
|
20 |
+
out += x
|
21 |
+
out = self.act(out)
|
22 |
+
return out
|
23 |
+
|
24 |
+
|
25 |
+
class ResnetBlock(nn.Module):
|
26 |
+
def __init__(self, channel, padding_mode, norm_layer=nn.BatchNorm2d, bias=False):
|
27 |
+
super().__init__()
|
28 |
+
if padding_mode not in ['reflect', 'zero']:
|
29 |
+
raise NotImplementedError(f"{padding_mode} is not supported!")
|
30 |
+
|
31 |
+
self.block = nn.Sequential(
|
32 |
+
nn.Conv2d(channel, channel, kernel_size=3, padding=1, padding_mode=padding_mode, bias=bias),
|
33 |
+
norm_layer(channel)
|
34 |
+
)
|
35 |
+
self.act = nn.ReLU()
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
out = self.block(x)
|
39 |
+
out = out + x
|
40 |
+
out = self.act(out)
|
41 |
+
return out
|
42 |
+
|
43 |
+
|
44 |
+
class ResidualBlocks(nn.Module):
|
45 |
+
def __init__(self, channel, n_blocks=6):
|
46 |
+
super().__init__()
|
47 |
+
model = []
|
48 |
+
for i in range(n_blocks): # add ResNet blocks
|
49 |
+
model += [ResnetBlock(channel, padding_mode='reflect')]
|
50 |
+
|
51 |
+
self.module = nn.Sequential(*model)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
return self.module(x)
|
55 |
+
|
56 |
+
|
57 |
+
class SelfAttentionBlock(nn.Module):
|
58 |
+
|
59 |
+
def __init__(self, in_dim):
|
60 |
+
super().__init__()
|
61 |
+
self.feature_dim = in_dim // 8
|
62 |
+
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.feature_dim, kernel_size=1)
|
63 |
+
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.feature_dim, kernel_size=1)
|
64 |
+
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
|
65 |
+
self.gamma = nn.Parameter(torch.zeros(1))
|
66 |
+
self.softmax = nn.Softmax(dim=-1)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
B, C, H, W = x.size()
|
70 |
+
_query = self.query_conv(x).view(B, -1, H * W).permute(0, 2, 1) # B x C x (H'*W')
|
71 |
+
_key = self.key_conv(x).view(B, -1, H * W) # B x C x (H'*W')
|
72 |
+
attn_matrix = torch.bmm(_query, _key)
|
73 |
+
attention = self.softmax(attn_matrix) # B x (H'*W') x (H'*W')
|
74 |
+
_value = self.value_conv(x).view(B, -1, H * W) # B X C X (H * W)
|
75 |
+
|
76 |
+
out = torch.bmm(_value, attention.permute(0, 2, 1))
|
77 |
+
out = out.view(B, C, H, W)
|
78 |
+
|
79 |
+
out = self.gamma * out + x
|
80 |
+
return out
|
81 |
+
|
82 |
+
|
83 |
+
class ContextAwareAttentionBlock(nn.Module):
|
84 |
+
|
85 |
+
def __init__(self, in_channels, hidden_dim=128):
|
86 |
+
super().__init__()
|
87 |
+
self.self_attn = SelfAttentionBlock(in_channels)
|
88 |
+
self.fc = nn.Linear(in_channels, hidden_dim)
|
89 |
+
self.context_vector = nn.Linear(hidden_dim, 1, bias=False)
|
90 |
+
self.softmax = nn.Softmax(dim=1)
|
91 |
+
|
92 |
+
def forward(self, style_features):
|
93 |
+
B, C, H, W = style_features.size()
|
94 |
+
h = self.self_attn(style_features)
|
95 |
+
h = h.permute(0, 2, 3, 1).reshape(-1, C)
|
96 |
+
h = torch.tanh(self.fc(h)) # (B*H*W) x self.hidden_dim
|
97 |
+
h = self.context_vector(h) # (B*H*W) x 1
|
98 |
+
attention_score = self.softmax(h.view(B, H * W)).view(B, 1, H, W) # B x 1 x H x W
|
99 |
+
return torch.sum(style_features * attention_score, dim=[2, 3]) # B x C
|
100 |
+
|
101 |
+
|
102 |
+
class LayerAttentionBlock(nn.Module):
|
103 |
+
"""from FTransGAN
|
104 |
+
"""
|
105 |
+
|
106 |
+
def __init__(self, in_channels):
|
107 |
+
super().__init__()
|
108 |
+
self.in_channels = in_channels
|
109 |
+
self.width_feat = 4
|
110 |
+
self.height_feat = 4
|
111 |
+
self.fc = nn.Linear(self.in_channels * self.width_feat * self.height_feat, 3)
|
112 |
+
self.softmax = nn.Softmax(dim=1)
|
113 |
+
|
114 |
+
def forward(self, style_features, style_features_1, style_features_2, style_features_3, B, K):
|
115 |
+
style_features = torch.mean(style_features.view(B, K, self.in_channels, self.height_feat, self.width_feat), dim=1)
|
116 |
+
style_features = style_features.view(B, -1)
|
117 |
+
weight = self.softmax(self.fc(style_features))
|
118 |
+
|
119 |
+
style_features_1 = torch.mean(style_features_1.view(B, K, self.in_channels), dim=1)
|
120 |
+
style_features_2 = torch.mean(style_features_2.view(B, K, self.in_channels), dim=1)
|
121 |
+
style_features_3 = torch.mean(style_features_3.view(B, K, self.in_channels), dim=1)
|
122 |
+
|
123 |
+
style_features = (style_features_1 * weight.narrow(1, 0, 1) +
|
124 |
+
style_features_2 * weight.narrow(1, 1, 1) +
|
125 |
+
style_features_3 * weight.narrow(1, 2, 1))
|
126 |
+
style_features = style_features.view(B, self.in_channels, 1, 1)
|
127 |
+
return style_features
|
128 |
+
|
129 |
+
|
130 |
+
class StyleAttentionBlock(nn.Module):
|
131 |
+
"""from FTransGAN
|
132 |
+
"""
|
133 |
+
|
134 |
+
def __init__(self, in_channels):
|
135 |
+
super().__init__()
|
136 |
+
self.num_local_attention = 3
|
137 |
+
for module_idx in range(1, self.num_local_attention + 1):
|
138 |
+
self.add_module(f"local_attention_{module_idx}",
|
139 |
+
ContextAwareAttentionBlock(in_channels))
|
140 |
+
|
141 |
+
for module_idx in range(1, self.num_local_attention):
|
142 |
+
self.add_module(f"downsample_{module_idx}",
|
143 |
+
Conv2d(in_channels, in_channels,
|
144 |
+
kernel_size=3, stride=2, padding=1, bias=False))
|
145 |
+
|
146 |
+
self.add_module(f"layer_attention", LayerAttentionBlock(in_channels))
|
147 |
+
|
148 |
+
def forward(self, x, B, K):
|
149 |
+
feature_1 = self.local_attention_1(x)
|
150 |
+
|
151 |
+
x = self.downsample_1(x)
|
152 |
+
feature_2 = self.local_attention_2(x)
|
153 |
+
|
154 |
+
x = self.downsample_2(x)
|
155 |
+
feature_3 = self.local_attention_3(x)
|
156 |
+
|
157 |
+
out = self.layer_attention(x, feature_1, feature_2, feature_3, B, K)
|
158 |
+
|
159 |
+
return out
|
pretrained/.gitkeep
ADDED
File without changes
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pytorch-lightning==1.6.0
|
2 |
+
omegaconf
|
3 |
+
fire
|
4 |
+
lpips
|
5 |
+
tensorboard
|
6 |
+
pillow==8.4.0
|
trainer.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import glob
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from omegaconf import OmegaConf
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
8 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
9 |
+
|
10 |
+
from lightning import FontLightningModule
|
11 |
+
from utils import save_files
|
12 |
+
|
13 |
+
|
14 |
+
def load_configuration(path_config):
|
15 |
+
setting = OmegaConf.load(path_config)
|
16 |
+
|
17 |
+
# load hyperparameter
|
18 |
+
hp = OmegaConf.load(setting.config.dataset)
|
19 |
+
hp = OmegaConf.merge(hp, OmegaConf.load(setting.config.model))
|
20 |
+
hp = OmegaConf.merge(hp, OmegaConf.load(setting.config.logging))
|
21 |
+
|
22 |
+
# with lightning setting
|
23 |
+
if hasattr(setting.config, 'lightning'):
|
24 |
+
pl_config = OmegaConf.load(setting.config.lightning)
|
25 |
+
if hasattr(pl_config, 'pl_config'):
|
26 |
+
return hp, pl_config.pl_config
|
27 |
+
return hp, pl_config
|
28 |
+
|
29 |
+
# without lightning setting
|
30 |
+
return hp
|
31 |
+
|
32 |
+
|
33 |
+
def parse_args():
|
34 |
+
parser = argparse.ArgumentParser(description='Code to train font style transfer')
|
35 |
+
|
36 |
+
parser.add_argument("--config", type=str, default="./config/setting.yaml",
|
37 |
+
help="Config file for training")
|
38 |
+
parser.add_argument('-g', '--gpus', type=str, default='0,1',
|
39 |
+
help="Number of gpus to use (e.g. '0,1,2,3'). Will use all if not given.")
|
40 |
+
parser.add_argument('-p', '--resume_checkpoint_path', type=str, default=None,
|
41 |
+
help="path of checkpoint for resuming")
|
42 |
+
|
43 |
+
args = parser.parse_args()
|
44 |
+
return args
|
45 |
+
|
46 |
+
|
47 |
+
def main():
|
48 |
+
args = parse_args()
|
49 |
+
hp, pl_config = load_configuration(args.config)
|
50 |
+
|
51 |
+
logging_dir = Path(hp.logging.log_dir)
|
52 |
+
|
53 |
+
# call lightning module
|
54 |
+
font_pl = FontLightningModule(hp)
|
55 |
+
|
56 |
+
# set logging
|
57 |
+
hp.logging['log_dir'] = logging_dir / 'tensorboard'
|
58 |
+
savefiles = []
|
59 |
+
for reg in hp.logging.savefiles:
|
60 |
+
savefiles += glob.glob(reg)
|
61 |
+
hp.logging['log_dir'].mkdir(exist_ok=True)
|
62 |
+
save_files(str(logging_dir), savefiles)
|
63 |
+
|
64 |
+
# set tensorboard logger
|
65 |
+
logger = TensorBoardLogger(str(logging_dir), name=str(hp.logging.seed))
|
66 |
+
|
67 |
+
# set checkpoing callback
|
68 |
+
weights_save_path = logging_dir / 'checkpoint' / str(hp.logging.seed)
|
69 |
+
weights_save_path.mkdir(exist_ok=True)
|
70 |
+
checkpoint_callback = ModelCheckpoint(
|
71 |
+
dirpath=str(weights_save_path),
|
72 |
+
**pl_config.checkpoint.callback
|
73 |
+
)
|
74 |
+
|
75 |
+
# set lightning trainer
|
76 |
+
trainer = pl.Trainer(
|
77 |
+
logger=logger,
|
78 |
+
gpus=-1 if args.gpus is None else args.gpus,
|
79 |
+
callbacks=[checkpoint_callback],
|
80 |
+
**pl_config.trainer
|
81 |
+
)
|
82 |
+
|
83 |
+
# let's train
|
84 |
+
trainer.fit(font_pl)
|
85 |
+
|
86 |
+
|
87 |
+
if __name__ == "__main__":
|
88 |
+
main()
|
trainer.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
python trainer.py --config ./config/setting-google-font.yaml --gpus 0,1,2,3
|
utils/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .logger import *
|
2 |
+
from .tb import *
|
3 |
+
from .util import *
|
utils/logger.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import fire
|
2 |
+
import logging
|
3 |
+
import time
|
4 |
+
|
5 |
+
def _custom_logger(name):
|
6 |
+
fmt = '[{}|%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >>> %(message)s'.format(name)
|
7 |
+
fmt_date = '%Y-%m-%d_%T %Z'
|
8 |
+
|
9 |
+
handler = logging.StreamHandler()
|
10 |
+
|
11 |
+
formatter = logging.Formatter(fmt, fmt_date)
|
12 |
+
handler.setFormatter(formatter)
|
13 |
+
|
14 |
+
logger = logging.getLogger(name)
|
15 |
+
logger.setLevel(logging.DEBUG)
|
16 |
+
logger.addHandler(handler)
|
17 |
+
|
18 |
+
def set_logger(logger_name, level):
|
19 |
+
try:
|
20 |
+
time.tzset()
|
21 |
+
except AttributeError as e:
|
22 |
+
print(e)
|
23 |
+
print("Skipping timezone setting.")
|
24 |
+
_custom_logger(name=logger_name)
|
25 |
+
logger = logging.getLogger(logger_name)
|
26 |
+
if level == 'DEBUG':
|
27 |
+
logger.setLevel(logging.DEBUG)
|
28 |
+
elif level == 'INFO':
|
29 |
+
logger.setLevel(logging.INFO)
|
30 |
+
elif level == 'WARNING':
|
31 |
+
logger.setLevel(logging.WARNING)
|
32 |
+
elif level == 'ERROR':
|
33 |
+
logger.setLevel(logging.ERROR)
|
34 |
+
elif level == 'CRITICAL':
|
35 |
+
logger.setLevel(logging.CRITICAL)
|
36 |
+
return logger
|
37 |
+
|
38 |
+
if __name__ == '__main__':
|
39 |
+
set_logger("test", "DEBUG")
|
utils/tb.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
def magic_image_handler(img):
|
5 |
+
if isinstance(img, torch.Tensor):
|
6 |
+
img = img.detach().cpu().numpy()
|
7 |
+
if img.ndim == 3:
|
8 |
+
img = img.transpose((1, 2, 0))
|
9 |
+
elif img.ndim == 2:
|
10 |
+
img = np.repeat(img[..., np.newaxis], 3, axis=2)
|
11 |
+
elif img.ndim == 4:
|
12 |
+
img = img[:4] # first 4 batch
|
13 |
+
img = np.concatenate(img, axis=-1)
|
14 |
+
img = img.transpose((1, 2, 0))
|
15 |
+
elif img.ndim == 5:
|
16 |
+
img = img[:4] # first 4 batch
|
17 |
+
img = np.concatenate(img, axis=-2)
|
18 |
+
img = np.concatenate(img, axis=-1)
|
19 |
+
img = img.transpose((1, 2, 0))
|
20 |
+
else:
|
21 |
+
raise ValueError(f'img ndim is {img.ndim}, should be 2~4')
|
22 |
+
if img.shape[-1] != 1 or img.shape[-1] != 3:
|
23 |
+
img = np.expand_dims(np.concatenate([img[..., i] for i in range(img.shape[-1])], axis=0), -1)
|
24 |
+
img = np.clip(img, a_min=0, a_max=255)
|
25 |
+
return img
|
utils/util.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import shutil
|
3 |
+
|
4 |
+
|
5 |
+
def save_files(path_save_, savefiles):
|
6 |
+
path_save = Path(path_save_)
|
7 |
+
path_save.mkdir(exist_ok=True)
|
8 |
+
|
9 |
+
for savefile in savefiles:
|
10 |
+
parents_dir = Path(savefile).parents
|
11 |
+
if len(parents_dir) >= 1:
|
12 |
+
for parent_dir in list(parents_dir)[::-1]:
|
13 |
+
target_dir = path_save / parent_dir
|
14 |
+
target_dir.mkdir(exist_ok=True)
|
15 |
+
try:
|
16 |
+
shutil.copy2(savefile, str(path_save / savefile))
|
17 |
+
except Exception as e:
|
18 |
+
# skip the file
|
19 |
+
print(f'{e} occured while saving {savefile}')
|
20 |
+
|
21 |
+
return # success
|
22 |
+
|
23 |
+
|
24 |
+
if __name__ == "__main__":
|
25 |
+
import glob
|
26 |
+
savefiles = glob.glob('config/*.yaml')
|
27 |
+
savefiles += glob.glob('config/**/*.yaml')
|
28 |
+
save_files(".temp", savefiles)
|