ThaiVecFont / options.py
microhum's picture
rm dockerfile
024aa56
raw
history blame
5.12 kB
import argparse
def get_parser_main_model():
parser = argparse.ArgumentParser()
# basic parameters training related
parser.add_argument('--model_name', type=str, default='main_model', choices=['main_model', 'neural_raster'], help='current model_name')
parser.add_argument("--language", type=str, default='tha', choices=['eng', 'chn', 'tha'])
parser.add_argument('--bottleneck_bits', type=int, default=512, help='latent code number of bottleneck bits')
parser.add_argument('--char_num', type=int, default=44, help='number of glyphs, original is 44 (Thai)')
parser.add_argument('--seed', type=int, default=3712)
parser.add_argument('--ref_nshot', type=int, default=8, help='reference number')
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
parser.add_argument('--batch_size_val', type=int, default=8, help='batch size when do validation')
parser.add_argument('--img_size', type=int, default=64, help='image size')
parser.add_argument('--max_seq_len', type=int, default=121, help='maximum length of sequence')
parser.add_argument('--dim_seq', type=int, default=12, help='the dim of each stroke in a sequence, 4 + 8, 4 is cmd, and 8 is args')
parser.add_argument('--dim_seq_short', type=int, default=9, help='the short dim of each stroke in a sequence, 1 + 8, 1 is cmd class num, and 8 is args')
parser.add_argument('--hidden_size', type=int, default=512, help='hidden_size')
parser.add_argument('--dim_seq_latent', type=int, default=512, help='sequence encoder latent dim')
parser.add_argument('--ngf', type=int, default=16, help='the basic num of channel in image encoder and decoder')
parser.add_argument('--n_aux_pts', type=int, default=6, help='the number of aux pts in bezier curves for additional supervison')
# experiment related
parser.add_argument('--random_index', type=str, default='00')
parser.add_argument('--name_ckpt', type=str, default='600_192921.ckpt')
parser.add_argument('--model_path', type=str, default='.')
parser.add_argument('--n_epochs', type=int, default=800, help='number of epochs')
parser.add_argument('--n_samples', type=int, default=20, help='the number of samples for each glyph when testing')
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate')
parser.add_argument('--ref_char_ids', type=str, default='0,1,26,27', help='default is A, B, a, b')
parser.add_argument('--mode', type=str, default='test', choices=['train', 'val', 'test'])
parser.add_argument('--multi_gpu', type=bool, default=False)
parser.add_argument('--name_exp', type=str, default='dvf')
# continue training'
parser.add_argument('--continue_training', type=bool, default=False, help='whether continue training from old checkpoint')
parser.add_argument('--continue_ckpt', type=str, default='.', help='checkpoint model for continue training')
parser.add_argument('--init_epoch', type=int, default=0, help='init epoch')
# Manually Add
parser.add_argument('--exp_path', type=str, default='.')
parser.add_argument('--dir_res', type=str, default=None)
parser.add_argument('--data_root', type=str, default='./data/vecfont_dataset/')
parser.add_argument('--freq_ckpt', type=int, default=50, help='save checkpoint frequency of epoch')
parser.add_argument('--threshold_ckpt', type=int, default=0, help='save checkpoint only when more than threshold epoch')
parser.add_argument('--freq_sample', type=int, default=500, help='sample train output of steps')
parser.add_argument('--freq_log', type=int, default=50, help='freq of showing logs')
parser.add_argument('--freq_val', type=int, default=500, help='sample validate output of steps')
parser.add_argument('--beta1', type=float, default=0.9, help='beta1 of Adam optimizer')
parser.add_argument('--beta2', type=float, default=0.999, help='beta2 of Adam optimizer')
parser.add_argument('--eps', type=float, default=1e-8, help='Adam epsilon')
parser.add_argument('--weight_decay', type=float, default=0.0, help='weight decay')
parser.add_argument('--wandb', type=bool, default=True, help='whether use wandb to visulize loss')
parser.add_argument('--wandb_project_name', type=str, default="DeepVecFontV2", help='wandb project name')
# loss weight
parser.add_argument('--kl_beta', type=float, default=0.01, help='latent code kl loss beta')
parser.add_argument('--loss_w_pt_c', type=float, default=0.001 * 10, help='the weight of perceptual content loss')
parser.add_argument('--loss_w_l1', type=float, default=1.0 * 10, help='the weight of image reconstruction l1 loss')
parser.add_argument('--loss_w_cmd', type=float, default=1.0, help='the weight of cmd loss')
parser.add_argument('--loss_w_args', type=float, default=1.0, help='the weight of args loss')
parser.add_argument('--loss_w_aux', type=float, default=0.01, help='the weight of pts aux loss')
parser.add_argument('--loss_w_smt', type=float, default=10., help='the weight of smooth loss')
return parser