File size: 1,626 Bytes
76b9ad3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import os
#os.environ["TOKENIZERS_PARALLELISM"] = "False"
from aitextgen.tokenizers import train_tokenizer
from aitextgen import aitextgen
from aitextgen.utils import build_gpt2_config
def train_atg_tokenizer():
train_tokenizer("svg_flat.txt", vocab_size=1000)
def prepare_model():
config = build_gpt2_config(vocab_size=1000, max_length=4096, dropout=0.1, n_embd=768, n_layer=8, n_head=12)
ai = aitextgen(tokenizer_file="aitextgen.tokenizer.json", config=config)
ai.save_for_upload("./trained_model")
def do_train():
ai = aitextgen(model_folder="./trained_model", tokenizer_file="aitextgen.tokenizer.json")
ai.train("svg_flat.txt", batch_size=1, num_steps=60000, save_every= 2500, fp16=False, generate_every=1000, learning_rate=0.001)
ai.train("svg_flat.txt", batch_size=1, num_steps=40000, save_every= 2500, fp16=False, generate_every=1000, learning_rate=0.0001)
def do_sample():
ai = aitextgen(model_folder="./trained_model", tokenizer_file="./trained_model/tokenizer.json", to_gpu=True)
ai.generate(prompt="\n", max_length=4000,seed=42,do_sample=True)
# svg_file_header = "<svg width=\"32\" height=\"32\" viewBox=\"0 0 32 32\" fill=\"none\" xmlns=\"http://www.w3.org/2000/svg\">"
# TODO: Extract from generated output and into a seperate .svg file all sequences which starts with svg_file_header and ends with:
# A. </svg>
# B. If the sequence does not end with </svg> then find the last > in the sequence and append </svg> to it
def main():
#train_atg_tokenizer()
#prepare_model()
#do_train()
do_sample()
if __name__ == "__main__":
main()
|