Gabor Cselle
commited on
Commit
•
7987245
1
Parent(s):
a2676e8
Refactor; move consts to consts.py
Browse files- .gitignore +3 -1
- arrange_train_test_images.py +7 -8
- consts.py +10 -0
- gen_sample_data.py +36 -39
- train_font_identifier.py +5 -7
.gitignore
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
-
|
2 |
train_test_images
|
3 |
.DS_Store
|
4 |
.ipynb_checkpoints/visualize-checkpoint.ipynb
|
5 |
font_identifier_model.pth
|
|
|
|
|
|
1 |
+
generated_images
|
2 |
train_test_images
|
3 |
.DS_Store
|
4 |
.ipynb_checkpoints/visualize-checkpoint.ipynb
|
5 |
font_identifier_model.pth
|
6 |
+
*.pyc
|
7 |
+
__pycache__
|
arrange_train_test_images.py
CHANGED
@@ -3,18 +3,17 @@
|
|
3 |
import os
|
4 |
import shutil
|
5 |
import random
|
|
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
train_dir = os.path.join(organized_dir, 'train')
|
10 |
-
test_dir = os.path.join(organized_dir, 'test')
|
11 |
|
12 |
# create directories if they don't exist
|
13 |
os.makedirs(train_dir, exist_ok=True)
|
14 |
os.makedirs(test_dir, exist_ok=True)
|
15 |
|
16 |
# make a list of all the font names
|
17 |
-
fonts = [f.split('_')[0] for f in os.listdir(
|
18 |
fonts = list(set(fonts)) # getting unique font names
|
19 |
|
20 |
for font in fonts:
|
@@ -23,7 +22,7 @@ for font in fonts:
|
|
23 |
os.makedirs(font_train_dir, exist_ok=True)
|
24 |
os.makedirs(font_test_dir, exist_ok=True)
|
25 |
|
26 |
-
font_files = [f for f in os.listdir(
|
27 |
random.shuffle(font_files)
|
28 |
|
29 |
train_files = font_files[:int(0.8 * len(font_files))]
|
@@ -31,8 +30,8 @@ for font in fonts:
|
|
31 |
|
32 |
# Move training files
|
33 |
for train_file in train_files:
|
34 |
-
shutil.move(os.path.join(
|
35 |
|
36 |
# Move test files
|
37 |
for test_file in test_files:
|
38 |
-
shutil.move(os.path.join(
|
|
|
3 |
import os
|
4 |
import shutil
|
5 |
import random
|
6 |
+
from consts import TRAIN_TEST_IMAGES_DIR, GEN_IMAGES_DIR
|
7 |
|
8 |
+
train_dir = os.path.join(TRAIN_TEST_IMAGES_DIR, 'train')
|
9 |
+
test_dir = os.path.join(TRAIN_TEST_IMAGES_DIR, 'test')
|
|
|
|
|
10 |
|
11 |
# create directories if they don't exist
|
12 |
os.makedirs(train_dir, exist_ok=True)
|
13 |
os.makedirs(test_dir, exist_ok=True)
|
14 |
|
15 |
# make a list of all the font names
|
16 |
+
fonts = [f.split('_')[0] for f in os.listdir(GEN_IMAGES_DIR) if f.endswith('.png')]
|
17 |
fonts = list(set(fonts)) # getting unique font names
|
18 |
|
19 |
for font in fonts:
|
|
|
22 |
os.makedirs(font_train_dir, exist_ok=True)
|
23 |
os.makedirs(font_test_dir, exist_ok=True)
|
24 |
|
25 |
+
font_files = [f for f in os.listdir(GEN_IMAGES_DIR) if f.startswith(font)]
|
26 |
random.shuffle(font_files)
|
27 |
|
28 |
train_files = font_files[:int(0.8 * len(font_files))]
|
|
|
30 |
|
31 |
# Move training files
|
32 |
for train_file in train_files:
|
33 |
+
shutil.move(os.path.join(GEN_IMAGES_DIR, train_file), font_train_dir)
|
34 |
|
35 |
# Move test files
|
36 |
for test_file in test_files:
|
37 |
+
shutil.move(os.path.join(GEN_IMAGES_DIR, test_file), font_test_dir)
|
consts.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# number of images to generate per font
|
2 |
+
IMAGES_PER_FONT = 50
|
3 |
+
# allowlist of fonts to use
|
4 |
+
FONT_ALLOWLIST = ["Arial", "Avenir", "Courier", "Helvetica", "Georgia", "Tahoma", "Times New Roman", "Verdana"]
|
5 |
+
# directory where to store the generated images
|
6 |
+
GEN_IMAGES_DIR = './generated_images'
|
7 |
+
# images organized into train and test directories
|
8 |
+
TRAIN_TEST_IMAGES_DIR = './train_test_images'
|
9 |
+
# where to grab the font files from
|
10 |
+
FONT_FILE_DIRS = ['/System/Library/Fonts/', '/System/Library/Fonts/Supplemental/']
|
gen_sample_data.py
CHANGED
@@ -6,22 +6,15 @@ from PIL import Image, ImageDraw, ImageFont
|
|
6 |
import nltk
|
7 |
from nltk.corpus import brown
|
8 |
import random
|
9 |
-
|
10 |
-
IMAGES_PER_FONT = 50
|
11 |
|
12 |
# Download the necessary data from nltk
|
13 |
nltk.download('brown')
|
14 |
|
15 |
-
|
16 |
-
font_dirs = ['/System/Library/Fonts/', '/System/Library/Fonts/Supplemental/']
|
17 |
-
output_dir = './font_images'
|
18 |
-
os.makedirs(output_dir, exist_ok=True)
|
19 |
|
20 |
all_brown_words = sorted(set(brown.words(categories='news')))
|
21 |
|
22 |
-
# This is a list of fonts that we want to use for our sample data
|
23 |
-
FONT_ALLOWLIST = ["Arial", "Avenir", "Courier", "Helvetica", "Georgia", "Tahoma", "Times New Roman", "Verdana"]
|
24 |
-
|
25 |
def wrap_text(text, line_length=10):
|
26 |
"""Wraps the provided text every 'line_length' words."""
|
27 |
words = text.split()
|
@@ -37,38 +30,42 @@ def random_code_text(base_code, num_lines=15):
|
|
37 |
lines = base_code.split("\n")
|
38 |
return "\n".join(random.sample(lines, min(num_lines, len(lines))))
|
39 |
|
40 |
-
|
41 |
-
for
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
# ttf fonts have only one font in the file
|
55 |
-
font_size = random.choice(range(32, 128)) # Increased minimum font size
|
56 |
-
font = ImageFont.truetype(font_path, font_size)
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
prose_sample = random_prose_text(all_brown_words)
|
62 |
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
draw.text((offset_x, offset_y), text, fill="black", font=font)
|
71 |
|
72 |
-
|
73 |
-
|
74 |
-
img.save(output_file)
|
|
|
6 |
import nltk
|
7 |
from nltk.corpus import brown
|
8 |
import random
|
9 |
+
from consts import FONT_ALLOWLIST, IMAGES_PER_FONT, GEN_IMAGES_DIR, FONT_FILE_DIRS
|
|
|
10 |
|
11 |
# Download the necessary data from nltk
|
12 |
nltk.download('brown')
|
13 |
|
14 |
+
os.makedirs(GEN_IMAGES_DIR, exist_ok=True)
|
|
|
|
|
|
|
15 |
|
16 |
all_brown_words = sorted(set(brown.words(categories='news')))
|
17 |
|
|
|
|
|
|
|
18 |
def wrap_text(text, line_length=10):
|
19 |
"""Wraps the provided text every 'line_length' words."""
|
20 |
words = text.split()
|
|
|
30 |
lines = base_code.split("\n")
|
31 |
return "\n".join(random.sample(lines, min(num_lines, len(lines))))
|
32 |
|
33 |
+
def main():
|
34 |
+
for font_dir in FONT_FILE_DIRS:
|
35 |
+
for font_file in os.listdir(font_dir):
|
36 |
+
if font_file.endswith('.ttf') or font_file.endswith('.ttc'):
|
37 |
+
font_path = os.path.join(font_dir, font_file)
|
38 |
+
font_name = font_file.split('.')[0]
|
39 |
+
if font_name not in FONT_ALLOWLIST:
|
40 |
+
continue
|
41 |
+
# Output the font name so we can see the progress
|
42 |
+
print(font_path, font_name)
|
43 |
+
|
44 |
+
if font_file.endswith('.ttc'):
|
45 |
+
# ttc fonts have multiple fonts in one file, so we need to specify which one we want
|
46 |
+
font = ImageFont.truetype(font_path, random.choice(range(32, 128)), index=0)
|
47 |
+
else:
|
48 |
+
# ttf fonts have only one font in the file
|
49 |
+
font_size = random.choice(range(32, 128)) # Increased minimum font size
|
50 |
+
font = ImageFont.truetype(font_path, font_size)
|
51 |
|
52 |
+
# Counter for the image filename
|
53 |
+
j = 0
|
54 |
+
for i in range(IMAGES_PER_FONT): # Generate 50 images per font - reduced to 10 for now to make things faster
|
55 |
+
prose_sample = random_prose_text(all_brown_words)
|
|
|
|
|
|
|
56 |
|
57 |
+
for text in [prose_sample]:
|
58 |
+
img = Image.new('RGB', (800, 400), color="white") # Canvas size
|
59 |
+
draw = ImageDraw.Draw(img)
|
|
|
60 |
|
61 |
+
# Random offsets, but ensuring that text isn't too far off the canvas
|
62 |
+
offset_x = random.randint(-20, 10)
|
63 |
+
offset_y = random.randint(-20, 10)
|
64 |
+
draw.text((offset_x, offset_y), text, fill="black", font=font)
|
65 |
|
66 |
+
j += 1
|
67 |
+
output_file = os.path.join(GEN_IMAGES_DIR, f"{font_name}_{j}.png")
|
68 |
+
img.save(output_file)
|
|
|
69 |
|
70 |
+
if __name__ == '__main__':
|
71 |
+
main()
|
|
train_font_identifier.py
CHANGED
@@ -1,15 +1,11 @@
|
|
1 |
-
import copy
|
2 |
import os
|
3 |
-
import time
|
4 |
import torch
|
5 |
import torch.optim as optim
|
6 |
import torch.nn as nn
|
7 |
-
from torch.optim import lr_scheduler
|
8 |
from torchvision import datasets, models, transforms
|
9 |
from tqdm import tqdm
|
10 |
-
|
11 |
-
|
12 |
-
data_dir = './train_test_images'
|
13 |
|
14 |
# Transformations for the image data
|
15 |
data_transforms = transforms.Compose([
|
@@ -21,7 +17,7 @@ data_transforms = transforms.Compose([
|
|
21 |
|
22 |
# Create datasets
|
23 |
image_datasets = {
|
24 |
-
x: datasets.ImageFolder(os.path.join(
|
25 |
for x in ['train', 'test']
|
26 |
}
|
27 |
|
@@ -92,3 +88,5 @@ for epoch in range(num_epochs):
|
|
92 |
|
93 |
# Save the model to disk
|
94 |
torch.save(model.state_dict(), 'font_identifier_model.pth')
|
|
|
|
|
|
|
|
1 |
import os
|
|
|
2 |
import torch
|
3 |
import torch.optim as optim
|
4 |
import torch.nn as nn
|
|
|
5 |
from torchvision import datasets, models, transforms
|
6 |
from tqdm import tqdm
|
7 |
+
import torch
|
8 |
+
from consts import TRAIN_TEST_IMAGES_DIR
|
|
|
9 |
|
10 |
# Transformations for the image data
|
11 |
data_transforms = transforms.Compose([
|
|
|
17 |
|
18 |
# Create datasets
|
19 |
image_datasets = {
|
20 |
+
x: datasets.ImageFolder(os.path.join(TRAIN_TEST_IMAGES_DIR, x), data_transforms)
|
21 |
for x in ['train', 'test']
|
22 |
}
|
23 |
|
|
|
88 |
|
89 |
# Save the model to disk
|
90 |
torch.save(model.state_dict(), 'font_identifier_model.pth')
|
91 |
+
|
92 |
+
|