Spaces:
Sleeping
Sleeping
import argparse | |
def get_parser_main_model(): | |
parser = argparse.ArgumentParser() | |
# basic parameters training related | |
parser.add_argument('--model_name', type=str, default='main_model', choices=['main_model', 'neural_raster'], help='current model_name') | |
parser.add_argument("--language", type=str, default='tha', choices=['eng', 'chn', 'tha']) | |
parser.add_argument('--bottleneck_bits', type=int, default=512, help='latent code number of bottleneck bits') | |
parser.add_argument('--char_num', type=int, default=44, help='number of glyphs, original is 44 (Thai)') | |
parser.add_argument('--seed', type=int, default=3712) | |
parser.add_argument('--ref_nshot', type=int, default=8, help='reference number') | |
parser.add_argument('--batch_size', type=int, default=64, help='batch size') | |
parser.add_argument('--batch_size_val', type=int, default=8, help='batch size when do validation') | |
parser.add_argument('--img_size', type=int, default=64, help='image size') | |
parser.add_argument('--max_seq_len', type=int, default=121, help='maximum length of sequence') | |
parser.add_argument('--dim_seq', type=int, default=12, help='the dim of each stroke in a sequence, 4 + 8, 4 is cmd, and 8 is args') | |
parser.add_argument('--dim_seq_short', type=int, default=9, help='the short dim of each stroke in a sequence, 1 + 8, 1 is cmd class num, and 8 is args') | |
parser.add_argument('--hidden_size', type=int, default=512, help='hidden_size') | |
parser.add_argument('--dim_seq_latent', type=int, default=512, help='sequence encoder latent dim') | |
parser.add_argument('--ngf', type=int, default=16, help='the basic num of channel in image encoder and decoder') | |
parser.add_argument('--n_aux_pts', type=int, default=6, help='the number of aux pts in bezier curves for additional supervison') | |
# experiment related | |
parser.add_argument('--random_index', type=str, default='00') | |
parser.add_argument('--name_ckpt', type=str, default='600_192921.ckpt') | |
parser.add_argument('--model_path', type=str, default='.') | |
parser.add_argument('--n_epochs', type=int, default=800, help='number of epochs') | |
parser.add_argument('--n_samples', type=int, default=20, help='the number of samples for each glyph when testing') | |
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate') | |
parser.add_argument('--ref_char_ids', type=str, default='0,1,26,27', help='default is A, B, a, b') | |
parser.add_argument('--mode', type=str, default='test', choices=['train', 'val', 'test']) | |
parser.add_argument('--multi_gpu', type=bool, default=False) | |
parser.add_argument('--name_exp', type=str, default='dvf') | |
# continue training' | |
parser.add_argument('--continue_training', type=bool, default=False, help='whether continue training from old checkpoint') | |
parser.add_argument('--continue_ckpt', type=str, default='.', help='checkpoint model for continue training') | |
parser.add_argument('--init_epoch', type=int, default=0, help='init epoch') | |
# Manually Add | |
parser.add_argument('--exp_path', type=str, default='.') | |
parser.add_argument('--dir_res', type=str, default=None) | |
parser.add_argument('--data_root', type=str, default='./data/vecfont_dataset/') | |
parser.add_argument('--freq_ckpt', type=int, default=50, help='save checkpoint frequency of epoch') | |
parser.add_argument('--threshold_ckpt', type=int, default=0, help='save checkpoint only when more than threshold epoch') | |
parser.add_argument('--freq_sample', type=int, default=500, help='sample train output of steps') | |
parser.add_argument('--freq_log', type=int, default=50, help='freq of showing logs') | |
parser.add_argument('--freq_val', type=int, default=500, help='sample validate output of steps') | |
parser.add_argument('--beta1', type=float, default=0.9, help='beta1 of Adam optimizer') | |
parser.add_argument('--beta2', type=float, default=0.999, help='beta2 of Adam optimizer') | |
parser.add_argument('--eps', type=float, default=1e-8, help='Adam epsilon') | |
parser.add_argument('--weight_decay', type=float, default=0.0, help='weight decay') | |
parser.add_argument('--wandb', type=bool, default=True, help='whether use wandb to visulize loss') | |
parser.add_argument('--wandb_project_name', type=str, default="DeepVecFontV2", help='wandb project name') | |
# loss weight | |
parser.add_argument('--kl_beta', type=float, default=0.01, help='latent code kl loss beta') | |
parser.add_argument('--loss_w_pt_c', type=float, default=0.001 * 10, help='the weight of perceptual content loss') | |
parser.add_argument('--loss_w_l1', type=float, default=1.0 * 10, help='the weight of image reconstruction l1 loss') | |
parser.add_argument('--loss_w_cmd', type=float, default=1.0, help='the weight of cmd loss') | |
parser.add_argument('--loss_w_args', type=float, default=1.0, help='the weight of args loss') | |
parser.add_argument('--loss_w_aux', type=float, default=0.01, help='the weight of pts aux loss') | |
parser.add_argument('--loss_w_smt', type=float, default=10., help='the weight of smooth loss') | |
return parser | |