S5 / Essay_classifier /run_train.py
dbal0503's picture
Upload 693 files
2ce7b1a
import argparse
from s5.utils.util import str2bool
from s5.train import train
from s5.dataloading import Datasets
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--USE_WANDB", type=str2bool, default=False,
help="log with wandb?")
parser.add_argument("--wandb_project", type=str, default=None,
help="wandb project name")
parser.add_argument("--wandb_entity", type=str, default=None,
help="wandb entity name, e.g. username")
parser.add_argument("--dir_name", type=str, default='./cache_dir',
help="name of directory where data is cached")
parser.add_argument("--dataset", type=str, choices=Datasets.keys(),
default='mnist-classification',
help="dataset name")
# Model Parameters
parser.add_argument("--n_layers", type=int, default=6,
help="Number of layers in the network")
parser.add_argument("--d_model", type=int, default=128,
help="Number of features, i.e. H, "
"dimension of layer inputs/outputs")
parser.add_argument("--ssm_size_base", type=int, default=256,
help="SSM Latent size, i.e. P")
parser.add_argument("--blocks", type=int, default=8,
help="How many blocks, J, to initialize with")
parser.add_argument("--C_init", type=str, default="trunc_standard_normal",
choices=["trunc_standard_normal", "lecun_normal", "complex_normal"],
help="Options for initialization of C: \\"
"trunc_standard_normal: sample from trunc. std. normal then multiply by V \\ " \
"lecun_normal sample from lecun normal, then multiply by V\\ " \
"complex_normal: sample directly from complex standard normal")
parser.add_argument("--discretization", type=str, default="zoh", choices=["zoh", "bilinear"])
parser.add_argument("--mode", type=str, default="pool", choices=["pool", "last"],
help="options: (for classification tasks) \\" \
" pool: mean pooling \\" \
"last: take last element")
parser.add_argument("--activation_fn", default="half_glu1", type=str,
choices=["full_glu", "half_glu1", "half_glu2", "gelu"])
parser.add_argument("--conj_sym", type=str2bool, default=True,
help="whether to enforce conjugate symmetry")
parser.add_argument("--clip_eigs", type=str2bool, default=False,
help="whether to enforce the left-half plane condition")
parser.add_argument("--bidirectional", type=str2bool, default=False,
help="whether to use bidirectional model")
parser.add_argument("--dt_min", type=float, default=0.001,
help="min value to sample initial timescale params from")
parser.add_argument("--dt_max", type=float, default=0.1,
help="max value to sample initial timescale params from")
# Optimization Parameters
parser.add_argument("--prenorm", type=str2bool, default=True,
help="True: use prenorm, False: use postnorm")
parser.add_argument("--batchnorm", type=str2bool, default=True,
help="True: use batchnorm, False: use layernorm")
parser.add_argument("--bn_momentum", type=float, default=0.95,
help="batchnorm momentum")
parser.add_argument("--bsz", type=int, default=64,
help="batch size")
parser.add_argument("--epochs", type=int, default=100,
help="max number of epochs")
parser.add_argument("--early_stop_patience", type=int, default=1000,
help="number of epochs to continue training when val loss plateaus")
parser.add_argument("--ssm_lr_base", type=float, default=1e-3,
help="initial ssm learning rate")
parser.add_argument("--lr_factor", type=float, default=1,
help="global learning rate = lr_factor*ssm_lr_base")
parser.add_argument("--dt_global", type=str2bool, default=False,
help="Treat timescale parameter as global parameter or SSM parameter")
parser.add_argument("--lr_min", type=float, default=0,
help="minimum learning rate")
parser.add_argument("--cosine_anneal", type=str2bool, default=True,
help="whether to use cosine annealing schedule")
parser.add_argument("--warmup_end", type=int, default=1,
help="epoch to end linear warmup")
parser.add_argument("--lr_patience", type=int, default=1000000,
help="patience before decaying learning rate for lr_decay_on_val_plateau")
parser.add_argument("--reduce_factor", type=float, default=1.0,
help="factor to decay learning rate for lr_decay_on_val_plateau")
parser.add_argument("--p_dropout", type=float, default=0.0,
help="probability of dropout")
parser.add_argument("--weight_decay", type=float, default=0.05,
help="weight decay value")
parser.add_argument("--opt_config", type=str, default="standard", choices=['standard',
'BandCdecay',
'BfastandCdecay',
'noBCdecay'],
help="Opt configurations: \\ " \
"standard: no weight decay on B (ssm lr), weight decay on C (global lr) \\" \
"BandCdecay: weight decay on B (ssm lr), weight decay on C (global lr) \\" \
"BfastandCdecay: weight decay on B (global lr), weight decay on C (global lr) \\" \
"noBCdecay: no weight decay on B (ssm lr), no weight decay on C (ssm lr) \\")
parser.add_argument("--jax_seed", type=int, default=1919,
help="seed randomness")
train(parser.parse_args())